https://mooseframework.inl.gov
Public Types | Public Member Functions | Static Public Member Functions | Public Attributes | Static Public Attributes | Protected Member Functions | Static Protected Member Functions | Protected Attributes | Static Protected Attributes | Private Member Functions | Private Attributes | List of all members
LibtorchDRLControlTrainer Class Reference

This trainer is responsible for training neural networks that efficiently control different processes. More...

#include <LibtorchDRLControlTrainer.h>

Inheritance diagram for LibtorchDRLControlTrainer:
[legend]

Public Types

typedef DataFileName DataFileParameterType
 

Public Member Functions

 LibtorchDRLControlTrainer (const InputParameters &parameters)
 construct using input parameters More...
 
virtual void execute () override
 
Real averageEpisodeReward ()
 Function which returns the current average episodic reward. More...
 
void trainController ()
 The condensed training function. More...
 
const Moose::LibtorchArtificialNeuralNetcontrolNeuralNet () const
 
virtual void initialize ()
 
virtual void finalize ()
 
virtual void threadJoin (const UserObject &) final
 
SubProblemgetSubProblem () const
 
bool shouldDuplicateInitialExecution () const
 
virtual Real spatialValue (const Point &) const
 
virtual const std::vector< Point > spatialPoints () const
 
void gatherSum (T &value)
 
void gatherMax (T &value)
 
void gatherMin (T &value)
 
void gatherProxyValueMax (T1 &proxy, T2 &value)
 
void gatherProxyValueMin (T1 &proxy, T2 &value)
 
void setPrimaryThreadCopy (UserObject *primary)
 
UserObjectprimaryThreadCopy ()
 
std::set< UserObjectName > getDependObjects () const
 
virtual bool needThreadedCopy () const
 
const std::set< std::string > & getRequestedItems () override
 
const std::set< std::string > & getSuppliedItems () override
 
unsigned int systemNumber () const
 
virtual bool enabled () const
 
std::shared_ptr< MooseObjectgetSharedPtr ()
 
std::shared_ptr< const MooseObjectgetSharedPtr () const
 
MooseAppgetMooseApp () const
 
const std::string & type () const
 
virtual const std::string & name () const
 
std::string typeAndName () const
 
std::string errorPrefix (const std::string &error_type) const
 
void callMooseError (std::string msg, const bool with_prefix) const
 
MooseObjectParameterName uniqueParameterName (const std::string &parameter_name) const
 
const InputParametersparameters () const
 
MooseObjectName uniqueName () const
 
const T & getParam (const std::string &name) const
 
std::vector< std::pair< T1, T2 > > getParam (const std::string &param1, const std::string &param2) const
 
const T * queryParam (const std::string &name) const
 
const T & getRenamedParam (const std::string &old_name, const std::string &new_name) const
 
getCheckedPointerParam (const std::string &name, const std::string &error_string="") const
 
bool isParamValid (const std::string &name) const
 
bool isParamSetByUser (const std::string &nm) const
 
void paramError (const std::string &param, Args... args) const
 
void paramWarning (const std::string &param, Args... args) const
 
void paramInfo (const std::string &param, Args... args) const
 
void connectControllableParams (const std::string &parameter, const std::string &object_type, const std::string &object_name, const std::string &object_parameter) const
 
void mooseError (Args &&... args) const
 
void mooseErrorNonPrefixed (Args &&... args) const
 
void mooseDocumentedError (const std::string &repo_name, const unsigned int issue_num, Args &&... args) const
 
void mooseWarning (Args &&... args) const
 
void mooseWarningNonPrefixed (Args &&... args) const
 
void mooseDeprecated (Args &&... args) const
 
void mooseInfo (Args &&... args) const
 
std::string getDataFileName (const std::string &param) const
 
std::string getDataFileNameByName (const std::string &relative_path) const
 
std::string getDataFilePath (const std::string &relative_path) const
 
virtual void initialSetup ()
 
virtual void timestepSetup ()
 
virtual void jacobianSetup ()
 
virtual void residualSetup ()
 
virtual void customSetup (const ExecFlagType &)
 
const ExecFlagEnumgetExecuteOnEnum () const
 
UserObjectName getUserObjectName (const std::string &param_name) const
 
const T & getUserObject (const std::string &param_name, bool is_dependency=true) const
 
const T & getUserObjectByName (const UserObjectName &object_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBase (const std::string &param_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBaseByName (const UserObjectName &object_name, bool is_dependency=true) const
 
const std::vector< MooseVariableScalar *> & getCoupledMooseScalarVars ()
 
const std::set< TagID > & getScalarVariableCoupleableVectorTags () const
 
const std::set< TagID > & getScalarVariableCoupleableMatrixTags () const
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, MaterialData &material_data, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialProperty (const std::string &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, MaterialData &material_data, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialProperty (const std::string &name, const unsigned int state=0)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name, MaterialData &material_data)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > & getADMaterialProperty (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOld (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name)
 
const MaterialProperty< T > & getMaterialPropertyOlder (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data, const unsigned int state)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const GenericMaterialProperty< T, is_ad > & getGenericMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const MaterialProperty< T > & getMaterialPropertyByName (const MaterialPropertyName &name, const unsigned int state=0)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name)
 
const ADMaterialProperty< T > & getADMaterialPropertyByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOldByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name, MaterialData &material_data)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name)
 
const MaterialProperty< T > & getMaterialPropertyOlderByName (const MaterialPropertyName &name)
 
