Skip to content

Commit

Permalink
Update ZipIterator:
Browse files Browse the repository at this point in the history
- Fix operator[]
- Change constexpr to ALPAKA_FN_HOST_ACC
  • Loading branch information
victorjunaidy committed Jan 31, 2022
1 parent f6625f8 commit 53c7f10
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 93 deletions.
32 changes: 15 additions & 17 deletions example/zipIterator/src/zipIterator-main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ inline typename std::enable_if<I < sizeof...(Tp), void>::type forEach(std::tuple
forEach<I + 1, FuncT, Tp...>(t, f);
}

template<typename IteratorTupleVal>
void printTuple(IteratorTupleVal tuple)
template<typename TIteratorTupleVal>
void printTuple(TIteratorTupleVal tuple)
{
std::cout << "(";
int index = 0;
int tupleSize = std::tuple_size<IteratorTupleVal>{};
int tupleSize = std::tuple_size<TIteratorTupleVal>{};
forEach(tuple, [&index, tupleSize](auto &x) { std::cout << x << (++index < tupleSize ? ", " : ""); });
std::cout << ")";
}
Expand Down Expand Up @@ -113,10 +113,10 @@ int main()

std::cout << "\nTesting zip iterator in host with tuple<uint64_t, char, double>\n\n";

using IteratorTuplePtr = std::tuple<uint64_t*, char*, double*>;
using IteratorTupleVal = std::tuple<uint64_t, char, double>;
IteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble);
vikunja::mem::iterator::ZipIterator<IteratorTuplePtr, IteratorTupleVal> zipIter(zipTuple);
using TIteratorTuplePtr = std::tuple<uint64_t*, char*, double*>;
using TIteratorTupleVal = std::tuple<uint64_t, char, double>;
TIteratorTuplePtr zipTuple = std::make_tuple(hostNative, hostNativeChar, hostNativeDouble);
vikunja::mem::iterator::ZipIterator<TIteratorTuplePtr, TIteratorTupleVal> zipIter(zipTuple);

std::cout << "*zipIter: ";
printTuple(*zipIter);
Expand Down Expand Up @@ -201,19 +201,17 @@ int main()
std::cout << "\n\n"
<< "-----\n\n";

IteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble);
vikunja::mem::iterator::ZipIterator<IteratorTuplePtr, IteratorTupleVal> deviceZipIter(deviceZipTuple);
TIteratorTuplePtr deviceZipTuple = std::make_tuple(deviceNative, deviceNativeChar, deviceNativeDouble);
vikunja::mem::iterator::ZipIterator<TIteratorTuplePtr, TIteratorTupleVal> deviceZipIter(deviceZipTuple);

auto deviceMemResult(alpaka::allocBuf<IteratorTupleVal, Idx>(devAcc, extent));
auto hostMemResult(alpaka::allocBuf<IteratorTupleVal, Idx>(devHost, extent));
IteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult);
IteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult);
auto deviceMemResult(alpaka::allocBuf<TIteratorTupleVal, Idx>(devAcc, extent));
auto hostMemResult(alpaka::allocBuf<TIteratorTupleVal, Idx>(devHost, extent));
TIteratorTupleVal* hostNativeResultPtr = alpaka::getPtrNative(hostMemResult);
TIteratorTupleVal* deviceNativeResultPtr = alpaka::getPtrNative(deviceMemResult);

auto doubleNum = [] ALPAKA_FN_HOST_ACC(IteratorTupleVal const& t)
auto doubleNum = [] ALPAKA_FN_HOST_ACC(TIteratorTupleVal const& t)
{
// return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t));
// return std::make_tuple(static_cast<uint64_t>(5), 'e', static_cast<double>(14.12));
return t;
return std::make_tuple(2 * std::get<0>(t), std::get<1>(t), 2 * std::get<2>(t));
};

