Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://www.mooseframework.org
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include "KokkosDatum.h"
13 :
14 : #include "MooseVariableFieldBase.h"
15 :
16 : namespace Moose::Kokkos
17 : {
18 :
19 : /**
20 : * The Kokkos wrapper classes for MOOSE-like shape function access
21 : */
22 : ///@{
23 : class VariablePhiValue
24 : {
25 : public:
26 : /**
27 : * Get the current shape function
28 : * @param datum The AssemblyDatum object of the current thread
29 : * @param i The element-local DOF index
30 : * @param qp The local quadrature point index
31 : * @returns The shape function
32 : */
33 4906888 : KOKKOS_FUNCTION Real operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
34 : {
35 4906888 : auto & elem = datum.elem();
36 4906888 : auto side = datum.side();
37 4906888 : auto fe = datum.jfe();
38 :
39 0 : return side == libMesh::invalid_uint
40 4906888 : ? datum.assembly().getPhi(elem.subdomain, elem.type, fe)(i, qp)
41 4906888 : : datum.assembly().getPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp);
42 : }
43 : };
44 :
45 : class VariablePhiGradient
46 : {
47 : public:
48 : /**
49 : * Get the gradient of the current shape function
50 : * @param datum The AssemblyDatum object of the current thread
51 : * @param i The element-local DOF index
52 : * @param qp The local quadrature point index
53 : * @returns The gradient of the shape function
54 : */
55 25385984 : KOKKOS_FUNCTION Real3 operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
56 : {
57 25385984 : auto & elem = datum.elem();
58 25385984 : auto side = datum.side();
59 25385984 : auto fe = datum.jfe();
60 :
61 0 : return datum.J(qp) *
62 : (side == libMesh::invalid_uint
63 25385984 : ? datum.assembly().getGradPhi(elem.subdomain, elem.type, fe)(i, qp)
64 50771968 : : datum.assembly().getGradPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp));
65 : }
66 : };
67 :
68 : class VariableTestValue
69 : {
70 : public:
71 : /**
72 : * Get the current test function
73 : * @param datum The AssemblyDatum object of the current thread
74 : * @param i The element-local DOF index
75 : * @param qp The local quadrature point index
76 : * @returns The test function
77 : */
78 113200760 : KOKKOS_FUNCTION Real operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
79 : {
80 113200760 : auto & elem = datum.elem();
81 113200760 : auto side = datum.side();
82 113200760 : auto fe = datum.ife();
83 :
84 0 : return side == libMesh::invalid_uint
85 113200760 : ? datum.assembly().getPhi(elem.subdomain, elem.type, fe)(i, qp)
86 113200760 : : datum.assembly().getPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp);
87 : }
88 : };
89 :
90 : class VariableTestGradient
91 : {
92 : public:
93 : /**
94 : * Get the gradient of the current test function
95 : * @param datum The AssemblyDatum object of the current thread
96 : * @param i The element-local DOF index
97 : * @param qp The local quadrature point index
98 : * @returns The gradient of the test function
99 : */
100 68215168 : KOKKOS_FUNCTION Real3 operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
101 : {
102 68215168 : auto & elem = datum.elem();
103 68215168 : auto side = datum.side();
104 68215168 : auto fe = datum.ife();
105 :
106 0 : return datum.J(qp) *
107 : (side == libMesh::invalid_uint
108 68215168 : ? datum.assembly().getGradPhi(elem.subdomain, elem.type, fe)(i, qp)
109 136430336 : : datum.assembly().getGradPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp));
110 : }
111 : };
112 :
113 : using ADVariablePhiValue = VariablePhiValue;
114 : using ADVariablePhiGradient = VariablePhiGradient;
115 : using ADVariableTestValue = VariableTestValue;
116 : using ADVariableTestGradient = VariableTestGradient;
117 :
118 : ///@}
119 :
120 : /**
121 : * The Kokkos wrapper classes for MOOSE-like variable value access
122 : */
123 : ///@{
124 : template <bool is_ad>
125 : class VariableValueTempl
126 : {
127 : using real_type = std::conditional_t<is_ad, ADReal, Real>;
128 :
129 : public:
130 : /**
131 : * Default constructor
132 : */
133 8867 : VariableValueTempl() = default;
134 : /**
135 : * Constructor
136 : * @param var The Kokkos variable
137 : * @param dof Whether to get DOF values
138 : */
139 2868 : VariableValueTempl(Variable var, bool dof = false) : _var(var), _dof(dof) {}
140 : /**
141 : * Constructor
142 : * @param var The MOOSE variable
143 : * @param tag The vector tag name
144 : * @param dof Whether to get DOF values
145 : */
146 8812 : VariableValueTempl(const MooseVariableFieldBase & var,
147 : const TagName & tag = Moose::SOLUTION_TAG,
148 4012 : bool dof = false)
149 4800 : : _var(var, tag), _dof(dof)
150 : {
151 8812 : }
152 : /**
153 : * Constructor
154 : * @param vars The MOOSE variables
155 : * @param tag The vector tag name
156 : * @param dof Whether to get DOF values
157 : */
158 : ///@{
159 : VariableValueTempl(const std::vector<const MooseVariableFieldBase *> & vars,
160 : const TagName & tag = Moose::SOLUTION_TAG,
161 : bool dof = false)
162 : : _var(vars, tag), _dof(dof)
163 : {
164 : }
165 : VariableValueTempl(const std::vector<MooseVariableFieldBase *> & vars,
166 : const TagName & tag = Moose::SOLUTION_TAG,
167 : bool dof = false)
168 : : _var(vars, tag), _dof(dof)
169 : {
170 : }
171 : ///@}
172 :
173 : /**
174 : * Copy constructor for parallel dispatch
175 : */
176 : VariableValueTempl(const VariableValueTempl<is_ad> & object);
177 : /**
178 : * Copy assignment operator
179 : */
180 : VariableValueTempl<is_ad> & operator=(const VariableValueTempl<is_ad> & object);
181 :
182 : /**
183 : * Get whether the variable was coupled
184 : * @returns Whether the variable was coupled
185 : */
186 28032 : KOKKOS_FUNCTION operator bool() const { return _var.coupled(); }
187 :
188 : /**
189 : * Get the current variable value
190 : * @param datum The Datum object of the current thread
191 : * @param idx The local quadrature point or DOF index
192 : * @param comp The variable component
193 : * @returns The variable value
194 : */
195 2592699 : KOKKOS_FUNCTION auto operator()(Datum & datum, unsigned int idx, unsigned int comp = 0) const
196 : {
197 2592699 : return get(datum, idx, comp);
198 : }
199 :
200 : /**
201 : * Get the current variable value
202 : * @param datum The AssemblyDatum object of the current thread
203 : * @param idx The local quadrature point or DOF index
204 : * @param comp The variable component
205 : * @returns The variable value
206 : */
207 : KOKKOS_FUNCTION auto
208 : operator()(AssemblyDatum & datum, unsigned int idx, unsigned int comp = 0) const;
209 :
210 : /**
211 : * Get the Kokkos variable
212 : * @returns The Kokkos variable
213 : */
214 460146 : KOKKOS_FUNCTION const Variable & variable() const { return _var; }
215 :
216 : private:
217 : /**
218 : * Get the current variable value
219 : * @param datum The Datum object of the current thread
220 : * @param idx The local quadrature point or DOF index
221 : * @param comp The variable component
222 : * @param seed The derivative seed (only meaningful for AD)
223 : * @returns The variable value
224 : */
225 : KOKKOS_FUNCTION auto
226 : get(Datum & datum, unsigned int idx, unsigned int comp = 0, Real seed = 0) const;
227 :
228 : /**
229 : * Coupled Kokkos variable
230 : */
231 : Variable _var;
232 : /**
233 : * Derivative seed of each component for AD
234 : */
235 : Array<Real> _seed;
236 : /**
237 : * Flag whether DOF values are requested
238 : */
239 : bool _dof = false;
240 : };
241 :
242 : template <bool is_ad>
243 587081 : VariableValueTempl<is_ad>::VariableValueTempl(const VariableValueTempl<is_ad> & object)
244 334813 : : _var(object._var), _seed(object._seed), _dof(object._dof)
245 : {
246 : if constexpr (is_ad)
247 59425 : if (_var.coupled())
248 : {
249 57533 : if (!_seed.isAlloc())
250 57533 : _seed.create(_var.components());
251 :
252 115066 : for (unsigned int comp = 0; comp < _var.components(); ++comp)
253 57533 : _seed[comp] =
254 29058 : _var.dot() ? _var.mooseVar(comp)->sys().duDotDu(_var.var(comp)) : (_var.old() ? 0 : 1);
255 :
256 57533 : _seed.copyToDevice();
257 : }
258 587081 : }
259 :
260 : template <bool is_ad>
261 : VariableValueTempl<is_ad> &
262 8867 : VariableValueTempl<is_ad>::operator=(const VariableValueTempl<is_ad> & object)
263 : {
264 8867 : _var = object._var;
265 8867 : _dof = object._dof;
266 :
267 8867 : return *this;
268 : }
269 :
270 : template <bool is_ad>
271 : KOKKOS_FUNCTION auto
272 38051287 : VariableValueTempl<is_ad>::operator()(AssemblyDatum & datum,
273 : unsigned int idx,
274 : unsigned int comp) const
275 : {
276 : if constexpr (is_ad)
277 : {
278 2369564 : Real seed =
279 2369564 : datum.do_derivatives() && _var.coupled() && _var.sys(comp) == datum.sys() ? _seed[comp] : 0;
280 :
281 2369564 : return get(datum, idx, comp, seed);
282 : }
283 : else
284 35681723 : return get(datum, idx, comp);
285 : }
286 :
287 : template <bool is_ad>
288 : KOKKOS_FUNCTION auto
289 40643986 : VariableValueTempl<is_ad>::get(Datum & datum,
290 : unsigned int idx,
291 : unsigned int comp,
292 : [[maybe_unused]] Real seed) const
293 : {
294 : KOKKOS_ASSERT(_var.initialized());
295 :
296 2369564 : real_type value;
297 :
298 40643986 : if (_var.coupled())
299 : {
300 40612718 : auto & sys = datum.system(_var.sys(comp));
301 40612718 : auto var = _var.var(comp);
302 40612718 : auto tag = _var.tag();
303 :
304 40612718 : if (_dof)
305 : {
306 : unsigned int dof;
307 :
308 4355358 : if (datum.isNodal())
309 : {
310 4354758 : auto node = datum.node();
311 4354758 : dof = sys.getNodeLocalDofIndex(node, 0, var);
312 : }
313 : else
314 : {
315 600 : auto elem = datum.elem().id;
316 600 : dof = sys.getElemLocalDofIndex(elem, idx, var);
317 : }
318 :
319 : if constexpr (is_ad)
320 68852 : value = sys.getVectorDofADValue(dof, tag, seed);
321 : else
322 4286506 : value = sys.getVectorDofValue(dof, tag);
323 : }
324 : else
325 : {
326 36257360 : auto & elem = datum.elem();
327 36257360 : auto side = datum.side();
328 :
329 : if constexpr (is_ad)
330 2285556 : value = side == libMesh::invalid_uint
331 4571112 : ? sys.getVectorQpADValue(elem, datum.qpOffset(), idx, var, tag, seed)
332 : : sys.getVectorQpADValueFace(elem, side, idx, var, tag, seed);
333 : else
334 33971804 : value = side == libMesh::invalid_uint
335 33971804 : ? sys.getVectorQpValue(elem, datum.qpOffset() + idx, var, tag)
336 265208 : : sys.getVectorQpValueFace(elem, side, idx, var, tag);
337 : }
338 : }
339 : else
340 31268 : value = _var.value(comp);
341 :
342 40643986 : return value;
343 0 : }
344 :
345 : template <bool is_ad>
346 : class VariableGradientTempl
347 : {
348 : using real3_type = std::conditional_t<is_ad, ADReal3, Real3>;
349 :
350 : public:
351 : /**
352 : * Default constructor
353 : */
354 : VariableGradientTempl() = default;
355 : /**
356 : * Constructor
357 : * @param var The Kokkos variable
358 : */
359 665 : VariableGradientTempl(Variable var) : _var(var) {}
360 : /**
361 : * Constructor
362 : * @param var The MOOSE variable
363 : * @param tag The vector tag name
364 : */
365 3690 : VariableGradientTempl(const MooseVariableFieldBase & var,
366 1650 : const TagName & tag = Moose::SOLUTION_TAG)
367 2040 : : _var(var, tag)
368 : {
369 3690 : }
370 : /**
371 : * Constructor
372 : * @param vars The MOOSE variables
373 : * @param tag The vector tag name
374 : */
375 : ///@{
376 : VariableGradientTempl(const std::vector<const MooseVariableFieldBase *> & vars,
377 : const TagName & tag = Moose::SOLUTION_TAG)
378 : : _var(vars, tag)
379 : {
380 : }
381 : VariableGradientTempl(const std::vector<MooseVariableFieldBase *> & vars,
382 : const TagName & tag = Moose::SOLUTION_TAG)
383 : : _var(vars, tag)
384 : {
385 : }
386 : ///@}
387 :
388 : /**
389 : * Copy constructor for parallel dispatch
390 : */
391 : VariableGradientTempl(const VariableGradientTempl<is_ad> & object);
392 : /**
393 : * Copy assignment operator
394 : */
395 : VariableGradientTempl<is_ad> & operator=(const VariableGradientTempl<is_ad> & object);
396 :
397 : /**
398 : * Get whether the variable was coupled
399 : * @returns Whether the variable was coupled
400 : */
401 : KOKKOS_FUNCTION operator bool() const { return _var.coupled(); }
402 :
403 : /**
404 : * Get the current variable gradient
405 : * @param datum The Datum object of the current thread
406 : * @param qp The local quadrature point index
407 : * @param comp The variable component
408 : * @returns The variable gradient
409 : */
410 : KOKKOS_FUNCTION auto operator()(Datum & datum, unsigned int qp, unsigned int comp = 0) const
411 : {
412 : return get(datum, qp, comp);
413 : }
414 :
415 : /**
416 : * Get the current variable gradient
417 : * @param datum The AssemblyDatum object of the current thread
418 : * @param qp The local quadrature point index
419 : * @param comp The variable component
420 : * @returns The variable gradient
421 : */
422 : KOKKOS_FUNCTION auto
423 : operator()(AssemblyDatum & datum, unsigned int qp, unsigned int comp = 0) const;
424 :
425 : /**
426 : * Get the Kokkos variable
427 : * @returns The Kokkos variable
428 : */
429 : KOKKOS_FUNCTION const Variable & variable() const { return _var; }
430 :
431 : private:
432 : /**
433 : * Get the current variable gradient
434 : * @param datum The Datum object of the current thread
435 : * @param qp The local quadrature point index
436 : * @param comp The variable component
437 : * @param seed The derivative seed (only meaningful for AD)
438 : * @returns The variable gradient
439 : */
440 : KOKKOS_FUNCTION auto
441 : get(Datum & datum, unsigned int qp, unsigned int comp = 0, Real seed = 0) const;
442 :
443 : /**
444 : * Coupled Kokkos variable
445 : */
446 : Variable _var;
447 : /**
448 : * Derivative seed of each component for AD
449 : */
450 : Array<Real> _seed;
451 : };
452 :
453 : template <bool is_ad>
454 186869 : VariableGradientTempl<is_ad>::VariableGradientTempl(const VariableGradientTempl<is_ad> & object)
455 110929 : : _var(object._var), _seed(object._seed)
456 : {
457 : if constexpr (is_ad)
458 17931 : if (_var.coupled())
459 : {
460 17931 : if (!_seed.isAlloc())
461 17931 : _seed.create(_var.components());
462 :
463 35862 : for (unsigned int comp = 0; comp < _var.components(); ++comp)
464 17931 : _seed[comp] =
465 9068 : _var.dot() ? _var.mooseVar(comp)->sys().duDotDu(_var.var(comp)) : (_var.old() ? 0 : 1);
466 :
467 17931 : _seed.copyToDevice();
468 : }
469 186869 : }
470 :
471 : template <bool is_ad>
472 : VariableGradientTempl<is_ad> &
473 : VariableGradientTempl<is_ad>::operator=(const VariableGradientTempl<is_ad> & object)
474 : {
475 : _var = object._var;
476 :
477 : return *this;
478 : }
479 :
480 : template <bool is_ad>
481 : KOKKOS_FUNCTION auto
482 42829952 : VariableGradientTempl<is_ad>::operator()(AssemblyDatum & datum,
483 : unsigned int qp,
484 : unsigned int comp) const
485 : {
486 : if constexpr (is_ad)
487 : {
488 2369344 : Real seed =
489 2369344 : datum.do_derivatives() && _var.coupled() && _var.sys(comp) == datum.sys() ? _seed[comp] : 0;
490 :
491 2369344 : return get(datum, qp, comp, seed);
492 : }
493 : else
494 40460608 : return get(datum, qp, comp);
495 : }
496 :
497 : template <bool is_ad>
498 : KOKKOS_FUNCTION auto
499 42829952 : VariableGradientTempl<is_ad>::get(Datum & datum,
500 : unsigned int qp,
501 : unsigned int comp,
502 : [[maybe_unused]] Real seed) const
503 : {
504 : KOKKOS_ASSERT(_var.initialized());
505 :
506 42829952 : real3_type grad;
507 :
508 42829952 : if (_var.coupled())
509 : {
510 : KOKKOS_ASSERT(!datum.isNodal());
511 :
512 42829952 : auto & elem = datum.elem();
513 42829952 : auto side = datum.side();
514 :
515 : if constexpr (is_ad)
516 2369344 : grad =
517 : side == libMesh::invalid_uint
518 4738688 : ? datum.system(_var.sys(comp))
519 : .getVectorQpADGrad(
520 : elem, datum.J(qp), datum.qpOffset(), qp, _var.var(comp), _var.tag(), seed)
521 0 : : datum.system(_var.sys(comp))
522 : .getVectorQpADGradFace(
523 : elem, side, datum.J(qp), qp, _var.var(comp), _var.tag(), seed);
524 : else
525 40460608 : grad =
526 : side == libMesh::invalid_uint
527 80921216 : ? datum.system(_var.sys(comp))
528 40460608 : .getVectorQpGrad(elem, datum.qpOffset() + qp, _var.var(comp), _var.tag())
529 0 : : datum.system(_var.sys(comp))
530 : .getVectorQpGradFace(elem, side, datum.J(qp), qp, _var.var(comp), _var.tag());
531 : }
532 :
533 42829952 : return grad;
534 0 : }
535 :
536 : using VariableValue = VariableValueTempl<false>;
537 : using ADVariableValue = VariableValueTempl<true>;
538 : using VariableGradient = VariableGradientTempl<false>;
539 : using ADVariableGradient = VariableGradientTempl<true>;
540 :
541 : template <>
542 : struct ArrayDeepCopy<ADVariableValue>
543 : {
544 : static constexpr bool value = true;
545 : };
546 :
547 : template <>
548 : struct ArrayDeepCopy<ADVariableGradient>
549 : {
550 : static constexpr bool value = true;
551 : };
552 : ///@}
553 :
554 : } // namespace Moose::Kokkos
|