LCOV - code coverage report
Current view: top level - include/libtorch/userobjects - TorchScriptUserObject.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 8601ad Lines: 2 2 100.0 %
Date: 2025-07-18 13:27:08 Functions: 2 2 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef LIBTORCH_ENABLED
      11             : 
      12             : #pragma once
      13             : 
      14             : // MOOSE includes
      15             : #include "GeneralUserObject.h"
      16             : #include "TorchScriptModule.h"
      17             : 
      18             : /**
      19             :  * A user object the loads a torch module using the
      20             :  * torch script format and just-in-time compilation.
      21             :  */
      22             : class TorchScriptUserObject : public GeneralUserObject
      23             : {
      24             : public:
      25             :   static InputParameters validParams();
      26             : 
      27             :   TorchScriptUserObject(const InputParameters & parameters);
      28             : 
      29           1 :   virtual void initialize() override {}
      30             :   virtual void execute() override;
      31           1 :   virtual void finalize() override {}
      32             : 
      33             :   ///@{
      34             :   /// Get const access to the module pointer.
      35             :   const std::unique_ptr<Moose::TorchScriptModule> & modulePtr() const
      36             :   {
      37             :     return _torchscript_module;
      38             :   }
      39             :   /// Get const access to the module.
      40             :   const Moose::TorchScriptModule & module() const { return *_torchscript_module; }
      41             :   /// Get non-const access to the module pointer. Could be used for further training within MOOSE.
      42             :   std::unique_ptr<Moose::TorchScriptModule> & modulePtr() { return _torchscript_module; }
      43             :   /// Get non-const access to the module. Could be used for further training within MOOSE.
      44             :   Moose::TorchScriptModule & module() { return *_torchscript_module; }
      45             :   ///@}
      46             : 
      47             :   /**
      48             :    * Function to evaluate the torch script module at certain input.
      49             :    * @param input The input tensor.
      50             :    */
      51             :   torch::Tensor evaluate(const torch::Tensor & input) const;
      52             : 
      53             : protected:
      54             :   /// The file name that specifies the torch script model.
      55             :   const FileName & _filename;
      56             : 
      57             :   /// The libtorch neural network that is currently stored here.
      58             :   std::unique_ptr<Moose::TorchScriptModule> _torchscript_module;
      59             : };
      60             : 
      61             : #endif

Generated by: LCOV version 1.14