LCOV - code coverage report
Current view: top level - src/libtorch/postprocessors - LibtorchDRLLogProbabilityPostprocessor.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 21 24 87.5 %
Date: 2025-07-25 05:00:46 Functions: 4 4 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchDRLLogProbabilityPostprocessor.h"
      13             : 
      14             : registerMooseObject("StochasticToolsApp", LibtorchDRLLogProbabilityPostprocessor);
      15             : 
      16             : InputParameters
      17          64 : LibtorchDRLLogProbabilityPostprocessor::validParams()
      18             : {
      19          64 :   InputParameters params = GeneralPostprocessor::validParams();
      20             : 
      21          64 :   params.addClassDescription(
      22             :       "Computes the logarithmic probability of the action in a given LibtorchDRLController.");
      23             : 
      24         128 :   params.addRequiredParam<std::string>("control_name",
      25             :                                        "The name of the LibtorchNeuralNetControl object.");
      26         128 :   params.addParam<unsigned int>("signal_index",
      27         128 :                                 0,
      28             :                                 "The index of the signal from the LibtorchNeuralNetControl object. "
      29             :                                 "This assumes indexing between [0,num_signals).");
      30          64 :   return params;
      31           0 : }
      32             : 
      33          32 : LibtorchDRLLogProbabilityPostprocessor::LibtorchDRLLogProbabilityPostprocessor(
      34          32 :     const InputParameters & params)
      35          64 :   : GeneralPostprocessor(params), _signal_index(getParam<unsigned int>("signal_index"))
      36             : 
      37             : {
      38          32 : }
      39             : 
      40             : void
      41          32 : LibtorchDRLLogProbabilityPostprocessor::initialSetup()
      42             : {
      43          32 :   GeneralPostprocessor::initialSetup();
      44             : 
      45          32 :   _libtorch_nn_control =
      46          32 :       dynamic_cast<LibtorchDRLControl *>(_fe_problem.getControlWarehouse()
      47          64 :                                              .getActiveObject(getParam<std::string>("control_name"))
      48             :                                              .get());
      49          32 :   if (!_libtorch_nn_control)
      50           0 :     paramError("control_name",
      51             :                "The supplied control object is not derived from LibtorchDRLControl!");
      52             : 
      53          32 :   if (_libtorch_nn_control->numberOfControlSignals() <= _signal_index)
      54           0 :     paramError("signal_index",
      55             :                "The given control object only has ",
      56             :                _libtorch_nn_control->numberOfControlSignals(),
      57             :                " control signals!");
      58          32 : }
      59             : 
      60             : Real
      61        1000 : LibtorchDRLLogProbabilityPostprocessor::getValue() const
      62             : {
      63             :   // Return the value of the control signal
      64        1000 :   return _libtorch_nn_control->getSignalLogProbability(_signal_index);
      65             : }
      66             : 
      67             : #endif

Generated by: LCOV version 1.14