https://mooseframework.inl.gov
TorchScriptUserObject.C
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #include "TorchScriptUserObject.h"
13 
15 
18 {
20  params.addClassDescription("User-facing object which loads a torch script module.");
21  params.addRequiredParam<FileName>("filename",
22  "The file name which contains the torch script module.");
23  params.declareControllable("filename");
24  params.addParam<bool>(
25  "load_during_construction",
26  false,
27  "If we want to load this neural network while we are constructing this object.");
28 
29  // By default we don't execute this user object, depending on the desired reload frequency,
30  // the user can override this in the input file.
31  params.set<ExecFlagEnum>("execute_on", true) = {EXEC_NONE};
32 
33  return params;
34 }
35 
37  : GeneralUserObject(parameters),
38  _filename(getParam<FileName>("filename")),
39  _torchscript_module(std::make_unique<Moose::TorchScriptModule>(_filename))
40 {
41 }
42 
43 void
45 {
46  // We load when the user executes this user object
47  _torchscript_module = std::make_unique<Moose::TorchScriptModule>(_filename);
48 }
49 
50 torch::Tensor
51 TorchScriptUserObject::evaluate(const torch::Tensor & input) const
52 {
53  return _torchscript_module->forward(input);
54 }
55 
56 #endif
std::unique_ptr< Moose::TorchScriptModule > _torchscript_module
The libtorch neural network that is currently stored here.
A MultiMooseEnum object to hold "execute_on" flags.
Definition: ExecFlagEnum.h:21
static InputParameters validParams()
static InputParameters validParams()
const ExecFlagType EXEC_NONE
Definition: Moose.C:27
T & set(const std::string &name, bool quiet_mode=false)
Returns a writable reference to the named parameters.
The main MOOSE class responsible for handling user-defined parameters in almost every MOOSE system...
A user object the loads a torch module using the torch script format and just-in-time compilation...
void addRequiredParam(const std::string &name, const std::string &doc_string)
This method adds a parameter and documentation string to the InputParameters object that will be extr...
virtual void execute() override
Execute method.
TorchScriptUserObject(const InputParameters &parameters)
registerMooseObject("MooseApp", TorchScriptUserObject)
const FileName & _filename
The file name that specifies the torch script model.
void addClassDescription(const std::string &doc_string)
This method adds a description of the class that will be displayed in the input file syntax dump...
void addParam(const std::string &name, const S &value, const std::string &doc_string)
These methods add an optional parameter and a documentation string to the InputParameters object...
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...
void declareControllable(const std::string &name, std::set< ExecFlagType > execute_flags={})
Declare the given parameters as controllable.
torch::Tensor evaluate(const torch::Tensor &input) const
Function to evaluate the torch script module at certain input.