client: add tests

This commit is contained in:
Stanislav Chzhen 2024-06-18 15:18:08 +03:00
parent 702467f7ca
commit 045b838823
4 changed files with 195 additions and 99 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -98,6 +99,39 @@ type Persistent struct {
SafeSearchConf filtering.SafeSearchConfig SafeSearchConf filtering.SafeSearchConfig
} }
// Validate returns an error if persistent client information contains errors.
func (c *Persistent) Validate(allTags *container.MapSet[string]) (err error) {
switch {
case c.Name == "":
return errors.Error("empty name")
case c.IDsLen() == 0:
return errors.Error("id required")
case c.UID == UID{}:
return errors.Error("uid required")
}
conf, err := proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err)
}
err = conf.Close()
if err != nil {
log.Error("client: closing upstream config: %s", err)
}
for _, t := range c.Tags {
if !allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
return nil
}
// SetTags sets the tags if they are known, otherwise logs an unknown tag. // SetTags sets the tags if they are known, otherwise logs an unknown tag.
func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) { func (c *Persistent) SetTags(tags []string, known *container.MapSet[string]) {
for _, t := range tags { for _, t := range tags {

View File

@ -1,13 +1,15 @@
package client package client
import ( import (
"net/netip"
"testing" "testing"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestPersistentClient_EqualIDs(t *testing.T) { func TestPersistent_EqualIDs(t *testing.T) {
const ( const (
ip = "0.0.0.0" ip = "0.0.0.0"
ip1 = "1.1.1.1" ip1 = "1.1.1.1"
@ -122,3 +124,50 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
}) })
} }
} }
func TestPersistent_Validate(t *testing.T) {
// TODO(s.chzhen): Add test cases.
testCases := []struct {
name string
cli *Persistent
wantErrMsg string
}{{
name: "basic",
cli: &Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: MustNewUID(),
},
wantErrMsg: "",
}, {
name: "empty_name",
cli: &Persistent{
Name: "",
},
wantErrMsg: "empty name",
}, {
name: "no_id",
cli: &Persistent{
Name: "no_id",
},
wantErrMsg: "id required",
}, {
name: "no_uid",
cli: &Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "uid required",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.cli.Validate(nil)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}

View File

@ -1,23 +1,14 @@
package client package client
import ( import (
"fmt"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
) )
// Storage contains information about persistent and runtime clients. // Storage contains information about persistent and runtime clients.
type Storage struct { type Storage struct {
// allTags is a set of all client tags.
allTags *container.MapSet[string]
// mu protects index of persistent clients. // mu protects index of persistent clients.
mu *sync.Mutex mu *sync.Mutex
@ -29,11 +20,8 @@ type Storage struct {
} }
// NewStorage returns initialized client storage. // NewStorage returns initialized client storage.
func NewStorage(clientTags []string) (s *Storage) { func NewStorage() (s *Storage) {
allTags := container.NewMapSet(clientTags...)
return &Storage{ return &Storage{
allTags: allTags,
mu: &sync.Mutex{}, mu: &sync.Mutex{},
index: NewIndex(), index: NewIndex(),
runtimeIndex: map[netip.Addr]*Runtime{}, runtimeIndex: map[netip.Addr]*Runtime{},
@ -41,60 +29,26 @@ func NewStorage(clientTags []string) (s *Storage) {
} }
// Add stores persistent client information or returns an error. p must be // Add stores persistent client information or returns an error. p must be
// valid persistent client. // valid persistent client. See [Persistent.Validate].
func (s *Storage) Add(p *Persistent) (err error) { func (s *Storage) Add(p *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "adding client: %w") }()
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
err = s.check(p)
if err != nil {
return fmt.Errorf("adding client: %w", err)
}
s.index.Add(p)
return nil
}
// check returns an error if persistent client information contains errors.
//
// TODO(s.chzhen): Remove persistent client information validation.
func (s *Storage) check(p *Persistent) (err error) {
switch {
case p == nil:
return errors.Error("client is nil")
case p.Name == "":
return errors.Error("empty name")
case p.IDsLen() == 0:
return errors.Error("id required")
case p.UID == UID{}:
return errors.Error("uid required")
}
err = s.index.ClashesUID(p) err = s.index.ClashesUID(p)
if err != nil { if err != nil {
// Don't wrap the error since there is already an annotation deferred. // Don't wrap the error since there is already an annotation deferred.
return err return err
} }
conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) err = s.index.Clashes(p)
if err != nil { if err != nil {
return fmt.Errorf("invalid upstream servers: %w", err) // Don't wrap the error since there is already an annotation deferred.
return err
} }
err = conf.Close() s.index.Add(p)
if err != nil {
log.Error("client: closing upstream config: %s", err)
}
for _, t := range p.Tags {
if !s.allTags.Has(t) {
return fmt.Errorf("invalid tag: %q", t)
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(p.Tags)
return nil return nil
} }
@ -117,7 +71,6 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
func (s *Storage) Update(p, n *Persistent) (err error) { func (s *Storage) Update(p, n *Persistent) (err error) {
defer func() { err = errors.Annotate(err, "updating client: %w") }() defer func() { err = errors.Annotate(err, "updating client: %w") }()
err = s.check(n)
if err != nil { if err != nil {
// Don't wrap the error since there is already an annotation deferred. // Don't wrap the error since there is already an annotation deferred.
return err return err

View File

@ -6,10 +6,34 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestStorage_Add(t *testing.T) { func TestStorage_Add(t *testing.T) {
const (
existingName = "existing_name"
existingClientID = "existing_client_id"
)
var (
existingClientUID = client.MustNewUID()
existingIP = netip.MustParseAddr("1.2.3.4")
existingSubnet = netip.MustParsePrefix("1.2.3.0/24")
)
existingClient := &client.Persistent{
Name: existingName,
IPs: []netip.Addr{existingIP},
Subnets: []netip.Prefix{existingSubnet},
ClientIDs: []string{existingClientID},
UID: existingClientUID,
}
s := client.NewStorage()
err := s.Add(existingClient)
require.NoError(t, err)
testCases := []struct { testCases := []struct {
name string name string
cli *client.Persistent cli *client.Persistent
@ -18,68 +42,104 @@ func TestStorage_Add(t *testing.T) {
name: "basic", name: "basic",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "basic", Name: "basic",
IPs: []netip.Addr{ IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")},
netip.MustParseAddr("1.2.3.4"),
},
UID: client.MustNewUID(), UID: client.MustNewUID(),
}, },
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "nil", name: "duplicate_uid",
cli: nil,
wantErrMsg: "adding client: client is nil",
}, {
name: "empty_name",
cli: &client.Persistent{
Name: "",
},
wantErrMsg: "adding client: empty name",
}, {
name: "no_id",
cli: &client.Persistent{
Name: "no_id",
},
wantErrMsg: "adding client: id required",
}, {
name: "no_uid",
cli: &client.Persistent{ cli: &client.Persistent{
Name: "no_uid", Name: "no_uid",
IPs: []netip.Addr{ IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
netip.MustParseAddr("1.2.3.4"), UID: existingClientUID,
}, },
wantErrMsg: `adding client: another client "existing_name" uses the same uid`,
}, {
name: "duplicate_name",
cli: &client.Persistent{
Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")},
UID: client.MustNewUID(),
}, },
wantErrMsg: "adding client: uid required", wantErrMsg: `adding client: another client uses the same name "existing_name"`,
}, {
name: "duplicate_ip",
cli: &client.Persistent{
Name: "duplicate_ip",
IPs: []netip.Addr{existingIP},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`,
}, {
name: "duplicate_subnet",
cli: &client.Persistent{
Name: "duplicate_subnet",
Subnets: []netip.Prefix{existingSubnet},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same subnet "1.2.3.0/24"`,
}, {
name: "duplicate_client_id",
cli: &client.Persistent{
Name: "duplicate_client_id",
ClientIDs: []string{existingClientID},
UID: client.MustNewUID(),
},
wantErrMsg: `adding client: another client "existing_name" ` +
`uses the same ClientID "existing_client_id"`,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := client.NewStorage(nil) err = s.Add(tc.cli)
err := s.Add(tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
}) })
} }
}
t.Run("duplicate_uid", func(t *testing.T) { func TestStorage_RemoveByName(t *testing.T) {
sameUID := client.MustNewUID() const (
s := client.NewStorage(nil) existingName = "existing_name"
)
cli1 := &client.Persistent{ existingClient := &client.Persistent{
Name: "cli1", Name: existingName,
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
UID: sameUID, UID: client.MustNewUID(),
} }
cli2 := &client.Persistent{ s := client.NewStorage()
Name: "cli2", err := s.Add(existingClient)
IPs: []netip.Addr{netip.MustParseAddr("4.3.2.1")},
UID: sameUID,
}
err := s.Add(cli1)
require.NoError(t, err) require.NoError(t, err)
err = s.Add(cli2) testCases := []struct {
testutil.AssertErrorMsg(t, `adding client: another client "cli1" uses the same uid`, err) want assert.BoolAssertionFunc
name string
cliName string
}{{
name: "existing_client",
cliName: existingName,
want: assert.True,
}, {
name: "non_existing_client",
cliName: "non_existing_client",
want: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.want(t, s.RemoveByName(tc.cliName))
})
}
t.Run("duplicate_remove", func(t *testing.T) {
s = client.NewStorage()
err = s.Add(existingClient)
require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName))
assert.False(t, s.RemoveByName(existingName))
}) })
} }