LCOV - code coverage report
Current view: top level - src/neml2/utils - NEML2Utils.C (source / functions) Hit Total Coverage
Test: idaholab/moose tensor_mechanics: d6b47a Lines: 5 5 100.0 %
Date: 2024-02-27 11:53:14 Functions: 2 2 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://www.mooseframework.org
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #include "NEML2Utils.h"
      11             : 
      12             : #ifdef NEML2_ENABLED
      13             : 
      14             : #include "VariadicTable.h"
      15             : 
      16             : namespace neml2
      17             : {
      18             : 
      19             : std::ostream &
      20             : operator<<(std::ostream & os, const Model & model)
      21             : {
      22             :   auto print_axis = [](std::ostream & os, const LabeledAxis & axis)
      23             :   {
      24             :     VariadicTable<std::string, TorchSize> table({"Variable", "Storage size"});
      25             :     for (const auto & var : axis.variable_accessors(/*recursive=*/true))
      26             :       table.addRow(utils::stringify(var), axis.storage_size(var));
      27             :     table.print(os);
      28             :   };
      29             : 
      30             :   os << "Input:" << std::endl;
      31             :   print_axis(os, model.input());
      32             : 
      33             :   os << std::endl;
      34             : 
      35             :   os << "Output:" << std::endl;
      36             :   print_axis(os, model.output());
      37             : 
      38             :   os << std::endl;
      39             : 
      40             :   os << "Parameters: " << std::endl;
      41             :   VariadicTable<std::string, std::string> table({"Parameter", "Requires grad"});
      42             :   for (auto && [name, value] : model.named_parameters(/*recursive=*/true))
      43             :     table.addRow(name, value.requires_grad() ? "True" : "False");
      44             :   table.print(os);
      45             : 
      46             :   return os;
      47             : }
      48             : } // namespace neml2
      49             : 
      50             : #endif // NEML2_ENABLED
      51             : 
      52             : namespace NEML2Utils
      53             : {
      54             : 
      55             : #ifdef NEML2_ENABLED
      56             : 
      57             : template <>
      58             : neml2::BatchTensor
      59             : toNEML2(const Real & v)
      60             : {
      61             :   return neml2::Scalar(v, neml2::default_tensor_options);
      62             : }
      63             : 
      64             : // FIXME: This is an unfortunately specialization because the models I included for testing use
      65             : // symmetric tensors everywhere. Once I tested all the models with full tensors (i.e. not in Mandel
      66             : // notation), I should be able to "fix" this specialization.
      67             : template <>
      68             : neml2::BatchTensor
      69             : toNEML2(const RankTwoTensor & r2t)
      70             : {
      71             :   return neml2::SR2::fill(r2t(0, 0), r2t(1, 1), r2t(2, 2), r2t(1, 2), r2t(0, 2), r2t(0, 1));
      72             : }
      73             : 
      74             : template <>
      75             : neml2::BatchTensor
      76             : toNEML2(const SymmetricRankTwoTensor & r2t)
      77             : {
      78             :   return neml2::SR2::fill(r2t(0, 0), r2t(1, 1), r2t(2, 2), r2t(1, 2), r2t(0, 2), r2t(0, 1));
      79             : }
      80             : 
      81             : template <>
      82             : neml2::BatchTensor
      83             : toNEML2(const std::vector<Real> & v)
      84             : {
      85             :   return neml2::BatchTensor(torch::tensor(v, neml2::default_tensor_options), 0);
      86             : }
      87             : 
      88             : template <>
      89             : SymmetricRankTwoTensor
      90             : toMOOSE(const neml2::BatchTensor & t)
      91             : {
      92             :   using symr2t = SymmetricRankTwoTensor;
      93             :   return symr2t(t.base_index({0}).item<neml2::Real>() / symr2t::mandelFactor(0),
      94             :                 t.base_index({1}).item<neml2::Real>() / symr2t::mandelFactor(1),
      95             :                 t.base_index({2}).item<neml2::Real>() / symr2t::mandelFactor(2),
      96             :                 t.base_index({3}).item<neml2::Real>() / symr2t::mandelFactor(3),
      97             :                 t.base_index({4}).item<neml2::Real>() / symr2t::mandelFactor(4),
      98             :                 t.base_index({5}).item<neml2::Real>() / symr2t::mandelFactor(5));
      99             : }
     100             : 
     101             : template <>
     102             : std::vector<Real>
     103             : toMOOSE(const neml2::BatchTensor & t)
     104             : {
     105             :   auto tc = t.contiguous();
     106             :   return std::vector<Real>(tc.data_ptr<neml2::Real>(), tc.data_ptr<neml2::Real>() + tc.numel());
     107             : }
     108             : 
     109             : template <>
     110             : SymmetricRankFourTensor
     111             : toMOOSE(const neml2::BatchTensor & t)
     112             : {
     113             :   // Well I don't see a good constructor for this, so let me fill out all the components.
     114             :   SymmetricRankFourTensor symsymr4t;
     115             :   for (const auto a : make_range(6))
     116             :     for (const auto b : make_range(6))
     117             :       symsymr4t(a, b) = t.base_index({a, b}).item<neml2::Real>();
     118             : 
     119             :   return symsymr4t;
     120             : }
     121             : 
     122             : #endif // NEML2_ENABLED
     123             : 
     124             : static const std::string message_all =
     125             :     "To use this object, you need to have the `NEML2` library installed. Refer to the "
     126             :     "documentation for guidance on how to enable it.";
     127             : #ifdef LIBTORCH_ENABLED
     128             : static const std::string message = message_all;
     129             : #else
     130             : static const std::string message =
     131             :     message_all + " To build this library MOOSE must be configured with `LIBTORCH` support!";
     132             : #endif
     133             : 
     134             : void
     135        4398 : addClassDescription(InputParameters & params, const std::string & desc)
     136             : {
     137             : #ifdef NEML2_ENABLED
     138             :   params.addClassDescription(desc);
     139             : #else
     140        8796 :   params.addClassDescription(message + " (Original description: " + desc + ")");
     141             : #endif
     142        4398 : }
     143             : 
     144             : void
     145           1 : libraryNotEnabledError(const InputParameters & params)
     146             : {
     147             : #ifndef NEML2_ENABLED
     148           1 :   mooseError(params.blockLocation() + ": " + message);
     149             : #else
     150             :   libmesh_ignore(params);
     151             :   static_assert(
     152             :       "Only place libraryNotEnabledError() in a branch that is compiled if NEML2 is not enabled!");
     153             : #endif
     154             : }
     155             : 
     156             : } // namespace NEML2Utils

Generated by: LCOV version 1.14