From 5be0e84719c15b92f0c15975af4a422155370914 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Thu, 20 Jul 2023 14:26:35 +0300 Subject: [PATCH] Pull request 1933: upd-golibs Squashed commit of the following: commit 081d10e6909def3a075707e75dbd0c5f63f91903 Author: Ainar Garipov Date: Thu Jul 20 14:17:01 2023 +0300 aghnet: fix docs commit 7433b72c0653cb33fe5ff810ae8a1346a6994f95 Author: Ainar Garipov Date: Thu Jul 20 14:03:16 2023 +0300 all: imp tests; upd golibs --- go.mod | 2 +- go.sum | 4 +- internal/aghio/limitedreader_test.go | 11 +- internal/aghnet/hostscontainer.go | 6 +- .../aghnet/hostscontainer_internal_test.go | 144 ++++++++ internal/aghnet/hostscontainer_test.go | 162 +-------- internal/aghnet/net.go | 5 + internal/aghnet/net_darwin_test.go | 4 +- internal/aghnet/net_internal_test.go | 334 ++++++++++++++++++ internal/aghnet/net_test.go | 324 +---------------- internal/aghtest/aghtest.go | 9 + internal/aghtest/interface.go | 76 +--- internal/aghtest/interface_test.go | 8 + internal/aghtest/resolver.go | 57 --- internal/client/addrproc.go | 5 +- internal/dnsforward/dialcontext.go | 2 +- internal/dnsforward/dnsforward_test.go | 14 +- internal/filtering/rulelist/parser_test.go | 4 +- .../safesearch/safesearch_internal_test.go | 25 +- .../filtering/safesearch/safesearch_test.go | 13 +- internal/next/websvc/websvc_test.go | 4 +- internal/whois/whois.go | 12 +- 22 files changed, 587 insertions(+), 638 deletions(-) create mode 100644 internal/aghnet/hostscontainer_internal_test.go create mode 100644 internal/aghnet/net_internal_test.go delete mode 100644 internal/aghtest/resolver.go diff --git a/go.mod b/go.mod index 60cada3e..7567553b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.19 require ( github.com/AdguardTeam/dnsproxy v0.52.0 - github.com/AdguardTeam/golibs v0.13.4 + github.com/AdguardTeam/golibs v0.13.5 github.com/AdguardTeam/urlfilter v0.16.1 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.7 diff --git a/go.sum b/go.sum index 13431f68..1cef6a07 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/AdguardTeam/dnsproxy v0.52.0 h1:uZxCXflHSAwtJ7uTYXP6qgWcxaBsH0pJvldpw github.com/AdguardTeam/dnsproxy v0.52.0/go.mod h1:Jo2zeRe97Rxt3yikXc+fn0LdLtqCj0Xlyh1PNBj6bpM= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= -github.com/AdguardTeam/golibs v0.13.4 h1:ACTwIR1pEENBijHcEWtiMbSh4wWQOlIHRxmUB8oBHf8= -github.com/AdguardTeam/golibs v0.13.4/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI= +github.com/AdguardTeam/golibs v0.13.5 h1:fpa30Yr9Rcn4vJ88nE4XHSompY7/qMOq2aNS/4PGymA= +github.com/AdguardTeam/golibs v0.13.5/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw= github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= diff --git a/internal/aghio/limitedreader_test.go b/internal/aghio/limitedreader_test.go index b3cef08d..a85fd051 100644 --- a/internal/aghio/limitedreader_test.go +++ b/internal/aghio/limitedreader_test.go @@ -1,10 +1,11 @@ -package aghio +package aghio_test import ( "io" "strings" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,7 +32,7 @@ func TestLimitReader(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := LimitReader(nil, tc.n) + _, err := aghio.LimitReader(nil, tc.n) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } @@ -57,7 +58,7 @@ func TestLimitedReader_Read(t *testing.T) { limit: 3, want: 0, }, { - err: &LimitReachedError{ + err: &aghio.LimitReachedError{ Limit: 0, }, name: "limit_reached", @@ -74,7 +75,7 @@ func TestLimitedReader_Read(t *testing.T) { for _, tc := range testCases { readCloser := io.NopCloser(strings.NewReader(tc.rStr)) - lreader, err := LimitReader(readCloser, tc.limit) + lreader, err := aghio.LimitReader(readCloser, tc.limit) require.NoError(t, err) require.NotNil(t, lreader) @@ -89,7 +90,7 @@ func TestLimitedReader_Read(t *testing.T) { } func TestLimitedReader_LimitReachedError(t *testing.T) { - testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{ + testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{ Limit: 0, }) } diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 2fecbc6f..f2e57c46 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -141,9 +141,9 @@ type HostsRecord struct { Canonical string } -// equal returns true if all fields of rec are equal to field in other or they +// Equal returns true if all fields of rec are equal to field in other or they // both are nil. -func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) { +func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) { if rec == nil { return other == nil } else if other == nil { @@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) { } // hc.last is nil on the first refresh, so let that one through. - if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) { + if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) { log.Debug("%s: no changes detected", hostsContainerPrefix) return nil diff --git a/internal/aghnet/hostscontainer_internal_test.go b/internal/aghnet/hostscontainer_internal_test.go new file mode 100644 index 00000000..e3855f39 --- /dev/null +++ b/internal/aghnet/hostscontainer_internal_test.go @@ -0,0 +1,144 @@ +package aghnet + +import ( + "io/fs" + "net/netip" + "path" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil/fakefs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const nl = "\n" + +func TestHostsContainer_PathsToPatterns(t *testing.T) { + gsfs := fstest.MapFS{ + "dir_0/file_1": &fstest.MapFile{Data: []byte{1}}, + "dir_0/file_2": &fstest.MapFile{Data: []byte{2}}, + "dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}}, + } + + testCases := []struct { + name string + paths []string + want []string + }{{ + name: "no_paths", + paths: nil, + want: nil, + }, { + name: "single_file", + paths: []string{"dir_0/file_1"}, + want: []string{"dir_0/file_1"}, + }, { + name: "several_files", + paths: []string{"dir_0/file_1", "dir_0/file_2"}, + want: []string{"dir_0/file_1", "dir_0/file_2"}, + }, { + name: "whole_dir", + paths: []string{"dir_0"}, + want: []string{"dir_0/*"}, + }, { + name: "file_and_dir", + paths: []string{"dir_0/file_1", "dir_0/dir_1"}, + want: []string{"dir_0/file_1", "dir_0/dir_1/*"}, + }, { + name: "non-existing", + paths: []string{path.Join("dir_0", "file_3")}, + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + patterns, err := pathsToPatterns(gsfs, tc.paths) + require.NoError(t, err) + + assert.Equal(t, tc.want, patterns) + }) + } + + t.Run("bad_file", func(t *testing.T) { + const errStat errors.Error = "bad file" + + badFS := &fakefs.StatFS{ + OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") }, + OnStat: func(name string) (fi fs.FileInfo, err error) { + return nil, errStat + }, + } + + _, err := pathsToPatterns(badFS, []string{""}) + assert.ErrorIs(t, err, errStat) + }) +} + +func TestUniqueRules_ParseLine(t *testing.T) { + ip := netutil.IPv4Localhost() + ipStr := ip.String() + + testCases := []struct { + name string + line string + wantIP netip.Addr + wantHosts []string + }{{ + name: "simple", + line: ipStr + ` hostname`, + wantIP: ip, + wantHosts: []string{"hostname"}, + }, { + name: "aliases", + line: ipStr + ` hostname alias`, + wantIP: ip, + wantHosts: []string{"hostname", "alias"}, + }, { + name: "invalid_line", + line: ipStr, + wantIP: netip.Addr{}, + wantHosts: nil, + }, { + name: "invalid_line_hostname", + line: ipStr + ` # hostname`, + wantIP: ip, + wantHosts: nil, + }, { + name: "commented_aliases", + line: ipStr + ` hostname # alias`, + wantIP: ip, + wantHosts: []string{"hostname"}, + }, { + name: "whole_comment", + line: `# ` + ipStr + ` hostname`, + wantIP: netip.Addr{}, + wantHosts: nil, + }, { + name: "partial_comment", + line: ipStr + ` host#name`, + wantIP: ip, + wantHosts: []string{"host"}, + }, { + name: "empty", + line: ``, + wantIP: netip.Addr{}, + wantHosts: nil, + }, { + name: "bad_hosts", + line: ipStr + ` bad..host bad._tld empty.tld. ok.host`, + wantIP: ip, + wantHosts: []string{"ok.host"}, + }} + + for _, tc := range testCases { + hp := hostsParser{} + t.Run(tc.name, func(t *testing.T) { + got, hosts := hp.parseLine(tc.line) + assert.Equal(t, tc.wantIP, got) + assert.Equal(t, tc.wantHosts, hosts) + }) + } +} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index d145a7b4..00c2aeed 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -1,9 +1,7 @@ -package aghnet +package aghnet_test import ( - "io/fs" "net" - "net/netip" "path" "strings" "sync/atomic" @@ -12,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghchan" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" @@ -24,10 +23,7 @@ import ( "github.com/stretchr/testify/require" ) -const ( - nl = "\n" - sp = " " -) +const nl = "\n" func TestNewHostsContainer(t *testing.T) { const dirname = "dir" @@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) { name: "one_file", paths: []string{p}, }, { - wantErr: ErrNoHostsPaths, + wantErr: aghnet.ErrNoHostsPaths, name: "no_files", paths: []string{}, }, { - wantErr: ErrNoHostsPaths, + wantErr: aghnet.ErrNoHostsPaths, name: "non-existent_file", paths: []string{path.Join(dirname, filename+"2")}, }, { @@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) { return eventsCh } - hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{ + hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{ OnEvents: onEvents, OnAdd: onAdd, OnClose: func() (err error) { return nil }, @@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) { t.Run("nil_fs", func(t *testing.T) { require.Panics(t, func() { - _, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{ + _, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{ // Those shouldn't panic. OnEvents: func() (e <-chan struct{}) { return nil }, OnAdd: func(name string) (err error) { return nil }, @@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) { t.Run("nil_watcher", func(t *testing.T) { require.Panics(t, func() { - _, _ = NewHostsContainer(0, testFS, nil, p) + _, _ = aghnet.NewHostsContainer(0, testFS, nil, p) }) }) @@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) { OnClose: func() (err error) { return nil }, } - hc, err := NewHostsContainer(0, testFS, errWatcher, p) + hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p) require.ErrorIs(t, err, errOnAdd) assert.Nil(t, hc) @@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) { OnClose: func() (err error) { return nil }, } - hc, err := NewHostsContainer(0, testFS, w, "dir") + hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) - checkRefresh := func(t *testing.T, want *HostsRecord) { + checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) { t.Helper() upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) @@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) { require.True(t, ok) require.NotNil(t, rec) - assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want) + assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want) } t.Run("initial_refresh", func(t *testing.T) { - checkRefresh(t, &HostsRecord{ + checkRefresh(t, &aghnet.HostsRecord{ Aliases: stringutil.NewSet(), Canonical: "hostname", }) @@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) { testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)} eventsCh <- event{} - checkRefresh(t, &HostsRecord{ + checkRefresh(t, &aghnet.HostsRecord{ Aliases: stringutil.NewSet("alias"), Canonical: "hostname", }) @@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) { }) } -func TestHostsContainer_PathsToPatterns(t *testing.T) { - gsfs := fstest.MapFS{ - "dir_0/file_1": &fstest.MapFile{Data: []byte{1}}, - "dir_0/file_2": &fstest.MapFile{Data: []byte{2}}, - "dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}}, - } - - testCases := []struct { - name string - paths []string - want []string - }{{ - name: "no_paths", - paths: nil, - want: nil, - }, { - name: "single_file", - paths: []string{"dir_0/file_1"}, - want: []string{"dir_0/file_1"}, - }, { - name: "several_files", - paths: []string{"dir_0/file_1", "dir_0/file_2"}, - want: []string{"dir_0/file_1", "dir_0/file_2"}, - }, { - name: "whole_dir", - paths: []string{"dir_0"}, - want: []string{"dir_0/*"}, - }, { - name: "file_and_dir", - paths: []string{"dir_0/file_1", "dir_0/dir_1"}, - want: []string{"dir_0/file_1", "dir_0/dir_1/*"}, - }, { - name: "non-existing", - paths: []string{path.Join("dir_0", "file_3")}, - want: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - patterns, err := pathsToPatterns(gsfs, tc.paths) - require.NoError(t, err) - - assert.Equal(t, tc.want, patterns) - }) - } - - t.Run("bad_file", func(t *testing.T) { - const errStat errors.Error = "bad file" - - badFS := &aghtest.StatFS{ - OnStat: func(name string) (fs.FileInfo, error) { - return nil, errStat - }, - } - - _, err := pathsToPatterns(badFS, []string{""}) - assert.ErrorIs(t, err, errStat) - }) -} - func TestHostsContainer_Translate(t *testing.T) { stubWatcher := aghtest.FSWatcher{ OnEvents: func() (e <-chan struct{}) { return nil }, @@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) { require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) - hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts") + hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) @@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) { OnClose: func() (err error) { return nil }, } - hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts") + hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) @@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) { }) } } - -func TestUniqueRules_ParseLine(t *testing.T) { - ip := netutil.IPv4Localhost() - ipStr := ip.String() - - testCases := []struct { - name string - line string - wantIP netip.Addr - wantHosts []string - }{{ - name: "simple", - line: ipStr + ` hostname`, - wantIP: ip, - wantHosts: []string{"hostname"}, - }, { - name: "aliases", - line: ipStr + ` hostname alias`, - wantIP: ip, - wantHosts: []string{"hostname", "alias"}, - }, { - name: "invalid_line", - line: ipStr, - wantIP: netip.Addr{}, - wantHosts: nil, - }, { - name: "invalid_line_hostname", - line: ipStr + ` # hostname`, - wantIP: ip, - wantHosts: nil, - }, { - name: "commented_aliases", - line: ipStr + ` hostname # alias`, - wantIP: ip, - wantHosts: []string{"hostname"}, - }, { - name: "whole_comment", - line: `# ` + ipStr + ` hostname`, - wantIP: netip.Addr{}, - wantHosts: nil, - }, { - name: "partial_comment", - line: ipStr + ` host#name`, - wantIP: ip, - wantHosts: []string{"host"}, - }, { - name: "empty", - line: ``, - wantIP: netip.Addr{}, - wantHosts: nil, - }, { - name: "bad_hosts", - line: ipStr + ` bad..host bad._tld empty.tld. ok.host`, - wantIP: ip, - wantHosts: []string{"ok.host"}, - }} - - for _, tc := range testCases { - hp := hostsParser{} - t.Run(tc.name, func(t *testing.T) { - got, hosts := hp.parseLine(tc.line) - assert.Equal(t, tc.wantIP, got) - assert.Equal(t, tc.wantHosts, hosts) - }) - } -} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index b8c8c05e..919b03d6 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -3,6 +3,7 @@ package aghnet import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -15,6 +16,10 @@ import ( "github.com/AdguardTeam/golibs/log" ) +// DialContextFunc is the semantic alias for dialing functions, such as +// [http.Transport.DialContext]. +type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error) + // Variables and functions to substitute in tests. var ( // aghosRunCommand is the function to run shell commands. diff --git a/internal/aghnet/net_darwin_test.go b/internal/aghnet/net_darwin_test.go index 905600d5..06e7eeaf 100644 --- a/internal/aghnet/net_darwin_test.go +++ b/internal/aghnet/net_darwin_test.go @@ -5,9 +5,9 @@ import ( "testing" "testing/fstest" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/testutil/fakefs" "github.com/stretchr/testify/assert" ) @@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) { Data: []byte(`nameserver 1.1.1.1`), }, } - panicFsys := &aghtest.FS{ + panicFsys := &fakefs.FS{ OnOpen: func(name string) (fs.File, error) { panic("not implemented") }, } diff --git a/internal/aghnet/net_internal_test.go b/internal/aghnet/net_internal_test.go new file mode 100644 index 00000000..9c4cff8c --- /dev/null +++ b/internal/aghnet/net_internal_test.go @@ -0,0 +1,334 @@ +package aghnet + +import ( + "bytes" + "encoding/json" + "fmt" + "io/fs" + "net" + "net/netip" + "os" + "strings" + "testing" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testdata is the filesystem containing data for testing the package. +var testdata fs.FS = os.DirFS("./testdata") + +// substRootDirFS replaces the aghos.RootDirFS function used throughout the +// package with fsys for tests ran under t. +func substRootDirFS(t testing.TB, fsys fs.FS) { + t.Helper() + + prev := rootDirFS + t.Cleanup(func() { rootDirFS = prev }) + rootDirFS = fsys +} + +// RunCmdFunc is the signature of aghos.RunCommand function. +type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error) + +// substShell replaces the the aghos.RunCommand function used throughout the +// package with rc for tests ran under t. +func substShell(t testing.TB, rc RunCmdFunc) { + t.Helper() + + prev := aghosRunCommand + t.Cleanup(func() { aghosRunCommand = prev }) + aghosRunCommand = rc +} + +// mapShell is a substitution of aghos.RunCommand that maps the command to it's +// execution result. It's only needed to simplify testing. +// +// TODO(e.burkov): Perhaps put all the shell interactions behind an interface. +type mapShell map[string]struct { + err error + out string + code int +} + +// theOnlyCmd returns mapShell that only handles a single command and arguments +// combination from cmd. +func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { + return mapShell{cmd: {code: code, out: out, err: err}} +} + +// RunCmd is a RunCmdFunc handled by s. +func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) { + key := strings.Join(append([]string{cmd}, args...), " ") + ret, ok := s[key] + if !ok { + return 0, nil, fmt.Errorf("unexpected shell command %q", key) + } + + return ret.code, []byte(ret.out), ret.err +} + +// ifaceAddrsFunc is the signature of net.InterfaceAddrs function. +type ifaceAddrsFunc func() (ifaces []net.Addr, err error) + +// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used +// throughout the package with f for tests ran under t. +func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) { + t.Helper() + + prev := netInterfaceAddrs + t.Cleanup(func() { netInterfaceAddrs = prev }) + netInterfaceAddrs = f +} + +func TestGatewayIP(t *testing.T) { + const ifaceName = "ifaceName" + const cmd = "ip route show dev " + ifaceName + + testCases := []struct { + shell mapShell + want netip.Addr + name string + }{{ + shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil), + want: netip.MustParseAddr("1.2.3.4"), + name: "success_v4", + }, { + shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil), + want: netip.MustParseAddr("::ffff"), + name: "success_v6", + }, { + shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil), + want: netip.Addr{}, + name: "bad_output", + }, { + shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")), + want: netip.Addr{}, + name: "err_runcmd", + }, { + shell: theOnlyCmd(cmd, 1, "", nil), + want: netip.Addr{}, + name: "bad_code", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substShell(t, tc.shell.RunCmd) + + assert.Equal(t, tc.want, GatewayIP(ifaceName)) + }) + } +} + +func TestInterfaceByIP(t *testing.T) { + ifaces, err := GetValidNetInterfacesForWeb() + require.NoError(t, err) + require.NotEmpty(t, ifaces) + + for _, iface := range ifaces { + t.Run(iface.Name, func(t *testing.T) { + require.NotEmpty(t, iface.Addresses) + + for _, ip := range iface.Addresses { + ifaceName := InterfaceByIP(ip) + require.Equal(t, iface.Name, ifaceName) + } + }) + } +} + +func TestBroadcastFromIPNet(t *testing.T) { + known4 := netip.MustParseAddr("192.168.0.1") + fullBroadcast4 := netip.MustParseAddr("255.255.255.255") + + known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10") + + testCases := []struct { + pref netip.Prefix + want netip.Addr + name string + }{{ + pref: netip.PrefixFrom(known4, 0), + want: fullBroadcast4, + name: "full", + }, { + pref: netip.PrefixFrom(known4, 20), + want: netip.MustParseAddr("192.168.15.255"), + name: "full", + }, { + pref: netip.PrefixFrom(known6, netutil.IPv6BitLen), + want: known6, + name: "ipv6_no_mask", + }, { + pref: netip.PrefixFrom(known4, netutil.IPv4BitLen), + want: known4, + name: "ipv4_no_mask", + }, { + pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0), + want: fullBroadcast4, + name: "unspecified", + }, { + pref: netip.Prefix{}, + want: netip.Addr{}, + name: "invalid", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, BroadcastFromPref(tc.pref)) + }) + } +} + +func TestCheckPort(t *testing.T) { + laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0) + + t.Run("tcp_bound", func(t *testing.T) { + l, err := net.Listen("tcp", laddr.String()) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, l.Close) + + ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort() + require.Equal(t, laddr.Addr(), ipp.Addr()) + require.NotZero(t, ipp.Port()) + + err = CheckPort("tcp", ipp) + target := &net.OpError{} + require.ErrorAs(t, err, &target) + + assert.Equal(t, "listen", target.Op) + }) + + t.Run("udp_bound", func(t *testing.T) { + conn, err := net.ListenPacket("udp", laddr.String()) + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, conn.Close) + + ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort() + require.Equal(t, laddr.Addr(), ipp.Addr()) + require.NotZero(t, ipp.Port()) + + err = CheckPort("udp", ipp) + target := &net.OpError{} + require.ErrorAs(t, err, &target) + + assert.Equal(t, "listen", target.Op) + }) + + t.Run("bad_network", func(t *testing.T) { + err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0)) + assert.NoError(t, err) + }) + + t.Run("can_bind", func(t *testing.T) { + err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0)) + assert.NoError(t, err) + }) +} + +func TestCollectAllIfacesAddrs(t *testing.T) { + testCases := []struct { + name string + wantErrMsg string + addrs []net.Addr + wantAddrs []string + }{{ + name: "success", + wantErrMsg: ``, + addrs: []net.Addr{&net.IPNet{ + IP: net.IP{1, 2, 3, 4}, + Mask: net.CIDRMask(24, netutil.IPv4BitLen), + }, &net.IPNet{ + IP: net.IP{4, 3, 2, 1}, + Mask: net.CIDRMask(16, netutil.IPv4BitLen), + }}, + wantAddrs: []string{"1.2.3.4", "4.3.2.1"}, + }, { + name: "not_cidr", + wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`, + addrs: []net.Addr{&net.IPAddr{ + IP: net.IP{1, 2, 3, 4}, + }}, + wantAddrs: nil, + }, { + name: "empty", + wantErrMsg: ``, + addrs: []net.Addr{}, + wantAddrs: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil }) + + addrs, err := CollectAllIfacesAddrs() + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + + assert.Equal(t, tc.wantAddrs, addrs) + }) + } + + t.Run("internal_error", func(t *testing.T) { + const errAddrs errors.Error = "can't get addresses" + const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs) + + substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs }) + + _, err := CollectAllIfacesAddrs() + testutil.AssertErrorMsg(t, wantErrMsg, err) + }) +} + +func TestIsAddrInUse(t *testing.T) { + t.Run("addr_in_use", func(t *testing.T) { + l, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, l.Close) + + _, err = net.Listen(l.Addr().Network(), l.Addr().String()) + assert.True(t, IsAddrInUse(err)) + }) + + t.Run("another", func(t *testing.T) { + const anotherErr errors.Error = "not addr in use" + + assert.False(t, IsAddrInUse(anotherErr)) + }) +} + +func TestNetInterface_MarshalJSON(t *testing.T) { + const want = `{` + + `"hardware_address":"aa:bb:cc:dd:ee:ff",` + + `"flags":"up|multicast",` + + `"ip_addresses":["1.2.3.4","aaaa::1"],` + + `"name":"iface0",` + + `"mtu":1500` + + `}` + "\n" + + ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4}) + require.True(t, ok) + + ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + require.True(t, ok) + + net4 := netip.PrefixFrom(ip4, 24) + net6 := netip.PrefixFrom(ip6, 8) + + iface := &NetInterface{ + Addresses: []netip.Addr{ip4, ip6}, + Subnets: []netip.Prefix{net4, net6}, + Name: "iface0", + HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, + Flags: net.FlagUp | net.FlagMulticast, + MTU: 1500, + } + + b := &bytes.Buffer{} + err := json.NewEncoder(b).Encode(iface) + require.NoError(t, err) + + assert.Equal(t, want, b.String()) +} diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 6e9e612e..8615eed9 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,21 +1,11 @@ -package aghnet +package aghnet_test import ( - "bytes" - "encoding/json" - "fmt" "io/fs" - "net" - "net/netip" "os" - "strings" "testing" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { @@ -24,315 +14,3 @@ func TestMain(m *testing.M) { // testdata is the filesystem containing data for testing the package. var testdata fs.FS = os.DirFS("./testdata") - -// substRootDirFS replaces the aghos.RootDirFS function used throughout the -// package with fsys for tests ran under t. -func substRootDirFS(t testing.TB, fsys fs.FS) { - t.Helper() - - prev := rootDirFS - t.Cleanup(func() { rootDirFS = prev }) - rootDirFS = fsys -} - -// RunCmdFunc is the signature of aghos.RunCommand function. -type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error) - -// substShell replaces the the aghos.RunCommand function used throughout the -// package with rc for tests ran under t. -func substShell(t testing.TB, rc RunCmdFunc) { - t.Helper() - - prev := aghosRunCommand - t.Cleanup(func() { aghosRunCommand = prev }) - aghosRunCommand = rc -} - -// mapShell is a substitution of aghos.RunCommand that maps the command to it's -// execution result. It's only needed to simplify testing. -// -// TODO(e.burkov): Perhaps put all the shell interactions behind an interface. -type mapShell map[string]struct { - err error - out string - code int -} - -// theOnlyCmd returns mapShell that only handles a single command and arguments -// combination from cmd. -func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { - return mapShell{cmd: {code: code, out: out, err: err}} -} - -// RunCmd is a RunCmdFunc handled by s. -func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) { - key := strings.Join(append([]string{cmd}, args...), " ") - ret, ok := s[key] - if !ok { - return 0, nil, fmt.Errorf("unexpected shell command %q", key) - } - - return ret.code, []byte(ret.out), ret.err -} - -// ifaceAddrsFunc is the signature of net.InterfaceAddrs function. -type ifaceAddrsFunc func() (ifaces []net.Addr, err error) - -// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used -// throughout the package with f for tests ran under t. -func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) { - t.Helper() - - prev := netInterfaceAddrs - t.Cleanup(func() { netInterfaceAddrs = prev }) - netInterfaceAddrs = f -} - -func TestGatewayIP(t *testing.T) { - const ifaceName = "ifaceName" - const cmd = "ip route show dev " + ifaceName - - testCases := []struct { - shell mapShell - want netip.Addr - name string - }{{ - shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil), - want: netip.MustParseAddr("1.2.3.4"), - name: "success_v4", - }, { - shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil), - want: netip.MustParseAddr("::ffff"), - name: "success_v6", - }, { - shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil), - want: netip.Addr{}, - name: "bad_output", - }, { - shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")), - want: netip.Addr{}, - name: "err_runcmd", - }, { - shell: theOnlyCmd(cmd, 1, "", nil), - want: netip.Addr{}, - name: "bad_code", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - substShell(t, tc.shell.RunCmd) - - assert.Equal(t, tc.want, GatewayIP(ifaceName)) - }) - } -} - -func TestInterfaceByIP(t *testing.T) { - ifaces, err := GetValidNetInterfacesForWeb() - require.NoError(t, err) - require.NotEmpty(t, ifaces) - - for _, iface := range ifaces { - t.Run(iface.Name, func(t *testing.T) { - require.NotEmpty(t, iface.Addresses) - - for _, ip := range iface.Addresses { - ifaceName := InterfaceByIP(ip) - require.Equal(t, iface.Name, ifaceName) - } - }) - } -} - -func TestBroadcastFromIPNet(t *testing.T) { - known4 := netip.MustParseAddr("192.168.0.1") - fullBroadcast4 := netip.MustParseAddr("255.255.255.255") - - known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10") - - testCases := []struct { - pref netip.Prefix - want netip.Addr - name string - }{{ - pref: netip.PrefixFrom(known4, 0), - want: fullBroadcast4, - name: "full", - }, { - pref: netip.PrefixFrom(known4, 20), - want: netip.MustParseAddr("192.168.15.255"), - name: "full", - }, { - pref: netip.PrefixFrom(known6, netutil.IPv6BitLen), - want: known6, - name: "ipv6_no_mask", - }, { - pref: netip.PrefixFrom(known4, netutil.IPv4BitLen), - want: known4, - name: "ipv4_no_mask", - }, { - pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0), - want: fullBroadcast4, - name: "unspecified", - }, { - pref: netip.Prefix{}, - want: netip.Addr{}, - name: "invalid", - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, BroadcastFromPref(tc.pref)) - }) - } -} - -func TestCheckPort(t *testing.T) { - laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0) - - t.Run("tcp_bound", func(t *testing.T) { - l, err := net.Listen("tcp", laddr.String()) - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, l.Close) - - ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort() - require.Equal(t, laddr.Addr(), ipp.Addr()) - require.NotZero(t, ipp.Port()) - - err = CheckPort("tcp", ipp) - target := &net.OpError{} - require.ErrorAs(t, err, &target) - - assert.Equal(t, "listen", target.Op) - }) - - t.Run("udp_bound", func(t *testing.T) { - conn, err := net.ListenPacket("udp", laddr.String()) - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, conn.Close) - - ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort() - require.Equal(t, laddr.Addr(), ipp.Addr()) - require.NotZero(t, ipp.Port()) - - err = CheckPort("udp", ipp) - target := &net.OpError{} - require.ErrorAs(t, err, &target) - - assert.Equal(t, "listen", target.Op) - }) - - t.Run("bad_network", func(t *testing.T) { - err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0)) - assert.NoError(t, err) - }) - - t.Run("can_bind", func(t *testing.T) { - err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0)) - assert.NoError(t, err) - }) -} - -func TestCollectAllIfacesAddrs(t *testing.T) { - testCases := []struct { - name string - wantErrMsg string - addrs []net.Addr - wantAddrs []string - }{{ - name: "success", - wantErrMsg: ``, - addrs: []net.Addr{&net.IPNet{ - IP: net.IP{1, 2, 3, 4}, - Mask: net.CIDRMask(24, netutil.IPv4BitLen), - }, &net.IPNet{ - IP: net.IP{4, 3, 2, 1}, - Mask: net.CIDRMask(16, netutil.IPv4BitLen), - }}, - wantAddrs: []string{"1.2.3.4", "4.3.2.1"}, - }, { - name: "not_cidr", - wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`, - addrs: []net.Addr{&net.IPAddr{ - IP: net.IP{1, 2, 3, 4}, - }}, - wantAddrs: nil, - }, { - name: "empty", - wantErrMsg: ``, - addrs: []net.Addr{}, - wantAddrs: nil, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil }) - - addrs, err := CollectAllIfacesAddrs() - testutil.AssertErrorMsg(t, tc.wantErrMsg, err) - - assert.Equal(t, tc.wantAddrs, addrs) - }) - } - - t.Run("internal_error", func(t *testing.T) { - const errAddrs errors.Error = "can't get addresses" - const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs) - - substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs }) - - _, err := CollectAllIfacesAddrs() - testutil.AssertErrorMsg(t, wantErrMsg, err) - }) -} - -func TestIsAddrInUse(t *testing.T) { - t.Run("addr_in_use", func(t *testing.T) { - l, err := net.Listen("tcp", "0.0.0.0:0") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, l.Close) - - _, err = net.Listen(l.Addr().Network(), l.Addr().String()) - assert.True(t, IsAddrInUse(err)) - }) - - t.Run("another", func(t *testing.T) { - const anotherErr errors.Error = "not addr in use" - - assert.False(t, IsAddrInUse(anotherErr)) - }) -} - -func TestNetInterface_MarshalJSON(t *testing.T) { - const want = `{` + - `"hardware_address":"aa:bb:cc:dd:ee:ff",` + - `"flags":"up|multicast",` + - `"ip_addresses":["1.2.3.4","aaaa::1"],` + - `"name":"iface0",` + - `"mtu":1500` + - `}` + "\n" - - ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4}) - require.True(t, ok) - - ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) - require.True(t, ok) - - net4 := netip.PrefixFrom(ip4, 24) - net6 := netip.PrefixFrom(ip6, 8) - - iface := &NetInterface{ - Addresses: []netip.Addr{ip4, ip6}, - Subnets: []netip.Prefix{net4, net6}, - Name: "iface0", - HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF}, - Flags: net.FlagUp | net.FlagMulticast, - MTU: 1500, - } - - b := &bytes.Buffer{} - err := json.NewEncoder(b).Encode(iface) - require.NoError(t, err) - - assert.Equal(t, want, b.String()) -} diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 850446e0..dfe0551d 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -2,7 +2,9 @@ package aghtest import ( + "crypto/sha256" "io" + "net" "testing" "github.com/AdguardTeam/golibs/log" @@ -34,3 +36,10 @@ func ReplaceLogLevel(t testing.TB, l log.Level) { t.Cleanup(func() { log.SetLevel(prev) }) log.SetLevel(l) } + +// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host. +func HostToIPs(host string) (ipv4, ipv6 net.IP) { + hash := sha256.Sum256([]byte(host)) + + return net.IP(hash[:4]), net.IP(hash[4:20]) +} diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 779e2fe5..a84e9af6 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -2,8 +2,7 @@ package aghtest import ( "context" - "io" - "io/fs" + "net" "net/netip" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -19,67 +18,6 @@ import ( // // Keep entities in this file in alphabetic order. -// Standard Library - -// Package fs - -// FS is a fake [fs.FS] implementation for tests. -type FS struct { - OnOpen func(name string) (fs.File, error) -} - -// type check -var _ fs.FS = (*FS)(nil) - -// Open implements the [fs.FS] interface for *FS. -func (fsys *FS) Open(name string) (fs.File, error) { - return fsys.OnOpen(name) -} - -// type check -var _ fs.GlobFS = (*GlobFS)(nil) - -// GlobFS is a fake [fs.GlobFS] implementation for tests. -type GlobFS struct { - // FS is embedded here to avoid implementing all it's methods. - FS - OnGlob func(pattern string) ([]string, error) -} - -// Glob implements the [fs.GlobFS] interface for *GlobFS. -func (fsys *GlobFS) Glob(pattern string) ([]string, error) { - return fsys.OnGlob(pattern) -} - -// type check -var _ fs.StatFS = (*StatFS)(nil) - -// StatFS is a fake [fs.StatFS] implementation for tests. -type StatFS struct { - // FS is embedded here to avoid implementing all it's methods. - FS - OnStat func(name string) (fs.FileInfo, error) -} - -// Stat implements the [fs.StatFS] interface for *StatFS. -func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { - return fsys.OnStat(name) -} - -// Package io - -// Writer is a fake [io.Writer] implementation for tests. -type Writer struct { - OnWrite func(b []byte) (n int, err error) -} - -var _ io.Writer = (*Writer)(nil) - -// Write implements the [io.Writer] interface for *Writer. -func (w *Writer) Write(b []byte) (n int, err error) { - return w.OnWrite(b) -} - // Module adguard-home // Package aghos @@ -177,6 +115,18 @@ func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.I p.OnUpdateAddress(ip, host, info) } +// Package filtering + +// Resolver is a fake [filtering.Resolver] implementation for tests. +type Resolver struct { + OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error) +} + +// LookupIP implements the [filtering.Resolver] interface for *Resolver. +func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) { + return r.OnLookupIP(ctx, network, host) +} + // Package rdns // Exchanger is a fake [rdns.Exchanger] implementation for tests. diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index 9141d132..a17c5e67 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -1,3 +1,11 @@ package aghtest_test +import ( + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" +) + // Put interface checks that cause import cycles here. + +// type check +var _ filtering.Resolver = (*aghtest.Resolver)(nil) diff --git a/internal/aghtest/resolver.go b/internal/aghtest/resolver.go deleted file mode 100644 index 3c2df964..00000000 --- a/internal/aghtest/resolver.go +++ /dev/null @@ -1,57 +0,0 @@ -package aghtest - -import ( - "context" - "crypto/sha256" - "net" - "sync" -) - -// TestResolver is a Resolver for tests. -type TestResolver struct { - counter int - counterLock sync.Mutex -} - -// HostToIPs generates IPv4 and IPv6 from host. -func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) { - hash := sha256.Sum256([]byte(host)) - - return net.IP(hash[:4]), net.IP(hash[4:20]) -} - -// LookupIP implements Resolver interface for *testResolver. It returns the -// slice of net.IP with IPv4 and IPv6 instances. -func (r *TestResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) { - ipv4, ipv6 := r.HostToIPs(host) - addrs := []net.IP{ipv4, ipv6} - - r.counterLock.Lock() - defer r.counterLock.Unlock() - r.counter++ - - return addrs, nil -} - -// LookupHost implements Resolver interface for *testResolver. It returns the -// slice of IPv4 and IPv6 instances converted to strings. -func (r *TestResolver) LookupHost(host string) (addrs []string, err error) { - ipv4, ipv6 := r.HostToIPs(host) - - r.counterLock.Lock() - defer r.counterLock.Unlock() - r.counter++ - - return []string{ - ipv4.String(), - ipv6.String(), - }, nil -} - -// Counter returns the number of requests handled. -func (r *TestResolver) Counter() int { - r.counterLock.Lock() - defer r.counterLock.Unlock() - - return r.counter -} diff --git a/internal/client/addrproc.go b/internal/client/addrproc.go index 72969ced..04ee50d5 100644 --- a/internal/client/addrproc.go +++ b/internal/client/addrproc.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" @@ -39,7 +40,7 @@ func (EmptyAddrProc) Close() (_ error) { return nil } type DefaultAddrProcConfig struct { // DialContext is used to create TCP connections to WHOIS servers. // DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true. - DialContext whois.DialContextFunc + DialContext aghnet.DialContextFunc // Exchanger is used to perform rDNS queries. Exchanger must not be nil if // [DefaultAddrProcConfig.UseRDNS] is true. @@ -161,7 +162,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { // newWHOIS returns a whois.Interface instance using the given function for // dialing. -func newWHOIS(dialFunc whois.DialContextFunc) (w whois.Interface) { +func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) { // TODO(s.chzhen): Consider making configurable. const ( // defaultTimeout is the timeout for WHOIS requests. diff --git a/internal/dnsforward/dialcontext.go b/internal/dnsforward/dialcontext.go index db32dd3d..f917f54c 100644 --- a/internal/dnsforward/dialcontext.go +++ b/internal/dnsforward/dialcontext.go @@ -10,7 +10,7 @@ import ( "github.com/AdguardTeam/golibs/log" ) -// DialContext is a [whois.DialContextFunc] that uses s to resolve hostnames. +// DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames. func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { log.Debug("dnsforward: dialing %q for network %q", addr, network) diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index bbb42116..eb077aff 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1,6 +1,7 @@ package dnsforward import ( + "context" "crypto/ecdsa" "crypto/rand" "crypto/rsa" @@ -467,7 +468,14 @@ func TestServerRace(t *testing.T) { } func TestSafeSearch(t *testing.T) { - resolver := &aghtest.TestResolver{} + resolver := &aghtest.Resolver{ + OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { + ip4, ip6 := aghtest.HostToIPs(host) + + return []net.IP{ip4, ip6}, nil + }, + } + safeSearchConf := filtering.SafeSearchConfig{ Enabled: true, Google: true, @@ -506,7 +514,7 @@ func TestSafeSearch(t *testing.T) { client := &dns.Client{} yandexIP := net.IP{213, 180, 193, 56} - googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com") + googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com") testCases := []struct { host string @@ -954,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { Upstream: aghtest.NewBlockUpstream(hostname, true), }) - ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) + ans4, _ := aghtest.HostToIPs(hostname) filterConf := &filtering.Config{ SafeBrowsingEnabled: true, diff --git a/internal/filtering/rulelist/parser_test.go b/internal/filtering/rulelist/parser_test.go index 3ca3565d..5554458d 100644 --- a/internal/filtering/rulelist/parser_test.go +++ b/internal/filtering/rulelist/parser_test.go @@ -6,10 +6,10 @@ import ( "strings" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/testutil/fakeio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -159,7 +159,7 @@ func TestParser_Parse(t *testing.T) { func TestParser_Parse_writeError(t *testing.T) { t.Parallel() - dst := &aghtest.Writer{ + dst := &fakeio.Writer{ OnWrite: func(b []byte) (n int, err error) { return 1, errors.Error("test error") }, diff --git a/internal/filtering/safesearch/safesearch_internal_test.go b/internal/filtering/safesearch/safesearch_internal_test.go index c87a9ad5..909265ee 100644 --- a/internal/filtering/safesearch/safesearch_internal_test.go +++ b/internal/filtering/safesearch/safesearch_internal_test.go @@ -89,37 +89,34 @@ func TestSafeSearchCacheGoogle(t *testing.T) { assert.False(t, res.IsFiltered) assert.Empty(t, res.Rules) - resolver := &aghtest.TestResolver{} + resolver := &aghtest.Resolver{ + OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { + ip4, ip6 := aghtest.HostToIPs(host) + + return []net.IP{ip4, ip6}, nil + }, + } + ss = newForTest(t, defaultSafeSearchConf) ss.resolver = resolver // Lookup for safesearch domain. rewrite := ss.searchHost(domain, testQType) - ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) - require.NoError(t, err) - - var foundIP net.IP - for _, ip := range ips { - if ip.To4() != nil { - foundIP = ip - - break - } - } + wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME) res, err = ss.CheckHost(domain, testQType) require.NoError(t, err) require.Len(t, res.Rules, 1) - assert.True(t, res.Rules[0].IP.Equal(foundIP)) + assert.True(t, res.Rules[0].IP.Equal(wantIP)) // Check cache. cachedValue, isFound := ss.getCachedResult(domain, testQType) require.True(t, isFound) require.Len(t, cachedValue.Rules, 1) - assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP)) + assert.True(t, cachedValue.Rules[0].IP.Equal(wantIP)) } const googleHost = "www.google.com" diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index 12860c5d..c62dd6e4 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -92,8 +92,15 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) { } func TestDefault_CheckHost_google(t *testing.T) { - resolver := &aghtest.TestResolver{} - ip, _ := resolver.HostToIPs("forcesafesearch.google.com") + resolver := &aghtest.Resolver{ + OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { + ip4, ip6 := aghtest.HostToIPs(host) + + return []net.IP{ip4, ip6}, nil + }, + } + + wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com") conf := testConf conf.CustomResolver = resolver @@ -119,7 +126,7 @@ func TestDefault_CheckHost_google(t *testing.T) { require.Len(t, res.Rules, 1) - assert.Equal(t, ip, res.Rules[0].IP) + assert.Equal(t, wantIP, res.Rules[0].IP) assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) }) } diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index b0d32902..ab8e485d 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -12,11 +12,11 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/golibs/testutil" + "github.com/AdguardTeam/golibs/testutil/fakefs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -90,7 +90,7 @@ func newTestServer( c := &websvc.Config{ ConfigManager: confMgr, - Frontend: &aghtest.FS{ + Frontend: &fakefs.FS{ OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist }, }, TLS: nil, diff --git a/internal/whois/whois.go b/internal/whois/whois.go index 49c179b8..ae01304b 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -13,6 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghio" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -49,7 +50,7 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool) // Config is the configuration structure for Default. type Config struct { // DialContext is used to create TCP connections to WHOIS servers. - DialContext DialContextFunc + DialContext aghnet.DialContextFunc // ServerAddr is the address of the WHOIS server. ServerAddr string @@ -77,13 +78,6 @@ type Config struct { Port uint16 } -// DialContextFunc is the semantic alias for dialing functions, such as -// [http.Transport.DialContext]. -// -// TODO(a.garipov): Move to aghnet once it stops importing aghtest, because -// otherwise there is an import cycle. -type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error) - // Default is the default WHOIS information processor. type Default struct { // cache is the cache containing IP addresses of clients. An active IP @@ -93,7 +87,7 @@ type Default struct { cache gcache.Cache // dialContext is used to create TCP connections to WHOIS servers. - dialContext DialContextFunc + dialContext aghnet.DialContextFunc // serverAddr is the address of the WHOIS server. serverAddr string