LCOV - code coverage report
Current view: top level - include/kokkos/base - KokkosDispatcher.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #32971 (54bef8) with base c6cf66 Lines: 64 69 92.8 %
Date: 2026-05-29 20:35:17 Functions: 1241 2084 59.5 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://www.mooseframework.org
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosHeader.h"
      13             : #include "KokkosThread.h"
      14             : 
      15             : #include <typeindex>
      16             : 
      17             : namespace Moose::Kokkos
      18             : {
      19             : 
      20             : using Policy = ::Kokkos::RangePolicy<ExecSpace, ::Kokkos::IndexType<ThreadID>>;
      21             : 
      22             : /**
      23             :  * Base class for Kokkos functor dispatcher.
      24             :  * Used for type erasure so that the base class of functors can hold the dispatcher without knowing
      25             :  * the actual type of functors.
      26             :  */
      27             : class DispatcherBase
      28             : {
      29             : public:
      30      508889 :   virtual ~DispatcherBase() {}
      31             :   /**
      32             :    * Dispatch this functor with Kokkos parallel_for() given a Kokkos execution policy
      33             :    * @param policy The Kokkos execution policy
      34             :    */
      35           0 :   virtual void parallelFor(const Policy & /* policy */)
      36             :   {
      37           0 :     mooseError("parallelFor() called for an instance that is not a dispatcher.");
      38             :   }
      39             :   /**
      40             :    * Dispatch this functor with Kokkos parallel_reduce() given a Kokkos execution policy and result
      41             :    * buffer
      42             :    * @param policy The Kokkos execution policy
      43             :    * @param result The result buffer
      44             :    */
      45           0 :   virtual void parallelReduce(const Policy & /* policy */,
      46             :                               ::Kokkos::View<Real *, ::Kokkos::HostSpace> & /* result */)
      47             :   {
      48           0 :     mooseError("parallelReduce() called for an instance that is not a reducer.");
      49             :   }
      50             : };
      51             : 
      52             : /**
      53             :  * Class that dispatches a parallel loop operation of a Kokkos functor.
      54             :  * Calls operator() of the functor with a specified function tag.
      55             :  * @tparam Operation The function tag of operator() to be dispatched
      56             :  * @tparam Object The functor class type
      57             :  */
      58             : template <typename Operation, typename Object>
      59             : class Dispatcher : public DispatcherBase
      60             : {
      61             : public:
      62             :   /**
      63             :    * Constructor
      64             :    * @param object The pointer to the functor. This dispatcher is constructed by the base class of
      65             :    * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
      66             :    * as a void pointer and cast to the actual type here.
      67             :    */
      68       22923 :   Dispatcher(const void * object)
      69       12609 :     : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
      70             :   {
      71       22923 :   }
      72             :   /**
      73             :    * Copy constructor for parallel dispatch
      74             :    */
      75      478284 :   Dispatcher(const Dispatcher & functor)
      76      272454 :     : _functor_host(functor._functor_host), _functor_device(functor._functor_host)
      77             :   {
      78      478284 :   }
      79             : 
      80      478283 :   void parallelFor(const Policy & policy) override final
      81             :   {
      82      478283 :     ::Kokkos::parallel_for(policy, *this);
      83      478283 :     ::Kokkos::fence();
      84      478283 :   }
      85             : 
      86             :   /**
      87             :    * The parallel computation entry function called by Kokkos::parallel_for
      88             :    */
      89    27318554 :   KOKKOS_FUNCTION void operator()(const ThreadID tid) const
      90             :   {
      91    27318554 :     _functor_device(Operation{}, tid, _functor_device);
      92    27318554 :   }
      93             : 
      94             : private:
      95             :   /**
      96             :    * Reference of the functor on host
      97             :    */
      98             :   const Object & _functor_host;
      99             :   /**
     100             :    * Copy of the functor on device
     101             :    */
     102             :   const Object _functor_device;
     103             : };
     104             : 
     105             : /**
     106             :  * Class that dispatches a parallel reduction operation of a Kokkos functor.
     107             :  * Calls operator() of the functor with a specified function tag.
     108             :  * @tparam Operation The function tag of operator() to be dispatched
     109             :  * @tparam Object The functor class type
     110             :  */
     111             : template <typename Operation, typename Object>
     112             : class Reducer : public DispatcherBase
     113             : {
     114             : public:
     115             :   /**
     116             :    * Constructor
     117             :    * @param object The pointer to the functor. This reducer is constructed by the base class of
     118             :    * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
     119             :    * as a void pointer and cast to the actual type here.
     120             :    */
     121         966 :   Reducer(const void * object)
     122         538 :     : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
     123             :   {
     124         966 :   }
     125             :   /**
     126             :    * Copy constructor for parallel dispatch
     127             :    */
     128        6754 :   Reducer(const Reducer & functor)
     129        4092 :     : value_count(functor.value_count),
     130        4092 :       _functor_host(functor._functor_host),
     131        4092 :       _functor_device(functor._functor_host)
     132             :   {
     133        6754 :   }
     134             : 
     135        2251 :   void parallelReduce(const Policy & policy,
     136             :                       ::Kokkos::View<Real *, ::Kokkos::HostSpace> & result) override final
     137             :   {
     138        2251 :     value_count = result.size();
     139             : 
     140        2251 :     ::Kokkos::parallel_reduce(policy, *this, result);
     141        2251 :     ::Kokkos::fence();
     142        2251 :   }
     143             : 
     144             :   using value_type = Real[];
     145             :   using size_type = ::Kokkos::View<Real *>::size_type;
     146             : 
     147             :   size_type value_count;
     148             : 
     149             :   /**
     150             :    * The parallel computation entry function called by Kokkos::parallel_reduce
     151             :    */
     152      382782 :   KOKKOS_FUNCTION void operator()(const ThreadID tid, value_type result) const
     153             :   {
     154      382782 :     _functor_device(Operation{}, tid, _functor_device, result);
     155      382782 :   }
     156             : 
     157             :   /**
     158             :    * Functions required by the reducer concept of Kokkos
     159             :    */
     160             :   ///@{
     161         169 :   KOKKOS_FUNCTION void join(value_type result, const value_type source) const
     162             :   {
     163         169 :     _functor_device.template join<Object>(result, source);
     164         169 :   }
     165        1533 :   KOKKOS_FUNCTION void init(value_type result) const
     166             :   {
     167        1533 :     _functor_device.template init<Object>(result);
     168        1533 :   }
     169             :   ///@}
     170             : 
     171             : private:
     172             :   /**
     173             :    * Reference of the functor on host
     174             :    */
     175             :   const Object & _functor_host;
     176             :   /**
     177             :    * Copy of the functor on device
     178             :    */
     179             :   const Object _functor_device;
     180             : };
     181             : 
     182             : /**
     183             :  * Base class for dispatcher registry entry.
     184             :  * Used for type erasure so that the registry can hold dispatchers for different functor types in a
     185             :  * single container.
     186             :  */
     187             : class DispatcherRegistryEntryBase
     188             : {
     189             : public:
     190           4 :   virtual ~DispatcherRegistryEntryBase() {}
     191             : 
     192             :   /**
     193             :    * Build a dispatcher for this operation and functor
     194             :    * @param object The pointer to the functor
     195             :    */
     196             :   virtual std::unique_ptr<DispatcherBase> build(const void * object) const = 0;
     197             : 
     198             :   /**
     199             :    * Set whether the user has overriden the hook method associated with this operation
     200             :    * @param flag Whether the user has overriden the hook method
     201             :    */
     202     6656600 :   void hasUserMethod(bool flag) { _has_user_method = flag; }
     203             :   /**
     204             :    * Get whether the user has overriden the hook method associated with this operation
     205             :    * @returns Whether the user has overriden the hook method
     206             :    */
     207       90531 :   bool hasUserMethod() const { return _has_user_method; }
     208             : 
     209             : private:
     210             :   /**
     211             :    * Flag whether the user has overriden the hook method associated with this operation
     212             :    */
     213             :   bool _has_user_method = false;
     214             : };
     215             : 
     216             : /**
     217             :  * Class that stores the information of a dispatcher and builds it.
     218             :  * This shell class is the entry of the dispatcher registry instead of the dispatcher itself.
     219             :  * The reason this class does not dispatch the functor directly is to let the dispatcher hold
     220             :  * the reference of the functor so that the functor does not need to be copied twice at each
     221             :  * dispatch. Namely, dispatchers are to be built and held by the functors, not the registry.
     222             :  * @tparam Operation The function tag of operator() to be dispatched
     223             :  * @tparam Object The functor class type
     224             :  */
     225             : ///@{
     226             : template <typename Operation, typename Object>
     227             : class DispatcherRegistryEntry : public DispatcherRegistryEntryBase
     228             : {
     229             : public:
     230       22923 :   std::unique_ptr<DispatcherBase> build(const void * object) const override final
     231             :   {
     232       22923 :     return std::make_unique<Dispatcher<Operation, Object>>(object);
     233             :   }
     234             : };
     235             : 
     236             : template <typename Operation, typename Object>
     237             : class ReducerRegistryEntry : public DispatcherRegistryEntryBase
     238             : {
     239             : public:
     240         966 :   std::unique_ptr<DispatcherBase> build(const void * object) const override final
     241             :   {
     242         966 :     return std::make_unique<Reducer<Operation, Object>>(object);
     243             :   }
     244             : };
     245             : ///@}
     246             : 
     247             : /**
     248             :  * Class that registers dispatchers of all Kokkos functors
     249             :  */
     250             : class DispatcherRegistry
     251             : {
     252             : public:
     253       40355 :   DispatcherRegistry() = default;
     254             : 
     255             :   DispatcherRegistry(DispatcherRegistry const &) = delete;
     256             :   DispatcherRegistry & operator=(DispatcherRegistry const &) = delete;
     257             : 
     258             :   DispatcherRegistry(DispatcherRegistry &&) = delete;
     259             :   DispatcherRegistry & operator=(DispatcherRegistry &&) = delete;
     260             : 
     261             :   /**
     262             :    * Register a dispatcher of an operation of a functor
     263             :    * @tparam Operation The function tag of operator() to be dispatched
     264             :    * @tparam Object The functor class type
     265             :    * @param name The registered object type name
     266             :    */
     267             :   template <typename Operation, typename Object>
     268    10488804 :   static void addDispatcher(const std::string & name)
     269             :   {
     270    10488804 :     auto operation = std::type_index(typeid(Operation));
     271             : 
     272    10488804 :     getRegistry()._dispatchers[std::make_pair(operation, name)] =
     273             :         std::make_unique<DispatcherRegistryEntry<Operation, Object>>();
     274    10488804 :   }
     275             : 
     276             :   /**
     277             :    * Register a reducer of an operation of a functor
     278             :    * @tparam Operation The function tag of operator() to be dispatched
     279             :    * @tparam Object The functor class type
     280             :    * @param name The registered object type name
     281             :    */
     282             :   template <typename Operation, typename Object>
     283      847384 :   static void addReducer(const std::string & name)
     284             :   {
     285      847384 :     auto operation = std::type_index(typeid(Operation));
     286             : 
     287      847384 :     getRegistry()._dispatchers[std::make_pair(operation, name)] =
     288             :         std::make_unique<ReducerRegistryEntry<Operation, Object>>();
     289      847384 :   }
     290             : 
     291             :   /**
     292             :    * Set whether the user has overriden the hook method associated with an operation of a functor
     293             :    * @tparam Operation The function tag of operator()
     294             :    * @param name The registered object type name
     295             :    * @param flag Whether the user has overriden the hook method
     296             :    */
     297             :   template <typename Operation>
     298     6656600 :   static void hasUserMethod(const std::string & name, const bool flag)
     299             :   {
     300     6656600 :     getDispatcher<Operation>(name)->hasUserMethod(flag);
     301     6656600 :   }
     302             : 
     303             :   /**
     304             :    * Get whether the user has overriden the hook method associated with an operation of a functor
     305             :    * @tparam Operation The function tag of operator()
     306             :    * @param name The registered object type name
     307             :    * @returns Whether the user has overriden the hook method
     308             :    */
     309             :   template <typename Operation>
     310       90531 :   static bool hasUserMethod(const std::string & name)
     311             :   {
     312       90531 :     return getDispatcher<Operation>(name)->hasUserMethod();
     313             :   }
     314             : 
     315             :   /**
     316             :    * Build and get a dispatcher of an operation of a functor
     317             :    * @tparam Operation The function tag of operator()
     318             :    * @param object The pointer to the functor
     319             :    * @param name The registered object type name
     320             :    * @returns The dispatcher
     321             :    */
     322             :   template <typename Operation>
     323       23889 :   static std::unique_ptr<DispatcherBase> build(const void * object, const std::string & name)
     324             :   {
     325       23889 :     return getDispatcher<Operation>(name)->build(object);
     326             :   }
     327             : 
     328             : private:
     329             :   /**
     330             :    * Get the registry singleton
     331             :    * @returns The registry singleton
     332             :    */
     333             :   static DispatcherRegistry & getRegistry();
     334             : 
     335             :   /**
     336             :    * Get the dispatcher shell of an operation of a functor
     337             :    * @tparam Operation The function tag of operator()
     338             :    * @param name The registered object type name
     339             :    * @returns The dispatcher shell
     340             :    */
     341             :   template <typename Operation>
     342     6771020 :   static auto & getDispatcher(const std::string & name)
     343             :   {
     344     6771020 :     auto operation = std::type_index(typeid(Operation));
     345             : 
     346     6771020 :     auto it = getRegistry()._dispatchers.find(std::make_pair(operation, name));
     347     6771020 :     if (it == getRegistry()._dispatchers.end())
     348           0 :       mooseError("Kokkos functor dispatcher not registered for object type '",
     349             :                  name,
     350             :                  "'. Double check that you used Kokkos-specific registration macro.");
     351             : 
     352    13542040 :     return it->second;
     353             :   }
     354             : 
     355             :   /**
     356             :    * Map containing the dispatcher shells with the key being the pair of function tag type index and
     357             :    * registered object type name
     358             :    */
     359             :   std::map<std::pair<std::type_index, std::string>, std::unique_ptr<DispatcherRegistryEntryBase>>
     360             :       _dispatchers;
     361             : };
     362             : 
     363             : } // namespace Moose::Kokkos
     364             : 
     365             : // Kernel, NodalKernel, BC
     366             : 
     367             : #define callRegisterKokkosResidualObjectFunction(classname, objectname)                            \
     368             :   static char registerKokkosResidualObject##classname()                                            \
     369             :   {                                                                                                \
     370             :     using namespace Moose::Kokkos;                                                                 \
     371             :                                                                                                    \
     372             :     DispatcherRegistry::addDispatcher<classname::ResidualLoop, classname>(objectname);             \
     373             :     DispatcherRegistry::addDispatcher<classname::JacobianLoop, classname>(objectname);             \
     374             :     DispatcherRegistry::addDispatcher<classname::OffDiagJacobianLoop, classname>(objectname);      \
     375             :     DispatcherRegistry::hasUserMethod<classname::JacobianLoop>(                                    \
     376             :         objectname,                                                                                \
     377             :         &classname::computeQpJacobian<classname> != classname::defaultJacobian<classname>());      \
     378             :     DispatcherRegistry::hasUserMethod<classname::OffDiagJacobianLoop>(                             \
     379             :         objectname,                                                                                \
     380             :         &classname::computeQpOffDiagJacobian<classname> !=                                         \
     381             :             classname::defaultOffDiagJacobian<classname>());                                       \
     382             :                                                                                                    \
     383             :     return 0;                                                                                      \
     384             :   }                                                                                                \
     385             :                                                                                                    \
     386             :   static char combineNames(kokkos_dispatcher_residual_object_##classname, __COUNTER__) =           \
     387             :       registerKokkosResidualObject##classname()
     388             : 
     389             : #define registerKokkosResidualObject(app, classname)                                               \
     390             :   registerMooseObject(app, classname);                                                             \
     391             :   callRegisterKokkosResidualObjectFunction(classname, #classname)
     392             : 
     393             : #define registerKokkosResidualObjectAliased(app, classname, alias)                                 \
     394             :   registerMooseObjectAliased(app, classname, alias);                                               \
     395             :   callRegisterKokkosResidualObjectFunction(classname, alias)
     396             : 
     397             : // AD Kernel, NodalKernel, BC
     398             : 
     399             : #define callRegisterKokkosADResidualObjectFunction(classname, objectname)                          \
     400             :   static char registerKokkosADResidualObject##classname()                                          \
     401             :   {                                                                                                \
     402             :     using namespace Moose::Kokkos;                                                                 \
     403             :                                                                                                    \
     404             :     DispatcherRegistry::addDispatcher<classname::ResidualLoop, classname>(objectname);             \
     405             :                                                                                                    \
     406             :     return 0;                                                                                      \
     407             :   }                                                                                                \
     408             :                                                                                                    \
     409             :   static char combineNames(kokkos_dispatcher_ad_residual_object_##classname, __COUNTER__) =        \
     410             :       registerKokkosADResidualObject##classname()
     411             : 
     412             : #define registerKokkosADResidualObject(app, classname)                                             \
     413             :   registerMooseObject(app, classname);                                                             \
     414             :   callRegisterKokkosADResidualObjectFunction(classname, #classname)
     415             : 
     416             : #define registerKokkosADResidualObjectAliased(app, classname, alias)                               \
     417             :   registerMooseObjectAliased(app, classname, alias);                                               \
     418             :   callRegisterKokkosADResidualObjectFunction(classname, alias)
     419             : 
     420             : // Material
     421             : 
     422             : #define callRegisterKokkosMaterialFunction(classname, objectname)                                  \
     423             :   static char registerKokkosMaterial##classname()                                                  \
     424             :   {                                                                                                \
     425             :     using namespace Moose::Kokkos;                                                                 \
     426             :                                                                                                    \
     427             :     DispatcherRegistry::addDispatcher<classname::ElementInit, classname>(objectname);              \
     428             :     DispatcherRegistry::addDispatcher<classname::SideInit, classname>(objectname);                 \
     429             :     DispatcherRegistry::addDispatcher<classname::NeighborInit, classname>(objectname);             \
     430             :     DispatcherRegistry::addDispatcher<classname::ElementCompute, classname>(objectname);           \
     431             :     DispatcherRegistry::addDispatcher<classname::SideCompute, classname>(objectname);              \
     432             :     DispatcherRegistry::addDispatcher<classname::NeighborCompute, classname>(objectname);          \
     433             :     DispatcherRegistry::hasUserMethod<classname::ElementInit>(                                     \
     434             :         objectname,                                                                                \
     435             :         &classname::initQpStatefulProperties<classname> !=                                         \
     436             :             classname::defaultInitStateful<classname>());                                          \
     437             :     DispatcherRegistry::hasUserMethod<classname::SideInit>(                                        \
     438             :         objectname,                                                                                \
     439             :         &classname::initQpStatefulProperties<classname> !=                                         \
     440             :             classname::defaultInitStateful<classname>());                                          \
     441             :     DispatcherRegistry::hasUserMethod<classname::NeighborInit>(                                    \
     442             :         objectname,                                                                                \
     443             :         &classname::initQpStatefulProperties<classname> !=                                         \
     444             :             classname::defaultInitStateful<classname>());                                          \
     445             :                                                                                                    \
     446             :     return 0;                                                                                      \
     447             :   }                                                                                                \
     448             :                                                                                                    \
     449             :   static char combineNames(kokkos_dispatcher_material_##classname, __COUNTER__) =                  \
     450             :       registerKokkosMaterial##classname()
     451             : 
     452             : #define registerKokkosMaterial(app, classname)                                                     \
     453             :   registerMooseObject(app, classname);                                                             \
     454             :   callRegisterKokkosMaterialFunction(classname, #classname)
     455             : 
     456             : #define registerKokkosMaterialAliased(app, classname, alias)                                       \
     457             :   registerMooseObjectAliased(app, classname, alias);                                               \
     458             :   callRegisterKokkosMaterialFunction(classname, alias)
     459             : 
     460             : // AuxKernel
     461             : 
     462             : #define callRegisterKokkosAuxKernelFunction(classname, objectname)                                 \
     463             :   static char registerKokkosAuxKernel##classname()                                                 \
     464             :   {                                                                                                \
     465             :     using namespace Moose::Kokkos;                                                                 \
     466             :                                                                                                    \
     467             :     DispatcherRegistry::addDispatcher<classname::ElementLoop, classname>(objectname);              \
     468             :     DispatcherRegistry::addDispatcher<classname::NodeLoop, classname>(objectname);                 \
     469             :                                                                                                    \
     470             :     return 0;                                                                                      \
     471             :   }                                                                                                \
     472             :                                                                                                    \
     473             :   static char combineNames(kokkos_dispatcher_auxkernel_##classname, __COUNTER__) =                 \
     474             :       registerKokkosAuxKernel##classname()
     475             : 
     476             : #define registerKokkosAuxKernel(app, classname)                                                    \
     477             :   registerMooseObject(app, classname);                                                             \
     478             :   callRegisterKokkosAuxKernelFunction(classname, #classname)
     479             : 
     480             : #define registerKokkosAuxKernelAliased(app, classname, alias)                                      \
     481             :   registerMooseObjectAliased(app, classname, alias);                                               \
     482             :   callRegisterKokkosAuxKernelFunction(classname, alias)
     483             : 
     484             : // UserObject
     485             : 
     486             : #define callRegisterKokkosUserObjectFunction(classname, objectname)                                \
     487             :   static char registerKokkosUserObject##classname()                                                \
     488             :   {                                                                                                \
     489             :     using namespace Moose::Kokkos;                                                                 \
     490             :                                                                                                    \
     491             :     DispatcherRegistry::addDispatcher<classname::DefaultLoop, classname>(objectname);              \
     492             :     DispatcherRegistry::addReducer<classname::ReducerLoop, classname>(objectname);                 \
     493             :     DispatcherRegistry::hasUserMethod<classname::DefaultLoop>(                                     \
     494             :         objectname, &classname::execute<classname> != classname::defaultExecute<classname>());     \
     495             :     DispatcherRegistry::hasUserMethod<classname::ReducerLoop>(                                     \
     496             :         objectname, &classname::reduce<classname> != classname::defaultReduce<classname>());       \
     497             :                                                                                                    \
     498             :     return 0;                                                                                      \
     499             :   }                                                                                                \
     500             :                                                                                                    \
     501             :   static char combineNames(kokkos_dispatcher_userobject_##classname, __COUNTER__) =                \
     502             :       registerKokkosUserObject##classname()
     503             : 
     504             : #define registerKokkosUserObject(app, classname)                                                   \
     505             :   registerMooseObject(app, classname);                                                             \
     506             :   callRegisterKokkosUserObjectFunction(classname, #classname)
     507             : 
     508             : #define registerKokkosUserObjectAliased(app, classname, alias)                                     \
     509             :   registerMooseObjectAliased(app, classname, alias);                                               \
     510             :   callRegisterKokkosUserObjectFunction(classname, alias)
     511             : 
     512             : // User-defined parallel operation registry
     513             : 
     514             : #define registerKokkosAdditionalOperation(classname, operation)                                    \
     515             :   static char registerKokkos##classname##operation()                                               \
     516             :   {                                                                                                \
     517             :     using namespace Moose::Kokkos;                                                                 \
     518             :                                                                                                    \
     519             :     DispatcherRegistry::addDispatcher<classname::operation, classname>(#classname);                \
     520             :                                                                                                    \
     521             :     return 0;                                                                                      \
     522             :   }                                                                                                \
     523             :                                                                                                    \
     524             :   static char combineNames(kokkos_##classname##_##operation, __COUNTER__) =                        \
     525             :       registerKokkos##classname##operation()

Generated by: LCOV version 1.14