https://mooseframework.inl.gov
Public Types | Public Member Functions | Static Public Member Functions | Public Attributes | Protected Member Functions | Protected Attributes | List of all members
LibtorchDRLControl Class Reference

A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimization. More...

#include <LibtorchDRLControl.h>

Inheritance diagram for LibtorchDRLControl:
[legend]

Public Types

typedef DataFileName DataFileParameterType
 

Public Member Functions

 LibtorchDRLControl (const InputParameters &parameters)
 Construct using input parameters. More...
 
virtual void execute () override
 We compute the actions in this function together with the corresponding logarithmic probabilities. More...
 
Real getSignalLogProbability (const unsigned int signal_index) const
 Get the logarithmic probability of (signal_index)-th signal of the control neural net. More...
 
Real getSignal (const unsigned int signal_index) const
 
unsigned int numberOfControlSignals () const
 
void loadControlNeuralNet (const Moose::LibtorchArtificialNeuralNet &input_nn)
 
const Moose::LibtorchNeuralNetBasecontrolNeuralNet () const
 
bool hasControlNeuralNet () const
 
std::vector< std::string > & getDependencies ()
 
virtual bool enabled () const
 
std::shared_ptr< MooseObjectgetSharedPtr ()
 
std::shared_ptr< const MooseObjectgetSharedPtr () const
 
MooseAppgetMooseApp () const
 
const std::string & type () const
 
virtual const std::string & name () const
 
std::string typeAndName () const
 
std::string errorPrefix (const std::string &error_type) const
 
void callMooseError (std::string msg, const bool with_prefix) const
 
MooseObjectParameterName uniqueParameterName (const std::string &parameter_name) const
 
const InputParametersparameters () const
 
MooseObjectName uniqueName () const
 
const T & getParam (const std::string &name) const
 
std::vector< std::pair< T1, T2 > > getParam (const std::string &param1, const std::string &param2) const
 
const T * queryParam (const std::string &name) const
 
const T & getRenamedParam (const std::string &old_name, const std::string &new_name) const
 
getCheckedPointerParam (const std::string &name, const std::string &error_string="") const
 
bool isParamValid (const std::string &name) const
 
bool isParamSetByUser (const std::string &nm) const
 
void paramError (const std::string &param, Args... args) const
 
void paramWarning (const std::string &param, Args... args) const
 
void paramInfo (const std::string &param, Args... args) const
 
void connectControllableParams (const std::string &parameter, const std::string &object_type, const std::string &object_name, const std::string &object_parameter) const
 
void mooseError (Args &&... args) const
 
void mooseErrorNonPrefixed (Args &&... args) const
 
void mooseDocumentedError (const std::string &repo_name, const unsigned int issue_num, Args &&... args) const
 
void mooseWarning (Args &&... args) const
 
void mooseWarningNonPrefixed (Args &&... args) const
 
void mooseDeprecated (Args &&... args) const
 
void mooseInfo (Args &&... args) const
 
std::string getDataFileName (const std::string &param) const
 
std::string getDataFileNameByName (const std::string &relative_path) const
 
std::string getDataFilePath (const std::string &relative_path) const
 
bool isImplicit ()
 
Moose::StateArg determineState () const
 
virtual void initialSetup ()
 
virtual void timestepSetup ()
 
virtual void jacobianSetup ()
 
virtual void residualSetup ()
 
virtual void subdomainSetup ()
 
virtual void customSetup (const ExecFlagType &)
 
const ExecFlagEnumgetExecuteOnEnum () const
 
const FunctiongetFunction (const std::string &name) const
 
const FunctiongetFunctionByName (const FunctionName &name) const
 
bool hasFunction (const std::string &param_name) const
 
bool hasFunctionByName (const FunctionName &name) const
 
UserObjectName getUserObjectName (const std::string &param_name) const
 
const T & getUserObject (const std::string &param_name, bool is_dependency=true) const
 
