LCOV - code coverage report
Current view: top level - include/tensor_buffers - TensorBuffer.h (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 19 26 73.1 %
Date: 2025-09-10 17:10:32 Functions: 4 8 50.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include "TensorBufferBase.h"
      12             : 
      13             : /**
      14             :  * Tensor wrapper arbitrary tensor value dimensions
      15             :  */
      16             : template <typename T>
      17             : class TensorBuffer : public TensorBufferBase
      18             : {
      19             : public:
      20             :   static InputParameters validParams();
      21             : 
      22             :   TensorBuffer(const InputParameters & parameters);
      23             : 
      24             :   virtual std::size_t advanceState() override;
      25             :   virtual void clearStates() override;
      26             : 
      27             :   T & getTensor();
      28             :   const std::vector<T> & getOldTensor(std::size_t states_requested);
      29             : 
      30             :   virtual const torch::Tensor & getRawTensor() const override;
      31             :   virtual const torch::Tensor & getRawCPUTensor() override;
      32             : 
      33             : protected:
      34             :   /// current state of the tensor
      35             :   T _u;
      36             : 
      37             :   /// potential CPU copy of the tensor (if requested)
      38             :   T _u_cpu;
      39             : 
      40             :   /// was a CPU copy requested?
      41             :   bool _cpu_copy_requested;
      42             : 
      43             :   /// old states of the tensor
      44             :   std::vector<T> _u_old;
      45             :   std::size_t _max_states;
      46             : };
      47             : 
      48             : template <typename T>
      49             : InputParameters
      50           0 : TensorBuffer<T>::validParams()
      51             : {
      52        1214 :   InputParameters params = TensorBufferBase::validParams();
      53           0 :   return params;
      54             : }
      55             : 
      56             : template <typename T>
      57         874 : TensorBuffer<T>::TensorBuffer(const InputParameters & parameters)
      58         874 :   : TensorBufferBase(parameters), _cpu_copy_requested(false), _max_states(0)
      59             : {
      60         874 : }
      61             : 
      62             : template <typename T>
      63             : std::size_t
      64      826944 : TensorBuffer<T>::advanceState()
      65             : {
      66             :   // make room to push state one step further back
      67      826944 :   if (_u_old.size() < _max_states)
      68         134 :     _u_old.resize(_u_old.size() + 1);
      69             : 
      70             :   // push state further back
      71      826944 :   if (!_u_old.empty())
      72             :   {
      73      166154 :     for (std::size_t i = _u_old.size() - 1; i > 0; --i)
      74       81560 :       _u_old[i] = _u_old[i - 1];
      75             :     _u_old[0] = _u;
      76             :   }
      77             : 
      78      826944 :   return _u_old.size();
      79             : }
      80             : 
      81             : template <typename T>
      82             : void
      83           0 : TensorBuffer<T>::clearStates()
      84             : {
      85           0 :   _u_old.clear();
      86           0 : }
      87             : 
      88             : template <typename T>
      89             : const torch::Tensor &
      90         588 : TensorBuffer<T>::getRawTensor() const
      91             : {
      92         588 :   return _u;
      93             : }
      94             : 
      95             : template <typename T>
      96             : const torch::Tensor &
      97        1402 : TensorBuffer<T>::getRawCPUTensor()
      98             : {
      99        1402 :   _cpu_copy_requested = true;
     100        1402 :   return _u_cpu;
     101             : }
     102             : 
     103             : template <typename T>
     104             : T &
     105           0 : TensorBuffer<T>::getTensor()
     106             : {
     107        1422 :   return _u;
     108             : }
     109             : 
     110             : template <typename T>
     111             : const std::vector<T> &
     112           0 : TensorBuffer<T>::getOldTensor(std::size_t states_requested)
     113             : {
     114         130 :   _max_states = std::max(_max_states, states_requested);
     115         130 :   return _u_old;
     116             : }
     117             : 
     118             : /**
     119             :  * Specialization of this helper struct can be used to force the use of derived
     120             :  * classes for implicit TensorBuffer construction (i.e. tensors that are not explicitly
     121             :  * listed under [TensorBuffers]).
     122             :  */
     123             : template <typename T>
     124             : struct TensorBufferSpecialization;
     125             : 
     126             : #define registerTensorType(derived_class, tensor_type)                                             \
     127             :   template <>                                                                                      \
     128             :   struct TensorBufferSpecialization<tensor_type>                                                   \
     129             :   {                                                                                                \
     130             :     using type = derived_class;                                                                    \
     131             :   }

Generated by: LCOV version 1.14