https://mooseframework.inl.gov
Functions
LibtorchUtils Namespace Reference

Functions

template<typename DataType >
void vectorToTensor (std::vector< DataType > &vector, torch::Tensor &tensor, const bool detach=false)
 Utility function that converts a standard vector to a torch::Tensor. More...
 
template<typename DataType >
void tensorToVector (torch::Tensor &tensor, std::vector< DataType > &vector)
 Utility function that converts a torch::Tensor to a standard vector. More...
 
template void vectorToTensor< Real > (std::vector< Real > &vector, torch::Tensor &tensor, const bool detach)
 
template void tensorToVector< Real > (torch::Tensor &tensor, std::vector< Real > &vector)
 

Function Documentation

◆ tensorToVector()

template<typename DataType >
void LibtorchUtils::tensorToVector ( torch::Tensor tensor,
std::vector< DataType > &  vector 
)

Utility function that converts a torch::Tensor to a standard vector.

Template Parameters
DataTypeThe type of data (float,double, etc.) which the vector is filled with
Parameters
tensorThe tensor which needs to be converted
vectorThe output vector

Definition at line 44 of file LibtorchUtils.C.

45 {
46  try
47  {
48  tensor.data_ptr<DataType>();
49  }
50  catch (const c10::Error & e)
51  {
52  mooseError(
53  "Cannot cast tensor values to", MooseUtils::prettyCppType<DataType>(), "!\n", e.msg());
54  }
55 
56  const auto & sizes = tensor.sizes();
57 
58  long int max_size = 0;
59  for (const auto & dim_size : sizes)
60  // We do this comparison because XCode complains if we use std::max
61  max_size = dim_size > max_size ? dim_size : max_size;
62 
63  mooseAssert(max_size == tensor.numel(), "The given tensor should be one-dimensional!");
64  vector = {tensor.data_ptr<DataType>(), tensor.data_ptr<DataType>() + tensor.numel()};
65 }
void mooseError(Args &&... args)
Emit an error message with the given stringified, concatenated args and terminate the application...
Definition: MooseError.h:302

◆ tensorToVector< Real >()

template void LibtorchUtils::tensorToVector< Real > ( torch::Tensor tensor,
std::vector< Real > &  vector 
)

◆ vectorToTensor()

template<typename DataType >
void LibtorchUtils::vectorToTensor ( std::vector< DataType > &  vector,
torch::Tensor tensor,
const bool  detach = false 
)

Utility function that converts a standard vector to a torch::Tensor.

Template Parameters
DataTypeThe type of data (float,double, etc.) which the vector is filled with
Parameters
vectorThe vector that needs to be converted
tensorThe output tensor
detachIf the gradient information needs to be detached during the conversion

Definition at line 19 of file LibtorchUtils.C.

Referenced by LibtorchNeuralNetControl::prepareInputTensor().

20 {
21  auto options = torch::TensorOptions();
22  if constexpr (std::is_same<DataType, double>::value)
23  options = torch::TensorOptions().dtype(at::kDouble);
24  else if constexpr (std::is_same<DataType, float>::value)
25  options = torch::TensorOptions().dtype(at::kFloat);
26  else
27  static_assert(Moose::always_false<DataType>,
28  "vectorToTensor is not implemented for the given data type!");
29 
30  // We need to clone here because from_blob() doesn't take ownership of the pointer so if it
31  // vector goes out of scope before tensor, we get unwanted behavior
32  tensor = torch::from_blob(vector.data(), {long(vector.size()), 1}, options).clone();
33 
34  if (detach)
35  tensor.detach();
36 }

◆ vectorToTensor< Real >()

template void LibtorchUtils::vectorToTensor< Real > ( std::vector< Real > &  vector,
torch::Tensor tensor,
const bool  detach 
)