client: imp code

This commit is contained in:
Stanislav Chzhen 2024-09-26 18:04:36 +03:00
parent 79272b299a
commit 6cc4ed53a2
4 changed files with 41 additions and 14 deletions

View File

@ -13,7 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -136,7 +135,8 @@ type Persistent struct {
} }
// validate returns an error if persistent client information contains errors. // validate returns an error if persistent client information contains errors.
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) { // allTags must be sorted.
func (c *Persistent) validate(allTags []string) (err error) {
switch { switch {
case c.Name == "": case c.Name == "":
return errors.Error("empty name") return errors.Error("empty name")
@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
} }
for _, t := range c.Tags { for _, t := range c.Tags {
if !allTags.Has(t) { _, ok := slices.BinarySearch(allTags, t)
if !ok {
return fmt.Errorf("invalid tag: %q", t) return fmt.Errorf("invalid tag: %q", t)
} }
} }

View File

@ -4,7 +4,6 @@ import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -132,7 +131,7 @@ func TestPersistent_Validate(t *testing.T) {
notAllowedTag = "not_allowed_tag" notAllowedTag = "not_allowed_tag"
) )
allowedTags := container.NewMapSet(allowedTag) allowedTags := []string{allowedTag}
testCases := []struct { testCases := []struct {
name string name string

View File

@ -5,13 +5,13 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -108,9 +108,6 @@ type StorageConfig struct {
// Storage contains information about persistent and runtime clients. // Storage contains information about persistent and runtime clients.
type Storage struct { type Storage struct {
// allowedTags is a set of all allowed tags.
allowedTags *container.MapSet[string]
// mu protects indexes of persistent and runtime clients. // mu protects indexes of persistent and runtime clients.
mu *sync.Mutex mu *sync.Mutex
@ -132,6 +129,12 @@ type Storage struct {
// done is the shutdown signaling channel. // done is the shutdown signaling channel.
done chan struct{} done chan struct{}
// allowedTags is a sorted list of all allowed tags. It must not be
// modified after initialization.
//
// TODO(s.chzhen): Use custom type.
allowedTags []string
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client // arpClientsUpdatePeriod defines how often [SourceARP] runtime client
// information is updated. It must be greater than zero. // information is updated. It must be greater than zero.
arpClientsUpdatePeriod time.Duration arpClientsUpdatePeriod time.Duration
@ -143,8 +146,11 @@ type Storage struct {
// NewStorage returns initialized client storage. conf must not be nil. // NewStorage returns initialized client storage. conf must not be nil.
func NewStorage(conf *StorageConfig) (s *Storage, err error) { func NewStorage(conf *StorageConfig) (s *Storage, err error) {
tags := slices.Clone(allowedTags)
slices.Sort(tags)
s = &Storage{ s = &Storage{
allowedTags: container.NewMapSet(allowedTags...), allowedTags: tags,
mu: &sync.Mutex{}, mu: &sync.Mutex{},
index: newIndex(), index: newIndex(),
runtimeIndex: newRuntimeIndex(), runtimeIndex: newRuntimeIndex(),
@ -576,7 +582,8 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.runtimeIndex.rangeClients(f) s.runtimeIndex.rangeClients(f)
} }
// AllowedTags returns the list of available client tags. // AllowedTags returns the list of available client tags. tags must not be
// modified.
func (s *Storage) AllowedTags() (tags []string) { func (s *Storage) AllowedTags() (tags []string) {
return allowedTags return s.allowedTags
} }

View File

@ -4,6 +4,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"slices"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -536,7 +537,7 @@ func TestStorage_Add(t *testing.T) {
existingName = "existing_name" existingName = "existing_name"
existingClientID = "existing_client_id" existingClientID = "existing_client_id"
allowedTag = "tag" allowedTag = "user_admin"
notAllowedTag = "not_allowed_tag" notAllowedTag = "not_allowed_tag"
) )
@ -557,6 +558,16 @@ func TestStorage_Add(t *testing.T) {
s, err := client.NewStorage(&client.StorageConfig{}) s, err := client.NewStorage(&client.StorageConfig{})
require.NoError(t, err) require.NoError(t, err)
tags := s.AllowedTags()
require.NotZero(t, len(tags))
require.True(t, slices.IsSorted(tags))
_, ok := slices.BinarySearch(tags, allowedTag)
require.True(t, ok)
_, ok = slices.BinarySearch(tags, notAllowedTag)
require.False(t, ok)
err = s.Add(existingClient) err = s.Add(existingClient)
require.NoError(t, err) require.NoError(t, err)
@ -617,12 +628,21 @@ func TestStorage_Add(t *testing.T) {
}, { }, {
name: "not_allowed_tag", name: "not_allowed_tag",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "nont_allowed_tag", Name: "not_allowed_tag",
Tags: []string{notAllowedTag}, Tags: []string{notAllowedTag},
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
UID: client.MustNewUID(), UID: client.MustNewUID(),
}, },
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
}, {
name: "allowed_tag",
cli: &client.Persistent{
Name: "allowed_tag",
Tags: []string{allowedTag},
IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}} }}
for _, tc := range testCases { for _, tc := range testCases {