LCOV - code coverage report
Current view: top level - src/tensor_computes - ParsedCompute.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 88 103 85.4 %
Date: 2025-09-10 17:10:32 Functions: 7 7 100.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #include "ParsedCompute.h"
      10             : 
      11             : #include "MooseUtils.h"
      12             : #include "SwiftUtils.h"
      13             : #include "MultiMooseEnum.h"
      14             : #include "DomainAction.h"
      15             : 
      16             : registerMooseObject("SwiftApp", ParsedCompute);
      17             : 
      18             : InputParameters
      19         628 : ParsedCompute::validParams()
      20             : {
      21         628 :   InputParameters params = TensorOperator<>::validParams();
      22         628 :   params.addClassDescription("ParsedCompute object.");
      23        1256 :   params.addRequiredParam<std::string>("expression", "Parsed expression");
      24        1256 :   params.addParam<std::vector<TensorInputBufferName>>(
      25             :       "inputs", {}, "Buffer names used in the expression");
      26        1256 :   params.addParam<std::vector<TensorInputBufferName>>(
      27             :       "derivatives", {}, "List of inputs to take the derivative w.r.t. (or none)");
      28        1256 :   params.addParam<bool>(
      29        1256 :       "enable_jit", true, "Use operator fusion and just in time compilation (recommended on GPU)");
      30        1256 :   params.addParam<bool>("enable_fpoptimizer", true, "Use algebraic optimizer");
      31        1256 :   params.addParam<bool>("extra_symbols",
      32        1256 :                         false,
      33             :                         "Provide i (imaginary unit), kx,ky,kz (reciprocal space frequency), k2 "
      34             :                         "(square of the k-vector), x,y,z "
      35             :                         "(real space coordinates), time t, pi, and e.");
      36             :   // Constants and their values
      37         628 :   params.addParam<std::vector<std::string>>(
      38             :       "constant_names",
      39         628 :       std::vector<std::string>(),
      40             :       "Vector of constants used in the parsed function (use this for kB etc.)");
      41         628 :   params.addParam<std::vector<std::string>>(
      42             :       "constant_expressions",
      43         628 :       std::vector<std::string>(),
      44             :       "Vector of values for the constants in constant_names (can be an FParser expression)");
      45        1256 :   MooseEnum expandEnum("REAL RECIPROCAL NONE", "NONE");
      46        1256 :   params.addParam<MooseEnum>("expand", expandEnum, "Expand the tensor to full size.");
      47         628 :   return params;
      48         628 : }
      49             : 
      50         314 : ParsedCompute::ParsedCompute(const InputParameters & parameters)
      51             :   : TensorOperator<>(parameters),
      52         314 :     _use_jit(getParam<bool>("enable_jit")),
      53         628 :     _extra_symbols(getParam<bool>("extra_symbols")),
      54         942 :     _expand(getParam<MooseEnum>("expand").getEnum<ExpandEnum>())
      55             : {
      56         314 :   const auto & expression = getParam<std::string>("expression");
      57         314 :   const auto & names = getParam<std::vector<TensorInputBufferName>>("inputs");
      58             : 
      59             :   // check for duplicates
      60         628 :   auto hasDuplicates = [](const std::vector<std::string> & values)
      61             :   {
      62         628 :     std::set<std::string> s(values.begin(), values.end());
      63         628 :     return values.size() != s.size();
      64             :   };
      65             : 
      66         314 :   if (hasDuplicates(names))
      67           0 :     paramError("inputs", "Duplicate buffer name.");
      68             : 
      69             :   // get all input buffers
      70         728 :   for (const auto & name : names)
      71         414 :     _params.push_back(&getInputBufferByName(name));
      72             : 
      73             :   static const std::vector<std::string> reserved_symbols = {
      74         314 :       "i", "x", "kx", "y", "ky", "z", "kz", "k2", "t"};
      75             : 
      76             :   // helper function to check if the name given is one of the reserved_names
      77         582 :   auto isReservedName = [this](const auto & name)
      78         582 :   { return _extra_symbols && std::count(reserved_symbols.begin(), reserved_symbols.end(), name); };
      79             : 
      80         314 :   const auto & constant_names = getParam<std::vector<std::string>>("constant_names");
      81         314 :   const auto & constant_expressions = getParam<std::vector<std::string>>("constant_expressions");
      82             : 
      83         314 :   if (hasDuplicates(constant_names))
      84           0 :     paramError("constant_names", "Duplicate constant name.");
      85             : 
      86         482 :   for (const auto & name : constant_names)
      87         168 :     if (isReservedName(name))
      88           0 :       paramError("constant_names", "Cannot use reserved name '", name, "' for constant.");
      89         728 :   for (const auto & name : names)
      90         414 :     if (isReservedName(name))
      91           0 :       paramError("inputs", "Cannot use reserved name '", name, "' for coupled fields.");
      92             : 
      93             :   // check constant vectors
      94         314 :   unsigned int nconst = constant_expressions.size();
      95         314 :   if (nconst != constant_names.size())
      96           0 :     paramError("constant_names",
      97             :                "The parameter vectors constant_names (size ",
      98             :                constant_names.size(),
      99             :                ") and constant_values (size ",
     100             :                nconst,
     101             :                ") must have equal length.");
     102             : 
     103         314 :   auto setup = [&](auto & fp)
     104             :   {
     105         314 :     std::vector variables_vec = names;
     106             : 
     107             :     // add extra symbols
     108         314 :     if (_extra_symbols)
     109             :     {
     110             :       // append extra symbols
     111         118 :       variables_vec.insert(variables_vec.end(), reserved_symbols.begin(), reserved_symbols.end());
     112             : 
     113         236 :       _constant_tensors.push_back(
     114         118 :           torch::tensor(c10::complex<double>(0.0, 1.0), MooseTensor::complexFloatTensorOptions()));
     115         118 :       _params.push_back(&_constant_tensors[0]);
     116             : 
     117         472 :       for (const auto dim : make_range(3u))
     118             :       {
     119         354 :         _params.push_back(&_domain.getAxis(dim));
     120         354 :         _params.push_back(&_domain.getReciprocalAxis(dim));
     121             :       }
     122             : 
     123         118 :       _params.push_back(&_domain.getKSquare());
     124         118 :       _params.push_back(&_time_tensor);
     125             : 
     126         118 :       fp.AddConstant("pi", libMesh::pi);
     127         236 :       fp.AddConstant("e", std::exp(Real(1.0)));
     128             :     }
     129             : 
     130             :     // previously evaluated constant_expressions may be used in following constant_expressions
     131         314 :     std::vector<Real> constant_values(nconst);
     132         482 :     for (unsigned int i = 0; i < nconst; ++i)
     133             :     {
     134             :       // no need to use dual numbers for the constant expressions
     135             :       auto expression = std::make_shared<FunctionParserADBase<Real>>();
     136             : 
     137             :       // add previously evaluated constants
     138         252 :       for (unsigned int j = 0; j < i; ++j)
     139          84 :         if (!expression->AddConstant(constant_names[j], constant_values[j]))
     140           0 :           paramError("constant_names", "Invalid constant name '", constant_names[j], "'");
     141             : 
     142             :       // build the temporary constant expression function
     143         336 :       if (expression->Parse(constant_expressions[i], "") >= 0)
     144           0 :         mooseError("Invalid constant expression\n",
     145             :                    constant_expressions[i],
     146             :                    "\n in parsed function object.\n",
     147           0 :                    expression->ErrorMsg());
     148             : 
     149         168 :       constant_values[i] = expression->Eval(nullptr);
     150             : 
     151         168 :       if (!fp.AddConstant(constant_names[i], constant_values[i]))
     152           0 :         mooseError("Invalid constant name in parsed function object");
     153             :     }
     154             : 
     155             :     // build variables string
     156         314 :     const auto variables = MooseUtils::join(variables_vec, ",");
     157             : 
     158             :     // parse
     159         314 :     if (fp.Parse(expression, variables) >= 0)
     160           0 :       paramError("expression", "Invalid function: ", fp.ErrorMsg());
     161             : 
     162             :     // take derivatives
     163         666 :     for (const auto & d : getParam<std::vector<TensorInputBufferName>>("derivatives"))
     164          38 :       if (std::find(names.begin(), names.end(), d) != names.end())
     165             :       {
     166          38 :         if (fp.AutoDiff(d) != -1)
     167           0 :           paramError("expression", "Failed to take derivative w.r.t. `", d, "`.");
     168             :       }
     169             :       else
     170           0 :         paramError("derivatives",
     171             :                    "Derivative w.r.t `",
     172             :                    d,
     173             :                    "` was requested, but it is not listed in `inputs`.");
     174             : 
     175         628 :     if (getParam<bool>("enable_fpoptimizer"))
     176         314 :       fp.Optimize();
     177             : 
     178         314 :     fp.setupTensors();
     179         628 :   };
     180             : 
     181         314 :   if (_use_jit)
     182         306 :     setup(_jit);
     183             :   else
     184           8 :     setup(_no_jit);
     185         314 : }
     186             : 
     187             : void
     188      205514 : ParsedCompute::computeBuffer()
     189             : {
     190      205514 :   if (_extra_symbols)
     191         300 :     _time_tensor = torch::tensor(_time, MooseTensor::floatTensorOptions());
     192             : 
     193             :   // use local shape if we add parallel support, and add option for reciprocal shape
     194      205514 :   if (_use_jit)
     195      411012 :     _u = _jit.Eval(_params);
     196             :   else
     197          16 :     _u = _no_jit.Eval(_params);
     198             : 
     199             :   // optionally expand the tensor
     200      205514 :   switch (_expand)
     201             :   {
     202          68 :     case ExpandEnum::REAL:
     203          68 :       _u = _u.expand(_domain.getShape());
     204          68 :       break;
     205             : 
     206           0 :     case ExpandEnum::RECIPROCAL:
     207           0 :       _u = _u.expand(_domain.getReciprocalShape());
     208           0 :       break;
     209             : 
     210             :     case ExpandEnum::NONE:
     211             :       break;
     212             :   }
     213      205514 : }

Generated by: LCOV version 1.14