client: add tests
This commit is contained in:
parent
702467f7ca
commit
045b838823
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
"github.com/AdguardTeam/golibs/container"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -98,6 +99,39 @@ type Persistent struct {
|
||||||
SafeSearchConf filtering.SafeSearchConfig
|
SafeSearchConf filtering.SafeSearchConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate returns an error if persistent client information contains errors.
|
||||||
|
func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) {
|
||||||
|
switch {
|
||||||
|
case c.Name == "":
|
||||||
|
return errors.Error("empty name")
|
||||||
|
case c.IDsLen() == 0:
|
||||||
|
return errors.Error("id required")
|
||||||
|
case c.UID == UID{}:
|
||||||
|
return errors.Error("uid required")
|
||||||
|
}
|
||||||
|
|
||||||
|
conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid upstream servers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conf.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("client: closing upstream config: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range c.Tags {
|
||||||
|
if !allTags.Has(t) {
|
||||||
|
return fmt.Errorf("invalid tag: %q", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Move to the constructor.
|
||||||
|
slices.Sort(c.Tags)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||||
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
|
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
|
||||||
for _, t := range tags {
|
for _, t := range tags {
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPersistentClient_EqualIDs(t *testing.T) {
|
func TestPersistent_EqualIDs(t *testing.T) {
|
||||||
const (
|
const (
|
||||||
ip = "0.0.0.0"
|
ip = "0.0.0.0"
|
||||||
ip1 = "1.1.1.1"
|
ip1 = "1.1.1.1"
|
||||||
|
@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPersistent_Validate(t *testing.T) {
|
||||||
|
// TODO(s.chzhen): Add test cases.
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
cli *Persistent
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "basic",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "basic",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
|
},
|
||||||
|
UID: MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "empty_name",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "",
|
||||||
|
},
|
||||||
|
wantErrMsg: "empty name",
|
||||||
|
}, {
|
||||||
|
name: "no_id",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "no_id",
|
||||||
|
},
|
||||||
|
wantErrMsg: "id required",
|
||||||
|
}, {
|
||||||
|
name: "no_uid",
|
||||||
|
cli: &Persistent{
|
||||||
|
Name: "no_uid",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr("1.2.3.4"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErrMsg: "uid required",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
err := tc.cli.Validate(nil)
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,23 +1,14 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Storage contains information about persistent and runtime clients.
|
// Storage contains information about persistent and runtime clients.
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
// allTags is a set of all client tags.
|
|
||||||
allTags *container.MapSet[string]
|
|
||||||
|
|
||||||
// mu protects index of persistent clients.
|
// mu protects index of persistent clients.
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
@ -29,11 +20,8 @@ type Storage struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStorage returns initialized client storage.
|
// NewStorage returns initialized client storage.
|
||||||
func NewStorage(clientTags []string) (s *Storage) {
|
func NewStorage() (s *Storage) {
|
||||||
allTags := container.NewMapSet(clientTags...)
|
|
||||||
|
|
||||||
return &Storage{
|
return &Storage{
|
||||||
allTags: allTags,
|
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
index: NewIndex(),
|
index: NewIndex(),
|
||||||
runtimeIndex: map[netip.Addr]*Runtime{},
|
runtimeIndex: map[netip.Addr]*Runtime{},
|
||||||
|
@ -41,60 +29,26 @@ func NewStorage(clientTags []string) (s *Storage) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add stores persistent client information or returns an error. p must be
|
// Add stores persistent client information or returns an error. p must be
|
||||||
// valid persistent client.
|
// valid persistent client. See [Persistent.Validate].
|
||||||
func (s *Storage) Add(p *Persistent) (err error) {
|
func (s *Storage) Add(p *Persistent) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "adding client: %w") }()
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
err = s.check(p)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("adding client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.index.Add(p)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// check returns an error if persistent client information contains errors.
|
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Remove persistent client information validation.
|
|
||||||
func (s *Storage) check(p *Persistent) (err error) {
|
|
||||||
switch {
|
|
||||||
case p == nil:
|
|
||||||
return errors.Error("client is nil")
|
|
||||||
case p.Name == "":
|
|
||||||
return errors.Error("empty name")
|
|
||||||
case p.IDsLen() == 0:
|
|
||||||
return errors.Error("id required")
|
|
||||||
case p.UID == UID{}:
|
|
||||||
return errors.Error("uid required")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.index.ClashesUID(p)
|
err = s.index.ClashesUID(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since there is already an annotation deferred.
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{})
|
err = s.index.Clashes(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid upstream servers: %w", err)
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = conf.Close()
|
s.index.Add(p)
|
||||||
if err != nil {
|
|
||||||
log.Error("client: closing upstream config: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, t := range p.Tags {
|
|
||||||
if !s.allTags.Has(t) {
|
|
||||||
return fmt.Errorf("invalid tag: %q", t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(s.chzhen): Move to the constructor.
|
|
||||||
slices.Sort(p.Tags)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -117,7 +71,6 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||||
func (s *Storage) Update(p, n *Persistent) (err error) {
|
func (s *Storage) Update(p, n *Persistent) (err error) {
|
||||||
defer func() { err = errors.Annotate(err, "updating client: %w") }()
|
defer func() { err = errors.Annotate(err, "updating client: %w") }()
|
||||||
|
|
||||||
err = s.check(n)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since there is already an annotation deferred.
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -6,10 +6,34 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStorage_Add(t *testing.T) {
|
func TestStorage_Add(t *testing.T) {
|
||||||
|
const (
|
||||||
|
existingName = "existing_name"
|
||||||
|
existingClientID = "existing_client_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
existingClientUID = client.MustNewUID()
|
||||||
|
existingIP = netip.MustParseAddr("1.2.3.4")
|
||||||
|
existingSubnet = netip.MustParsePrefix("1.2.3.0/24")
|
||||||
|
)
|
||||||
|
|
||||||
|
existingClient := &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{existingIP},
|
||||||
|
Subnets: []netip.Prefix{existingSubnet},
|
||||||
|
ClientIDs: []string{existingClientID},
|
||||||
|
UID: existingClientUID,
|
||||||
|
}
|
||||||
|
|
||||||
|
s := client.NewStorage()
|
||||||
|
err := s.Add(existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
cli *client.Persistent
|
cli *client.Persistent
|
||||||
|
@ -18,68 +42,104 @@ func TestStorage_Add(t *testing.T) {
|
||||||
name: "basic",
|
name: "basic",
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "basic",
|
Name: "basic",
|
||||||
IPs: []netip.Addr{
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
|
||||||
},
|
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
}, {
|
}, {
|
||||||
name: "nil",
|
name: "duplicate_uid",
|
||||||
cli: nil,
|
|
||||||
wantErrMsg: "adding client: client is nil",
|
|
||||||
}, {
|
|
||||||
name: "empty_name",
|
|
||||||
cli: &client.Persistent{
|
|
||||||
Name: "",
|
|
||||||
},
|
|
||||||
wantErrMsg: "adding client: empty name",
|
|
||||||
}, {
|
|
||||||
name: "no_id",
|
|
||||||
cli: &client.Persistent{
|
|
||||||
Name: "no_id",
|
|
||||||
},
|
|
||||||
wantErrMsg: "adding client: id required",
|
|
||||||
}, {
|
|
||||||
name: "no_uid",
|
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "no_uid",
|
Name: "no_uid",
|
||||||
IPs: []netip.Addr{
|
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
UID: existingClientUID,
|
||||||
},
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" uses the same uid`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_name",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: existingName,
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: "adding client: uid required",
|
wantErrMsg: `adding client: another client uses the same name "existing_name"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_ip",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_ip",
|
||||||
|
IPs: []netip.Addr{existingIP},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_subnet",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_subnet",
|
||||||
|
Subnets: []netip.Prefix{existingSubnet},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
|
`uses the same subnet "1.2.3.0/24"`,
|
||||||
|
}, {
|
||||||
|
name: "duplicate_client_id",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "duplicate_client_id",
|
||||||
|
ClientIDs: []string{existingClientID},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: `adding client: another client "existing_name" ` +
|
||||||
|
`uses the same ClientID "existing_client_id"`,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
s := client.NewStorage(nil)
|
err = s.Add(tc.cli)
|
||||||
err := s.Add(tc.cli)
|
|
||||||
|
|
||||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.Run("duplicate_uid", func(t *testing.T) {
|
func TestStorage_RemoveByName(t *testing.T) {
|
||||||
sameUID := client.MustNewUID()
|
const (
|
||||||
s := client.NewStorage(nil)
|
existingName = "existing_name"
|
||||||
|
)
|
||||||
|
|
||||||
cli1 := &client.Persistent{
|
existingClient := &client.Persistent{
|
||||||
Name: "cli1",
|
Name: existingName,
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
|
||||||
UID: sameUID,
|
UID: client.MustNewUID(),
|
||||||
}
|
}
|
||||||
|
|
||||||
cli2 := &client.Persistent{
|
s := client.NewStorage()
|
||||||
Name: "cli2",
|
err := s.Add(existingClient)
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("4.3.2.1")},
|
|
||||||
UID: sameUID,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.Add(cli1)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
err = s.Add(cli2)
|
testCases := []struct {
|
||||||
testutil.AssertErrorMsg(t, `adding client: another client "cli1" uses the same uid`, err)
|
want assert.BoolAssertionFunc
|
||||||
|
name string
|
||||||
|
cliName string
|
||||||
|
}{{
|
||||||
|
name: "existing_client",
|
||||||
|
cliName: existingName,
|
||||||
|
want: assert.True,
|
||||||
|
}, {
|
||||||
|
name: "non_existing_client",
|
||||||
|
cliName: "non_existing_client",
|
||||||
|
want: assert.False,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
tc.want(t, s.RemoveByName(tc.cliName))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("duplicate_remove", func(t *testing.T) {
|
||||||
|
s = client.NewStorage()
|
||||||
|
err = s.Add(existingClient)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, s.RemoveByName(existingName))
|
||||||
|
assert.False(t, s.RemoveByName(existingName))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue