diff --git a/util/deephash/deephash.go b/util/deephash/deephash.go index 955f83e55..e9f9c9a1b 100644 --- a/util/deephash/deephash.go +++ b/util/deephash/deephash.go @@ -125,7 +125,7 @@ func Hash(v any) (s Sum) { seed = uint64(time.Now().UnixNano()) }) h.hashUint64(seed) - h.hashValue(reflect.ValueOf(v)) + h.hashValue(reflect.ValueOf(v), false) return h.sum() } @@ -164,26 +164,151 @@ func (h *hasher) hashUint64(i uint64) { var uint8Type = reflect.TypeOf(byte(0)) -func (h *hasher) hashValue(v reflect.Value) { +// typeInfo describes properties of a type. +type typeInfo struct { + rtype reflect.Type + isRecursive bool + + // elemTypeInfo is the element type's typeInfo. + // It's set when rtype is of Kind Ptr, Slice, Array, Map. + elemTypeInfo *typeInfo + + // keyTypeInfo is the map key type's typeInfo. + // It's set when rtype is of Kind Map. + keyTypeInfo *typeInfo +} + +var typeInfoMap sync.Map // map[reflect.Type]*typeInfo +var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap + +func getTypeInfo(t reflect.Type) *typeInfo { + if f, ok := typeInfoMap.Load(t); ok { + return f.(*typeInfo) + } + typeInfoMapPopulate.Lock() + defer typeInfoMapPopulate.Unlock() + newTypes := map[reflect.Type]*typeInfo{} + ti := getTypeInfoLocked(t, newTypes) + for t, ti := range newTypes { + typeInfoMap.Store(t, ti) + } + return ti +} + +func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *typeInfo { + if v, ok := typeInfoMap.Load(t); ok { + return v.(*typeInfo) + } + if ti, ok := incomplete[t]; ok { + return ti + } + ti := &typeInfo{ + rtype: t, + isRecursive: typeIsRecursive(t), + } + incomplete[t] = ti + + switch t.Kind() { + case reflect.Map: + ti.keyTypeInfo = getTypeInfoLocked(t.Key(), incomplete) + fallthrough + case reflect.Ptr, reflect.Slice, reflect.Array: + ti.elemTypeInfo = getTypeInfoLocked(t.Elem(), incomplete) + } + + return ti +} + +// typeIsRecursive reports whether t has a path back to itself. +// +// For interfaces, it currently always reports true. +func typeIsRecursive(t reflect.Type) bool { + inStack := map[reflect.Type]bool{} + + var stack []reflect.Type + + var visitType func(t reflect.Type) (isRecursiveSoFar bool) + visitType = func(t reflect.Type) (isRecursiveSoFar bool) { + switch t.Kind() { + case reflect.Bool, + reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64, + reflect.Uintptr, + reflect.Float32, + reflect.Float64, + reflect.Complex64, + reflect.Complex128, + reflect.String, + reflect.UnsafePointer, + reflect.Func: + return false + } + if t.Size() == 0 { + return false + } + if inStack[t] { + return true + } + stack = append(stack, t) + inStack[t] = true + defer func() { + delete(inStack, t) + stack = stack[:len(stack)-1] + }() + + switch t.Kind() { + default: + panic("unhandled kind " + t.Kind().String()) + case reflect.Interface: + // Assume the worst for now. TODO(bradfitz): in some cases + // we should be able to prove that it's not recursive. Not worth + // it for now. + return true + case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice: + return visitType(t.Elem()) + case reflect.Map: + if visitType(t.Key()) { + return true + } + if visitType(t.Elem()) { + return true + } + case reflect.Struct: + if t.String() == "intern.Value" { + // Otherwise its interface{} makes this return true. + return false + } + for i, numField := 0, t.NumField(); i < numField; i++ { + if visitType(t.Field(i).Type) { + return true + } + } + return false + } + return false + } + return visitType(t) +} + +func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) { if !v.IsValid() { return } + ti := getTypeInfo(v.Type()) + h.hashValueWithType(v, ti, forceCycleChecking) +} +func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChecking bool) { w := h.bw - - if v.CanInterface() { - // Use AppendTo methods, if available and cheap. - if v.CanAddr() && v.Type().Implements(appenderToType) { - a := v.Addr().Interface().(appenderTo) - size := h.scratch[:8] - record := a.AppendTo(size) - binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size))) - w.Write(record) - return - } - } - - // TODO(dsnet): Avoid cycle detection for types that cannot have cycles. + doCheckCycles := forceCycleChecking || ti.isRecursive // Generic handling. switch v.Kind() { @@ -195,21 +320,22 @@ func (h *hasher) hashValue(v reflect.Value) { return } - // Check for cycle. - ptr := pointerOf(v) - if idx, ok := h.visitStack.seen(ptr); ok { - h.hashUint8(2) // indicates cycle - h.hashUint64(uint64(idx)) - return + if doCheckCycles { + ptr := pointerOf(v) + if idx, ok := h.visitStack.seen(ptr); ok { + h.hashUint8(2) // indicates cycle + h.hashUint64(uint64(idx)) + return + } + h.visitStack.push(ptr) + defer h.visitStack.pop(ptr) } - h.visitStack.push(ptr) - defer h.visitStack.pop(ptr) h.hashUint8(1) // indicates visiting a pointer - h.hashValue(v.Elem()) + h.hashValueWithType(v.Elem(), ti.elemTypeInfo, doCheckCycles) case reflect.Struct: for i, n := 0, v.NumField(); i < n; i++ { - h.hashValue(v.Field(i)) + h.hashValue(v.Field(i), doCheckCycles) } case reflect.Slice, reflect.Array: vLen := v.Len() @@ -233,7 +359,7 @@ func (h *hasher) hashValue(v reflect.Value) { // TODO(dsnet): Perform cycle detection for slices, // which is functionally a list of pointers. // See https://github.com/google/go-cmp/blob/402949e8139bb890c71a707b6faf6dd05c92f4e5/cmp/compare.go#L438-L450 - h.hashValue(v.Index(i)) + h.hashValueWithType(v.Index(i), ti.elemTypeInfo, doCheckCycles) } case reflect.Interface: if v.IsNil() { @@ -244,20 +370,21 @@ func (h *hasher) hashValue(v reflect.Value) { h.hashUint8(1) // indicates visiting interface value h.hashType(v.Type()) - h.hashValue(v) + h.hashValue(v, doCheckCycles) case reflect.Map: // Check for cycle. - ptr := pointerOf(v) - if idx, ok := h.visitStack.seen(ptr); ok { - h.hashUint8(2) // indicates cycle - h.hashUint64(uint64(idx)) - return + if doCheckCycles { + ptr := pointerOf(v) + if idx, ok := h.visitStack.seen(ptr); ok { + h.hashUint8(2) // indicates cycle + h.hashUint64(uint64(idx)) + return + } + h.visitStack.push(ptr) + defer h.visitStack.pop(ptr) } - h.visitStack.push(ptr) - defer h.visitStack.pop(ptr) - h.hashUint8(1) // indicates visiting a map - h.hashMap(v) + h.hashMap(v, ti, doCheckCycles) case reflect.String: s := v.String() h.hashUint64(uint64(len(s))) @@ -325,7 +452,7 @@ func (c *valueCache) get(t reflect.Type) reflect.Value { // It relies on a map being a functionally an unordered set of KV entries. // So long as we hash each KV entry together, we can XOR all // of the individual hashes to produce a unique hash for the entire map. -func (h *hasher) hashMap(v reflect.Value) { +func (h *hasher) hashMap(v reflect.Value, ti *typeInfo, checkCycles bool) { mh := mapHasherPool.Get().(*mapHasher) defer mapHasherPool.Put(mh) @@ -341,8 +468,8 @@ func (h *hasher) hashMap(v reflect.Value) { k.SetIterKey(iter) e.SetIterValue(iter) mh.h.reset() - mh.h.hashValue(k) - mh.h.hashValue(e) + mh.h.hashValueWithType(k, ti.keyTypeInfo, checkCycles) + mh.h.hashValueWithType(e, ti.elemTypeInfo, checkCycles) sum.xor(mh.h.sum()) } h.bw.Write(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index 3650beade..93b1b8314 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -14,6 +14,8 @@ import ( "reflect" "runtime" "testing" + "time" + "unsafe" "go4.org/mem" "inet.af/netaddr" @@ -21,6 +23,7 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/ipproto" "tailscale.com/types/key" + "tailscale.com/types/structs" "tailscale.com/util/dnsname" "tailscale.com/version" "tailscale.com/wgengine/filter" @@ -235,6 +238,41 @@ func getVal() []any { } } +func TestTypeIsRecursive(t *testing.T) { + type RecursiveStruct struct { + v *RecursiveStruct + } + type RecursiveChan chan *RecursiveChan + + tests := []struct { + val any + want bool + }{ + {val: 42, want: false}, + {val: "string", want: false}, + {val: 1 + 2i, want: false}, + {val: struct{}{}, want: false}, + {val: (*RecursiveStruct)(nil), want: true}, + {val: RecursiveStruct{}, want: true}, + {val: time.Unix(0, 0), want: false}, + {val: structs.Incomparable{}, want: false}, // ignore its [0]func() + {val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable + {val: (*tailcfg.Node)(nil), want: false}, + {val: map[string]bool{}, want: false}, + {val: func() {}, want: false}, + {val: make(chan int), want: false}, + {val: unsafe.Pointer(nil), want: false}, + {val: make(RecursiveChan), want: true}, + {val: make(chan int), want: false}, + } + for _, tt := range tests { + got := typeIsRecursive(reflect.TypeOf(tt.val)) + if got != tt.want { + t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want) + } + } +} + var sink = Hash("foo") func BenchmarkHash(b *testing.B) { @@ -255,12 +293,14 @@ func TestHashMapAcyclic(t *testing.T) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) + ti := getTypeInfo(reflect.TypeOf(m)) + for i := 0; i < 20; i++ { v := reflect.ValueOf(m) buf.Reset() bw.Reset(&buf) h := &hasher{bw: bw} - h.hashMap(v) + h.hashMap(v, ti, false) if got[string(buf.Bytes())] { continue } @@ -279,7 +319,7 @@ func TestPrintArray(t *testing.T) { var got bytes.Buffer bw := bufio.NewWriter(&got) h := &hasher{bw: bw} - h.hashValue(reflect.ValueOf(x)) + h.hashValue(reflect.ValueOf(x), false) bw.Flush() const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f" if got := got.Bytes(); string(got) != want { @@ -297,13 +337,14 @@ func BenchmarkHashMapAcyclic(b *testing.B) { var buf bytes.Buffer bw := bufio.NewWriter(&buf) v := reflect.ValueOf(m) + ti := getTypeInfo(v.Type()) h := &hasher{bw: bw} for i := 0; i < b.N; i++ { buf.Reset() bw.Reset(&buf) - h.hashMap(v) + h.hashMap(v, ti, false) } }