LCOV - code coverage report
Current view: top level - src/neml2/models - LibtorchModel.C (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 2bf808 Lines: 47 48 97.9 %
Date: 2025-07-17 01:28:37 Functions: 5 5 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef NEML2_ENABLED
      11             : 
      12             : // libtorch headers
      13             : #include <ATen/ops/ones_like.h>
      14             : 
      15             : // neml2 headers
      16             : #include "neml2/misc/assertions.h"
      17             : 
      18             : // moose headers
      19             : #include "LibtorchModel.h"
      20             : 
      21             : namespace neml2
      22             : {
      23             : register_NEML2_object(LibtorchModel);
      24             : 
      25             : OptionSet
      26        5628 : LibtorchModel::expected_options()
      27             : {
      28        5628 :   auto options = Model::expected_options();
      29        5628 :   options.set<std::vector<VariableName>>("inputs");
      30        5628 :   options.set<std::vector<VariableName>>("outputs");
      31        5628 :   options.set("outputs").doc() = "The scaled neural network output";
      32        5628 :   options.set<std::string>("file_path");
      33             :   // No jitting :/
      34        5628 :   options.set<bool>("jit") = false;
      35        5628 :   options.set("jit").suppressed() = true;
      36        5628 :   return options;
      37           0 : }
      38             : 
      39          18 : LibtorchModel::LibtorchModel(const OptionSet & options)
      40             :   : Model(options),
      41          18 :     _file_path(Moose::DataFileUtils::getPath(options.get<std::string>("file_path"))),
      42          36 :     _surrogate(std::make_unique<torch::jit::script::Module>(torch::jit::load(_file_path.path)))
      43             : {
      44             :   // inputs
      45          36 :   for (const auto & fv : options.get<std::vector<VariableName>>("inputs"))
      46          36 :     _inputs.push_back(&declare_input_variable<Scalar>(fv));
      47          36 :   for (const auto & fv : options.get<std::vector<VariableName>>("outputs"))
      48          36 :     _outputs.push_back(&declare_output_variable<Scalar>(fv));
      49          18 : }
      50             : 
      51             : void
      52          27 : LibtorchModel::to(const torch::TensorOptions & options)
      53             : {
      54          27 :   Model::to(options);
      55             : 
      56          27 :   if (options.has_device())
      57           9 :     _surrogate->to(options.device());
      58             : 
      59          27 :   if (options.has_dtype())
      60          18 :     _surrogate->to(torch::Dtype(caffe2::typeMetaToScalarType(options.dtype())));
      61          27 : }
      62             : 
      63             : void
      64          18 : LibtorchModel::request_AD()
      65             : {
      66          18 :   std::vector<const VariableBase *> inputs;
      67          36 :   for (size_t i = 0; i < _inputs.size(); ++i)
      68          18 :     inputs.push_back(_inputs[i]);
      69             : 
      70          36 :   for (size_t i = 0; i < _outputs.size(); ++i)
      71          18 :     _outputs[i]->request_AD(inputs);
      72          18 : }
      73             : 
      74             : void
      75        1609 : LibtorchModel::set_value(bool out, bool /*dout_din*/, bool /*d2out_din2*/)
      76             : {
      77        1609 :   if (out)
      78             :   {
      79        1609 :     std::vector<at::Tensor> values;
      80        1609 :     auto first_batch_dim = _inputs[0]->batch_dim();
      81        3218 :     for (size_t i = 0; i < _inputs.size(); ++i)
      82             :     {
      83             :       // assert that all inputs have the same batch dimension
      84        1609 :       neml_assert(_inputs[i]->batch_dim() == first_batch_dim);
      85        1609 :       values.push_back(_inputs[i]->value());
      86             :     }
      87             : 
      88        3218 :     auto x = Tensor(torch::transpose(torch::vstack(at::ArrayRef<at::Tensor>(
      89        1609 :                                          values.data(), static_cast<int64_t>(values.size()))),
      90             :                                      0,
      91             :                                      1),
      92        3218 :                     _inputs[0]->batch_dim());
      93             : 
      94             :     // Feed forward the neural network and process the output
      95        4827 :     auto temp = _surrogate->forward({x}).toTensor().squeeze();
      96             :     auto y0 =
      97        1609 :         (temp.dim() == 1) ? temp.view({temp.size(0), 1}).transpose(0, 1) : temp.transpose(0, 1);
      98             : 
      99        3218 :     for (size_t i = 0; i < _outputs.size(); ++i)
     100        1609 :       *_outputs[i] = Scalar(y0[i], _inputs[0]->batch_dim());
     101        1609 :   }
     102        4827 : }
     103             : 
     104             : }
     105             : 
     106             : #endif

Generated by: LCOV version 1.14