LCOV - code coverage report
Current view: top level - src/libtorch/controls - LibtorchDRLControl.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 42 44 95.5 %
Date: 2025-07-25 05:00:46 Functions: 5 5 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchDRLControl.h"
      13             : #include "TorchScriptModule.h"
      14             : #include "LibtorchArtificialNeuralNet.h"
      15             : #include "Transient.h"
      16             : 
      17             : registerMooseObject("StochasticToolsApp", LibtorchDRLControl);
      18             : 
      19             : InputParameters
      20          64 : LibtorchDRLControl::validParams()
      21             : {
      22          64 :   InputParameters params = LibtorchNeuralNetControl::validParams();
      23          64 :   params.addClassDescription(
      24             :       "Sets the value of multiple 'Real' input parameters and postprocessors based on a Deep "
      25             :       "Reinforcement Learning (DRL) neural network trained using a PPO algorithm.");
      26         128 :   params.addRequiredParam<std::vector<Real>>(
      27             :       "action_standard_deviations", "Standard deviation value used while sampling the actions.");
      28         128 :   params.addParam<unsigned int>("seed", "Seed for the random number generator.");
      29             : 
      30          64 :   return params;
      31           0 : }
      32             : 
      33          32 : LibtorchDRLControl::LibtorchDRLControl(const InputParameters & parameters)
      34             :   : LibtorchNeuralNetControl(parameters),
      35          64 :     _current_control_signal_log_probabilities(std::vector<Real>(_control_names.size(), 0.0)),
      36          96 :     _action_std(getParam<std::vector<Real>>("action_standard_deviations"))
      37             : {
      38          32 :   if (_control_names.size() != _action_std.size())
      39           0 :     paramError("action_standard_deviations",
      40             :                "Number of action_standard_deviations does not match the number of controlled "
      41             :                "parameters.");
      42             : 
      43             :   // Fixing the RNG seed to make sure every experiment is the same.
      44          64 :   if (isParamValid("seed"))
      45          24 :     torch::manual_seed(getParam<unsigned int>("seed"));
      46             : 
      47             :   // We convert and store the user-supplied standard deviations into a tensor which can be easily
      48             :   // used by routines in libtorch
      49          32 :   _std = torch::eye(_control_names.size());
      50          64 :   for (unsigned int i = 0; i < _control_names.size(); ++i)
      51          64 :     _std[i][i] = _action_std[i];
      52          32 : }
      53             : 
      54             : void
      55         992 : LibtorchDRLControl::execute()
      56             : {
      57         992 :   if (_nn)
      58             :   {
      59         912 :     unsigned int n_controls = _control_names.size();
      60         912 :     unsigned int num_old_timesteps = _input_timesteps - 1;
      61             : 
      62             :     // Fill a vector with the current values of the responses
      63         912 :     updateCurrentResponse();
      64             : 
      65             :     // If this is the first time this control is called and we need to use older values, fill up the
      66             :     // needed old values using the initial values
      67         912 :     if (_old_responses.empty())
      68          24 :       _old_responses.assign(num_old_timesteps, _current_response);
      69             : 
      70             :     // Organize the old an current solution into a tensor so we can evaluate the neural net
      71         912 :     torch::Tensor input_tensor = prepareInputTensor();
      72             : 
      73             :     // Evaluate the neural network to get the expected control value
      74         912 :     torch::Tensor output_tensor = _nn->forward(input_tensor);
      75             : 
      76             :     // Sample control value (action) from Gaussian distribution
      77         912 :     torch::Tensor action = at::normal(output_tensor, _std);
      78             : 
      79             :     // Compute log probability
      80         912 :     torch::Tensor log_probability = computeLogProbability(action, output_tensor);
      81             : 
      82             :     // Convert data
      83        1824 :     _current_control_signals = {action.data_ptr<Real>(), action.data_ptr<Real>() + action.size(1)};
      84             : 
      85             :     _current_control_signal_log_probabilities = {log_probability.data_ptr<Real>(),
      86         912 :                                                  log_probability.data_ptr<Real>() +
      87        1824 :                                                      log_probability.size(1)};
      88             : 
      89        1824 :     for (unsigned int control_i = 0; control_i < n_controls; ++control_i)
      90             :     {
      91             :       // We scale the controllable value for physically meaningful control action
      92         912 :       setControllableValueByName<Real>(_control_names[control_i],
      93         912 :                                        _current_control_signals[control_i] *
      94             :                                            _action_scaling_factors[control_i]);
      95             :     }
      96             : 
      97             :     // We add the curent solution to the old solutions and move everything in there one step
      98             :     // backward
      99         912 :     std::rotate(_old_responses.rbegin(), _old_responses.rbegin() + 1, _old_responses.rend());
     100         912 :     _old_responses[0] = _current_response;
     101             :   }
     102         992 : }
     103             : 
     104             : torch::Tensor
     105         912 : LibtorchDRLControl::computeLogProbability(const torch::Tensor & action,
     106             :                                           const torch::Tensor & output_tensor)
     107             : {
     108             :   // Logarithmic probability of taken action, given the current distribution.
     109         912 :   torch::Tensor var = torch::matmul(_std, _std);
     110             : 
     111        3648 :   return -((action - output_tensor) * (action - output_tensor)) / (2.0 * var) - torch::log(_std) -
     112        1824 :          std::log(std::sqrt(2.0 * M_PI));
     113             : }
     114             : 
     115             : Real
     116        1000 : LibtorchDRLControl::getSignalLogProbability(const unsigned int signal_index) const
     117             : {
     118             :   mooseAssert(signal_index < _control_names.size(),
     119             :               "The index of the requested control signal is not in the [0," +
     120             :                   std::to_string(_control_names.size()) + ") range!");
     121        1000 :   return _current_control_signal_log_probabilities[signal_index];
     122             : }
     123             : 
     124             : #endif

Generated by: LCOV version 1.14