Line data Source code
1 : /**********************************************************************/ 2 : /* DO NOT MODIFY THIS HEADER */ 3 : /* Swift, a Fourier spectral solver for MOOSE */ 4 : /* */ 5 : /* Copyright 2024 Battelle Energy Alliance, LLC */ 6 : /* ALL RIGHTS RESERVED */ 7 : /**********************************************************************/ 8 : 9 : #pragma once 10 : 11 : #include "ParsedTensor.h" 12 : #include "libmesh/extrasrc/fptypes.hh" 13 : 14 360 : ParsedTensor::ParsedTensor() : FunctionParserAD(), _data(*getParserData()) 15 : { 16 360 : _mFFT = _data.mFuncPtrs.size(); 17 360 : this->AddFunction("FFT", fp_dummy, 1); 18 360 : _miFFT = _data.mFuncPtrs.size(); 19 360 : this->AddFunction("iFFT", fp_dummy, 1); 20 360 : } 21 : 22 : Real 23 0 : ParsedTensor::fp_dummy(const Real *) 24 : { 25 0 : throw std::runtime_error("This function is only implemented for torch tensors"); 26 : } 27 : 28 : void 29 100 : ParsedTensor::setupTensors() 30 : { 31 : // allocate stack 32 100 : s.resize(_data.mStackSize); 33 : 34 : // convert immediate data 35 100 : tensor_immed.clear(); 36 144 : for (const auto & i : _data.mImmed) 37 88 : tensor_immed.push_back(torch::tensor(i, MooseTensor::floatTensorOptions())); 38 100 : } 39 : 40 : torch::Tensor 41 100 : ParsedTensor::Eval(const std::vector<const torch::Tensor *> & params) 42 : { 43 : using namespace FUNCTIONPARSERTYPES; 44 : 45 : // get a reference to the stored bytecode 46 100 : const auto & ByteCode = _data.mByteCode; 47 : 48 : int nImmed = 0, sp = -1, op; 49 522 : for (unsigned int i = 0; i < ByteCode.size(); ++i) 50 : { 51 : // execute bytecode 52 422 : switch (op = ByteCode[i]) 53 : { 54 44 : case cImmed: 55 44 : ++sp; 56 44 : s[sp] = tensor_immed[nImmed++]; 57 : break; 58 28 : case cAdd: 59 28 : --sp; 60 28 : s[sp] = s[sp] + s[sp + 1]; 61 28 : break; 62 8 : case cSub: 63 8 : --sp; 64 8 : s[sp] = s[sp] - s[sp + 1]; 65 8 : break; 66 12 : case cRSub: 67 12 : --sp; 68 12 : s[sp] = s[sp + 1] - s[sp]; 69 12 : break; 70 32 : case cMul: 71 32 : --sp; 72 32 : s[sp] = s[sp] * s[sp + 1]; 73 32 : break; 74 16 : case cDiv: 75 16 : --sp; 76 16 : s[sp] = s[sp] / s[sp + 1]; 77 16 : break; 78 0 : case cMod: 79 0 : --sp; 80 0 : s[sp] = fmod(s[sp], s[sp + 1]); 81 0 : break; 82 0 : case cRDiv: 83 0 : --sp; 84 0 : s[sp] = s[sp + 1] / s[sp]; 85 0 : break; 86 : 87 6 : case cSin: 88 6 : s[sp] = sin(s[sp]); 89 6 : break; 90 6 : case cCos: 91 6 : s[sp] = cos(s[sp]); 92 6 : break; 93 0 : case cTan: 94 0 : s[sp] = tan(s[sp]); 95 0 : break; 96 4 : case cSinh: 97 4 : s[sp] = sinh(s[sp]); 98 4 : break; 99 4 : case cCosh: 100 4 : s[sp] = cosh(s[sp]); 101 4 : break; 102 8 : case cTanh: 103 8 : s[sp] = tanh(s[sp]); 104 8 : break; 105 0 : case cCsc: 106 0 : s[sp] = 1.0 / sin(s[sp]); 107 0 : break; 108 0 : case cSec: 109 0 : s[sp] = 1.0 / cos(s[sp]); 110 0 : break; 111 0 : case cCot: 112 0 : s[sp] = 1.0 / tan(s[sp]); 113 0 : break; 114 2 : case cSinCos: 115 4 : s[sp + 1] = cos(s[sp]); 116 2 : s[sp] = sin(s[sp]); 117 : ++sp; 118 2 : break; 119 0 : case cSinhCosh: 120 0 : s[sp + 1] = cosh(s[sp]); 121 0 : s[sp] = sinh(s[sp]); 122 : ++sp; 123 0 : break; 124 4 : case cAsin: 125 4 : s[sp] = asin(s[sp]); 126 4 : break; 127 4 : case cAcos: 128 4 : s[sp] = acos(s[sp]); 129 4 : break; 130 4 : case cAsinh: 131 4 : s[sp] = asinh(s[sp]); 132 4 : break; 133 4 : case cAcosh: 134 4 : s[sp] = acosh(s[sp]); 135 4 : break; 136 4 : case cAtan: 137 4 : s[sp] = atan(s[sp]); 138 4 : break; 139 0 : case cAtanh: 140 0 : s[sp] = atanh(s[sp]); 141 0 : break; 142 4 : case cAtan2: 143 4 : --sp; 144 4 : s[sp] = atan2(s[sp], s[sp + 1]); 145 4 : break; 146 12 : case cHypot: 147 12 : --sp; 148 12 : s[sp] = torch::hypot(s[sp], s[sp + 1]); 149 12 : break; 150 : 151 4 : case cAbs: 152 4 : s[sp] = abs(s[sp]); 153 4 : break; 154 4 : case cMax: 155 4 : --sp; 156 4 : s[sp] = torch::maximum(s[sp], s[sp + 1]); 157 4 : break; 158 4 : case cMin: 159 4 : --sp; 160 4 : s[sp] = torch::minimum(s[sp], s[sp + 1]); 161 4 : break; 162 0 : case cTrunc: 163 0 : s[sp] = torch::trunc(s[sp]); 164 0 : break; 165 0 : case cCeil: 166 0 : s[sp] = torch::ceil(s[sp]); 167 0 : break; 168 0 : case cFloor: 169 0 : s[sp] = torch::floor(s[sp]); 170 0 : break; 171 0 : case cInt: 172 0 : s[sp] = torch::round(s[sp]); 173 0 : break; 174 : 175 : // case cEqual: 176 : // //--sp; s[sp] = s[sp] == s[sp+1]; break; 177 : // --sp; 178 : // s[sp] = abs(s[sp] - s[sp + 1]) <= eps; 179 : // break; 180 : // case cNEqual: 181 : // //--sp; s[sp] = s[sp] != s[sp+1]; break; 182 : // --sp; 183 : // s[sp] = abs(s[sp] - s[sp + 1]) > eps; 184 : // break; 185 : // case cLess: 186 : // --sp; 187 : // s[sp] = s[sp] < (s[sp + 1] - eps); 188 : // break; 189 : // case cLessOrEq: 190 : // --sp; 191 : // s[sp] = s[sp] <= (s[sp + 1] + eps); 192 : // break; 193 : // case cGreater: 194 : // --sp; 195 : // s[sp] = (s[sp] - eps) > s[sp + 1]; 196 : // break; 197 : // case cGreaterOrEq: 198 : // --sp; 199 : // s[sp] = (s[sp] + eps) >= s[sp + 1]; 200 : // break; 201 : // case cNot: 202 : // s[sp] = abs(s[sp]) < 0.5; 203 : // break; 204 : // case cNotNot: 205 : // s[sp] = abs(s[sp]) >= 0.5; 206 : // break; 207 : // case cAbsNot: 208 : // s[sp] = s[sp] < 0.5; 209 : // break; 210 : // case cAbsNotNot: 211 : // s[sp] = s[sp] >= 0.5; 212 : // break; 213 : // case cOr: 214 : // --sp; 215 : // s[sp] = (abs(s[sp]) >= 0.5) || (abs(s[sp + 1]) >= 0.5); 216 : // break; 217 : // case cAbsOr: 218 : // --sp; 219 : // s[sp] = (s[sp] >= 0.5) || (s[sp + 1] >= 0.5); 220 : // break; 221 : // case cAnd: 222 : // --sp; 223 : // s[sp] = (abs(s[sp]) >= 0.5) && (abs(s[sp + 1]) >= 0.5); 224 : // break; 225 : // case cAbsAnd: 226 : // --sp; 227 : // s[sp] = (s[sp] >= 0.5) && (s[sp + 1] >= 0.5); 228 : // break; 229 : 230 4 : case cLog: 231 4 : s[sp] = torch::log(s[sp]); 232 4 : break; 233 : case cLog2: 234 : #ifdef FP_SUPPORT_CPLUSPLUS11_MATH_FUNCS 235 : s[sp - 1] = torch::log2(s[sp - 1]); 236 : #else 237 0 : s[sp] = torch::log(s[sp]) / log(2.0); 238 : #endif 239 0 : break; 240 0 : case cLog10: 241 0 : s[sp] = torch::log10(s[sp]); 242 0 : break; 243 : 244 4 : case cNeg: 245 4 : s[sp] = -s[sp]; 246 4 : break; 247 0 : case cInv: 248 0 : s[sp] = 1.0 / s[sp]; 249 0 : break; 250 : case cDeg: 251 0 : s[sp] = s[sp] * (180.0 / libMesh::pi); 252 0 : break; 253 : case cRad: 254 0 : s[sp] = s[sp] / (180.0 / libMesh::pi); 255 0 : break; 256 : 257 2 : case cFetch: 258 2 : ++sp; 259 2 : s[sp] = s[ByteCode[++i]]; 260 : break; 261 2 : case cDup: 262 2 : ++sp; 263 2 : s[sp] = s[sp - 1]; 264 : break; 265 : 266 0 : case cFCall: 267 : { 268 0 : auto function = ByteCode[++i]; 269 0 : if (function == _mFFT) 270 : { 271 0 : if (s[sp].dim() == 1) 272 0 : s[sp] = torch::fft::rfft(s[sp]); 273 0 : else if (s[sp].dim() == 2) 274 0 : s[sp] = torch::fft::rfft2(s[sp]); 275 : else 276 0 : throw std::domain_error("3D not implemented yet"); 277 : } 278 0 : else if (function == _miFFT) 279 : { 280 0 : if (s[sp].dim() == 1) 281 0 : s[sp] = torch::fft::irfft(s[sp]); 282 0 : else if (s[sp].dim() == 2) 283 0 : s[sp] = torch::fft::irfft2(s[sp]); 284 : else 285 0 : throw std::domain_error("3D not implemented yet"); 286 : } 287 : else 288 0 : throw std::runtime_error("Function call not supported for libtorch tensors."); 289 : } 290 : break; 291 : 292 : #ifdef FP_SUPPORT_OPTIMIZER 293 0 : case cPopNMov: 294 : { 295 0 : int dst = ByteCode[++i], src = ByteCode[++i]; 296 0 : s[dst] = s[src]; 297 : sp = dst; 298 0 : break; 299 : } 300 0 : case cLog2by: 301 0 : --sp; 302 0 : s[sp] = (torch::log(s[sp]) / std::log(2.0)) * s[sp + 1]; 303 0 : break; 304 : case cNop: 305 : break; 306 : #endif 307 : 308 0 : case cSqr: 309 0 : s[sp] = s[sp] * s[sp]; 310 0 : break; 311 0 : case cSqrt: 312 0 : s[sp] = torch::sqrt(s[sp]); 313 0 : break; 314 : case cRSqrt: 315 4 : s[sp] = torch::pow(s[sp], -0.5); 316 4 : break; 317 4 : case cPow: 318 4 : --sp; 319 4 : s[sp] = torch::pow(s[sp], s[sp + 1]); 320 4 : break; 321 4 : case cExp: 322 4 : s[sp] = torch::exp(s[sp]); 323 4 : break; 324 0 : case cExp2: 325 0 : s[sp] = torch::pow(2.0, s[sp]); 326 0 : break; 327 : case cCbrt: 328 4 : s[sp] = torch::pow(s[sp], 1.0 / 3.0); 329 4 : break; 330 : 331 0 : case cJump: 332 : case cIf: 333 : case cAbsIf: 334 : { 335 0 : throw std::domain_error("Conditionals not implemented yet"); 336 : // unsigned long ip = ByteCode[++i] + 1; 337 : 338 : // if (op == cIf) 339 : // ccout << "if (abs(s[sp--]) < 0.5); 340 : // if (op == cAbsIf) 341 : // ccout << "if (s[" << sp-- << "] < 0.5) "; 342 : 343 : // if (ip >= ByteCode.size()) 344 : // ccout << "*ret = s[sp]; return; 345 : // else 346 : // { 347 : // ccout << "goto l" << ip << "; 348 : // stackAtTarget[ip] = sp; 349 : // } 350 : 351 : // ++i; 352 : // break; 353 : } 354 : 355 176 : default: 356 176 : if (op >= VarBegin) 357 : { 358 : // load variable 359 176 : ++sp; 360 176 : s[sp] = *params[op - VarBegin]; 361 : } 362 : else 363 : { 364 0 : throw std::runtime_error("Opcode not supported for libtorch tensors."); 365 : } 366 : } 367 : } 368 : 369 100 : return s[sp]; 370 : }