Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #pragma once 13 : 14 : // MOOSE includes 15 : #include "GeneralUserObject.h" 16 : #include "TorchScriptModule.h" 17 : 18 : /** 19 : * A user object the loads a torch module using the 20 : * torch script format and just-in-time compilation. 21 : */ 22 : class TorchScriptUserObject : public GeneralUserObject 23 : { 24 : public: 25 : static InputParameters validParams(); 26 : 27 : TorchScriptUserObject(const InputParameters & parameters); 28 : 29 1 : virtual void initialize() override {} 30 : virtual void execute() override; 31 1 : virtual void finalize() override {} 32 : 33 : ///@{ 34 : /// Get const access to the module pointer. 35 : const std::unique_ptr<Moose::TorchScriptModule> & modulePtr() const 36 : { 37 : return _torchscript_module; 38 : } 39 : /// Get const access to the module. 40 : const Moose::TorchScriptModule & module() const { return *_torchscript_module; } 41 : /// Get non-const access to the module pointer. Could be used for further training within MOOSE. 42 : std::unique_ptr<Moose::TorchScriptModule> & modulePtr() { return _torchscript_module; } 43 : /// Get non-const access to the module. Could be used for further training within MOOSE. 44 : Moose::TorchScriptModule & module() { return *_torchscript_module; } 45 : ///@} 46 : 47 : /** 48 : * Function to evaluate the torch script module at certain input. 49 : * @param input The input tensor. 50 : */ 51 : torch::Tensor evaluate(const torch::Tensor & input) const; 52 : 53 : protected: 54 : /// The file name that specifies the torch script model. 55 : const FileName & _filename; 56 : 57 : /// The libtorch neural network that is currently stored here. 58 : std::unique_ptr<Moose::TorchScriptModule> _torchscript_module; 59 : }; 60 : 61 : #endif