Evaluate a pretrained libtorch model in .pt
format, such as a neural network.
More...
#include <LibtorchModel.h>
Public Member Functions | |
LibtorchModel (const OptionSet &options) | |
virtual void | to (const torch::TensorOptions &options) override |
Override the base implementation to additionally send the model loaded from torch script to different device and dtype. More... | |
virtual void | request_AD () override |
Static Public Member Functions | |
static OptionSet | expected_options () |
Protected Member Functions | |
virtual void | set_value (bool out, bool dout_din, bool d2out_din2) override |
Protected Attributes | |
std::vector< const Variable< Scalar > * > | _inputs |
std::vector< Variable< Scalar > * > | _outputs |
Moose::DataFileUtils::Path | _file_path |
std::unique_ptr< torch::jit::script::Module > | _surrogate |
We need to use a pointer here because forward is not const qualified. More... | |
Evaluate a pretrained libtorch model in .pt
format, such as a neural network.
Evaluates models with an arbitrary number of inputs and maps them to an arbitrary number of outputs.
Definition at line 31 of file LibtorchModel.h.
neml2::LibtorchModel::LibtorchModel | ( | const OptionSet & | options | ) |
Definition at line 39 of file LibtorchModel.C.
|
static |
Definition at line 26 of file LibtorchModel.C.
|
overridevirtual |
Definition at line 64 of file LibtorchModel.C.
|
overrideprotectedvirtual |
Definition at line 75 of file LibtorchModel.C.
|
overridevirtual |
Override the base implementation to additionally send the model loaded from torch script to different device and dtype.
Definition at line 52 of file LibtorchModel.C.
|
protected |
Definition at line 53 of file LibtorchModel.h.
|
protected |
Definition at line 50 of file LibtorchModel.h.
Referenced by LibtorchModel(), request_AD(), and set_value().
|
protected |
Definition at line 52 of file LibtorchModel.h.
Referenced by LibtorchModel(), request_AD(), and set_value().
|
protected |
We need to use a pointer here because forward is not const qualified.
Definition at line 55 of file LibtorchModel.h.
Referenced by set_value(), and to().