Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "StochasticToolsApp.h" 13 : #include "GeneralUserObject.h" 14 : #include "LoadSurrogateDataAction.h" 15 : #include "RestartableModelInterface.h" 16 : 17 : #include "Sampler.h" 18 : #include "StochasticToolsApp.h" 19 : #include "SurrogateModelInterface.h" 20 : #include "MooseRandom.h" 21 : 22 : class TrainingDataBase; 23 : template <typename T> 24 : class TrainingData; 25 : 26 : /** 27 : * This is the base trainer class whose main functionality is the API for declaring 28 : * model data. All trainer must at least derive from this. Unless a trainer needs 29 : * to perform its own loop through data, it is highly recommended to derive from 30 : * SurrogateTrainer. 31 : */ 32 : class SurrogateTrainerBase : public GeneralUserObject, public RestartableModelInterface 33 : { 34 : public: 35 : static InputParameters validParams(); 36 : SurrogateTrainerBase(const InputParameters & parameters); 37 : 38 16 : virtual void initialize() {} // not required, but available 39 2388 : virtual void finalize() {} // not required, but available 40 0 : virtual void threadJoin(const UserObject &) final {} // GeneralUserObjects are not threaded 41 : }; 42 : 43 : /** 44 : * This is the main trainer base class. The main purpose is to avoid a lot of code 45 : * duplication from performing sampler loops and dealing with distributed data. There 46 : * three functions that derived trainer should override: preTrain, train, and postTrain. 47 : * Derived class should also use the getTrainingData functionality, which provides a 48 : * refernce to vector reporter data in its current state within the sampler loop. 49 : * 50 : * The idea behind this is to emulate the element loop behaiviour in other MOOSE objects. 51 : * For instance, in a kernel, the value of _u corresponds to the solution in an element. 52 : * Here data referenced with getTrainingData will correspond to the the value of the 53 : * data in a sampler row. 54 : */ 55 : class SurrogateTrainer : public SurrogateTrainerBase, public SurrogateModelInterface 56 : { 57 : public: 58 : static InputParameters validParams(); 59 : SurrogateTrainer(const InputParameters & parameters); 60 : 61 : virtual void initialize() final; 62 : virtual void execute() final; 63 1190 : virtual void finalize() final {} 64 : 65 : protected: 66 : /* 67 : * Setup function called before sampler loop 68 : */ 69 0 : virtual void preTrain() {} 70 : 71 : /* 72 : * Function needed to be overried, called during sampler loop 73 : */ 74 0 : virtual void train() {} 75 : 76 : /* 77 : * Function called after sampler loop, used for mpi communication mainly 78 : */ 79 0 : virtual void postTrain() {} 80 : 81 : // TRAINING_DATA_BEGIN 82 : 83 : /* 84 : * Get a reference to training data given a reporter name 85 : */ 86 : template <typename T> 87 : const T & getTrainingData(const ReporterName & rname); 88 : 89 : /* 90 : * Get a reference to the sampler row data 91 : */ 92 244 : const std::vector<Real> & getSamplerData() const { return _row_data; }; 93 : 94 : /* 95 : * Get a reference to the predictor row data 96 : */ 97 1198 : const std::vector<Real> & getPredictorData() const { return _predictor_data; }; 98 : 99 : /* 100 : * Get current sample size (this is recalculated to reflect the number of skipped samples) 101 : */ 102 46352 : unsigned int getCurrentSampleSize() const { return _current_sample_size; }; 103 : 104 : /* 105 : * Get current local sample size (recalculated to reflect number of skipped samples) 106 : */ 107 1488 : unsigned int getLocalSampleSize() const { return _local_sample_size; }; 108 : 109 : // TRAINING_DATA_END 110 : 111 : /* 112 : * Evaluate CV error using _cv_surrogate and appropriate predictor row. 113 : */ 114 : virtual std::vector<Real> evaluateModelError(const SurrogateModel & surr); 115 : 116 : // TRAINING_DATA_MEMBERS 117 : ///@{ 118 : /// Sampler being used for training 119 : Sampler & _sampler; 120 : /// During training loop, this is the row index of the data 121 : dof_id_type _row; 122 : /// During training loop, this is the local row index of the data 123 : dof_id_type _local_row; 124 : /// Response value 125 : const Real * _rval; 126 : /// Vector response value 127 : const std::vector<Real> * _rvecval; 128 : /// Predictor values from reporters 129 : std::vector<const Real *> _pvals; 130 : /// Columns from sampler for predictors 131 : std::vector<unsigned int> _pcols; 132 : /// Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols.size(). 133 : unsigned int _n_dims; 134 : /// The number of outputs 135 : unsigned int & _n_outputs; 136 : ///@} 137 : // TRAINING_DATA_MEMBERS_END 138 : 139 : private: 140 : /* 141 : * Called at the beginning of execute() to make sure values are set properly 142 : */ 143 : void checkIntegrity() const; 144 : 145 : /* 146 : * Main model training method - called during crossValidate() and for final model training. 147 : */ 148 : void executeTraining(); 149 : 150 : /* 151 : * Call if cross-validation is turned on. 152 : */ 153 : std::vector<Real> crossValidate(); 154 : 155 : /* 156 : * Update predictor row (uses both Sampler and Reporter values, according to _pvals and _pcols) 157 : */ 158 : void updatePredictorRow(); 159 : 160 : /// Sampler data for the current row 161 : std::vector<Real> _row_data; 162 : 163 : /// Predictor data for current row - can be combination of Sampler and Reporter values. 164 : std::vector<Real> _predictor_data; 165 : 166 : /// Whether or not we are skipping samples that have unconverged solutions 167 : const bool _skip_unconverged; 168 : 169 : /// Whether or not the current sample has a converged solution 170 : const bool * _converged; 171 : 172 : /// Number of samples used to train the model. 173 : unsigned int _current_sample_size; 174 : 175 : /// Number of samples (locally) used to train the model. 176 : unsigned int _local_sample_size; 177 : 178 : /// Vector of reporter names and their corresponding values (to be filled by getTrainingData) 179 : std::unordered_map<ReporterName, std::shared_ptr<TrainingDataBase>> _training_data; 180 : 181 : /* 182 : * Variables related to cross validation. 183 : */ 184 : ///@{ 185 : /// Vector of indices to skip during executeTraining() 186 : std::vector<dof_id_type> _skip_indices; 187 : /// Type of cross validation to perform - for now, just 'none' (no CV) or 'k_fold' 188 : const MooseEnum & _cv_type; 189 : /// Number of splits (k) to split sampler data into. 190 : const unsigned int & _n_splits; 191 : /// Number of repeated trials of cross validation to perform. 192 : const unsigned int & _cv_n_trials; 193 : /// Seed used for _cv_generator. 194 : const unsigned int & _cv_seed; 195 : /// Random number generator used for shuffling sampler rows during splitting. 196 : MooseRandom _cv_generator; 197 : /// SurrogateModel used to evaluate model error relative to test points. 198 : const SurrogateModel * _cv_surrogate; 199 : /// Set to true if cross validation is being performed, controls behavior in execute(). 200 : const bool _doing_cv; 201 : /// RMSE scores from each CV trial - can be grabbed by VPP or Reporter. 202 : std::vector<std::vector<Real>> & _cv_trial_scores; 203 : ///@} 204 : }; 205 : 206 : template <typename T> 207 : const T & 208 1230 : SurrogateTrainer::getTrainingData(const ReporterName & rname) 209 : { 210 : auto it = _training_data.find(rname); 211 1230 : if (it != _training_data.end()) 212 : { 213 0 : auto data = std::dynamic_pointer_cast<TrainingData<T>>(it->second); 214 0 : if (!data) 215 0 : mooseError("Reporter value ", rname, " already exists but is of different type."); 216 : return data->get(); 217 : } 218 : else 219 : { 220 1230 : const std::vector<T> & rval = getReporterValueByName<std::vector<T>>(rname); 221 1230 : _training_data[rname] = std::make_shared<TrainingData<T>>(rval); 222 2460 : return std::dynamic_pointer_cast<TrainingData<T>>(_training_data[rname])->get(); 223 : } 224 : } 225 : 226 : class TrainingDataBase 227 : { 228 : public: 229 1230 : TrainingDataBase() : _is_distributed(false) {} 230 : 231 : virtual ~TrainingDataBase() = default; 232 : 233 : virtual dof_id_type size() const = 0; 234 : virtual void setCurrentIndex(dof_id_type index) = 0; 235 : bool & isDistributed() { return _is_distributed; } 236 : 237 : protected: 238 : bool _is_distributed; 239 : }; 240 : 241 : template <typename T> 242 : class TrainingData : public TrainingDataBase 243 : { 244 : public: 245 1230 : TrainingData(const std::vector<T> & vector) : _vector(vector) {} 246 : 247 1850 : virtual dof_id_type size() const override { return _vector.size(); } 248 246015 : virtual void setCurrentIndex(dof_id_type index) override { _value = _vector[index]; } 249 : 250 1230 : const T & get() const { return _value; } 251 : 252 : private: 253 : const std::vector<T> & _vector; 254 : T _value; 255 : };