LCOV - code coverage report
Current view: top level - include/tensor_computes - TensorOperatorBase.h (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 9 16 56.2 %
Date: 2025-09-10 17:10:32 Functions: 6 10 60.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include "MooseObject.h"
      12             : #include "SwiftTypes.h"
      13             : #include "TensorProblem.h"
      14             : #include "DependencyResolverInterface.h"
      15             : #include "SwiftConstantInterface.h"
      16             : 
      17             : #include <torch/torch.h>
      18             : 
      19             : class DomainAction;
      20             : 
      21             : /**
      22             :  * TensorOperatorBase object
      23             :  */
      24             : class TensorOperatorBase : public MooseObject,
      25             :                            public DependencyResolverInterface,
      26             :                            public SwiftConstantInterface
      27             : {
      28             : public:
      29             :   static InputParameters validParams();
      30             : 
      31             :   TensorOperatorBase(const InputParameters & parameters);
      32             : 
      33        1344 :   virtual const std::set<std::string> & getRequestedItems() override { return _requested_buffers; }
      34        2372 :   virtual const std::set<std::string> & getSuppliedItems() override { return _supplied_buffers; }
      35             : 
      36             :   /// Helper to recursively update dependencies for grouped operators
      37           0 :   virtual void updateDependencies() {}
      38             : 
      39             :   /// perform the computation
      40             :   virtual void computeBuffer() = 0;
      41             : 
      42             :   /// called  after all objects have been constructed (before dependency resolution)
      43         900 :   virtual void init() {}
      44             : 
      45             :   /// called  after all objects have been constructed (after dependency resolution)
      46         526 :   virtual void check() {}
      47             : 
      48             :   /// called if the simulation cell dimensions change
      49           0 :   virtual void gridChanged() {}
      50             : 
      51             : protected:
      52             :   template <typename T = torch::Tensor>
      53             :   const T & getInputBuffer(const std::string & param);
      54             : 
      55             :   template <typename T = torch::Tensor>
      56             :   const T & getInputBufferByName(const TensorInputBufferName & buffer_name);
      57             : 
      58             :   template <typename T = torch::Tensor>
      59             :   T & getOutputBuffer(const std::string & param);
      60             : 
      61             :   template <typename T = torch::Tensor>
      62             :   T & getOutputBufferByName(const TensorOutputBufferName & buffer_name);
      63             : 
      64             :   TensorOperatorBase & getCompute(const std::string & param_name);
      65             : 
      66             :   std::set<std::string> _requested_buffers;
      67             :   std::set<std::string> _supplied_buffers;
      68             : 
      69             :   TensorProblem & _tensor_problem;
      70             :   const DomainAction & _domain;
      71             : 
      72             :   /// axes
      73             :   const torch::Tensor &_x, &_y, &_z;
      74             : 
      75             :   /// reciprocal axes
      76             :   const torch::Tensor &_i, &_j, &_k;
      77             : 
      78             :   /// Imaginary unit i
      79             :   const torch::Tensor _imaginary;
      80             : 
      81             :   /// substep time
      82             :   const Real & _time;
      83             : 
      84             :   /// problem dimension
      85             :   const unsigned int & _dim;
      86             : };
      87             : 
      88             : template <typename T>
      89             : const T &
      90         332 : TensorOperatorBase::getInputBuffer(const std::string & param)
      91             : {
      92         332 :   return getInputBufferByName<T>(getParam<TensorInputBufferName>(param));
      93             : }
      94             : 
      95             : template <typename T>
      96             : const T &
      97        1140 : TensorOperatorBase::getInputBufferByName(const TensorInputBufferName & buffer_name)
      98             : {
      99        1140 :   _requested_buffers.insert(buffer_name);
     100        1140 :   return _tensor_problem.getBuffer<T>(buffer_name);
     101             : }
     102             : 
     103             : template <typename T>
     104             : T &
     105           0 : TensorOperatorBase::getOutputBuffer(const std::string & param)
     106             : {
     107           0 :   return getOutputBufferByName<T>(getParam<TensorOutputBufferName>(param));
     108             : }
     109             : 
     110             : template <typename T>
     111             : T &
     112           0 : TensorOperatorBase::getOutputBufferByName(const TensorOutputBufferName & buffer_name)
     113             : {
     114           0 :   _supplied_buffers.insert(buffer_name);
     115           0 :   return _tensor_problem.getBuffer<T>(buffer_name);
     116             : }

Generated by: LCOV version 1.14