diff --git a/util/expvarx/expvarx.go b/util/expvarx/expvarx.go new file mode 100644 index 000000000..762f65d06 --- /dev/null +++ b/util/expvarx/expvarx.go @@ -0,0 +1,89 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package expvarx provides some extensions to the [expvar] package. +package expvarx + +import ( + "encoding/json" + "expvar" + "sync" + "time" + + "tailscale.com/types/lazy" +) + +// SafeFunc is a wrapper around [expvar.Func] that guards against unbounded call +// time and ensures that only a single call is in progress at any given time. +type SafeFunc struct { + f expvar.Func + limit time.Duration + onSlow func(time.Duration, any) + + mu sync.Mutex + inflight *lazy.SyncValue[any] +} + +// NewSafeFunc returns a new SafeFunc that wraps f. +// If f takes longer than limit to execute then Value calls return nil. +// If onSlow is non-nil, it is called when f takes longer than limit to execute. +// onSlow is called with the duration of the slow call and the final computed +// value. +func NewSafeFunc(f expvar.Func, limit time.Duration, onSlow func(time.Duration, any)) *SafeFunc { + return &SafeFunc{f: f, limit: limit, onSlow: onSlow} +} + +// Value acts similarly to [expvar.Func.Value], but if the underlying function +// takes longer than the configured limit, all callers will receive nil until +// the underlying operation completes. On completion of the underlying +// operation, the onSlow callback is called if set. +func (s *SafeFunc) Value() any { + s.mu.Lock() + + if s.inflight == nil { + s.inflight = new(lazy.SyncValue[any]) + } + var inflight = s.inflight + s.mu.Unlock() + + // inflight ensures that only a single work routine is spawned at any given + // time, but if the routine takes too long inflight is populated with a nil + // result. The long running computed value is lost forever. + return inflight.Get(func() any { + start := time.Now() + result := make(chan any, 1) + + // work is spawned in routine so that the caller can timeout. + go func() { + // Allow new work to be started after this work completes + defer func() { + s.mu.Lock() + s.inflight = nil + s.mu.Unlock() + + }() + + v := s.f.Value() + result <- v + }() + + select { + case v := <-result: + return v + case <-time.After(s.limit): + if s.onSlow != nil { + go func() { + s.onSlow(time.Since(start), <-result) + }() + } + return nil + } + }) +} + +// String implements stringer in the same pattern as [expvar.Func], calling +// Value and serializing the result as JSON, ignoring errors. +func (s *SafeFunc) String() string { + v, _ := json.Marshal(s.Value()) + return string(v) +} diff --git a/util/expvarx/expvarx_test.go b/util/expvarx/expvarx_test.go new file mode 100644 index 000000000..16a989928 --- /dev/null +++ b/util/expvarx/expvarx_test.go @@ -0,0 +1,137 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package expvarx + +import ( + "expvar" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func ExampleNewSafeFunc() { + // An artificial blocker to emulate a slow operation. + blocker := make(chan struct{}) + + // limit is the amount of time a call can take before Value returns nil. No + // new calls to the unsafe func will be started until the slow call + // completes, at which point onSlow will be called. + limit := time.Millisecond + + // onSlow is called with the final call duration and the final value in the + // event a slow call. + onSlow := func(d time.Duration, v any) { + _ = d // d contains the time the call took + _ = v // v contains the final value computed by the slow call + fmt.Println("slow call!") + } + + // An unsafe expvar.Func that blocks on the blocker channel. + unsafeFunc := expvar.Func(func() any { + for range blocker { + } + return "hello world" + }) + + // f implements the same interface as expvar.Func, but returns nil values + // when the unsafe func is too slow. + f := NewSafeFunc(unsafeFunc, limit, onSlow) + + fmt.Println(f.Value()) + fmt.Println(f.Value()) + close(blocker) + time.Sleep(time.Millisecond) + fmt.Println(f.Value()) + // Output: + // + // slow call! + // hello world +} + +func TestSafeFuncHappyPath(t *testing.T) { + var count int + f := NewSafeFunc(expvar.Func(func() any { + count++ + return count + }), time.Millisecond, nil) + + if got, want := f.Value(), 1; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := f.Value(), 2; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestSafeFuncSlow(t *testing.T) { + var count int + blocker := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + f := NewSafeFunc(expvar.Func(func() any { + defer wg.Done() + count++ + <-blocker + return count + }), time.Millisecond, nil) + + if got := f.Value(); got != nil { + t.Errorf("got %v; want nil", got) + } + if got := f.Value(); got != nil { + t.Errorf("got %v; want nil", got) + } + + close(blocker) + wg.Wait() + + if count != 1 { + t.Errorf("got count=%d; want 1", count) + } +} + +func TestSafeFuncSlowOnSlow(t *testing.T) { + var count int + blocker := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(2) + var slowDuration atomic.Pointer[time.Duration] + var slowCallCount atomic.Int32 + var slowValue atomic.Value + f := NewSafeFunc(expvar.Func(func() any { + defer wg.Done() + count++ + <-blocker + return count + }), time.Millisecond, func(d time.Duration, v any) { + defer wg.Done() + slowDuration.Store(&d) + slowCallCount.Add(1) + slowValue.Store(v) + }) + + for i := 0; i < 10; i++ { + if got := f.Value(); got != nil { + t.Fatalf("got value=%v; want nil", got) + } + } + + close(blocker) + wg.Wait() + + if count != 1 { + t.Errorf("got count=%d; want 1", count) + } + if got, want := *slowDuration.Load(), 1*time.Millisecond; got < want { + t.Errorf("got slowDuration=%v; want at least %d", got, want) + } + if got, want := slowCallCount.Load(), int32(1); got != want { + t.Errorf("got slowCallCount=%d; want %d", got, want) + } + if got, want := slowValue.Load().(int), 1; got != want { + t.Errorf("got slowValue=%d, want %d", got, want) + } +}