Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #include "TorchScriptUserObject.h" 13 : 14 : registerMooseObject("MooseApp", TorchScriptUserObject); 15 : 16 : InputParameters 17 1569 : TorchScriptUserObject::validParams() 18 : { 19 1569 : InputParameters params = GeneralUserObject::validParams(); 20 1569 : params.addClassDescription("User-facing object which loads a torch script module."); 21 1569 : params.addRequiredParam<FileName>("filename", 22 : "The file name which contains the torch script module."); 23 1569 : params.declareControllable("filename"); 24 4707 : params.addParam<bool>( 25 : "load_during_construction", 26 3138 : false, 27 : "If we want to load this neural network while we are constructing this object."); 28 : 29 : // By default we don't execute this user object, depending on the desired reload frequency, 30 : // the user can override this in the input file. 31 3138 : params.set<ExecFlagEnum>("execute_on", true) = {EXEC_NONE}; 32 : 33 1569 : return params; 34 1569 : } 35 : 36 4 : TorchScriptUserObject::TorchScriptUserObject(const InputParameters & parameters) 37 : : GeneralUserObject(parameters), 38 4 : _filename(getParam<FileName>("filename")), 39 8 : _torchscript_module(std::make_unique<Moose::TorchScriptModule>(_filename)) 40 : { 41 4 : } 42 : 43 : void 44 1 : TorchScriptUserObject::execute() 45 : { 46 : // We load when the user executes this user object 47 1 : _torchscript_module = std::make_unique<Moose::TorchScriptModule>(_filename); 48 1 : } 49 : 50 : torch::Tensor 51 221 : TorchScriptUserObject::evaluate(const torch::Tensor & input) const 52 : { 53 221 : return _torchscript_module->forward(input); 54 : } 55 : 56 : #endif