LCOV - code coverage report
Current view: top level - src/libtorch/reporters - DRLRewardReporter.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 13 14 92.9 %
Date: 2025-07-25 05:00:46 Functions: 3 3 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "DRLRewardReporter.h"
      13             : 
      14             : registerMooseObject("StochasticToolsApp", DRLRewardReporter);
      15             : 
      16             : InputParameters
      17          16 : DRLRewardReporter::validParams()
      18             : {
      19          16 :   InputParameters params = GeneralReporter::validParams();
      20          16 :   params += SurrogateModelInterface::validParams();
      21             : 
      22          16 :   params.addClassDescription("Reporter containing the reward values of a DRL controller trainer.");
      23          32 :   params.addRequiredParam<UserObjectName>(
      24             :       "drl_trainer_name", "The name of the RDL controller trainer which computes the rewards.");
      25             : 
      26          16 :   return params;
      27           0 : }
      28             : 
      29           8 : DRLRewardReporter::DRLRewardReporter(const InputParameters & parameters)
      30             :   : GeneralReporter(parameters),
      31             :     SurrogateModelInterface(this),
      32           8 :     _reward(declareValueByName<Real>("average_reward", REPORTER_MODE_ROOT)),
      33          16 :     _trainer(getSurrogateTrainer<LibtorchDRLControlTrainer>("drl_trainer_name"))
      34             : {
      35           8 : }
      36             : 
      37             : void
      38           8 : DRLRewardReporter::execute()
      39             : {
      40           8 :   _reward = _trainer.averageEpisodeReward();
      41           8 : }
      42             : 
      43             : #endif

Generated by: LCOV version 1.14