Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #include "SurrogateTrainer.h"
11 : #include "SurrogateModel.h"
12 : #include "Sampler.h"
13 : #include "StochasticToolsApp.h"
14 : #include "MooseRandom.h"
15 : #include "Shuffle.h"
16 : #include <algorithm>
17 :
18 : InputParameters
19 2876 : SurrogateTrainerBase::validParams()
20 : {
21 2876 : InputParameters params = GeneralUserObject::validParams();
22 2876 : params += RestartableModelInterface::validParams();
23 2876 : params.registerBase("SurrogateTrainer");
24 2876 : return params;
25 0 : }
26 :
27 1448 : SurrogateTrainerBase::SurrogateTrainerBase(const InputParameters & parameters)
28 : : GeneralUserObject(parameters),
29 1448 : RestartableModelInterface(*this, /*read_only=*/false, _type + "_" + name())
30 : {
31 1448 : }
32 :
33 : InputParameters
34 2396 : SurrogateTrainer::validParams()
35 : {
36 2396 : InputParameters params = SurrogateTrainerBase::validParams();
37 4792 : params.addRequiredParam<SamplerName>("sampler",
38 : "Sampler used to create predictor and response data.");
39 4792 : params.addParam<ReporterName>(
40 : "converged_reporter",
41 : "Reporter value used to determine if a sample's multiapp solve converged.");
42 4792 : params.addParam<bool>("skip_unconverged_samples",
43 4792 : false,
44 : "True to skip samples where the multiapp did not converge, "
45 : "'stochastic_reporter' is required to do this.");
46 :
47 : // Common Training Data
48 4792 : MooseEnum data_type("real=0 vector_real=1", "real");
49 4792 : params.addRequiredParam<ReporterName>(
50 : "response",
51 : "Reporter value of response results, can be vpp with <vpp_name>/<vector_name> or sampler "
52 : "column with 'sampler/col_<index>'.");
53 4792 : params.addParam<MooseEnum>("response_type", data_type, "Response data type.");
54 2396 : params.addParam<std::vector<ReporterName>>(
55 : "predictors",
56 2396 : std::vector<ReporterName>(),
57 : "Reporter values used as the independent random variables, If 'predictors' and "
58 : "'predictor_cols' are both empty, all sampler columns are used.");
59 4792 : params.addParam<std::vector<unsigned int>>(
60 : "predictor_cols",
61 2396 : std::vector<unsigned int>(),
62 : "Sampler columns used as the independent random variables, If 'predictors' and "
63 : "'predictor_cols' are both empty, all sampler columns are used.");
64 : // End Common Training Data
65 :
66 4792 : MooseEnum cv_type("none=0 k_fold=1", "none");
67 4792 : params.addParam<MooseEnum>(
68 : "cv_type",
69 : cv_type,
70 : "Cross-validation method to use for dataset. Options are 'none' or 'k_fold'.");
71 7188 : params.addRangeCheckedParam<unsigned int>(
72 4792 : "cv_splits", 10, "cv_splits > 1", "Number of splits (k) to use in k-fold cross-validation.");
73 4792 : params.addParam<UserObjectName>("cv_surrogate",
74 : "Name of Surrogate object used for model cross-validation.");
75 4792 : params.addParam<unsigned int>(
76 4792 : "cv_n_trials", 1, "Number of repeated trials of cross-validation to perform.");
77 4792 : params.addParam<unsigned int>("cv_seed",
78 4792 : std::numeric_limits<unsigned int>::max(),
79 : "Seed used to initialize random number generator for data "
80 : "splitting during cross validation.");
81 :
82 2396 : return params;
83 2396 : }
84 :
85 1202 : SurrogateTrainer::SurrogateTrainer(const InputParameters & parameters)
86 : : SurrogateTrainerBase(parameters),
87 : SurrogateModelInterface(this),
88 1202 : _sampler(getSampler("sampler")),
89 1202 : _rval(nullptr),
90 1202 : _rvecval(nullptr),
91 2404 : _pvals(getParam<std::vector<ReporterName>>("predictors").size()),
92 2404 : _pcols(getParam<std::vector<unsigned int>>("predictor_cols")),
93 2404 : _n_outputs(declareModelData<unsigned int>("_n_outputs", 1)),
94 1202 : _row_data(_sampler.getNumberOfCols()),
95 2404 : _skip_unconverged(getParam<bool>("skip_unconverged_samples")),
96 1202 : _converged(nullptr),
97 2404 : _cv_type(getParam<MooseEnum>("cv_type")),
98 2404 : _n_splits(getParam<unsigned int>("cv_splits")),
99 2404 : _cv_n_trials(getParam<unsigned int>("cv_n_trials")),
100 2404 : _cv_seed(getParam<unsigned int>("cv_seed")),
101 1202 : _doing_cv(_cv_type != "none"),
102 4808 : _cv_trial_scores(declareModelData<std::vector<std::vector<Real>>>("cv_scores"))
103 : {
104 1202 : if (_skip_unconverged)
105 : {
106 0 : if (!isParamValid("converged_reporter"))
107 0 : paramError("skip_unconverged_samples",
108 : "'converged_reporter' needs to be specified to skip unconverged sample.");
109 0 : _converged = &getTrainingData<bool>(getParam<ReporterName>("converged_reporter"));
110 : }
111 :
112 1202 : if (_doing_cv)
113 : {
114 216 : if (!isParamValid("cv_surrogate"))
115 4 : paramError("cv_type",
116 : "To perform cross-validation, the option cv_surrogate needs to be specified",
117 : " to provide a Surrogate object for training and evaluation.");
118 :
119 104 : if (_n_splits > _sampler.getNumberOfRows())
120 0 : paramError("cv_splits",
121 : "The specified number of splits (cv_splits = ",
122 0 : _n_splits,
123 : ")",
124 : " exceeds the number of rows in Sampler '",
125 : getParam<SamplerName>("sampler"),
126 : "'");
127 :
128 104 : _cv_generator.seed(0, _cv_seed);
129 : }
130 :
131 : // Get TrainingData for responses and predictors
132 3594 : if (getParam<MooseEnum>("response_type") == 0)
133 2156 : _rval = &getTrainingData<Real>(getParam<ReporterName>("response"));
134 360 : else if (getParam<MooseEnum>("response_type") == 1)
135 240 : _rvecval = &getTrainingData<std::vector<Real>>(getParam<ReporterName>("response"));
136 :
137 1198 : const auto & pnames = getParam<std::vector<ReporterName>>("predictors");
138 1230 : for (unsigned int i = 0; i < pnames.size(); ++i)
139 32 : _pvals[i] = &getTrainingData<Real>(pnames[i]);
140 :
141 : // If predictors and predictor_cols are empty, use all sampler columns
142 1198 : if (_pvals.empty() && _pcols.empty())
143 : {
144 1166 : _pcols.resize(_sampler.getNumberOfCols());
145 : std::iota(_pcols.begin(), _pcols.end(), 0);
146 : }
147 1198 : _n_dims = _pvals.size() + _pcols.size();
148 :
149 1198 : _predictor_data.resize(_n_dims);
150 1198 : }
151 :
152 : void
153 1194 : SurrogateTrainer::initialize()
154 : {
155 : // Figure out if data is distributed
156 2420 : for (auto & pair : _training_data)
157 : {
158 1226 : const ReporterName & name = pair.first;
159 : TrainingDataBase & data = *pair.second;
160 :
161 1226 : const auto & mode = _fe_problem.getReporterData().getReporterMode(name);
162 1226 : if (mode == REPORTER_MODE_DISTRIBUTED || (mode == REPORTER_MODE_ROOT && processor_id() != 0))
163 911 : data.isDistributed() = true;
164 534 : else if (mode == REPORTER_MODE_REPLICATED ||
165 438 : (mode == REPORTER_MODE_ROOT && processor_id() == 0))
166 315 : data.isDistributed() = false;
167 : else
168 0 : mooseError("Predictor reporter value ", name, " is not of supported mode.");
169 : }
170 :
171 1194 : if (_doing_cv)
172 104 : _cv_surrogate = &getSurrogateModel("cv_surrogate");
173 1194 : }
174 :
175 : void
176 1194 : SurrogateTrainer::execute()
177 : {
178 1194 : if (_doing_cv)
179 416 : for (const auto & trial : make_range(_cv_n_trials))
180 : {
181 312 : std::vector<Real> trial_score = crossValidate();
182 :
183 : // Expand _cv_trial_scores with more columns if necessary, then insert values.
184 560 : for (unsigned int r = _cv_trial_scores.size(); r < trial_score.size(); ++r)
185 496 : _cv_trial_scores.push_back(std::vector<Real>(_cv_n_trials, 0.0));
186 1056 : for (auto r : make_range(trial_score.size()))
187 744 : _cv_trial_scores[r][trial] = trial_score[r];
188 : }
189 :
190 1194 : _current_sample_size = _sampler.getNumberOfRows();
191 1194 : _local_sample_size = _sampler.getNumberOfLocalRows();
192 1194 : executeTraining();
193 1190 : }
194 :
195 : void
196 1818 : SurrogateTrainer::checkIntegrity() const
197 : {
198 : // Check that the number of sampler columns hasn't changed
199 1818 : if (_row_data.size() != _sampler.getNumberOfCols())
200 0 : mooseError("Number of sampler columns has changed.");
201 :
202 : // Check that training data is correctly sized
203 3664 : for (auto & pair : _training_data)
204 : {
205 1850 : dof_id_type rsize = pair.second->size();
206 : dof_id_type nrow =
207 1850 : pair.second->isDistributed() ? _sampler.getNumberOfLocalRows() : _sampler.getNumberOfRows();
208 1850 : if (rsize != nrow)
209 4 : mooseError("Reporter value ",
210 4 : pair.first,
211 : " of size ",
212 : rsize,
213 : " does not match sampler size (",
214 : nrow,
215 : ").");
216 : }
217 1814 : }
218 :
219 : void
220 1818 : SurrogateTrainer::executeTraining()
221 : {
222 1818 : checkIntegrity();
223 1814 : _row = _sampler.getLocalRowBegin();
224 1814 : _local_row = 0;
225 :
226 1814 : preTrain();
227 :
228 232369 : for (_row = _sampler.getLocalRowBegin(); _row < _sampler.getLocalRowEnd(); ++_row)
229 : {
230 : // Need to do this manually in order to keep the iterators valid
231 230555 : const std::vector<Real> data = _sampler.getNextLocalRow();
232 1416090 : for (unsigned int i = 0; i < _row_data.size(); ++i)
233 1185535 : _row_data[i] = data[i];
234 :
235 : // Set training data
236 471920 : for (auto & pair : _training_data)
237 241365 : pair.second->setCurrentIndex((pair.second->isDistributed() ? _local_row : _row));
238 :
239 230555 : updatePredictorRow();
240 :
241 230555 : if ((!_skip_unconverged || *_converged) &&
242 230555 : std::find(_skip_indices.begin(), _skip_indices.end(), _row) == _skip_indices.end())
243 225905 : train();
244 :
245 230555 : _local_row++;
246 : }
247 :
248 1814 : postTrain();
249 1814 : }
250 :
251 : std::vector<Real>
252 312 : SurrogateTrainer::crossValidate()
253 : {
254 312 : std::vector<Real> cv_score(1, 0.0);
255 :
256 : // Get skipped indices for each split
257 312 : dof_id_type n_rows = _sampler.getNumberOfRows();
258 : std::vector<std::vector<dof_id_type>> split_indices;
259 312 : if (processor_id() == 0)
260 : {
261 195 : std::vector<dof_id_type> indices_flat(n_rows);
262 : std::iota(indices_flat.begin(), indices_flat.end(), 0);
263 195 : MooseUtils::shuffle(indices_flat, _cv_generator, 0);
264 :
265 195 : split_indices.resize(_n_splits);
266 585 : for (const auto & k : make_range(_n_splits))
267 : {
268 780 : const dof_id_type num_ind = n_rows / _n_splits + (k < (n_rows % _n_splits) ? 1 : 0);
269 390 : split_indices[k].insert(split_indices[k].begin(),
270 : std::make_move_iterator(indices_flat.begin()),
271 : std::make_move_iterator(indices_flat.begin() + num_ind));
272 390 : std::sort(split_indices[k].begin(), split_indices[k].end());
273 : indices_flat.erase(indices_flat.begin(), indices_flat.begin() + num_ind);
274 : }
275 : }
276 :
277 : std::vector<dof_id_type> split_ids_buffer;
278 936 : for (const auto & k : make_range(_n_splits))
279 : {
280 624 : if (processor_id() == 0)
281 390 : split_ids_buffer = split_indices[k];
282 624 : _communicator.broadcast(split_ids_buffer, 0);
283 :
284 624 : _current_sample_size = _sampler.getNumberOfRows() - split_ids_buffer.size();
285 :
286 624 : auto first = std::lower_bound(
287 624 : split_ids_buffer.begin(), split_ids_buffer.end(), _sampler.getLocalRowBegin());
288 624 : auto last = std::upper_bound(
289 624 : split_ids_buffer.begin(), split_ids_buffer.end(), _sampler.getLocalRowEnd());
290 624 : _skip_indices.insert(_skip_indices.begin(), first, last);
291 :
292 624 : _local_sample_size = _sampler.getNumberOfLocalRows() - _skip_indices.size();
293 :
294 : // Train the model
295 624 : executeTraining();
296 :
297 : // Evaluate the model
298 624 : std::vector<Real> split_mse(1, 0.0);
299 624 : std::vector<Real> row_mse(1, 0.0);
300 :
301 : auto skipped_row = _skip_indices.begin();
302 :
303 9924 : for (dof_id_type p = _sampler.getLocalRowBegin(); p < _sampler.getLocalRowEnd(); ++p)
304 : {
305 9300 : const std::vector<Real> row = _sampler.getNextLocalRow();
306 9300 : if (skipped_row != _skip_indices.end() && p == *skipped_row)
307 : {
308 13650 : for (unsigned int i = 0; i < _row_data.size(); ++i)
309 9000 : _row_data[i] = row[i];
310 :
311 9300 : for (auto & pair : _training_data)
312 4650 : pair.second->setCurrentIndex(
313 4650 : (pair.second->isDistributed() ? p - _sampler.getLocalRowBegin() : p));
314 :
315 4650 : updatePredictorRow();
316 :
317 4650 : row_mse = evaluateModelError(*_cv_surrogate);
318 :
319 : // Expand split_mse if needed.
320 4650 : split_mse.resize(row_mse.size(), 0.0);
321 :
322 : // Increment errors
323 12000 : for (unsigned int r = 0; r < split_mse.size(); ++r)
324 7350 : split_mse[r] += row_mse[r];
325 :
326 : skipped_row++;
327 : }
328 : }
329 : gatherSum(split_mse);
330 :
331 : // Expand cv_score if necessary.
332 624 : cv_score.resize(split_mse.size(), 0.0);
333 :
334 2112 : for (auto r : make_range(split_mse.size()))
335 1488 : cv_score[r] += split_mse[r] / n_rows;
336 :
337 : _skip_indices.clear();
338 : }
339 :
340 1056 : for (auto r : make_range(cv_score.size()))
341 744 : cv_score[r] = std::sqrt(cv_score[r]);
342 :
343 312 : return cv_score;
344 312 : }
345 :
346 : std::vector<Real>
347 4650 : SurrogateTrainer::evaluateModelError(const SurrogateModel & surr)
348 : {
349 4650 : std::vector<Real> error(1, 0.0);
350 :
351 4650 : if (_rval)
352 : {
353 4350 : Real model_eval = surr.evaluate(_predictor_data);
354 4350 : error[0] = MathUtils::pow(model_eval - (*_rval), 2);
355 : }
356 300 : else if (_rvecval)
357 : {
358 300 : error.resize(_rvecval->size());
359 :
360 : // Evaluate for vector response.
361 300 : std::vector<Real> model_eval(error.size());
362 300 : surr.evaluate(_predictor_data, model_eval);
363 3300 : for (auto r : make_range(_rvecval->size()))
364 3000 : error[r] = MathUtils::pow(model_eval[r] - (*_rvecval)[r], 2);
365 : }
366 :
367 4650 : return error;
368 : }
369 :
370 : void
371 235205 : SurrogateTrainer::updatePredictorRow()
372 : {
373 : unsigned int d = 0;
374 246015 : for (const auto & val : _pvals)
375 10810 : _predictor_data[d++] = *val;
376 1418930 : for (const auto & col : _pcols)
377 1183725 : _predictor_data[d++] = _row_data[col];
378 235205 : }
|