Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #include "NearestPointTrainer.h" 11 : #include "Sampler.h" 12 : 13 : registerMooseObject("StochasticToolsApp", NearestPointTrainer); 14 : 15 : InputParameters 16 288 : NearestPointTrainer::validParams() 17 : { 18 288 : InputParameters params = SurrogateTrainer::validParams(); 19 288 : params.addClassDescription("Loops over and saves sample values for [NearestPointSurrogate.md]."); 20 : 21 288 : return params; 22 0 : } 23 : 24 144 : NearestPointTrainer::NearestPointTrainer(const InputParameters & parameters) 25 : : SurrogateTrainer(parameters), 26 144 : _sample_points(declareModelData<std::vector<std::vector<Real>>>("_sample_points")), 27 288 : _sample_results(declareModelData<std::vector<std::vector<Real>>>("_sample_results")), 28 144 : _predictor_row(getPredictorData()) 29 : { 30 144 : _sample_points.resize(_n_dims); 31 144 : _sample_results.resize(1); 32 144 : } 33 : 34 : void 35 240 : NearestPointTrainer::preTrain() 36 : { 37 1056 : for (auto & it : _sample_points) 38 : { 39 : it.clear(); 40 816 : it.reserve(getLocalSampleSize()); 41 : } 42 : 43 480 : for (auto & it : _sample_results) 44 : { 45 : it.clear(); 46 240 : it.reserve(getLocalSampleSize()); 47 : } 48 240 : } 49 : 50 : void 51 22380 : NearestPointTrainer::train() 52 : { 53 22380 : if (_rvecval && (_sample_results.size() != _rvecval->size())) 54 16 : _sample_results.resize(_rvecval->size()); 55 : 56 : // Get predictors from reporter values 57 91240 : for (auto d : make_range(_n_dims)) 58 68860 : _sample_points[d].push_back(_predictor_row[d]); 59 : 60 : // Get responses 61 22380 : if (_rval) 62 22280 : _sample_results[0].push_back(*_rval); 63 100 : else if (_rvecval) 64 1100 : for (auto r : make_range(_rvecval->size())) 65 1000 : _sample_results[r].push_back((*_rvecval)[r]); 66 22380 : } 67 : 68 : void 69 240 : NearestPointTrainer::postTrain() 70 : { 71 1056 : for (auto & it : _sample_points) 72 816 : _communicator.allgather(it); 73 : 74 624 : for (auto & it : _sample_results) 75 384 : _communicator.allgather(it); 76 240 : }