util/syspolicy: add caching handler (#10288)
Fixes tailscale/corp#15850 Co-authored-by: Adrian Dewhurst <adrian@tailscale.com> Signed-off-by: Claire Wang <claire@tailscale.com>
This commit is contained in:
parent
719ee4415e
commit
b8a2aedccd
|
@ -0,0 +1,98 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
|
||||
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
|
||||
// otherwise the actual error is returned and the next read for that key will retry using the handler.
|
||||
type CachingHandler struct {
|
||||
mu sync.Mutex
|
||||
strings map[string]string
|
||||
uint64s map[string]uint64
|
||||
bools map[string]bool
|
||||
notFound map[string]bool
|
||||
handler Handler
|
||||
}
|
||||
|
||||
// NewCachingHandler creates a CachingHandler given a handler.
|
||||
func NewCachingHandler(handler Handler) *CachingHandler {
|
||||
return &CachingHandler{
|
||||
handler: handler,
|
||||
strings: make(map[string]string),
|
||||
uint64s: make(map[string]uint64),
|
||||
bools: make(map[string]bool),
|
||||
notFound: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// ReadString reads the policy settings value string given the key.
|
||||
// ReadString first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadString(key string) (string, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.strings[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return "", ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadString(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return "", err
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ch.strings[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadUInt64 reads the policy settings uint64 value given the key.
|
||||
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.uint64s[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return 0, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadUInt64(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return 0, err
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ch.uint64s[key] = val
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) {
|
||||
ch.mu.Lock()
|
||||
defer ch.mu.Unlock()
|
||||
if val, ok := ch.bools[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
if notFound := ch.notFound[key]; notFound {
|
||||
return false, ErrNoSuchKey
|
||||
}
|
||||
val, err := ch.handler.ReadBoolean(key)
|
||||
if errors.Is(err, ErrNoSuchKey) {
|
||||
ch.notFound[key] = true
|
||||
return false, err
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
ch.bools[key] = val
|
||||
return val, nil
|
||||
}
|
|
@ -0,0 +1,262 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHandlerReadString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue string
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue string
|
||||
wantErr error
|
||||
strings map[string]string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
strings: map[string]string{"test": "foo"},
|
||||
wantValue: "foo",
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: "foo",
|
||||
wantValue: "foo",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
s: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.strings != nil {
|
||||
cache.strings = tt.strings
|
||||
}
|
||||
got, err := cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadString(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerReadUint64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue uint64
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue uint64
|
||||
wantErr error
|
||||
uint64s map[string]uint64
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
uint64s: map[string]uint64{"test": 1},
|
||||
wantValue: 1,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: 1,
|
||||
wantValue: 1,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
u64: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.uint64s != nil {
|
||||
cache.uint64s = tt.uint64s
|
||||
}
|
||||
got, err := cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadUInt64(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestHandlerReadBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
handlerKey Key
|
||||
handlerValue bool
|
||||
handlerError error
|
||||
preserveHandler bool
|
||||
wantValue bool
|
||||
wantErr error
|
||||
bools map[string]bool
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "read existing cached values",
|
||||
key: "test",
|
||||
handlerKey: "do not read",
|
||||
bools: map[string]bool{"test": true},
|
||||
wantValue: true,
|
||||
expectedCalls: 0,
|
||||
},
|
||||
{
|
||||
name: "read existing values not cached",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerValue: true,
|
||||
wantValue: true,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "error no such key",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: ErrNoSuchKey,
|
||||
wantErr: ErrNoSuchKey,
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
key: "test",
|
||||
handlerKey: "test",
|
||||
handlerError: someOtherError,
|
||||
wantErr: someOtherError,
|
||||
preserveHandler: true,
|
||||
expectedCalls: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testHandler := &testHandler{
|
||||
t: t,
|
||||
key: tt.handlerKey,
|
||||
b: tt.handlerValue,
|
||||
err: tt.handlerError,
|
||||
}
|
||||
cache := NewCachingHandler(testHandler)
|
||||
if tt.bools != nil {
|
||||
cache.bools = tt.bools
|
||||
}
|
||||
got, err := cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if !tt.preserveHandler {
|
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil
|
||||
}
|
||||
got, err = cache.ReadBoolean(tt.key)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr)
|
||||
}
|
||||
if got != tt.wantValue {
|
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key])
|
||||
}
|
||||
if testHandler.calls != tt.expectedCalls {
|
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
|
@ -12,7 +12,7 @@ import (
|
|||
type windowsHandler struct{}
|
||||
|
||||
func init() {
|
||||
RegisterHandler(windowsHandler{})
|
||||
RegisterHandler(NewCachingHandler(windowsHandler{}))
|
||||
}
|
||||
|
||||
func (windowsHandler) ReadString(key string) (string, error) {
|
||||
|
|
|
@ -13,12 +13,13 @@ import (
|
|||
// methods that involve getting a policy value.
|
||||
// For keys and the corresponding values, check policy_keys.go.
|
||||
type testHandler struct {
|
||||
t *testing.T
|
||||
key Key
|
||||
s string
|
||||
u64 uint64
|
||||
b bool
|
||||
err error
|
||||
t *testing.T
|
||||
key Key
|
||||
s string
|
||||
u64 uint64
|
||||
b bool
|
||||
err error
|
||||
calls int // used for testing reads from cache vs. handler
|
||||
}
|
||||
|
||||
var someOtherError = errors.New("error other than not found")
|
||||
|
@ -34,6 +35,7 @@ func (th *testHandler) ReadString(key string) (string, error) {
|
|||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadString(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.s, th.err
|
||||
}
|
||||
|
||||
|
@ -41,6 +43,7 @@ func (th *testHandler) ReadUInt64(key string) (uint64, error) {
|
|||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadUint64(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.u64, th.err
|
||||
}
|
||||
|
||||
|
@ -48,6 +51,7 @@ func (th *testHandler) ReadBoolean(key string) (bool, error) {
|
|||
if key != string(th.key) {
|
||||
th.t.Errorf("ReadBool(%q) want %q", key, th.key)
|
||||
}
|
||||
th.calls++
|
||||
return th.b, th.err
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue