tailscale/tstest/tstest.go

96 lines
2.3 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package tstest provides utilities for use in unit tests.
package tstest
import (
"context"
"os"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"tailscale.com/envknob"
"tailscale.com/logtail/backoff"
"tailscale.com/types/logger"
"tailscale.com/util/cibuild"
)
// Replace replaces the value of target with val.
// The old value is restored when the test ends.
func Replace[T any](t testing.TB, target *T, val T) {
t.Helper()
if target == nil {
t.Fatalf("Replace: nil pointer")
panic("unreachable") // pacify staticcheck
}
old := *target
t.Cleanup(func() {
*target = old
})
*target = val
return
}
// WaitFor retries try for up to maxWait.
// It returns nil once try returns nil the first time.
// If maxWait passes without success, it returns try's last error.
func WaitFor(maxWait time.Duration, try func() error) error {
bo := backoff.NewBackoff("wait-for", logger.Discard, maxWait/4)
deadline := time.Now().Add(maxWait)
var err error
for time.Now().Before(deadline) {
err = try()
if err == nil {
break
}
bo.BackOff(context.Background(), err)
}
return err
}
var testNum atomic.Int32
// Shard skips t if it's not running if the TS_TEST_SHARD test shard is set to
// "n/m" and this test execution number in the process mod m is not equal to n-1.
// That is, to run with 4 shards, set TS_TEST_SHARD=1/4, ..., TS_TEST_SHARD=4/4
// for the four jobs.
func Shard(t testing.TB) {
e := os.Getenv("TS_TEST_SHARD")
a, b, ok := strings.Cut(e, "/")
if !ok {
return
}
wantShard, _ := strconv.ParseInt(a, 10, 32)
shards, _ := strconv.ParseInt(b, 10, 32)
if wantShard == 0 || shards == 0 {
return
}
shard := ((testNum.Add(1) - 1) % int32(shards)) + 1
if shard != int32(wantShard) {
t.Skipf("skipping shard %d/%d (process has TS_TEST_SHARD=%q)", shard, shards, e)
}
}
// SkipOnUnshardedCI skips t if we're in CI and the TS_TEST_SHARD
// environment variable isn't set.
func SkipOnUnshardedCI(t testing.TB) {
if cibuild.On() && os.Getenv("TS_TEST_SHARD") == "" {
t.Skip("skipping on CI without TS_TEST_SHARD")
}
}
var serializeParallel = envknob.RegisterBool("TS_SERIAL_TESTS")
// Parallel calls t.Parallel, unless TS_SERIAL_TESTS is set true.
func Parallel(t *testing.T) {
if !serializeParallel() {
t.Parallel()
}
}