LCOV - code coverage report
Current view: top level - src/neml2/userobjects - NEML2ModelExecutor.C (source / functions) Hit Total Coverage
Test: idaholab/moose framework: #32971 (54bef8) with base c6cf66 Lines: 140 346 40.5 %
Date: 2026-05-29 20:35:17 Functions: 16 22 72.7 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #include "NEML2ModelExecutor.h"
      11             : #include "MOOSEToNEML2.h"
      12             : #include "NEML2Utils.h"
      13             : #include <string>
      14             : #include <sstream>
      15             : 
      16             : #ifdef NEML2_ENABLED
      17             : #include <ATen/ATen.h>
      18             : #include "libmesh/id_types.h"
      19             : #include "neml2/tensors/functions/jacrev.h"
      20             : #include "neml2/dispatchers/ValueMapLoader.h"
      21             : #include "neml2/misc/string_utils.h"
      22             : #endif
      23             : 
      24             : registerMooseObject("MooseApp", NEML2ModelExecutor);
      25             : 
      26             : InputParameters
      27        3182 : NEML2ModelExecutor::actionParams()
      28             : {
      29        3182 :   auto params = emptyInputParameters();
      30             :   // allow user to explicit skip required input variables
      31       12728 :   params.addParam<std::vector<std::string>>(
      32             :       "skip_inputs",
      33             :       {},
      34             :       "List of NEML2 variables to skip error checking when setting up the model input. If an "
      35             :       "input variable is skipped, its value will stay zero. If a required input variable is "
      36             :       "not skipped, an error will be raised.");
      37        9546 :   params.addParam<bool>(
      38             :       "keep_tensors_on_device",
      39        6364 :       false,
      40             :       "Keep state and forces on the device and advance it to old_state and old_forces without a "
      41             :       "roundtrip through MOOSE materials. This is only recommended for explicit time integration "
      42             :       "or when absolutely no restepping occurs (e.g. failed timesteps).");
      43        6364 :   params.addParam<bool>(
      44             :       "debug_inputs_on_failure",
      45        6364 :       false,
      46             :       "When a NEML2 solve fails, append a detailed dump of input tensors (defined/missing, "
      47             :       "shapes, and devices) to the error message.");
      48        3182 :   return params;
      49           0 : }
      50             : 
      51             : InputParameters
      52        3077 : NEML2ModelExecutor::validParams()
      53             : {
      54        3077 :   auto params = NEML2ModelInterface<GeneralUserObject>::validParams();
      55        3077 :   params += NEML2ModelExecutor::actionParams();
      56        6154 :   params.addClassDescription("Execute the specified NEML2 model");
      57             : 
      58       12308 :   params.addRequiredParam<UserObjectName>(
      59             :       "batch_index_generator",
      60             :       "The NEML2BatchIndexGenerator used to generate the element-to-batch-index map.");
      61       12308 :   params.addParam<std::vector<UserObjectName>>(
      62             :       "gatherers",
      63             :       {},
      64             :       "List of MOOSE*ToNEML2 user objects gathering MOOSE data as NEML2 input variables");
      65        9231 :   params.addParam<std::vector<UserObjectName>>(
      66             :       "param_gatherers",
      67             :       {},
      68             :       "List of MOOSE*ToNEML2 user objects gathering MOOSE data as NEML2 model parameters");
      69             : 
      70             :   // Since we use the NEML2 model to evaluate the residual AND the Jacobian at the same time, we
      71             :   // want to execute this user object only at execute_on = LINEAR (i.e. during residual evaluation).
      72             :   // The NONLINEAR exec flag below is for computing Jacobian during automatic scaling.
      73        3077 :   ExecFlagEnum execute_options = MooseUtils::getDefaultExecFlagEnum();
      74       15385 :   execute_options = {EXEC_INITIAL, EXEC_LINEAR, EXEC_NONLINEAR, EXEC_TIMESTEP_END};
      75        3077 :   params.set<ExecFlagEnum>("execute_on") = execute_options;
      76             : 
      77        6154 :   return params;
      78        6154 : }
      79             : 
      80           8 : NEML2ModelExecutor::NEML2ModelExecutor(const InputParameters & params)
      81           0 :   : NEML2ModelInterface<GeneralUserObject>(params)
      82             : #ifdef NEML2_ENABLED
      83             :     ,
      84           8 :     _batch_index_generator(getUserObject<NEML2BatchIndexGenerator>("batch_index_generator")),
      85          16 :     _keep_tensors_on_device(getParam<bool>("keep_tensors_on_device")),
      86          16 :     _debug_inputs_on_failure(getParam<bool>("debug_inputs_on_failure")),
      87           8 :     _output_ready(false),
      88          24 :     _error_message("")
      89             : #endif
      90             : {
      91             : #ifdef NEML2_ENABLED
      92           8 :   validateModel();
      93             : 
      94             :   // add user object dependencies by name (the UOs do not need to exist yet for this)
      95          32 :   for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("gatherers"))
      96           8 :     _depend_uo.insert(gatherer_name);
      97          24 :   for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("param_gatherers"))
      98           0 :     _depend_uo.insert(gatherer_name);
      99             : 
     100             :   // variables to skip error checking (converting vector to set to prevent duplicate checks)
     101          24 :   for (const auto & var_name : getParam<std::vector<std::string>>("skip_inputs"))
     102           0 :     _skip_vars.insert(NEML2Utils::parseVariableName(var_name));
     103             : #endif
     104           8 : }
     105             : 
     106             : #ifdef NEML2_ENABLED
     107             : void
     108           8 : NEML2ModelExecutor::initialSetup()
     109             : {
     110             :   // deal with user object provided inputs
     111          32 :   for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("gatherers"))
     112             :   {
     113             :     // gather coupled user objects late to ensure they are constructed. Do not add them as
     114             :     // dependencies (that's already done in the constructor).
     115           8 :     const auto & uo = getUserObjectByName<MOOSEToNEML2>(gatherer_name, /*is_dependency=*/false);
     116             : 
     117             :     // the target neml2 variable must exist on the input axis
     118           8 :     if (!model().input_axis().has_variable(NEML2Utils::parseVariableName(uo.NEML2Name())))
     119           0 :       mooseError("The MOOSEToNEML2 gatherer named '",
     120             :                  gatherer_name,
     121             :                  "' is gathering MOOSE data for a non-existent NEML2 input variable named '",
     122           0 :                  uo.NEML2Name(),
     123             :                  "'.");
     124             : 
     125             :     // tell the gatherer to gather for a model input variable
     126           8 :     const auto varname = NEML2Utils::parseVariableName(uo.NEML2Name());
     127           8 :     if (varname.is_old_force() || varname.is_old_state())
     128           0 :       uo.setMode(MOOSEToNEML2::Mode::OLD_VARIABLE);
     129             :     else
     130           8 :       uo.setMode(MOOSEToNEML2::Mode::VARIABLE);
     131             : 
     132           8 :     addGatheredVariable(gatherer_name, uo.NEML2VariableName());
     133           8 :     _gatherers.push_back(&uo);
     134           8 :   }
     135             : 
     136             :   // deal with user object provided model parameters
     137          24 :   for (const auto & gatherer_name : getParam<std::vector<UserObjectName>>("param_gatherers"))
     138             :   {
     139             :     // gather coupled user objects late to ensure they are constructed. Do not add them as
     140             :     // dependencies (that's already done in the constructor).
     141           0 :     const auto & uo = getUserObjectByName<MOOSEToNEML2>(gatherer_name, /*is_dependency=*/false);
     142             : 
     143             :     // introspect the NEML2 model to figure out if the gatherer UO is gathering for a NEML2 input
     144             :     // variable or for a NEML2 model parameter
     145           0 :     if (model().named_parameters().count(uo.NEML2Name()) != 1)
     146           0 :       mooseError("The MOOSEToNEML2 gatherer named '",
     147             :                  gatherer_name,
     148             :                  "' is gathering MOOSE data for a non-existent NEML2 model parameter named '",
     149           0 :                  uo.NEML2Name(),
     150             :                  "'.");
     151             : 
     152             :     // tell the gatherer to gather for a model parameter
     153           0 :     uo.setMode(MOOSEToNEML2::Mode::PARAMETER);
     154             : 
     155           0 :     addGatheredParameter(gatherer_name, uo.NEML2ParameterName());
     156           0 :     _gatherers.push_back(&uo);
     157             :   }
     158             : 
     159             :   // iterate over set of required inputs and error out if we find one that is not provided
     160           8 :   std::vector<neml2::VariableName> required_inputs = model().input_axis().variable_names();
     161          16 :   for (const auto & input : required_inputs)
     162             :   {
     163           8 :     if (_skip_vars.count(input))
     164           0 :       continue;
     165             :     // skip input state variables because they are "initial guesses" to the nonlinear system
     166          16 :     if (input.is_state() ||
     167           8 :         (_keep_tensors_on_device && (input.is_old_state() || input.is_old_force())))
     168           0 :       continue;
     169           8 :     if (!_gathered_variable_names.count(input))
     170           0 :       paramError("gatherers", "The required model input `", input, "` is not gathered");
     171             :   }
     172             : 
     173             :   // If a variable is stateful, then it'd better been retrieved by someone! In theory that's not
     174             :   // sufficient for stateful data management, but that's the best we can do here without being
     175             :   // overly restrictive.
     176          16 :   for (const auto & input : required_inputs)
     177           8 :     if (!_keep_tensors_on_device && input.is_old_state() &&
     178           8 :         !_retrieved_outputs.count(input.current()))
     179           0 :       mooseError(
     180             :           "The NEML2 model requires a stateful input variable `",
     181             :           input,
     182             :           "`, but its state counterpart on the output axis has not been retrieved by any object. "
     183             :           "Therefore, there is no way to properly propagate the corresponding stateful data in "
     184             :           "time. The common solution to this problem is to add a NEML2ToMOOSE retriever such as "
     185             :           "those called `NEML2To*MOOSEMaterialProperty`.");
     186             : 
     187             :   // check if the model has state/old_state
     188          16 :   for (const auto & [vname, var] : model().input_variables())
     189             :   {
     190           8 :     if (vname.is_state())
     191           0 :       _has_state = true;
     192           8 :     if (vname.is_old_state())
     193           0 :       _has_old_state = true;
     194             :   }
     195           8 : }
     196             : 
     197             : std::size_t
     198       40120 : NEML2ModelExecutor::getBatchIndex(dof_id_type elem_id) const
     199             : {
     200       40120 :   return _batch_index_generator.getBatchIndex(elem_id);
     201             : }
     202             : 
     203             : void
     204           8 : NEML2ModelExecutor::addGatheredVariable(const UserObjectName & gatherer_name,
     205             :                                         const neml2::VariableName & var)
     206             : {
     207           8 :   if (_gathered_variable_names.count(var))
     208           0 :     paramError("gatherers",
     209             :                "The NEML2 input variable `",
     210             :                var,
     211             :                "` gathered by UO '",
     212             :                gatherer_name,
     213             :                "' is already gathered by another gatherer.");
     214           8 :   _gathered_variable_names.insert(var);
     215           8 : }
     216             : 
     217             : void
     218           0 : NEML2ModelExecutor::addGatheredParameter(const UserObjectName & gatherer_name,
     219             :                                          const std::string & param)
     220             : {
     221           0 :   if (_gathered_parameter_names.count(param))
     222           0 :     paramError("gatherers",
     223             :                "The NEML2 model parameter `",
     224             :                param,
     225             :                "` gathered by UO '",
     226             :                gatherer_name,
     227             :                "' is already gathered by another gatherer.");
     228           0 :   _gathered_parameter_names.insert(param);
     229           0 : }
     230             : 
     231             : void
     232        2815 : NEML2ModelExecutor::initialize()
     233             : {
     234        2815 :   if (!NEML2Utils::shouldCompute(_fe_problem))
     235        1050 :     return;
     236             : 
     237        1765 :   _output_ready = false;
     238        1765 :   _error = false;
     239        1765 :   _error_message.clear();
     240             : }
     241             : 
     242             : void
     243           0 : NEML2ModelExecutor::meshChanged()
     244             : {
     245           0 :   if (!NEML2Utils::shouldCompute(_fe_problem))
     246           0 :     return;
     247             : 
     248           0 :   _output_ready = false;
     249           0 :   if (_keep_tensors_on_device)
     250           0 :     mooseError("The mesh changed while `keep_tensors_on_device = true` for NEML2 model executor '",
     251           0 :                name(),
     252             :                "'. This mode requires a fixed mesh because state history is cached on the device.");
     253             : }
     254             : 
     255             : void
     256        2815 : NEML2ModelExecutor::execute()
     257             : {
     258        2815 :   if (!NEML2Utils::shouldCompute(_fe_problem))
     259        1050 :     return;
     260             : 
     261        1765 :   if (_current_execute_flag == EXEC_TIMESTEP_END)
     262             :   {
     263         350 :     if (_keep_tensors_on_device && _fe_problem.solverSystemConverged(/*sys_num=*/0))
     264           0 :       advanceDeviceCaches();
     265         350 :     return;
     266             :   }
     267             : 
     268             :   // If the batch is empty, we do not need to do anything
     269        1415 :   if (_batch_index_generator.isEmpty())
     270           0 :     return;
     271             : 
     272        1415 :   fillInputs();
     273             : 
     274        1415 :   if (_t_step > 0)
     275             :   {
     276        1408 :     applyPredictor();
     277        1408 :     auto success = solve();
     278        1408 :     if (success)
     279        1408 :       extractOutputs();
     280             :   }
     281             : }
     282             : 
     283             : void
     284        1415 : NEML2ModelExecutor::fillInputs()
     285             : {
     286             :   try
     287             :   {
     288        2830 :     for (const auto & uo : _gatherers)
     289        1415 :       uo->insertInto(_in, _model_params);
     290             : 
     291        1415 :     if (_keep_tensors_on_device && _t_step > 0)
     292             :     {
     293           0 :       for (const auto & [name, val] : _device_state_cache)
     294           0 :         if (val.defined() && model().input_axis().has_variable(name.old()))
     295           0 :           _in[name.old()] = val;
     296           0 :       for (const auto & [name, val] : _device_forces_cache)
     297           0 :         if (val.defined() && model().input_axis().has_variable(name.old()))
     298           0 :           _in[name.old()] = val;
     299             :     }
     300             : 
     301             :     // Initialize missing inputs that are allowed to be absent
     302        1415 :     if (_keep_tensors_on_device || !_skip_vars.empty())
     303             :     {
     304           0 :       const auto options = neml2::default_tensor_options().dtype(neml2::kFloat64).device(device());
     305           0 :       const auto shape = neml2::TensorShape{neml2::Size(_batch_index_generator.getBatchIndex())};
     306             : 
     307           0 :       for (const auto & [vname, var] : model().input_variables())
     308             :       {
     309           0 :         const auto it = _in.find(vname);
     310           0 :         if (it != _in.end() && it->second.defined())
     311           0 :           continue;
     312             : 
     313           0 :         if (!_skip_vars.count(vname) && !vname.is_state() &&
     314           0 :             !(_keep_tensors_on_device && (vname.is_old_state() || vname.is_old_force())))
     315           0 :           continue;
     316             : 
     317           0 :         _in[vname] = var->zeros(options).dynamic_expand({shape});
     318             :       }
     319           0 :     }
     320             : 
     321             :     // Send input variables and parameters to device
     322        2830 :     for (auto & [var, val] : _in)
     323        1415 :       val = val.to(device());
     324        1415 :     for (auto & [param, pval] : _model_params)
     325           0 :       pval = pval.to(device());
     326             : 
     327             :     // Update model parameters
     328        1415 :     model().set_parameters(_model_params);
     329        1415 :     _model_params.clear();
     330             : 
     331             :     // Request gradient for the model parameters that we request AD for
     332        1415 :     for (const auto & [y, dy] : _retrieved_parameter_derivatives)
     333           0 :       for (const auto & [p, tensor] : dy)
     334           0 :         model().get_parameter(p).requires_grad_(true);
     335             :   }
     336           0 :   catch (std::exception & e)
     337             :   {
     338           0 :     mooseError("An error occurred while filling inputs for the NEML2 model. Error message:\n",
     339           0 :                e.what(),
     340             :                NEML2Utils::NEML2_help_message);
     341           0 :   }
     342        1415 : }
     343             : 
     344             : void
     345        1408 : NEML2ModelExecutor::applyPredictor()
     346             : {
     347             :   try
     348             :   {
     349        1408 :     if (!_has_state || !_has_old_state)
     350        1408 :       return;
     351             : 
     352             :     // Set trial state variables (i.e., initial guesses).
     353             :     // Right now we hard-code to use the old state as the trial state.
     354             :     // TODO: implement other predictors
     355           0 :     const auto & input_state = model().input_axis().subaxis(neml2::STATE);
     356           0 :     const auto & input_old_state = model().input_axis().subaxis(neml2::OLD_STATE);
     357           0 :     for (const auto & var : input_state.variable_names())
     358           0 :       if (input_old_state.has_variable(var))
     359             :       {
     360           0 :         const auto old_name = var.prepend(neml2::OLD_STATE);
     361           0 :         const auto it = _in.find(old_name);
     362           0 :         if (it != _in.end() && it->second.defined())
     363           0 :           _in[var.prepend(neml2::STATE)] = it->second;
     364           0 :       }
     365             :   }
     366           0 :   catch (std::exception & e)
     367             :   {
     368           0 :     mooseError("An error occurred while applying predictor for the NEML2 model. Error message:\n",
     369           0 :                e.what(),
     370             :                NEML2Utils::NEML2_help_message);
     371           0 :   }
     372             : }
     373             : 
     374             : void
     375           0 : NEML2ModelExecutor::expandInputs()
     376             : {
     377             :   // Figure out what our batch size is
     378           0 :   std::vector<neml2::Tensor> defined;
     379           0 :   for (const auto & [key, value] : _in)
     380           0 :     defined.push_back(value);
     381           0 :   const auto s = neml2::utils::broadcast_dynamic_sizes(defined);
     382             : 
     383             :   // Make all inputs conformal
     384           0 :   for (auto & [key, value] : _in)
     385           0 :     if (value.dynamic_sizes() != s)
     386           0 :       _in[key] = value.dynamic_unsqueeze(0).dynamic_expand(s);
     387           0 : }
     388             : 
     389             : void
     390           0 : NEML2ModelExecutor::advanceDeviceCaches()
     391             : {
     392           0 :   if (!_keep_tensors_on_device || _t_step == 0)
     393           0 :     return;
     394             : 
     395           0 :   _device_state_cache.clear();
     396           0 :   for (const auto & [name, val] : _out)
     397           0 :     if (name.is_state() && val.defined())
     398           0 :       _device_state_cache[name] = val;
     399             : 
     400           0 :   _device_forces_cache.clear();
     401           0 :   for (const auto & [name, val] : _in)
     402           0 :     if (name.is_force() && !name.is_old_force() && val.defined())
     403           0 :       _device_forces_cache[name] = val;
     404             : }
     405             : 
     406             : bool
     407        1408 : NEML2ModelExecutor::solve()
     408             : {
     409             :   try
     410             :   {
     411             :     // Evaluate the NEML2 material model
     412        7040 :     TIME_SECTION("NEML2 solve", 3, "Solving NEML2 material model");
     413             : 
     414             :     // NEML2 requires double precision
     415        1408 :     auto prev_dtype = neml2::get_default_dtype();
     416        1408 :     neml2::set_default_dtype(neml2::kFloat64);
     417             : 
     418        1408 :     if (scheduler())
     419             :     {
     420             :       // We only need consistent batch sizes if we are using the dispatcher
     421           0 :       expandInputs();
     422           0 :       neml2::ValueMapLoader loader(_in, 0);
     423           0 :       std::tie(_out, _dout_din) = dispatcher()->run(loader);
     424           0 :     }
     425             :     else
     426        1408 :       std::tie(_out, _dout_din) = model().value_and_dvalue(_in);
     427        1408 :     if (!_keep_tensors_on_device)
     428        1408 :       _in.clear();
     429             : 
     430             :     // Restore the default dtype
     431        1408 :     neml2::set_default_dtype(prev_dtype);
     432        1408 :   }
     433           0 :   catch (std::exception & e)
     434             :   {
     435           0 :     _error_message = e.what();
     436           0 :     _error = true;
     437           0 :     if (_debug_inputs_on_failure)
     438             :     {
     439           0 :       auto shape_to_string = [](const neml2::TensorShapeRef & shape) -> std::string
     440             :       {
     441           0 :         std::ostringstream os;
     442           0 :         os << "[";
     443           0 :         for (std::size_t i = 0; i < shape.size(); ++i)
     444             :         {
     445           0 :           if (i)
     446           0 :             os << ", ";
     447           0 :           os << shape[i];
     448             :         }
     449           0 :         os << "]";
     450           0 :         return os.str();
     451           0 :       };
     452             : 
     453           0 :       std::ostringstream os;
     454           0 :       os << "\nNEML2 input debug (input map + expected shapes):\n";
     455           0 :       for (const auto & var : model().input_axis().variable_names())
     456             :       {
     457           0 :         os << "  - " << neml2::utils::stringify(var) << ": ";
     458           0 :         const auto it = _in.find(var);
     459           0 :         if (it == _in.end())
     460           0 :           os << "missing\n";
     461           0 :         else if (!it->second.defined())
     462           0 :           os << "undefined\n";
     463             :         else
     464             :         {
     465           0 :           const auto & val = it->second;
     466           0 :           const auto & v = model().input_variable(var);
     467           0 :           neml2::TensorShape expected;
     468           0 :           const auto & intmd_sizes = v.intmd_sizes();
     469           0 :           expected.insert(expected.end(), intmd_sizes.begin(), intmd_sizes.end());
     470           0 :           const auto & base_sizes = v.base_sizes();
     471           0 :           expected.insert(expected.end(), base_sizes.begin(), base_sizes.end());
     472             : 
     473           0 :           os << "device=" << neml2::utils::stringify(val.device())
     474           0 :              << " dtype=" << neml2::utils::stringify(val.scalar_type())
     475           0 :              << " sizes=" << shape_to_string(val.sizes())
     476           0 :              << " batch=" << shape_to_string(val.batch_sizes().concrete())
     477           0 :              << " expected_base=" << shape_to_string(expected);
     478             : 
     479           0 :           if (val.numel() > 0)
     480             :           {
     481           0 :             auto cpu = val.detach().to(val.options().device(at::kCPU));
     482           0 :             auto flat = cpu.reshape({-1});
     483           0 :             auto min = flat.min().item<double>();
     484           0 :             auto max = flat.max().item<double>();
     485           0 :             auto mean = flat.mean().item<double>();
     486           0 :             auto has_nan = at::isnan(flat).any().item<bool>();
     487           0 :             auto has_inf = at::isinf(flat).any().item<bool>();
     488           0 :             os << " min=" << min << " max=" << max << " mean=" << mean
     489             :                << " nan=" << (has_nan ? "true" : "false")
     490           0 :                << " inf=" << (has_inf ? "true" : "false");
     491           0 :           }
     492             : 
     493           0 :           os << "\n";
     494           0 :         }
     495             :       }
     496             : 
     497           0 :       if (_keep_tensors_on_device && model().input_axis().has_subaxis(neml2::OLD_STATE) &&
     498           0 :           model().output_axis().has_subaxis(neml2::STATE))
     499             :       {
     500           0 :         os << "NEML2 cached outputs (state for old_state inputs):\n";
     501           0 :         const auto & input_old_state = model().input_axis().subaxis(neml2::OLD_STATE);
     502           0 :         for (const auto & var : input_old_state.variable_names())
     503             :         {
     504           0 :           const auto state_var = var.prepend(neml2::STATE);
     505           0 :           os << "  - " << neml2::utils::stringify(state_var) << ": ";
     506           0 :           const auto it = _out.find(state_var);
     507           0 :           if (it == _out.end())
     508           0 :             os << "missing\n";
     509           0 :           else if (!it->second.defined())
     510           0 :             os << "undefined\n";
     511             :           else
     512             :           {
     513           0 :             const auto & val = it->second;
     514           0 :             os << "device=" << neml2::utils::stringify(val.device())
     515           0 :                << " dtype=" << neml2::utils::stringify(val.scalar_type())
     516           0 :                << " sizes=" << shape_to_string(val.sizes())
     517           0 :                << " batch=" << shape_to_string(val.batch_sizes().concrete());
     518             : 
     519           0 :             if (val.numel() > 0)
     520             :             {
     521           0 :               auto cpu = val.detach().to(val.options().device(at::kCPU));
     522           0 :               auto flat = cpu.reshape({-1});
     523           0 :               auto min = flat.min().item<double>();
     524           0 :               auto max = flat.max().item<double>();
     525           0 :               auto mean = flat.mean().item<double>();
     526           0 :               auto has_nan = at::isnan(flat).any().item<bool>();
     527           0 :               auto has_inf = at::isinf(flat).any().item<bool>();
     528           0 :               os << " min=" << min << " max=" << max << " mean=" << mean
     529             :                  << " nan=" << (has_nan ? "true" : "false")
     530           0 :                  << " inf=" << (has_inf ? "true" : "false");
     531           0 :             }
     532             : 
     533           0 :             os << "\n";
     534             :           }
     535           0 :         }
     536             :       }
     537           0 :       _error_message += os.str();
     538           0 :     }
     539           0 :   }
     540             : 
     541        1408 :   return !_error;
     542             : }
     543             : 
     544             : void
     545        1408 : NEML2ModelExecutor::extractOutputs()
     546             : {
     547             :   try
     548             :   {
     549        1408 :     const auto N = _batch_index_generator.getBatchIndex();
     550             : 
     551             :     // retrieve outputs
     552        2816 :     for (auto & [y, target] : _retrieved_outputs)
     553        1408 :       target = _out[y].to(output_device());
     554             : 
     555             :     // retrieve parameter derivatives
     556        1408 :     for (auto & [y, dy] : _retrieved_parameter_derivatives)
     557           0 :       for (auto & [p, target] : dy)
     558           0 :         target = neml2::jacrev(_out[y],
     559           0 :                                model().get_parameter(p),
     560             :                                /*retain_graph=*/true,
     561             :                                /*create_graph=*/false,
     562             :                                /*allow_unused=*/false)
     563           0 :                      .to(output_device());
     564             : 
     565             :     // clear output unless we need it for on-device state advance
     566        1408 :     if (!_keep_tensors_on_device)
     567        1408 :       _out.clear();
     568             : 
     569             :     // retrieve derivatives
     570        2816 :     for (auto & [y, dy] : _retrieved_derivatives)
     571        2816 :       for (auto & [x, target] : dy)
     572             :       {
     573        1408 :         const auto & source = _dout_din[y][x];
     574        1408 :         if (source.defined())
     575        4224 :           target = source.to(output_device()).dynamic_expand({neml2::Size(N)});
     576             :       }
     577             : 
     578             :     // clear derivatives
     579        1408 :     _dout_din.clear();
     580             :   }
     581           0 :   catch (std::exception & e)
     582             :   {
     583           0 :     mooseError("An error occurred while retrieving outputs from the NEML2 model. Error message:\n",
     584           0 :                e.what(),
     585             :                NEML2Utils::NEML2_help_message);
     586           0 :   }
     587        2816 : }
     588             : 
     589             : void
     590        2815 : NEML2ModelExecutor::finalize()
     591             : {
     592        2815 :   if (!NEML2Utils::shouldCompute(_fe_problem))
     593        1050 :     return;
     594             : 
     595             :   // See if any rank failed
     596             :   processor_id_type pid;
     597        1765 :   _communicator.maxloc(_error, pid);
     598             : 
     599             :   // Fail the next nonlinear convergence check if any rank failed
     600        1765 :   if (_error)
     601             :   {
     602           0 :     _communicator.broadcast(_error_message, pid);
     603           0 :     if (_communicator.rank() == 0)
     604             :     {
     605           0 :       std::string msg = "NEML2 model execution failed on at least one processor with ID " +
     606           0 :                         std::to_string(pid) + ". Error message:\n";
     607           0 :       msg += _error_message;
     608           0 :       if (_fe_problem.isTransient())
     609             :         msg += "\nTo recover, the solution will fail and then be re-attempted with a reduced time "
     610           0 :                "step.";
     611           0 :       _console << COLOR_YELLOW << msg << COLOR_DEFAULT << std::endl;
     612           0 :     }
     613           0 :     _fe_problem.setFailNextNonlinearConvergenceCheck();
     614             :   }
     615        1765 :   else if (_t_step > 0)
     616        1758 :     _output_ready = true;
     617             : }
     618             : 
     619             : void
     620          54 : NEML2ModelExecutor::checkExecutionStage() const
     621             : {
     622          54 :   if (_fe_problem.startedInitialSetup())
     623           0 :     mooseError("NEML2 output variables and derivatives must be retrieved during object "
     624             :                "construction. This is a code problem.");
     625          54 : }
     626             : 
     627             : const neml2::Tensor &
     628          27 : NEML2ModelExecutor::getOutput(const neml2::VariableName & output_name) const
     629             : {
     630          27 :   checkExecutionStage();
     631             : 
     632          27 :   if (!model().output_axis().has_variable(output_name))
     633           0 :     mooseError("Trying to retrieve a non-existent NEML2 output variable '", output_name, "'.");
     634             : 
     635          27 :   return _retrieved_outputs[output_name];
     636             : }
     637             : 
     638             : const neml2::Tensor &
     639          27 : NEML2ModelExecutor::getOutputDerivative(const neml2::VariableName & output_name,
     640             :                                         const neml2::VariableName & input_name) const
     641             : {
     642          27 :   checkExecutionStage();
     643             : 
     644          27 :   if (!model().output_axis().has_variable(output_name))
     645           0 :     mooseError("Trying to retrieve the derivative of NEML2 output variable '",
     646             :                output_name,
     647             :                "' with respect to NEML2 input variable '",
     648             :                input_name,
     649             :                "', but the NEML2 output variable does not exist.");
     650             : 
     651          27 :   if (!model().input_axis().has_variable(input_name))
     652           0 :     mooseError("Trying to retrieve the derivative of NEML2 output variable '",
     653             :                output_name,
     654             :                "' with respect to NEML2 input variable '",
     655             :                input_name,
     656             :                "', but the NEML2 input variable does not exist.");
     657             : 
     658          27 :   return _retrieved_derivatives[output_name][input_name];
     659             : }
     660             : 
     661             : const neml2::Tensor &
     662           0 : NEML2ModelExecutor::getOutputParameterDerivative(const neml2::VariableName & output_name,
     663             :                                                  const std::string & parameter_name) const
     664             : {
     665           0 :   checkExecutionStage();
     666             : 
     667           0 :   if (!model().output_axis().has_variable(output_name))
     668           0 :     mooseError("Trying to retrieve the derivative of NEML2 output variable '",
     669             :                output_name,
     670             :                "' with respect to NEML2 model parameter '",
     671             :                parameter_name,
     672             :                "', but the NEML2 output variable does not exist.");
     673             : 
     674           0 :   if (model().named_parameters().count(parameter_name) != 1)
     675           0 :     mooseError("Trying to retrieve the derivative of NEML2 output variable '",
     676             :                output_name,
     677             :                "' with respect to NEML2 model parameter '",
     678             :                parameter_name,
     679             :                "', but the NEML2 model parameter does not exist.");
     680             : 
     681           0 :   return _retrieved_parameter_derivatives[output_name][parameter_name];
     682             : }
     683             : 
     684             : #endif

Generated by: LCOV version 1.14