LCOV - code coverage report
Current view: top level - include/kokkos/base - KokkosVariableValue.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #32971 (54bef8) with base c6cf66 Lines: 107 115 93.0 %
Date: 2026-05-29 20:35:17 Functions: 28 29 96.6 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://www.mooseframework.org
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosDatum.h"
      13             : 
      14             : #include "MooseVariableFieldBase.h"
      15             : 
      16             : namespace Moose::Kokkos
      17             : {
      18             : 
      19             : /**
      20             :  * The Kokkos wrapper classes for MOOSE-like shape function access
      21             :  */
      22             : ///@{
      23             : class VariablePhiValue
      24             : {
      25             : public:
      26             :   /**
      27             :    * Get the current shape function
      28             :    * @param datum The AssemblyDatum object of the current thread
      29             :    * @param i The element-local DOF index
      30             :    * @param qp The local quadrature point index
      31             :    * @returns The shape function
      32             :    */
      33     4906888 :   KOKKOS_FUNCTION Real operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
      34             :   {
      35     4906888 :     auto & elem = datum.elem();
      36     4906888 :     auto side = datum.side();
      37     4906888 :     auto fe = datum.jfe();
      38             : 
      39           0 :     return side == libMesh::invalid_uint
      40     4906888 :                ? datum.assembly().getPhi(elem.subdomain, elem.type, fe)(i, qp)
      41     4906888 :                : datum.assembly().getPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp);
      42             :   }
      43             : };
      44             : 
      45             : class VariablePhiGradient
      46             : {
      47             : public:
      48             :   /**
      49             :    * Get the gradient of the current shape function
      50             :    * @param datum The AssemblyDatum object of the current thread
      51             :    * @param i The element-local DOF index
      52             :    * @param qp The local quadrature point index
      53             :    * @returns The gradient of the shape function
      54             :    */
      55    25385984 :   KOKKOS_FUNCTION Real3 operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
      56             :   {
      57    25385984 :     auto & elem = datum.elem();
      58    25385984 :     auto side = datum.side();
      59    25385984 :     auto fe = datum.jfe();
      60             : 
      61           0 :     return datum.J(qp) *
      62             :            (side == libMesh::invalid_uint
      63    25385984 :                 ? datum.assembly().getGradPhi(elem.subdomain, elem.type, fe)(i, qp)
      64    50771968 :                 : datum.assembly().getGradPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp));
      65             :   }
      66             : };
      67             : 
      68             : class VariableTestValue
      69             : {
      70             : public:
      71             :   /**
      72             :    * Get the current test function
      73             :    * @param datum The AssemblyDatum object of the current thread
      74             :    * @param i The element-local DOF index
      75             :    * @param qp The local quadrature point index
      76             :    * @returns The test function
      77             :    */
      78   113200760 :   KOKKOS_FUNCTION Real operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
      79             :   {
      80   113200760 :     auto & elem = datum.elem();
      81   113200760 :     auto side = datum.side();
      82   113200760 :     auto fe = datum.ife();
      83             : 
      84           0 :     return side == libMesh::invalid_uint
      85   113200760 :                ? datum.assembly().getPhi(elem.subdomain, elem.type, fe)(i, qp)
      86   113200760 :                : datum.assembly().getPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp);
      87             :   }
      88             : };
      89             : 
      90             : class VariableTestGradient
      91             : {
      92             : public:
      93             :   /**
      94             :    * Get the gradient of the current test function
      95             :    * @param datum The AssemblyDatum object of the current thread
      96             :    * @param i The element-local DOF index
      97             :    * @param qp The local quadrature point index
      98             :    * @returns The gradient of the test function
      99             :    */
     100    68215168 :   KOKKOS_FUNCTION Real3 operator()(AssemblyDatum & datum, unsigned int i, unsigned int qp) const
     101             :   {
     102    68215168 :     auto & elem = datum.elem();
     103    68215168 :     auto side = datum.side();
     104    68215168 :     auto fe = datum.ife();
     105             : 
     106           0 :     return datum.J(qp) *
     107             :            (side == libMesh::invalid_uint
     108    68215168 :                 ? datum.assembly().getGradPhi(elem.subdomain, elem.type, fe)(i, qp)
     109   136430336 :                 : datum.assembly().getGradPhiFace(elem.subdomain, elem.type, fe)(side)(i, qp));
     110             :   }
     111             : };
     112             : 
     113             : using ADVariablePhiValue = VariablePhiValue;
     114             : using ADVariablePhiGradient = VariablePhiGradient;
     115             : using ADVariableTestValue = VariableTestValue;
     116             : using ADVariableTestGradient = VariableTestGradient;
     117             : 
     118             : ///@}
     119             : 
     120             : /**
     121             :  * The Kokkos wrapper classes for MOOSE-like variable value access
     122             :  */
     123             : ///@{
     124             : template <bool is_ad>
     125             : class VariableValueTempl
     126             : {
     127             :   using real_type = std::conditional_t<is_ad, ADReal, Real>;
     128             : 
     129             : public:
     130             :   /**
     131             :    * Default constructor
     132             :    */
     133        8867 :   VariableValueTempl() = default;
     134             :   /**
     135             :    * Constructor
     136             :    * @param var The Kokkos variable
     137             :    * @param dof Whether to get DOF values
     138             :    */
     139        2868 :   VariableValueTempl(Variable var, bool dof = false) : _var(var), _dof(dof) {}
     140             :   /**
     141             :    * Constructor
     142             :    * @param var The MOOSE variable
     143             :    * @param tag The vector tag name
     144             :    * @param dof Whether to get DOF values
     145             :    */
     146        8812 :   VariableValueTempl(const MooseVariableFieldBase & var,
     147             :                      const TagName & tag = Moose::SOLUTION_TAG,
     148        4012 :                      bool dof = false)
     149        4800 :     : _var(var, tag), _dof(dof)
     150             :   {
     151        8812 :   }
     152             :   /**
     153             :    * Constructor
     154             :    * @param vars The MOOSE variables
     155             :    * @param tag The vector tag name
     156             :    * @param dof Whether to get DOF values
     157             :    */
     158             :   ///@{
     159             :   VariableValueTempl(const std::vector<const MooseVariableFieldBase *> & vars,
     160             :                      const TagName & tag = Moose::SOLUTION_TAG,
     161             :                      bool dof = false)
     162             :     : _var(vars, tag), _dof(dof)
     163             :   {
     164             :   }
     165             :   VariableValueTempl(const std::vector<MooseVariableFieldBase *> & vars,
     166             :                      const TagName & tag = Moose::SOLUTION_TAG,
     167             :                      bool dof = false)
     168             :     : _var(vars, tag), _dof(dof)
     169             :   {
     170             :   }
     171             :   ///@}
     172             : 
     173             :   /**
     174             :    * Copy constructor for parallel dispatch
     175             :    */
     176             :   VariableValueTempl(const VariableValueTempl<is_ad> & object);
     177             :   /**
     178             :    * Copy assignment operator
     179             :    */
     180             :   VariableValueTempl<is_ad> & operator=(const VariableValueTempl<is_ad> & object);
     181             : 
     182             :   /**
     183             :    * Get whether the variable was coupled
     184             :    * @returns Whether the variable was coupled
     185             :    */
     186       28032 :   KOKKOS_FUNCTION operator bool() const { return _var.coupled(); }
     187             : 
     188             :   /**
     189             :    * Get the current variable value
     190             :    * @param datum The Datum object of the current thread
     191             :    * @param idx The local quadrature point or DOF index
     192             :    * @param comp The variable component
     193             :    * @returns The variable value
     194             :    */
     195     2592699 :   KOKKOS_FUNCTION auto operator()(Datum & datum, unsigned int idx, unsigned int comp = 0) const
     196             :   {
     197     2592699 :     return get(datum, idx, comp);
     198             :   }
     199             : 
     200             :   /**
     201             :    * Get the current variable value
     202             :    * @param datum The AssemblyDatum object of the current thread
     203             :    * @param idx The local quadrature point or DOF index
     204             :    * @param comp The variable component
     205             :    * @returns The variable value
     206             :    */
     207             :   KOKKOS_FUNCTION auto
     208             :   operator()(AssemblyDatum & datum, unsigned int idx, unsigned int comp = 0) const;
     209             : 
     210             :   /**
     211             :    * Get the Kokkos variable
     212             :    * @returns The Kokkos variable
     213             :    */
     214      460146 :   KOKKOS_FUNCTION const Variable & variable() const { return _var; }
     215             : 
     216             : private:
     217             :   /**
     218             :    * Get the current variable value
     219             :    * @param datum The Datum object of the current thread
     220             :    * @param idx The local quadrature point or DOF index
     221             :    * @param comp The variable component
     222             :    * @param seed The derivative seed (only meaningful for AD)
     223             :    * @returns The variable value
     224             :    */
     225             :   KOKKOS_FUNCTION auto
     226             :   get(Datum & datum, unsigned int idx, unsigned int comp = 0, Real seed = 0) const;
     227             : 
     228             :   /**
     229             :    * Coupled Kokkos variable
     230             :    */
     231             :   Variable _var;
     232             :   /**
     233             :    * Derivative seed of each component for AD
     234             :    */
     235             :   Array<Real> _seed;
     236             :   /**
     237             :    * Flag whether DOF values are requested
     238             :    */
     239             :   bool _dof = false;
     240             : };
     241             : 
     242             : template <bool is_ad>
     243      587081 : VariableValueTempl<is_ad>::VariableValueTempl(const VariableValueTempl<is_ad> & object)
     244      334813 :   : _var(object._var), _seed(object._seed), _dof(object._dof)
     245             : {
     246             :   if constexpr (is_ad)
     247       59425 :     if (_var.coupled())
     248             :     {
     249       57533 :       if (!_seed.isAlloc())
     250       57533 :         _seed.create(_var.components());
     251             : 
     252      115066 :       for (unsigned int comp = 0; comp < _var.components(); ++comp)
     253       57533 :         _seed[comp] =
     254       29058 :             _var.dot() ? _var.mooseVar(comp)->sys().duDotDu(_var.var(comp)) : (_var.old() ? 0 : 1);
     255             : 
     256       57533 :       _seed.copyToDevice();
     257             :     }
     258      587081 : }
     259             : 
     260             : template <bool is_ad>
     261             : VariableValueTempl<is_ad> &
     262        8867 : VariableValueTempl<is_ad>::operator=(const VariableValueTempl<is_ad> & object)
     263             : {
     264        8867 :   _var = object._var;
     265        8867 :   _dof = object._dof;
     266             : 
     267        8867 :   return *this;
     268             : }
     269             : 
     270             : template <bool is_ad>
     271             : KOKKOS_FUNCTION auto
     272    38051287 : VariableValueTempl<is_ad>::operator()(AssemblyDatum & datum,
     273             :                                       unsigned int idx,
     274             :                                       unsigned int comp) const
     275             : {
     276             :   if constexpr (is_ad)
     277             :   {
     278     2369564 :     Real seed =
     279     2369564 :         datum.do_derivatives() && _var.coupled() && _var.sys(comp) == datum.sys() ? _seed[comp] : 0;
     280             : 
     281     2369564 :     return get(datum, idx, comp, seed);
     282             :   }
     283             :   else
     284    35681723 :     return get(datum, idx, comp);
     285             : }
     286             : 
     287             : template <bool is_ad>
     288             : KOKKOS_FUNCTION auto
     289    40643986 : VariableValueTempl<is_ad>::get(Datum & datum,
     290             :                                unsigned int idx,
     291             :                                unsigned int comp,
     292             :                                [[maybe_unused]] Real seed) const
     293             : {
     294             :   KOKKOS_ASSERT(_var.initialized());
     295             : 
     296     2369564 :   real_type value;
     297             : 
     298    40643986 :   if (_var.coupled())
     299             :   {
     300    40612718 :     auto & sys = datum.system(_var.sys(comp));
     301    40612718 :     auto var = _var.var(comp);
     302    40612718 :     auto tag = _var.tag();
     303             : 
     304    40612718 :     if (_dof)
     305             :     {
     306             :       unsigned int dof;
     307             : 
     308     4355358 :       if (datum.isNodal())
     309             :       {
     310     4354758 :         auto node = datum.node();
     311     4354758 :         dof = sys.getNodeLocalDofIndex(node, 0, var);
     312             :       }
     313             :       else
     314             :       {
     315         600 :         auto elem = datum.elem().id;
     316         600 :         dof = sys.getElemLocalDofIndex(elem, idx, var);
     317             :       }
     318             : 
     319             :       if constexpr (is_ad)
     320       68852 :         value = sys.getVectorDofADValue(dof, tag, seed);
     321             :       else
     322     4286506 :         value = sys.getVectorDofValue(dof, tag);
     323             :     }
     324             :     else
     325             :     {
     326    36257360 :       auto & elem = datum.elem();
     327    36257360 :       auto side = datum.side();
     328             : 
     329             :       if constexpr (is_ad)
     330     2285556 :         value = side == libMesh::invalid_uint
     331     4571112 :                     ? sys.getVectorQpADValue(elem, datum.qpOffset(), idx, var, tag, seed)
     332             :                     : sys.getVectorQpADValueFace(elem, side, idx, var, tag, seed);
     333             :       else
     334    33971804 :         value = side == libMesh::invalid_uint
     335    33971804 :                     ? sys.getVectorQpValue(elem, datum.qpOffset() + idx, var, tag)
     336      265208 :                     : sys.getVectorQpValueFace(elem, side, idx, var, tag);
     337             :     }
     338             :   }
     339             :   else
     340       31268 :     value = _var.value(comp);
     341             : 
     342    40643986 :   return value;
     343           0 : }
     344             : 
     345             : template <bool is_ad>
     346             : class VariableGradientTempl
     347             : {
     348             :   using real3_type = std::conditional_t<is_ad, ADReal3, Real3>;
     349             : 
     350             : public:
     351             :   /**
     352             :    * Default constructor
     353             :    */
     354             :   VariableGradientTempl() = default;
     355             :   /**
     356             :    * Constructor
     357             :    * @param var The Kokkos variable
     358             :    */
     359         665 :   VariableGradientTempl(Variable var) : _var(var) {}
     360             :   /**
     361             :    * Constructor
     362             :    * @param var The MOOSE variable
     363             :    * @param tag The vector tag name
     364             :    */
     365        3690 :   VariableGradientTempl(const MooseVariableFieldBase & var,
     366        1650 :                         const TagName & tag = Moose::SOLUTION_TAG)
     367        2040 :     : _var(var, tag)
     368             :   {
     369        3690 :   }
     370             :   /**
     371             :    * Constructor
     372             :    * @param vars The MOOSE variables
     373             :    * @param tag The vector tag name
     374             :    */
     375             :   ///@{
     376             :   VariableGradientTempl(const std::vector<const MooseVariableFieldBase *> & vars,
     377             :                         const TagName & tag = Moose::SOLUTION_TAG)
     378             :     : _var(vars, tag)
     379             :   {
     380             :   }
     381             :   VariableGradientTempl(const std::vector<MooseVariableFieldBase *> & vars,
     382             :                         const TagName & tag = Moose::SOLUTION_TAG)
     383             :     : _var(vars, tag)
     384             :   {
     385             :   }
     386             :   ///@}
     387             : 
     388             :   /**
     389             :    * Copy constructor for parallel dispatch
     390             :    */
     391             :   VariableGradientTempl(const VariableGradientTempl<is_ad> & object);
     392             :   /**
     393             :    * Copy assignment operator
     394             :    */
     395             :   VariableGradientTempl<is_ad> & operator=(const VariableGradientTempl<is_ad> & object);
     396             : 
     397             :   /**
     398             :    * Get whether the variable was coupled
     399             :    * @returns Whether the variable was coupled
     400             :    */
     401             :   KOKKOS_FUNCTION operator bool() const { return _var.coupled(); }
     402             : 
     403             :   /**
     404             :    * Get the current variable gradient
     405             :    * @param datum The Datum object of the current thread
     406             :    * @param qp The local quadrature point index
     407             :    * @param comp The variable component
     408             :    * @returns The variable gradient
     409             :    */
     410             :   KOKKOS_FUNCTION auto operator()(Datum & datum, unsigned int qp, unsigned int comp = 0) const
     411             :   {
     412             :     return get(datum, qp, comp);
     413             :   }
     414             : 
     415             :   /**
     416             :    * Get the current variable gradient
     417             :    * @param datum The AssemblyDatum object of the current thread
     418             :    * @param qp The local quadrature point index
     419             :    * @param comp The variable component
     420             :    * @returns The variable gradient
     421             :    */
     422             :   KOKKOS_FUNCTION auto
     423             :   operator()(AssemblyDatum & datum, unsigned int qp, unsigned int comp = 0) const;
     424             : 
     425             :   /**
     426             :    * Get the Kokkos variable
     427             :    * @returns The Kokkos variable
     428             :    */
     429             :   KOKKOS_FUNCTION const Variable & variable() const { return _var; }
     430             : 
     431             : private:
     432             :   /**
     433             :    * Get the current variable gradient
     434             :    * @param datum The Datum object of the current thread
     435             :    * @param qp The local quadrature point index
     436             :    * @param comp The variable component
     437             :    * @param seed The derivative seed (only meaningful for AD)
     438             :    * @returns The variable gradient
     439             :    */
     440             :   KOKKOS_FUNCTION auto
     441             :   get(Datum & datum, unsigned int qp, unsigned int comp = 0, Real seed = 0) const;
     442             : 
     443             :   /**
     444             :    * Coupled Kokkos variable
     445             :    */
     446             :   Variable _var;
     447             :   /**
     448             :    * Derivative seed of each component for AD
     449             :    */
     450             :   Array<Real> _seed;
     451             : };
     452             : 
     453             : template <bool is_ad>
     454      186869 : VariableGradientTempl<is_ad>::VariableGradientTempl(const VariableGradientTempl<is_ad> & object)
     455      110929 :   : _var(object._var), _seed(object._seed)
     456             : {
     457             :   if constexpr (is_ad)
     458       17931 :     if (_var.coupled())
     459             :     {
     460       17931 :       if (!_seed.isAlloc())
     461       17931 :         _seed.create(_var.components());
     462             : 
     463       35862 :       for (unsigned int comp = 0; comp < _var.components(); ++comp)
     464       17931 :         _seed[comp] =
     465        9068 :             _var.dot() ? _var.mooseVar(comp)->sys().duDotDu(_var.var(comp)) : (_var.old() ? 0 : 1);
     466             : 
     467       17931 :       _seed.copyToDevice();
     468             :     }
     469      186869 : }
     470             : 
     471             : template <bool is_ad>
     472             : VariableGradientTempl<is_ad> &
     473             : VariableGradientTempl<is_ad>::operator=(const VariableGradientTempl<is_ad> & object)
     474             : {
     475             :   _var = object._var;
     476             : 
     477             :   return *this;
     478             : }
     479             : 
     480             : template <bool is_ad>
     481             : KOKKOS_FUNCTION auto
     482    42829952 : VariableGradientTempl<is_ad>::operator()(AssemblyDatum & datum,
     483             :                                          unsigned int qp,
     484             :                                          unsigned int comp) const
     485             : {
     486             :   if constexpr (is_ad)
     487             :   {
     488     2369344 :     Real seed =
     489     2369344 :         datum.do_derivatives() && _var.coupled() && _var.sys(comp) == datum.sys() ? _seed[comp] : 0;
     490             : 
     491     2369344 :     return get(datum, qp, comp, seed);
     492             :   }
     493             :   else
     494    40460608 :     return get(datum, qp, comp);
     495             : }
     496             : 
     497             : template <bool is_ad>
     498             : KOKKOS_FUNCTION auto
     499    42829952 : VariableGradientTempl<is_ad>::get(Datum & datum,
     500             :                                   unsigned int qp,
     501             :                                   unsigned int comp,
     502             :                                   [[maybe_unused]] Real seed) const
     503             : {
     504             :   KOKKOS_ASSERT(_var.initialized());
     505             : 
     506    42829952 :   real3_type grad;
     507             : 
     508    42829952 :   if (_var.coupled())
     509             :   {
     510             :     KOKKOS_ASSERT(!datum.isNodal());
     511             : 
     512    42829952 :     auto & elem = datum.elem();
     513    42829952 :     auto side = datum.side();
     514             : 
     515             :     if constexpr (is_ad)
     516     2369344 :       grad =
     517             :           side == libMesh::invalid_uint
     518     4738688 :               ? datum.system(_var.sys(comp))
     519             :                     .getVectorQpADGrad(
     520             :                         elem, datum.J(qp), datum.qpOffset(), qp, _var.var(comp), _var.tag(), seed)
     521           0 :               : datum.system(_var.sys(comp))
     522             :                     .getVectorQpADGradFace(
     523             :                         elem, side, datum.J(qp), qp, _var.var(comp), _var.tag(), seed);
     524             :     else
     525    40460608 :       grad =
     526             :           side == libMesh::invalid_uint
     527    80921216 :               ? datum.system(_var.sys(comp))
     528    40460608 :                     .getVectorQpGrad(elem, datum.qpOffset() + qp, _var.var(comp), _var.tag())
     529           0 :               : datum.system(_var.sys(comp))
     530             :                     .getVectorQpGradFace(elem, side, datum.J(qp), qp, _var.var(comp), _var.tag());
     531             :   }
     532             : 
     533    42829952 :   return grad;
     534           0 : }
     535             : 
     536             : using VariableValue = VariableValueTempl<false>;
     537             : using ADVariableValue = VariableValueTempl<true>;
     538             : using VariableGradient = VariableGradientTempl<false>;
     539             : using ADVariableGradient = VariableGradientTempl<true>;
     540             : 
     541             : template <>
     542             : struct ArrayDeepCopy<ADVariableValue>
     543             : {
     544             :   static constexpr bool value = true;
     545             : };
     546             : 
     547             : template <>
     548             : struct ArrayDeepCopy<ADVariableGradient>
     549             : {
     550             :   static constexpr bool value = true;
     551             : };
     552             : ///@}
     553             : 
     554             : } // namespace Moose::Kokkos

Generated by: LCOV version 1.14