Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://www.mooseframework.org 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #ifdef MOOSE_KOKKOS_SCOPE 13 : #include "KokkosHeader.h" 14 : #endif 15 : 16 : #include "MooseTypes.h" 17 : 18 : namespace Moose 19 : { 20 : namespace Kokkos 21 : { 22 : 23 : /** 24 : * The Kokkos object that can hold the reference of a variable. 25 : * Reference of a host variable is not accessible on device, so if there is a variable that should 26 : * be stored as a reference but still needs to be accessed on device, define an instance of this 27 : * class and construct it with the reference of the variable. 28 : * This class holds the device copy as well as the host reference of the variable. 29 : * The copy constructor of this object that copies the host reference to the device copy is invoked 30 : * whenever a Kokkos functor containing this object is dispatched to device, so it is guaranteed 31 : * that the device copy is always up-to-date with the host reference when it is used on device 32 : * Therefore, the variable must be copy constructible. 33 : */ 34 : template <typename T> 35 : class ReferenceWrapper 36 : { 37 : public: 38 : /** 39 : * Constructor 40 : * @param reference The writeable reference of the variable to store 41 : */ 42 21553 : ReferenceWrapper(T & reference) : _reference(reference), _copy(reference) {} 43 : /** 44 : * Copy constructor 45 : */ 46 1084321 : ReferenceWrapper(const ReferenceWrapper<T> & object) 47 883040 : : _reference(object._reference), _copy(object._reference) 48 : { 49 1084321 : } 50 : 51 : #ifdef MOOSE_KOKKOS_SCOPE 52 : /** 53 : * Get the const reference of the stored variable 54 : * @returns The const reference of the stored variable depending on the architecture this function 55 : * is being called on 56 : */ 57 26532990 : KOKKOS_FUNCTION operator const T &() const 58 : { 59 26532990 : KOKKOS_IF_ON_HOST(return _reference;) 60 : 61 26532990 : return _copy; 62 : } 63 : /** 64 : * Get the const reference of the stored variable 65 : * @returns The const reference of the stored variable depending on the architecture this function 66 : * is being called on 67 : */ 68 : KOKKOS_FUNCTION const T & operator*() const 69 : { 70 : KOKKOS_IF_ON_HOST(return _reference;) 71 : 72 : return _copy; 73 : } 74 : /** 75 : * Get the const pointer to the stored variable 76 : * @returns The const pointer to the stored variable depending on the architecture this function 77 : * is being called on 78 : */ 79 : KOKKOS_FUNCTION const T * operator->() const 80 : { 81 : KOKKOS_IF_ON_HOST(return &_reference;) 82 : 83 : return &_copy; 84 : } 85 : /** 86 : * Forward arguments to the stored variable's const operator() depending on the architecture this 87 : * function is being called on 88 : * @param args The variadic arguments to be forwarded 89 : */ 90 : template <typename... Args> 91 1985600 : KOKKOS_FUNCTION auto operator()(Args &&... args) const -> decltype(auto) 92 : { 93 1985600 : KOKKOS_IF_ON_HOST(return _reference(std::forward<Args>(args)...);) 94 : 95 1985600 : return _copy(std::forward<Args>(args)...); 96 : } 97 : #else 98 : /** 99 : * Get the const reference of the stored host reference 100 : * @returns The const reference of the stored host reference 101 : */ 102 : operator const T &() const { return _reference; } 103 : /** 104 : * Get the const reference of the stored host reference 105 : * @returns The const reference of the stored host reference 106 : */ 107 : const T & operator*() const { return _reference; } 108 : /** 109 : * Get the const pointer of the stored host reference 110 : * @returns The const pointer to the stored host reference 111 : */ 112 : const T * operator->() const { return &_reference; } 113 : /** 114 : * Forward arguments to the stored host reference's const operator() 115 : * @param args The variadic arguments to be forwarded 116 : */ 117 : template <typename... Args> 118 : auto operator()(Args &&... args) const -> decltype(auto) 119 : { 120 : return _reference(std::forward<Args>(args)...); 121 : } 122 : #endif 123 : /** 124 : * Get the writeable reference of the stored host reference 125 : * @returns The writeable reference of the stored host reference 126 : */ 127 502 : operator T &() { return _reference; } 128 : /** 129 : * Get the writeable reference of the stored host reference 130 : * @returns The writeable reference of the stored host reference 131 : */ 132 24 : T & operator*() { return _reference; } 133 : /** 134 : * Get the writeable pointer of the stored host reference 135 : * @returns The writeable pointer to the stored host reference 136 : */ 137 100 : T * operator->() { return &_reference; } 138 : /** 139 : * Forward arguments to the stored host reference's operator() 140 : * @param args The variadic arguments to be forwarded 141 : */ 142 : template <typename... Args> 143 200 : auto operator()(Args &&... args) -> decltype(auto) 144 : { 145 200 : return _reference(std::forward<Args>(args)...); 146 : } 147 : 148 : protected: 149 : /** 150 : * Writeable host reference of the variable 151 : */ 152 : T & _reference; 153 : /** 154 : * Device copy of the variable 155 : */ 156 : const T _copy; 157 : }; 158 : 159 : } // namespace Kokkos 160 : } // namespace Moose