Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #include "SwiftUtils.h"
10 : #include "SwiftApp.h"
11 : #include "MooseUtils.h"
12 : #include "Moose.h"
13 :
14 : namespace MooseTensor
15 : {
16 :
17 : struct TorchDeviceSingleton
18 : {
19 774 : static bool isSupported(torch::Dtype dtype, torch::Device device)
20 : {
21 : try
22 : {
23 774 : auto tensor = torch::zeros({1}, torch::dtype(dtype).device(device));
24 : return true;
25 : }
26 0 : catch (const std::exception &)
27 : {
28 : return false;
29 0 : }
30 : }
31 :
32 258 : TorchDeviceSingleton()
33 260 : : _device_string(torchDevice().empty() ? (torch::cuda::is_available()
34 : ? "cuda"
35 1 : : (torch::mps::is_available() ? "mps" : "cpu"))
36 : : torchDevice()),
37 258 : _device(_device_string),
38 258 : _floating_precision(precision().empty() ? "DEVICE_DEFAULT" : precision()),
39 0 : _float_dtype(_floating_precision == "DEVICE_DEFAULT" || _floating_precision == "DOUBLE"
40 258 : ? (isSupported(torch::kFloat64, _device) ? torch::kFloat64 : torch::kFloat32)
41 : : torch::kFloat32),
42 258 : _complex_float_dtype(isSupported(torch::kComplexDouble, _device) ? torch::kComplexDouble
43 : : torch::kComplexFloat),
44 258 : _int_dtype(isSupported(torch::kInt64, _device) ? torch::kInt64 : torch::kInt32)
45 : {
46 : mooseInfo("Running on '", _device_string, "'.");
47 258 : if (_float_dtype == torch::kFloat64)
48 : mooseInfo("Device supports double precision floating point numbers.");
49 : else
50 : mooseWarning("Running with single precision floating point numbers");
51 258 : }
52 :
53 : const std::string _device_string;
54 : const torch::Device _device;
55 : const std::string _floating_precision;
56 : const torch::Dtype _float_dtype;
57 : const torch::Dtype _complex_float_dtype;
58 : const torch::Dtype _int_dtype;
59 : };
60 :
61 : void
62 0 : printTensorInfo(const torch::Tensor & x)
63 : {
64 : Moose::out << " dimension: " << x.dim() << std::endl;
65 0 : Moose::out << " shape: " << x.sizes() << std::endl;
66 0 : Moose::out << " dtype: " << x.dtype() << std::endl;
67 0 : Moose::out << " device: " << x.device() << std::endl;
68 0 : Moose::out << " requires grad: " << (x.requires_grad() ? "true" : "false") << std::endl;
69 : Moose::out << std::endl;
70 0 : }
71 :
72 : void
73 0 : printTensorInfo(const std::string & name, const torch::Tensor & x)
74 : {
75 : Moose::out << "============== " << name << " ==============\n";
76 0 : printTensorInfo(x);
77 : Moose::out << std::endl;
78 0 : }
79 :
80 : void
81 0 : printElementZero(const torch::Tensor & tensor)
82 : {
83 : // Access the element at all zero indices
84 0 : auto element = tensor[0][0];
85 : // for (int i = 1; i < tensor.dim(); ++i)
86 : // element = element[0];
87 :
88 : Moose::out << element << std::endl;
89 0 : }
90 :
91 : void
92 0 : printElementZero(const std::string & name, const torch::Tensor & x)
93 : {
94 : Moose::out << "============== " << name << " ==============\n";
95 0 : printElementZero(x);
96 : Moose::out << std::endl;
97 0 : }
98 :
99 : const torch::TensorOptions
100 2452 : floatTensorOptions()
101 : {
102 2452 : const static TorchDeviceSingleton ts;
103 : return torch::TensorOptions()
104 : .dtype(ts._float_dtype)
105 : .layout(torch::kStrided)
106 : .memory_format(torch::MemoryFormat::Contiguous)
107 2452 : .pinned_memory(false)
108 2452 : .device(ts._device)
109 2452 : .requires_grad(false);
110 : }
111 :
112 : const torch::TensorOptions
113 1338 : complexFloatTensorOptions()
114 : {
115 1338 : const static TorchDeviceSingleton ts;
116 : return torch::TensorOptions()
117 : .dtype(ts._complex_float_dtype)
118 : .layout(torch::kStrided)
119 : .memory_format(torch::MemoryFormat::Contiguous)
120 1338 : .pinned_memory(false)
121 1338 : .device(ts._device)
122 1338 : .requires_grad(false);
123 : }
124 :
125 : const torch::TensorOptions
126 0 : intTensorOptions()
127 : {
128 0 : const static TorchDeviceSingleton ts;
129 : return torch::TensorOptions()
130 : .dtype(ts._int_dtype)
131 : .layout(torch::kStrided)
132 : .memory_format(torch::MemoryFormat::Contiguous)
133 0 : .pinned_memory(false)
134 0 : .device(ts._device)
135 0 : .requires_grad(false);
136 : }
137 :
138 : torch::Tensor
139 0 : unsqueeze0(const torch::Tensor & t, unsigned int ndim)
140 : {
141 : torch::Tensor u = t;
142 0 : for (unsigned int i = 0; i < ndim; ++i)
143 0 : u = u.unsqueeze(0);
144 0 : return u;
145 : }
146 :
147 : torch::Tensor
148 0 : trans2(const torch::Tensor & A2)
149 : {
150 0 : return torch::einsum("...ij ->...ji ", {A2});
151 : }
152 :
153 : torch::Tensor
154 0 : ddot42(const torch::Tensor & A4, const torch::Tensor & B2)
155 : {
156 0 : return torch::einsum("...ijkl,...lk ->...ij ", {A4, B2});
157 : }
158 :
159 : torch::Tensor
160 0 : ddot44(const torch::Tensor & A4, const torch::Tensor & B4)
161 : {
162 0 : return torch::einsum("...ijkl,...lkmn->...ijmn", {A4, B4});
163 : }
164 :
165 : torch::Tensor
166 0 : dot22(const torch::Tensor & A2, const torch::Tensor & B2)
167 : {
168 0 : return torch::einsum("...ij ,...jk ->...ik ", {A2, B2});
169 : }
170 :
171 : torch::Tensor
172 0 : dot24(const torch::Tensor & A2, const torch::Tensor & B4)
173 : {
174 0 : return torch::einsum("...ij ,...jkmn->...ikmn", {A2, B4});
175 : }
176 :
177 : torch::Tensor
178 0 : dot42(const torch::Tensor & A4, const torch::Tensor & B2)
179 : {
180 0 : return torch::einsum("...ijkl,...lm ->...ijkm", {A4, B2});
181 : }
182 :
183 : torch::Tensor
184 0 : dyad22(const torch::Tensor & A2, const torch::Tensor & B2)
185 : {
186 0 : return torch::einsum("...ij ,...kl ->...ijkl", {A2, B2});
187 : }
188 :
189 : void
190 0 : printBuffer(const torch::Tensor & t, const unsigned int & precision, const unsigned int & index)
191 : {
192 : /**
193 : * Print the entire field for debugging
194 : */
195 : torch::Tensor field = t;
196 : // for buffers higher than 3 dimensions, such as distribution functions
197 : // pass an index to print or call this method repeatedly to print all directions
198 : // higher than 4 dimensions is not supported
199 :
200 0 : if (t.dim() == 4)
201 0 : field = t.select(3, index);
202 :
203 0 : if (t.dim() > 4)
204 0 : mooseError("Higher than 4 dimensional tensor buffers are not supported.");
205 :
206 0 : if (t.dim() == 2)
207 : {
208 0 : for (int64_t j = 0; j < field.size(1); j++)
209 : {
210 0 : for (int64_t k = 0; k < field.size(0); k++)
211 0 : std::cout << std::fixed << std::setprecision(precision) << field[k][j].item<Real>() << " ";
212 : std::cout << std::endl;
213 : }
214 : }
215 :
216 0 : else if (t.dim() >= 3)
217 : {
218 0 : for (int64_t i = 0; i < field.size(2); i++)
219 : {
220 0 : for (int64_t j = 0; j < field.size(1); j++)
221 : {
222 0 : for (int64_t k = 0; k < field.size(0); k++)
223 0 : std::cout << std::fixed << std::setprecision(precision) << field[k][j][i].item<Real>()
224 0 : << " ";
225 : std::cout << std::endl;
226 : }
227 : std::cout << std::endl;
228 : }
229 : }
230 :
231 0 : else if (t.dim() == 1)
232 : {
233 0 : for (int64_t k = 0; k < field.size(0); k++)
234 0 : std::cout << std::fixed << std::setprecision(precision) << field[k].item<Real>() << " ";
235 : std::cout << std::endl;
236 : }
237 :
238 : else
239 0 : mooseError("Unsupported output dimension");
240 0 : }
241 :
242 : } // namespace MooseTensor
|