Loading [MathJax]/extensions/tex2jax.js
https://mooseframework.inl.gov
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends
Public Types | Public Member Functions | Static Public Member Functions | Public Attributes | Static Public Attributes | Protected Member Functions | Static Protected Member Functions | Protected Attributes | Static Protected Attributes | Private Attributes | List of all members
LibtorchANNTrainer Class Reference

Trainer responsible of fitting a neural network on predefined data. More...

#include <LibtorchANNTrainer.h>

Inheritance diagram for LibtorchANNTrainer:
[legend]

Public Types

typedef DataFileName DataFileParameterType
 

Public Member Functions

 LibtorchANNTrainer (const InputParameters &parameters)
 Construct using input parameters. More...
 
virtual void preTrain () override
 Contains processes which are executed before the training loop. More...
 
virtual void train () override
 Contains processes which are executed for every sample in the training loop. More...
 
virtual void postTrain () override
 Contains processes which are executed after the training loop. More...
 
virtual void initialize () final
 
virtual void execute () final
 
virtual void finalize () final
 
virtual void threadJoin (const UserObject &) final
 
SubProblemgetSubProblem () const
 
bool shouldDuplicateInitialExecution () const
 
virtual Real spatialValue (const Point &) const
 
virtual const std::vector< Point > spatialPoints () const
 
void gatherSum (T &value)
 
void gatherMax (T &value)
 
void gatherMin (T &value)
 
void gatherProxyValueMax (T1 &proxy, T2 &value)
 
void gatherProxyValueMin (T1 &proxy, T2 &value)
 
void setPrimaryThreadCopy (UserObject *primary)
 
UserObjectprimaryThreadCopy ()
 
std::set< UserObjectName > getDependObjects () const
 
virtual bool needThreadedCopy () const
 
const std::set< std::string > & getRequestedItems () override
 
const std::set< std::string > & getSuppliedItems () override
 
unsigned int systemNumber () const
 
virtual bool enabled () const
 
std::shared_ptr< MooseObjectgetSharedPtr ()
 
std::shared_ptr< const MooseObjectgetSharedPtr () const
 
MooseAppgetMooseApp () const
 
const std::string & type () const
 
virtual const std::string & name () const
 
std::string typeAndName () const
 
std::string errorPrefix (const std::string &error_type) const
 
void callMooseError (std::string msg, const bool with_prefix) const
 
MooseObjectParameterName uniqueParameterName (const std::string &parameter_name) const
 
const InputParametersparameters () const
 
MooseObjectName uniqueName () const
 
const T & getParam (const std::string &name) const
 
std::vector< std::pair< T1, T2 > > getParam (const std::string &param1, const std::string &param2) const
 
const T & getRenamedParam (const std::string &old_name, const std::string &new_name) const
 
getCheckedPointerParam (const std::string &name, const std::string &error_string="") const
 
bool isParamValid (const std::string &name) const
 
bool isParamSetByUser (const std::string &nm) const
 
void paramError (const std::string &param, Args... args) const
 
void paramWarning (const std::string &param, Args... args) const
 
void paramInfo (const std::string &param, Args... args) const
 
void connectControllableParams (const std::string &parameter, const std::string &object_type, const std::string &object_name, const std::string &object_parameter) const
 
void mooseError (Args &&... args) const
 
void mooseErrorNonPrefixed (Args &&... args) const
 
void mooseDocumentedError (const std::string &repo_name, const unsigned int issue_num, Args &&... args) const
 
void mooseWarning (Args &&... args) const
 
void mooseWarningNonPrefixed (Args &&... args) const
 
void mooseDeprecated (Args &&... args) const
 
void mooseInfo (Args &&... args) const
 
std::string getDataFileName (const std::string &param) const
 
std::string getDataFileNameByName (const std::string &relative_path) const
 
std::string getDataFilePath (const std::string &relative_path) const
 
virtual void initialSetup ()
 
virtual void timestepSetup ()
 
virtual void jacobianSetup ()
 
virtual void residualSetup ()
 
virtual void customSetup (const ExecFlagType &)
 
const ExecFlagEnumgetExecuteOnEnum () const
 
UserObjectName getUserObjectName (const std::string &param_name) const
 
const T & getUserObject (const std::string &param_name, bool is_dependency=true) const
 
const T & getUserObjectByName (const UserObjectName &object_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBase (const std::string &param_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBaseByName (const UserObjectName &object_name, bool is_dependency=true) const
 
const std::vector< MooseVariableScalar *> & getCoupledMooseScalarVars ()
 
const std::set< TagID > & getScalarVariableCoupleableVectorTags () const
 
const std::set< TagID > & getScalarVariableCoupleableMatrixTags () const
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, MaterialData &material_data, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, MaterialData &material_data, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, const unsigned int state=0)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name, MaterialData &material_data)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data, const unsigned int state)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name)
 
