LCOV - code coverage report
Current view: top level - src/actions - DomainAction.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 195 353 55.2 %
Date: 2025-09-10 17:10:32 Functions: 14 26 53.8 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #include "DomainAction.h"
      10             : #include "MooseError.h"
      11             : #include "TensorProblem.h"
      12             : #include "MooseEnum.h"
      13             : #include "SetupMeshAction.h"
      14             : #include "SwiftApp.h"
      15             : #include "CreateProblemAction.h"
      16             : 
      17             : #include <initializer_list>
      18             : #include <util/Optional.h>
      19             : 
      20             : // run this early, before any objects are constructed
      21             : registerMooseAction("SwiftApp", DomainAction, "meta_action");
      22             : registerMooseAction("SwiftApp", DomainAction, "add_mesh_generator");
      23             : registerMooseAction("SwiftApp", DomainAction, "create_problem_custom");
      24             : 
      25             : InputParameters
      26         128 : DomainAction::validParams()
      27             : {
      28         128 :   InputParameters params = Action::validParams();
      29         128 :   params.addClassDescription("Set up the domain and compute devices.");
      30             : 
      31         256 :   MooseEnum dims("1=1 2 3");
      32         256 :   params.addRequiredParam<MooseEnum>("dim", dims, "Problem dimension");
      33             : 
      34         256 :   MooseEnum parmode("NONE FFT_SLAB FFT_PENCIL", "NONE");
      35         256 :   parmode.addDocumentation("NONE", "Serial execution without domain decomposition.");
      36         256 :   parmode.addDocumentation("FFT_SLAB",
      37             :                            "Slab decomposition with X-Z slabs stacked along the Y direction in "
      38             :                            "real space and Y-Z slabs stacked along the X direction in Fourier "
      39             :                            "space. This requires one all-to-all communication per FFT.");
      40         256 :   parmode.addDocumentation(
      41             :       "FFT_PENCIL",
      42             :       "Pencil decomposition (3D only). Three 1D FFTs in pencil arrays along the X, Y, and lastly Z "
      43             :       "direction. Thie requires two many-to-many communications per FFT.");
      44             : 
      45         256 :   params.addParam<MooseEnum>("parallel_mode", parmode, "Parallelization mode.");
      46             : 
      47         256 :   params.addParam<unsigned int>("nx", 1, "Number of elements in the X direction");
      48         256 :   params.addParam<unsigned int>("ny", 1, "Number of elements in the Y direction");
      49         256 :   params.addParam<unsigned int>("nz", 1, "Number of elements in the Z direction");
      50         256 :   params.addParam<Real>("xmax", 1.0, "Upper X Coordinate of the generated mesh");
      51         256 :   params.addParam<Real>("ymax", 1.0, "Upper Y Coordinate of the generated mesh");
      52         256 :   params.addParam<Real>("zmax", 1.0, "Upper Z Coordinate of the generated mesh");
      53         256 :   params.addParam<Real>("xmin", 0.0, "Lower X Coordinate of the generated mesh");
      54         256 :   params.addParam<Real>("ymin", 0.0, "Lower Y Coordinate of the generated mesh");
      55         256 :   params.addParam<Real>("zmin", 0.0, "Lower Z Coordinate of the generated mesh");
      56             : 
      57         256 :   MooseEnum meshmode("DUMMY DOMAIN MANUAL", "DUMMY");
      58         256 :   meshmode.addDocumentation("DUMMY",
      59             :                             "Create a single element mesh the size of the simulation domain");
      60         256 :   meshmode.addDocumentation("DOMAIN", "Create a mesh with one element per grid cell");
      61         256 :   meshmode.addDocumentation("MANUAL",
      62             :                             "Do not auto-generate a mesh. User must add a Mesh block themselves.");
      63             : 
      64         256 :   params.addParam<MooseEnum>("mesh_mode", meshmode, "Mesh generation mode.");
      65             : 
      66         256 :   params.addParam<std::vector<std::string>>("device_names", {}, "Compute devices to run on.");
      67         256 :   params.addParam<std::vector<unsigned int>>(
      68             :       "device_weights", {}, "Device weights (or speeds) to influence the partitioning.");
      69             : 
      70         256 :   MooseEnum floatingPrecision("DEVICE_DEFAULT SINGLE DOUBLE", "DEVICE_DEFAULT");
      71         256 :   params.addParam<MooseEnum>("floating_precision", floatingPrecision, "Floating point precision.");
      72             : 
      73         256 :   params.addParam<bool>(
      74             :       "debug",
      75         256 :       false,
      76             :       "Enable additional debugging and diagnostics, such a checking for initialized tensors.");
      77         128 :   return params;
      78         128 : }
      79             : 
      80         128 : DomainAction::DomainAction(const InputParameters & parameters)
      81             :   : Action(parameters),
      82         128 :     _device_names(getParam<std::vector<std::string>>("device_names")),
      83         256 :     _device_weights(getParam<std::vector<unsigned int>>("device_weights")),
      84         256 :     _floating_precision(getParam<MooseEnum>("floating_precision").getEnum<FloatingPrecision>()),
      85         256 :     _parallel_mode(getParam<MooseEnum>("parallel_mode").getEnum<ParallelMode>()),
      86         256 :     _dim(getParam<MooseEnum>("dim")),
      87         384 :     _n_global(
      88         512 :         {getParam<unsigned int>("nx"), getParam<unsigned int>("ny"), getParam<unsigned int>("nz")}),
      89         512 :     _min_global({getParam<Real>("xmin"), getParam<Real>("ymin"), getParam<Real>("zmin")}),
      90         512 :     _max_global({getParam<Real>("xmax"), getParam<Real>("ymax"), getParam<Real>("zmax")}),
      91         256 :     _mesh_mode(getParam<MooseEnum>("mesh_mode").getEnum<MeshMode>()),
      92         128 :     _shape(torch::IntArrayRef(_n_local.data(), _dim)),
      93             :     _reciprocal_shape(torch::IntArrayRef(_n_reciprocal_local.data(), _dim)),
      94         128 :     _domain_dimensions_buffer({0, 1, 2}),
      95             :     _domain_dimensions(torch::IntArrayRef(_domain_dimensions_buffer.data(), _dim)),
      96         128 :     _rank(_communicator.rank()),
      97         128 :     _n_rank(_communicator.size()),
      98         128 :     _send_tensor(_n_rank),
      99         128 :     _recv_tensor(_n_rank),
     100         384 :     _debug(getParam<bool>("debug"))
     101             : {
     102         128 :   if (_parallel_mode == ParallelMode::NONE && comm().size() > 1)
     103           0 :     paramError("parallel_mode", "NONE requires the application to run in serial.");
     104             : 
     105         128 :   if (_device_names.empty())
     106             :   {
     107         108 :     if (comm().size() > 1)
     108           0 :       mooseError("Specify Domain/device_names for parallel operation.");
     109             : 
     110             :     // set local weights and ranks for serial
     111         108 :     _local_ranks = {0};
     112         108 :     _local_weights = {1};
     113             :   }
     114             :   else
     115             :   {
     116             :     // process weights
     117          20 :     if (_device_weights.empty())
     118          20 :       _device_weights.assign(1, _device_names.size());
     119             : 
     120          20 :     if (_device_weights.size() != _device_names.size())
     121           0 :       mooseError("Specify one weight per device or none at all");
     122             : 
     123             :     // determine the processor name
     124             :     char name[MPI_MAX_PROCESSOR_NAME + 1];
     125             :     int len;
     126          20 :     MPI_Get_processor_name(name, &len);
     127          20 :     name[len] = 0;
     128             : 
     129             :     // gather all processor names
     130             :     std::vector<std::string> host_names;
     131          40 :     _communicator.allgather(std::string(name), host_names);
     132             : 
     133             :     // get the local rank on the current processor (used for compute device assignment)
     134             :     std::map<std::string, unsigned int> host_rank_count;
     135             : 
     136          40 :     for (const auto & host_name : host_names)
     137             :     {
     138          40 :       if (host_rank_count.find(name) == host_rank_count.end())
     139          20 :         host_rank_count[host_name] = 0;
     140             : 
     141          20 :       auto & local_rank = host_rank_count[host_name];
     142          20 :       _local_ranks.push_back(local_rank);
     143          20 :       _local_weights.push_back(_device_weights[local_rank % _device_weights.size()]);
     144             : 
     145             :       // std::cout << "Process on " << host_name << ' ' << local_rank << ' '
     146             :       //           << _device_weights[local_rank % _device_weights.size()] << '\n';
     147             : 
     148          20 :       local_rank++;
     149             :     }
     150             : 
     151             :     // for (const auto i : index_range(host_names))
     152             :     //   std::cout << host_names[i] << '\t' << _local_ranks[i] << '\n';
     153             : 
     154             :     // pick a compute device for a list of available devices
     155          20 :     auto swift_app = dynamic_cast<SwiftApp *>(&_app);
     156          20 :     if (!swift_app)
     157           0 :       mooseError("This action requires a SwftApp object to be present.");
     158          20 :     swift_app->setTorchDevice(_device_names[_local_ranks[_rank] % _device_names.size()], {});
     159             : 
     160          20 :     switch (_floating_precision)
     161             :     {
     162          20 :       case FloatingPrecision::DEVICE_DEFAULT:
     163             :       {
     164          20 :         swift_app->setTorchPrecision("DEVICE_DEFAULT", {});
     165          20 :         break;
     166             :       }
     167           0 :       case FloatingPrecision::DOUBLE:
     168             :       {
     169           0 :         swift_app->setTorchPrecision("DOUBLE", {});
     170           0 :         break;
     171             :       }
     172           0 :       case FloatingPrecision::SINGLE:
     173             :       {
     174           0 :         swift_app->setTorchPrecision("SINGLE", {});
     175           0 :         break;
     176             :       }
     177           0 :       default:
     178           0 :         mooseError("Invalid floating precision.");
     179             :     };
     180          20 :   }
     181             : 
     182             :   // domain partitioning
     183         128 :   gridChanged();
     184         128 : }
     185             : 
     186             : void
     187         128 : DomainAction::gridChanged()
     188             : {
     189         128 :   auto options = MooseTensor::floatTensorOptions();
     190             : 
     191             :   // build real space axes
     192         128 :   _volume_global = 1.0;
     193         512 :   for (const unsigned int dim : {0, 1, 2})
     194             :   {
     195             :     // error check
     196         384 :     if (_max_global(dim) <= _min_global(dim))
     197           0 :       mooseError("Max coordinate must be larger than the min coordinate in every dimension");
     198             : 
     199             :     // get grid geometry
     200         384 :     _grid_spacing(dim) = (_max_global(dim) - _min_global(dim)) / _n_global[dim];
     201             : 
     202             :     // real space axis
     203         384 :     if (dim < _dim)
     204             :     {
     205             :       _global_axis[dim] =
     206         532 :           align(torch::linspace(c10::Scalar(_min_global(dim) + _grid_spacing(dim) / 2.0),
     207         266 :                                 c10::Scalar(_max_global(dim) - _grid_spacing(dim) / 2.0),
     208             :                                 _n_global[dim],
     209             :                                 options),
     210             :                 dim);
     211         266 :       _volume_global *= _max_global(dim) - _min_global(dim);
     212             :     }
     213             :     else
     214         354 :       _global_axis[dim] = torch::tensor({0.0}, options);
     215             :   }
     216             : 
     217             :   // build reciprocal space axes
     218         512 :   for (const unsigned int dim : {0, 1, 2})
     219             :   {
     220         384 :     if (dim < _dim)
     221             :     {
     222         266 :       const auto freq = (dim == _dim - 1)
     223         266 :                             ? torch::fft::rfftfreq(_n_global[dim], _grid_spacing(dim), options)
     224         266 :                             : torch::fft::fftfreq(_n_global[dim], _grid_spacing(dim), options);
     225             : 
     226             :       // zero out nyquist frequency
     227             :       // if (_n_global[dim] % 2 == 0)
     228             :       //   freq[_n_global[dim] / 2] = 0.0;
     229             : 
     230         532 :       _global_reciprocal_axis[dim] = align(freq * 2.0 * libMesh::pi, dim);
     231             :     }
     232             :     else
     233         354 :       _global_reciprocal_axis[dim] = torch::tensor({0.0}, options);
     234             : 
     235             :     // compute max frequency along each axis
     236         384 :     _max_k(dim) = libMesh::pi / _grid_spacing(dim);
     237             : 
     238             :     // get global reciprocal axis size
     239         384 :     _n_reciprocal_global[dim] = _global_reciprocal_axis[dim].sizes()[dim];
     240             :   }
     241             : 
     242         128 :   switch (_parallel_mode)
     243             :   {
     244         128 :     case ParallelMode::NONE:
     245         128 :       partitionSerial();
     246             :       break;
     247             : 
     248           0 :     case ParallelMode::FFT_SLAB:
     249           0 :       partitionSlabs();
     250             :       break;
     251             : 
     252           0 :     case ParallelMode::FFT_PENCIL:
     253           0 :       partitionPencils();
     254             :       break;
     255             :   }
     256             : 
     257             :   // get local reciprocal axis size
     258         512 :   for (const auto dim : {0, 1, 2})
     259         384 :     _n_reciprocal_local[dim] = _local_reciprocal_axis[dim].sizes()[dim];
     260             : 
     261             :   // update on-demand grids
     262         128 :   if (_x_grid.defined())
     263           0 :     updateXGrid();
     264         128 :   if (_k_grid.defined())
     265           0 :     updateKGrid();
     266         128 :   if (_k_square.defined())
     267           0 :     updateKSquare();
     268         128 : }
     269             : 
     270             : void
     271         128 : DomainAction::partitionSerial()
     272             : {
     273             :   // goes along the full dimension for each rank
     274         512 :   for (const auto d : make_range(3u))
     275             :   {
     276         384 :     _local_begin[d].resize(_n_rank);
     277         384 :     _local_end[d].resize(_n_rank);
     278         768 :     for (const auto i : make_range(_communicator.size()))
     279             :     {
     280         384 :       _local_begin[d][i] = 0;
     281         384 :       _local_end[d][i] = _n_global[d];
     282             :     }
     283             :   }
     284             : 
     285             :   // to do, make those slices dependent on local begin/end
     286         128 :   _local_axis = _global_axis;
     287         128 :   _n_local = _n_global;
     288         128 :   _local_reciprocal_axis = _global_reciprocal_axis;
     289         128 : }
     290             : 
     291             : void
     292           0 : DomainAction::partitionSlabs()
     293             : {
     294           0 :   if (_dim < 2)
     295           0 :     paramError("dim", "Dimension must be 2 or 3 for slab decomposition.");
     296             : 
     297             :   // x is partitioned along a halved dimension due to the use of rfft
     298           0 :   _n_local_all[0] = partitionHepler(_global_reciprocal_axis[0].sizes()[0], _device_weights);
     299             : 
     300             :   // y is partitioned along the y realspace axis
     301           0 :   _n_local_all[1] = partitionHepler(_global_axis[1].sizes()[1], _device_weights);
     302             : 
     303             :   // set begin/end for x and y
     304           0 :   for (const auto d : {0, 1})
     305             :   {
     306             :     int64_t b = 0;
     307           0 :     for (const auto r : index_range(_n_local_all[d]))
     308             :     {
     309           0 :       _local_begin[d][r] = b;
     310           0 :       b += _n_local_all[d][r];
     311           0 :       _local_end[d][r] = b;
     312             :     }
     313             :   }
     314             : 
     315             :   // z is not partitioned at all
     316           0 :   _n_local_all[2].assign(_n_rank, _n_global[2]);
     317           0 :   _local_begin[2].assign(_n_rank, 0);
     318           0 :   _local_end[2].assign(_n_rank, _n_global[2]);
     319             : 
     320             :   // slice the real space into x-z slabs stacked in y direction
     321           0 :   _local_axis[0] = _global_axis[0].slice(0, 0, _n_global[0]);
     322           0 :   _local_axis[1] = _global_axis[1].slice(1, _local_begin[1][_rank], _local_end[1][_rank]);
     323           0 :   _n_local[0] = _n_global[0];
     324           0 :   _n_local[1] = _local_end[1][_rank] - _local_begin[1][_rank];
     325             : 
     326             :   // slice the reciprocal space into y-z slices stacked in x direction
     327             :   _local_reciprocal_axis[0] =
     328           0 :       _global_reciprocal_axis[0].slice(0, 0, _local_begin[0][_rank], _local_end[0][_rank]);
     329           0 :   _local_reciprocal_axis[1] = _global_reciprocal_axis[1].slice(1, 0, _n_reciprocal_global[1]);
     330             : 
     331           0 :   _n_local[2] = _n_global[2];
     332             : 
     333             :   // special casing this should not be neccessary
     334           0 :   if (_dim == 3)
     335             :   {
     336           0 :     _local_axis[2] = _global_axis[2].slice(2, 0, _n_global[2]);
     337           0 :     _local_reciprocal_axis[2] = _global_reciprocal_axis[2].slice(2, 0, _n_reciprocal_global[2]);
     338             :   }
     339             :   else
     340             :   {
     341             :     _local_axis[2] = _global_axis[2];
     342             :     _local_reciprocal_axis[2] = _global_reciprocal_axis[2];
     343             :   }
     344             : 
     345             :   // allocate receive buffer
     346           0 :   for (const auto i : make_range(_communicator.size()))
     347           0 :     if (i != _rank)
     348           0 :       _recv_data[i].resize(_n_local_all[0][_rank] * _n_local_all[1][i] * _n_local_all[2][i]);
     349           0 : }
     350             : 
     351             : void
     352           0 : DomainAction::partitionPencils()
     353             : {
     354           0 :   if (_dim < 3)
     355           0 :     paramError("dim", "Dimension must be 3 for pencil decomposition.");
     356           0 :   paramError("parallel_mode", "Not implemented yet!");
     357             : }
     358             : 
     359             : void
     360         384 : DomainAction::act()
     361             : {
     362         384 :   if (_current_task == "meta_action" && _mesh_mode != MeshMode::SWIFT_MANUAL)
     363             :   {
     364             :     // check if a SetupMesh action exists
     365         128 :     auto mesh_actions = _awh.getActions<SetupMeshAction>();
     366         128 :     if (mesh_actions.size() > 0)
     367           0 :       paramError("mesh_mode", "Do not specify a [Mesh] block unless mesh_mode is set to MANUAL");
     368             : 
     369             :     // otherwise create one
     370         128 :     auto & af = _app.getActionFactory();
     371         128 :     InputParameters action_params = af.getValidParams("SetupMeshAction");
     372             :     auto action = std::static_pointer_cast<MooseObjectAction>(
     373         256 :         af.create("SetupMeshAction", "Mesh", action_params));
     374         384 :     _app.actionWarehouse().addActionBlock(action);
     375         128 :   }
     376             : 
     377             :   // add a DomainMeshGenerator
     378         384 :   if (_current_task == "add_mesh_generator" && _mesh_mode != MeshMode::SWIFT_MANUAL)
     379             :   {
     380             :     // Don't do mesh generators when recovering or when the user has requested for us not to
     381         128 :     if ((_app.isRecovering() && _app.isUltimateMaster()) || _app.masterMesh())
     382           0 :       return;
     383             : 
     384             :     const MeshGeneratorName name = "domain_mesh_generator";
     385         128 :     auto params = _factory.getValidParams("DomainMeshGenerator");
     386             : 
     387         128 :     params.set<MooseEnum>("dim") = _dim;
     388         128 :     params.set<Real>("xmax") = _max_global(0);
     389         128 :     params.set<Real>("ymax") = _max_global(1);
     390         128 :     params.set<Real>("zmax") = _max_global(2);
     391         128 :     params.set<Real>("xmin") = _min_global(0);
     392         128 :     params.set<Real>("ymin") = _min_global(1);
     393         128 :     params.set<Real>("zmin") = _min_global(2);
     394             : 
     395         128 :     if (_mesh_mode == MeshMode::SWIFT_DOMAIN)
     396             :     {
     397          30 :       params.set<unsigned int>("nx") = _n_global[0];
     398          30 :       params.set<unsigned int>("ny") = _n_global[1];
     399          30 :       params.set<unsigned int>("nz") = _n_global[2];
     400             :     }
     401          98 :     else if (_mesh_mode == MeshMode::SWIFT_DUMMY)
     402             :     {
     403          98 :       params.set<unsigned int>("nx") = 1;
     404          98 :       params.set<unsigned int>("ny") = 1;
     405          98 :       params.set<unsigned int>("nz") = 1;
     406             :     }
     407             :     else
     408           0 :       mooseError("Internal error");
     409             : 
     410         128 :     _app.addMeshGenerator("DomainMeshGenerator", name, params);
     411         128 :   }
     412             : 
     413         384 :   if (_current_task == "create_problem_custom")
     414             :   {
     415         128 :     if (!_problem)
     416             :     {
     417           0 :       const std::string type = "TensorProblem";
     418           0 :       auto params = _factory.getValidParams(type);
     419             : 
     420             :       // apply common parameters of the object held by CreateProblemAction
     421             :       // to honor user inputs in [Problem]
     422           0 :       auto p = _awh.getActionByTask<CreateProblemAction>("create_problem");
     423           0 :       if (p)
     424           0 :         params.applyParameters(p->getObjectParams());
     425             : 
     426           0 :       params.set<MooseMesh *>("mesh") = _mesh.get();
     427           0 :       _problem = _factory.create<FEProblemBase>(type, "MOOSE Problem", params);
     428           0 :     }
     429             :   }
     430             : }
     431             : 
     432             : const torch::Tensor &
     433        3918 : DomainAction::getAxis(std::size_t component) const
     434             : {
     435        3918 :   if (component < 3)
     436        3918 :     return _local_axis[component];
     437           0 :   mooseError("Invalid component");
     438             : }
     439             : 
     440             : const torch::Tensor &
     441        4014 : DomainAction::getReciprocalAxis(std::size_t component) const
     442             : {
     443        4014 :   if (component < 3)
     444        4014 :     return _local_reciprocal_axis[component];
     445           0 :   mooseError("Invalid component");
     446             : }
     447             : 
     448             : torch::Tensor
     449      791596 : DomainAction::fft(const torch::Tensor & t) const
     450             : {
     451      791596 :   switch (_parallel_mode)
     452             :   {
     453      791596 :     case ParallelMode::NONE:
     454      791596 :       return fftSerial(t);
     455             : 
     456           0 :     case ParallelMode::FFT_SLAB:
     457           0 :       return fftSlab(t);
     458             : 
     459           0 :     case ParallelMode::FFT_PENCIL:
     460           0 :       return fftPencil(t);
     461             :   }
     462           0 :   mooseError("Not implemented");
     463             : }
     464             : 
     465             : torch::Tensor
     466      791596 : DomainAction::fftSerial(const torch::Tensor & t) const
     467             : {
     468      791596 :   switch (_dim)
     469             :   {
     470             :     case 1:
     471          80 :       return torch::fft::rfft(t, c10::nullopt, 0);
     472             :     case 2:
     473      790708 :       return torch::fft::rfft2(t, c10::nullopt, {0, 1});
     474             :     case 3:
     475         848 :       return torch::fft::rfftn(t, c10::nullopt, {0, 1, 2});
     476           0 :     default:
     477           0 :       mooseError("Unsupported mesh dimension");
     478             :   }
     479             : }
     480             : 
     481             : torch::Tensor
     482           0 : DomainAction::fftSlab(const torch::Tensor & t) const
     483             : {
     484             :   mooseInfoRepeated("fftSlab");
     485           0 :   if (_dim == 1)
     486           0 :     mooseError("Unsupported mesh dimension");
     487             : 
     488           0 :   MooseTensor::printTensorInfo(t);
     489             : 
     490             :   // 2D transform the local slab
     491             :   auto slab =
     492           0 :       _dim == 3 ? torch::fft::fft2(t, c10::nullopt, {0, 2}) : torch::fft::fft(t, c10::nullopt, 0);
     493           0 :   MooseTensor::printTensorInfo(slab);
     494             : 
     495             :   // send
     496           0 :   std::vector<MPI_Request> send_requests(_n_rank, MPI_REQUEST_NULL);
     497           0 :   for (const auto & i : make_range(_n_rank))
     498           0 :     if (i != _rank)
     499             :     {
     500           0 :       _send_tensor[i] = slab.slice(0, _local_begin[0][i], _local_end[0][i]).contiguous().cpu();
     501           0 :       MooseTensor::printTensorInfo(_send_tensor[i]);
     502             : 
     503           0 :       auto data_ptr = _send_tensor[i].data_ptr<double>();
     504           0 :       MPI_Isend(
     505             :           data_ptr, _send_tensor[i].numel(), MPI_DOUBLE, i, 0, MPI_COMM_WORLD, &send_requests[i]);
     506             :     }
     507             :     else
     508             :       // keep the local slice on device
     509           0 :       _recv_tensor[i] = slab.slice(0, _local_begin[0][i], _local_end[0][i]);
     510             : 
     511             :   // receive
     512             :   MPI_Status recv_status;
     513           0 :   for (const auto & i : make_range(_n_rank))
     514           0 :     if (i != _rank)
     515           0 :       MPI_Recv(_recv_data[i].data(), 1, MPI_DOUBLE, i, 0, MPI_COMM_WORLD, &recv_status);
     516             : 
     517             :   // Wait for all non-blocking sends to complete
     518           0 :   for (const auto & i : make_range(_n_rank))
     519           0 :     if (i != _rank)
     520             :     {
     521             :       // 2d _n_local_all[0][_rank] * _n_local_all[1][i] * _n_local_all[2][i]
     522           0 :       _recv_tensor[i] = torch::from_blob(_recv_data[i].data(),
     523           0 :                                          {_n_local_all[0][_rank], _n_local_all[1][i]},
     524             :                                          torch::kFloat64)
     525           0 :                             .to(MooseTensor::floatTensorOptions()); // todo: take care of 32 but
     526             :                                                                     // floats as well!
     527             :     }
     528             : 
     529             :   // stack
     530           0 :   auto t2 = torch::vstack(_recv_tensor);
     531             : 
     532             :   // Wait for all non-blocking sends to complete
     533           0 :   MPI_Waitall(_n_rank, send_requests.data(), MPI_STATUSES_IGNORE);
     534             : 
     535             :   // transfor along y direction
     536           0 :   return torch::fft::rfft(t2, c10::nullopt, 1);
     537           0 : }
     538             : 
     539             : torch::Tensor
     540           0 : DomainAction::fftPencil(const torch::Tensor & /*t*/) const
     541             : {
     542           0 :   if (_dim != 3)
     543           0 :     mooseError("Unsupported mesh dimension");
     544           0 :   paramError("parallel_mode", "Not implemented yet!");
     545             : }
     546             : 
     547             : torch::Tensor
     548      370008 : DomainAction::ifft(const torch::Tensor & t) const
     549             : {
     550      370008 :   switch (_dim)
     551             :   {
     552             :     case 1:
     553         160 :       return torch::fft::irfft(t, getShape()[0], 0);
     554             :     case 2:
     555      369440 :       return torch::fft::irfft2(t, getShape(), {0, 1});
     556             :     case 3:
     557         488 :       return torch::fft::irfftn(t, getShape(), {0, 1, 2});
     558           0 :     default:
     559           0 :       mooseError("Unsupported mesh dimension");
     560             :   }
     561             : }
     562             : 
     563             : torch::Tensor
     564         532 : DomainAction::align(torch::Tensor t, unsigned int dim) const
     565             : {
     566         532 :   if (dim >= _dim)
     567           0 :     mooseError("Unsupported alignment dimension requested dimension");
     568             : 
     569         532 :   switch (_dim)
     570             :   {
     571             :     case 1:
     572             :       return t;
     573             : 
     574         360 :     case 2:
     575         360 :       if (dim == 0)
     576             :         return torch::unsqueeze(t, 1);
     577             :       else
     578             :         return torch::unsqueeze(t, 0);
     579             : 
     580         144 :     case 3:
     581         144 :       if (dim == 0)
     582          48 :         return t.unsqueeze(1).unsqueeze(2);
     583          96 :       else if (dim == 1)
     584          48 :         return t.unsqueeze(0).unsqueeze(2);
     585             :       else
     586          48 :         return t.unsqueeze(0).unsqueeze(0);
     587             : 
     588           0 :     default:
     589           0 :       mooseError("Unsupported mesh dimension");
     590             :   }
     591             : }
     592             : 
     593             : std::vector<int64_t>
     594         860 : DomainAction::getValueShape(std::vector<int64_t> extra_dims) const
     595             : {
     596         860 :   std::vector<int64_t> dims(_dim);
     597        2624 :   for (const auto i : make_range(_dim))
     598        1764 :     dims[i] = _n_local[i];
     599         860 :   dims.insert(dims.end(), extra_dims.begin(), extra_dims.end());
     600         860 :   return dims;
     601           0 : }
     602             : 
     603             : std::vector<int64_t>
     604           0 : DomainAction::getReciprocalValueShape(std::initializer_list<int64_t> extra_dims) const
     605             : {
     606           0 :   std::vector<int64_t> dims(_dim);
     607           0 :   for (const auto i : make_range(_dim))
     608           0 :     dims[i] = _n_reciprocal_local[i];
     609           0 :   dims.insert(dims.end(), extra_dims.begin(), extra_dims.end());
     610           0 :   return dims;
     611           0 : }
     612             : 
     613             : void
     614           0 : DomainAction::updateXGrid() const
     615             : {
     616             :   // TODO: add mutex to avoid thread race
     617           0 :   switch (_dim)
     618             :   {
     619           0 :     case 1:
     620             :       _x_grid = _local_axis[0];
     621             :       break;
     622           0 :     case 2:
     623           0 :       _x_grid = torch::stack({_local_axis[0].expand(_shape), _local_axis[1].expand(_shape)}, -1);
     624           0 :       break;
     625           0 :     case 3:
     626           0 :       _x_grid = torch::stack({_local_axis[0].expand(_shape),
     627             :                               _local_axis[1].expand(_shape),
     628             :                               _local_axis[2].expand(_shape)},
     629           0 :                              -1);
     630           0 :       break;
     631           0 :     default:
     632           0 :       mooseError("Unsupported problem dimension ", _dim);
     633             :   }
     634           0 : }
     635             : 
     636             : void
     637           0 : DomainAction::updateKGrid() const
     638             : {
     639           0 :   switch (_dim)
     640             :   {
     641           0 :     case 1:
     642             :       _k_grid = _local_reciprocal_axis[0];
     643             :       break;
     644           0 :     case 2:
     645           0 :       _k_grid = torch::stack({_local_reciprocal_axis[0].expand(_reciprocal_shape),
     646             :                               _local_reciprocal_axis[1].expand(_reciprocal_shape)},
     647           0 :                              -1);
     648           0 :       break;
     649           0 :     case 3:
     650           0 :       _k_grid = torch::stack({_local_reciprocal_axis[0].expand(_reciprocal_shape),
     651             :                               _local_reciprocal_axis[1].expand(_reciprocal_shape),
     652             :                               _local_reciprocal_axis[2].expand(_reciprocal_shape)},
     653           0 :                              -1);
     654           0 :       break;
     655           0 :     default:
     656           0 :       mooseError("Unsupported problem dimension ", _dim);
     657             :   }
     658           0 : }
     659             : 
     660             : void
     661         128 : DomainAction::updateKSquare() const
     662             : {
     663         256 :   _k_square = _local_reciprocal_axis[0] * _local_reciprocal_axis[0] +
     664         256 :               _local_reciprocal_axis[1] * _local_reciprocal_axis[1] +
     665         128 :               _local_reciprocal_axis[2] * _local_reciprocal_axis[2];
     666         128 : }
     667             : 
     668             : const torch::Tensor &
     669           0 : DomainAction::getXGrid() const
     670             : {
     671             : 
     672             :   // build on demand
     673           0 :   if (!_x_grid.defined())
     674           0 :     updateXGrid();
     675             : 
     676           0 :   return _x_grid;
     677             : }
     678             : 
     679             : const torch::Tensor &
     680           0 : DomainAction::getKGrid() const
     681             : {
     682             : 
     683             :   // build on demand
     684           0 :   if (!_k_grid.defined())
     685           0 :     updateKGrid();
     686             : 
     687           0 :   return _k_grid;
     688             : }
     689             : 
     690             : const torch::Tensor &
     691         234 : DomainAction::getKSquare() const
     692             : {
     693             :   // build on demand
     694         234 :   if (!_k_square.defined())
     695         128 :     updateKSquare();
     696             : 
     697         234 :   return _k_square;
     698             : }
     699             : 
     700             : torch::Tensor
     701           0 : DomainAction::sum(const torch::Tensor & t) const
     702             : {
     703           0 :   torch::Tensor local_sum = t.sum(_domain_dimensions, false, c10::nullopt);
     704             : 
     705             :   // TODO: parallel implementation
     706           0 :   if (comm().size() == 1)
     707           0 :     return local_sum;
     708             :   else
     709           0 :     mooseError("Sum is not implemented in parallel, yet.");
     710             : }
     711             : 
     712             : torch::Tensor
     713           0 : DomainAction::average(const torch::Tensor & t) const
     714             : {
     715           0 :   return sum(t) / Real(_n_global[0] * _n_global[1] * _n_global[2]);
     716             : }
     717             : 
     718             : int64_t
     719           0 : DomainAction::getNumberOfCells() const
     720             : {
     721           0 :   return _n_global[0] * _n_global[1] * _n_global[2];
     722             : }

Generated by: LCOV version 1.14