Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "ActiveLearningReporterBase.h" 13 : #include "ActiveLearningGaussianProcess.h" 14 : #include "GaussianProcessSurrogate.h" 15 : #include "SurrogateModelInterface.h" 16 : 17 : class ActiveLearningGPDecision : public ActiveLearningReporterTempl<Real>, 18 : public SurrogateModelInterface 19 : { 20 : public: 21 : static InputParameters validParams(); 22 : ActiveLearningGPDecision(const InputParameters & parameters); 23 : 24 : /// Access the number of training samples 25 16 : const int & getTrainingSamples() const { return _n_train; } 26 : 27 : protected: 28 : /** 29 : * This is where most of the computations happen: 30 : * - Data is accumulated for training 31 : * - GP models are trained 32 : * - Decision is made whether more data is needed for GP training 33 : */ 34 : virtual void preNeedSample() override; 35 : 36 : /** 37 : * Based on the computations in preNeedSample, the decision to get more data is passed and results 38 : * from the GP fills @param val 39 : * 40 : * @param row Input parameters to the model 41 : * @param local_ind Current processor row index 42 : * @param global_ind All processors row index 43 : * @param val Output predicted by either the LF model + GP correction or the HF model 44 : * @return bool Whether a full order model evaluation is required 45 : */ 46 : virtual bool needSample(const std::vector<Real> & row, 47 : dof_id_type local_ind, 48 : dof_id_type global_ind, 49 : Real & val) override; 50 : 51 : /** 52 : * Make decisions whether to call the full model or not based on 53 : * GP prediction and uncertainty. 54 : * 55 : * @return bool Whether a full order model evaluation is required 56 : */ 57 : virtual bool facilitateDecision(); 58 : 59 : /** 60 : * This sets up data for re-training the GP. 61 : * 62 : * @param inputs Matrix of inputs for the current step 63 : * @param outputs Vector of outputs for the current step 64 : */ 65 : virtual void setupData(const std::vector<std::vector<Real>> & inputs, 66 : const std::vector<Real> & outputs); 67 : 68 : /** 69 : * This method evaluates the active learning acquisition function and returns bool 70 : * that indicates whether the GP model failed. 71 : * 72 : * @param gp_mean Mean of the gaussian process model 73 : * @param gp_mean Standard deviation of the gaussian process model 74 : * @return bool If the GP model failed 75 : */ 76 : bool learningFunction(const Real & gp_mean, const Real & gp_std) const; 77 : 78 : /// The learning function for active learning 79 : const MooseEnum & _learning_function; 80 : /// The learning function threshold 81 : const Real & _learning_function_threshold; 82 : /// The learning function parameter 83 : const Real & _learning_function_parameter; 84 : 85 : /// Store all the input vectors used for training 86 : std::vector<std::vector<Real>> _inputs_batch; 87 : /// Store all the outputs used for training 88 : std::vector<Real> _outputs_batch; 89 : 90 : /// The active learning GP trainer that permits re-training 91 : const ActiveLearningGaussianProcess & _al_gp; 92 : /// The GP evaluator object that permits re-evaluations 93 : const SurrogateModel & _gp_eval; 94 : 95 : /// Flag samples when the GP fails 96 : std::vector<bool> & _flag_sample; 97 : 98 : /// Number of initial training points for GP 99 : const int _n_train; 100 : 101 : /// Storage for the input vectors to be transferred to the output file 102 : std::vector<std::vector<Real>> & _inputs; 103 : 104 : /// Broadcast the GP mean prediciton to JSON 105 : std::vector<Real> & _gp_mean; 106 : /// Broadcast the GP standard deviation to JSON 107 : std::vector<Real> & _gp_std; 108 : 109 : /// GP pass/fail decision 110 : bool _decision; 111 : 112 : /// Reference to global input data requested from base class 113 : const std::vector<std::vector<Real>> & _inputs_global; 114 : /// Reference to global output data requested from base class 115 : const std::vector<Real> & _outputs_global; 116 : };