Line data Source code
1 : //* This file is part of the MOOSE framework 2 : //* https://mooseframework.inl.gov 3 : //* 4 : //* All rights reserved, see COPYRIGHT for full restrictions 5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT 6 : //* 7 : //* Licensed under LGPL 2.1, please see LICENSE for details 8 : //* https://www.gnu.org/licenses/lgpl-2.1.html 9 : 10 : #ifdef NEML2_ENABLED 11 : 12 : // Torch includes 13 : #include <ATen/ops/from_blob.h> 14 : 15 : // MOOSE includes 16 : #include "NEML2Assembly.h" 17 : 18 : using namespace libMesh; 19 : 20 : registerMooseObject("MooseApp", NEML2Assembly); 21 : 22 : InputParameters 23 2 : NEML2Assembly::validParams() 24 : { 25 2 : InputParameters params = ElementUserObject::validParams(); 26 : 27 2 : params.addClassDescription( 28 : "This user object gathers the JxWxT values from all elements in the assembly and " 29 : "provides them as a neml2 tensor. This is useful for assembling NEML2 models that " 30 : "require the JxWxT values for each element."); 31 : 32 2 : ExecFlagEnum execute_options = MooseUtils::getDefaultExecFlagEnum(); 33 6 : execute_options = {EXEC_INITIAL, EXEC_LINEAR}; 34 4 : params.set<ExecFlagEnum>("execute_on") = execute_options; 35 2 : params.suppressParameter<ExecFlagEnum>("execute_on"); 36 : 37 4 : return params; 38 4 : } 39 : 40 0 : NEML2Assembly::NEML2Assembly(const InputParameters & parameters) : ElementUserObject(parameters) {} 41 : 42 : void 43 0 : NEML2Assembly::invalidate() 44 : { 45 0 : _up_to_date = false; 46 0 : } 47 : 48 : void 49 0 : NEML2Assembly::initialize() 50 : { 51 0 : if (_up_to_date) 52 0 : return; 53 : 54 0 : _nelem = 0; 55 0 : _nqp = 0; 56 0 : _moose_JxWxT.clear(); 57 : } 58 : 59 : void 60 0 : NEML2Assembly::threadJoin(const UserObject & y) 61 : { 62 0 : const auto & other = static_cast<const NEML2Assembly &>(y); 63 : mooseAssert(_up_to_date == other._up_to_date, 64 : "NEML2Assembly becomes out of sync with other thread"); 65 : 66 0 : if (_up_to_date) 67 0 : return; 68 : 69 0 : _nelem += other._nelem; 70 : mooseAssert(_nqp == other._nqp, 71 : "The number of quadrature points per element must be the same in all threads."); 72 : 73 0 : _moose_JxWxT.insert(_moose_JxWxT.end(), other._moose_JxWxT.begin(), other._moose_JxWxT.end()); 74 : } 75 : 76 : void 77 0 : NEML2Assembly::execute() 78 : { 79 0 : if (_up_to_date) 80 0 : return; 81 : 82 0 : _nelem++; 83 : 84 : // number of quadrature points 85 0 : if (_nqp != 0 && std::size_t(_nqp) != _q_point.size()) 86 0 : mooseError("All elements must have the same number of quadrature points per element for all " 87 : "elements"); 88 0 : _nqp = _q_point.size(); 89 : 90 : // JxWxT 91 0 : for (auto qp : index_range(_q_point)) 92 0 : _moose_JxWxT.push_back(_JxW[qp] * _coord[qp]); 93 : } 94 : 95 : void 96 0 : NEML2Assembly::finalize() 97 : { 98 0 : TIME_SECTION("finalize", 1, "Updating FEM assembly for NEML2"); 99 : 100 0 : if (_up_to_date) 101 0 : return; 102 : 103 : // sanity checks on sizes 104 0 : if (_moose_JxWxT.size() != std::size_t(_nelem * _nqp)) 105 0 : mooseError("JxWxT size mismatch, expected ", _nelem * _nqp, " but got ", _moose_JxWxT.size()); 106 : 107 : // convert gathered data to neml2 tensors (and send to device) 108 0 : auto device = _app.getLibtorchDevice(); 109 : _neml2_JxWxT = 110 0 : neml2::Tensor(at::from_blob(_moose_JxWxT.data(), {_nelem, _nqp}, torch::kFloat64), 2) 111 0 : .to(device); 112 : 113 : // done 114 0 : _up_to_date = true; 115 0 : } 116 : 117 : #endif