From 2c1f14d9e6e785eae8f82e88fe2651cd512d9f67 Mon Sep 17 00:00:00 2001 From: Andrew Lytvynov Date: Mon, 20 Nov 2023 09:00:31 -0700 Subject: [PATCH] util/set: implement json.Marshaler/Unmarshaler (#10308) Marshal as a JSON list instead of a map. Because set elements are `comparable` and not `cmp.Ordered`, we cannot easily sort the items before marshaling. Updates #cleanup Signed-off-by: Andrew Lytvynov --- util/set/set.go | 14 ++++++++++++++ util/set/set_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/util/set/set.go b/util/set/set.go index 6a0111f9c..78c929818 100644 --- a/util/set/set.go +++ b/util/set/set.go @@ -5,6 +5,7 @@ package set import ( + "encoding/json" "maps" ) @@ -66,3 +67,16 @@ func (s Set[T]) Len() int { return len(s) } func (s Set[T]) Equal(other Set[T]) bool { return maps.Equal(s, other) } + +func (s Set[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(s.Slice()) +} + +func (s *Set[T]) UnmarshalJSON(buf []byte) error { + var ss []T + if err := json.Unmarshal(buf, &ss); err != nil { + return err + } + *s = SetOf(ss) + return nil +} diff --git a/util/set/set_test.go b/util/set/set_test.go index c0c7826ec..cff81c776 100644 --- a/util/set/set_test.go +++ b/util/set/set_test.go @@ -4,6 +4,7 @@ package set import ( + "encoding/json" "slices" "testing" ) @@ -112,3 +113,48 @@ func TestClone(t *testing.T) { t.Error("clone is not distinct from original") } } + +func TestSetJSONRoundTrip(t *testing.T) { + tests := []struct { + desc string + strings Set[string] + ints Set[int] + }{ + {"empty", make(Set[string]), make(Set[int])}, + {"nil", nil, nil}, + {"one-item", SetOf([]string{"one"}), SetOf([]int{1})}, + {"multiple-items", SetOf([]string{"one", "two", "three"}), SetOf([]int{1, 2, 3})}, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + t.Run("strings", func(t *testing.T) { + buf, err := json.Marshal(tt.strings) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + t.Logf("marshaled: %s", buf) + var s Set[string] + if err := json.Unmarshal(buf, &s); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + if !s.Equal(tt.strings) { + t.Errorf("set changed after JSON marshal/unmarshal, before: %v, after: %v", tt.strings, s) + } + }) + t.Run("ints", func(t *testing.T) { + buf, err := json.Marshal(tt.ints) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + t.Logf("marshaled: %s", buf) + var s Set[int] + if err := json.Unmarshal(buf, &s); err != nil { + t.Fatalf("json.Unmarshal: %v", err) + } + if !s.Equal(tt.ints) { + t.Errorf("set changed after JSON marshal/unmarshal, before: %v, after: %v", tt.ints, s) + } + }) + }) + } +}