https://mooseframework.inl.gov
Public Member Functions | List of all members
Moose::TorchScriptModule Class Reference

#include <TorchScriptModule.h>

Inheritance diagram for Moose::TorchScriptModule:
[legend]

Public Member Functions

 TorchScriptModule ()
 Construct using a filename which contains the source code in torchscript format. More...
 
 TorchScriptModule (const std::string &filename)
 Construct using a filename which contains the source code in torchscript format. More...
 
void loadNeuralNetwork (const std::string &filename)
 Construct the neural network. More...
 
virtual torch::Tensor forward (const torch::Tensor &x) override
 Overriding the forward substitution function for the neural network, unfortunately this cannot be const since it creates a graph in the background. More...
 

Detailed Description

Definition at line 24 of file TorchScriptModule.h.

Constructor & Destructor Documentation

◆ TorchScriptModule() [1/2]

Moose::TorchScriptModule::TorchScriptModule ( )

Construct using a filename which contains the source code in torchscript format.

Definition at line 19 of file TorchScriptModule.C.

19 {}

◆ TorchScriptModule() [2/2]

Moose::TorchScriptModule::TorchScriptModule ( const std::string &  filename)

Construct using a filename which contains the source code in torchscript format.

Parameters
filenameThe name of the file that contains the neural net

Definition at line 21 of file TorchScriptModule.C.

21 { loadNeuralNetwork(filename); }
void loadNeuralNetwork(const std::string &filename)
Construct the neural network.

Member Function Documentation

◆ forward()

torch::Tensor Moose::TorchScriptModule::forward ( const torch::Tensor x)
overridevirtual

Overriding the forward substitution function for the neural network, unfortunately this cannot be const since it creates a graph in the background.

Parameters
xInput tensor for the evaluation

Implements Moose::LibtorchNeuralNetBase.

Definition at line 38 of file TorchScriptModule.C.

39 {
40  std::vector<torch::jit::IValue> inputs(1, x);
41  return torch::jit::script::Module::forward(inputs).toTensor();
42 }

◆ loadNeuralNetwork()

void Moose::TorchScriptModule::loadNeuralNetwork ( const std::string &  filename)

Construct the neural network.

Definition at line 24 of file TorchScriptModule.C.

Referenced by TorchScriptModule().

25 {
26  try
27  {
28  torch::jit::script::Module * base = this;
29  *base = torch::jit::load(filename);
30  }
31  catch (const c10::Error & e)
32  {
33  mooseError("Error while loading torchscript file ", filename, "!\n", e.msg());
34  }
35 }
void mooseError(Args &&... args)
Emit an error message with the given stringified, concatenated args and terminate the application...
Definition: MooseError.h:302

The documentation for this class was generated from the following files: