From 53c7f1019d86f491cad6f8761840125f08c81146 Mon Sep 17 00:00:00 2001 From: Victor Date: Mon, 31 Jan 2022 16:11:56 +0100 Subject: [PATCH] Update ZipIterator: - Fix operator[] - Change constexpr to ALPAKA_FN_HOST_ACC --- example/zipIterator/src/zipIterator-main.cpp | 32 ++-- include/vikunja/mem/iterator/ZipIterator.hpp | 150 +++++++++---------- 2 files changed, 89 insertions(+), 93 deletions(-) diff --git a/example/zipIterator/src/zipIterator-main.cpp b/example/zipIterator/src/zipIterator-main.cpp index d93f850..c7e6caa 100644 --- a/example/zipIterator/src/zipIterator-main.cpp +++ b/example/zipIterator/src/zipIterator-main.cpp @@ -27,12 +27,12 @@ inline typename std::enable_if::type forEach(std::tuple forEach(t, f); } -template -void printTuple(IteratorTupleVal tuple) +template +void printTuple(TIteratorTupleVal tuple) { std::cout << "("; int index = 0; - int tupleSize = std::tuple_size{}; + int tupleSize = std::tuple_size{}; forEach(tuple, [&index, tupleSize](auto &x) { std::cout << x << (++index < tupleSize ? ", " : ""); }); std::cout << ")"; } @@ -113,10 +113,10 @@ int main() std::cout << "\nTesting zip iterator in host with tuple\n\n"; - using IteratorTuplePtr = std::tuple; - using IteratorTupleVal = std::tuple; - IteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble); - vikunja::mem::iterator::ZipIterator zipIter(zipTuple); + using TIteratorTuplePtr = std::tuple; + using TIteratorTupleVal = std::tuple; + TIteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble); + vikunja::mem::iterator::ZipIterator zipIter(zipTuple); std::cout << "*zipIter: "; printTuple(*zipIter); @@ -201,19 +201,17 @@ int main() std::cout << "\n\n" << "-----\n\n"; - IteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble); - vikunja::mem::iterator::ZipIterator deviceZipIter(deviceZipTuple); + TIteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble); + vikunja::mem::iterator::ZipIterator deviceZipIter(deviceZipTuple); - auto deviceMemResult(alpaka::allocBuf(devAcc, extent)); - auto hostMemResult(alpaka::allocBuf(devHost, extent)); - IteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult); - IteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult); + auto deviceMemResult(alpaka::allocBuf(devAcc, extent)); + auto hostMemResult(alpaka::allocBuf(devHost, extent)); + TIteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult); + TIteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult); - auto doubleNum = [] ALPAKA_FN_HOST_ACC(IteratorTupleVal const& t) + auto doubleNum = [] ALPAKA_FN_HOST_ACC(TIteratorTupleVal const& t) { - // return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t)); - // return std::make_tuple(static_cast(5), 'e', static_cast(14.12)); - return t; + return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t)); }; vikunja::transform::deviceTransform( diff --git a/include/vikunja/mem/iterator/ZipIterator.hpp b/include/vikunja/mem/iterator/ZipIterator.hpp index ac1e6d0..e918ad9 100644 --- a/include/vikunja/mem/iterator/ZipIterator.hpp +++ b/include/vikunja/mem/iterator/ZipIterator.hpp @@ -40,19 +40,19 @@ namespace vikunja { /** * @brief A zip iterator that takes multiple input sequences and yields a sequence of tuples - * @tparam IteratorTuplePtr The type of the data - * @tparam IteratorTupleVal The type of the data - * @tparam IdxType The type of the index + * @tparam TIteratorTuplePtr The type of the data + * @tparam TIteratorTupleVal The type of the data + * @tparam TIdx The type of the index */ - template + template class ZipIterator { public: // Need all 5 of these types for iterator_traits - using reference = IteratorTupleVal&; - using value_type = IteratorTupleVal; - using pointer = IteratorTupleVal*; - using difference_type = IdxType; + using reference = TIteratorTupleVal&; + using value_type = TIteratorTupleVal; + using pointer = TIteratorTupleVal*; + using difference_type = TIdx; using iterator_category = std::random_access_iterator_tag; /** @@ -60,53 +60,51 @@ namespace vikunja * @param iteratorTuplePtr The tuple to initialize the iterator with * @param idx The index for the iterator, default 0 */ - constexpr ZipIterator(IteratorTuplePtr iteratorTuplePtr, const IdxType& idx = static_cast(0)) - : mIndex(idx) - , mIteratorTuplePtr(iteratorTuplePtr) - , mIteratorTupleVal(makeValueTuple(mIteratorTuplePtr)) + ALPAKA_FN_HOST_ACC ZipIterator(TIteratorTuplePtr iteratorTuplePtr, const TIdx& idx = static_cast(0)) + : m_index(idx) + , m_iteratorTuplePtr(iteratorTuplePtr) + , m_iteratorTupleVal(makeValueTuple(m_iteratorTuplePtr)) { if (idx != 0) { - forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); } } /** * @brief Dereference operator to receive the stored value */ - NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator*() + NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator*() { - return mIteratorTupleVal; + return m_iteratorTupleVal; } /** * @brief Index operator to get stored value at some given offset from this iterator */ - NODISCARD constexpr ALPAKA_FN_INLINE const IteratorTupleVal operator[](const IdxType idx) + NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal operator[](const TIdx idx) const { - IteratorTuplePtr tmp = mIteratorTuplePtr; - IdxType indexDiff = idx - mIndex; - forEach(tmp, [indexDiff](auto &x) { x += indexDiff; }); - return makeValueTuple(tmp); + TIdx indexDiff = idx - m_index; + return (*this + indexDiff).operator*(); } - // NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator=(IteratorTupleVal iteratorTupleVal) + // NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator=(TIteratorTupleVal iteratorTupleVal) // { // updateIteratorTupleValue(iteratorTupleVal); - // mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); - // return mIteratorTupleVal; + // m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); + // return m_iteratorTupleVal; // } #pragma region arithmeticoperators /** * @brief Prefix increment operator */ - constexpr ALPAKA_FN_INLINE ZipIterator& operator++() + ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator++() { - ++mIndex; - forEach(mIteratorTuplePtr, [](auto &x) { ++x; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + ++m_index; + forEach(m_iteratorTuplePtr, [](auto &x) { ++x; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return *this; } @@ -114,23 +112,23 @@ namespace vikunja * @brief Postfix increment operator * @note Use prefix increment operator instead if possible to avoid copies */ - constexpr ZipIterator operator++(int) + ALPAKA_FN_HOST_ACC ZipIterator operator++(int) { ZipIterator tmp = *this; - ++mIndex; - forEach(mIteratorTuplePtr, [](auto &x) { ++x; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + ++m_index; + forEach(m_iteratorTuplePtr, [](auto &x) { ++x; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return tmp; } /** * @brief Prefix decrement operator */ - constexpr ALPAKA_FN_INLINE ZipIterator& operator--() + ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator--() { - --mIndex; - forEach(mIteratorTuplePtr, [](auto &x) { --x; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + --m_index; + forEach(m_iteratorTuplePtr, [](auto &x) { --x; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return *this; } @@ -138,19 +136,19 @@ namespace vikunja * @brief Postfix decrement operator * @note Use prefix decrement operator instead if possible to avoid copies */ - constexpr ALPAKA_FN_INLINE ZipIterator operator--(int) + ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator operator--(int) { ZipIterator tmp = *this; - --mIndex; - forEach(mIteratorTuplePtr, [](auto &x) { --x; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + --m_index; + forEach(m_iteratorTuplePtr, [](auto &x) { --x; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return tmp; } /** * @brief Add an index to this iterator */ - NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const IdxType idx) + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const TIdx idx) { zipIter += idx; return zipIter; @@ -159,7 +157,7 @@ namespace vikunja /** * @brief Subtract an index from this iterator */ - NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const IdxType idx) + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const TIdx idx) { zipIter -= idx; return zipIter; @@ -168,22 +166,22 @@ namespace vikunja /** * @brief Add an index to this iterator */ - constexpr ALPAKA_FN_INLINE ZipIterator& operator+=(const IdxType idx) + ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator+=(const TIdx idx) { - mIndex += idx; - forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + m_index += idx; + forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return *this; } /** * @brief Subtract an index from this iterator */ - constexpr ALPAKA_FN_INLINE ZipIterator& operator-=(const IdxType idx) + ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator-=(const TIdx idx) { - mIndex -= idx; - forEach(mIteratorTuplePtr, [idx](auto &x) { x -= idx; }); - mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr); + m_index -= idx; + forEach(m_iteratorTuplePtr, [idx](auto &x) { x -= idx; }); + m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr); return *this; } @@ -196,9 +194,9 @@ namespace vikunja /** * @brief Spaceship operator for comparisons */ - NODISCARD constexpr ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept + NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept { - return mIteratorTuplePtr.operator<=>(other.mIteratorTuplePtr); + return m_iteratorTuplePtr.operator<=>(other.m_iteratorTuplePtr); } #else @@ -206,15 +204,15 @@ namespace vikunja /** * @brief Equality comparison, returns true if the index are the same */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept { - return zipIter.mIndex == other.mIndex; + return zipIter.m_index == other.m_index; } /** * @brief Inequality comparison, negated equality operator */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept { return !operator==(zipIter, other); } @@ -222,42 +220,42 @@ namespace vikunja /** * @brief Less than comparison, index is checked */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept { - return zipIter.mIndex < other.mIndex; + return zipIter.m_index < other.m_index; } /** * @brief Greater than comparison, index is checked */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept { - return zipIter.mIndex > other.mIndex; + return zipIter.m_index > other.m_index; } /** * @brief Less than or equal comparison, index is checked */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept { - return zipIter.mIndex <= other.mIndex; + return zipIter.m_index <= other.m_index; } /** * @brief Greater than or equal comparison, index is checked */ - NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept + NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept { - return zipIter.mIndex >= other.mIndex; + return zipIter.m_index >= other.m_index; } #endif #pragma endregion comparisonoperators private: - IdxType mIndex; - IteratorTuplePtr mIteratorTuplePtr; - IteratorTupleVal mIteratorTupleVal; + TIdx m_index; + TIteratorTuplePtr m_iteratorTuplePtr; + TIteratorTupleVal m_iteratorTupleVal; template struct seq { }; @@ -294,17 +292,17 @@ namespace vikunja forEach(t, f); } - template - inline typename std::enable_if::type updateIteratorTupleValue(std::tuple &) // Unused arguments are given no names - { - } + // template + // inline typename std::enable_if::type updateIteratorTupleValue(std::tuple &) // Unused arguments are given no names + // { + // } - template - inline typename std::enable_if::type updateIteratorTupleValue(std::tuple& t) - { - *std::get(mIteratorTuplePtr) = std::get(t); - updateIteratorTupleValue(t); - } + // template + // inline typename std::enable_if::type updateIteratorTupleValue(std::tuple& t) + // { + // *std::get(m_iteratorTuplePtr) = std::get(t); + // updateIteratorTupleValue(t); + // } }; } // namespace iterator