From 512fc0b5020a53851c36c1fd1713ca97b0f25fcb Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 15 Mar 2024 12:14:29 -0400 Subject: [PATCH] util/reload: add new package to handle periodic value loading This can be used to reload a value periodically, whether from disk or another source, while handling jitter and graceful shutdown. Updates tailscale/corp#1297 Signed-off-by: Andrew Dunham Change-Id: Iee2b4385c9abae59805f642a7308837877cb5b3f --- util/reload/reload.go | 189 +++++++++++++++++++++++++++++++++++++ util/reload/reload_test.go | 171 +++++++++++++++++++++++++++++++++ 2 files changed, 360 insertions(+) create mode 100644 util/reload/reload.go create mode 100644 util/reload/reload_test.go diff --git a/util/reload/reload.go b/util/reload/reload.go new file mode 100644 index 000000000..84e7ba2bf --- /dev/null +++ b/util/reload/reload.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package reload contains functions that allow periodically reloading a value +// (e.g. a config file) from various sources. +package reload + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "reflect" + "time" + + "tailscale.com/syncs" + "tailscale.com/types/logger" +) + +// DefaultInterval is the default value for ReloadOpts.Interval if none is +// provided. +const DefaultInterval = 5 * time.Minute + +// ReloadOpts specifies options for reloading a value. Various helper functions +// in this package can be used to create one of these specialized for a given +// use-case. +type ReloadOpts[T any] struct { + // Read is called to obtain the data to be unmarshaled; e.g. by reading + // from a file, or making a network request, etc. + // + // An error from this function is fatal when calling New, but only a + // warning during reload. + // + // This value is required. + Read func(context.Context) ([]byte, error) + + // Unmarshal is called with the data that the Read function returns and + // should return a parsed form of the given value, or an error. + // + // An error from this function is fatal when calling New, but only a + // warning during reload. + // + // This value is required. + Unmarshal func([]byte) (T, error) + + // Logf is a logger used to print errors that occur on reload. If nil, + // no messages are printed. + Logf logger.Logf + + // Interval is the interval at which to reload the given data from the + // source; if zero, DefaultInterval will be used. + Interval time.Duration + + // IntervalJitter is the jitter to be added to the given Interval; if + // provided, a duration between 0 and this value will be added to each + // Interval when waiting. + IntervalJitter time.Duration +} + +func (r *ReloadOpts[T]) logf(format string, args ...any) { + if r.Logf != nil { + r.Logf(format, args...) + } +} + +func (r *ReloadOpts[T]) intervalWithJitter() time.Duration { + tt := r.Interval + if tt == 0 { + tt = DefaultInterval + } + if r.IntervalJitter == 0 { + return tt + } + + jitter := time.Duration(rand.Intn(int(r.IntervalJitter))) + return tt + jitter +} + +// New creates and starts reloading the provided value as per opts. It returns +// a function that, when called, returns the current stored value, or an error +// that indicates something went wrong. +// +// The value will be present immediately upon return. +func New[T any](ctx context.Context, opts ReloadOpts[T]) (func() T, error) { + // Create our reloader, which hasn't started. + reloader, err := newUnstarted(ctx, opts) + if err != nil { + return nil, err + } + + // Start it + go reloader.run() + + // Return the load function now that we're all set up. + return reloader.store.Load, nil +} + +type reloader[T any] struct { + ctx context.Context + store syncs.AtomicValue[T] + opts ReloadOpts[T] +} + +// newUnstarted creates a reloader that hasn't yet been started. +func newUnstarted[T any](ctx context.Context, opts ReloadOpts[T]) (*reloader[T], error) { + if opts.Read == nil { + return nil, fmt.Errorf("the Read function is required") + } + if opts.Unmarshal == nil { + return nil, fmt.Errorf("the Unmarshal function is required") + } + + // Start by reading and unmarshaling the value. + data, err := opts.Read(ctx) + if err != nil { + return nil, fmt.Errorf("reading initial value: %w", err) + } + + initial, err := opts.Unmarshal(data) + if err != nil { + return nil, fmt.Errorf("unmarshaling initial value: %v", err) + } + + reloader := &reloader[T]{ + ctx: ctx, + opts: opts, + } + reloader.store.Store(initial) + return reloader, nil +} + +func (r *reloader[T]) run() { + // Create a timer that we re-set each time we fire. + timer := time.NewTimer(r.opts.intervalWithJitter()) + defer timer.Stop() + + for { + select { + case <-r.ctx.Done(): + r.opts.logf("run context is done") + return + case <-timer.C: + } + + if err := r.updateOnce(); err != nil { + r.opts.logf("error refreshing data: %v", err) + } + + // Re-arm the timer after we're done; this is safe + // since the only way this loop woke up was by reading + // from timer.C + timer.Reset(r.opts.intervalWithJitter()) + } +} + +func (r *reloader[T]) updateOnce() error { + data, err := r.opts.Read(r.ctx) + if err != nil { + return fmt.Errorf("reading data: %w", err) + } + next, err := r.opts.Unmarshal(data) + if err != nil { + return fmt.Errorf("unmarshaling data: %w", err) + } + + oldValue := r.store.Swap(next) + if !reflect.DeepEqual(oldValue, next) { + r.opts.logf("stored new value: %+v", next) + } + return nil +} + +// FromJSONFile creates a ReloadOpts describing reloading a value of type T +// from the given JSON file on-disk. +func FromJSONFile[T any](path string) ReloadOpts[T] { + return ReloadOpts[T]{ + Read: func(_ context.Context) ([]byte, error) { + return os.ReadFile(path) + }, + Unmarshal: func(b []byte) (T, error) { + var ret, zero T + if err := json.Unmarshal(b, &ret); err != nil { + return zero, err + } + return ret, nil + }, + } +} diff --git a/util/reload/reload_test.go b/util/reload/reload_test.go new file mode 100644 index 000000000..f6a381686 --- /dev/null +++ b/util/reload/reload_test.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package reload + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "sync/atomic" + "testing" + "time" + + "tailscale.com/tstest" +) + +func TestReloader(t *testing.T) { + buf := []byte("hello world") + + ctx := context.Background() + r, err := newUnstarted[string](ctx, ReloadOpts[string]{ + Logf: t.Logf, + Read: func(context.Context) ([]byte, error) { + return buf, nil + }, + Unmarshal: func(b []byte) (string, error) { + return "The value is: " + string(b), nil + }, + }) + if err != nil { + t.Fatal(err) + } + + // We should have an initial value. + const wantInitial = "The value is: hello world" + if v := r.store.Load(); v != wantInitial { + t.Errorf("got initial value %q, want %q", v, wantInitial) + } + + // Reloading should result in a new value + buf = []byte("new value") + if err := r.updateOnce(); err != nil { + t.Fatal(err) + } + + const wantReload = "The value is: new value" + if v := r.store.Load(); v != wantReload { + t.Errorf("got reloaded value %q, want %q", v, wantReload) + } +} + +func TestReloader_InitialError(t *testing.T) { + fakeErr := errors.New("fake error") + + ctx := context.Background() + _, err := newUnstarted[string](ctx, ReloadOpts[string]{ + Logf: t.Logf, + Read: func(context.Context) ([]byte, error) { return nil, fakeErr }, + Unmarshal: func(b []byte) (string, error) { panic("unused because Read fails") }, + }) + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, fakeErr) { + t.Errorf("wanted errors.Is(%v, fakeErr)=true", err) + } +} + +func TestReloader_ReloadError(t *testing.T) { + fakeErr := errors.New("fake error") + shouldError := false + + ctx := context.Background() + r, err := newUnstarted[string](ctx, ReloadOpts[string]{ + Logf: t.Logf, + Read: func(context.Context) ([]byte, error) { + return []byte("hello"), nil + }, + Unmarshal: func(b []byte) (string, error) { + if shouldError { + return "", fakeErr + } + return string(b), nil + }, + }) + if err != nil { + t.Fatal(err) + } + if got := r.store.Load(); got != "hello" { + t.Fatalf("got value %q, want \"hello\"", got) + } + + shouldError = true + + if err := r.updateOnce(); err == nil { + t.Errorf("expected error from updateOnce") + } + if got := r.store.Load(); got != "hello" { + t.Fatalf("got value %q, want \"hello\"", got) + } +} + +func TestReloader_Run(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var ncalls atomic.Int64 + load, err := New[string](ctx, ReloadOpts[string]{ + Logf: tstest.WhileTestRunningLogger(t), + Interval: 10 * time.Millisecond, + Read: func(context.Context) ([]byte, error) { + return []byte("hello"), nil + }, + Unmarshal: func(b []byte) (string, error) { + callNum := ncalls.Add(1) + if callNum == 3 { + cancel() + } + return fmt.Sprintf("call %d: %s", callNum, b), nil + }, + }) + if err != nil { + t.Fatal(err) + } + want := "call 1: hello" + if got := load(); got != want { + t.Fatalf("got value %q, want %q", got, want) + } + + // Wait for the periodic refresh to cancel our context + select { + case <-ctx.Done(): + case <-time.After(10 * time.Second): + t.Fatal("test timed out") + } + + // Depending on how goroutines get scheduled, we can either read call 2 + // (if we woke up before the run goroutine stores call 3), or call 3 + // (if we woke up after the run goroutine stores the next value). Check + // for both. + want1, want2 := "call 2: hello", "call 3: hello" + if got := load(); got != want1 && got != want2 { + t.Fatalf("got value %q, want %q or %q", got, want1, want2) + } +} + +func TestFromJSONFile(t *testing.T) { + type testStruct struct { + Value string + Number int + } + fpath := filepath.Join(t.TempDir(), "test.json") + if err := os.WriteFile(fpath, []byte(`{"Value": "hello", "Number": 1234}`), 0600); err != nil { + t.Fatal(err) + } + + ctx := context.Background() + r, err := newUnstarted(ctx, FromJSONFile[*testStruct](fpath)) + if err != nil { + t.Fatal(err) + } + + got := r.store.Load() + want := &testStruct{Value: "hello", Number: 1234} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %+v, want %+v", got, want) + } +}