本征:可修改的自定义表达式 [英] Eigen: Modifyable Custom Expression

查看:127
本文介绍了本征:可修改的自定义表达式的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用Eigen实现可修改的自定义表达式,类似于此问题。基本上,我想要的是与教程中的索引示例相似的东西,但是为选定的系数分配新值的可能性。

I'm trying to implement a modifyable custom expression using Eigen, similar to this question. Basically, what I want is something similar to the indexing example in the tutorial, but with the possibility to assign new values to the selected coefficients.

正如上面提到的问题中公认的答案所示,我研究了 Transpose 的实现,并尝试了许多事情,但没有成功。基本上,我的尝试失败了,并出现如下错误:'Eigen :: internal :: evaluator< SrcXprType> :: evaluator(const Eigen :: internal :: evaluator< SrcXprType>&)':无法转换参数1从 const Eigen :: Indexing< Derived>到 Eigen :: Indexing< Derived> &’。问题可能出在我的 evaluator 结构中,该结构似乎是只读的。

As suggested in the accepted answer in the question mentioned above, I have looked into the Transpose implementation and tried many things, yet without success. Basically, my attempts are failing with errors like 'Eigen::internal::evaluator<SrcXprType>::evaluator(const Eigen::internal::evaluator<SrcXprType> &)': cannot convert argument 1 from 'const Eigen::Indexing<Derived>' to 'Eigen::Indexing<Derived> &'. Probably, the problem lies in my evaluator struct which seems to be read-only.

namespace Eigen {
namespace internal {
    template<typename ArgType>
    struct evaluator<Indexing<ArgType> >
        : evaluator_base<Indexing<ArgType> >
    {
        typedef Indexing<ArgType> XprType;
        typedef typename nested_eval<ArgType, XprType::ColsAtCompileTime>::type ArgTypeNested;
        typedef typename remove_all<ArgTypeNested>::type ArgTypeNestedCleaned;
        typedef typename XprType::CoeffReturnType CoeffReturnType;
        typedef typename traits<ArgType>::Scalar Scalar;
        enum {
            CoeffReadCost = evaluator<ArgTypeNestedCleaned>::CoeffReadCost,
            Flags = Eigen::ColMajor
        };

        evaluator(XprType& xpr)
            : m_argImpl(xpr.m_arg), m_rows(xpr.rows())
        { }
        const Scalar& coeffRef(Index row, Index col) const
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        Scalar& coeffRef(Index row, Index col)
        {
             return m_argImpl.coeffRef(... very clever stuff ...)
        }

        evaluator<ArgTypeNestedCleaned> m_argImpl;
        const Index m_rows;
    };
}
}

此外,我还更改了所有 typedef typename Eigen :: internal :: ref_selector< ArgType> :: type ... :: non_const_type ,但这已经没有效果。

Also, I've changed all occurences of typedef typename Eigen::internal::ref_selector<ArgType>::type to ...::non_const_type, but this had no effect.

由于Eigen库的复杂性,我无法弄清楚如何正确地使表达式和求值器困惑。我不明白为什么我的评估者是只读的,或者如何获得启用写功能的评估者。
如果有人可以为可修改的自定义表达式提供一个最小的示例,那就太好了。

Due to the complexity of the Eigen library, I cant figure out how to puzzle the expression and the evaluator together correctly. I don't understand, why my evaluator is read-only or how to get a write-enabled evaluator. It would be great if someone could provide a minimal example for a modifyable custom expression.

推荐答案

借助ggael的帮助提示我已经能够成功添加自己的可修改表达式。我基本上已经改编了Eigen开发分支的 IndexedView

With help of ggael's hint I've been able to sucessfully add my own modifyable expression. I've basically adapted the IndexedView of the Eigen development branch.

由于最初要求的功能由 IndexedView 覆盖,因此我编写了一个可修改的循环移位函数,它很简单可修改的自定义表达式的示例。大部分代码直接取自 IndexedView ,因此功劳归功于该代码的作者。

As the originally requested funcionality is covered by the IndexedView, I've written a modifyable circular shift function as simple example of a modifyable custom expression. Most of the code is directly taken from the IndexedView, so credits go to the authors of that.

