LCOV - code coverage report
Current view: top level - include/kokkos/base - KokkosDispatcher.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 863ef6 Lines: 35 37 94.6 %
Date: 2025-10-15 18:16:15 Functions: 692 1241 55.8 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://www.mooseframework.org
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosHeader.h"
      13             : #include "KokkosThread.h"
      14             : 
      15             : #include <typeindex>
      16             : 
      17             : namespace Moose
      18             : {
      19             : namespace Kokkos
      20             : {
      21             : 
      22             : using Policy = ::Kokkos::RangePolicy<ExecSpace, ::Kokkos::IndexType<ThreadID>>;
      23             : 
      24             : /**
      25             :  * Base class for Kokkos functor dispatcher.
      26             :  * Used for type erasure so that the base class of functors can hold the dispatcher without knowing
      27             :  * the actual type of functors.
      28             :  */
      29             : class DispatcherBase
      30             : {
      31             : public:
      32      196710 :   virtual ~DispatcherBase() {}
      33             :   /**
      34             :    * Dispatch this functor with Kokkos parallel_for() given a Kokkos execution policy
      35             :    * @param policy The Kokkos execution policy
      36             :    */
      37             :   virtual void parallelFor(const Policy & policy) = 0;
      38             : };
      39             : 
      40             : /**
      41             :  * Class that dispatches an operation of a Kokkos functor.
      42             :  * Calls operator() of the functor with a specified function tag.
      43             :  * @tparam Operation The function tag of operator() to be dispatched
      44             :  * @tparam Object The functor class type
      45             :  */
      46             : template <typename Operation, typename Object>
      47             : class Dispatcher : public DispatcherBase
      48             : {
      49             : public:
      50             :   /**
      51             :    * Constructor
      52             :    * @param object The pointer to the functor. This dispatcher is constructed by the base class of
      53             :    * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
      54             :    * as a void pointer and cast to the actual type here.
      55             :    */
      56        6232 :   Dispatcher(const void * object)
      57        4889 :     : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
      58             :   {
      59        6232 :   }
      60             :   /**
      61             :    * Copy constructor for parallel dispatch
      62             :    */
      63      190516 :   Dispatcher(const Dispatcher & functor)
      64      155189 :     : _functor_host(functor._functor_host), _functor_device(functor._functor_host)
      65             :   {
      66      190516 :   }
      67             : 
      68      190515 :   void parallelFor(const Policy & policy) override final
      69             :   {
      70      190515 :     ::Kokkos::parallel_for(policy, *this);
      71      190515 :     ::Kokkos::fence();
      72      190515 :   }
      73             : 
      74             :   /**
      75             :    * The parallel computation entry function called by Kokkos
      76             :    */
      77     5647140 :   KOKKOS_FUNCTION void operator()(const ThreadID tid) const
      78             :   {
      79     5647140 :     _functor_device(Operation{}, tid, _functor_device);
      80     5647140 :   }
      81             : 
      82             : private:
      83             :   /**
      84             :    * Reference of the functor on host
      85             :    */
      86             :   const Object & _functor_host;
      87             :   /**
      88             :    * Copy of the functor on device
      89             :    */
      90             :   const Object _functor_device;
      91             : };
      92             : 
      93             : /**
      94             :  * Base class for dispatcher registry entry.
      95             :  * Used for type erasure so that the registry can hold dispatchers for different functor types in a
      96             :  * single container.
      97             :  */
      98             : class DispatcherRegistryEntryBase
      99             : {
     100             : public:
     101           0 :   virtual ~DispatcherRegistryEntryBase() {}
     102             :   /**
     103             :    * Build a dispatcher for this operation and functor
     104             :    * @param object The pointer to the functor
     105             :    */
     106             :   virtual std::unique_ptr<DispatcherBase> build(const void * object) const = 0;
     107             : 
     108             :   /**
     109             :    * Set whether the user has overriden the hook method associated with this operation
     110             :    * @param flag Whether the user has overriden the hook method
     111             :    */
     112     3737125 :   void hasUserMethod(bool flag) { _has_user_method = flag; }
     113             :   /**
     114             :    * Get whether the user has overriden the hook method associated with this operation
     115             :    * @returns Whether the user has overriden the hook method
     116             :    */
     117       40908 :   bool hasUserMethod() const { return _has_user_method; }
     118             : 
     119             : private:
     120             :   /**
     121             :    * Flag whether the user has overriden the hook method associated with this operation
     122             :    */
     123             :   bool _has_user_method = false;
     124             : };
     125             : 
     126             : /**
     127             :  * Class that stores the information of a dispatcher and builds it.
     128             :  * This shell class is the entry of the dispatcher registry instead of the dispatcher itself.
     129             :  * The reason this class does not dispatch the functor directly is to let the dispatcher hold
     130             :  * the reference of the functor so that the functor does not need to be copied twice at each
     131             :  * dispatch. Namely, dispatchers are to be built and held by the functors, not the registry.
     132             :  * @tparam Operation The function tag of operator() to be dispatched
     133             :  * @tparam Object The functor class type
     134             :  */
     135             : template <typename Operation, typename Object>
     136             : class DispatcherRegistryEntry : public DispatcherRegistryEntryBase
     137             : {
     138             : public:
     139        6232 :   std::unique_ptr<DispatcherBase> build(const void * object) const override final
     140             :   {
     141        6232 :     return std::make_unique<Dispatcher<Operation, Object>>(object);
     142             :   }
     143             : };
     144             : 
     145             : /**
     146             :  * Class that registers dispatchers of all Kokkos functors
     147             :  */
     148             : class DispatcherRegistry
     149             : {
     150             : public:
     151       37386 :   DispatcherRegistry() = default;
     152             : 
     153             :   DispatcherRegistry(DispatcherRegistry const &) = delete;
     154             :   DispatcherRegistry & operator=(DispatcherRegistry const &) = delete;
     155             : 
     156             :   DispatcherRegistry(DispatcherRegistry &&) = delete;
     157             :   DispatcherRegistry & operator=(DispatcherRegistry &&) = delete;
     158             : 
     159             :   /**
     160             :    * Register a dispatcher of an operation of a functor
     161             :    * @tparam Operation The function tag of operator() to be dispatched
     162             :    * @tparam Object The functor class type
     163             :    * @param name The registered object type name
     164             :    */
     165             :   template <typename Operation, typename Object>
     166     6390308 :   static void add(const std::string & name)
     167             :   {
     168     6390308 :     auto operation = std::type_index(typeid(Operation));
     169             : 
     170     6390308 :     getRegistry()._dispatchers[std::make_pair(operation, name)] =
     171             :         std::make_unique<DispatcherRegistryEntry<Operation, Object>>();
     172     6390308 :   }
     173             : 
     174             :   /**
     175             :    * Set whether the user has overriden the hook method associated with an operation of a functor
     176             :    * @tparam Operation The function tag of operator()
     177             :    * @param name The registered object type name
     178             :    * @param flag Whether the user has overriden the hook method
     179             :    */
     180             :   template <typename Operation>
     181     3737125 :   static void hasUserMethod(const std::string & name, bool flag)
     182             :   {
     183     3737125 :     getDispatcher<Operation>(name)->hasUserMethod(flag);
     184     3737125 :   }
     185             : 
     186             :   /**
     187             :    * Get whether the user has overriden the hook method associated with an operation of a functor
     188             :    * @tparam Operation The function tag of operator()
     189             :    * @param name The registered object type name
     190             :    * @returns Whether the user has overriden the hook method
     191             :    */
     192             :   template <typename Operation>
     193       40908 :   static bool hasUserMethod(const std::string & name)
     194             :   {
     195       40908 :     return getDispatcher<Operation>(name)->hasUserMethod();
     196             :   }
     197             : 
     198             :   /**
     199             :    * Build and get a dispatcher of an operation of a functor
     200             :    * @tparam Operation The function tag of operator()
     201             :    * @param object The pointer to the functor
     202             :    * @param name The registered object type name
     203             :    * @returns The dispatcher
     204             :    */
     205             :   template <typename Operation>
     206        6232 :   static std::unique_ptr<DispatcherBase> build(const void * object, const std::string & name)
     207             :   {
     208        6232 :     return getDispatcher<Operation>(name)->build(object);
     209             :   }
     210             : 
     211             : private:
     212             :   /**
     213             :    * Get the registry singleton
     214             :    * @returns The registry singleton
     215             :    */
     216             :   static DispatcherRegistry & getRegistry();
     217             : 
     218             :   /**
     219             :    * Get the dispatcher shell of an operation of a functor
     220             :    * @tparam Operation The function tag of operator()
     221             :    * @param name The registered object type name
     222             :    * @returns The dispatcher shell
     223             :    */
     224             :   template <typename Operation>
     225     3784265 :   static auto & getDispatcher(const std::string & name)
     226             :   {
     227     3784265 :     auto operation = std::type_index(typeid(Operation));
     228             : 
     229     3784265 :     auto it = getRegistry()._dispatchers.find(std::make_pair(operation, name));
     230     3784265 :     if (it == getRegistry()._dispatchers.end())
     231           0 :       mooseError("Kokkos functor dispatcher not registered for object type '",
     232             :                  name,
     233             :                  "'. Double check that you used Kokkos-specific registration macro.");
     234             : 
     235     7568530 :     return it->second;
     236             :   }
     237             : 
     238             :   /**
     239             :    * Map containing the dispatcher shells with the key being the pair of function tag type index and
     240             :    * registered object type name
     241             :    */
     242             :   std::map<std::pair<std::type_index, std::string>, std::unique_ptr<DispatcherRegistryEntryBase>>
     243             :       _dispatchers;
     244             : };
     245             : 
     246             : } // namespace Kokkos
     247             : } // namespace Moose
     248             : 
     249             : #define callRegisterKokkosResidualObjectFunction(classname, objectname)                            \
     250             :   static char registerKokkosResidualObject##classname()                                            \
     251             :   {                                                                                                \
     252             :     using namespace Moose::Kokkos;                                                                 \
     253             :                                                                                                    \
     254             :     DispatcherRegistry::add<classname::ResidualLoop, classname>(objectname);                       \
     255             :     DispatcherRegistry::add<classname::JacobianLoop, classname>(objectname);                       \
     256             :     DispatcherRegistry::add<classname::OffDiagJacobianLoop, classname>(objectname);                \
     257             :     DispatcherRegistry::hasUserMethod<classname::JacobianLoop>(                                    \
     258             :         objectname, &classname::computeQpJacobian != classname::defaultJacobian());                \
     259             :     DispatcherRegistry::hasUserMethod<classname::OffDiagJacobianLoop>(                             \
     260             :         objectname, &classname::computeQpOffDiagJacobian != classname::defaultOffDiagJacobian());  \
     261             :                                                                                                    \
     262             :     return 0;                                                                                      \
     263             :   }                                                                                                \
     264             :                                                                                                    \
     265             :   static char combineNames(kokkos_dispatcher_residual_object_##classname, __COUNTER__) =           \
     266             :       registerKokkosResidualObject##classname()
     267             : 
     268             : #define registerKokkosResidualObject(app, classname)                                               \
     269             :   registerMooseObject(app, classname);                                                             \
     270             :   callRegisterKokkosResidualObjectFunction(classname, #classname)
     271             : 
     272             : #define registerKokkosResidualObjectAliased(app, classname, alias)                                 \
     273             :   registerMooseObjectAliased(app, classname, alias);                                               \
     274             :   callRegisterKokkosResidualObjectFunction(classname, alias)
     275             : 
     276             : #define callRegisterKokkosMaterialFunction(classname, objectname)                                  \
     277             :   static char registerKokkosMaterial##classname()                                                  \
     278             :   {                                                                                                \
     279             :     using namespace Moose::Kokkos;                                                                 \
     280             :                                                                                                    \
     281             :     DispatcherRegistry::add<classname::ElementInit, classname>(objectname);                        \
     282             :     DispatcherRegistry::add<classname::SideInit, classname>(objectname);                           \
     283             :     DispatcherRegistry::add<classname::NeighborInit, classname>(objectname);                       \
     284             :     DispatcherRegistry::add<classname::ElementCompute, classname>(objectname);                     \
     285             :     DispatcherRegistry::add<classname::SideCompute, classname>(objectname);                        \
     286             :     DispatcherRegistry::add<classname::NeighborCompute, classname>(objectname);                    \
     287             :     DispatcherRegistry::hasUserMethod<classname::ElementInit>(                                     \
     288             :         objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful());     \
     289             :     DispatcherRegistry::hasUserMethod<classname::SideInit>(                                        \
     290             :         objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful());     \
     291             :     DispatcherRegistry::hasUserMethod<classname::NeighborInit>(                                    \
     292             :         objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful());     \
     293             :                                                                                                    \
     294             :     return 0;                                                                                      \
     295             :   }                                                                                                \
     296             :                                                                                                    \
     297             :   static char combineNames(kokkos_dispatcher_material_##classname, __COUNTER__) =                  \
     298             :       registerKokkosMaterial##classname()
     299             : 
     300             : #define registerKokkosMaterial(app, classname)                                                     \
     301             :   registerMooseObject(app, classname);                                                             \
     302             :   callRegisterKokkosMaterialFunction(classname, #classname)
     303             : 
     304             : #define registerKokkosMaterialAliased(app, classname, alias)                                       \
     305             :   registerMooseObjectAliased(app, classname, alias);                                               \
     306             :   callRegisterKokkosMaterialFunction(classname, alias)

Generated by: LCOV version 1.14