Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://www.mooseframework.org
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include "KokkosHeader.h"
13 : #include "KokkosThread.h"
14 :
15 : #include <typeindex>
16 :
17 : namespace Moose::Kokkos
18 : {
19 :
20 : using Policy = ::Kokkos::RangePolicy<ExecSpace, ::Kokkos::IndexType<ThreadID>>;
21 :
22 : /**
23 : * Base class for Kokkos functor dispatcher.
24 : * Used for type erasure so that the base class of functors can hold the dispatcher without knowing
25 : * the actual type of functors.
26 : */
27 : class DispatcherBase
28 : {
29 : public:
30 508889 : virtual ~DispatcherBase() {}
31 : /**
32 : * Dispatch this functor with Kokkos parallel_for() given a Kokkos execution policy
33 : * @param policy The Kokkos execution policy
34 : */
35 0 : virtual void parallelFor(const Policy & /* policy */)
36 : {
37 0 : mooseError("parallelFor() called for an instance that is not a dispatcher.");
38 : }
39 : /**
40 : * Dispatch this functor with Kokkos parallel_reduce() given a Kokkos execution policy and result
41 : * buffer
42 : * @param policy The Kokkos execution policy
43 : * @param result The result buffer
44 : */
45 0 : virtual void parallelReduce(const Policy & /* policy */,
46 : ::Kokkos::View<Real *, ::Kokkos::HostSpace> & /* result */)
47 : {
48 0 : mooseError("parallelReduce() called for an instance that is not a reducer.");
49 : }
50 : };
51 :
52 : /**
53 : * Class that dispatches a parallel loop operation of a Kokkos functor.
54 : * Calls operator() of the functor with a specified function tag.
55 : * @tparam Operation The function tag of operator() to be dispatched
56 : * @tparam Object The functor class type
57 : */
58 : template <typename Operation, typename Object>
59 : class Dispatcher : public DispatcherBase
60 : {
61 : public:
62 : /**
63 : * Constructor
64 : * @param object The pointer to the functor. This dispatcher is constructed by the base class of
65 : * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
66 : * as a void pointer and cast to the actual type here.
67 : */
68 22923 : Dispatcher(const void * object)
69 12609 : : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
70 : {
71 22923 : }
72 : /**
73 : * Copy constructor for parallel dispatch
74 : */
75 478284 : Dispatcher(const Dispatcher & functor)
76 272454 : : _functor_host(functor._functor_host), _functor_device(functor._functor_host)
77 : {
78 478284 : }
79 :
80 478283 : void parallelFor(const Policy & policy) override final
81 : {
82 478283 : ::Kokkos::parallel_for(policy, *this);
83 478283 : ::Kokkos::fence();
84 478283 : }
85 :
86 : /**
87 : * The parallel computation entry function called by Kokkos::parallel_for
88 : */
89 27318554 : KOKKOS_FUNCTION void operator()(const ThreadID tid) const
90 : {
91 27318554 : _functor_device(Operation{}, tid, _functor_device);
92 27318554 : }
93 :
94 : private:
95 : /**
96 : * Reference of the functor on host
97 : */
98 : const Object & _functor_host;
99 : /**
100 : * Copy of the functor on device
101 : */
102 : const Object _functor_device;
103 : };
104 :
105 : /**
106 : * Class that dispatches a parallel reduction operation of a Kokkos functor.
107 : * Calls operator() of the functor with a specified function tag.
108 : * @tparam Operation The function tag of operator() to be dispatched
109 : * @tparam Object The functor class type
110 : */
111 : template <typename Operation, typename Object>
112 : class Reducer : public DispatcherBase
113 : {
114 : public:
115 : /**
116 : * Constructor
117 : * @param object The pointer to the functor. This reducer is constructed by the base class of
118 : * functors, and the actual type of functors is unknown by the base class. Therefore, it is passed
119 : * as a void pointer and cast to the actual type here.
120 : */
121 966 : Reducer(const void * object)
122 538 : : _functor_host(*static_cast<const Object *>(object)), _functor_device(_functor_host)
123 : {
124 966 : }
125 : /**
126 : * Copy constructor for parallel dispatch
127 : */
128 6754 : Reducer(const Reducer & functor)
129 4092 : : value_count(functor.value_count),
130 4092 : _functor_host(functor._functor_host),
131 4092 : _functor_device(functor._functor_host)
132 : {
133 6754 : }
134 :
135 2251 : void parallelReduce(const Policy & policy,
136 : ::Kokkos::View<Real *, ::Kokkos::HostSpace> & result) override final
137 : {
138 2251 : value_count = result.size();
139 :
140 2251 : ::Kokkos::parallel_reduce(policy, *this, result);
141 2251 : ::Kokkos::fence();
142 2251 : }
143 :
144 : using value_type = Real[];
145 : using size_type = ::Kokkos::View<Real *>::size_type;
146 :
147 : size_type value_count;
148 :
149 : /**
150 : * The parallel computation entry function called by Kokkos::parallel_reduce
151 : */
152 382782 : KOKKOS_FUNCTION void operator()(const ThreadID tid, value_type result) const
153 : {
154 382782 : _functor_device(Operation{}, tid, _functor_device, result);
155 382782 : }
156 :
157 : /**
158 : * Functions required by the reducer concept of Kokkos
159 : */
160 : ///@{
161 169 : KOKKOS_FUNCTION void join(value_type result, const value_type source) const
162 : {
163 169 : _functor_device.template join<Object>(result, source);
164 169 : }
165 1533 : KOKKOS_FUNCTION void init(value_type result) const
166 : {
167 1533 : _functor_device.template init<Object>(result);
168 1533 : }
169 : ///@}
170 :
171 : private:
172 : /**
173 : * Reference of the functor on host
174 : */
175 : const Object & _functor_host;
176 : /**
177 : * Copy of the functor on device
178 : */
179 : const Object _functor_device;
180 : };
181 :
182 : /**
183 : * Base class for dispatcher registry entry.
184 : * Used for type erasure so that the registry can hold dispatchers for different functor types in a
185 : * single container.
186 : */
187 : class DispatcherRegistryEntryBase
188 : {
189 : public:
190 4 : virtual ~DispatcherRegistryEntryBase() {}
191 :
192 : /**
193 : * Build a dispatcher for this operation and functor
194 : * @param object The pointer to the functor
195 : */
196 : virtual std::unique_ptr<DispatcherBase> build(const void * object) const = 0;
197 :
198 : /**
199 : * Set whether the user has overriden the hook method associated with this operation
200 : * @param flag Whether the user has overriden the hook method
201 : */
202 6656600 : void hasUserMethod(bool flag) { _has_user_method = flag; }
203 : /**
204 : * Get whether the user has overriden the hook method associated with this operation
205 : * @returns Whether the user has overriden the hook method
206 : */
207 90531 : bool hasUserMethod() const { return _has_user_method; }
208 :
209 : private:
210 : /**
211 : * Flag whether the user has overriden the hook method associated with this operation
212 : */
213 : bool _has_user_method = false;
214 : };
215 :
216 : /**
217 : * Class that stores the information of a dispatcher and builds it.
218 : * This shell class is the entry of the dispatcher registry instead of the dispatcher itself.
219 : * The reason this class does not dispatch the functor directly is to let the dispatcher hold
220 : * the reference of the functor so that the functor does not need to be copied twice at each
221 : * dispatch. Namely, dispatchers are to be built and held by the functors, not the registry.
222 : * @tparam Operation The function tag of operator() to be dispatched
223 : * @tparam Object The functor class type
224 : */
225 : ///@{
226 : template <typename Operation, typename Object>
227 : class DispatcherRegistryEntry : public DispatcherRegistryEntryBase
228 : {
229 : public:
230 22923 : std::unique_ptr<DispatcherBase> build(const void * object) const override final
231 : {
232 22923 : return std::make_unique<Dispatcher<Operation, Object>>(object);
233 : }
234 : };
235 :
236 : template <typename Operation, typename Object>
237 : class ReducerRegistryEntry : public DispatcherRegistryEntryBase
238 : {
239 : public:
240 966 : std::unique_ptr<DispatcherBase> build(const void * object) const override final
241 : {
242 966 : return std::make_unique<Reducer<Operation, Object>>(object);
243 : }
244 : };
245 : ///@}
246 :
247 : /**
248 : * Class that registers dispatchers of all Kokkos functors
249 : */
250 : class DispatcherRegistry
251 : {
252 : public:
253 40355 : DispatcherRegistry() = default;
254 :
255 : DispatcherRegistry(DispatcherRegistry const &) = delete;
256 : DispatcherRegistry & operator=(DispatcherRegistry const &) = delete;
257 :
258 : DispatcherRegistry(DispatcherRegistry &&) = delete;
259 : DispatcherRegistry & operator=(DispatcherRegistry &&) = delete;
260 :
261 : /**
262 : * Register a dispatcher of an operation of a functor
263 : * @tparam Operation The function tag of operator() to be dispatched
264 : * @tparam Object The functor class type
265 : * @param name The registered object type name
266 : */
267 : template <typename Operation, typename Object>
268 10488804 : static void addDispatcher(const std::string & name)
269 : {
270 10488804 : auto operation = std::type_index(typeid(Operation));
271 :
272 10488804 : getRegistry()._dispatchers[std::make_pair(operation, name)] =
273 : std::make_unique<DispatcherRegistryEntry<Operation, Object>>();
274 10488804 : }
275 :
276 : /**
277 : * Register a reducer of an operation of a functor
278 : * @tparam Operation The function tag of operator() to be dispatched
279 : * @tparam Object The functor class type
280 : * @param name The registered object type name
281 : */
282 : template <typename Operation, typename Object>
283 847384 : static void addReducer(const std::string & name)
284 : {
285 847384 : auto operation = std::type_index(typeid(Operation));
286 :
287 847384 : getRegistry()._dispatchers[std::make_pair(operation, name)] =
288 : std::make_unique<ReducerRegistryEntry<Operation, Object>>();
289 847384 : }
290 :
291 : /**
292 : * Set whether the user has overriden the hook method associated with an operation of a functor
293 : * @tparam Operation The function tag of operator()
294 : * @param name The registered object type name
295 : * @param flag Whether the user has overriden the hook method
296 : */
297 : template <typename Operation>
298 6656600 : static void hasUserMethod(const std::string & name, const bool flag)
299 : {
300 6656600 : getDispatcher<Operation>(name)->hasUserMethod(flag);
301 6656600 : }
302 :
303 : /**
304 : * Get whether the user has overriden the hook method associated with an operation of a functor
305 : * @tparam Operation The function tag of operator()
306 : * @param name The registered object type name
307 : * @returns Whether the user has overriden the hook method
308 : */
309 : template <typename Operation>
310 90531 : static bool hasUserMethod(const std::string & name)
311 : {
312 90531 : return getDispatcher<Operation>(name)->hasUserMethod();
313 : }
314 :
315 : /**
316 : * Build and get a dispatcher of an operation of a functor
317 : * @tparam Operation The function tag of operator()
318 : * @param object The pointer to the functor
319 : * @param name The registered object type name
320 : * @returns The dispatcher
321 : */
322 : template <typename Operation>
323 23889 : static std::unique_ptr<DispatcherBase> build(const void * object, const std::string & name)
324 : {
325 23889 : return getDispatcher<Operation>(name)->build(object);
326 : }
327 :
328 : private:
329 : /**
330 : * Get the registry singleton
331 : * @returns The registry singleton
332 : */
333 : static DispatcherRegistry & getRegistry();
334 :
335 : /**
336 : * Get the dispatcher shell of an operation of a functor
337 : * @tparam Operation The function tag of operator()
338 : * @param name The registered object type name
339 : * @returns The dispatcher shell
340 : */
341 : template <typename Operation>
342 6771020 : static auto & getDispatcher(const std::string & name)
343 : {
344 6771020 : auto operation = std::type_index(typeid(Operation));
345 :
346 6771020 : auto it = getRegistry()._dispatchers.find(std::make_pair(operation, name));
347 6771020 : if (it == getRegistry()._dispatchers.end())
348 0 : mooseError("Kokkos functor dispatcher not registered for object type '",
349 : name,
350 : "'. Double check that you used Kokkos-specific registration macro.");
351 :
352 13542040 : return it->second;
353 : }
354 :
355 : /**
356 : * Map containing the dispatcher shells with the key being the pair of function tag type index and
357 : * registered object type name
358 : */
359 : std::map<std::pair<std::type_index, std::string>, std::unique_ptr<DispatcherRegistryEntryBase>>
360 : _dispatchers;
361 : };
362 :
363 : } // namespace Moose::Kokkos
364 :
365 : // Kernel, NodalKernel, BC
366 :
367 : #define callRegisterKokkosResidualObjectFunction(classname, objectname) \
368 : static char registerKokkosResidualObject##classname() \
369 : { \
370 : using namespace Moose::Kokkos; \
371 : \
372 : DispatcherRegistry::addDispatcher<classname::ResidualLoop, classname>(objectname); \
373 : DispatcherRegistry::addDispatcher<classname::JacobianLoop, classname>(objectname); \
374 : DispatcherRegistry::addDispatcher<classname::OffDiagJacobianLoop, classname>(objectname); \
375 : DispatcherRegistry::hasUserMethod<classname::JacobianLoop>( \
376 : objectname, \
377 : &classname::computeQpJacobian<classname> != classname::defaultJacobian<classname>()); \
378 : DispatcherRegistry::hasUserMethod<classname::OffDiagJacobianLoop>( \
379 : objectname, \
380 : &classname::computeQpOffDiagJacobian<classname> != \
381 : classname::defaultOffDiagJacobian<classname>()); \
382 : \
383 : return 0; \
384 : } \
385 : \
386 : static char combineNames(kokkos_dispatcher_residual_object_##classname, __COUNTER__) = \
387 : registerKokkosResidualObject##classname()
388 :
389 : #define registerKokkosResidualObject(app, classname) \
390 : registerMooseObject(app, classname); \
391 : callRegisterKokkosResidualObjectFunction(classname, #classname)
392 :
393 : #define registerKokkosResidualObjectAliased(app, classname, alias) \
394 : registerMooseObjectAliased(app, classname, alias); \
395 : callRegisterKokkosResidualObjectFunction(classname, alias)
396 :
397 : // AD Kernel, NodalKernel, BC
398 :
399 : #define callRegisterKokkosADResidualObjectFunction(classname, objectname) \
400 : static char registerKokkosADResidualObject##classname() \
401 : { \
402 : using namespace Moose::Kokkos; \
403 : \
404 : DispatcherRegistry::addDispatcher<classname::ResidualLoop, classname>(objectname); \
405 : \
406 : return 0; \
407 : } \
408 : \
409 : static char combineNames(kokkos_dispatcher_ad_residual_object_##classname, __COUNTER__) = \
410 : registerKokkosADResidualObject##classname()
411 :
412 : #define registerKokkosADResidualObject(app, classname) \
413 : registerMooseObject(app, classname); \
414 : callRegisterKokkosADResidualObjectFunction(classname, #classname)
415 :
416 : #define registerKokkosADResidualObjectAliased(app, classname, alias) \
417 : registerMooseObjectAliased(app, classname, alias); \
418 : callRegisterKokkosADResidualObjectFunction(classname, alias)
419 :
420 : // Material
421 :
422 : #define callRegisterKokkosMaterialFunction(classname, objectname) \
423 : static char registerKokkosMaterial##classname() \
424 : { \
425 : using namespace Moose::Kokkos; \
426 : \
427 : DispatcherRegistry::addDispatcher<classname::ElementInit, classname>(objectname); \
428 : DispatcherRegistry::addDispatcher<classname::SideInit, classname>(objectname); \
429 : DispatcherRegistry::addDispatcher<classname::NeighborInit, classname>(objectname); \
430 : DispatcherRegistry::addDispatcher<classname::ElementCompute, classname>(objectname); \
431 : DispatcherRegistry::addDispatcher<classname::SideCompute, classname>(objectname); \
432 : DispatcherRegistry::addDispatcher<classname::NeighborCompute, classname>(objectname); \
433 : DispatcherRegistry::hasUserMethod<classname::ElementInit>( \
434 : objectname, \
435 : &classname::initQpStatefulProperties<classname> != \
436 : classname::defaultInitStateful<classname>()); \
437 : DispatcherRegistry::hasUserMethod<classname::SideInit>( \
438 : objectname, \
439 : &classname::initQpStatefulProperties<classname> != \
440 : classname::defaultInitStateful<classname>()); \
441 : DispatcherRegistry::hasUserMethod<classname::NeighborInit>( \
442 : objectname, \
443 : &classname::initQpStatefulProperties<classname> != \
444 : classname::defaultInitStateful<classname>()); \
445 : \
446 : return 0; \
447 : } \
448 : \
449 : static char combineNames(kokkos_dispatcher_material_##classname, __COUNTER__) = \
450 : registerKokkosMaterial##classname()
451 :
452 : #define registerKokkosMaterial(app, classname) \
453 : registerMooseObject(app, classname); \
454 : callRegisterKokkosMaterialFunction(classname, #classname)
455 :
456 : #define registerKokkosMaterialAliased(app, classname, alias) \
457 : registerMooseObjectAliased(app, classname, alias); \
458 : callRegisterKokkosMaterialFunction(classname, alias)
459 :
460 : // AuxKernel
461 :
462 : #define callRegisterKokkosAuxKernelFunction(classname, objectname) \
463 : static char registerKokkosAuxKernel##classname() \
464 : { \
465 : using namespace Moose::Kokkos; \
466 : \
467 : DispatcherRegistry::addDispatcher<classname::ElementLoop, classname>(objectname); \
468 : DispatcherRegistry::addDispatcher<classname::NodeLoop, classname>(objectname); \
469 : \
470 : return 0; \
471 : } \
472 : \
473 : static char combineNames(kokkos_dispatcher_auxkernel_##classname, __COUNTER__) = \
474 : registerKokkosAuxKernel##classname()
475 :
476 : #define registerKokkosAuxKernel(app, classname) \
477 : registerMooseObject(app, classname); \
478 : callRegisterKokkosAuxKernelFunction(classname, #classname)
479 :
480 : #define registerKokkosAuxKernelAliased(app, classname, alias) \
481 : registerMooseObjectAliased(app, classname, alias); \
482 : callRegisterKokkosAuxKernelFunction(classname, alias)
483 :
484 : // UserObject
485 :
486 : #define callRegisterKokkosUserObjectFunction(classname, objectname) \
487 : static char registerKokkosUserObject##classname() \
488 : { \
489 : using namespace Moose::Kokkos; \
490 : \
491 : DispatcherRegistry::addDispatcher<classname::DefaultLoop, classname>(objectname); \
492 : DispatcherRegistry::addReducer<classname::ReducerLoop, classname>(objectname); \
493 : DispatcherRegistry::hasUserMethod<classname::DefaultLoop>( \
494 : objectname, &classname::execute<classname> != classname::defaultExecute<classname>()); \
495 : DispatcherRegistry::hasUserMethod<classname::ReducerLoop>( \
496 : objectname, &classname::reduce<classname> != classname::defaultReduce<classname>()); \
497 : \
498 : return 0; \
499 : } \
500 : \
501 : static char combineNames(kokkos_dispatcher_userobject_##classname, __COUNTER__) = \
502 : registerKokkosUserObject##classname()
503 :
504 : #define registerKokkosUserObject(app, classname) \
505 : registerMooseObject(app, classname); \
506 : callRegisterKokkosUserObjectFunction(classname, #classname)
507 :
508 : #define registerKokkosUserObjectAliased(app, classname, alias) \
509 : registerMooseObjectAliased(app, classname, alias); \
510 : callRegisterKokkosUserObjectFunction(classname, alias)
511 :
512 : // User-defined parallel operation registry
513 :
514 : #define registerKokkosAdditionalOperation(classname, operation) \
515 : static char registerKokkos##classname##operation() \
516 : { \
517 : using namespace Moose::Kokkos; \
518 : \
519 : DispatcherRegistry::addDispatcher<classname::operation, classname>(#classname); \
520 : \
521 : return 0; \
522 : } \
523 : \
524 : static char combineNames(kokkos_##classname##_##operation, __COUNTER__) = \
525 : registerKokkos##classname##operation()
|