LCOV - code coverage report
Current view: top level - src/tensor_computes - ComputeDisplacements.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 0 50 0.0 %
Date: 2025-09-10 17:10:32 Functions: 0 7 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #include "ComputeDisplacements.h"
      10             : #include "MooseError.h"
      11             : #include "DomainAction.h"
      12             : #include "SwiftUtils.h"
      13             : #include <ATen/core/TensorBody.h>
      14             : 
      15             : registerMooseObject("SwiftApp", ComputeDisplacements);
      16             : 
      17             : InputParameters
      18           0 : ComputeDisplacements::validParams()
      19             : {
      20           0 :   InputParameters params = TensorOperator<>::validParams();
      21           0 :   params.addClassDescription("Compute updated displacements from the deformation gradient tensor.");
      22           0 :   params.addRequiredParam<TensorInputBufferName>("F", "Deformation gradient tensor.");
      23           0 :   return params;
      24           0 : }
      25             : 
      26           0 : ComputeDisplacements::ComputeDisplacements(const InputParameters & parameters)
      27           0 :   : TensorOperator<>(parameters), _deformation_gradient_tensor(getInputBuffer("F"))
      28             : {
      29           0 : }
      30             : 
      31             : void
      32           0 : saveDebug(const torch::Tensor & debug)
      33             : {
      34             :   // dump Ghat
      35           0 :   MooseTensor::printTensorInfo("debug", debug);
      36             : 
      37           0 :   std::size_t raw_size = debug.numel();
      38             :   char * raw_ptr = static_cast<char *>(debug.data_ptr());
      39             : 
      40           0 :   if (debug.dtype() == torch::kFloat32)
      41           0 :     raw_size *= 4;
      42           0 :   else if (debug.dtype() == torch::kFloat64)
      43           0 :     raw_size *= 8;
      44             :   else
      45           0 :     mooseError("Unsupported output type");
      46             : 
      47           0 :   auto file = std::fstream("debug.bin", std::ios::out | std::ios::binary);
      48           0 :   file.write(raw_ptr, raw_size);
      49           0 :   file.close();
      50           0 : }
      51             : 
      52             : void
      53           0 : ComputeDisplacements::computeBuffer()
      54             : {
      55           0 :   const auto & F = _deformation_gradient_tensor;
      56           0 :   if (!F.defined())
      57           0 :     return;
      58             : 
      59             :   // compute strain gradient tensor H
      60             :   mooseAssert(
      61             :       F.size(-1) == _dim && F.size(-2) == _dim,
      62             :       "Value dimensions of the deformation gradient tensor to not match the problem dimension");
      63             : 
      64           0 :   const auto I3 = torch::eye(_dim, F.options());
      65             : 
      66           0 :   const auto Fbox = _domain.average(F);
      67             : 
      68             :   // const auto Hbar = _domain.fft(F - MooseTensor::unsqueeze0(Fbox, _dim));
      69           0 :   const auto Hbar = _domain.fft(F - Fbox);
      70             : 
      71           0 :   const auto q = _domain.getKGrid() * (-_imaginary);
      72           0 :   const auto Q = _domain.getKSquare();
      73             : 
      74           0 :   const auto numer = torch::einsum("...ij,...j->...i", {Hbar, q});
      75             :   const auto denom = Q.unsqueeze(-1);
      76             : 
      77           0 :   const auto u_periodic_bar = torch::where(denom == 0, 0.0, numer / denom);
      78             : 
      79             :   torch::Tensor u_periodic;
      80             :   torch::Tensor u_aff;
      81             : 
      82           0 :   const auto & X = _domain.getXGrid();
      83           0 :   u_aff = torch::einsum("ij,...j->...i", {Fbox - I3, X});
      84           0 :   u_periodic = _domain.ifft(u_periodic_bar);
      85             : 
      86           0 :   std::vector<int64_t> shape(_domain.getShape().begin(), _domain.getShape().end());
      87           0 :   for (auto & n : shape)
      88           0 :     n++;
      89             : 
      90             :   namespace tf = torch::nn::functional;
      91           0 :   auto interpolate = [&](auto mode)
      92             :   {
      93           0 :     _u = tf::interpolate((u_aff + u_periodic).movedim(-1, 0).unsqueeze(1),
      94           0 :                          tf::InterpolateFuncOptions().size(shape).mode(mode).align_corners(true))
      95             :              .squeeze(1)
      96             :              .movedim(0, -1);
      97           0 :   };
      98             : 
      99           0 :   if (_dim == 3)
     100           0 :     interpolate(torch::kTrilinear);
     101           0 :   else if (_dim == 2)
     102           0 :     interpolate(torch::kBilinear);
     103           0 :   else if (_dim == 1)
     104           0 :     interpolate(torch::kLinear);
     105             :   else
     106           0 :     mooseError("Unsupported problem dimension");
     107           0 : }

Generated by: LCOV version 1.14