https://mooseframework.inl.gov
LibtorchDRLControl.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #pragma once
13 
16 
25 {
26 public:
28 
31 
33  virtual void execute() override;
34 
40  Real getSignalLogProbability(const unsigned int signal_index) const;
41 
42 protected:
50  torch::Tensor computeLogProbability(const torch::Tensor & action,
51  const torch::Tensor & output_tensor);
52 
55 
57  const std::vector<Real> _action_std;
58 
60  torch::Tensor _std;
61 };
62 
63 #endif
virtual void execute() override
We compute the actions in this function together with the corresponding logarithmic probabilities...
const std::vector< Real > _action_std
Standard deviation for the actions, supplied by the user.
Real getSignalLogProbability(const unsigned int signal_index) const
Get the logarithmic probability of (signal_index)-th signal of the control neural net...
static InputParameters validParams()
LibtorchDRLControl(const InputParameters &parameters)
Construct using input parameters.
DIE A HORRIBLE DEATH HERE typedef LIBMESH_DEFAULT_SCALAR_TYPE Real
const InputParameters & parameters() const
torch::Tensor _std
Standard deviations converted to a 2D diagonal tensor that can be used by Libtorch routines...
A time-dependent, neural-network-based controller which is associated with a Proximal Policy Optimiza...
std::vector< Real > _current_control_signal_log_probabilities
The log probability of control signals from the last evaluation of the controller.
torch::Tensor computeLogProbability(const torch::Tensor &action, const torch::Tensor &output_tensor)
Function which computes the logarithmic probability of given actions.