client: add tests

This commit is contained in:
Stanislav Chzhen 2024-06-13 15:19:36 +03:00
parent 5601cfce39
commit e268abf926
2 changed files with 105 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"slices" "slices"
"sync"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -17,6 +18,9 @@ type Storage struct {
// allTags is a set of all client tags. // allTags is a set of all client tags.
allTags *container.MapSet[string] allTags *container.MapSet[string]
// mu protects index of persistent clients.
mu *sync.Mutex
// index contains information about persistent clients. // index contains information about persistent clients.
index *Index index *Index
@ -30,13 +34,18 @@ func NewStorage(clientTags []string) (s *Storage) {
return &Storage{ return &Storage{
allTags: allTags, allTags: allTags,
mu: &sync.Mutex{},
index: NewIndex(), index: NewIndex(),
runtimeIndex: map[netip.Addr]*Runtime{}, runtimeIndex: map[netip.Addr]*Runtime{},
} }
} }
// Add stores persistent client information or returns an error. // Add stores persistent client information or returns an error. p must be
// valid persistent client.
func (s *Storage) Add(p *Persistent) (err error) { func (s *Storage) Add(p *Persistent) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
err = s.check(p) err = s.check(p)
if err != nil { if err != nil {
return fmt.Errorf("adding client: %w", err) return fmt.Errorf("adding client: %w", err)
@ -48,6 +57,8 @@ func (s *Storage) Add(p *Persistent) (err error) {
} }
// check returns an error if persistent client information contains errors. // check returns an error if persistent client information contains errors.
//
// TODO(s.chzhen): Remove persistent client information validation.
func (s *Storage) check(p *Persistent) (err error) { func (s *Storage) check(p *Persistent) (err error) {
switch { switch {
case p == nil: case p == nil:
@ -56,6 +67,14 @@ func (s *Storage) check(p *Persistent) (err error) {
return errors.Error("empty name") return errors.Error("empty name")
case p.IDsLen() == 0: case p.IDsLen() == 0:
return errors.Error("id required") return errors.Error("id required")
case p.UID == UID{}:
return errors.Error("uid required")
}
err = s.index.ClashesUID(p)
if err != nil {
// Don't wrap the error since there is already an annotation deferred.
return err
} }
conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{})

View File

@ -0,0 +1,85 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/require"
)
func TestStorage_Add(t *testing.T) {
testCases := []struct {
name string
cli *client.Persistent
wantErrMsg string
}{{
name: "basic",
cli: &client.Persistent{
Name: "basic",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
UID: client.MustNewUID(),
},
wantErrMsg: "",
}, {
name: "nil",
cli: nil,
wantErrMsg: "adding client: client is nil",
}, {
name: "empty_name",
cli: &client.Persistent{
Name: "",
},
wantErrMsg: "adding client: empty name",
}, {
name: "no_id",
cli: &client.Persistent{
Name: "no_id",
},
wantErrMsg: "adding client: id required",
}, {
name: "no_uid",
cli: &client.Persistent{
Name: "no_uid",
IPs: []netip.Addr{
netip.MustParseAddr("1.2.3.4"),
},
},
wantErrMsg: "adding client: uid required",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := client.NewStorage(nil)
err := s.Add(tc.cli)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
t.Run("duplicate_uid", func(t *testing.T) {
sameUID := client.MustNewUID()
s := client.NewStorage(nil)
cli1 := &client.Persistent{
Name: "cli1",
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
UID: sameUID,
}
cli2 := &client.Persistent{
Name: "cli2",
IPs: []netip.Addr{netip.MustParseAddr("4.3.2.1")},
UID: sameUID,
}
err := s.Add(cli1)
require.NoError(t, err)
err = s.Add(cli2)
testutil.AssertErrorMsg(t, `adding client: another client "cli1" uses the same uid`, err)
})
}