Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for Concurrent Map Access in HashSet & LinkedHashSet #265

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
28 changes: 19 additions & 9 deletions sets/hashset/hashset.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package hashset
import (
"fmt"
"strings"
"sync"

"github.com/emirpasic/gods/v2/sets"
)
Expand All @@ -22,13 +23,14 @@ var _ sets.Set[int] = (*Set[int])(nil)
// Set holds elements in go's native map
type Set[T comparable] struct {
items map[T]struct{}
mux *sync.RWMutex
}

var itemExists = struct{}{}

// New instantiates a new empty set and adds the passed values, if any, to the set
func New[T comparable](values ...T) *Set[T] {
set := &Set[T]{items: make(map[T]struct{})}
set := &Set[T]{items: make(map[T]struct{}), mux: &sync.RWMutex{}}
if len(values) > 0 {
set.Add(values...)
}
Expand All @@ -37,13 +39,17 @@ func New[T comparable](values ...T) *Set[T] {

// Add adds the items (one or more) to the set.
func (set *Set[T]) Add(items ...T) {
set.mux.Lock()
defer set.mux.Unlock()
for _, item := range items {
set.items[item] = itemExists
}
}

// Remove removes the items (one or more) from the set.
func (set *Set[T]) Remove(items ...T) {
set.mux.Lock()
defer set.mux.Unlock()
for _, item := range items {
delete(set.items, item)
}
Expand All @@ -53,6 +59,8 @@ func (set *Set[T]) Remove(items ...T) {
// All items have to be present in the set for the method to return true.
// Returns true if no arguments are passed at all, i.e. set is always superset of empty set.
func (set *Set[T]) Contains(items ...T) bool {
set.mux.RLock()
defer set.mux.RUnlock()
for _, item := range items {
if _, contains := set.items[item]; !contains {
return false
Expand All @@ -78,6 +86,8 @@ func (set *Set[T]) Clear() {

// Values returns all items in the set.
func (set *Set[T]) Values() []T {
set.mux.RLock()
defer set.mux.RUnlock()
values := make([]T, set.Size())
count := 0
for item := range set.items {
Expand All @@ -91,7 +101,7 @@ func (set *Set[T]) Values() []T {
func (set *Set[T]) String() string {
str := "HashSet\n"
items := []string{}
for k := range set.items {
for _, k := range set.Values() {
items = append(items, fmt.Sprintf("%v", k))
}
str += strings.Join(items, ", ")
Expand All @@ -106,14 +116,14 @@ func (set *Set[T]) Intersection(another *Set[T]) *Set[T] {

// Iterate over smaller set (optimization)
if set.Size() <= another.Size() {
for item := range set.items {
if _, contains := another.items[item]; contains {
for _, item := range set.Values() {
if another.Contains(item) {
result.Add(item)
}
}
} else {
for item := range another.items {
if _, contains := set.items[item]; contains {
for _, item := range another.Values() {
if set.Contains(item) {
result.Add(item)
}
}
Expand All @@ -128,7 +138,7 @@ func (set *Set[T]) Intersection(another *Set[T]) *Set[T] {
func (set *Set[T]) Union(another *Set[T]) *Set[T] {
result := New[T]()

for item := range set.items {
for _, item := range set.Values() {
result.Add(item)
}
for item := range another.items {
Expand All @@ -144,8 +154,8 @@ func (set *Set[T]) Union(another *Set[T]) *Set[T] {
func (set *Set[T]) Difference(another *Set[T]) *Set[T] {
result := New[T]()

for item := range set.items {
if _, contains := another.items[item]; !contains {
for _, item := range set.Values() {
if !another.Contains(item) {
result.Add(item)
}
}
Expand Down
95 changes: 95 additions & 0 deletions sets/hashset/hashset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ package hashset

import (
"encoding/json"
"fmt"
"strings"
"sync"
"testing"
)

Expand Down Expand Up @@ -344,3 +346,96 @@ func BenchmarkHashSetRemove100000(b *testing.B) {
b.StartTimer()
benchmarkRemove(b, set, size)
}

func TestConcurrentAdd(t *testing.T) {
s := New(1)

size := 100000
wg := &sync.WaitGroup{}
wg.Add(2)

go func() {
defer wg.Done()
for i := 0; i < size; i += 2 {
s.Add(i)

}
}()
go func() {
defer wg.Done()
for i := 1; i < size; i += 2 {
s.Add(i)
}
}()

wg.Wait()
fmt.Println(s.Size())
}

func TestConcurrentRemove(t *testing.T) {
s := New(1)
size := 100000
for i := 0; i < size; i++ {
s.Add(i)
}
fmt.Println(s.Size())
wg := &sync.WaitGroup{}
wg.Add(2)

go func() {
defer wg.Done()
for i := 0; i < size; i += 2 {
s.Remove(i)
}
}()

go func() {
defer wg.Done()
for i := 1; i < size; i += 2 {
s.Remove(i)
}
}()

wg.Wait()

fmt.Println(s.Size())
}

func TestConcurrentRW(t *testing.T) {
s := New(1)
size := 1000
wg := &sync.WaitGroup{}
wg.Add(3)

go func() {
defer wg.Done()
for i := 0; i < size; i++ {
s.Add(i)
}
}()

go func() {
defer wg.Done()
for i := 0; i < size; i += 2 {
// _ = s.Contains(i)
// _ = s.Values()
// _ = s.String()
// s.Intersection(New(-1))
// s.Union(New(-1))
s.Difference(New(1))
}
}()

go func() {
defer wg.Done()
for i := 1; i < size; i += 2 {
// _ = s.Contains(i)
// _ = s.Values()
// _ = s.String()
// s.Intersection(New(-2))
// s.Union(New(-2))
s.Difference(New(2))
}
}()
wg.Wait()
}
27 changes: 19 additions & 8 deletions sets/linkedhashset/linkedhashset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package linkedhashset
import (
"fmt"
"strings"
"sync"

"github.com/emirpasic/gods/v2/lists/doublylinkedlist"
"github.com/emirpasic/gods/v2/sets"
Expand All @@ -28,6 +29,7 @@ var _ sets.Set[int] = (*Set[int])(nil)
type Set[T comparable] struct {
table map[T]struct{}
ordering *doublylinkedlist.List[T]
mux *sync.RWMutex
}

var itemExists = struct{}{}
Expand All @@ -37,6 +39,7 @@ func New[T comparable](values ...T) *Set[T] {
set := &Set[T]{
table: make(map[T]struct{}),
ordering: doublylinkedlist.New[T](),
mux: &sync.RWMutex{},
}
if len(values) > 0 {
set.Add(values...)
Expand All @@ -47,6 +50,8 @@ func New[T comparable](values ...T) *Set[T] {
// Add adds the items (one or more) to the set.
// Note that insertion-order is not affected if an element is re-inserted into the set.
func (set *Set[T]) Add(items ...T) {
set.mux.Lock()
defer set.mux.Unlock()
for _, item := range items {
if _, contains := set.table[item]; !contains {
set.table[item] = itemExists
Expand All @@ -58,6 +63,8 @@ func (set *Set[T]) Add(items ...T) {
// Remove removes the items (one or more) from the set.
// Slow operation, worst-case O(n^2).
func (set *Set[T]) Remove(items ...T) {
set.mux.Lock()
defer set.mux.Unlock()
for _, item := range items {
if _, contains := set.table[item]; contains {
delete(set.table, item)
Expand All @@ -71,6 +78,8 @@ func (set *Set[T]) Remove(items ...T) {
// All items have to be present in the set for the method to return true.
// Returns true if no arguments are passed at all, i.e. set is always superset of empty set.
func (set *Set[T]) Contains(items ...T) bool {
set.mux.RLock()
defer set.mux.RUnlock()
for _, item := range items {
if _, contains := set.table[item]; !contains {
return false
Expand All @@ -97,6 +106,8 @@ func (set *Set[T]) Clear() {

// Values returns all items in the set.
func (set *Set[T]) Values() []T {
set.mux.RLock()
defer set.mux.RUnlock()
values := make([]T, set.Size())
it := set.Iterator()
for it.Next() {
Expand Down Expand Up @@ -125,14 +136,14 @@ func (set *Set[T]) Intersection(another *Set[T]) *Set[T] {

// Iterate over smaller set (optimization)
if set.Size() <= another.Size() {
for item := range set.table {
if _, contains := another.table[item]; contains {
for _, item := range set.Values() {
if another.Contains(item) {
result.Add(item)
}
}
} else {
for item := range another.table {
if _, contains := set.table[item]; contains {
for _, item := range another.Values() {
if set.Contains(item) {
result.Add(item)
}
}
Expand All @@ -147,10 +158,10 @@ func (set *Set[T]) Intersection(another *Set[T]) *Set[T] {
func (set *Set[T]) Union(another *Set[T]) *Set[T] {
result := New[T]()

for item := range set.table {
for _, item := range set.Values() {
result.Add(item)
}
for item := range another.table {
for _, item := range another.Values() {
result.Add(item)
}

Expand All @@ -163,8 +174,8 @@ func (set *Set[T]) Union(another *Set[T]) *Set[T] {
func (set *Set[T]) Difference(another *Set[T]) *Set[T] {
result := New[T]()

for item := range set.table {
if _, contains := another.table[item]; !contains {
for _, item := range set.Values() {
if !another.Contains(item) {
result.Add(item)
}
}
Expand Down
Loading