Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #include "NearestPointTrainer.h" 11 : #include "Sampler.h" 12 : 13 : registerMooseObject("StochasticToolsApp", NearestPointTrainer); 14 : 15 : InputParameters 16 306 : NearestPointTrainer::validParams() 17 : { 18 306 : InputParameters params = SurrogateTrainer::validParams(); 19 306 : params.addClassDescription("Loops over and saves sample values for [NearestPointSurrogate.md]."); 20 : 21 306 : return params; 22 0 : } 23 : 24 153 : NearestPointTrainer::NearestPointTrainer(const InputParameters & parameters) 25 : : SurrogateTrainer(parameters), 26 153 : _sample_points(declareModelData<std::vector<std::vector<Real>>>("_sample_points")), 27 306 : _sample_results(declareModelData<std::vector<std::vector<Real>>>("_sample_results")), 28 153 : _predictor_row(getPredictorData()) 29 : { 30 153 : _sample_points.resize(_n_dims); 31 153 : _sample_results.resize(1); 32 153 : } 33 : 34 : void 35 255 : NearestPointTrainer::preTrain() 36 : { 37 1122 : for (auto & it : _sample_points) 38 : { 39 867 : it.clear(); 40 867 : it.reserve(getLocalSampleSize()); 41 : } 42 : 43 510 : for (auto & it : _sample_results) 44 : { 45 255 : it.clear(); 46 255 : it.reserve(getLocalSampleSize()); 47 : } 48 255 : } 49 : 50 : void 51 24618 : NearestPointTrainer::train() 52 : { 53 24618 : if (_rvecval && (_sample_results.size() != _rvecval->size())) 54 17 : _sample_results.resize(_rvecval->size()); 55 : 56 : // Get predictors from reporter values 57 100364 : for (auto d : make_range(_n_dims)) 58 75746 : _sample_points[d].push_back(_predictor_row[d]); 59 : 60 : // Get responses 61 24618 : if (_rval) 62 24508 : _sample_results[0].push_back(*_rval); 63 110 : else if (_rvecval) 64 1210 : for (auto r : make_range(_rvecval->size())) 65 1100 : _sample_results[r].push_back((*_rvecval)[r]); 66 24618 : } 67 : 68 : void 69 255 : NearestPointTrainer::postTrain() 70 : { 71 1122 : for (auto & it : _sample_points) 72 867 : _communicator.allgather(it); 73 : 74 663 : for (auto & it : _sample_results) 75 408 : _communicator.allgather(it); 76 255 : }