Pull request 1933: upd-golibs
Squashed commit of the following: commit 081d10e6909def3a075707e75dbd0c5f63f91903 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jul 20 14:17:01 2023 +0300 aghnet: fix docs commit 7433b72c0653cb33fe5ff810ae8a1346a6994f95 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jul 20 14:03:16 2023 +0300 all: imp tests; upd golibs
This commit is contained in:
parent
4e8d3d7628
commit
5be0e84719
2
go.mod
2
go.mod
|
@ -4,7 +4,7 @@ go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.52.0
|
github.com/AdguardTeam/dnsproxy v0.52.0
|
||||||
github.com/AdguardTeam/golibs v0.13.4
|
github.com/AdguardTeam/golibs v0.13.5
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1
|
github.com/AdguardTeam/urlfilter v0.16.1
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
github.com/ameshkov/dnscrypt/v2 v2.2.7
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -2,8 +2,8 @@ github.com/AdguardTeam/dnsproxy v0.52.0 h1:uZxCXflHSAwtJ7uTYXP6qgWcxaBsH0pJvldpw
|
||||||
github.com/AdguardTeam/dnsproxy v0.52.0/go.mod h1:Jo2zeRe97Rxt3yikXc+fn0LdLtqCj0Xlyh1PNBj6bpM=
|
github.com/AdguardTeam/dnsproxy v0.52.0/go.mod h1:Jo2zeRe97Rxt3yikXc+fn0LdLtqCj0Xlyh1PNBj6bpM=
|
||||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||||
github.com/AdguardTeam/golibs v0.13.4 h1:ACTwIR1pEENBijHcEWtiMbSh4wWQOlIHRxmUB8oBHf8=
|
github.com/AdguardTeam/golibs v0.13.5 h1:fpa30Yr9Rcn4vJ88nE4XHSompY7/qMOq2aNS/4PGymA=
|
||||||
github.com/AdguardTeam/golibs v0.13.4/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI=
|
github.com/AdguardTeam/golibs v0.13.5/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI=
|
||||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
|
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package aghio
|
package aghio_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -31,7 +32,7 @@ func TestLimitReader(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
_, err := LimitReader(nil, tc.n)
|
_, err := aghio.LimitReader(nil, tc.n)
|
||||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -57,7 +58,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||||
limit: 3,
|
limit: 3,
|
||||||
want: 0,
|
want: 0,
|
||||||
}, {
|
}, {
|
||||||
err: &LimitReachedError{
|
err: &aghio.LimitReachedError{
|
||||||
Limit: 0,
|
Limit: 0,
|
||||||
},
|
},
|
||||||
name: "limit_reached",
|
name: "limit_reached",
|
||||||
|
@ -74,7 +75,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||||
lreader, err := LimitReader(readCloser, tc.limit)
|
lreader, err := aghio.LimitReader(readCloser, tc.limit)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, lreader)
|
require.NotNil(t, lreader)
|
||||||
|
|
||||||
|
@ -89,7 +90,7 @@ func TestLimitedReader_Read(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
||||||
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{
|
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
|
||||||
Limit: 0,
|
Limit: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -141,9 +141,9 @@ type HostsRecord struct {
|
||||||
Canonical string
|
Canonical string
|
||||||
}
|
}
|
||||||
|
|
||||||
// equal returns true if all fields of rec are equal to field in other or they
|
// Equal returns true if all fields of rec are equal to field in other or they
|
||||||
// both are nil.
|
// both are nil.
|
||||||
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) {
|
func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
|
||||||
if rec == nil {
|
if rec == nil {
|
||||||
return other == nil
|
return other == nil
|
||||||
} else if other == nil {
|
} else if other == nil {
|
||||||
|
@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// hc.last is nil on the first refresh, so let that one through.
|
// hc.last is nil on the first refresh, so let that one through.
|
||||||
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) {
|
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) {
|
||||||
log.Debug("%s: no changes detected", hostsContainerPrefix)
|
log.Debug("%s: no changes detected", hostsContainerPrefix)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/fs"
|
||||||
|
"net/netip"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"testing/fstest"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const nl = "\n"
|
||||||
|
|
||||||
|
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
||||||
|
gsfs := fstest.MapFS{
|
||||||
|
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
|
||||||
|
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
|
||||||
|
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
paths []string
|
||||||
|
want []string
|
||||||
|
}{{
|
||||||
|
name: "no_paths",
|
||||||
|
paths: nil,
|
||||||
|
want: nil,
|
||||||
|
}, {
|
||||||
|
name: "single_file",
|
||||||
|
paths: []string{"dir_0/file_1"},
|
||||||
|
want: []string{"dir_0/file_1"},
|
||||||
|
}, {
|
||||||
|
name: "several_files",
|
||||||
|
paths: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||||
|
want: []string{"dir_0/file_1", "dir_0/file_2"},
|
||||||
|
}, {
|
||||||
|
name: "whole_dir",
|
||||||
|
paths: []string{"dir_0"},
|
||||||
|
want: []string{"dir_0/*"},
|
||||||
|
}, {
|
||||||
|
name: "file_and_dir",
|
||||||
|
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
|
||||||
|
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
|
||||||
|
}, {
|
||||||
|
name: "non-existing",
|
||||||
|
paths: []string{path.Join("dir_0", "file_3")},
|
||||||
|
want: nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, patterns)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("bad_file", func(t *testing.T) {
|
||||||
|
const errStat errors.Error = "bad file"
|
||||||
|
|
||||||
|
badFS := &fakefs.StatFS{
|
||||||
|
OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") },
|
||||||
|
OnStat: func(name string) (fi fs.FileInfo, err error) {
|
||||||
|
return nil, errStat
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pathsToPatterns(badFS, []string{""})
|
||||||
|
assert.ErrorIs(t, err, errStat)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUniqueRules_ParseLine(t *testing.T) {
|
||||||
|
ip := netutil.IPv4Localhost()
|
||||||
|
ipStr := ip.String()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
wantIP netip.Addr
|
||||||
|
wantHosts []string
|
||||||
|
}{{
|
||||||
|
name: "simple",
|
||||||
|
line: ipStr + ` hostname`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: []string{"hostname"},
|
||||||
|
}, {
|
||||||
|
name: "aliases",
|
||||||
|
line: ipStr + ` hostname alias`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: []string{"hostname", "alias"},
|
||||||
|
}, {
|
||||||
|
name: "invalid_line",
|
||||||
|
line: ipStr,
|
||||||
|
wantIP: netip.Addr{},
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "invalid_line_hostname",
|
||||||
|
line: ipStr + ` # hostname`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "commented_aliases",
|
||||||
|
line: ipStr + ` hostname # alias`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: []string{"hostname"},
|
||||||
|
}, {
|
||||||
|
name: "whole_comment",
|
||||||
|
line: `# ` + ipStr + ` hostname`,
|
||||||
|
wantIP: netip.Addr{},
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "partial_comment",
|
||||||
|
line: ipStr + ` host#name`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: []string{"host"},
|
||||||
|
}, {
|
||||||
|
name: "empty",
|
||||||
|
line: ``,
|
||||||
|
wantIP: netip.Addr{},
|
||||||
|
wantHosts: nil,
|
||||||
|
}, {
|
||||||
|
name: "bad_hosts",
|
||||||
|
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
|
||||||
|
wantIP: ip,
|
||||||
|
wantHosts: []string{"ok.host"},
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
hp := hostsParser{}
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
got, hosts := hp.parseLine(tc.line)
|
||||||
|
assert.Equal(t, tc.wantIP, got)
|
||||||
|
assert.Equal(t, tc.wantHosts, hosts)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,9 +1,7 @@
|
||||||
package aghnet
|
package aghnet_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -12,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
@ -24,10 +23,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const nl = "\n"
|
||||||
nl = "\n"
|
|
||||||
sp = " "
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewHostsContainer(t *testing.T) {
|
func TestNewHostsContainer(t *testing.T) {
|
||||||
const dirname = "dir"
|
const dirname = "dir"
|
||||||
|
@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
name: "one_file",
|
name: "one_file",
|
||||||
paths: []string{p},
|
paths: []string{p},
|
||||||
}, {
|
}, {
|
||||||
wantErr: ErrNoHostsPaths,
|
wantErr: aghnet.ErrNoHostsPaths,
|
||||||
name: "no_files",
|
name: "no_files",
|
||||||
paths: []string{},
|
paths: []string{},
|
||||||
}, {
|
}, {
|
||||||
wantErr: ErrNoHostsPaths,
|
wantErr: aghnet.ErrNoHostsPaths,
|
||||||
name: "non-existent_file",
|
name: "non-existent_file",
|
||||||
paths: []string{path.Join(dirname, filename+"2")},
|
paths: []string{path.Join(dirname, filename+"2")},
|
||||||
}, {
|
}, {
|
||||||
|
@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
return eventsCh
|
return eventsCh
|
||||||
}
|
}
|
||||||
|
|
||||||
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
|
||||||
OnEvents: onEvents,
|
OnEvents: onEvents,
|
||||||
OnAdd: onAdd,
|
OnAdd: onAdd,
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
|
@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
|
|
||||||
t.Run("nil_fs", func(t *testing.T) {
|
t.Run("nil_fs", func(t *testing.T) {
|
||||||
require.Panics(t, func() {
|
require.Panics(t, func() {
|
||||||
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{
|
_, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{
|
||||||
// Those shouldn't panic.
|
// Those shouldn't panic.
|
||||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
OnAdd: func(name string) (err error) { return nil },
|
OnAdd: func(name string) (err error) { return nil },
|
||||||
|
@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
|
|
||||||
t.Run("nil_watcher", func(t *testing.T) {
|
t.Run("nil_watcher", func(t *testing.T) {
|
||||||
require.Panics(t, func() {
|
require.Panics(t, func() {
|
||||||
_, _ = NewHostsContainer(0, testFS, nil, p)
|
_, _ = aghnet.NewHostsContainer(0, testFS, nil, p)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
hc, err := NewHostsContainer(0, testFS, errWatcher, p)
|
hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p)
|
||||||
require.ErrorIs(t, err, errOnAdd)
|
require.ErrorIs(t, err, errOnAdd)
|
||||||
|
|
||||||
assert.Nil(t, hc)
|
assert.Nil(t, hc)
|
||||||
|
@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
hc, err := NewHostsContainer(0, testFS, w, "dir")
|
hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||||
|
|
||||||
checkRefresh := func(t *testing.T, want *HostsRecord) {
|
checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
|
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
|
||||||
|
@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.NotNil(t, rec)
|
require.NotNil(t, rec)
|
||||||
|
|
||||||
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want)
|
assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("initial_refresh", func(t *testing.T) {
|
t.Run("initial_refresh", func(t *testing.T) {
|
||||||
checkRefresh(t, &HostsRecord{
|
checkRefresh(t, &aghnet.HostsRecord{
|
||||||
Aliases: stringutil.NewSet(),
|
Aliases: stringutil.NewSet(),
|
||||||
Canonical: "hostname",
|
Canonical: "hostname",
|
||||||
})
|
})
|
||||||
|
@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||||
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
|
||||||
eventsCh <- event{}
|
eventsCh <- event{}
|
||||||
|
|
||||||
checkRefresh(t, &HostsRecord{
|
checkRefresh(t, &aghnet.HostsRecord{
|
||||||
Aliases: stringutil.NewSet("alias"),
|
Aliases: stringutil.NewSet("alias"),
|
||||||
Canonical: "hostname",
|
Canonical: "hostname",
|
||||||
})
|
})
|
||||||
|
@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHostsContainer_PathsToPatterns(t *testing.T) {
|
|
||||||
gsfs := fstest.MapFS{
|
|
||||||
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
|
|
||||||
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
|
|
||||||
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
paths []string
|
|
||||||
want []string
|
|
||||||
}{{
|
|
||||||
name: "no_paths",
|
|
||||||
paths: nil,
|
|
||||||
want: nil,
|
|
||||||
}, {
|
|
||||||
name: "single_file",
|
|
||||||
paths: []string{"dir_0/file_1"},
|
|
||||||
want: []string{"dir_0/file_1"},
|
|
||||||
}, {
|
|
||||||
name: "several_files",
|
|
||||||
paths: []string{"dir_0/file_1", "dir_0/file_2"},
|
|
||||||
want: []string{"dir_0/file_1", "dir_0/file_2"},
|
|
||||||
}, {
|
|
||||||
name: "whole_dir",
|
|
||||||
paths: []string{"dir_0"},
|
|
||||||
want: []string{"dir_0/*"},
|
|
||||||
}, {
|
|
||||||
name: "file_and_dir",
|
|
||||||
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
|
|
||||||
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
|
|
||||||
}, {
|
|
||||||
name: "non-existing",
|
|
||||||
paths: []string{path.Join("dir_0", "file_3")},
|
|
||||||
want: nil,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
patterns, err := pathsToPatterns(gsfs, tc.paths)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.want, patterns)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("bad_file", func(t *testing.T) {
|
|
||||||
const errStat errors.Error = "bad file"
|
|
||||||
|
|
||||||
badFS := &aghtest.StatFS{
|
|
||||||
OnStat: func(name string) (fs.FileInfo, error) {
|
|
||||||
return nil, errStat
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := pathsToPatterns(badFS, []string{""})
|
|
||||||
assert.ErrorIs(t, err, errStat)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHostsContainer_Translate(t *testing.T) {
|
func TestHostsContainer_Translate(t *testing.T) {
|
||||||
stubWatcher := aghtest.FSWatcher{
|
stubWatcher := aghtest.FSWatcher{
|
||||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
|
@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) {
|
||||||
|
|
||||||
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
|
||||||
|
|
||||||
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||||
|
|
||||||
|
@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) {
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
|
hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
testutil.CleanupAndRequireSuccess(t, hc.Close)
|
||||||
|
|
||||||
|
@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUniqueRules_ParseLine(t *testing.T) {
|
|
||||||
ip := netutil.IPv4Localhost()
|
|
||||||
ipStr := ip.String()
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
line string
|
|
||||||
wantIP netip.Addr
|
|
||||||
wantHosts []string
|
|
||||||
}{{
|
|
||||||
name: "simple",
|
|
||||||
line: ipStr + ` hostname`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: []string{"hostname"},
|
|
||||||
}, {
|
|
||||||
name: "aliases",
|
|
||||||
line: ipStr + ` hostname alias`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: []string{"hostname", "alias"},
|
|
||||||
}, {
|
|
||||||
name: "invalid_line",
|
|
||||||
line: ipStr,
|
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantHosts: nil,
|
|
||||||
}, {
|
|
||||||
name: "invalid_line_hostname",
|
|
||||||
line: ipStr + ` # hostname`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: nil,
|
|
||||||
}, {
|
|
||||||
name: "commented_aliases",
|
|
||||||
line: ipStr + ` hostname # alias`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: []string{"hostname"},
|
|
||||||
}, {
|
|
||||||
name: "whole_comment",
|
|
||||||
line: `# ` + ipStr + ` hostname`,
|
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantHosts: nil,
|
|
||||||
}, {
|
|
||||||
name: "partial_comment",
|
|
||||||
line: ipStr + ` host#name`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: []string{"host"},
|
|
||||||
}, {
|
|
||||||
name: "empty",
|
|
||||||
line: ``,
|
|
||||||
wantIP: netip.Addr{},
|
|
||||||
wantHosts: nil,
|
|
||||||
}, {
|
|
||||||
name: "bad_hosts",
|
|
||||||
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
|
|
||||||
wantIP: ip,
|
|
||||||
wantHosts: []string{"ok.host"},
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
hp := hostsParser{}
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
got, hosts := hp.parseLine(tc.line)
|
|
||||||
assert.Equal(t, tc.wantIP, got)
|
|
||||||
assert.Equal(t, tc.wantHosts, hosts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package aghnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -15,6 +16,10 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DialContextFunc is the semantic alias for dialing functions, such as
|
||||||
|
// [http.Transport.DialContext].
|
||||||
|
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
||||||
|
|
||||||
// Variables and functions to substitute in tests.
|
// Variables and functions to substitute in tests.
|
||||||
var (
|
var (
|
||||||
// aghosRunCommand is the function to run shell commands.
|
// aghosRunCommand is the function to run shell commands.
|
||||||
|
|
|
@ -5,9 +5,9 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"testing/fstest"
|
"testing/fstest"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
|
||||||
Data: []byte(`nameserver 1.1.1.1`),
|
Data: []byte(`nameserver 1.1.1.1`),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
panicFsys := &aghtest.FS{
|
panicFsys := &fakefs.FS{
|
||||||
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
|
OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,334 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io/fs"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testdata is the filesystem containing data for testing the package.
|
||||||
|
var testdata fs.FS = os.DirFS("./testdata")
|
||||||
|
|
||||||
|
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
|
||||||
|
// package with fsys for tests ran under t.
|
||||||
|
func substRootDirFS(t testing.TB, fsys fs.FS) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
prev := rootDirFS
|
||||||
|
t.Cleanup(func() { rootDirFS = prev })
|
||||||
|
rootDirFS = fsys
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunCmdFunc is the signature of aghos.RunCommand function.
|
||||||
|
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
||||||
|
|
||||||
|
// substShell replaces the the aghos.RunCommand function used throughout the
|
||||||
|
// package with rc for tests ran under t.
|
||||||
|
func substShell(t testing.TB, rc RunCmdFunc) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
prev := aghosRunCommand
|
||||||
|
t.Cleanup(func() { aghosRunCommand = prev })
|
||||||
|
aghosRunCommand = rc
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
||||||
|
// execution result. It's only needed to simplify testing.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
||||||
|
type mapShell map[string]struct {
|
||||||
|
err error
|
||||||
|
out string
|
||||||
|
code int
|
||||||
|
}
|
||||||
|
|
||||||
|
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
||||||
|
// combination from cmd.
|
||||||
|
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
||||||
|
return mapShell{cmd: {code: code, out: out, err: err}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunCmd is a RunCmdFunc handled by s.
|
||||||
|
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
||||||
|
key := strings.Join(append([]string{cmd}, args...), " ")
|
||||||
|
ret, ok := s[key]
|
||||||
|
if !ok {
|
||||||
|
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.code, []byte(ret.out), ret.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
|
||||||
|
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
|
||||||
|
|
||||||
|
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
|
||||||
|
// throughout the package with f for tests ran under t.
|
||||||
|
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
prev := netInterfaceAddrs
|
||||||
|
t.Cleanup(func() { netInterfaceAddrs = prev })
|
||||||
|
netInterfaceAddrs = f
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatewayIP(t *testing.T) {
|
||||||
|
const ifaceName = "ifaceName"
|
||||||
|
const cmd = "ip route show dev " + ifaceName
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
shell mapShell
|
||||||
|
want netip.Addr
|
||||||
|
name string
|
||||||
|
}{{
|
||||||
|
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
||||||
|
want: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
name: "success_v4",
|
||||||
|
}, {
|
||||||
|
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
||||||
|
want: netip.MustParseAddr("::ffff"),
|
||||||
|
name: "success_v6",
|
||||||
|
}, {
|
||||||
|
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
||||||
|
want: netip.Addr{},
|
||||||
|
name: "bad_output",
|
||||||
|
}, {
|
||||||
|
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
||||||
|
want: netip.Addr{},
|
||||||
|
name: "err_runcmd",
|
||||||
|
}, {
|
||||||
|
shell: theOnlyCmd(cmd, 1, "", nil),
|
||||||
|
want: netip.Addr{},
|
||||||
|
name: "bad_code",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
substShell(t, tc.shell.RunCmd)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, GatewayIP(ifaceName))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInterfaceByIP(t *testing.T) {
|
||||||
|
ifaces, err := GetValidNetInterfacesForWeb()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, ifaces)
|
||||||
|
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
t.Run(iface.Name, func(t *testing.T) {
|
||||||
|
require.NotEmpty(t, iface.Addresses)
|
||||||
|
|
||||||
|
for _, ip := range iface.Addresses {
|
||||||
|
ifaceName := InterfaceByIP(ip)
|
||||||
|
require.Equal(t, iface.Name, ifaceName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBroadcastFromIPNet(t *testing.T) {
|
||||||
|
known4 := netip.MustParseAddr("192.168.0.1")
|
||||||
|
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
|
||||||
|
|
||||||
|
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
pref netip.Prefix
|
||||||
|
want netip.Addr
|
||||||
|
name string
|
||||||
|
}{{
|
||||||
|
pref: netip.PrefixFrom(known4, 0),
|
||||||
|
want: fullBroadcast4,
|
||||||
|
name: "full",
|
||||||
|
}, {
|
||||||
|
pref: netip.PrefixFrom(known4, 20),
|
||||||
|
want: netip.MustParseAddr("192.168.15.255"),
|
||||||
|
name: "full",
|
||||||
|
}, {
|
||||||
|
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
|
||||||
|
want: known6,
|
||||||
|
name: "ipv6_no_mask",
|
||||||
|
}, {
|
||||||
|
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
|
||||||
|
want: known4,
|
||||||
|
name: "ipv4_no_mask",
|
||||||
|
}, {
|
||||||
|
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
||||||
|
want: fullBroadcast4,
|
||||||
|
name: "unspecified",
|
||||||
|
}, {
|
||||||
|
pref: netip.Prefix{},
|
||||||
|
want: netip.Addr{},
|
||||||
|
name: "invalid",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPort(t *testing.T) {
|
||||||
|
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
||||||
|
|
||||||
|
t.Run("tcp_bound", func(t *testing.T) {
|
||||||
|
l, err := net.Listen("tcp", laddr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||||
|
|
||||||
|
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
|
||||||
|
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||||
|
require.NotZero(t, ipp.Port())
|
||||||
|
|
||||||
|
err = CheckPort("tcp", ipp)
|
||||||
|
target := &net.OpError{}
|
||||||
|
require.ErrorAs(t, err, &target)
|
||||||
|
|
||||||
|
assert.Equal(t, "listen", target.Op)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("udp_bound", func(t *testing.T) {
|
||||||
|
conn, err := net.ListenPacket("udp", laddr.String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
||||||
|
|
||||||
|
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
|
||||||
|
require.Equal(t, laddr.Addr(), ipp.Addr())
|
||||||
|
require.NotZero(t, ipp.Port())
|
||||||
|
|
||||||
|
err = CheckPort("udp", ipp)
|
||||||
|
target := &net.OpError{}
|
||||||
|
require.ErrorAs(t, err, &target)
|
||||||
|
|
||||||
|
assert.Equal(t, "listen", target.Op)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad_network", func(t *testing.T) {
|
||||||
|
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can_bind", func(t *testing.T) {
|
||||||
|
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCollectAllIfacesAddrs(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
wantErrMsg string
|
||||||
|
addrs []net.Addr
|
||||||
|
wantAddrs []string
|
||||||
|
}{{
|
||||||
|
name: "success",
|
||||||
|
wantErrMsg: ``,
|
||||||
|
addrs: []net.Addr{&net.IPNet{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
|
||||||
|
}, &net.IPNet{
|
||||||
|
IP: net.IP{4, 3, 2, 1},
|
||||||
|
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
||||||
|
}},
|
||||||
|
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
||||||
|
}, {
|
||||||
|
name: "not_cidr",
|
||||||
|
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
||||||
|
addrs: []net.Addr{&net.IPAddr{
|
||||||
|
IP: net.IP{1, 2, 3, 4},
|
||||||
|
}},
|
||||||
|
wantAddrs: nil,
|
||||||
|
}, {
|
||||||
|
name: "empty",
|
||||||
|
wantErrMsg: ``,
|
||||||
|
addrs: []net.Addr{},
|
||||||
|
wantAddrs: nil,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
|
||||||
|
|
||||||
|
addrs, err := CollectAllIfacesAddrs()
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantAddrs, addrs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("internal_error", func(t *testing.T) {
|
||||||
|
const errAddrs errors.Error = "can't get addresses"
|
||||||
|
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
||||||
|
|
||||||
|
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
||||||
|
|
||||||
|
_, err := CollectAllIfacesAddrs()
|
||||||
|
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAddrInUse(t *testing.T) {
|
||||||
|
t.Run("addr_in_use", func(t *testing.T) {
|
||||||
|
l, err := net.Listen("tcp", "0.0.0.0:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
testutil.CleanupAndRequireSuccess(t, l.Close)
|
||||||
|
|
||||||
|
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
|
||||||
|
assert.True(t, IsAddrInUse(err))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("another", func(t *testing.T) {
|
||||||
|
const anotherErr errors.Error = "not addr in use"
|
||||||
|
|
||||||
|
assert.False(t, IsAddrInUse(anotherErr))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetInterface_MarshalJSON(t *testing.T) {
|
||||||
|
const want = `{` +
|
||||||
|
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
|
||||||
|
`"flags":"up|multicast",` +
|
||||||
|
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
|
||||||
|
`"name":"iface0",` +
|
||||||
|
`"mtu":1500` +
|
||||||
|
`}` + "\n"
|
||||||
|
|
||||||
|
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
net4 := netip.PrefixFrom(ip4, 24)
|
||||||
|
net6 := netip.PrefixFrom(ip6, 8)
|
||||||
|
|
||||||
|
iface := &NetInterface{
|
||||||
|
Addresses: []netip.Addr{ip4, ip6},
|
||||||
|
Subnets: []netip.Prefix{net4, net6},
|
||||||
|
Name: "iface0",
|
||||||
|
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
||||||
|
Flags: net.FlagUp | net.FlagMulticast,
|
||||||
|
MTU: 1500,
|
||||||
|
}
|
||||||
|
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
err := json.NewEncoder(b).Encode(iface)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, want, b.String())
|
||||||
|
}
|
|
@ -1,21 +1,11 @@
|
||||||
package aghnet
|
package aghnet_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
@ -24,315 +14,3 @@ func TestMain(m *testing.M) {
|
||||||
|
|
||||||
// testdata is the filesystem containing data for testing the package.
|
// testdata is the filesystem containing data for testing the package.
|
||||||
var testdata fs.FS = os.DirFS("./testdata")
|
var testdata fs.FS = os.DirFS("./testdata")
|
||||||
|
|
||||||
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
|
|
||||||
// package with fsys for tests ran under t.
|
|
||||||
func substRootDirFS(t testing.TB, fsys fs.FS) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
prev := rootDirFS
|
|
||||||
t.Cleanup(func() { rootDirFS = prev })
|
|
||||||
rootDirFS = fsys
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunCmdFunc is the signature of aghos.RunCommand function.
|
|
||||||
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
|
||||||
|
|
||||||
// substShell replaces the the aghos.RunCommand function used throughout the
|
|
||||||
// package with rc for tests ran under t.
|
|
||||||
func substShell(t testing.TB, rc RunCmdFunc) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
prev := aghosRunCommand
|
|
||||||
t.Cleanup(func() { aghosRunCommand = prev })
|
|
||||||
aghosRunCommand = rc
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
|
||||||
// execution result. It's only needed to simplify testing.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
|
||||||
type mapShell map[string]struct {
|
|
||||||
err error
|
|
||||||
out string
|
|
||||||
code int
|
|
||||||
}
|
|
||||||
|
|
||||||
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
|
||||||
// combination from cmd.
|
|
||||||
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
|
||||||
return mapShell{cmd: {code: code, out: out, err: err}}
|
|
||||||
}
|
|
||||||
|
|
||||||
// RunCmd is a RunCmdFunc handled by s.
|
|
||||||
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
|
||||||
key := strings.Join(append([]string{cmd}, args...), " ")
|
|
||||||
ret, ok := s[key]
|
|
||||||
if !ok {
|
|
||||||
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret.code, []byte(ret.out), ret.err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
|
|
||||||
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
|
|
||||||
|
|
||||||
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
|
|
||||||
// throughout the package with f for tests ran under t.
|
|
||||||
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
prev := netInterfaceAddrs
|
|
||||||
t.Cleanup(func() { netInterfaceAddrs = prev })
|
|
||||||
netInterfaceAddrs = f
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGatewayIP(t *testing.T) {
|
|
||||||
const ifaceName = "ifaceName"
|
|
||||||
const cmd = "ip route show dev " + ifaceName
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
shell mapShell
|
|
||||||
want netip.Addr
|
|
||||||
name string
|
|
||||||
}{{
|
|
||||||
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
|
|
||||||
want: netip.MustParseAddr("1.2.3.4"),
|
|
||||||
name: "success_v4",
|
|
||||||
}, {
|
|
||||||
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
|
|
||||||
want: netip.MustParseAddr("::ffff"),
|
|
||||||
name: "success_v6",
|
|
||||||
}, {
|
|
||||||
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
|
|
||||||
want: netip.Addr{},
|
|
||||||
name: "bad_output",
|
|
||||||
}, {
|
|
||||||
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
|
|
||||||
want: netip.Addr{},
|
|
||||||
name: "err_runcmd",
|
|
||||||
}, {
|
|
||||||
shell: theOnlyCmd(cmd, 1, "", nil),
|
|
||||||
want: netip.Addr{},
|
|
||||||
name: "bad_code",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
substShell(t, tc.shell.RunCmd)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.want, GatewayIP(ifaceName))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInterfaceByIP(t *testing.T) {
|
|
||||||
ifaces, err := GetValidNetInterfacesForWeb()
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NotEmpty(t, ifaces)
|
|
||||||
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
t.Run(iface.Name, func(t *testing.T) {
|
|
||||||
require.NotEmpty(t, iface.Addresses)
|
|
||||||
|
|
||||||
for _, ip := range iface.Addresses {
|
|
||||||
ifaceName := InterfaceByIP(ip)
|
|
||||||
require.Equal(t, iface.Name, ifaceName)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBroadcastFromIPNet(t *testing.T) {
|
|
||||||
known4 := netip.MustParseAddr("192.168.0.1")
|
|
||||||
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
|
|
||||||
|
|
||||||
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
pref netip.Prefix
|
|
||||||
want netip.Addr
|
|
||||||
name string
|
|
||||||
}{{
|
|
||||||
pref: netip.PrefixFrom(known4, 0),
|
|
||||||
want: fullBroadcast4,
|
|
||||||
name: "full",
|
|
||||||
}, {
|
|
||||||
pref: netip.PrefixFrom(known4, 20),
|
|
||||||
want: netip.MustParseAddr("192.168.15.255"),
|
|
||||||
name: "full",
|
|
||||||
}, {
|
|
||||||
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
|
|
||||||
want: known6,
|
|
||||||
name: "ipv6_no_mask",
|
|
||||||
}, {
|
|
||||||
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
|
|
||||||
want: known4,
|
|
||||||
name: "ipv4_no_mask",
|
|
||||||
}, {
|
|
||||||
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
|
||||||
want: fullBroadcast4,
|
|
||||||
name: "unspecified",
|
|
||||||
}, {
|
|
||||||
pref: netip.Prefix{},
|
|
||||||
want: netip.Addr{},
|
|
||||||
name: "invalid",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckPort(t *testing.T) {
|
|
||||||
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
|
|
||||||
|
|
||||||
t.Run("tcp_bound", func(t *testing.T) {
|
|
||||||
l, err := net.Listen("tcp", laddr.String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
|
||||||
|
|
||||||
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
|
|
||||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
|
||||||
require.NotZero(t, ipp.Port())
|
|
||||||
|
|
||||||
err = CheckPort("tcp", ipp)
|
|
||||||
target := &net.OpError{}
|
|
||||||
require.ErrorAs(t, err, &target)
|
|
||||||
|
|
||||||
assert.Equal(t, "listen", target.Op)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("udp_bound", func(t *testing.T) {
|
|
||||||
conn, err := net.ListenPacket("udp", laddr.String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, conn.Close)
|
|
||||||
|
|
||||||
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
|
|
||||||
require.Equal(t, laddr.Addr(), ipp.Addr())
|
|
||||||
require.NotZero(t, ipp.Port())
|
|
||||||
|
|
||||||
err = CheckPort("udp", ipp)
|
|
||||||
target := &net.OpError{}
|
|
||||||
require.ErrorAs(t, err, &target)
|
|
||||||
|
|
||||||
assert.Equal(t, "listen", target.Op)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("bad_network", func(t *testing.T) {
|
|
||||||
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("can_bind", func(t *testing.T) {
|
|
||||||
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCollectAllIfacesAddrs(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
wantErrMsg string
|
|
||||||
addrs []net.Addr
|
|
||||||
wantAddrs []string
|
|
||||||
}{{
|
|
||||||
name: "success",
|
|
||||||
wantErrMsg: ``,
|
|
||||||
addrs: []net.Addr{&net.IPNet{
|
|
||||||
IP: net.IP{1, 2, 3, 4},
|
|
||||||
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
|
|
||||||
}, &net.IPNet{
|
|
||||||
IP: net.IP{4, 3, 2, 1},
|
|
||||||
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
|
|
||||||
}},
|
|
||||||
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
|
|
||||||
}, {
|
|
||||||
name: "not_cidr",
|
|
||||||
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
|
|
||||||
addrs: []net.Addr{&net.IPAddr{
|
|
||||||
IP: net.IP{1, 2, 3, 4},
|
|
||||||
}},
|
|
||||||
wantAddrs: nil,
|
|
||||||
}, {
|
|
||||||
name: "empty",
|
|
||||||
wantErrMsg: ``,
|
|
||||||
addrs: []net.Addr{},
|
|
||||||
wantAddrs: nil,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
|
|
||||||
|
|
||||||
addrs, err := CollectAllIfacesAddrs()
|
|
||||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tc.wantAddrs, addrs)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("internal_error", func(t *testing.T) {
|
|
||||||
const errAddrs errors.Error = "can't get addresses"
|
|
||||||
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
|
|
||||||
|
|
||||||
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
|
|
||||||
|
|
||||||
_, err := CollectAllIfacesAddrs()
|
|
||||||
testutil.AssertErrorMsg(t, wantErrMsg, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsAddrInUse(t *testing.T) {
|
|
||||||
t.Run("addr_in_use", func(t *testing.T) {
|
|
||||||
l, err := net.Listen("tcp", "0.0.0.0:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
testutil.CleanupAndRequireSuccess(t, l.Close)
|
|
||||||
|
|
||||||
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
|
|
||||||
assert.True(t, IsAddrInUse(err))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("another", func(t *testing.T) {
|
|
||||||
const anotherErr errors.Error = "not addr in use"
|
|
||||||
|
|
||||||
assert.False(t, IsAddrInUse(anotherErr))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNetInterface_MarshalJSON(t *testing.T) {
|
|
||||||
const want = `{` +
|
|
||||||
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
|
|
||||||
`"flags":"up|multicast",` +
|
|
||||||
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
|
|
||||||
`"name":"iface0",` +
|
|
||||||
`"mtu":1500` +
|
|
||||||
`}` + "\n"
|
|
||||||
|
|
||||||
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
net4 := netip.PrefixFrom(ip4, 24)
|
|
||||||
net6 := netip.PrefixFrom(ip6, 8)
|
|
||||||
|
|
||||||
iface := &NetInterface{
|
|
||||||
Addresses: []netip.Addr{ip4, ip6},
|
|
||||||
Subnets: []netip.Prefix{net4, net6},
|
|
||||||
Name: "iface0",
|
|
||||||
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
|
|
||||||
Flags: net.FlagUp | net.FlagMulticast,
|
|
||||||
MTU: 1500,
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &bytes.Buffer{}
|
|
||||||
err := json.NewEncoder(b).Encode(iface)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, want, b.String())
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,7 +2,9 @@
|
||||||
package aghtest
|
package aghtest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
@ -34,3 +36,10 @@ func ReplaceLogLevel(t testing.TB, l log.Level) {
|
||||||
t.Cleanup(func() { log.SetLevel(prev) })
|
t.Cleanup(func() { log.SetLevel(prev) })
|
||||||
log.SetLevel(l)
|
log.SetLevel(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host.
|
||||||
|
func HostToIPs(host string) (ipv4, ipv6 net.IP) {
|
||||||
|
hash := sha256.Sum256([]byte(host))
|
||||||
|
|
||||||
|
return net.IP(hash[:4]), net.IP(hash[4:20])
|
||||||
|
}
|
||||||
|
|
|
@ -2,8 +2,7 @@ package aghtest
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"net"
|
||||||
"io/fs"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
|
@ -19,67 +18,6 @@ import (
|
||||||
//
|
//
|
||||||
// Keep entities in this file in alphabetic order.
|
// Keep entities in this file in alphabetic order.
|
||||||
|
|
||||||
// Standard Library
|
|
||||||
|
|
||||||
// Package fs
|
|
||||||
|
|
||||||
// FS is a fake [fs.FS] implementation for tests.
|
|
||||||
type FS struct {
|
|
||||||
OnOpen func(name string) (fs.File, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ fs.FS = (*FS)(nil)
|
|
||||||
|
|
||||||
// Open implements the [fs.FS] interface for *FS.
|
|
||||||
func (fsys *FS) Open(name string) (fs.File, error) {
|
|
||||||
return fsys.OnOpen(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ fs.GlobFS = (*GlobFS)(nil)
|
|
||||||
|
|
||||||
// GlobFS is a fake [fs.GlobFS] implementation for tests.
|
|
||||||
type GlobFS struct {
|
|
||||||
// FS is embedded here to avoid implementing all it's methods.
|
|
||||||
FS
|
|
||||||
OnGlob func(pattern string) ([]string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Glob implements the [fs.GlobFS] interface for *GlobFS.
|
|
||||||
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
|
|
||||||
return fsys.OnGlob(pattern)
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ fs.StatFS = (*StatFS)(nil)
|
|
||||||
|
|
||||||
// StatFS is a fake [fs.StatFS] implementation for tests.
|
|
||||||
type StatFS struct {
|
|
||||||
// FS is embedded here to avoid implementing all it's methods.
|
|
||||||
FS
|
|
||||||
OnStat func(name string) (fs.FileInfo, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stat implements the [fs.StatFS] interface for *StatFS.
|
|
||||||
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
|
|
||||||
return fsys.OnStat(name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Package io
|
|
||||||
|
|
||||||
// Writer is a fake [io.Writer] implementation for tests.
|
|
||||||
type Writer struct {
|
|
||||||
OnWrite func(b []byte) (n int, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ io.Writer = (*Writer)(nil)
|
|
||||||
|
|
||||||
// Write implements the [io.Writer] interface for *Writer.
|
|
||||||
func (w *Writer) Write(b []byte) (n int, err error) {
|
|
||||||
return w.OnWrite(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Module adguard-home
|
// Module adguard-home
|
||||||
|
|
||||||
// Package aghos
|
// Package aghos
|
||||||
|
@ -177,6 +115,18 @@ func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.I
|
||||||
p.OnUpdateAddress(ip, host, info)
|
p.OnUpdateAddress(ip, host, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Package filtering
|
||||||
|
|
||||||
|
// Resolver is a fake [filtering.Resolver] implementation for tests.
|
||||||
|
type Resolver struct {
|
||||||
|
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LookupIP implements the [filtering.Resolver] interface for *Resolver.
|
||||||
|
func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) {
|
||||||
|
return r.OnLookupIP(ctx, network, host)
|
||||||
|
}
|
||||||
|
|
||||||
// Package rdns
|
// Package rdns
|
||||||
|
|
||||||
// Exchanger is a fake [rdns.Exchanger] implementation for tests.
|
// Exchanger is a fake [rdns.Exchanger] implementation for tests.
|
||||||
|
|
|
@ -1,3 +1,11 @@
|
||||||
package aghtest_test
|
package aghtest_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
)
|
||||||
|
|
||||||
// Put interface checks that cause import cycles here.
|
// Put interface checks that cause import cycles here.
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
package aghtest
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/sha256"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestResolver is a Resolver for tests.
|
|
||||||
type TestResolver struct {
|
|
||||||
counter int
|
|
||||||
counterLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostToIPs generates IPv4 and IPv6 from host.
|
|
||||||
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
|
|
||||||
hash := sha256.Sum256([]byte(host))
|
|
||||||
|
|
||||||
return net.IP(hash[:4]), net.IP(hash[4:20])
|
|
||||||
}
|
|
||||||
|
|
||||||
// LookupIP implements Resolver interface for *testResolver. It returns the
|
|
||||||
// slice of net.IP with IPv4 and IPv6 instances.
|
|
||||||
func (r *TestResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
|
|
||||||
ipv4, ipv6 := r.HostToIPs(host)
|
|
||||||
addrs := []net.IP{ipv4, ipv6}
|
|
||||||
|
|
||||||
r.counterLock.Lock()
|
|
||||||
defer r.counterLock.Unlock()
|
|
||||||
r.counter++
|
|
||||||
|
|
||||||
return addrs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LookupHost implements Resolver interface for *testResolver. It returns the
|
|
||||||
// slice of IPv4 and IPv6 instances converted to strings.
|
|
||||||
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
|
|
||||||
ipv4, ipv6 := r.HostToIPs(host)
|
|
||||||
|
|
||||||
r.counterLock.Lock()
|
|
||||||
defer r.counterLock.Unlock()
|
|
||||||
r.counter++
|
|
||||||
|
|
||||||
return []string{
|
|
||||||
ipv4.String(),
|
|
||||||
ipv6.String(),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Counter returns the number of requests handled.
|
|
||||||
func (r *TestResolver) Counter() int {
|
|
||||||
r.counterLock.Lock()
|
|
||||||
defer r.counterLock.Unlock()
|
|
||||||
|
|
||||||
return r.counter
|
|
||||||
}
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -39,7 +40,7 @@ func (EmptyAddrProc) Close() (_ error) { return nil }
|
||||||
type DefaultAddrProcConfig struct {
|
type DefaultAddrProcConfig struct {
|
||||||
// DialContext is used to create TCP connections to WHOIS servers.
|
// DialContext is used to create TCP connections to WHOIS servers.
|
||||||
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
|
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
|
||||||
DialContext whois.DialContextFunc
|
DialContext aghnet.DialContextFunc
|
||||||
|
|
||||||
// Exchanger is used to perform rDNS queries. Exchanger must not be nil if
|
// Exchanger is used to perform rDNS queries. Exchanger must not be nil if
|
||||||
// [DefaultAddrProcConfig.UseRDNS] is true.
|
// [DefaultAddrProcConfig.UseRDNS] is true.
|
||||||
|
@ -161,7 +162,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
|
||||||
|
|
||||||
// newWHOIS returns a whois.Interface instance using the given function for
|
// newWHOIS returns a whois.Interface instance using the given function for
|
||||||
// dialing.
|
// dialing.
|
||||||
func newWHOIS(dialFunc whois.DialContextFunc) (w whois.Interface) {
|
func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
|
||||||
// TODO(s.chzhen): Consider making configurable.
|
// TODO(s.chzhen): Consider making configurable.
|
||||||
const (
|
const (
|
||||||
// defaultTimeout is the timeout for WHOIS requests.
|
// defaultTimeout is the timeout for WHOIS requests.
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DialContext is a [whois.DialContextFunc] that uses s to resolve hostnames.
|
// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
|
||||||
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||||
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
@ -467,7 +468,14 @@ func TestServerRace(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSafeSearch(t *testing.T) {
|
func TestSafeSearch(t *testing.T) {
|
||||||
resolver := &aghtest.TestResolver{}
|
resolver := &aghtest.Resolver{
|
||||||
|
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||||
|
ip4, ip6 := aghtest.HostToIPs(host)
|
||||||
|
|
||||||
|
return []net.IP{ip4, ip6}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
safeSearchConf := filtering.SafeSearchConfig{
|
safeSearchConf := filtering.SafeSearchConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
Google: true,
|
Google: true,
|
||||||
|
@ -506,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
client := &dns.Client{}
|
client := &dns.Client{}
|
||||||
|
|
||||||
yandexIP := net.IP{213, 180, 193, 56}
|
yandexIP := net.IP{213, 180, 193, 56}
|
||||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
host string
|
host string
|
||||||
|
@ -954,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
Upstream: aghtest.NewBlockUpstream(hostname, true),
|
Upstream: aghtest.NewBlockUpstream(hostname, true),
|
||||||
})
|
})
|
||||||
|
|
||||||
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
|
ans4, _ := aghtest.HostToIPs(hostname)
|
||||||
|
|
||||||
filterConf := &filtering.Config{
|
filterConf := &filtering.Config{
|
||||||
SafeBrowsingEnabled: true,
|
SafeBrowsingEnabled: true,
|
||||||
|
|
|
@ -6,10 +6,10 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil/fakeio"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -159,7 +159,7 @@ func TestParser_Parse(t *testing.T) {
|
||||||
func TestParser_Parse_writeError(t *testing.T) {
|
func TestParser_Parse_writeError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
dst := &aghtest.Writer{
|
dst := &fakeio.Writer{
|
||||||
OnWrite: func(b []byte) (n int, err error) {
|
OnWrite: func(b []byte) (n int, err error) {
|
||||||
return 1, errors.Error("test error")
|
return 1, errors.Error("test error")
|
||||||
},
|
},
|
||||||
|
|
|
@ -89,37 +89,34 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
|
||||||
assert.False(t, res.IsFiltered)
|
assert.False(t, res.IsFiltered)
|
||||||
assert.Empty(t, res.Rules)
|
assert.Empty(t, res.Rules)
|
||||||
|
|
||||||
resolver := &aghtest.TestResolver{}
|
resolver := &aghtest.Resolver{
|
||||||
|
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||||
|
ip4, ip6 := aghtest.HostToIPs(host)
|
||||||
|
|
||||||
|
return []net.IP{ip4, ip6}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
ss = newForTest(t, defaultSafeSearchConf)
|
ss = newForTest(t, defaultSafeSearchConf)
|
||||||
ss.resolver = resolver
|
ss.resolver = resolver
|
||||||
|
|
||||||
// Lookup for safesearch domain.
|
// Lookup for safesearch domain.
|
||||||
rewrite := ss.searchHost(domain, testQType)
|
rewrite := ss.searchHost(domain, testQType)
|
||||||
|
|
||||||
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
|
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var foundIP net.IP
|
|
||||||
for _, ip := range ips {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
foundIP = ip
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err = ss.CheckHost(domain, testQType)
|
res, err = ss.CheckHost(domain, testQType)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.True(t, res.Rules[0].IP.Equal(foundIP))
|
assert.True(t, res.Rules[0].IP.Equal(wantIP))
|
||||||
|
|
||||||
// Check cache.
|
// Check cache.
|
||||||
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
cachedValue, isFound := ss.getCachedResult(domain, testQType)
|
||||||
require.True(t, isFound)
|
require.True(t, isFound)
|
||||||
require.Len(t, cachedValue.Rules, 1)
|
require.Len(t, cachedValue.Rules, 1)
|
||||||
|
|
||||||
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
|
assert.True(t, cachedValue.Rules[0].IP.Equal(wantIP))
|
||||||
}
|
}
|
||||||
|
|
||||||
const googleHost = "www.google.com"
|
const googleHost = "www.google.com"
|
||||||
|
|
|
@ -92,8 +92,15 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefault_CheckHost_google(t *testing.T) {
|
func TestDefault_CheckHost_google(t *testing.T) {
|
||||||
resolver := &aghtest.TestResolver{}
|
resolver := &aghtest.Resolver{
|
||||||
ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
|
||||||
|
ip4, ip6 := aghtest.HostToIPs(host)
|
||||||
|
|
||||||
|
return []net.IP{ip4, ip6}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
|
||||||
|
|
||||||
conf := testConf
|
conf := testConf
|
||||||
conf.CustomResolver = resolver
|
conf.CustomResolver = resolver
|
||||||
|
@ -119,7 +126,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
|
||||||
|
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
||||||
assert.Equal(t, ip, res.Rules[0].IP)
|
assert.Equal(t, wantIP, res.Rules[0].IP)
|
||||||
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,11 +12,11 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil/fakefs"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -90,7 +90,7 @@ func newTestServer(
|
||||||
|
|
||||||
c := &websvc.Config{
|
c := &websvc.Config{
|
||||||
ConfigManager: confMgr,
|
ConfigManager: confMgr,
|
||||||
Frontend: &aghtest.FS{
|
Frontend: &fakefs.FS{
|
||||||
OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist },
|
OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist },
|
||||||
},
|
},
|
||||||
TLS: nil,
|
TLS: nil,
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
@ -49,7 +50,7 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool)
|
||||||
// Config is the configuration structure for Default.
|
// Config is the configuration structure for Default.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// DialContext is used to create TCP connections to WHOIS servers.
|
// DialContext is used to create TCP connections to WHOIS servers.
|
||||||
DialContext DialContextFunc
|
DialContext aghnet.DialContextFunc
|
||||||
|
|
||||||
// ServerAddr is the address of the WHOIS server.
|
// ServerAddr is the address of the WHOIS server.
|
||||||
ServerAddr string
|
ServerAddr string
|
||||||
|
@ -77,13 +78,6 @@ type Config struct {
|
||||||
Port uint16
|
Port uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialContextFunc is the semantic alias for dialing functions, such as
|
|
||||||
// [http.Transport.DialContext].
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Move to aghnet once it stops importing aghtest, because
|
|
||||||
// otherwise there is an import cycle.
|
|
||||||
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
|
|
||||||
|
|
||||||
// Default is the default WHOIS information processor.
|
// Default is the default WHOIS information processor.
|
||||||
type Default struct {
|
type Default struct {
|
||||||
// cache is the cache containing IP addresses of clients. An active IP
|
// cache is the cache containing IP addresses of clients. An active IP
|
||||||
|
@ -93,7 +87,7 @@ type Default struct {
|
||||||
cache gcache.Cache
|
cache gcache.Cache
|
||||||
|
|
||||||
// dialContext is used to create TCP connections to WHOIS servers.
|
// dialContext is used to create TCP connections to WHOIS servers.
|
||||||
dialContext DialContextFunc
|
dialContext aghnet.DialContextFunc
|
||||||
|
|
||||||
// serverAddr is the address of the WHOIS server.
|
// serverAddr is the address of the WHOIS server.
|
||||||
serverAddr string
|
serverAddr string
|
||||||
|
|
Loading…
Reference in New Issue