LCOV - code coverage report
Current view: top level - src/libtorch/utils - TorchScriptModule.C (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 8601ad Lines: 10 13 76.9 %
Date: 2025-07-18 13:27:08 Functions: 4 4 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef LIBTORCH_ENABLED
      11             : 
      12             : #include <torch/torch.h>
      13             : #include "TorchScriptModule.h"
      14             : #include "MooseError.h"
      15             : 
      16             : namespace Moose
      17             : {
      18             : 
      19           1 : TorchScriptModule::TorchScriptModule() {}
      20             : 
      21           6 : TorchScriptModule::TorchScriptModule(const std::string & filename) { loadNeuralNetwork(filename); }
      22             : 
      23             : void
      24           7 : TorchScriptModule::loadNeuralNetwork(const std::string & filename)
      25             : {
      26             :   try
      27             :   {
      28           7 :     torch::jit::script::Module * base = this;
      29           7 :     *base = torch::jit::load(filename);
      30             :   }
      31           0 :   catch (const c10::Error & e)
      32             :   {
      33           0 :     mooseError("Error while loading torchscript file ", filename, "!\n", e.msg());
      34           0 :   }
      35           7 : }
      36             : 
      37             : torch::Tensor
      38         249 : TorchScriptModule::forward(const torch::Tensor & x)
      39             : {
      40         249 :   std::vector<torch::jit::IValue> inputs(1, x);
      41         747 :   return torch::jit::script::Module::forward(inputs).toTensor();
      42         249 : }
      43             : 
      44             : }
      45             : 
      46             : #endif

Generated by: LCOV version 1.14