LCOV - code coverage report
Current view: top level - include/neml2/interfaces - NEML2ModelInterface.h (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #32971 (54bef8) with base c6cf66 Lines: 32 66 48.5 %
Date: 2026-05-29 20:35:17 Functions: 8 13 61.5 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #include <thread>
      13             : #include <utility>
      14             : #include <tuple>
      15             : #include "NEML2Utils.h"
      16             : #include "InputParameters.h"
      17             : 
      18             : #ifdef NEML2_ENABLED
      19             : #include <ATen/Parallel.h>
      20             : #include "neml2/neml2.h"
      21             : #include "neml2/models/Model.h"
      22             : #include "neml2/dispatchers/WorkScheduler.h"
      23             : #include "neml2/dispatchers/WorkDispatcher.h"
      24             : #include "neml2/dispatchers/valuemap_helpers.h"
      25             : #include "neml2/dispatchers/derivmap_helpers.h"
      26             : #endif
      27             : 
      28             : /**
      29             :  * Interface class to provide common input parameters, members, and methods for MOOSEObjects that
      30             :  * use NEML2 models.
      31             :  */
      32             : template <class T>
      33             : class NEML2ModelInterface : public T
      34             : {
      35             : public:
      36             :   static InputParameters validParams();
      37             : 
      38             :   template <typename... P>
      39             :   NEML2ModelInterface(const InputParameters & params, P &&... args);
      40             : 
      41             : #ifdef NEML2_ENABLED
      42             : 
      43             : protected:
      44             :   /**
      45             :    * Validate the NEML2 material model. Note that the developer is responsible for calling this
      46             :    * method at the appropriate times, for example, at initialSetup().
      47             :    */
      48             :   virtual void validateModel() const;
      49             : 
      50             :   /// Get the NEML2 model
      51        2928 :   neml2::Model & model() const { return *_model; }
      52             : 
      53             :   /// Get the target compute device
      54        1415 :   const neml2::Device & device() const { return _device; }
      55             : 
      56             :   /// Get the target output device
      57        2816 :   const neml2::Device & output_device() const { return _output_device; }
      58             : 
      59             :   using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
      60             :   using DispatcherType =
      61             :       neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
      62             : 
      63             :   /// Get the work scheduler
      64        1408 :   neml2::WorkScheduler * scheduler() { return _scheduler.get(); }
      65             :   /// Get the work dispatcher
      66           0 :   const std::unique_ptr<DispatcherType> & dispatcher() const { return _dispatcher; }
      67             : 
      68             : private:
      69             :   /// The device on which to evaluate the NEML2 model
      70             :   const neml2::Device _device;
      71             :   /// The device on which to store the outputs
      72             :   const neml2::Device _output_device;
      73             :   /// The NEML2 factory
      74             :   std::unique_ptr<neml2::Factory> _factory;
      75             :   /// The NEML2 material model
      76             :   std::shared_ptr<neml2::Model> _model;
      77             : 
      78             :   /// The work scheduler to use
      79             :   std::shared_ptr<neml2::WorkScheduler> _scheduler;
      80             :   /// Work dispatcher
      81             :   std::unique_ptr<DispatcherType> _dispatcher;
      82             :   /// Whether to dispatch work asynchronously
      83             :   const bool _async_dispatch;
      84             :   /// Models for each thread
      85             :   std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>> _model_pool;
      86             : 
      87             : #endif // NEML2_ENABLED
      88             : };
      89             : 
      90             : template <class T>
      91             : InputParameters
      92        3182 : NEML2ModelInterface<T>::validParams()
      93             : {
      94        3182 :   InputParameters params = T::validParams();
      95       12728 :   params.addParam<DataFileName>("input",
      96             :                                 "Path to the NEML2 input file containing the NEML2 model(s).");
      97       12728 :   params.addParam<std::vector<std::string>>(
      98             :       "cli_args",
      99             :       {},
     100             :       "Additional command line arguments to use when parsing the NEML2 input file.");
     101       12728 :   params.addParam<std::string>(
     102             :       "model",
     103             :       "",
     104             :       "Name of the NEML2 model, i.e., the string inside the brackets [] in the NEML2 input file "
     105             :       "that corresponds to the model you want to use.");
     106       12728 :   params.addParam<std::string>(
     107             :       "device",
     108             :       "Device on which to evaluate the NEML2 model. The string supplied must follow the following "
     109             :       "schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, and "
     110             :       ":<device-index> optionally specifies a device index. For example, device='cpu' sets the "
     111             :       "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
     112             :       "CUDA with device ID 1. If not specified, default to the compute device specified via the "
     113             :       "command line argument --compute-device.");
     114       12728 :   params.addParam<std::string>(
     115             :       "output_device",
     116             :       "Similar to the 'device' parameter, this parameter specifies the device on which to store "
     117             :       "the outputs. Default to be the same as 'device'.");
     118             : 
     119       12728 :   params.addParam<std::string>(
     120             :       "scheduler",
     121             :       "NEML2 scheduler to use to run the model.  If not specified no scheduler is used and MOOSE "
     122             :       "will pass all the constitutive updates to the provided device at once.");
     123             : 
     124        9546 :   params.addParam<bool>("async_dispatch", true, "Whether to use asynchronous dispatch.");
     125             : 
     126        3182 :   return params;
     127           0 : }
     128             : 
     129             : #ifndef NEML2_ENABLED
     130             : 
     131             : template <class T>
     132             : template <typename... P>
     133           0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
     134           0 :   : T(params, args...)
     135             : {
     136           0 : }
     137             : 
     138             : #else
     139             : 
     140             : template <class T>
     141             : template <typename... P>
     142           8 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
     143             :   : T(params, args...),
     144          24 :     _device(params.isParamValid("device") ? neml2::Device(params.get<std::string>("device"))
     145           8 :                                           : this->getMooseApp().getLibtorchDevice()),
     146          16 :     _output_device(params.isParamValid("output_device")
     147          16 :                        ? neml2::Device(params.get<std::string>("output_device"))
     148             :                        : _device),
     149           8 :     _scheduler(nullptr),
     150          24 :     _async_dispatch(params.get<bool>("async_dispatch"))
     151             : {
     152             :   // Load model
     153           8 :   const auto & fname = params.get<DataFileName>("input");
     154           8 :   const auto & cli_args = params.get<std::vector<std::string>>("cli_args");
     155           8 :   _factory = neml2::load_input(std::string(fname), neml2::utils::join(cli_args, " "));
     156           8 :   _model = NEML2Utils::getModel(*_factory, params.get<std::string>("model"));
     157           8 :   _model->to(_device);
     158             : 
     159             :   // Load scheduler if specified
     160          16 :   if (params.isParamValid("scheduler"))
     161           0 :     _scheduler = _factory->get_scheduler(params.get<std::string>("scheduler"));
     162             : 
     163           8 :   if (_scheduler)
     164             :   {
     165           0 :     auto red = [](std::vector<RJType> && results) -> RJType
     166             :     {
     167             :       // Split into two separate vectors
     168           0 :       std::vector<neml2::ValueMap> vms;
     169           0 :       std::vector<neml2::DerivMap> dms;
     170           0 :       for (auto && [vm, dm] : results)
     171             :       {
     172           0 :         vms.push_back(std::move(vm));
     173           0 :         dms.push_back(std::move(dm));
     174             :       }
     175           0 :       return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
     176           0 :                              neml2::derivmap_cat_reduce(std::move(dms), 0));
     177           0 :     };
     178             : 
     179           0 :     auto post = [this](RJType && x) -> RJType
     180             :     {
     181           0 :       return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)), _device),
     182           0 :                              neml2::derivmap_move_device(std::move(std::get<1>(x)), _device));
     183             :     };
     184             : 
     185           0 :     auto thread_init = [this](neml2::Device device) -> void
     186             :     {
     187             :       mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_threads()) == libMesh::n_threads(),
     188             :                   "Inconsistent number of threads");
     189             :       mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_interop_threads()) ==
     190             :                       libMesh::n_threads(),
     191             :                   "Inconsistent number of interop threads");
     192           0 :       auto model = NEML2Utils::getModel(*_factory, _model->name());
     193           0 :       model->to(device);
     194           0 :       _model_pool[std::this_thread::get_id()] = std::move(model);
     195           0 :     };
     196             : 
     197           0 :     _dispatcher = std::make_unique<DispatcherType>(
     198           0 :         *_scheduler,
     199           0 :         _async_dispatch,
     200           0 :         [&](neml2::ValueMap && x, neml2::Device device) -> RJType
     201             :         {
     202           0 :           auto & model =
     203           0 :               _async_dispatch ? libmesh_map_find(_model_pool, std::this_thread::get_id()) : _model;
     204             : 
     205             :           // If this is not an async dispatch, we need to move the model to the target device
     206             :           // _every_ time before evaluation
     207           0 :           if (!_async_dispatch)
     208           0 :             model->to(device);
     209             : 
     210           0 :           return model->value_and_dvalue(std::move(x));
     211             :         },
     212             :         red,
     213           0 :         &neml2::valuemap_move_device,
     214             :         post,
     215           0 :         _async_dispatch ? thread_init : std::function<void(neml2::Device)>());
     216             :   }
     217           8 : }
     218             : 
     219             : template <class T>
     220             : void
     221           8 : NEML2ModelInterface<T>::validateModel() const
     222             : {
     223             :   mooseAssert(_model != nullptr, "_model must be initialized");
     224           8 :   neml2::diagnose(*_model);
     225           8 : }
     226             : 
     227             : #endif // NEML2_ENABLED

Generated by: LCOV version 1.14