Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #pragma once 10 : 11 : #include "TensorBufferBase.h" 12 : 13 : /** 14 : * Tensor wrapper arbitrary tensor value dimensions 15 : */ 16 : template <typename T> 17 : class TensorBuffer : public TensorBufferBase 18 : { 19 : public: 20 : static InputParameters validParams(); 21 : 22 : TensorBuffer(const InputParameters & parameters); 23 : 24 : virtual std::size_t advanceState() override; 25 : virtual void clearStates() override; 26 : 27 : T & getTensor(); 28 : const std::vector<T> & getOldTensor(std::size_t states_requested); 29 : 30 : virtual const torch::Tensor & getRawTensor() const override; 31 : virtual const torch::Tensor & getRawCPUTensor() override; 32 : 33 : protected: 34 : /// current state of the tensor 35 : T _u; 36 : 37 : /// potential CPU copy of the tensor (if requested) 38 : T _u_cpu; 39 : 40 : /// was a CPU copy requested? 41 : bool _cpu_copy_requested; 42 : 43 : /// old states of the tensor 44 : std::vector<T> _u_old; 45 : std::size_t _max_states; 46 : }; 47 : 48 : template <typename T> 49 : InputParameters 50 0 : TensorBuffer<T>::validParams() 51 : { 52 1214 : InputParameters params = TensorBufferBase::validParams(); 53 0 : return params; 54 : } 55 : 56 : template <typename T> 57 874 : TensorBuffer<T>::TensorBuffer(const InputParameters & parameters) 58 874 : : TensorBufferBase(parameters), _cpu_copy_requested(false), _max_states(0) 59 : { 60 874 : } 61 : 62 : template <typename T> 63 : std::size_t 64 826944 : TensorBuffer<T>::advanceState() 65 : { 66 : // make room to push state one step further back 67 826944 : if (_u_old.size() < _max_states) 68 134 : _u_old.resize(_u_old.size() + 1); 69 : 70 : // push state further back 71 826944 : if (!_u_old.empty()) 72 : { 73 166154 : for (std::size_t i = _u_old.size() - 1; i > 0; --i) 74 81560 : _u_old[i] = _u_old[i - 1]; 75 : _u_old[0] = _u; 76 : } 77 : 78 826944 : return _u_old.size(); 79 : } 80 : 81 : template <typename T> 82 : void 83 0 : TensorBuffer<T>::clearStates() 84 : { 85 0 : _u_old.clear(); 86 0 : } 87 : 88 : template <typename T> 89 : const torch::Tensor & 90 588 : TensorBuffer<T>::getRawTensor() const 91 : { 92 588 : return _u; 93 : } 94 : 95 : template <typename T> 96 : const torch::Tensor & 97 1402 : TensorBuffer<T>::getRawCPUTensor() 98 : { 99 1402 : _cpu_copy_requested = true; 100 1402 : return _u_cpu; 101 : } 102 : 103 : template <typename T> 104 : T & 105 0 : TensorBuffer<T>::getTensor() 106 : { 107 1422 : return _u; 108 : } 109 : 110 : template <typename T> 111 : const std::vector<T> & 112 0 : TensorBuffer<T>::getOldTensor(std::size_t states_requested) 113 : { 114 130 : _max_states = std::max(_max_states, states_requested); 115 130 : return _u_old; 116 : } 117 : 118 : /** 119 : * Specialization of this helper struct can be used to force the use of derived 120 : * classes for implicit TensorBuffer construction (i.e. tensors that are not explicitly 121 : * listed under [TensorBuffers]). 122 : */ 123 : template <typename T> 124 : struct TensorBufferSpecialization; 125 : 126 : #define registerTensorType(derived_class, tensor_type) \ 127 : template <> \ 128 : struct TensorBufferSpecialization<tensor_type> \ 129 : { \ 130 : using type = derived_class; \ 131 : }