LCOV - code coverage report
Current view: top level - include/neml2/interfaces - NEML2ModelInterface.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: 2bf808 Lines: 59 64 92.2 %
Date: 2025-07-17 01:28:37 Functions: 12 12 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include <thread>
      13             : #include <utility>
      14             : #include <tuple>
      15             : #include "NEML2Utils.h"
      16             : #include "InputParameters.h"
      17             : 
      18             : #ifdef NEML2_ENABLED
      19             : #include <ATen/Parallel.h>
      20             : #include "neml2/models/Model.h"
      21             : #include "neml2/dispatchers/WorkScheduler.h"
      22             : #include "neml2/dispatchers/WorkDispatcher.h"
      23             : #include "neml2/dispatchers/valuemap_helpers.h"
      24             : #include "neml2/dispatchers/derivmap_helpers.h"
      25             : #endif
      26             : 
      27             : /**
      28             :  * Interface class to provide common input parameters, members, and methods for MOOSEObjects that
      29             :  * use NEML2 models.
      30             :  */
      31             : template <class T>
      32             : class NEML2ModelInterface : public T
      33             : {
      34             : public:
      35             :   static InputParameters validParams();
      36             : 
      37             :   template <typename... P>
      38             :   NEML2ModelInterface(const InputParameters & params, P &&... args);
      39             : 
      40             : #ifdef NEML2_ENABLED
      41             : 
      42             : protected:
      43             :   /**
      44             :    * Validate the NEML2 material model. Note that the developer is responsible for calling this
      45             :    * method at the appropriate times, for example, at initialSetup().
      46             :    */
      47             :   virtual void validateModel() const;
      48             : 
      49             :   /// Get the NEML2 model
      50        5272 :   neml2::Model & model() const { return *_model; }
      51             : 
      52             :   /// Get the target compute device
      53        1731 :   const neml2::Device & device() const { return _device; }
      54             : 
      55             :   using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
      56             :   using DispatcherType =
      57             :       neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
      58             : 
      59             :   /// Get the work scheduler
      60        1650 :   neml2::WorkScheduler * scheduler() { return _scheduler.get(); }
      61             :   /// Get the work dispatcher
      62           4 :   const std::unique_ptr<DispatcherType> & dispatcher() const { return _dispatcher; }
      63             : 
      64             : private:
      65             :   /// The device on which to evaluate the NEML2 model
      66             :   const neml2::Device _device;
      67             :   /// The NEML2 factory
      68             :   std::unique_ptr<neml2::Factory> _factory;
      69             :   /// The NEML2 material model
      70             :   std::shared_ptr<neml2::Model> _model;
      71             : 
      72             :   /// The work scheduler to use
      73             :   std::shared_ptr<neml2::WorkScheduler> _scheduler;
      74             :   /// Work dispatcher
      75             :   std::unique_ptr<DispatcherType> _dispatcher;
      76             :   /// Whether to dispatch work asynchronously
      77             :   const bool _async_dispatch;
      78             :   /// Models for each thread
      79             :   std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>> _model_pool;
      80             : 
      81             : #endif // NEML2_ENABLED
      82             : };
      83             : 
      84             : template <class T>
      85             : InputParameters
      86       14759 : NEML2ModelInterface<T>::validParams()
      87             : {
      88       14759 :   InputParameters params = T::validParams();
      89       14759 :   params.addParam<DataFileName>(
      90             :       "input",
      91             :       NEML2Utils::docstring("Path to the NEML2 input file containing the NEML2 model(s)."));
      92       14759 :   params.addParam<std::vector<std::string>>(
      93             :       "cli_args",
      94             :       {},
      95             :       NEML2Utils::docstring(
      96             :           "Additional command line arguments to use when parsing the NEML2 input file."));
      97       14759 :   params.addParam<std::string>(
      98             :       "model",
      99             :       "",
     100             :       NEML2Utils::docstring("Name of the NEML2 model, i.e., the string inside the brackets [] in "
     101             :                             "the NEML2 input file that corresponds to the model you want to use."));
     102       14759 :   params.addParam<std::string>(
     103             :       "device",
     104             :       "cpu",
     105             :       NEML2Utils::docstring(
     106             :           "Device on which to evaluate the NEML2 model. The string supplied must follow the "
     107             :           "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device "
     108             :           "type, and :<device-index> optionally specifies a device index. For example, "
     109             :           "device='cpu' sets the target compute device to be CPU, and device='cuda:1' sets the "
     110             :           "target compute device to be CUDA with device ID 1."));
     111             : 
     112       14759 :   params.addParam<std::string>(
     113             :       "scheduler",
     114             :       NEML2Utils::docstring(
     115             :           "NEML2 scheduler to use to run the model.  If not specified no scheduler is used and "
     116             :           "MOOSE will pass all the constitutive updates to the provided device at once."));
     117             : 
     118       44277 :   params.addParam<bool>(
     119       29518 :       "async_dispatch", true, NEML2Utils::docstring("Whether to use asynchronous dispatch."));
     120             : 
     121       14759 :   return params;
     122           0 : }
     123             : 
     124             : #ifndef NEML2_ENABLED
     125             : 
     126             : template <class T>
     127             : template <typename... P>
     128           0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
     129           0 :   : T(params, args...)
     130             : {
     131           0 : }
     132             : 
     133             : #else
     134             : 
     135             : template <class T>
     136             : template <typename... P>
     137          22 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
     138             :   : T(params, args...),
     139          22 :     _device(params.get<std::string>("device")),
     140          22 :     _scheduler(nullptr),
     141          66 :     _async_dispatch(params.get<bool>("async_dispatch"))
     142             : {
     143             :   // Load model
     144          22 :   const auto & fname = params.get<DataFileName>("input");
     145          22 :   const auto & cli_args = params.get<std::vector<std::string>>("cli_args");
     146          22 :   _factory = neml2::load_input(std::string(fname), neml2::utils::join(cli_args, " "));
     147          22 :   _model = NEML2Utils::getModel(*_factory, params.get<std::string>("model"));
     148          22 :   _model->to(_device);
     149             : 
     150             :   // Load scheduler if specified
     151          22 :   if (params.isParamValid("scheduler"))
     152           2 :     _scheduler = _factory->get_scheduler(params.get<std::string>("scheduler"));
     153             : 
     154          22 :   if (_scheduler)
     155             :   {
     156           4 :     auto red = [](std::vector<RJType> && results) -> RJType
     157             :     {
     158             :       // Split into two separate vectors
     159           4 :       std::vector<neml2::ValueMap> vms;
     160           4 :       std::vector<neml2::DerivMap> dms;
     161          44 :       for (auto && [vm, dm] : results)
     162             :       {
     163          40 :         vms.push_back(std::move(vm));
     164          40 :         dms.push_back(std::move(dm));
     165             :       }
     166           4 :       return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
     167          12 :                              neml2::derivmap_cat_reduce(std::move(dms), 0));
     168           4 :     };
     169             : 
     170          82 :     auto post = [this](RJType && x) -> RJType
     171             :     {
     172          40 :       return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)), _device),
     173          80 :                              neml2::derivmap_move_device(std::move(std::get<1>(x)), _device));
     174             :     };
     175             : 
     176           1 :     auto thread_init = [this](neml2::Device device) -> void
     177             :     {
     178           1 :       at::set_num_threads(libMesh::n_threads());
     179           1 :       at::set_num_interop_threads(libMesh::n_threads());
     180           1 :       auto model = NEML2Utils::getModel(*_factory, _model->name());
     181           1 :       model->to(device);
     182           1 :       _model_pool[std::this_thread::get_id()] = std::move(model);
     183           1 :     };
     184             : 
     185           2 :     _dispatcher = std::make_unique<DispatcherType>(
     186           2 :         *_scheduler,
     187           2 :         _async_dispatch,
     188           0 :         [&](neml2::ValueMap && x, neml2::Device device) -> RJType
     189             :         {
     190          40 :           auto & model =
     191          40 :               _async_dispatch ? libmesh_map_find(_model_pool, std::this_thread::get_id()) : _model;
     192             : 
     193             :           // If this is not an async dispatch, we need to move the model to the target device
     194             :           // _every_ time before evaluation
     195          40 :           if (!_async_dispatch)
     196          20 :             model->to(device);
     197             : 
     198          40 :           return model->value_and_dvalue(std::move(x));
     199             :         },
     200             :         red,
     201           2 :         &neml2::valuemap_move_device,
     202             :         post,
     203           4 :         _async_dispatch ? thread_init : std::function<void(neml2::Device)>());
     204             :   }
     205          22 : }
     206             : 
     207             : template <class T>
     208             : void
     209          22 : NEML2ModelInterface<T>::validateModel() const
     210             : {
     211             :   mooseAssert(_model != nullptr, "_model must be initialized");
     212          22 :   neml2::diagnose(*_model);
     213          22 : }
     214             : 
     215             : #endif // NEML2_ENABLED

Generated by: LCOV version 1.14