LCOV - code coverage report
Current view: top level - src/utils - ParsedJITTensor.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 146 350 41.7 %
Date: 2025-09-10 17:10:32 Functions: 3 4 75.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include "ParsedJITTensor.h"
      12             : #include "Conversion.h"
      13             : #include "SwiftUtils.h"
      14             : #include "libmesh/extrasrc/fptypes.hh"
      15             : 
      16             : #include <torch/csrc/autograd/generated/variable_factories.h>
      17             : #include <torch/csrc/jit/frontend/ir_emitter.h>
      18             : #include <torch/csrc/jit/ir/alias_analysis.h>
      19             : #include <torch/csrc/jit/ir/irparser.h>
      20             : #include <torch/csrc/jit/ir/type_hashing.h>
      21             : #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
      22             : #include <torch/csrc/jit/runtime/custom_operator.h>
      23             : #include <torch/csrc/jit/runtime/graph_iterator.h>
      24             : 
      25             : #include <torch/csrc/jit/passes/graph_fuser.h>
      26             : #include <torch/csrc/jit/passes/constant_propagation.h>
      27             : #include <torch/csrc/jit/passes/dead_code_elimination.h>
      28             : #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
      29             : 
      30             : #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
      31             : 
      32             : #include <torch/csrc/autograd/grad_mode.h>
      33             : 
      34             : #include <ATen/TensorOperators.h>
      35             : 
      36         360 : ParsedJITTensor::ParsedJITTensor()
      37         360 :   : FunctionParserAD(), _graph_executor(nullptr), _execution_plan(nullptr), _data(*getParserData())
      38             : {
      39         360 : }
      40             : 
      41             : namespace
      42             : {
      43             : const std::string
      44           0 : FP_GetOpcodeName(int opcode)
      45             : {
      46             :   using namespace FUNCTIONPARSERTYPES;
      47             : 
      48             :   /* Symbolic meanings for the opcodes? */
      49             :   const char * p = 0;
      50           0 :   switch (opcode)
      51             :   {
      52             :     case cAbs:
      53             :       p = "cAbs";
      54             :       break;
      55           0 :     case cAcos:
      56             :       p = "cAcos";
      57           0 :       break;
      58           0 :     case cAcosh:
      59             :       p = "cAcosh";
      60           0 :       break;
      61           0 :     case cArg:
      62             :       p = "cArg";
      63           0 :       break;
      64           0 :     case cAsin:
      65             :       p = "cAsin";
      66           0 :       break;
      67           0 :     case cAsinh:
      68             :       p = "cAsinh";
      69           0 :       break;
      70           0 :     case cAtan:
      71             :       p = "cAtan";
      72           0 :       break;
      73           0 :     case cAtan2:
      74             :       p = "cAtan2";
      75           0 :       break;
      76           0 :     case cAtanh:
      77             :       p = "cAtanh";
      78           0 :       break;
      79           0 :     case cCbrt:
      80             :       p = "cCbrt";
      81           0 :       break;
      82           0 :     case cCeil:
      83             :       p = "cCeil";
      84           0 :       break;
      85           0 :     case cConj:
      86             :       p = "cConj";
      87           0 :       break;
      88           0 :     case cCos:
      89             :       p = "cCos";
      90           0 :       break;
      91           0 :     case cCosh:
      92             :       p = "cCosh";
      93           0 :       break;
      94           0 :     case cCot:
      95             :       p = "cCot";
      96           0 :       break;
      97           0 :     case cCsc:
      98             :       p = "cCsc";
      99           0 :       break;
     100           0 :     case cExp:
     101             :       p = "cExp";
     102           0 :       break;
     103           0 :     case cExp2:
     104             :       p = "cExp2";
     105           0 :       break;
     106           0 :     case cFloor:
     107             :       p = "cFloor";
     108           0 :       break;
     109           0 :     case cHypot:
     110             :       p = "cHypot";
     111           0 :       break;
     112           0 :     case cIf:
     113             :       p = "cIf";
     114           0 :       break;
     115           0 :     case cImag:
     116             :       p = "cImag";
     117           0 :       break;
     118           0 :     case cInt:
     119             :       p = "cInt";
     120           0 :       break;
     121           0 :     case cLog:
     122             :       p = "cLog";
     123           0 :       break;
     124           0 :     case cLog2:
     125             :       p = "cLog2";
     126           0 :       break;
     127           0 :     case cLog10:
     128             :       p = "cLog10";
     129           0 :       break;
     130           0 :     case cMax:
     131             :       p = "cMax";
     132           0 :       break;
     133           0 :     case cMin:
     134             :       p = "cMin";
     135           0 :       break;
     136           0 :     case cPolar:
     137             :       p = "cPolar";
     138           0 :       break;
     139           0 :     case cPow:
     140             :       p = "cPow";
     141           0 :       break;
     142           0 :     case cReal:
     143             :       p = "cReal";
     144           0 :       break;
     145           0 :     case cSec:
     146             :       p = "cSec";
     147           0 :       break;
     148           0 :     case cSin:
     149             :       p = "cSin";
     150           0 :       break;
     151           0 :     case cSinh:
     152             :       p = "cSinh";
     153           0 :       break;
     154           0 :     case cSqrt:
     155             :       p = "cSqrt";
     156           0 :       break;
     157           0 :     case cTan:
     158             :       p = "cTan";
     159           0 :       break;
     160           0 :     case cTanh:
     161             :       p = "cTanh";
     162           0 :       break;
     163           0 :     case cTrunc:
     164             :       p = "cTrunc";
     165           0 :       break;
     166           0 :     case cImmed:
     167             :       p = "cImmed";
     168           0 :       break;
     169           0 :     case cJump:
     170             :       p = "cJump";
     171           0 :       break;
     172           0 :     case cNeg:
     173             :       p = "cNeg";
     174           0 :       break;
     175           0 :     case cAdd:
     176             :       p = "cAdd";
     177           0 :       break;
     178           0 :     case cSub:
     179             :       p = "cSub";
     180           0 :       break;
     181           0 :     case cMul:
     182             :       p = "cMul";
     183           0 :       break;
     184           0 :     case cDiv:
     185             :       p = "cDiv";
     186           0 :       break;
     187           0 :     case cMod:
     188             :       p = "cMod";
     189           0 :       break;
     190           0 :     case cEqual:
     191             :       p = "cEqual";
     192           0 :       break;
     193           0 :     case cNEqual:
     194             :       p = "cNEqual";
     195           0 :       break;
     196           0 :     case cLess:
     197             :       p = "cLess";
     198           0 :       break;
     199           0 :     case cLessOrEq:
     200             :       p = "cLessOrEq";
     201           0 :       break;
     202           0 :     case cGreater:
     203             :       p = "cGreater";
     204           0 :       break;
     205           0 :     case cGreaterOrEq:
     206             :       p = "cGreaterOrEq";
     207           0 :       break;
     208           0 :     case cNot:
     209             :       p = "cNot";
     210           0 :       break;
     211           0 :     case cAnd:
     212             :       p = "cAnd";
     213           0 :       break;
     214           0 :     case cOr:
     215             :       p = "cOr";
     216           0 :       break;
     217           0 :     case cDeg:
     218             :       p = "cDeg";
     219           0 :       break;
     220           0 :     case cRad:
     221             :       p = "cRad";
     222           0 :       break;
     223           0 :     case cFCall:
     224             :       p = "cFCall";
     225           0 :       break;
     226           0 :     case cPCall:
     227             :       p = "cPCall";
     228           0 :       break;
     229             : #ifdef FP_SUPPORT_OPTIMIZER
     230           0 :     case cFetch:
     231             :       p = "cFetch";
     232           0 :       break;
     233           0 :     case cPopNMov:
     234             :       p = "cPopNMov";
     235           0 :       break;
     236           0 :     case cLog2by:
     237             :       p = "cLog2by";
     238           0 :       break;
     239           0 :     case cNop:
     240             :       p = "cNop";
     241           0 :       break;
     242             : #endif
     243           0 :     case cSinCos:
     244             :       p = "cSinCos";
     245           0 :       break;
     246           0 :     case cSinhCosh:
     247             :       p = "cSinhCosh";
     248           0 :       break;
     249           0 :     case cAbsNot:
     250             :       p = "cAbsNot";
     251           0 :       break;
     252           0 :     case cAbsNotNot:
     253             :       p = "cAbsNotNot";
     254           0 :       break;
     255           0 :     case cAbsAnd:
     256             :       p = "cAbsAnd";
     257           0 :       break;
     258           0 :     case cAbsOr:
     259             :       p = "cAbsOr";
     260           0 :       break;
     261           0 :     case cAbsIf:
     262             :       p = "cAbsIf";
     263           0 :       break;
     264           0 :     case cDup:
     265             :       p = "cDup";
     266           0 :       break;
     267           0 :     case cInv:
     268             :       p = "cInv";
     269           0 :       break;
     270           0 :     case cSqr:
     271             :       p = "cSqr";
     272           0 :       break;
     273           0 :     case cRDiv:
     274             :       p = "cRDiv";
     275           0 :       break;
     276           0 :     case cRSub:
     277             :       p = "cRSub";
     278           0 :       break;
     279           0 :     case cNotNot:
     280             :       p = "cNotNot";
     281           0 :       break;
     282           0 :     case cRSqrt:
     283             :       p = "cRSqrt";
     284           0 :       break;
     285           0 :     case VarBegin:
     286             :       p = "VarBegin";
     287           0 :       break;
     288           0 :     default:
     289           0 :       throw std::runtime_error("Unknown opcode.");
     290             :   }
     291           0 :   std::ostringstream tmp;
     292             :   // if(!p) std::cerr << "o=" << opcode << "\n";
     293             :   assert(p);
     294           0 :   tmp << p;
     295           0 :   return tmp.str();
     296             : }
     297             : }
     298             : 
     299             : void
     300         398 : ParsedJITTensor::setupTensors()
     301             : {
     302             :   using namespace torch::jit;
     303             : 
     304             :   // allocate node stack
     305         398 :   std::vector<Value *> s(_data.mStackSize);
     306             : 
     307             :   // create graph
     308         398 :   _graph = std::make_shared<Graph>();
     309             : 
     310             :   // convert immediate data
     311         398 :   _constant_immed.clear();
     312         788 :   for (const auto & immed : _data.mImmed)
     313         780 :     _constant_immed.push_back(_graph->insertConstant(immed));
     314             : 
     315             :   // math constants
     316         796 :   const auto const_one_third = _graph->insertConstant(1.0 / 3.0);
     317             : 
     318             :   // create input nodes
     319         398 :   _input.clear();
     320        2078 :   for (unsigned i = 0; i < _data.mVariablesAmount; ++i)
     321        5040 :     _input.push_back(_graph->addInput());
     322             : 
     323             :   // build graph
     324             :   using namespace FUNCTIONPARSERTYPES;
     325             : 
     326             :   // get a reference to the stored bytecode
     327             :   const auto & ByteCode = _data.mByteCode;
     328             : 
     329             :   int nImmed = 0, sp = -1, op;
     330        3648 :   for (unsigned int i = 0; i < ByteCode.size(); ++i)
     331             :   {
     332             :     // execute bytecode
     333        3250 :     switch (op = ByteCode[i])
     334             :     {
     335         390 :       case cImmed:
     336         390 :         ++sp;
     337         390 :         s[sp] = _constant_immed[nImmed++];
     338         390 :         break;
     339             : 
     340         446 :       case cAdd:
     341         446 :         --sp;
     342        1338 :         s[sp] = _graph->insert(aten::add, {s[sp], s[sp + 1]});
     343         446 :         break;
     344          68 :       case cSub:
     345          68 :         --sp;
     346         204 :         s[sp] = _graph->insert(aten::sub, {s[sp], s[sp + 1]});
     347          68 :         break;
     348          88 :       case cRSub:
     349          88 :         --sp;
     350         264 :         s[sp] = _graph->insert(aten::sub, {s[sp + 1], s[sp]});
     351          88 :         break;
     352             : 
     353         504 :       case cMul:
     354         504 :         --sp;
     355        1512 :         s[sp] = _graph->insert(aten::mul, {s[sp], s[sp + 1]});
     356         504 :         break;
     357          16 :       case cDiv:
     358          16 :         --sp;
     359          48 :         s[sp] = _graph->insert(aten::div, {s[sp], s[sp + 1]});
     360          16 :         break;
     361           0 :       case cInv:
     362           0 :         s[sp] = _graph->insert(aten::reciprocal, {s[sp]});
     363           0 :         break;
     364           0 :       case cMod:
     365           0 :         --sp;
     366           0 :         s[sp] = _graph->insert(aten::fmod, {s[sp], s[sp + 1]});
     367           0 :         break;
     368           0 :       case cRDiv:
     369           0 :         --sp;
     370           0 :         s[sp] = _graph->insert(aten::div, {s[sp + 1], s[sp]});
     371           0 :         break;
     372             : 
     373         180 :       case cSin:
     374         360 :         s[sp] = _graph->insert(aten::sin, {s[sp]});
     375         180 :         break;
     376          30 :       case cCos:
     377          60 :         s[sp] = _graph->insert(aten::cos, {s[sp]});
     378          30 :         break;
     379           0 :       case cTan:
     380           0 :         s[sp] = _graph->insert(aten::tan, {s[sp]});
     381           0 :         break;
     382             : 
     383           0 :       case cTanh:
     384           0 :         s[sp] = _graph->insert(aten::tanh, {s[sp]});
     385           0 :         break;
     386           4 :       case cSinh:
     387           8 :         s[sp] = _graph->insert(aten::sinh, {s[sp]});
     388           4 :         break;
     389           4 :       case cCosh:
     390           8 :         s[sp] = _graph->insert(aten::cosh, {s[sp]});
     391           4 :         break;
     392             : 
     393           2 :       case cSinCos:
     394           4 :         s[sp + 1] = _graph->insert(aten::cos, {s[sp]});
     395           4 :         s[sp] = _graph->insert(aten::sin, {s[sp]});
     396             :         ++sp;
     397           2 :         break;
     398             : 
     399          80 :       case cAbs:
     400         160 :         s[sp] = _graph->insert(aten::abs, {s[sp]});
     401          80 :         break;
     402           4 :       case cMax:
     403           4 :         --sp;
     404          12 :         s[sp] = _graph->insert(aten::maximum, {s[sp], s[sp + 1]});
     405           4 :         break;
     406           4 :       case cMin:
     407           4 :         --sp;
     408          12 :         s[sp] = _graph->insert(aten::minimum, {s[sp], s[sp + 1]});
     409           4 :         break;
     410             : 
     411           0 :       case cInt:
     412           0 :         s[sp] = _graph->insert(aten::round, {s[sp]});
     413           0 :         break;
     414             : 
     415           4 :       case cLog:
     416           8 :         s[sp] = _graph->insert(aten::log, {s[sp]});
     417           4 :         break;
     418           0 :       case cLog2:
     419           0 :         s[sp] = _graph->insert(aten::log2, {s[sp]});
     420           0 :         break;
     421           0 :       case cLog10:
     422           0 :         s[sp] = _graph->insert(aten::log10, {s[sp]});
     423           0 :         break;
     424             : 
     425           4 :       case cNeg:
     426           8 :         s[sp] = _graph->insert(aten::neg, {s[sp]});
     427           4 :         break;
     428             : 
     429         116 :       case cSqr:
     430         348 :         s[sp] = _graph->insert(aten::mul, {s[sp], s[sp]});
     431         116 :         break;
     432           0 :       case cSqrt:
     433           0 :         s[sp] = _graph->insert(aten::sqrt, {s[sp]});
     434           0 :         break;
     435           4 :       case cRSqrt:
     436           8 :         s[sp] = _graph->insert(aten::rsqrt, {s[sp]});
     437           4 :         break;
     438           4 :       case cPow:
     439           4 :         --sp;
     440          12 :         s[sp] = _graph->insert(aten::pow, {s[sp], s[sp + 1]});
     441           4 :         break;
     442           4 :       case cExp:
     443           8 :         s[sp] = _graph->insert(aten::exp, {s[sp]});
     444           4 :         break;
     445           0 :       case cExp2:
     446           0 :         s[sp] = _graph->insert(aten::exp2, {s[sp]});
     447           0 :         break;
     448           4 :       case cCbrt:
     449          12 :         s[sp] = _graph->insert(aten::pow, {s[sp], const_one_third});
     450           4 :         break;
     451             : 
     452           4 :       case cHypot:
     453           4 :         --sp;
     454          12 :         s[sp] = _graph->insert(aten::hypot, {s[sp], s[sp + 1]});
     455           4 :         break;
     456             : 
     457           4 :       case cAcos:
     458           8 :         s[sp] = _graph->insert(aten::acos, {s[sp]});
     459           4 :         break;
     460           4 :       case cAcosh:
     461           8 :         s[sp] = _graph->insert(aten::acosh, {s[sp]});
     462           4 :         break;
     463           4 :       case cAsin:
     464           8 :         s[sp] = _graph->insert(aten::asin, {s[sp]});
     465           4 :         break;
     466           4 :       case cAsinh:
     467           8 :         s[sp] = _graph->insert(aten::asinh, {s[sp]});
     468           4 :         break;
     469           4 :       case cAtan:
     470           8 :         s[sp] = _graph->insert(aten::atan, {s[sp]});
     471           4 :         break;
     472           4 :       case cAtan2:
     473           4 :         --sp;
     474          12 :         s[sp] = _graph->insert(aten::atan2, {s[sp], s[sp + 1]});
     475           4 :         break;
     476           0 :       case cAtanh:
     477           0 :         s[sp] = _graph->insert(aten::atanh, {s[sp]});
     478           0 :         break;
     479             : 
     480         118 :       case cFetch:
     481         118 :         ++sp;
     482         118 :         s[sp] = s[ByteCode[++i]];
     483         118 :         break;
     484         112 :       case cDup:
     485         112 :         ++sp;
     486         112 :         s[sp] = s[sp - 1];
     487         112 :         break;
     488             : 
     489             : #ifdef FP_SUPPORT_OPTIMIZER
     490           8 :       case cPopNMov:
     491             :       {
     492           8 :         int dst = ByteCode[++i], src = ByteCode[++i];
     493           8 :         s[dst] = s[src];
     494             :         sp = dst;
     495           8 :         break;
     496             :       }
     497           0 :       case cLog2by:
     498           0 :         --sp;
     499           0 :         s[sp] = _graph->insert(aten::mul, {_graph->insert(aten::log2, {s[sp]}), s[sp + 1]});
     500           0 :         break;
     501             :       case cNop:
     502             :         break;
     503             : #endif
     504             : 
     505        1024 :       default:
     506        1024 :         if (op >= VarBegin)
     507             :         {
     508             :           // // load variable
     509        1024 :           ++sp;
     510        1024 :           s[sp] = _input[op - VarBegin];
     511             :         }
     512             :         else
     513             :         {
     514           0 :           throw std::runtime_error("JIT Opcode " + FP_GetOpcodeName(op) +
     515           0 :                                    " not supported for libtorch tensors.");
     516             :         }
     517             :     }
     518             :   }
     519             : 
     520         398 :   auto outputs = s[sp]->node()->outputs();
     521         796 :   for (auto output : outputs)
     522             :     _graph->registerOutput(output);
     523             : 
     524             :   // make sure graph is well formed
     525         398 :   _graph->lint();
     526             : 
     527             :   // optimization
     528         398 :   EliminateDeadCode(_graph); // Tracing of some ops depends on the DCE trick
     529         398 :   ConstantPropagation(_graph);
     530         398 :   EliminateCommonSubexpression(_graph);
     531         398 :   FuseGraph(_graph, true);
     532        1998 : }
     533             : 
     534             : namespace
     535             : {
     536             : template <class... Inputs>
     537             : inline std::vector<c10::IValue>
     538             : makeStack(Inputs &&... inputs)
     539             : {
     540             :   return {std::forward<Inputs>(inputs)...};
     541             : }
     542             : }
     543             : 
     544             : torch::Tensor
     545      205598 : ParsedJITTensor::Eval(const std::vector<const torch::Tensor *> & params)
     546             : {
     547             :   using namespace torch::jit;
     548             : 
     549             :   // build stack
     550             :   Stack stack;
     551      692400 :   for (const auto & p : params)
     552      973604 :     stack.push_back(*p);
     553             : 
     554      205598 :   if (_input.size() != params.size())
     555           0 :     throw std::runtime_error("Unexpected number of inputs in ParsedJITTensor::Eval.");
     556             : 
     557             :   // disable autograd
     558      205598 :   torch::NoGradGuard no_grad;
     559             : 
     560      205598 :   if (!_graph_executor)
     561         672 :     _graph_executor = std::make_shared<GraphExecutor>(_graph, "F");
     562             : 
     563      205598 :   _graph_executor->run(stack);
     564             : 
     565      205598 :   if (stack.size() != 1)
     566           0 :     throw std::runtime_error("Unexpected number vof outputs in ParsedJITTensor::Eval.");
     567             : 
     568      205598 :   return stack[0].toTensor();
     569      205598 : }

Generated by: LCOV version 1.14