Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "ParallelUniqueId.h" 13 : #include "InputParameters.h" 14 : #include "FEProblemBase.h" 15 : 16 : // Forward declarations 17 : class SurrogateModel; 18 : class SurrogateTrainerBase; 19 : 20 : /** 21 : * Interface for objects that need to use samplers. 22 : * 23 : * This practically adds two methods for getting SurrogateModel objects: 24 : * 25 : * 1. Call `getSurrogateModel` or `getSurrogateModelByName` without a template parameter and you 26 : * will get a `SurrogateModel` base object (see SurrogateModelInterface.C for the template 27 : * specialization). 28 : * 2. Call `getSurrogateModel<MySurrogateModel>` or `getSurrogateModelByName<MySurrogateModel>` to 29 : * perform a cast to the desired type, as done for UserObjects. 30 : */ 31 : class SurrogateModelInterface 32 : { 33 : public: 34 : static InputParameters validParams(); 35 : 36 : /** 37 : * @param params The parameters used by the object being instantiated. This 38 : * class needs them so it can get the sampler named in the input file, 39 : * but the object calling getSurrogateModel only needs to use the name on the 40 : * left hand side of the statement "sampler = sampler_name" 41 : */ 42 : SurrogateModelInterface(const MooseObject * moose_object); 43 : 44 : ///@{ 45 : /** 46 : * Get a SurrogateModel/Trainer with a given name 47 : * @param name The name of the parameter key of the sampler to retrieve 48 : * @return The sampler with name associated with the parameter 'name' 49 : */ 50 : template <typename T = SurrogateModel> 51 : T & getSurrogateModel(const std::string & name) const; 52 : template <typename T = SurrogateTrainerBase> 53 : T & getSurrogateTrainer(const std::string & name) const; 54 : ///@} 55 : 56 : ///@{ 57 : /** 58 : * Get a sampler with a given name 59 : * @param name The name of the sampler to retrieve 60 : * @return The sampler with name 'name' 61 : */ 62 : template <typename T = SurrogateModel> 63 : T & getSurrogateModelByName(const UserObjectName & name) const; 64 : template <typename T = SurrogateTrainerBase> 65 : T & getSurrogateTrainerByName(const UserObjectName & name) const; 66 : 67 : ///@} 68 : private: 69 : /// Parameters of the object with this interface 70 : const InputParameters & _smi_params; 71 : 72 : /// Reference to FEProblemBase instance 73 : FEProblemBase & _smi_feproblem; 74 : 75 : /// Thread ID 76 : const THREAD_ID _smi_tid; 77 : }; 78 : 79 : template <typename T> 80 : T & 81 470 : SurrogateModelInterface::getSurrogateModel(const std::string & name) const 82 : { 83 470 : return getSurrogateModelByName<T>(_smi_params.get<UserObjectName>(name)); 84 : } 85 : 86 : template <typename T> 87 : T & 88 722 : SurrogateModelInterface::getSurrogateModelByName(const UserObjectName & name) const 89 : { 90 : std::vector<T *> models; 91 722 : _smi_feproblem.theWarehouse() 92 : .query() 93 1444 : .condition<AttribName>(name) 94 722 : .condition<AttribSystem>("SurrogateModel") 95 : .queryInto(models); 96 722 : if (models.empty()) 97 0 : mooseError("Unable to find a SurrogateModel object of type " + std::string(typeid(T).name()) + 98 : " with the name '" + name + "'"); 99 722 : return *(models[0]); 100 : } 101 : 102 : template <typename T> 103 : T & 104 1198 : SurrogateModelInterface::getSurrogateTrainer(const std::string & name) const 105 : { 106 1198 : return getSurrogateTrainerByName<T>(_smi_params.get<UserObjectName>(name)); 107 : } 108 : 109 : template <typename T> 110 : T & 111 400 : SurrogateModelInterface::getSurrogateTrainerByName(const UserObjectName & name) const 112 : { 113 : SurrogateTrainerBase * base_ptr = 114 400 : &_smi_feproblem.getUserObject<SurrogateTrainerBase>(name, _smi_tid); 115 392 : T * obj_ptr = dynamic_cast<T *>(base_ptr); 116 392 : if (!obj_ptr) 117 0 : mooseError("Failed to find a SurrogateTrainer object of type " + std::string(typeid(T).name()) + 118 : " with the name '", 119 : name, 120 : "' for the desired type."); 121 392 : return *obj_ptr; 122 : }