vikunja::transform::deviceTransform<TAcc>(
Expand Down
150 changes: 74 additions & 76 deletions include/vikunja/mem/iterator/ZipIterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,117 +40,115 @@ namespace vikunja
{
/**
* @brief A zip iterator that takes multiple input sequences and yields a sequence of tuples
* @tparam IteratorTuplePtr The type of the data
* @tparam IteratorTupleVal The type of the data
* @tparam IdxType The type of the index
* @tparam TIteratorTuplePtr The type of the data
* @tparam TIteratorTupleVal The type of the data
* @tparam TIdx The type of the index
*/
template<typename IteratorTuplePtr, typename IteratorTupleVal, typename IdxType = int64_t>
template<typename TIteratorTuplePtr, typename TIteratorTupleVal, typename TIdx = int64_t>
class ZipIterator
{
public:
// Need all 5 of these types for iterator_traits
using reference = IteratorTupleVal&;
using value_type = IteratorTupleVal;
using pointer = IteratorTupleVal*;
using difference_type = IdxType;
using reference = TIteratorTupleVal&;
using value_type = TIteratorTupleVal;
using pointer = TIteratorTupleVal*;
using difference_type = TIdx;
using iterator_category = std::random_access_iterator_tag;

/**
* @brief Constructor for the ZipIterator
* @param iteratorTuplePtr The tuple to initialize the iterator with
* @param idx The index for the iterator, default 0
*/
constexpr ZipIterator(IteratorTuplePtr iteratorTuplePtr, const IdxType& idx = static_cast<IdxType>(0))
: mIndex(idx)
, mIteratorTuplePtr(iteratorTuplePtr)
, mIteratorTupleVal(makeValueTuple(mIteratorTuplePtr))
ALPAKA_FN_HOST_ACC ZipIterator(TIteratorTuplePtr iteratorTuplePtr, const TIdx& idx = static_cast<TIdx>(0))
: m_index(idx)
, m_iteratorTuplePtr(iteratorTuplePtr)
, m_iteratorTupleVal(makeValueTuple(m_iteratorTuplePtr))
{
if (idx != 0)
{
forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
}
}

/**
* @brief Dereference operator to receive the stored value
*/
NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator*()
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator*()
{
return mIteratorTupleVal;
return m_iteratorTupleVal;
}

/**
* @brief Index operator to get stored value at some given offset from this iterator
*/
NODISCARD constexpr ALPAKA_FN_INLINE const IteratorTupleVal operator[](const IdxType idx)
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal operator[](const TIdx idx) const
{
IteratorTuplePtr tmp = mIteratorTuplePtr;
IdxType indexDiff = idx - mIndex;
forEach(tmp, [indexDiff](auto &x) { x += indexDiff; });
return makeValueTuple(tmp);
TIdx indexDiff = idx - m_index;
return (*this + indexDiff).operator*();
}

// NODISCARD constexpr ALPAKA_FN_INLINE IteratorTupleVal& operator=(IteratorTupleVal iteratorTupleVal)
// NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE TIteratorTupleVal& operator=(TIteratorTupleVal iteratorTupleVal)
// {
// updateIteratorTupleValue(iteratorTupleVal);
// mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
// return mIteratorTupleVal;
// m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
// return m_iteratorTupleVal;
// }

#pragma region arithmeticoperators
/**
* @brief Prefix increment operator
*/
constexpr ALPAKA_FN_INLINE ZipIterator& operator++()
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator++()
{
++mIndex;
forEach(mIteratorTuplePtr, [](auto &x) { ++x; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
++m_index;
forEach(m_iteratorTuplePtr, [](auto &x) { ++x; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return *this;
}

/**
* @brief Postfix increment operator
* @note Use prefix increment operator instead if possible to avoid copies
*/
constexpr ZipIterator operator++(int)
ALPAKA_FN_HOST_ACC ZipIterator operator++(int)
{
ZipIterator tmp = *this;
++mIndex;
forEach(mIteratorTuplePtr, [](auto &x) { ++x; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
++m_index;
forEach(m_iteratorTuplePtr, [](auto &x) { ++x; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return tmp;
}

/**
* @brief Prefix decrement operator
*/
constexpr ALPAKA_FN_INLINE ZipIterator& operator--()
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator--()
{
--mIndex;
forEach(mIteratorTuplePtr, [](auto &x) { --x; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
--m_index;
forEach(m_iteratorTuplePtr, [](auto &x) { --x; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return *this;
}

/**
* @brief Postfix decrement operator
* @note Use prefix decrement operator instead if possible to avoid copies
*/
constexpr ALPAKA_FN_INLINE ZipIterator operator--(int)
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator operator--(int)
{
ZipIterator tmp = *this;
--mIndex;
forEach(mIteratorTuplePtr, [](auto &x) { --x; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
--m_index;
forEach(m_iteratorTuplePtr, [](auto &x) { --x; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return tmp;
}

/**
* @brief Add an index to this iterator
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const IdxType idx)
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator+(ZipIterator zipIter, const TIdx idx)
{
zipIter += idx;
return zipIter;
Expand All @@ -159,7 +157,7 @@ namespace vikunja
/**
* @brief Subtract an index from this iterator
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const IdxType idx)
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE ZipIterator operator-(ZipIterator zipIter, const TIdx idx)
{
zipIter -= idx;
return zipIter;
Expand All @@ -168,22 +166,22 @@ namespace vikunja
/**
* @brief Add an index to this iterator
*/
constexpr ALPAKA_FN_INLINE ZipIterator& operator+=(const IdxType idx)
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator+=(const TIdx idx)
{
mIndex += idx;
forEach(mIteratorTuplePtr, [idx](auto &x) { x += idx; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
m_index += idx;
forEach(m_iteratorTuplePtr, [idx](auto &x) { x += idx; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return *this;
}

/**
* @brief Subtract an index from this iterator
*/
constexpr ALPAKA_FN_INLINE ZipIterator& operator-=(const IdxType idx)
ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE ZipIterator& operator-=(const TIdx idx)
{
mIndex -= idx;
forEach(mIteratorTuplePtr, [idx](auto &x) { x -= idx; });
mIteratorTupleVal = makeValueTuple(mIteratorTuplePtr);
m_index -= idx;
forEach(m_iteratorTuplePtr, [idx](auto &x) { x -= idx; });
m_iteratorTupleVal = makeValueTuple(m_iteratorTuplePtr);
return *this;
}

Expand All @@ -196,68 +194,68 @@ namespace vikunja
/**
* @brief Spaceship operator for comparisons
*/
NODISCARD constexpr ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept
NODISCARD ALPAKA_FN_HOST_ACC ALPAKA_FN_INLINE auto operator<=>(const ZipIterator& other) const noexcept
{
return mIteratorTuplePtr.operator<=>(other.mIteratorTuplePtr);
return m_iteratorTuplePtr.operator<=>(other.m_iteratorTuplePtr);
}

#else

/**
* @brief Equality comparison, returns true if the index are the same
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator==(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return zipIter.mIndex == other.mIndex;
return zipIter.m_index == other.m_index;
}

/**
* @brief Inequality comparison, negated equality operator
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator!=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return !operator==(zipIter, other);
}

/**
* @brief Less than comparison, index is checked
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return zipIter.mIndex < other.mIndex;
return zipIter.m_index < other.m_index;
}

/**
* @brief Greater than comparison, index is checked
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return zipIter.mIndex > other.mIndex;
return zipIter.m_index > other.m_index;
}

/**
* @brief Less than or equal comparison, index is checked
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator<=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return zipIter.mIndex <= other.mIndex;
return zipIter.m_index <= other.m_index;
}

/**
* @brief Greater than or equal comparison, index is checked
*/
NODISCARD constexpr friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
NODISCARD ALPAKA_FN_HOST_ACC friend ALPAKA_FN_INLINE bool operator>=(const ZipIterator& zipIter, const ZipIterator& other) noexcept
{
return zipIter.mIndex >= other.mIndex;
return zipIter.m_index >= other.m_index;
}
#endif

#pragma endregion comparisonoperators

private:
IdxType mIndex;
IteratorTuplePtr mIteratorTuplePtr;
IteratorTupleVal mIteratorTupleVal;
TIdx m_index;
TIteratorTuplePtr m_iteratorTuplePtr;
TIteratorTupleVal m_iteratorTupleVal;

template<int... Is>
struct seq { };
Expand Down Expand Up @@ -294,17 +292,17 @@ namespace vikunja
forEach<I + 1, FuncT, Tp...>(t, f);
}

template<std::size_t I = 0, typename... Tp>
inline typename std::enable_if<I == sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...> &) // Unused arguments are given no names
{
}
// template<std::size_t I = 0, typename... Tp>
// inline typename std::enable_if<I == sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...> &) // Unused arguments are given no names
// {
// }

template<std::size_t I = 0, typename... Tp>
inline typename std::enable_if<I < sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...>& t)
{
*std::get<I>(mIteratorTuplePtr) = std::get<I>(t);
updateIteratorTupleValue<I + 1, Tp...>(t);
}
// template<std::size_t I = 0, typename... Tp>
// inline typename std::enable_if<I < sizeof...(Tp), void>::type updateIteratorTupleValue(std::tuple<Tp...>& t)
// {
// *std::get<I>(m_iteratorTuplePtr) = std::get<I>(t);
// updateIteratorTupleValue<I + 1, Tp...>(t);
// }
};

} // namespace iterator
Expand Down

0 comments on commit 53c7f10

Please sign in to comment.