5 #ifndef LODESTAR_BUTCHERTABLEAU_HPP
6 #define LODESTAR_BUTCHERTABLEAU_HPP
12 #include <string_view>
14 #include "Lodestar/aux/Indices.hpp"
15 #include "Lodestar/aux/Conjunction.hpp"
17 #include <Eigen/Dense>
20 namespace primitives {
34 template<
typename TScalarType,
size_t TStages,
size_t TStage,
bool TIsWeights>
47 template<
typename TScalarType,
size_t TStages,
size_t TStage>
49 static_assert(TStage >= 0,
"Butcher tableau row number must be non-negative.");
50 static_assert(TStage < (TStages - 1),
51 "Butcher tableau row number must be smaller than the number of stages minus one.");
54 std::array<TScalarType, TStage + 1> rkCoefficients;
66 template<
typename TScalarType,
size_t TStages,
size_t TStage>
68 std::array<TScalarType, TStages> weights;
88 template<
size_t TStages,
bool TExtended = true,
typename TScalarType =
double>
90 static_assert(TStages > 1,
"Butcher tableau must have more than one stage.");
93 template<
size_t TStages,
typename TScalarType>
96 static_assert(TStages > 1,
"Butcher tableau must have more than one stage.");
98 static const size_t stages = TStages;
107 template<
size_t TStage,
bool TIsWeights = (TStage >= (TStages - 1))>
118 template<
size_t TRow,
size_t TCoeffIdx>
119 typename std::enable_if<(TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1)), TScalarType>
::type
122 return std::get<TCoeffIdx>(std::get<TRow>(rows_).rkCoefficients);
133 template<
size_t TRow,
size_t TCoeffIdx>
134 typename std::enable_if<!((TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1))), TScalarType>::type
137 static_assert(TRow < (TStages - 1),
138 "Row index must be smaller than the number of stages minus one to access coefficients.");
139 static_assert(TCoeffIdx < (TRow + 1),
"Coefficient index must be less than or equal to the row index.");
141 return TScalarType{};
152 template<
size_t TRow,
size_t TCoeffIdx>
153 typename std::enable_if<(TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1)),
void>::type
156 std::get<TCoeffIdx>(std::get<TRow>(rows_).rkCoefficients) = coeff;
167 template<
size_t TRow,
size_t TCoeffIdx>
168 typename std::enable_if<!((TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1))),
void>::type
171 static_assert(TRow < (TStages - 1),
172 "Row index must be smaller than the number of stages minus one to access coefficients.");
173 static_assert(TCoeffIdx < (TRow + 1),
"Coefficient index must be less than or equal to the row index.");
183 template<
size_t TRow>
184 typename std::enable_if<TRow < (TStages - 1), TScalarType>::type
187 return std::get<TRow>(rows_).node;
197 template<
size_t TRow>
198 typename std::enable_if<TRow >= (TStages - 1), TScalarType>::type
201 static_assert(TRow < (TStages - 1),
202 "Row index must be smaller than the number of stages minus one to access coefficients.");
204 return TScalarType{};
214 template<
size_t TRow>
215 typename std::enable_if<TRow < (TStages - 1),
void>::type
216 inline setNode(TScalarType node)
218 std::get<TRow>(rows_).node = node;
228 template<
size_t TRow>
229 typename std::enable_if<TRow >= (TStages - 1),
void>::type
232 static_assert(TRow < (TStages - 1),
233 "Row index must be smaller than the number of stages minus one to access coefficients.");
243 template<
size_t TIdx>
244 typename std::enable_if<TIdx < TStages, TScalarType>::type
247 return std::get<TIdx>(std::get<TStages - 1>(rows_).weights);
257 template<
size_t TIdx>
258 typename std::enable_if<TIdx >= TStages, TScalarType>::type
261 static_assert(TIdx < TStages,
"Weight index must be smaller than the number of stages.");
263 return TScalarType{};
273 template<
size_t TIdx>
274 typename std::enable_if<TIdx < TStages, void>::type
275 inline setWeight(TScalarType weight)
277 std::get<TIdx>(std::get<TStages - 1>(rows_).weights) = weight;
287 template<
size_t TIdx>
288 typename std::enable_if<TIdx >= TStages,
void>::type
291 static_assert(TIdx < TStages,
"Weight index must be smaller than the number of stages.");
312 template<
typename TType,
size_t TStage = 0>
313 typename std::enable_if<TStage == 0, TType>::type
314 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
322 return execute<TType, TStage + 1>(f, y, t, h, kCurr);
342 template<
typename TType,
size_t TStage = 0,
typename... TArgs>
343 typename std::enable_if<(TStage > 0) && (TStage < TStages) &&
345 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
349 TType kCurr{}, yCurr = y;
351 yCurr += h * sumCoefficients<TType, TStage>(vars...);
352 TScalarType tCurr = t + getNode<TStage - 1>() * h;
354 kCurr = f(tCurr, yCurr);
356 return execute<TType, TStage + 1>(f, y, t, h, vars..., kCurr);
377 template<
typename TType,
size_t TStage = 0,
typename... TArgs>
378 typename std::enable_if<(TStage > 0) && (TStage == TStages) &&
380 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
385 yFinal += h * sumWeights<TType>(vars...);
400 template<int TTimes, typename TIndices = typename Indices<TTimes>::type>
409 template<
int TTimes,
int... TIndices>
411 using type = std::tuple<ButcherRow<TIndices>...>;
414 typename Rows<TStages>::type rows_;
430 template<
typename TType,
size_t TStage = 0,
typename... TArgs,
typename TIndex =
typename Indices<
sizeof...(TArgs)>::
type>
431 typename std::enable_if<
432 (TStage > 0) && (TStage < TStages) &&
436 return sumCoefficientsImpl<TType, TStage - 1>(TIndex{}, vars...);
453 template<
typename TType,
size_t TStage = 0,
typename... TArgs,
int... TIndices>
457 return sum<TType>((getCoefficient<TStage, (
size_t) TIndices>() * vars)...);
471 template<
typename TType,
typename... TArgs,
typename TIndex =
typename Indices<
sizeof...(TArgs)>::type>
472 typename std::enable_if<(
Conjunction<std::is_convertible<TArgs, TType>...>::value), TType>::type
475 return sumWeightsImpl<TType>(TIndex{}, vars...);
491 template<
typename TType,
typename... TArgs,
int... TIndices>
495 return sum<TType>((getWeight<(
size_t) TIndices>() * vars)...);
510 template<
typename TType,
typename TArg>
511 typename std::enable_if<std::is_convertible<TArg, TType>::value, TType>::type
527 template<
typename TType,
typename TArg>
528 typename std::enable_if<!std::is_convertible<TArg, TType>::value, TType>::type
531 static_assert(std::is_convertible<TArg, TType>::value,
"Summed values must be convertible.");
547 template<
typename TType,
typename TArg,
typename... TArgs>
548 typename std::enable_if<std::is_convertible<TArg, TType>::value &&
550 inline sum(TArg var, TArgs... vars)
552 return var + sum<TType>(vars...);
565 template<
typename TType,
typename TArg,
typename... TArgs>
566 typename std::enable_if<!(std::is_convertible<TArg, TType>::value &&
568 inline sum(TArg var, TArgs... vars)
570 static_assert(std::is_convertible<TArg, TType>::value &&
571 (
Conjunction<std::is_convertible<TArgs, TType>...>::value),
572 "Summed values must be convertible.");
574 return var + sum<TType>(vars...);
595 template<
size_t TStages,
typename TScalarType>
598 static_assert(TStages > 1,
"Butcher tableau must have more than one stage.");
600 static const size_t stages = TStages;
609 template<
size_t TStage,
bool TIsWeights = (TStage >= (TStages - 1))>
620 template<
size_t TRow,
size_t TCoeffIdx>
621 typename std::enable_if<(TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1)), TScalarType>
::type
624 return std::get<TCoeffIdx>(std::get<TRow>(rows_).rkCoefficients);
635 template<
size_t TRow,
size_t TCoeffIdx>
636 typename std::enable_if<!((TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1))), TScalarType>::type
639 static_assert(TRow < (TStages - 1),
640 "Row index must be smaller than the number of stages minus one to access coefficients.");
641 static_assert(TCoeffIdx < (TRow + 1),
"Coefficient index must be less than or equal to the row index.");
643 return TScalarType{};
654 template<
size_t TRow,
size_t TCoeffIdx>
655 typename std::enable_if<(TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1)),
void>::type
658 std::get<TCoeffIdx>(std::get<TRow>(rows_).rkCoefficients) = coeff;
669 template<
size_t TRow,
size_t TCoeffIdx>
670 typename std::enable_if<!((TRow < (TStages - 1)) && (TCoeffIdx < (TRow + 1))),
void>::type
673 static_assert(TRow < (TStages - 1),
674 "Row index must be smaller than the number of stages minus one to access coefficients.");
675 static_assert(TCoeffIdx < (TRow + 1),
"Coefficient index must be less than or equal to the row index.");
685 template<
size_t TRow>
686 typename std::enable_if<TRow < (TStages - 1), TScalarType>::type
689 return std::get<TRow>(rows_).node;
699 template<
size_t TRow>
700 typename std::enable_if<TRow >= (TStages - 1), TScalarType>::type
703 static_assert(TRow < (TStages - 1),
704 "Row index must be smaller than the number of stages minus one to access coefficients.");
706 return TScalarType{};
716 template<
size_t TRow>
717 typename std::enable_if<TRow < (TStages - 1),
void>::type
718 inline setNode(TScalarType node)
720 std::get<TRow>(rows_).node = node;
730 template<
size_t TRow>
731 typename std::enable_if<TRow >= (TStages - 1),
void>::type
734 static_assert(TRow < (TStages - 1),
735 "Row index must be smaller than the number of stages minus one to access coefficients.");
748 template<
size_t TIdx,
bool THigherOrder = true>
749 typename std::enable_if<(TIdx < TStages) && THigherOrder, TScalarType>::type
752 return std::get<TIdx>(std::get<TStages - 1>(rows_).weights);
765 template<
size_t TIdx,
bool THigherOrder = true>
766 typename std::enable_if<(TIdx < TStages) && !THigherOrder, TScalarType>::type
769 return std::get<TIdx>(std::get<TStages>(rows_).weights);
780 template<
size_t TIdx,
bool THigherOrder = true>
781 typename std::enable_if<TIdx >= TStages, TScalarType>::type
784 static_assert(TIdx < TStages,
"Weight index must be smaller than the number of stages.");
786 return TScalarType{};
799 template<
size_t TIdx,
bool THigherOrder = true>
800 typename std::enable_if<(TIdx < TStages) && THigherOrder, void>::type
803 std::get<TIdx>(std::get<TStages - 1>(rows_).weights) = weight;
816 template<
size_t TIdx,
bool THigherOrder = true>
817 typename std::enable_if<(TIdx < TStages) && !THigherOrder, void>::type
820 std::get<TIdx>(std::get<TStages>(rows_).weights) = weight;
831 template<
size_t TIdx,
bool THigherOrder = true>
832 typename std::enable_if<TIdx >= TStages,
void>::type
833 inline setWeight(TScalarType weight)
835 static_assert(TIdx < TStages,
"Weight index must be smaller than the number of stages.");
857 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded = false>
858 typename std::enable_if<(TStage == 0) && !TIsEmbedded, TType>::type
859 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
862 TType kCurr = f(t, y);
864 return execute<TType, TStage + 1, TIsEmbedded>(f, y, t, h, kCurr);
886 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded = false>
887 typename std::enable_if<(TStage == 0) && TIsEmbedded, std::pair<TType, TType>>::type
888 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
891 TType kCurr = f(t, y);
893 return execute<TType, TStage + 1, TIsEmbedded>(f, y, t, h, kCurr);
914 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded =
false,
typename... TArgs>
915 typename std::enable_if<(TStage > 0) && (TStage < TStages) && !TIsEmbedded &&
917 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
921 TType kCurr{}, yCurr = y;
923 yCurr += h * sumCoefficients<TType, TStage>(vars...);
924 TScalarType tCurr = t + getNode<TStage - 1>() * h;
926 kCurr = f(tCurr, yCurr);
928 return execute<TType, TStage + 1, TIsEmbedded>(f, y, t, h, vars..., kCurr);
949 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded =
false,
typename... TArgs>
950 typename std::enable_if<(TStage > 0) && (TStage < TStages) && TIsEmbedded &&
952 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
956 TType kCurr{}, yCurr = y;
958 yCurr += h * sumCoefficients<TType, TStage>(vars...);
959 TScalarType tCurr = t + getNode<TStage - 1>() * h;
961 kCurr = f(tCurr, yCurr);
963 return execute<TType, TStage + 1, TIsEmbedded>(f, y, t, h, vars..., kCurr);
984 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded =
false,
typename... TArgs>
985 typename std::enable_if<(TStage > 0) && (TStage == TStages) && !TIsEmbedded &&
987 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
992 yFinal += h * sumWeights<TType>(vars...);
1015 template<
typename TType,
size_t TStage = 0,
bool TIsEmbedded =
false,
typename... TArgs>
1016 typename std::enable_if<(TStage > 0) && (TStage == TStages) && TIsEmbedded &&
1018 inline execute(
const std::function<TType(TScalarType, TType)> &f,
const TType &y,
const TScalarType t,
1019 const TScalarType h,
1023 TType summedWeightsHigher = sumWeights<TType, true>(vars...);
1024 TType summedWeightsLower = sumWeights<TType, false>(vars...);
1025 yFinal += h * summedWeightsHigher;
1027 std::pair<TType, TType> pair{yFinal, h * (summedWeightsHigher - summedWeightsLower)};
1042 template<int TTimes, typename TIndices = typename Indices<TTimes>::type>
1051 template<
int TTimes,
int... TIndices>
1053 using type = std::tuple<ButcherRow<TIndices>...>;
1056 typename Rows<TStages + 1>::type rows_;
1072 template<
typename TType,
size_t TStage = 0,
typename... TArgs,
typename TIndex =
typename Indices<
sizeof...(TArgs)>::
type>
1073 typename std::enable_if<
1074 (TStage > 0) && (TStage < TStages) &&
1078 return sumCoefficientsImpl<TType, TStage - 1>(TIndex{}, vars...);
1095 template<
typename TType,
size_t TStage = 0,
typename... TArgs,
int... TIndices>
1099 return sum<TType>((getCoefficient<TStage, (
size_t) TIndices>() * vars)...);
1114 template<
typename TType,
bool THigherOrder =
true,
typename... TArgs,
typename TIndex =
typename Indices<
sizeof...(TArgs)>::type>
1115 typename std::enable_if<(
Conjunction<std::is_convertible<TArgs, TType>...>::value), TType>::type
1118 return sumWeightsImpl<TType, THigherOrder>(TIndex{}, vars...);
1135 template<
typename TType,
bool THigherOrder,
typename... TArgs,
int... TIndices>
1136 typename std::enable_if<THigherOrder, TType>::type
1139 return sum<TType>((getWeight<(
size_t) TIndices,
true>() * vars)...);
1156 template<
typename TType,
bool THigherOrder,
typename... TArgs,
int... TIndices>
1157 typename std::enable_if<!THigherOrder, TType>::type
1160 return sum<TType>((getWeight<(
size_t) TIndices,
false>() * vars)...);
1175 template<
typename TType,
typename TArg>
1176 typename std::enable_if<std::is_convertible<TArg, TType>::value, TType>::type
1192 template<
typename TType,
typename TArg>
1193 typename std::enable_if<!std::is_convertible<TArg, TType>::value, TType>::type
1196 static_assert(std::is_convertible<TArg, TType>::value,
"Summed values must be convertible.");
1212 template<
typename TType,
typename TArg,
typename... TArgs>
1213 typename std::enable_if<std::is_convertible<TArg, TType>::value &&
1215 inline sum(TArg var, TArgs... vars)
1217 return var + sum<TType>(vars...);
1230 template<
typename TType,
typename TArg,
typename... TArgs>
1231 typename std::enable_if<!(std::is_convertible<TArg, TType>::value &&
1233 inline sum(TArg var, TArgs... vars)
1235 static_assert(std::is_convertible<TArg, TType>::value &&
1236 (
Conjunction<std::is_convertible<TArgs, TType>...>::value),
1237 "Summed values must be convertible.");
1239 return var + sum<TType>(vars...);
1245 #endif //LODESTAR_BUTCHERTABLEAU_HPP