Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef MOOSE_LIBTORCH_ENABLED 11 : 12 : #include "LibtorchDRLLogProbabilityPostprocessor.h" 13 : 14 : registerMooseObject("StochasticToolsApp", LibtorchDRLLogProbabilityPostprocessor); 15 : 16 : InputParameters 17 64 : LibtorchDRLLogProbabilityPostprocessor::validParams() 18 : { 19 64 : InputParameters params = GeneralPostprocessor::validParams(); 20 : 21 64 : params.addClassDescription( 22 : "Computes the logarithmic probability of the action in a given LibtorchDRLController."); 23 : 24 128 : params.addRequiredParam<std::string>("control_name", 25 : "The name of the LibtorchNeuralNetControl object."); 26 128 : params.addParam<unsigned int>("signal_index", 27 128 : 0, 28 : "The index of the signal from the LibtorchNeuralNetControl object. " 29 : "This assumes indexing between [0,num_signals)."); 30 64 : return params; 31 0 : } 32 : 33 32 : LibtorchDRLLogProbabilityPostprocessor::LibtorchDRLLogProbabilityPostprocessor( 34 32 : const InputParameters & params) 35 64 : : GeneralPostprocessor(params), _signal_index(getParam<unsigned int>("signal_index")) 36 : 37 : { 38 32 : } 39 : 40 : void 41 32 : LibtorchDRLLogProbabilityPostprocessor::initialSetup() 42 : { 43 32 : GeneralPostprocessor::initialSetup(); 44 : 45 32 : _libtorch_nn_control = 46 32 : dynamic_cast<LibtorchDRLControl *>(_fe_problem.getControlWarehouse() 47 64 : .getActiveObject(getParam<std::string>("control_name")) 48 : .get()); 49 32 : if (!_libtorch_nn_control) 50 0 : paramError("control_name", 51 : "The supplied control object is not derived from LibtorchDRLControl!"); 52 : 53 32 : if (_libtorch_nn_control->numberOfControlSignals() <= _signal_index) 54 0 : paramError("signal_index", 55 : "The given control object only has ", 56 : _libtorch_nn_control->numberOfControlSignals(), 57 : " control signals!"); 58 32 : } 59 : 60 : Real 61 1000 : LibtorchDRLLogProbabilityPostprocessor::getValue() const 62 : { 63 : // Return the value of the control signal 64 1000 : return _libtorch_nn_control->getSignalLogProbability(_signal_index); 65 : } 66 : 67 : #endif