Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef MOOSE_LIBTORCH_ENABLED 11 : 12 : #include "DRLRewardReporter.h" 13 : 14 : registerMooseObject("StochasticToolsApp", DRLRewardReporter); 15 : 16 : InputParameters 17 16 : DRLRewardReporter::validParams() 18 : { 19 16 : InputParameters params = GeneralReporter::validParams(); 20 16 : params += SurrogateModelInterface::validParams(); 21 : 22 16 : params.addClassDescription("Reporter containing the reward values of a DRL controller trainer."); 23 32 : params.addRequiredParam<UserObjectName>( 24 : "drl_trainer_name", "The name of the RDL controller trainer which computes the rewards."); 25 : 26 16 : return params; 27 0 : } 28 : 29 8 : DRLRewardReporter::DRLRewardReporter(const InputParameters & parameters) 30 : : GeneralReporter(parameters), 31 : SurrogateModelInterface(this), 32 8 : _reward(declareValueByName<Real>("average_reward", REPORTER_MODE_ROOT)), 33 16 : _trainer(getSurrogateTrainer<LibtorchDRLControlTrainer>("drl_trainer_name")) 34 : { 35 8 : } 36 : 37 : void 38 8 : DRLRewardReporter::execute() 39 : { 40 8 : _reward = _trainer.averageEpisodeReward(); 41 8 : } 42 : 43 : #endif