Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "KokkosTypes.h" 13 : 14 : namespace Moose::Kokkos 15 : { 16 : 17 : template <typename Object> 18 : class FunctionWrapperHost; 19 : 20 : /** 21 : * Base class for device function wrapper 22 : */ 23 : class FunctionWrapperDeviceBase 24 : { 25 : public: 26 : /** 27 : * Constructor 28 : */ 29 118 : KOKKOS_FUNCTION FunctionWrapperDeviceBase() {} 30 : /** 31 : * Virtual destructor 32 : */ 33 0 : KOKKOS_FUNCTION virtual ~FunctionWrapperDeviceBase() {} 34 : 35 : /** 36 : * Virtual shims that calls the corresponding methods of the actual stored function 37 : */ 38 : ///@{ 39 : KOKKOS_FUNCTION virtual Real value(Real t, Real3 p) const = 0; 40 : KOKKOS_FUNCTION virtual Real3 vectorValue(Real t, Real3 p) const = 0; 41 : KOKKOS_FUNCTION virtual Real3 gradient(Real t, Real3 p) const = 0; 42 : KOKKOS_FUNCTION virtual Real3 curl(Real t, Real3 p) const = 0; 43 : KOKKOS_FUNCTION virtual Real div(Real t, Real3 p) const = 0; 44 : KOKKOS_FUNCTION virtual Real timeDerivative(Real t, Real3 p) const = 0; 45 : KOKKOS_FUNCTION virtual Real timeIntegral(Real t1, Real t2, Real3 p) const = 0; 46 : KOKKOS_FUNCTION virtual Real integral() const = 0; 47 : KOKKOS_FUNCTION virtual Real average() const = 0; 48 : ///@} 49 : }; 50 : 51 : /** 52 : * Device function wrapper class that provides polymorphic interfaces for a function. The function 53 : * itself is a static object and does not have any virtual method. Instead, the device wrapper 54 : * defines the virtual shims and forwards the calls to the static methods of the stored function. 55 : * @tparam Object The function class type 56 : */ 57 : template <typename Object> 58 : class FunctionWrapperDevice : public FunctionWrapperDeviceBase 59 : { 60 : friend class FunctionWrapperHost<Object>; 61 : 62 : public: 63 : /** 64 : * Constructor 65 : */ 66 118 : KOKKOS_FUNCTION FunctionWrapperDevice() {} 67 : 68 13123379 : KOKKOS_FUNCTION Real value(Real t, Real3 p) const override final 69 : { 70 13123379 : return _function->value(t, p); 71 : } 72 0 : KOKKOS_FUNCTION Real3 vectorValue(Real t, Real3 p) const override final 73 : { 74 0 : return _function->vectorValue(t, p); 75 : } 76 0 : KOKKOS_FUNCTION Real3 gradient(Real t, Real3 p) const override final 77 : { 78 0 : return _function->gradient(t, p); 79 : } 80 0 : KOKKOS_FUNCTION Real3 curl(Real t, Real3 p) const override final { return _function->curl(t, p); } 81 0 : KOKKOS_FUNCTION Real div(Real t, Real3 p) const override final { return _function->div(t, p); } 82 0 : KOKKOS_FUNCTION Real timeDerivative(Real t, Real3 p) const override final 83 : { 84 0 : return _function->timeDerivative(t, p); 85 : } 86 0 : KOKKOS_FUNCTION Real timeIntegral(Real t1, Real t2, Real3 p) const override final 87 : { 88 0 : return _function->timeIntegral(t1, t2, p); 89 : } 90 0 : KOKKOS_FUNCTION Real integral() const override final { return _function->integral(); } 91 0 : KOKKOS_FUNCTION Real average() const override final { return _function->average(); } 92 : 93 : protected: 94 : /** 95 : * Pointer to the function on device 96 : */ 97 : Object * _function = nullptr; 98 : }; 99 : 100 : /** 101 : * Base class for host function wrapper 102 : */ 103 : class FunctionWrapperHostBase 104 : { 105 : public: 106 : /** 107 : * Virtual destructor 108 : */ 109 119 : virtual ~FunctionWrapperHostBase() {} 110 : 111 : /** 112 : * Allocate device function and wrapper 113 : * @returns The pointer to the device wrapper 114 : */ 115 : virtual FunctionWrapperDeviceBase * allocate() = 0; 116 : /** 117 : * Copy function to device 118 : */ 119 : virtual void copyFunction() = 0; 120 : /** 121 : * Free host and device copies of function 122 : */ 123 : virtual void freeFunction() = 0; 124 : }; 125 : 126 : /** 127 : * Host function wrapper class that allocates a function on device and creates its device wrapper. 128 : * This class holds the actual device instance of the function and manages its allocation and 129 : * deallocation, and the device wrapper simply keeps a pointer to it. 130 : * @tparam Object The function class type 131 : */ 132 : template <typename Object> 133 : class FunctionWrapperHost : public FunctionWrapperHostBase 134 : { 135 : public: 136 : /** 137 : * Constructor 138 : * @param function Pointer to the function 139 : */ 140 119 : FunctionWrapperHost(const void * function) 141 118 : : _function_host(*static_cast<const Object *>(function)) 142 : { 143 119 : } 144 : /** 145 : * Destructor 146 : */ 147 : ~FunctionWrapperHost(); 148 : 149 : FunctionWrapperDeviceBase * allocate() override final; 150 : void copyFunction() override final; 151 : void freeFunction() override final; 152 : 153 : private: 154 : /** 155 : * Reference of the function on host 156 : */ 157 : const Object & _function_host; 158 : /** 159 : * Copy of the function on host 160 : */ 161 : std::unique_ptr<Object> _function_copy; 162 : /** 163 : * Copy of the function on device 164 : */ 165 : Object * _function_device = nullptr; 166 : }; 167 : 168 : template <typename Object> 169 : FunctionWrapperDeviceBase * 170 119 : FunctionWrapperHost<Object>::allocate() 171 : { 172 : // Allocate storage for device wrapper on device 173 1 : auto wrapper_device = static_cast<FunctionWrapperDevice<Object> *>( 174 118 : ::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(FunctionWrapperDevice<Object>))); 175 : 176 : // Allocate device wrapper on device using placement new to populate vtable with device pointers 177 119 : ::Kokkos::parallel_for( 178 236 : 1, KOKKOS_LAMBDA(const int) { new (wrapper_device) FunctionWrapperDevice<Object>(); }); 179 : 180 : // Allocate storage for function on device 181 119 : _function_device = 182 118 : static_cast<Object *>(::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(Object))); 183 : 184 : // Let device wrapper point to the copy 185 237 : ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>( 186 118 : &(wrapper_device->_function), &_function_device, sizeof(Object *)); 187 : 188 119 : return wrapper_device; 189 : } 190 : 191 : template <typename Object> 192 : void 193 20949 : FunctionWrapperHost<Object>::copyFunction() 194 : { 195 : // Make a copy of function on host to trigger copy constructor 196 20949 : _function_copy = std::make_unique<Object>(_function_host); 197 : 198 : // Copy function to device 199 41894 : ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>( 200 20945 : _function_device, _function_copy.get(), sizeof(Object)); 201 20949 : } 202 : 203 : template <typename Object> 204 : void 205 21068 : FunctionWrapperHost<Object>::freeFunction() 206 : { 207 21068 : _function_copy.reset(); 208 21068 : } 209 : 210 : template <typename Object> 211 238 : FunctionWrapperHost<Object>::~FunctionWrapperHost() 212 : { 213 119 : ::Kokkos::kokkos_free<ExecSpace::memory_space>(_function_device); 214 238 : } 215 : 216 : } // namespace Moose::Kokkos