Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #include "SecantSolver.h" 10 : #include "TensorProblem.h" 11 : #include "DomainAction.h" 12 : 13 : registerMooseObject("SwiftApp", SecantSolver); 14 : 15 : InputParameters 16 0 : SecantSolver::validParams() 17 : { 18 0 : InputParameters params = SplitOperatorBase::validParams(); 19 0 : params.addClassDescription("Implicit secant solver time integration."); 20 0 : params.addParam<unsigned int>("substeps", 1, "secant solver substeps per time step."); 21 0 : params.addParam<unsigned int>("max_iterations", 30, "Maximum number of secant solver iteration."); 22 0 : params.addParam<Real>("relative_tolerance", 1e-9, "Convergence tolerance."); 23 0 : params.addParam<Real>("absolute_tolerance", 1e-9, "Convergence tolerance."); 24 0 : params.addParam<Real>("damping", 1.0, "Damping factor for the update step."); 25 0 : params.addParam<Real>( 26 0 : "dt_epsilon", 1e-4, "Semi-implicit stable timestep to bootstrap secant solve."); 27 0 : params.set<unsigned int>("substeps") = 1; 28 0 : params.addParam<bool>("verbose", false, "Show convergence history."); 29 0 : return params; 30 0 : } 31 : 32 0 : SecantSolver::SecantSolver(const InputParameters & parameters) 33 : : SplitOperatorBase(parameters), 34 : IterativeTensorSolverInterface(), 35 0 : _substeps(getParam<unsigned int>("substeps")), 36 0 : _max_iterations(getParam<unsigned int>("max_iterations")), 37 0 : _relative_tolerance(getParam<Real>("relative_tolerance")), 38 0 : _absolute_tolerance(getParam<Real>("absolute_tolerance")), 39 0 : _verbose(getParam<bool>("verbose")), 40 0 : _damping(getParam<Real>("damping")) 41 : { 42 : // no history required 43 0 : getVariables(0); 44 : 45 : const auto n = _variables.size(); 46 0 : if (n > 1) 47 0 : paramWarning("buffer", 48 : "The secant solver only work well for uncoupled variables. Use the BroydenSolver " 49 : "for solves with multiple coupled variables."); 50 0 : } 51 : 52 : void 53 0 : SecantSolver::computeBuffer() 54 : { 55 0 : for (_substep = 0; _substep < _substeps; ++_substep) 56 0 : secantSolve(); 57 0 : } 58 : 59 : void 60 0 : SecantSolver::secantSolve() 61 : { 62 : const auto n = _variables.size(); 63 0 : const auto dt = _dt / _substeps; 64 0 : std::vector<torch::Tensor> u_old(n); 65 0 : std::vector<torch::Tensor> Rprev(n); 66 0 : std::vector<torch::Tensor> uprev(n); 67 0 : std::vector<Real> R0norm(n); 68 : 69 0 : if (_verbose) 70 0 : _console << "Substep " << _substep << '\n'; 71 : 72 : // initial guess computed using semi-implicit Euler 73 0 : _compute->computeBuffer(); 74 0 : forwardBuffers(); 75 : 76 0 : for (const auto i : make_range(n)) 77 : { 78 0 : auto & u_out = _variables[i]._buffer; 79 0 : const auto & u = _variables[i]._reciprocal_buffer; 80 0 : const auto & N = _variables[i]._nonlinear_reciprocal; 81 0 : const auto * L = _variables[i]._linear_reciprocal; 82 : 83 0 : if (L) 84 0 : Rprev[i] = (N + *L * u) * dt; // u = u_old at this point! 85 : else 86 0 : Rprev[i] = N * dt; // u = u_old at this point! 87 : uprev[i] = u; 88 : 89 0 : R0norm[i] = torch::norm(Rprev[i]).item<double>(); 90 : 91 : // previous timestep solution 92 0 : if (_variables[i]._reciprocal_buffer.defined()) 93 : u_old[i] = _variables[i]._reciprocal_buffer; 94 : else 95 0 : u_old[i] = _domain.fft(_variables[i]._buffer); 96 : 97 : // now modify u_out 98 0 : const auto dt_epsilon = getParam<Real>("dt_epsilon"); 99 0 : if (L) 100 0 : u_out = _domain.ifft((u + dt_epsilon * N) / (1.0 - dt_epsilon * *L)); 101 : else 102 0 : u_out = _domain.ifft(u + dt_epsilon * N); 103 : 104 0 : if (_verbose) 105 0 : _console << "|R0|=" << R0norm[i] << std::endl; 106 : } 107 : 108 : // forward predict (on solver outputs) 109 0 : applyPredictors(); 110 : 111 : // Jacobian 112 : torch::Tensor J; 113 : // Residual 114 : torch::Tensor R; 115 : 116 : // secant iterations 117 : bool all_converged; 118 0 : for (_iterations = 0; _iterations < _max_iterations; ++_iterations) 119 : { 120 : // re-evaluate the solve compute 121 0 : _compute->computeBuffer(); 122 0 : forwardBuffers(); 123 : 124 : all_converged = true; 125 : 126 : // integrate all variables 127 0 : for (const auto i : make_range(n)) 128 : { 129 0 : auto & u_out = _variables[i]._buffer; 130 0 : const auto & u = _variables[i]._reciprocal_buffer; 131 0 : const auto & N = _variables[i]._nonlinear_reciprocal; 132 0 : const auto * L = _variables[i]._linear_reciprocal; 133 : 134 : // residual in reciprocal space 135 0 : if (L) 136 0 : R = (N + *L * u) * dt + u_old[i] - u; 137 : else 138 0 : R = N * dt + u_old[i] - u; 139 : 140 : // avoid NaN 141 0 : const auto dx = u - uprev[i]; 142 0 : const auto dy = R - Rprev[i]; 143 0 : auto du = torch::where(dy != 0, -R * dx / dy, 0.0); 144 : 145 : uprev[i] = u; 146 : Rprev[i] = R; 147 : 148 0 : if (_damping == 1.0) 149 0 : u_out = _domain.ifft(u + du); 150 : else 151 0 : u_out = _domain.ifft(u + du * _damping); 152 : 153 0 : const auto Rnorm = torch::norm(R).item<double>(); 154 : 155 0 : if (_verbose) 156 : { 157 0 : const auto unorm = torch::norm(du).item<double>(); 158 0 : _console << _iterations << " |du| = " << unorm << " |R|=" << Rnorm << std::endl; 159 : } 160 : 161 : // nan check 162 0 : if (std::isnan(Rnorm)) 163 : { 164 : all_converged = false; 165 0 : _iterations = _max_iterations; 166 0 : _console << "NaN detected, aborting solve.\n"; 167 : break; 168 : } 169 : 170 : // relative convergence check 171 : all_converged = 172 0 : all_converged && (Rnorm < _absolute_tolerance || Rnorm / R0norm[i] < _relative_tolerance); 173 : } 174 : 175 0 : if (all_converged) 176 : { 177 : // std::cout << "Secant solve converged after " << _iterations << " iterations. |R|=" <<Rnorm 178 : // << " |R|/|R0|=" << Rnorm / R0norm << '\n'; 179 0 : _is_converged = true; 180 0 : break; 181 : } 182 : } 183 : 184 0 : if (!all_converged) 185 : { 186 0 : _console << "Solve not converged.\n"; 187 : 188 : // restore old solution (TODO: fix time, etc) 189 0 : for (const auto i : make_range(n)) 190 0 : _variables[i]._buffer = _domain.ifft(u_old[i]); 191 : 192 0 : _is_converged = false; 193 : } 194 0 : }