https://mooseframework.inl.gov
TorchScriptUserObject.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #pragma once
13 
14 // MOOSE includes
15 #include "GeneralUserObject.h"
16 #include "TorchScriptModule.h"
17 
23 {
24 public:
26 
28 
29  virtual void initialize() override {}
30  virtual void execute() override;
31  virtual void finalize() override {}
32 
35  const std::unique_ptr<Moose::TorchScriptModule> & modulePtr() const
36  {
37  return _torchscript_module;
38  }
42  std::unique_ptr<Moose::TorchScriptModule> & modulePtr() { return _torchscript_module; }
46 
51  torch::Tensor evaluate(const torch::Tensor & input) const;
52 
53 protected:
55  const FileName & _filename;
56 
58  std::unique_ptr<Moose::TorchScriptModule> _torchscript_module;
59 };
60 
61 #endif
std::unique_ptr< Moose::TorchScriptModule > _torchscript_module
The libtorch neural network that is currently stored here.
const Moose::TorchScriptModule & module() const
Get const access to the module.
static InputParameters validParams()
The main MOOSE class responsible for handling user-defined parameters in almost every MOOSE system...
A user object the loads a torch module using the torch script format and just-in-time compilation...
virtual void execute() override
Execute method.
TorchScriptUserObject(const InputParameters &parameters)
Moose::TorchScriptModule & module()
Get non-const access to the module. Could be used for further training within MOOSE.
std::unique_ptr< Moose::TorchScriptModule > & modulePtr()
Get non-const access to the module pointer. Could be used for further training within MOOSE...
const FileName & _filename
The file name that specifies the torch script model.
virtual void finalize() override
Finalize.
const std::unique_ptr< Moose::TorchScriptModule > & modulePtr() const
const InputParameters & parameters() const
Get the parameters of the object.
virtual void initialize() override
Called before execute() is ever called so that data can be cleared.
torch::Tensor evaluate(const torch::Tensor &input) const
Function to evaluate the torch script module at certain input.