13 #include <ATen/ops/ones_like.h> 16 #include "neml2/misc/assertions.h" 28 auto options = Model::expected_options();
29 options.set<std::vector<VariableName>>(
"inputs");
30 options.set<std::vector<VariableName>>(
"outputs");
31 options.set(
"outputs").doc() =
"The scaled neural network output";
32 options.set<std::string>(
"file_path");
34 options.set<
bool>(
"jit") =
false;
35 options.set(
"jit").suppressed() =
true;
42 _surrogate(
std::make_unique<torch::jit::script::Module>(torch::jit::load(_file_path.path)))
45 for (
const auto & fv : options.get<std::vector<VariableName>>(
"inputs"))
46 _inputs.push_back(&declare_input_variable<Scalar>(fv));
47 for (
const auto & fv : options.get<std::vector<VariableName>>(
"outputs"))
48 _outputs.push_back(&declare_output_variable<Scalar>(fv));
56 if (options.has_device())
59 if (options.has_dtype())
60 _surrogate->to(torch::Dtype(caffe2::typeMetaToScalarType(options.dtype())));
66 std::vector<const VariableBase *> inputs;
67 for (
size_t i = 0; i <
_inputs.size(); ++i)
70 for (
size_t i = 0; i <
_outputs.size(); ++i)
79 std::vector<at::Tensor> values;
80 auto first_batch_dim =
_inputs[0]->batch_dim();
81 for (
size_t i = 0; i <
_inputs.size(); ++i)
84 neml_assert(
_inputs[i]->batch_dim() == first_batch_dim);
88 auto x =
Tensor(torch::transpose(torch::vstack(at::ArrayRef<at::Tensor>(
89 values.data(),
static_cast<int64_t
>(values.size()))),
95 auto temp =
_surrogate->forward({x}).toTensor().squeeze();
97 (temp.dim() == 1) ? temp.view({temp.size(0), 1}).transpose(0, 1) : temp.transpose(0, 1);
99 for (
size_t i = 0; i <
_outputs.size(); ++i)
std::unique_ptr< torch::jit::script::Module > _surrogate
We need to use a pointer here because forward is not const qualified.
virtual void to(const torch::TensorOptions &options) override
Override the base implementation to additionally send the model loaded from torch script to different...
T * get(const std::unique_ptr< T > &u)
The MooseUtils::get() specializations are used to support making forwards-compatible code changes fro...
register_NEML2_object(LibtorchModel)
static OptionSet expected_options()
Real value(unsigned n, unsigned alpha, unsigned beta, Real x)
virtual void request_AD() override
LibtorchModel(const OptionSet &options)
virtual void set_value(bool out, bool dout_din, bool d2out_din2) override
std::vector< Variable< Scalar > * > _outputs
Path getPath(std::string path, const std::optional< std::string > &base=std::optional< std::string >())
Get the data path for a given path, searching the registered data.
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...
std::vector< const Variable< Scalar > * > _inputs