LCOV - code coverage report
Current view: top level - include/libtorch/utils - LibtorchDataset.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 8601ad Lines: 5 5 100.0 %
Date: 2025-07-18 13:27:08 Functions: 3 3 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #ifdef LIBTORCH_ENABLED
      11             : 
      12             : #pragma once
      13             : 
      14             : #include <torch/torch.h>
      15             : #include "MooseError.h"
      16             : 
      17             : namespace Moose
      18             : {
      19             : 
      20             : /**
      21             :  * This class is a wrapper around a libtorch dataset which can be used by the
      22             :  * data loaders in the neural net training process.
      23             :  */
      24             : class LibtorchDataset : public torch::data::datasets::Dataset<LibtorchDataset>
      25             : {
      26             : public:
      27             :   /// Construct using the input and output tensors
      28          25 :   LibtorchDataset(torch::Tensor dt, torch::Tensor rt) : _data_tensor(dt), _response_tensor(rt) {}
      29             : 
      30             :   /// Get a sample pair from the input and output tensors
      31     2040000 :   torch::data::Example<> get(size_t index) override
      32             :   {
      33             :     mooseAssert(index < size(), "Index is out of range!");
      34     2040000 :     return {_data_tensor[index], _response_tensor[index]};
      35             :   }
      36             : 
      37             :   /// Return the number of samples this data set contains
      38          24 :   torch::optional<size_t> size() const override
      39             :   {
      40             :     mooseAssert(_response_tensor.sizes().size(), "The tensors are empty!");
      41          24 :     return _response_tensor.sizes()[0];
      42             :   }
      43             : 
      44             : private:
      45             :   /// Tensor containing the data (inputs)
      46             :   torch::Tensor _data_tensor;
      47             :   /// Tensor containing the responses (outputs) for the data
      48             :   torch::Tensor _response_tensor;
      49             : };
      50             : 
      51             : }
      52             : 
      53             : #endif

Generated by: LCOV version 1.14