// circ_shift.h
#pragma once
#include <Eigen/Core>

namespace helper
{
        namespace detail
    {
        template <typename T>
        constexpr std::true_type is_matrix(Eigen::MatrixBase<T>);
        std::false_type constexpr is_matrix(...);

        template <typename T>
        constexpr std::true_type is_array(Eigen::ArrayBase<T>);
        std::false_type constexpr is_array(...);
    }


    template <typename T>
    struct is_matrix : decltype(detail::is_matrix(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    struct is_array : decltype(detail::is_array(std::declval<std::remove_cv_t<T>>()))
    {
    };

    template <typename T>
    using is_matrix_or_array = std::bool_constant<is_array<T>::value || is_matrix<T>::value>;



    /*
     * Index something if it's not an scalar
     */
    template <typename T, typename std::enable_if<is_matrix_or_array<T>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index idx)
    {
        return thing(idx);
    }

    /*
    * Overload for scalar.
    */
    template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value, int>::type = 0>
    auto index_if_necessary(T&& thing, Eigen::Index)
    {
        return thing;
    }
}

namespace Eigen
{
    template <typename XprType, typename RowIndices, typename ColIndices>
    class CircShiftedView;

    namespace internal
    {
        template <typename XprType, typename RowIndices, typename ColIndices>
        struct traits<CircShiftedView<XprType, RowIndices, ColIndices>>
            : traits<XprType>
        {
            enum
            {
                RowsAtCompileTime = traits<XprType>::RowsAtCompileTime,
                ColsAtCompileTime = traits<XprType>::ColsAtCompileTime,
                MaxRowsAtCompileTime = RowsAtCompileTime != Dynamic ? int(RowsAtCompileTime) : int(traits<XprType>::MaxRowsAtCompileTime),
                MaxColsAtCompileTime = ColsAtCompileTime != Dynamic ? int(ColsAtCompileTime) : int(traits<XprType>::MaxColsAtCompileTime),

                XprTypeIsRowMajor = (int(traits<XprType>::Flags) & RowMajorBit) != 0,
                IsRowMajor = (MaxRowsAtCompileTime == 1 && MaxColsAtCompileTime != 1) ? 1
                                 : (MaxColsAtCompileTime == 1 && MaxRowsAtCompileTime != 1) ? 0
                                 : XprTypeIsRowMajor,


                FlagsRowMajorBit = IsRowMajor ? RowMajorBit : 0,
                FlagsLvalueBit = is_lvalue<XprType>::value ? LvalueBit : 0,
                Flags = (traits<XprType>::Flags & HereditaryBits) | FlagsLvalueBit | FlagsRowMajorBit
            };
        };
    }

    template <typename XprType, typename RowShift, typename ColShift, typename StorageKind>
    class CircShiftedViewImpl;


    template <typename XprType, typename RowShift, typename ColShift>
    class CircShiftedView : public CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>
    {
    public:
        typedef typename CircShiftedViewImpl<XprType, RowShift, ColShift, typename internal::traits<XprType>::StorageKind>::Base Base;
        EIGEN_GENERIC_PUBLIC_INTERFACE(CircShiftedView)
        EIGEN_INHERIT_ASSIGNMENT_OPERATORS(CircShiftedView)

        typedef typename internal::ref_selector<XprType>::non_const_type MatrixTypeNested;
        typedef typename internal::remove_all<XprType>::type NestedExpression;

        template <typename T0, typename T1>
        CircShiftedView(XprType& xpr, const T0& rowShift, const T1& colShift)
            : m_xpr(xpr), m_rowShift(rowShift), m_colShift(colShift)
        {
            for (auto c = 0; c < xpr.cols(); ++c)
            assert(std::abs(helper::index_if_necessary(m_rowShift, c)) < m_xpr.rows()); // row shift must be within +- rows()-1
            for (auto r = 0; r < xpr.rows(); ++r)
            assert(std::abs(helper::index_if_necessary(m_colShift, r)) < m_xpr.cols()); // col shift must be within +- cols()-1
        }

        /** \returns number of rows */
        Index rows() const { return m_xpr.rows(); }

        /** \returns number of columns */
        Index cols() const { return m_xpr.cols(); }

        /** \returns the nested expression */
        const typename internal::remove_all<XprType>::type&
        nestedExpression() const { return m_xpr; }

        /** \returns the nested expression */
        typename internal::remove_reference<XprType>::type&
        nestedExpression() { return m_xpr.const_cast_derived(); }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getRowIdx(Index row, Index col) const
        {
            Index R = m_xpr.rows();
            assert(row >= 0 && row < R && col >= 0 && col < m_xpr.cols());
            Index r = row - helper::index_if_necessary(m_rowShift, col);
            if (r >= R)
                return r - R;
            if (r < 0)
                return r + R;
            return r;
        }

        EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
        Index getColIdx(Index row, Index col) const
        {
            Index C = m_xpr.cols();
            assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < C);
            Index c = col - helper::index_if_necessary(m_colShift, row);
            if (c >= C)
                return c - C;
            if (c < 0)
                return c + C;
            return c;
        }

    protected:
        MatrixTypeNested m_xpr;
        RowShift m_rowShift;
        ColShift m_colShift;
    };


    // Generic API dispatcher
    template <typename XprType, typename RowIndices, typename ColIndices, typename StorageKind>
    class CircShiftedViewImpl
        : public internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type
    {
    public:
        typedef typename internal::generic_xpr_base<CircShiftedView<XprType, RowIndices, ColIndices>>::type Base;
    };

    namespace internal
    {
        template <typename ArgType, typename RowIndices, typename ColIndices>
        struct unary_evaluator<CircShiftedView<ArgType, RowIndices, ColIndices>, IndexBased>
            : evaluator_base<CircShiftedView<ArgType, RowIndices, ColIndices>>
        {
            typedef CircShiftedView<ArgType, RowIndices, ColIndices> XprType;

            enum
            {
                CoeffReadCost = evaluator<ArgType>::CoeffReadCost + NumTraits<Index>::AddCost /* for comparison */ + NumTraits<Index>::AddCost /*for addition*/,

                Flags = (evaluator<ArgType>::Flags & HereditaryBits),

                Alignment = 0
            };

            EIGEN_DEVICE_FUNC explicit unary_evaluator(const XprType& xpr) : m_argImpl(xpr.nestedExpression()), m_xpr(xpr)
            {
                EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
            }

            typedef typename XprType::Scalar Scalar;
            typedef typename XprType::CoeffReturnType CoeffReturnType;


            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            CoeffReturnType coeff(Index row, Index col) const
            {
                return m_argImpl.coeff(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

            EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
            Scalar& coeffRef(Index row, Index col)
            {
                assert(row >= 0 && row < m_xpr.rows() && col >= 0 && col < m_xpr.cols());

                return m_argImpl.coeffRef(m_xpr.getRowIdx(row, col), m_xpr.getColIdx(row, col));
            }

        protected:

            evaluator<ArgType> m_argImpl;
            const XprType& m_xpr;
        };
    } // end namespace internal
} // end namespace Eigen


template <typename XprType, typename RowShift, typename ColShift>
auto circShift(Eigen::DenseBase<XprType>& x, RowShift r, ColShift c)
{
    return Eigen::CircShiftedView<XprType, RowShift, ColShift>(x.derived(), r, c);
}

并且:

// main.cpp
#include "stdafx.h"
#include "Eigen/Core"
#include <iostream>
#include "circ_shift.h"

using namespace Eigen;


int main()
{

    ArrayXXf x(4, 2);
    x.transpose() << 1, 2, 3, 4, 10, 20, 30, 40;


    Vector2i rowShift;
    rowShift << 3, -3; // rotate col 1 by 3 and col 2 by -3

    Index colShift = 1; // flip columns

    auto shifted = circShift(x, rowShift, colShift);

    std::cout << "shifted: " << std::endl << shifted << std::endl;

    shifted.block(2,0,2,1) << -1, -2; // will appear in row 3 and 0.
    shifted.col(1) << 2,4,6,8;  // shifted col 1 is col 0 of the original

    std::cout << "modified original:" << std::endl << x << std::endl;

    return 0;
}

这篇关于本征:可修改的自定义表达式的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