Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef LIBTORCH_ENABLED 11 : 12 : #pragma once 13 : 14 : #include <torch/torch.h> 15 : #include "MooseError.h" 16 : 17 : namespace Moose 18 : { 19 : 20 : /** 21 : * This class is a wrapper around a libtorch dataset which can be used by the 22 : * data loaders in the neural net training process. 23 : */ 24 : class LibtorchDataset : public torch::data::datasets::Dataset<LibtorchDataset> 25 : { 26 : public: 27 : /// Construct using the input and output tensors 28 25 : LibtorchDataset(torch::Tensor dt, torch::Tensor rt) : _data_tensor(dt), _response_tensor(rt) {} 29 : 30 : /// Get a sample pair from the input and output tensors 31 2040000 : torch::data::Example<> get(size_t index) override 32 : { 33 : mooseAssert(index < size(), "Index is out of range!"); 34 2040000 : return {_data_tensor[index], _response_tensor[index]}; 35 : } 36 : 37 : /// Return the number of samples this data set contains 38 24 : torch::optional<size_t> size() const override 39 : { 40 : mooseAssert(_response_tensor.sizes().size(), "The tensors are empty!"); 41 24 : return _response_tensor.sizes()[0]; 42 : } 43 : 44 : private: 45 : /// Tensor containing the data (inputs) 46 : torch::Tensor _data_tensor; 47 : /// Tensor containing the responses (outputs) for the data 48 : torch::Tensor _response_tensor; 49 : }; 50 : 51 : } 52 : 53 : #endif