LCOV - code coverage report
Current view: top level - include/utils - SwiftUtils.h (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 22 23 95.7 %
Date: 2025-09-10 17:10:32 Functions: 4 6 66.7 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include <torch/torch.h>
      12             : #include "libmesh/int_range.h"
      13             : 
      14             : #define pti(tensor) MooseTensor::printTensorInfo(#tensor, tensor)
      15             : 
      16             : #define pez(tensor) MooseTensor::printElementZero(#tensor, tensor)
      17             : 
      18             : namespace MooseTensor
      19             : {
      20             : 
      21             : /// Passkey pattern key template
      22             : template <typename T>
      23             : class Key
      24             : {
      25             :   friend T;
      26             :   Key() {}
      27             :   Key(Key const &) {}
      28             : };
      29             : 
      30             : void printTensorInfo(const torch::Tensor & x);
      31             : void printTensorInfo(const std::string & name, const torch::Tensor & x);
      32             : 
      33             : void printElementZero(const torch::Tensor & tensor);
      34             : void printElementZero(const std::string & name, const torch::Tensor & tensor);
      35             : 
      36             : void printBuffer(const torch::Tensor & t,
      37             :                  const unsigned int & precision = 5,
      38             :                  const unsigned int & index = 0);
      39             : 
      40             : const torch::TensorOptions floatTensorOptions();
      41             : const torch::TensorOptions complexFloatTensorOptions();
      42             : const torch::TensorOptions intTensorOptions();
      43             : 
      44             : /// unsqueeze(0) ndim times
      45             : torch::Tensor unsqueeze0(const torch::Tensor & t, unsigned int ndim);
      46             : 
      47             : torch::Tensor trans2(const torch::Tensor & A2);
      48             : torch::Tensor ddot42(const torch::Tensor & A4, const torch::Tensor & B2);
      49             : torch::Tensor ddot44(const torch::Tensor & A4, const torch::Tensor & B4);
      50             : torch::Tensor dot22(const torch::Tensor & A2, const torch::Tensor & B2);
      51             : torch::Tensor dot24(const torch::Tensor & A2, const torch::Tensor & B4);
      52             : torch::Tensor dot42(const torch::Tensor & A4, const torch::Tensor & B2);
      53             : torch::Tensor dyad22(const torch::Tensor & A2, const torch::Tensor & B2);
      54             : 
      55             : template <typename T1, typename T2>
      56             : std::tuple<torch::Tensor, unsigned int, double>
      57           4 : conjugateGradientSolve(T1 A, torch::Tensor b, torch::Tensor x0, double tol, int64_t maxiter, T2 M)
      58             : {
      59             :   // initialize solution guess
      60           4 :   torch::Tensor x = x0.defined() ? x0.clone() : torch::zeros_like(b);
      61             : 
      62             :   // norm of b (for relative tolerance)
      63           4 :   const double b_norm = torch::norm(b).cpu().template item<double>();
      64           4 :   if (b_norm == 0.0)
      65             :     // solution is zero if b is zero
      66             :     return {x, 0u, 0.0};
      67             : 
      68             :   // default max iterations
      69           4 :   if (!maxiter)
      70             :     maxiter = b.numel();
      71             : 
      72             :   // initial residual
      73           8 :   torch::Tensor r = b - A(x);
      74             : 
      75             :   // Apply preconditioner (or identity)
      76           4 :   torch::Tensor z = M(r); // z = M^{-1} r
      77             : 
      78             :   // initial search direction p
      79           4 :   torch::Tensor p = z.clone();
      80             : 
      81             :   // dot product (r, z)
      82           4 :   double rz_old = torch::sum(r * z).cpu().template item<double>();
      83             : 
      84             :   // CG iteration
      85             :   double res_norm;
      86          12 :   for (const auto k : libMesh::make_range(maxiter))
      87             :   {
      88             :     // compute matrix-vector product
      89           0 :     const auto Ap = A(p);
      90             : 
      91             :     // step size alpha
      92          12 :     double alpha = rz_old / torch::sum(p * Ap).cpu().template item<double>();
      93             : 
      94             :     // update solution
      95          24 :     x = x + alpha * p;
      96             : 
      97             :     // update residual
      98          24 :     r = r - alpha * Ap;
      99          12 :     res_norm = torch::norm(r).cpu().template item<double>(); // ||r||
     100             : 
     101             :     // std::cout << res_norm << '\n';
     102             : 
     103             :     // Converged to desired tolerance
     104          12 :     if (res_norm <= tol * b_norm)
     105           4 :       return {x, k + 1, res_norm};
     106             : 
     107             :     // apply preconditioner to new residual
     108           8 :     z = M(r);
     109           8 :     const auto rz_new = torch::sum(r * z).cpu().template item<double>();
     110             : 
     111             :     // update scalar beta
     112           8 :     double beta = rz_new / rz_old;
     113             : 
     114             :     // update search direction
     115          16 :     p = z + beta * p;
     116             : 
     117             :     // prepare for next iteration
     118             :     rz_old = rz_new;
     119             :   }
     120             : 
     121             :   // Reached max iterations without full convergence
     122             :   return {x, maxiter, res_norm};
     123             : }
     124             : 
     125             : template <typename T>
     126             : std::tuple<torch::Tensor, unsigned int, double>
     127           4 : conjugateGradientSolve(
     128             :     T A, torch::Tensor b, torch::Tensor x0 = {}, double tol = 1e-6, int64_t maxiter = 0)
     129             : {
     130           8 :   return conjugateGradientSolve(A, b, x0, tol, maxiter, [](const torch::Tensor r) { return r; });
     131             : }
     132             : 
     133             : } // namespace MooseTensor

Generated by: LCOV version 1.14