A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimization. More...
#include <LibtorchDRLControl.h>
Public Types | |
typedef DataFileName | DataFileParameterType |
Public Member Functions | |
LibtorchDRLControl (const InputParameters ¶meters) | |
Construct using input parameters. More... | |
virtual void | execute () override |
We compute the actions in this function together with the corresponding logarithmic probabilities. More... | |
Real | getSignalLogProbability (const unsigned int signal_index) const |
Get the logarithmic probability of (signal_index)-th signal of the control neural net. More... | |
Real | getSignal (const unsigned int signal_index) const |
unsigned int | numberOfControlSignals () const |
void | loadControlNeuralNet (const Moose::LibtorchArtificialNeuralNet &input_nn) |
const Moose::LibtorchNeuralNetBase & | controlNeuralNet () const |
bool | hasControlNeuralNet () const |
std::vector< std::string > & | getDependencies () |
virtual bool | enabled () const |
std::shared_ptr< MooseObject > | getSharedPtr () |
std::shared_ptr< const MooseObject > | getSharedPtr () const |
MooseApp & | getMooseApp () const |
const std::string & | type () const |
virtual const std::string & | name () const |
std::string | typeAndName () const |
std::string | errorPrefix (const std::string &error_type) const |
void | callMooseError (std::string msg, const bool with_prefix) const |
MooseObjectParameterName | uniqueParameterName (const std::string ¶meter_name) const |
const InputParameters & | parameters () const |
MooseObjectName | uniqueName () const |
const T & | getParam (const std::string &name) const |
std::vector< std::pair< T1, T2 > > | getParam (const std::string ¶m1, const std::string ¶m2) const |
const T * | queryParam (const std::string &name) const |
const T & | getRenamedParam (const std::string &old_name, const std::string &new_name) const |
T | getCheckedPointerParam (const std::string &name, const std::string &error_string="") const |
bool | isParamValid (const std::string &name) const |
bool | isParamSetByUser (const std::string &nm) const |
void | paramError (const std::string ¶m, Args... args) const |
void | paramWarning (const std::string ¶m, Args... args) const |
void | paramInfo (const std::string ¶m, Args... args) const |
void | connectControllableParams (const std::string ¶meter, const std::string &object_type, const std::string &object_name, const std::string &object_parameter) const |
void | mooseError (Args &&... args) const |
void | mooseErrorNonPrefixed (Args &&... args) const |
void | mooseDocumentedError (const std::string &repo_name, const unsigned int issue_num, Args &&... args) const |
void | mooseWarning (Args &&... args) const |
void | mooseWarningNonPrefixed (Args &&... args) const |
void | mooseDeprecated (Args &&... args) const |
void | mooseInfo (Args &&... args) const |
std::string | getDataFileName (const std::string ¶m) const |
std::string | getDataFileNameByName (const std::string &relative_path) const |
std::string | getDataFilePath (const std::string &relative_path) const |
bool | isImplicit () |
Moose::StateArg | determineState () const |
virtual void | initialSetup () |
virtual void | timestepSetup () |
virtual void | jacobianSetup () |
virtual void | residualSetup () |
virtual void | subdomainSetup () |
virtual void | customSetup (const ExecFlagType &) |
const ExecFlagEnum & | getExecuteOnEnum () const |
const Function & | getFunction (const std::string &name) const |
const Function & | getFunctionByName (const FunctionName &name) const |
bool | hasFunction (const std::string ¶m_name) const |
bool | hasFunctionByName (const FunctionName &name) const |
UserObjectName | getUserObjectName (const std::string ¶m_name) const |
const T & | getUserObject (const std::string ¶m_name, bool is_dependency=true) const |
const T & | getUserObjectByName (const UserObjectName &object_name, bool is_dependency=true) const |
const UserObject & | getUserObjectBase (const std::string ¶m_name, bool is_dependency=true) const |
const UserObject & | getUserObjectBaseByName (const UserObjectName &object_name, bool is_dependency=true) const |
bool | hasUserObject (const std::string ¶m_name) const |
bool | hasUserObject (const std::string ¶m_name) const |
bool | hasUserObject (const std::string ¶m_name) const |
bool | hasUserObject (const std::string ¶m_name) const |
bool | hasUserObjectByName (const UserObjectName &object_name) const |
bool | hasUserObjectByName (const UserObjectName &object_name) const |
bool | hasUserObjectByName (const UserObjectName &object_name) const |
bool | hasUserObjectByName (const UserObjectName &object_name) const |
PerfGraph & | perfGraph () |
bool | isDefaultPostprocessorValue (const std::string ¶m_name, const unsigned int index=0) const |
bool | hasPostprocessor (const std::string ¶m_name, const unsigned int index=0) const |
bool | hasPostprocessorByName (const PostprocessorName &name) const |
std::size_t | coupledPostprocessors (const std::string ¶m_name) const |
const PostprocessorName & | getPostprocessorName (const std::string ¶m_name, const unsigned int index=0) const |
const VectorPostprocessorValue & | getVectorPostprocessorValue (const std::string ¶m_name, const std::string &vector_name) const |
const VectorPostprocessorValue & | getVectorPostprocessorValue (const std::string ¶m_name, const std::string &vector_name, bool needs_broadcast) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueOld (const std::string ¶m_name, const std::string &vector_name) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueOld (const std::string ¶m_name, const std::string &vector_name, bool needs_broadcast) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const |
const VectorPostprocessorValue & | getVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name, bool needs_broadcast) const |
const ScatterVectorPostprocessorValue & | getScatterVectorPostprocessorValue (const std::string ¶m_name, const std::string &vector_name) const |
const ScatterVectorPostprocessorValue & | getScatterVectorPostprocessorValueByName (const VectorPostprocessorName &name, const std::string &vector_name) const |
const ScatterVectorPostprocessorValue & | getScatterVectorPostprocessorValueOld (const std::string ¶m_name, const std::string &vector_name) const |
const ScatterVectorPostprocessorValue & | getScatterVectorPostprocessorValueOldByName (const VectorPostprocessorName &name, const std::string &vector_name) const |
bool | hasVectorPostprocessor (const std::string ¶m_name, const std::string &vector_name) const |
bool | hasVectorPostprocessor (const std::string ¶m_name) const |
bool | hasVectorPostprocessorByName (const VectorPostprocessorName &name, const std::string &vector_name) const |
bool | hasVectorPostprocessorByName (const VectorPostprocessorName &name) const |
const VectorPostprocessorName & | getVectorPostprocessorName (const std::string ¶m_name) const |
const PostprocessorValue & | getPostprocessorValue (const std::string ¶m_name, const unsigned int index=0) const |
const PostprocessorValue & | getPostprocessorValue (const std::string ¶m_name, const unsigned int index=0) const |
const PostprocessorValue & | getPostprocessorValueOld (const std::string ¶m_name, const unsigned int index=0) const |
const PostprocessorValue & | getPostprocessorValueOld (const std::string ¶m_name, const unsigned int index=0) const |
const PostprocessorValue & | getPostprocessorValueOlder (const std::string ¶m_name, const unsigned int index=0) const |
const PostprocessorValue & | getPostprocessorValueOlder (const std::string ¶m_name, const unsigned int index=0) const |
virtual const PostprocessorValue & | getPostprocessorValueByName (const PostprocessorName &name) const |
virtual const PostprocessorValue & | getPostprocessorValueByName (const PostprocessorName &name) const |
const PostprocessorValue & | getPostprocessorValueOldByName (const PostprocessorName &name) const |
const PostprocessorValue & | getPostprocessorValueOldByName (const PostprocessorName &name) const |
const PostprocessorValue & | getPostprocessorValueOlderByName (const PostprocessorName &name) const |
const PostprocessorValue & | getPostprocessorValueOlderByName (const PostprocessorName &name) const |
bool | isVectorPostprocessorDistributed (const std::string ¶m_name) const |
bool | isVectorPostprocessorDistributed (const std::string ¶m_name) const |
bool | isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const |
bool | isVectorPostprocessorDistributedByName (const VectorPostprocessorName &name) const |
const Parallel::Communicator & | comm () const |
processor_id_type | n_processors () const |
processor_id_type | processor_id () const |
Static Public Member Functions | |
static InputParameters | validParams () |
static MultiMooseEnum | getExecuteOptions () |
Public Attributes | |
const ConsoleStream | _console |
Protected Member Functions | |
torch::Tensor | computeLogProbability (const torch::Tensor &action, const torch::Tensor &output_tensor) |
Function which computes the logarithmic probability of given actions. More... | |
void | conditionalParameterError (const std::string ¶m_name, const std::vector< std::string > &conditional_param, bool should_be_defined=true) |
void | updateCurrentResponse () |
torch::Tensor | prepareInputTensor () |
bool | hasControllableParameterByName (const std::string &name) const |
PerfID | registerTimedSection (const std::string §ion_name, const unsigned int level) const |
PerfID | registerTimedSection (const std::string §ion_name, const unsigned int level, const std::string &live_message, const bool print_dots=true) const |
std::string | timedSectionName (const std::string §ion_name) const |
virtual void | addUserObjectDependencyHelper (const UserObject &) const |
T & | declareRestartableData (const std::string &data_name, Args &&... args) |
ManagedValue< T > | declareManagedRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args) |
const T & | getRestartableData (const std::string &data_name) const |
T & | declareRestartableDataWithContext (const std::string &data_name, void *context, Args &&... args) |
T & | declareRecoverableData (const std::string &data_name, Args &&... args) |
T & | declareRestartableDataWithObjectName (const std::string &data_name, const std::string &object_name, Args &&... args) |
T & | declareRestartableDataWithObjectNameWithContext (const std::string &data_name, const std::string &object_name, void *context, Args &&... args) |
std::string | restartableName (const std::string &data_name) const |
virtual void | addPostprocessorDependencyHelper (const PostprocessorName &) const |
virtual void | addVectorPostprocessorDependencyHelper (const VectorPostprocessorName &) const |
ControllableParameter | getControllableParameter (const std::string ¶m_name) |
ControllableParameter | getControllableParameter (const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const MooseObjectName &object_name, const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const MooseObjectParameterName ¶m_name) |
ControllableParameter | getControllableParameterByName (const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const MooseObjectName &object_name, const std::string ¶m_name) |
ControllableParameter | getControllableParameterByName (const MooseObjectParameterName ¶m_name) |
T | getControllableValue (const std::string &name, bool warn_when_values_differ=true) |
T | getControllableValue (const std::string &name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const MooseObjectName &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const MooseObjectParameterName &desired, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const MooseObjectName &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name, bool warn_when_values_differ=true) |
T | getControllableValueByName (const MooseObjectParameterName &desired, bool warn_when_values_differ=true) |
void | setControllableValue (const std::string &name, const T &value) |
void | setControllableValue (const std::string &name, const T &value) |
void | setControllableValueByName (const std::string &name, const T &value) |
void | setControllableValueByName (const std::string &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const MooseObjectName &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const MooseObjectParameterName &name, const T &value) |
void | setControllableValueByName (const std::string &name, const T &value) |
void | setControllableValueByName (const std::string &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const MooseObjectName &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const std::string &tag, const std::string &object_name, const std::string ¶m_name, const T &value) |
void | setControllableValueByName (const MooseObjectParameterName &name, const T &value) |
Protected Attributes | |
std::vector< Real > | _current_control_signal_log_probabilities |
The log probability of control signals from the last evaluation of the controller. More... | |
const std::vector< Real > | _action_std |
Standard deviation for the actions, supplied by the user. More... | |
torch::Tensor | _std |
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines. More... | |
std::vector< Real > | _current_response |
std::vector< std::vector< Real > > & | _old_responses |
const std::vector< std::string > & | _control_names |
std::vector< Real > | _current_control_signals |
const std::vector< PostprocessorName > & | _response_names |
std::vector< const Real *> | _response_values |
const unsigned int | _input_timesteps |
const std::vector< Real > | _response_shift_factors |
const std::vector< Real > | _response_scaling_factors |
const std::vector< Real > | _action_scaling_factors |
std::shared_ptr< Moose::LibtorchNeuralNetBase > | _nn |
FEProblemBase & | _fe_problem |
std::vector< std::string > | _depends_on |
const bool & | _enabled |
MooseApp & | _app |
const std::string | _type |
const std::string | _name |
const InputParameters & | _pars |
Factory & | _factory |
ActionFactory & | _action_factory |
MooseApp & | _pg_moose_app |
const std::string | _prefix |
const InputParameters & | _ti_params |
FEProblemBase & | _ti_feproblem |
bool | _is_implicit |
Real & | _t |
const Real & | _t_old |
int & | _t_step |
Real & | _dt |
Real & | _dt_old |
bool | _is_transient |
const ExecFlagEnum & | _execute_enum |
const ExecFlagType & | _current_execute_flag |
MooseApp & | _restartable_app |
const std::string | _restartable_system_name |
const THREAD_ID | _restartable_tid |
const bool | _restartable_read_only |
const Parallel::Communicator & | _communicator |
A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimization.
We use this neural net for the training of a controller. The additional functionality in this controller is the addition of the variability (using an assumed Gaussian distribution) to avoid overfitting. This control is supposed to be used in conjunction with LibtorchDRLControlTrainer.
Definition at line 24 of file LibtorchDRLControl.h.
LibtorchDRLControl::LibtorchDRLControl | ( | const InputParameters & | parameters | ) |
Construct using input parameters.
Definition at line 33 of file LibtorchDRLControl.C.
|
protected |
Function which computes the logarithmic probability of given actions.
action | The tensor containing the perturbed control signals (also known as the action of the controller) |
output_tensor | The expected value of the signals predicted by the neural net |
Definition at line 105 of file LibtorchDRLControl.C.
Referenced by execute().
|
overridevirtual |
We compute the actions in this function together with the corresponding logarithmic probabilities.
Reimplemented from LibtorchNeuralNetControl.
Definition at line 55 of file LibtorchDRLControl.C.
Get the logarithmic probability of (signal_index)-th signal of the control neural net.
signal_index | The index of the signal |
Definition at line 116 of file LibtorchDRLControl.C.
Referenced by LibtorchDRLLogProbabilityPostprocessor::getValue().
|
static |
Definition at line 20 of file LibtorchDRLControl.C.
|
protected |
Standard deviation for the actions, supplied by the user.
Definition at line 57 of file LibtorchDRLControl.h.
Referenced by LibtorchDRLControl().
|
protected |
The log probability of control signals from the last evaluation of the controller.
Definition at line 54 of file LibtorchDRLControl.h.
Referenced by execute(), and getSignalLogProbability().
|
protected |
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines.
Definition at line 60 of file LibtorchDRLControl.h.
Referenced by computeLogProbability(), execute(), and LibtorchDRLControl().