Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include <thread>
13 : #include <utility>
14 : #include <tuple>
15 : #include "NEML2Utils.h"
16 : #include "InputParameters.h"
17 :
18 : #ifdef NEML2_ENABLED
19 : #include <ATen/Parallel.h>
20 : #include "neml2/models/Model.h"
21 : #include "neml2/dispatchers/WorkScheduler.h"
22 : #include "neml2/dispatchers/WorkDispatcher.h"
23 : #include "neml2/dispatchers/valuemap_helpers.h"
24 : #include "neml2/dispatchers/derivmap_helpers.h"
25 : #endif
26 :
27 : /**
28 : * Interface class to provide common input parameters, members, and methods for MOOSEObjects that
29 : * use NEML2 models.
30 : */
31 : template <class T>
32 : class NEML2ModelInterface : public T
33 : {
34 : public:
35 : static InputParameters validParams();
36 :
37 : template <typename... P>
38 : NEML2ModelInterface(const InputParameters & params, P &&... args);
39 :
40 : #ifdef NEML2_ENABLED
41 :
42 : protected:
43 : /**
44 : * Validate the NEML2 material model. Note that the developer is responsible for calling this
45 : * method at the appropriate times, for example, at initialSetup().
46 : */
47 : virtual void validateModel() const;
48 :
49 : /// Get the NEML2 model
50 5272 : neml2::Model & model() const { return *_model; }
51 :
52 : /// Get the target compute device
53 1731 : const neml2::Device & device() const { return _device; }
54 :
55 : using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
56 : using DispatcherType =
57 : neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
58 :
59 : /// Get the work scheduler
60 1650 : neml2::WorkScheduler * scheduler() { return _scheduler.get(); }
61 : /// Get the work dispatcher
62 4 : const std::unique_ptr<DispatcherType> & dispatcher() const { return _dispatcher; }
63 :
64 : private:
65 : /// The device on which to evaluate the NEML2 model
66 : const neml2::Device _device;
67 : /// The NEML2 factory
68 : std::unique_ptr<neml2::Factory> _factory;
69 : /// The NEML2 material model
70 : std::shared_ptr<neml2::Model> _model;
71 :
72 : /// The work scheduler to use
73 : std::shared_ptr<neml2::WorkScheduler> _scheduler;
74 : /// Work dispatcher
75 : std::unique_ptr<DispatcherType> _dispatcher;
76 : /// Whether to dispatch work asynchronously
77 : const bool _async_dispatch;
78 : /// Models for each thread
79 : std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>> _model_pool;
80 :
81 : #endif // NEML2_ENABLED
82 : };
83 :
84 : template <class T>
85 : InputParameters
86 14759 : NEML2ModelInterface<T>::validParams()
87 : {
88 14759 : InputParameters params = T::validParams();
89 14759 : params.addParam<DataFileName>(
90 : "input",
91 : NEML2Utils::docstring("Path to the NEML2 input file containing the NEML2 model(s)."));
92 14759 : params.addParam<std::vector<std::string>>(
93 : "cli_args",
94 : {},
95 : NEML2Utils::docstring(
96 : "Additional command line arguments to use when parsing the NEML2 input file."));
97 14759 : params.addParam<std::string>(
98 : "model",
99 : "",
100 : NEML2Utils::docstring("Name of the NEML2 model, i.e., the string inside the brackets [] in "
101 : "the NEML2 input file that corresponds to the model you want to use."));
102 14759 : params.addParam<std::string>(
103 : "device",
104 : "cpu",
105 : NEML2Utils::docstring(
106 : "Device on which to evaluate the NEML2 model. The string supplied must follow the "
107 : "following schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device "
108 : "type, and :<device-index> optionally specifies a device index. For example, "
109 : "device='cpu' sets the target compute device to be CPU, and device='cuda:1' sets the "
110 : "target compute device to be CUDA with device ID 1."));
111 :
112 14759 : params.addParam<std::string>(
113 : "scheduler",
114 : NEML2Utils::docstring(
115 : "NEML2 scheduler to use to run the model. If not specified no scheduler is used and "
116 : "MOOSE will pass all the constitutive updates to the provided device at once."));
117 :
118 44277 : params.addParam<bool>(
119 29518 : "async_dispatch", true, NEML2Utils::docstring("Whether to use asynchronous dispatch."));
120 :
121 14759 : return params;
122 0 : }
123 :
124 : #ifndef NEML2_ENABLED
125 :
126 : template <class T>
127 : template <typename... P>
128 0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
129 0 : : T(params, args...)
130 : {
131 0 : }
132 :
133 : #else
134 :
135 : template <class T>
136 : template <typename... P>
137 22 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
138 : : T(params, args...),
139 22 : _device(params.get<std::string>("device")),
140 22 : _scheduler(nullptr),
141 66 : _async_dispatch(params.get<bool>("async_dispatch"))
142 : {
143 : // Load model
144 22 : const auto & fname = params.get<DataFileName>("input");
145 22 : const auto & cli_args = params.get<std::vector<std::string>>("cli_args");
146 22 : _factory = neml2::load_input(std::string(fname), neml2::utils::join(cli_args, " "));
147 22 : _model = NEML2Utils::getModel(*_factory, params.get<std::string>("model"));
148 22 : _model->to(_device);
149 :
150 : // Load scheduler if specified
151 22 : if (params.isParamValid("scheduler"))
152 2 : _scheduler = _factory->get_scheduler(params.get<std::string>("scheduler"));
153 :
154 22 : if (_scheduler)
155 : {
156 4 : auto red = [](std::vector<RJType> && results) -> RJType
157 : {
158 : // Split into two separate vectors
159 4 : std::vector<neml2::ValueMap> vms;
160 4 : std::vector<neml2::DerivMap> dms;
161 44 : for (auto && [vm, dm] : results)
162 : {
163 40 : vms.push_back(std::move(vm));
164 40 : dms.push_back(std::move(dm));
165 : }
166 4 : return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
167 12 : neml2::derivmap_cat_reduce(std::move(dms), 0));
168 4 : };
169 :
170 82 : auto post = [this](RJType && x) -> RJType
171 : {
172 40 : return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)), _device),
173 80 : neml2::derivmap_move_device(std::move(std::get<1>(x)), _device));
174 : };
175 :
176 1 : auto thread_init = [this](neml2::Device device) -> void
177 : {
178 1 : at::set_num_threads(libMesh::n_threads());
179 1 : at::set_num_interop_threads(libMesh::n_threads());
180 1 : auto model = NEML2Utils::getModel(*_factory, _model->name());
181 1 : model->to(device);
182 1 : _model_pool[std::this_thread::get_id()] = std::move(model);
183 1 : };
184 :
185 2 : _dispatcher = std::make_unique<DispatcherType>(
186 2 : *_scheduler,
187 2 : _async_dispatch,
188 0 : [&](neml2::ValueMap && x, neml2::Device device) -> RJType
189 : {
190 40 : auto & model =
191 40 : _async_dispatch ? libmesh_map_find(_model_pool, std::this_thread::get_id()) : _model;
192 :
193 : // If this is not an async dispatch, we need to move the model to the target device
194 : // _every_ time before evaluation
195 40 : if (!_async_dispatch)
196 20 : model->to(device);
197 :
198 40 : return model->value_and_dvalue(std::move(x));
199 : },
200 : red,
201 2 : &neml2::valuemap_move_device,
202 : post,
203 4 : _async_dispatch ? thread_init : std::function<void(neml2::Device)>());
204 : }
205 22 : }
206 :
207 : template <class T>
208 : void
209 22 : NEML2ModelInterface<T>::validateModel() const
210 : {
211 : mooseAssert(_model != nullptr, "_model must be initialized");
212 22 : neml2::diagnose(*_model);
213 22 : }
214 :
215 : #endif // NEML2_ENABLED
|