Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://www.mooseframework.org 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #include "KokkosArray.h" 13 : 14 : #include "MooseTypes.h" 15 : 16 : namespace Moose::Kokkos 17 : { 18 : 19 : /** 20 : * The Kokkos object that can hold the reference of a variable. 21 : * Reference of a host variable is not accessible on device, so if there is a variable that should 22 : * be stored as a reference but still needs to be accessed on device, define an instance of this 23 : * class and construct it with the reference of the variable. 24 : * This class holds the device copy as well as the host reference of the variable. 25 : * The copy constructor of this object that copies the host reference to the device copy is invoked 26 : * whenever a Kokkos functor containing this object is dispatched to device, so it is guaranteed 27 : * that the device copy is always up-to-date with the host reference when it is used on device 28 : * Therefore, the variable must be copy constructible. 29 : */ 30 : template <typename T> 31 : class ReferenceWrapper 32 : { 33 : public: 34 : /** 35 : * Constructor 36 : * @param reference The writeable reference of the variable to store 37 : */ 38 59578 : ReferenceWrapper(T & reference) : _reference(reference), _copy(reference) {} 39 : /** 40 : * Copy constructor 41 : */ 42 2792534 : ReferenceWrapper(const ReferenceWrapper<T> & object) 43 1589814 : : _reference(object._reference), _copy(object._reference) 44 : { 45 2792534 : } 46 : 47 : #ifdef MOOSE_KOKKOS_SCOPE 48 : /** 49 : * Get the const reference of the stored variable 50 : * @returns The const reference of the stored variable depending on the architecture this function 51 : * is being called on 52 : */ 53 57499908 : KOKKOS_FUNCTION operator const T &() const 54 : { 55 57499908 : KOKKOS_IF_ON_HOST(return _reference;) 56 : 57 57499908 : return _copy; 58 : } 59 : /** 60 : * Get the const reference of the stored variable 61 : * @returns The const reference of the stored variable depending on the architecture this function 62 : * is being called on 63 : */ 64 : KOKKOS_FUNCTION const T & operator*() const 65 : { 66 : KOKKOS_IF_ON_HOST(return _reference;) 67 : 68 : return _copy; 69 : } 70 : /** 71 : * Get the const pointer to the stored variable 72 : * @returns The const pointer to the stored variable depending on the architecture this function 73 : * is being called on 74 : */ 75 : KOKKOS_FUNCTION const T * operator->() const 76 : { 77 : KOKKOS_IF_ON_HOST(return &_reference;) 78 : 79 : return &_copy; 80 : } 81 : /** 82 : * Forward arguments to the stored variable's const operator() depending on the architecture this 83 : * function is being called on 84 : * @param args The variadic arguments to be forwarded 85 : */ 86 : template <typename... Args> 87 1985600 : KOKKOS_FUNCTION auto operator()(Args &&... args) const -> decltype(auto) 88 : { 89 1985600 : KOKKOS_IF_ON_HOST(return _reference(std::forward<Args>(args)...);) 90 : 91 1985600 : return _copy(std::forward<Args>(args)...); 92 : } 93 : #else 94 : /** 95 : * Get the const reference of the stored host reference 96 : * @returns The const reference of the stored host reference 97 : */ 98 : operator const T &() const { return _reference; } 99 : /** 100 : * Get the const reference of the stored host reference 101 : * @returns The const reference of the stored host reference 102 : */ 103 : const T & operator*() const { return _reference; } 104 : /** 105 : * Get the const pointer of the stored host reference 106 : * @returns The const pointer to the stored host reference 107 : */ 108 : const T * operator->() const { return &_reference; } 109 : /** 110 : * Forward arguments to the stored host reference's const operator() 111 : * @param args The variadic arguments to be forwarded 112 : */ 113 : template <typename... Args> 114 : auto operator()(Args &&... args) const -> decltype(auto) 115 : { 116 : return _reference(std::forward<Args>(args)...); 117 : } 118 : #endif 119 : /** 120 : * Get the writeable reference of the stored host reference 121 : * @returns The writeable reference of the stored host reference 122 : */ 123 752 : operator T &() { return _reference; } 124 : /** 125 : * Get the writeable reference of the stored host reference 126 : * @returns The writeable reference of the stored host reference 127 : */ 128 34 : T & operator*() { return _reference; } 129 : /** 130 : * Get the writeable pointer of the stored host reference 131 : * @returns The writeable pointer to the stored host reference 132 : */ 133 150 : T * operator->() { return &_reference; } 134 : /** 135 : * Forward arguments to the stored host reference's operator() 136 : * @param args The variadic arguments to be forwarded 137 : */ 138 : template <typename... Args> 139 300 : auto operator()(Args &&... args) -> decltype(auto) 140 : { 141 300 : return _reference(std::forward<Args>(args)...); 142 : } 143 : 144 : protected: 145 : /** 146 : * Writeable host reference of the variable 147 : */ 148 : T & _reference; 149 : /** 150 : * Device copy of the variable 151 : */ 152 : const T _copy; 153 : }; 154 : 155 : template <typename T> 156 : struct ArrayDeepCopy<ReferenceWrapper<T>> 157 : { 158 : static constexpr bool value = true; 159 : }; 160 : 161 : } // namespace Moose::Kokkos