LCOV - code coverage report
Current view: top level - src/neml2 - LAROMANCE6DInterpolation.C (source / functions) Hit Total Coverage
Test: idaholab/moose solid_mechanics: #33187 (5aa0b2) with base d7c4bd Lines: 213 222 95.9 %
Date: 2026-06-30 12:24:09 Functions: 19 19 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://mooseframework.inl.gov
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #include "LAROMANCE6DInterpolation.h"
      11             : 
      12             : #ifdef NEML2_ENABLED
      13             : 
      14             : #include <fstream>
      15             : #include <initializer_list>
      16             : #include <torch/torch.h>
      17             : 
      18             : #include "neml2/tensors/functions/sign.h"
      19             : #include "neml2/tensors/functions/abs.h"
      20             : #include "neml2/tensors/functions/pow.h"
      21             : #include "neml2/tensors/functions/log10.h"
      22             : #include "neml2/tensors/functions/clamp.h"
      23             : #include "neml2/tensors/functions/stack.h"
      24             : 
      25             : namespace neml2
      26             : {
      27             : register_NEML2_object(LAROMANCE6DInterpolation);
      28             : 
      29             : OptionSet
      30         173 : LAROMANCE6DInterpolation::expected_options()
      31             : {
      32         173 :   auto options = Model::expected_options();
      33             :   options.doc() =
      34             :       "Multilinear interpolation over six dimensions (von_mises_stress, temperature, "
      35             :       "equivalent_plastic_strain, cell_dislocation_density, wall_dislocation_density, env_factor)";
      36             : 
      37             :   // Model inputs
      38         346 :   options.add_input("equivalent_plastic_strain", "The equivalent plastic strain");
      39         346 :   options.add_input("von_mises_stress", "The von Mises stress");
      40         346 :   options.add_input("cell_dislocation_density", "The cell dislocation density");
      41         346 :   options.add_input("wall_dislocation_density", "The wall dislocation density");
      42         346 :   options.add_input("temperature", "The temperature");
      43         346 :   options.add_input("env_factor", "The environment factor");
      44             : 
      45             :   // Model Outputs
      46         346 :   options.add_output("output_rate", "The output rate");
      47             : 
      48             :   // JSON
      49         346 :   options.add<std::string>("model_file_name", "The name of the model file");
      50         346 :   options.add<std::string>("model_file_variable_name",
      51             :                            "The name of the variable in the model file");
      52             : 
      53             :   // jit does not currently work with this
      54         173 :   options.set<bool>("jit", false);
      55         173 :   options.suppress("jit");
      56             : 
      57         173 :   return options;
      58           0 : }
      59             : 
      60          60 : LAROMANCE6DInterpolation::LAROMANCE6DInterpolation(const OptionSet & options)
      61             :   : Model(options),
      62          60 :     _vm_stress(declare_input_variable<Scalar>("von_mises_stress")),
      63          60 :     _temperature(declare_input_variable<Scalar>("temperature")),
      64          60 :     _ep_strain(declare_input_variable<Scalar>("equivalent_plastic_strain")),
      65          60 :     _cell_dd(declare_input_variable<Scalar>("cell_dislocation_density")),
      66          60 :     _wall_dd(declare_input_variable<Scalar>("wall_dislocation_density")),
      67          60 :     _env_fac(declare_input_variable<Scalar>("env_factor")),
      68         120 :     _output_rate(declare_output_variable<Scalar>("output_rate"))
      69             : {
      70         120 :   std::string filename = options.get<std::string>("model_file_name");
      71          60 :   std::ifstream model_file(filename.c_str());
      72          60 :   model_file >> _json;
      73             : 
      74             :   // storing grid points for indexing.
      75             :   // these should be stored differently so that they are all read in at once.  The order of this can
      76             :   // get messed up easily
      77          60 :   _stress_grid = json_vector_to_torch("in_stress");
      78          60 :   _temperature_grid = json_vector_to_torch("in_temperature");
      79          60 :   _plastic_strain_grid = json_vector_to_torch("in_plastic_strain");
      80          60 :   _cell_grid = json_vector_to_torch("in_cell");
      81          60 :   _wall_grid = json_vector_to_torch("in_wall");
      82          60 :   _env_grid = json_vector_to_torch("in_env");
      83             : 
      84             :   // Read in grid axes transform enums
      85          60 :   _stress_transform_enum = get_transform_enum(json_to_string("in_stress_transform_type"));
      86          60 :   _temperature_transform_enum = get_transform_enum(json_to_string("in_temperature_transform_type"));
      87          60 :   _plastic_strain_transform_enum =
      88          60 :       get_transform_enum(json_to_string("in_plastic_strain_transform_type"));
      89          60 :   _cell_transform_enum = get_transform_enum(json_to_string("in_cell_transform_type"));
      90          60 :   _wall_transform_enum = get_transform_enum(json_to_string("in_wall_transform_type"));
      91          60 :   _env_transform_enum = get_transform_enum(json_to_string("in_env_transform_type"));
      92             : 
      93             :   // Read in grid axes transform values
      94          60 :   _stress_transform_values = json_to_vector("in_stress_transform_values");
      95          60 :   _temperature_transform_values = json_to_vector("in_temperature_transform_values");
      96          60 :   _plastic_strain_transform_values = json_to_vector("in_plastic_strain_transform_values");
      97          60 :   _cell_transform_values = json_to_vector("in_cell_transform_values");
      98          60 :   _wall_transform_values = json_to_vector("in_wall_transform_values");
      99          60 :   _env_transform_values = json_to_vector("in_env_transform_values");
     100             : 
     101             :   // Storing values for interpolation
     102          60 :   _output_rate_name = options.get<std::string>("model_file_variable_name");
     103          60 :   _grid_values = json_6Dvector_to_torch(_output_rate_name);
     104             : 
     105             :   // set up output transforms
     106          60 :   if (_output_rate_name == "out_ep")
     107             :   {
     108          20 :     _output_transform_enum = get_transform_enum(json_to_string("out_strain_rate_transform_type"));
     109          40 :     _output_transform_values = json_to_vector("out_strain_rate_transform_values");
     110             :   }
     111          40 :   else if (_output_rate_name == "out_cell")
     112             :   {
     113          20 :     _output_transform_enum = get_transform_enum(json_to_string("out_cell_rate_transform_type"));
     114          40 :     _output_transform_values = json_to_vector("out_cell_rate_transform_values");
     115             :   }
     116          20 :   else if (_output_rate_name == "out_wall")
     117             :   {
     118          20 :     _output_transform_enum = get_transform_enum(json_to_string("out_wall_rate_transform_type"));
     119          40 :     _output_transform_values = json_to_vector("out_wall_rate_transform_values");
     120             :   }
     121             :   else
     122             :   {
     123           0 :     throw NEMLException("This ouput variable is not implemented, model_file_variable_name: " +
     124           0 :                         std::string(_output_rate_name));
     125             :   }
     126         120 : }
     127             : 
     128             : void
     129          60 : LAROMANCE6DInterpolation::request_AD()
     130             : {
     131             :   // only using first derivatives of out_ep, not out_cell and out_wall
     132          60 :   if (_output_rate_name == "out_ep")
     133             :   {
     134          20 :     std::vector<const VariableBase *> inputs = {&_vm_stress};
     135          20 :     _output_rate.request_AD(inputs);
     136          20 :   }
     137          60 : }
     138             : 
     139             : void
     140         360 : LAROMANCE6DInterpolation::set_value(bool out, bool /*dout_din*/, bool /*d2out_din2*/)
     141             : {
     142         360 :   if (out)
     143         720 :     _output_rate = interpolate_and_transform();
     144         360 : }
     145             : 
     146             : LAROMANCE6DInterpolation::TransformEnum
     147         420 : LAROMANCE6DInterpolation::get_transform_enum(const std::string & name) const
     148             : {
     149         420 :   if (name == "COMPRESS")
     150             :     return TransformEnum::COMPRESS;
     151         360 :   else if (name == "DECOMPRESS")
     152             :     return TransformEnum::DECOMPRESS;
     153         320 :   else if (name == "LOG10BOUNDED")
     154             :     return TransformEnum::LOG10BOUNDED;
     155         260 :   else if (name == "EXP10BOUNDED")
     156             :     return TransformEnum::EXP10BOUNDED;
     157         240 :   else if (name == "MINMAX")
     158             :     return TransformEnum::MINMAX;
     159             : 
     160           0 :   throw NEMLException("Unrecognized transform: " + std::string(name));
     161             : }
     162             : 
     163             : std::pair<Scalar, Scalar>
     164        2160 : LAROMANCE6DInterpolation::findLeftIndexAndFraction(const Scalar & grid,
     165             :                                                    const Scalar & interp_points) const
     166             : {
     167             :   // idx is for the left grid point.
     168             :   // searchsorted returns the right idx so -1 makes it the left
     169        4320 :   auto left_idx = Scalar(torch::searchsorted(grid, interp_points) - 1, 0);
     170             : 
     171             :   // this allows us to extrapolate
     172        4320 :   left_idx = Scalar(torch::clamp(left_idx, 0, grid.sizes()[0] - 2), 0);
     173             : 
     174        8640 :   auto left_coord = grid.dynamic_index({left_idx});
     175             :   auto right_coord =
     176       10800 :       grid.dynamic_index({left_idx + torch::tensor(1, default_integer_tensor_options())});
     177        4320 :   auto left_fraction = (right_coord - interp_points) / (right_coord - left_coord);
     178             : 
     179       10800 :   return {left_idx, neml2::dynamic_stack({left_fraction, 1 - left_fraction}, -1)};
     180       10800 : }
     181             : 
     182             : Scalar
     183         360 : LAROMANCE6DInterpolation::compute_interpolation(
     184             :     const std::vector<std::pair<Scalar, Scalar>> index_and_fraction, const Scalar grid_values) const
     185             : {
     186         360 :   Scalar result = Scalar::zeros_like(_temperature());
     187        1080 :   for (const auto i : {0, 1})
     188        2160 :     for (const auto j : {0, 1})
     189        4320 :       for (const auto k : {0, 1})
     190        8640 :         for (const auto l : {0, 1})
     191       17280 :           for (const auto m : {0, 1})
     192       34560 :             for (const auto n : {0, 1})
     193             :             {
     194             :               auto vertex_value =
     195      161280 :                   grid_values.index({(index_and_fraction[0].first +
     196       46080 :                                       torch::tensor(i, default_integer_tensor_options())),
     197       23040 :                                      (index_and_fraction[1].first +
     198       46080 :                                       torch::tensor(j, default_integer_tensor_options())),
     199       23040 :                                      (index_and_fraction[2].first +
     200       46080 :                                       torch::tensor(k, default_integer_tensor_options())),
     201       23040 :                                      (index_and_fraction[3].first +
     202       46080 :                                       torch::tensor(l, default_integer_tensor_options())),
     203       23040 :                                      (index_and_fraction[4].first +
     204       46080 :                                       torch::tensor(m, default_integer_tensor_options())),
     205       23040 :                                      (index_and_fraction[5].first +
     206      276480 :                                       torch::tensor(n, default_integer_tensor_options()))});
     207       46080 :               auto weight = index_and_fraction[0].second.select(-1, i) *
     208       46080 :                             index_and_fraction[1].second.select(-1, j) *
     209       46080 :                             index_and_fraction[2].second.select(-1, k) *
     210       46080 :                             index_and_fraction[3].second.select(-1, l) *
     211       46080 :                             index_and_fraction[4].second.select(-1, m) *
     212       46080 :                             index_and_fraction[5].second.select(-1, n);
     213       46080 :               result += vertex_value * weight;
     214             :             }
     215         360 :   return result;
     216      138240 : }
     217             : 
     218             : /// compute interpolated value
     219             : Scalar
     220         360 : LAROMANCE6DInterpolation::interpolate_and_transform() const
     221             : {
     222             :   // These transform constants should be given in the json file.
     223             :   const auto cell_dd_transformed =
     224         360 :       transform_data(_cell_dd(), _cell_transform_values, _cell_transform_enum);
     225             :   const auto wall_dd_transformed =
     226         360 :       transform_data(_wall_dd(), _wall_transform_values, _wall_transform_enum);
     227             :   const auto vm_stress_transformed =
     228         360 :       transform_data(_vm_stress(), _stress_transform_values, _stress_transform_enum);
     229             :   const auto ep_strain_transformed = transform_data(
     230         360 :       _ep_strain(), _plastic_strain_transform_values, _plastic_strain_transform_enum);
     231             :   const auto temperature_transformed =
     232         360 :       transform_data(_temperature(), _temperature_transform_values, _temperature_transform_enum);
     233             :   const auto env_fac_transformed =
     234         360 :       transform_data(_env_fac(), _env_transform_values, _env_transform_enum);
     235             : 
     236             :   std::vector<std::pair<Scalar, Scalar>> left_index_weight;
     237         360 :   left_index_weight.push_back(findLeftIndexAndFraction(_stress_grid, vm_stress_transformed));
     238         360 :   left_index_weight.push_back(findLeftIndexAndFraction(_temperature_grid, temperature_transformed));
     239             :   left_index_weight.push_back(
     240         360 :       findLeftIndexAndFraction(_plastic_strain_grid, ep_strain_transformed));
     241         360 :   left_index_weight.push_back(findLeftIndexAndFraction(_cell_grid, cell_dd_transformed));
     242         360 :   left_index_weight.push_back(findLeftIndexAndFraction(_wall_grid, wall_dd_transformed));
     243         720 :   left_index_weight.push_back(findLeftIndexAndFraction(_env_grid, env_fac_transformed));
     244         360 :   Scalar interpolated_result = compute_interpolation(left_index_weight, _grid_values);
     245             :   Scalar transformed_result =
     246         360 :       transform_data(interpolated_result, _output_transform_values, _output_transform_enum);
     247         360 :   return transformed_result;
     248         360 : }
     249             : 
     250             : Scalar
     251        2520 : LAROMANCE6DInterpolation::transform_data(const Scalar & data,
     252             :                                          const std::vector<double> & param,
     253             :                                          TransformEnum transform_type) const
     254             : {
     255        2520 :   switch (transform_type)
     256             :   {
     257         360 :     case TransformEnum::COMPRESS:
     258         360 :       return transform_compress(data, param);
     259             : 
     260         240 :     case TransformEnum::DECOMPRESS:
     261         240 :       return transform_decompress(data, param);
     262             : 
     263         360 :     case TransformEnum::LOG10BOUNDED:
     264         360 :       return transform_log10_bounded(data, param);
     265             : 
     266         120 :     case TransformEnum::EXP10BOUNDED:
     267         120 :       return transform_exp10_bounded(data, param);
     268             : 
     269        1440 :     case TransformEnum::MINMAX:
     270        1440 :       return transform_min_max(data, param);
     271             : 
     272             :     default:
     273             :       return data;
     274             :   }
     275             : }
     276             : 
     277             : Scalar
     278         360 : LAROMANCE6DInterpolation::transform_compress(const Scalar & data,
     279             :                                              const std::vector<double> & param) const
     280             : {
     281         360 :   double factor = param[0];
     282         360 :   double compressor = param[1];
     283         360 :   double original_min = param[2];
     284        1080 :   auto d1 = neml2::sign(data) * neml2::pow(neml2::abs(data * factor), compressor);
     285         720 :   auto transformed_data = neml2::log10(1.0 + d1 - original_min);
     286         360 :   return transformed_data;
     287             : }
     288             : 
     289             : Scalar
     290         240 : LAROMANCE6DInterpolation::transform_decompress(const Scalar & data,
     291             :                                                const std::vector<double> & param) const
     292             : {
     293         240 :   double factor = param[0];
     294         240 :   double compressor = param[1];
     295         240 :   double original_min = param[2];
     296         480 :   auto d1 = neml2::pow(10.0, data) - 1.0 + original_min;
     297         720 :   auto transformed_data = neml2::sign(d1) * neml2::pow(neml2::abs(d1), 1.0 / compressor) / factor;
     298         240 :   return transformed_data;
     299             : }
     300             : 
     301             : Scalar
     302         360 : LAROMANCE6DInterpolation::transform_log10_bounded(const Scalar & data,
     303             :                                                   const std::vector<double> & param) const
     304             : 
     305             : {
     306         360 :   double factor = param[0];
     307         360 :   double lowerbound = param[1];
     308         360 :   double upperbound = param[2];
     309         360 :   double logmin = param[3];
     310         360 :   double logmax = param[4];
     311         360 :   double range = upperbound - lowerbound;
     312             :   auto transformed_data =
     313        1080 :       range * (neml2::log10(data + factor) - logmin) / (logmax - logmin) + lowerbound;
     314         360 :   return transformed_data;
     315             : }
     316             : 
     317             : Scalar
     318         120 : LAROMANCE6DInterpolation::transform_exp10_bounded(const Scalar & data,
     319             :                                                   const std::vector<double> & param) const
     320             : {
     321         120 :   double factor = param[0];
     322         120 :   double lowerbound = param[1];
     323         120 :   double upperbound = param[2];
     324         120 :   double logmin = param[3];
     325         120 :   double logmax = param[4];
     326         120 :   double range = upperbound - lowerbound;
     327             :   auto transformed_data =
     328         360 :       (neml2::pow(10.0, ((data - lowerbound) * (logmax - logmin) / range) + logmin) - factor);
     329         120 :   return transformed_data;
     330             : }
     331             : 
     332             : Scalar
     333        1440 : LAROMANCE6DInterpolation::transform_min_max(const Scalar & data,
     334             :                                             const std::vector<double> & param) const
     335             : {
     336        1440 :   double data_min = param[0];
     337        1440 :   double data_max = param[1];
     338        1440 :   double scaled_min = param[2];
     339        1440 :   double scaled_max = param[3];
     340             :   auto transformed_data =
     341        2880 :       ((data - data_min) / (data_max - data_min)) * (scaled_max - scaled_min) + scaled_min;
     342        1440 :   return transformed_data;
     343             : }
     344             : 
     345             : std::string
     346         420 : LAROMANCE6DInterpolation::json_to_string(const std::string & key) const
     347             : {
     348         420 :   if (!_json.contains(key))
     349           0 :     throw NEMLException("The key '" + std::string(key) + "' is missing from the JSON data file.");
     350             : 
     351         420 :   std::string name = _json[key].get<std::string>();
     352         420 :   return name;
     353             : }
     354             : 
     355             : std::vector<double>
     356         420 : LAROMANCE6DInterpolation::json_to_vector(const std::string & key) const
     357             : {
     358         420 :   if (!_json.contains(key))
     359           0 :     throw NEMLException("The key '" + std::string(key) + "' is missing from the JSON data file.");
     360             : 
     361         420 :   std::vector<double> data_vec = _json[key].get<std::vector<double>>();
     362         420 :   return data_vec;
     363             : }
     364             : 
     365             : Scalar
     366         360 : LAROMANCE6DInterpolation::json_vector_to_torch(const std::string & key) const
     367             : {
     368         360 :   if (!_json.contains(key))
     369           0 :     throw NEMLException("The key '" + std::string(key) + "' is missing from the JSON data file.");
     370             : 
     371         360 :   std::vector<double> in_data = _json[key].get<std::vector<double>>();
     372         720 :   return Scalar::create(in_data).clone();
     373         360 : }
     374             : 
     375             : Scalar
     376          60 : LAROMANCE6DInterpolation::json_6Dvector_to_torch(const std::string & key) const
     377             : {
     378             :   using std::vector;
     379          60 :   if (!_json.contains(key))
     380           0 :     throw NEMLException("The key '" + std::string(key) + "' is missing from the JSON data file.");
     381             : 
     382             :   vector<vector<vector<vector<vector<vector<double>>>>>> out_data =
     383          60 :       _json[key].get<vector<vector<vector<vector<vector<vector<double>>>>>>>();
     384             : 
     385             :   const int64_t sz_l0 = out_data.size();
     386             :   const int64_t sz_l1 = out_data[0].size();
     387             :   const int64_t sz_l2 = out_data[0][0].size();
     388             :   const int64_t sz_l3 = out_data[0][0][0].size();
     389             :   const int64_t sz_l4 = out_data[0][0][0][0].size();
     390             :   const int64_t sz_l5 = out_data[0][0][0][0][0].size();
     391             : 
     392             :   auto check_level_size =
     393      786720 :       [](const int64_t current_vec_size, const int64_t sz_level, const std::string & key)
     394             :   {
     395      786720 :     if (current_vec_size != sz_level)
     396           0 :       throw NEMLException("Incorrect JSON interpolation grid size for '" + key + "'.");
     397      786720 :   };
     398             : 
     399             :   std::vector<double> linearize_values;
     400          60 :   check_level_size(out_data.size(), sz_l0, key);
     401         480 :   for (auto && level1 : out_data)
     402             :   {
     403         420 :     check_level_size(level1.size(), sz_l1, key);
     404        7140 :     for (auto && level2 : level1)
     405             :     {
     406        6720 :       check_level_size(level2.size(), sz_l2, key);
     407       33600 :       for (auto && level3 : level2)
     408             :       {
     409       26880 :         check_level_size(level3.size(), sz_l3, key);
     410      134400 :         for (auto && level4 : level3)
     411             :         {
     412      107520 :           check_level_size(level4.size(), sz_l4, key);
     413      752640 :           for (auto && level5 : level4)
     414             :           {
     415      645120 :             check_level_size(level5.size(), sz_l5, key);
     416     3225600 :             for (auto && value : level5)
     417     2580480 :               linearize_values.push_back(value);
     418             :           }
     419             :         }
     420             :       }
     421             :     }
     422             :   }
     423             : 
     424         120 :   return Scalar::create(linearize_values)
     425         480 :       .dynamic_reshape({sz_l0, sz_l1, sz_l2, sz_l3, sz_l4, sz_l5})
     426         120 :       .clone();
     427          60 : }
     428             : 
     429             : } // namespace neml2
     430             : 
     431             : #endif

Generated by: LCOV version 1.14