LCOV - code coverage report
Current view: top level - src/utils - ParsedTensor.C (source / functions) Hit Total Coverage
Test: idaholab/swift: #92 (25e020) with base b3cd84 Lines: 120 208 57.7 %
Date: 2025-09-10 17:10:32 Functions: 3 4 75.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /**********************************************************************/
       2             : /*                    DO NOT MODIFY THIS HEADER                       */
       3             : /*             Swift, a Fourier spectral solver for MOOSE             */
       4             : /*                                                                    */
       5             : /*            Copyright 2024 Battelle Energy Alliance, LLC            */
       6             : /*                        ALL RIGHTS RESERVED                         */
       7             : /**********************************************************************/
       8             : 
       9             : #pragma once
      10             : 
      11             : #include "ParsedTensor.h"
      12             : #include "libmesh/extrasrc/fptypes.hh"
      13             : 
      14         360 : ParsedTensor::ParsedTensor() : FunctionParserAD(), _data(*getParserData())
      15             : {
      16         360 :   _mFFT = _data.mFuncPtrs.size();
      17         360 :   this->AddFunction("FFT", fp_dummy, 1);
      18         360 :   _miFFT = _data.mFuncPtrs.size();
      19         360 :   this->AddFunction("iFFT", fp_dummy, 1);
      20         360 : }
      21             : 
      22             : Real
      23           0 : ParsedTensor::fp_dummy(const Real *)
      24             : {
      25           0 :   throw std::runtime_error("This function is only implemented for torch tensors");
      26             : }
      27             : 
      28             : void
      29         100 : ParsedTensor::setupTensors()
      30             : {
      31             :   // allocate stack
      32         100 :   s.resize(_data.mStackSize);
      33             : 
      34             :   // convert immediate data
      35         100 :   tensor_immed.clear();
      36         144 :   for (const auto & i : _data.mImmed)
      37          88 :     tensor_immed.push_back(torch::tensor(i, MooseTensor::floatTensorOptions()));
      38         100 : }
      39             : 
      40             : torch::Tensor
      41         100 : ParsedTensor::Eval(const std::vector<const torch::Tensor *> & params)
      42             : {
      43             :   using namespace FUNCTIONPARSERTYPES;
      44             : 
      45             :   // get a reference to the stored bytecode
      46         100 :   const auto & ByteCode = _data.mByteCode;
      47             : 
      48             :   int nImmed = 0, sp = -1, op;
      49         522 :   for (unsigned int i = 0; i < ByteCode.size(); ++i)
      50             :   {
      51             :     // execute bytecode
      52         422 :     switch (op = ByteCode[i])
      53             :     {
      54          44 :       case cImmed:
      55          44 :         ++sp;
      56          44 :         s[sp] = tensor_immed[nImmed++];
      57             :         break;
      58          28 :       case cAdd:
      59          28 :         --sp;
      60          28 :         s[sp] = s[sp] + s[sp + 1];
      61          28 :         break;
      62           8 :       case cSub:
      63           8 :         --sp;
      64           8 :         s[sp] = s[sp] - s[sp + 1];
      65           8 :         break;
      66          12 :       case cRSub:
      67          12 :         --sp;
      68          12 :         s[sp] = s[sp + 1] - s[sp];
      69          12 :         break;
      70          32 :       case cMul:
      71          32 :         --sp;
      72          32 :         s[sp] = s[sp] * s[sp + 1];
      73          32 :         break;
      74          16 :       case cDiv:
      75          16 :         --sp;
      76          16 :         s[sp] = s[sp] / s[sp + 1];
      77          16 :         break;
      78           0 :       case cMod:
      79           0 :         --sp;
      80           0 :         s[sp] = fmod(s[sp], s[sp + 1]);
      81           0 :         break;
      82           0 :       case cRDiv:
      83           0 :         --sp;
      84           0 :         s[sp] = s[sp + 1] / s[sp];
      85           0 :         break;
      86             : 
      87           6 :       case cSin:
      88           6 :         s[sp] = sin(s[sp]);
      89           6 :         break;
      90           6 :       case cCos:
      91           6 :         s[sp] = cos(s[sp]);
      92           6 :         break;
      93           0 :       case cTan:
      94           0 :         s[sp] = tan(s[sp]);
      95           0 :         break;
      96           4 :       case cSinh:
      97           4 :         s[sp] = sinh(s[sp]);
      98           4 :         break;
      99           4 :       case cCosh:
     100           4 :         s[sp] = cosh(s[sp]);
     101           4 :         break;
     102           8 :       case cTanh:
     103           8 :         s[sp] = tanh(s[sp]);
     104           8 :         break;
     105           0 :       case cCsc:
     106           0 :         s[sp] = 1.0 / sin(s[sp]);
     107           0 :         break;
     108           0 :       case cSec:
     109           0 :         s[sp] = 1.0 / cos(s[sp]);
     110           0 :         break;
     111           0 :       case cCot:
     112           0 :         s[sp] = 1.0 / tan(s[sp]);
     113           0 :         break;
     114           2 :       case cSinCos:
     115           4 :         s[sp + 1] = cos(s[sp]);
     116           2 :         s[sp] = sin(s[sp]);
     117             :         ++sp;
     118           2 :         break;
     119           0 :       case cSinhCosh:
     120           0 :         s[sp + 1] = cosh(s[sp]);
     121           0 :         s[sp] = sinh(s[sp]);
     122             :         ++sp;
     123           0 :         break;
     124           4 :       case cAsin:
     125           4 :         s[sp] = asin(s[sp]);
     126           4 :         break;
     127           4 :       case cAcos:
     128           4 :         s[sp] = acos(s[sp]);
     129           4 :         break;
     130           4 :       case cAsinh:
     131           4 :         s[sp] = asinh(s[sp]);
     132           4 :         break;
     133           4 :       case cAcosh:
     134           4 :         s[sp] = acosh(s[sp]);
     135           4 :         break;
     136           4 :       case cAtan:
     137           4 :         s[sp] = atan(s[sp]);
     138           4 :         break;
     139           0 :       case cAtanh:
     140           0 :         s[sp] = atanh(s[sp]);
     141           0 :         break;
     142           4 :       case cAtan2:
     143           4 :         --sp;
     144           4 :         s[sp] = atan2(s[sp], s[sp + 1]);
     145           4 :         break;
     146          12 :       case cHypot:
     147          12 :         --sp;
     148          12 :         s[sp] = torch::hypot(s[sp], s[sp + 1]);
     149          12 :         break;
     150             : 
     151           4 :       case cAbs:
     152           4 :         s[sp] = abs(s[sp]);
     153           4 :         break;
     154           4 :       case cMax:
     155           4 :         --sp;
     156           4 :         s[sp] = torch::maximum(s[sp], s[sp + 1]);
     157           4 :         break;
     158           4 :       case cMin:
     159           4 :         --sp;
     160           4 :         s[sp] = torch::minimum(s[sp], s[sp + 1]);
     161           4 :         break;
     162           0 :       case cTrunc:
     163           0 :         s[sp] = torch::trunc(s[sp]);
     164           0 :         break;
     165           0 :       case cCeil:
     166           0 :         s[sp] = torch::ceil(s[sp]);
     167           0 :         break;
     168           0 :       case cFloor:
     169           0 :         s[sp] = torch::floor(s[sp]);
     170           0 :         break;
     171           0 :       case cInt:
     172           0 :         s[sp] = torch::round(s[sp]);
     173           0 :         break;
     174             : 
     175             :         // case cEqual:
     176             :         //   //--sp; s[sp] = s[sp] == s[sp+1]; break;
     177             :         //   --sp;
     178             :         //   s[sp] = abs(s[sp] - s[sp + 1]) <= eps;
     179             :         //   break;
     180             :         // case cNEqual:
     181             :         //   //--sp; s[sp] = s[sp] != s[sp+1]; break;
     182             :         //   --sp;
     183             :         //   s[sp] = abs(s[sp] - s[sp + 1]) > eps;
     184             :         //   break;
     185             :         // case cLess:
     186             :         //   --sp;
     187             :         //   s[sp] = s[sp] < (s[sp + 1] - eps);
     188             :         //   break;
     189             :         // case cLessOrEq:
     190             :         //   --sp;
     191             :         //   s[sp] = s[sp] <= (s[sp + 1] + eps);
     192             :         //   break;
     193             :         // case cGreater:
     194             :         //   --sp;
     195             :         //   s[sp] = (s[sp] - eps) > s[sp + 1];
     196             :         //   break;
     197             :         // case cGreaterOrEq:
     198             :         //   --sp;
     199             :         //   s[sp] = (s[sp] + eps) >= s[sp + 1];
     200             :         //   break;
     201             :         // case cNot:
     202             :         //   s[sp] = abs(s[sp]) < 0.5;
     203             :         //   break;
     204             :         // case cNotNot:
     205             :         //   s[sp] = abs(s[sp]) >= 0.5;
     206             :         //   break;
     207             :         // case cAbsNot:
     208             :         //   s[sp] = s[sp] < 0.5;
     209             :         //   break;
     210             :         // case cAbsNotNot:
     211             :         //   s[sp] = s[sp] >= 0.5;
     212             :         //   break;
     213             :         // case cOr:
     214             :         //   --sp;
     215             :         //   s[sp] = (abs(s[sp]) >= 0.5) || (abs(s[sp + 1]) >= 0.5);
     216             :         //   break;
     217             :         // case cAbsOr:
     218             :         //   --sp;
     219             :         //   s[sp] = (s[sp] >= 0.5) || (s[sp + 1] >= 0.5);
     220             :         //   break;
     221             :         // case cAnd:
     222             :         //   --sp;
     223             :         //   s[sp] = (abs(s[sp]) >= 0.5) && (abs(s[sp + 1]) >= 0.5);
     224             :         //   break;
     225             :         // case cAbsAnd:
     226             :         //   --sp;
     227             :         //   s[sp] = (s[sp] >= 0.5) && (s[sp + 1] >= 0.5);
     228             :         //   break;
     229             : 
     230           4 :       case cLog:
     231           4 :         s[sp] = torch::log(s[sp]);
     232           4 :         break;
     233             :       case cLog2:
     234             : #ifdef FP_SUPPORT_CPLUSPLUS11_MATH_FUNCS
     235             :         s[sp - 1] = torch::log2(s[sp - 1]);
     236             : #else
     237           0 :         s[sp] = torch::log(s[sp]) / log(2.0);
     238             : #endif
     239           0 :         break;
     240           0 :       case cLog10:
     241           0 :         s[sp] = torch::log10(s[sp]);
     242           0 :         break;
     243             : 
     244           4 :       case cNeg:
     245           4 :         s[sp] = -s[sp];
     246           4 :         break;
     247           0 :       case cInv:
     248           0 :         s[sp] = 1.0 / s[sp];
     249           0 :         break;
     250             :       case cDeg:
     251           0 :         s[sp] = s[sp] * (180.0 / libMesh::pi);
     252           0 :         break;
     253             :       case cRad:
     254           0 :         s[sp] = s[sp] / (180.0 / libMesh::pi);
     255           0 :         break;
     256             : 
     257           2 :       case cFetch:
     258           2 :         ++sp;
     259           2 :         s[sp] = s[ByteCode[++i]];
     260             :         break;
     261           2 :       case cDup:
     262           2 :         ++sp;
     263           2 :         s[sp] = s[sp - 1];
     264             :         break;
     265             : 
     266           0 :       case cFCall:
     267             :       {
     268           0 :         auto function = ByteCode[++i];
     269           0 :         if (function == _mFFT)
     270             :         {
     271           0 :           if (s[sp].dim() == 1)
     272           0 :             s[sp] = torch::fft::rfft(s[sp]);
     273           0 :           else if (s[sp].dim() == 2)
     274           0 :             s[sp] = torch::fft::rfft2(s[sp]);
     275             :           else
     276           0 :             throw std::domain_error("3D not implemented yet");
     277             :         }
     278           0 :         else if (function == _miFFT)
     279             :         {
     280           0 :           if (s[sp].dim() == 1)
     281           0 :             s[sp] = torch::fft::irfft(s[sp]);
     282           0 :           else if (s[sp].dim() == 2)
     283           0 :             s[sp] = torch::fft::irfft2(s[sp]);
     284             :           else
     285           0 :             throw std::domain_error("3D not implemented yet");
     286             :         }
     287             :         else
     288           0 :           throw std::runtime_error("Function call not supported for libtorch tensors.");
     289             :       }
     290             :       break;
     291             : 
     292             : #ifdef FP_SUPPORT_OPTIMIZER
     293           0 :       case cPopNMov:
     294             :       {
     295           0 :         int dst = ByteCode[++i], src = ByteCode[++i];
     296           0 :         s[dst] = s[src];
     297             :         sp = dst;
     298           0 :         break;
     299             :       }
     300           0 :       case cLog2by:
     301           0 :         --sp;
     302           0 :         s[sp] = (torch::log(s[sp]) / std::log(2.0)) * s[sp + 1];
     303           0 :         break;
     304             :       case cNop:
     305             :         break;
     306             : #endif
     307             : 
     308           0 :       case cSqr:
     309           0 :         s[sp] = s[sp] * s[sp];
     310           0 :         break;
     311           0 :       case cSqrt:
     312           0 :         s[sp] = torch::sqrt(s[sp]);
     313           0 :         break;
     314             :       case cRSqrt:
     315           4 :         s[sp] = torch::pow(s[sp], -0.5);
     316           4 :         break;
     317           4 :       case cPow:
     318           4 :         --sp;
     319           4 :         s[sp] = torch::pow(s[sp], s[sp + 1]);
     320           4 :         break;
     321           4 :       case cExp:
     322           4 :         s[sp] = torch::exp(s[sp]);
     323           4 :         break;
     324           0 :       case cExp2:
     325           0 :         s[sp] = torch::pow(2.0, s[sp]);
     326           0 :         break;
     327             :       case cCbrt:
     328           4 :         s[sp] = torch::pow(s[sp], 1.0 / 3.0);
     329           4 :         break;
     330             : 
     331           0 :       case cJump:
     332             :       case cIf:
     333             :       case cAbsIf:
     334             :       {
     335           0 :         throw std::domain_error("Conditionals not implemented yet");
     336             :         // unsigned long ip = ByteCode[++i] + 1;
     337             : 
     338             :         // if (op == cIf)
     339             :         //   ccout << "if (abs(s[sp--]) < 0.5);
     340             :         // if (op == cAbsIf)
     341             :         //   ccout << "if (s[" << sp-- << "] < 0.5) ";
     342             : 
     343             :         // if (ip >= ByteCode.size())
     344             :         //   ccout << "*ret = s[sp]; return;
     345             :         // else
     346             :         // {
     347             :         //   ccout << "goto l" << ip << ";
     348             :         //   stackAtTarget[ip] = sp;
     349             :         // }
     350             : 
     351             :         // ++i;
     352             :         // break;
     353             :       }
     354             : 
     355         176 :       default:
     356         176 :         if (op >= VarBegin)
     357             :         {
     358             :           // load variable
     359         176 :           ++sp;
     360         176 :           s[sp] = *params[op - VarBegin];
     361             :         }
     362             :         else
     363             :         {
     364           0 :           throw std::runtime_error("Opcode not supported for libtorch tensors.");
     365             :         }
     366             :     }
     367             :   }
     368             : 
     369         100 :   return s[sp];
     370             : }

Generated by: LCOV version 1.14