Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "KokkosTypes.h" 13 : 14 : namespace Moose 15 : { 16 : namespace Kokkos 17 : { 18 : 19 : template <typename Object> 20 : class FunctionWrapperHost; 21 : 22 : /** 23 : * Base class for device function wrapper 24 : */ 25 : class FunctionWrapperDeviceBase 26 : { 27 : public: 28 : /** 29 : * Constructor 30 : */ 31 91 : KOKKOS_FUNCTION FunctionWrapperDeviceBase() {} 32 : /** 33 : * Virtual destructor 34 : */ 35 0 : KOKKOS_FUNCTION virtual ~FunctionWrapperDeviceBase() {} 36 : 37 : /** 38 : * Virtual shims that calls the corresponding methods of the actual stored function 39 : */ 40 : ///{@ 41 : KOKKOS_FUNCTION virtual Real value(Real t, Real3 p) const = 0; 42 : KOKKOS_FUNCTION virtual Real3 vectorValue(Real t, Real3 p) const = 0; 43 : KOKKOS_FUNCTION virtual Real3 gradient(Real t, Real3 p) const = 0; 44 : KOKKOS_FUNCTION virtual Real3 curl(Real t, Real3 p) const = 0; 45 : KOKKOS_FUNCTION virtual Real div(Real t, Real3 p) const = 0; 46 : KOKKOS_FUNCTION virtual Real timeDerivative(Real t, Real3 p) const = 0; 47 : KOKKOS_FUNCTION virtual Real timeIntegral(Real t1, Real t2, Real3 p) const = 0; 48 : KOKKOS_FUNCTION virtual Real integral() const = 0; 49 : KOKKOS_FUNCTION virtual Real average() const = 0; 50 : ///@} 51 : }; 52 : 53 : /** 54 : * Device function wrapper class that provides polymorphic interfaces for a function. The function 55 : * itself is a static object and does not have any virtual method. Instead, the device wrapper 56 : * defines the virtual shims and forwards the calls to the static methods of the stored function. 57 : * @tparam Object The function class type 58 : */ 59 : template <typename Object> 60 : class FunctionWrapperDevice : public FunctionWrapperDeviceBase 61 : { 62 : friend class FunctionWrapperHost<Object>; 63 : 64 : public: 65 : /** 66 : * Constructor 67 : */ 68 91 : KOKKOS_FUNCTION FunctionWrapperDevice() {} 69 : 70 12656539 : KOKKOS_FUNCTION Real value(Real t, Real3 p) const override final 71 : { 72 12656539 : return _function->value(t, p); 73 : } 74 0 : KOKKOS_FUNCTION Real3 vectorValue(Real t, Real3 p) const override final 75 : { 76 0 : return _function->vectorValue(t, p); 77 : } 78 0 : KOKKOS_FUNCTION Real3 gradient(Real t, Real3 p) const override final 79 : { 80 0 : return _function->gradient(t, p); 81 : } 82 0 : KOKKOS_FUNCTION Real3 curl(Real t, Real3 p) const override final { return _function->curl(t, p); } 83 0 : KOKKOS_FUNCTION Real div(Real t, Real3 p) const override final { return _function->div(t, p); } 84 0 : KOKKOS_FUNCTION Real timeDerivative(Real t, Real3 p) const override final 85 : { 86 0 : return _function->timeDerivative(t, p); 87 : } 88 0 : KOKKOS_FUNCTION Real timeIntegral(Real t1, Real t2, Real3 p) const override final 89 : { 90 0 : return _function->timeIntegral(t1, t2, p); 91 : } 92 0 : KOKKOS_FUNCTION Real integral() const override final { return _function->integral(); } 93 0 : KOKKOS_FUNCTION Real average() const override final { return _function->average(); } 94 : 95 : protected: 96 : /** 97 : * Pointer to the function on device 98 : */ 99 : Object * _function = nullptr; 100 : }; 101 : 102 : /** 103 : * Base class for host function wrapper 104 : */ 105 : class FunctionWrapperHostBase 106 : { 107 : public: 108 : /** 109 : * Virtual destructor 110 : */ 111 92 : virtual ~FunctionWrapperHostBase() {} 112 : 113 : /** 114 : * Allocate device function and wrapper 115 : * @returns The pointer to the device wrapper 116 : */ 117 : virtual FunctionWrapperDeviceBase * allocate() = 0; 118 : /** 119 : * Copy function to device 120 : */ 121 : virtual void copyFunction() = 0; 122 : /** 123 : * Free host and device copies of function 124 : */ 125 : virtual void freeFunction() = 0; 126 : }; 127 : 128 : /** 129 : * Host function wrapper class that allocates a function on device and creates its device wrapper. 130 : * This class holds the actual device instance of the function and manages its allocation and 131 : * deallocation, and the device wrapper simply keeps a pointer to it. 132 : * @tparam Object The function class type 133 : */ 134 : template <typename Object> 135 : class FunctionWrapperHost : public FunctionWrapperHostBase 136 : { 137 : public: 138 : /** 139 : * Constructor 140 : * @param function Pointer to the function 141 : */ 142 92 : FunctionWrapperHost(const void * function) 143 91 : : _function_host(*static_cast<const Object *>(function)) 144 : { 145 92 : } 146 : /** 147 : * Destructor 148 : */ 149 : ~FunctionWrapperHost(); 150 : 151 : FunctionWrapperDeviceBase * allocate() override final; 152 : void copyFunction() override final; 153 : void freeFunction() override final; 154 : 155 : private: 156 : /** 157 : * Reference of the function on host 158 : */ 159 : const Object & _function_host; 160 : /** 161 : * Copy of the function on host 162 : */ 163 : std::unique_ptr<Object> _function_copy; 164 : /** 165 : * Copy of the function on device 166 : */ 167 : Object * _function_device = nullptr; 168 : }; 169 : 170 : template <typename Object> 171 : FunctionWrapperDeviceBase * 172 92 : FunctionWrapperHost<Object>::allocate() 173 : { 174 : // Allocate storage for device wrapper on device 175 1 : auto wrapper_device = static_cast<FunctionWrapperDevice<Object> *>( 176 91 : ::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(FunctionWrapperDevice<Object>))); 177 : 178 : // Allocate device wrapper on device using placement new to populate vtable with device pointers 179 92 : ::Kokkos::parallel_for( 180 182 : 1, KOKKOS_LAMBDA(const int) { new (wrapper_device) FunctionWrapperDevice<Object>(); }); 181 : 182 : // Allocate storage for function on device 183 92 : _function_device = 184 91 : static_cast<Object *>(::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(Object))); 185 : 186 : // Let device wrapper point to the copy 187 183 : ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>( 188 91 : &(wrapper_device->_function), &_function_device, sizeof(Object *)); 189 : 190 92 : return wrapper_device; 191 : } 192 : 193 : template <typename Object> 194 : void 195 11943 : FunctionWrapperHost<Object>::copyFunction() 196 : { 197 : // Make a copy of function on host to trigger copy constructor 198 11943 : _function_copy = std::make_unique<Object>(_function_host); 199 : 200 : // Copy function to device 201 23882 : ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>( 202 11939 : _function_device, _function_copy.get(), sizeof(Object)); 203 11943 : } 204 : 205 : template <typename Object> 206 : void 207 12035 : FunctionWrapperHost<Object>::freeFunction() 208 : { 209 12035 : _function_copy.reset(); 210 12035 : } 211 : 212 : template <typename Object> 213 184 : FunctionWrapperHost<Object>::~FunctionWrapperHost() 214 : { 215 92 : ::Kokkos::kokkos_free<ExecSpace::memory_space>(_function_device); 216 184 : } 217 : 218 : } // namespace Kokkos 219 : } // namespace Moose