https://mooseframework.inl.gov
NEML2ModelInterface.h
Go to the documentation of this file.
1 //* This file is part of the MOOSE framework
2 //* https://mooseframework.inl.gov
3 //*
4 //* All rights reserved, see COPYRIGHT for full restrictions
5 //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 //*
7 //* Licensed under LGPL 2.1, please see LICENSE for details
8 //* https://www.gnu.org/licenses/lgpl-2.1.html
9 
10 #pragma once
11 
12 #include <thread>
13 #include <utility>
14 #include <tuple>
15 #include "NEML2Utils.h"
16 #include "InputParameters.h"
17 
18 #ifdef NEML2_ENABLED
19 #include <ATen/Parallel.h>
20 #include "neml2/neml2.h"
21 #include "neml2/models/Model.h"
22 #include "neml2/dispatchers/WorkScheduler.h"
23 #include "neml2/dispatchers/WorkDispatcher.h"
24 #include "neml2/dispatchers/valuemap_helpers.h"
25 #include "neml2/dispatchers/derivmap_helpers.h"
26 #endif
27 
32 template <class T>
33 class NEML2ModelInterface : public T
34 {
35 public:
37 
38  template <typename... P>
39  NEML2ModelInterface(const InputParameters & params, P &&... args);
40 
41 #ifdef NEML2_ENABLED
42 
43 protected:
48  virtual void validateModel() const;
49 
51  neml2::Model & model() const { return *_model; }
52 
54  const neml2::Device & device() const { return _device; }
55 
57  const neml2::Device & output_device() const { return _output_device; }
58 
59  using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
60  using DispatcherType =
61  neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
62 
64  neml2::WorkScheduler * scheduler() { return _scheduler.get(); }
66  const std::unique_ptr<DispatcherType> & dispatcher() const { return _dispatcher; }
67 
68 private:
70  const neml2::Device _device;
72  const neml2::Device _output_device;
74  std::unique_ptr<neml2::Factory> _factory;
76  std::shared_ptr<neml2::Model> _model;
77 
79  std::shared_ptr<neml2::WorkScheduler> _scheduler;
81  std::unique_ptr<DispatcherType> _dispatcher;
83  const bool _async_dispatch;
85  std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>> _model_pool;
86 
87 #endif // NEML2_ENABLED
88 };
89 
90 template <class T>
93 {
95  params.addParam<DataFileName>("input",
96  "Path to the NEML2 input file containing the NEML2 model(s).");
97  params.addParam<std::vector<std::string>>(
98  "cli_args",
99  {},
100  "Additional command line arguments to use when parsing the NEML2 input file.");
101  params.addParam<std::string>(
102  "model",
103  "",
104  "Name of the NEML2 model, i.e., the string inside the brackets [] in the NEML2 input file "
105  "that corresponds to the model you want to use.");
106  params.addParam<std::string>(
107  "device",
108  "Device on which to evaluate the NEML2 model. The string supplied must follow the following "
109  "schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, and "
110  ":<device-index> optionally specifies a device index. For example, device='cpu' sets the "
111  "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
112  "CUDA with device ID 1. If not specified, default to the compute device specified via the "
113  "command line argument --compute-device.");
114  params.addParam<std::string>(
115  "output_device",
116  "Similar to the 'device' parameter, this parameter specifies the device on which to store "
117  "the outputs. Default to be the same as 'device'.");
118 
119  params.addParam<std::string>(
120  "scheduler",
121  "NEML2 scheduler to use to run the model. If not specified no scheduler is used and MOOSE "
122  "will pass all the constitutive updates to the provided device at once.");
123 
124  params.addParam<bool>("async_dispatch", true, "Whether to use asynchronous dispatch.");
125 
126  return params;
127 }
128 
129 #ifndef NEML2_ENABLED
130 
131 template <class T>
132 template <typename... P>
134  : T(params, args...)
135 {
136 }
137 
138 #else
139 
140 template <class T>
141 template <typename... P>
143  : T(params, args...),
144  _device(params.isParamValid("device") ? neml2::Device(params.get<std::string>("device"))
145  : this->getMooseApp().getLibtorchDevice()),
146  _output_device(params.isParamValid("output_device")
147  ? neml2::Device(params.get<std::string>("output_device"))
148  : _device),
149  _scheduler(nullptr),
150  _async_dispatch(params.get<bool>("async_dispatch"))
151 {
152  // Load model
153  const auto & fname = params.get<DataFileName>("input");
154  const auto & cli_args = params.get<std::vector<std::string>>("cli_args");
155  _factory = neml2::load_input(std::string(fname), neml2::utils::join(cli_args, " "));
156  _model = NEML2Utils::getModel(*_factory, params.get<std::string>("model"));
157  _model->to(_device);
158 
159  // Load scheduler if specified
160  if (params.isParamValid("scheduler"))
161  _scheduler = _factory->get_scheduler(params.get<std::string>("scheduler"));
162 
163  if (_scheduler)
164  {
165  auto red = [](std::vector<RJType> && results) -> RJType
166  {
167  // Split into two separate vectors
168  std::vector<neml2::ValueMap> vms;
169  std::vector<neml2::DerivMap> dms;
170  for (auto && [vm, dm] : results)
171  {
172  vms.push_back(std::move(vm));
173  dms.push_back(std::move(dm));
174  }
175  return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
176  neml2::derivmap_cat_reduce(std::move(dms), 0));
177  };
178 
179  auto post = [this](RJType && x) -> RJType
180  {
181  return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)), _device),
182  neml2::derivmap_move_device(std::move(std::get<1>(x)), _device));
183  };
184 
185  auto thread_init = [this](neml2::Device device) -> void
186  {
187  mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_threads()) == libMesh::n_threads(),
188  "Inconsistent number of threads");
189  mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_interop_threads()) ==
191  "Inconsistent number of interop threads");
192  auto model = NEML2Utils::getModel(*_factory, _model->name());
193  model->to(device);
194  _model_pool[std::this_thread::get_id()] = std::move(model);
195  };
196 
197  _dispatcher = std::make_unique<DispatcherType>(
198  *_scheduler,
200  [&](neml2::ValueMap && x, neml2::Device device) -> RJType
201  {
202  auto & model =
203  _async_dispatch ? libmesh_map_find(_model_pool, std::this_thread::get_id()) : _model;
204 
205  // If this is not an async dispatch, we need to move the model to the target device
206  // _every_ time before evaluation
207  if (!_async_dispatch)
208  model->to(device);
209 
210  return model->value_and_dvalue(std::move(x));
211  },
212  red,
213  &neml2::valuemap_move_device,
214  post,
215  _async_dispatch ? thread_init : std::function<void(neml2::Device)>());
216  }
217 }
218 
219 template <class T>
220 void
222 {
223  mooseAssert(_model != nullptr, "_model must be initialized");
224  neml2::diagnose(*_model);
225 }
226 
227 #endif // NEML2_ENABLED
std::shared_ptr< neml2::WorkScheduler > _scheduler
The work scheduler to use.
neml2::WorkScheduler * scheduler()
Get the work scheduler.
unsigned int n_threads()
static InputParameters validParams()
std::unordered_map< std::thread::id, std::shared_ptr< neml2::Model > > _model_pool
Models for each thread.
std::vector< std::pair< R1, R2 > > get(const std::string &param1, const std::string &param2) const
Combine two vector parameters into a single vector of pairs.
Interface class to provide common input parameters, members, and methods for MOOSEObjects that use NE...
PetscErrorCode PetscOptionItems *PetscErrorCode DM dm
neml2::Model & model() const
Get the NEML2 model.
std::unique_ptr< DispatcherType > _dispatcher
Work dispatcher.
const neml2::Device & device() const
Get the target compute device.
The main MOOSE class responsible for handling user-defined parameters in almost every MOOSE system...
const std::unique_ptr< DispatcherType > & dispatcher() const
Get the work dispatcher.
std::tuple< neml2::ValueMap, neml2::DerivMap > RJType
std::shared_ptr< neml2::Model > _model
The NEML2 material model.
NEML2ModelInterface(const InputParameters &params, P &&... args)
const neml2::Device _output_device
The device on which to store the outputs.
std::shared_ptr< neml2::Model > getModel(neml2::Factory &factory, const std::string &name, neml2::Dtype dtype=neml2::kFloat64)
Get the NEML2 Model.
Definition: NEML2Utils.C:38
neml2::WorkDispatcher< neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType > DispatcherType
const neml2::Device & output_device() const
Get the target output device.
std::unique_ptr< neml2::Factory > _factory
The NEML2 factory.
void addParam(const std::string &name, const S &value, const std::string &doc_string)
These methods add an optional parameter and a documentation string to the InputParameters object...
const bool _async_dispatch
Whether to dispatch work asynchronously.
virtual void validateModel() const
Validate the NEML2 material model.
const Elem & get(const ElemType type_in)
InputParameters validParams()
const neml2::Device _device
The device on which to evaluate the NEML2 model.
bool isParamValid(const std::string &name) const
This method returns parameters that have been initialized in one fashion or another, i.e.