types/lazy: helpers for lazily computed values
Co-authored-by: Maisem Ali <maisem@tailscale.com> Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com> Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
5bca44d572
commit
9e6b4d7ad8
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
// Package lazy provides types for lazily initialized values.
|
||||||
|
package lazy
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// SyncValue is a lazily computed value.
|
||||||
|
//
|
||||||
|
// Use either Get or GetErr, depending on whether your fill function returns an
|
||||||
|
// error.
|
||||||
|
//
|
||||||
|
// Recursive use of a SyncValue from its own fill function will deadlock.
|
||||||
|
//
|
||||||
|
// SyncValue is safe for concurrent use.
|
||||||
|
type SyncValue[T any] struct {
|
||||||
|
once sync.Once
|
||||||
|
v T
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set attempts to set z's value to val, and reports whether it succeeded.
|
||||||
|
// Set only succeeds if none of Get/GetErr/Set have been called before.
|
||||||
|
func (z *SyncValue[T]) Set(val T) bool {
|
||||||
|
var wasSet bool
|
||||||
|
z.once.Do(func() {
|
||||||
|
z.v = val
|
||||||
|
wasSet = true
|
||||||
|
})
|
||||||
|
return wasSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustSet sets z's value to val, or panics if z already has a value.
|
||||||
|
func (z *SyncValue[T]) MustSet(val T) {
|
||||||
|
if !z.Set(val) {
|
||||||
|
panic("Set after already filled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns z's value, calling fill to compute it if necessary.
|
||||||
|
// f is called at most once.
|
||||||
|
func (z *SyncValue[T]) Get(fill func() T) T {
|
||||||
|
z.once.Do(func() { z.v = fill() })
|
||||||
|
return z.v
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErr returns z's value, calling fill to compute it if necessary.
|
||||||
|
// f is called at most once, and z remembers both of fill's outputs.
|
||||||
|
func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) {
|
||||||
|
z.once.Do(func() { z.v, z.err = fill() })
|
||||||
|
return z.v, z.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncFunc wraps a function to make it lazy.
|
||||||
|
//
|
||||||
|
// The returned function calls fill the first time it's called, and returns
|
||||||
|
// fill's result on every subsequent call.
|
||||||
|
//
|
||||||
|
// The returned function is safe for concurrent use.
|
||||||
|
func SyncFunc[T any](fill func() T) func() T {
|
||||||
|
var (
|
||||||
|
once sync.Once
|
||||||
|
v T
|
||||||
|
)
|
||||||
|
return func() T {
|
||||||
|
once.Do(func() { v = fill() })
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncFuncErr wraps a function to make it lazy.
|
||||||
|
//
|
||||||
|
// The returned function calls fill the first time it's called, and returns
|
||||||
|
// fill's results on every subsequent call.
|
||||||
|
//
|
||||||
|
// The returned function is safe for concurrent use.
|
||||||
|
func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) {
|
||||||
|
var (
|
||||||
|
once sync.Once
|
||||||
|
v T
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
return func() (T, error) {
|
||||||
|
once.Do(func() { v, err = fill() })
|
||||||
|
return v, err
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,150 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lazy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSyncValue(t *testing.T) {
|
||||||
|
var lt SyncValue[int]
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := lt.Get(fortyTwo)
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncValueErr(t *testing.T) {
|
||||||
|
var lt SyncValue[int]
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := lt.GetErr(func() (int, error) {
|
||||||
|
return 42, nil
|
||||||
|
})
|
||||||
|
if got != 42 || err != nil {
|
||||||
|
t.Fatalf("got %v, %v; want 42, nil", got, err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lterr SyncValue[int]
|
||||||
|
wantErr := errors.New("test error")
|
||||||
|
n = int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := lterr.GetErr(func() (int, error) {
|
||||||
|
return 0, wantErr
|
||||||
|
})
|
||||||
|
if got != 0 || err != wantErr {
|
||||||
|
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncValueSet(t *testing.T) {
|
||||||
|
var lt SyncValue[int]
|
||||||
|
if !lt.Set(42) {
|
||||||
|
t.Fatalf("Set failed")
|
||||||
|
}
|
||||||
|
if lt.Set(43) {
|
||||||
|
t.Fatalf("Set succeeded after first Set")
|
||||||
|
}
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := lt.Get(fortyTwo)
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncValueMustSet(t *testing.T) {
|
||||||
|
var lt SyncValue[int]
|
||||||
|
lt.MustSet(42)
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e == nil {
|
||||||
|
t.Errorf("unexpected success; want panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
lt.MustSet(43)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncValueConcurrent(t *testing.T) {
|
||||||
|
var (
|
||||||
|
lt SyncValue[int]
|
||||||
|
wg sync.WaitGroup
|
||||||
|
start = make(chan struct{})
|
||||||
|
routines = 10000
|
||||||
|
)
|
||||||
|
wg.Add(routines)
|
||||||
|
for i := 0; i < routines; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// Every goroutine waits for the go signal, so that more of them
|
||||||
|
// have a chance to race on the initial Get than with sequential
|
||||||
|
// goroutine starts.
|
||||||
|
<-start
|
||||||
|
got := lt.Get(fortyTwo)
|
||||||
|
if got != 42 {
|
||||||
|
t.Errorf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
close(start)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncFunc(t *testing.T) {
|
||||||
|
f := SyncFunc(fortyTwo)
|
||||||
|
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := f()
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncFuncErr(t *testing.T) {
|
||||||
|
f := SyncFuncErr(func() (int, error) {
|
||||||
|
return 42, nil
|
||||||
|
})
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := f()
|
||||||
|
if got != 42 || err != nil {
|
||||||
|
t.Fatalf("got %v, %v; want 42, nil", got, err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantErr := errors.New("test error")
|
||||||
|
f = SyncFuncErr(func() (int, error) {
|
||||||
|
return 0, wantErr
|
||||||
|
})
|
||||||
|
n = int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := f()
|
||||||
|
if got != 0 || err != wantErr {
|
||||||
|
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,99 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lazy
|
||||||
|
|
||||||
|
// GValue is a lazily computed value.
|
||||||
|
//
|
||||||
|
// Use either Get or GetErr, depending on whether your fill function returns an
|
||||||
|
// error.
|
||||||
|
//
|
||||||
|
// Recursive use of a GValue from its own fill function will panic.
|
||||||
|
//
|
||||||
|
// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine,
|
||||||
|
// which isn't strictly true if you provide your own synchronization between
|
||||||
|
// goroutines, but in practice most of our callers have been using it within
|
||||||
|
// a single goroutine.)
|
||||||
|
type GValue[T any] struct {
|
||||||
|
done bool
|
||||||
|
calling bool
|
||||||
|
V T
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set attempts to set z's value to val, and reports whether it succeeded.
|
||||||
|
// Set only succeeds if none of Get/GetErr/Set have been called before.
|
||||||
|
func (z *GValue[T]) Set(v T) bool {
|
||||||
|
if z.done {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if z.calling {
|
||||||
|
panic("Set while Get fill is running")
|
||||||
|
}
|
||||||
|
z.V = v
|
||||||
|
z.done = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustSet sets z's value to val, or panics if z already has a value.
|
||||||
|
func (z *GValue[T]) MustSet(val T) {
|
||||||
|
if !z.Set(val) {
|
||||||
|
panic("Set after already filled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns z's value, calling fill to compute it if necessary.
|
||||||
|
// f is called at most once.
|
||||||
|
func (z *GValue[T]) Get(fill func() T) T {
|
||||||
|
if !z.done {
|
||||||
|
if z.calling {
|
||||||
|
panic("recursive lazy fill")
|
||||||
|
}
|
||||||
|
z.calling = true
|
||||||
|
z.V = fill()
|
||||||
|
z.done = true
|
||||||
|
z.calling = false
|
||||||
|
}
|
||||||
|
return z.V
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErr returns z's value, calling fill to compute it if necessary.
|
||||||
|
// f is called at most once, and z remembers both of fill's outputs.
|
||||||
|
func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) {
|
||||||
|
if !z.done {
|
||||||
|
if z.calling {
|
||||||
|
panic("recursive lazy fill")
|
||||||
|
}
|
||||||
|
z.calling = true
|
||||||
|
z.V, z.err = fill()
|
||||||
|
z.done = true
|
||||||
|
z.calling = false
|
||||||
|
}
|
||||||
|
return z.V, z.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GFunc wraps a function to make it lazy.
|
||||||
|
//
|
||||||
|
// The returned function calls fill the first time it's called, and returns
|
||||||
|
// fill's result on every subsequent call.
|
||||||
|
//
|
||||||
|
// The returned function is not safe for concurrent use.
|
||||||
|
func GFunc[T any](fill func() T) func() T {
|
||||||
|
var v GValue[T]
|
||||||
|
return func() T {
|
||||||
|
return v.Get(fill)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncFuncErr wraps a function to make it lazy.
|
||||||
|
//
|
||||||
|
// The returned function calls fill the first time it's called, and returns
|
||||||
|
// fill's results on every subsequent call.
|
||||||
|
//
|
||||||
|
// The returned function is not safe for concurrent use.
|
||||||
|
func GFuncErr[T any](fill func() (T, error)) func() (T, error) {
|
||||||
|
var v GValue[T]
|
||||||
|
return func() (T, error) {
|
||||||
|
return v.GetErr(fill)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,140 @@
|
||||||
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||||||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||||||
|
|
||||||
|
package lazy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func fortyTwo() int { return 42 }
|
||||||
|
|
||||||
|
func TestGValue(t *testing.T) {
|
||||||
|
var lt GValue[int]
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := lt.Get(fortyTwo)
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGValueErr(t *testing.T) {
|
||||||
|
var lt GValue[int]
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := lt.GetErr(func() (int, error) {
|
||||||
|
return 42, nil
|
||||||
|
})
|
||||||
|
if got != 42 || err != nil {
|
||||||
|
t.Fatalf("got %v, %v; want 42, nil", got, err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lterr GValue[int]
|
||||||
|
wantErr := errors.New("test error")
|
||||||
|
n = int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := lterr.GetErr(func() (int, error) {
|
||||||
|
return 0, wantErr
|
||||||
|
})
|
||||||
|
if got != 0 || err != wantErr {
|
||||||
|
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGValueSet(t *testing.T) {
|
||||||
|
var lt GValue[int]
|
||||||
|
if !lt.Set(42) {
|
||||||
|
t.Fatalf("Set failed")
|
||||||
|
}
|
||||||
|
if lt.Set(43) {
|
||||||
|
t.Fatalf("Set succeeded after first Set")
|
||||||
|
}
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := lt.Get(fortyTwo)
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGValueMustSet(t *testing.T) {
|
||||||
|
var lt GValue[int]
|
||||||
|
lt.MustSet(42)
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e == nil {
|
||||||
|
t.Errorf("unexpected success; want panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
lt.MustSet(43)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGValueRecursivePanic(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
t.Logf("got panic, as expected")
|
||||||
|
} else {
|
||||||
|
t.Errorf("unexpected success; want panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
v := GValue[int]{}
|
||||||
|
v.Get(func() int {
|
||||||
|
return v.Get(func() int { return 42 })
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGFunc(t *testing.T) {
|
||||||
|
f := GFunc(fortyTwo)
|
||||||
|
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got := f()
|
||||||
|
if got != 42 {
|
||||||
|
t.Fatalf("got %v; want 42", got)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGFuncErr(t *testing.T) {
|
||||||
|
f := GFuncErr(func() (int, error) {
|
||||||
|
return 42, nil
|
||||||
|
})
|
||||||
|
n := int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := f()
|
||||||
|
if got != 42 || err != nil {
|
||||||
|
t.Fatalf("got %v, %v; want 42, nil", got, err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantErr := errors.New("test error")
|
||||||
|
f = GFuncErr(func() (int, error) {
|
||||||
|
return 0, wantErr
|
||||||
|
})
|
||||||
|
n = int(testing.AllocsPerRun(1000, func() {
|
||||||
|
got, err := f()
|
||||||
|
if got != 0 || err != wantErr {
|
||||||
|
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("allocs = %v; want 0", n)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue