From 9e6b4d7ad8b9045ebadf77fae578e0d9e9f7d439 Mon Sep 17 00:00:00 2001 From: David Anderson Date: Fri, 10 Feb 2023 13:54:07 -0800 Subject: [PATCH] types/lazy: helpers for lazily computed values Co-authored-by: Maisem Ali Co-authored-by: Brad Fitzpatrick Signed-off-by: David Anderson --- types/lazy/lazy.go | 88 ++++++++++++++++++++++ types/lazy/sync_test.go | 150 ++++++++++++++++++++++++++++++++++++++ types/lazy/unsync.go | 99 +++++++++++++++++++++++++ types/lazy/unsync_test.go | 140 +++++++++++++++++++++++++++++++++++ 4 files changed, 477 insertions(+) create mode 100644 types/lazy/lazy.go create mode 100644 types/lazy/sync_test.go create mode 100644 types/lazy/unsync.go create mode 100644 types/lazy/unsync_test.go diff --git a/types/lazy/lazy.go b/types/lazy/lazy.go new file mode 100644 index 000000000..04629bcbe --- /dev/null +++ b/types/lazy/lazy.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lazy provides types for lazily initialized values. +package lazy + +import "sync" + +// SyncValue is a lazily computed value. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// Recursive use of a SyncValue from its own fill function will deadlock. +// +// SyncValue is safe for concurrent use. +type SyncValue[T any] struct { + once sync.Once + v T + err error +} + +// Set attempts to set z's value to val, and reports whether it succeeded. +// Set only succeeds if none of Get/GetErr/Set have been called before. +func (z *SyncValue[T]) Set(val T) bool { + var wasSet bool + z.once.Do(func() { + z.v = val + wasSet = true + }) + return wasSet +} + +// MustSet sets z's value to val, or panics if z already has a value. +func (z *SyncValue[T]) MustSet(val T) { + if !z.Set(val) { + panic("Set after already filled") + } +} + +// Get returns z's value, calling fill to compute it if necessary. +// f is called at most once. +func (z *SyncValue[T]) Get(fill func() T) T { + z.once.Do(func() { z.v = fill() }) + return z.v +} + +// GetErr returns z's value, calling fill to compute it if necessary. +// f is called at most once, and z remembers both of fill's outputs. +func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) { + z.once.Do(func() { z.v, z.err = fill() }) + return z.v, z.err +} + +// SyncFunc wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's result on every subsequent call. +// +// The returned function is safe for concurrent use. +func SyncFunc[T any](fill func() T) func() T { + var ( + once sync.Once + v T + ) + return func() T { + once.Do(func() { v = fill() }) + return v + } +} + +// SyncFuncErr wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's results on every subsequent call. +// +// The returned function is safe for concurrent use. +func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) { + var ( + once sync.Once + v T + err error + ) + return func() (T, error) { + once.Do(func() { v, err = fill() }) + return v, err + } +} diff --git a/types/lazy/sync_test.go b/types/lazy/sync_test.go new file mode 100644 index 000000000..ac92c4914 --- /dev/null +++ b/types/lazy/sync_test.go @@ -0,0 +1,150 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "sync" + "testing" +) + +func TestSyncValue(t *testing.T) { + var lt SyncValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestSyncValueErr(t *testing.T) { + var lt SyncValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := lt.GetErr(func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var lterr SyncValue[int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := lterr.GetErr(func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestSyncValueSet(t *testing.T) { + var lt SyncValue[int] + if !lt.Set(42) { + t.Fatalf("Set failed") + } + if lt.Set(43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestSyncValueMustSet(t *testing.T) { + var lt SyncValue[int] + lt.MustSet(42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + lt.MustSet(43) +} + +func TestSyncValueConcurrent(t *testing.T) { + var ( + lt SyncValue[int] + wg sync.WaitGroup + start = make(chan struct{}) + routines = 10000 + ) + wg.Add(routines) + for i := 0; i < routines; i++ { + go func() { + defer wg.Done() + // Every goroutine waits for the go signal, so that more of them + // have a chance to race on the initial Get than with sequential + // goroutine starts. + <-start + got := lt.Get(fortyTwo) + if got != 42 { + t.Errorf("got %v; want 42", got) + } + }() + } + close(start) + wg.Wait() +} + +func TestSyncFunc(t *testing.T) { + f := SyncFunc(fortyTwo) + + n := int(testing.AllocsPerRun(1000, func() { + got := f() + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestSyncFuncErr(t *testing.T) { + f := SyncFuncErr(func() (int, error) { + return 42, nil + }) + n := int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + wantErr := errors.New("test error") + f = SyncFuncErr(func() (int, error) { + return 0, wantErr + }) + n = int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} diff --git a/types/lazy/unsync.go b/types/lazy/unsync.go new file mode 100644 index 000000000..0f89ce4f6 --- /dev/null +++ b/types/lazy/unsync.go @@ -0,0 +1,99 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +// GValue is a lazily computed value. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// Recursive use of a GValue from its own fill function will panic. +// +// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine, +// which isn't strictly true if you provide your own synchronization between +// goroutines, but in practice most of our callers have been using it within +// a single goroutine.) +type GValue[T any] struct { + done bool + calling bool + V T + err error +} + +// Set attempts to set z's value to val, and reports whether it succeeded. +// Set only succeeds if none of Get/GetErr/Set have been called before. +func (z *GValue[T]) Set(v T) bool { + if z.done { + return false + } + if z.calling { + panic("Set while Get fill is running") + } + z.V = v + z.done = true + return true +} + +// MustSet sets z's value to val, or panics if z already has a value. +func (z *GValue[T]) MustSet(val T) { + if !z.Set(val) { + panic("Set after already filled") + } +} + +// Get returns z's value, calling fill to compute it if necessary. +// f is called at most once. +func (z *GValue[T]) Get(fill func() T) T { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V = fill() + z.done = true + z.calling = false + } + return z.V +} + +// GetErr returns z's value, calling fill to compute it if necessary. +// f is called at most once, and z remembers both of fill's outputs. +func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { + if !z.done { + if z.calling { + panic("recursive lazy fill") + } + z.calling = true + z.V, z.err = fill() + z.done = true + z.calling = false + } + return z.V, z.err +} + +// GFunc wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's result on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFunc[T any](fill func() T) func() T { + var v GValue[T] + return func() T { + return v.Get(fill) + } +} + +// SyncFuncErr wraps a function to make it lazy. +// +// The returned function calls fill the first time it's called, and returns +// fill's results on every subsequent call. +// +// The returned function is not safe for concurrent use. +func GFuncErr[T any](fill func() (T, error)) func() (T, error) { + var v GValue[T] + return func() (T, error) { + return v.GetErr(fill) + } +} diff --git a/types/lazy/unsync_test.go b/types/lazy/unsync_test.go new file mode 100644 index 000000000..f0d2494d1 --- /dev/null +++ b/types/lazy/unsync_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "testing" +) + +func fortyTwo() int { return 42 } + +func TestGValue(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueErr(t *testing.T) { + var lt GValue[int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := lt.GetErr(func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var lterr GValue[int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := lterr.GetErr(func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueSet(t *testing.T) { + var lt GValue[int] + if !lt.Set(42) { + t.Fatalf("Set failed") + } + if lt.Set(43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := lt.Get(fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGValueMustSet(t *testing.T) { + var lt GValue[int] + lt.MustSet(42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + lt.MustSet(43) +} + +func TestGValueRecursivePanic(t *testing.T) { + defer func() { + if e := recover(); e != nil { + t.Logf("got panic, as expected") + } else { + t.Errorf("unexpected success; want panic") + } + }() + v := GValue[int]{} + v.Get(func() int { + return v.Get(func() int { return 42 }) + }) +} + +func TestGFunc(t *testing.T) { + f := GFunc(fortyTwo) + + n := int(testing.AllocsPerRun(1000, func() { + got := f() + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGFuncErr(t *testing.T) { + f := GFuncErr(func() (int, error) { + return 42, nil + }) + n := int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + wantErr := errors.New("test error") + f = GFuncErr(func() (int, error) { + return 0, wantErr + }) + n = int(testing.AllocsPerRun(1000, func() { + got, err := f() + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +}