https://mooseframework.inl.gov
Public Member Functions | Static Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
neml2::LibtorchModel Class Reference

Evaluate a pretrained libtorch model in .pt format, such as a neural network. More...

#include <LibtorchModel.h>

Inheritance diagram for neml2::LibtorchModel:
[legend]

Public Member Functions

 LibtorchModel (const OptionSet &options)
 
virtual void to (const torch::TensorOptions &options) override
 Override the base implementation to additionally send the model loaded from torch script to different device and dtype. More...
 
virtual void request_AD () override
 

Static Public Member Functions

static OptionSet expected_options ()
 

Protected Member Functions

virtual void set_value (bool out, bool dout_din, bool d2out_din2) override
 

Protected Attributes

std::vector< const Variable< Scalar > * > _inputs
 
std::vector< Variable< Scalar > * > _outputs
 
Moose::DataFileUtils::Path _file_path
 
std::unique_ptr< torch::jit::script::Module > _surrogate
 We need to use a pointer here because forward is not const qualified. More...
 

Detailed Description

Evaluate a pretrained libtorch model in .pt format, such as a neural network.

Evaluates models with an arbitrary number of inputs and maps them to an arbitrary number of outputs.

Definition at line 31 of file LibtorchModel.h.

Constructor & Destructor Documentation

◆ LibtorchModel()

neml2::LibtorchModel::LibtorchModel ( const OptionSet &  options)

Definition at line 39 of file LibtorchModel.C.

40  : Model(options),
41  _file_path(Moose::DataFileUtils::getPath(options.get<std::string>("file_path"))),
42  _surrogate(std::make_unique<torch::jit::script::Module>(torch::jit::load(_file_path.path)))
43 {
44  // inputs
45  for (const auto & fv : options.get<std::vector<VariableName>>("inputs"))
46  _inputs.push_back(&declare_input_variable<Scalar>(fv));
47  for (const auto & fv : options.get<std::vector<VariableName>>("outputs"))
48  _outputs.push_back(&declare_output_variable<Scalar>(fv));
49 }
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
Definition: LibtorchModel.h:55
std::vector< Variable< Scalar > * > _outputs
Definition: LibtorchModel.h:52
Path getPath(std::string path, const std::optional< std::string > &base=std::optional< std::string >())
Get the data path for a given path, searching the registered data.
Definition: DataFileUtils.C:22
std::vector< const Variable< Scalar > * > _inputs
Definition: LibtorchModel.h:50
Moose::DataFileUtils::Path _file_path
Definition: LibtorchModel.h:53

Member Function Documentation

◆ expected_options()

OptionSet neml2::LibtorchModel::expected_options ( )
static

Definition at line 26 of file LibtorchModel.C.

27 {
28  auto options = Model::expected_options();
29  options.set<std::vector<VariableName>>("inputs");
30  options.set<std::vector<VariableName>>("outputs");
31  options.set("outputs").doc() = "The scaled neural network output";
32  options.set<std::string>("file_path");
33  // No jitting :/
34  options.set<bool>("jit") = false;
35  options.set("jit").suppressed() = true;
36  return options;
37 }

◆ request_AD()

void neml2::LibtorchModel::request_AD ( )
overridevirtual

Definition at line 64 of file LibtorchModel.C.

65 {
66  std::vector<const VariableBase *> inputs;
67  for (size_t i = 0; i < _inputs.size(); ++i)
68  inputs.push_back(_inputs[i]);
69 
70  for (size_t i = 0; i < _outputs.size(); ++i)
71  _outputs[i]->request_AD(inputs);
72 }
virtual void request_AD() override
Definition: LibtorchModel.C:64
std::vector< Variable< Scalar > * > _outputs
Definition: LibtorchModel.h:52
std::vector< const Variable< Scalar > * > _inputs
Definition: LibtorchModel.h:50

◆ set_value()

void neml2::LibtorchModel::set_value ( bool  out,
bool  dout_din,
bool  d2out_din2 
)
overrideprotectedvirtual

Definition at line 75 of file LibtorchModel.C.

76 {
77  if (out)
78  {
79  std::vector<at::Tensor> values;
80  auto first_batch_dim = _inputs[0]->batch_dim();
81  for (size_t i = 0; i < _inputs.size(); ++i)
82  {
83  // assert that all inputs have the same batch dimension
84  neml_assert(_inputs[i]->batch_dim() == first_batch_dim);
85  values.push_back(_inputs[i]->value());
86  }
87 
88  auto x = Tensor(torch::transpose(torch::vstack(at::ArrayRef<at::Tensor>(
89  values.data(), static_cast<int64_t>(values.size()))),
90  0,
91  1),
92  _inputs[0]->batch_dim());
93 
94  // Feed forward the neural network and process the output
95  auto temp = _surrogate->forward({x}).toTensor().squeeze();
96  auto y0 =
97  (temp.dim() == 1) ? temp.view({temp.size(0), 1}).transpose(0, 1) : temp.transpose(0, 1);
98 
99  for (size_t i = 0; i < _outputs.size(); ++i)
100  *_outputs[i] = Scalar(y0[i], _inputs[0]->batch_dim());
101  }
102 }
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
Definition: LibtorchModel.h:55
Real value(unsigned n, unsigned alpha, unsigned beta, Real x)
std::vector< Variable< Scalar > * > _outputs
Definition: LibtorchModel.h:52
NumberTensorValue Tensor
std::vector< const Variable< Scalar > * > _inputs
Definition: LibtorchModel.h:50

◆ to()

void neml2::LibtorchModel::to ( const torch::TensorOptions &  options)
overridevirtual

Override the base implementation to additionally send the model loaded from torch script to different device and dtype.

Definition at line 52 of file LibtorchModel.C.

53 {
54  Model::to(options);
55 
56  if (options.has_device())
57  _surrogate->to(options.device());
58 
59  if (options.has_dtype())
60  _surrogate->to(torch::Dtype(caffe2::typeMetaToScalarType(options.dtype())));
61 }
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
Definition: LibtorchModel.h:55

Member Data Documentation

◆ _file_path

Moose::DataFileUtils::Path neml2::LibtorchModel::_file_path
protected

Definition at line 53 of file LibtorchModel.h.

◆ _inputs

std::vector<const Variable<Scalar> *> neml2::LibtorchModel::_inputs
protected

Definition at line 50 of file LibtorchModel.h.

Referenced by LibtorchModel(), request_AD(), and set_value().

◆ _outputs

std::vector<Variable<Scalar> *> neml2::LibtorchModel::_outputs
protected

Definition at line 52 of file LibtorchModel.h.

Referenced by LibtorchModel(), request_AD(), and set_value().

◆ _surrogate

std::unique_ptr<torch::jit::script::Module> neml2::LibtorchModel::_surrogate
protected

We need to use a pointer here because forward is not const qualified.

Definition at line 55 of file LibtorchModel.h.

Referenced by set_value(), and to().


The documentation for this class was generated from the following files: