Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include <thread>
13 : #include <utility>
14 : #include <tuple>
15 : #include "NEML2Utils.h"
16 : #include "InputParameters.h"
17 :
18 : #ifdef NEML2_ENABLED
19 : #include <ATen/Parallel.h>
20 : #include "neml2/neml2.h"
21 : #include "neml2/models/Model.h"
22 : #include "neml2/dispatchers/WorkScheduler.h"
23 : #include "neml2/dispatchers/WorkDispatcher.h"
24 : #include "neml2/dispatchers/valuemap_helpers.h"
25 : #include "neml2/dispatchers/derivmap_helpers.h"
26 : #endif
27 :
28 : /**
29 : * Interface class to provide common input parameters, members, and methods for MOOSEObjects that
30 : * use NEML2 models.
31 : */
32 : template <class T>
33 : class NEML2ModelInterface : public T
34 : {
35 : public:
36 : static InputParameters validParams();
37 :
38 : template <typename... P>
39 : NEML2ModelInterface(const InputParameters & params, P &&... args);
40 :
41 : #ifdef NEML2_ENABLED
42 :
43 : protected:
44 : /**
45 : * Validate the NEML2 material model. Note that the developer is responsible for calling this
46 : * method at the appropriate times, for example, at initialSetup().
47 : */
48 : virtual void validateModel() const;
49 :
50 : /// Get the NEML2 model
51 2928 : neml2::Model & model() const { return *_model; }
52 :
53 : /// Get the target compute device
54 1415 : const neml2::Device & device() const { return _device; }
55 :
56 : /// Get the target output device
57 2816 : const neml2::Device & output_device() const { return _output_device; }
58 :
59 : using RJType = std::tuple<neml2::ValueMap, neml2::DerivMap>;
60 : using DispatcherType =
61 : neml2::WorkDispatcher<neml2::ValueMap, RJType, RJType, neml2::ValueMap, RJType>;
62 :
63 : /// Get the work scheduler
64 1408 : neml2::WorkScheduler * scheduler() { return _scheduler.get(); }
65 : /// Get the work dispatcher
66 0 : const std::unique_ptr<DispatcherType> & dispatcher() const { return _dispatcher; }
67 :
68 : private:
69 : /// The device on which to evaluate the NEML2 model
70 : const neml2::Device _device;
71 : /// The device on which to store the outputs
72 : const neml2::Device _output_device;
73 : /// The NEML2 factory
74 : std::unique_ptr<neml2::Factory> _factory;
75 : /// The NEML2 material model
76 : std::shared_ptr<neml2::Model> _model;
77 :
78 : /// The work scheduler to use
79 : std::shared_ptr<neml2::WorkScheduler> _scheduler;
80 : /// Work dispatcher
81 : std::unique_ptr<DispatcherType> _dispatcher;
82 : /// Whether to dispatch work asynchronously
83 : const bool _async_dispatch;
84 : /// Models for each thread
85 : std::unordered_map<std::thread::id, std::shared_ptr<neml2::Model>> _model_pool;
86 :
87 : #endif // NEML2_ENABLED
88 : };
89 :
90 : template <class T>
91 : InputParameters
92 3182 : NEML2ModelInterface<T>::validParams()
93 : {
94 3182 : InputParameters params = T::validParams();
95 12728 : params.addParam<DataFileName>("input",
96 : "Path to the NEML2 input file containing the NEML2 model(s).");
97 12728 : params.addParam<std::vector<std::string>>(
98 : "cli_args",
99 : {},
100 : "Additional command line arguments to use when parsing the NEML2 input file.");
101 12728 : params.addParam<std::string>(
102 : "model",
103 : "",
104 : "Name of the NEML2 model, i.e., the string inside the brackets [] in the NEML2 input file "
105 : "that corresponds to the model you want to use.");
106 12728 : params.addParam<std::string>(
107 : "device",
108 : "Device on which to evaluate the NEML2 model. The string supplied must follow the following "
109 : "schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, and "
110 : ":<device-index> optionally specifies a device index. For example, device='cpu' sets the "
111 : "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
112 : "CUDA with device ID 1. If not specified, default to the compute device specified via the "
113 : "command line argument --compute-device.");
114 12728 : params.addParam<std::string>(
115 : "output_device",
116 : "Similar to the 'device' parameter, this parameter specifies the device on which to store "
117 : "the outputs. Default to be the same as 'device'.");
118 :
119 12728 : params.addParam<std::string>(
120 : "scheduler",
121 : "NEML2 scheduler to use to run the model. If not specified no scheduler is used and MOOSE "
122 : "will pass all the constitutive updates to the provided device at once.");
123 :
124 9546 : params.addParam<bool>("async_dispatch", true, "Whether to use asynchronous dispatch.");
125 :
126 3182 : return params;
127 0 : }
128 :
129 : #ifndef NEML2_ENABLED
130 :
131 : template <class T>
132 : template <typename... P>
133 0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
134 0 : : T(params, args...)
135 : {
136 0 : }
137 :
138 : #else
139 :
140 : template <class T>
141 : template <typename... P>
142 8 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
143 : : T(params, args...),
144 24 : _device(params.isParamValid("device") ? neml2::Device(params.get<std::string>("device"))
145 8 : : this->getMooseApp().getLibtorchDevice()),
146 16 : _output_device(params.isParamValid("output_device")
147 16 : ? neml2::Device(params.get<std::string>("output_device"))
148 : : _device),
149 8 : _scheduler(nullptr),
150 24 : _async_dispatch(params.get<bool>("async_dispatch"))
151 : {
152 : // Load model
153 8 : const auto & fname = params.get<DataFileName>("input");
154 8 : const auto & cli_args = params.get<std::vector<std::string>>("cli_args");
155 8 : _factory = neml2::load_input(std::string(fname), neml2::utils::join(cli_args, " "));
156 8 : _model = NEML2Utils::getModel(*_factory, params.get<std::string>("model"));
157 8 : _model->to(_device);
158 :
159 : // Load scheduler if specified
160 16 : if (params.isParamValid("scheduler"))
161 0 : _scheduler = _factory->get_scheduler(params.get<std::string>("scheduler"));
162 :
163 8 : if (_scheduler)
164 : {
165 0 : auto red = [](std::vector<RJType> && results) -> RJType
166 : {
167 : // Split into two separate vectors
168 0 : std::vector<neml2::ValueMap> vms;
169 0 : std::vector<neml2::DerivMap> dms;
170 0 : for (auto && [vm, dm] : results)
171 : {
172 0 : vms.push_back(std::move(vm));
173 0 : dms.push_back(std::move(dm));
174 : }
175 0 : return std::make_tuple(neml2::valuemap_cat_reduce(std::move(vms), 0),
176 0 : neml2::derivmap_cat_reduce(std::move(dms), 0));
177 0 : };
178 :
179 0 : auto post = [this](RJType && x) -> RJType
180 : {
181 0 : return std::make_tuple(neml2::valuemap_move_device(std::move(std::get<0>(x)), _device),
182 0 : neml2::derivmap_move_device(std::move(std::get<1>(x)), _device));
183 : };
184 :
185 0 : auto thread_init = [this](neml2::Device device) -> void
186 : {
187 : mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_threads()) == libMesh::n_threads(),
188 : "Inconsistent number of threads");
189 : mooseAssert(libMesh::cast_int<unsigned int>(at::get_num_interop_threads()) ==
190 : libMesh::n_threads(),
191 : "Inconsistent number of interop threads");
192 0 : auto model = NEML2Utils::getModel(*_factory, _model->name());
193 0 : model->to(device);
194 0 : _model_pool[std::this_thread::get_id()] = std::move(model);
195 0 : };
196 :
197 0 : _dispatcher = std::make_unique<DispatcherType>(
198 0 : *_scheduler,
199 0 : _async_dispatch,
200 0 : [&](neml2::ValueMap && x, neml2::Device device) -> RJType
201 : {
202 0 : auto & model =
203 0 : _async_dispatch ? libmesh_map_find(_model_pool, std::this_thread::get_id()) : _model;
204 :
205 : // If this is not an async dispatch, we need to move the model to the target device
206 : // _every_ time before evaluation
207 0 : if (!_async_dispatch)
208 0 : model->to(device);
209 :
210 0 : return model->value_and_dvalue(std::move(x));
211 : },
212 : red,
213 0 : &neml2::valuemap_move_device,
214 : post,
215 0 : _async_dispatch ? thread_init : std::function<void(neml2::Device)>());
216 : }
217 8 : }
218 :
219 : template <class T>
220 : void
221 8 : NEML2ModelInterface<T>::validateModel() const
222 : {
223 : mooseAssert(_model != nullptr, "_model must be initialized");
224 8 : neml2::diagnose(*_model);
225 8 : }
226 :
227 : #endif // NEML2_ENABLED
|