Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://www.mooseframework.org
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include "KokkosHeader.h"
13 : #include "KokkosThread.h"
14 :
15 : #include <typeindex>
16 :
17 : namespace Moose
18 : {
19 : namespace Kokkos
20 : {
21 :
22 : using Policy = ::Kokkos::RangePolicy<ExecSpace, ::Kokkos::IndexType<ThreadID>>;
23 :
24 : /**
25 : * Base class for Kokkos functor dispatcher.
26 : * Used for type erasure so that the base class of functors can hold the dispatcher without knowing
27 : * the actual type of functors.
28 : */
29 : class DispatcherBase
30 : {
31 : public:
32 196710 : virtual ~DispatcherBase() {}
33 : /**
34 : * Dispatch this functor with Kokkos parallel_for() given a Kokkos execution policy
35 : * @param policy The Kokkos execution policy
36 : */
37 : virtual void parallelFor(const Policy & policy) = 0;
38 : };
39 :
40 : /**
41 : * Class that dispatches an operation of a Kokkos functor.
42 : * Calls operator() of the functor with a specified function tag.
43 : * @tparam Operation The function tag of operator() to be dispatched
44 : * @tparam Object The functor class type
45 : */
46 : template <typename Operation, typename Object>
47 : class Dispatcher : public DispatcherBase
48 : {
49 : public:
50 : /**
51 : * Constructor
52 : * @param object The pointer to the functor. This dispatcher is constructed by the base class of
53 : * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
54 : * as a void pointer and cast to the actual type here.
55 : */
56 6232 : Dispatcher(const void * object)
57 4889 : : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
58 : {
59 6232 : }
60 : /**
61 : * Copy constructor for parallel dispatch
62 : */
63 190516 : Dispatcher(const Dispatcher & functor)
64 155189 : : _functor_host(functor._functor_host), _functor_device(functor._functor_host)
65 : {
66 190516 : }
67 :
68 190515 : void parallelFor(const Policy & policy) override final
69 : {
70 190515 : ::Kokkos::parallel_for(policy, *this);
71 190515 : ::Kokkos::fence();
72 190515 : }
73 :
74 : /**
75 : * The parallel computation entry function called by Kokkos
76 : */
77 5647140 : KOKKOS_FUNCTION void operator()(const ThreadID tid) const
78 : {
79 5647140 : _functor_device(Operation{}, tid, _functor_device);
80 5647140 : }
81 :
82 : private:
83 : /**
84 : * Reference of the functor on host
85 : */
86 : const Object & _functor_host;
87 : /**
88 : * Copy of the functor on device
89 : */
90 : const Object _functor_device;
91 : };
92 :
93 : /**
94 : * Base class for dispatcher registry entry.
95 : * Used for type erasure so that the registry can hold dispatchers for different functor types in a
96 : * single container.
97 : */
98 : class DispatcherRegistryEntryBase
99 : {
100 : public:
101 0 : virtual ~DispatcherRegistryEntryBase() {}
102 : /**
103 : * Build a dispatcher for this operation and functor
104 : * @param object The pointer to the functor
105 : */
106 : virtual std::unique_ptr<DispatcherBase> build(const void * object) const = 0;
107 :
108 : /**
109 : * Set whether the user has overriden the hook method associated with this operation
110 : * @param flag Whether the user has overriden the hook method
111 : */
112 3737125 : void hasUserMethod(bool flag) { _has_user_method = flag; }
113 : /**
114 : * Get whether the user has overriden the hook method associated with this operation
115 : * @returns Whether the user has overriden the hook method
116 : */
117 40908 : bool hasUserMethod() const { return _has_user_method; }
118 :
119 : private:
120 : /**
121 : * Flag whether the user has overriden the hook method associated with this operation
122 : */
123 : bool _has_user_method = false;
124 : };
125 :
126 : /**
127 : * Class that stores the information of a dispatcher and builds it.
128 : * This shell class is the entry of the dispatcher registry instead of the dispatcher itself.
129 : * The reason this class does not dispatch the functor directly is to let the dispatcher hold
130 : * the reference of the functor so that the functor does not need to be copied twice at each
131 : * dispatch. Namely, dispatchers are to be built and held by the functors, not the registry.
132 : * @tparam Operation The function tag of operator() to be dispatched
133 : * @tparam Object The functor class type
134 : */
135 : template <typename Operation, typename Object>
136 : class DispatcherRegistryEntry : public DispatcherRegistryEntryBase
137 : {
138 : public:
139 6232 : std::unique_ptr<DispatcherBase> build(const void * object) const override final
140 : {
141 6232 : return std::make_unique<Dispatcher<Operation, Object>>(object);
142 : }
143 : };
144 :
145 : /**
146 : * Class that registers dispatchers of all Kokkos functors
147 : */
148 : class DispatcherRegistry
149 : {
150 : public:
151 37386 : DispatcherRegistry() = default;
152 :
153 : DispatcherRegistry(DispatcherRegistry const &) = delete;
154 : DispatcherRegistry & operator=(DispatcherRegistry const &) = delete;
155 :
156 : DispatcherRegistry(DispatcherRegistry &&) = delete;
157 : DispatcherRegistry & operator=(DispatcherRegistry &&) = delete;
158 :
159 : /**
160 : * Register a dispatcher of an operation of a functor
161 : * @tparam Operation The function tag of operator() to be dispatched
162 : * @tparam Object The functor class type
163 : * @param name The registered object type name
164 : */
165 : template <typename Operation, typename Object>
166 6390308 : static void add(const std::string & name)
167 : {
168 6390308 : auto operation = std::type_index(typeid(Operation));
169 :
170 6390308 : getRegistry()._dispatchers[std::make_pair(operation, name)] =
171 : std::make_unique<DispatcherRegistryEntry<Operation, Object>>();
172 6390308 : }
173 :
174 : /**
175 : * Set whether the user has overriden the hook method associated with an operation of a functor
176 : * @tparam Operation The function tag of operator()
177 : * @param name The registered object type name
178 : * @param flag Whether the user has overriden the hook method
179 : */
180 : template <typename Operation>
181 3737125 : static void hasUserMethod(const std::string & name, bool flag)
182 : {
183 3737125 : getDispatcher<Operation>(name)->hasUserMethod(flag);
184 3737125 : }
185 :
186 : /**
187 : * Get whether the user has overriden the hook method associated with an operation of a functor
188 : * @tparam Operation The function tag of operator()
189 : * @param name The registered object type name
190 : * @returns Whether the user has overriden the hook method
191 : */
192 : template <typename Operation>
193 40908 : static bool hasUserMethod(const std::string & name)
194 : {
195 40908 : return getDispatcher<Operation>(name)->hasUserMethod();
196 : }
197 :
198 : /**
199 : * Build and get a dispatcher of an operation of a functor
200 : * @tparam Operation The function tag of operator()
201 : * @param object The pointer to the functor
202 : * @param name The registered object type name
203 : * @returns The dispatcher
204 : */
205 : template <typename Operation>
206 6232 : static std::unique_ptr<DispatcherBase> build(const void * object, const std::string & name)
207 : {
208 6232 : return getDispatcher<Operation>(name)->build(object);
209 : }
210 :
211 : private:
212 : /**
213 : * Get the registry singleton
214 : * @returns The registry singleton
215 : */
216 : static DispatcherRegistry & getRegistry();
217 :
218 : /**
219 : * Get the dispatcher shell of an operation of a functor
220 : * @tparam Operation The function tag of operator()
221 : * @param name The registered object type name
222 : * @returns The dispatcher shell
223 : */
224 : template <typename Operation>
225 3784265 : static auto & getDispatcher(const std::string & name)
226 : {
227 3784265 : auto operation = std::type_index(typeid(Operation));
228 :
229 3784265 : auto it = getRegistry()._dispatchers.find(std::make_pair(operation, name));
230 3784265 : if (it == getRegistry()._dispatchers.end())
231 0 : mooseError("Kokkos functor dispatcher not registered for object type '",
232 : name,
233 : "'. Double check that you used Kokkos-specific registration macro.");
234 :
235 7568530 : return it->second;
236 : }
237 :
238 : /**
239 : * Map containing the dispatcher shells with the key being the pair of function tag type index and
240 : * registered object type name
241 : */
242 : std::map<std::pair<std::type_index, std::string>, std::unique_ptr<DispatcherRegistryEntryBase>>
243 : _dispatchers;
244 : };
245 :
246 : } // namespace Kokkos
247 : } // namespace Moose
248 :
249 : #define callRegisterKokkosResidualObjectFunction(classname, objectname) \
250 : static char registerKokkosResidualObject##classname() \
251 : { \
252 : using namespace Moose::Kokkos; \
253 : \
254 : DispatcherRegistry::add<classname::ResidualLoop, classname>(objectname); \
255 : DispatcherRegistry::add<classname::JacobianLoop, classname>(objectname); \
256 : DispatcherRegistry::add<classname::OffDiagJacobianLoop, classname>(objectname); \
257 : DispatcherRegistry::hasUserMethod<classname::JacobianLoop>( \
258 : objectname, &classname::computeQpJacobian != classname::defaultJacobian()); \
259 : DispatcherRegistry::hasUserMethod<classname::OffDiagJacobianLoop>( \
260 : objectname, &classname::computeQpOffDiagJacobian != classname::defaultOffDiagJacobian()); \
261 : \
262 : return 0; \
263 : } \
264 : \
265 : static char combineNames(kokkos_dispatcher_residual_object_##classname, __COUNTER__) = \
266 : registerKokkosResidualObject##classname()
267 :
268 : #define registerKokkosResidualObject(app, classname) \
269 : registerMooseObject(app, classname); \
270 : callRegisterKokkosResidualObjectFunction(classname, #classname)
271 :
272 : #define registerKokkosResidualObjectAliased(app, classname, alias) \
273 : registerMooseObjectAliased(app, classname, alias); \
274 : callRegisterKokkosResidualObjectFunction(classname, alias)
275 :
276 : #define callRegisterKokkosMaterialFunction(classname, objectname) \
277 : static char registerKokkosMaterial##classname() \
278 : { \
279 : using namespace Moose::Kokkos; \
280 : \
281 : DispatcherRegistry::add<classname::ElementInit, classname>(objectname); \
282 : DispatcherRegistry::add<classname::SideInit, classname>(objectname); \
283 : DispatcherRegistry::add<classname::NeighborInit, classname>(objectname); \
284 : DispatcherRegistry::add<classname::ElementCompute, classname>(objectname); \
285 : DispatcherRegistry::add<classname::SideCompute, classname>(objectname); \
286 : DispatcherRegistry::add<classname::NeighborCompute, classname>(objectname); \
287 : DispatcherRegistry::hasUserMethod<classname::ElementInit>( \
288 : objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful()); \
289 : DispatcherRegistry::hasUserMethod<classname::SideInit>( \
290 : objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful()); \
291 : DispatcherRegistry::hasUserMethod<classname::NeighborInit>( \
292 : objectname, &classname::initQpStatefulProperties != classname::defaultInitStateful()); \
293 : \
294 : return 0; \
295 : } \
296 : \
297 : static char combineNames(kokkos_dispatcher_material_##classname, __COUNTER__) = \
298 : registerKokkosMaterial##classname()
299 :
300 : #define registerKokkosMaterial(app, classname) \
301 : registerMooseObject(app, classname); \
302 : callRegisterKokkosMaterialFunction(classname, #classname)
303 :
304 : #define registerKokkosMaterialAliased(app, classname, alias) \
305 : registerMooseObjectAliased(app, classname, alias); \
306 : callRegisterKokkosMaterialFunction(classname, alias)
|