https://mooseframework.inl.gov
TorchScriptModule.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #pragma once
13 
14 #include <torch/torch.h>
15 #include <torch/script.h>
16 #include "LibtorchNeuralNetBase.h"
17 #include "DataIO.h"
18 #include "MultiMooseEnum.h"
19 
20 namespace Moose
21 {
22 
23 // A class that describes a torch-script-based module that can be evalauted.
24 class TorchScriptModule : public torch::jit::script::Module, public LibtorchNeuralNetBase
25 {
26 public:
31 
36  TorchScriptModule(const std::string & filename);
37 
39  void loadNeuralNetwork(const std::string & filename);
40 
46  virtual torch::Tensor forward(const torch::Tensor & x) override;
47 };
48 
49 }
50 
51 #endif
virtual torch::Tensor forward(const torch::Tensor &x) override
Overriding the forward substitution function for the neural network, unfortunately this cannot be con...
void loadNeuralNetwork(const std::string &filename)
Construct the neural network.
This base class is meant to gather the functions and members common in every neural network based on ...
TorchScriptModule()
Construct using a filename which contains the source code in torchscript format.
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...