449 lines
11 KiB
Go
449 lines
11 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package source
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
xmaps "golang.org/x/exp/maps"
|
|
"tailscale.com/util/mak"
|
|
"tailscale.com/util/set"
|
|
"tailscale.com/util/syspolicy/internal"
|
|
"tailscale.com/util/syspolicy/setting"
|
|
)
|
|
|
|
var (
|
|
_ Store = (*TestStore)(nil)
|
|
_ Lockable = (*TestStore)(nil)
|
|
_ Changeable = (*TestStore)(nil)
|
|
_ Expirable = (*TestStore)(nil)
|
|
)
|
|
|
|
// TestValueType is a constraint that allows types supported by [TestStore].
|
|
type TestValueType interface {
|
|
bool | uint64 | string | []string
|
|
}
|
|
|
|
// TestSetting is a policy setting in a [TestStore].
|
|
type TestSetting[T TestValueType] struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Error is the error to be returned by the [TestStore] when reading
|
|
// a policy setting with the specified key.
|
|
Error error
|
|
// Value is the value to be returned by the [TestStore] when reading
|
|
// a policy setting with the specified key.
|
|
// It is only used if the Error is nil.
|
|
Value T
|
|
}
|
|
|
|
// TestSettingOf returns a [TestSetting] representing a policy setting
|
|
// configured with the specified key and value.
|
|
func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] {
|
|
return TestSetting[T]{Key: key, Value: value}
|
|
}
|
|
|
|
// TestSettingWithError returns a [TestSetting] representing a policy setting
|
|
// with the specified key and error.
|
|
func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] {
|
|
return TestSetting[T]{Key: key, Error: err}
|
|
}
|
|
|
|
// testReadOperation describes a single policy setting read operation.
|
|
type testReadOperation struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Type is a value type of a read operation.
|
|
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
|
Type setting.Type
|
|
}
|
|
|
|
// TestExpectedReads is the number of read operations with the specified details.
|
|
type TestExpectedReads struct {
|
|
// Key is the setting's unique identifier.
|
|
Key setting.Key
|
|
// Type is a value type of a read operation.
|
|
// [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue]
|
|
Type setting.Type
|
|
// NumTimes is how many times a setting with the specified key and type should have been read.
|
|
NumTimes int
|
|
}
|
|
|
|
func (r TestExpectedReads) operation() testReadOperation {
|
|
return testReadOperation{r.Key, r.Type}
|
|
}
|
|
|
|
// TestStore is a [Store] that can be used in tests.
|
|
type TestStore struct {
|
|
tb internal.TB
|
|
|
|
done chan struct{}
|
|
|
|
storeLock sync.RWMutex // its RLock is exposed via [Store.Lock]/[Store.Unlock].
|
|
storeLockCount atomic.Int32
|
|
|
|
mu sync.RWMutex
|
|
suspendCount int // change callback are suspended if > 0
|
|
mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended.
|
|
cbs set.HandleSet[func()]
|
|
closed bool
|
|
|
|
readsMu sync.Mutex
|
|
reads map[testReadOperation]int // how many times a policy setting was read
|
|
}
|
|
|
|
// NewTestStore returns a new [TestStore].
|
|
// The tb will be used to report coding errors detected by the [TestStore].
|
|
func NewTestStore(tb internal.TB) *TestStore {
|
|
m := make(map[setting.Key]any)
|
|
store := &TestStore{
|
|
tb: tb,
|
|
done: make(chan struct{}),
|
|
mr: m,
|
|
mw: m,
|
|
}
|
|
tb.Cleanup(store.Close)
|
|
return store
|
|
}
|
|
|
|
// NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans],
|
|
// [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists].
|
|
func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore {
|
|
store := NewTestStore(tb)
|
|
switch settings := any(settings).(type) {
|
|
case []TestSetting[bool]:
|
|
store.SetBooleans(settings...)
|
|
case []TestSetting[uint64]:
|
|
store.SetUInt64s(settings...)
|
|
case []TestSetting[string]:
|
|
store.SetStrings(settings...)
|
|
case []TestSetting[[]string]:
|
|
store.SetStringLists(settings...)
|
|
}
|
|
return store
|
|
}
|
|
|
|
// Lock implements [Lockable].
|
|
func (s *TestStore) Lock() error {
|
|
s.storeLock.RLock()
|
|
s.storeLockCount.Add(1)
|
|
return nil
|
|
}
|
|
|
|
// Unlock implements [Lockable].
|
|
func (s *TestStore) Unlock() {
|
|
if s.storeLockCount.Add(-1) < 0 {
|
|
s.tb.Fatal("negative storeLockCount")
|
|
}
|
|
s.storeLock.RUnlock()
|
|
}
|
|
|
|
// RegisterChangeCallback implements [Changeable].
|
|
func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), err error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
handle := s.cbs.Add(callback)
|
|
return func() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
delete(s.cbs, handle)
|
|
}, nil
|
|
}
|
|
|
|
// ReadString implements [Store].
|
|
func (s *TestStore) ReadString(key setting.Key) (string, error) {
|
|
defer s.recordRead(key, setting.StringValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return "", setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return "", err
|
|
}
|
|
str, ok := v.(string)
|
|
if !ok {
|
|
return "", fmt.Errorf("%w in ReadString: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return str, nil
|
|
}
|
|
|
|
// ReadUInt64 implements [Store].
|
|
func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) {
|
|
defer s.recordRead(key, setting.IntegerValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return 0, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return 0, err
|
|
}
|
|
u64, ok := v.(uint64)
|
|
if !ok {
|
|
return 0, fmt.Errorf("%w in ReadUInt64: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return u64, nil
|
|
}
|
|
|
|
// ReadBoolean implements [Store].
|
|
func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) {
|
|
defer s.recordRead(key, setting.BooleanValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return false, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return false, err
|
|
}
|
|
b, ok := v.(bool)
|
|
if !ok {
|
|
return false, fmt.Errorf("%w in ReadBoolean: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return b, nil
|
|
}
|
|
|
|
// ReadStringArray implements [Store].
|
|
func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) {
|
|
defer s.recordRead(key, setting.StringListValue)
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
v, ok := s.mr[key]
|
|
if !ok {
|
|
return nil, setting.ErrNotConfigured
|
|
}
|
|
if err, ok := v.(error); ok {
|
|
return nil, err
|
|
}
|
|
slice, ok := v.([]string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("%w in ReadStringArray: got %T", setting.ErrTypeMismatch, v)
|
|
}
|
|
return slice, nil
|
|
}
|
|
|
|
func (s *TestStore) recordRead(key setting.Key, typ setting.Type) {
|
|
s.readsMu.Lock()
|
|
op := testReadOperation{key, typ}
|
|
num := s.reads[op]
|
|
num++
|
|
mak.Set(&s.reads, op, num)
|
|
s.readsMu.Unlock()
|
|
}
|
|
|
|
func (s *TestStore) ResetCounters() {
|
|
s.readsMu.Lock()
|
|
clear(s.reads)
|
|
s.readsMu.Unlock()
|
|
}
|
|
|
|
// ReadsMustEqual fails the test if the actual reads differs from the specified reads.
|
|
func (s *TestStore) ReadsMustEqual(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
s.readsMu.Lock()
|
|
defer s.readsMu.Unlock()
|
|
s.readsMustContainLocked(reads...)
|
|
s.readMustNoExtraLocked(reads...)
|
|
}
|
|
|
|
// ReadsMustContain fails the test if the specified reads have not been made,
|
|
// or have been made a different number of times. It permits other values to be
|
|
// read in addition to the ones being tested.
|
|
func (s *TestStore) ReadsMustContain(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
s.readsMu.Lock()
|
|
defer s.readsMu.Unlock()
|
|
s.readsMustContainLocked(reads...)
|
|
}
|
|
|
|
func (s *TestStore) readsMustContainLocked(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
for _, r := range reads {
|
|
if numTimes := s.reads[r.operation()]; numTimes != r.NumTimes {
|
|
s.tb.Errorf("%q (%v) reads: got %v, want %v", r.Key, r.Type, numTimes, r.NumTimes)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *TestStore) readMustNoExtraLocked(reads ...TestExpectedReads) {
|
|
s.tb.Helper()
|
|
rs := make(set.Set[testReadOperation])
|
|
for i := range reads {
|
|
rs.Add(reads[i].operation())
|
|
}
|
|
for ro, num := range s.reads {
|
|
if !rs.Contains(ro) {
|
|
s.tb.Errorf("%q (%v) reads: got %v, want 0", ro.Key, ro.Type, num)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Suspend suspends the store, batching changes and notifications
|
|
// until [TestStore.Resume] is called the same number of times as Suspend.
|
|
func (s *TestStore) Suspend() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.suspendCount++; s.suspendCount == 1 {
|
|
s.mw = xmaps.Clone(s.mr)
|
|
}
|
|
}
|
|
|
|
// Resume resumes the store, applying the changes and invoking
|
|
// the change callbacks.
|
|
func (s *TestStore) Resume() {
|
|
s.storeLock.Lock()
|
|
s.mu.Lock()
|
|
switch s.suspendCount--; {
|
|
case s.suspendCount == 0:
|
|
s.mr = s.mw
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
case s.suspendCount < 0:
|
|
s.tb.Fatal("negative suspendCount")
|
|
default:
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
}
|
|
}
|
|
|
|
// SetBooleans sets the specified boolean settings in s.
|
|
func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetUInt64s sets the specified integer settings in s.
|
|
func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetStrings sets the specified string settings in s.
|
|
func (s *TestStore) SetStrings(settings ...TestSetting[string]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// SetStrings sets the specified string list settings in s.
|
|
func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) {
|
|
s.storeLock.Lock()
|
|
for _, setting := range settings {
|
|
if setting.Key == "" {
|
|
s.tb.Fatal("empty keys disallowed")
|
|
}
|
|
s.mu.Lock()
|
|
if setting.Error != nil {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Error))
|
|
} else {
|
|
mak.Set(&s.mw, setting.Key, any(setting.Value))
|
|
}
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// Delete deletes the specified settings from s.
|
|
func (s *TestStore) Delete(keys ...setting.Key) {
|
|
s.storeLock.Lock()
|
|
for _, key := range keys {
|
|
s.mu.Lock()
|
|
delete(s.mw, key)
|
|
s.mu.Unlock()
|
|
}
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
// Clear deletes all settings from s.
|
|
func (s *TestStore) Clear() {
|
|
s.storeLock.Lock()
|
|
s.mu.Lock()
|
|
clear(s.mw)
|
|
s.mu.Unlock()
|
|
s.storeLock.Unlock()
|
|
s.NotifyPolicyChanged()
|
|
}
|
|
|
|
func (s *TestStore) NotifyPolicyChanged() {
|
|
s.mu.RLock()
|
|
if s.suspendCount != 0 {
|
|
s.mu.RUnlock()
|
|
return
|
|
}
|
|
cbs := xmaps.Values(s.cbs)
|
|
s.mu.RUnlock()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(cbs))
|
|
for _, cb := range cbs {
|
|
go func() {
|
|
defer wg.Done()
|
|
cb()
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// Close closes s, notifying its users that it has expired.
|
|
func (s *TestStore) Close() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if !s.closed {
|
|
close(s.done)
|
|
s.closed = true
|
|
}
|
|
}
|
|
|
|
// Done implements [Expirable].
|
|
func (s *TestStore) Done() <-chan struct{} {
|
|
return s.done
|
|
}
|