LCOV - code coverage report
Current view: top level - src/libtorch/utils - LibtorchUtils.C (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 8601ad Lines: 6 18 33.3 %
Date: 2025-07-18 13:27:08 Functions: 1 2 50.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef LIBTORCH_ENABLED
      11             : 
      12             : #include "LibtorchUtils.h"
      13             : 
      14             : namespace LibtorchUtils
      15             : {
      16             : 
      17             : template <typename DataType>
      18             : void
      19          50 : vectorToTensor(std::vector<DataType> & vector, torch::Tensor & tensor, const bool detach)
      20             : {
      21          50 :   auto options = torch::TensorOptions();
      22             :   if constexpr (std::is_same<DataType, double>::value)
      23          50 :     options = torch::TensorOptions().dtype(at::kDouble);
      24             :   else if constexpr (std::is_same<DataType, float>::value)
      25             :     options = torch::TensorOptions().dtype(at::kFloat);
      26             :   else
      27             :     static_assert(Moose::always_false<DataType>,
      28             :                   "vectorToTensor is not implemented for the given data type!");
      29             : 
      30             :   // We need to clone here because from_blob() doesn't take ownership of the pointer so if it
      31             :   // vector goes out of scope before tensor, we get unwanted behavior
      32          50 :   tensor = torch::from_blob(vector.data(), {long(vector.size()), 1}, options).clone();
      33             : 
      34          50 :   if (detach)
      35           0 :     tensor.detach();
      36          50 : }
      37             : 
      38             : // Explicitly instantiate for DataType=Real
      39             : template void
      40             : vectorToTensor<Real>(std::vector<Real> & vector, torch::Tensor & tensor, const bool detach);
      41             : 
      42             : template <typename DataType>
      43             : void
      44           0 : tensorToVector(torch::Tensor & tensor, std::vector<DataType> & vector)
      45             : {
      46             :   try
      47             :   {
      48           0 :     tensor.data_ptr<DataType>();
      49             :   }
      50           0 :   catch (const c10::Error & e)
      51             :   {
      52           0 :     mooseError(
      53           0 :         "Cannot cast tensor values to", MooseUtils::prettyCppType<DataType>(), "!\n", e.msg());
      54             :   }
      55             : 
      56           0 :   const auto & sizes = tensor.sizes();
      57             : 
      58           0 :   long int max_size = 0;
      59           0 :   for (const auto & dim_size : sizes)
      60             :     // We do this comparison because XCode complains if we use std::max
      61           0 :     max_size = dim_size > max_size ? dim_size : max_size;
      62             : 
      63             :   mooseAssert(max_size == tensor.numel(), "The given tensor should be one-dimensional!");
      64           0 :   vector = {tensor.data_ptr<DataType>(), tensor.data_ptr<DataType>() + tensor.numel()};
      65           0 : }
      66             : 
      67             : // Explicitly instantiate for DataType=Real
      68             : template void tensorToVector<Real>(torch::Tensor & tensor, std::vector<Real> & vector);
      69             : 
      70             : } // LibtorchUtils namespace
      71             : 
      72             : #endif

Generated by: LCOV version 1.14