https://mooseframework.inl.gov
Public Member Functions | Private Attributes | List of all members
Moose::LibtorchDataset Class Reference

This class is a wrapper around a libtorch dataset which can be used by the data loaders in the neural net training process. More...

#include <LibtorchDataset.h>

Inheritance diagram for Moose::LibtorchDataset:
[legend]

Public Member Functions

 LibtorchDataset (torch::Tensor dt, torch::Tensor rt)
 Construct using the input and output tensors. More...
 
torch::data::Example get (size_t index) override
 Get a sample pair from the input and output tensors. More...
 
torch::optional< size_t > size () const override
 Return the number of samples this data set contains. More...
 

Private Attributes

torch::Tensor _data_tensor
 Tensor containing the data (inputs) More...
 
torch::Tensor _response_tensor
 Tensor containing the responses (outputs) for the data. More...
 

Detailed Description

This class is a wrapper around a libtorch dataset which can be used by the data loaders in the neural net training process.

Definition at line 24 of file LibtorchDataset.h.

Constructor & Destructor Documentation

◆ LibtorchDataset()

Moose::LibtorchDataset::LibtorchDataset ( torch::Tensor  dt,
torch::Tensor  rt 
)
inline

Construct using the input and output tensors.

Definition at line 28 of file LibtorchDataset.h.

28 : _data_tensor(dt), _response_tensor(rt) {}
torch::Tensor _data_tensor
Tensor containing the data (inputs)
torch::Tensor _response_tensor
Tensor containing the responses (outputs) for the data.

Member Function Documentation

◆ get()

torch::data::Example Moose::LibtorchDataset::get ( size_t  index)
inlineoverride

Get a sample pair from the input and output tensors.

Definition at line 31 of file LibtorchDataset.h.

32  {
33  mooseAssert(index < size(), "Index is out of range!");
34  return {_data_tensor[index], _response_tensor[index]};
35  }
torch::optional< size_t > size() const override
Return the number of samples this data set contains.
torch::Tensor _data_tensor
Tensor containing the data (inputs)
torch::Tensor _response_tensor
Tensor containing the responses (outputs) for the data.

◆ size()

torch::optional<size_t> Moose::LibtorchDataset::size ( ) const
inlineoverride

Return the number of samples this data set contains.

Definition at line 38 of file LibtorchDataset.h.

Referenced by get(), and Moose::LibtorchArtificialNeuralNetTrainer< SamplerType >::train().

39  {
40  mooseAssert(_response_tensor.sizes().size(), "The tensors are empty!");
41  return _response_tensor.sizes()[0];
42  }
torch::Tensor _response_tensor
Tensor containing the responses (outputs) for the data.

Member Data Documentation

◆ _data_tensor

torch::Tensor Moose::LibtorchDataset::_data_tensor
private

Tensor containing the data (inputs)

Definition at line 46 of file LibtorchDataset.h.

Referenced by get().

◆ _response_tensor

torch::Tensor Moose::LibtorchDataset::_response_tensor
private

Tensor containing the responses (outputs) for the data.

Definition at line 48 of file LibtorchDataset.h.

Referenced by get(), and size().


The documentation for this class was generated from the following file: