From b7104cde4af5bffa9bfae0b719813e85f72c9601 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 2 Feb 2024 23:31:17 -0500 Subject: [PATCH] util/topk: add package containing a probabilistic top-K tracker This package uses a count-min sketch and a heap to track the top K items in a stream of data. Tracking a new item and adding a count to an existing item both require no memory allocations and is at worst O(log(k)) complexity. Change-Id: I0553381be3fef2470897e2bd806d43396f2dbb36 Signed-off-by: Andrew Dunham --- util/topk/topk.go | 261 +++++++++++++++++++++++++++++++++++++++++ util/topk/topk_test.go | 135 +++++++++++++++++++++ 2 files changed, 396 insertions(+) create mode 100644 util/topk/topk.go create mode 100644 util/topk/topk_test.go diff --git a/util/topk/topk.go b/util/topk/topk.go new file mode 100644 index 000000000..aef779102 --- /dev/null +++ b/util/topk/topk.go @@ -0,0 +1,261 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package topk defines a count-min sketch and a cheap probabilistic top-K data +// structure that uses the count-min sketch to track the top K items in +// constant memory and O(log(k)) time. +package topk + +import ( + "container/heap" + "hash/maphash" + "math" + "slices" + "sync" +) + +// TopK is a probabilistic counter of the top K items, using a count-min sketch +// to keep track of item counts and a heap to track the top K of them. +type TopK[T any] struct { + heap minHeap[T] + k int + sf SerializeFunc[T] + cms CountMinSketch +} + +// HashFunc is responsible for providing a []byte serialization of a value, +// appended to the provided byte slice. This is used for hashing the value when +// adding to a CountMinSketch. +type SerializeFunc[T any] func([]byte, T) []byte + +// New creates a new TopK that stores k values. Parameters for the underlying +// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of +// error. +func New[T any](k int, sf SerializeFunc[T]) *TopK[T] { + hashes, buckets := PickParams(0.001, 0.001) + return NewWithParams(k, sf, hashes, buckets) +} + +// NewWithParams creates a new TopK that stores k values, and additionally +// allows customizing the parameters for the underlying count-min sketch. +func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { + ret := &TopK[T]{ + heap: make(minHeap[T], 0, k), + k: k, + sf: sf, + } + ret.cms.init(numHashes, numCols) + return ret +} + +// Add calls AddN(val, 1). +func (tk *TopK[T]) Add(val T) uint64 { + return tk.AddN(val, 1) +} + +var hashPool = &sync.Pool{ + New: func() any { + buf := make([]byte, 0, 128) + return &buf + }, +} + +// AddN adds the given item to the set with the provided count, returning the +// new estimated count. +func (tk *TopK[T]) AddN(val T, count uint64) uint64 { + buf := hashPool.Get().(*[]byte) + defer hashPool.Put(buf) + ser := tk.sf((*buf)[:0], val) + + vcount := tk.cms.AddN(ser, count) + + // If we don't have a full heap, just push it. + if len(tk.heap) < tk.k { + heap.Push(&tk.heap, mhValue[T]{ + count: vcount, + val: val, + }) + return vcount + } + + // If this item's count surpasses the heap's minimum, update the heap. + if vcount > tk.heap[0].count { + tk.heap[0] = mhValue[T]{ + count: vcount, + val: val, + } + heap.Fix(&tk.heap, 0) + } + return vcount +} + +// Top returns the estimated top K items as stored by this TopK. +func (tk *TopK[T]) Top() []T { + ret := make([]T, 0, tk.k) + for _, item := range tk.heap { + ret = append(ret, item.val) + } + return ret +} + +// AppendTop appends the estimated top K items as stored by this TopK to the +// provided slice, allocating only if the slice does not have enough capacity +// to store all items. The provided slice can be nil. +func (tk *TopK[T]) AppendTop(sl []T) []T { + sl = slices.Grow(sl, tk.k) + for _, item := range tk.heap { + sl = append(sl, item.val) + } + return sl +} + +// CountMinSketch implements a count-min sketch, a probabilistic data structure +// that tracks the frequency of events in a stream of data. +// +// See: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch +type CountMinSketch struct { + hashes []maphash.Seed + nbuckets int + matrix []uint64 +} + +// NewCountMinSketch creates a new CountMinSketch with the provided number of +// hashes and buckets. Hashes and buckets are often called "depth" and "width", +// or "d" and "w", respectively. +func NewCountMinSketch(hashes, buckets int) *CountMinSketch { + ret := &CountMinSketch{} + ret.init(hashes, buckets) + return ret +} + +// PickParams provides good parameters for 'hashes' and 'buckets' when +// constructing a CountMinSketch, given an estimated total number of counts +// (i.e. the sum of all counts ever stored), the error factor ϵ as a float +// (e.g. 1% is 0.001), and the probability factor δ. +// +// Parameters are chosen such that with a probability of 1−δ, the error is at +// most ϵ∗totalCount. Or, in other words: if N is the true count of an event, +// E is the estimate given by a sketch and T the total count of items in the +// sketch, E ≤ N + T*ϵ with probability (1 - δ). +func PickParams(err, probability float64) (hashes, buckets int) { + d := math.Ceil(math.Log(1 / probability)) + w := math.Ceil(math.E / err) + + return int(d), int(w) +} + +func (cms *CountMinSketch) init(hashes, buckets int) { + for i := 0; i < hashes; i++ { + cms.hashes = append(cms.hashes, maphash.MakeSeed()) + } + + // Need a matrix of hashes * buckets to store counts + cms.nbuckets = buckets + cms.matrix = make([]uint64, hashes*buckets) +} + +// Add calls AddN(val, 1). +func (cms *CountMinSketch) Add(val []byte) uint64 { + return cms.AddN(val, 1) +} + +// AddN increments the count for the given value by the provided count, +// returning the new count. +func (cms *CountMinSketch) AddN(val []byte, count uint64) uint64 { + var ( + mh maphash.Hash + ret uint64 = math.MaxUint64 + ) + for i, seed := range cms.hashes { + mh.SetSeed(seed) + + // Generate a hash for this value using Lemire's alternative to modular reduction: + // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + mh.Write(val) + hash := mh.Sum64() + hash = multiplyHigh64(hash, uint64(cms.nbuckets)) + + // The index in our matrix is (i * buckets) to move "down" i + // rows in our matrix to the row for this hash, plus 'hash' to + // move inside this row. + idx := (i * cms.nbuckets) + int(hash) + + // Add to this row + cms.matrix[idx] += count + ret = min(ret, cms.matrix[idx]) + } + return ret +} + +// Get returns the count for the provided value. +func (cms *CountMinSketch) Get(val []byte) uint64 { + var ( + mh maphash.Hash + ret uint64 = math.MaxUint64 + ) + for i, seed := range cms.hashes { + mh.SetSeed(seed) + + // Generate a hash for this value using Lemire's alternative to modular reduction: + // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ + mh.Write(val) + hash := mh.Sum64() + hash = multiplyHigh64(hash, uint64(cms.nbuckets)) + + // The index in our matrix is (i * buckets) to move "down" i + // rows in our matrix to the row for this hash, plus 'hash' to + // move inside this row. + idx := (i * cms.nbuckets) + int(hash) + + // Select the minimal value among all rows + ret = min(ret, cms.matrix[idx]) + } + return ret +} + +// multiplyHigh64 implements (x * y) >> 64 "the long way" without access to a +// 128-bit type. This function is adapted from something similar in Tensorflow: +// +// https://github.com/tensorflow/tensorflow/commit/a47a300185026fe7829990def9113bf3a5109fed +// +// TODO(andrew-d): this could be replaced with a single "MULX" instruction on +// x86_64 platforms, which we can do if this ever turns out to be a performance +// bottleneck. +func multiplyHigh64(x, y uint64) uint64 { + x_lo := x & 0xffffffff + x_hi := x >> 32 + buckets_lo := y & 0xffffffff + buckets_hi := y >> 32 + prod_hi := x_hi * buckets_hi + prod_lo := x_lo * buckets_lo + prod_mid1 := x_hi * buckets_lo + prod_mid2 := x_lo * buckets_hi + carry := ((prod_mid1 & 0xffffffff) + (prod_mid2 & 0xffffffff) + (prod_lo >> 32)) >> 32 + return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry +} + +type mhValue[T any] struct { + count uint64 + val T +} + +// An minHeap is a min-heap of ints and associated values. +type minHeap[T any] []mhValue[T] + +func (h minHeap[T]) Len() int { return len(h) } +func (h minHeap[T]) Less(i, j int) bool { return h[i].count < h[j].count } +func (h minHeap[T]) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *minHeap[T]) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(mhValue[T])) +} + +func (h *minHeap[T]) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/util/topk/topk_test.go b/util/topk/topk_test.go new file mode 100644 index 000000000..499948b5f --- /dev/null +++ b/util/topk/topk_test.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package topk + +import ( + "encoding/binary" + "fmt" + "slices" + "testing" +) + +func TestCountMinSketch(t *testing.T) { + cms := NewCountMinSketch(4, 10) + items := []string{"foo", "bar", "baz", "asdf", "quux"} + for _, item := range items { + cms.Add([]byte(item)) + } + for _, item := range items { + count := cms.Get([]byte(item)) + if count < 1 { + t.Errorf("item %q should have count >= 1", item) + } else if count > 1 { + t.Logf("item %q has count > 1: %d", item, count) + } + } + + // Test that an item that's *not* in the set has a value lower than the + // total number of items we inserted (in the case that all items + // collided). + noItemCount := cms.Get([]byte("doesn't exist")) + if noItemCount > uint64(len(items)) { + t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount) + } +} + +func TestTopK(t *testing.T) { + // This is probabilistic, so we're going to try 10 times to get the + // "right" value; the likelihood that we fail on all attempts is + // vanishingly small since the number of hash buckets is drastically + // larger than the number of items we're inserting. + var ( + got []int + want = []int{5, 6, 7, 8, 9} + ) + for try := 0; try < 10; try++ { + topk := NewWithParams[int](5, func(in []byte, val int) []byte { + return binary.LittleEndian.AppendUint64(in, uint64(val)) + }, 4, 1000) + + // Add the first 10 integers with counts equal to 2x their value + for i := 0; i < 10; i++ { + topk.AddN(i, uint64(i*2)) + } + + got = topk.Top() + t.Logf("top K items: %+v", got) + slices.Sort(got) + + if slices.Equal(got, want) { + // All good! + return + } + + // continue and retry or fail + } + + t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want) +} + +func TestPickParams(t *testing.T) { + hashes, buckets := PickParams( + 0.001, // 0.1% error rate + 0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error + ) + t.Logf("hashes = %d, buckets = %d", hashes, buckets) +} + +func BenchmarkCountMinSketch(b *testing.B) { + cms := NewCountMinSketch(PickParams(0.001, 0.001)) + b.ResetTimer() + b.ReportAllocs() + + var enc [8]byte + for i := 0; i < b.N; i++ { + binary.LittleEndian.PutUint64(enc[:], uint64(i)) + cms.Add(enc[:]) + } +} + +func BenchmarkTopK(b *testing.B) { + for _, n := range []int{ + 10, + 128, + 256, + 1024, + 8192, + } { + b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) { + out := make([]int, 0, n) + topk := New[int](n, func(in []byte, val int) []byte { + return binary.LittleEndian.AppendUint64(in, uint64(val)) + }) + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + topk.Add(i) + } + out = topk.AppendTop(out[:0]) // should not allocate + _ = out // appease linter + }) + } +} + +func TestMultiplyHigh64(t *testing.T) { + testCases := []struct { + x, y uint64 + want uint64 + }{ + {0, 0, 0}, + {0xffffffff, 0xffffffff, 0}, + {0x2, 0xf000000000000000, 1}, + {0x3, 0xf000000000000000, 2}, + {0x3, 0xf000000000000001, 2}, + {0x3, 0xffffffffffffffff, 2}, + {0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe}, + } + for _, tc := range testCases { + got := multiplyHigh64(tc.x, tc.y) + if got != tc.want { + t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want) + } + } +}