LCOV - code coverage report
Current view: top level - include/neml2/interfaces - NEML2ModelInterface.h (source / functions) Hit Total Coverage
Test: idaholab/moose tensor_mechanics: d6b47a Lines: 0 10 0.0 %
Date: 2024-02-27 11:53:14 Functions: 0 4 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : //* This file is part of the MOOSE framework
       2             : //* https://www.mooseframework.org
       3             : //*
       4             : //* All rights reserved, see COPYRIGHT for full restrictions
       5             : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
       6             : //*
       7             : //* Licensed under LGPL 2.1, please see LICENSE for details
       8             : //* https://www.gnu.org/licenses/lgpl-2.1.html
       9             : 
      10             : #pragma once
      11             : 
      12             : #ifdef NEML2_ENABLED
      13             : #include "neml2/models/Model.h"
      14             : #include "neml2/misc/parser_utils.h"
      15             : #include "Material.h"
      16             : #include "UserObject.h"
      17             : #endif
      18             : 
      19             : #include "NEML2Utils.h"
      20             : 
      21             : /**
      22             :  * Interface class to provide common input parameters, members, and methods for MOOSEObjects that
      23             :  * use NEML2 models.
      24             :  */
      25             : template <class T>
      26             : class NEML2ModelInterface : public T
      27             : {
      28             : public:
      29             :   static InputParameters validParams();
      30             : 
      31             :   template <typename... P>
      32             :   NEML2ModelInterface(const InputParameters & params, P &&... args);
      33             : 
      34             : #ifdef NEML2_ENABLED
      35             : 
      36             : protected:
      37             :   /**
      38             :    * Validate the NEML2 material model. This method should throw a moose error with the first
      39             :    * encountered problem. Note that the developer is responsible for calling this method at the
      40             :    * appropriate times, for example, at initialSetup().
      41             :    */
      42             :   virtual void validateModel() const;
      43             : 
      44             :   /// Get the NEML2 model
      45             :   neml2::Model & model() const { return _model; }
      46             : 
      47             :   /// Get the target compute device
      48             :   const torch::Device & device() const { return _device; }
      49             : 
      50             :   /**
      51             :    * @brief Convert a raw string to a LabeledAxisAccessor
      52             :    *
      53             :    * @param raw_str
      54             :    * @return neml2::LabeledAxisAccessor
      55             :    */
      56             :   neml2::LabeledAxisAccessor getLabeledAxisAccessor(const std::string & raw_str) const;
      57             : 
      58             : private:
      59             :   /// The NEML2 material model
      60             :   neml2::Model & _model;
      61             : 
      62             :   /// The device on which to evaluate the NEML2 model
      63             :   const torch::Device _device;
      64             : 
      65             : #endif // NEML2_ENABLED
      66             : };
      67             : 
      68             : template <class T>
      69             : InputParameters
      70           0 : NEML2ModelInterface<T>::validParams()
      71             : {
      72           0 :   InputParameters params = T::validParams();
      73           0 :   params.addRequiredParam<std::string>(
      74             :       "model",
      75             :       "Name of the NEML2 model, i.e., the string inside the brackets [] in the NEML2 input file "
      76             :       "that corresponds to the model you want to use.");
      77           0 :   params.addParam<std::string>(
      78             :       "device",
      79             :       "cpu",
      80             :       "Device on which to evaluate the NEML2 model. The string supplied must follow the following "
      81             :       "schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, and "
      82             :       ":<device-index> optionally specifies a device index. For example, device='cpu' sets the "
      83             :       "target compute device to be CPU, and device='cuda:1' sets the target compute device to be "
      84             :       "CUDA with device ID 1.");
      85           0 :   return params;
      86           0 : }
      87             : 
      88             : #ifndef NEML2_ENABLED
      89             : 
      90             : template <class T>
      91             : template <typename... P>
      92           0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
      93           0 :   : T(params, args...)
      94             : {
      95           0 :   NEML2Utils::libraryNotEnabledError(params);
      96           0 : }
      97             : 
      98             : #else
      99             : 
     100             : template <class T>
     101             : template <typename... P>
     102             : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args)
     103             :   : T(params, args...),
     104             :     _model(neml2::Factory::get_object<neml2::Model>("Models", params.get<std::string>("model"))),
     105             :     _device(params.get<std::string>("device"))
     106             : {
     107             :   // Send the model to the compute device
     108             :   _model.to(_device);
     109             : }
     110             : 
     111             : template <class T>
     112             : void
     113             : NEML2ModelInterface<T>::validateModel() const
     114             : {
     115             :   // Forces and old forces on the input axis must match, i.e. all the variables on the old_forces
     116             :   // subaxis must also exist on the forces subaxis:
     117             :   if (_model.input().has_subaxis("old_forces"))
     118             :     for (auto var : _model.input().subaxis("old_forces").variable_accessors(/*recursive=*/true))
     119             :       if (!_model.input().subaxis("forces").has_variable(var))
     120             :         mooseError("The NEML2 model has old force variable ",
     121             :                    var,
     122             :                    " as input, but does not have the corresponding force variable as input.");
     123             : 
     124             :   // Similarly, state (on the output axis) and old state (on the input axis) must match, i.e. all
     125             :   // the variables on the input's old_state subaxis must also exist on the output's state subaxis:
     126             :   if (_model.input().has_subaxis("old_state"))
     127             :     for (auto var : _model.input().subaxis("old_state").variable_accessors(/*recursive=*/true))
     128             :       if (!_model.output().subaxis("state").has_variable(var))
     129             :         mooseError("The NEML2 model has old state variable ",
     130             :                    var,
     131             :                    " as input, but does not have the corresponding state variable as output.");
     132             : }
     133             : 
     134             : template <class T>
     135             : neml2::LabeledAxisAccessor
     136             : NEML2ModelInterface<T>::getLabeledAxisAccessor(const std::string & raw_str) const
     137             : {
     138             :   return neml2::utils::parse<neml2::LabeledAxisAccessor>(raw_str);
     139             : }
     140             : 
     141             : #endif // NEML2_ENABLED

Generated by: LCOV version 1.14