LCOV - code coverage report
Current view: top level - include/interfaces - SurrogateModelInterface.h (source / functions) Hit Total Coverage
Test: idaholab/moose stochastic_tools: f45d79 Lines: 15 17 88.2 %
Date: 2025-07-25 05:00:46 Functions: 10 12 83.3 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "ParallelUniqueId.h"
      13             : #include "InputParameters.h"
      14             : #include "FEProblemBase.h"
      15             : 
      16             : // Forward declarations
      17             : class SurrogateModel;
      18             : class SurrogateTrainerBase;
      19             : 
      20             : /**
      21             :  * Interface for objects that need to use samplers.
      22             :  *
      23             :  * This practically adds two methods for getting SurrogateModel objects:
      24             :  *
      25             :  *  1. Call `getSurrogateModel` or `getSurrogateModelByName` without a template parameter and you
      26             :  * will get a `SurrogateModel` base object (see SurrogateModelInterface.C for the template
      27             :  * specialization).
      28             :  *  2. Call `getSurrogateModel<MySurrogateModel>` or `getSurrogateModelByName<MySurrogateModel>` to
      29             :  * perform a cast to the desired type, as done for UserObjects.
      30             :  */
      31             : class SurrogateModelInterface
      32             : {
      33             : public:
      34             :   static InputParameters validParams();
      35             : 
      36             :   /**
      37             :    * @param params The parameters used by the object being instantiated. This
      38             :    *        class needs them so it can get the sampler named in the input file,
      39             :    *        but the object calling getSurrogateModel only needs to use the name on the
      40             :    *        left hand side of the statement "sampler = sampler_name"
      41             :    */
      42             :   SurrogateModelInterface(const MooseObject * moose_object);
      43             : 
      44             :   ///@{
      45             :   /**
      46             :    * Get a SurrogateModel/Trainer with a given name
      47             :    * @param name The name of the parameter key of the sampler to retrieve
      48             :    * @return The sampler with name associated with the parameter 'name'
      49             :    */
      50             :   template <typename T = SurrogateModel>
      51             :   T & getSurrogateModel(const std::string & name) const;
      52             :   template <typename T = SurrogateTrainerBase>
      53             :   T & getSurrogateTrainer(const std::string & name) const;
      54             :   ///@}
      55             : 
      56             :   ///@{
      57             :   /**
      58             :    * Get a sampler with a given name
      59             :    * @param name The name of the sampler to retrieve
      60             :    * @return The sampler with name 'name'
      61             :    */
      62             :   template <typename T = SurrogateModel>
      63             :   T & getSurrogateModelByName(const UserObjectName & name) const;
      64             :   template <typename T = SurrogateTrainerBase>
      65             :   T & getSurrogateTrainerByName(const UserObjectName & name) const;
      66             : 
      67             :   ///@}
      68             : private:
      69             :   /// Parameters of the object with this interface
      70             :   const InputParameters & _smi_params;
      71             : 
      72             :   /// Reference to FEProblemBase instance
      73             :   FEProblemBase & _smi_feproblem;
      74             : 
      75             :   /// Thread ID
      76             :   const THREAD_ID _smi_tid;
      77             : };
      78             : 
      79             : template <typename T>
      80             : T &
      81         470 : SurrogateModelInterface::getSurrogateModel(const std::string & name) const
      82             : {
      83         470 :   return getSurrogateModelByName<T>(_smi_params.get<UserObjectName>(name));
      84             : }
      85             : 
      86             : template <typename T>
      87             : T &
      88         722 : SurrogateModelInterface::getSurrogateModelByName(const UserObjectName & name) const
      89             : {
      90             :   std::vector<T *> models;
      91         722 :   _smi_feproblem.theWarehouse()
      92             :       .query()
      93        1444 :       .condition<AttribName>(name)
      94         722 :       .condition<AttribSystem>("SurrogateModel")
      95             :       .queryInto(models);
      96         722 :   if (models.empty())
      97           0 :     mooseError("Unable to find a SurrogateModel object of type " + std::string(typeid(T).name()) +
      98             :                " with the name '" + name + "'");
      99         722 :   return *(models[0]);
     100             : }
     101             : 
     102             : template <typename T>
     103             : T &
     104        1198 : SurrogateModelInterface::getSurrogateTrainer(const std::string & name) const
     105             : {
     106        1198 :   return getSurrogateTrainerByName<T>(_smi_params.get<UserObjectName>(name));
     107             : }
     108             : 
     109             : template <typename T>
     110             : T &
     111         400 : SurrogateModelInterface::getSurrogateTrainerByName(const UserObjectName & name) const
     112             : {
     113             :   SurrogateTrainerBase * base_ptr =
     114         400 :       &_smi_feproblem.getUserObject<SurrogateTrainerBase>(name, _smi_tid);
     115         392 :   T * obj_ptr = dynamic_cast<T *>(base_ptr);
     116         392 :   if (!obj_ptr)
     117           0 :     mooseError("Failed to find a SurrogateTrainer object of type " + std::string(typeid(T).name()) +
     118             :                    " with the name '",
     119             :                name,
     120             :                "' for the desired type.");
     121         392 :   return *obj_ptr;
     122             : }

Generated by: LCOV version 1.14