19 #include <ATen/Parallel.h> 20 #include "neml2/models/Model.h" 21 #include "neml2/dispatchers/WorkScheduler.h" 22 #include "neml2/dispatchers/WorkDispatcher.h" 23 #include "neml2/dispatchers/valuemap_helpers.h" 24 #include "neml2/dispatchers/derivmap_helpers.h" 37 template <
typename... P>
55 using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
57 neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
70 std::shared_ptr<neml2::Model>
_model;
79 std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>>
_model_pool;
81 #endif // NEML2_ENABLED 92 params.
addParam<std::vector<std::string>>(
96 "Additional command line arguments to use when parsing the NEML2 input file."));
101 "the NEML2 input file that corresponds to the model you want to use."));
106 "Device on which to evaluate the NEML2 model. The string supplied must follow the " 107 "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device " 108 "type, and :<device-index> optionally specifies a device index. For example, " 109 "device='cpu' sets the target compute device to be CPU, and device='cuda:1' sets the " 110 "target compute device to be CUDA with device ID 1."));
115 "NEML2 scheduler to use to run the model. If not specified no scheduler is used and " 116 "MOOSE will pass all the constitutive updates to the provided device at once."));
124 #ifndef NEML2_ENABLED 127 template <
typename... P>
136 template <
typename... P>
138 : T(params, args...),
139 _device(params.
get<
std::string>(
"device")),
141 _async_dispatch(params.
get<bool>(
"async_dispatch"))
144 const auto & fname = params.
get<DataFileName>(
"input");
145 const auto & cli_args = params.
get<std::vector<std::string>>(
"cli_args");
156 auto red = [](std::vector<RJType> && results) ->
RJType 159 std::vector<neml2::ValueMap> vms;
160 std::vector<neml2::DerivMap> dms;
161 for (
auto && [vm,
dm] : results)
163 vms.push_back(std::move(vm));
164 dms.push_back(std::move(
dm));
166 return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
167 neml2::derivmap_cat_reduce(std::move(dms), 0));
172 return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)),
_device),
173 neml2::derivmap_move_device(std::move(std::get<1>(x)),
_device));
176 auto thread_init = [
this](neml2::Device
device) ->
void 188 [&](neml2::ValueMap && x, neml2::Device
device) ->
RJType 198 return model->value_and_dvalue(std::move(x));
201 &neml2::valuemap_move_device,
211 mooseAssert(_model !=
nullptr,
"_model must be initialized");
212 neml2::diagnose(*_model);
215 #endif // NEML2_ENABLED std::shared_ptr< neml2::WorkScheduler > _scheduler
The work scheduler to use.
neml2::WorkScheduler * scheduler()
Get the work scheduler.
std::string join(Iterator begin, Iterator end, const std::string &delimiter)
Python-like join function for strings over an iterator range.
static InputParameters validParams()
std::unordered_map< std::thread::id, std::shared_ptr< neml2::Model > > _model_pool
Models for each thread.
T * get(const std::unique_ptr< T > &u)
The MooseUtils::get() specializations are used to support making forwards-compatible code changes fro...
Interface class to provide common input parameters, members, and methods for MOOSEObjects that use NE...
PetscErrorCode PetscOptionItems *PetscErrorCode DM dm
neml2::Model & model() const
Get the NEML2 model.
std::unique_ptr< DispatcherType > _dispatcher
Work dispatcher.
const neml2::Device & device() const
Get the target compute device.
const std::unique_ptr< DispatcherType > & dispatcher() const
Get the work dispatcher.
std::tuple< neml2::ValueMap, neml2::DerivMap > RJType
InputParameters validParams()
std::shared_ptr< neml2::Model > _model
The NEML2 material model.
NEML2ModelInterface(const InputParameters ¶ms, P &&... args)
std::string docstring(const std::string &desc)
Augment docstring if NEML2 is not enabled.
std::shared_ptr< neml2::Model > getModel(neml2::Factory &factory, const std::string &name, neml2::Dtype dtype=neml2::kFloat64)
Get the NEML2 Model.
neml2::WorkDispatcher< neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType > DispatcherType
std::unique_ptr< neml2::Factory > _factory
The NEML2 factory.
const bool _async_dispatch
Whether to dispatch work asynchronously.
virtual void validateModel() const
Validate the NEML2 material model.
const neml2::Device _device
The device on which to evaluate the NEML2 model.