From 395cb588b62bf4d2004fb7218c70174554eb4fd3 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Mon, 9 May 2022 19:31:45 -0700 Subject: [PATCH] types/views: make SliceOf/MapOf panic if they see a pointer Signed-off-by: Maisem Ali --- types/views/views.go | 80 ++++++++++++++++++++++++++++++++++----- types/views/views_test.go | 33 ++++++++++++++++ 2 files changed, 104 insertions(+), 9 deletions(-) diff --git a/types/views/views.go b/types/views/views.go index 466462460..1f0f8b2aa 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -9,6 +9,8 @@ package views import ( "encoding/json" "errors" + "fmt" + "reflect" "inet.af/netaddr" "tailscale.com/net/tsaddr" @@ -97,8 +99,14 @@ type Slice[T any] struct { ж []T } -// SliceOf returns a Slice for the provided slice. -func SliceOf[T any](x []T) Slice[T] { return Slice[T]{x} } +// SliceOf returns a Slice for the provided slice for immutable values. +// It panics if the value type contains pointers. +func SliceOf[T any](x []T) Slice[T] { + if ev := reflect.TypeOf(x).Elem(); containsMutable(ev) { + panic(fmt.Sprintf("slice value type %q has pointers", ev.Name())) + } + return Slice[T]{x} +} // MarshalJSON implements json.Marshaler. func (v Slice[T]) MarshalJSON() ([]byte, error) { @@ -186,8 +194,52 @@ func (v *IPPrefixSlice) UnmarshalJSON(b []byte) error { return v.ж.UnmarshalJSON(b) } -// MapOf returns a read-only view over m. +// containsMutable reports whether the provided type has anything mutable. +func containsMutable(t reflect.Type) bool { + switch x := fmt.Sprintf("%v.%v", t.PkgPath(), t.Name()); x { + case "time.Time", + "inet.af/netaddr.IP": + return false + } + k := t.Kind() + switch k { + case reflect.Bool, + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128, + reflect.String: + return false + case reflect.Array: // Not a slice. + return containsMutable(t.Elem()) && t.Len() > 0 + case reflect.Struct: + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if containsMutable(f.Type) { + return true + } + } + return false + } + return true +} + +// MapOf returns a read-only view over m for immutable values. +// It panics if the value type contains pointers. func MapOf[K comparable, V comparable](m map[K]V) Map[K, V] { + if ev := reflect.TypeOf(m).Elem(); containsMutable(ev) { + panic(fmt.Sprintf("map value type %q has pointers", ev.Name())) + } return Map[K, V]{m} } @@ -226,10 +278,17 @@ func (m Map[K, V]) GetOk(k K) (V, bool) { return v, ok } -// ForEach calls f for every k,v pair in the underlying map. -func (m Map[K, V]) ForEach(f func(k K, v V)) { +// MapRangeFn is the func called from a Map.Range call. +// Implementations should return false to stop range. +type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) + +// Range calls f for every k,v pair in the underlying map. +// It stops iteration immediately if f returns false. +func (m Map[K, V]) Range(f MapRangeFn[K, V]) { for k, v := range m.ж { - f(k, v) + if !f(k, v) { + return + } } } @@ -278,9 +337,12 @@ func (m MapFn[K, T, V]) GetOk(k K) (V, bool) { return m.wrapv(v), ok } -// ForEach calls f for every k,v pair in the underlying map. -func (m MapFn[K, T, V]) ForEach(f func(k K, v V)) { +// Range calls f for every k,v pair in the underlying map. +// It stops iteration immediately if f returns false. +func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) { for k, v := range m.ж { - f(k, m.wrapv(v)) + if !f(k, m.wrapv(v)) { + return + } } } diff --git a/types/views/views_test.go b/types/views/views_test.go index d9785c9a1..a064e9c42 100644 --- a/types/views/views_test.go +++ b/types/views/views_test.go @@ -10,10 +10,43 @@ import ( "reflect" "strings" "testing" + "time" + "go4.org/mem" "inet.af/netaddr" + "tailscale.com/types/structs" ) +func TestContainsPointers(t *testing.T) { + tests := []struct { + name string + in any + want bool + }{ + {name: "string", in: "foo", want: false}, + {name: "int", in: 42, want: false}, + {name: "struct", in: struct{ string }{"foo"}, want: false}, + {name: "mem.RO", in: mem.B([]byte{1}), want: false}, + {name: "time.Time", in: time.Now(), want: false}, + {name: "netaddr.IP", in: netaddr.MustParseIP("1.1.1.1"), want: false}, + {name: "netaddr.IPPrefix", in: netaddr.MustParseIP("1.1.1.1"), want: false}, + {name: "structs.Incomparable", in: structs.Incomparable{}, want: false}, + + {name: "*int", in: (*int)(nil), want: true}, + {name: "*string", in: (*string)(nil), want: true}, + {name: "struct-with-pointer", in: struct{ X *string }{}, want: true}, + {name: "slice-with-pointer", in: []struct{ X *string }{}, want: true}, + {name: "slice-of-struct", in: []struct{ string }{}, want: true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if containsMutable(reflect.TypeOf(tc.in)) != tc.want { + t.Errorf("containsPointers %T; want %v", tc.in, tc.want) + } + }) + } +} + func TestViewsJSON(t *testing.T) { mustCIDR := func(cidrs ...string) (out []netaddr.IPPrefix) { for _, cidr := range cidrs {