Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #include "NEML2ModelExecutor.h"
11 : #include "MOOSEToNEML2.h"
12 : #include "NEML2Utils.h"
13 : #include <string>
14 : #include <sstream>
15 :
16 : #ifdef NEML2_ENABLED
17 : #include <ATen/ATen.h>
18 : #include "libmesh/id_types.h"
19 : #include "neml2/tensors/functions/jacrev.h"
20 : #include "neml2/dispatchers/ValueMapLoader.h"
21 : #include "neml2/misc/string_utils.h"
22 : #include "neml2/base/Settings.h"
23 : #endif
24 :
25 : registerMooseObject("MooseApp", NEML2ModelExecutor);
26 :
27 : InputParameters
28 3244 : NEML2ModelExecutor::actionParams()
29 : {
30 3244 : auto params = emptyInputParameters();
31 9732 : params.addParam<bool>(
32 : "manage_state_advance",
33 6488 : false,
34 : "Keep state and forces on the device and advance it to old_state and old_forces without a "
35 : "roundtrip through MOOSE materials. This is only recommended for explicit time integration "
36 : "or when absolutely no restepping occurs (e.g. failed timesteps).");
37 6488 : params.addParam<bool>(
38 : "debug_inputs_on_failure",
39 6488 : false,
40 : "When a NEML2 solve fails, append a detailed dump of input tensors (defined/missing, "
41 : "shapes, and devices) to the error message.");
42 3244 : return params;
43 0 : }
44 :
45 : InputParameters
46 3109 : NEML2ModelExecutor::validParams()
47 : {
48 3109 : auto params = NEML2ModelInterface<GeneralUserObject>::validParams();
49 3109 : params += NEML2ModelExecutor::actionParams();
50 6218 : params.addClassDescription("Execute the specified NEML2 model");
51 :
52 12436 : params.addRequiredParam<UserObjectName>(
53 : "batch_index_generator",
54 : "The NEML2BatchIndexGenerator used to generate the element-to-batch-index map.");
55 12436 : params.addParam<std::vector<UserObjectName>>(
56 : "gatherers",
57 : {},
58 : "List of MOOSE*ToNEML2 user objects gathering MOOSE data as NEML2 input variables");
59 9327 : params.addParam<std::vector<UserObjectName>>(
60 : "param_gatherers",
61 : {},
62 : "List of MOOSE*ToNEML2 user objects gathering MOOSE data as NEML2 model parameters");
63 :
64 : // Since we use the NEML2 model to evaluate the residual AND the Jacobian at the same time, we
65 : // want to execute this user object only at execute_on = LINEAR (i.e. during residual evaluation).
66 : // The NONLINEAR exec flag below is for computing Jacobian during automatic scaling.
67 3109 : ExecFlagEnum execute_options = MooseUtils::getDefaultExecFlagEnum();
68 15545 : execute_options = {EXEC_INITIAL, EXEC_LINEAR, EXEC_NONLINEAR, EXEC_TIMESTEP_END};
69 3109 : params.set<ExecFlagEnum>("execute_on") = execute_options;
70 :
71 6218 : return params;
72 6218 : }
73 :
74 24 : NEML2ModelExecutor::NEML2ModelExecutor(const InputParameters & params)
75 0 : : NEML2ModelInterface<GeneralUserObject>(params)
76 : #ifdef NEML2_ENABLED
77 : ,
78 24 : _batch_index_generator(getUserObject<NEML2BatchIndexGenerator>("batch_index_generator")),
79 48 : _manage_state_advance(getParam<bool>("manage_state_advance")),
80 48 : _debug_inputs_on_failure(getParam<bool>("debug_inputs_on_failure")),
81 24 : _output_ready(false),
82 72 : _error_message("")
83 : #endif
84 : {
85 : #ifdef NEML2_ENABLED
86 24 : validateModel();
87 :
88 : // add user object dependencies by name (the UOs do not need to exist yet for this)
89 111 : for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("gatherers"))
90 39 : _depend_uo.insert(gatherer_name);
91 74 : for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("param_gatherers"))
92 2 : _depend_uo.insert(gatherer_name);
93 : #endif
94 24 : }
95 :
96 : #ifdef NEML2_ENABLED
97 : void
98 24 : NEML2ModelExecutor::initialSetup()
99 : {
100 : // deal with user object provided inputs
101 111 : for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("gatherers"))
102 : {
103 : // gather coupled user objects late to ensure they are constructed. Do not add them as
104 : // dependencies (that's already done in the constructor).
105 39 : const auto & uo = getUserObjectByName<MOOSEToNEML2>(gatherer_name, /*is_dependency=*/false);
106 :
107 : // there's no need to gather old/older variables if we're managing state advance
108 39 : auto sep = model().settings().history_separator();
109 39 : auto [base_name, history_order] = neml2::parse_history(uo.NEML2Name(), sep);
110 39 : if (_manage_state_advance && history_order > 0)
111 0 : paramError("gatherers",
112 : "The gatherer for history variable `",
113 0 : uo.NEML2Name(),
114 : "` is not needed when `manage_state_advance = true`.");
115 :
116 39 : addGatheredVariable(gatherer_name, uo.NEML2Name());
117 39 : _gatherers.push_back(&uo);
118 39 : }
119 :
120 : // deal with user object provided model parameters
121 74 : for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("param_gatherers"))
122 : {
123 : // gather coupled user objects late to ensure they are constructed. Do not add them as
124 : // dependencies (that's already done in the constructor).
125 2 : const auto & uo = getUserObjectByName<MOOSEToNEML2>(gatherer_name, /*is_dependency=*/false);
126 2 : addGatheredParameter(gatherer_name, uo.NEML2Name());
127 2 : _param_gatherers.push_back(&uo);
128 : }
129 :
130 : // iterate over set of required inputs and error out if we find one that is not provided
131 63 : for (const auto & [iname, ivar] : model().input_variables())
132 : {
133 : // if tensors are kept on device, we are not going to gather old values from moose
134 39 : if (_manage_state_advance && ivar->history_order() > 0)
135 0 : continue;
136 39 : if (!_gathered_variable_names.count(iname))
137 0 : paramError("gatherers", "The required model input `", iname, "` is not gathered");
138 : }
139 :
140 : // keep track of stateful variables if manage_state_advance is true
141 24 : if (_manage_state_advance)
142 0 : for (const auto & [iname, ivar] : model().input_variables())
143 0 : if (ivar->history_order() > 0)
144 0 : _state_vars[iname] = neml2::Tensor();
145 24 : }
146 :
147 : std::size_t
148 69290 : NEML2ModelExecutor::getBatchIndex(dof_id_type elem_id) const
149 : {
150 69290 : return _batch_index_generator.getBatchIndex(elem_id);
151 : }
152 :
153 : void
154 39 : NEML2ModelExecutor::addGatheredVariable(const UserObjectName & gatherer_name,
155 : const neml2::VariableName & var)
156 : {
157 39 : if (_gathered_variable_names.count(var))
158 0 : paramError("gatherers",
159 : "The NEML2 input variable `",
160 : var,
161 : "` gathered by UO '",
162 : gatherer_name,
163 : "' is already gathered by another gatherer.");
164 39 : _gathered_variable_names.insert(var);
165 39 : }
166 :
167 : void
168 2 : NEML2ModelExecutor::addGatheredParameter(const UserObjectName & gatherer_name,
169 : const std::string & param)
170 : {
171 2 : if (_gathered_parameter_names.count(param))
172 0 : paramError("gatherers",
173 : "The NEML2 model parameter `",
174 : param,
175 : "` gathered by UO '",
176 : gatherer_name,
177 : "' is already gathered by another gatherer.");
178 2 : _gathered_parameter_names.insert(param);
179 2 : }
180 :
181 : void
182 3369 : NEML2ModelExecutor::initialize()
183 : {
184 3369 : if (!NEML2Utils::shouldCompute(_fe_problem))
185 1214 : return;
186 :
187 2155 : _output_ready = false;
188 2155 : _error = false;
189 2155 : _error_message.clear();
190 : }
191 :
192 : void
193 8 : NEML2ModelExecutor::meshChanged()
194 : {
195 8 : if (!NEML2Utils::shouldCompute(_fe_problem))
196 0 : return;
197 :
198 8 : _output_ready = false;
199 8 : if (_manage_state_advance)
200 0 : mooseError("The mesh changed while `manage_state_advance = true` for NEML2 model executor '",
201 0 : name(),
202 : "'. This mode requires a fixed mesh because state history is cached on the device.");
203 : }
204 :
205 : void
206 3369 : NEML2ModelExecutor::execute()
207 : {
208 3369 : if (!NEML2Utils::shouldCompute(_fe_problem))
209 1214 : return;
210 :
211 2155 : if (_current_execute_flag == EXEC_TIMESTEP_END)
212 : {
213 434 : if (_manage_state_advance && _fe_problem.solverSystemConverged(/*sys_num=*/0))
214 0 : advanceState();
215 434 : return;
216 : }
217 :
218 : // If the batch is empty, we do not need to do anything
219 1721 : if (_batch_index_generator.isEmpty())
220 5 : return;
221 :
222 1716 : fillInputs();
223 :
224 1716 : if (_t_step > 0)
225 : {
226 1694 : auto success = solve();
227 1694 : if (success)
228 1653 : extractOutputs();
229 : }
230 : }
231 :
232 : void
233 1716 : NEML2ModelExecutor::fillInputs()
234 : {
235 : try
236 : {
237 3531 : for (const auto & uo : _gatherers)
238 1815 : uo->insertInto(_in);
239 1722 : for (const auto & uo : _param_gatherers)
240 6 : uo->insertInto(_model_params);
241 :
242 1716 : if (_manage_state_advance && _t_step > 0)
243 0 : for (const auto & [name, val] : _state_vars)
244 0 : if (val.defined())
245 0 : _in[name] = val;
246 :
247 : // Send input variables and parameters to device
248 3531 : for (auto & [var, val] : _in)
249 1815 : val = val.to(device());
250 1722 : for (auto & [param, pval] : _model_params)
251 6 : pval = pval.to(device());
252 :
253 : // Update model parameters
254 1716 : model().set_parameters(_model_params);
255 1716 : _model_params.clear();
256 :
257 : // Request gradient for the model parameters that we request AD for
258 1716 : for (const auto & [y, dy] : _retrieved_parameter_derivatives)
259 0 : for (const auto & [p, tensor] : dy)
260 0 : model().get_parameter(p).requires_grad_(true);
261 : }
262 0 : catch (std::exception & e)
263 : {
264 0 : mooseError("An error occurred while filling inputs for the NEML2 model. Error message:\n",
265 0 : e.what(),
266 : NEML2Utils::NEML2_help_message);
267 0 : }
268 1716 : }
269 :
270 : void
271 4 : NEML2ModelExecutor::expandInputs()
272 : {
273 : // Figure out what our batch size is
274 4 : std::vector<neml2::Tensor> defined;
275 12 : for (const auto & [key, value] : _in)
276 8 : defined.push_back(value);
277 4 : const auto s = neml2::utils::broadcast_dynamic_sizes(defined);
278 :
279 : // Make all inputs conformal
280 12 : for (auto & [key, value] : _in)
281 8 : if (value.dynamic_sizes() != s)
282 0 : _in[key] = value.dynamic_unsqueeze(0).dynamic_expand(s);
283 4 : }
284 :
285 : void
286 0 : NEML2ModelExecutor::advanceState()
287 : {
288 0 : if (!_manage_state_advance || _t_step == 0)
289 0 : return;
290 :
291 0 : for (const auto & [name, val] : _state_vars)
292 : {
293 0 : auto sep = model().settings().history_separator();
294 0 : auto [base_name, order] = neml2::parse_history(name, sep);
295 : mooseAssert(order > 0, "Invalid history order");
296 : // cache value from the current step
297 : // favor output over input
298 0 : auto curr_name = order == 1 ? base_name : base_name + sep + std::to_string(order - 1);
299 0 : if (_out.count(curr_name))
300 0 : _state_vars[name] = _out.at(curr_name);
301 0 : else if (_in.count(curr_name))
302 0 : _state_vars[name] = _in.at(curr_name);
303 : else
304 0 : mooseError("Failed to find cached value for history variable: ", name);
305 0 : }
306 : }
307 :
308 : bool
309 1694 : NEML2ModelExecutor::solve()
310 : {
311 : try
312 : {
313 : // Evaluate the NEML2 material model
314 8470 : TIME_SECTION("NEML2 solve", 3, "Solving NEML2 material model");
315 :
316 : // NEML2 requires double precision
317 1694 : auto prev_dtype = neml2::get_default_dtype();
318 1694 : neml2::set_default_dtype(neml2::kFloat64);
319 :
320 1694 : if (scheduler())
321 : {
322 : // We only need consistent batch sizes if we are using the dispatcher
323 4 : expandInputs();
324 4 : neml2::ValueMapLoader loader(_in, 0);
325 4 : std::tie(_out, _dout_din) = dispatcher()->run(loader);
326 4 : }
327 : else
328 1690 : std::tie(_out, _dout_din) = model().value_and_dvalue(_in);
329 1653 : if (!_manage_state_advance)
330 1653 : _in.clear();
331 :
332 : // Restore the default dtype
333 1653 : neml2::set_default_dtype(prev_dtype);
334 1694 : }
335 41 : catch (std::exception & e)
336 : {
337 41 : _error_message = e.what();
338 41 : _error = true;
339 41 : if (_debug_inputs_on_failure)
340 : {
341 0 : auto shape_to_string = [](const neml2::TensorShapeRef & shape) -> std::string
342 : {
343 0 : std::ostringstream os;
344 0 : os << "(";
345 0 : for (std::size_t i = 0; i < shape.size(); ++i)
346 : {
347 0 : if (i)
348 0 : os << ", ";
349 0 : os << shape[i];
350 : }
351 0 : os << ")";
352 0 : return os.str();
353 0 : };
354 :
355 0 : std::ostringstream os;
356 0 : os << "\nNEML2 input variables:\n";
357 0 : for (const auto & [var, val] : model().input_variables())
358 : {
359 0 : os << " - " << var << ": ";
360 0 : const auto it = _in.find(var);
361 0 : if (it == _in.end())
362 0 : os << "missing\n";
363 0 : else if (!it->second.defined())
364 0 : os << "undefined\n";
365 : else
366 : {
367 0 : const auto & val = it->second;
368 0 : const auto & v = model().input_variable(var);
369 0 : neml2::TensorShape expected;
370 0 : const auto & intmd_sizes = v.intmd_sizes();
371 0 : expected.insert(expected.end(), intmd_sizes.begin(), intmd_sizes.end());
372 0 : const auto & base_sizes = v.base_sizes();
373 0 : expected.insert(expected.end(), base_sizes.begin(), base_sizes.end());
374 :
375 0 : os << "device=" << val.device() << " dtype=" << val.scalar_type()
376 0 : << " sizes=" << shape_to_string(val.sizes())
377 0 : << " batch=" << shape_to_string(val.batch_sizes().concrete())
378 0 : << " expected_base=" << shape_to_string(expected);
379 :
380 0 : if (val.numel() > 0)
381 : {
382 0 : auto cpu = val.detach().to(val.options().device(at::kCPU));
383 0 : auto flat = cpu.reshape({-1});
384 0 : auto min = flat.min().item<double>();
385 0 : auto max = flat.max().item<double>();
386 0 : auto mean = flat.mean().item<double>();
387 0 : auto has_nan = at::isnan(flat).any().item<bool>();
388 0 : auto has_inf = at::isinf(flat).any().item<bool>();
389 0 : os << " min=" << min << " max=" << max << " mean=" << mean
390 : << " nan=" << (has_nan ? "true" : "false")
391 0 : << " inf=" << (has_inf ? "true" : "false");
392 0 : }
393 :
394 0 : os << "\n";
395 0 : }
396 : }
397 :
398 0 : if (_manage_state_advance)
399 : {
400 0 : os << "NEML2 stateful variables:\n";
401 0 : for (const auto & [var, cached_val] : _state_vars)
402 : {
403 0 : os << " - " << var << ": ";
404 0 : const auto it_out = _out.find(var);
405 0 : const auto it_in = _in.find(var);
406 0 : if (it_out == _out.end() || it_in == _in.end())
407 0 : os << "missing\n";
408 : else
409 : {
410 0 : const auto it = it_out != _out.end() ? it_out : it_in;
411 0 : const auto & val = it->second;
412 0 : os << "device=" << val.device() << " dtype=" << val.scalar_type()
413 0 : << " sizes=" << shape_to_string(val.sizes())
414 0 : << " batch=" << shape_to_string(val.batch_sizes().concrete());
415 :
416 0 : if (val.numel() > 0)
417 : {
418 0 : auto cpu = val.detach().to(val.options().device(at::kCPU));
419 0 : auto flat = cpu.reshape({-1});
420 0 : auto min = flat.min().item<double>();
421 0 : auto max = flat.max().item<double>();
422 0 : auto mean = flat.mean().item<double>();
423 0 : auto has_nan = at::isnan(flat).any().item<bool>();
424 0 : auto has_inf = at::isinf(flat).any().item<bool>();
425 0 : os << " min=" << min << " max=" << max << " mean=" << mean
426 : << " nan=" << (has_nan ? "true" : "false")
427 0 : << " inf=" << (has_inf ? "true" : "false");
428 0 : }
429 :
430 0 : os << "\n";
431 : }
432 : }
433 : }
434 0 : _error_message += os.str();
435 0 : }
436 41 : }
437 :
438 1694 : return !_error;
439 : }
440 :
441 : void
442 1653 : NEML2ModelExecutor::extractOutputs()
443 : {
444 : try
445 : {
446 1653 : const auto N = _batch_index_generator.getBatchIndex();
447 :
448 : // retrieve outputs
449 3350 : for (auto & [y, target] : _retrieved_outputs)
450 1697 : target = _out[y].to(output_device());
451 :
452 : // retrieve parameter derivatives
453 1653 : for (auto & [y, dy] : _retrieved_parameter_derivatives)
454 0 : for (auto & [p, target] : dy)
455 0 : target = neml2::jacrev(_out[y],
456 0 : model().get_parameter(p),
457 : /*retain_graph=*/true,
458 : /*create_graph=*/false,
459 : /*allow_unused=*/false)
460 0 : .to(output_device());
461 :
462 : // clear output unless we need it for on-device state advance
463 1653 : if (!_manage_state_advance)
464 1653 : _out.clear();
465 :
466 : // retrieve derivatives
467 3306 : for (auto & [y, dy] : _retrieved_derivatives)
468 3306 : for (auto & [x, target] : dy)
469 : {
470 1653 : const auto & source = _dout_din[y][x];
471 1653 : if (source.defined())
472 4959 : target = source.to(output_device()).dynamic_expand({neml2::Size(N)});
473 : }
474 :
475 : // clear derivatives
476 1653 : _dout_din.clear();
477 : }
478 0 : catch (std::exception & e)
479 : {
480 0 : mooseError("An error occurred while retrieving outputs from the NEML2 model. Error message:\n",
481 0 : e.what(),
482 : NEML2Utils::NEML2_help_message);
483 0 : }
484 3306 : }
485 :
486 : void
487 3369 : NEML2ModelExecutor::finalize()
488 : {
489 3369 : if (!NEML2Utils::shouldCompute(_fe_problem))
490 1214 : return;
491 :
492 : // See if any rank failed
493 : processor_id_type pid;
494 2155 : _communicator.maxloc(_error, pid);
495 :
496 : // Fail the next nonlinear convergence check if any rank failed
497 2155 : if (_error)
498 : {
499 41 : _communicator.broadcast(_error_message, pid);
500 41 : if (_communicator.rank() == 0)
501 : {
502 82 : std::string msg = "NEML2 model execution failed on at least one processor with ID " +
503 123 : std::to_string(pid) + ". Error message:\n";
504 41 : msg += _error_message;
505 41 : if (_fe_problem.isTransient())
506 : msg += "\nTo recover, the solution will fail and then be re-attempted with a reduced time "
507 41 : "step.";
508 41 : _console << COLOR_YELLOW << msg << COLOR_DEFAULT << std::endl;
509 41 : }
510 41 : _fe_problem.setFailNextNonlinearConvergenceCheck();
511 : }
512 2114 : else if (_t_step > 0)
513 2091 : _output_ready = true;
514 : }
515 :
516 : void
517 195 : NEML2ModelExecutor::checkExecutionStage() const
518 : {
519 195 : if (_fe_problem.startedInitialSetup())
520 0 : mooseError("NEML2 output variables and derivatives must be retrieved during object "
521 : "construction. This is a code problem.");
522 195 : }
523 :
524 : const neml2::Tensor &
525 120 : NEML2ModelExecutor::getOutput(const neml2::VariableName & output_name) const
526 : {
527 120 : checkExecutionStage();
528 :
529 120 : if (!model().output_variables().count(output_name))
530 0 : mooseError("Trying to retrieve a non-existent NEML2 output variable '", output_name, "'.");
531 :
532 120 : return _retrieved_outputs[output_name];
533 : }
534 :
535 : const neml2::Tensor &
536 75 : NEML2ModelExecutor::getOutputDerivative(const neml2::VariableName & output_name,
537 : const neml2::VariableName & input_name) const
538 : {
539 75 : checkExecutionStage();
540 :
541 75 : if (!model().output_variables().count(output_name))
542 0 : mooseError("Trying to retrieve the derivative of NEML2 output variable '",
543 : output_name,
544 : "' with respect to NEML2 input variable '",
545 : input_name,
546 : "', but the NEML2 output variable does not exist.");
547 :
548 75 : if (!model().input_variables().count(input_name))
549 0 : mooseError("Trying to retrieve the derivative of NEML2 output variable '",
550 : output_name,
551 : "' with respect to NEML2 input variable '",
552 : input_name,
553 : "', but the NEML2 input variable does not exist.");
554 :
555 75 : return _retrieved_derivatives[output_name][input_name];
556 : }
557 :
558 : const neml2::Tensor &
559 0 : NEML2ModelExecutor::getOutputParameterDerivative(const neml2::VariableName & output_name,
560 : const std::string & parameter_name) const
561 : {
562 0 : checkExecutionStage();
563 :
564 0 : if (!model().output_variables().count(output_name))
565 0 : mooseError("Trying to retrieve the derivative of NEML2 output variable '",
566 : output_name,
567 : "' with respect to NEML2 model parameter '",
568 : parameter_name,
569 : "', but the NEML2 output variable does not exist.");
570 :
571 0 : if (model().named_parameters().count(parameter_name) != 1)
572 0 : mooseError("Trying to retrieve the derivative of NEML2 output variable '",
573 : output_name,
574 : "' with respect to NEML2 model parameter '",
575 : parameter_name,
576 : "', but the NEML2 model parameter does not exist.");
577 :
578 0 : return _retrieved_parameter_derivatives[output_name][parameter_name];
579 : }
580 :
581 : #endif
|