util/deephash: don't track cycles on non-recursive types

name              old time/op    new time/op    delta
Hash-8              67.3µs ±20%    76.5µs ±16%     ~     (p=0.143 n=10+10)
HashMapAcyclic-8    63.0µs ± 2%    56.3µs ± 1%  -10.65%  (p=0.000 n=10+8)
TailcfgNode-8       9.18µs ± 2%    6.52µs ± 3%  -28.96%  (p=0.000 n=9+10)
HashArray-8          732ns ± 3%     709ns ± 1%   -3.21%  (p=0.000 n=10+10)

name              old alloc/op   new alloc/op   delta
Hash-8               24.0B ± 0%     24.0B ± 0%     ~     (all equal)
HashMapAcyclic-8     0.00B          0.00B          ~     (all equal)
TailcfgNode-8        0.00B          0.00B          ~     (all equal)
HashArray-8          0.00B          0.00B          ~     (all equal)

name              old allocs/op  new allocs/op  delta
Hash-8                1.00 ± 0%      1.00 ± 0%     ~     (all equal)
HashMapAcyclic-8      0.00           0.00          ~     (all equal)
TailcfgNode-8         0.00           0.00          ~     (all equal)
HashArray-8           0.00           0.00          ~     (all equal)

Change-Id: I28642050d837dff66b2db54b2b0e6d272a930be8
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-06-14 22:49:11 -07:00 committed by Brad Fitzpatrick
parent 36ea837736
commit f31588786f
2 changed files with 211 additions and 43 deletions

View File

