Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #include "LBMStackTensors.h" 10 : 11 : registerMooseObject("SwiftApp", LBMStackTensors); 12 : 13 : InputParameters 14 0 : LBMStackTensors::validParams() 15 : { 16 0 : InputParameters params = LatticeBoltzmannOperator::validParams(); 17 : 18 0 : params.addRequiredParam<std::vector<TensorInputBufferName>>( 19 : "inputs", "Names of input tensor buffers to stack."); 20 0 : params.addClassDescription("Stack given scalar tensor buffers and output vectorial tensor."); 21 : 22 0 : return params; 23 0 : } 24 : 25 0 : LBMStackTensors::LBMStackTensors(const InputParameters & parameters) 26 : : LatticeBoltzmannOperator(parameters), 27 0 : _buffer_names(getParam<std::vector<TensorInputBufferName>>("inputs")) 28 : { 29 : // check for duplicates 30 0 : auto hasDuplicates = [](const std::vector<std::string> & values) 31 : { 32 0 : std::set<std::string> s(values.begin(), values.end()); 33 0 : return values.size() != s.size(); 34 : }; 35 : 36 0 : if (hasDuplicates(_buffer_names)) 37 0 : paramError("inputs", "Duplicate buffer name."); 38 0 : } 39 : 40 : void 41 0 : LBMStackTensors::init() 42 : { 43 : // make sure output buffer has the same dimensions 44 0 : if (_u.dim() < 4) 45 0 : mooseError("Output buffer must be vectorial tensor."); 46 0 : } 47 : 48 : void 49 0 : LBMStackTensors::computeBuffer() 50 : { 51 : using torch::indexing::Slice; 52 : 53 : std::vector<torch::Tensor> tensor_vector; 54 0 : for (const auto & name : _buffer_names) 55 : { 56 0 : auto tensor_buffer = getInputBufferByName(name); 57 : 58 0 : if (tensor_buffer.dim() < 3) 59 0 : tensor_buffer = tensor_buffer.unsqueeze(2); 60 0 : if (tensor_buffer.dim() > 3) 61 : { 62 0 : std::string error_msg = "Input buffer "; 63 : error_msg.append(name); 64 : error_msg += " must be scalar"; 65 0 : mooseError(error_msg); 66 : } 67 0 : tensor_vector.push_back(tensor_buffer); 68 : } 69 : 70 : // Stack the tensors along a new dimension 71 0 : _u = torch::stack(tensor_vector, 3); 72 0 : }