https://mooseframework.inl.gov
LibtorchModel.C
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef NEML2_ENABLED
11 
12 // libtorch headers
13 #include <ATen/ops/ones_like.h>
14 
15 // neml2 headers
16 #include "neml2/misc/assertions.h"
17 
18 // moose headers
19 #include "LibtorchModel.h"
20 
21 namespace neml2
22 {
23 register_NEML2_object(LibtorchModel);
24 
25 OptionSet
27 {
28  auto options = Model::expected_options();
29  options.set<std::vector<VariableName>>("inputs");
30  options.set<std::vector<VariableName>>("outputs");
31  options.set("outputs").doc() = "The scaled neural network output";
32  options.set<std::string>("file_path");
33  // No jitting :/
34  options.set<bool>("jit") = false;
35  options.set("jit").suppressed() = true;
36  return options;
37 }
38 
39 LibtorchModel::LibtorchModel(const OptionSet & options)
40  : Model(options),
41  _file_path(Moose::DataFileUtils::getPath(options.get<std::string>("file_path"))),
42  _surrogate(std::make_unique<torch::jit::script::Module>(torch::jit::load(_file_path.path)))
43 {
44  // inputs
45  for (const auto & fv : options.get<std::vector<VariableName>>("inputs"))
46  _inputs.push_back(&declare_input_variable<Scalar>(fv));
47  for (const auto & fv : options.get<std::vector<VariableName>>("outputs"))
48  _outputs.push_back(&declare_output_variable<Scalar>(fv));
49 }
50 
51 void
52 LibtorchModel::to(const torch::TensorOptions & options)
53 {
54  Model::to(options);
55 
56  if (options.has_device())
57  _surrogate->to(options.device());
58 
59  if (options.has_dtype())
60  _surrogate->to(torch::Dtype(caffe2::typeMetaToScalarType(options.dtype())));
61 }
62 
63 void
65 {
66  std::vector<const VariableBase *> inputs;
67  for (size_t i = 0; i < _inputs.size(); ++i)
68  inputs.push_back(_inputs[i]);
69 
70  for (size_t i = 0; i < _outputs.size(); ++i)
71  _outputs[i]->request_AD(inputs);
72 }
73 
74 void
75 LibtorchModel::set_value(bool out, bool /*dout_din*/, bool /*d2out_din2*/)
76 {
77  if (out)
78  {
79  std::vector<at::Tensor> values;
80  auto first_batch_dim = _inputs[0]->batch_dim();
81  for (size_t i = 0; i < _inputs.size(); ++i)
82  {
83  // assert that all inputs have the same batch dimension
84  neml_assert(_inputs[i]->batch_dim() == first_batch_dim);
85  values.push_back(_inputs[i]->value());
86  }
87 
88  auto x = Tensor(torch::transpose(torch::vstack(at::ArrayRef<at::Tensor>(
89  values.data(), static_cast<int64_t>(values.size()))),
90  0,
91  1),
92  _inputs[0]->batch_dim());
93 
94  // Feed forward the neural network and process the output
95  auto temp = _surrogate->forward({x}).toTensor().squeeze();
96  auto y0 =
97  (temp.dim() == 1) ? temp.view({temp.size(0), 1}).transpose(0, 1) : temp.transpose(0, 1);
98 
99  for (size_t i = 0; i < _outputs.size(); ++i)
100  *_outputs[i] = Scalar(y0[i], _inputs[0]->batch_dim());
101  }
102 }
103 
104 }
105 
106 #endif
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
Definition: LibtorchModel.h:55
virtual void to(const torch::TensorOptions &options) override
Override the base implementation to additionally send the model loaded from torch script to different...
Definition: LibtorchModel.C:52
T * get(const std::unique_ptr< T > &u)
The MooseUtils::get() specializations are used to support making forwards-compatible code changes fro...
Definition: MooseUtils.h:1155
register_NEML2_object(LibtorchModel)
static OptionSet expected_options()
Definition: LibtorchModel.C:26
Real value(unsigned n, unsigned alpha, unsigned beta, Real x)
virtual void request_AD() override
Definition: LibtorchModel.C:64
LibtorchModel(const OptionSet &options)
Definition: LibtorchModel.C:39
virtual void set_value(bool out, bool dout_din, bool d2out_din2) override
Definition: LibtorchModel.C:75
std::vector< Variable< Scalar > * > _outputs
Definition: LibtorchModel.h:52
NumberTensorValue Tensor
OStreamProxy out
Path getPath(std::string path, const std::optional< std::string > &base=std::optional< std::string >())
Get the data path for a given path, searching the registered data.
Definition: DataFileUtils.C:22
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...
std::vector< const Variable< Scalar > * > _inputs
Definition: LibtorchModel.h:50