https://mooseframework.inl.gov
LibtorchDataset.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #ifdef LIBTORCH_ENABLED
11 
12 #pragma once
13 
14 #include <torch/torch.h>
15 #include "MooseError.h"
16 
17 namespace Moose
18 {
19 
24 class LibtorchDataset : public torch::data::datasets::Dataset<LibtorchDataset>
25 {
26 public:
28  LibtorchDataset(torch::Tensor dt, torch::Tensor rt) : _data_tensor(dt), _response_tensor(rt) {}
29 
31  torch::data::Example<> get(size_t index) override
32  {
33  mooseAssert(index < size(), "Index is out of range!");
34  return {_data_tensor[index], _response_tensor[index]};
35  }
36 
38  torch::optional<size_t> size() const override
39  {
40  mooseAssert(_response_tensor.sizes().size(), "The tensors are empty!");
41  return _response_tensor.sizes()[0];
42  }
43 
44 private:
46  torch::Tensor _data_tensor;
48  torch::Tensor _response_tensor;
49 };
50 
51 }
52 
53 #endif
torch::optional< size_t > size() const override
Return the number of samples this data set contains.
This class is a wrapper around a libtorch dataset which can be used by the data loaders in the neural...
LibtorchDataset(torch::Tensor dt, torch::Tensor rt)
Construct using the input and output tensors.
torch::Tensor _data_tensor
Tensor containing the data (inputs)
MOOSE now contains C++17 code, so give a reasonable error message stating what the user can do to add...
torch::Tensor _response_tensor
Tensor containing the responses (outputs) for the data.