Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #include "TensorSolver.h" 10 : #include "TensorProblem.h" 11 : #include "SwiftTypes.h" 12 : 13 : InputParameters 14 222 : TensorSolver::validParams() 15 : { 16 222 : InputParameters params = TensorOperatorBase::validParams(); 17 222 : params.registerBase("TensorSolver"); 18 444 : params.addParam<TensorComputeName>( 19 : "root_compute", 20 : "Primary compute object that updates the buffers. This is usually a " 21 : "ComputeGroup object. A ComputeGroup encompassing all computes will be generated " 22 : "automatically if the user does not provide this parameter."); 23 222 : params.addClassDescription("TensorSolver object."); 24 : 25 444 : params.addParam<std::vector<TensorOutputBufferName>>( 26 : "forward_buffer", 27 : {}, 28 : "These buffers are updated with the corresponding buffers from forward_buffer_old. No " 29 : "integration is performed. Buffer forwarding is used only to resolve cyclic dependencies."); 30 444 : params.addParam<std::vector<TensorInputBufferName>>( 31 : "forward_buffer_new", {}, "New values to update `forward_buffer` with."); 32 : 33 222 : return params; 34 0 : } 35 : 36 110 : TensorSolver::TensorSolver(const InputParameters & parameters) 37 110 : : TensorOperatorBase(parameters), _dt(_tensor_problem.dt()), _dt_old(_tensor_problem.dtOld()) 38 : { 39 : const auto & forward_buffer_names = getParam<TensorOutputBufferName, TensorOutputBufferName>( 40 220 : "forward_buffer", "forward_buffer_new"); 41 110 : for (const auto & [forward_buffer, forward_buffer_new] : forward_buffer_names) 42 0 : _forwarded_buffers.emplace_back(getOutputBufferByName(forward_buffer), 43 : getInputBufferByName(forward_buffer_new)); 44 110 : } 45 : 46 : const std::vector<torch::Tensor> & 47 0 : TensorSolver::getBufferOld(const std::string & param, unsigned int max_states) 48 : { 49 0 : return getBufferOldByName(getParam<TensorInputBufferName>(param), max_states); 50 : } 51 : 52 : const std::vector<torch::Tensor> & 53 110 : TensorSolver::getBufferOldByName(const TensorInputBufferName & buffer_name, unsigned int max_states) 54 : { 55 110 : return _tensor_problem.getBufferOld(buffer_name, max_states); 56 : } 57 : 58 : void 59 110 : TensorSolver::updateDependencies() 60 : { 61 : // the compute that's being solved for (usually a ComputeGroup) 62 110 : const auto & root_name = getParam<TensorComputeName>("root_compute"); 63 410 : for (const auto & cmp : _tensor_problem.getComputes()) 64 410 : if (cmp->name() == root_name) 65 : { 66 : _compute = cmp; 67 110 : _compute->updateDependencies(); 68 110 : return; 69 : } 70 : 71 0 : paramError("root_compute", "Compute object not found."); 72 : } 73 : 74 : void 75 82784 : TensorSolver::forwardBuffers() 76 : { 77 82784 : for (const auto & [forward_buffer, forward_buffer_new] : _forwarded_buffers) 78 : forward_buffer = forward_buffer_new; 79 82784 : }