https://mooseframework.inl.gov
LibtorchModel.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #pragma once
11 
12 #ifdef NEML2_ENABLED
13 
14 // libtorch headers
15 #include <torch/script.h>
16 
17 // neml2 headers
18 #include "neml2/models/Model.h"
19 
20 // moose headers
21 #include "DataFileUtils.h"
22 
23 namespace neml2
24 {
25 
31 class LibtorchModel : public Model
32 {
33 public:
34  static OptionSet expected_options();
35 
36  LibtorchModel(const OptionSet & options);
37 
42  virtual void to(const torch::TensorOptions & options) override;
43 
44  virtual void request_AD() override;
45 
46 protected:
47  virtual void set_value(bool out, bool dout_din, bool d2out_din2) override;
48 
49  // Input variable vector
50  std::vector<const Variable<Scalar> *> _inputs;
51  // Output vector
52  std::vector<Variable<Scalar> *> _outputs;
55  std::unique_ptr<torch::jit::script::Module> _surrogate;
56 };
57 
58 }
59 
60 #endif
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
Definition: LibtorchModel.h:55
virtual void to(const torch::TensorOptions &options) override
Override the base implementation to additionally send the model loaded from torch script to different...
Definition: LibtorchModel.C:52
static OptionSet expected_options()
Definition: LibtorchModel.C:26
virtual void request_AD() override
Definition: LibtorchModel.C:64
Evaluate a pretrained libtorch model in .pt format, such as a neural network.
Definition: LibtorchModel.h:31
Representation of a data file path.
Definition: DataFileUtils.h:36
LibtorchModel(const OptionSet &options)
Definition: LibtorchModel.C:39
virtual void set_value(bool out, bool dout_din, bool d2out_din2) override
Definition: LibtorchModel.C:75
std::vector< Variable< Scalar > * > _outputs
Definition: LibtorchModel.h:52
OStreamProxy out
std::vector< const Variable< Scalar > * > _inputs
Definition: LibtorchModel.h:50
Moose::DataFileUtils::Path _file_path
Definition: LibtorchModel.h:53