Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #pragma once
10 :
11 : #include <torch/torch.h>
12 : #include "libmesh/int_range.h"
13 :
14 : #define pti(tensor) MooseTensor::printTensorInfo(#tensor, tensor)
15 :
16 : #define pez(tensor) MooseTensor::printElementZero(#tensor, tensor)
17 :
18 : namespace MooseTensor
19 : {
20 :
21 : /// Passkey pattern key template
22 : template <typename T>
23 : class Key
24 : {
25 : friend T;
26 : Key() {}
27 : Key(Key const &) {}
28 : };
29 :
30 : void printTensorInfo(const torch::Tensor & x);
31 : void printTensorInfo(const std::string & name, const torch::Tensor & x);
32 :
33 : void printElementZero(const torch::Tensor & tensor);
34 : void printElementZero(const std::string & name, const torch::Tensor & tensor);
35 :
36 : void printBuffer(const torch::Tensor & t,
37 : const unsigned int & precision = 5,
38 : const unsigned int & index = 0);
39 :
40 : const torch::TensorOptions floatTensorOptions();
41 : const torch::TensorOptions complexFloatTensorOptions();
42 : const torch::TensorOptions intTensorOptions();
43 :
44 : /// unsqueeze(0) ndim times
45 : torch::Tensor unsqueeze0(const torch::Tensor & t, unsigned int ndim);
46 :
47 : torch::Tensor trans2(const torch::Tensor & A2);
48 : torch::Tensor ddot42(const torch::Tensor & A4, const torch::Tensor & B2);
49 : torch::Tensor ddot44(const torch::Tensor & A4, const torch::Tensor & B4);
50 : torch::Tensor dot22(const torch::Tensor & A2, const torch::Tensor & B2);
51 : torch::Tensor dot24(const torch::Tensor & A2, const torch::Tensor & B4);
52 : torch::Tensor dot42(const torch::Tensor & A4, const torch::Tensor & B2);
53 : torch::Tensor dyad22(const torch::Tensor & A2, const torch::Tensor & B2);
54 :
55 : template <typename T1, typename T2>
56 : std::tuple<torch::Tensor, unsigned int, double>
57 4 : conjugateGradientSolve(T1 A, torch::Tensor b, torch::Tensor x0, double tol, int64_t maxiter, T2 M)
58 : {
59 : // initialize solution guess
60 4 : torch::Tensor x = x0.defined() ? x0.clone() : torch::zeros_like(b);
61 :
62 : // norm of b (for relative tolerance)
63 4 : const double b_norm = torch::norm(b).cpu().template item<double>();
64 4 : if (b_norm == 0.0)
65 : // solution is zero if b is zero
66 : return {x, 0u, 0.0};
67 :
68 : // default max iterations
69 4 : if (!maxiter)
70 : maxiter = b.numel();
71 :
72 : // initial residual
73 8 : torch::Tensor r = b - A(x);
74 :
75 : // Apply preconditioner (or identity)
76 4 : torch::Tensor z = M(r); // z = M^{-1} r
77 :
78 : // initial search direction p
79 4 : torch::Tensor p = z.clone();
80 :
81 : // dot product (r, z)
82 4 : double rz_old = torch::sum(r * z).cpu().template item<double>();
83 :
84 : // CG iteration
85 : double res_norm;
86 12 : for (const auto k : libMesh::make_range(maxiter))
87 : {
88 : // compute matrix-vector product
89 0 : const auto Ap = A(p);
90 :
91 : // step size alpha
92 12 : double alpha = rz_old / torch::sum(p * Ap).cpu().template item<double>();
93 :
94 : // update solution
95 24 : x = x + alpha * p;
96 :
97 : // update residual
98 24 : r = r - alpha * Ap;
99 12 : res_norm = torch::norm(r).cpu().template item<double>(); // ||r||
100 :
101 : // std::cout << res_norm << '\n';
102 :
103 : // Converged to desired tolerance
104 12 : if (res_norm <= tol * b_norm)
105 4 : return {x, k + 1, res_norm};
106 :
107 : // apply preconditioner to new residual
108 8 : z = M(r);
109 8 : const auto rz_new = torch::sum(r * z).cpu().template item<double>();
110 :
111 : // update scalar beta
112 8 : double beta = rz_new / rz_old;
113 :
114 : // update search direction
115 16 : p = z + beta * p;
116 :
117 : // prepare for next iteration
118 : rz_old = rz_new;
119 : }
120 :
121 : // Reached max iterations without full convergence
122 : return {x, maxiter, res_norm};
123 : }
124 :
125 : template <typename T>
126 : std::tuple<torch::Tensor, unsigned int, double>
127 4 : conjugateGradientSolve(
128 : T A, torch::Tensor b, torch::Tensor x0 = {}, double tol = 1e-6, int64_t maxiter = 0)
129 : {
130 8 : return conjugateGradientSolve(A, b, x0, tol, maxiter, [](const torch::Tensor r) { return r; });
131 : }
132 :
133 : } // namespace MooseTensor
|