LCOV - code coverage report
Current view: top level - src/tensor_computes - FFTQuasistaticElasticity.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 0 45 0.0 %
Date: 2025-09-10 17:10:32 Functions: 0 3 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : 
       2             : /**********************************************************************/
       3             : /*                    DO NOT MODIFY THIS HEADER                       */
       4             : /*             Swift, a Fourier spectral solver for MOOSE             */
       5             : /*                                                                    */
       6             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       7             : /*                        ALL RIGHTS RESERVED                         */
       8             : /**********************************************************************/
       9             : 
      10             : #include "FFTQuasistaticElasticity.h"
      11             : #include "DomainAction.h"
      12             : 
      13             : registerMooseObject("SwiftApp", FFTQuasistaticElasticity);
      14             : 
      15             : InputParameters
      16           0 : FFTQuasistaticElasticity::validParams()
      17             : {
      18           0 :   InputParameters params = TensorOperatorBase::validParams();
      19           0 :   params.addClassDescription("FFT based monolithic homogeneous quasistatic elasticity solve.");
      20           0 :   params.addParam<std::vector<TensorOutputBufferName>>("displacements", "Displacements");
      21           0 :   params.addParam<TensorInputBufferName>("cbar", "FFT of concentration buffer");
      22           0 :   params.addRequiredParam<Real>("mu", "Lame mu");
      23           0 :   params.addRequiredParam<Real>("lambda", "Lame lambda");
      24           0 :   params.addRequiredParam<Real>("e0", "volumetric eigenstrain");
      25           0 :   return params;
      26           0 : }
      27             : 
      28           0 : FFTQuasistaticElasticity::FFTQuasistaticElasticity(const InputParameters & parameters)
      29             :   : TensorOperatorBase(parameters),
      30           0 :     _two_pi_i(torch::tensor(c10::complex<double>(0.0, 2.0 * pi),
      31           0 :                             MooseTensor::complexFloatTensorOptions())),
      32           0 :     _mu(getParam<Real>("mu")),
      33           0 :     _lambda(getParam<Real>("lambda")),
      34           0 :     _e0(getParam<Real>("e0")),
      35           0 :     _cbar(getInputBuffer("cbar"))
      36             : {
      37           0 :   for (const auto & name : getParam<std::vector<TensorOutputBufferName>>("displacements"))
      38           0 :     _displacements.push_back(&getOutputBufferByName(name));
      39             : 
      40           0 :   if (_domain.getDim() != _displacements.size())
      41           0 :     paramError("displacements", "Need one displacement variable per mesh dimension");
      42           0 : }
      43             : 
      44             : void
      45           0 : FFTQuasistaticElasticity::computeBuffer()
      46             : {
      47             :   // const auto & ux = *_displacements[0];
      48             :   // const auto & uy = *_displacements[1];
      49             :   // const auto & uz = *_displacements[2];
      50             : 
      51             :   // // FFT displacements
      52             :   // auto uxbar = _domain.fft(ux);
      53             :   // auto uybar = _domain.fft(uy);
      54             :   // auto uzbar = _domain.fft(uz);
      55             : 
      56             :   // strain tensor (in reciprocal space)
      57             :   // const auto exx = uxbar * _two_pi_i * _i;
      58             :   // const auto eyy = uybar * _two_pi_i * _j;
      59             :   // const auto ezz = uzbar * _two_pi_i * _k;
      60             :   // const auto exy = 0.5 * (uxbar * _two_pi_i * _j + uybar * _two_pi_i * _i);
      61             :   // const auto exz = 0.5 * (uxbar * _two_pi_i * _k + uzbar * _two_pi_i * _i);
      62             :   // const auto eyz = 0.5 * (uybar * _two_pi_i * _k + uzbar * _two_pi_i * _j);
      63             : 
      64             :   // precalculate these!
      65           0 :   const auto ul = 2.0 * _mu + _lambda;
      66           0 :   const auto kx = _two_pi_i * _i;
      67           0 :   const auto ky = _two_pi_i * _j;
      68           0 :   const auto kz = _two_pi_i * _k;
      69             : 
      70             :   // system matrix ()
      71           0 :   const auto Axx = ul * kx * kx + _mu * ky * ky + _mu * kz * kz;
      72             :   const auto s = Axx.sizes();
      73           0 :   const auto Axy = ((_lambda + _mu) * kx * ky).expand(s);
      74           0 :   const auto Axz = ((_lambda + _mu) * kx * kz).expand(s);
      75           0 :   const auto Ayy = ul * ky * ky + _mu * kx * kx + _mu * kz * kz;
      76           0 :   const auto Ayz = ((_lambda + _mu) * ky * kz).expand(s);
      77           0 :   const auto Azz = ul * kz * kz + _mu * kx * kx + _mu * ky * ky;
      78             : 
      79             :   // override Axx, Ayy, Azz for |k|=0
      80           0 :   Axx.index({0, 0, 0}) = 1.0;
      81           0 :   Ayy.index({0, 0, 0}) = 1.0;
      82           0 :   Azz.index({0, 0, 0}) = 1.0;
      83             : 
      84             :   // RHS (eigenstrain)
      85             :   using torch::stack;
      86           0 :   const auto e = 2.0 * _e0 * _cbar * (3.0 * _lambda + _mu);
      87           0 :   e.index({0, 0, 0}) = 0.0;
      88             : 
      89           0 :   const auto b = stack({kx * e, ky * e, kz * e}, -1);
      90             : 
      91           0 :   const auto A = stack(
      92           0 :       {stack({Axx, Axy, Axz}, -1), stack({Axy, Ayy, Ayz}, -1), stack({Axz, Ayz, Azz}, -1)}, -1);
      93             : 
      94             :   // solve
      95             :   const auto x = at::linalg_solve(A, b, true);
      96             : 
      97             :   // inverse transform the solution
      98             :   using torch::indexing::Slice;
      99           0 :   for (const auto i : make_range(3))
     100             :   {
     101           0 :     const auto slice = torch::squeeze(x.index({Slice(), Slice(), Slice(), i}), -1);
     102           0 :     *_displacements[i] = _domain.ifft(slice);
     103             :   }
     104           0 : }

Generated by: LCOV version 1.14