LCOV - code coverage report
Current view: top level - include/kokkos/auxkernels - KokkosAuxKernel.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 7edd10 Lines: 54 55 98.2 %
Date: 2025-11-11 08:32:45 Functions: 23 36 63.9 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosDispatcher.h"
      13             : #include "KokkosVariableValue.h"
      14             : #include "KokkosMaterialPropertyValue.h"
      15             : #include "KokkosAssembly.h"
      16             : #include "KokkosSystem.h"
      17             : 
      18             : #include "AuxKernelBase.h"
      19             : 
      20             : namespace Moose
      21             : {
      22             : namespace Kokkos
      23             : {
      24             : 
      25             : /**
      26             :  * The base class for a user to derive their own Kokkos auxiliary kernels.
      27             :  *
      28             :  * The user should define computeValue() as inlined public method in their derived class (not
      29             :  * virtual override). The signature of computeValue() expected to be defined in the derived
      30             :  * class is as follows:
      31             :  *
      32             :  * @param qp The local quadrature point index
      33             :  * @param datum The AssemblyDatum object of the current thread
      34             :  * @returns The value at the quadrature point
      35             :  *
      36             :  * KOKKOS_FUNCTION Real computeValue(const unsigned int qp, AssemblyDatum & datum) const;
      37             :  */
      38             : class AuxKernel : public ::AuxKernelBase,
      39             :                   public MeshHolder,
      40             :                   public AssemblyHolder,
      41             :                   public SystemHolder
      42             : {
      43             : public:
      44             :   static InputParameters validParams();
      45             : 
      46             :   /**
      47             :    * Constructor
      48             :    */
      49             :   AuxKernel(const InputParameters & parameters);
      50             : 
      51             :   /**
      52             :    * Copy constructor for parallel dispatch
      53             :    */
      54             :   AuxKernel(const AuxKernel & object);
      55             : 
      56             :   // Unused for Kokkos auxiliary kernels because all elements are computed in parallel
      57           0 :   virtual void subdomainSetup() override final {}
      58             : 
      59             :   /**
      60             :    * Dispatch calculation
      61             :    */
      62             :   virtual void compute() override;
      63             : 
      64             :   /**
      65             :    * Get whether this auxiliary kernel is nodal
      66             :    * @returns Whether this auxiliary kernel is nodal
      67             :    */
      68         524 :   KOKKOS_FUNCTION bool isNodal() const { return _nodal; }
      69             : 
      70             :   /**
      71             :    * Kokkos function tags
      72             :    */
      73             :   ///@{
      74             :   struct ElementLoop
      75             :   {
      76             :   };
      77             :   struct NodeLoop
      78             :   {
      79             :   };
      80             :   ///@}
      81             : 
      82             :   /**
      83             :    * Shim for hook method that can be leveraged to implement static polymorphism
      84             :    */
      85             :   template <typename Derived>
      86     2801998 :   KOKKOS_FUNCTION Real computeValueShim(const Derived & auxkernel,
      87             :                                         const unsigned int qp,
      88             :                                         AssemblyDatum & datum) const
      89             :   {
      90     2801998 :     return auxkernel.computeValue(qp, datum);
      91             :   }
      92             : 
      93             :   /**
      94             :    * The parallel computation entry functions called by Kokkos
      95             :    */
      96             :   ///@{
      97             :   template <typename Derived>
      98             :   KOKKOS_FUNCTION void operator()(ElementLoop, const ThreadID tid, const Derived & auxkernel) const;
      99             :   template <typename Derived>
     100             :   KOKKOS_FUNCTION void operator()(NodeLoop, const ThreadID tid, const Derived & auxkernel) const;
     101             :   ///@}
     102             : 
     103             :   /**
     104             :    * The parallel computation bodies that can be customized in the derived class by defining
     105             :    * them in the derived class with the same signature.
     106             :    * Make sure to define them as inlined public methods if to be defined in the derived class.
     107             :    */
     108             :   ///@{
     109             :   /**
     110             :    * Compute an element
     111             :    * @param kernel The auxiliary kernel object of the final derived type
     112             :    * @param datum The AssemblyDatum object of the current thread
     113             :    */
     114             :   template <typename Derived>
     115             :   KOKKOS_FUNCTION void computeElementInternal(const Derived & auxkernel,
     116             :                                               AssemblyDatum & datum) const;
     117             :   /**
     118             :    * Compute a node
     119             :    * @param kernel The auxiliary kernel object of the final derived type
     120             :    * @param datum The AssemblyDatum object of the current thread
     121             :    */
     122             :   template <typename Derived>
     123             :   KOKKOS_FUNCTION void computeNodeInternal(const Derived & auxkernel, AssemblyDatum & datum) const;
     124             :   ///@}
     125             : 
     126             : protected:
     127             :   /**
     128             :    * Retrieve the old value of the variable that this kernel operates on
     129             :    * @returns The old variable value object
     130             :    */
     131             :   VariableValue uOld() const;
     132             :   /**
     133             :    * Retrieve the older value of the variable that this kernel operates on
     134             :    * @returns The older variable value object
     135             :    */
     136             :   VariableValue uOlder() const;
     137             : 
     138             :   /**
     139             :    * Set element values to the auxiliary solution vector
     140             :    * @param values The array containing the solution values of the element
     141             :    * @param datum The AssemblyDatum object of the current thread
     142             :    * @param comp The variable component
     143             :    */
     144             :   KOKKOS_FUNCTION void setElementSolution(const Real * const values,
     145             :                                           const AssemblyDatum & datum,
     146             :                                           const unsigned int comp = 0) const;
     147             :   /**
     148             :    * Set node value to the auxiliary solution vector
     149             :    * @param values The node solution value
     150             :    * @param datum The AssemblyDatum object of the current thread
     151             :    * @param comp The variable component
     152             :    */
     153             :   KOKKOS_FUNCTION void
     154             :   setNodeSolution(const Real value, const AssemblyDatum & datum, const unsigned int comp = 0) const;
     155             : 
     156             :   /**
     157             :    * Flag whether this kernel is nodal
     158             :    */
     159             :   const bool _nodal;
     160             : 
     161             :   /**
     162             :    * Kokkos variable
     163             :    */
     164             :   Variable _kokkos_var;
     165             :   /**
     166             :    * Kokkos functor dispatchers
     167             :    */
     168             :   ///@{
     169             :   std::unique_ptr<DispatcherBase> _element_dispatcher;
     170             :   std::unique_ptr<DispatcherBase> _node_dispatcher;
     171             :   ///@}
     172             : 
     173             :   /**
     174             :    * Current test function
     175             :    */
     176             :   const VariableTestValue _test;
     177             :   /**
     178             :    * Current solution
     179             :    */
     180             :   const VariableValue _u;
     181             : 
     182             :   /**
     183             :    * TODO: Move to TransientInterface
     184             :    */
     185             :   ///@{
     186             :   /**
     187             :    * Time
     188             :    */
     189             :   Scalar<Real> _t;
     190             :   /**
     191             :    * Old time
     192             :    */
     193             :   Scalar<const Real> _t_old;
     194             :   /**
     195             :    * The number of the time step
     196             :    */
     197             :   Scalar<int> _t_step;
     198             :   /**
     199             :    * Time step size
     200             :    */
     201             :   Scalar<Real> _dt;
     202             :   /**
     203             :    * Size of the old time step
     204             :    */
     205             :   Scalar<Real> _dt_old;
     206             :   ///@}
     207             : 
     208             : private:
     209             :   /**
     210             :    * Override of the MaterialPropertyInterface function to throw on material property request for
     211             :    * nodal kernels
     212             :    */
     213             :   void getKokkosMaterialPropertyHook(const std::string & prop_name_in,
     214             :                                      const unsigned int state) override final;
     215             : };
     216             : 
     217             : template <typename Derived>
     218             : KOKKOS_FUNCTION void
     219      210176 : AuxKernel::operator()(ElementLoop, const ThreadID tid, const Derived & auxkernel) const
     220             : {
     221      210176 :   auto elem = kokkosBlockElementID(tid);
     222             : 
     223      420352 :   AssemblyDatum datum(elem,
     224             :                       libMesh::invalid_uint,
     225             :                       kokkosAssembly(),
     226             :                       kokkosSystems(),
     227      210176 :                       _kokkos_var,
     228             :                       _kokkos_var.var());
     229             : 
     230      210176 :   auxkernel.computeElementInternal(auxkernel, datum);
     231      210176 : }
     232             : 
     233             : template <typename Derived>
     234             : KOKKOS_FUNCTION void
     235     2494596 : AuxKernel::operator()(NodeLoop, const ThreadID tid, const Derived & auxkernel) const
     236             : {
     237     2494596 :   auto node = _bnd ? kokkosBoundaryNodeID(tid) : kokkosBlockNodeID(tid);
     238     2494596 :   auto & sys = kokkosSystem(_kokkos_var.sys());
     239             : 
     240     2494596 :   if (!sys.isNodalDefined(node, _kokkos_var.var()))
     241     1505280 :     return;
     242             : 
     243      989316 :   AssemblyDatum datum(node, kokkosAssembly(), kokkosSystems(), _kokkos_var, _kokkos_var.var());
     244             : 
     245      989316 :   auxkernel.computeNodeInternal(auxkernel, datum);
     246             : }
     247             : 
     248             : template <typename Derived>
     249             : KOKKOS_FUNCTION void
     250      209576 : AuxKernel::computeElementInternal(const Derived & auxkernel, AssemblyDatum & datum) const
     251             : {
     252             :   Real x[MAX_CACHED_DOF];
     253             :   Real b[MAX_CACHED_DOF];
     254             :   Real A[MAX_CACHED_DOF * MAX_CACHED_DOF];
     255             : 
     256     1104152 :   for (unsigned int i = 0; i < datum.n_dofs(); ++i)
     257             :   {
     258      894576 :     x[i] = 0;
     259      894576 :     b[i] = 0;
     260             : 
     261     5899152 :     for (unsigned int j = 0; j < datum.n_dofs(); ++j)
     262     5004576 :       A[j + datum.n_dofs() * i] = 0;
     263             :   }
     264             : 
     265     2023184 :   for (unsigned int qp = 0; qp < datum.n_qps(); ++qp)
     266             :   {
     267     1813608 :     const auto value = auxkernel.computeValueShim(auxkernel, qp, datum);
     268             : 
     269     1813608 :     datum.reinit();
     270             : 
     271     9792216 :     for (unsigned int i = 0; i < datum.n_dofs(); ++i)
     272             :     {
     273     7978608 :       const auto t = datum.JxW(qp) * _test(datum, i, qp);
     274             : 
     275     7978608 :       b[i] += t * value;
     276             : 
     277    52947216 :       for (unsigned int j = 0; j < datum.n_dofs(); ++j)
     278    44968608 :         A[j + datum.n_dofs() * i] += t * _test(datum, j, qp);
     279             :     }
     280             :   }
     281             : 
     282      209576 :   if (datum.n_dofs() == 1)
     283       72576 :     x[0] = b[0] / A[0];
     284             :   else
     285      137000 :     Utils::choleskySolve(A, x, b, datum.n_dofs());
     286             : 
     287      209576 :   setElementSolution(x, datum);
     288      209576 : }
     289             : 
     290             : template <typename Derived>
     291             : KOKKOS_FUNCTION void
     292      988390 : AuxKernel::computeNodeInternal(const Derived & auxkernel, AssemblyDatum & datum) const
     293             : {
     294      988390 :   auto value = auxkernel.computeValueShim(auxkernel, 0, datum);
     295             : 
     296      988390 :   setNodeSolution(value, datum);
     297      988390 : }
     298             : 
     299             : KOKKOS_FUNCTION inline void
     300      209576 : AuxKernel::setElementSolution(const Real * const values,
     301             :                               const AssemblyDatum & datum,
     302             :                               const unsigned int comp) const
     303             : {
     304      209576 :   auto & sys = kokkosSystem(_kokkos_var.sys(comp));
     305      209576 :   auto var_num = _kokkos_var.var(comp);
     306      209576 :   auto tag = _kokkos_var.tag();
     307      209576 :   auto elem = datum.elem().id;
     308             : 
     309     1104152 :   for (unsigned int i = 0; i < datum.n_dofs(); ++i)
     310      894576 :     sys.getVectorDofValue(sys.getElemLocalDofIndex(elem, i, var_num), tag) = values[i];
     311      209576 : }
     312             : 
     313             : KOKKOS_FUNCTION inline void
     314      988390 : AuxKernel::setNodeSolution(const Real value,
     315             :                            const AssemblyDatum & datum,
     316             :                            const unsigned int comp) const
     317             : {
     318      988390 :   auto & sys = kokkosSystem(_kokkos_var.sys(comp));
     319      988390 :   auto var_num = _kokkos_var.var(comp);
     320      988390 :   auto tag = _kokkos_var.tag();
     321      988390 :   auto node = datum.node();
     322             : 
     323      988390 :   sys.getVectorDofValue(sys.getNodeLocalDofIndex(node, 0, var_num), tag) = value;
     324      988390 : }
     325             : 
     326             : } // namespace Kokkos
     327             : } // namespace Moose

Generated by: LCOV version 1.14