From a1a31cd916773848626aa097c7a45432f717c004 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Wed, 26 Jun 2024 14:30:02 +0300 Subject: [PATCH] Pull request 2221: AG-27492-client-persistent-runtime-storage Squashed commit of the following: commit a2b1e829f57fa7411354d882ec67d0c8736efbac Merge: 5fde76bb2 65b7d232a Author: Stanislav Chzhen Date: Tue Jun 25 16:12:17 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commit 5fde76bb20f818f052fe89dc90c2b3ea790da4d2 Author: Stanislav Chzhen Date: Fri Jun 21 16:58:17 2024 +0300 all: imp code commit eae49f91bc1b5eedae3d03b0b6c782afa11896d8 Author: Stanislav Chzhen Date: Wed Jun 19 20:10:55 2024 +0300 all: use storage commit 2c7efa46099d9b8ffe297ce247aff0aa8f45dff7 Author: Stanislav Chzhen Date: Tue Jun 18 20:14:34 2024 +0300 client: add tests commit d59bd7a24e273e58737c3efa832adabc57495bed Author: Stanislav Chzhen Date: Tue Jun 18 18:31:23 2024 +0300 client: add tests commit 045b83882380a8e181f6892cc3245944e4c9fd52 Author: Stanislav Chzhen Date: Tue Jun 18 15:18:08 2024 +0300 client: add tests commit 702467f7cadf3574c4a1b7b441ac02e26581bfcf Merge: 4abc23bf8 1c82be295 Author: Stanislav Chzhen Date: Mon Jun 17 18:40:43 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commit 4abc23bf84dd8de02a1b805afba4d5a724b39d0c Merge: e268abf92 bed86d57f Author: Stanislav Chzhen Date: Thu Jun 13 15:19:47 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-runtime-storage commit e268abf9268aef7a5386b5e126b01b249c590f49 Author: Stanislav Chzhen Date: Thu Jun 13 15:19:36 2024 +0300 client: add tests commit 5601cfce39599337aaf04688ffe2b14b49f856e5 Author: Stanislav Chzhen Date: Mon May 27 14:27:53 2024 +0300 client: runtime index commit bde3baa5da85dd5404f78bd79a6a3e85c55cf7fc Author: Stanislav Chzhen Date: Mon May 20 14:39:35 2024 +0300 all: persistent client storage --- internal/client/index.go | 2 +- internal/client/index_internal_test.go | 1 + internal/client/persistent.go | 35 ++ internal/client/persistent_internal_test.go | 51 ++- internal/client/storage.go | 255 +++++++++++ internal/client/storage_test.go | 473 ++++++++++++++++++++ internal/home/clients.go | 11 +- 7 files changed, 816 insertions(+), 12 deletions(-) create mode 100644 internal/client/storage.go create mode 100644 internal/client/storage_test.go diff --git a/internal/client/index.go b/internal/client/index.go index 63ae690e..8cdbad13 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -64,7 +64,7 @@ func NewIndex() (ci *Index) { } // Add stores information about a persistent client in the index. c must be -// non-nil and contain UID. +// non-nil, have a UID, and contain at least one identifier. func (ci *Index) Add(c *Persistent) { if (c.UID == UID{}) { panic("client must contain uid") diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index 38c0df15..f51f461c 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -22,6 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) { return ci } +// TODO(s.chzhen): Remove. func TestClientIndex_Find(t *testing.T) { const ( cliIPNone = "1.2.3.4" diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 52f3aacc..b573b0fe 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" @@ -70,6 +71,7 @@ type Persistent struct { // must not be nil after initialization. BlockedServices *filtering.BlockedServices + // Name of the persistent client. Must not be empty. Name string Tags []string @@ -99,6 +101,39 @@ type Persistent struct { SafeSearchConf filtering.SafeSearchConfig } +// validate returns an error if persistent client information contains errors. +func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { + switch { + case c.Name == "": + return errors.Error("empty name") + case c.IDsLen() == 0: + return errors.Error("id required") + case c.UID == UID{}: + return errors.Error("uid required") + } + + conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{}) + if err != nil { + return fmt.Errorf("invalid upstream servers: %w", err) + } + + err = conf.Close() + if err != nil { + log.Error("client: closing upstream config: %s", err) + } + + for _, t := range c.Tags { + if !allTags.Has(t) { + return fmt.Errorf("invalid tag: %q", t) + } + } + + // TODO(s.chzhen): Move to the constructor. + slices.Sort(c.Tags) + + return nil +} + // SetTags sets the tags if they are known, otherwise logs an unknown tag. func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) { for _, t := range tags { diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index 76da1e4b..89190285 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -1,13 +1,15 @@ package client import ( + "net/netip" "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestPersistentClient_EqualIDs(t *testing.T) { +func TestPersistent_EqualIDs(t *testing.T) { const ( ip = "0.0.0.0" ip1 = "1.1.1.1" @@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) { }) } } + +func TestPersistent_Validate(t *testing.T) { + // TODO(s.chzhen): Add test cases. + testCases := []struct { + name string + cli *Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &Persistent{ + Name: "basic", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + UID: MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "empty_name", + cli: &Persistent{ + Name: "", + }, + wantErrMsg: "empty name", + }, { + name: "no_id", + cli: &Persistent{ + Name: "no_id", + }, + wantErrMsg: "id required", + }, { + name: "no_uid", + cli: &Persistent{ + Name: "no_uid", + IPs: []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + }, + }, + wantErrMsg: "uid required", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.cli.validate(nil) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} diff --git a/internal/client/storage.go b/internal/client/storage.go new file mode 100644 index 00000000..d9abc529 --- /dev/null +++ b/internal/client/storage.go @@ -0,0 +1,255 @@ +package client + +import ( + "fmt" + "net" + "net/netip" + "sync" + + "github.com/AdguardTeam/golibs/container" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" +) + +// Storage contains information about persistent and runtime clients. +type Storage struct { + // allowedTags is a set of all allowed tags. + allowedTags *container.MapSet[string] + + // mu protects indexes of persistent and runtime clients. + mu *sync.Mutex + + // index contains information about persistent clients. + index *Index + + // runtimeIndex contains information about runtime clients. + runtimeIndex *RuntimeIndex +} + +// NewStorage returns initialized client storage. +func NewStorage(allowedTags *container.MapSet[string]) (s *Storage) { + return &Storage{ + allowedTags: allowedTags, + mu: &sync.Mutex{}, + index: NewIndex(), + runtimeIndex: NewRuntimeIndex(), + } +} + +// Add stores persistent client information or returns an error. +func (s *Storage) Add(p *Persistent) (err error) { + defer func() { err = errors.Annotate(err, "adding client: %w") }() + + err = p.validate(s.allowedTags) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + err = s.index.ClashesUID(p) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + err = s.index.Clashes(p) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + s.index.Add(p) + + log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.Size()) + + return nil +} + +// FindByName finds persistent client by name. +func (s *Storage) FindByName(name string) (c *Persistent, found bool) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.FindByName(name) +} + +// Find finds persistent client by string representation of the client ID, IP +// address, or MAC. And returns it shallow copy. +func (s *Storage) Find(id string) (p *Persistent, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + p, ok = s.index.Find(id) + if ok { + return p.ShallowClone(), ok + } + + return nil, false +} + +// FindLoose is like [Storage.Find] but it also tries to find a persistent +// client by IP address without zone. It strips the IPv6 zone index from the +// stored IP addresses before comparing, because querylog entries don't have it. +// See TODO on [querylog.logEntry.IP]. +// +// Note that multiple clients can have the same IP address with different zones. +// Therefore, the result of this method is indeterminate. +func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + p, ok = s.index.Find(id) + if ok { + return p.ShallowClone(), ok + } + + p = s.index.FindByIPWithoutZone(ip) + if p != nil { + return p.ShallowClone(), true + } + + return nil, false +} + +// FindByMAC finds persistent client by MAC. +func (s *Storage) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.FindByMAC(mac) +} + +// RemoveByName removes persistent client information. ok is false if no such +// client exists by that name. +func (s *Storage) RemoveByName(name string) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + p, ok := s.index.FindByName(name) + if !ok { + return false + } + + if err := p.CloseUpstreams(); err != nil { + log.Error("client storage: removing client %q: %s", p.Name, err) + } + + s.index.Delete(p) + + return true +} + +// Update finds the stored persistent client by its name and updates its +// information from p. +func (s *Storage) Update(name string, p *Persistent) (err error) { + defer func() { err = errors.Annotate(err, "updating client: %w") }() + + err = p.validate(s.allowedTags) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + stored, ok := s.index.FindByName(name) + if !ok { + return fmt.Errorf("client %q is not found", name) + } + + // Client p has a newly generated UID, so replace it with the stored one. + // + // TODO(s.chzhen): Remove when frontend starts handling UIDs. + p.UID = stored.UID + + err = s.index.Clashes(p) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + s.index.Delete(stored) + s.index.Add(p) + + return nil +} + +// RangeByName calls f for each persistent client sorted by name, unless cont is +// false. +func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.index.RangeByName(f) +} + +// Size returns the number of persistent clients. +func (s *Storage) Size() (n int) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.Size() +} + +// CloseUpstreams closes upstream configurations of persistent clients. +func (s *Storage) CloseUpstreams() (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.CloseUpstreams() +} + +// ClientRuntime returns a copy of the saved runtime client by ip. If no such +// client exists, returns nil. +func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.Client(ip) +} + +// AddRuntime saves the runtime client information in the storage. IP address +// of a client must be unique. rc must not be nil. +func (s *Storage) AddRuntime(rc *Runtime) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Add(rc) +} + +// SizeRuntime returns the number of the runtime clients. +func (s *Storage) SizeRuntime() (n int) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.Size() +} + +// RangeRuntime calls f for each runtime client in an undefined order. +func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Range(f) +} + +// DeleteRuntime removes the runtime client by ip. +func (s *Storage) DeleteRuntime(ip netip.Addr) { + s.mu.Lock() + defer s.mu.Unlock() + + s.runtimeIndex.Delete(ip) +} + +// DeleteBySource removes all runtime clients that have information only from +// the specified source and returns the number of removed clients. +func (s *Storage) DeleteBySource(src Source) (n int) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.runtimeIndex.DeleteBySource(src) +} diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go new file mode 100644 index 00000000..fef02108 --- /dev/null +++ b/internal/client/storage_test.go @@ -0,0 +1,473 @@ +package client_test + +import ( + "net" + "net/netip" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newStorage is a helper function that returns a client storage filled with +// persistent clients from the m. It also generates a UID for each client. +func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { + tb.Helper() + + s = client.NewStorage(nil) + + for _, c := range m { + c.UID = client.MustNewUID() + require.NoError(tb, s.Add(c)) + } + + return s +} + +// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an +// error. +func mustParseMAC(s string) (mac net.HardwareAddr) { + mac, err := net.ParseMAC(s) + if err != nil { + panic(err) + } + + return mac +} + +func TestStorage_Add(t *testing.T) { + const ( + existingName = "existing_name" + existingClientID = "existing_client_id" + ) + + var ( + existingClientUID = client.MustNewUID() + existingIP = netip.MustParseAddr("1.2.3.4") + existingSubnet = netip.MustParsePrefix("1.2.3.0/24") + ) + + existingClient := &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{existingIP}, + Subnets: []netip.Prefix{existingSubnet}, + ClientIDs: []string{existingClientID}, + UID: existingClientUID, + } + + s := client.NewStorage(nil) + err := s.Add(existingClient) + require.NoError(t, err) + + testCases := []struct { + name string + cli *client.Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &client.Persistent{ + Name: "basic", + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "duplicate_uid", + cli: &client.Persistent{ + Name: "no_uid", + IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + UID: existingClientUID, + }, + wantErrMsg: `adding client: another client "existing_name" uses the same uid`, + }, { + name: "duplicate_name", + cli: &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: another client uses the same name "existing_name"`, + }, { + name: "duplicate_ip", + cli: &client.Persistent{ + Name: "duplicate_ip", + IPs: []netip.Addr{existingIP}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`, + }, { + name: "duplicate_subnet", + cli: &client.Persistent{ + Name: "duplicate_subnet", + Subnets: []netip.Prefix{existingSubnet}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: another client "existing_name" ` + + `uses the same subnet "1.2.3.0/24"`, + }, { + name: "duplicate_client_id", + cli: &client.Persistent{ + Name: "duplicate_client_id", + ClientIDs: []string{existingClientID}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: another client "existing_name" ` + + `uses the same ClientID "existing_client_id"`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err = s.Add(tc.cli) + + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} + +func TestStorage_RemoveByName(t *testing.T) { + const ( + existingName = "existing_name" + ) + + existingClient := &client.Persistent{ + Name: existingName, + IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, + UID: client.MustNewUID(), + } + + s := client.NewStorage(nil) + err := s.Add(existingClient) + require.NoError(t, err) + + testCases := []struct { + want assert.BoolAssertionFunc + name string + cliName string + }{{ + name: "existing_client", + cliName: existingName, + want: assert.True, + }, { + name: "non_existing_client", + cliName: "non_existing_client", + want: assert.False, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.want(t, s.RemoveByName(tc.cliName)) + }) + } + + t.Run("duplicate_remove", func(t *testing.T) { + s = client.NewStorage(nil) + err = s.Add(existingClient) + require.NoError(t, err) + + assert.True(t, s.RemoveByName(existingName)) + assert.False(t, s.RemoveByName(existingName)) + }) +} + +func TestStorage_Find(t *testing.T) { + const ( + cliIPNone = "1.2.3.4" + cliIP1 = "1.1.1.1" + cliIP2 = "2.2.2.2" + + cliIPv6 = "1:2:3::4" + + cliSubnet = "2.2.2.0/24" + cliSubnetIP = "2.2.2.222" + + cliID = "client-id" + cliMAC = "11:11:11:11:11:11" + + linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0" + linkLocalSubnet = "fe80::/16" + ) + + var ( + clientWithBothFams = &client.Persistent{ + Name: "client1", + IPs: []netip.Addr{ + netip.MustParseAddr(cliIP1), + netip.MustParseAddr(cliIPv6), + }, + } + + clientWithSubnet = &client.Persistent{ + Name: "client2", + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + } + + clientWithMAC = &client.Persistent{ + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + } + + clientWithID = &client.Persistent{ + Name: "client_with_id", + ClientIDs: []string{cliID}, + } + + clientLinkLocal = &client.Persistent{ + Name: "client_link_local", + Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)}, + } + ) + + clients := []*client.Persistent{ + clientWithBothFams, + clientWithSubnet, + clientWithMAC, + clientWithID, + clientLinkLocal, + } + s := newStorage(t, clients) + + testCases := []struct { + want *client.Persistent + name string + ids []string + }{{ + name: "ipv4_ipv6", + ids: []string{cliIP1, cliIPv6}, + want: clientWithBothFams, + }, { + name: "ipv4_subnet", + ids: []string{cliIP2, cliSubnetIP}, + want: clientWithSubnet, + }, { + name: "mac", + ids: []string{cliMAC}, + want: clientWithMAC, + }, { + name: "client_id", + ids: []string{cliID}, + want: clientWithID, + }, { + name: "client_link_local_subnet", + ids: []string{linkLocalIP}, + want: clientLinkLocal, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, id := range tc.ids { + c, ok := s.Find(id) + require.True(t, ok) + + assert.Equal(t, tc.want, c) + } + }) + } + + t.Run("not_found", func(t *testing.T) { + _, ok := s.Find(cliIPNone) + assert.False(t, ok) + }) +} + +func TestStorage_FindLoose(t *testing.T) { + const ( + nonExistingClientID = "client_id" + ) + + var ( + ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1") + ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2") + ) + + var ( + clientNoZone = &client.Persistent{ + Name: "client", + IPs: []netip.Addr{ip}, + } + + clientWithZone = &client.Persistent{ + Name: "client_with_zone", + IPs: []netip.Addr{ipWithZone}, + } + ) + + s := newStorage( + t, + []*client.Persistent{ + clientNoZone, + clientWithZone, + }, + ) + + testCases := []struct { + ip netip.Addr + want assert.BoolAssertionFunc + wantCli *client.Persistent + name string + }{{ + name: "without_zone", + ip: ip, + wantCli: clientNoZone, + want: assert.True, + }, { + name: "with_zone", + ip: ipWithZone, + wantCli: clientWithZone, + want: assert.True, + }, { + name: "zero_address", + ip: netip.Addr{}, + wantCli: nil, + want: assert.False, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID) + assert.Equal(t, tc.wantCli, c) + tc.want(t, ok) + }) + } +} + +func TestStorage_Update(t *testing.T) { + const ( + clientName = "client_name" + obstructingName = "obstructing_name" + obstructingClientID = "obstructing_client_id" + ) + + var ( + obstructingIP = netip.MustParseAddr("1.2.3.4") + obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24") + ) + + obstructingClient := &client.Persistent{ + Name: obstructingName, + IPs: []netip.Addr{obstructingIP}, + Subnets: []netip.Prefix{obstructingSubnet}, + ClientIDs: []string{obstructingClientID}, + } + + clientToUpdate := &client.Persistent{ + Name: clientName, + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + } + + testCases := []struct { + name string + cli *client.Persistent + wantErrMsg string + }{{ + name: "basic", + cli: &client.Persistent{ + Name: "basic", + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "", + }, { + name: "duplicate_name", + cli: &client.Persistent{ + Name: obstructingName, + IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, + UID: client.MustNewUID(), + }, + wantErrMsg: `updating client: another client uses the same name "obstructing_name"`, + }, { + name: "duplicate_ip", + cli: &client.Persistent{ + Name: "duplicate_ip", + IPs: []netip.Addr{obstructingIP}, + UID: client.MustNewUID(), + }, + wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`, + }, { + name: "duplicate_subnet", + cli: &client.Persistent{ + Name: "duplicate_subnet", + Subnets: []netip.Prefix{obstructingSubnet}, + UID: client.MustNewUID(), + }, + wantErrMsg: `updating client: another client "obstructing_name" ` + + `uses the same subnet "1.2.3.0/24"`, + }, { + name: "duplicate_client_id", + cli: &client.Persistent{ + Name: "duplicate_client_id", + ClientIDs: []string{obstructingClientID}, + UID: client.MustNewUID(), + }, + wantErrMsg: `updating client: another client "obstructing_name" ` + + `uses the same ClientID "obstructing_client_id"`, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := newStorage( + t, + []*client.Persistent{ + clientToUpdate, + obstructingClient, + }, + ) + + err := s.Update(clientName, tc.cli) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} + +func TestStorage_RangeByName(t *testing.T) { + sortedClients := []*client.Persistent{{ + Name: "clientA", + ClientIDs: []string{"A"}, + }, { + Name: "clientB", + ClientIDs: []string{"B"}, + }, { + Name: "clientC", + ClientIDs: []string{"C"}, + }, { + Name: "clientD", + ClientIDs: []string{"D"}, + }, { + Name: "clientE", + ClientIDs: []string{"E"}, + }} + + testCases := []struct { + name string + want []*client.Persistent + }{{ + name: "basic", + want: sortedClients, + }, { + name: "nil", + want: nil, + }, { + name: "one_element", + want: sortedClients[:1], + }, { + name: "two_elements", + want: sortedClients[:2], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := newStorage(t, tc.want) + + var got []*client.Persistent + s.RangeByName(func(c *client.Persistent) (cont bool) { + got = append(got, c) + + return true + }) + + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/home/clients.go b/internal/home/clients.go index 3616cb0b..9d39451d 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -5,7 +5,6 @@ import ( "net" "net/netip" "slices" - "strings" "sync" "time" @@ -317,7 +316,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { defer clients.lock.Unlock() objs = make([]*clientObject, 0, clients.clientIndex.Size()) - clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) { + clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) { objs = append(objs, &clientObject{ Name: cli.Name, @@ -344,14 +343,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { return true }) - // Maps aren't guaranteed to iterate in the same order each time, so the - // above loop can generate different orderings when writing to the config - // file: this produces lots of diffs in config files, so sort objects by - // name before writing. - slices.SortStableFunc(objs, func(a, b *clientObject) (res int) { - return strings.Compare(a.Name, b.Name) - }) - return objs }