Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://www.mooseframework.org 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #pragma once 11 : 12 : #ifdef NEML2_ENABLED 13 : #include "neml2/models/Model.h" 14 : #include "neml2/misc/parser_utils.h" 15 : #include "Material.h" 16 : #include "UserObject.h" 17 : #endif 18 : 19 : #include "NEML2Utils.h" 20 : 21 : /** 22 : * Interface class to provide common input parameters, members, and methods for MOOSEObjects that 23 : * use NEML2 models. 24 : */ 25 : template <class T> 26 : class NEML2ModelInterface : public T 27 : { 28 : public: 29 : static InputParameters validParams(); 30 : 31 : template <typename... P> 32 : NEML2ModelInterface(const InputParameters & params, P &&... args); 33 : 34 : #ifdef NEML2_ENABLED 35 : 36 : protected: 37 : /** 38 : * Validate the NEML2 material model. This method should throw a moose error with the first 39 : * encountered problem. Note that the developer is responsible for calling this method at the 40 : * appropriate times, for example, at initialSetup(). 41 : */ 42 : virtual void validateModel() const; 43 : 44 : /// Get the NEML2 model 45 : neml2::Model & model() const { return _model; } 46 : 47 : /// Get the target compute device 48 : const torch::Device & device() const { return _device; } 49 : 50 : /** 51 : * @brief Convert a raw string to a LabeledAxisAccessor 52 : * 53 : * @param raw_str 54 : * @return neml2::LabeledAxisAccessor 55 : */ 56 : neml2::LabeledAxisAccessor getLabeledAxisAccessor(const std::string & raw_str) const; 57 : 58 : private: 59 : /// The NEML2 material model 60 : neml2::Model & _model; 61 : 62 : /// The device on which to evaluate the NEML2 model 63 : const torch::Device _device; 64 : 65 : #endif // NEML2_ENABLED 66 : }; 67 : 68 : template <class T> 69 : InputParameters 70 0 : NEML2ModelInterface<T>::validParams() 71 : { 72 0 : InputParameters params = T::validParams(); 73 0 : params.addRequiredParam<std::string>( 74 : "model", 75 : "Name of the NEML2 model, i.e., the string inside the brackets [] in the NEML2 input file " 76 : "that corresponds to the model you want to use."); 77 0 : params.addParam<std::string>( 78 : "device", 79 : "cpu", 80 : "Device on which to evaluate the NEML2 model. The string supplied must follow the following " 81 : "schema: (cpu|cuda)[:<device-index>] where cpu or cuda specifies the device type, and " 82 : ":<device-index> optionally specifies a device index. For example, device='cpu' sets the " 83 : "target compute device to be CPU, and device='cuda:1' sets the target compute device to be " 84 : "CUDA with device ID 1."); 85 0 : return params; 86 0 : } 87 : 88 : #ifndef NEML2_ENABLED 89 : 90 : template <class T> 91 : template <typename... P> 92 0 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args) 93 0 : : T(params, args...) 94 : { 95 0 : NEML2Utils::libraryNotEnabledError(params); 96 0 : } 97 : 98 : #else 99 : 100 : template <class T> 101 : template <typename... P> 102 : NEML2ModelInterface<T>::NEML2ModelInterface(const InputParameters & params, P &&... args) 103 : : T(params, args...), 104 : _model(neml2::Factory::get_object<neml2::Model>("Models", params.get<std::string>("model"))), 105 : _device(params.get<std::string>("device")) 106 : { 107 : // Send the model to the compute device 108 : _model.to(_device); 109 : } 110 : 111 : template <class T> 112 : void 113 : NEML2ModelInterface<T>::validateModel() const 114 : { 115 : // Forces and old forces on the input axis must match, i.e. all the variables on the old_forces 116 : // subaxis must also exist on the forces subaxis: 117 : if (_model.input().has_subaxis("old_forces")) 118 : for (auto var : _model.input().subaxis("old_forces").variable_accessors(/*recursive=*/true)) 119 : if (!_model.input().subaxis("forces").has_variable(var)) 120 : mooseError("The NEML2 model has old force variable ", 121 : var, 122 : " as input, but does not have the corresponding force variable as input."); 123 : 124 : // Similarly, state (on the output axis) and old state (on the input axis) must match, i.e. all 125 : // the variables on the input's old_state subaxis must also exist on the output's state subaxis: 126 : if (_model.input().has_subaxis("old_state")) 127 : for (auto var : _model.input().subaxis("old_state").variable_accessors(/*recursive=*/true)) 128 : if (!_model.output().subaxis("state").has_variable(var)) 129 : mooseError("The NEML2 model has old state variable ", 130 : var, 131 : " as input, but does not have the corresponding state variable as output."); 132 : } 133 : 134 : template <class T> 135 : neml2::LabeledAxisAccessor 136 : NEML2ModelInterface<T>::getLabeledAxisAccessor(const std::string & raw_str) const 137 : { 138 : return neml2::utils::parse<neml2::LabeledAxisAccessor>(raw_str); 139 : } 140 : 141 : #endif // NEML2_ENABLED