Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #include "ParsedCompute.h"
10 :
11 : #include "MooseUtils.h"
12 : #include "SwiftUtils.h"
13 : #include "MultiMooseEnum.h"
14 : #include "DomainAction.h"
15 :
16 : registerMooseObject("SwiftApp", ParsedCompute);
17 :
18 : InputParameters
19 628 : ParsedCompute::validParams()
20 : {
21 628 : InputParameters params = TensorOperator<>::validParams();
22 628 : params.addClassDescription("ParsedCompute object.");
23 1256 : params.addRequiredParam<std::string>("expression", "Parsed expression");
24 1256 : params.addParam<std::vector<TensorInputBufferName>>(
25 : "inputs", {}, "Buffer names used in the expression");
26 1256 : params.addParam<std::vector<TensorInputBufferName>>(
27 : "derivatives", {}, "List of inputs to take the derivative w.r.t. (or none)");
28 1256 : params.addParam<bool>(
29 1256 : "enable_jit", true, "Use operator fusion and just in time compilation (recommended on GPU)");
30 1256 : params.addParam<bool>("enable_fpoptimizer", true, "Use algebraic optimizer");
31 1256 : params.addParam<bool>("extra_symbols",
32 1256 : false,
33 : "Provide i (imaginary unit), kx,ky,kz (reciprocal space frequency), k2 "
34 : "(square of the k-vector), x,y,z "
35 : "(real space coordinates), time t, pi, and e.");
36 : // Constants and their values
37 628 : params.addParam<std::vector<std::string>>(
38 : "constant_names",
39 628 : std::vector<std::string>(),
40 : "Vector of constants used in the parsed function (use this for kB etc.)");
41 628 : params.addParam<std::vector<std::string>>(
42 : "constant_expressions",
43 628 : std::vector<std::string>(),
44 : "Vector of values for the constants in constant_names (can be an FParser expression)");
45 1256 : MooseEnum expandEnum("REAL RECIPROCAL NONE", "NONE");
46 1256 : params.addParam<MooseEnum>("expand", expandEnum, "Expand the tensor to full size.");
47 628 : return params;
48 628 : }
49 :
50 314 : ParsedCompute::ParsedCompute(const InputParameters & parameters)
51 : : TensorOperator<>(parameters),
52 314 : _use_jit(getParam<bool>("enable_jit")),
53 628 : _extra_symbols(getParam<bool>("extra_symbols")),
54 942 : _expand(getParam<MooseEnum>("expand").getEnum<ExpandEnum>())
55 : {
56 314 : const auto & expression = getParam<std::string>("expression");
57 314 : const auto & names = getParam<std::vector<TensorInputBufferName>>("inputs");
58 :
59 : // check for duplicates
60 628 : auto hasDuplicates = [](const std::vector<std::string> & values)
61 : {
62 628 : std::set<std::string> s(values.begin(), values.end());
63 628 : return values.size() != s.size();
64 : };
65 :
66 314 : if (hasDuplicates(names))
67 0 : paramError("inputs", "Duplicate buffer name.");
68 :
69 : // get all input buffers
70 728 : for (const auto & name : names)
71 414 : _params.push_back(&getInputBufferByName(name));
72 :
73 : static const std::vector<std::string> reserved_symbols = {
74 314 : "i", "x", "kx", "y", "ky", "z", "kz", "k2", "t"};
75 :
76 : // helper function to check if the name given is one of the reserved_names
77 582 : auto isReservedName = [this](const auto & name)
78 582 : { return _extra_symbols && std::count(reserved_symbols.begin(), reserved_symbols.end(), name); };
79 :
80 314 : const auto & constant_names = getParam<std::vector<std::string>>("constant_names");
81 314 : const auto & constant_expressions = getParam<std::vector<std::string>>("constant_expressions");
82 :
83 314 : if (hasDuplicates(constant_names))
84 0 : paramError("constant_names", "Duplicate constant name.");
85 :
86 482 : for (const auto & name : constant_names)
87 168 : if (isReservedName(name))
88 0 : paramError("constant_names", "Cannot use reserved name '", name, "' for constant.");
89 728 : for (const auto & name : names)
90 414 : if (isReservedName(name))
91 0 : paramError("inputs", "Cannot use reserved name '", name, "' for coupled fields.");
92 :
93 : // check constant vectors
94 314 : unsigned int nconst = constant_expressions.size();
95 314 : if (nconst != constant_names.size())
96 0 : paramError("constant_names",
97 : "The parameter vectors constant_names (size ",
98 : constant_names.size(),
99 : ") and constant_values (size ",
100 : nconst,
101 : ") must have equal length.");
102 :
103 314 : auto setup = [&](auto & fp)
104 : {
105 314 : std::vector variables_vec = names;
106 :
107 : // add extra symbols
108 314 : if (_extra_symbols)
109 : {
110 : // append extra symbols
111 118 : variables_vec.insert(variables_vec.end(), reserved_symbols.begin(), reserved_symbols.end());
112 :
113 236 : _constant_tensors.push_back(
114 118 : torch::tensor(c10::complex<double>(0.0, 1.0), MooseTensor::complexFloatTensorOptions()));
115 118 : _params.push_back(&_constant_tensors[0]);
116 :
117 472 : for (const auto dim : make_range(3u))
118 : {
119 354 : _params.push_back(&_domain.getAxis(dim));
120 354 : _params.push_back(&_domain.getReciprocalAxis(dim));
121 : }
122 :
123 118 : _params.push_back(&_domain.getKSquare());
124 118 : _params.push_back(&_time_tensor);
125 :
126 118 : fp.AddConstant("pi", libMesh::pi);
127 236 : fp.AddConstant("e", std::exp(Real(1.0)));
128 : }
129 :
130 : // previously evaluated constant_expressions may be used in following constant_expressions
131 314 : std::vector<Real> constant_values(nconst);
132 482 : for (unsigned int i = 0; i < nconst; ++i)
133 : {
134 : // no need to use dual numbers for the constant expressions
135 : auto expression = std::make_shared<FunctionParserADBase<Real>>();
136 :
137 : // add previously evaluated constants
138 252 : for (unsigned int j = 0; j < i; ++j)
139 84 : if (!expression->AddConstant(constant_names[j], constant_values[j]))
140 0 : paramError("constant_names", "Invalid constant name '", constant_names[j], "'");
141 :
142 : // build the temporary constant expression function
143 336 : if (expression->Parse(constant_expressions[i], "") >= 0)
144 0 : mooseError("Invalid constant expression\n",
145 : constant_expressions[i],
146 : "\n in parsed function object.\n",
147 0 : expression->ErrorMsg());
148 :
149 168 : constant_values[i] = expression->Eval(nullptr);
150 :
151 168 : if (!fp.AddConstant(constant_names[i], constant_values[i]))
152 0 : mooseError("Invalid constant name in parsed function object");
153 : }
154 :
155 : // build variables string
156 314 : const auto variables = MooseUtils::join(variables_vec, ",");
157 :
158 : // parse
159 314 : if (fp.Parse(expression, variables) >= 0)
160 0 : paramError("expression", "Invalid function: ", fp.ErrorMsg());
161 :
162 : // take derivatives
163 666 : for (const auto & d : getParam<std::vector<TensorInputBufferName>>("derivatives"))
164 38 : if (std::find(names.begin(), names.end(), d) != names.end())
165 : {
166 38 : if (fp.AutoDiff(d) != -1)
167 0 : paramError("expression", "Failed to take derivative w.r.t. `", d, "`.");
168 : }
169 : else
170 0 : paramError("derivatives",
171 : "Derivative w.r.t `",
172 : d,
173 : "` was requested, but it is not listed in `inputs`.");
174 :
175 628 : if (getParam<bool>("enable_fpoptimizer"))
176 314 : fp.Optimize();
177 :
178 314 : fp.setupTensors();
179 628 : };
180 :
181 314 : if (_use_jit)
182 306 : setup(_jit);
183 : else
184 8 : setup(_no_jit);
185 314 : }
186 :
187 : void
188 205514 : ParsedCompute::computeBuffer()
189 : {
190 205514 : if (_extra_symbols)
191 300 : _time_tensor = torch::tensor(_time, MooseTensor::floatTensorOptions());
192 :
193 : // use local shape if we add parallel support, and add option for reciprocal shape
194 205514 : if (_use_jit)
195 411012 : _u = _jit.Eval(_params);
196 : else
197 16 : _u = _no_jit.Eval(_params);
198 :
199 : // optionally expand the tensor
200 205514 : switch (_expand)
201 : {
202 68 : case ExpandEnum::REAL:
203 68 : _u = _u.expand(_domain.getShape());
204 68 : break;
205 :
206 0 : case ExpandEnum::RECIPROCAL:
207 0 : _u = _u.expand(_domain.getReciprocalShape());
208 0 : break;
209 :
210 : case ExpandEnum::NONE:
211 : break;
212 : }
213 205514 : }
|