Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #pragma once 10 : 11 : #include "MooseObject.h" 12 : #include "SwiftTypes.h" 13 : #include "TensorProblem.h" 14 : #include "DependencyResolverInterface.h" 15 : #include "SwiftConstantInterface.h" 16 : 17 : #include <torch/torch.h> 18 : 19 : class DomainAction; 20 : 21 : /** 22 : * TensorOperatorBase object 23 : */ 24 : class TensorOperatorBase : public MooseObject, 25 : public DependencyResolverInterface, 26 : public SwiftConstantInterface 27 : { 28 : public: 29 : static InputParameters validParams(); 30 : 31 : TensorOperatorBase(const InputParameters & parameters); 32 : 33 1344 : virtual const std::set<std::string> & getRequestedItems() override { return _requested_buffers; } 34 2372 : virtual const std::set<std::string> & getSuppliedItems() override { return _supplied_buffers; } 35 : 36 : /// Helper to recursively update dependencies for grouped operators 37 0 : virtual void updateDependencies() {} 38 : 39 : /// perform the computation 40 : virtual void computeBuffer() = 0; 41 : 42 : /// called after all objects have been constructed (before dependency resolution) 43 900 : virtual void init() {} 44 : 45 : /// called after all objects have been constructed (after dependency resolution) 46 526 : virtual void check() {} 47 : 48 : /// called if the simulation cell dimensions change 49 0 : virtual void gridChanged() {} 50 : 51 : protected: 52 : template <typename T = torch::Tensor> 53 : const T & getInputBuffer(const std::string & param); 54 : 55 : template <typename T = torch::Tensor> 56 : const T & getInputBufferByName(const TensorInputBufferName & buffer_name); 57 : 58 : template <typename T = torch::Tensor> 59 : T & getOutputBuffer(const std::string & param); 60 : 61 : template <typename T = torch::Tensor> 62 : T & getOutputBufferByName(const TensorOutputBufferName & buffer_name); 63 : 64 : TensorOperatorBase & getCompute(const std::string & param_name); 65 : 66 : std::set<std::string> _requested_buffers; 67 : std::set<std::string> _supplied_buffers; 68 : 69 : TensorProblem & _tensor_problem; 70 : const DomainAction & _domain; 71 : 72 : /// axes 73 : const torch::Tensor &_x, &_y, &_z; 74 : 75 : /// reciprocal axes 76 : const torch::Tensor &_i, &_j, &_k; 77 : 78 : /// Imaginary unit i 79 : const torch::Tensor _imaginary; 80 : 81 : /// substep time 82 : const Real & _time; 83 : 84 : /// problem dimension 85 : const unsigned int & _dim; 86 : }; 87 : 88 : template <typename T> 89 : const T & 90 332 : TensorOperatorBase::getInputBuffer(const std::string & param) 91 : { 92 332 : return getInputBufferByName<T>(getParam<TensorInputBufferName>(param)); 93 : } 94 : 95 : template <typename T> 96 : const T & 97 1140 : TensorOperatorBase::getInputBufferByName(const TensorInputBufferName & buffer_name) 98 : { 99 1140 : _requested_buffers.insert(buffer_name); 100 1140 : return _tensor_problem.getBuffer<T>(buffer_name); 101 : } 102 : 103 : template <typename T> 104 : T & 105 0 : TensorOperatorBase::getOutputBuffer(const std::string & param) 106 : { 107 0 : return getOutputBufferByName<T>(getParam<TensorOutputBufferName>(param)); 108 : } 109 : 110 : template <typename T> 111 : T & 112 0 : TensorOperatorBase::getOutputBufferByName(const TensorOutputBufferName & buffer_name) 113 : { 114 0 : _supplied_buffers.insert(buffer_name); 115 0 : return _tensor_problem.getBuffer<T>(buffer_name); 116 : }