Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #include <torch/torch.h> 13 : #include "TorchScriptModule.h" 14 : #include "MooseError.h" 15 : 16 : namespace Moose 17 : { 18 : 19 1 : TorchScriptModule::TorchScriptModule() {} 20 : 21 6 : TorchScriptModule::TorchScriptModule(const std::string & filename) { loadNeuralNetwork(filename); } 22 : 23 : void 24 7 : TorchScriptModule::loadNeuralNetwork(const std::string & filename) 25 : { 26 : try 27 : { 28 7 : torch::jit::script::Module * base = this; 29 7 : *base = torch::jit::load(filename); 30 : } 31 0 : catch (const c10::Error & e) 32 : { 33 0 : mooseError("Error while loading torchscript file ", filename, "!\n", e.msg()); 34 0 : } 35 7 : } 36 : 37 : torch::Tensor 38 249 : TorchScriptModule::forward(const torch::Tensor & x) 39 : { 40 249 : std::vector<torch::jit::IValue> inputs(1, x); 41 747 : return torch::jit::script::Module::forward(inputs).toTensor(); 42 249 : } 43 : 44 : } 45 : 46 : #endif