LCOV - code coverage report
Current view: top level - include/kokkos/functions - KokkosFunctionWrapper.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #31782 (615931) with base 7edd10 Lines: 29 42 69.0 %
Date: 2025-11-11 23:21:15 Functions: 29 56 51.8 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosTypes.h"
      13             : 
      14             : namespace Moose
      15             : {
      16             : namespace Kokkos
      17             : {
      18             : 
      19             : template <typename Object>
      20             : class FunctionWrapperHost;
      21             : 
      22             : /**
      23             :  * Base class for device function wrapper
      24             :  */
      25             : class FunctionWrapperDeviceBase
      26             : {
      27             : public:
      28             :   /**
      29             :    * Constructor
      30             :    */
      31          91 :   KOKKOS_FUNCTION FunctionWrapperDeviceBase() {}
      32             :   /**
      33             :    * Virtual destructor
      34             :    */
      35           0 :   KOKKOS_FUNCTION virtual ~FunctionWrapperDeviceBase() {}
      36             : 
      37             :   /**
      38             :    * Virtual shims that calls the corresponding methods of the actual stored function
      39             :    */
      40             :   ///{@
      41             :   KOKKOS_FUNCTION virtual Real value(Real t, Real3 p) const = 0;
      42             :   KOKKOS_FUNCTION virtual Real3 vectorValue(Real t, Real3 p) const = 0;
      43             :   KOKKOS_FUNCTION virtual Real3 gradient(Real t, Real3 p) const = 0;
      44             :   KOKKOS_FUNCTION virtual Real3 curl(Real t, Real3 p) const = 0;
      45             :   KOKKOS_FUNCTION virtual Real div(Real t, Real3 p) const = 0;
      46             :   KOKKOS_FUNCTION virtual Real timeDerivative(Real t, Real3 p) const = 0;
      47             :   KOKKOS_FUNCTION virtual Real timeIntegral(Real t1, Real t2, Real3 p) const = 0;
      48             :   KOKKOS_FUNCTION virtual Real integral() const = 0;
      49             :   KOKKOS_FUNCTION virtual Real average() const = 0;
      50             :   ///@}
      51             : };
      52             : 
      53             : /**
      54             :  * Device function wrapper class that provides polymorphic interfaces for a function. The function
      55             :  * itself is a static object and does not have any virtual method. Instead, the device wrapper
      56             :  * defines the virtual shims and forwards the calls to the static methods of the stored function.
      57             :  * @tparam Object The function class type
      58             :  */
      59             : template <typename Object>
      60             : class FunctionWrapperDevice : public FunctionWrapperDeviceBase
      61             : {
      62             :   friend class FunctionWrapperHost<Object>;
      63             : 
      64             : public:
      65             :   /**
      66             :    * Constructor
      67             :    */
      68          91 :   KOKKOS_FUNCTION FunctionWrapperDevice() {}
      69             : 
      70    12656539 :   KOKKOS_FUNCTION Real value(Real t, Real3 p) const override final
      71             :   {
      72    12656539 :     return _function->value(t, p);
      73             :   }
      74           0 :   KOKKOS_FUNCTION Real3 vectorValue(Real t, Real3 p) const override final
      75             :   {
      76           0 :     return _function->vectorValue(t, p);
      77             :   }
      78           0 :   KOKKOS_FUNCTION Real3 gradient(Real t, Real3 p) const override final
      79             :   {
      80           0 :     return _function->gradient(t, p);
      81             :   }
      82           0 :   KOKKOS_FUNCTION Real3 curl(Real t, Real3 p) const override final { return _function->curl(t, p); }
      83           0 :   KOKKOS_FUNCTION Real div(Real t, Real3 p) const override final { return _function->div(t, p); }
      84           0 :   KOKKOS_FUNCTION Real timeDerivative(Real t, Real3 p) const override final
      85             :   {
      86           0 :     return _function->timeDerivative(t, p);
      87             :   }
      88           0 :   KOKKOS_FUNCTION Real timeIntegral(Real t1, Real t2, Real3 p) const override final
      89             :   {
      90           0 :     return _function->timeIntegral(t1, t2, p);
      91             :   }
      92           0 :   KOKKOS_FUNCTION Real integral() const override final { return _function->integral(); }
      93           0 :   KOKKOS_FUNCTION Real average() const override final { return _function->average(); }
      94             : 
      95             : protected:
      96             :   /**
      97             :    * Pointer to the function on device
      98             :    */
      99             :   Object * _function = nullptr;
     100             : };
     101             : 
     102             : /**
     103             :  * Base class for host function wrapper
     104             :  */
     105             : class FunctionWrapperHostBase
     106             : {
     107             : public:
     108             :   /**
     109             :    * Virtual destructor
     110             :    */
     111          92 :   virtual ~FunctionWrapperHostBase() {}
     112             : 
     113             :   /**
     114             :    * Allocate device function and wrapper
     115             :    * @returns The pointer to the device wrapper
     116             :    */
     117             :   virtual FunctionWrapperDeviceBase * allocate() = 0;
     118             :   /**
     119             :    * Copy function to device
     120             :    */
     121             :   virtual void copyFunction() = 0;
     122             :   /**
     123             :    * Free host and device copies of function
     124             :    */
     125             :   virtual void freeFunction() = 0;
     126             : };
     127             : 
     128             : /**
     129             :  * Host function wrapper class that allocates a function on device and creates its device wrapper.
     130             :  * This class holds the actual device instance of the function and manages its allocation and
     131             :  * deallocation, and the device wrapper simply keeps a pointer to it.
     132             :  * @tparam Object The function class type
     133             :  */
     134             : template <typename Object>
     135             : class FunctionWrapperHost : public FunctionWrapperHostBase
     136             : {
     137             : public:
     138             :   /**
     139             :    * Constructor
     140             :    * @param function Pointer to the function
     141             :    */
     142          92 :   FunctionWrapperHost(const void * function)
     143          91 :     : _function_host(*static_cast<const Object *>(function))
     144             :   {
     145          92 :   }
     146             :   /**
     147             :    * Destructor
     148             :    */
     149             :   ~FunctionWrapperHost();
     150             : 
     151             :   FunctionWrapperDeviceBase * allocate() override final;
     152             :   void copyFunction() override final;
     153             :   void freeFunction() override final;
     154             : 
     155             : private:
     156             :   /**
     157             :    * Reference of the function on host
     158             :    */
     159             :   const Object & _function_host;
     160             :   /**
     161             :    * Copy of the function on host
     162             :    */
     163             :   std::unique_ptr<Object> _function_copy;
     164             :   /**
     165             :    * Copy of the function on device
     166             :    */
     167             :   Object * _function_device = nullptr;
     168             : };
     169             : 
     170             : template <typename Object>
     171             : FunctionWrapperDeviceBase *
     172          92 : FunctionWrapperHost<Object>::allocate()
     173             : {
     174             :   // Allocate storage for device wrapper on device
     175           1 :   auto wrapper_device = static_cast<FunctionWrapperDevice<Object> *>(
     176          91 :       ::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(FunctionWrapperDevice<Object>)));
     177             : 
     178             :   // Allocate device wrapper on device using placement new to populate vtable with device pointers
     179          92 :   ::Kokkos::parallel_for(
     180         182 :       1, KOKKOS_LAMBDA(const int) { new (wrapper_device) FunctionWrapperDevice<Object>(); });
     181             : 
     182             :   // Allocate storage for function on device
     183          92 :   _function_device =
     184          91 :       static_cast<Object *>(::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(Object)));
     185             : 
     186             :   // Let device wrapper point to the copy
     187         183 :   ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>(
     188          91 :       &(wrapper_device->_function), &_function_device, sizeof(Object *));
     189             : 
     190          92 :   return wrapper_device;
     191             : }
     192             : 
     193             : template <typename Object>
     194             : void
     195       11943 : FunctionWrapperHost<Object>::copyFunction()
     196             : {
     197             :   // Make a copy of function on host to trigger copy constructor
     198       11943 :   _function_copy = std::make_unique<Object>(_function_host);
     199             : 
     200             :   // Copy function to device
     201       23882 :   ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>(
     202       11939 :       _function_device, _function_copy.get(), sizeof(Object));
     203       11943 : }
     204             : 
     205             : template <typename Object>
     206             : void
     207       12035 : FunctionWrapperHost<Object>::freeFunction()
     208             : {
     209       12035 :   _function_copy.reset();
     210       12035 : }
     211             : 
     212             : template <typename Object>
     213         184 : FunctionWrapperHost<Object>::~FunctionWrapperHost()
     214             : {
     215          92 :   ::Kokkos::kokkos_free<ExecSpace::memory_space>(_function_device);
     216         184 : }
     217             : 
     218             : } // namespace Kokkos
     219             : } // namespace Moose

Generated by: LCOV version 1.14