Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #pragma once
10 :
11 : #include "Action.h"
12 : #include <vector>
13 : #include <string>
14 : #include <array>
15 :
16 : #include <torch/torch.h>
17 :
18 : /**
19 : * The DomainAction manages the [Domain] syntax and simulation domain parameters.
20 : */
21 : class DomainAction : public Action
22 : {
23 : public:
24 : static InputParameters validParams();
25 :
26 : DomainAction(const InputParameters & parameters);
27 :
28 : virtual void act() override;
29 :
30 1486 : const unsigned int & getDim() const { return _dim; }
31 152 : const std::array<int64_t, 3> & getGridSize() const { return _n_global; }
32 : const std::array<int64_t, 3> & getReciprocalGridSize() const { return _n_reciprocal_global; }
33 : const std::array<int64_t, 3> & getLocalGridSize() const { return _n_local; }
34 : const std::array<int64_t, 3> & getLocalReciprocalGridSize() const { return _n_reciprocal_local; }
35 4 : const Real & getVolume() const { return _volume_global; }
36 : const torch::IntArrayRef & getDimIndices() const { return _domain_dimensions; }
37 : const RealVectorValue & getDomainMin() const { return _min_global; }
38 : const RealVectorValue & getDomainMax() const { return _max_global; }
39 140 : const RealVectorValue & getGridSpacing() const { return _grid_spacing; }
40 : const torch::Tensor & getAxis(std::size_t component) const;
41 : const torch::Tensor & getReciprocalAxis(std::size_t component) const;
42 :
43 : int64_t getNumberOfCells() const;
44 :
45 : /// return X-vector (coordinate) tensor for the local real space domain
46 : const torch::Tensor & getXGrid() const;
47 :
48 : /// return k-vector tensor for the local reciprocal domain
49 : const torch::Tensor & getKGrid() const;
50 :
51 : /// return k-square tensor for the local reciprocal domain
52 : const torch::Tensor & getKSquare() const;
53 :
54 : /// get the maximum spatial frequency
55 : const RealVectorValue & getMaxK() const { return _max_k; }
56 :
57 : /// get the shape of the local domain
58 1184 : const torch::IntArrayRef & getShape() const { return _shape; }
59 84 : const torch::IntArrayRef & getReciprocalShape() const { return _reciprocal_shape; }
60 :
61 : std::vector<int64_t> getValueShape(std::vector<int64_t> extra_dims) const;
62 : std::vector<int64_t> getReciprocalValueShape(std::initializer_list<int64_t> extra_dims) const;
63 :
64 : torch::Tensor fft(const torch::Tensor & t) const;
65 : torch::Tensor ifft(const torch::Tensor & t) const;
66 :
67 : /// compute the sum of a tensor, reduced over the spatial dimensions
68 : torch::Tensor sum(const torch::Tensor & t) const;
69 : /// compute the average of a tensor, reduced over the spatial dimensions
70 : torch::Tensor average(const torch::Tensor & t) const;
71 :
72 : /// align a 1d tensor in a specific dimension
73 : torch::Tensor align(torch::Tensor t, unsigned int dim) const;
74 :
75 : /// check if debugging is enabled
76 516848 : bool debug() const { return _debug; }
77 :
78 : protected:
79 : void gridChanged();
80 :
81 : void partitionSerial();
82 : void partitionSlabs();
83 : void partitionPencils();
84 :
85 : torch::Tensor fftSerial(const torch::Tensor & t) const;
86 : torch::Tensor fftSlab(const torch::Tensor & t) const;
87 : torch::Tensor fftPencil(const torch::Tensor & t) const;
88 :
89 : template <bool is_real>
90 : torch::Tensor cosineTransform(const torch::Tensor & t, int64_t axis) const;
91 :
92 : template <typename T>
93 : std::vector<int64_t> partitionHepler(int64_t total, const std::vector<T> & weights);
94 :
95 : void updateXGrid() const;
96 : void updateKGrid() const;
97 : void updateKSquare() const;
98 :
99 : /// device names to be used on the nodes
100 : const std::vector<std::string> _device_names;
101 :
102 : /// device weights to be used on the nodes
103 : std::vector<unsigned int> _device_weights;
104 :
105 : /// device floating precision
106 : enum class FloatingPrecision
107 : {
108 : DEVICE_DEFAULT,
109 : SINGLE,
110 : DOUBLE
111 : } _floating_precision;
112 :
113 : /// parallelization mode
114 : const enum class ParallelMode { NONE, FFT_SLAB, FFT_PENCIL } _parallel_mode;
115 :
116 : /// host local ranks of all procs
117 : std::vector<unsigned int> _local_ranks;
118 : std::vector<unsigned int> _local_weights;
119 :
120 : /// The dimension of the mesh
121 : const unsigned int _dim;
122 :
123 : /// global number of grid points in real space
124 : const std::array<int64_t, 3> _n_global;
125 :
126 : /// global number of grid points in real space
127 : std::array<int64_t, 3> _n_reciprocal_global;
128 :
129 : /// local number of grid points in real space
130 : std::array<int64_t, 3> _n_local;
131 :
132 : /// local number of grid points in real space
133 : std::array<int64_t, 3> _n_reciprocal_local;
134 :
135 : /// local begin/end indixes along each direction for slabs/pencils
136 : std::array<std::vector<int64_t>, 3> _local_begin;
137 : std::array<std::vector<int64_t>, 3> _local_end;
138 : std::array<std::vector<int64_t>, 3> _n_local_all;
139 :
140 : ///@{ global domain length in each dimension
141 : const RealVectorValue _min_global;
142 : const RealVectorValue _max_global;
143 : ///@}
144 :
145 : /// Volume of the simulation domain in real space
146 : Real _volume_global;
147 :
148 : const enum class MeshMode { SWIFT_DUMMY, SWIFT_DOMAIN, SWIFT_MANUAL } _mesh_mode;
149 :
150 : /// grid spacing
151 : RealVectorValue _grid_spacing;
152 :
153 : /// real space axes
154 : std::array<torch::Tensor, 3> _global_axis;
155 : std::array<torch::Tensor, 3> _local_axis;
156 :
157 : /// reciprocal space axes
158 : std::array<torch::Tensor, 3> _global_reciprocal_axis;
159 : std::array<torch::Tensor, 3> _local_reciprocal_axis;
160 :
161 : /// X-grid (cordinate vectors - built only if requested)
162 : mutable torch::Tensor _x_grid;
163 :
164 : /// k-grid (built only if requested)
165 : mutable torch::Tensor _k_grid;
166 :
167 : /// k-square (built only if requested)
168 : mutable torch::Tensor _k_square;
169 :
170 : /// largest frequency along each axis
171 : RealVectorValue _max_k;
172 :
173 : /// domain shape
174 : torch::IntArrayRef _shape;
175 : torch::IntArrayRef _reciprocal_shape;
176 :
177 : /// domain dimensions ({0},{0,1},or {0,1,2})
178 : const std::array<int64_t, 3> _domain_dimensions_buffer;
179 : const torch::IntArrayRef _domain_dimensions;
180 :
181 : /// MPI rank
182 : unsigned int _rank;
183 :
184 : /// number of MPI ranks
185 : unsigned int _n_rank;
186 :
187 : /// send tensors
188 : mutable std::vector<torch::Tensor> _send_tensor;
189 : /// receive buffer
190 : mutable std::vector<std::vector<double>> _recv_data;
191 : /// receive tensors
192 : mutable std::vector<torch::Tensor> _recv_tensor;
193 :
194 : /// enable debugging
195 : const bool _debug;
196 : };
197 :
198 : template <typename T>
199 : std::vector<int64_t>
200 0 : DomainAction::partitionHepler(int64_t total, const std::vector<T> & weights)
201 : {
202 : std::vector<int64_t> ns;
203 :
204 0 : T remaining_total_weight = 0;
205 0 : for (const auto w : weights)
206 0 : remaining_total_weight += w;
207 :
208 0 : for (const auto w : weights)
209 : {
210 0 : if (remaining_total_weight == 0)
211 0 : mooseError("Internal partitioning error. remaining_total_weight ",
212 : remaining_total_weight,
213 : " == 0 ",
214 0 : _rank);
215 :
216 : // assign at least one layer
217 0 : const auto n = std::max((total * w) / remaining_total_weight, int64_t(1));
218 0 : ns.push_back(n);
219 :
220 0 : remaining_total_weight -= w;
221 :
222 0 : if (total < n)
223 0 : mooseError("Internal partitioning error.");
224 :
225 0 : total -= n;
226 : }
227 :
228 : // add remainsder to last slice
229 0 : ns.back() += total;
230 0 : return ns;
231 0 : }
232 :
233 : // See Makhoul 2003 (DOI: 10.1109/TASSP.1980.1163351)
234 : template <bool is_real>
235 : torch::Tensor
236 : DomainAction::cosineTransform(const torch::Tensor & t, int64_t axis) const
237 : {
238 : // size along the axis
239 : // const auto l = t.sizes()[axis];
240 :
241 : // mirror tensor and stack onto itself (with one layer removed)
242 : auto t_flip = torch::flip(t, {axis});
243 :
244 : // stack tensor along axis
245 : auto t_stacked = torch::stack({t, t_flip}, axis);
246 :
247 : // perform 1D FFT along the selected axis and slice in the reciprocal domain
248 : torch::Tensor t_bar;
249 : if constexpr (is_real)
250 : t_bar = torch::fft::rfft(t_stacked, -1, axis);
251 : else
252 : t_bar = torch::fft::fft(t_stacked, -1, axis);
253 :
254 : mooseError("Not implemented!");
255 : // return t_bar;
256 : }
|