LCOV - code coverage report
Current view: top level - include/trainers - SurrogateTrainer.h (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 17 24 70.8 %
Date: 2025-07-25 05:00:46 Functions: 9 16 56.2 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "StochasticToolsApp.h"
      13             : #include "GeneralUserObject.h"
      14             : #include "LoadSurrogateDataAction.h"
      15             : #include "RestartableModelInterface.h"
      16             : 
      17             : #include "Sampler.h"
      18             : #include "StochasticToolsApp.h"
      19             : #include "SurrogateModelInterface.h"
      20             : #include "MooseRandom.h"
      21             : 
      22             : class TrainingDataBase;
      23             : template <typename T>
      24             : class TrainingData;
      25             : 
      26             : /**
      27             :  * This is the base trainer class whose main functionality is the API for declaring
      28             :  * model data. All trainer must at least derive from this. Unless a trainer needs
      29             :  * to perform its own loop through data, it is highly recommended to derive from
      30             :  * SurrogateTrainer.
      31             :  */
      32             : class SurrogateTrainerBase : public GeneralUserObject, public RestartableModelInterface
      33             : {
      34             : public:
      35             :   static InputParameters validParams();
      36             :   SurrogateTrainerBase(const InputParameters & parameters);
      37             : 
      38          16 :   virtual void initialize() {}                         // not required, but available
      39        2388 :   virtual void finalize() {}                           // not required, but available
      40           0 :   virtual void threadJoin(const UserObject &) final {} // GeneralUserObjects are not threaded
      41             : };
      42             : 
      43             : /**
      44             :  * This is the main trainer base class. The main purpose is to avoid a lot of code
      45             :  * duplication from performing sampler loops and dealing with distributed data. There
      46             :  * three functions that derived trainer should override: preTrain, train, and postTrain.
      47             :  * Derived class should also use the getTrainingData functionality, which provides a
      48             :  * refernce to vector reporter data in its current state within the sampler loop.
      49             :  *
      50             :  * The idea behind this is to emulate the element loop behaiviour in other MOOSE objects.
      51             :  * For instance, in a kernel, the value of _u corresponds to the solution in an element.
      52             :  * Here data referenced with getTrainingData will correspond to the the value of the
      53             :  * data in a sampler row.
      54             :  */
      55             : class SurrogateTrainer : public SurrogateTrainerBase, public SurrogateModelInterface
      56             : {
      57             : public:
      58             :   static InputParameters validParams();
      59             :   SurrogateTrainer(const InputParameters & parameters);
      60             : 
      61             :   virtual void initialize() final;
      62             :   virtual void execute() final;
      63        1190 :   virtual void finalize() final {}
      64             : 
      65             : protected:
      66             :   /*
      67             :    * Setup function called before sampler loop
      68             :    */
      69           0 :   virtual void preTrain() {}
      70             : 
      71             :   /*
      72             :    * Function needed to be overried, called during sampler loop
      73             :    */
      74           0 :   virtual void train() {}
      75             : 
      76             :   /*
      77             :    * Function called after sampler loop, used for mpi communication mainly
      78             :    */
      79           0 :   virtual void postTrain() {}
      80             : 
      81             :   // TRAINING_DATA_BEGIN
      82             : 
      83             :   /*
      84             :    * Get a reference to training data given a reporter name
      85             :    */
      86             :   template <typename T>
      87             :   const T & getTrainingData(const ReporterName & rname);
      88             : 
      89             :   /*
      90             :    * Get a reference to the sampler row data
      91             :    */
      92         244 :   const std::vector<Real> & getSamplerData() const { return _row_data; };
      93             : 
      94             :   /*
      95             :    * Get a reference to the predictor row data
      96             :    */
      97        1198 :   const std::vector<Real> & getPredictorData() const { return _predictor_data; };
      98             : 
      99             :   /*
     100             :    * Get current sample size (this is recalculated to reflect the number of skipped samples)
     101             :    */
     102       46352 :   unsigned int getCurrentSampleSize() const { return _current_sample_size; };
     103             : 
     104             :   /*
     105             :    * Get current local sample size (recalculated to reflect number of skipped samples)
     106             :    */
     107        1488 :   unsigned int getLocalSampleSize() const { return _local_sample_size; };
     108             : 
     109             :   // TRAINING_DATA_END
     110             : 
     111             :   /*
     112             :    * Evaluate CV error using _cv_surrogate and appropriate predictor row.
     113             :    */
     114             :   virtual std::vector<Real> evaluateModelError(const SurrogateModel & surr);
     115             : 
     116             :   // TRAINING_DATA_MEMBERS
     117             :   ///@{
     118             :   /// Sampler being used for training
     119             :   Sampler & _sampler;
     120             :   /// During training loop, this is the row index of the data
     121             :   dof_id_type _row;
     122             :   /// During training loop, this is the local row index of the data
     123             :   dof_id_type _local_row;
     124             :   /// Response value
     125             :   const Real * _rval;
     126             :   /// Vector response value
     127             :   const std::vector<Real> * _rvecval;
     128             :   /// Predictor values from reporters
     129             :   std::vector<const Real *> _pvals;
     130             :   /// Columns from sampler for predictors
     131             :   std::vector<unsigned int> _pcols;
     132             :   /// Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols.size().
     133             :   unsigned int _n_dims;
     134             :   /// The number of outputs
     135             :   unsigned int & _n_outputs;
     136             :   ///@}
     137             :   // TRAINING_DATA_MEMBERS_END
     138             : 
     139             : private:
     140             :   /*
     141             :    * Called at the beginning of execute() to make sure values are set properly
     142             :    */
     143             :   void checkIntegrity() const;
     144             : 
     145             :   /*
     146             :    * Main model training method - called during crossValidate() and for final model training.
     147             :    */
     148             :   void executeTraining();
     149             : 
     150             :   /*
     151             :    * Call if cross-validation is turned on.
     152             :    */
     153             :   std::vector<Real> crossValidate();
     154             : 
     155             :   /*
     156             :    * Update predictor row (uses both Sampler and Reporter values, according to _pvals and _pcols)
     157             :    */
     158             :   void updatePredictorRow();
     159             : 
     160             :   /// Sampler data for the current row
     161             :   std::vector<Real> _row_data;
     162             : 
     163             :   /// Predictor data for current row - can be combination of Sampler and Reporter values.
     164             :   std::vector<Real> _predictor_data;
     165             : 
     166             :   /// Whether or not we are skipping samples that have unconverged solutions
     167             :   const bool _skip_unconverged;
     168             : 
     169             :   /// Whether or not the current sample has a converged solution
     170             :   const bool * _converged;
     171             : 
     172             :   /// Number of samples used to train the model.
     173             :   unsigned int _current_sample_size;
     174             : 
     175             :   /// Number of samples (locally) used to train the model.
     176             :   unsigned int _local_sample_size;
     177             : 
     178             :   /// Vector of reporter names and their corresponding values (to be filled by getTrainingData)
     179             :   std::unordered_map<ReporterName, std::shared_ptr<TrainingDataBase>> _training_data;
     180             : 
     181             :   /*
     182             :    * Variables related to cross validation.
     183             :    */
     184             :   ///@{
     185             :   /// Vector of indices to skip during executeTraining()
     186             :   std::vector<dof_id_type> _skip_indices;
     187             :   /// Type of cross validation to perform - for now, just 'none' (no CV) or 'k_fold'
     188             :   const MooseEnum & _cv_type;
     189             :   /// Number of splits (k) to split sampler data into.
     190             :   const unsigned int & _n_splits;
     191             :   /// Number of repeated trials of cross validation to perform.
     192             :   const unsigned int & _cv_n_trials;
     193             :   /// Seed used for _cv_generator.
     194             :   const unsigned int & _cv_seed;
     195             :   /// Random number generator used for shuffling sampler rows during splitting.
     196             :   MooseRandom _cv_generator;
     197             :   /// SurrogateModel used to evaluate model error relative to test points.
     198             :   const SurrogateModel * _cv_surrogate;
     199             :   /// Set to true if cross validation is being performed, controls behavior in execute().
     200             :   const bool _doing_cv;
     201             :   /// RMSE scores from each CV trial - can be grabbed by VPP or Reporter.
     202             :   std::vector<std::vector<Real>> & _cv_trial_scores;
     203             :   ///@}
     204             : };
     205             : 
     206             : template <typename T>
     207             : const T &
     208        1230 : SurrogateTrainer::getTrainingData(const ReporterName & rname)
     209             : {
     210             :   auto it = _training_data.find(rname);
     211        1230 :   if (it != _training_data.end())
     212             :   {
     213           0 :     auto data = std::dynamic_pointer_cast<TrainingData<T>>(it->second);
     214           0 :     if (!data)
     215           0 :       mooseError("Reporter value ", rname, " already exists but is of different type.");
     216             :     return data->get();
     217             :   }
     218             :   else
     219             :   {
     220        1230 :     const std::vector<T> & rval = getReporterValueByName<std::vector<T>>(rname);
     221        1230 :     _training_data[rname] = std::make_shared<TrainingData<T>>(rval);
     222        2460 :     return std::dynamic_pointer_cast<TrainingData<T>>(_training_data[rname])->get();
     223             :   }
     224             : }
     225             : 
     226             : class TrainingDataBase
     227             : {
     228             : public:
     229        1230 :   TrainingDataBase() : _is_distributed(false) {}
     230             : 
     231             :   virtual ~TrainingDataBase() = default;
     232             : 
     233             :   virtual dof_id_type size() const = 0;
     234             :   virtual void setCurrentIndex(dof_id_type index) = 0;
     235             :   bool & isDistributed() { return _is_distributed; }
     236             : 
     237             : protected:
     238             :   bool _is_distributed;
     239             : };
     240             : 
     241             : template <typename T>
     242             : class TrainingData : public TrainingDataBase
     243             : {
     244             : public:
     245        1230 :   TrainingData(const std::vector<T> & vector) : _vector(vector) {}
     246             : 
     247        1850 :   virtual dof_id_type size() const override { return _vector.size(); }
     248      246015 :   virtual void setCurrentIndex(dof_id_type index) override { _value = _vector[index]; }
     249             : 
     250        1230 :   const T & get() const { return _value; }
     251             : 
     252             : private:
     253             :   const std::vector<T> & _vector;
     254             :   T _value;
     255             : };

Generated by: LCOV version 1.14