10 #ifdef LIBTORCH_ENABLED 24 "Sets the value of multiple 'Real' input parameters and postprocessors based on a Deep " 25 "Reinforcement Learning (DRL) neural network trained using a PPO algorithm.");
27 "action_standard_deviations",
"Standard deviation value used while sampling the actions.");
28 params.
addParam<
unsigned int>(
"seed",
"Seed for the random number generator.");
35 _current_control_signal_log_probabilities(
std::vector<
Real>(_control_names.size(), 0.0)),
36 _action_std(getParam<
std::vector<
Real>>(
"action_standard_deviations"))
40 "Number of action_standard_deviations does not match the number of controlled " 45 torch::manual_seed(getParam<unsigned int>(
"seed"));
74 torch::Tensor output_tensor =
_nn->forward(input_tensor);
77 torch::Tensor action = at::normal(output_tensor,
_std);
86 log_probability.data_ptr<
Real>() +
87 log_probability.size(1)};
89 for (
unsigned int control_i = 0; control_i < n_controls; ++control_i)
106 const torch::Tensor & output_tensor)
109 torch::Tensor var = torch::matmul(
_std,
_std);
111 return -((action - output_tensor) * (action - output_tensor)) / (2.0 * var) - torch::log(
_std) -
112 std::log(std::sqrt(2.0 * M_PI));
119 "The index of the requested control signal is not in the [0," +
const std::vector< Real > _action_scaling_factors
std::vector< Real > _current_response
static InputParameters validParams()
std::shared_ptr< Moose::LibtorchNeuralNetBase > _nn
registerMooseObject("StochasticToolsApp", LibtorchDRLControl)
torch::Tensor prepareInputTensor()
const std::vector< std::string > & _control_names
bool isParamValid(const std::string &name) const
virtual void execute() override
We compute the actions in this function together with the corresponding logarithmic probabilities...
const std::vector< Real > _action_std
Standard deviation for the actions, supplied by the user.
Real getSignalLogProbability(const unsigned int signal_index) const
Get the logarithmic probability of (signal_index)-th signal of the control neural net...
static InputParameters validParams()
void paramError(const std::string ¶m, Args... args) const
std::vector< std::vector< Real > > & _old_responses
LibtorchDRLControl(const InputParameters ¶meters)
Construct using input parameters.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
const unsigned int _input_timesteps
torch::Tensor _std
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines...
A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimiza...
void updateCurrentResponse()
std::vector< Real > _current_control_signal_log_probabilities
The log probability of control signals from the last evaluation of the controller.
std::vector< Real > _current_control_signals
torch::Tensor computeLogProbability(const torch::Tensor &action, const torch::Tensor &output_tensor)
Function which computes the logarithmic probability of given actions.