diff --git a/util/set/set.go b/util/set/set.go index 987747892..202262a23 100644 --- a/util/set/set.go +++ b/util/set/set.go @@ -4,6 +4,10 @@ // Package set contains set types. package set +import ( + "maps" +) + // Set is a set of T. type Set[T comparable] map[T]struct{} @@ -14,16 +18,28 @@ func SetOf[T comparable](slice []T) Set[T] { return s } -// Add adds e to the set. +// Clone returns a new set cloned from the elements in s. +func Clone[T comparable](s Set[T]) Set[T] { + return maps.Clone(s) +} + +// Add adds e to s. func (s Set[T]) Add(e T) { s[e] = struct{}{} } -// AddSlice adds each element of es to the set. +// AddSlice adds each element of es to s. func (s Set[T]) AddSlice(es []T) { for _, e := range es { s.Add(e) } } +// AddSet adds each element of es to s. +func (s Set[T]) AddSet(es Set[T]) { + for e := range es { + s.Add(e) + } +} + // Slice returns the elements of the set as a slice. The elements will not be // in any particular order. func (s Set[T]) Slice() []T { @@ -45,3 +61,8 @@ func (s Set[T]) Contains(e T) bool { // Len reports the number of items in s. func (s Set[T]) Len() int { return len(s) } + +// Equal reports whether s is equal to other. +func (s Set[T]) Equal(other Set[T]) bool { + return maps.Equal(s, other) +} diff --git a/util/set/set_test.go b/util/set/set_test.go index e898f4f69..1c98631df 100644 --- a/util/set/set_test.go +++ b/util/set/set_test.go @@ -54,7 +54,7 @@ func TestSet(t *testing.T) { func TestSetOf(t *testing.T) { s := SetOf[int]([]int{1, 2, 3, 4, 4, 1}) if s.Len() != 4 { - t.Errorf("wrong len %d; want 2", s.Len()) + t.Errorf("wrong len %d; want 4", s.Len()) } for _, n := range []int{1, 2, 3, 4} { if !s.Contains(n) { @@ -62,3 +62,53 @@ func TestSetOf(t *testing.T) { } } } + +func TestEqual(t *testing.T) { + type test struct { + name string + a Set[int] + b Set[int] + expected bool + } + tests := []test{ + { + "equal", + SetOf([]int{1, 2, 3, 4}), + SetOf([]int{1, 2, 3, 4}), + true, + }, + { + "not equal", + SetOf([]int{1, 2, 3, 4}), + SetOf([]int{1, 2, 3, 5}), + false, + }, + { + "different lengths", + SetOf([]int{1, 2, 3, 4, 5}), + SetOf([]int{1, 2, 3, 5}), + false, + }, + } + + for _, tt := range tests { + if tt.a.Equal(tt.b) != tt.expected { + t.Errorf("%s: failed", tt.name) + } + } +} + +func TestClone(t *testing.T) { + s := SetOf[int]([]int{1, 2, 3, 4, 4, 1}) + if s.Len() != 4 { + t.Errorf("wrong len %d; want 4", s.Len()) + } + s2 := Clone(s) + if !s.Equal(s2) { + t.Error("clone not equal to original") + } + s.Add(100) + if s.Equal(s2) { + t.Error("clone is not distinct from original") + } +}