172 lines
4.0 KiB
Go
172 lines
4.0 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
package reload
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"os"
|
||
|
"path/filepath"
|
||
|
"reflect"
|
||
|
"sync/atomic"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"tailscale.com/tstest"
|
||
|
)
|
||
|
|
||
|
func TestReloader(t *testing.T) {
|
||
|
buf := []byte("hello world")
|
||
|
|
||
|
ctx := context.Background()
|
||
|
r, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||
|
Logf: t.Logf,
|
||
|
Read: func(context.Context) ([]byte, error) {
|
||
|
return buf, nil
|
||
|
},
|
||
|
Unmarshal: func(b []byte) (string, error) {
|
||
|
return "The value is: " + string(b), nil
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
// We should have an initial value.
|
||
|
const wantInitial = "The value is: hello world"
|
||
|
if v := r.store.Load(); v != wantInitial {
|
||
|
t.Errorf("got initial value %q, want %q", v, wantInitial)
|
||
|
}
|
||
|
|
||
|
// Reloading should result in a new value
|
||
|
buf = []byte("new value")
|
||
|
if err := r.updateOnce(); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
const wantReload = "The value is: new value"
|
||
|
if v := r.store.Load(); v != wantReload {
|
||
|
t.Errorf("got reloaded value %q, want %q", v, wantReload)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReloader_InitialError(t *testing.T) {
|
||
|
fakeErr := errors.New("fake error")
|
||
|
|
||
|
ctx := context.Background()
|
||
|
_, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||
|
Logf: t.Logf,
|
||
|
Read: func(context.Context) ([]byte, error) { return nil, fakeErr },
|
||
|
Unmarshal: func(b []byte) (string, error) { panic("unused because Read fails") },
|
||
|
})
|
||
|
if err == nil {
|
||
|
t.Fatal("expected non-nil error")
|
||
|
}
|
||
|
if !errors.Is(err, fakeErr) {
|
||
|
t.Errorf("wanted errors.Is(%v, fakeErr)=true", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReloader_ReloadError(t *testing.T) {
|
||
|
fakeErr := errors.New("fake error")
|
||
|
shouldError := false
|
||
|
|
||
|
ctx := context.Background()
|
||
|
r, err := newUnstarted[string](ctx, ReloadOpts[string]{
|
||
|
Logf: t.Logf,
|
||
|
Read: func(context.Context) ([]byte, error) {
|
||
|
return []byte("hello"), nil
|
||
|
},
|
||
|
Unmarshal: func(b []byte) (string, error) {
|
||
|
if shouldError {
|
||
|
return "", fakeErr
|
||
|
}
|
||
|
return string(b), nil
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
if got := r.store.Load(); got != "hello" {
|
||
|
t.Fatalf("got value %q, want \"hello\"", got)
|
||
|
}
|
||
|
|
||
|
shouldError = true
|
||
|
|
||
|
if err := r.updateOnce(); err == nil {
|
||
|
t.Errorf("expected error from updateOnce")
|
||
|
}
|
||
|
if got := r.store.Load(); got != "hello" {
|
||
|
t.Fatalf("got value %q, want \"hello\"", got)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestReloader_Run(t *testing.T) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
var ncalls atomic.Int64
|
||
|
load, err := New[string](ctx, ReloadOpts[string]{
|
||
|
Logf: tstest.WhileTestRunningLogger(t),
|
||
|
Interval: 10 * time.Millisecond,
|
||
|
Read: func(context.Context) ([]byte, error) {
|
||
|
return []byte("hello"), nil
|
||
|
},
|
||
|
Unmarshal: func(b []byte) (string, error) {
|
||
|
callNum := ncalls.Add(1)
|
||
|
if callNum == 3 {
|
||
|
cancel()
|
||
|
}
|
||
|
return fmt.Sprintf("call %d: %s", callNum, b), nil
|
||
|
},
|
||
|
})
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
want := "call 1: hello"
|
||
|
if got := load(); got != want {
|
||
|
t.Fatalf("got value %q, want %q", got, want)
|
||
|
}
|
||
|
|
||
|
// Wait for the periodic refresh to cancel our context
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
case <-time.After(10 * time.Second):
|
||
|
t.Fatal("test timed out")
|
||
|
}
|
||
|
|
||
|
// Depending on how goroutines get scheduled, we can either read call 2
|
||
|
// (if we woke up before the run goroutine stores call 3), or call 3
|
||
|
// (if we woke up after the run goroutine stores the next value). Check
|
||
|
// for both.
|
||
|
want1, want2 := "call 2: hello", "call 3: hello"
|
||
|
if got := load(); got != want1 && got != want2 {
|
||
|
t.Fatalf("got value %q, want %q or %q", got, want1, want2)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestFromJSONFile(t *testing.T) {
|
||
|
type testStruct struct {
|
||
|
Value string
|
||
|
Number int
|
||
|
}
|
||
|
fpath := filepath.Join(t.TempDir(), "test.json")
|
||
|
if err := os.WriteFile(fpath, []byte(`{"Value": "hello", "Number": 1234}`), 0600); err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
ctx := context.Background()
|
||
|
r, err := newUnstarted(ctx, FromJSONFile[*testStruct](fpath))
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
|
||
|
got := r.store.Load()
|
||
|
want := &testStruct{Value: "hello", Number: 1234}
|
||
|
if !reflect.DeepEqual(got, want) {
|
||
|
t.Errorf("got %+v, want %+v", got, want)
|
||
|
}
|
||
|
}
|