@ -125,7 +125,7 @@ func Hash(v any) (s Sum) {
seed = uint64(time.Now().UnixNano())
})
h.hashUint64(seed)
h.hashValue(reflect.ValueOf(v))
h.hashValue(reflect.ValueOf(v), false)
return h.sum()
}
@ -164,26 +164,151 @@ func (h *hasher) hashUint64(i uint64) {
var uint8Type = reflect.TypeOf(byte(0))
func (h *hasher) hashValue(v reflect.Value) {
// typeInfo describes properties of a type.
type typeInfo struct {
rtype reflect.Type
isRecursive bool
// elemTypeInfo is the element type's typeInfo.
// It's set when rtype is of Kind Ptr, Slice, Array, Map.
elemTypeInfo *typeInfo
// keyTypeInfo is the map key type's typeInfo.
// It's set when rtype is of Kind Map.
keyTypeInfo *typeInfo
}
var typeInfoMap sync.Map // map[reflect.Type]*typeInfo
var typeInfoMapPopulate sync.Mutex // just for adding to typeInfoMap
func getTypeInfo(t reflect.Type) *typeInfo {
if f, ok := typeInfoMap.Load(t); ok {
return f.(*typeInfo)
}
typeInfoMapPopulate.Lock()
defer typeInfoMapPopulate.Unlock()
newTypes := map[reflect.Type]*typeInfo{}
ti := getTypeInfoLocked(t, newTypes)
for t, ti := range newTypes {
typeInfoMap.Store(t, ti)
}
return ti
}
func getTypeInfoLocked(t reflect.Type, incomplete map[reflect.Type]*typeInfo) *typeInfo {
if v, ok := typeInfoMap.Load(t); ok {
return v.(*typeInfo)
}
if ti, ok := incomplete[t]; ok {
return ti
}
ti := &typeInfo{
rtype: t,
isRecursive: typeIsRecursive(t),
}
incomplete[t] = ti
switch t.Kind() {
case reflect.Map:
ti.keyTypeInfo = getTypeInfoLocked(t.Key(), incomplete)
fallthrough
case reflect.Ptr, reflect.Slice, reflect.Array:
ti.elemTypeInfo = getTypeInfoLocked(t.Elem(), incomplete)
}
return ti
}
// typeIsRecursive reports whether t has a path back to itself.
//
// For interfaces, it currently always reports true.
func typeIsRecursive(t reflect.Type) bool {
inStack := map[reflect.Type]bool{}
var stack []reflect.Type
var visitType func(t reflect.Type) (isRecursiveSoFar bool)
visitType = func(t reflect.Type) (isRecursiveSoFar bool) {
switch t.Kind() {
case reflect.Bool,
reflect.Int,
reflect.Int8,
reflect.Int16,
reflect.Int32,
reflect.Int64,
reflect.Uint,
reflect.Uint8,
reflect.Uint16,
reflect.Uint32,
reflect.Uint64,
reflect.Uintptr,
reflect.Float32,
reflect.Float64,
reflect.Complex64,
reflect.Complex128,
reflect.String,
reflect.UnsafePointer,
reflect.Func:
return false
}
if t.Size() == 0 {
return false
}
if inStack[t] {
return true
}
stack = append(stack, t)
inStack[t] = true
defer func() {
delete(inStack, t)
stack = stack[:len(stack)-1]
}()
switch t.Kind() {
default:
panic("unhandled kind " + t.Kind().String())
case reflect.Interface:
// Assume the worst for now. TODO(bradfitz): in some cases
// we should be able to prove that it's not recursive. Not worth
// it for now.
return true
case reflect.Array, reflect.Chan, reflect.Pointer, reflect.Slice:
return visitType(t.Elem())
case reflect.Map:
if visitType(t.Key()) {
return true
}
if visitType(t.Elem()) {
return true
}
case reflect.Struct:
if t.String() == "intern.Value" {
// Otherwise its interface{} makes this return true.
return false
}
for i, numField := 0, t.NumField(); i < numField; i++ {
if visitType(t.Field(i).Type) {
return true
}
}
return false
}
return false
}
return visitType(t)
}
func (h *hasher) hashValue(v reflect.Value, forceCycleChecking bool) {
if !v.IsValid() {
return
}
ti := getTypeInfo(v.Type())
h.hashValueWithType(v, ti, forceCycleChecking)
}
func (h *hasher) hashValueWithType(v reflect.Value, ti *typeInfo, forceCycleChecking bool) {
w := h.bw
if v.CanInterface() {
// Use AppendTo methods, if available and cheap.
if v.CanAddr() && v.Type().Implements(appenderToType) {
a := v.Addr().Interface().(appenderTo)
size := h.scratch[:8]
record := a.AppendTo(size)
binary.LittleEndian.PutUint64(record, uint64(len(record)-len(size)))
w.Write(record)
return
}
}
// TODO(dsnet): Avoid cycle detection for types that cannot have cycles.
doCheckCycles := forceCycleChecking || ti.isRecursive
// Generic handling.
switch v.Kind() {
@ -195,21 +320,22 @@ func (h *hasher) hashValue(v reflect.Value) {
return
}
// Check for cycle.
ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle
h.hashUint64(uint64(idx))
return
if doCheckCycles {
ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle
h.hashUint64(uint64(idx))
return
}
h.visitStack.push(ptr)
defer h.visitStack.pop(ptr)
}
h.visitStack.push(ptr)
defer h.visitStack.pop(ptr)
h.hashUint8(1) // indicates visiting a pointer
h.hashValue(v.Elem())
h.hashValueWithType(v.Elem(), ti.elemTypeInfo, doCheckCycles)
case reflect.Struct:
for i, n := 0, v.NumField(); i < n; i++ {
h.hashValue(v.Field(i))
h.hashValue(v.Field(i), doCheckCycles)
}
case reflect.Slice, reflect.Array:
vLen := v.Len()
@ -233,7 +359,7 @@ func (h *hasher) hashValue(v reflect.Value) {
// TODO(dsnet): Perform cycle detection for slices,
// which is functionally a list of pointers.
// See https://github.com/google/go-cmp/blob/402949e8139bb890c71a707b6faf6dd05c92f4e5/cmp/compare.go#L438-L450
h.hashValue(v.Index(i))
h.hashValueWithType(v.Index(i), ti.elemTypeInfo, doCheckCycles)
}
case reflect.Interface:
if v.IsNil() {
@ -244,20 +370,21 @@ func (h *hasher) hashValue(v reflect.Value) {
h.hashUint8(1) // indicates visiting interface value
h.hashType(v.Type())
h.hashValue(v)
h.hashValue(v, doCheckCycles)
case reflect.Map:
// Check for cycle.
ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle
h.hashUint64(uint64(idx))
return
if doCheckCycles {
ptr := pointerOf(v)
if idx, ok := h.visitStack.seen(ptr); ok {
h.hashUint8(2) // indicates cycle
h.hashUint64(uint64(idx))
return
}
h.visitStack.push(ptr)
defer h.visitStack.pop(ptr)
}
h.visitStack.push(ptr)
defer h.visitStack.pop(ptr)
h.hashUint8(1) // indicates visiting a map
h.hashMap(v)
h.hashMap(v, ti, doCheckCycles)
case reflect.String:
s := v.String()
h.hashUint64(uint64(len(s)))
@ -325,7 +452,7 @@ func (c *valueCache) get(t reflect.Type) reflect.Value {
// It relies on a map being a functionally an unordered set of KV entries.
// So long as we hash each KV entry together, we can XOR all
// of the individual hashes to produce a unique hash for the entire map.
func (h *hasher) hashMap(v reflect.Value) {
func (h *hasher) hashMap(v reflect.Value, ti *typeInfo, checkCycles bool) {
mh := mapHasherPool.Get().(*mapHasher)
defer mapHasherPool.Put(mh)
@ -341,8 +468,8 @@ func (h *hasher) hashMap(v reflect.Value) {
k.SetIterKey(iter)
e.SetIterValue(iter)
mh.h.reset()
mh.h.hashValue(k)
mh.h.hashValue(e)
mh.h.hashValueWithType(k, ti.keyTypeInfo, checkCycles)
mh.h.hashValueWithType(e, ti.elemTypeInfo, checkCycles)
sum.xor(mh.h.sum())
}
h.bw.Write(append(h.scratch[:0], sum.sum[:]...)) // append into scratch to avoid heap allocation

View File

@ -14,6 +14,8 @@ import (
"reflect"
"runtime"
"testing"
"time"
"unsafe"
"go4.org/mem"
"inet.af/netaddr"
@ -21,6 +23,7 @@ import (
"tailscale.com/types/dnstype"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/structs"
"tailscale.com/util/dnsname"
"tailscale.com/version"
"tailscale.com/wgengine/filter"
@ -235,6 +238,41 @@ func getVal() []any {
}
}
func TestTypeIsRecursive(t *testing.T) {
type RecursiveStruct struct {
v *RecursiveStruct
}
type RecursiveChan chan *RecursiveChan
tests := []struct {
val any
want bool
}{
{val: 42, want: false},
{val: "string", want: false},
{val: 1 + 2i, want: false},
{val: struct{}{}, want: false},
{val: (*RecursiveStruct)(nil), want: true},
{val: RecursiveStruct{}, want: true},
{val: time.Unix(0, 0), want: false},
{val: structs.Incomparable{}, want: false}, // ignore its [0]func()
{val: tailcfg.NetPortRange{}, want: false}, // uses structs.Incomparable
{val: (*tailcfg.Node)(nil), want: false},
{val: map[string]bool{}, want: false},
{val: func() {}, want: false},
{val: make(chan int), want: false},
{val: unsafe.Pointer(nil), want: false},
{val: make(RecursiveChan), want: true},
{val: make(chan int), want: false},
}
for _, tt := range tests {
got := typeIsRecursive(reflect.TypeOf(tt.val))
if got != tt.want {
t.Errorf("for type %T: got %v, want %v", tt.val, got, tt.want)
}
}
}
var sink = Hash("foo")
func BenchmarkHash(b *testing.B) {
@ -255,12 +293,14 @@ func TestHashMapAcyclic(t *testing.T) {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
ti := getTypeInfo(reflect.TypeOf(m))
for i := 0; i < 20; i++ {
v := reflect.ValueOf(m)
buf.Reset()
bw.Reset(&buf)
h := &hasher{bw: bw}
h.hashMap(v)
h.hashMap(v, ti, false)
if got[string(buf.Bytes())] {
continue
}
@ -279,7 +319,7 @@ func TestPrintArray(t *testing.T) {
var got bytes.Buffer
bw := bufio.NewWriter(&got)
h := &hasher{bw: bw}
h.hashValue(reflect.ValueOf(x))
h.hashValue(reflect.ValueOf(x), false)
bw.Flush()
const want = "\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f"
if got := got.Bytes(); string(got) != want {
@ -297,13 +337,14 @@ func BenchmarkHashMapAcyclic(b *testing.B) {
var buf bytes.Buffer
bw := bufio.NewWriter(&buf)
v := reflect.ValueOf(m)
ti := getTypeInfo(v.Type())
h := &hasher{bw: bw}
for i := 0; i < b.N; i++ {
buf.Reset()
bw.Reset(&buf)
h.hashMap(v)
h.hashMap(v, ti, false)
}
}