std::pair< const MaterialProperty< T > *, std::set< SubdomainID > > getBlockMaterialProperty (const MaterialPropertyName &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialProperty (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialProperty ()
 
const GenericMaterialProperty< T, is_ad > & getGenericZeroMaterialPropertyByName (const std::string &prop_name)
 
const MaterialProperty< T > & getZeroMaterialProperty (Ts... args)
 
std::set< SubdomainIDgetMaterialPropertyBlocks (const std::string &name)
 
std::vector< SubdomainName > getMaterialPropertyBlockNames (const std::string &name)
 
std::set< BoundaryIDgetMaterialPropertyBoundaryIDs (const std::string &name)
 
std::vector< BoundaryName > getMaterialPropertyBoundaryNames (const std::string &name)
 
void checkBlockAndBoundaryCompatibility (std::shared_ptr< MaterialBase > discrete)
 
std::unordered_map< SubdomainID, std::vector< MaterialBase *> > buildRequiredMaterials (bool allow_stateful=true)
 
void statefulPropertiesAllowed (bool)
 
bool getMaterialPropertyCalled () const
 
virtual const std::unordered_set< unsigned int > & getMatPropDependencies () const
 
virtual void resolveOptionalProperties ()
 
const GenericMaterialProperty< T, is_ad > & getPossiblyConstantGenericMaterialPropertyByName (const MaterialPropertyName &prop_name, MaterialData &material_data, const unsigned int state)
 
bool isImplicit ()
 
Moose::StateArg determineState () const
 
virtual void subdomainSetup () override
 
virtual void subdomainSetup () override
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
const GenericOptionalMaterialProperty< T, is_ad > & getGenericOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const GenericOptionalMaterialProperty< T, is_ad > & getGenericOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalMaterialProperty< T > & getOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalMaterialProperty< T > & getOptionalMaterialProperty (const std::string &name, const unsigned int state=0)
 
const OptionalADMaterialProperty< T > & getOptionalADMaterialProperty (const std::string &name)
 
const OptionalADMaterialProperty< T > & getOptionalADMaterialProperty (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOld (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOld (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOlder (const std::string &name)
 
const OptionalMaterialProperty< T > & getOptionalMaterialPropertyOlder (const std::string &name)
 
MaterialBasegetMaterial (const std::string &name)
 
MaterialBasegetMaterial (const std::string &name)
 
MaterialBasegetMaterialByName (const std::string &name, bool no_warn=false)
 
MaterialBasegetMaterialByName (const std::string &name, bool no_warn=false)
 
bool hasMaterialProperty (const std::string &name)
 
bool hasMaterialProperty (const std::string &name)
 
bool hasMaterialPropertyByName (const std::string &name)
 
bool hasMaterialPropertyByName (const std::string &name)
 
bool hasADMaterialProperty (const std::string &name)
 
bool hasADMaterialProperty (const std::string &name)
 
bool hasADMaterialPropertyByName (const std::string &name)
 
bool hasADMaterialPropertyByName (const std::string &name)
 
bool hasGenericMaterialProperty (const std::string &name)
 
bool hasGenericMaterialProperty (const std::string &name)
 
bool hasGenericMaterialPropertyByName (const std::string &name)
 
bool hasGenericMaterialPropertyByName (const std::string &name)
 
const FunctiongetFunction (const std::string &name) const
 
const FunctiongetFunctionByName (const FunctionName &name) const
 
bool hasFunction (const std::string &param_name) const
 
bool hasFunctionByName (const FunctionName &name) const
 
bool isDefaultPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessor (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessorByName (const PostprocessorName &name) const
 
std::size_t coupledPostprocessors (const std::string &param_name) const
 
const PostprocessorName & getPostprocessorName (const std::string &param_name, const unsigned int index=0) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name) const
 
const VectorPostprocessorName & getVectorPostprocessorName (const std::string &param_name) const
 
T & getSampler (const std::string &name)
 
SamplergetSampler (const std::string &name)
 
T & getSamplerByName (const SamplerName &name)
 
SamplergetSamplerByName (const SamplerName &name)
 
virtual void meshChanged ()
 
virtual void meshDisplaced ()
 
PerfGraphperfGraph ()
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
const DistributiongetDistribution (const std::string &name) const
 
const T & getDistribution (const std::string &name) const
 
const DistributiongetDistribution (const std::string &name) const
 
const T & getDistribution (const std::string &name) const
 
const DistributiongetDistributionByName (const DistributionName &name) const
 
const T & getDistributionByName (const std::string &name) const
 
const DistributiongetDistributionByName (const DistributionName &name) const
 
const T & getDistributionByName (const std::string &name) const
 
const Parallel::Communicator & comm () const
 
processor_id_type n_processors () const
 
processor_id_type processor_id () const
 
const std::string & modelMetaDataName () const
 Accessor for the name of the model meta data. More...
 
const FileName & getModelDataFileName () const
 Get the associated filename. More...
 
bool hasModelData () const
 Check if we need to load model data (if the filename parameter is used) More...
 
template<typename T , typename... Args>
T & declareModelData (const std::string &data_name, Args &&... args)
 Declare model data for loading from file as well as restart. More...
 
template<typename T , typename... Args>
const T & getModelData (const std::string &data_name, Args &&... args) const
 Retrieve model data from the interface. More...
 

Static Public Member Functions

static InputParameters validParams ()
 
static void sort (typename std::vector< T > &vector)
 
static void sortDFS (typename std::vector< T > &vector)
 
static void cyclicDependencyError (CyclicDependencyException< T2 > &e, const std::string &header)
 

Public Attributes

const ConsoleStream _console
 

Static Public Attributes

static constexpr PropertyValue::id_type default_property_id
 
static constexpr PropertyValue::id_type zero_property_id
 
static constexpr auto SYSTEM
 
static constexpr auto NAME
 

Protected Member Functions

void computeAverageEpisodeReward ()
 Compute the average eposiodic reward. More...
 
void convertDataToTensor (std::vector< std::vector< Real >> &vector_data, torch::Tensor &tensor_data, const bool detach=false)
 Function to convert input/output data from std::vector<std::vector> to torch::tensor. More...
 
torch::Tensor evaluateValue (torch::Tensor &input)
 Function which evaluates the critic to get the value (discounter reward) More...
 
torch::Tensor evaluateAction (torch::Tensor &input, torch::Tensor &output)
 Function which evaluates the control net and then computes the logarithmic probability of the action. More...
 
void computeRewardToGo ()
 Compute the return value by discounting the rewards and summing them. More...
 
void resetData ()
 Reset data after updating the neural network. More...
 
virtual void addPostprocessorDependencyHelper (const PostprocessorName &name) const override
 
virtual void addVectorPostprocessorDependencyHelper (const VectorPostprocessorName &name) const override
 
virtual void addUserObjectDependencyHelper (const UserObject &uo) const override
 
void addReporterDependencyHelper (const ReporterName &reporter_name) override
 
const ReporterNamegetReporterName (const std::string &param_name) const
 
T & declareRestartableData (const std::string &data_name, Args &&... args)
 
ManagedValue< T > declareManagedRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
const T & getRestartableData (const std::string &data_name) const
 
T & declareRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
T & declareRecoverableData (const std::string &data_name, Args &&... args)
 
T & declareRestartableDataWithObjectName (const std::string &data_name, const std::string &object_name, Args &&... args)
 
T & declareRestartableDataWithObjectNameWithContext (const std::string &data_name, const std::string &object_name, void *context, Args &&... args)
 
std::string restartableName (const std::string &data_name) const
 
const T & getMeshProperty (const std::string &data_name, const std::string &prefix)
 
const T & getMeshProperty (const std::string &data_name)
 
bool hasMeshProperty (const std::string &data_name, const std::string &prefix) const
 
bool hasMeshProperty (const std::string &data_name, const std::string &prefix) const
 
bool hasMeshProperty (const std::string &data_name) const
 
bool hasMeshProperty (const std::string &data_name) const
 
std::string meshPropertyName (const std::string &data_name) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level, const std::string &live_message, const bool print_dots=true) const
 
std::string timedSectionName (const std::string &section_name) const
 
bool isCoupledScalar (const std::string &var_name, unsigned int i=0) const
 
unsigned int coupledScalarComponents (const std::string &var_name) const
 
unsigned int coupledScalar (const std::string &var_name, unsigned int comp=0) const
 
libMesh::Order coupledScalarOrder (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const ADVariableValueadCoupledScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const GenericVariableValue< is_ad > & coupledGenericScalarValue (const std::string &var_name, unsigned int comp=0) const
 
const GenericVariableValue< false > & coupledGenericScalarValue (const std::string &var_name, const unsigned int comp) const
 
const GenericVariableValue< true > & coupledGenericScalarValue (const std::string &var_name, const unsigned int comp) const
 
const VariableValuecoupledVectorTagScalarValue (const std::string &var_name, TagID tag, unsigned int comp=0) const
 
const VariableValuecoupledMatrixTagScalarValue (const std::string &var_name, TagID tag, unsigned int comp=0) const
 
const VariableValuecoupledScalarValueOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarValueOlder (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDot (const std::string &var_name, unsigned int comp=0) const
 
const ADVariableValueadCoupledScalarDot (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDot (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDotOld (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDu (const std::string &var_name, unsigned int comp=0) const
 
const VariableValuecoupledScalarDotDotDu (const std::string &var_name, unsigned int comp=0) const
 
const MooseVariableScalargetScalarVar (const std::string &var_name, unsigned int comp) const
 
virtual void checkMaterialProperty (const std::string &name, const unsigned int state)
 
void markMatPropRequested (const std::string &)
 
MaterialPropertyName getMaterialPropertyName (const std::string &name) const
 
void checkExecutionStage ()
 
const T & getReporterValue (const std::string &param_name, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, const std::size_t time_index=0)
 
const T & getReporterValue (const std::string &param_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, ReporterMode mode, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, const std::size_t time_index=0)
 
const T & getReporterValueByName (const ReporterName &reporter_name, ReporterMode mode, const std::size_t time_index=0)
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValue (const std::string &param_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
bool hasReporterValueByName (const ReporterName &reporter_name) const
 
const GenericMaterialProperty< T, is_ad > * defaultGenericMaterialProperty (const std::string &name)
 
const GenericMaterialProperty< T, is_ad > * defaultGenericMaterialProperty (const std::string &name)
 
const MaterialProperty< T > * defaultMaterialProperty (const std::string &name)
 
const MaterialProperty< T > * defaultMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > * defaultADMaterialProperty (const std::string &name)
 
const ADMaterialProperty< T > * defaultADMaterialProperty (const std::string &name)
 

Static Protected Member Functions

static std::string meshPropertyName (const std::string &data_name, const std::string &prefix)
 

Protected Attributes

const std::vector< ReporterName_response_names
 Response reporter names. More...
 
std::vector< const std::vector< Real > * > _response_value_pointers
 Pointers to the current values of the responses. More...
 
const std::vector< Real_response_shift_factors
 Shifting constants for the responses. More...
 
const std::vector< Real_response_scaling_factors
 Scaling constants for the responses. More...
 
const std::vector< ReporterName_control_names
 Control reporter names. More...
 
std::vector< const std::vector< Real > * > _control_value_pointers
 Pointers to the current values of the control signals. More...
 
const std::vector< ReporterName_log_probability_names
 Log probability reporter names. More...
 
std::vector< const std::vector< Real > * > _log_probability_value_pointers
 Pointers to the current values of the control log probabilities. More...
 
const ReporterName _reward_name
 Reward reporter name. More...
 
const std::vector< Real > * _reward_value_pointer
 Pointer to the current values of the reward. More...
 
const unsigned int _input_timesteps
 Number of timesteps to fetch from the reporters to be the input of then eural nets. More...
 
unsigned int _num_inputs
 Number of inputs for the control and critic neural nets. More...
 
unsigned int _num_outputs
 Number of outputs for the control neural network. More...
 
const unsigned int _num_epochs
 Number of epochs for the training of the emulator. More...
 
const std::vector< unsigned int_num_critic_neurons_per_layer
 Number of neurons within the hidden layers in the critic neural net. More...
 
const Real _critic_learning_rate
 The learning rate for the optimization algorithm for the critic. More...
 
const std::vector< unsigned int_num_control_neurons_per_layer
 Number of neurons within the hidden layers in the control neural net. More...
 
const Real _control_learning_rate
 The learning rate for the optimization algorithm for the control. More...
 
const unsigned int _update_frequency
 Number of transients to run and collect data from before updating the controller neural net. More...
 
const Real _clip_param
 The clip parameter used while clamping the advantage value. More...
 
const Real _decay_factor
 Decaying factor that is used when calculating the return from the reward. More...
 
const std::vector< Real_action_std
 Standard deviation for the actions. More...
 
const std::string _filename_base
 Name of the pytorch output file. More...
 
const bool _read_from_file
 Switch indicating if an already existing neural net should be read from a file or not. More...
 
const bool _shift_outputs
 Currently, the controls are executed after the user objects at initial in moose. More...
 
Real _average_episode_reward
 Storage for the current average episode reward. More...
 
const bool _standardize_advantage
 Switch to enable the standardization of the advantages. More...
 
const unsigned int _loss_print_frequency
 The frequency the loss should be printed. More...
 
std::shared_ptr< Moose::LibtorchArtificialNeuralNet_control_nn
 Pointer to the control (or actor) neural net object. More...
 
std::shared_ptr< Moose::LibtorchArtificialNeuralNet_critic_nn
 Pointer to the critic neural net object. More...
 
torch::Tensor _std
 standard deviation in a tensor format for sampling the actual control value More...
 
torch::Tensor _input_tensor
 Torch::tensor version of the input and action data. More...
 
torch::Tensor _output_tensor
 
torch::Tensor _return_tensor
 
torch::Tensor _log_probability_tensor
 
SubProblem_subproblem
 
FEProblemBase_fe_problem
 
SystemBase_sys
 
const THREAD_ID _tid
 
Assembly_assembly
 
const Moose::CoordinateSystemType_coord_sys
 
const bool _duplicate_initial_execution
 
std::set< std::string > _depend_uo
 
const bool & _enabled
 
MooseApp_app
 
const std::string _type
 
const std::string _name
 
const InputParameters_pars
 
Factory_factory
 
ActionFactory_action_factory
 
const ExecFlagEnum_execute_enum
 
const ExecFlagType_current_execute_flag
 
MooseApp_restartable_app
 
const std::string _restartable_system_name
 
const THREAD_ID _restartable_tid
 
const bool _restartable_read_only
 
FEProblemBase_mci_feproblem
 
FEProblemBase_mdi_feproblem
 
MooseApp_pg_moose_app
 
const std::string _prefix
 
FEProblemBase_sc_fe_problem
 
const THREAD_ID _sc_tid
 
const Real_real_zero
 
const VariableValue_scalar_zero
 
const Point & _point_zero
 
const InputParameters_mi_params
 
const std::string _mi_name
 
const MooseObjectName _mi_moose_object_name
 
FEProblemBase_mi_feproblem
 
SubProblem_mi_subproblem
 
const THREAD_ID _mi_tid
 
const Moose::MaterialDataType _material_data_type
 
MaterialData_material_data
 
bool _stateful_allowed
 
bool _get_material_property_called
 
std::vector< std::unique_ptr< PropertyValue > > _default_properties
 
std::unordered_set< unsigned int_material_property_dependencies
 
const MaterialPropertyName _get_suffix
 
const bool _use_interpolated_state
 
const InputParameters_ti_params
 
FEProblemBase_ti_feproblem
 
bool _is_implicit
 
Real_t
 
const Real_t_old
 
int_t_step
 
Real_dt
 
Real_dt_old
 
bool _is_transient
 
const Parallel::Communicator & _communicator
 
std::vector< std::vector< Real > > _input_data
 
std::vector< std::vector< Real > > _output_data
 
std::vector< std::vector< Real > > _log_probability_data
 
std::vector< Real_reward_data
 
std::vector< Real_return_data
 

Static Protected Attributes

static const std::string _interpolated_old
 
static const std::string _interpolated_older
 

Private Member Functions

void getInputDataFromReporter (std::vector< std::vector< Real >> &data, const std::vector< const std::vector< Real > *> &reporter_links, const unsigned int num_timesteps)
 Extract the response values from the postprocessors of the controlled system. More...
 
void getOutputDataFromReporter (std::vector< std::vector< Real >> &data, const std::vector< const std::vector< Real > *> &reporter_links)
 Extract the output (actions, logarithmic probabilities) values from the postprocessors of the controlled system. More...
 
void getRewardDataFromReporter (std::vector< Real > &data, const std::vector< Real > *const reporter_link)
 Extract the reward values from the postprocessors of the controlled system This assumes that they are stored in an AccumulateReporter. More...
 
void getReporterPointers (const std::vector< ReporterName > &reporter_names, std::vector< const std::vector< Real > *> &pointer_storage)
 Getting reporter pointers with given names. More...
 

Private Attributes

unsigned int _update_counter
 Counter for number of transient simulations that have been run before updating the controller. More...
 

Detailed Description

This trainer is responsible for training neural networks that efficiently control different processes.

It utilizes the Proximal Policy Optimization algorithms. For more information on the algorithm, see the following resources: Schulman, John, et al. "Proximal policy optimization algorithms." arXiv preprint arXiv:1707.06347 (2017). https://medium.com/analytics-vidhya/coding-ppo-from-scratch-with-pytorch-part-1-4-613dfc1b14c8 https://stable-baselines.readthedocs.io/en/master/modules/ppo2.html

Definition at line 28 of file LibtorchDRLControlTrainer.h.

Constructor & Destructor Documentation

◆ LibtorchDRLControlTrainer()

LibtorchDRLControlTrainer::LibtorchDRLControlTrainer ( const InputParameters parameters)

construct using input parameters

Definition at line 128 of file LibtorchDRLControlTrainer.C.

130  _response_names(getParam<std::vector<ReporterName>>("response")),
131  _response_shift_factors(isParamValid("response_shift_factors")
132  ? getParam<std::vector<Real>>("response_shift_factors")
133  : std::vector<Real>(_response_names.size(), 0.0)),
134  _response_scaling_factors(isParamValid("response_scaling_factors")
135  ? getParam<std::vector<Real>>("response_scaling_factors")
136  : std::vector<Real>(_response_names.size(), 1.0)),
137  _control_names(getParam<std::vector<ReporterName>>("control")),
138  _log_probability_names(getParam<std::vector<ReporterName>>("log_probability")),
139  _reward_name(getParam<ReporterName>("reward")),
141  _input_timesteps(getParam<unsigned int>("input_timesteps")),
144  _input_data(std::vector<std::vector<Real>>(_num_inputs)),
145  _output_data(std::vector<std::vector<Real>>(_num_outputs)),
146  _log_probability_data(std::vector<std::vector<Real>>(_num_outputs)),
147  _num_epochs(getParam<unsigned int>("num_epochs")),
149  getParam<std::vector<unsigned int>>("num_critic_neurons_per_layer")),
150  _critic_learning_rate(getParam<Real>("critic_learning_rate")),
152  getParam<std::vector<unsigned int>>("num_control_neurons_per_layer")),
153  _control_learning_rate(getParam<Real>("control_learning_rate")),
154  _update_frequency(getParam<unsigned int>("update_frequency")),
155  _clip_param(getParam<Real>("clip_parameter")),
156  _decay_factor(getParam<Real>("decay_factor")),
157  _action_std(getParam<std::vector<Real>>("action_standard_deviations")),
158  _filename_base(isParamValid("filename_base") ? getParam<std::string>("filename_base") : ""),
159  _read_from_file(getParam<bool>("read_from_file")),
160  _shift_outputs(getParam<bool>("shift_outputs")),
161  _standardize_advantage(getParam<bool>("standardize_advantage")),
162  _loss_print_frequency(getParam<unsigned int>("loss_print_frequency")),
164 {
165  if (_response_names.size() != _response_shift_factors.size())
166  paramError("response_shift_factors",
167  "The number of shift factors is not the same as the number of responses!");
168 
169  if (_response_names.size() != _response_scaling_factors.size())
170  paramError(
171  "response_scaling_factors",
172  "The number of normalization coefficients is not the same as the number of responses!");
173 
174  // We establish the links with the chosen reporters
178 
179  // Fixing the RNG seed to make sure every experiment is the same.
180  // Otherwise sampling / stochastic gradient descent would be different.
181  torch::manual_seed(getParam<unsigned int>("seed"));
182 
183  // Convert the user input standard deviations to a diagonal tensor
184  _std = torch::eye(_control_names.size());
185  for (unsigned int i = 0; i < _control_names.size(); ++i)
186  _std[i][i] = _action_std[i];
187 
188  bool filename_valid = isParamValid("filename_base");
189 
190  // Initializing the control neural net so that the control can grab it right away
191  _control_nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
192  filename_valid ? _filename_base + "_control.net" : "control.net",
193  _num_inputs,
194  _num_outputs,
196  getParam<std::vector<std::string>>("control_activation_functions"));
197 
198  // We read parameters for the control neural net if it is requested
199  if (_read_from_file)
200  {
201  try
202  {
203  torch::load(_control_nn, _control_nn->name());
204  _console << "Loaded requested .pt file." << std::endl;
205  }
206  catch (const c10::Error & e)
207  {
208  mooseError("The requested pytorch file could not be loaded for the control neural net.\n",
209  e.msg());
210  }
211  }
212  else if (filename_valid)
213  torch::save(_control_nn, _control_nn->name());
214 
215  // Initialize the critic neural net
216  _critic_nn = std::make_shared<Moose::LibtorchArtificialNeuralNet>(
217  filename_valid ? _filename_base + "_ctiric.net" : "ctiric.net",
218  _num_inputs,
219  1,
221  getParam<std::vector<std::string>>("critic_activation_functions"));
222 
223  // We read parameters for the critic neural net if it is requested
224  if (_read_from_file)
225  {
226  try
227  {
228  torch::load(_critic_nn, _critic_nn->name());
229  _console << "Loaded requested .pt file." << std::endl;
230  }
231  catch (const c10::Error & e)
232  {
233  mooseError("The requested pytorch file could not be loaded for the critic neural net.\n",
234  e.msg());
235  }
236  }
237  else if (filename_valid)
238  torch::save(_critic_nn, _critic_nn->name());
239 }
SurrogateTrainerBase(const InputParameters &parameters)
std::vector< std::vector< Real > > _input_data
const std::vector< ReporterName > _control_names
Control reporter names.
const std::vector< Real > _action_std
Standard deviation for the actions.
const bool _shift_outputs
Currently, the controls are executed after the user objects at initial in moose.
const T & getReporterValueByName(const ReporterName &reporter_name, const std::size_t time_index=0)
void getReporterPointers(const std::vector< ReporterName > &reporter_names, std::vector< const std::vector< Real > *> &pointer_storage)
Getting reporter pointers with given names.
const std::vector< unsigned int > _num_control_neurons_per_layer
Number of neurons within the hidden layers in the control neural net.
unsigned int _num_inputs
Number of inputs for the control and critic neural nets.
std::vector< const std::vector< Real > * > _response_value_pointers
Pointers to the current values of the responses.
const ReporterName _reward_name
Reward reporter name.
const std::vector< Real > _response_shift_factors
Shifting constants for the responses.
const Real _clip_param
The clip parameter used while clamping the advantage value.
const std::vector< ReporterName > _response_names
Response reporter names.
bool isParamValid(const std::string &name) const
std::vector< const std::vector< Real > * > _log_probability_value_pointers
Pointers to the current values of the control log probabilities.
const unsigned int _input_timesteps
Number of timesteps to fetch from the reporters to be the input of then eural nets.
std::vector< std::vector< Real > > _output_data
torch::Tensor _std
standard deviation in a tensor format for sampling the actual control value
const Real _control_learning_rate
The learning rate for the optimization algorithm for the control.
unsigned int _num_outputs
Number of outputs for the control neural network.
const Real _decay_factor
Decaying factor that is used when calculating the return from the reward.
const std::vector< unsigned int > _num_critic_neurons_per_layer
Number of neurons within the hidden layers in the critic neural net.
const Real _critic_learning_rate
The learning rate for the optimization algorithm for the critic.
const T & getParam(const std::string &name) const
std::vector< std::vector< Real > > _log_probability_data
const unsigned int _loss_print_frequency
The frequency the loss should be printed.
void paramError(const std::string &param, Args... args) const
const unsigned int _num_epochs
Number of epochs for the training of the emulator.
const std::vector< Real > _response_scaling_factors
Scaling constants for the responses.
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _critic_nn
Pointer to the critic neural net object.
const bool _read_from_file
Switch indicating if an already existing neural net should be read from a file or not...
const unsigned int _update_frequency
Number of transients to run and collect data from before updating the controller neural net...
const std::string _filename_base
Name of the pytorch output file.
void mooseError(Args &&... args) const
const InputParameters & parameters() const
const ConsoleStream _console
unsigned int _update_counter
Counter for number of transient simulations that have been run before updating the controller...
const bool _standardize_advantage
Switch to enable the standardization of the advantages.
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _control_nn
Pointer to the control (or actor) neural net object.
const std::vector< ReporterName > _log_probability_names
Log probability reporter names.
std::vector< const std::vector< Real > * > _control_value_pointers
Pointers to the current values of the control signals.
const std::vector< Real > * _reward_value_pointer
Pointer to the current values of the reward.

Member Function Documentation

◆ averageEpisodeReward()

Real LibtorchDRLControlTrainer::averageEpisodeReward ( )
inline

Function which returns the current average episodic reward.

It is only updated at the end of every episode.

Definition at line 42 of file LibtorchDRLControlTrainer.h.

Referenced by DRLRewardReporter::execute().

42 { return _average_episode_reward; }
Real _average_episode_reward
Storage for the current average episode reward.

◆ computeAverageEpisodeReward()

void LibtorchDRLControlTrainer::computeAverageEpisodeReward ( )
protected

Compute the average eposiodic reward.

Definition at line 278 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

279 {
280  if (_reward_data.size())
282  std::accumulate(_reward_data.begin(), _reward_data.end(), 0.0) / _reward_data.size();
283  else
285 }
Real _average_episode_reward
Storage for the current average episode reward.

◆ computeRewardToGo()

void LibtorchDRLControlTrainer::computeRewardToGo ( )
protected

Compute the return value by discounting the rewards and summing them.

Definition at line 288 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

289 {
290  // Get reward data from one simulation
291  std::vector<Real> reward_data_per_sim;
292  std::vector<Real> return_data_per_sim;
293  getRewardDataFromReporter(reward_data_per_sim, _reward_value_pointer);
294 
295  // Discount the reward to get the return value, we need this to be able to anticipate
296  // rewards based on the current behavior.
297  Real discounted_reward(0.0);
298  for (int i = reward_data_per_sim.size() - 1; i >= 0; --i)
299  {
300  discounted_reward = reward_data_per_sim[i] + discounted_reward * _decay_factor;
301 
302  // We are inserting to the front of the vector and push the rest back, this will
303  // ensure that the first element of the vector is the discounter reward for the whole transient
304  return_data_per_sim.insert(return_data_per_sim.begin(), discounted_reward);
305  }
306 
307  // Save and accumulate the return values
308  _return_data.insert(_return_data.end(), return_data_per_sim.begin(), return_data_per_sim.end());
309 }
void getRewardDataFromReporter(std::vector< Real > &data, const std::vector< Real > *const reporter_link)
Extract the reward values from the postprocessors of the controlled system This assumes that they are...
const Real _decay_factor
Decaying factor that is used when calculating the return from the reward.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
const std::vector< Real > * _reward_value_pointer
Pointer to the current values of the reward.

◆ controlNeuralNet()

const Moose::LibtorchArtificialNeuralNet& LibtorchDRLControlTrainer::controlNeuralNet ( ) const
inline

Definition at line 47 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchNeuralNetControlTransfer::execute().

47 { return *_control_nn; }
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _control_nn
Pointer to the control (or actor) neural net object.

◆ convertDataToTensor()

void LibtorchDRLControlTrainer::convertDataToTensor ( std::vector< std::vector< Real >> &  vector_data,
torch::Tensor tensor_data,
const bool  detach = false 
)
protected

Function to convert input/output data from std::vector<std::vector> to torch::tensor.

Parameters
vector_dataThe input data in vector-vectors format
tensor_dataThe tensor where we would like to save the results
detachIf the gradient info needs to be detached from the tensor

Definition at line 373 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

376 {
377  for (unsigned int i = 0; i < vector_data.size(); ++i)
378  {
379  torch::Tensor input_row;
380  LibtorchUtils::vectorToTensor(vector_data[i], input_row, detach);
381 
382  if (i == 0)
383  tensor_data = input_row;
384  else
385  tensor_data = torch::cat({tensor_data, input_row}, 1);
386  }
387 
388  if (detach)
389  tensor_data.detach();
390 }
void vectorToTensor(std::vector< DataType > &vector, torch::Tensor &tensor, const bool detach=false)

◆ declareModelData()

template<typename T , typename... Args>
T & RestartableModelInterface::declareModelData ( const std::string &  data_name,
Args &&...  args 
)
inherited

Declare model data for loading from file as well as restart.

Definition at line 78 of file RestartableModelInterface.h.

79 {
80  return _model_restartable.declareRestartableData<T>(data_name, std::forward<Args>(args)...);
81 }
T & declareRestartableData(const std::string &data_name, Args &&... args)
Declare a piece of data as "restartable" and initialize it.
PublicRestartable _model_restartable
Member for interfacing with the framework&#39;s restartable system.

◆ evaluateAction()

torch::Tensor LibtorchDRLControlTrainer::evaluateAction ( torch::Tensor input,
torch::Tensor output 
)
protected

Function which evaluates the control net and then computes the logarithmic probability of the action.

Parameters
inputThe observation values (responses)
outputThe actions corresponding to the observations
Returns
The estimated value for the logarithmic probability

Definition at line 399 of file LibtorchDRLControlTrainer.C.

Referenced by trainController().

400 {
401  torch::Tensor var = torch::matmul(_std, _std);
402 
403  // Compute an action and get it's logarithmic proability based on an assumed Gaussian distribution
404  torch::Tensor action = _control_nn->forward(input);
405  return -((action - output) * (action - output)) / (2 * var) - torch::log(_std) -
406  std::log(std::sqrt(2 * M_PI));
407 }
torch::Tensor _std
standard deviation in a tensor format for sampling the actual control value
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _control_nn
Pointer to the control (or actor) neural net object.

◆ evaluateValue()

torch::Tensor LibtorchDRLControlTrainer::evaluateValue ( torch::Tensor input)
protected

Function which evaluates the critic to get the value (discounter reward)

Parameters
inputThe observation values (responses)
Returns
The estimated value

Definition at line 393 of file LibtorchDRLControlTrainer.C.

Referenced by trainController().

394 {
395  return _critic_nn->forward(input);
396 }
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _critic_nn
Pointer to the critic neural net object.

◆ execute()

void LibtorchDRLControlTrainer::execute ( )
overridevirtual

Implements GeneralUserObject.

Definition at line 242 of file LibtorchDRLControlTrainer.C.

243 {
244  // Extract data from the reporters
249 
250  // Calculate return from the reward (discounting the reward)
252 
253  _update_counter--;
254 
255  // Only update the NNs when
256  if (_update_counter == 0)
257  {
258  // We compute the average reward first
260 
261  // Transform input/output/return data to torch::Tensor
265 
266  // Discard (detach) the gradient info for return data
268 
269  // We train the controller using the emulator to get a good control strategy
270  trainController();
271 
272  // We clean the training data after contoller update and reset the counter
273  resetData();
274  }
275 }
void computeAverageEpisodeReward()
Compute the average eposiodic reward.
std::vector< std::vector< Real > > _input_data
torch::Tensor _input_tensor
Torch::tensor version of the input and action data.
void getRewardDataFromReporter(std::vector< Real > &data, const std::vector< Real > *const reporter_link)
Extract the reward values from the postprocessors of the controlled system This assumes that they are...
void convertDataToTensor(std::vector< std::vector< Real >> &vector_data, torch::Tensor &tensor_data, const bool detach=false)
Function to convert input/output data from std::vector<std::vector> to torch::tensor.
void trainController()
The condensed training function.
std::vector< const std::vector< Real > * > _response_value_pointers
Pointers to the current values of the responses.
void getInputDataFromReporter(std::vector< std::vector< Real >> &data, const std::vector< const std::vector< Real > *> &reporter_links, const unsigned int num_timesteps)
Extract the response values from the postprocessors of the controlled system.
std::vector< const std::vector< Real > * > _log_probability_value_pointers
Pointers to the current values of the control log probabilities.
const unsigned int _input_timesteps
Number of timesteps to fetch from the reporters to be the input of then eural nets.
std::vector< std::vector< Real > > _output_data
std::vector< std::vector< Real > > _log_probability_data
void resetData()
Reset data after updating the neural network.
template void vectorToTensor< Real >(std::vector< Real > &vector, torch::Tensor &tensor, const bool detach)
void getOutputDataFromReporter(std::vector< std::vector< Real >> &data, const std::vector< const std::vector< Real > *> &reporter_links)
Extract the output (actions, logarithmic probabilities) values from the postprocessors of the control...
void computeRewardToGo()
Compute the return value by discounting the rewards and summing them.
unsigned int _update_counter
Counter for number of transient simulations that have been run before updating the controller...
std::vector< const std::vector< Real > * > _control_value_pointers
Pointers to the current values of the control signals.
const std::vector< Real > * _reward_value_pointer
Pointer to the current values of the reward.

◆ finalize()

virtual void SurrogateTrainerBase::finalize ( )
inlinevirtualinherited

Implements GeneralUserObject.

Reimplemented in SurrogateTrainer, and PODReducedBasisTrainer.

Definition at line 39 of file SurrogateTrainer.h.

39 {} // not required, but available

◆ getInputDataFromReporter()

void LibtorchDRLControlTrainer::getInputDataFromReporter ( std::vector< std::vector< Real >> &  data,
const std::vector< const std::vector< Real > *> &  reporter_links,
const unsigned int  num_timesteps 
)
private

Extract the response values from the postprocessors of the controlled system.

This assumes that they are stored in an AccumulateReporter

Parameters
dataThe data where we would like to store the response values
reporter_namesThe names of the reporters which need to be extracted
num_timestepsThe number of timesteps we want to use for training

Definition at line 426 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

430 {
431  for (const auto & rep_i : index_range(reporter_links))
432  {
433  std::vector<Real> reporter_data = *reporter_links[rep_i];
434 
435  // We shift and scale the inputs to get better training efficiency
436  std::transform(
437  reporter_data.begin(),
438  reporter_data.end(),
439  reporter_data.begin(),
440  [this, &rep_i](Real value) -> Real
441  { return (value - _response_shift_factors[rep_i]) * _response_scaling_factors[rep_i]; });
442 
443  // Fill the corresponding containers
444  for (const auto & start_step : make_range(num_timesteps))
445  {
446  unsigned int row = reporter_links.size() * start_step + rep_i;
447  for (unsigned int fill_i = 1; fill_i < num_timesteps - start_step; ++fill_i)
448  data[row].push_back(reporter_data[0]);
449 
450  data[row].insert(data[row].end(),
451  reporter_data.begin(),
452  reporter_data.begin() + start_step + reporter_data.size() -
453  (num_timesteps - 1) - _shift_outputs);
454  }
455  }
456 }
const bool _shift_outputs
Currently, the controls are executed after the user objects at initial in moose.
const std::vector< Real > _response_shift_factors
Shifting constants for the responses.
Real value(unsigned n, unsigned alpha, unsigned beta, Real x)
const std::vector< Real > _response_scaling_factors
Scaling constants for the responses.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
IntRange< T > make_range(T beg, T end)
auto index_range(const T &sizable)

◆ getModelData()

template<typename T , typename... Args>
const T & RestartableModelInterface::getModelData ( const std::string &  data_name,
Args &&...  args 
) const
inherited

Retrieve model data from the interface.

Definition at line 85 of file RestartableModelInterface.h.

86 {
87  return _model_restartable.getRestartableData<T>(data_name, std::forward<Args>(args)...);
88 }
const T & getRestartableData(const std::string &data_name) const
Declare a piece of data as "restartable" and initialize it Similar to declareRestartableData but retu...
PublicRestartable _model_restartable
Member for interfacing with the framework&#39;s restartable system.

◆ getModelDataFileName()

const FileName & RestartableModelInterface::getModelDataFileName ( ) const
inherited

Get the associated filename.

Definition at line 33 of file RestartableModelInterface.C.

34 {
35  return _model_object.getParam<FileName>("filename");
36 }
const T & getParam(const std::string &name) const
const MooseObject & _model_object
Reference to the MooseObject that uses this interface.

◆ getOutputDataFromReporter()

void LibtorchDRLControlTrainer::getOutputDataFromReporter ( std::vector< std::vector< Real >> &  data,
const std::vector< const std::vector< Real > *> &  reporter_links 
)
private

Extract the output (actions, logarithmic probabilities) values from the postprocessors of the controlled system.

This assumes that they are stored in an AccumulateReporter

Parameters
dataThe data where we would like to store the output values
reporter_namesThe names of the reporters which need to be extracted

Definition at line 459 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

462 {
463  for (const auto & rep_i : index_range(reporter_links))
464  // Fill the corresponding containers
465  data[rep_i].insert(data[rep_i].end(),
466  reporter_links[rep_i]->begin() + _shift_outputs,
467  reporter_links[rep_i]->end());
468 }
const bool _shift_outputs
Currently, the controls are executed after the user objects at initial in moose.
auto index_range(const T &sizable)

◆ getReporterPointers()

void LibtorchDRLControlTrainer::getReporterPointers ( const std::vector< ReporterName > &  reporter_names,
std::vector< const std::vector< Real > *> &  pointer_storage 
)
private

Getting reporter pointers with given names.

Definition at line 479 of file LibtorchDRLControlTrainer.C.

Referenced by LibtorchDRLControlTrainer().

482 {
483  pointer_storage.clear();
484  for (const auto & name : reporter_names)
485  pointer_storage.push_back(&getReporterValueByName<std::vector<Real>>(name));
486 }
const T & getReporterValueByName(const ReporterName &reporter_name, const std::size_t time_index=0)
virtual const std::string & name() const

◆ getRewardDataFromReporter()

void LibtorchDRLControlTrainer::getRewardDataFromReporter ( std::vector< Real > &  data,
const std::vector< Real > *const  reporter_link 
)
private

Extract the reward values from the postprocessors of the controlled system This assumes that they are stored in an AccumulateReporter.

Parameters
dataThe data where we would like to store the reward values
reporter_namesThe name of the reporter which need to be extracted

Definition at line 471 of file LibtorchDRLControlTrainer.C.

Referenced by computeRewardToGo(), and execute().

473 {
474  // Fill the corresponding container
475  data.insert(data.end(), reporter_link->begin() + _shift_outputs, reporter_link->end());
476 }
const bool _shift_outputs
Currently, the controls are executed after the user objects at initial in moose.

◆ hasModelData()

bool RestartableModelInterface::hasModelData ( ) const
inherited

Check if we need to load model data (if the filename parameter is used)

Definition at line 39 of file RestartableModelInterface.C.

40 {
41  return _model_object.isParamValid("filename");
42 }
bool isParamValid(const std::string &name) const
const MooseObject & _model_object
Reference to the MooseObject that uses this interface.

◆ initialize()

virtual void SurrogateTrainerBase::initialize ( )
inlinevirtualinherited

Implements GeneralUserObject.

Reimplemented in SurrogateTrainer, ActiveLearningGaussianProcess, and PODReducedBasisTrainer.

Definition at line 38 of file SurrogateTrainer.h.

38 {} // not required, but available

◆ modelMetaDataName()

const std::string& RestartableModelInterface::modelMetaDataName ( ) const
inlineinherited

Accessor for the name of the model meta data.

Definition at line 47 of file RestartableModelInterface.h.

Referenced by SurrogateTrainerOutput::output(), and MappingOutput::output().

47 { return _model_meta_data_name; }
const std::string _model_meta_data_name
The model meta data name.

◆ resetData()

void LibtorchDRLControlTrainer::resetData ( )
protected

Reset data after updating the neural network.

Definition at line 410 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

411 {
412  for (auto & data : _input_data)
413  data.clear();
414  for (auto & data : _output_data)
415  data.clear();
416  for (auto & data : _log_probability_data)
417  data.clear();
418 
419  _reward_data.clear();
420  _return_data.clear();
421 
423 }
std::vector< std::vector< Real > > _input_data
std::vector< std::vector< Real > > _output_data
std::vector< std::vector< Real > > _log_probability_data
const unsigned int _update_frequency
Number of transients to run and collect data from before updating the controller neural net...
unsigned int _update_counter
Counter for number of transient simulations that have been run before updating the controller...

◆ threadJoin()

virtual void SurrogateTrainerBase::threadJoin ( const UserObject )
inlinefinalvirtualinherited

Reimplemented from GeneralUserObject.

Definition at line 40 of file SurrogateTrainer.h.

40 {} // GeneralUserObjects are not threaded

◆ trainController()

void LibtorchDRLControlTrainer::trainController ( )

The condensed training function.

Definition at line 312 of file LibtorchDRLControlTrainer.C.

Referenced by execute().

313 {
314  // Define the optimizers for the training
315  torch::optim::Adam actor_optimizer(_control_nn->parameters(),
316  torch::optim::AdamOptions(_control_learning_rate));
317 
318  torch::optim::Adam critic_optimizer(_critic_nn->parameters(),
319  torch::optim::AdamOptions(_critic_learning_rate));
320 
321  // Compute the approximate value (return) from the critic neural net and use it to compute an
322  // advantage
323  auto value = evaluateValue(_input_tensor).detach();
324  auto advantage = _return_tensor - value;
325 
326  // If requested, standardize the advantage
328  advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-10);
329 
330  for (unsigned int epoch = 0; epoch < _num_epochs; ++epoch)
331  {
332  // Get the approximate return from the neural net again (this one does have an associated
333  // gradient)
335  // Get the approximate logarithmic action probability using the control neural net
336  auto curr_log_probability = evaluateAction(_input_tensor, _output_tensor);
337 
338  // Prepare the ratio by using the e^(logx-logy)=x/y expression
339  auto ratio = (curr_log_probability - _log_probability_tensor).exp();
340 
341  // Use clamping for limiting
342  auto surr1 = ratio * advantage;
343  auto surr2 = torch::clamp(ratio, 1.0 - _clip_param, 1.0 + _clip_param) * advantage;
344 
345  // Compute loss values for the critic and the control neural net
346  auto actor_loss = -torch::min(surr1, surr2).mean();
347  auto critic_loss = torch::mse_loss(value, _return_tensor);
348 
349  // Update the weights in the neural nets
350  actor_optimizer.zero_grad();
351  actor_loss.backward();
352  actor_optimizer.step();
353 
354  critic_optimizer.zero_grad();
355  critic_loss.backward();
356  critic_optimizer.step();
357 
358  // print loss per epoch
360  if (epoch % _loss_print_frequency == 0)
361  _console << "Epoch: " << epoch << " | Actor Loss: " << COLOR_GREEN
362  << actor_loss.item<double>() << COLOR_DEFAULT << " | Critic Loss: " << COLOR_GREEN
363  << critic_loss.item<double>() << COLOR_DEFAULT << std::endl;
364  }
365 
366  // Save the controller neural net so our controller can read it, we also save the critic if we
367  // want to continue training
368  torch::save(_control_nn, _control_nn->name());
369  torch::save(_critic_nn, _critic_nn->name());
370 }
torch::Tensor _input_tensor
Torch::tensor version of the input and action data.
auto exp(const T &)
torch::Tensor evaluateAction(torch::Tensor &input, torch::Tensor &output)
Function which evaluates the control net and then computes the logarithmic probability of the action...
const Real _clip_param
The clip parameter used while clamping the advantage value.
const Real _control_learning_rate
The learning rate for the optimization algorithm for the control.
Real value(unsigned n, unsigned alpha, unsigned beta, Real x)
const Real _critic_learning_rate
The learning rate for the optimization algorithm for the critic.
const unsigned int _loss_print_frequency
The frequency the loss should be printed.
const unsigned int _num_epochs
Number of epochs for the training of the emulator.
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _critic_nn
Pointer to the critic neural net object.
torch::Tensor evaluateValue(torch::Tensor &input)
Function which evaluates the critic to get the value (discounter reward)
const ConsoleStream _console
const bool _standardize_advantage
Switch to enable the standardization of the advantages.
std::shared_ptr< Moose::LibtorchArtificialNeuralNet > _control_nn
Pointer to the control (or actor) neural net object.

◆ validParams()

InputParameters LibtorchDRLControlTrainer::validParams ( )
static

Definition at line 22 of file LibtorchDRLControlTrainer.C.

23 {
25 
26  params.addClassDescription(
27  "Trains a neural network controller using the Proximal Policy Optimization (PPO) algorithm.");
28 
29  params.addRequiredParam<std::vector<ReporterName>>(
30  "response", "Reporter values containing the response values from the model.");
31  params.addParam<std::vector<Real>>(
32  "response_shift_factors",
33  "A shift constant which will be used to shift the response values. This is used for the "
34  "manipulation of the neural net inputs for better training efficiency.");
35  params.addParam<std::vector<Real>>(
36  "response_scaling_factors",
37  "A normalization constant which will be used to divide the response values. This is used for "
38  "the manipulation of the neural net inputs for better training efficiency.");
39  params.addRequiredParam<std::vector<ReporterName>>(
40  "control",
41  "Reporters containing the values of the controlled quantities (control signals) from the "
42  "model simulations.");
43  params.addRequiredParam<std::vector<ReporterName>>(
44  "log_probability",
45  "Reporters containing the log probabilities of the actions taken during the simulations.");
47  "reward", "Reporter containing the earned time-dependent rewards from the simulation.");
48  params.addRangeCheckedParam<unsigned int>(
49  "input_timesteps",
50  1,
51  "1<=input_timesteps",
52  "Number of time steps to use in the input data, if larger than 1, "
53  "data from the previous timesteps will be used as inputs in the training.");
54  params.addParam<unsigned int>("skip_num_rows",
55  1,
56  "Number of rows to ignore from training. We usually skip the 1st "
57  "row from the reporter since it contains only initial values.");
58 
59  params.addRequiredParam<unsigned int>("num_epochs", "Number of epochs for the training.");
60 
62  "critic_learning_rate",
63  "0<critic_learning_rate",
64  "Learning rate (relaxation) for the emulator training.");
65  params.addRequiredParam<std::vector<unsigned int>>(
66  "num_critic_neurons_per_layer", "Number of neurons per layer in the emulator neural net.");
67  params.addParam<std::vector<std::string>>(
68  "critic_activation_functions",
69  std::vector<std::string>({"relu"}),
70  "The type of activation functions to use in the emulator neural net. It is either one value "
71  "or one value per hidden layer.");
72 
74  "control_learning_rate",
75  "0<control_learning_rate",
76  "Learning rate (relaxation) for the control neural net training.");
77  params.addRequiredParam<std::vector<unsigned int>>(
78  "num_control_neurons_per_layer",
79  "Number of neurons per layer for the control neural network.");
80  params.addParam<std::vector<std::string>>(
81  "control_activation_functions",
82  std::vector<std::string>({"relu"}),
83  "The type of activation functions to use in the control neural net. It "
84  "is either one value "
85  "or one value per hidden layer.");
86 
87  params.addParam<std::string>("filename_base",
88  "Filename used to output the neural net parameters.");
89 
90  params.addParam<unsigned int>(
91  "seed", 11, "Random number generator seed for stochastic optimizers.");
92 
93  params.addRequiredParam<std::vector<Real>>(
94  "action_standard_deviations", "Standard deviation value used while sampling the actions.");
95 
96  params.addParam<Real>(
97  "clip_parameter", 0.2, "Clip parameter used while clamping the advantage value.");
98  params.addRangeCheckedParam<unsigned int>(
99  "update_frequency",
100  1,
101  "1<=update_frequency",
102  "Number of transient simulation data to collect for updating the controller neural network.");
103 
104  params.addRangeCheckedParam<Real>(
105  "decay_factor",
106  1.0,
107  "0.0<=decay_factor<=1.0",
108  "Decay factor for calculating the return. This accounts for decreased "
109  "reward values from the later steps.");
110 
111  params.addParam<bool>(
112  "read_from_file", false, "Switch to read the neural network parameters from a file.");
113  params.addParam<bool>(
114  "shift_outputs",
115  true,
116  "If we would like to shift the outputs the realign the input-output pairs.");
117  params.addParam<bool>(
118  "standardize_advantage",
119  true,
120  "Switch to enable the shifting and normalization of the advantages in the PPO algorithm.");
121  params.addParam<unsigned int>("loss_print_frequency",
122  0,
123  "The frequency which is used to print the loss values. If 0, the "
124  "loss values are not printed.");
125  return params;
126 }
void addRequiredRangeCheckedParam(const std::string &name, const std::string &parsed_function, const std::string &doc_string)
void addParam(const std::string &name, const std::initializer_list< typename T::value_type > &value, const std::string &doc_string)
void addRequiredParam(const std::string &name, const std::string &doc_string)
static InputParameters validParams()
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
void addClassDescription(const std::string &doc_string)
void addRangeCheckedParam(const std::string &name, const T &value, const std::string &parsed_function, const std::string &doc_string)

Member Data Documentation

◆ _action_std

const std::vector<Real> LibtorchDRLControlTrainer::_action_std
protected

Standard deviation for the actions.

Definition at line 161 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _average_episode_reward

Real LibtorchDRLControlTrainer::_average_episode_reward
protected

Storage for the current average episode reward.

Definition at line 178 of file LibtorchDRLControlTrainer.h.

Referenced by averageEpisodeReward(), and computeAverageEpisodeReward().

◆ _clip_param

const Real LibtorchDRLControlTrainer::_clip_param
protected

The clip parameter used while clamping the advantage value.

Definition at line 155 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _control_learning_rate

const Real LibtorchDRLControlTrainer::_control_learning_rate
protected

The learning rate for the optimization algorithm for the control.

Definition at line 149 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _control_names

const std::vector<ReporterName> LibtorchDRLControlTrainer::_control_names
protected

Control reporter names.

Definition at line 98 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _control_nn

std::shared_ptr<Moose::LibtorchArtificialNeuralNet> LibtorchDRLControlTrainer::_control_nn
protected

Pointer to the control (or actor) neural net object.

Definition at line 187 of file LibtorchDRLControlTrainer.h.

Referenced by controlNeuralNet(), evaluateAction(), LibtorchDRLControlTrainer(), and trainController().

◆ _control_value_pointers

std::vector<const std::vector<Real> *> LibtorchDRLControlTrainer::_control_value_pointers
protected

Pointers to the current values of the control signals.

Definition at line 101 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and LibtorchDRLControlTrainer().

◆ _critic_learning_rate

const Real LibtorchDRLControlTrainer::_critic_learning_rate
protected

The learning rate for the optimization algorithm for the critic.

Definition at line 143 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _critic_nn

std::shared_ptr<Moose::LibtorchArtificialNeuralNet> LibtorchDRLControlTrainer::_critic_nn
protected

Pointer to the critic neural net object.

Definition at line 189 of file LibtorchDRLControlTrainer.h.

Referenced by evaluateValue(), LibtorchDRLControlTrainer(), and trainController().

◆ _decay_factor

const Real LibtorchDRLControlTrainer::_decay_factor
protected

Decaying factor that is used when calculating the return from the reward.

Definition at line 158 of file LibtorchDRLControlTrainer.h.

Referenced by computeRewardToGo().

◆ _filename_base

const std::string LibtorchDRLControlTrainer::_filename_base
protected

Name of the pytorch output file.

This is used for loading and storing already existing data

Definition at line 165 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _input_data

std::vector<std::vector<Real> > LibtorchDRLControlTrainer::_input_data
protected

The gathered data from the reporters, each row represents one QoI, each column represents one time step

Definition at line 125 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and resetData().

◆ _input_tensor

torch::Tensor LibtorchDRLControlTrainer::_input_tensor
protected

Torch::tensor version of the input and action data.

Definition at line 195 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and trainController().

◆ _input_timesteps

const unsigned int LibtorchDRLControlTrainer::_input_timesteps
protected

Number of timesteps to fetch from the reporters to be the input of then eural nets.

Definition at line 116 of file LibtorchDRLControlTrainer.h.

Referenced by execute().

◆ _log_probability_data

std::vector<std::vector<Real> > LibtorchDRLControlTrainer::_log_probability_data
protected

Definition at line 127 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and resetData().

◆ _log_probability_names

const std::vector<ReporterName> LibtorchDRLControlTrainer::_log_probability_names
protected

Log probability reporter names.

Definition at line 104 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _log_probability_tensor

torch::Tensor LibtorchDRLControlTrainer::_log_probability_tensor
protected

Definition at line 198 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and trainController().

◆ _log_probability_value_pointers

std::vector<const std::vector<Real> *> LibtorchDRLControlTrainer::_log_probability_value_pointers
protected

Pointers to the current values of the control log probabilities.

Definition at line 107 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and LibtorchDRLControlTrainer().

◆ _loss_print_frequency

const unsigned int LibtorchDRLControlTrainer::_loss_print_frequency
protected

The frequency the loss should be printed.

Definition at line 184 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _num_control_neurons_per_layer

const std::vector<unsigned int> LibtorchDRLControlTrainer::_num_control_neurons_per_layer
protected

Number of neurons within the hidden layers in the control neural net.

Definition at line 146 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _num_critic_neurons_per_layer

const std::vector<unsigned int> LibtorchDRLControlTrainer::_num_critic_neurons_per_layer
protected

Number of neurons within the hidden layers in the critic neural net.

Definition at line 140 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _num_epochs

const unsigned int LibtorchDRLControlTrainer::_num_epochs
protected

Number of epochs for the training of the emulator.

Definition at line 137 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _num_inputs

unsigned int LibtorchDRLControlTrainer::_num_inputs
protected

Number of inputs for the control and critic neural nets.

Definition at line 119 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _num_outputs

unsigned int LibtorchDRLControlTrainer::_num_outputs
protected

Number of outputs for the control neural network.

Definition at line 121 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _output_data

std::vector<std::vector<Real> > LibtorchDRLControlTrainer::_output_data
protected

Definition at line 126 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and resetData().

◆ _output_tensor

torch::Tensor LibtorchDRLControlTrainer::_output_tensor
protected

Definition at line 196 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and trainController().

◆ _read_from_file

const bool LibtorchDRLControlTrainer::_read_from_file
protected

Switch indicating if an already existing neural net should be read from a file or not.

This can be used to load existing torch files (from previous MOOSE runs for retraining and further manipulation)

Definition at line 170 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _response_names

const std::vector<ReporterName> LibtorchDRLControlTrainer::_response_names
protected

Response reporter names.

Definition at line 86 of file LibtorchDRLControlTrainer.h.

Referenced by LibtorchDRLControlTrainer().

◆ _response_scaling_factors

const std::vector<Real> LibtorchDRLControlTrainer::_response_scaling_factors
protected

Scaling constants for the responses.

Definition at line 95 of file LibtorchDRLControlTrainer.h.

Referenced by getInputDataFromReporter(), and LibtorchDRLControlTrainer().

◆ _response_shift_factors

const std::vector<Real> LibtorchDRLControlTrainer::_response_shift_factors
protected

Shifting constants for the responses.

Definition at line 92 of file LibtorchDRLControlTrainer.h.

Referenced by getInputDataFromReporter(), and LibtorchDRLControlTrainer().

◆ _response_value_pointers

std::vector<const std::vector<Real> *> LibtorchDRLControlTrainer::_response_value_pointers
protected

Pointers to the current values of the responses.

Definition at line 89 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and LibtorchDRLControlTrainer().

◆ _return_data

std::vector<Real> LibtorchDRLControlTrainer::_return_data
protected

Definition at line 133 of file LibtorchDRLControlTrainer.h.

Referenced by computeRewardToGo(), execute(), and resetData().

◆ _return_tensor

torch::Tensor LibtorchDRLControlTrainer::_return_tensor
protected

Definition at line 197 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and trainController().

◆ _reward_data

std::vector<Real> LibtorchDRLControlTrainer::_reward_data
protected

The reward and return data. The return is calculated using the _reward_data

Definition at line 132 of file LibtorchDRLControlTrainer.h.

Referenced by computeAverageEpisodeReward(), execute(), and resetData().

◆ _reward_name

const ReporterName LibtorchDRLControlTrainer::_reward_name
protected

Reward reporter name.

Definition at line 110 of file LibtorchDRLControlTrainer.h.

◆ _reward_value_pointer

const std::vector<Real>* LibtorchDRLControlTrainer::_reward_value_pointer
protected

Pointer to the current values of the reward.

Definition at line 113 of file LibtorchDRLControlTrainer.h.

Referenced by computeRewardToGo(), and execute().

◆ _shift_outputs

const bool LibtorchDRLControlTrainer::_shift_outputs
protected

Currently, the controls are executed after the user objects at initial in moose.

So using a shift can realign the corresponding input-output values while reading the reporters

Definition at line 175 of file LibtorchDRLControlTrainer.h.

Referenced by getInputDataFromReporter(), getOutputDataFromReporter(), and getRewardDataFromReporter().

◆ _standardize_advantage

const bool LibtorchDRLControlTrainer::_standardize_advantage
protected

Switch to enable the standardization of the advantages.

Definition at line 181 of file LibtorchDRLControlTrainer.h.

Referenced by trainController().

◆ _std

torch::Tensor LibtorchDRLControlTrainer::_std
protected

standard deviation in a tensor format for sampling the actual control value

Definition at line 192 of file LibtorchDRLControlTrainer.h.

Referenced by evaluateAction(), and LibtorchDRLControlTrainer().

◆ _update_counter

unsigned int LibtorchDRLControlTrainer::_update_counter
private

Counter for number of transient simulations that have been run before updating the controller.

Definition at line 234 of file LibtorchDRLControlTrainer.h.

Referenced by execute(), and resetData().

◆ _update_frequency

const unsigned int LibtorchDRLControlTrainer::_update_frequency
protected

Number of transients to run and collect data from before updating the controller neural net.

Definition at line 152 of file LibtorchDRLControlTrainer.h.

Referenced by resetData().


The documentation for this class was generated from the following files: