LCOV - code coverage report
Current view: top level - include/actions - DomainAction.h (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 7 24 29.2 %
Date: 2025-09-10 17:10:32 Functions: 0 1 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include "Action.h"
      12             : #include <vector>
      13             : #include <string>
      14             : #include <array>
      15             : 
      16             : #include <torch/torch.h>
      17             : 
      18             : /**
      19             :  * The DomainAction manages the [Domain] syntax and simulation domain parameters.
      20             :  */
      21             : class DomainAction : public Action
      22             : {
      23             : public:
      24             :   static InputParameters validParams();
      25             : 
      26             :   DomainAction(const InputParameters & parameters);
      27             : 
      28             :   virtual void act() override;
      29             : 
      30        1486 :   const unsigned int & getDim() const { return _dim; }
      31         152 :   const std::array<int64_t, 3> & getGridSize() const { return _n_global; }
      32             :   const std::array<int64_t, 3> & getReciprocalGridSize() const { return _n_reciprocal_global; }
      33             :   const std::array<int64_t, 3> & getLocalGridSize() const { return _n_local; }
      34             :   const std::array<int64_t, 3> & getLocalReciprocalGridSize() const { return _n_reciprocal_local; }
      35           4 :   const Real & getVolume() const { return _volume_global; }
      36             :   const torch::IntArrayRef & getDimIndices() const { return _domain_dimensions; }
      37             :   const RealVectorValue & getDomainMin() const { return _min_global; }
      38             :   const RealVectorValue & getDomainMax() const { return _max_global; }
      39         140 :   const RealVectorValue & getGridSpacing() const { return _grid_spacing; }
      40             :   const torch::Tensor & getAxis(std::size_t component) const;
      41             :   const torch::Tensor & getReciprocalAxis(std::size_t component) const;
      42             : 
      43             :   int64_t getNumberOfCells() const;
      44             : 
      45             :   /// return X-vector (coordinate) tensor for the local real space domain
      46             :   const torch::Tensor & getXGrid() const;
      47             : 
      48             :   /// return k-vector tensor for the local reciprocal domain
      49             :   const torch::Tensor & getKGrid() const;
      50             : 
      51             :   /// return k-square tensor for the local reciprocal domain
      52             :   const torch::Tensor & getKSquare() const;
      53             : 
      54             :   /// get the maximum spatial frequency
      55             :   const RealVectorValue & getMaxK() const { return _max_k; }
      56             : 
      57             :   /// get the shape of the local domain
      58        1184 :   const torch::IntArrayRef & getShape() const { return _shape; }
      59          84 :   const torch::IntArrayRef & getReciprocalShape() const { return _reciprocal_shape; }
      60             : 
      61             :   std::vector<int64_t> getValueShape(std::vector<int64_t> extra_dims) const;
      62             :   std::vector<int64_t> getReciprocalValueShape(std::initializer_list<int64_t> extra_dims) const;
      63             : 
      64             :   torch::Tensor fft(const torch::Tensor & t) const;
      65             :   torch::Tensor ifft(const torch::Tensor & t) const;
      66             : 
      67             :   /// compute the sum of a tensor, reduced over the spatial dimensions
      68             :   torch::Tensor sum(const torch::Tensor & t) const;
      69             :   /// compute the average of a tensor, reduced over the spatial dimensions
      70             :   torch::Tensor average(const torch::Tensor & t) const;
      71             : 
      72             :   /// align a 1d tensor in a specific dimension
      73             :   torch::Tensor align(torch::Tensor t, unsigned int dim) const;
      74             : 
      75             :   /// check if debugging is enabled
      76      516848 :   bool debug() const { return _debug; }
      77             : 
      78             : protected:
      79             :   void gridChanged();
      80             : 
      81             :   void partitionSerial();
      82             :   void partitionSlabs();
      83             :   void partitionPencils();
      84             : 
      85             :   torch::Tensor fftSerial(const torch::Tensor & t) const;
      86             :   torch::Tensor fftSlab(const torch::Tensor & t) const;
      87             :   torch::Tensor fftPencil(const torch::Tensor & t) const;
      88             : 
      89             :   template <bool is_real>
      90             :   torch::Tensor cosineTransform(const torch::Tensor & t, int64_t axis) const;
      91             : 
      92             :   template <typename T>
      93             :   std::vector<int64_t> partitionHepler(int64_t total, const std::vector<T> & weights);
      94             : 
      95             :   void updateXGrid() const;
      96             :   void updateKGrid() const;
      97             :   void updateKSquare() const;
      98             : 
      99             :   /// device names to be used on the nodes
     100             :   const std::vector<std::string> _device_names;
     101             : 
     102             :   /// device weights to be used on the nodes
     103             :   std::vector<unsigned int> _device_weights;
     104             : 
     105             :   /// device floating precision
     106             :   enum class FloatingPrecision
     107             :   {
     108             :     DEVICE_DEFAULT,
     109             :     SINGLE,
     110             :     DOUBLE
     111             :   } _floating_precision;
     112             : 
     113             :   /// parallelization mode
     114             :   const enum class ParallelMode { NONE, FFT_SLAB, FFT_PENCIL } _parallel_mode;
     115             : 
     116             :   /// host local ranks of all procs
     117             :   std::vector<unsigned int> _local_ranks;
     118             :   std::vector<unsigned int> _local_weights;
     119             : 
     120             :   /// The dimension of the mesh
     121             :   const unsigned int _dim;
     122             : 
     123             :   /// global number of grid points in real space
     124             :   const std::array<int64_t, 3> _n_global;
     125             : 
     126             :   /// global number of grid points in real space
     127             :   std::array<int64_t, 3> _n_reciprocal_global;
     128             : 
     129             :   /// local number of grid points in real space
     130             :   std::array<int64_t, 3> _n_local;
     131             : 
     132             :   /// local number of grid points in real space
     133             :   std::array<int64_t, 3> _n_reciprocal_local;
     134             : 
     135             :   /// local begin/end indixes along each direction for slabs/pencils
     136             :   std::array<std::vector<int64_t>, 3> _local_begin;
     137             :   std::array<std::vector<int64_t>, 3> _local_end;
     138             :   std::array<std::vector<int64_t>, 3> _n_local_all;
     139             : 
     140             :   ///@{ global domain length in each dimension
     141             :   const RealVectorValue _min_global;
     142             :   const RealVectorValue _max_global;
     143             :   ///@}
     144             : 
     145             :   /// Volume of the simulation domain in real space
     146             :   Real _volume_global;
     147             : 
     148             :   const enum class MeshMode { SWIFT_DUMMY, SWIFT_DOMAIN, SWIFT_MANUAL } _mesh_mode;
     149             : 
     150             :   /// grid spacing
     151             :   RealVectorValue _grid_spacing;
     152             : 
     153             :   /// real space axes
     154             :   std::array<torch::Tensor, 3> _global_axis;
     155             :   std::array<torch::Tensor, 3> _local_axis;
     156             : 
     157             :   /// reciprocal space axes
     158             :   std::array<torch::Tensor, 3> _global_reciprocal_axis;
     159             :   std::array<torch::Tensor, 3> _local_reciprocal_axis;
     160             : 
     161             :   /// X-grid (cordinate vectors - built only if requested)
     162             :   mutable torch::Tensor _x_grid;
     163             : 
     164             :   /// k-grid (built only if requested)
     165             :   mutable torch::Tensor _k_grid;
     166             : 
     167             :   /// k-square (built only if requested)
     168             :   mutable torch::Tensor _k_square;
     169             : 
     170             :   /// largest frequency along each axis
     171             :   RealVectorValue _max_k;
     172             : 
     173             :   /// domain shape
     174             :   torch::IntArrayRef _shape;
     175             :   torch::IntArrayRef _reciprocal_shape;
     176             : 
     177             :   /// domain dimensions ({0},{0,1},or {0,1,2})
     178             :   const std::array<int64_t, 3> _domain_dimensions_buffer;
     179             :   const torch::IntArrayRef _domain_dimensions;
     180             : 
     181             :   /// MPI rank
     182             :   unsigned int _rank;
     183             : 
     184             :   /// number of MPI ranks
     185             :   unsigned int _n_rank;
     186             : 
     187             :   /// send tensors
     188             :   mutable std::vector<torch::Tensor> _send_tensor;
     189             :   /// receive buffer
     190             :   mutable std::vector<std::vector<double>> _recv_data;
     191             :   /// receive tensors
     192             :   mutable std::vector<torch::Tensor> _recv_tensor;
     193             : 
     194             :   /// enable debugging
     195             :   const bool _debug;
     196             : };
     197             : 
     198             : template <typename T>
     199             : std::vector<int64_t>
     200           0 : DomainAction::partitionHepler(int64_t total, const std::vector<T> & weights)
     201             : {
     202             :   std::vector<int64_t> ns;
     203             : 
     204           0 :   T remaining_total_weight = 0;
     205           0 :   for (const auto w : weights)
     206           0 :     remaining_total_weight += w;
     207             : 
     208           0 :   for (const auto w : weights)
     209             :   {
     210           0 :     if (remaining_total_weight == 0)
     211           0 :       mooseError("Internal partitioning error. remaining_total_weight ",
     212             :                  remaining_total_weight,
     213             :                  " == 0 ",
     214           0 :                  _rank);
     215             : 
     216             :     // assign at least one layer
     217           0 :     const auto n = std::max((total * w) / remaining_total_weight, int64_t(1));
     218           0 :     ns.push_back(n);
     219             : 
     220           0 :     remaining_total_weight -= w;
     221             : 
     222           0 :     if (total < n)
     223           0 :       mooseError("Internal partitioning error.");
     224             : 
     225           0 :     total -= n;
     226             :   }
     227             : 
     228             :   // add remainsder to last slice
     229           0 :   ns.back() += total;
     230           0 :   return ns;
     231           0 : }
     232             : 
     233             : // See Makhoul 2003 (DOI: 10.1109/TASSP.1980.1163351)
     234             : template <bool is_real>
     235             : torch::Tensor
     236             : DomainAction::cosineTransform(const torch::Tensor & t, int64_t axis) const
     237             : {
     238             :   // size along the axis
     239             :   // const auto l = t.sizes()[axis];
     240             : 
     241             :   // mirror tensor and stack onto itself (with one layer removed)
     242             :   auto t_flip = torch::flip(t, {axis});
     243             : 
     244             :   // stack tensor along axis
     245             :   auto t_stacked = torch::stack({t, t_flip}, axis);
     246             : 
     247             :   // perform 1D FFT along the selected axis and slice in the reciprocal domain
     248             :   torch::Tensor t_bar;
     249             :   if constexpr (is_real)
     250             :     t_bar = torch::fft::rfft(t_stacked, -1, axis);
     251             :   else
     252             :     t_bar = torch::fft::fft(t_stacked, -1, axis);
     253             : 
     254             :   mooseError("Not implemented!");
     255             :   // return t_bar;
     256             : }

Generated by: LCOV version 1.14