Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #include "DomainAction.h"
10 : #include "MooseError.h"
11 : #include "TensorProblem.h"
12 : #include "MooseEnum.h"
13 : #include "SetupMeshAction.h"
14 : #include "SwiftApp.h"
15 : #include "CreateProblemAction.h"
16 :
17 : #include <initializer_list>
18 : #include <util/Optional.h>
19 :
20 : // run this early, before any objects are constructed
21 : registerMooseAction("SwiftApp", DomainAction, "meta_action");
22 : registerMooseAction("SwiftApp", DomainAction, "add_mesh_generator");
23 : registerMooseAction("SwiftApp", DomainAction, "create_problem_custom");
24 :
25 : InputParameters
26 128 : DomainAction::validParams()
27 : {
28 128 : InputParameters params = Action::validParams();
29 128 : params.addClassDescription("Set up the domain and compute devices.");
30 :
31 256 : MooseEnum dims("1=1 2 3");
32 256 : params.addRequiredParam<MooseEnum>("dim", dims, "Problem dimension");
33 :
34 256 : MooseEnum parmode("NONE FFT_SLAB FFT_PENCIL", "NONE");
35 256 : parmode.addDocumentation("NONE", "Serial execution without domain decomposition.");
36 256 : parmode.addDocumentation("FFT_SLAB",
37 : "Slab decomposition with X-Z slabs stacked along the Y direction in "
38 : "real space and Y-Z slabs stacked along the X direction in Fourier "
39 : "space. This requires one all-to-all communication per FFT.");
40 256 : parmode.addDocumentation(
41 : "FFT_PENCIL",
42 : "Pencil decomposition (3D only). Three 1D FFTs in pencil arrays along the X, Y, and lastly Z "
43 : "direction. Thie requires two many-to-many communications per FFT.");
44 :
45 256 : params.addParam<MooseEnum>("parallel_mode", parmode, "Parallelization mode.");
46 :
47 256 : params.addParam<unsigned int>("nx", 1, "Number of elements in the X direction");
48 256 : params.addParam<unsigned int>("ny", 1, "Number of elements in the Y direction");
49 256 : params.addParam<unsigned int>("nz", 1, "Number of elements in the Z direction");
50 256 : params.addParam<Real>("xmax", 1.0, "Upper X Coordinate of the generated mesh");
51 256 : params.addParam<Real>("ymax", 1.0, "Upper Y Coordinate of the generated mesh");
52 256 : params.addParam<Real>("zmax", 1.0, "Upper Z Coordinate of the generated mesh");
53 256 : params.addParam<Real>("xmin", 0.0, "Lower X Coordinate of the generated mesh");
54 256 : params.addParam<Real>("ymin", 0.0, "Lower Y Coordinate of the generated mesh");
55 256 : params.addParam<Real>("zmin", 0.0, "Lower Z Coordinate of the generated mesh");
56 :
57 256 : MooseEnum meshmode("DUMMY DOMAIN MANUAL", "DUMMY");
58 256 : meshmode.addDocumentation("DUMMY",
59 : "Create a single element mesh the size of the simulation domain");
60 256 : meshmode.addDocumentation("DOMAIN", "Create a mesh with one element per grid cell");
61 256 : meshmode.addDocumentation("MANUAL",
62 : "Do not auto-generate a mesh. User must add a Mesh block themselves.");
63 :
64 256 : params.addParam<MooseEnum>("mesh_mode", meshmode, "Mesh generation mode.");
65 :
66 256 : params.addParam<std::vector<std::string>>("device_names", {}, "Compute devices to run on.");
67 256 : params.addParam<std::vector<unsigned int>>(
68 : "device_weights", {}, "Device weights (or speeds) to influence the partitioning.");
69 :
70 256 : MooseEnum floatingPrecision("DEVICE_DEFAULT SINGLE DOUBLE", "DEVICE_DEFAULT");
71 256 : params.addParam<MooseEnum>("floating_precision", floatingPrecision, "Floating point precision.");
72 :
73 256 : params.addParam<bool>(
74 : "debug",
75 256 : false,
76 : "Enable additional debugging and diagnostics, such a checking for initialized tensors.");
77 128 : return params;
78 128 : }
79 :
80 128 : DomainAction::DomainAction(const InputParameters & parameters)
81 : : Action(parameters),
82 128 : _device_names(getParam<std::vector<std::string>>("device_names")),
83 256 : _device_weights(getParam<std::vector<unsigned int>>("device_weights")),
84 256 : _floating_precision(getParam<MooseEnum>("floating_precision").getEnum<FloatingPrecision>()),
85 256 : _parallel_mode(getParam<MooseEnum>("parallel_mode").getEnum<ParallelMode>()),
86 256 : _dim(getParam<MooseEnum>("dim")),
87 384 : _n_global(
88 512 : {getParam<unsigned int>("nx"), getParam<unsigned int>("ny"), getParam<unsigned int>("nz")}),
89 512 : _min_global({getParam<Real>("xmin"), getParam<Real>("ymin"), getParam<Real>("zmin")}),
90 512 : _max_global({getParam<Real>("xmax"), getParam<Real>("ymax"), getParam<Real>("zmax")}),
91 256 : _mesh_mode(getParam<MooseEnum>("mesh_mode").getEnum<MeshMode>()),
92 128 : _shape(torch::IntArrayRef(_n_local.data(), _dim)),
93 : _reciprocal_shape(torch::IntArrayRef(_n_reciprocal_local.data(), _dim)),
94 128 : _domain_dimensions_buffer({0, 1, 2}),
95 : _domain_dimensions(torch::IntArrayRef(_domain_dimensions_buffer.data(), _dim)),
96 128 : _rank(_communicator.rank()),
97 128 : _n_rank(_communicator.size()),
98 128 : _send_tensor(_n_rank),
99 128 : _recv_tensor(_n_rank),
100 384 : _debug(getParam<bool>("debug"))
101 : {
102 128 : if (_parallel_mode == ParallelMode::NONE && comm().size() > 1)
103 0 : paramError("parallel_mode", "NONE requires the application to run in serial.");
104 :
105 128 : if (_device_names.empty())
106 : {
107 108 : if (comm().size() > 1)
108 0 : mooseError("Specify Domain/device_names for parallel operation.");
109 :
110 : // set local weights and ranks for serial
111 108 : _local_ranks = {0};
112 108 : _local_weights = {1};
113 : }
114 : else
115 : {
116 : // process weights
117 20 : if (_device_weights.empty())
118 20 : _device_weights.assign(1, _device_names.size());
119 :
120 20 : if (_device_weights.size() != _device_names.size())
121 0 : mooseError("Specify one weight per device or none at all");
122 :
123 : // determine the processor name
124 : char name[MPI_MAX_PROCESSOR_NAME + 1];
125 : int len;
126 20 : MPI_Get_processor_name(name, &len);
127 20 : name[len] = 0;
128 :
129 : // gather all processor names
130 : std::vector<std::string> host_names;
131 40 : _communicator.allgather(std::string(name), host_names);
132 :
133 : // get the local rank on the current processor (used for compute device assignment)
134 : std::map<std::string, unsigned int> host_rank_count;
135 :
136 40 : for (const auto & host_name : host_names)
137 : {
138 40 : if (host_rank_count.find(name) == host_rank_count.end())
139 20 : host_rank_count[host_name] = 0;
140 :
141 20 : auto & local_rank = host_rank_count[host_name];
142 20 : _local_ranks.push_back(local_rank);
143 20 : _local_weights.push_back(_device_weights[local_rank % _device_weights.size()]);
144 :
145 : // std::cout << "Process on " << host_name << ' ' << local_rank << ' '
146 : // << _device_weights[local_rank % _device_weights.size()] << '\n';
147 :
148 20 : local_rank++;
149 : }
150 :
151 : // for (const auto i : index_range(host_names))
152 : // std::cout << host_names[i] << '\t' << _local_ranks[i] << '\n';
153 :
154 : // pick a compute device for a list of available devices
155 20 : auto swift_app = dynamic_cast<SwiftApp *>(&_app);
156 20 : if (!swift_app)
157 0 : mooseError("This action requires a SwftApp object to be present.");
158 20 : swift_app->setTorchDevice(_device_names[_local_ranks[_rank] % _device_names.size()], {});
159 :
160 20 : switch (_floating_precision)
161 : {
162 20 : case FloatingPrecision::DEVICE_DEFAULT:
163 : {
164 20 : swift_app->setTorchPrecision("DEVICE_DEFAULT", {});
165 20 : break;
166 : }
167 0 : case FloatingPrecision::DOUBLE:
168 : {
169 0 : swift_app->setTorchPrecision("DOUBLE", {});
170 0 : break;
171 : }
172 0 : case FloatingPrecision::SINGLE:
173 : {
174 0 : swift_app->setTorchPrecision("SINGLE", {});
175 0 : break;
176 : }
177 0 : default:
178 0 : mooseError("Invalid floating precision.");
179 : };
180 20 : }
181 :
182 : // domain partitioning
183 128 : gridChanged();
184 128 : }
185 :
186 : void
187 128 : DomainAction::gridChanged()
188 : {
189 128 : auto options = MooseTensor::floatTensorOptions();
190 :
191 : // build real space axes
192 128 : _volume_global = 1.0;
193 512 : for (const unsigned int dim : {0, 1, 2})
194 : {
195 : // error check
196 384 : if (_max_global(dim) <= _min_global(dim))
197 0 : mooseError("Max coordinate must be larger than the min coordinate in every dimension");
198 :
199 : // get grid geometry
200 384 : _grid_spacing(dim) = (_max_global(dim) - _min_global(dim)) / _n_global[dim];
201 :
202 : // real space axis
203 384 : if (dim < _dim)
204 : {
205 : _global_axis[dim] =
206 532 : align(torch::linspace(c10::Scalar(_min_global(dim) + _grid_spacing(dim) / 2.0),
207 266 : c10::Scalar(_max_global(dim) - _grid_spacing(dim) / 2.0),
208 : _n_global[dim],
209 : options),
210 : dim);
211 266 : _volume_global *= _max_global(dim) - _min_global(dim);
212 : }
213 : else
214 354 : _global_axis[dim] = torch::tensor({0.0}, options);
215 : }
216 :
217 : // build reciprocal space axes
218 512 : for (const unsigned int dim : {0, 1, 2})
219 : {
220 384 : if (dim < _dim)
221 : {
222 266 : const auto freq = (dim == _dim - 1)
223 266 : ? torch::fft::rfftfreq(_n_global[dim], _grid_spacing(dim), options)
224 266 : : torch::fft::fftfreq(_n_global[dim], _grid_spacing(dim), options);
225 :
226 : // zero out nyquist frequency
227 : // if (_n_global[dim] % 2 == 0)
228 : // freq[_n_global[dim] / 2] = 0.0;
229 :
230 532 : _global_reciprocal_axis[dim] = align(freq * 2.0 * libMesh::pi, dim);
231 : }
232 : else
233 354 : _global_reciprocal_axis[dim] = torch::tensor({0.0}, options);
234 :
235 : // compute max frequency along each axis
236 384 : _max_k(dim) = libMesh::pi / _grid_spacing(dim);
237 :
238 : // get global reciprocal axis size
239 384 : _n_reciprocal_global[dim] = _global_reciprocal_axis[dim].sizes()[dim];
240 : }
241 :
242 128 : switch (_parallel_mode)
243 : {
244 128 : case ParallelMode::NONE:
245 128 : partitionSerial();
246 : break;
247 :
248 0 : case ParallelMode::FFT_SLAB:
249 0 : partitionSlabs();
250 : break;
251 :
252 0 : case ParallelMode::FFT_PENCIL:
253 0 : partitionPencils();
254 : break;
255 : }
256 :
257 : // get local reciprocal axis size
258 512 : for (const auto dim : {0, 1, 2})
259 384 : _n_reciprocal_local[dim] = _local_reciprocal_axis[dim].sizes()[dim];
260 :
261 : // update on-demand grids
262 128 : if (_x_grid.defined())
263 0 : updateXGrid();
264 128 : if (_k_grid.defined())
265 0 : updateKGrid();
266 128 : if (_k_square.defined())
267 0 : updateKSquare();
268 128 : }
269 :
270 : void
271 128 : DomainAction::partitionSerial()
272 : {
273 : // goes along the full dimension for each rank
274 512 : for (const auto d : make_range(3u))
275 : {
276 384 : _local_begin[d].resize(_n_rank);
277 384 : _local_end[d].resize(_n_rank);
278 768 : for (const auto i : make_range(_communicator.size()))
279 : {
280 384 : _local_begin[d][i] = 0;
281 384 : _local_end[d][i] = _n_global[d];
282 : }
283 : }
284 :
285 : // to do, make those slices dependent on local begin/end
286 128 : _local_axis = _global_axis;
287 128 : _n_local = _n_global;
288 128 : _local_reciprocal_axis = _global_reciprocal_axis;
289 128 : }
290 :
291 : void
292 0 : DomainAction::partitionSlabs()
293 : {
294 0 : if (_dim < 2)
295 0 : paramError("dim", "Dimension must be 2 or 3 for slab decomposition.");
296 :
297 : // x is partitioned along a halved dimension due to the use of rfft
298 0 : _n_local_all[0] = partitionHepler(_global_reciprocal_axis[0].sizes()[0], _device_weights);
299 :
300 : // y is partitioned along the y realspace axis
301 0 : _n_local_all[1] = partitionHepler(_global_axis[1].sizes()[1], _device_weights);
302 :
303 : // set begin/end for x and y
304 0 : for (const auto d : {0, 1})
305 : {
306 : int64_t b = 0;
307 0 : for (const auto r : index_range(_n_local_all[d]))
308 : {
309 0 : _local_begin[d][r] = b;
310 0 : b += _n_local_all[d][r];
311 0 : _local_end[d][r] = b;
312 : }
313 : }
314 :
315 : // z is not partitioned at all
316 0 : _n_local_all[2].assign(_n_rank, _n_global[2]);
317 0 : _local_begin[2].assign(_n_rank, 0);
318 0 : _local_end[2].assign(_n_rank, _n_global[2]);
319 :
320 : // slice the real space into x-z slabs stacked in y direction
321 0 : _local_axis[0] = _global_axis[0].slice(0, 0, _n_global[0]);
322 0 : _local_axis[1] = _global_axis[1].slice(1, _local_begin[1][_rank], _local_end[1][_rank]);
323 0 : _n_local[0] = _n_global[0];
324 0 : _n_local[1] = _local_end[1][_rank] - _local_begin[1][_rank];
325 :
326 : // slice the reciprocal space into y-z slices stacked in x direction
327 : _local_reciprocal_axis[0] =
328 0 : _global_reciprocal_axis[0].slice(0, 0, _local_begin[0][_rank], _local_end[0][_rank]);
329 0 : _local_reciprocal_axis[1] = _global_reciprocal_axis[1].slice(1, 0, _n_reciprocal_global[1]);
330 :
331 0 : _n_local[2] = _n_global[2];
332 :
333 : // special casing this should not be neccessary
334 0 : if (_dim == 3)
335 : {
336 0 : _local_axis[2] = _global_axis[2].slice(2, 0, _n_global[2]);
337 0 : _local_reciprocal_axis[2] = _global_reciprocal_axis[2].slice(2, 0, _n_reciprocal_global[2]);
338 : }
339 : else
340 : {
341 : _local_axis[2] = _global_axis[2];
342 : _local_reciprocal_axis[2] = _global_reciprocal_axis[2];
343 : }
344 :
345 : // allocate receive buffer
346 0 : for (const auto i : make_range(_communicator.size()))
347 0 : if (i != _rank)
348 0 : _recv_data[i].resize(_n_local_all[0][_rank] * _n_local_all[1][i] * _n_local_all[2][i]);
349 0 : }
350 :
351 : void
352 0 : DomainAction::partitionPencils()
353 : {
354 0 : if (_dim < 3)
355 0 : paramError("dim", "Dimension must be 3 for pencil decomposition.");
356 0 : paramError("parallel_mode", "Not implemented yet!");
357 : }
358 :
359 : void
360 384 : DomainAction::act()
361 : {
362 384 : if (_current_task == "meta_action" && _mesh_mode != MeshMode::SWIFT_MANUAL)
363 : {
364 : // check if a SetupMesh action exists
365 128 : auto mesh_actions = _awh.getActions<SetupMeshAction>();
366 128 : if (mesh_actions.size() > 0)
367 0 : paramError("mesh_mode", "Do not specify a [Mesh] block unless mesh_mode is set to MANUAL");
368 :
369 : // otherwise create one
370 128 : auto & af = _app.getActionFactory();
371 128 : InputParameters action_params = af.getValidParams("SetupMeshAction");
372 : auto action = std::static_pointer_cast<MooseObjectAction>(
373 256 : af.create("SetupMeshAction", "Mesh", action_params));
374 384 : _app.actionWarehouse().addActionBlock(action);
375 128 : }
376 :
377 : // add a DomainMeshGenerator
378 384 : if (_current_task == "add_mesh_generator" && _mesh_mode != MeshMode::SWIFT_MANUAL)
379 : {
380 : // Don't do mesh generators when recovering or when the user has requested for us not to
381 128 : if ((_app.isRecovering() && _app.isUltimateMaster()) || _app.masterMesh())
382 0 : return;
383 :
384 : const MeshGeneratorName name = "domain_mesh_generator";
385 128 : auto params = _factory.getValidParams("DomainMeshGenerator");
386 :
387 128 : params.set<MooseEnum>("dim") = _dim;
388 128 : params.set<Real>("xmax") = _max_global(0);
389 128 : params.set<Real>("ymax") = _max_global(1);
390 128 : params.set<Real>("zmax") = _max_global(2);
391 128 : params.set<Real>("xmin") = _min_global(0);
392 128 : params.set<Real>("ymin") = _min_global(1);
393 128 : params.set<Real>("zmin") = _min_global(2);
394 :
395 128 : if (_mesh_mode == MeshMode::SWIFT_DOMAIN)
396 : {
397 30 : params.set<unsigned int>("nx") = _n_global[0];
398 30 : params.set<unsigned int>("ny") = _n_global[1];
399 30 : params.set<unsigned int>("nz") = _n_global[2];
400 : }
401 98 : else if (_mesh_mode == MeshMode::SWIFT_DUMMY)
402 : {
403 98 : params.set<unsigned int>("nx") = 1;
404 98 : params.set<unsigned int>("ny") = 1;
405 98 : params.set<unsigned int>("nz") = 1;
406 : }
407 : else
408 0 : mooseError("Internal error");
409 :
410 128 : _app.addMeshGenerator("DomainMeshGenerator", name, params);
411 128 : }
412 :
413 384 : if (_current_task == "create_problem_custom")
414 : {
415 128 : if (!_problem)
416 : {
417 0 : const std::string type = "TensorProblem";
418 0 : auto params = _factory.getValidParams(type);
419 :
420 : // apply common parameters of the object held by CreateProblemAction
421 : // to honor user inputs in [Problem]
422 0 : auto p = _awh.getActionByTask<CreateProblemAction>("create_problem");
423 0 : if (p)
424 0 : params.applyParameters(p->getObjectParams());
425 :
426 0 : params.set<MooseMesh *>("mesh") = _mesh.get();
427 0 : _problem = _factory.create<FEProblemBase>(type, "MOOSE Problem", params);
428 0 : }
429 : }
430 : }
431 :
432 : const torch::Tensor &
433 3918 : DomainAction::getAxis(std::size_t component) const
434 : {
435 3918 : if (component < 3)
436 3918 : return _local_axis[component];
437 0 : mooseError("Invalid component");
438 : }
439 :
440 : const torch::Tensor &
441 4014 : DomainAction::getReciprocalAxis(std::size_t component) const
442 : {
443 4014 : if (component < 3)
444 4014 : return _local_reciprocal_axis[component];
445 0 : mooseError("Invalid component");
446 : }
447 :
448 : torch::Tensor
449 791596 : DomainAction::fft(const torch::Tensor & t) const
450 : {
451 791596 : switch (_parallel_mode)
452 : {
453 791596 : case ParallelMode::NONE:
454 791596 : return fftSerial(t);
455 :
456 0 : case ParallelMode::FFT_SLAB:
457 0 : return fftSlab(t);
458 :
459 0 : case ParallelMode::FFT_PENCIL:
460 0 : return fftPencil(t);
461 : }
462 0 : mooseError("Not implemented");
463 : }
464 :
465 : torch::Tensor
466 791596 : DomainAction::fftSerial(const torch::Tensor & t) const
467 : {
468 791596 : switch (_dim)
469 : {
470 : case 1:
471 80 : return torch::fft::rfft(t, c10::nullopt, 0);
472 : case 2:
473 790708 : return torch::fft::rfft2(t, c10::nullopt, {0, 1});
474 : case 3:
475 848 : return torch::fft::rfftn(t, c10::nullopt, {0, 1, 2});
476 0 : default:
477 0 : mooseError("Unsupported mesh dimension");
478 : }
479 : }
480 :
481 : torch::Tensor
482 0 : DomainAction::fftSlab(const torch::Tensor & t) const
483 : {
484 : mooseInfoRepeated("fftSlab");
485 0 : if (_dim == 1)
486 0 : mooseError("Unsupported mesh dimension");
487 :
488 0 : MooseTensor::printTensorInfo(t);
489 :
490 : // 2D transform the local slab
491 : auto slab =
492 0 : _dim == 3 ? torch::fft::fft2(t, c10::nullopt, {0, 2}) : torch::fft::fft(t, c10::nullopt, 0);
493 0 : MooseTensor::printTensorInfo(slab);
494 :
495 : // send
496 0 : std::vector<MPI_Request> send_requests(_n_rank, MPI_REQUEST_NULL);
497 0 : for (const auto & i : make_range(_n_rank))
498 0 : if (i != _rank)
499 : {
500 0 : _send_tensor[i] = slab.slice(0, _local_begin[0][i], _local_end[0][i]).contiguous().cpu();
501 0 : MooseTensor::printTensorInfo(_send_tensor[i]);
502 :
503 0 : auto data_ptr = _send_tensor[i].data_ptr<double>();
504 0 : MPI_Isend(
505 : data_ptr, _send_tensor[i].numel(), MPI_DOUBLE, i, 0, MPI_COMM_WORLD, &send_requests[i]);
506 : }
507 : else
508 : // keep the local slice on device
509 0 : _recv_tensor[i] = slab.slice(0, _local_begin[0][i], _local_end[0][i]);
510 :
511 : // receive
512 : MPI_Status recv_status;
513 0 : for (const auto & i : make_range(_n_rank))
514 0 : if (i != _rank)
515 0 : MPI_Recv(_recv_data[i].data(), 1, MPI_DOUBLE, i, 0, MPI_COMM_WORLD, &recv_status);
516 :
517 : // Wait for all non-blocking sends to complete
518 0 : for (const auto & i : make_range(_n_rank))
519 0 : if (i != _rank)
520 : {
521 : // 2d _n_local_all[0][_rank] * _n_local_all[1][i] * _n_local_all[2][i]
522 0 : _recv_tensor[i] = torch::from_blob(_recv_data[i].data(),
523 0 : {_n_local_all[0][_rank], _n_local_all[1][i]},
524 : torch::kFloat64)
525 0 : .to(MooseTensor::floatTensorOptions()); // todo: take care of 32 but
526 : // floats as well!
527 : }
528 :
529 : // stack
530 0 : auto t2 = torch::vstack(_recv_tensor);
531 :
532 : // Wait for all non-blocking sends to complete
533 0 : MPI_Waitall(_n_rank, send_requests.data(), MPI_STATUSES_IGNORE);
534 :
535 : // transfor along y direction
536 0 : return torch::fft::rfft(t2, c10::nullopt, 1);
537 0 : }
538 :
539 : torch::Tensor
540 0 : DomainAction::fftPencil(const torch::Tensor & /*t*/) const
541 : {
542 0 : if (_dim != 3)
543 0 : mooseError("Unsupported mesh dimension");
544 0 : paramError("parallel_mode", "Not implemented yet!");
545 : }
546 :
547 : torch::Tensor
548 370008 : DomainAction::ifft(const torch::Tensor & t) const
549 : {
550 370008 : switch (_dim)
551 : {
552 : case 1:
553 160 : return torch::fft::irfft(t, getShape()[0], 0);
554 : case 2:
555 369440 : return torch::fft::irfft2(t, getShape(), {0, 1});
556 : case 3:
557 488 : return torch::fft::irfftn(t, getShape(), {0, 1, 2});
558 0 : default:
559 0 : mooseError("Unsupported mesh dimension");
560 : }
561 : }
562 :
563 : torch::Tensor
564 532 : DomainAction::align(torch::Tensor t, unsigned int dim) const
565 : {
566 532 : if (dim >= _dim)
567 0 : mooseError("Unsupported alignment dimension requested dimension");
568 :
569 532 : switch (_dim)
570 : {
571 : case 1:
572 : return t;
573 :
574 360 : case 2:
575 360 : if (dim == 0)
576 : return torch::unsqueeze(t, 1);
577 : else
578 : return torch::unsqueeze(t, 0);
579 :
580 144 : case 3:
581 144 : if (dim == 0)
582 48 : return t.unsqueeze(1).unsqueeze(2);
583 96 : else if (dim == 1)
584 48 : return t.unsqueeze(0).unsqueeze(2);
585 : else
586 48 : return t.unsqueeze(0).unsqueeze(0);
587 :
588 0 : default:
589 0 : mooseError("Unsupported mesh dimension");
590 : }
591 : }
592 :
593 : std::vector<int64_t>
594 860 : DomainAction::getValueShape(std::vector<int64_t> extra_dims) const
595 : {
596 860 : std::vector<int64_t> dims(_dim);
597 2624 : for (const auto i : make_range(_dim))
598 1764 : dims[i] = _n_local[i];
599 860 : dims.insert(dims.end(), extra_dims.begin(), extra_dims.end());
600 860 : return dims;
601 0 : }
602 :
603 : std::vector<int64_t>
604 0 : DomainAction::getReciprocalValueShape(std::initializer_list<int64_t> extra_dims) const
605 : {
606 0 : std::vector<int64_t> dims(_dim);
607 0 : for (const auto i : make_range(_dim))
608 0 : dims[i] = _n_reciprocal_local[i];
609 0 : dims.insert(dims.end(), extra_dims.begin(), extra_dims.end());
610 0 : return dims;
611 0 : }
612 :
613 : void
614 0 : DomainAction::updateXGrid() const
615 : {
616 : // TODO: add mutex to avoid thread race
617 0 : switch (_dim)
618 : {
619 0 : case 1:
620 : _x_grid = _local_axis[0];
621 : break;
622 0 : case 2:
623 0 : _x_grid = torch::stack({_local_axis[0].expand(_shape), _local_axis[1].expand(_shape)}, -1);
624 0 : break;
625 0 : case 3:
626 0 : _x_grid = torch::stack({_local_axis[0].expand(_shape),
627 : _local_axis[1].expand(_shape),
628 : _local_axis[2].expand(_shape)},
629 0 : -1);
630 0 : break;
631 0 : default:
632 0 : mooseError("Unsupported problem dimension ", _dim);
633 : }
634 0 : }
635 :
636 : void
637 0 : DomainAction::updateKGrid() const
638 : {
639 0 : switch (_dim)
640 : {
641 0 : case 1:
642 : _k_grid = _local_reciprocal_axis[0];
643 : break;
644 0 : case 2:
645 0 : _k_grid = torch::stack({_local_reciprocal_axis[0].expand(_reciprocal_shape),
646 : _local_reciprocal_axis[1].expand(_reciprocal_shape)},
647 0 : -1);
648 0 : break;
649 0 : case 3:
650 0 : _k_grid = torch::stack({_local_reciprocal_axis[0].expand(_reciprocal_shape),
651 : _local_reciprocal_axis[1].expand(_reciprocal_shape),
652 : _local_reciprocal_axis[2].expand(_reciprocal_shape)},
653 0 : -1);
654 0 : break;
655 0 : default:
656 0 : mooseError("Unsupported problem dimension ", _dim);
657 : }
658 0 : }
659 :
660 : void
661 128 : DomainAction::updateKSquare() const
662 : {
663 256 : _k_square = _local_reciprocal_axis[0] * _local_reciprocal_axis[0] +
664 256 : _local_reciprocal_axis[1] * _local_reciprocal_axis[1] +
665 128 : _local_reciprocal_axis[2] * _local_reciprocal_axis[2];
666 128 : }
667 :
668 : const torch::Tensor &
669 0 : DomainAction::getXGrid() const
670 : {
671 :
672 : // build on demand
673 0 : if (!_x_grid.defined())
674 0 : updateXGrid();
675 :
676 0 : return _x_grid;
677 : }
678 :
679 : const torch::Tensor &
680 0 : DomainAction::getKGrid() const
681 : {
682 :
683 : // build on demand
684 0 : if (!_k_grid.defined())
685 0 : updateKGrid();
686 :
687 0 : return _k_grid;
688 : }
689 :
690 : const torch::Tensor &
691 234 : DomainAction::getKSquare() const
692 : {
693 : // build on demand
694 234 : if (!_k_square.defined())
695 128 : updateKSquare();
696 :
697 234 : return _k_square;
698 : }
699 :
700 : torch::Tensor
701 0 : DomainAction::sum(const torch::Tensor & t) const
702 : {
703 0 : torch::Tensor local_sum = t.sum(_domain_dimensions, false, c10::nullopt);
704 :
705 : // TODO: parallel implementation
706 0 : if (comm().size() == 1)
707 0 : return local_sum;
708 : else
709 0 : mooseError("Sum is not implemented in parallel, yet.");
710 : }
711 :
712 : torch::Tensor
713 0 : DomainAction::average(const torch::Tensor & t) const
714 : {
715 0 : return sum(t) / Real(_n_global[0] * _n_global[1] * _n_global[2]);
716 : }
717 :
718 : int64_t
719 0 : DomainAction::getNumberOfCells() const
720 : {
721 0 : return _n_global[0] * _n_global[1] * _n_global[2];
722 : }
|