https://mooseframework.inl.gov
LibtorchUtils.C
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #include "LibtorchUtils.h"
13 
14 namespace LibtorchUtils
15 {
16 
17 template <typename DataType>
18 void
19 vectorToTensor(std::vector<DataType> & vector, torch::Tensor & tensor, const bool detach)
20 {
21  auto options = torch::TensorOptions();
22  if constexpr (std::is_same<DataType, double>::value)
23  options = torch::TensorOptions().dtype(at::kDouble);
24  else if constexpr (std::is_same<DataType, float>::value)
25  options = torch::TensorOptions().dtype(at::kFloat);
26  else
27  static_assert(Moose::always_false<DataType>,
28  "vectorToTensor is not implemented for the given data type!");
29 
30  // We need to clone here because from_blob() doesn't take ownership of the pointer so if it
31  // vector goes out of scope before tensor, we get unwanted behavior
32  tensor = torch::from_blob(vector.data(), {long(vector.size()), 1}, options).clone();
33 
34  if (detach)
35  tensor.detach();
36 }
37 
38 // Explicitly instantiate for DataType=Real
39 template void
40 vectorToTensor<Real>(std::vector<Real> & vector, torch::Tensor & tensor, const bool detach);
41 
42 template <typename DataType>
43 void
44 tensorToVector(torch::Tensor & tensor, std::vector<DataType> & vector)
45 {
46  try
47  {
48  tensor.data_ptr<DataType>();
49  }
50  catch (const c10::Error & e)
51  {
52  mooseError(
53  "Cannot cast tensor values to", MooseUtils::prettyCppType<DataType>(), "!\n", e.msg());
54  }
55 
56  const auto & sizes = tensor.sizes();
57 
58  long int max_size = 0;
59  for (const auto & dim_size : sizes)
60  // We do this comparison because XCode complains if we use std::max
61  max_size = dim_size > max_size ? dim_size : max_size;
62 
63  mooseAssert(max_size == tensor.numel(), "The given tensor should be one-dimensional!");
64  vector = {tensor.data_ptr<DataType>(), tensor.data_ptr<DataType>() + tensor.numel()};
65 }
66 
67 // Explicitly instantiate for DataType=Real
68 template void tensorToVector<Real>(torch::Tensor & tensor, std::vector<Real> & vector);
69 
70 } // LibtorchUtils namespace
71 
72 #endif
void mooseError(Args &&... args)
Emit an error message with the given stringified, concatenated args and terminate the application...
Definition: MooseError.h:302
void tensorToVector(torch::Tensor &tensor, std::vector< DataType > &vector)
Utility function that converts a torch::Tensor to a standard vector.
Definition: LibtorchUtils.C:44
void vectorToTensor(std::vector< DataType > &vector, torch::Tensor &tensor, const bool detach=false)
Utility function that converts a standard vector to a torch::Tensor.
Definition: LibtorchUtils.C:19
template void tensorToVector< Real >(torch::Tensor &tensor, std::vector< Real > &vector)
template void vectorToTensor< Real >(std::vector< Real > &vector, torch::Tensor &tensor, const bool detach)