399 lines
11 KiB
Go
399 lines
11 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package source
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/sys/windows"
|
|
"golang.org/x/sys/windows/registry"
|
|
"tailscale.com/tstest"
|
|
"tailscale.com/util/cibuild"
|
|
"tailscale.com/util/mak"
|
|
"tailscale.com/util/syspolicy/setting"
|
|
"tailscale.com/util/winutil"
|
|
"tailscale.com/util/winutil/gp"
|
|
)
|
|
|
|
// subkeyStrings is a test type indicating that a string slice should be written
|
|
// to the registry as multiple REG_SZ values under the setting's key,
|
|
// rather than as a single REG_MULTI_SZ value under the group key.
|
|
// This is the same format as ADMX use for string lists.
|
|
type subkeyStrings []string
|
|
|
|
type testPolicyValue struct {
|
|
name setting.Key
|
|
value any
|
|
}
|
|
|
|
func TestLockUnlockPolicyStore(t *testing.T) {
|
|
// Make sure we don't leak goroutines
|
|
tstest.ResourceCheck(t)
|
|
|
|
store, err := NewMachinePlatformPolicyStore()
|
|
if err != nil {
|
|
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
|
}
|
|
|
|
t.Run("One-Goroutine", func(t *testing.T) {
|
|
if err := store.Lock(); err != nil {
|
|
t.Errorf("store.Lock(): got %v; want nil", err)
|
|
return
|
|
}
|
|
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
|
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
|
}
|
|
store.Unlock()
|
|
})
|
|
|
|
// Lock the store N times from different goroutines.
|
|
const N = 100
|
|
var unlocked atomic.Int32
|
|
t.Run("N-Goroutines", func(t *testing.T) {
|
|
var wg sync.WaitGroup
|
|
wg.Add(N)
|
|
for range N {
|
|
go func() {
|
|
if err := store.Lock(); err != nil {
|
|
t.Errorf("store.Lock(): got %v; want nil", err)
|
|
return
|
|
}
|
|
if v, err := store.ReadString("NonExistingPolicySetting"); err == nil || !errors.Is(err, setting.ErrNotConfigured) {
|
|
t.Errorf(`ReadString: got %v, %v; want "", %v`, v, err, setting.ErrNotConfigured)
|
|
}
|
|
wg.Done()
|
|
time.Sleep(10 * time.Millisecond)
|
|
unlocked.Add(1)
|
|
store.Unlock()
|
|
}()
|
|
}
|
|
|
|
// Wait until the store is locked N times.
|
|
wg.Wait()
|
|
})
|
|
|
|
// Close the store. The call should wait for all held locks to be released.
|
|
if err := store.Close(); err != nil {
|
|
t.Fatalf("(*PolicyStore).Close failed: %v", err)
|
|
}
|
|
if locked := unlocked.Load(); locked != N {
|
|
t.Errorf("locked.Load(): got %v; want %v", locked, N)
|
|
}
|
|
|
|
// Any further attempts to lock it should fail.
|
|
if err = store.Lock(); err == nil || !errors.Is(err, ErrStoreClosed) {
|
|
t.Errorf("store.Lock(): got %v; want %v", err, ErrStoreClosed)
|
|
}
|
|
}
|
|
|
|
func TestReadPolicyStore(t *testing.T) {
|
|
if !winutil.IsCurrentProcessElevated() {
|
|
t.Skipf("test requires running as elevated user")
|
|
}
|
|
tests := []struct {
|
|
name setting.Key
|
|
newValue any
|
|
legacyValue any
|
|
want any
|
|
}{
|
|
{name: "LegacyPolicy", legacyValue: "LegacyValue", want: "LegacyValue"},
|
|
{name: "StringPolicy", legacyValue: "LegacyValue", newValue: "Value", want: "Value"},
|
|
{name: "StringPolicy_Empty", legacyValue: "LegacyValue", newValue: "", want: ""},
|
|
{name: "BoolPolicy_True", newValue: true, want: true},
|
|
{name: "BoolPolicy_False", newValue: false, want: false},
|
|
{name: "UIntPolicy_1", newValue: uint32(10), want: uint64(10)}, // uint32 values should be returned as uint64
|
|
{name: "UIntPolicy_2", newValue: uint64(1 << 37), want: uint64(1 << 37)},
|
|
{name: "StringListPolicy", newValue: []string{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
|
{name: "StringListPolicy_Empty", newValue: []string{}, want: []string{}},
|
|
{name: "StringListPolicy_SubKey", newValue: subkeyStrings{"Value1", "Value2"}, want: []string{"Value1", "Value2"}},
|
|
{name: "StringListPolicy_SubKey_Empty", newValue: subkeyStrings{}, want: []string{}},
|
|
}
|
|
|
|
runTests := func(t *testing.T, userStore bool, token windows.Token) {
|
|
var hive registry.Key
|
|
if userStore {
|
|
hive = registry.CURRENT_USER
|
|
} else {
|
|
hive = registry.LOCAL_MACHINE
|
|
}
|
|
|
|
// Write policy values to the registry.
|
|
newValues := make([]testPolicyValue, 0, len(tests))
|
|
for _, tt := range tests {
|
|
if tt.newValue != nil {
|
|
newValues = append(newValues, testPolicyValue{name: tt.name, value: tt.newValue})
|
|
}
|
|
}
|
|
policiesKeyName := softwareKeyName + `\` + tsPoliciesSubkey
|
|
cleanup, err := createTestPolicyValues(hive, policiesKeyName, newValues)
|
|
if err != nil {
|
|
t.Fatalf("createTestPolicyValues failed: %v", err)
|
|
}
|
|
t.Cleanup(cleanup)
|
|
|
|
// Write legacy policy values to the registry.
|
|
legacyValues := make([]testPolicyValue, 0, len(tests))
|
|
for _, tt := range tests {
|
|
if tt.legacyValue != nil {
|
|
legacyValues = append(legacyValues, testPolicyValue{name: tt.name, value: tt.legacyValue})
|
|
}
|
|
}
|
|
legacyKeyName := softwareKeyName + `\` + tsIPNSubkey
|
|
cleanup, err = createTestPolicyValues(hive, legacyKeyName, legacyValues)
|
|
if err != nil {
|
|
t.Fatalf("createTestPolicyValues failed: %v", err)
|
|
}
|
|
t.Cleanup(cleanup)
|
|
|
|
var store *PlatformPolicyStore
|
|
if userStore {
|
|
store, err = NewUserPlatformPolicyStore(token)
|
|
} else {
|
|
store, err = NewMachinePlatformPolicyStore()
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("NewXPolicyStore failed: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if err := store.Close(); err != nil {
|
|
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
|
}
|
|
})
|
|
|
|
// testReadValues checks that [PolicyStore] returns the same values we wrote directly to the registry.
|
|
testReadValues := func(t *testing.T, withLocks bool) {
|
|
for _, tt := range tests {
|
|
t.Run(string(tt.name), func(t *testing.T) {
|
|
if userStore && tt.newValue == nil {
|
|
t.Skip("there is no legacy policies for users")
|
|
}
|
|
|
|
t.Parallel()
|
|
|
|
if withLocks {
|
|
if err := store.Lock(); err != nil {
|
|
t.Errorf("failed to acquire the lock: %v", err)
|
|
}
|
|
defer store.Unlock()
|
|
}
|
|
|
|
var got any
|
|
var err error
|
|
switch tt.want.(type) {
|
|
case string:
|
|
got, err = store.ReadString(tt.name)
|
|
case uint64:
|
|
got, err = store.ReadUInt64(tt.name)
|
|
case bool:
|
|
got, err = store.ReadBoolean(tt.name)
|
|
case []string:
|
|
got, err = store.ReadStringArray(tt.name)
|
|
}
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("got %v; want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
t.Run("NoLock", func(t *testing.T) {
|
|
testReadValues(t, false)
|
|
})
|
|
|
|
t.Run("WithLock", func(t *testing.T) {
|
|
testReadValues(t, true)
|
|
})
|
|
}
|
|
|
|
t.Run("MachineStore", func(t *testing.T) {
|
|
runTests(t, false, 0)
|
|
})
|
|
|
|
t.Run("CurrentUserStore", func(t *testing.T) {
|
|
runTests(t, true, 0)
|
|
})
|
|
|
|
t.Run("UserStoreWithToken", func(t *testing.T) {
|
|
var token windows.Token
|
|
if err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token); err != nil {
|
|
t.Fatalf("OpenProcessToken: %v", err)
|
|
}
|
|
defer token.Close()
|
|
runTests(t, true, token)
|
|
})
|
|
}
|
|
|
|
func TestPolicyStoreChangeNotifications(t *testing.T) {
|
|
if cibuild.On() {
|
|
t.Skipf("test requires running on a real Windows environment")
|
|
}
|
|
store, err := NewMachinePlatformPolicyStore()
|
|
if err != nil {
|
|
t.Fatalf("NewMachinePolicyStore failed: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
if err := store.Close(); err != nil {
|
|
t.Errorf("(*PolicyStore).Close failed: %v", err)
|
|
}
|
|
})
|
|
|
|
done := make(chan struct{})
|
|
unregister, err := store.RegisterChangeCallback(func() { close(done) })
|
|
if err != nil {
|
|
t.Fatalf("RegisterChangeCallback failed: %v", err)
|
|
}
|
|
t.Cleanup(unregister)
|
|
|
|
// RefreshMachinePolicy is a non-blocking call.
|
|
if err := gp.RefreshMachinePolicy(true); err != nil {
|
|
t.Fatalf("RefreshMachinePolicy failed: %v", err)
|
|
}
|
|
|
|
// We should receive a policy change notification when
|
|
// the Group Policy service completes policy processing.
|
|
// Otherwise, the test will eventually time out.
|
|
<-done
|
|
}
|
|
|
|
func TestSplitSettingKey(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
key setting.Key
|
|
wantPath string
|
|
wantValue string
|
|
}{
|
|
{
|
|
name: "empty",
|
|
key: "",
|
|
wantPath: ``,
|
|
wantValue: "",
|
|
},
|
|
{
|
|
name: "explicit-empty-path",
|
|
key: "/ValueName",
|
|
wantPath: ``,
|
|
wantValue: "ValueName",
|
|
},
|
|
{
|
|
name: "empty-value",
|
|
key: "Root/Sub/",
|
|
wantPath: `Root\Sub`,
|
|
wantValue: "",
|
|
},
|
|
{
|
|
name: "with-path",
|
|
key: "Root/Sub/ValueName",
|
|
wantPath: `Root\Sub`,
|
|
wantValue: "ValueName",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
gotPath, gotValue := splitSettingKey(tt.key)
|
|
if gotPath != tt.wantPath {
|
|
t.Errorf("Path: got %q, want %q", gotPath, tt.wantPath)
|
|
}
|
|
if gotValue != tt.wantValue {
|
|
t.Errorf("Value: got %q, want %q", gotValue, tt.wantPath)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func createTestPolicyValues(hive registry.Key, keyName string, values []testPolicyValue) (cleanup func(), err error) {
|
|
key, existing, err := registry.CreateKey(hive, keyName, registry.ALL_ACCESS)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var valuesToDelete map[string][]string
|
|
doCleanup := func() {
|
|
for path, values := range valuesToDelete {
|
|
if len(values) == 0 {
|
|
registry.DeleteKey(key, path)
|
|
continue
|
|
}
|
|
key, err := registry.OpenKey(key, path, windows.KEY_ALL_ACCESS)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
defer key.Close()
|
|
for _, value := range values {
|
|
key.DeleteValue(value)
|
|
}
|
|
}
|
|
|
|
key.Close()
|
|
if !existing {
|
|
registry.DeleteKey(hive, keyName)
|
|
}
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
doCleanup()
|
|
}
|
|
}()
|
|
|
|
for _, v := range values {
|
|
key, existing := key, existing
|
|
path, valueName := splitSettingKey(v.name)
|
|
if path != "" {
|
|
if key, existing, err = registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS); err != nil {
|
|
return nil, err
|
|
}
|
|
defer key.Close()
|
|
}
|
|
if values, ok := valuesToDelete[path]; len(values) > 0 || (!ok && existing) {
|
|
values = append(values, valueName)
|
|
mak.Set(&valuesToDelete, path, values)
|
|
} else if !ok {
|
|
mak.Set(&valuesToDelete, path, nil)
|
|
}
|
|
|
|
switch value := v.value.(type) {
|
|
case string:
|
|
err = key.SetStringValue(valueName, value)
|
|
case uint32:
|
|
err = key.SetDWordValue(valueName, value)
|
|
case uint64:
|
|
err = key.SetQWordValue(valueName, value)
|
|
case bool:
|
|
if value {
|
|
err = key.SetDWordValue(valueName, 1)
|
|
} else {
|
|
err = key.SetDWordValue(valueName, 0)
|
|
}
|
|
case []string:
|
|
err = key.SetStringsValue(valueName, value)
|
|
case subkeyStrings:
|
|
key, _, err := registry.CreateKey(key, valueName, windows.KEY_ALL_ACCESS)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer key.Close()
|
|
mak.Set(&valuesToDelete, strings.Trim(path+`\`+valueName, `\`), nil)
|
|
for i, value := range value {
|
|
if err := key.SetStringValue(strconv.Itoa(i), value); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
default:
|
|
err = fmt.Errorf("unsupported value: %v (%T), name: %q", value, value, v.name)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return doCleanup, nil
|
|
}
|