5 #ifndef LODESTAR_GAINBLOCK_HPP
6 #define LODESTAR_GAINBLOCK_HPP
8 #include "Lodestar/blocks/Block.hpp"
9 #include "Lodestar/aux/TemplateTools.hpp"
47 #define OUTPUT_GAINBLOCK() \
48 typename ::std::conditional< \
49 TOps == GainBlockOperator::Convolution, \
51 typename ::std::conditional< \
52 ::std::is_arithmetic<TGain>::value, \
54 typename ls::aux::TemplateTraits::BinaryOperators::sanitizeTypeMultiplicable<typename ::std::conditional< \
55 TOps != GainBlockOperator::Right, \
56 ls::aux::TemplateTraits::BinaryOperators::isMultiplicable<TGain, TInput>, \
57 ls::aux::TemplateTraits::BinaryOperators::isMultiplicable<TInput, TGain> \
58 >::type::returnType>::returnType \
93 template<
typename TInput,
typename TGain = TInput, GainBlockOperator TOps = GainBlockOperator::Left, GainBlockConvolutionMode TConv = GainBlockConvolutionMode::Reflect>
97 ::std::tuple<OUTPUT_GAINBLOCK()>,
98 ::std::tuple<TGain, double>
108 "Gain is not multiplicable with input.");
120 "For convolution, TInput and TGain must be matrix-like.");
125 ::std::tuple<TInput>,
126 ::std::tuple<OutputType>,
127 ::std::tuple<TGain, double>
186 this->
template p<0>() =
gain;
197 this->
template p<0>() =
gain;
207 return this->
template p<0>();
225 typename ::std::tuple_element<0, typename Base::Params>::type &
228 return this->
template p<0>();
236 const typename ::std::tuple_element<0, typename Base::Params>::type &
239 return this->
template p<0>();
249 typename ::std::tuple_element<1, typename Base::Params>::type &
252 return this->
template p<1>();
260 const typename ::std::tuple_element<1, typename Base::Params>::type &
263 return this->
template p<1>();
268 const ::std::array<GiNaC::ex, Base::kIns> &inputSymbols()
270 if (!this->isInitInput_) {
271 for (
int i = 0; i < Base::kIns; i++) {
274 for (
int ii = 0; ii <
277 for (
int jj = 0; jj <
280 "blk" + ::std::to_string(this->
id) +
"_i_" + ::std::to_string(i) +
281 "_r_" + ::std::to_string(ii) +
"_c_" + ::std::to_string(jj),
282 "\\text{BLK}^{i, " + ::std::to_string(i) +
", " + ::std::to_string(ii) +
283 ", " + ::std::to_string(jj) +
"}_{" + ::std::to_string(this->
id) +
"}"};
290 this->inputSymbols_[i] = GiNaC::lst_to_matrix(input);
292 this->inputSymbols_[i] = GiNaC::symbol{
293 "blk" + ::std::to_string(this->
id) +
"_i_" + ::std::to_string(i),
294 "\\text{BLK}^{i, " + ::std::to_string(i) +
"}_{" +
295 ::std::to_string(this->
id) +
300 this->isInitInput_ =
true;
303 return this->inputSymbols_;
306 const ::std::array<GiNaC::ex, Base::kOuts> &outputSymbols()
308 if (!this->isInitOutput_) {
309 for (
int i = 0; i < Base::kOuts; i++) {
312 for (
int ii = 0; ii <
315 for (
int jj = 0; jj <
318 "blk" + ::std::to_string(this->
id) +
"_o_" + ::std::to_string(i) +
319 "_r_" + ::std::to_string(ii) +
"_c_" + ::std::to_string(jj),
320 "\\text{BLK}^{i, " + ::std::to_string(i) +
", " + ::std::to_string(ii) +
321 ", " + ::std::to_string(jj) +
"}_{" + ::std::to_string(this->
id) +
"}"};
328 this->outputSymbols_[i] = GiNaC::lst_to_matrix(output);
330 this->outputSymbols_[i] = GiNaC::symbol{
331 "blk" + ::std::to_string(this->
id) +
"_o_" + ::std::to_string(i),
332 "\\text{BLK}^{o, " + ::std::to_string(i) +
"}_{" +
333 ::std::to_string(this->
id) +
338 this->isInitOutput_ =
true;
341 return this->outputSymbols_;
344 const ::std::array<GiNaC::ex, Base::kPars> ¶meterSymbols()
346 if (!this->isInitParameter_) {
351 for (
int ii = 0; ii <
354 for (
int jj = 0; jj <
357 "blk" + ::std::to_string(this->
id) +
"_p_" + ::std::to_string(i) +
358 "_r_" + ::std::to_string(ii) +
"_c_" + ::std::to_string(jj),
359 "\\text{BLK}^{i, " + ::std::to_string(i) +
", " + ::std::to_string(ii) +
360 ", " + ::std::to_string(jj) +
"}_{" + ::std::to_string(this->
id) +
"}"};
367 this->parameterSymbols_[i] = GiNaC::lst_to_matrix(par);
369 this->parameterSymbols_[i] = GiNaC::symbol{
370 "blk" + ::std::to_string(this->
id) +
"_p_" + ::std::to_string(i),
371 "\\text{BLK}^{p, " + ::std::to_string(i) +
"}_{" +
372 ::std::to_string(this->
id) +
379 this->parameterSymbols_[i] = GiNaC::symbol{
380 "blk" + ::std::to_string(this->
id) +
"_p_" + ::std::to_string(i),
381 "\\text{BLK}^{p, " + ::std::to_string(i) +
"}_{" +
382 ::std::to_string(this->
id) +
385 this->isInitParameter_ =
true;
388 return this->parameterSymbols_;
401 template<GainBlockOperator TTOps = TOps, typename ::std::enable_if<TTOps == Left>::type * =
nullptr>
404 this->equation = [](
Base &b) ->
void {
405 b.template o<0>().object =
406 b.template p<0>() * b.template i<0>().object;
407 b.template o<0>().propagate();
411 GiNaC::function_options fops(
"blkf" + ::std::to_string(this->
id) +
"__", this->blkFunc_NPARAMS);
413 ls::blocks::symbolicEvalFunctionMap[this->
id] = [&](
414 const ::std::vector<GiNaC::ex> &exvec) -> GiNaC::ex {
418 for (
auto &ex: exvec) {
419 res += this->parameterSymbols()[0] * ex;
427 fops.eval_func(ls::blocks::symbolicEval);
430 this->serial = GiNaC::function::register_new(
442 template<GainBlockOperator TTOps = TOps, typename ::std::enable_if<TTOps == Right>::type * =
nullptr>
445 this->equation = [](
Base &b) ->
void {
446 b.template o<0>().object =
447 b.template i<0>().object * b.template p<0>();
448 b.template o<0>().propagate();
452 GiNaC::function_options fops(
"blkf" + ::std::to_string(this->
id) +
"__", this->blkFunc_NPARAMS);
454 ls::blocks::symbolicEvalFunctionMap[this->
id] = [&](
455 const ::std::vector<GiNaC::ex> &exvec) -> GiNaC::ex {
459 for (
auto &ex: exvec) {
460 res += ex * this->parameterSymbols()[0];
468 fops.eval_func(ls::blocks::symbolicEval);
471 this->serial = GiNaC::function::register_new(
485 typename ::std::enable_if<((TTOps ==
Convolution) && (TTConv ==
Reflect))>::type * =
nullptr>
488 static const auto NRowsKernel = GainTrait::rows;
489 static const auto NColsKernel = GainTrait::cols;
490 static const auto NRowsInput = InputTrait::rows;
491 static const auto NColsInput = InputTrait::cols;
493 auto getInput = [&](
int i,
int j) {
495 i = (-i - 1) % NRowsInput;
497 i = (2*NRowsInput - i - 1) % NRowsInput;
500 j = (-j - 1) % NColsInput;
502 j = (2*NColsInput - j - 1) % NColsInput;
504 return b.template i<0>().object(i, j);
507 b.template o<0>().object.setZero();
509 for (
int i=0; i < NRowsInput; i++) {
510 for (
int j=0; j < NColsInput; j++) {
511 for (
int ii=0; ii < NRowsKernel; ii++) {
512 for (
int jj=0; jj < NColsKernel; jj++) {
513 b.template o<0>().object(i, j) -= b.template p<0>()(ii, jj) * getInput(i - (NRowsKernel-1)/2 + ii, j - (NColsKernel-1)/2 + jj);
519 b.template o<0>().propagate();
530 typename ::std::enable_if<((TTOps ==
Convolution) && (TTConv ==
Constant))>::type * =
nullptr>
533 static const auto NRowsKernel = GainTrait::rows;
534 static const auto NColsKernel = GainTrait::cols;
535 static const auto NRowsInput = InputTrait::rows;
536 static const auto NColsInput = InputTrait::cols;
538 auto getInput = [&](
int i,
int j) {
540 return b.template p<1>();
542 return b.template p<1>();
545 return b.template p<1>();
547 return b.template p<1>();
549 return b.template i<0>().object(i, j);
552 b.template o<0>().object.setZero();
554 for (
int i=0; i < NRowsInput; i++) {
555 for (
int j=0; j < NColsInput; j++) {
556 for (
int ii=0; ii < NRowsKernel; ii++) {
557 for (
int jj=0; jj < NColsKernel; jj++) {
558 b.template o<0>().object(i, j) -= b.template p<0>()(ii, jj) * getInput(i - (NRowsKernel-1)/2 + ii, j - (NColsKernel-1)/2 + jj);
564 b.template o<0>().propagate();
575 typename ::std::enable_if<((TTOps ==
Convolution) && (TTConv ==
Nearest))>::type * =
nullptr>
578 static const auto NRowsKernel = GainTrait::rows;
579 static const auto NColsKernel = GainTrait::cols;
580 static const auto NRowsInput = InputTrait::rows;
581 static const auto NColsInput = InputTrait::cols;
583 auto getInput = [&](
int i,
int j) {
594 return b.template i<0>().object(i, j);
597 b.template o<0>().object.setZero();
599 for (
int i=0; i < NRowsInput; i++) {
600 for (
int j=0; j < NColsInput; j++) {
601 for (
int ii=0; ii < NRowsKernel; ii++) {
602 for (
int jj=0; jj < NColsKernel; jj++) {
603 b.template o<0>().object(i, j) -= b.template p<0>()(ii, jj) * getInput(i - (NRowsKernel-1)/2 + ii, j - (NColsKernel-1)/2 + jj);
609 b.template o<0>().propagate();
620 typename ::std::enable_if<((TTOps ==
Convolution) && (TTConv ==
Mirror))>::type * =
nullptr>
623 static const auto NRowsKernel = GainTrait::rows;
624 static const auto NColsKernel = GainTrait::cols;
625 static const auto NRowsInput = InputTrait::rows;
626 static const auto NColsInput = InputTrait::cols;
628 auto getInput = [&](
int i,
int j) {
630 i = (-i) % NRowsInput;
632 i = (2*NRowsInput - i - 2) % NRowsInput;
635 j = (-j) % NColsInput;
637 j = (2*NColsInput - j - 2) % NColsInput;
639 return b.template i<0>().object(i, j);
642 b.template o<0>().object.setZero();
644 for (
int i=0; i < NRowsInput; i++) {
645 for (
int j=0; j < NColsInput; j++) {
646 for (
int ii=0; ii < NRowsKernel; ii++) {
647 for (
int jj=0; jj < NColsKernel; jj++) {
648 b.template o<0>().object(i, j) -= b.template p<0>()(ii, jj) * getInput(i - (NRowsKernel-1)/2 + ii, j - (NColsKernel-1)/2 + jj);
654 b.template o<0>().propagate();
665 typename ::std::enable_if<((TTOps ==
Convolution) && (TTConv ==
Wrap))>::type * =
nullptr>
668 static const auto NRowsKernel = GainTrait::rows;
669 static const auto NColsKernel = GainTrait::cols;
670 static const auto NRowsInput = InputTrait::rows;
671 static const auto NColsInput = InputTrait::cols;
673 auto getInput = [&](
int i,
int j) {
675 i = (NRowsInput + i) % NRowsInput;
680 j = (NColsInput + j) % NColsInput;
684 return b.template i<0>().object(i, j);
687 b.template o<0>().object.setZero();
689 for (
int i=0; i < NRowsInput; i++) {
690 for (
int j=0; j < NColsInput; j++) {
691 for (
int ii=0; ii < NRowsKernel; ii++) {
692 for (
int jj=0; jj < NColsKernel; jj++) {
693 b.template o<0>().object(i, j) -= b.template p<0>()(ii, jj) * getInput(i - (NRowsKernel-1)/2 + ii, j - (NColsKernel-1)/2 + jj);
699 b.template o<0>().propagate();
712 this->equation = [](
Base &b) ->
void {
752 template<
typename TInput,
typename TGain, std::GainBlockOperator TOps, std::GainBlockConvolutionMode TConv>
769 static const ::std::array<::std::string, kIns>
inTypes;
770 static const ::std::array<::std::string, kOuts>
outTypes;
771 static const ::std::array<::std::string, kPars>
parTypes;
776 template<
typename TInput,
typename TGain, std::GainBlockOperator TOps, std::GainBlockConvolutionMode TConv>
778 {demangle(
typeid(TInput).name())};
780 template<
typename TInput,
typename TGain, std::GainBlockOperator TOps, std::GainBlockConvolutionMode TConv>
781 const ::std::array<::std::string, BlockTraits<std::GainBlock<TInput, TGain, TOps, TConv>>::kOuts> BlockTraits<std::GainBlock<TInput, TGain, TOps, TConv>>::outTypes =
784 template<
typename TInput,
typename TGain, std::GainBlockOperator TOps, std::GainBlockConvolutionMode TConv>
785 const ::std::array<::std::string, BlockTraits<std::GainBlock<TInput, TGain, TOps, TConv>>::kPars> BlockTraits<std::GainBlock<TInput, TGain, TOps, TConv>>::parTypes =
786 {demangle(
typeid(TGain).name()),
"double"};
788 template<
typename TInput,
typename TGain, std::GainBlockOperator TOps, std::GainBlockConvolutionMode TConv>
789 const ::std::array<::std::string, 4> BlockTraits<std::GainBlock<TInput, TGain, TOps, TConv>>::templateTypes =
790 {demangle(
typeid(TInput).name()), demangle(
typeid(TGain).name()), demangle(
typeid(TOps).name()),
791 demangle(
typeid(TConv).name())};
795 #undef OUTPUT_GAINBLOCK
797 #endif //LODESTAR_GAINBLOCK_HPP