2024-02-03 04:31:17 +00:00
|
|
|
// Copyright (c) Tailscale Inc & AUTHORS
|
|
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
|
|
|
|
package topk
|
|
|
|
|
|
|
|
import (
|
|
|
|
"encoding/binary"
|
|
|
|
"fmt"
|
|
|
|
"slices"
|
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
|
|
|
func TestCountMinSketch(t *testing.T) {
|
|
|
|
cms := NewCountMinSketch(4, 10)
|
|
|
|
items := []string{"foo", "bar", "baz", "asdf", "quux"}
|
|
|
|
for _, item := range items {
|
|
|
|
cms.Add([]byte(item))
|
|
|
|
}
|
|
|
|
for _, item := range items {
|
|
|
|
count := cms.Get([]byte(item))
|
|
|
|
if count < 1 {
|
|
|
|
t.Errorf("item %q should have count >= 1", item)
|
|
|
|
} else if count > 1 {
|
|
|
|
t.Logf("item %q has count > 1: %d", item, count)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Test that an item that's *not* in the set has a value lower than the
|
|
|
|
// total number of items we inserted (in the case that all items
|
|
|
|
// collided).
|
|
|
|
noItemCount := cms.Get([]byte("doesn't exist"))
|
|
|
|
if noItemCount > uint64(len(items)) {
|
|
|
|
t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestTopK(t *testing.T) {
|
|
|
|
// This is probabilistic, so we're going to try 10 times to get the
|
|
|
|
// "right" value; the likelihood that we fail on all attempts is
|
|
|
|
// vanishingly small since the number of hash buckets is drastically
|
|
|
|
// larger than the number of items we're inserting.
|
|
|
|
var (
|
|
|
|
got []int
|
|
|
|
want = []int{5, 6, 7, 8, 9}
|
|
|
|
)
|
|
|
|
for try := 0; try < 10; try++ {
|
|
|
|
topk := NewWithParams[int](5, func(in []byte, val int) []byte {
|
|
|
|
return binary.LittleEndian.AppendUint64(in, uint64(val))
|
|
|
|
}, 4, 1000)
|
|
|
|
|
|
|
|
// Add the first 10 integers with counts equal to 2x their value
|
2024-04-16 21:15:13 +01:00
|
|
|
for i := range 10 {
|
2024-02-03 04:31:17 +00:00
|
|
|
topk.AddN(i, uint64(i*2))
|
|
|
|
}
|
|
|
|
|
|
|
|
got = topk.Top()
|
|
|
|
t.Logf("top K items: %+v", got)
|
|
|
|
slices.Sort(got)
|
|
|
|
|
|
|
|
if slices.Equal(got, want) {
|
|
|
|
// All good!
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// continue and retry or fail
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestPickParams(t *testing.T) {
|
|
|
|
hashes, buckets := PickParams(
|
|
|
|
0.001, // 0.1% error rate
|
|
|
|
0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error
|
|
|
|
)
|
|
|
|
t.Logf("hashes = %d, buckets = %d", hashes, buckets)
|
|
|
|
}
|
|
|
|
|
|
|
|
func BenchmarkCountMinSketch(b *testing.B) {
|
|
|
|
cms := NewCountMinSketch(PickParams(0.001, 0.001))
|
|
|
|
b.ResetTimer()
|
|
|
|
b.ReportAllocs()
|
|
|
|
|
|
|
|
var enc [8]byte
|
2024-04-16 21:15:13 +01:00
|
|
|
for i := range b.N {
|
2024-02-03 04:31:17 +00:00
|
|
|
binary.LittleEndian.PutUint64(enc[:], uint64(i))
|
|
|
|
cms.Add(enc[:])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func BenchmarkTopK(b *testing.B) {
|
|
|
|
for _, n := range []int{
|
|
|
|
10,
|
|
|
|
128,
|
|
|
|
256,
|
|
|
|
1024,
|
|
|
|
8192,
|
|
|
|
} {
|
|
|
|
b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) {
|
|
|
|
out := make([]int, 0, n)
|
|
|
|
topk := New[int](n, func(in []byte, val int) []byte {
|
|
|
|
return binary.LittleEndian.AppendUint64(in, uint64(val))
|
|
|
|
})
|
|
|
|
b.ResetTimer()
|
|
|
|
b.ReportAllocs()
|
|
|
|
|
2024-04-16 21:15:13 +01:00
|
|
|
for i := range b.N {
|
2024-02-03 04:31:17 +00:00
|
|
|
topk.Add(i)
|
|
|
|
}
|
|
|
|
out = topk.AppendTop(out[:0]) // should not allocate
|
|
|
|
_ = out // appease linter
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestMultiplyHigh64(t *testing.T) {
|
|
|
|
testCases := []struct {
|
|
|
|
x, y uint64
|
|
|
|
want uint64
|
|
|
|
}{
|
|
|
|
{0, 0, 0},
|
|
|
|
{0xffffffff, 0xffffffff, 0},
|
|
|
|
{0x2, 0xf000000000000000, 1},
|
|
|
|
{0x3, 0xf000000000000000, 2},
|
|
|
|
{0x3, 0xf000000000000001, 2},
|
|
|
|
{0x3, 0xffffffffffffffff, 2},
|
|
|
|
{0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe},
|
|
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
|
|
got := multiplyHigh64(tc.x, tc.y)
|
|
|
|
if got != tc.want {
|
|
|
|
t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|