std::pair< const MaterialProperty< T > *, std::set< SubdomainID > > getBlockMaterialProperty (const MaterialPropertyName &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialProperty (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialProperty ()
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialPropertyByName (const std::string &prop_name)
 
const MaterialProperty< T > & getZeroMaterialProperty (Ts... args)
 
std::set< SubdomainIDgetMaterialPropertyBlocks (const std::string &name)
 
std::vector< SubdomainName > getMaterialPropertyBlockNames (const std::string &name)
 
std::set< BoundaryIDgetMaterialPropertyBoundaryIDs (const std::string &name)
 
std::vector< BoundaryName > getMaterialPropertyBoundaryNames (const std::string &name)
 
void checkBlockAndBoundaryCompatibility (std::shared_ptr< MaterialBase > discrete)
 
std::unordered_map< SubdomainID, std::vector< MaterialBase *> > buildRequiredMaterials (bool allow_stateful=true)
 
void statefulPropertiesAllowed (bool)
 
bool getMaterialPropertyCalled () const
 
const std::unordered_set< unsigned int > & getMatPropDependencies () const
 
virtual void resolveOptionalProperties ()
 
const GenericMaterialProperty< T, is_ad > & getPossiblyConstantGenericMaterialPropertyByName (const MaterialPropertyName &prop_name, MaterialData &material_data, const unsigned int state)
 
bool isImplicit ()
 
Moose::StateArg determineState () const
 
virtual void subdomainSetup () override
 
virtual void subdomainSetup () override
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
const GenericOptionalMaterialProperty< T, is_ad > & getGenericOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const GenericOptionalMaterialProperty< T, is_ad > & getGenericOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalMaterialProperty< T > & getOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalMaterialProperty< T > & getOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalADMaterialProperty< T > & getOptionalADMaterialProperty (const std::string &name)
 
const OptionalADMaterialProperty< T > & getOptionalADMaterialProperty (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOld (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOld (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOlder (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOlder (const std::string &name)
 
MaterialBasegetMaterial (const std::string &name)
 
MaterialBasegetMaterial (const std::string &name)
 
MaterialBasegetMaterialByName (const std::string &name, bool no_warn=false)
 
MaterialBasegetMaterialByName (const std::string &name, bool no_warn=false)
 
bool hasMaterialProperty (const std::string &name)
 
bool hasMaterialProperty (const std::string &name)
 
bool hasMaterialPropertyByName (const std::string &name)
 
bool hasMaterialPropertyByName (const std::string &name)
 
bool hasADMaterialProperty (const std::string &name)
 
bool hasADMaterialProperty (const std::string &name)
 
bool hasADMaterialPropertyByName (const std::string &name)
 
bool hasADMaterialPropertyByName (const std::string &name)
 
bool hasGenericMaterialProperty (const std::string &name)
 
bool hasGenericMaterialProperty (const std::string &name)
 
bool hasGenericMaterialPropertyByName (const std::string &name)
 
bool hasGenericMaterialPropertyByName (const std::string &name)
 
const FunctiongetFunction (const std::string &name) const
 
const FunctiongetFunctionByName (const FunctionName &name) const
 
bool hasFunction (const std::string &param_name) const
 
bool hasFunctionByName (const FunctionName &name) const
 
bool isDefaultPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessor (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessorByName (const PostprocessorName &name) const
 
std::size_t coupledPostprocessors (const std::string &param_name) const
 
const PostprocessorName & getPostprocessorName (const std::string &param_name, const unsigned int index=0) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name) const
 
const VectorPostprocessorName & getVectorPostprocessorName (const std::string &param_name) const
 
T & getSampler (const std::string &name)
 
SamplergetSampler (const std::string &name)
 
T & getSamplerByName (const SamplerName &name)
 
SamplergetSamplerByName (const SamplerName &name)
 
virtual void meshChanged ()
 
PerfGraphperfGraph ()
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
const DistributiongetDistribution (const std::string &name) const
 
const T & getDistribution (const std::string &name) const
 
const DistributiongetDistribution (const std::string &name) const
 
const T & getDistribution (const std::string &name) const
 
const DistributiongetDistributionByName (const DistributionName &name) const
 
const T & getDistributionByName (const std::string &name) const
 
const DistributiongetDistributionByName (const DistributionName &name) const
 
const T & getDistributionByName (const std::string &name) const
 
const Parallel::Communicator & comm () const
 
processor_id_type n_processors () const
 
processor_id_type processor_id () const
 
const std::string & modelMetaDataName () const
 Accessor for the name of the model meta data. More...
 
const FileName & getModelDataFileName () const
 Get the associated filename. More...
 
bool hasModelData () const
 Check if we need to load model data (if the filename parameter is used) More...
 
template<>
SurrogateModelgetSurrogateModel (const std::string &name) const
 
template<>
SurrogateTrainerBasegetSurrogateTrainer (const std::string &name) const
 
template<>
SurrogateModelgetSurrogateModelByName (const UserObjectName &name) const
 
template<>
SurrogateTrainerBasegetSurrogateTrainerByName (const UserObjectName &name) const
 
template<typename T , typename... Args>
T & declareModelData (const std::string &data_name, Args &&... args)
 Declare model data for loading from file as well as restart. More...
 
template<typename T , typename... Args>
const T & getModelData (const std::string &data_name, Args &&... args) const
 Retrieve model data from the interface. More...
 
template<typename T = SurrogateModel>
T & getSurrogateModel (const std::string &name) const
 Get a SurrogateModel/Trainer with a given name. More...
 
template<typename T = SurrogateTrainerBase>
T & getSurrogateTrainer (const std::string &name) const
 
template<typename T = SurrogateModel>
T & getSurrogateModelByName (const UserObjectName &name) const
 Get a sampler with a given name. More...
 
template<typename T = SurrogateTrainerBase>
T & getSurrogateTrainerByName (const UserObjectName &name) const
 

Static Public Member Functions

static InputParameters validParams ()
 
static void sort (typename std::vector< T > &vector)
 
static void sortDFS (typename std::vector< T > &vector)
 
static void cyclicDependencyError (CyclicDependencyException< T2 > &e, const std::string &header)
 

Public Attributes

const ConsoleStream _console
 

Static Public Attributes

static constexpr PropertyValue::id_type default_property_id
 
static constexpr PropertyValue::id_type zero_property_id
 
static constexpr auto SYSTEM
 
static constexpr auto NAME
 

Protected Member Functions

template<typename T >
const T & getTrainingData (const ReporterName &rname)
 
const std::vector< Real > & getSamplerData () const
 
const std::vector< Real > & getPredictorData () const
 
unsigned int getCurrentSampleSize () const
 
unsigned int getLocalSampleSize () const
 
virtual std::vector< RealevaluateModelError (const SurrogateModel &surr)
 
virtual void addPostprocessorDependencyHelper (const PostprocessorName &name) const override
 
virtual void addVectorPostprocessorDependencyHelper (const VectorPostprocessorName &name) const override
 
virtual void addUserObjectDependencyHelper (const UserObject &uo) const override
 
void addReporterDependencyHelper (const ReporterName &reporter_name) override
 
const ReporterNamegetReporterName (const std::string &param_name) const
 
T & declareRestartableData (const std::string &data_name, Args &&... args)
 
ManagedValue< T > declareManagedRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
const T & getRestartableData (const std::string &data_name) const
 
T & declareRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
T & declareRecoverableData (const std::string &data_name, Args &&... args)
 
T & declareRestartableDataWithObjectName (const std::string &data_name, const std::string &object_name, Args &&... args)
 
T & declareRestartableDataWithObjectNameWithContext (const std::string &data_name, const std::string &object_name, void *context, Args &&... args)
 
std::string restartableName (const std::string &data_name) const
 
const T & getMeshProperty (const std::string &data_name, const std::string &prefix)
 
const T & getMeshProperty (const std::string &data_name)
 
bool hasMeshProperty (const std::string &data_name, const std::string &prefix) const
 
bool hasMeshProperty (const std::string &data_name, const std::string &prefix) const
 
bool hasMeshProperty (const std::string &data_name) const
 
bool hasMeshProperty (const std::string &data_name) const
 
std::string meshPropertyName (const std::string &data_name) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level, const std::string &live_message, const bool print_dots=true) const
 
std::string timedSectionName (const std::string &section_name) const
 
bool isCoupledScalar (const std::string &var_name, unsigned int i=0) const
 
unsigned int coupledScalarComponents (const std::string &var_name) const
 
unsigned int coupledScalar (const std::string &var_name, unsigned int comp=0) const
 
libMesh::Order coupledScalarOrder (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const ADVariableValueadCoupledScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const GenericVariableValue< is_ad > & coupledGenericScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const GenericVariableValue< false > & coupledGenericScalarValue (const std::string &var_name, const unsigned int comp) const
 
const GenericVariableValue< true > & coupledGenericScalarValue (const std::string &var_name, const unsigned int comp) const
 
const VariableValuecoupledVectorTagScalarValue (const std::string &var_name, TagID tag, unsigned int comp=0) const
 
const VariableValuecoupledMatrixTagScalarValue (const std::string &var_name, TagID tag, unsigned int comp=0) const
 
const VariableValuecoupledScalarValueOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarValueOlder (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDot (const std::string &var_name, unsigned int comp=0) const
 
const ADVariableValueadCoupledScalarDot (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDot (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDotOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDu (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDotDu (const std::string &var_name, unsigned int comp=0) const
 
const MooseVariableScalargetScalarVar (const std::string &var_name, unsigned int comp) const
 
virtual void checkMaterialProperty (const std::string &name, const unsigned int state)
 
void markMatPropRequested (const std::string &)
 
MaterialPropertyName getMaterialPropertyName (const std::string &name) const
 
void checkExecutionStage ()
 
const T & getReporterValue (const std::string &param_name, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, ReporterMode mode, const std::size_t time_index=0)
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
const GenericMaterialProperty< T, is_ad > * defaultGenericMaterialProperty (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > * defaultGenericMaterialProperty (const std::string &name)
 
const MaterialProperty< T > * defaultMaterialProperty (const std::string &name)
 
const MaterialProperty< T > * defaultMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > * defaultADMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > * defaultADMaterialProperty (const std::string &name)
 

Static Protected Member Functions

static std::string meshPropertyName (const std::string &data_name, const std::string &prefix)
 

Protected Attributes

SubProblem_subproblem
 
FEProblemBase_fe_problem
 
SystemBase_sys
 
const THREAD_ID _tid
 
Assembly_assembly
 
const Moose::CoordinateSystemType_coord_sys
 
const bool _duplicate_initial_execution
 
std::set< std::string > _depend_uo
 
const bool & _enabled
 
MooseApp_app
 
const std::string _type
 
const std::string _name
 
const InputParameters_pars
 
Factory_factory
 
ActionFactory_action_factory
 
const ExecFlagEnum_execute_enum
 
const ExecFlagType_current_execute_flag
 
MooseApp_restartable_app
 
const std::string _restartable_system_name
 
const THREAD_ID _restartable_tid
 
const bool _restartable_read_only
 
FEProblemBase_mci_feproblem
 
MooseApp_pg_moose_app
 
const std::string _prefix
 
FEProblemBase_sc_fe_problem
 
const THREAD_ID _sc_tid
 
const Real_real_zero
 
const VariableValue_scalar_zero
 
const Point & _point_zero
 
const InputParameters_mi_params
 
const std::string _mi_name
 
const MooseObjectName _mi_moose_object_name
 
FEProblemBase_mi_feproblem
 
SubProblem_mi_subproblem
 
const THREAD_ID _mi_tid
 
const Moose::MaterialDataType _material_data_type
 
MaterialData_material_data
 
bool _stateful_allowed
 
bool _get_material_property_called
 
std::vector< std::unique_ptr< PropertyValue > > _default_properties
 
std::unordered_set< unsigned int_material_property_dependencies
 
const MaterialPropertyName _get_suffix
 
const bool _use_interpolated_state
 
const InputParameters_ti_params
 
FEProblemBase_ti_feproblem
 
bool _is_implicit
 
Real_t
 
const Real_t_old
 
int_t_step
 
Real_dt
 
Real_dt_old
 
bool _is_transient
 
const Parallel::Communicator & _communicator
 
Sampler_sampler
 
dof_id_type _row
 During training loop, this is the row index of the data. More...
 
dof_id_type _local_row
 During training loop, this is the local row index of the data. More...
 
const Real_rval
 Response value. More...
 
const std::vector< Real > * _rvecval
 Vector response value. More...
 
std::vector< const Real * > _pvals
 Predictor values from reporters. More...
 
std::vector< unsigned int_pcols
 Columns from sampler for predictors. More...
 
unsigned int _n_dims
 Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols.size(). More...
 
unsigned int_n_outputs
 The number of outputs. More...
 

Static Protected Attributes

static const std::string _interpolated_old
 
static const std::string _interpolated_older
 

Private Attributes

const std::vector< Real > & _predictor_row
 Data from the current predictor row. More...
 
std::vector< Real_flattened_data
 The gathered data in a flattened form to be able to convert easily to torch::Tensor. More...
 
std::vector< Real_flattened_response
 The gathered response in a flattened form to be able to convert easily to torch::Tensor. More...
 
std::vector< unsigned int > & _num_neurons_per_layer
 Number of neurons within the hidden layers (the length of this vector should be the same as _num_hidden_layers) More...
 
std::vector< std::string > & _activation_function
 Activation functions for each hidden layer. More...
 
const std::string _nn_filename
 Name of the pytorch output file. More...
 
const bool _read_from_file
 Switch indicating if an already existing neural net should be read from a file or not. More...
 
Moose::LibtorchTrainingOptions _optim_options
 The struct which contains the information for the training of the neural net. More...
 
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > & _nn
 Pointer to the neural net object (initialized as null) More...
 
const bool _standardize_input
 If the training output should be standardized (scaled and shifted) More...
 
const bool _standardize_output
 If the training output should be standardized (scaled and shifted) More...
 
StochasticTools::Standardizer_input_standardizer
 Standardizer for use with input (x) More...
 
StochasticTools::Standardizer_output_standardizer
 Standardizer for use with output response (y) More...
 

Detailed Description

Trainer responsible of fitting a neural network on predefined data.

Definition at line 23 of file LibtorchANNTrainer.h.

Constructor & Destructor Documentation

◆ LibtorchANNTrainer()

LibtorchANNTrainer::LibtorchANNTrainer ( const InputParameters parameters)

Construct using input parameters.

Definition at line 66 of file LibtorchANNTrainer.C.

69  _num_neurons_per_layer(declareModelData<std::vector<unsigned int>>(
70  "num_neurons_per_layer", getParam<std::vector<unsigned int>>("num_neurons_per_layer"))),
71  _activation_function(declareModelData<std::vector<std::string>>(
72  "activation_function", getParam<std::vector<std::string>>("activation_function"))),
73  _nn_filename(getParam<std::string>("nn_filename")),
74  _read_from_file(getParam<bool>("read_from_file")),
75  _nn(declareModelData<std::shared_ptr<Moose::LibtorchArtificialNeuralNet>>("nn")),
76  _standardize_input(getParam<bool>("standardize_input")),
77  _standardize_output(getParam<bool>("standardize_output")),
78  _input_standardizer(declareModelData<StochasticTools::Standardizer>("input_standardizer")),
79  _output_standardizer(declareModelData<StochasticTools::Standardizer>("output_standardizer"))
80 {
81  // Fixing the RNG seed to make sure every experiment is the same.
82  // Otherwise sampling / stochastic gradient descent would be different.
83  torch::manual_seed(getParam<unsigned int>("seed"));
84 
86  _optim_options.learning_rate = getParam<Real>("learning_rate");
87  _optim_options.num_epochs = getParam<unsigned int>("num_epochs");
88  _optim_options.num_batches = getParam<unsigned int>("num_batches");
89  _optim_options.rel_loss_tol = getParam<Real>("rel_loss_tol");
90  _optim_options.print_loss = getParam<unsigned int>("print_epoch_loss") > 0;
91  _optim_options.print_epoch_loss = getParam<unsigned int>("print_epoch_loss");
92  _optim_options.parallel_processes = getParam<unsigned int>("max_processes");
93 }
std::vector< unsigned int > & _num_neurons_per_layer
Number of neurons within the hidden layers (the length of this vector should be the same as _num_hidd...
const bool _standardize_output
If the training output should be standardized (scaled and shifted)
const std::vector< Real > & _predictor_row
Data from the current predictor row.
const bool _standardize_input
If the training output should be standardized (scaled and shifted)
const T & getParam(const std::string &name) const
T & declareModelData(const std::string &data_name, Args &&... args)
Declare model data for loading from file as well as restart.
const std::string _nn_filename
Name of the pytorch output file.
Moose::LibtorchTrainingOptions _optim_options
The struct which contains the information for the training of the neural net.
const bool _read_from_file
Switch indicating if an already existing neural net should be read from a file or not...
std::vector< std::string > & _activation_function
Activation functions for each hidden layer.
StochasticTools::Standardizer & _output_standardizer
Standardizer for use with output response (y)
const std::vector< Real > & getPredictorData() const
const InputParameters & parameters() const
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > & _nn
Pointer to the neural net object (initialized as null)
SurrogateTrainer(const InputParameters &parameters)
StochasticTools::Standardizer & _input_standardizer
Standardizer for use with input (x)

Member Function Documentation

◆ declareModelData()

template<typename T , typename... Args>
T & RestartableModelInterface::declareModelData ( const std::string &  data_name,
Args &&...  args 
)
inherited

Declare model data for loading from file as well as restart.

Definition at line 78 of file RestartableModelInterface.h.

79 {
80  return _model_restartable.declareRestartableData<T>(data_name, std::forward<Args>(args)...);
81 }
T & declareRestartableData(const std::string &data_name, Args &&... args)
Declare a piece of data as "restartable" and initialize it.
PublicRestartable _model_restartable
Member for interfacing with the framework&#39;s restartable system.

◆ evaluateModelError()

std::vector< Real > SurrogateTrainer::evaluateModelError ( const SurrogateModel surr)
protectedvirtualinherited

Definition at line 347 of file SurrogateTrainer.C.

Referenced by SurrogateTrainer::crossValidate().

348 {
349  std::vector<Real> error(1, 0.0);
350 
351  if (_rval)
352  {
353  Real model_eval = surr.evaluate(_predictor_data);
354  error[0] = MathUtils::pow(model_eval - (*_rval), 2);
355  }
356  else if (_rvecval)
357  {
358  error.resize(_rvecval->size());
359 
360  // Evaluate for vector response.
361  std::vector<Real> model_eval(error.size());
362  surr.evaluate(_predictor_data, model_eval);
363  for (auto r : make_range(_rvecval->size()))
364  error[r] = MathUtils::pow(model_eval[r] - (*_rvecval)[r], 2);
365  }
366 
367  return error;
368 }
const Real * _rval
Response value.
const std::vector< Real > * _rvecval
Vector response value.
std::vector< Real > _predictor_data
Predictor data for current row - can be combination of Sampler and Reporter values.
virtual Real evaluate(const std::vector< Real > &x) const
Evaluate surrogate model given a row of parameters.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
IntRange< T > make_range(T beg, T end)
T pow(T x, int e)

◆ execute()

void SurrogateTrainer::execute ( )
finalvirtualinherited

Implements GeneralUserObject.

Definition at line 176 of file SurrogateTrainer.C.

177 {
178  if (_doing_cv)
179  for (const auto & trial : make_range(_cv_n_trials))
180  {
181  std::vector<Real> trial_score = crossValidate();
182 
183  // Expand _cv_trial_scores with more columns if necessary, then insert values.
184  for (unsigned int r = _cv_trial_scores.size(); r < trial_score.size(); ++r)
185  _cv_trial_scores.push_back(std::vector<Real>(_cv_n_trials, 0.0));
186  for (auto r : make_range(trial_score.size()))
187  _cv_trial_scores[r][trial] = trial_score[r];
188  }
189 
192  executeTraining();
193 }
const bool _doing_cv
Set to true if cross validation is being performed, controls behavior in execute().
const unsigned int & _cv_n_trials
Number of repeated trials of cross validation to perform.
std::vector< std::vector< Real > > & _cv_trial_scores
RMSE scores from each CV trial - can be grabbed by VPP or Reporter.
dof_id_type getNumberOfLocalRows() const
std::vector< Real > crossValidate()
dof_id_type getNumberOfRows() const
IntRange< T > make_range(T beg, T end)
unsigned int _local_sample_size
Number of samples (locally) used to train the model.
unsigned int _current_sample_size
Number of samples used to train the model.

◆ finalize()

virtual void SurrogateTrainer::finalize ( )
inlinefinalvirtualinherited

Reimplemented from SurrogateTrainerBase.

Definition at line 63 of file SurrogateTrainer.h.

63 {}

◆ getCurrentSampleSize()

unsigned int SurrogateTrainer::getCurrentSampleSize ( ) const
inlineprotectedinherited

Definition at line 102 of file SurrogateTrainer.h.

Referenced by PolynomialChaosTrainer::postTrain(), and PolynomialChaosTrainer::preTrain().

102 { return _current_sample_size; };
unsigned int _current_sample_size
Number of samples used to train the model.

◆ getLocalSampleSize()

unsigned int SurrogateTrainer::getLocalSampleSize ( ) const
inlineprotectedinherited

Definition at line 107 of file SurrogateTrainer.h.

Referenced by NearestPointTrainer::preTrain(), GaussianProcessTrainer::preTrain(), and preTrain().

107 { return _local_sample_size; };
unsigned int _local_sample_size
Number of samples (locally) used to train the model.

◆ getModelData()

template<typename T , typename... Args>
const T & RestartableModelInterface::getModelData ( const std::string &  data_name,
Args &&...  args 
) const
inherited

Retrieve model data from the interface.

Definition at line 85 of file RestartableModelInterface.h.

86 {
87  return _model_restartable.getRestartableData<T>(data_name, std::forward<Args>(args)...);
88 }
const T & getRestartableData(const std::string &data_name) const
Declare a piece of data as "restartable" and initialize it Similar to declareRestartableData but retu...
PublicRestartable _model_restartable
Member for interfacing with the framework&#39;s restartable system.

◆ getModelDataFileName()

const FileName & RestartableModelInterface::getModelDataFileName ( ) const
inherited

Get the associated filename.

Definition at line 33 of file RestartableModelInterface.C.

34 {
35  return _model_object.getParam<FileName>("filename");
36 }
const T & getParam(const std::string &name) const
const MooseObject & _model_object
Reference to the MooseObject that uses this interface.

◆ getPredictorData()

const std::vector<Real>& SurrogateTrainer::getPredictorData ( ) const
inlineprotectedinherited

Definition at line 97 of file SurrogateTrainer.h.

97 { return _predictor_data; };
std::vector< Real > _predictor_data
Predictor data for current row - can be combination of Sampler and Reporter values.

◆ getSamplerData()

const std::vector<Real>& SurrogateTrainer::getSamplerData ( ) const
inlineprotectedinherited

Definition at line 92 of file SurrogateTrainer.h.

92 { return _row_data; };
std::vector< Real > _row_data
Sampler data for the current row.

◆ getSurrogateModel() [1/2]

template<>
SurrogateModel& SurrogateModelInterface::getSurrogateModel ( const std::string &  name) const
inherited

Definition at line 46 of file SurrogateModelInterface.C.

47 {
48  return getSurrogateModelByName<SurrogateModel>(_smi_params.get<UserObjectName>(name));
49 }
const InputParameters & _smi_params
Parameters of the object with this interface.
std::vector< std::pair< R1, R2 > > get(const std::string &param1, const std::string &param2) const
const std::string name
Definition: Setup.h:20

◆ getSurrogateModel() [2/2]

template<typename T >
T & SurrogateModelInterface::getSurrogateModel ( const std::string &  name) const
inherited

Get a SurrogateModel/Trainer with a given name.

Parameters
nameThe name of the parameter key of the sampler to retrieve
Returns
The sampler with name associated with the parameter 'name'

Definition at line 81 of file SurrogateModelInterface.h.

Referenced by SurrogateTrainer::initialize().

82 {
83  return getSurrogateModelByName<T>(_smi_params.get<UserObjectName>(name));
84 }
const InputParameters & _smi_params
Parameters of the object with this interface.
std::vector< std::pair< R1, R2 > > get(const std::string &param1, const std::string &param2) const
const std::string name
Definition: Setup.h:20

◆ getSurrogateModelByName() [1/2]

template<>
SurrogateModel& SurrogateModelInterface::getSurrogateModelByName ( const UserObjectName &  name) const
inherited

Definition at line 31 of file SurrogateModelInterface.C.

32 {
33  std::vector<SurrogateModel *> models;
35  .query()
36  .condition<AttribName>(name)
37  .condition<AttribSystem>("SurrogateModel")
38  .queryInto(models);
39  if (models.empty())
40  mooseError("Unable to find a SurrogateModel object with the name '" + name + "'");
41  return *(models[0]);
42 }
void mooseError(Args &&... args)
FEProblemBase & _smi_feproblem
Reference to FEProblemBase instance.
TheWarehouse & theWarehouse() const
const std::string name
Definition: Setup.h:20
Query query()

◆ getSurrogateModelByName() [2/2]

template<typename T >
T & SurrogateModelInterface::getSurrogateModelByName ( const UserObjectName &  name) const
inherited

Get a sampler with a given name.

Parameters
nameThe name of the sampler to retrieve
Returns
The sampler with name 'name'

Definition at line 88 of file SurrogateModelInterface.h.

Referenced by CrossValidationScores::CrossValidationScores(), EvaluateSurrogate::EvaluateSurrogate(), and InverseMapping::initialSetup().

89 {
90  std::vector<T *> models;
92  .query()
93  .condition<AttribName>(name)
94  .condition<AttribSystem>("SurrogateModel")
95  .queryInto(models);
96  if (models.empty())
97  mooseError("Unable to find a SurrogateModel object of type " + std::string(typeid(T).name()) +
98  " with the name '" + name + "'");
99  return *(models[0]);
100 }
void mooseError(Args &&... args)
FEProblemBase & _smi_feproblem
Reference to FEProblemBase instance.
TheWarehouse & theWarehouse() const
const std::string name
Definition: Setup.h:20
Query query()

◆ getSurrogateTrainer() [1/2]

template<typename T >
T & SurrogateModelInterface::getSurrogateTrainer ( const std::string &  name) const
inherited

Definition at line 104 of file SurrogateModelInterface.h.

105 {
106  return getSurrogateTrainerByName<T>(_smi_params.get<UserObjectName>(name));
107 }
const InputParameters & _smi_params
Parameters of the object with this interface.
std::vector< std::pair< R1, R2 > > get(const std::string &param1, const std::string &param2) const
const std::string name
Definition: Setup.h:20

◆ getSurrogateTrainer() [2/2]

template<>
SurrogateTrainerBase& SurrogateModelInterface::getSurrogateTrainer ( const std::string &  name) const
inherited

Definition at line 60 of file SurrogateModelInterface.C.

61 {
62  return getSurrogateTrainerByName<SurrogateTrainerBase>(_smi_params.get<UserObjectName>(name));
63 }
const InputParameters & _smi_params
Parameters of the object with this interface.
std::vector< std::pair< R1, R2 > > get(const std::string &param1, const std::string &param2) const
const std::string name
Definition: Setup.h:20

◆ getSurrogateTrainerByName() [1/2]

template<>
SurrogateTrainerBase& SurrogateModelInterface::getSurrogateTrainerByName ( const UserObjectName &  name) const
inherited

Definition at line 53 of file SurrogateModelInterface.C.

54 {
56 }
T & getUserObject(const std::string &name, unsigned int tid=0) const
FEProblemBase & _smi_feproblem
Reference to FEProblemBase instance.
const std::string name
Definition: Setup.h:20
This is the base trainer class whose main functionality is the API for declaring model data...

◆ getSurrogateTrainerByName() [2/2]

template<typename T >
T & SurrogateModelInterface::getSurrogateTrainerByName ( const UserObjectName &  name) const
inherited

Definition at line 111 of file SurrogateModelInterface.h.

Referenced by SurrogateTrainerOutput::output().

112 {
113  SurrogateTrainerBase * base_ptr =
115  T * obj_ptr = dynamic_cast<T *>(base_ptr);
116  if (!obj_ptr)
117  mooseError("Failed to find a SurrogateTrainer object of type " + std::string(typeid(T).name()) +
118  " with the name '",
119  name,
120  "' for the desired type.");
121  return *obj_ptr;
122 }
T & getUserObject(const std::string &name, unsigned int tid=0) const
void mooseError(Args &&... args)
FEProblemBase & _smi_feproblem
Reference to FEProblemBase instance.
const std::string name
Definition: Setup.h:20
const THREAD_ID _smi_tid
Thread ID.
This is the base trainer class whose main functionality is the API for declaring model data...

◆ getTrainingData()

template<typename T >
const T & SurrogateTrainer::getTrainingData ( const ReporterName rname)
protectedinherited

Definition at line 208 of file SurrogateTrainer.h.

209 {
210  auto it = _training_data.find(rname);
211  if (it != _training_data.end())
212  {
213  auto data = std::dynamic_pointer_cast<TrainingData<T>>(it->second);
214  if (!data)
215  mooseError("Reporter value ", rname, " already exists but is of different type.");
216  return data->get();
217  }
218  else
219  {
220  const std::vector<T> & rval = getReporterValueByName<std::vector<T>>(rname);
221  _training_data[rname] = std::make_shared<TrainingData<T>>(rval);
222  return std::dynamic_pointer_cast<TrainingData<T>>(_training_data[rname])->get();
223  }
224 }
std::unordered_map< ReporterName, std::shared_ptr< TrainingDataBase > > _training_data
Vector of reporter names and their corresponding values (to be filled by getTrainingData) ...
void mooseError(Args &&... args) const

◆ hasModelData()

bool RestartableModelInterface::hasModelData ( ) const
inherited

Check if we need to load model data (if the filename parameter is used)

Definition at line 39 of file RestartableModelInterface.C.

40 {
41  return _model_object.isParamValid("filename");
42 }
bool isParamValid(const std::string &name) const
const MooseObject & _model_object
Reference to the MooseObject that uses this interface.

◆ initialize()

void SurrogateTrainer::initialize ( )
finalvirtualinherited

Reimplemented from SurrogateTrainerBase.

Definition at line 153 of file SurrogateTrainer.C.

154 {
155  // Figure out if data is distributed
156  for (auto & pair : _training_data)
157  {
158  const ReporterName & name = pair.first;
159  TrainingDataBase & data = *pair.second;
160 
161  const auto & mode = _fe_problem.getReporterData().getReporterMode(name);
162  if (mode == REPORTER_MODE_DISTRIBUTED || (mode == REPORTER_MODE_ROOT && processor_id() != 0))
163  data.isDistributed() = true;
164  else if (mode == REPORTER_MODE_REPLICATED ||
165  (mode == REPORTER_MODE_ROOT && processor_id() == 0))
166  data.isDistributed() = false;
167  else
168  mooseError("Predictor reporter value ", name, " is not of supported mode.");
169  }
170 
171  if (_doing_cv)
172  _cv_surrogate = &getSurrogateModel("cv_surrogate");
173 }
const bool _doing_cv
Set to true if cross validation is being performed, controls behavior in execute().
T & getSurrogateModel(const std::string &name) const
Get a SurrogateModel/Trainer with a given name.
virtual const std::string & name() const
const ReporterData & getReporterData() const
std::unordered_map< ReporterName, std::shared_ptr< TrainingDataBase > > _training_data
Vector of reporter names and their corresponding values (to be filled by getTrainingData) ...
const ReporterProducerEnum & getReporterMode(const ReporterName &reporter_name) const
FEProblemBase & _fe_problem
void mooseError(Args &&... args) const
const SurrogateModel * _cv_surrogate
SurrogateModel used to evaluate model error relative to test points.
processor_id_type processor_id() const

◆ modelMetaDataName()

const std::string& RestartableModelInterface::modelMetaDataName ( ) const
inlineinherited

Accessor for the name of the model meta data.

Definition at line 47 of file RestartableModelInterface.h.

Referenced by SurrogateTrainerOutput::output(), and MappingOutput::output().

47 { return _model_meta_data_name; }
const std::string _model_meta_data_name
The model meta data name.

◆ postTrain()

void LibtorchANNTrainer::postTrain ( )
overridevirtual

Contains processes which are executed after the training loop.

Reimplemented from SurrogateTrainer.

Definition at line 115 of file LibtorchANNTrainer.C.

116 {
119 
120  // Then, we create and load our Tensors
121  unsigned int num_samples = _flattened_response.size();
122  unsigned int num_inputs = _n_dims;
123 
124  // We create a neural net (for the definition of the net see the header file)
125  _nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
127 
128  if (_read_from_file)
129  try
130  {
131  torch::load(_nn, _nn_filename);
132  _console << "Loaded requested .pt file." << std::endl;
133  }
134  catch (const c10::Error & e)
135  {
136  mooseError("The requested pytorch file could not be loaded.\n", e.msg());
137  }
138 
139  // The default data type in pytorch is float, while we use double in MOOSE.
140  // Therefore, in some cases we have to convert Tensors to double.
141  auto options = torch::TensorOptions().dtype(at::kDouble);
142  torch::Tensor data_tensor =
143  torch::from_blob(_flattened_data.data(), {num_samples, num_inputs}, options).to(at::kDouble);
144  torch::Tensor response_tensor =
145  torch::from_blob(_flattened_response.data(), {num_samples, 1}, options).to(at::kDouble);
146 
147  // We standardize the input/output pairs if the user requested it
148  if (_standardize_input)
149  {
150  auto data_std_mean = torch::std_mean(data_tensor, 0);
151  auto & data_std = std::get<0>(data_std_mean);
152  auto & data_mean = std::get<1>(data_std_mean);
153 
154  data_tensor = (data_tensor - data_mean) / data_std;
155 
156  std::vector<Real> converted_data_mean;
157  LibtorchUtils::tensorToVector(data_mean, converted_data_mean);
158  std::vector<Real> converted_data_std;
159  LibtorchUtils::tensorToVector(data_std, converted_data_std);
160  _input_standardizer.set(converted_data_mean, converted_data_std);
161  }
162  else
164 
166  {
167  auto response_std_mean = torch::std_mean(response_tensor, 0);
168  auto & response_std = std::get<0>(response_std_mean);
169  auto & response_mean = std::get<1>(response_std_mean);
170 
171  response_tensor = (response_tensor - response_mean) / response_std;
172 
173  std::vector<Real> converted_response_mean;
174  LibtorchUtils::tensorToVector(response_mean, converted_response_mean);
175  std::vector<Real> converted_response_std;
176  LibtorchUtils::tensorToVector(response_std, converted_response_std);
177  _output_standardizer.set(converted_response_mean, converted_response_std);
178  }
179  else
181 
182  // We create a custom data set from our converted data
183  Moose::LibtorchDataset my_data(data_tensor, response_tensor);
184 
185  // We create atrainer for our neral net and train it with the dataset
187  trainer.train(my_data, _optim_options);
188 }
std::vector< unsigned int > & _num_neurons_per_layer
Number of neurons within the hidden layers (the length of this vector should be the same as _num_hidd...
void allgather(const T &send_data, std::vector< T, A > &recv_data) const
unsigned int _n_dims
Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols...
const bool _standardize_output
If the training output should be standardized (scaled and shifted)
void tensorToVector(torch::Tensor &tensor, std::vector< DataType > &vector)
const Parallel::Communicator & comm() const
const Parallel::Communicator & _communicator
std::vector< Real > _flattened_response
The gathered response in a flattened form to be able to convert easily to torch::Tensor.
const bool _standardize_input
If the training output should be standardized (scaled and shifted)
const std::string _nn_filename
Name of the pytorch output file.
Moose::LibtorchTrainingOptions _optim_options
The struct which contains the information for the training of the neural net.
const bool _read_from_file
Switch indicating if an already existing neural net should be read from a file or not...
std::vector< std::string > & _activation_function
Activation functions for each hidden layer.
StochasticTools::Standardizer & _output_standardizer
Standardizer for use with output response (y)
void set(const Real &n)
Methods for setting mean and standard deviation directly Sets mean=0, std=1 for n variables...
Definition: Standardizer.C:16
void mooseError(Args &&... args) const
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > & _nn
Pointer to the neural net object (initialized as null)
const ConsoleStream _console
StochasticTools::Standardizer & _input_standardizer
Standardizer for use with input (x)
std::vector< Real > _flattened_data
The gathered data in a flattened form to be able to convert easily to torch::Tensor.

◆ preTrain()

void LibtorchANNTrainer::preTrain ( )
overridevirtual

Contains processes which are executed before the training loop.

Reimplemented from SurrogateTrainer.

Definition at line 96 of file LibtorchANNTrainer.C.

97 {
98  // Resize to number of sample points
99  _flattened_data.clear();
100  _flattened_response.clear();
103 }
unsigned int _n_dims
Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols...
std::vector< Real > _flattened_response
The gathered response in a flattened form to be able to convert easily to torch::Tensor.
unsigned int getLocalSampleSize() const
std::vector< Real > _flattened_data
The gathered data in a flattened form to be able to convert easily to torch::Tensor.

◆ threadJoin()

virtual void SurrogateTrainerBase::threadJoin ( const UserObject )
inlinefinalvirtualinherited

Reimplemented from GeneralUserObject.

Definition at line 40 of file SurrogateTrainer.h.

40 {} // GeneralUserObjects are not threaded

◆ train()

void LibtorchANNTrainer::train ( )
overridevirtual

Contains processes which are executed for every sample in the training loop.

Reimplemented from SurrogateTrainer.

Definition at line 106 of file LibtorchANNTrainer.C.

107 {
108  for (auto & p : _predictor_row)
109  _flattened_data.push_back(p);
110 
111  _flattened_response.push_back(*_rval);
112 }
const Real * _rval
Response value.
std::vector< Real > _flattened_response
The gathered response in a flattened form to be able to convert easily to torch::Tensor.
const std::vector< Real > & _predictor_row
Data from the current predictor row.
std::vector< Real > _flattened_data
The gathered data in a flattened form to be able to convert easily to torch::Tensor.

◆ validParams()

InputParameters LibtorchANNTrainer::validParams ( )
static

Definition at line 19 of file LibtorchANNTrainer.C.

20 {
22 
23  params.addClassDescription("Trains a simple neural network using libtorch.");
24 
25  params.addRangeCheckedParam<unsigned int>(
26  "num_batches", 1, "1<=num_batches", "Number of batches.");
27  params.addRangeCheckedParam<unsigned int>(
28  "num_epochs", 1, "0<num_epochs", "Number of training epochs.");
29  params.addRangeCheckedParam<Real>(
30  "rel_loss_tol",
31  0,
32  "0<=rel_loss_tol<=1",
33  "The relative loss where we stop the training of the neural net.");
34  params.addParam<std::vector<unsigned int>>(
35  "num_neurons_per_layer", std::vector<unsigned int>(), "Number of neurons per layer.");
36  params.addParam<std::vector<std::string>>(
37  "activation_function",
38  std::vector<std::string>({"relu"}),
39  "The type of activation functions to use. It is either one value "
40  "or one value per hidden layer.");
41  params.addParam<std::string>(
42  "nn_filename", "net.pt", "Filename used to output the neural net parameters.");
43  params.addParam<bool>("read_from_file",
44  false,
45  "Switch to allow reading old trained neural nets for further training.");
46  params.addParam<Real>("learning_rate", 0.001, "Learning rate (relaxation).");
47  params.addRangeCheckedParam<unsigned int>(
48  "print_epoch_loss",
49  0,
50  "0<=print_epoch_loss",
51  "Epoch training loss printing. 0 - no printing, 1 - every epoch, 10 - every 10th epoch.");
52  params.addParam<unsigned int>(
53  "seed", 11, "Random number generator seed for stochastic optimizers.");
54  params.addParam<unsigned int>(
55  "max_processes", 1, "The maximum number of parallel processes that the trainer will use.");
56 
57  params.addParam<bool>(
58  "standardize_input", true, "Standardize (center and scale) training inputs (x values)");
59  params.addParam<bool>(
60  "standardize_output", true, "Standardize (center and scale) training outputs (y values)");
61 
62  params.suppressParameter<MooseEnum>("response_type");
63  return params;
64 }
void addParam(const std::string &name, const std::initializer_list< typename T::value_type > &value, const std::string &doc_string)
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
void addClassDescription(const std::string &doc_string)
void addRangeCheckedParam(const std::string &name, const T &value, const std::string &parsed_function, const std::string &doc_string)
static InputParameters validParams()

Member Data Documentation

◆ _activation_function

std::vector<std::string>& LibtorchANNTrainer::_activation_function
private

Activation functions for each hidden layer.

Definition at line 55 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _flattened_data

std::vector<Real> LibtorchANNTrainer::_flattened_data
private

The gathered data in a flattened form to be able to convert easily to torch::Tensor.

Definition at line 45 of file LibtorchANNTrainer.h.

Referenced by postTrain(), preTrain(), and train().

◆ _flattened_response

std::vector<Real> LibtorchANNTrainer::_flattened_response
private

The gathered response in a flattened form to be able to convert easily to torch::Tensor.

Definition at line 48 of file LibtorchANNTrainer.h.

Referenced by postTrain(), preTrain(), and train().

◆ _input_standardizer

StochasticTools::Standardizer& LibtorchANNTrainer::_input_standardizer
private

Standardizer for use with input (x)

Definition at line 79 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _local_row

dof_id_type SurrogateTrainer::_local_row
protectedinherited

During training loop, this is the local row index of the data.

Definition at line 123 of file SurrogateTrainer.h.

Referenced by SurrogateTrainer::executeTraining().

◆ _n_dims

unsigned int SurrogateTrainer::_n_dims
protectedinherited

Dimension of predictor data - either _sampler.getNumberOfCols() or _pvals.size() + _pcols.size().

Definition at line 133 of file SurrogateTrainer.h.

Referenced by NearestPointTrainer::NearestPointTrainer(), GaussianProcessTrainer::postTrain(), postTrain(), preTrain(), SurrogateTrainer::SurrogateTrainer(), NearestPointTrainer::train(), and PolynomialRegressionTrainer::train().

◆ _n_outputs

unsigned int& SurrogateTrainer::_n_outputs
protectedinherited

◆ _nn

std::shared_ptr<Moose::LibtorchArtificialNeuralNet>& LibtorchANNTrainer::_nn
private

Pointer to the neural net object (initialized as null)

Definition at line 70 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _nn_filename

const std::string LibtorchANNTrainer::_nn_filename
private

Name of the pytorch output file.

This is used for loading and storing already existing data.

Definition at line 59 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _num_neurons_per_layer

std::vector<unsigned int>& LibtorchANNTrainer::_num_neurons_per_layer
private

Number of neurons within the hidden layers (the length of this vector should be the same as _num_hidden_layers)

Definition at line 52 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _optim_options

Moose::LibtorchTrainingOptions LibtorchANNTrainer::_optim_options
private

The struct which contains the information for the training of the neural net.

Definition at line 67 of file LibtorchANNTrainer.h.

Referenced by LibtorchANNTrainer(), and postTrain().

◆ _output_standardizer

StochasticTools::Standardizer& LibtorchANNTrainer::_output_standardizer
private

Standardizer for use with output response (y)

Definition at line 82 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _pcols

std::vector<unsigned int> SurrogateTrainer::_pcols
protectedinherited

◆ _predictor_row

const std::vector<Real>& LibtorchANNTrainer::_predictor_row
private

Data from the current predictor row.

Definition at line 42 of file LibtorchANNTrainer.h.

Referenced by train().

◆ _pvals

std::vector<const Real *> SurrogateTrainer::_pvals
protectedinherited

◆ _read_from_file

const bool LibtorchANNTrainer::_read_from_file
private

Switch indicating if an already existing neural net should be read from a file or not.

This can be used to load existing torch files (from previous MOOSE or python runs for retraining and further manipulation)

Definition at line 64 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _row

dof_id_type SurrogateTrainer::_row
protectedinherited

During training loop, this is the row index of the data.

Definition at line 121 of file SurrogateTrainer.h.

Referenced by SurrogateTrainer::executeTraining(), and PolynomialChaosTrainer::train().

◆ _rval

const Real* SurrogateTrainer::_rval
protectedinherited

◆ _rvecval

const std::vector<Real>* SurrogateTrainer::_rvecval
protectedinherited

◆ _sampler

Sampler& SurrogateTrainer::_sampler
protectedinherited

◆ _standardize_input

const bool LibtorchANNTrainer::_standardize_input
private

If the training output should be standardized (scaled and shifted)

Definition at line 73 of file LibtorchANNTrainer.h.

Referenced by postTrain().

◆ _standardize_output

const bool LibtorchANNTrainer::_standardize_output
private

If the training output should be standardized (scaled and shifted)

Definition at line 76 of file LibtorchANNTrainer.h.

Referenced by postTrain().


The documentation for this class was generated from the following files: