LCOV - code coverage report
Current view: top level - include/kokkos/functions - KokkosFunctionWrapper.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #32971 (54bef8) with base c6cf66 Lines: 29 42 69.0 %
Date: 2026-05-29 20:35:17 Functions: 38 73 52.1 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include "KokkosTypes.h"
      13             : 
      14             : namespace Moose::Kokkos
      15             : {
      16             : 
      17             : template <typename Object>
      18             : class FunctionWrapperHost;
      19             : 
      20             : /**
      21             :  * Base class for device function wrapper
      22             :  */
      23             : class FunctionWrapperDeviceBase
      24             : {
      25             : public:
      26             :   /**
      27             :    * Constructor
      28             :    */
      29         118 :   KOKKOS_FUNCTION FunctionWrapperDeviceBase() {}
      30             :   /**
      31             :    * Virtual destructor
      32             :    */
      33           0 :   KOKKOS_FUNCTION virtual ~FunctionWrapperDeviceBase() {}
      34             : 
      35             :   /**
      36             :    * Virtual shims that calls the corresponding methods of the actual stored function
      37             :    */
      38             :   ///@{
      39             :   KOKKOS_FUNCTION virtual Real value(Real t, Real3 p) const = 0;
      40             :   KOKKOS_FUNCTION virtual Real3 vectorValue(Real t, Real3 p) const = 0;
      41             :   KOKKOS_FUNCTION virtual Real3 gradient(Real t, Real3 p) const = 0;
      42             :   KOKKOS_FUNCTION virtual Real3 curl(Real t, Real3 p) const = 0;
      43             :   KOKKOS_FUNCTION virtual Real div(Real t, Real3 p) const = 0;
      44             :   KOKKOS_FUNCTION virtual Real timeDerivative(Real t, Real3 p) const = 0;
      45             :   KOKKOS_FUNCTION virtual Real timeIntegral(Real t1, Real t2, Real3 p) const = 0;
      46             :   KOKKOS_FUNCTION virtual Real integral() const = 0;
      47             :   KOKKOS_FUNCTION virtual Real average() const = 0;
      48             :   ///@}
      49             : };
      50             : 
      51             : /**
      52             :  * Device function wrapper class that provides polymorphic interfaces for a function. The function
      53             :  * itself is a static object and does not have any virtual method. Instead, the device wrapper
      54             :  * defines the virtual shims and forwards the calls to the static methods of the stored function.
      55             :  * @tparam Object The function class type
      56             :  */
      57             : template <typename Object>
      58             : class FunctionWrapperDevice : public FunctionWrapperDeviceBase
      59             : {
      60             :   friend class FunctionWrapperHost<Object>;
      61             : 
      62             : public:
      63             :   /**
      64             :    * Constructor
      65             :    */
      66         118 :   KOKKOS_FUNCTION FunctionWrapperDevice() {}
      67             : 
      68    13123379 :   KOKKOS_FUNCTION Real value(Real t, Real3 p) const override final
      69             :   {
      70    13123379 :     return _function->value(t, p);
      71             :   }
      72           0 :   KOKKOS_FUNCTION Real3 vectorValue(Real t, Real3 p) const override final
      73             :   {
      74           0 :     return _function->vectorValue(t, p);
      75             :   }
      76           0 :   KOKKOS_FUNCTION Real3 gradient(Real t, Real3 p) const override final
      77             :   {
      78           0 :     return _function->gradient(t, p);
      79             :   }
      80           0 :   KOKKOS_FUNCTION Real3 curl(Real t, Real3 p) const override final { return _function->curl(t, p); }
      81           0 :   KOKKOS_FUNCTION Real div(Real t, Real3 p) const override final { return _function->div(t, p); }
      82           0 :   KOKKOS_FUNCTION Real timeDerivative(Real t, Real3 p) const override final
      83             :   {
      84           0 :     return _function->timeDerivative(t, p);
      85             :   }
      86           0 :   KOKKOS_FUNCTION Real timeIntegral(Real t1, Real t2, Real3 p) const override final
      87             :   {
      88           0 :     return _function->timeIntegral(t1, t2, p);
      89             :   }
      90           0 :   KOKKOS_FUNCTION Real integral() const override final { return _function->integral(); }
      91           0 :   KOKKOS_FUNCTION Real average() const override final { return _function->average(); }
      92             : 
      93             : protected:
      94             :   /**
      95             :    * Pointer to the function on device
      96             :    */
      97             :   Object * _function = nullptr;
      98             : };
      99             : 
     100             : /**
     101             :  * Base class for host function wrapper
     102             :  */
     103             : class FunctionWrapperHostBase
     104             : {
     105             : public:
     106             :   /**
     107             :    * Virtual destructor
     108             :    */
     109         119 :   virtual ~FunctionWrapperHostBase() {}
     110             : 
     111             :   /**
     112             :    * Allocate device function and wrapper
     113             :    * @returns The pointer to the device wrapper
     114             :    */
     115             :   virtual FunctionWrapperDeviceBase * allocate() = 0;
     116             :   /**
     117             :    * Copy function to device
     118             :    */
     119             :   virtual void copyFunction() = 0;
     120             :   /**
     121             :    * Free host and device copies of function
     122             :    */
     123             :   virtual void freeFunction() = 0;
     124             : };
     125             : 
     126             : /**
     127             :  * Host function wrapper class that allocates a function on device and creates its device wrapper.
     128             :  * This class holds the actual device instance of the function and manages its allocation and
     129             :  * deallocation, and the device wrapper simply keeps a pointer to it.
     130             :  * @tparam Object The function class type
     131             :  */
     132             : template <typename Object>
     133             : class FunctionWrapperHost : public FunctionWrapperHostBase
     134             : {
     135             : public:
     136             :   /**
     137             :    * Constructor
     138             :    * @param function Pointer to the function
     139             :    */
     140         119 :   FunctionWrapperHost(const void * function)
     141         118 :     : _function_host(*static_cast<const Object *>(function))
     142             :   {
     143         119 :   }
     144             :   /**
     145             :    * Destructor
     146             :    */
     147             :   ~FunctionWrapperHost();
     148             : 
     149             :   FunctionWrapperDeviceBase * allocate() override final;
     150             :   void copyFunction() override final;
     151             :   void freeFunction() override final;
     152             : 
     153             : private:
     154             :   /**
     155             :    * Reference of the function on host
     156             :    */
     157             :   const Object & _function_host;
     158             :   /**
     159             :    * Copy of the function on host
     160             :    */
     161             :   std::unique_ptr<Object> _function_copy;
     162             :   /**
     163             :    * Copy of the function on device
     164             :    */
     165             :   Object * _function_device = nullptr;
     166             : };
     167             : 
     168             : template <typename Object>
     169             : FunctionWrapperDeviceBase *
     170         119 : FunctionWrapperHost<Object>::allocate()
     171             : {
     172             :   // Allocate storage for device wrapper on device
     173           1 :   auto wrapper_device = static_cast<FunctionWrapperDevice<Object> *>(
     174         118 :       ::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(FunctionWrapperDevice<Object>)));
     175             : 
     176             :   // Allocate device wrapper on device using placement new to populate vtable with device pointers
     177         119 :   ::Kokkos::parallel_for(
     178         236 :       1, KOKKOS_LAMBDA(const int) { new (wrapper_device) FunctionWrapperDevice<Object>(); });
     179             : 
     180             :   // Allocate storage for function on device
     181         119 :   _function_device =
     182         118 :       static_cast<Object *>(::Kokkos::kokkos_malloc<ExecSpace::memory_space>(sizeof(Object)));
     183             : 
     184             :   // Let device wrapper point to the copy
     185         237 :   ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>(
     186         118 :       &(wrapper_device->_function), &_function_device, sizeof(Object *));
     187             : 
     188         119 :   return wrapper_device;
     189             : }
     190             : 
     191             : template <typename Object>
     192             : void
     193       20949 : FunctionWrapperHost<Object>::copyFunction()
     194             : {
     195             :   // Make a copy of function on host to trigger copy constructor
     196       20949 :   _function_copy = std::make_unique<Object>(_function_host);
     197             : 
     198             :   // Copy function to device
     199       41894 :   ::Kokkos::Impl::DeepCopy<MemSpace, ::Kokkos::HostSpace>(
     200       20945 :       _function_device, _function_copy.get(), sizeof(Object));
     201       20949 : }
     202             : 
     203             : template <typename Object>
     204             : void
     205       21068 : FunctionWrapperHost<Object>::freeFunction()
     206             : {
     207       21068 :   _function_copy.reset();
     208       21068 : }
     209             : 
     210             : template <typename Object>
     211         238 : FunctionWrapperHost<Object>::~FunctionWrapperHost()
     212             : {
     213         119 :   ::Kokkos::kokkos_free<ExecSpace::memory_space>(_function_device);
     214         238 : }
     215             : 
     216             : } // namespace Moose::Kokkos

Generated by: LCOV version 1.14