From 24bac2763292c9a9d83a24a78a6adc10a2f56093 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Tue, 6 Feb 2024 14:31:31 -0800 Subject: [PATCH] util/rands: add Shuffle and Perm functions with on-stack RNG state The new math/rand/v2 package includes an m-local global random number generator that can not be reseeded by the user, which is suitable for most uses without the RNG pools we have in a number of areas of the code base. The new API still does not have an allocation-free way of performing a seeded operations, due to the long term compiler bug around interface parameter escapes, and the Source interface. This change introduces the two APIs that math/rand/v2 can not yet replace efficiently: seeded Perm() and Shuffle() operations. This implementation chooses to use the PCG random source from math/rand/v2, as with sufficient compiler optimization, this source should boil down to only two on-stack registers for random state under ideal conditions. Updates #17243 Signed-off-by: James Tucker --- cmd/tailscaled/depaware.txt | 1 + util/rands/cheap.go | 82 +++++++++++++++++++++++++++++++ util/rands/cheap_test.go | 96 +++++++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 util/rands/cheap.go create mode 100644 util/rands/cheap_test.go diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index f38e5e5fa..e45d20b75 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -519,6 +519,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ + math/rand/v2 from tailscale.com/util/rands mime from github.com/tailscale/xnet/webdav+ mime/multipart from net/http mime/quotedprintable from mime/multipart diff --git a/util/rands/cheap.go b/util/rands/cheap.go new file mode 100644 index 000000000..06e46a1b0 --- /dev/null +++ b/util/rands/cheap.go @@ -0,0 +1,82 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rands + +import ( + "math/bits" + + randv2 "math/rand/v2" +) + +// Shuffle is like rand.Shuffle, but it does not allocate or lock any RNG state. +func Shuffle[T any](seed uint64, data []T) { + var pcg randv2.PCG + pcg.Seed(seed, seed) + for i := len(data) - 1; i > 0; i-- { + j := int(uint64n(&pcg, uint64(i+1))) + data[i], data[j] = data[j], data[i] + } +} + +// Perm is like rand.Perm, but it is seeded on the stack and does not allocate +// or lock any RNG state. +func Perm(seed uint64, n int) []int { + p := make([]int, n) + for i := range p { + p[i] = i + } + Shuffle(seed, p) + return p +} + +// uint64n is the no-bounds-checks version of rand.Uint64N from the standard +// library. 32-bit optimizations have been elided. +func uint64n(pcg *randv2.PCG, n uint64) uint64 { + if n&(n-1) == 0 { // n is power of two, can mask + return pcg.Uint64() & (n - 1) + } + + // Suppose we have a uint64 x uniform in the range [0,2⁶⁴) + // and want to reduce it to the range [0,n) preserving exact uniformity. + // We can simulate a scaling arbitrary precision x * (n/2⁶⁴) by + // the high bits of a double-width multiply of x*n, meaning (x*n)/2⁶⁴. + // Since there are 2⁶⁴ possible inputs x and only n possible outputs, + // the output is necessarily biased if n does not divide 2⁶⁴. + // In general (x*n)/2⁶⁴ = k for x*n in [k*2⁶⁴,(k+1)*2⁶⁴). + // There are either floor(2⁶⁴/n) or ceil(2⁶⁴/n) possible products + // in that range, depending on k. + // But suppose we reject the sample and try again when + // x*n is in [k*2⁶⁴, k*2⁶⁴+(2⁶⁴%n)), meaning rejecting fewer than n possible + // outcomes out of the 2⁶⁴. + // Now there are exactly floor(2⁶⁴/n) possible ways to produce + // each output value k, so we've restored uniformity. + // To get valid uint64 math, 2⁶⁴ % n = (2⁶⁴ - n) % n = -n % n, + // so the direct implementation of this algorithm would be: + // + // hi, lo := bits.Mul64(r.Uint64(), n) + // thresh := -n % n + // for lo < thresh { + // hi, lo = bits.Mul64(r.Uint64(), n) + // } + // + // That still leaves an expensive 64-bit division that we would rather avoid. + // We know that thresh < n, and n is usually much less than 2⁶⁴, so we can + // avoid the last four lines unless lo < n. + // + // See also: + // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction + // https://lemire.me/blog/2016/06/30/fast-random-shuffling + hi, lo := bits.Mul64(pcg.Uint64(), n) + if lo < n { + thresh := -n % n + for lo < thresh { + hi, lo = bits.Mul64(pcg.Uint64(), n) + } + } + return hi +} diff --git a/util/rands/cheap_test.go b/util/rands/cheap_test.go new file mode 100644 index 000000000..756b55b4e --- /dev/null +++ b/util/rands/cheap_test.go @@ -0,0 +1,96 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rands + +import ( + "slices" + "testing" + + randv2 "math/rand/v2" +) + +func TestShuffleNoAllocs(t *testing.T) { + seed := randv2.Uint64() + data := make([]int, 100) + for i := range data { + data[i] = i + } + if n := testing.AllocsPerRun(1000, func() { + Shuffle(seed, data) + }); n > 0 { + t.Errorf("Rand got %v allocs per run", n) + } +} + +func BenchmarkStdRandV2Shuffle(b *testing.B) { + seed := randv2.Uint64() + data := make([]int, 100) + for i := range data { + data[i] = i + } + b.ReportAllocs() + for range b.N { + // PCG is the lightest source, taking just two uint64s, the chacha8 + // source has much larger state. + rng := randv2.New(randv2.NewPCG(seed, seed)) + rng.Shuffle(len(data), func(i, j int) { data[i], data[j] = data[j], data[i] }) + } +} + +func BenchmarkLocalShuffle(b *testing.B) { + seed := randv2.Uint64() + data := make([]int, 100) + for i := range data { + data[i] = i + } + b.ReportAllocs() + for range b.N { + Shuffle(seed, data) + } +} + +func TestPerm(t *testing.T) { + seed := uint64(12345) + p := Perm(seed, 100) + if len(p) != 100 { + t.Errorf("got %v; want 100", len(p)) + } + expect := [][]int{ + {5, 7, 1, 4, 0, 9, 2, 3, 6, 8}, + {0, 5, 9, 8, 1, 6, 2, 4, 3, 7}, + {5, 2, 3, 1, 9, 7, 6, 8, 4, 0}, + {4, 5, 7, 1, 6, 3, 8, 2, 0, 9}, + {5, 7, 0, 9, 2, 1, 8, 4, 6, 3}, + } + for i := range 5 { + got := Perm(seed+uint64(i), 10) + want := expect[i] + if !slices.Equal(got, want) { + t.Errorf("got %v; want %v", got, want) + } + } +} + +func TestShuffle(t *testing.T) { + seed := uint64(12345) + p := Perm(seed, 10) + if len(p) != 10 { + t.Errorf("got %v; want 10", len(p)) + } + + expect := [][]int{ + {9, 3, 7, 0, 5, 8, 1, 4, 2, 6}, + {9, 8, 6, 2, 3, 1, 7, 5, 0, 4}, + {1, 6, 2, 8, 4, 5, 7, 0, 3, 9}, + {4, 5, 0, 6, 7, 8, 3, 2, 1, 9}, + {8, 2, 4, 9, 0, 5, 1, 7, 3, 6}, + } + for i := range 5 { + Shuffle(seed+uint64(i), p) + want := expect[i] + if !slices.Equal(p, want) { + t.Errorf("got %v; want %v", p, want) + } + } +}