Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #include "SurrogateTrainer.h"
11 : #include "SurrogateModel.h"
12 : #include "Sampler.h"
13 : #include "StochasticToolsApp.h"
14 : #include "MooseRandom.h"
15 : #include "Shuffle.h"
16 : #include <algorithm>
17 :
18 : InputParameters
19 3052 : SurrogateTrainerBase::validParams()
20 : {
21 3052 : InputParameters params = GeneralUserObject::validParams();
22 3052 : params += RestartableModelInterface::validParams();
23 3052 : params.registerBase("SurrogateTrainer");
24 3052 : return params;
25 0 : }
26 :
27 1536 : SurrogateTrainerBase::SurrogateTrainerBase(const InputParameters & parameters)
28 : : GeneralUserObject(parameters),
29 1536 : RestartableModelInterface(*this, /*read_only=*/false, _type + "_" + name())
30 : {
31 1536 : }
32 :
33 : InputParameters
34 2544 : SurrogateTrainer::validParams()
35 : {
36 2544 : InputParameters params = SurrogateTrainerBase::validParams();
37 5088 : params.addRequiredParam<SamplerName>("sampler",
38 : "Sampler used to create predictor and response data.");
39 5088 : params.addParam<ReporterName>(
40 : "converged_reporter",
41 : "Reporter value used to determine if a sample's multiapp solve converged.");
42 5088 : params.addParam<bool>("skip_unconverged_samples",
43 5088 : false,
44 : "True to skip samples where the multiapp did not converge, "
45 : "'stochastic_reporter' is required to do this.");
46 :
47 : // Common Training Data
48 5088 : MooseEnum data_type("real=0 vector_real=1", "real");
49 5088 : params.addRequiredParam<ReporterName>(
50 : "response",
51 : "Reporter value of response results, can be vpp with <vpp_name>/<vector_name> or sampler "
52 : "column with 'sampler/col_<index>'.");
53 5088 : params.addParam<MooseEnum>("response_type", data_type, "Response data type.");
54 2544 : params.addParam<std::vector<ReporterName>>(
55 : "predictors",
56 2544 : std::vector<ReporterName>(),
57 : "Reporter values used as the independent random variables, If 'predictors' and "
58 : "'predictor_cols' are both empty, all sampler columns are used.");
59 2544 : params.addParam<std::vector<unsigned int>>(
60 : "predictor_cols",
61 2544 : std::vector<unsigned int>(),
62 : "Sampler columns used as the independent random variables, If 'predictors' and "
63 : "'predictor_cols' are both empty, all sampler columns are used.");
64 : // End Common Training Data
65 :
66 5088 : MooseEnum cv_type("none=0 k_fold=1", "none");
67 5088 : params.addParam<MooseEnum>(
68 : "cv_type",
69 : cv_type,
70 : "Cross-validation method to use for dataset. Options are 'none' or 'k_fold'.");
71 7632 : params.addRangeCheckedParam<unsigned int>(
72 5088 : "cv_splits", 10, "cv_splits > 1", "Number of splits (k) to use in k-fold cross-validation.");
73 5088 : params.addParam<UserObjectName>("cv_surrogate",
74 : "Name of Surrogate object used for model cross-validation.");
75 5088 : params.addParam<unsigned int>(
76 5088 : "cv_n_trials", 1, "Number of repeated trials of cross-validation to perform.");
77 5088 : params.addParam<unsigned int>("cv_seed",
78 5088 : std::numeric_limits<unsigned int>::max(),
79 : "Seed used to initialize random number generator for data "
80 : "splitting during cross validation.");
81 :
82 2544 : return params;
83 2544 : }
84 :
85 1276 : SurrogateTrainer::SurrogateTrainer(const InputParameters & parameters)
86 : : SurrogateTrainerBase(parameters),
87 : SurrogateModelInterface(this),
88 1276 : _sampler(getSampler("sampler")),
89 1276 : _rval(nullptr),
90 1276 : _rvecval(nullptr),
91 2552 : _pvals(getParam<std::vector<ReporterName>>("predictors").size()),
92 2552 : _pcols(getParam<std::vector<unsigned int>>("predictor_cols")),
93 2552 : _n_outputs(declareModelData<unsigned int>("_n_outputs", 1)),
94 1276 : _row_data(_sampler.getNumberOfCols()),
95 2552 : _skip_unconverged(getParam<bool>("skip_unconverged_samples")),
96 1276 : _converged(nullptr),
97 2552 : _cv_type(getParam<MooseEnum>("cv_type")),
98 2552 : _n_splits(getParam<unsigned int>("cv_splits")),
99 2552 : _cv_n_trials(getParam<unsigned int>("cv_n_trials")),
100 2552 : _cv_seed(getParam<unsigned int>("cv_seed")),
101 1276 : _doing_cv(_cv_type != "none"),
102 3828 : _cv_trial_scores(declareModelData<std::vector<std::vector<Real>>>("cv_scores"))
103 : {
104 1276 : if (_skip_unconverged)
105 : {
106 0 : if (!isParamValid("converged_reporter"))
107 0 : paramError("skip_unconverged_samples",
108 : "'converged_reporter' needs to be specified to skip unconverged sample.");
109 0 : _converged = &getTrainingData<bool>(getParam<ReporterName>("converged_reporter"));
110 : }
111 :
112 1276 : if (_doing_cv)
113 : {
114 228 : if (!isParamValid("cv_surrogate"))
115 4 : paramError("cv_type",
116 : "To perform cross-validation, the option cv_surrogate needs to be specified",
117 : " to provide a Surrogate object for training and evaluation.");
118 :
119 110 : if (_n_splits > _sampler.getNumberOfRows())
120 0 : paramError("cv_splits",
121 : "The specified number of splits (cv_splits = ",
122 0 : _n_splits,
123 : ")",
124 : " exceeds the number of rows in Sampler '",
125 : getParam<SamplerName>("sampler"),
126 : "'");
127 :
128 110 : _cv_generator.seed(0, _cv_seed);
129 : }
130 :
131 : // Get TrainingData for responses and predictors
132 3816 : if (getParam<MooseEnum>("response_type") == 0)
133 2286 : _rval = &getTrainingData<Real>(getParam<ReporterName>("response"));
134 387 : else if (getParam<MooseEnum>("response_type") == 1)
135 258 : _rvecval = &getTrainingData<std::vector<Real>>(getParam<ReporterName>("response"));
136 :
137 1272 : const auto & pnames = getParam<std::vector<ReporterName>>("predictors");
138 1306 : for (unsigned int i = 0; i < pnames.size(); ++i)
139 34 : _pvals[i] = &getTrainingData<Real>(pnames[i]);
140 :
141 : // If predictors and predictor_cols are empty, use all sampler columns
142 1272 : if (_pvals.empty() && _pcols.empty())
143 : {
144 1238 : _pcols.resize(_sampler.getNumberOfCols());
145 : std::iota(_pcols.begin(), _pcols.end(), 0);
146 : }
147 1272 : _n_dims = _pvals.size() + _pcols.size();
148 :
149 1272 : _predictor_data.resize(_n_dims);
150 1272 : }
151 :
152 : void
153 1268 : SurrogateTrainer::initialize()
154 : {
155 : // Figure out if data is distributed
156 2570 : for (auto & pair : _training_data)
157 : {
158 1302 : const ReporterName & name = pair.first;
159 : TrainingDataBase & data = *pair.second;
160 :
161 1302 : const auto & mode = _fe_problem.getReporterData().getReporterMode(name);
162 1302 : if (mode == REPORTER_MODE_DISTRIBUTED || (mode == REPORTER_MODE_ROOT && processor_id() != 0))
163 962 : data.isDistributed() = true;
164 578 : else if (mode == REPORTER_MODE_REPLICATED ||
165 476 : (mode == REPORTER_MODE_ROOT && processor_id() == 0))
166 340 : data.isDistributed() = false;
167 : else
168 0 : mooseError("Predictor reporter value ", name, " is not of supported mode.");
169 : }
170 :
171 1268 : if (_doing_cv)
172 110 : _cv_surrogate = &getSurrogateModel("cv_surrogate");
173 1268 : }
174 :
175 : void
176 1268 : SurrogateTrainer::execute()
177 : {
178 1268 : if (_doing_cv)
179 440 : for (const auto & trial : make_range(_cv_n_trials))
180 : {
181 330 : std::vector<Real> trial_score = crossValidate();
182 :
183 : // Expand _cv_trial_scores with more columns if necessary, then insert values.
184 593 : for (unsigned int r = _cv_trial_scores.size(); r < trial_score.size(); ++r)
185 526 : _cv_trial_scores.push_back(std::vector<Real>(_cv_n_trials, 0.0));
186 1119 : for (auto r : make_range(trial_score.size()))
187 789 : _cv_trial_scores[r][trial] = trial_score[r];
188 330 : }
189 :
190 1268 : _current_sample_size = _sampler.getNumberOfRows();
191 1268 : _local_sample_size = _sampler.getNumberOfLocalRows();
192 1268 : executeTraining();
193 1264 : }
194 :
195 : void
196 1928 : SurrogateTrainer::checkIntegrity() const
197 : {
198 : // Check that the number of sampler columns hasn't changed
199 1928 : if (_row_data.size() != _sampler.getNumberOfCols())
200 0 : mooseError("Number of sampler columns has changed.");
201 :
202 : // Check that training data is correctly sized
203 3886 : for (auto & pair : _training_data)
204 : {
205 1962 : dof_id_type rsize = pair.second->size();
206 : dof_id_type nrow =
207 1962 : pair.second->isDistributed() ? _sampler.getNumberOfLocalRows() : _sampler.getNumberOfRows();
208 1962 : if (rsize != nrow)
209 4 : mooseError("Reporter value ",
210 4 : pair.first,
211 : " of size ",
212 : rsize,
213 : " does not match sampler size (",
214 : nrow,
215 : ").");
216 : }
217 1924 : }
218 :
219 : void
220 1928 : SurrogateTrainer::executeTraining()
221 : {
222 1928 : checkIntegrity();
223 1924 : _row = _sampler.getLocalRowBegin();
224 1924 : _local_row = 0;
225 :
226 1924 : preTrain();
227 :
228 255187 : for (_row = _sampler.getLocalRowBegin(); _row < _sampler.getLocalRowEnd(); ++_row)
229 : {
230 : // Need to do this manually in order to keep the iterators valid
231 253263 : const std::vector<Real> data = _sampler.getNextLocalRow();
232 1556274 : for (unsigned int i = 0; i < _row_data.size(); ++i)
233 1303011 : _row_data[i] = data[i];
234 :
235 : // Set training data
236 518417 : for (auto & pair : _training_data)
237 265154 : pair.second->setCurrentIndex((pair.second->isDistributed() ? _local_row : _row));
238 :
239 253263 : updatePredictorRow();
240 :
241 253263 : if ((!_skip_unconverged || *_converged) &&
242 253263 : std::find(_skip_indices.begin(), _skip_indices.end(), _row) == _skip_indices.end())
243 248163 : train();
244 :
245 253263 : _local_row++;
246 253263 : }
247 :
248 1924 : postTrain();
249 1924 : }
250 :
251 : std::vector<Real>
252 330 : SurrogateTrainer::crossValidate()
253 : {
254 330 : std::vector<Real> cv_score(1, 0.0);
255 :
256 : // Get skipped indices for each split
257 330 : dof_id_type n_rows = _sampler.getNumberOfRows();
258 : std::vector<std::vector<dof_id_type>> split_indices;
259 330 : if (processor_id() == 0)
260 : {
261 213 : std::vector<dof_id_type> indices_flat(n_rows);
262 : std::iota(indices_flat.begin(), indices_flat.end(), 0);
263 213 : MooseUtils::shuffle(indices_flat, _cv_generator, 0);
264 :
265 213 : split_indices.resize(_n_splits);
266 639 : for (const auto & k : make_range(_n_splits))
267 : {
268 852 : const dof_id_type num_ind = n_rows / _n_splits + (k < (n_rows % _n_splits) ? 1 : 0);
269 426 : split_indices[k].insert(split_indices[k].begin(),
270 : std::make_move_iterator(indices_flat.begin()),
271 : std::make_move_iterator(indices_flat.begin() + num_ind));
272 426 : std::sort(split_indices[k].begin(), split_indices[k].end());
273 : indices_flat.erase(indices_flat.begin(), indices_flat.begin() + num_ind);
274 : }
275 213 : }
276 :
277 : std::vector<dof_id_type> split_ids_buffer;
278 990 : for (const auto & k : make_range(_n_splits))
279 : {
280 660 : if (processor_id() == 0)
281 426 : split_ids_buffer = split_indices[k];
282 660 : _communicator.broadcast(split_ids_buffer, 0);
283 :
284 660 : _current_sample_size = _sampler.getNumberOfRows() - split_ids_buffer.size();
285 :
286 660 : auto first = std::lower_bound(
287 660 : split_ids_buffer.begin(), split_ids_buffer.end(), _sampler.getLocalRowBegin());
288 660 : auto last = std::upper_bound(
289 660 : split_ids_buffer.begin(), split_ids_buffer.end(), _sampler.getLocalRowEnd());
290 660 : _skip_indices.insert(_skip_indices.begin(), first, last);
291 :
292 660 : _local_sample_size = _sampler.getNumberOfLocalRows() - _skip_indices.size();
293 :
294 : // Train the model
295 660 : executeTraining();
296 :
297 : // Evaluate the model
298 660 : std::vector<Real> split_mse(1, 0.0);
299 660 : std::vector<Real> row_mse(1, 0.0);
300 :
301 : auto skipped_row = _skip_indices.begin();
302 :
303 11520 : for (dof_id_type p = _sampler.getLocalRowBegin(); p < _sampler.getLocalRowEnd(); ++p)
304 : {
305 10200 : const std::vector<Real> row = _sampler.getNextLocalRow();
306 10200 : if (skipped_row != _skip_indices.end() && p == *skipped_row)
307 : {
308 14940 : for (unsigned int i = 0; i < _row_data.size(); ++i)
309 9840 : _row_data[i] = row[i];
310 :
311 10200 : for (auto & pair : _training_data)
312 5100 : pair.second->setCurrentIndex(
313 5100 : (pair.second->isDistributed() ? p - _sampler.getLocalRowBegin() : p));
314 :
315 5100 : updatePredictorRow();
316 :
317 10200 : row_mse = evaluateModelError(*_cv_surrogate);
318 :
319 : // Expand split_mse if needed.
320 5100 : split_mse.resize(row_mse.size(), 0.0);
321 :
322 : // Increment errors
323 13170 : for (unsigned int r = 0; r < split_mse.size(); ++r)
324 8070 : split_mse[r] += row_mse[r];
325 :
326 : skipped_row++;
327 : }
328 10200 : }
329 : gatherSum(split_mse);
330 :
331 : // Expand cv_score if necessary.
332 660 : cv_score.resize(split_mse.size(), 0.0);
333 :
334 2238 : for (auto r : make_range(split_mse.size()))
335 1578 : cv_score[r] += split_mse[r] / n_rows;
336 :
337 660 : _skip_indices.clear();
338 660 : }
339 :
340 1119 : for (auto r : make_range(cv_score.size()))
341 789 : cv_score[r] = std::sqrt(cv_score[r]);
342 :
343 330 : return cv_score;
344 330 : }
345 :
346 : std::vector<Real>
347 5100 : SurrogateTrainer::evaluateModelError(const SurrogateModel & surr)
348 : {
349 5100 : std::vector<Real> error(1, 0.0);
350 :
351 5100 : if (_rval)
352 : {
353 4770 : Real model_eval = surr.evaluate(_predictor_data);
354 4770 : error[0] = MathUtils::pow(model_eval - (*_rval), 2);
355 : }
356 330 : else if (_rvecval)
357 : {
358 330 : error.resize(_rvecval->size());
359 :
360 : // Evaluate for vector response.
361 330 : std::vector<Real> model_eval(error.size());
362 330 : surr.evaluate(_predictor_data, model_eval);
363 3630 : for (auto r : make_range(_rvecval->size()))
364 3300 : error[r] = MathUtils::pow(model_eval[r] - (*_rvecval)[r], 2);
365 330 : }
366 :
367 5100 : return error;
368 0 : }
369 :
370 : void
371 258363 : SurrogateTrainer::updatePredictorRow()
372 : {
373 : unsigned int d = 0;
374 270254 : for (const auto & val : _pvals)
375 11891 : _predictor_data[d++] = *val;
376 1559323 : for (const auto & col : _pcols)
377 1300960 : _predictor_data[d++] = _row_data[col];
378 258363 : }
|