https://mooseframework.inl.gov
SurrogateTrainer.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #pragma once
11 
12 #include "StochasticToolsApp.h"
13 #include "GeneralUserObject.h"
16 
17 #include "Sampler.h"
18 #include "StochasticToolsApp.h"
20 #include "MooseRandom.h"
21 
22 class TrainingDataBase;
23 template <typename T>
25 
33 {
34 public:
37 
38  virtual void initialize() {} // not required, but available
39  virtual void finalize() {} // not required, but available
40  virtual void threadJoin(const UserObject &) final {} // GeneralUserObjects are not threaded
41 };
42 
56 {
57 public:
60 
61  virtual void initialize() final;
62  virtual void execute() final;
63  virtual void finalize() final {}
64 
65 protected:
66  /*
67  * Setup function called before sampler loop
68  */
69  virtual void preTrain() {}
70 
71  /*
72  * Function needed to be overried, called during sampler loop
73  */
74  virtual void train() {}
75 
76  /*
77  * Function called after sampler loop, used for mpi communication mainly
78  */
79  virtual void postTrain() {}
80 
81  // TRAINING_DATA_BEGIN
82 
83  /*
84  * Get a reference to training data given a reporter name
85  */
86  template <typename T>
87  const T & getTrainingData(const ReporterName & rname);
88 
89  /*
90  * Get a reference to the sampler row data
91  */
92  const std::vector<Real> & getSamplerData() const { return _row_data; };
93 
94  /*
95  * Get a reference to the predictor row data
96  */
97  const std::vector<Real> & getPredictorData() const { return _predictor_data; };
98 
99  /*
100  * Get current sample size (this is recalculated to reflect the number of skipped samples)
101  */
102  unsigned int getCurrentSampleSize() const { return _current_sample_size; };
103 
104  /*
105  * Get current local sample size (recalculated to reflect number of skipped samples)
106  */
107  unsigned int getLocalSampleSize() const { return _local_sample_size; };
108 
109  // TRAINING_DATA_END
110 
111  /*
112  * Evaluate CV error using _cv_surrogate and appropriate predictor row.
113  */
114  virtual std::vector<Real> evaluateModelError(const SurrogateModel & surr);
115 
116  // TRAINING_DATA_MEMBERS
125  const Real * _rval;
127  const std::vector<Real> * _rvecval;
129  std::vector<const Real *> _pvals;
131  std::vector<unsigned int> _pcols;
133  unsigned int _n_dims;
135  unsigned int & _n_outputs;
137  // TRAINING_DATA_MEMBERS_END
138 
139 private:
140  /*
141  * Called at the beginning of execute() to make sure values are set properly
142  */
143  void checkIntegrity() const;
144 
145  /*
146  * Main model training method - called during crossValidate() and for final model training.
147  */
148  void executeTraining();
149 
150  /*
151  * Call if cross-validation is turned on.
152  */
153  std::vector<Real> crossValidate();
154 
155  /*
156  * Update predictor row (uses both Sampler and Reporter values, according to _pvals and _pcols)
157  */
158  void updatePredictorRow();
159 
161  std::vector<Real> _row_data;
162 
164  std::vector<Real> _predictor_data;
165 
167  const bool _skip_unconverged;
168 
170  const bool * _converged;
171 
173  unsigned int _current_sample_size;
174 
176  unsigned int _local_sample_size;
177 
179  std::unordered_map<ReporterName, std::shared_ptr<TrainingDataBase>> _training_data;
180 
181  /*
182  * Variables related to cross validation.
183  */
186  std::vector<dof_id_type> _skip_indices;
190  const unsigned int & _n_splits;
192  const unsigned int & _cv_n_trials;
194  const unsigned int & _cv_seed;
200  const bool _doing_cv;
202  std::vector<std::vector<Real>> & _cv_trial_scores;
204 };
205 
206 template <typename T>
207 const T &
209 {
210  auto it = _training_data.find(rname);
211  if (it != _training_data.end())
212  {
213  auto data = std::dynamic_pointer_cast<TrainingData<T>>(it->second);
214  if (!data)
215  mooseError("Reporter value ", rname, " already exists but is of different type.");
216  return data->get();
217  }
218  else
219  {
220  const std::vector<T> & rval = getReporterValueByName<std::vector<T>>(rname);
221  _training_data[rname] = std::make_shared<TrainingData<T>>(rval);
222  return std::dynamic_pointer_cast<TrainingData<T>>(_training_data[rname])->get();
223  }
224 }
225 
227 {
228 public:
230 
231  virtual ~TrainingDataBase() = default;
232 
233  virtual dof_id_type size() const = 0;
234  virtual void setCurrentIndex(dof_id_type index) = 0;
235  bool & isDistributed() { return _is_distributed; }
236 
237 protected:
239 };
240 
241 template <typename T>
242 class TrainingData : public TrainingDataBase
243 {
244 public:
245  TrainingData(const std::vector<T> & vector) : _vector(vector) {}
246 
247  virtual dof_id_type size() const override { return _vector.size(); }
248  virtual void setCurrentIndex(dof_id_type index) override { _value = _vector[index]; }
249 
250  const T & get() const { return _value; }
251 
252 private:
253  const std::vector<T> & _vector;
255 };
SurrogateTrainerBase(const InputParameters &parameters)
virtual void initialize() final
unsigned int getCurrentSampleSize() const
const bool _doing_cv
Set to true if cross validation is being performed, controls behavior in execute().
virtual void finalize() final
const Real * _rval
Response value.
unsigned int _n_dims
Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols...
virtual void initialize()
TrainingData(const std::vector< T > &vector)
const unsigned int & _cv_n_trials
Number of repeated trials of cross validation to perform.
const std::vector< Real > * _rvecval
Vector response value.
std::vector< unsigned int > _pcols
Columns from sampler for predictors.
std::vector< const Real * > _pvals
Predictor values from reporters.
MooseRandom _cv_generator
Random number generator used for shuffling sampler rows during splitting.
std::vector< std::vector< Real > > & _cv_trial_scores
RMSE scores from each CV trial - can be grabbed by VPP or Reporter.
virtual void setCurrentIndex(dof_id_type index)=0
unsigned int & _n_outputs
The number of outputs.
static InputParameters validParams()
void checkIntegrity() const
std::vector< Real > crossValidate()
dof_id_type _row
During training loop, this is the row index of the data.
virtual void train()
dof_id_type _local_row
During training loop, this is the local row index of the data.
virtual void postTrain()
std::vector< Real > _predictor_data
Predictor data for current row - can be combination of Sampler and Reporter values.
const std::vector< Real > & getSamplerData() const
std::vector< dof_id_type > _skip_indices
std::unordered_map< ReporterName, std::shared_ptr< TrainingDataBase > > _training_data
Vector of reporter names and their corresponding values (to be filled by getTrainingData) ...
virtual void finalize()
virtual void threadJoin(const UserObject &) final
const T & getTrainingData(const ReporterName &rname)
virtual dof_id_type size() const =0
virtual dof_id_type size() const override
virtual void preTrain()
virtual std::vector< Real > evaluateModelError(const SurrogateModel &surr)
const std::vector< T > & _vector
unsigned int getLocalSampleSize() const
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
This is the main trainer base class.
Interface for objects that need to use samplers.
const std::vector< Real > & getPredictorData() const
std::vector< Real > _row_data
Sampler data for the current row.
unsigned int _local_sample_size
Number of samples (locally) used to train the model.
void mooseError(Args &&... args) const
const InputParameters & parameters() const
unsigned int _current_sample_size
Number of samples used to train the model.
const unsigned int & _n_splits
Number of splits (k) to split sampler data into.
const bool * _converged
Whether or not the current sample has a converged solution.
SurrogateTrainer(const InputParameters &parameters)
const SurrogateModel * _cv_surrogate
SurrogateModel used to evaluate model error relative to test points.
static InputParameters validParams()
This is the base trainer class whose main functionality is the API for declaring model data...
virtual void setCurrentIndex(dof_id_type index) override
const bool _skip_unconverged
Whether or not we are skipping samples that have unconverged solutions.
const unsigned int & _cv_seed
Seed used for _cv_generator.
virtual ~TrainingDataBase()=default
const MooseEnum & _cv_type
Type of cross validation to perform - for now, just &#39;none&#39; (no CV) or &#39;k_fold&#39;.
uint8_t dof_id_type
An interface class which manages the model data save and load functionalities from moose objects (suc...
virtual void execute() final