Skip to content

Commit

Permalink
added iterator and complex type support
Browse files Browse the repository at this point in the history
  • Loading branch information
gk646 committed Apr 5, 2024
1 parent 150fea9 commit 708303f
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/cxstructs/StackArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
#ifndef CXSTRUCTS_SRC_CXSTRUCTS_STACKARRAY_H_
#define CXSTRUCTS_SRC_CXSTRUCTS_STACKARRAY_H_

#include <iterator> // For std::forward_iterator_tag
#include "../cxconfig.h"
#include <iterator> // For std::forward_iterator_tag

template <typename T, size_t N>
class StackArray {
Expand Down
253 changes: 233 additions & 20 deletions src/cxstructs/StackHashMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
#define CXSTRUCTS_SRC_CXSTRUCTS_STACKHASHMAP_H_

#include "../cxconfig.h"
#include <bitset> //For std::bitset<N> and std::hash<K>
#include <bitset> //For std::bitset<N> and std::hash<K>
#include <iterator> // For std::forward_iterator_tag

//Memory footprint is still quite big
//Using std::bitset<> saves memory but is a bit slower

/**
* StackHashMap is a hash map implemented entirely on the stack with an STL-like interface.<br>
Expand All @@ -36,7 +40,7 @@
*
* Performance Characteristics:
* - Build-up is generally slower compared to other hash maps at equal load factors.<br>
* - Lookup operations are approximately twice as fast as those in std::unordered_map.<br>
* - Lookup operations are approximately 2-3 times as fast as those in std::unordered_map.<br>
* - Erase operations are significantly faster due to the nature of the implementation.<br>
*
* @tparam K Type of the keys.
Expand All @@ -50,8 +54,6 @@
template <typename K, typename V, size_t N, typename HashFunc = std::hash<K>,
typename size_type = uint32_t>
class StackHashMap {
static_assert(std::is_trivial_v<K> && std::is_trivial_v<V>,
"StackHashMap only supports trivial types for now");
using const_key_ref = const K&;
using const_val_ref = const V&;
using move_key_ref = K&&;
Expand All @@ -66,7 +68,7 @@ class StackHashMap {

HashFunc hash_func_;
Node data_[N];
bool register_ [N] = {0};
bool register_[N] = {0};
size_type size_ = 0;
int rand_ = rand();

Expand All @@ -77,19 +79,27 @@ class StackHashMap {
return (hash_func_(key) * rand_) % N;
}

inline void add_node(size_t hash, const_key_ref key, auto val) {
auto& isSet = register_[hash];
size_ += isSet;
isSet = true;
data_[hash].val = val;
data_[hash].key = key;
inline void add_node(size_t hash, const_key_ref key, auto val) noexcept {
if (!register_[hash]) {
if constexpr (std::is_trivially_copyable_v<K> && std::is_trivially_copyable_v<V>) {
data_[hash].key = key;
data_[hash].val = val;
} else {
new (&data_[hash].key) K(key);
new (&data_[hash].val) V(val);
}
register_[hash] = true;
size_++;
} else {
data_[hash].val = val;
}
}

inline void insert_non_empty(const K& key, auto val) {
inline void insert_non_empty(const K& key, auto val) noexcept {
Node org_data[N];
bool org_register [N];
bool org_register[N];
std::memcpy(org_data, data_, N * sizeof(Node));
std::memcpy(org_register, register_, N );
std::memcpy(org_register, register_, N);

while (true) {
std::memset(register_, false, N);
Expand Down Expand Up @@ -141,10 +151,35 @@ class StackHashMap {

public:
StackHashMap() = default;
StackHashMap(const StackHashMap& other) {
if constexpr (std::is_trivial_v<K> && std::is_trivial_v<V>) {
std::memcpy(data_, other.data_, N * sizeof(Node));
} else {
for (size_t i = 0; i < N; ++i) {
if (other.register_[i]) {
new (&data_[i].key) K(other.data_[i].key);
new (&data_[i].val) V(other.data_[i].val);
}
}
}
std::memcpy(register_, other.register_, N);
size_ = other.size_;
rand_ = other.rand_;
}
explicit StackHashMap(size_type elems) : size_(elems) {}
~StackHashMap() {
if constexpr (!std::is_trivially_destructible_v<K> || !std::is_trivially_destructible_v<V>) {
for (size_t i = 0; i < N; ++i) {
if (register_[i]) {
data_[i].key.~K();
data_[i].val.~V();
}
}
}
}

//Call can result in endless loop at high load factor
inline void insert(const K& key, V&& val) {
inline void insert(const K& key, V&& val) noexcept {
CX_ASSERT(size_ < N, "Trying to add to full StackHashMap");
CX_STACK_ABORT_IMPL();

Expand All @@ -158,7 +193,7 @@ class StackHashMap {
}

//Call can result in endless loop at high load factor
inline void insert(const K& key, const V& val) {
inline void insert(const K& key, const V& val) noexcept {
CX_ASSERT(size_ < N, "Trying to add to full StackHashMap");
CX_STACK_ABORT_IMPL();

Expand All @@ -171,7 +206,7 @@ class StackHashMap {
add_node(hash, key, val);
}

inline void insert(const std::pair<const K, V>& keyValue) {
inline void insert(const std::pair<const K, V>& keyValue) noexcept {
CX_ASSERT(size_ < N, "Trying to add to full StackHashMap");
CX_STACK_ABORT_IMPL();

Expand Down Expand Up @@ -206,7 +241,14 @@ class StackHashMap {
inline bool erase(const K& key) noexcept {
const size_t hash = impl_hash_func(key);
if (register_[hash]) {
if constexpr (!std::is_trivial_v<K>) {
data_[hash].key.~K();
}
if constexpr (!std::is_trivial_v<V>) {
data_[hash].val.~V();
}
register_[hash] = false;
size_--;
return true;
}
return false;
Expand All @@ -217,15 +259,166 @@ class StackHashMap {
return register_[hash] && data_[hash].key == key;
}

inline void set_rand(int rand) noexcept { rand_ = rand; }

inline void clear() noexcept {
register_.reset();
std::memset(register_, false, N);
size_ = 0;
}

[[nodiscard]] inline bool empty() const noexcept { return register_.any(); }
[[nodiscard]] inline bool empty() const noexcept { return !register_.any(); }

[[nodiscard]] inline float load_factor() const noexcept { return (float)size_ / (float)N; }

[[nodiscard]] inline size_t get_hash(const K& key) const noexcept { return impl_hash_func(key); }

// Iterator support
class KeyIterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = K;
using difference_type = std::ptrdiff_t;
using pointer = K*;
using reference = K&;

explicit KeyIterator(Node* ptr, bool* reg, size_t size)
: ptr_(ptr), reg_(reg), size_(size), index_(0) {
if (!reg_[index_]) {
++(*this);
}
}

reference operator*() const { return ptr_[index_].key; }
pointer operator->() { return &(ptr_[index_].key); }

// Prefix increment
KeyIterator& operator++() {
do {
index_++;
} while (index_ < size_ && !reg_[index_]);
return *this;
}

// Postfix increment
KeyIterator operator++(int) {
KeyIterator tmp = *this;
++(*this);
return tmp;
}

friend bool operator==(const KeyIterator& a, const KeyIterator& b) {
return a.ptr_ + a.index_ == b.ptr_ + b.index_;
}
friend bool operator!=(const KeyIterator& a, const KeyIterator& b) { return !(a == b); }

private:
Node* ptr_;
bool* reg_;
size_t size_;
size_t index_;
};

class ValueIterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = V;
using difference_type = std::ptrdiff_t;
using pointer = V*;
using reference = V&;

explicit ValueIterator(Node* ptr, bool* reg, size_t size)
: ptr_(ptr), reg_(reg), size_(size), index_(0) {
if (!reg_[index_]) {
++(*this);
}
}

reference operator*() const { return ptr_[index_].key; }
pointer operator->() { return &(ptr_[index_].key); }

// Prefix increment
ValueIterator& operator++() {
do {
index_++;
} while (index_ < size_ && !reg_[index_]);
return *this;
}

// Postfix increment
ValueIterator operator++(int) {
ValueIterator tmp = *this;
++(*this);
return tmp;
}

friend bool operator==(const ValueIterator& a, const ValueIterator& b) {
return a.ptr_ + a.index_ == b.ptr_ + b.index_;
}
friend bool operator!=(const ValueIterator& a, const ValueIterator& b) { return !(a == b); }

private:
Node* ptr_;
bool* reg_;
size_t size_;
size_t index_;
};

class PairIterator {
public:
using iterator_category = std::forward_iterator_tag;
using value_type = Node;
using difference_type = std::ptrdiff_t;
using pointer = Node*;
using reference = Node&;

explicit PairIterator(Node* ptr, bool* reg, size_t size)
: ptr_(ptr), reg_(reg), size_(size), index_(0) {
if (!reg_[index_]) {
++(*this);
}
}

reference operator*() const { return ptr_[index_]; }
pointer operator->() { return &(ptr_[index_]); }

// Prefix increment
PairIterator& operator++() {
do {
index_++;
} while (index_ < size_ && !reg_[index_]);
return *this;
}

// Postfix increment
PairIterator operator++(int) {
PairIterator tmp = *this;
++(*this);
return tmp;
}

friend bool operator==(const PairIterator& a, const PairIterator& b) {
return a.ptr_ + a.index_ == b.ptr_ + b.index_;
}
friend bool operator!=(const PairIterator& a, const PairIterator& b) {
return a.ptr_ + a.index_ != b.ptr_ + b.index_;
}

private:
Node* ptr_;
bool* reg_;
size_t size_;
size_t index_;
};

KeyIterator key_begin() { return KeyIterator(data_, register_, N); }
KeyIterator key_end() { return KeyIterator(data_ + N, register_, N); }

ValueIterator value_begin() { return ValueIterator(data_, register_, N); }
ValueIterator value_end() { return ValueIterator(data_ + N, register_, N); }

PairIterator begin() { return PairIterator(data_, register_, N); }
PairIterator end() { return PairIterator(data_ + N, register_, N); }

#ifdef CX_INCLUDE_TESTS
static void TEST() {
StackHashMap<int, int, 100> map1;
Expand All @@ -250,14 +443,34 @@ class StackHashMap {
map1.clear();

for (int i = 0; i < 50; i++) {
printf("%f\n", map1.load_factor());
map1.insert(i, i * 10);
CX_ASSERT(map1[i] == i * 10, "");
}

for (int i = 0; i < 50; i++) {
CX_ASSERT(map1[i] == i * 10, "");
}

StackHashMap<int, int, 64> map2;
for (int i = 0; i < 50; i++) {
map2.insert(i, i * 10);
CX_ASSERT(map2[i] == i * 10, "");
}

for (auto& pair : map2) {
pair.val = 10;
}

for (auto& pair : map2) {
CX_ASSERT(pair.val == 10, "");
}

StackHashMap<std::string, int, 128> myMap2;

myMap2.insert("hey", 100);
myMap2.insert("blabla", 100);

printf("%d", myMap2["hey"]);
}
#endif
};
Expand Down

0 comments on commit 708303f

Please sign in to comment.