diff --git a/internal/client/persistent.go b/internal/client/persistent.go index ce68986f..4e09c5b8 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -13,7 +13,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -136,7 +135,8 @@ type Persistent struct { } // validate returns an error if persistent client information contains errors. -func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { +// allTags must be sorted. +func (c *Persistent) validate(allTags []string) (err error) { switch { case c.Name == "": return errors.Error("empty name") @@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { } for _, t := range c.Tags { - if !allTags.Has(t) { + _, ok := slices.BinarySearch(allTags, t) + if !ok { return fmt.Errorf("invalid tag: %q", t) } } diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index a96c3778..395ffa72 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -4,7 +4,6 @@ import ( "net/netip" "testing" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -132,7 +131,7 @@ func TestPersistent_Validate(t *testing.T) { notAllowedTag = "not_allowed_tag" ) - allowedTags := container.NewMapSet(allowedTag) + allowedTags := []string{allowedTag} testCases := []struct { name string diff --git a/internal/client/storage.go b/internal/client/storage.go index ccd6b15f..da6dda5c 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -5,13 +5,13 @@ import ( "fmt" "net" "net/netip" + "slices" "sync" "time" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" - "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" @@ -108,9 +108,6 @@ type StorageConfig struct { // Storage contains information about persistent and runtime clients. type Storage struct { - // allowedTags is a set of all allowed tags. - allowedTags *container.MapSet[string] - // mu protects indexes of persistent and runtime clients. mu *sync.Mutex @@ -132,6 +129,12 @@ type Storage struct { // done is the shutdown signaling channel. done chan struct{} + // allowedTags is a sorted list of all allowed tags. It must not be + // modified after initialization. + // + // TODO(s.chzhen): Use custom type. + allowedTags []string + // arpClientsUpdatePeriod defines how often [SourceARP] runtime client // information is updated. It must be greater than zero. arpClientsUpdatePeriod time.Duration @@ -143,8 +146,11 @@ type Storage struct { // NewStorage returns initialized client storage. conf must not be nil. func NewStorage(conf *StorageConfig) (s *Storage, err error) { + tags := slices.Clone(allowedTags) + slices.Sort(tags) + s = &Storage{ - allowedTags: container.NewMapSet(allowedTags...), + allowedTags: tags, mu: &sync.Mutex{}, index: newIndex(), runtimeIndex: newRuntimeIndex(), @@ -576,7 +582,8 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.runtimeIndex.rangeClients(f) } -// AllowedTags returns the list of available client tags. +// AllowedTags returns the list of available client tags. tags must not be +// modified. func (s *Storage) AllowedTags() (tags []string) { - return allowedTags + return s.allowedTags } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index f4914d03..58311f4c 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -4,6 +4,7 @@ import ( "net" "net/netip" "runtime" + "slices" "sync" "testing" "time" @@ -536,7 +537,7 @@ func TestStorage_Add(t *testing.T) { existingName = "existing_name" existingClientID = "existing_client_id" - allowedTag = "tag" + allowedTag = "user_admin" notAllowedTag = "not_allowed_tag" ) @@ -557,6 +558,16 @@ func TestStorage_Add(t *testing.T) { s, err := client.NewStorage(&client.StorageConfig{}) require.NoError(t, err) + tags := s.AllowedTags() + require.NotZero(t, len(tags)) + require.True(t, slices.IsSorted(tags)) + + _, ok := slices.BinarySearch(tags, allowedTag) + require.True(t, ok) + + _, ok = slices.BinarySearch(tags, notAllowedTag) + require.False(t, ok) + err = s.Add(existingClient) require.NoError(t, err) @@ -617,12 +628,21 @@ func TestStorage_Add(t *testing.T) { }, { name: "not_allowed_tag", cli: &client.Persistent{ - Name: "nont_allowed_tag", + Name: "not_allowed_tag", Tags: []string{notAllowedTag}, IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, + }, { + name: "allowed_tag", + cli: &client.Persistent{ + Name: "allowed_tag", + Tags: []string{allowedTag}, + IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")}, + UID: client.MustNewUID(), + }, + wantErrMsg: "", }} for _, tc := range testCases {