LCOV - code coverage report
Current view: top level - src/libtorch/trainers - LibtorchDRLControlTrainer.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 200 216 92.6 %
Date: 2025-07-25 05:00:46 Functions: 15 15 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchDataset.h"
      13             : #include "LibtorchArtificialNeuralNetTrainer.h"
      14             : #include "LibtorchUtils.h"
      15             : #include "LibtorchDRLControlTrainer.h"
      16             : #include "Sampler.h"
      17             : #include "Function.h"
      18             : 
      19             : registerMooseObject("StochasticToolsApp", LibtorchDRLControlTrainer);
      20             : 
      21             : InputParameters
      22          32 : LibtorchDRLControlTrainer::validParams()
      23             : {
      24          32 :   InputParameters params = SurrogateTrainerBase::validParams();
      25             : 
      26          32 :   params.addClassDescription(
      27             :       "Trains a neural network controller using the Proximal Policy Optimization (PPO) algorithm.");
      28             : 
      29          64 :   params.addRequiredParam<std::vector<ReporterName>>(
      30             :       "response", "Reporter values containing the response values from the model.");
      31          64 :   params.addParam<std::vector<Real>>(
      32             :       "response_shift_factors",
      33             :       "A shift constant which will be used to shift the response values. This is used for the "
      34             :       "manipulation of the neural net inputs for better training efficiency.");
      35          64 :   params.addParam<std::vector<Real>>(
      36             :       "response_scaling_factors",
      37             :       "A normalization constant which will be used to divide the response values. This is used for "
      38             :       "the manipulation of the neural net inputs for better training efficiency.");
      39          64 :   params.addRequiredParam<std::vector<ReporterName>>(
      40             :       "control",
      41             :       "Reporters containing the values of the controlled quantities (control signals) from the "
      42             :       "model simulations.");
      43          64 :   params.addRequiredParam<std::vector<ReporterName>>(
      44             :       "log_probability",
      45             :       "Reporters containing the log probabilities of the actions taken during the simulations.");
      46          64 :   params.addRequiredParam<ReporterName>(
      47             :       "reward", "Reporter containing the earned time-dependent rewards from the simulation.");
      48          96 :   params.addRangeCheckedParam<unsigned int>(
      49             :       "input_timesteps",
      50          64 :       1,
      51             :       "1<=input_timesteps",
      52             :       "Number of time steps to use in the input data, if larger than 1, "
      53             :       "data from the previous timesteps will be used as inputs in the training.");
      54          64 :   params.addParam<unsigned int>("skip_num_rows",
      55          64 :                                 1,
      56             :                                 "Number of rows to ignore from training. We usually skip the 1st "
      57             :                                 "row from the reporter since it contains only initial values.");
      58             : 
      59          64 :   params.addRequiredParam<unsigned int>("num_epochs", "Number of epochs for the training.");
      60             : 
      61          64 :   params.addRequiredRangeCheckedParam<Real>(
      62             :       "critic_learning_rate",
      63             :       "0<critic_learning_rate",
      64             :       "Learning rate (relaxation) for the emulator training.");
      65          64 :   params.addRequiredParam<std::vector<unsigned int>>(
      66             :       "num_critic_neurons_per_layer", "Number of neurons per layer in the emulator neural net.");
      67          64 :   params.addParam<std::vector<std::string>>(
      68             :       "critic_activation_functions",
      69          96 :       std::vector<std::string>({"relu"}),
      70             :       "The type of activation functions to use in the emulator neural net. It is either one value "
      71             :       "or one value per hidden layer.");
      72             : 
      73          64 :   params.addRequiredRangeCheckedParam<Real>(
      74             :       "control_learning_rate",
      75             :       "0<control_learning_rate",
      76             :       "Learning rate (relaxation) for the control neural net training.");
      77          64 :   params.addRequiredParam<std::vector<unsigned int>>(
      78             :       "num_control_neurons_per_layer",
      79             :       "Number of neurons per layer for the control neural network.");
      80          64 :   params.addParam<std::vector<std::string>>(
      81             :       "control_activation_functions",
      82          96 :       std::vector<std::string>({"relu"}),
      83             :       "The type of activation functions to use in the control neural net. It "
      84             :       "is either one value "
      85             :       "or one value per hidden layer.");
      86             : 
      87          64 :   params.addParam<std::string>("filename_base",
      88             :                                "Filename used to output the neural net parameters.");
      89             : 
      90          64 :   params.addParam<unsigned int>(
      91          64 :       "seed", 11, "Random number generator seed for stochastic optimizers.");
      92             : 
      93          64 :   params.addRequiredParam<std::vector<Real>>(
      94             :       "action_standard_deviations", "Standard deviation value used while sampling the actions.");
      95             : 
      96          64 :   params.addParam<Real>(
      97          64 :       "clip_parameter", 0.2, "Clip parameter used while clamping the advantage value.");
      98          96 :   params.addRangeCheckedParam<unsigned int>(
      99             :       "update_frequency",
     100          64 :       1,
     101             :       "1<=update_frequency",
     102             :       "Number of transient simulation data to collect for updating the controller neural network.");
     103             : 
     104          96 :   params.addRangeCheckedParam<Real>(
     105             :       "decay_factor",
     106          64 :       1.0,
     107             :       "0.0<=decay_factor<=1.0",
     108             :       "Decay factor for calculating the return. This accounts for decreased "
     109             :       "reward values from the later steps.");
     110             : 
     111          64 :   params.addParam<bool>(
     112          64 :       "read_from_file", false, "Switch to read the neural network parameters from a file.");
     113          64 :   params.addParam<bool>(
     114             :       "shift_outputs",
     115          64 :       true,
     116             :       "If we would like to shift the outputs the realign the input-output pairs.");
     117          64 :   params.addParam<bool>(
     118             :       "standardize_advantage",
     119          64 :       true,
     120             :       "Switch to enable the shifting and normalization of the advantages in the PPO algorithm.");
     121          64 :   params.addParam<unsigned int>("loss_print_frequency",
     122          64 :                                 0,
     123             :                                 "The frequency which is used to print the loss values. If 0, the "
     124             :                                 "loss values are not printed.");
     125          32 :   return params;
     126         128 : }
     127             : 
     128          16 : LibtorchDRLControlTrainer::LibtorchDRLControlTrainer(const InputParameters & parameters)
     129             :   : SurrogateTrainerBase(parameters),
     130          16 :     _response_names(getParam<std::vector<ReporterName>>("response")),
     131          64 :     _response_shift_factors(isParamValid("response_shift_factors")
     132          16 :                                 ? getParam<std::vector<Real>>("response_shift_factors")
     133          16 :                                 : std::vector<Real>(_response_names.size(), 0.0)),
     134          64 :     _response_scaling_factors(isParamValid("response_scaling_factors")
     135          16 :                                   ? getParam<std::vector<Real>>("response_scaling_factors")
     136          16 :                                   : std::vector<Real>(_response_names.size(), 1.0)),
     137          32 :     _control_names(getParam<std::vector<ReporterName>>("control")),
     138          32 :     _log_probability_names(getParam<std::vector<ReporterName>>("log_probability")),
     139          32 :     _reward_name(getParam<ReporterName>("reward")),
     140          16 :     _reward_value_pointer(&getReporterValueByName<std::vector<Real>>(_reward_name)),
     141          32 :     _input_timesteps(getParam<unsigned int>("input_timesteps")),
     142          16 :     _num_inputs(_input_timesteps * _response_names.size()),
     143          16 :     _num_outputs(_control_names.size()),
     144          16 :     _input_data(std::vector<std::vector<Real>>(_num_inputs)),
     145          16 :     _output_data(std::vector<std::vector<Real>>(_num_outputs)),
     146          16 :     _log_probability_data(std::vector<std::vector<Real>>(_num_outputs)),
     147          32 :     _num_epochs(getParam<unsigned int>("num_epochs")),
     148          32 :     _num_critic_neurons_per_layer(
     149             :         getParam<std::vector<unsigned int>>("num_critic_neurons_per_layer")),
     150          32 :     _critic_learning_rate(getParam<Real>("critic_learning_rate")),
     151          32 :     _num_control_neurons_per_layer(
     152             :         getParam<std::vector<unsigned int>>("num_control_neurons_per_layer")),
     153          32 :     _control_learning_rate(getParam<Real>("control_learning_rate")),
     154          32 :     _update_frequency(getParam<unsigned int>("update_frequency")),
     155          32 :     _clip_param(getParam<Real>("clip_parameter")),
     156          32 :     _decay_factor(getParam<Real>("decay_factor")),
     157          32 :     _action_std(getParam<std::vector<Real>>("action_standard_deviations")),
     158          32 :     _filename_base(isParamValid("filename_base") ? getParam<std::string>("filename_base") : ""),
     159          32 :     _read_from_file(getParam<bool>("read_from_file")),
     160          32 :     _shift_outputs(getParam<bool>("shift_outputs")),
     161          32 :     _standardize_advantage(getParam<bool>("standardize_advantage")),
     162          32 :     _loss_print_frequency(getParam<unsigned int>("loss_print_frequency")),
     163          80 :     _update_counter(_update_frequency)
     164             : {
     165          16 :   if (_response_names.size() != _response_shift_factors.size())
     166           0 :     paramError("response_shift_factors",
     167             :                "The number of shift factors is not the same as the number of responses!");
     168             : 
     169          16 :   if (_response_names.size() != _response_scaling_factors.size())
     170           0 :     paramError(
     171             :         "response_scaling_factors",
     172             :         "The number of normalization coefficients is not the same as the number of responses!");
     173             : 
     174             :   // We establish the links with the chosen reporters
     175          16 :   getReporterPointers(_response_names, _response_value_pointers);
     176          16 :   getReporterPointers(_control_names, _control_value_pointers);
     177          16 :   getReporterPointers(_log_probability_names, _log_probability_value_pointers);
     178             : 
     179             :   // Fixing the RNG seed to make sure every experiment is the same.
     180             :   // Otherwise sampling / stochastic gradient descent would be different.
     181          48 :   torch::manual_seed(getParam<unsigned int>("seed"));
     182             : 
     183             :   // Convert the user input standard deviations to a diagonal tensor
     184          16 :   _std = torch::eye(_control_names.size());
     185          32 :   for (unsigned int i = 0; i < _control_names.size(); ++i)
     186          32 :     _std[i][i] = _action_std[i];
     187             : 
     188          16 :   bool filename_valid = isParamValid("filename_base");
     189             : 
     190             :   // Initializing the control neural net so that the control can grab it right away
     191          16 :   _control_nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
     192          32 :       filename_valid ? _filename_base + "_control.net" : "control.net",
     193          16 :       _num_inputs,
     194          16 :       _num_outputs,
     195             :       _num_control_neurons_per_layer,
     196             :       getParam<std::vector<std::string>>("control_activation_functions"));
     197             : 
     198             :   // We read parameters for the control neural net if it is requested
     199          16 :   if (_read_from_file)
     200             :   {
     201             :     try
     202             :     {
     203           0 :       torch::load(_control_nn, _control_nn->name());
     204           0 :       _console << "Loaded requested .pt file." << std::endl;
     205             :     }
     206           0 :     catch (const c10::Error & e)
     207             :     {
     208           0 :       mooseError("The requested pytorch file could not be loaded for the control neural net.\n",
     209             :                  e.msg());
     210           0 :     }
     211             :   }
     212          16 :   else if (filename_valid)
     213           0 :     torch::save(_control_nn, _control_nn->name());
     214             : 
     215             :   // Initialize the critic neural net
     216          16 :   _critic_nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
     217          16 :       filename_valid ? _filename_base + "_ctiric.net" : "ctiric.net",
     218             :       _num_inputs,
     219          32 :       1,
     220             :       _num_critic_neurons_per_layer,
     221             :       getParam<std::vector<std::string>>("critic_activation_functions"));
     222             : 
     223             :   // We read parameters for the critic neural net if it is requested
     224          16 :   if (_read_from_file)
     225             :   {
     226             :     try
     227             :     {
     228           0 :       torch::load(_critic_nn, _critic_nn->name());
     229           0 :       _console << "Loaded requested .pt file." << std::endl;
     230             :     }
     231           0 :     catch (const c10::Error & e)
     232             :     {
     233           0 :       mooseError("The requested pytorch file could not be loaded for the critic neural net.\n",
     234             :                  e.msg());
     235           0 :     }
     236             :   }
     237          16 :   else if (filename_valid)
     238           0 :     torch::save(_critic_nn, _critic_nn->name());
     239          16 : }
     240             : 
     241             : void
     242          16 : LibtorchDRLControlTrainer::execute()
     243             : {
     244             :   // Extract data from the reporters
     245          16 :   getInputDataFromReporter(_input_data, _response_value_pointers, _input_timesteps);
     246          16 :   getOutputDataFromReporter(_output_data, _control_value_pointers);
     247          16 :   getOutputDataFromReporter(_log_probability_data, _log_probability_value_pointers);
     248          16 :   getRewardDataFromReporter(_reward_data, _reward_value_pointer);
     249             : 
     250             :   // Calculate return from the reward (discounting the reward)
     251          16 :   computeRewardToGo();
     252             : 
     253          16 :   _update_counter--;
     254             : 
     255             :   // Only update the NNs when
     256          16 :   if (_update_counter == 0)
     257             :   {
     258             :     // We compute the average reward first
     259           8 :     computeAverageEpisodeReward();
     260             : 
     261             :     // Transform input/output/return data to torch::Tensor
     262           8 :     convertDataToTensor(_input_data, _input_tensor);
     263           8 :     convertDataToTensor(_output_data, _output_tensor);
     264           8 :     convertDataToTensor(_log_probability_data, _log_probability_tensor);
     265             : 
     266             :     // Discard (detach) the gradient info for return data
     267           8 :     LibtorchUtils::vectorToTensor<Real>(_return_data, _return_tensor, true);
     268             : 
     269             :     // We train the controller using the emulator to get a good control strategy
     270           8 :     trainController();
     271             : 
     272             :     // We clean the training data after contoller update and reset the counter
     273           8 :     resetData();
     274             :   }
     275          16 : }
     276             : 
     277             : void
     278           8 : LibtorchDRLControlTrainer::computeAverageEpisodeReward()
     279             : {
     280           8 :   if (_reward_data.size())
     281           8 :     _average_episode_reward =
     282           8 :         std::accumulate(_reward_data.begin(), _reward_data.end(), 0.0) / _reward_data.size();
     283             :   else
     284           0 :     _average_episode_reward = 0.0;
     285           8 : }
     286             : 
     287             : void
     288          16 : LibtorchDRLControlTrainer::computeRewardToGo()
     289             : {
     290             :   // Get reward data from one simulation
     291             :   std::vector<Real> reward_data_per_sim;
     292             :   std::vector<Real> return_data_per_sim;
     293          16 :   getRewardDataFromReporter(reward_data_per_sim, _reward_value_pointer);
     294             : 
     295             :   // Discount the reward to get the return value, we need this to be able to anticipate
     296             :   // rewards based on the current behavior.
     297          16 :   Real discounted_reward(0.0);
     298         832 :   for (int i = reward_data_per_sim.size() - 1; i >= 0; --i)
     299             :   {
     300         816 :     discounted_reward = reward_data_per_sim[i] + discounted_reward * _decay_factor;
     301             : 
     302             :     // We are inserting to the front of the vector and push the rest back, this will
     303             :     // ensure that the first element of the vector is the discounter reward for the whole transient
     304         816 :     return_data_per_sim.insert(return_data_per_sim.begin(), discounted_reward);
     305             :   }
     306             : 
     307             :   // Save and accumulate the return values
     308          16 :   _return_data.insert(_return_data.end(), return_data_per_sim.begin(), return_data_per_sim.end());
     309          16 : }
     310             : 
     311             : void
     312           8 : LibtorchDRLControlTrainer::trainController()
     313             : {
     314             :   // Define the optimizers for the training
     315          16 :   torch::optim::Adam actor_optimizer(_control_nn->parameters(),
     316          16 :                                      torch::optim::AdamOptions(_control_learning_rate));
     317             : 
     318          16 :   torch::optim::Adam critic_optimizer(_critic_nn->parameters(),
     319          16 :                                       torch::optim::AdamOptions(_critic_learning_rate));
     320             : 
     321             :   // Compute the approximate value (return) from the critic neural net and use it to compute an
     322             :   // advantage
     323           8 :   auto value = evaluateValue(_input_tensor).detach();
     324           8 :   auto advantage = _return_tensor - value;
     325             : 
     326             :   // If requested, standardize the advantage
     327           8 :   if (_standardize_advantage)
     328          32 :     advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-10);
     329             : 
     330         808 :   for (unsigned int epoch = 0; epoch < _num_epochs; ++epoch)
     331             :   {
     332             :     // Get the approximate return from the neural net again (this one does have an associated
     333             :     // gradient)
     334         800 :     value = evaluateValue(_input_tensor);
     335             :     // Get the approximate logarithmic action probability using the control neural net
     336         800 :     auto curr_log_probability = evaluateAction(_input_tensor, _output_tensor);
     337             : 
     338             :     // Prepare the ratio by using the e^(logx-logy)=x/y expression
     339        1600 :     auto ratio = (curr_log_probability - _log_probability_tensor).exp();
     340             : 
     341             :     // Use clamping for limiting
     342             :     auto surr1 = ratio * advantage;
     343        2400 :     auto surr2 = torch::clamp(ratio, 1.0 - _clip_param, 1.0 + _clip_param) * advantage;
     344             : 
     345             :     // Compute loss values for the critic and the control neural net
     346        1600 :     auto actor_loss = -torch::min(surr1, surr2).mean();
     347             :     auto critic_loss = torch::mse_loss(value, _return_tensor);
     348             : 
     349             :     // Update the weights in the neural nets
     350         800 :     actor_optimizer.zero_grad();
     351        1600 :     actor_loss.backward();
     352         800 :     actor_optimizer.step();
     353             : 
     354         800 :     critic_optimizer.zero_grad();
     355        1600 :     critic_loss.backward();
     356         800 :     critic_optimizer.step();
     357             : 
     358             :     // print loss per epoch
     359         800 :     if (_loss_print_frequency)
     360         800 :       if (epoch % _loss_print_frequency == 0)
     361          80 :         _console << "Epoch: " << epoch << " | Actor Loss: " << COLOR_GREEN
     362          80 :                  << actor_loss.item<double>() << COLOR_DEFAULT << " | Critic Loss: " << COLOR_GREEN
     363          80 :                  << critic_loss.item<double>() << COLOR_DEFAULT << std::endl;
     364             :   }
     365             : 
     366             :   // Save the controller neural net so our controller can read it, we also save the critic if we
     367             :   // want to continue training
     368           8 :   torch::save(_control_nn, _control_nn->name());
     369           8 :   torch::save(_critic_nn, _critic_nn->name());
     370           8 : }
     371             : 
     372             : void
     373          24 : LibtorchDRLControlTrainer::convertDataToTensor(std::vector<std::vector<Real>> & vector_data,
     374             :                                                torch::Tensor & tensor_data,
     375             :                                                const bool detach)
     376             : {
     377          72 :   for (unsigned int i = 0; i < vector_data.size(); ++i)
     378             :   {
     379             :     torch::Tensor input_row;
     380          48 :     LibtorchUtils::vectorToTensor(vector_data[i], input_row, detach);
     381             : 
     382          48 :     if (i == 0)
     383             :       tensor_data = input_row;
     384             :     else
     385          72 :       tensor_data = torch::cat({tensor_data, input_row}, 1);
     386             :   }
     387             : 
     388          24 :   if (detach)
     389           0 :     tensor_data.detach();
     390          24 : }
     391             : 
     392             : torch::Tensor
     393         808 : LibtorchDRLControlTrainer::evaluateValue(torch::Tensor & input)
     394             : {
     395         808 :   return _critic_nn->forward(input);
     396             : }
     397             : 
     398             : torch::Tensor
     399         800 : LibtorchDRLControlTrainer::evaluateAction(torch::Tensor & input, torch::Tensor & output)
     400             : {
     401         800 :   torch::Tensor var = torch::matmul(_std, _std);
     402             : 
     403             :   // Compute an action and get it's logarithmic proability based on an assumed Gaussian distribution
     404         800 :   torch::Tensor action = _control_nn->forward(input);
     405        3200 :   return -((action - output) * (action - output)) / (2 * var) - torch::log(_std) -
     406        2400 :          std::log(std::sqrt(2 * M_PI));
     407             : }
     408             : 
     409             : void
     410           8 : LibtorchDRLControlTrainer::resetData()
     411             : {
     412          40 :   for (auto & data : _input_data)
     413             :     data.clear();
     414          16 :   for (auto & data : _output_data)
     415             :     data.clear();
     416          16 :   for (auto & data : _log_probability_data)
     417             :     data.clear();
     418             : 
     419             :   _reward_data.clear();
     420             :   _return_data.clear();
     421             : 
     422           8 :   _update_counter = _update_frequency;
     423           8 : }
     424             : 
     425             : void
     426          16 : LibtorchDRLControlTrainer::getInputDataFromReporter(
     427             :     std::vector<std::vector<Real>> & data,
     428             :     const std::vector<const std::vector<Real> *> & reporter_links,
     429             :     const unsigned int num_timesteps)
     430             : {
     431          48 :   for (const auto & rep_i : index_range(reporter_links))
     432             :   {
     433          32 :     std::vector<Real> reporter_data = *reporter_links[rep_i];
     434             : 
     435             :     // We shift and scale the inputs to get better training efficiency
     436          32 :     std::transform(
     437             :         reporter_data.begin(),
     438             :         reporter_data.end(),
     439             :         reporter_data.begin(),
     440        1664 :         [this, &rep_i](Real value) -> Real
     441        1664 :         { return (value - _response_shift_factors[rep_i]) * _response_scaling_factors[rep_i]; });
     442             : 
     443             :     // Fill the corresponding containers
     444          96 :     for (const auto & start_step : make_range(num_timesteps))
     445             :     {
     446          64 :       unsigned int row = reporter_links.size() * start_step + rep_i;
     447          96 :       for (unsigned int fill_i = 1; fill_i < num_timesteps - start_step; ++fill_i)
     448          32 :         data[row].push_back(reporter_data[0]);
     449             : 
     450          64 :       data[row].insert(data[row].end(),
     451             :                        reporter_data.begin(),
     452             :                        reporter_data.begin() + start_step + reporter_data.size() -
     453          64 :                            (num_timesteps - 1) - _shift_outputs);
     454             :     }
     455             :   }
     456          16 : }
     457             : 
     458             : void
     459          32 : LibtorchDRLControlTrainer::getOutputDataFromReporter(
     460             :     std::vector<std::vector<Real>> & data,
     461             :     const std::vector<const std::vector<Real> *> & reporter_links)
     462             : {
     463          64 :   for (const auto & rep_i : index_range(reporter_links))
     464             :     // Fill the corresponding containers
     465          32 :     data[rep_i].insert(data[rep_i].end(),
     466          32 :                        reporter_links[rep_i]->begin() + _shift_outputs,
     467             :                        reporter_links[rep_i]->end());
     468          32 : }
     469             : 
     470             : void
     471          32 : LibtorchDRLControlTrainer::getRewardDataFromReporter(std::vector<Real> & data,
     472             :                                                      const std::vector<Real> * const reporter_link)
     473             : {
     474             :   // Fill the corresponding container
     475          32 :   data.insert(data.end(), reporter_link->begin() + _shift_outputs, reporter_link->end());
     476          32 : }
     477             : 
     478             : void
     479          48 : LibtorchDRLControlTrainer::getReporterPointers(
     480             :     const std::vector<ReporterName> & reporter_names,
     481             :     std::vector<const std::vector<Real> *> & pointer_storage)
     482             : {
     483             :   pointer_storage.clear();
     484         112 :   for (const auto & name : reporter_names)
     485          64 :     pointer_storage.push_back(&getReporterValueByName<std::vector<Real>>(name));
     486          48 : }
     487             : 
     488             : #endif

Generated by: LCOV version 1.14