const T & getUserObjectByName (const UserObjectName &object_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBase (const std::string &param_name, bool is_dependency=true) const
 
const UserObjectgetUserObjectBaseByName (const UserObjectName &object_name, bool is_dependency=true) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObject (const std::string &param_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
bool hasUserObjectByName (const UserObjectName &object_name) const
 
PerfGraphperfGraph ()
 
bool isDefaultPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessor (const std::string &param_name, const unsigned int index=0) const
 
bool hasPostprocessorByName (const PostprocessorName &name) const
 
std::size_t coupledPostprocessors (const std::string &param_name) const
 
const PostprocessorName & getPostprocessorName (const std::string &param_name, const unsigned int index=0) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name, bool needs_broadcast) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const VectorPostprocessorValuegetVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValue (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOld (const std::string &param_name, const std::string &vector_name) const
 
const ScatterVectorPostprocessorValuegetScatterVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name, const std::string &vector_name) const
 
bool hasVectorPostprocessor (const std::string &param_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name, const std::string &vector_name) const
 
bool hasVectorPostprocessorByName (const VectorPostprocessorName &name) const
 
const VectorPostprocessorName & getVectorPostprocessorName (const std::string &param_name) const
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValue (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOld (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
const PostprocessorValuegetPostprocessorValueOlder (const std::string &param_name, const unsigned int index=0) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
virtual const PostprocessorValuegetPostprocessorValueByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOldByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
const PostprocessorValuegetPostprocessorValueOlderByName (const PostprocessorName &name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributed (const std::string &param_name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
bool isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const
 
const Parallel::Communicator & comm () const
 
processor_id_type n_processors () const
 
processor_id_type processor_id () const
 

Static Public Member Functions

static InputParameters validParams ()
 
static MultiMooseEnum getExecuteOptions ()
 

Public Attributes

const ConsoleStream _console
 

Protected Member Functions

torch::Tensor computeLogProbability (const torch::Tensor &action, const torch::Tensor &output_tensor)
 Function which computes the logarithmic probability of given actions. More...
 
void conditionalParameterError (const std::string &param_name, const std::vector< std::string > &conditional_param, bool should_be_defined=true)
 
void updateCurrentResponse ()
 
torch::Tensor prepareInputTensor ()
 
bool hasControllableParameterByName (const std::string &name) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level) const
 
PerfID registerTimedSection (const std::string &section_name, const unsigned int level, const std::string &live_message, const bool print_dots=true) const
 
std::string timedSectionName (const std::string &section_name) const
 
virtual void addUserObjectDependencyHelper (const UserObject &) const
 
T & declareRestartableData (const std::string &data_name, Args &&... args)
 
ManagedValue< T > declareManagedRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
const T & getRestartableData (const std::string &data_name) const
 
T & declareRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args)
 
T & declareRecoverableData (const std::string &data_name, Args &&... args)
 
T & declareRestartableDataWithObjectName (const std::string &data_name, const std::string &object_name, Args &&... args)
 
T & declareRestartableDataWithObjectNameWithContext (const std::string &data_name, const std::string &object_name, void *context, Args &&... args)
 
std::string restartableName (const std::string &data_name) const
 
virtual void addPostprocessorDependencyHelper (const PostprocessorName &) const
 
virtual void addVectorPostprocessorDependencyHelper (const VectorPostprocessorName &) const
 
ControllableParameter getControllableParameter (const std::string &param_name)
 
ControllableParameter getControllableParameter (const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const std::string &tag, const std::string &object_name, const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const MooseObjectName &object_name, const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const MooseObjectParameterName &param_name)
 
ControllableParameter getControllableParameterByName (const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const std::string &tag, const std::string &object_name, const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const MooseObjectName &object_name, const std::string &param_name)
 
ControllableParameter getControllableParameterByName (const MooseObjectParameterName &param_name)
 
getControllableValue (const std::string &name, bool warn_when_values_differ=true)
 
getControllableValue (const std::string &name, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &name, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const MooseObjectName &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &tag, const std::string &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const MooseObjectParameterName &desired, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &name, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const MooseObjectName &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const std::string &tag, const std::string &object_name, const std::string &param_name, bool warn_when_values_differ=true)
 
getControllableValueByName (const MooseObjectParameterName &desired, bool warn_when_values_differ=true)
 
void setControllableValue (const std::string &name, const T &value)
 
void setControllableValue (const std::string &name, const T &value)
 
void setControllableValueByName (const std::string &name, const T &value)
 
void setControllableValueByName (const std::string &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const MooseObjectName &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const std::string &tag, const std::string &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const MooseObjectParameterName &name, const T &value)
 
void setControllableValueByName (const std::string &name, const T &value)
 
void setControllableValueByName (const std::string &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const MooseObjectName &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const std::string &tag, const std::string &object_name, const std::string &param_name, const T &value)
 
void setControllableValueByName (const MooseObjectParameterName &name, const T &value)
 

Protected Attributes

std::vector< Real_current_control_signal_log_probabilities
 The log probability of control signals from the last evaluation of the controller. More...
 
const std::vector< Real_action_std
 Standard deviation for the actions, supplied by the user. More...
 
torch::Tensor _std
 Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines. More...
 
std::vector< Real_current_response
 
std::vector< std::vector< Real > > & _old_responses
 
const std::vector< std::string > & _control_names
 
std::vector< Real_current_control_signals
 
const std::vector< PostprocessorName > & _response_names
 
std::vector< const Real *> _response_values
 
const unsigned int _input_timesteps
 
const std::vector< Real_response_shift_factors
 
const std::vector< Real_response_scaling_factors
 
const std::vector< Real_action_scaling_factors
 
std::shared_ptr< Moose::LibtorchNeuralNetBase_nn
 
FEProblemBase_fe_problem
 
std::vector< std::string > _depends_on
 
const bool & _enabled
 
MooseApp_app
 
const std::string _type
 
const std::string _name
 
const InputParameters_pars
 
Factory_factory
 
ActionFactory_action_factory
 
MooseApp_pg_moose_app
 
const std::string _prefix
 
const InputParameters_ti_params
 
FEProblemBase_ti_feproblem
 
bool _is_implicit
 
Real_t
 
const Real_t_old
 
int_t_step
 
Real_dt
 
Real_dt_old
 
bool _is_transient
 
const ExecFlagEnum_execute_enum
 
const ExecFlagType_current_execute_flag
 
MooseApp_restartable_app
 
const std::string _restartable_system_name
 
const THREAD_ID _restartable_tid
 
const bool _restartable_read_only
 
const Parallel::Communicator & _communicator
 

Detailed Description

A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimization.

We use this neural net for the training of a controller. The additional functionality in this controller is the addition of the variability (using an assumed Gaussian distribution) to avoid overfitting. This control is supposed to be used in conjunction with LibtorchDRLControlTrainer.

Definition at line 24 of file LibtorchDRLControl.h.

Constructor & Destructor Documentation

◆ LibtorchDRLControl()

LibtorchDRLControl::LibtorchDRLControl ( const InputParameters parameters)

Construct using input parameters.

Definition at line 33 of file LibtorchDRLControl.C.

35  _current_control_signal_log_probabilities(std::vector<Real>(_control_names.size(), 0.0)),
36  _action_std(getParam<std::vector<Real>>("action_standard_deviations"))
37 {
38  if (_control_names.size() != _action_std.size())
39  paramError("action_standard_deviations",
40  "Number of action_standard_deviations does not match the number of controlled "
41  "parameters.");
42 
43  // Fixing the RNG seed to make sure every experiment is the same.
44  if (isParamValid("seed"))
45  torch::manual_seed(getParam<unsigned int>("seed"));
46 
47  // We convert and store the user-supplied standard deviations into a tensor which can be easily
48  // used by routines in libtorch
49  _std = torch::eye(_control_names.size());
50  for (unsigned int i = 0; i < _control_names.size(); ++i)
51  _std[i][i] = _action_std[i];
52 }
LibtorchNeuralNetControl(const InputParameters &parameters)
const std::vector< std::string > & _control_names
bool isParamValid(const std::string &name) const
const std::vector< Real > _action_std
Standard deviation for the actions, supplied by the user.
const T & getParam(const std::string &name) const
void paramError(const std::string &param, Args... args) const
const InputParameters & parameters() const
torch::Tensor _std
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines...
std::vector< Real > _current_control_signal_log_probabilities
The log probability of control signals from the last evaluation of the controller.

Member Function Documentation

◆ computeLogProbability()

torch::Tensor LibtorchDRLControl::computeLogProbability ( const torch::Tensor action,
const torch::Tensor output_tensor 
)
protected

Function which computes the logarithmic probability of given actions.

Parameters
actionThe tensor containing the perturbed control signals (also known as the action of the controller)
output_tensorThe expected value of the signals predicted by the neural net
Returns
The logarithmic probability of the action with respect to the neural net prediction

Definition at line 105 of file LibtorchDRLControl.C.

Referenced by execute().

107 {
108  // Logarithmic probability of taken action, given the current distribution.
109  torch::Tensor var = torch::matmul(_std, _std);
110 
111  return -((action - output_tensor) * (action - output_tensor)) / (2.0 * var) - torch::log(_std) -
112  std::log(std::sqrt(2.0 * M_PI));
113 }
torch::Tensor _std
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines...

◆ execute()

void LibtorchDRLControl::execute ( )
overridevirtual

We compute the actions in this function together with the corresponding logarithmic probabilities.

Reimplemented from LibtorchNeuralNetControl.

Definition at line 55 of file LibtorchDRLControl.C.

56 {
57  if (_nn)
58  {
59  unsigned int n_controls = _control_names.size();
60  unsigned int num_old_timesteps = _input_timesteps - 1;
61 
62  // Fill a vector with the current values of the responses
64 
65  // If this is the first time this control is called and we need to use older values, fill up the
66  // needed old values using the initial values
67  if (_old_responses.empty())
68  _old_responses.assign(num_old_timesteps, _current_response);
69 
70  // Organize the old an current solution into a tensor so we can evaluate the neural net
71  torch::Tensor input_tensor = prepareInputTensor();
72 
73  // Evaluate the neural network to get the expected control value
74  torch::Tensor output_tensor = _nn->forward(input_tensor);
75 
76  // Sample control value (action) from Gaussian distribution
77  torch::Tensor action = at::normal(output_tensor, _std);
78 
79  // Compute log probability
80  torch::Tensor log_probability = computeLogProbability(action, output_tensor);
81 
82  // Convert data
83  _current_control_signals = {action.data_ptr<Real>(), action.data_ptr<Real>() + action.size(1)};
84 
85  _current_control_signal_log_probabilities = {log_probability.data_ptr<Real>(),
86  log_probability.data_ptr<Real>() +
87  log_probability.size(1)};
88 
89  for (unsigned int control_i = 0; control_i < n_controls; ++control_i)
90  {
91  // We scale the controllable value for physically meaningful control action
92  setControllableValueByName<Real>(_control_names[control_i],
93  _current_control_signals[control_i] *
94  _action_scaling_factors[control_i]);
95  }
96 
97  // We add the curent solution to the old solutions and move everything in there one step
98  // backward
99  std::rotate(_old_responses.rbegin(), _old_responses.rbegin() + 1, _old_responses.rend());
101  }
102 }
const std::vector< Real > _action_scaling_factors
std::vector< Real > _current_response
std::shared_ptr< Moose::LibtorchNeuralNetBase > _nn
torch::Tensor prepareInputTensor()
const std::vector< std::string > & _control_names
std::vector< std::vector< Real > > & _old_responses
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
const unsigned int _input_timesteps
torch::Tensor _std
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines...
std::vector< Real > _current_control_signal_log_probabilities
The log probability of control signals from the last evaluation of the controller.
std::vector< Real > _current_control_signals
torch::Tensor computeLogProbability(const torch::Tensor &action, const torch::Tensor &output_tensor)
Function which computes the logarithmic probability of given actions.

◆ getSignalLogProbability()

Real LibtorchDRLControl::getSignalLogProbability ( const unsigned int  signal_index) const

Get the logarithmic probability of (signal_index)-th signal of the control neural net.

Parameters
signal_indexThe index of the signal
Returns
The logarithmic probability of the (signal_index)-th signal

Definition at line 116 of file LibtorchDRLControl.C.

Referenced by LibtorchDRLLogProbabilityPostprocessor::getValue().

117 {
118  mooseAssert(signal_index < _control_names.size(),
119  "The index of the requested control signal is not in the [0," +
120  std::to_string(_control_names.size()) + ") range!");
121  return _current_control_signal_log_probabilities[signal_index];
122 }
const std::vector< std::string > & _control_names
std::vector< Real > _current_control_signal_log_probabilities
The log probability of control signals from the last evaluation of the controller.

◆ validParams()

InputParameters LibtorchDRLControl::validParams ( )
static

Definition at line 20 of file LibtorchDRLControl.C.

21 {
23  params.addClassDescription(
24  "Sets the value of multiple 'Real' input parameters and postprocessors based on a Deep "
25  "Reinforcement Learning (DRL) neural network trained using a PPO algorithm.");
26  params.addRequiredParam<std::vector<Real>>(
27  "action_standard_deviations", "Standard deviation value used while sampling the actions.");
28  params.addParam<unsigned int>("seed", "Seed for the random number generator.");
29 
30  return params;
31 }
void addParam(const std::string &name, const std::initializer_list< typename T::value_type > &value, const std::string &doc_string)
static InputParameters validParams()
void addRequiredParam(const std::string &name, const std::string &doc_string)
void addClassDescription(const std::string &doc_string)

Member Data Documentation

◆ _action_std

const std::vector<Real> LibtorchDRLControl::_action_std
protected

Standard deviation for the actions, supplied by the user.

Definition at line 57 of file LibtorchDRLControl.h.

Referenced by LibtorchDRLControl().

◆ _current_control_signal_log_probabilities

std::vector<Real> LibtorchDRLControl::_current_control_signal_log_probabilities
protected

The log probability of control signals from the last evaluation of the controller.

Definition at line 54 of file LibtorchDRLControl.h.

Referenced by execute(), and getSignalLogProbability().

◆ _std

torch::Tensor LibtorchDRLControl::_std
protected

Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines.

Definition at line 60 of file LibtorchDRLControl.h.

Referenced by computeLogProbability(), execute(), and LibtorchDRLControl().


The documentation for this class was generated from the following files: