Pull request 2221: AG-27492-client-persistent-runtime-storage
Squashed commit of the following: commita2b1e829f5
Merge:5fde76bb2
65b7d232a
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 25 16:12:17 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commit5fde76bb20
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jun 21 16:58:17 2024 +0300 all: imp code commiteae49f91bc
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jun 19 20:10:55 2024 +0300 all: use storage commit2c7efa4609
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 18 20:14:34 2024 +0300 client: add tests commitd59bd7a24e
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 18 18:31:23 2024 +0300 client: add tests commit045b838823
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jun 18 15:18:08 2024 +0300 client: add tests commit702467f7ca
Merge:4abc23bf8
1c82be295
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Jun 17 18:40:43 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commit4abc23bf84
Merge:e268abf92
bed86d57f
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Jun 13 15:19:47 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commite268abf926
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Jun 13 15:19:36 2024 +0300 client: add tests commit5601cfce39
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 27 14:27:53 2024 +0300 client: runtime index commitbde3baa5da
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon May 20 14:39:35 2024 +0300 all: persistent client storage
This commit is contained in:
parent
65b7d232ab
commit
a1a31cd916
|
@ -64,7 +64,7 @@ func NewIndex() (ci *Index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add stores information about a persistent client in the index. c must be
|
// Add stores information about a persistent client in the index. c must be
|
||||||
// non-nil and contain UID.
|
// non-nil, have a UID, and contain at least one identifier.
|
||||||
func (ci *Index) Add(c *Persistent) {
|
func (ci *Index) Add(c *Persistent) {
|
||||||
if (c.UID == UID{}) {
|
if (c.UID == UID{}) {
|
||||||
panic("client must contain uid")
|
panic("client must contain uid")
|
||||||
|
|
|
@ -22,6 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
|
||||||
return ci
|
return ci
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Remove.
|
||||||
func TestClientIndex_Find(t *testing.T) {
|
func TestClientIndex_Find(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
cliIPNone = "1.2.3.4"
|
cliIPNone = "1.2.3.4"
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
"github.com/AdguardTeam/golibs/container"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -70,6 +71,7 @@ type Persistent struct {
|
||||||
// must not be nil after initialization.
|
// must not be nil after initialization.
|
||||||
BlockedServices *filtering.BlockedServices
|
BlockedServices *filtering.BlockedServices
|
||||||
|
|
||||||
|
// Name of the persistent client. Must not be empty.
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
Tags []string
|
Tags []string
|
||||||
|
@ -99,6 +101,39 @@ type Persistent struct {
|
||||||
SafeSearchConf filtering.SafeSearchConfig
|
SafeSearchConf filtering.SafeSearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validate returns an error if persistent client information contains errors.
|
||||||
|
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
||||||
|
switch {
|
||||||
|
case c.Name == "":
|
||||||
|
return errors.Error("empty name")
|
||||||
|
case c.IDsLen() == 0:
|
||||||
|
return errors.Error("id required")
|
||||||
|
case c.UID == UID{}:
|
||||||
|
return errors.Error("uid required")
|
||||||
|
}
|
||||||
|
|
||||||
|
conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid upstream servers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conf.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("client: closing upstream config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range c.Tags {
|
||||||
|
if !allTags.Has(t) {
|
||||||
|
return fmt.Errorf("invalid tag: %q", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Move to the constructor.
|
||||||
|
slices.Sort(c.Tags)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||||
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
|
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
|
||||||
for _, t := range tags {
|
for _, t := range tags {
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPersistentClient_EqualIDs(t *testing.T) {
|
func TestPersistent_EqualIDs(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
ip = "0.0.0.0"
|
ip = "0.0.0.0"
|
||||||
ip1 = "1.1.1.1"
|
ip1 = "1.1.1.1"
|
||||||
|
@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPersistent_Validate(t *testing.T) {
|
||||||
|
// TODO(s.chzhen): Add test cases.
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cli *Persistent
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "basic",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "basic",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
|
},
|
||||||
|
UID: MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "empty_name",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "",
|
||||||
|
},
|
||||||
|
wantErrMsg: "empty name",
|
||||||
|
}, {
|
||||||
|
name: "no_id",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "no_id",
|
||||||
|
},
|
||||||
|
wantErrMsg: "id required",
|
||||||
|
}, {
|
||||||
|
name: "no_uid",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "no_uid",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErrMsg: "uid required",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.cli.validate(nil)
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,255 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/container"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Storage contains information about persistent and runtime clients.
|
||||||
|
type Storage struct {
|
||||||
|
// allowedTags is a set of all allowed tags.
|
||||||
|
allowedTags *container.MapSet[string]
|
||||||
|
|
||||||
|
// mu protects indexes of persistent and runtime clients.
|
||||||
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
// index contains information about persistent clients.
|
||||||
|
index *Index
|
||||||
|
|
||||||
|
// runtimeIndex contains information about runtime clients.
|
||||||
|
runtimeIndex *RuntimeIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStorage returns initialized client storage.
|
||||||
|
func NewStorage(allowedTags *container.MapSet[string]) (s *Storage) {
|
||||||
|
return &Storage{
|
||||||
|
allowedTags: allowedTags,
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
index: NewIndex(),
|
||||||
|
runtimeIndex: NewRuntimeIndex(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add stores persistent client information or returns an error.
|
||||||
|
func (s *Storage) Add(p *Persistent) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "adding client: %w") }()
|
||||||
|
|
||||||
|
err = p.validate(s.allowedTags)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
err = s.index.ClashesUID(p)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.index.Clashes(p)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Add(p)
|
||||||
|
|
||||||
|
log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.Size())
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByName finds persistent client by name.
|
||||||
|
func (s *Storage) FindByName(name string) (c *Persistent, found bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.index.FindByName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find finds persistent client by string representation of the client ID, IP
|
||||||
|
// address, or MAC. And returns it shallow copy.
|
||||||
|
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
p, ok = s.index.Find(id)
|
||||||
|
if ok {
|
||||||
|
return p.ShallowClone(), ok
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindLoose is like [Storage.Find] but it also tries to find a persistent
|
||||||
|
// client by IP address without zone. It strips the IPv6 zone index from the
|
||||||
|
// stored IP addresses before comparing, because querylog entries don't have it.
|
||||||
|
// See TODO on [querylog.logEntry.IP].
|
||||||
|
//
|
||||||
|
// Note that multiple clients can have the same IP address with different zones.
|
||||||
|
// Therefore, the result of this method is indeterminate.
|
||||||
|
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
p, ok = s.index.Find(id)
|
||||||
|
if ok {
|
||||||
|
return p.ShallowClone(), ok
|
||||||
|
}
|
||||||
|
|
||||||
|
p = s.index.FindByIPWithoutZone(ip)
|
||||||
|
if p != nil {
|
||||||
|
return p.ShallowClone(), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindByMAC finds persistent client by MAC.
|
||||||
|
func (s *Storage) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.index.FindByMAC(mac)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveByName removes persistent client information. ok is false if no such
|
||||||
|
// client exists by that name.
|
||||||
|
func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
p, ok := s.index.FindByName(name)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.CloseUpstreams(); err != nil {
|
||||||
|
log.Error("client storage: removing client %q: %s", p.Name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Delete(p)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update finds the stored persistent client by its name and updates its
|
||||||
|
// information from p.
|
||||||
|
func (s *Storage) Update(name string, p *Persistent) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "updating client: %w") }()
|
||||||
|
|
||||||
|
err = p.validate(s.allowedTags)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
stored, ok := s.index.FindByName(name)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("client %q is not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client p has a newly generated UID, so replace it with the stored one.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Remove when frontend starts handling UIDs.
|
||||||
|
p.UID = stored.UID
|
||||||
|
|
||||||
|
err = s.index.Clashes(p)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Delete(stored)
|
||||||
|
s.index.Add(p)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RangeByName calls f for each persistent client sorted by name, unless cont is
|
||||||
|
// false.
|
||||||
|
func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.index.RangeByName(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the number of persistent clients.
|
||||||
|
func (s *Storage) Size() (n int) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.index.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseUpstreams closes upstream configurations of persistent clients.
|
||||||
|
func (s *Storage) CloseUpstreams() (err error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.index.CloseUpstreams()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
||||||
|
// client exists, returns nil.
|
||||||
|
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.runtimeIndex.Client(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRuntime saves the runtime client information in the storage. IP address
|
||||||
|
// of a client must be unique. rc must not be nil.
|
||||||
|
func (s *Storage) AddRuntime(rc *Runtime) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.runtimeIndex.Add(rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SizeRuntime returns the number of the runtime clients.
|
||||||
|
func (s *Storage) SizeRuntime() (n int) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.runtimeIndex.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||||
|
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.runtimeIndex.Range(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRuntime removes the runtime client by ip.
|
||||||
|
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.runtimeIndex.Delete(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteBySource removes all runtime clients that have information only from
|
||||||
|
// the specified source and returns the number of removed clients.
|
||||||
|
func (s *Storage) DeleteBySource(src Source) (n int) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
return s.runtimeIndex.DeleteBySource(src)
|
||||||
|
}
|
|
@ -0,0 +1,473 @@
|
||||||
|
package client_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newStorage is a helper function that returns a client storage filled with
|
||||||
|
// persistent clients from the m. It also generates a UID for each client.
|
||||||
|
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
s = client.NewStorage(nil)
|
||||||
|
|
||||||
|
for _, c := range m {
|
||||||
|
c.UID = client.MustNewUID()
|
||||||
|
require.NoError(tb, s.Add(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||||
|
// error.
|
||||||
|
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||||
|
mac, err := net.ParseMAC(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mac
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Add(t *testing.T) {
|
||||||
|
const (
|
||||||
|
existingName = "existing_name"
|
||||||
|
existingClientID = "existing_client_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
existingClientUID = client.MustNewUID()
|
||||||
|
existingIP = netip.MustParseAddr("1.2.3.4")
|
||||||
|
existingSubnet = netip.MustParsePrefix("1.2.3.0/24")
|
||||||
|
)
|
||||||
|
|
||||||
|
existingClient := &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{existingIP},
|
||||||
|
Subnets: []netip.Prefix{existingSubnet},
|
||||||
|
ClientIDs: []string{existingClientID},
|
||||||
|
UID: existingClientUID,
|
||||||
|
}
|
||||||
|
|
||||||
|
s := client.NewStorage(nil)
|
||||||
|
err := s.Add(existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cli *client.Persistent
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "basic",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "basic",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_uid",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "no_uid",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||||
|
UID: existingClientUID,
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" uses the same uid`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_name",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client uses the same name "existing_name"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_ip",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_ip",
|
||||||
|
IPs: []netip.Addr{existingIP},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_subnet",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_subnet",
|
||||||
|
Subnets: []netip.Prefix{existingSubnet},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
|
`uses the same subnet "1.2.3.0/24"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_client_id",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_client_id",
|
||||||
|
ClientIDs: []string{existingClientID},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
|
`uses the same ClientID "existing_client_id"`,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err = s.Add(tc.cli)
|
||||||
|
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_RemoveByName(t *testing.T) {
|
||||||
|
const (
|
||||||
|
existingName = "existing_name"
|
||||||
|
)
|
||||||
|
|
||||||
|
existingClient := &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s := client.NewStorage(nil)
|
||||||
|
err := s.Add(existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want assert.BoolAssertionFunc
|
||||||
|
name string
|
||||||
|
cliName string
|
||||||
|
}{{
|
||||||
|
name: "existing_client",
|
||||||
|
cliName: existingName,
|
||||||
|
want: assert.True,
|
||||||
|
}, {
|
||||||
|
name: "non_existing_client",
|
||||||
|
cliName: "non_existing_client",
|
||||||
|
want: assert.False,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tc.want(t, s.RemoveByName(tc.cliName))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("duplicate_remove", func(t *testing.T) {
|
||||||
|
s = client.NewStorage(nil)
|
||||||
|
err = s.Add(existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, s.RemoveByName(existingName))
|
||||||
|
assert.False(t, s.RemoveByName(existingName))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Find(t *testing.T) {
|
||||||
|
const (
|
||||||
|
cliIPNone = "1.2.3.4"
|
||||||
|
cliIP1 = "1.1.1.1"
|
||||||
|
cliIP2 = "2.2.2.2"
|
||||||
|
|
||||||
|
cliIPv6 = "1:2:3::4"
|
||||||
|
|
||||||
|
cliSubnet = "2.2.2.0/24"
|
||||||
|
cliSubnetIP = "2.2.2.222"
|
||||||
|
|
||||||
|
cliID = "client-id"
|
||||||
|
cliMAC = "11:11:11:11:11:11"
|
||||||
|
|
||||||
|
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
|
||||||
|
linkLocalSubnet = "fe80::/16"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientWithBothFams = &client.Persistent{
|
||||||
|
Name: "client1",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr(cliIP1),
|
||||||
|
netip.MustParseAddr(cliIPv6),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientWithSubnet = &client.Persistent{
|
||||||
|
Name: "client2",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||||
|
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientWithMAC = &client.Persistent{
|
||||||
|
Name: "client_with_mac",
|
||||||
|
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientWithID = &client.Persistent{
|
||||||
|
Name: "client_with_id",
|
||||||
|
ClientIDs: []string{cliID},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientLinkLocal = &client.Persistent{
|
||||||
|
Name: "client_link_local",
|
||||||
|
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*client.Persistent{
|
||||||
|
clientWithBothFams,
|
||||||
|
clientWithSubnet,
|
||||||
|
clientWithMAC,
|
||||||
|
clientWithID,
|
||||||
|
clientLinkLocal,
|
||||||
|
}
|
||||||
|
s := newStorage(t, clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *client.Persistent
|
||||||
|
name string
|
||||||
|
ids []string
|
||||||
|
}{{
|
||||||
|
name: "ipv4_ipv6",
|
||||||
|
ids: []string{cliIP1, cliIPv6},
|
||||||
|
want: clientWithBothFams,
|
||||||
|
}, {
|
||||||
|
name: "ipv4_subnet",
|
||||||
|
ids: []string{cliIP2, cliSubnetIP},
|
||||||
|
want: clientWithSubnet,
|
||||||
|
}, {
|
||||||
|
name: "mac",
|
||||||
|
ids: []string{cliMAC},
|
||||||
|
want: clientWithMAC,
|
||||||
|
}, {
|
||||||
|
name: "client_id",
|
||||||
|
ids: []string{cliID},
|
||||||
|
want: clientWithID,
|
||||||
|
}, {
|
||||||
|
name: "client_link_local_subnet",
|
||||||
|
ids: []string{linkLocalIP},
|
||||||
|
want: clientLinkLocal,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
for _, id := range tc.ids {
|
||||||
|
c, ok := s.Find(id)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("not_found", func(t *testing.T) {
|
||||||
|
_, ok := s.Find(cliIPNone)
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_FindLoose(t *testing.T) {
|
||||||
|
const (
|
||||||
|
nonExistingClientID = "client_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
|
||||||
|
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clientNoZone = &client.Persistent{
|
||||||
|
Name: "client",
|
||||||
|
IPs: []netip.Addr{ip},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientWithZone = &client.Persistent{
|
||||||
|
Name: "client_with_zone",
|
||||||
|
IPs: []netip.Addr{ipWithZone},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
s := newStorage(
|
||||||
|
t,
|
||||||
|
[]*client.Persistent{
|
||||||
|
clientNoZone,
|
||||||
|
clientWithZone,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
ip netip.Addr
|
||||||
|
want assert.BoolAssertionFunc
|
||||||
|
wantCli *client.Persistent
|
||||||
|
name string
|
||||||
|
}{{
|
||||||
|
name: "without_zone",
|
||||||
|
ip: ip,
|
||||||
|
wantCli: clientNoZone,
|
||||||
|
want: assert.True,
|
||||||
|
}, {
|
||||||
|
name: "with_zone",
|
||||||
|
ip: ipWithZone,
|
||||||
|
wantCli: clientWithZone,
|
||||||
|
want: assert.True,
|
||||||
|
}, {
|
||||||
|
name: "zero_address",
|
||||||
|
ip: netip.Addr{},
|
||||||
|
wantCli: nil,
|
||||||
|
want: assert.False,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID)
|
||||||
|
assert.Equal(t, tc.wantCli, c)
|
||||||
|
tc.want(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Update(t *testing.T) {
|
||||||
|
const (
|
||||||
|
clientName = "client_name"
|
||||||
|
obstructingName = "obstructing_name"
|
||||||
|
obstructingClientID = "obstructing_client_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
obstructingIP = netip.MustParseAddr("1.2.3.4")
|
||||||
|
obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24")
|
||||||
|
)
|
||||||
|
|
||||||
|
obstructingClient := &client.Persistent{
|
||||||
|
Name: obstructingName,
|
||||||
|
IPs: []netip.Addr{obstructingIP},
|
||||||
|
Subnets: []netip.Prefix{obstructingSubnet},
|
||||||
|
ClientIDs: []string{obstructingClientID},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientToUpdate := &client.Persistent{
|
||||||
|
Name: clientName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cli *client.Persistent
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "basic",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "basic",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_name",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: obstructingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `updating client: another client uses the same name "obstructing_name"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_ip",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_ip",
|
||||||
|
IPs: []netip.Addr{obstructingIP},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_subnet",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_subnet",
|
||||||
|
Subnets: []netip.Prefix{obstructingSubnet},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `updating client: another client "obstructing_name" ` +
|
||||||
|
`uses the same subnet "1.2.3.0/24"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_client_id",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_client_id",
|
||||||
|
ClientIDs: []string{obstructingClientID},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `updating client: another client "obstructing_name" ` +
|
||||||
|
`uses the same ClientID "obstructing_client_id"`,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := newStorage(
|
||||||
|
t,
|
||||||
|
[]*client.Persistent{
|
||||||
|
clientToUpdate,
|
||||||
|
obstructingClient,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
err := s.Update(clientName, tc.cli)
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_RangeByName(t *testing.T) {
|
||||||
|
sortedClients := []*client.Persistent{{
|
||||||
|
Name: "clientA",
|
||||||
|
ClientIDs: []string{"A"},
|
||||||
|
}, {
|
||||||
|
Name: "clientB",
|
||||||
|
ClientIDs: []string{"B"},
|
||||||
|
}, {
|
||||||
|
Name: "clientC",
|
||||||
|
ClientIDs: []string{"C"},
|
||||||
|
}, {
|
||||||
|
Name: "clientD",
|
||||||
|
ClientIDs: []string{"D"},
|
||||||
|
}, {
|
||||||
|
Name: "clientE",
|
||||||
|
ClientIDs: []string{"E"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
want []*client.Persistent
|
||||||
|
}{{
|
||||||
|
name: "basic",
|
||||||
|
want: sortedClients,
|
||||||
|
}, {
|
||||||
|
name: "nil",
|
||||||
|
want: nil,
|
||||||
|
}, {
|
||||||
|
name: "one_element",
|
||||||
|
want: sortedClients[:1],
|
||||||
|
}, {
|
||||||
|
name: "two_elements",
|
||||||
|
want: sortedClients[:2],
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := newStorage(t, tc.want)
|
||||||
|
|
||||||
|
var got []*client.Persistent
|
||||||
|
s.RangeByName(func(c *client.Persistent) (cont bool) {
|
||||||
|
got = append(got, c)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -317,7 +316,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
||||||
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
|
clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) {
|
||||||
objs = append(objs, &clientObject{
|
objs = append(objs, &clientObject{
|
||||||
Name: cli.Name,
|
Name: cli.Name,
|
||||||
|
|
||||||
|
@ -344,14 +343,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
|
||||||
// above loop can generate different orderings when writing to the config
|
|
||||||
// file: this produces lots of diffs in config files, so sort objects by
|
|
||||||
// name before writing.
|
|
||||||
slices.SortStableFunc(objs, func(a, b *clientObject) (res int) {
|
|
||||||
return strings.Compare(a.Name, b.Name)
|
|
||||||
})
|
|
||||||
|
|
||||||
return objs
|
return objs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue