Line data Source code
1 : /**********************************************************************/
2 : /* DO NOT MODIFY THIS HEADER */
3 : /* Swift, a Fourier spectral solver for MOOSE */
4 : /* */
5 : /* Copyright 2024 Battelle Energy Alliance, LLC */
6 : /* ALL RIGHTS RESERVED */
7 : /**********************************************************************/
8 :
9 : #pragma once
10 :
11 : #include "ParsedJITTensor.h"
12 : #include "Conversion.h"
13 : #include "SwiftUtils.h"
14 : #include "libmesh/extrasrc/fptypes.hh"
15 :
16 : #include <torch/csrc/autograd/generated/variable_factories.h>
17 : #include <torch/csrc/jit/frontend/ir_emitter.h>
18 : #include <torch/csrc/jit/ir/alias_analysis.h>
19 : #include <torch/csrc/jit/ir/irparser.h>
20 : #include <torch/csrc/jit/ir/type_hashing.h>
21 : #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
22 : #include <torch/csrc/jit/runtime/custom_operator.h>
23 : #include <torch/csrc/jit/runtime/graph_iterator.h>
24 :
25 : #include <torch/csrc/jit/passes/graph_fuser.h>
26 : #include <torch/csrc/jit/passes/constant_propagation.h>
27 : #include <torch/csrc/jit/passes/dead_code_elimination.h>
28 : #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
29 :
30 : #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
31 :
32 : #include <torch/csrc/autograd/grad_mode.h>
33 :
34 : #include <ATen/TensorOperators.h>
35 :
36 360 : ParsedJITTensor::ParsedJITTensor()
37 360 : : FunctionParserAD(), _graph_executor(nullptr), _execution_plan(nullptr), _data(*getParserData())
38 : {
39 360 : }
40 :
41 : namespace
42 : {
43 : const std::string
44 0 : FP_GetOpcodeName(int opcode)
45 : {
46 : using namespace FUNCTIONPARSERTYPES;
47 :
48 : /* Symbolic meanings for the opcodes? */
49 : const char * p = 0;
50 0 : switch (opcode)
51 : {
52 : case cAbs:
53 : p = "cAbs";
54 : break;
55 0 : case cAcos:
56 : p = "cAcos";
57 0 : break;
58 0 : case cAcosh:
59 : p = "cAcosh";
60 0 : break;
61 0 : case cArg:
62 : p = "cArg";
63 0 : break;
64 0 : case cAsin:
65 : p = "cAsin";
66 0 : break;
67 0 : case cAsinh:
68 : p = "cAsinh";
69 0 : break;
70 0 : case cAtan:
71 : p = "cAtan";
72 0 : break;
73 0 : case cAtan2:
74 : p = "cAtan2";
75 0 : break;
76 0 : case cAtanh:
77 : p = "cAtanh";
78 0 : break;
79 0 : case cCbrt:
80 : p = "cCbrt";
81 0 : break;
82 0 : case cCeil:
83 : p = "cCeil";
84 0 : break;
85 0 : case cConj:
86 : p = "cConj";
87 0 : break;
88 0 : case cCos:
89 : p = "cCos";
90 0 : break;
91 0 : case cCosh:
92 : p = "cCosh";
93 0 : break;
94 0 : case cCot:
95 : p = "cCot";
96 0 : break;
97 0 : case cCsc:
98 : p = "cCsc";
99 0 : break;
100 0 : case cExp:
101 : p = "cExp";
102 0 : break;
103 0 : case cExp2:
104 : p = "cExp2";
105 0 : break;
106 0 : case cFloor:
107 : p = "cFloor";
108 0 : break;
109 0 : case cHypot:
110 : p = "cHypot";
111 0 : break;
112 0 : case cIf:
113 : p = "cIf";
114 0 : break;
115 0 : case cImag:
116 : p = "cImag";
117 0 : break;
118 0 : case cInt:
119 : p = "cInt";
120 0 : break;
121 0 : case cLog:
122 : p = "cLog";
123 0 : break;
124 0 : case cLog2:
125 : p = "cLog2";
126 0 : break;
127 0 : case cLog10:
128 : p = "cLog10";
129 0 : break;
130 0 : case cMax:
131 : p = "cMax";
132 0 : break;
133 0 : case cMin:
134 : p = "cMin";
135 0 : break;
136 0 : case cPolar:
137 : p = "cPolar";
138 0 : break;
139 0 : case cPow:
140 : p = "cPow";
141 0 : break;
142 0 : case cReal:
143 : p = "cReal";
144 0 : break;
145 0 : case cSec:
146 : p = "cSec";
147 0 : break;
148 0 : case cSin:
149 : p = "cSin";
150 0 : break;
151 0 : case cSinh:
152 : p = "cSinh";
153 0 : break;
154 0 : case cSqrt:
155 : p = "cSqrt";
156 0 : break;
157 0 : case cTan:
158 : p = "cTan";
159 0 : break;
160 0 : case cTanh:
161 : p = "cTanh";
162 0 : break;
163 0 : case cTrunc:
164 : p = "cTrunc";
165 0 : break;
166 0 : case cImmed:
167 : p = "cImmed";
168 0 : break;
169 0 : case cJump:
170 : p = "cJump";
171 0 : break;
172 0 : case cNeg:
173 : p = "cNeg";
174 0 : break;
175 0 : case cAdd:
176 : p = "cAdd";
177 0 : break;
178 0 : case cSub:
179 : p = "cSub";
180 0 : break;
181 0 : case cMul:
182 : p = "cMul";
183 0 : break;
184 0 : case cDiv:
185 : p = "cDiv";
186 0 : break;
187 0 : case cMod:
188 : p = "cMod";
189 0 : break;
190 0 : case cEqual:
191 : p = "cEqual";
192 0 : break;
193 0 : case cNEqual:
194 : p = "cNEqual";
195 0 : break;
196 0 : case cLess:
197 : p = "cLess";
198 0 : break;
199 0 : case cLessOrEq:
200 : p = "cLessOrEq";
201 0 : break;
202 0 : case cGreater:
203 : p = "cGreater";
204 0 : break;
205 0 : case cGreaterOrEq:
206 : p = "cGreaterOrEq";
207 0 : break;
208 0 : case cNot:
209 : p = "cNot";
210 0 : break;
211 0 : case cAnd:
212 : p = "cAnd";
213 0 : break;
214 0 : case cOr:
215 : p = "cOr";
216 0 : break;
217 0 : case cDeg:
218 : p = "cDeg";
219 0 : break;
220 0 : case cRad:
221 : p = "cRad";
222 0 : break;
223 0 : case cFCall:
224 : p = "cFCall";
225 0 : break;
226 0 : case cPCall:
227 : p = "cPCall";
228 0 : break;
229 : #ifdef FP_SUPPORT_OPTIMIZER
230 0 : case cFetch:
231 : p = "cFetch";
232 0 : break;
233 0 : case cPopNMov:
234 : p = "cPopNMov";
235 0 : break;
236 0 : case cLog2by:
237 : p = "cLog2by";
238 0 : break;
239 0 : case cNop:
240 : p = "cNop";
241 0 : break;
242 : #endif
243 0 : case cSinCos:
244 : p = "cSinCos";
245 0 : break;
246 0 : case cSinhCosh:
247 : p = "cSinhCosh";
248 0 : break;
249 0 : case cAbsNot:
250 : p = "cAbsNot";
251 0 : break;
252 0 : case cAbsNotNot:
253 : p = "cAbsNotNot";
254 0 : break;
255 0 : case cAbsAnd:
256 : p = "cAbsAnd";
257 0 : break;
258 0 : case cAbsOr:
259 : p = "cAbsOr";
260 0 : break;
261 0 : case cAbsIf:
262 : p = "cAbsIf";
263 0 : break;
264 0 : case cDup:
265 : p = "cDup";
266 0 : break;
267 0 : case cInv:
268 : p = "cInv";
269 0 : break;
270 0 : case cSqr:
271 : p = "cSqr";
272 0 : break;
273 0 : case cRDiv:
274 : p = "cRDiv";
275 0 : break;
276 0 : case cRSub:
277 : p = "cRSub";
278 0 : break;
279 0 : case cNotNot:
280 : p = "cNotNot";
281 0 : break;
282 0 : case cRSqrt:
283 : p = "cRSqrt";
284 0 : break;
285 0 : case VarBegin:
286 : p = "VarBegin";
287 0 : break;
288 0 : default:
289 0 : throw std::runtime_error("Unknown opcode.");
290 : }
291 0 : std::ostringstream tmp;
292 : // if(!p) std::cerr << "o=" << opcode << "\n";
293 : assert(p);
294 0 : tmp << p;
295 0 : return tmp.str();
296 : }
297 : }
298 :
299 : void
300 398 : ParsedJITTensor::setupTensors()
301 : {
302 : using namespace torch::jit;
303 :
304 : // allocate node stack
305 398 : std::vector<Value *> s(_data.mStackSize);
306 :
307 : // create graph
308 398 : _graph = std::make_shared<Graph>();
309 :
310 : // convert immediate data
311 398 : _constant_immed.clear();
312 788 : for (const auto & immed : _data.mImmed)
313 780 : _constant_immed.push_back(_graph->insertConstant(immed));
314 :
315 : // math constants
316 796 : const auto const_one_third = _graph->insertConstant(1.0 / 3.0);
317 :
318 : // create input nodes
319 398 : _input.clear();
320 2078 : for (unsigned i = 0; i < _data.mVariablesAmount; ++i)
321 5040 : _input.push_back(_graph->addInput());
322 :
323 : // build graph
324 : using namespace FUNCTIONPARSERTYPES;
325 :
326 : // get a reference to the stored bytecode
327 : const auto & ByteCode = _data.mByteCode;
328 :
329 : int nImmed = 0, sp = -1, op;
330 3648 : for (unsigned int i = 0; i < ByteCode.size(); ++i)
331 : {
332 : // execute bytecode
333 3250 : switch (op = ByteCode[i])
334 : {
335 390 : case cImmed:
336 390 : ++sp;
337 390 : s[sp] = _constant_immed[nImmed++];
338 390 : break;
339 :
340 446 : case cAdd:
341 446 : --sp;
342 1338 : s[sp] = _graph->insert(aten::add, {s[sp], s[sp + 1]});
343 446 : break;
344 68 : case cSub:
345 68 : --sp;
346 204 : s[sp] = _graph->insert(aten::sub, {s[sp], s[sp + 1]});
347 68 : break;
348 88 : case cRSub:
349 88 : --sp;
350 264 : s[sp] = _graph->insert(aten::sub, {s[sp + 1], s[sp]});
351 88 : break;
352 :
353 504 : case cMul:
354 504 : --sp;
355 1512 : s[sp] = _graph->insert(aten::mul, {s[sp], s[sp + 1]});
356 504 : break;
357 16 : case cDiv:
358 16 : --sp;
359 48 : s[sp] = _graph->insert(aten::div, {s[sp], s[sp + 1]});
360 16 : break;
361 0 : case cInv:
362 0 : s[sp] = _graph->insert(aten::reciprocal, {s[sp]});
363 0 : break;
364 0 : case cMod:
365 0 : --sp;
366 0 : s[sp] = _graph->insert(aten::fmod, {s[sp], s[sp + 1]});
367 0 : break;
368 0 : case cRDiv:
369 0 : --sp;
370 0 : s[sp] = _graph->insert(aten::div, {s[sp + 1], s[sp]});
371 0 : break;
372 :
373 180 : case cSin:
374 360 : s[sp] = _graph->insert(aten::sin, {s[sp]});
375 180 : break;
376 30 : case cCos:
377 60 : s[sp] = _graph->insert(aten::cos, {s[sp]});
378 30 : break;
379 0 : case cTan:
380 0 : s[sp] = _graph->insert(aten::tan, {s[sp]});
381 0 : break;
382 :
383 0 : case cTanh:
384 0 : s[sp] = _graph->insert(aten::tanh, {s[sp]});
385 0 : break;
386 4 : case cSinh:
387 8 : s[sp] = _graph->insert(aten::sinh, {s[sp]});
388 4 : break;
389 4 : case cCosh:
390 8 : s[sp] = _graph->insert(aten::cosh, {s[sp]});
391 4 : break;
392 :
393 2 : case cSinCos:
394 4 : s[sp + 1] = _graph->insert(aten::cos, {s[sp]});
395 4 : s[sp] = _graph->insert(aten::sin, {s[sp]});
396 : ++sp;
397 2 : break;
398 :
399 80 : case cAbs:
400 160 : s[sp] = _graph->insert(aten::abs, {s[sp]});
401 80 : break;
402 4 : case cMax:
403 4 : --sp;
404 12 : s[sp] = _graph->insert(aten::maximum, {s[sp], s[sp + 1]});
405 4 : break;
406 4 : case cMin:
407 4 : --sp;
408 12 : s[sp] = _graph->insert(aten::minimum, {s[sp], s[sp + 1]});
409 4 : break;
410 :
411 0 : case cInt:
412 0 : s[sp] = _graph->insert(aten::round, {s[sp]});
413 0 : break;
414 :
415 4 : case cLog:
416 8 : s[sp] = _graph->insert(aten::log, {s[sp]});
417 4 : break;
418 0 : case cLog2:
419 0 : s[sp] = _graph->insert(aten::log2, {s[sp]});
420 0 : break;
421 0 : case cLog10:
422 0 : s[sp] = _graph->insert(aten::log10, {s[sp]});
423 0 : break;
424 :
425 4 : case cNeg:
426 8 : s[sp] = _graph->insert(aten::neg, {s[sp]});
427 4 : break;
428 :
429 116 : case cSqr:
430 348 : s[sp] = _graph->insert(aten::mul, {s[sp], s[sp]});
431 116 : break;
432 0 : case cSqrt:
433 0 : s[sp] = _graph->insert(aten::sqrt, {s[sp]});
434 0 : break;
435 4 : case cRSqrt:
436 8 : s[sp] = _graph->insert(aten::rsqrt, {s[sp]});
437 4 : break;
438 4 : case cPow:
439 4 : --sp;
440 12 : s[sp] = _graph->insert(aten::pow, {s[sp], s[sp + 1]});
441 4 : break;
442 4 : case cExp:
443 8 : s[sp] = _graph->insert(aten::exp, {s[sp]});
444 4 : break;
445 0 : case cExp2:
446 0 : s[sp] = _graph->insert(aten::exp2, {s[sp]});
447 0 : break;
448 4 : case cCbrt:
449 12 : s[sp] = _graph->insert(aten::pow, {s[sp], const_one_third});
450 4 : break;
451 :
452 4 : case cHypot:
453 4 : --sp;
454 12 : s[sp] = _graph->insert(aten::hypot, {s[sp], s[sp + 1]});
455 4 : break;
456 :
457 4 : case cAcos:
458 8 : s[sp] = _graph->insert(aten::acos, {s[sp]});
459 4 : break;
460 4 : case cAcosh:
461 8 : s[sp] = _graph->insert(aten::acosh, {s[sp]});
462 4 : break;
463 4 : case cAsin:
464 8 : s[sp] = _graph->insert(aten::asin, {s[sp]});
465 4 : break;
466 4 : case cAsinh:
467 8 : s[sp] = _graph->insert(aten::asinh, {s[sp]});
468 4 : break;
469 4 : case cAtan:
470 8 : s[sp] = _graph->insert(aten::atan, {s[sp]});
471 4 : break;
472 4 : case cAtan2:
473 4 : --sp;
474 12 : s[sp] = _graph->insert(aten::atan2, {s[sp], s[sp + 1]});
475 4 : break;
476 0 : case cAtanh:
477 0 : s[sp] = _graph->insert(aten::atanh, {s[sp]});
478 0 : break;
479 :
480 118 : case cFetch:
481 118 : ++sp;
482 118 : s[sp] = s[ByteCode[++i]];
483 118 : break;
484 112 : case cDup:
485 112 : ++sp;
486 112 : s[sp] = s[sp - 1];
487 112 : break;
488 :
489 : #ifdef FP_SUPPORT_OPTIMIZER
490 8 : case cPopNMov:
491 : {
492 8 : int dst = ByteCode[++i], src = ByteCode[++i];
493 8 : s[dst] = s[src];
494 : sp = dst;
495 8 : break;
496 : }
497 0 : case cLog2by:
498 0 : --sp;
499 0 : s[sp] = _graph->insert(aten::mul, {_graph->insert(aten::log2, {s[sp]}), s[sp + 1]});
500 0 : break;
501 : case cNop:
502 : break;
503 : #endif
504 :
505 1024 : default:
506 1024 : if (op >= VarBegin)
507 : {
508 : // // load variable
509 1024 : ++sp;
510 1024 : s[sp] = _input[op - VarBegin];
511 : }
512 : else
513 : {
514 0 : throw std::runtime_error("JIT Opcode " + FP_GetOpcodeName(op) +
515 0 : " not supported for libtorch tensors.");
516 : }
517 : }
518 : }
519 :
520 398 : auto outputs = s[sp]->node()->outputs();
521 796 : for (auto output : outputs)
522 : _graph->registerOutput(output);
523 :
524 : // make sure graph is well formed
525 398 : _graph->lint();
526 :
527 : // optimization
528 398 : EliminateDeadCode(_graph); // Tracing of some ops depends on the DCE trick
529 398 : ConstantPropagation(_graph);
530 398 : EliminateCommonSubexpression(_graph);
531 398 : FuseGraph(_graph, true);
532 1998 : }
533 :
534 : namespace
535 : {
536 : template <class... Inputs>
537 : inline std::vector<c10::IValue>
538 : makeStack(Inputs &&... inputs)
539 : {
540 : return {std::forward<Inputs>(inputs)...};
541 : }
542 : }
543 :
544 : torch::Tensor
545 205598 : ParsedJITTensor::Eval(const std::vector<const torch::Tensor *> & params)
546 : {
547 : using namespace torch::jit;
548 :
549 : // build stack
550 : Stack stack;
551 692400 : for (const auto & p : params)
552 973604 : stack.push_back(*p);
553 :
554 205598 : if (_input.size() != params.size())
555 0 : throw std::runtime_error("Unexpected number of inputs in ParsedJITTensor::Eval.");
556 :
557 : // disable autograd
558 205598 : torch::NoGradGuard no_grad;
559 :
560 205598 : if (!_graph_executor)
561 672 : _graph_executor = std::make_shared<GraphExecutor>(_graph, "F");
562 :
563 205598 : _graph_executor->run(stack);
564 :
565 205598 : if (stack.size() != 1)
566 0 : throw std::runtime_error("Unexpected number vof outputs in ParsedJITTensor::Eval.");
567 :
568 205598 : return stack[0].toTensor();
569 205598 : }
|