Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #include "LibtorchUtils.h" 13 : 14 : namespace LibtorchUtils 15 : { 16 : 17 : template <typename DataType> 18 : void 19 50 : vectorToTensor(std::vector<DataType> & vector, torch::Tensor & tensor, const bool detach) 20 : { 21 50 : auto options = torch::TensorOptions(); 22 : if constexpr (std::is_same<DataType, double>::value) 23 50 : options = torch::TensorOptions().dtype(at::kDouble); 24 : else if constexpr (std::is_same<DataType, float>::value) 25 : options = torch::TensorOptions().dtype(at::kFloat); 26 : else 27 : static_assert(Moose::always_false<DataType>, 28 : "vectorToTensor is not implemented for the given data type!"); 29 : 30 : // We need to clone here because from_blob() doesn't take ownership of the pointer so if it 31 : // vector goes out of scope before tensor, we get unwanted behavior 32 50 : tensor = torch::from_blob(vector.data(), {long(vector.size()), 1}, options).clone(); 33 : 34 50 : if (detach) 35 0 : tensor.detach(); 36 50 : } 37 : 38 : // Explicitly instantiate for DataType=Real 39 : template void 40 : vectorToTensor<Real>(std::vector<Real> & vector, torch::Tensor & tensor, const bool detach); 41 : 42 : template <typename DataType> 43 : void 44 0 : tensorToVector(torch::Tensor & tensor, std::vector<DataType> & vector) 45 : { 46 : try 47 : { 48 0 : tensor.data_ptr<DataType>(); 49 : } 50 0 : catch (const c10::Error & e) 51 : { 52 0 : mooseError( 53 0 : "Cannot cast tensor values to", MooseUtils::prettyCppType<DataType>(), "!\n", e.msg()); 54 : } 55 : 56 0 : const auto & sizes = tensor.sizes(); 57 : 58 0 : long int max_size = 0; 59 0 : for (const auto & dim_size : sizes) 60 : // We do this comparison because XCode complains if we use std::max 61 0 : max_size = dim_size > max_size ? dim_size : max_size; 62 : 63 : mooseAssert(max_size == tensor.numel(), "The given tensor should be one-dimensional!"); 64 0 : vector = {tensor.data_ptr<DataType>(), tensor.data_ptr<DataType>() + tensor.numel()}; 65 0 : } 66 : 67 : // Explicitly instantiate for DataType=Real 68 : template void tensorToVector<Real>(torch::Tensor & tensor, std::vector<Real> & vector); 69 : 70 : } // LibtorchUtils namespace 71 : 72 : #endif