Line data Source code
1 : //* This file is part of the MOOSE framework
2 : //* https://mooseframework.inl.gov
3 : //*
4 : //* All rights reserved, see COPYRIGHT for full restrictions
5 : //* https://github.com/idaholab/moose/blob/master/COPYRIGHT
6 : //*
7 : //* Licensed under LGPL 2.1, please see LICENSE for details
8 : //* https://www.gnu.org/licenses/lgpl-2.1.html
9 :
10 : #pragma once
11 :
12 : #include "libmesh/utility.h"
13 : #include "libmesh/compare_types.h"
14 : #include "Conversion.h"
15 : #include <limits>
16 :
17 : namespace CompileTimeDerivatives
18 : {
19 :
20 : /**
21 : * All compile time derivative system objects derive from this (empty) base class.
22 : * This allows to restrict templated operators to only act on compile time derivative
23 : * system objects (using std::enable_if).
24 : */
25 : class CTBase
26 : {
27 : public:
28 : /// precedence should reflect C++ operator precedence exactly (higher is binding tighter)
29 10 : constexpr static int precedence() { return 0; }
30 : /// left/right associative property should reflect C++ operator properties exactly
31 0 : constexpr static bool leftAssociative() { return false; }
32 : };
33 : class CTNullBase : public CTBase
34 : {
35 : };
36 : class CTOneBase : public CTBase
37 : {
38 : };
39 :
40 : template <typename T>
41 : using CTCleanType = typename std::remove_const<typename std::remove_reference<T>::type>::type;
42 :
43 : template <typename... Ts>
44 : struct CTSuperType;
45 :
46 : template <typename T>
47 : struct CTSuperType<T>
48 : {
49 : typedef T type;
50 : };
51 :
52 : template <typename T1, typename T2, typename... Ts>
53 : struct CTSuperType<T1, T2, Ts...>
54 : {
55 : typedef typename std::conditional<
56 : (sizeof...(Ts) > 0),
57 : typename CTSuperType<typename libMesh::CompareTypes<T1, T2>::supertype, Ts...>::type,
58 : typename libMesh::CompareTypes<T1, T2>::supertype>::type type;
59 : };
60 :
61 : /**
62 : * Operators representing variable values need to be tagged. The tag identifies
63 : * the object when a compile time derivative is taken. We use a type that can be
64 : * supplied as a template argument. int does the trick.
65 : */
66 : using CTTag = int;
67 : constexpr CTTag CTNoTag = std::numeric_limits<CTTag>::max();
68 :
69 : template <CTTag tag>
70 : std::string
71 8 : printTag()
72 : {
73 : if constexpr (tag == CTNoTag)
74 8 : return "";
75 : else
76 : return Moose::stringify(tag);
77 : }
78 :
79 : template <typename B, typename E>
80 : auto pow(const B & base, const E & exp);
81 : template <int E, typename B>
82 : auto pow(const B &);
83 :
84 : template <typename T>
85 : auto exp(const T &);
86 : template <typename T>
87 : auto log(const T &);
88 :
89 : /**
90 : * Template class to represent a "zero" value. Having zeroness associated with a type
91 : * enables compile time optimizations in some operators.
92 : */
93 : template <typename T>
94 : class CTNull : public CTNullBase
95 : {
96 : public:
97 1375149 : CTNull() {}
98 : typedef CTCleanType<T> ResultType;
99 :
100 6 : ResultType operator()() const { return ResultType(0); }
101 : std::string print() const { return "0"; }
102 :
103 : template <CTTag dtag>
104 367400 : auto D() const
105 : {
106 367400 : return CTNull<ResultType>();
107 : }
108 : };
109 :
110 : /**
111 : * Template class to represent a "one" value. Having oneness associated with a type
112 : * enables compile time optimizations in some operators.
113 : */
114 : template <typename T>
115 : class CTOne : public CTOneBase
116 : {
117 : public:
118 237881 : CTOne() {}
119 : typedef CTCleanType<T> ResultType;
120 :
121 114617 : ResultType operator()() const { return ResultType(1); }
122 : std::string print() const { return "1"; }
123 :
124 : template <CTTag dtag>
125 112720 : auto D() const
126 : {
127 112720 : return CTNull<ResultType>();
128 : }
129 : };
130 :
131 : /**
132 : * Base class for a unary operator/function
133 : */
134 : template <typename T>
135 : class CTUnary : public CTBase
136 : {
137 : public:
138 567691 : CTUnary(T arg) : _arg(arg) {}
139 :
140 : template <typename Self>
141 : std::string printParens(const Self *, const std::string & op) const
142 : {
143 : std::string out = op;
144 : if constexpr (T::precedence() > Self::precedence())
145 : out += "(" + _arg.print() + ")";
146 : else
147 : out += _arg.print();
148 :
149 : return out;
150 : }
151 :
152 : typedef typename T::ResultType ResultType;
153 :
154 : protected:
155 : const T _arg;
156 : };
157 :
158 : /**
159 : * Unary minus
160 : */
161 : template <typename T>
162 : class CTUnaryMinus : public CTUnary<T>
163 : {
164 : public:
165 169 : CTUnaryMinus(T arg) : CTUnary<T>(arg) {}
166 : using typename CTUnary<T>::ResultType;
167 :
168 254 : ResultType operator()() const { return -_arg(); }
169 : std::string print() const { return this->printParens(this, "-"); }
170 : constexpr static int precedence() { return 3; }
171 :
172 : template <CTTag dtag>
173 58 : auto D() const
174 : {
175 58 : return -_arg.template D<dtag>();
176 : }
177 :
178 : using CTUnary<T>::_arg;
179 : };
180 :
181 : template <typename T, class = std::enable_if_t<std::is_base_of<CTBase, T>::value>>
182 : auto
183 169 : operator-(const T & arg)
184 : {
185 169 : return CTUnaryMinus<T>(arg);
186 : }
187 :
188 : /**
189 : * Base class for a binary operator/function
190 : */
191 : template <typename L, typename R>
192 : class CTBinary : public CTBase
193 : {
194 : public:
195 4817655 : CTBinary(L left, R right) : _left(left), _right(right) {}
196 :
197 : typedef typename libMesh::CompareTypes<typename L::ResultType, typename R::ResultType>::supertype
198 : ResultType;
199 :
200 : template <typename Self>
201 10 : std::string printParens(const Self *, const std::string & op) const
202 : {
203 10 : std::string out;
204 : if constexpr (L::precedence() > Self::precedence())
205 1 : out = "(" + _left.print() + ")";
206 : else
207 9 : out = _left.print();
208 :
209 10 : out += op;
210 :
211 20 : if (R::precedence() > Self::precedence() ||
212 10 : (R::precedence() == Self::precedence() && Self::leftAssociative()))
213 2 : out += "(" + _right.print() + ")";
214 : else
215 8 : out += _right.print();
216 :
217 10 : return out;
218 0 : }
219 :
220 : protected:
221 : const L _left;
222 : const R _right;
223 : };
224 :
225 : template <typename C, typename L, typename R>
226 : auto conditional(const C &, const L &, const R &);
227 :
228 : /**
229 : * Base class for a ternary functions
230 : */
231 : template <typename C, typename L, typename R>
232 : class CTConditional : public CTBinary<L, R>
233 : {
234 : public:
235 31 : CTConditional(C condition, L left, R right) : CTBinary<L, R>(left, right), _condition(condition)
236 : {
237 31 : }
238 : using typename CTBinary<L, R>::ResultType;
239 :
240 50 : auto operator()() const { return _condition() ? _left() : _right(); }
241 : template <CTTag dtag>
242 : auto D() const
243 : {
244 : return conditional(_condition, _left.template D<dtag>(), _right.template D<dtag>());
245 : }
246 :
247 : template <typename Self>
248 : std::string print() const
249 : {
250 : return "conditional(" + _condition.print() + ", " + _left.print() + ", " + _right.print() + ")";
251 : }
252 :
253 : protected:
254 : const C _condition;
255 :
256 : using CTBinary<L, R>::_left;
257 : using CTBinary<L, R>::_right;
258 : };
259 :
260 : template <typename C, typename L, typename R>
261 : auto
262 1 : conditional(const C & condition, const L & left, const R & right)
263 : {
264 1 : return CTConditional<C, L, R>(condition, left, right);
265 : }
266 :
267 : template <typename L, typename R>
268 : auto
269 15 : min(const L & left, const R & right)
270 : {
271 15 : return CTConditional<decltype(left < right), L, R>(left < right, left, right);
272 : }
273 :
274 : template <typename L, typename R>
275 : auto
276 15 : max(const L & left, const R & right)
277 : {
278 15 : return CTConditional<decltype(left > right), L, R>(left > right, left, right);
279 : }
280 :
281 : /**
282 : * Constant value
283 : */
284 : template <CTTag tag, typename T>
285 : class CTValue : public CTBase
286 : {
287 : public:
288 387698 : CTValue(const T value) : _value(value) {}
289 : typedef T ResultType;
290 :
291 743092 : auto operator()() const { return _value; }
292 : template <CTTag dtag>
293 381984 : auto D() const
294 : {
295 381984 : return CTNull<ResultType>();
296 : }
297 :
298 4 : std::string print() const { return Moose::stringify(_value); }
299 :
300 : protected:
301 : T _value;
302 : };
303 :
304 : /**
305 : * Helper function to build a (potentially tagged) value
306 : */
307 : template <CTTag tag = CTNoTag, typename T>
308 : auto
309 387698 : makeValue(T value)
310 : {
311 387698 : return CTValue<tag, T>(value);
312 : }
313 :
314 : template <CTTag start_tag, typename... Values, CTTag... Tags>
315 : auto
316 : makeValuesHelper(const std::tuple<Values...> & values, std::integer_sequence<CTTag, Tags...>)
317 : {
318 : if constexpr (start_tag == CTNoTag)
319 : return std::make_tuple(CTValue<CTNoTag, Values>(std::get<Tags>(values))...);
320 : else
321 : return std::make_tuple(CTValue<Tags + start_tag, Values>(std::get<Tags>(values))...);
322 : }
323 :
324 : /**
325 : * Helper function to build a list of (potentially tagged) values
326 : */
327 : template <CTTag start_tag = CTNoTag, typename... Ts>
328 : auto
329 : makeValues(Ts... values)
330 : {
331 : return makeValuesHelper<start_tag>(std::tuple(values...),
332 : std::make_integer_sequence<CTTag, sizeof...(values)>{});
333 : }
334 :
335 : /**
336 : * Variable value, referencing a variable of type T. This object is tagged with a
337 : * CTTag to enable taking symbolic derivatives.
338 : */
339 : template <CTTag tag, typename T>
340 : class CTRef : public CTBase
341 : {
342 : public:
343 1268 : CTRef(const T & ref) : _ref(ref) {}
344 7181 : const T & operator()() const { return _ref; }
345 6 : std::string print() const { return "[v" + printTag<tag>() + "]"; }
346 :
347 : template <CTTag dtag>
348 2880 : auto D() const
349 : {
350 : if constexpr (tag == dtag)
351 2859 : return CTOne<ResultType>();
352 : else
353 21 : return CTNull<ResultType>();
354 : }
355 :
356 : typedef CTCleanType<T> ResultType;
357 :
358 : protected:
359 : const T & _ref;
360 : };
361 :
362 : /**
363 : * Helper function to build a tagged reference to a variable
364 : */
365 : template <CTTag tag = CTNoTag, typename T>
366 : auto
367 1262 : makeRef(const T & ref)
368 : {
369 1262 : return CTRef<tag, T>(ref);
370 : }
371 :
372 : template <CTTag start_tag, typename... Refs, CTTag... Tags>
373 : auto
374 2 : makeRefsHelper(const std::tuple<Refs...> & refs, std::integer_sequence<CTTag, Tags...>)
375 : {
376 : if constexpr (start_tag == CTNoTag)
377 : return std::make_tuple(CTRef<CTNoTag, Refs>(std::get<Tags>(refs))...);
378 : else
379 4 : return std::make_tuple(CTRef<Tags + start_tag, Refs>(std::get<Tags>(refs))...);
380 : }
381 :
382 : /**
383 : * Helper function to build a list of tagged references to variables
384 : */
385 : template <CTTag start_tag = CTNoTag, typename... Ts>
386 : auto
387 2 : makeRefs(const Ts &... refs)
388 : {
389 : return makeRefsHelper<start_tag>(std::tie(refs...),
390 2 : std::make_integer_sequence<CTTag, sizeof...(refs)>{});
391 : }
392 :
393 : /**
394 : * Array variable value, referencing an entry in an indexable container of T types.
395 : * The index of type I is also stored as a reference.
396 : * This object is tagged with a CTTag to enable taking symbolic derivatives.
397 : */
398 : template <CTTag tag, typename T, typename I>
399 : class CTArrayRef : public CTBase
400 : {
401 : public:
402 119 : CTArrayRef(const T & arr, const I & idx) : _arr(arr), _idx(idx) {}
403 1129994 : auto operator()() const { return _arr[_idx]; }
404 2 : std::string print() const { return "[a" + printTag<tag>() + "[" + Moose::stringify(_idx) + "]]"; }
405 :
406 : template <CTTag dtag>
407 748046 : auto D() const
408 : {
409 : if constexpr (tag == dtag)
410 235022 : return CTOne<ResultType>();
411 : else
412 513024 : return CTNull<ResultType>();
413 : }
414 :
415 : // get the value type returned by operator[]
416 : typedef CTCleanType<decltype((static_cast<T>(0))[0])> ResultType;
417 : static_assert(!std::is_same_v<ResultType, void>,
418 : "Instantiation of CTArrayRef was attempted for a non-subscriptable type.");
419 :
420 : protected:
421 : const T & _arr;
422 : const I & _idx;
423 : };
424 :
425 : /**
426 : * Helper function to build a tagged reference to a vector/array entry
427 : */
428 : template <CTTag tag = CTNoTag, typename T, typename I>
429 : auto
430 119 : makeRef(const T & ref, const I & idx)
431 : {
432 119 : return CTArrayRef<tag, T, I>(ref, idx);
433 : }
434 :
435 : /**
436 : * Addition operator node
437 : */
438 : template <typename L, typename R>
439 : class CTAdd : public CTBinary<L, R>
440 : {
441 : public:
442 1475375 : CTAdd(L left, R right) : CTBinary<L, R>(left, right) {}
443 : using typename CTBinary<L, R>::ResultType;
444 :
445 1950728 : ResultType operator()() const
446 : {
447 : // compile time optimization to skip null terms
448 : if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
449 303616 : return ResultType(0);
450 :
451 : if constexpr (std::is_base_of<CTNullBase, L>::value)
452 2 : return _right();
453 :
454 : if constexpr (std::is_base_of<CTNullBase, R>::value)
455 113156 : return _left();
456 :
457 : else
458 1533954 : return _left() + _right();
459 : }
460 1 : std::string print() const { return this->printParens(this, "+"); }
461 4 : constexpr static int precedence() { return 6; }
462 :
463 : template <CTTag dtag>
464 531034 : auto D() const
465 : {
466 531034 : return _left.template D<dtag>() + _right.template D<dtag>();
467 : }
468 :
469 : using CTBinary<L, R>::_left;
470 : using CTBinary<L, R>::_right;
471 : };
472 :
473 : /**
474 : * Subtraction operator node
475 : */
476 : template <typename L, typename R>
477 : class CTSub : public CTBinary<L, R>
478 : {
479 : public:
480 116169 : CTSub(L left, R right) : CTBinary<L, R>(left, right) {}
481 : using typename CTBinary<L, R>::ResultType;
482 :
483 99199 : ResultType operator()() const
484 : {
485 : if constexpr (std::is_base_of<CTNullBase, L>::value && std::is_base_of<CTNullBase, R>::value)
486 294 : return ResultType(0);
487 :
488 : if constexpr (std::is_base_of<CTNullBase, L>::value)
489 1078 : return -_right();
490 :
491 : if constexpr (std::is_base_of<CTNullBase, R>::value)
492 : return _left();
493 :
494 : else
495 97827 : return _left() - _right();
496 : }
497 4 : std::string print() const { return this->printParens(this, "-"); }
498 9 : constexpr static int precedence() { return 6; }
499 1 : constexpr static bool leftAssociative() { return true; }
500 :
501 : template <CTTag dtag>
502 30240 : auto D() const
503 : {
504 30240 : return _left.template D<dtag>() - _right.template D<dtag>();
505 : }
506 :
507 : using CTBinary<L, R>::_left;
508 : using CTBinary<L, R>::_right;
509 : };
510 :
511 : /**
512 : * Multiplication operator node
513 : */
514 : template <typename L, typename R>
515 : class CTMul : public CTBinary<L, R>
516 : {
517 : public:
518 3006357 : CTMul(L left, R right) : CTBinary<L, R>(left, right) {}
519 : using typename CTBinary<L, R>::ResultType;
520 :
521 2521786 : ResultType operator()() const
522 : {
523 : if constexpr (std::is_base_of<CTNullBase, L>::value || std::is_base_of<CTNullBase, R>::value)
524 903969 : return ResultType(0);
525 :
526 : if constexpr (std::is_base_of<CTOneBase, L>::value && std::is_base_of<CTOneBase, R>::value)
527 98 : return ResultType(1);
528 :
529 : if constexpr (std::is_base_of<CTOneBase, L>::value)
530 1021 : return _right();
531 :
532 : if constexpr (std::is_base_of<CTOneBase, R>::value)
533 281985 : return _left();
534 :
535 : else
536 1334713 : return _left() * _right();
537 : }
538 5 : std::string print() const { return this->printParens(this, "*"); }
539 11 : constexpr static int precedence() { return 5; }
540 :
541 : template <CTTag dtag>
542 941779 : auto D() const
543 : {
544 1883558 : return _left.template D<dtag>() * _right + _right.template D<dtag>() * _left;
545 : }
546 :
547 : using CTBinary<L, R>::_left;
548 : using CTBinary<L, R>::_right;
549 : };
550 :
551 : /**
552 : * Division operator node
553 : */
554 : template <typename L, typename R>
555 : class CTDiv : public CTBinary<L, R>
556 : {
557 : public:
558 219426 : CTDiv(L left, R right) : CTBinary<L, R>(left, right) {}
559 : using typename CTBinary<L, R>::ResultType;
560 :
561 221474 : ResultType operator()() const
562 : {
563 : if constexpr (std::is_base_of<CTOneBase, R>::value)
564 : return _left();
565 :
566 : if constexpr (std::is_base_of<CTNullBase, L>::value && !std::is_base_of<CTNullBase, R>::value)
567 98 : return ResultType(0);
568 :
569 221376 : return _left() / _right();
570 : }
571 : std::string print() const { return this->printParens(this, "/"); }
572 : constexpr static int precedence() { return 5; }
573 : constexpr static bool leftAssociative() { return true; }
574 :
575 : template <CTTag dtag>
576 85700 : auto D() const
577 : {
578 85700 : return _left.template D<dtag>() / _right -
579 171400 : _left * _right.template D<dtag>() / (_right * _right);
580 : }
581 :
582 : using CTBinary<L, R>::_left;
583 : using CTBinary<L, R>::_right;
584 : };
585 :
586 : enum class CTComparisonEnum
587 : {
588 : Less,
589 : LessEqual,
590 : Greater,
591 : GreaterEqual,
592 : Equal,
593 : Unequal
594 : };
595 :
596 : /**
597 : * Binary comparison operator node
598 : */
599 : template <CTComparisonEnum C, typename L, typename R>
600 : class CTCompare : public CTBinary<L, R>
601 : {
602 : public:
603 155 : CTCompare(L left, R right) : CTBinary<L, R>(left, right) {}
604 : typedef bool ResultType;
605 :
606 174 : ResultType operator()() const
607 : {
608 : if constexpr (C == CTComparisonEnum::Less)
609 70 : return _left() < _right();
610 : if constexpr (C == CTComparisonEnum::LessEqual)
611 24 : return _left() <= _right();
612 : if constexpr (C == CTComparisonEnum::Greater)
613 50 : return _left() > _right();
614 : if constexpr (C == CTComparisonEnum::GreaterEqual)
615 24 : return _left() >= _right();
616 : if constexpr (C == CTComparisonEnum::Equal)
617 3 : return _left() == _right();
618 : if constexpr (C == CTComparisonEnum::Unequal)
619 3 : return _left() != _right();
620 : }
621 : std::string print() const
622 : {
623 : if constexpr (C == CTComparisonEnum::Less)
624 : return this->printParens(this, "<");
625 : if constexpr (C == CTComparisonEnum::LessEqual)
626 : return this->printParens(this, "<=");
627 : if constexpr (C == CTComparisonEnum::Greater)
628 : return this->printParens(this, ">");
629 : if constexpr (C == CTComparisonEnum::GreaterEqual)
630 : return this->printParens(this, ">=");
631 : if constexpr (C == CTComparisonEnum::Equal)
632 : return this->printParens(this, "==");
633 : if constexpr (C == CTComparisonEnum::Unequal)
634 : return this->printParens(this, "!=");
635 : }
636 : constexpr static int precedence() { return 9; }
637 : constexpr static bool leftAssociative() { return true; }
638 :
639 : template <CTTag dtag>
640 : auto D() const
641 : {
642 : return CTNull<ResultType>();
643 : }
644 :
645 : using CTBinary<L, R>::_left;
646 : using CTBinary<L, R>::_right;
647 : };
648 :
649 : /// template aliases for the comparison operator nodes
650 : template <typename L, typename R>
651 : using CTCompareLess = CTCompare<CTComparisonEnum::Less, L, R>;
652 : template <typename L, typename R>
653 : using CTCompareLessEqual = CTCompare<CTComparisonEnum::LessEqual, L, R>;
654 : template <typename L, typename R>
655 : using CTCompareGreater = CTCompare<CTComparisonEnum::Greater, L, R>;
656 : template <typename L, typename R>
657 : using CTCompareGreaterEqual = CTCompare<CTComparisonEnum::GreaterEqual, L, R>;
658 : template <typename L, typename R>
659 : using CTCompareEqual = CTCompare<CTComparisonEnum::Equal, L, R>;
660 : template <typename L, typename R>
661 : using CTCompareUnequal = CTCompare<CTComparisonEnum::Unequal, L, R>;
662 :
663 : /**
664 : * Power operator where both base and exponent can be arbitrary operators.
665 : */
666 : template <typename L, typename R>
667 : class CTPow : public CTBinary<L, R>
668 : {
669 : public:
670 135 : CTPow(L left, R right) : CTBinary<L, R>(left, right) {}
671 : using typename CTBinary<L, R>::ResultType;
672 :
673 390 : ResultType operator()() const
674 : {
675 : if constexpr (std::is_base_of<CTNullBase, L>::value)
676 : return ResultType(0);
677 :
678 : if constexpr (std::is_base_of<CTOneBase, L>::value || std::is_base_of<CTNullBase, R>::value)
679 : return ResultType(1);
680 :
681 : if constexpr (std::is_base_of<CTOneBase, R>::value)
682 : return _left();
683 :
684 390 : return std::pow(_left(), _right());
685 : }
686 : std::string print() const { return "pow(" + _left.print() + "," + _right.print() + ")"; }
687 :
688 : template <CTTag dtag>
689 130 : auto D() const
690 : {
691 : if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value &&
692 : std::is_base_of<CTNullBase, decltype(_right.template D<dtag>())>::value)
693 : return CTNull<ResultType>();
694 :
695 : else if constexpr (std::is_base_of<CTNullBase, decltype(_left.template D<dtag>())>::value)
696 58 : return pow(_left, _right) * _right.template D<dtag>() * log(_left);
697 :
698 : else if constexpr (std::is_base_of<CTNullBase, decltype(_right.template D<dtag>())>::value)
699 116 : return pow(_left, _right) * _right * _left.template D<dtag>() / _left;
700 :
701 : else
702 14 : return pow(_left, _right) *
703 28 : (_right.template D<dtag>() * log(_left) + _right * _left.template D<dtag>() / _left);
704 : }
705 :
706 : using CTBinary<L, R>::_left;
707 : using CTBinary<L, R>::_right;
708 : };
709 :
710 : /**
711 : * pow(base, exponent) function overload.
712 : */
713 : template <typename B, typename E>
714 : auto
715 135 : pow(const B & base, const E & exp)
716 : {
717 : if constexpr (std::is_base_of<CTBase, B>::value && std::is_base_of<CTBase, E>::value)
718 131 : return CTPow(base, exp);
719 : else if constexpr (std::is_base_of<CTBase, E>::value)
720 2 : return CTPow(makeValue(base), exp);
721 : else if constexpr (std::is_base_of<CTBase, B>::value)
722 2 : return CTPow(base, makeValue(exp));
723 : else
724 : return CTPow(makeValue(base), makeValue(exp));
725 : }
726 :
727 : /**
728 : * Integer exponent power operator.
729 : */
730 : template <typename B, int E>
731 : class CTIPow : public CTUnary<B>
732 : {
733 : public:
734 234972 : CTIPow(B base) : CTUnary<B>(base) {}
735 : using typename CTUnary<B>::ResultType;
736 :
737 197658 : ResultType operator()() const
738 : {
739 : if constexpr (std::is_base_of<CTNullBase, B>::value)
740 : return ResultType(0);
741 :
742 : else if constexpr (std::is_base_of<CTOneBase, B>::value || E == 0)
743 : return ResultType(1);
744 :
745 : else if constexpr (E == 1)
746 6144 : return _arg();
747 :
748 : else if constexpr (E < 0)
749 : return 1.0 / libMesh::Utility::pow<-E>(_arg());
750 :
751 : else
752 191514 : return libMesh::Utility::pow<E>(_arg());
753 : }
754 : std::string print() const { return "pow<" + Moose::stringify(E) + ">(" + _arg.print() + ")"; }
755 :
756 : template <CTTag dtag>
757 238115 : auto D() const
758 : {
759 : if constexpr (E == 1)
760 5120 : return _arg.template D<dtag>();
761 :
762 : else if constexpr (E == 0)
763 : return CTNull<ResultType>();
764 :
765 : else
766 232995 : return pow<E - 1>(_arg) * E * _arg.template D<dtag>();
767 : }
768 :
769 : using CTUnary<B>::_arg;
770 : };
771 :
772 : /**
773 : * pow<exponent>(base) template for integer powers.
774 : */
775 : template <int E, typename B>
776 : auto
777 234972 : pow(const B & base)
778 : {
779 : if constexpr (std::is_base_of<CTBase, B>::value)
780 234971 : return CTIPow<B, E>(base);
781 : else
782 1 : return CTIPow<CTValue<CTNoTag, B>, E>(makeValue(base));
783 : }
784 :
785 : /**
786 : * Macro for implementing a binary math operator overload that works with a mix of CT system
787 : * objects, C variables, and number literals.
788 : */
789 : #define CT_OPERATOR_BINARY(op, OP) \
790 : template <typename L, \
791 : typename R, \
792 : class = std::enable_if_t<std::is_base_of<CTBase, L>::value || \
793 : std::is_base_of<CTBase, R>::value>> \
794 : auto operator op(const L & left, const R & right) \
795 : { \
796 : /* We need a template arguments here because: */ \
797 : /* alias template deduction is only available with '-std=c++2a' or '-std=gnu++2a' */ \
798 : if constexpr (std::is_base_of<CTBase, L>::value && std::is_base_of<CTBase, R>::value) \
799 : return OP<L, R>(left, right); \
800 : else if constexpr (std::is_base_of<CTBase, L>::value) \
801 : return OP<L, decltype(makeValue(right))>(left, makeValue(right)); \
802 : else if constexpr (std::is_base_of<CTBase, R>::value) \
803 : return OP<decltype(makeValue(left)), R>(makeValue(left), right); \
804 : else \
805 : static_assert(libMesh::always_false<L>, "This should not be instantiated."); \
806 : }
807 :
808 1475375 : CT_OPERATOR_BINARY(+, CTAdd)
809 116169 : CT_OPERATOR_BINARY(-, CTSub)
810 3006357 : CT_OPERATOR_BINARY(*, CTMul)
811 219426 : CT_OPERATOR_BINARY(/, CTDiv)
812 51 : CT_OPERATOR_BINARY(<, CTCompareLess)
813 24 : CT_OPERATOR_BINARY(<=, CTCompareLessEqual)
814 50 : CT_OPERATOR_BINARY(>, CTCompareGreater)
815 24 : CT_OPERATOR_BINARY(>=, CTCompareGreaterEqual)
816 3 : CT_OPERATOR_BINARY(==, CTCompareEqual)
817 3 : CT_OPERATOR_BINARY(!=, CTCompareUnequal)
818 :
819 : /**
820 : * Macro for implementing a simple unary function overload. No function specific optimizations are
821 : * possible. The parameters are the function name and the expression that returns the derivative
822 : * of the function.
823 : */
824 : #define CT_SIMPLE_UNARY_FUNCTION(name, derivative) \
825 : template <typename T> \
826 : class CTF##name : public CTUnary<T> \
827 : { \
828 : public: \
829 : CTF##name(T arg) : CTUnary<T>(arg) {} \
830 : auto operator()() const { return std::name(_arg()); } \
831 : template <CTTag dtag> \
832 : auto D() const \
833 : { \
834 : return derivative; \
835 : } \
836 : std::string print() const { return #name "(" + _arg.print() + ")"; } \
837 : constexpr static int precedence() { return 2; } \
838 : using typename CTUnary<T>::ResultType; \
839 : using CTUnary<T>::_arg; \
840 : }; \
841 : template <typename T> \
842 : auto name(const T & v) \
843 : { \
844 : using namespace CompileTimeDerivatives; \
845 : if constexpr (std::is_base_of<CTBase, T>::value) \
846 : return CTF##name(v); \
847 : else \
848 : return CTF##name(makeValue(v)); \
849 : }
850 :
851 311 : CT_SIMPLE_UNARY_FUNCTION(exp, exp(_arg) * _arg.template D<dtag>())
852 153354 : CT_SIMPLE_UNARY_FUNCTION(log, _arg.template D<dtag>() / _arg)
853 686671 : CT_SIMPLE_UNARY_FUNCTION(sin, cos(_arg) * _arg.template D<dtag>())
854 678477 : CT_SIMPLE_UNARY_FUNCTION(cos, -1.0 * sin(_arg) * _arg.template D<dtag>())
855 692 : CT_SIMPLE_UNARY_FUNCTION(tan, (pow<2>(tan(_arg)) + 1.0) * _arg.template D<dtag>())
856 63 : CT_SIMPLE_UNARY_FUNCTION(sqrt, 1.0 / (2.0 * sqrt(_arg)) * _arg.template D<dtag>())
857 1211 : CT_SIMPLE_UNARY_FUNCTION(tanh, (1.0 - pow<2>(tanh(_arg))) * _arg.template D<dtag>())
858 482 : CT_SIMPLE_UNARY_FUNCTION(sinh, cosh(_arg) * _arg.template D<dtag>())
859 482 : CT_SIMPLE_UNARY_FUNCTION(cosh, sinh(_arg) * _arg.template D<dtag>())
860 185 : CT_SIMPLE_UNARY_FUNCTION(erf,
861 : 2.0 * exp(-pow<2>(_arg)) / sqrt(libMesh::pi) * _arg.template D<dtag>())
862 362 : CT_SIMPLE_UNARY_FUNCTION(atan, 1.0 / (pow<2>(_arg) + 1.0) * _arg.template D<dtag>())
863 :
864 : /**
865 : * Macro for implementing a simple binary function overload. No function specific optimizations are
866 : * possible. The parameters are the function name and the expression that returns the derivative
867 : * of the function.
868 : */
869 : #define CT_SIMPLE_BINARY_FUNCTION_CLASS(name, derivative) \
870 : template <typename L, typename R> \
871 : class CTF##name : public CTBinary<L, R> \
872 : { \
873 : public: \
874 : CTF##name(L left, R right) : CTBinary<L, R>(left, right) {} \
875 : auto operator()() const { return std::name(_left(), _right()); } \
876 : template <CTTag dtag> \
877 : auto D() const \
878 : { \
879 : return derivative; \
880 : } \
881 : std::string print() const { return #name "(" + _left.print() + ", " + _right.print() + ")"; } \
882 : constexpr static int precedence() { return 2; } \
883 : using typename CTBinary<L, R>::ResultType; \
884 : using CTBinary<L, R>::_left; \
885 : using CTBinary<L, R>::_right; \
886 : };
887 : #define CT_SIMPLE_BINARY_FUNCTION_FUNC(name) \
888 : template <typename L, typename R> \
889 : auto name(const L & l, const R & r) \
890 : { \
891 : using namespace CompileTimeDerivatives; \
892 : if constexpr (std::is_base_of<CTBase, L>::value && std::is_base_of<CTBase, R>::value) \
893 : return CTF##name(l, r); \
894 : else if constexpr (std::is_base_of<CTBase, L>::value) \
895 : return CTF##name(l, makeValue(r)); \
896 : else if constexpr (std::is_base_of<CTBase, R>::value) \
897 : return CTF##name(makeValue(l), r); \
898 : else \
899 : return CTF##name(makeValue(l), makeValue(r)); \
900 : }
901 :
902 275 : CT_SIMPLE_BINARY_FUNCTION_CLASS(atan2,
903 : (-_left * _right.template D<dtag>() +
904 : _left.template D<dtag>() * _right) /
905 : (pow<2>(_left) + pow<2>(_right)))
906 7 : CT_SIMPLE_BINARY_FUNCTION_FUNC(atan2)
907 :
908 : template <typename T, int N, int M>
909 : class CTMatrix
910 : {
911 : public:
912 : template <typename... Ts>
913 1 : CTMatrix(Ts... a) : _data({a...})
914 : {
915 : static_assert(sizeof...(a) == N * M, "Invalid number of matrix entries");
916 1 : }
917 : T & operator()(std::size_t n, std::size_t m) { return _data[M * n + m]; }
918 9 : const T & operator()(std::size_t n, std::size_t m) const { return _data[M * n + m]; }
919 :
920 : protected:
921 : std::array<T, N * M> _data;
922 : };
923 :
924 : template <typename... Ds>
925 : class CTStandardDeviation : public CTBase
926 : {
927 : public:
928 : static constexpr auto N = sizeof...(Ds);
929 :
930 1 : CTStandardDeviation(std::tuple<Ds...> derivatives, CTMatrix<Real, N, N> covariance)
931 1 : : _derivatives(derivatives), _covariance(covariance)
932 : {
933 1 : }
934 1 : auto operator()() const { return std::sqrt(evalHelper(std::make_index_sequence<N>{})); }
935 :
936 : typedef typename CTSuperType<typename Ds::ResultType...>::type ResultType;
937 :
938 : protected:
939 : template <int R, std::size_t... Is>
940 3 : ResultType rowMul(std::index_sequence<Is...>, const std::array<ResultType, N> & d) const
941 : {
942 3 : return ((_covariance(R, Is) * d[Is]) + ...);
943 : }
944 :
945 : template <std::size_t... Is>
946 1 : auto evalHelper(const std::index_sequence<Is...> & is) const
947 : {
948 1 : const std::array<ResultType, N> d{std::get<Is>(_derivatives)()...};
949 1 : return ((rowMul<Is>(is, d) * d[Is]) + ...);
950 : }
951 :
952 : const std::tuple<Ds...> _derivatives;
953 : const CTMatrix<Real, N, N> _covariance;
954 : };
955 :
956 : template <CTTag start_tag, typename T, CTTag... Tags>
957 : auto
958 1 : makeStandardDeviationHelper(const T & f, std::integer_sequence<CTTag, Tags...>)
959 : {
960 2 : return std::make_tuple(f.template D<Tags + start_tag>()...);
961 : }
962 :
963 : /**
964 : * Helper function to build a standard deviation object for a function with N parameters with
965 : * consecutive tags starting at start_tag, and an NxN covariance matrix for said parameters.
966 : */
967 : template <CTTag start_tag, typename T, int N>
968 : auto
969 1 : makeStandardDeviation(const T & f, const CTMatrix<Real, N, N> covariance)
970 : {
971 : return CTStandardDeviation(
972 : makeStandardDeviationHelper<start_tag>(f, std::make_integer_sequence<CTTag, N>{}),
973 1 : covariance);
974 : }
975 :
976 : } // namespace CompileTimeDerivatives
|