This class is a wrapper around a libtorch dataset which can be used by the data loaders in the neural net training process. More...
#include <LibtorchDataset.h>
Public Member Functions | |
LibtorchDataset (torch::Tensor dt, torch::Tensor rt) | |
Construct using the input and output tensors. More... | |
torch::data::Example | get (size_t index) override |
Get a sample pair from the input and output tensors. More... | |
torch::optional< size_t > | size () const override |
Return the number of samples this data set contains. More... | |
Private Attributes | |
torch::Tensor | _data_tensor |
Tensor containing the data (inputs) More... | |
torch::Tensor | _response_tensor |
Tensor containing the responses (outputs) for the data. More... | |
This class is a wrapper around a libtorch dataset which can be used by the data loaders in the neural net training process.
Definition at line 24 of file LibtorchDataset.h.
|
inline |
Construct using the input and output tensors.
Definition at line 28 of file LibtorchDataset.h.
|
inlineoverride |
Get a sample pair from the input and output tensors.
Definition at line 31 of file LibtorchDataset.h.
|
inlineoverride |
Return the number of samples this data set contains.
Definition at line 38 of file LibtorchDataset.h.
Referenced by get(), and Moose::LibtorchArtificialNeuralNetTrainer< SamplerType >::train().
|
private |
Tensor containing the data (inputs)
Definition at line 46 of file LibtorchDataset.h.
Referenced by get().
|
private |
Tensor containing the responses (outputs) for the data.
Definition at line 48 of file LibtorchDataset.h.