LCOV - code coverage report
Current view: top level - src/libtorch/trainers - LibtorchANNTrainer.C (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 91 94 96.8 %
Date: 2025-07-25 05:00:46 Functions: 5 5 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef MOOSE_LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchANNTrainer.h"
      13             : #include "LibtorchDataset.h"
      14             : #include "Sampler.h"
      15             : 
      16             : registerMooseObject("StochasticToolsApp", LibtorchANNTrainer);
      17             : 
      18             : InputParameters
      19          96 : LibtorchANNTrainer::validParams()
      20             : {
      21          96 :   InputParameters params = SurrogateTrainer::validParams();
      22             : 
      23          96 :   params.addClassDescription("Trains a simple neural network using libtorch.");
      24             : 
      25         288 :   params.addRangeCheckedParam<unsigned int>(
      26         192 :       "num_batches", 1, "1<=num_batches", "Number of batches.");
      27         288 :   params.addRangeCheckedParam<unsigned int>(
      28         192 :       "num_epochs", 1, "0<num_epochs", "Number of training epochs.");
      29         192 :   params.addRangeCheckedParam<Real>(
      30             :       "rel_loss_tol",
      31             :       0,
      32             :       "0<=rel_loss_tol<=1",
      33             :       "The relative loss where we stop the training of the neural net.");
      34         192 :   params.addParam<std::vector<unsigned int>>(
      35          96 :       "num_neurons_per_layer", std::vector<unsigned int>(), "Number of neurons per layer.");
      36         192 :   params.addParam<std::vector<std::string>>(
      37             :       "activation_function",
      38         288 :       std::vector<std::string>({"relu"}),
      39             :       "The type of activation functions to use. It is either one value "
      40             :       "or one value per hidden layer.");
      41         192 :   params.addParam<std::string>(
      42             :       "nn_filename", "net.pt", "Filename used to output the neural net parameters.");
      43         192 :   params.addParam<bool>("read_from_file",
      44         192 :                         false,
      45             :                         "Switch to allow reading old trained neural nets for further training.");
      46         192 :   params.addParam<Real>("learning_rate", 0.001, "Learning rate (relaxation).");
      47         288 :   params.addRangeCheckedParam<unsigned int>(
      48             :       "print_epoch_loss",
      49         192 :       0,
      50             :       "0<=print_epoch_loss",
      51             :       "Epoch training loss printing. 0 - no printing, 1 - every epoch, 10 - every 10th epoch.");
      52         192 :   params.addParam<unsigned int>(
      53         192 :       "seed", 11, "Random number generator seed for stochastic optimizers.");
      54         192 :   params.addParam<unsigned int>(
      55         192 :       "max_processes", 1, "The maximum number of parallel processes that the trainer will use.");
      56             : 
      57         192 :   params.addParam<bool>(
      58         192 :       "standardize_input", true, "Standardize (center and scale) training inputs (x values)");
      59         192 :   params.addParam<bool>(
      60         192 :       "standardize_output", true, "Standardize (center and scale) training outputs (y values)");
      61             : 
      62          96 :   params.suppressParameter<MooseEnum>("response_type");
      63          96 :   return params;
      64         192 : }
      65             : 
      66          48 : LibtorchANNTrainer::LibtorchANNTrainer(const InputParameters & parameters)
      67             :   : SurrogateTrainer(parameters),
      68          48 :     _predictor_row(getPredictorData()),
      69         144 :     _num_neurons_per_layer(declareModelData<std::vector<unsigned int>>(
      70             :         "num_neurons_per_layer", getParam<std::vector<unsigned int>>("num_neurons_per_layer"))),
      71         144 :     _activation_function(declareModelData<std::vector<std::string>>(
      72             :         "activation_function", getParam<std::vector<std::string>>("activation_function"))),
      73          96 :     _nn_filename(getParam<std::string>("nn_filename")),
      74          96 :     _read_from_file(getParam<bool>("read_from_file")),
      75          96 :     _nn(declareModelData<std::shared_ptr<Moose::LibtorchArtificialNeuralNet>>("nn")),
      76          96 :     _standardize_input(getParam<bool>("standardize_input")),
      77          96 :     _standardize_output(getParam<bool>("standardize_output")),
      78          96 :     _input_standardizer(declareModelData<StochasticTools::Standardizer>("input_standardizer")),
      79         192 :     _output_standardizer(declareModelData<StochasticTools::Standardizer>("output_standardizer"))
      80             : {
      81             :   // Fixing the RNG seed to make sure every experiment is the same.
      82             :   // Otherwise sampling / stochastic gradient descent would be different.
      83          96 :   torch::manual_seed(getParam<unsigned int>("seed"));
      84             : 
      85          48 :   _optim_options.optimizer_type = "adam";
      86          96 :   _optim_options.learning_rate = getParam<Real>("learning_rate");
      87          96 :   _optim_options.num_epochs = getParam<unsigned int>("num_epochs");
      88          96 :   _optim_options.num_batches = getParam<unsigned int>("num_batches");
      89          96 :   _optim_options.rel_loss_tol = getParam<Real>("rel_loss_tol");
      90          96 :   _optim_options.print_loss = getParam<unsigned int>("print_epoch_loss") > 0;
      91          96 :   _optim_options.print_epoch_loss = getParam<unsigned int>("print_epoch_loss");
      92          96 :   _optim_options.parallel_processes = getParam<unsigned int>("max_processes");
      93          48 : }
      94             : 
      95             : void
      96          96 : LibtorchANNTrainer::preTrain()
      97             : {
      98             :   // Resize to number of sample points
      99          96 :   _flattened_data.clear();
     100          96 :   _flattened_response.clear();
     101          96 :   _flattened_data.reserve(getLocalSampleSize() * _n_dims);
     102          96 :   _flattened_response.reserve(getLocalSampleSize());
     103          96 : }
     104             : 
     105             : void
     106        3325 : LibtorchANNTrainer::train()
     107             : {
     108       13500 :   for (auto & p : _predictor_row)
     109       10175 :     _flattened_data.push_back(p);
     110             : 
     111        3325 :   _flattened_response.push_back(*_rval);
     112        3325 : }
     113             : 
     114             : void
     115          96 : LibtorchANNTrainer::postTrain()
     116             : {
     117          96 :   _communicator.allgather(_flattened_data);
     118          96 :   _communicator.allgather(_flattened_response);
     119             : 
     120             :   // Then, we create and load our Tensors
     121             :   unsigned int num_samples = _flattened_response.size();
     122          96 :   unsigned int num_inputs = _n_dims;
     123             : 
     124             :   // We create a neural net (for the definition of the net see the header file)
     125          96 :   _nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
     126          96 :       _nn_filename, num_inputs, 1, _num_neurons_per_layer, _activation_function);
     127             : 
     128          96 :   if (_read_from_file)
     129             :     try
     130             :     {
     131           8 :       torch::load(_nn, _nn_filename);
     132           8 :       _console << "Loaded requested .pt file." << std::endl;
     133             :     }
     134           0 :     catch (const c10::Error & e)
     135             :     {
     136           0 :       mooseError("The requested pytorch file could not be loaded.\n", e.msg());
     137           0 :     }
     138             : 
     139             :   // The default data type in pytorch is float, while we use double in MOOSE.
     140             :   // Therefore, in some cases we have to convert Tensors to double.
     141          96 :   auto options = torch::TensorOptions().dtype(at::kDouble);
     142             :   torch::Tensor data_tensor =
     143         192 :       torch::from_blob(_flattened_data.data(), {num_samples, num_inputs}, options).to(at::kDouble);
     144             :   torch::Tensor response_tensor =
     145         192 :       torch::from_blob(_flattened_response.data(), {num_samples, 1}, options).to(at::kDouble);
     146             : 
     147             :   // We standardize the input/output pairs if the user requested it
     148          96 :   if (_standardize_input)
     149             :   {
     150           8 :     auto data_std_mean = torch::std_mean(data_tensor, 0);
     151             :     auto & data_std = std::get<0>(data_std_mean);
     152             :     auto & data_mean = std::get<1>(data_std_mean);
     153             : 
     154          16 :     data_tensor = (data_tensor - data_mean) / data_std;
     155             : 
     156             :     std::vector<Real> converted_data_mean;
     157           8 :     LibtorchUtils::tensorToVector(data_mean, converted_data_mean);
     158             :     std::vector<Real> converted_data_std;
     159           8 :     LibtorchUtils::tensorToVector(data_std, converted_data_std);
     160           8 :     _input_standardizer.set(converted_data_mean, converted_data_std);
     161             :   }
     162             :   else
     163          88 :     _input_standardizer.set(_n_dims);
     164             : 
     165          96 :   if (_standardize_output)
     166             :   {
     167           8 :     auto response_std_mean = torch::std_mean(response_tensor, 0);
     168             :     auto & response_std = std::get<0>(response_std_mean);
     169             :     auto & response_mean = std::get<1>(response_std_mean);
     170             : 
     171          16 :     response_tensor = (response_tensor - response_mean) / response_std;
     172             : 
     173             :     std::vector<Real> converted_response_mean;
     174           8 :     LibtorchUtils::tensorToVector(response_mean, converted_response_mean);
     175             :     std::vector<Real> converted_response_std;
     176           8 :     LibtorchUtils::tensorToVector(response_std, converted_response_std);
     177           8 :     _output_standardizer.set(converted_response_mean, converted_response_std);
     178             :   }
     179             :   else
     180          88 :     _output_standardizer.set(1);
     181             : 
     182             :   // We create a custom data set from our converted data
     183         192 :   Moose::LibtorchDataset my_data(data_tensor, response_tensor);
     184             : 
     185             :   // We create atrainer for our neral net and train it with the dataset
     186          96 :   Moose::LibtorchArtificialNeuralNetTrainer<> trainer(*_nn, comm());
     187          96 :   trainer.train(my_data, _optim_options);
     188         192 : }
     189             : 
     190             : #endif

Generated by: LCOV version 1.14