https://mooseframework.inl.gov
TorchScriptModule.C
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #include <torch/torch.h>
13 #include "TorchScriptModule.h"
14 #include "MooseError.h"
15 
16 namespace Moose
17 {
18 
20 
21 TorchScriptModule::TorchScriptModule(const std::string & filename) { loadNeuralNetwork(filename); }
22 
23 void
24 TorchScriptModule::loadNeuralNetwork(const std::string & filename)
25 {
26  try
27  {
28  torch::jit::script::Module * base = this;
29  *base = torch::jit::load(filename);
30  }
31  catch (const c10::Error & e)
32  {
33  mooseError("Error while loading torchscript file ", filename, "!\n", e.msg());
34  }
35 }
36 
37 torch::Tensor
38 TorchScriptModule::forward(const torch::Tensor & x)
39 {
40  std::vector<torch::jit::IValue> inputs(1, x);
41  return torch::jit::script::Module::forward(inputs).toTensor();
42 }
43 
44 }
45 
46 #endif
virtual torch::Tensor forward(const torch::Tensor &x) override
Overriding the forward substitution function for the neural network, unfortunately this cannot be con...
void mooseError(Args &&... args)
Emit an error message with the given stringified, concatenated args and terminate the application...
Definition: MooseError.h:302
void loadNeuralNetwork(const std::string &filename)
Construct the neural network.
TorchScriptModule()
Construct using a filename which contains the source code in torchscript format.
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...