Pull request 1933: upd-golibs

Squashed commit of the following:

commit 081d10e6909def3a075707e75dbd0c5f63f91903
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jul 20 14:17:01 2023 +0300

    aghnet: fix docs

commit 7433b72c0653cb33fe5ff810ae8a1346a6994f95
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Jul 20 14:03:16 2023 +0300

    all: imp tests; upd golibs
This commit is contained in:
Ainar Garipov 2023-07-20 14:26:35 +03:00
parent 4e8d3d7628
commit 5be0e84719
22 changed files with 587 additions and 638 deletions

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.19
require ( require (
github.com/AdguardTeam/dnsproxy v0.52.0 github.com/AdguardTeam/dnsproxy v0.52.0
github.com/AdguardTeam/golibs v0.13.4 github.com/AdguardTeam/golibs v0.13.5
github.com/AdguardTeam/urlfilter v0.16.1 github.com/AdguardTeam/urlfilter v0.16.1
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnscrypt/v2 v2.2.7

4
go.sum
View File

@ -2,8 +2,8 @@ github.com/AdguardTeam/dnsproxy v0.52.0 h1:uZxCXflHSAwtJ7uTYXP6qgWcxaBsH0pJvldpw
github.com/AdguardTeam/dnsproxy v0.52.0/go.mod h1:Jo2zeRe97Rxt3yikXc+fn0LdLtqCj0Xlyh1PNBj6bpM= github.com/AdguardTeam/dnsproxy v0.52.0/go.mod h1:Jo2zeRe97Rxt3yikXc+fn0LdLtqCj0Xlyh1PNBj6bpM=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/golibs v0.13.4 h1:ACTwIR1pEENBijHcEWtiMbSh4wWQOlIHRxmUB8oBHf8= github.com/AdguardTeam/golibs v0.13.5 h1:fpa30Yr9Rcn4vJ88nE4XHSompY7/qMOq2aNS/4PGymA=
github.com/AdguardTeam/golibs v0.13.4/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI= github.com/AdguardTeam/golibs v0.13.5/go.mod h1:wkJ6EUsN4np/9Gp7+9QeooY9E2U2WCLJYAioLCzkHsI=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw= github.com/AdguardTeam/urlfilter v0.16.1 h1:ZPi0rjqo8cQf2FVdzo6cqumNoHZx2KPXj2yZa1A5BBw=
github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= github.com/AdguardTeam/urlfilter v0.16.1/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=

View File

@ -1,10 +1,11 @@
package aghio package aghio_test
import ( import (
"io" "io"
"strings" "strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -31,7 +32,7 @@ func TestLimitReader(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, err := LimitReader(nil, tc.n) _, err := aghio.LimitReader(nil, tc.n)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err) testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
}) })
} }
@ -57,7 +58,7 @@ func TestLimitedReader_Read(t *testing.T) {
limit: 3, limit: 3,
want: 0, want: 0,
}, { }, {
err: &LimitReachedError{ err: &aghio.LimitReachedError{
Limit: 0, Limit: 0,
}, },
name: "limit_reached", name: "limit_reached",
@ -74,7 +75,7 @@ func TestLimitedReader_Read(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
readCloser := io.NopCloser(strings.NewReader(tc.rStr)) readCloser := io.NopCloser(strings.NewReader(tc.rStr))
lreader, err := LimitReader(readCloser, tc.limit) lreader, err := aghio.LimitReader(readCloser, tc.limit)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, lreader) require.NotNil(t, lreader)
@ -89,7 +90,7 @@ func TestLimitedReader_Read(t *testing.T) {
} }
func TestLimitedReader_LimitReachedError(t *testing.T) { func TestLimitedReader_LimitReachedError(t *testing.T) {
testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &LimitReachedError{ testutil.AssertErrorMsg(t, "attempted to read more than 0 bytes", &aghio.LimitReachedError{
Limit: 0, Limit: 0,
}) })
} }

View File

@ -141,9 +141,9 @@ type HostsRecord struct {
Canonical string Canonical string
} }
// equal returns true if all fields of rec are equal to field in other or they // Equal returns true if all fields of rec are equal to field in other or they
// both are nil. // both are nil.
func (rec *HostsRecord) equal(other *HostsRecord) (ok bool) { func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) {
if rec == nil { if rec == nil {
return other == nil return other == nil
} else if other == nil { } else if other == nil {
@ -495,7 +495,7 @@ func (hc *HostsContainer) refresh() (err error) {
} }
// hc.last is nil on the first refresh, so let that one through. // hc.last is nil on the first refresh, so let that one through.
if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).equal) { if hc.last != nil && maps.EqualFunc(hp.table, hc.last, (*HostsRecord).Equal) {
log.Debug("%s: no changes detected", hostsContainerPrefix) log.Debug("%s: no changes detected", hostsContainerPrefix)
return nil return nil

View File

@ -0,0 +1,144 @@
package aghnet
import (
"io/fs"
"net/netip"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const nl = "\n"
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &fakefs.StatFS{
OnOpen: func(_ string) (f fs.File, err error) { panic("not implemented") },
OnStat: func(name string) (fi fs.FileInfo, err error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@ -1,9 +1,7 @@
package aghnet package aghnet_test
import ( import (
"io/fs"
"net" "net"
"net/netip"
"path" "path"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -12,6 +10,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -24,10 +23,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const ( const nl = "\n"
nl = "\n"
sp = " "
)
func TestNewHostsContainer(t *testing.T) { func TestNewHostsContainer(t *testing.T) {
const dirname = "dir" const dirname = "dir"
@ -48,11 +44,11 @@ func TestNewHostsContainer(t *testing.T) {
name: "one_file", name: "one_file",
paths: []string{p}, paths: []string{p},
}, { }, {
wantErr: ErrNoHostsPaths, wantErr: aghnet.ErrNoHostsPaths,
name: "no_files", name: "no_files",
paths: []string{}, paths: []string{},
}, { }, {
wantErr: ErrNoHostsPaths, wantErr: aghnet.ErrNoHostsPaths,
name: "non-existent_file", name: "non-existent_file",
paths: []string{path.Join(dirname, filename+"2")}, paths: []string{path.Join(dirname, filename+"2")},
}, { }, {
@ -77,7 +73,7 @@ func TestNewHostsContainer(t *testing.T) {
return eventsCh return eventsCh
} }
hc, err := NewHostsContainer(0, testFS, &aghtest.FSWatcher{ hc, err := aghnet.NewHostsContainer(0, testFS, &aghtest.FSWatcher{
OnEvents: onEvents, OnEvents: onEvents,
OnAdd: onAdd, OnAdd: onAdd,
OnClose: func() (err error) { return nil }, OnClose: func() (err error) { return nil },
@ -103,7 +99,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) { t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() { require.Panics(t, func() {
_, _ = NewHostsContainer(0, nil, &aghtest.FSWatcher{ _, _ = aghnet.NewHostsContainer(0, nil, &aghtest.FSWatcher{
// Those shouldn't panic. // Those shouldn't panic.
OnEvents: func() (e <-chan struct{}) { return nil }, OnEvents: func() (e <-chan struct{}) { return nil },
OnAdd: func(name string) (err error) { return nil }, OnAdd: func(name string) (err error) { return nil },
@ -114,7 +110,7 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_watcher", func(t *testing.T) { t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() { require.Panics(t, func() {
_, _ = NewHostsContainer(0, testFS, nil, p) _, _ = aghnet.NewHostsContainer(0, testFS, nil, p)
}) })
}) })
@ -127,7 +123,7 @@ func TestNewHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil }, OnClose: func() (err error) { return nil },
} }
hc, err := NewHostsContainer(0, testFS, errWatcher, p) hc, err := aghnet.NewHostsContainer(0, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd) require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc) assert.Nil(t, hc)
@ -158,11 +154,11 @@ func TestHostsContainer_refresh(t *testing.T) {
OnClose: func() (err error) { return nil }, OnClose: func() (err error) { return nil },
} }
hc, err := NewHostsContainer(0, testFS, w, "dir") hc, err := aghnet.NewHostsContainer(0, testFS, w, "dir")
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close) testutil.CleanupAndRequireSuccess(t, hc.Close)
checkRefresh := func(t *testing.T, want *HostsRecord) { checkRefresh := func(t *testing.T, want *aghnet.HostsRecord) {
t.Helper() t.Helper()
upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
@ -175,11 +171,11 @@ func TestHostsContainer_refresh(t *testing.T) {
require.True(t, ok) require.True(t, ok)
require.NotNil(t, rec) require.NotNil(t, rec)
assert.Truef(t, rec.equal(want), "%+v != %+v", rec, want) assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want)
} }
t.Run("initial_refresh", func(t *testing.T) { t.Run("initial_refresh", func(t *testing.T) {
checkRefresh(t, &HostsRecord{ checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet(), Aliases: stringutil.NewSet(),
Canonical: "hostname", Canonical: "hostname",
}) })
@ -189,7 +185,7 @@ func TestHostsContainer_refresh(t *testing.T) {
testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)} testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)}
eventsCh <- event{} eventsCh <- event{}
checkRefresh(t, &HostsRecord{ checkRefresh(t, &aghnet.HostsRecord{
Aliases: stringutil.NewSet("alias"), Aliases: stringutil.NewSet("alias"),
Canonical: "hostname", Canonical: "hostname",
}) })
@ -228,66 +224,6 @@ func TestHostsContainer_refresh(t *testing.T) {
}) })
} }
func TestHostsContainer_PathsToPatterns(t *testing.T) {
gsfs := fstest.MapFS{
"dir_0/file_1": &fstest.MapFile{Data: []byte{1}},
"dir_0/file_2": &fstest.MapFile{Data: []byte{2}},
"dir_0/dir_1/file_3": &fstest.MapFile{Data: []byte{3}},
}
testCases := []struct {
name string
paths []string
want []string
}{{
name: "no_paths",
paths: nil,
want: nil,
}, {
name: "single_file",
paths: []string{"dir_0/file_1"},
want: []string{"dir_0/file_1"},
}, {
name: "several_files",
paths: []string{"dir_0/file_1", "dir_0/file_2"},
want: []string{"dir_0/file_1", "dir_0/file_2"},
}, {
name: "whole_dir",
paths: []string{"dir_0"},
want: []string{"dir_0/*"},
}, {
name: "file_and_dir",
paths: []string{"dir_0/file_1", "dir_0/dir_1"},
want: []string{"dir_0/file_1", "dir_0/dir_1/*"},
}, {
name: "non-existing",
paths: []string{path.Join("dir_0", "file_3")},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
patterns, err := pathsToPatterns(gsfs, tc.paths)
require.NoError(t, err)
assert.Equal(t, tc.want, patterns)
})
}
t.Run("bad_file", func(t *testing.T) {
const errStat errors.Error = "bad file"
badFS := &aghtest.StatFS{
OnStat: func(name string) (fs.FileInfo, error) {
return nil, errStat
},
}
_, err := pathsToPatterns(badFS, []string{""})
assert.ErrorIs(t, err, errStat)
})
}
func TestHostsContainer_Translate(t *testing.T) { func TestHostsContainer_Translate(t *testing.T) {
stubWatcher := aghtest.FSWatcher{ stubWatcher := aghtest.FSWatcher{
OnEvents: func() (e <-chan struct{}) { return nil }, OnEvents: func() (e <-chan struct{}) { return nil },
@ -297,7 +233,7 @@ func TestHostsContainer_Translate(t *testing.T) {
require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) require.NoError(t, fstest.TestFS(testdata, "etc_hosts"))
hc, err := NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts") hc, err := aghnet.NewHostsContainer(0, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close) testutil.CleanupAndRequireSuccess(t, hc.Close)
@ -527,7 +463,7 @@ func TestHostsContainer(t *testing.T) {
OnClose: func() (err error) { return nil }, OnClose: func() (err error) { return nil },
} }
hc, err := NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts") hc, err := aghnet.NewHostsContainer(listID, testdata, &stubWatcher, "etc_hosts")
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close) testutil.CleanupAndRequireSuccess(t, hc.Close)
@ -558,69 +494,3 @@ func TestHostsContainer(t *testing.T) {
}) })
} }
} }
func TestUniqueRules_ParseLine(t *testing.T) {
ip := netutil.IPv4Localhost()
ipStr := ip.String()
testCases := []struct {
name string
line string
wantIP netip.Addr
wantHosts []string
}{{
name: "simple",
line: ipStr + ` hostname`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "aliases",
line: ipStr + ` hostname alias`,
wantIP: ip,
wantHosts: []string{"hostname", "alias"},
}, {
name: "invalid_line",
line: ipStr,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "invalid_line_hostname",
line: ipStr + ` # hostname`,
wantIP: ip,
wantHosts: nil,
}, {
name: "commented_aliases",
line: ipStr + ` hostname # alias`,
wantIP: ip,
wantHosts: []string{"hostname"},
}, {
name: "whole_comment",
line: `# ` + ipStr + ` hostname`,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "partial_comment",
line: ipStr + ` host#name`,
wantIP: ip,
wantHosts: []string{"host"},
}, {
name: "empty",
line: ``,
wantIP: netip.Addr{},
wantHosts: nil,
}, {
name: "bad_hosts",
line: ipStr + ` bad..host bad._tld empty.tld. ok.host`,
wantIP: ip,
wantHosts: []string{"ok.host"},
}}
for _, tc := range testCases {
hp := hostsParser{}
t.Run(tc.name, func(t *testing.T) {
got, hosts := hp.parseLine(tc.line)
assert.Equal(t, tc.wantIP, got)
assert.Equal(t, tc.wantHosts, hosts)
})
}
}

View File

@ -3,6 +3,7 @@ package aghnet
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -15,6 +16,10 @@ import (
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// DialContextFunc is the semantic alias for dialing functions, such as
// [http.Transport.DialContext].
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Variables and functions to substitute in tests. // Variables and functions to substitute in tests.
var ( var (
// aghosRunCommand is the function to run shell commands. // aghosRunCommand is the function to run shell commands.

View File

@ -5,9 +5,9 @@ import (
"testing" "testing"
"testing/fstest" "testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -118,7 +118,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
Data: []byte(`nameserver 1.1.1.1`), Data: []byte(`nameserver 1.1.1.1`),
}, },
} }
panicFsys := &aghtest.FS{ panicFsys := &fakefs.FS{
OnOpen: func(name string) (fs.File, error) { panic("not implemented") }, OnOpen: func(name string) (fs.File, error) { panic("not implemented") },
} }

View File

@ -0,0 +1,334 @@
package aghnet
import (
"bytes"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@ -1,21 +1,11 @@
package aghnet package aghnet_test
import ( import (
"bytes"
"encoding/json"
"fmt"
"io/fs" "io/fs"
"net"
"net/netip"
"os" "os"
"strings"
"testing" "testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -24,315 +14,3 @@ func TestMain(m *testing.M) {
// testdata is the filesystem containing data for testing the package. // testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata") var testdata fs.FS = os.DirFS("./testdata")
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(t testing.TB, fsys fs.FS) {
t.Helper()
prev := rootDirFS
t.Cleanup(func() { rootDirFS = prev })
rootDirFS = fsys
}
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under t.
func substShell(t testing.TB, rc RunCmdFunc) {
t.Helper()
prev := aghosRunCommand
t.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
// ifaceAddrsFunc is the signature of net.InterfaceAddrs function.
type ifaceAddrsFunc func() (ifaces []net.Addr, err error)
// substNetInterfaceAddrs replaces the the net.InterfaceAddrs function used
// throughout the package with f for tests ran under t.
func substNetInterfaceAddrs(t *testing.T, f ifaceAddrsFunc) {
t.Helper()
prev := netInterfaceAddrs
t.Cleanup(func() { netInterfaceAddrs = prev })
netInterfaceAddrs = f
}
func TestGatewayIP(t *testing.T) {
const ifaceName = "ifaceName"
const cmd = "ip route show dev " + ifaceName
testCases := []struct {
shell mapShell
want netip.Addr
name string
}{{
shell: theOnlyCmd(cmd, 0, `default via 1.2.3.4 onlink`, nil),
want: netip.MustParseAddr("1.2.3.4"),
name: "success_v4",
}, {
shell: theOnlyCmd(cmd, 0, `default via ::ffff onlink`, nil),
want: netip.MustParseAddr("::ffff"),
name: "success_v6",
}, {
shell: theOnlyCmd(cmd, 0, `non-default via 1.2.3.4 onlink`, nil),
want: netip.Addr{},
name: "bad_output",
}, {
shell: theOnlyCmd(cmd, 0, "", errors.Error("can't run command")),
want: netip.Addr{},
name: "err_runcmd",
}, {
shell: theOnlyCmd(cmd, 1, "", nil),
want: netip.Addr{},
name: "bad_code",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substShell(t, tc.shell.RunCmd)
assert.Equal(t, tc.want, GatewayIP(ifaceName))
})
}
}
func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err)
require.NotEmpty(t, ifaces)
for _, iface := range ifaces {
t.Run(iface.Name, func(t *testing.T) {
require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses {
ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName)
}
})
}
}
func TestBroadcastFromIPNet(t *testing.T) {
known4 := netip.MustParseAddr("192.168.0.1")
fullBroadcast4 := netip.MustParseAddr("255.255.255.255")
known6 := netip.MustParseAddr("102:304:506:708:90a:b0c:d0e:f10")
testCases := []struct {
pref netip.Prefix
want netip.Addr
name string
}{{
pref: netip.PrefixFrom(known4, 0),
want: fullBroadcast4,
name: "full",
}, {
pref: netip.PrefixFrom(known4, 20),
want: netip.MustParseAddr("192.168.15.255"),
name: "full",
}, {
pref: netip.PrefixFrom(known6, netutil.IPv6BitLen),
want: known6,
name: "ipv6_no_mask",
}, {
pref: netip.PrefixFrom(known4, netutil.IPv4BitLen),
want: known4,
name: "ipv4_no_mask",
}, {
pref: netip.PrefixFrom(netip.IPv4Unspecified(), 0),
want: fullBroadcast4,
name: "unspecified",
}, {
pref: netip.Prefix{},
want: netip.Addr{},
name: "invalid",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.want, BroadcastFromPref(tc.pref))
})
}
}
func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)
t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("tcp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("udp_bound", func(t *testing.T) {
conn, err := net.ListenPacket("udp", laddr.String())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, conn.Close)
ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort()
require.Equal(t, laddr.Addr(), ipp.Addr())
require.NotZero(t, ipp.Port())
err = CheckPort("udp", ipp)
target := &net.OpError{}
require.ErrorAs(t, err, &target)
assert.Equal(t, "listen", target.Op)
})
t.Run("bad_network", func(t *testing.T) {
err := CheckPort("bad_network", netip.AddrPortFrom(netip.Addr{}, 0))
assert.NoError(t, err)
})
t.Run("can_bind", func(t *testing.T) {
err := CheckPort("udp", netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
assert.NoError(t, err)
})
}
func TestCollectAllIfacesAddrs(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
addrs []net.Addr
wantAddrs []string
}{{
name: "success",
wantErrMsg: ``,
addrs: []net.Addr{&net.IPNet{
IP: net.IP{1, 2, 3, 4},
Mask: net.CIDRMask(24, netutil.IPv4BitLen),
}, &net.IPNet{
IP: net.IP{4, 3, 2, 1},
Mask: net.CIDRMask(16, netutil.IPv4BitLen),
}},
wantAddrs: []string{"1.2.3.4", "4.3.2.1"},
}, {
name: "not_cidr",
wantErrMsg: `parsing cidr: invalid CIDR address: 1.2.3.4`,
addrs: []net.Addr{&net.IPAddr{
IP: net.IP{1, 2, 3, 4},
}},
wantAddrs: nil,
}, {
name: "empty",
wantErrMsg: ``,
addrs: []net.Addr{},
wantAddrs: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return tc.addrs, nil })
addrs, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.wantAddrs, addrs)
})
}
t.Run("internal_error", func(t *testing.T) {
const errAddrs errors.Error = "can't get addresses"
const wantErrMsg string = `getting interfaces addresses: ` + string(errAddrs)
substNetInterfaceAddrs(t, func() ([]net.Addr, error) { return nil, errAddrs })
_, err := CollectAllIfacesAddrs()
testutil.AssertErrorMsg(t, wantErrMsg, err)
})
}
func TestIsAddrInUse(t *testing.T) {
t.Run("addr_in_use", func(t *testing.T) {
l, err := net.Listen("tcp", "0.0.0.0:0")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, l.Close)
_, err = net.Listen(l.Addr().Network(), l.Addr().String())
assert.True(t, IsAddrInUse(err))
})
t.Run("another", func(t *testing.T) {
const anotherErr errors.Error = "not addr in use"
assert.False(t, IsAddrInUse(anotherErr))
})
}
func TestNetInterface_MarshalJSON(t *testing.T) {
const want = `{` +
`"hardware_address":"aa:bb:cc:dd:ee:ff",` +
`"flags":"up|multicast",` +
`"ip_addresses":["1.2.3.4","aaaa::1"],` +
`"name":"iface0",` +
`"mtu":1500` +
`}` + "\n"
ip4, ok := netip.AddrFromSlice([]byte{1, 2, 3, 4})
require.True(t, ok)
ip6, ok := netip.AddrFromSlice([]byte{0xAA, 0xAA, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
require.True(t, ok)
net4 := netip.PrefixFrom(ip4, 24)
net6 := netip.PrefixFrom(ip6, 8)
iface := &NetInterface{
Addresses: []netip.Addr{ip4, ip6},
Subnets: []netip.Prefix{net4, net6},
Name: "iface0",
HardwareAddr: net.HardwareAddr{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF},
Flags: net.FlagUp | net.FlagMulticast,
MTU: 1500,
}
b := &bytes.Buffer{}
err := json.NewEncoder(b).Encode(iface)
require.NoError(t, err)
assert.Equal(t, want, b.String())
}

View File

@ -2,7 +2,9 @@
package aghtest package aghtest
import ( import (
"crypto/sha256"
"io" "io"
"net"
"testing" "testing"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
@ -34,3 +36,10 @@ func ReplaceLogLevel(t testing.TB, l log.Level) {
t.Cleanup(func() { log.SetLevel(prev) }) t.Cleanup(func() { log.SetLevel(prev) })
log.SetLevel(l) log.SetLevel(l)
} }
// HostToIPs is a helper that generates one IPv4 and one IPv6 address from host.
func HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}

View File

@ -2,8 +2,7 @@ package aghtest
import ( import (
"context" "context"
"io" "net"
"io/fs"
"net/netip" "net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -19,67 +18,6 @@ import (
// //
// Keep entities in this file in alphabetic order. // Keep entities in this file in alphabetic order.
// Standard Library
// Package fs
// FS is a fake [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// type check
var _ fs.FS = (*FS)(nil)
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = (*GlobFS)(nil)
// GlobFS is a fake [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = (*StatFS)(nil)
// StatFS is a fake [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// Package io
// Writer is a fake [io.Writer] implementation for tests.
type Writer struct {
OnWrite func(b []byte) (n int, err error)
}
var _ io.Writer = (*Writer)(nil)
// Write implements the [io.Writer] interface for *Writer.
func (w *Writer) Write(b []byte) (n int, err error) {
return w.OnWrite(b)
}
// Module adguard-home // Module adguard-home
// Package aghos // Package aghos
@ -177,6 +115,18 @@ func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.I
p.OnUpdateAddress(ip, host, info) p.OnUpdateAddress(ip, host, info)
} }
// Package filtering
// Resolver is a fake [filtering.Resolver] implementation for tests.
type Resolver struct {
OnLookupIP func(ctx context.Context, network, host string) (ips []net.IP, err error)
}
// LookupIP implements the [filtering.Resolver] interface for *Resolver.
func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []net.IP, err error) {
return r.OnLookupIP(ctx, network, host)
}
// Package rdns // Package rdns
// Exchanger is a fake [rdns.Exchanger] implementation for tests. // Exchanger is a fake [rdns.Exchanger] implementation for tests.

View File

@ -1,3 +1,11 @@
package aghtest_test package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
)
// Put interface checks that cause import cycles here. // Put interface checks that cause import cycles here.
// type check
var _ filtering.Resolver = (*aghtest.Resolver)(nil)

View File

@ -1,57 +0,0 @@
package aghtest
import (
"context"
"crypto/sha256"
"net"
"sync"
)
// TestResolver is a Resolver for tests.
type TestResolver struct {
counter int
counterLock sync.Mutex
}
// HostToIPs generates IPv4 and IPv6 from host.
func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) {
hash := sha256.Sum256([]byte(host))
return net.IP(hash[:4]), net.IP(hash[4:20])
}
// LookupIP implements Resolver interface for *testResolver. It returns the
// slice of net.IP with IPv4 and IPv6 instances.
func (r *TestResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
ipv4, ipv6 := r.HostToIPs(host)
addrs := []net.IP{ipv4, ipv6}
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return addrs, nil
}
// LookupHost implements Resolver interface for *testResolver. It returns the
// slice of IPv4 and IPv6 instances converted to strings.
func (r *TestResolver) LookupHost(host string) (addrs []string, err error) {
ipv4, ipv6 := r.HostToIPs(host)
r.counterLock.Lock()
defer r.counterLock.Unlock()
r.counter++
return []string{
ipv4.String(),
ipv6.String(),
}, nil
}
// Counter returns the number of requests handled.
func (r *TestResolver) Counter() int {
r.counterLock.Lock()
defer r.counterLock.Unlock()
return r.counter
}

View File

@ -6,6 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/rdns"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -39,7 +40,7 @@ func (EmptyAddrProc) Close() (_ error) { return nil }
type DefaultAddrProcConfig struct { type DefaultAddrProcConfig struct {
// DialContext is used to create TCP connections to WHOIS servers. // DialContext is used to create TCP connections to WHOIS servers.
// DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true. // DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true.
DialContext whois.DialContextFunc DialContext aghnet.DialContextFunc
// Exchanger is used to perform rDNS queries. Exchanger must not be nil if // Exchanger is used to perform rDNS queries. Exchanger must not be nil if
// [DefaultAddrProcConfig.UseRDNS] is true. // [DefaultAddrProcConfig.UseRDNS] is true.
@ -161,7 +162,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) {
// newWHOIS returns a whois.Interface instance using the given function for // newWHOIS returns a whois.Interface instance using the given function for
// dialing. // dialing.
func newWHOIS(dialFunc whois.DialContextFunc) (w whois.Interface) { func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) {
// TODO(s.chzhen): Consider making configurable. // TODO(s.chzhen): Consider making configurable.
const ( const (
// defaultTimeout is the timeout for WHOIS requests. // defaultTimeout is the timeout for WHOIS requests.

View File

@ -10,7 +10,7 @@ import (
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// DialContext is a [whois.DialContextFunc] that uses s to resolve hostnames. // DialContext is an [aghnet.DialContextFunc] that uses s to resolve hostnames.
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Debug("dnsforward: dialing %q for network %q", addr, network) log.Debug("dnsforward: dialing %q for network %q", addr, network)

View File

@ -1,6 +1,7 @@
package dnsforward package dnsforward
import ( import (
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -467,7 +468,14 @@ func TestServerRace(t *testing.T) {
} }
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
resolver := &aghtest.TestResolver{} resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
safeSearchConf := filtering.SafeSearchConfig{ safeSearchConf := filtering.SafeSearchConfig{
Enabled: true, Enabled: true,
Google: true, Google: true,
@ -506,7 +514,7 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{} client := &dns.Client{}
yandexIP := net.IP{213, 180, 193, 56} yandexIP := net.IP{213, 180, 193, 56}
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com") googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct { testCases := []struct {
host string host string
@ -954,7 +962,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
Upstream: aghtest.NewBlockUpstream(hostname, true), Upstream: aghtest.NewBlockUpstream(hostname, true),
}) })
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) ans4, _ := aghtest.HostToIPs(hostname)
filterConf := &filtering.Config{ filterConf := &filtering.Config{
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,

View File

@ -6,10 +6,10 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakeio"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -159,7 +159,7 @@ func TestParser_Parse(t *testing.T) {
func TestParser_Parse_writeError(t *testing.T) { func TestParser_Parse_writeError(t *testing.T) {
t.Parallel() t.Parallel()
dst := &aghtest.Writer{ dst := &fakeio.Writer{
OnWrite: func(b []byte) (n int, err error) { OnWrite: func(b []byte) (n int, err error) {
return 1, errors.Error("test error") return 1, errors.Error("test error")
}, },

View File

@ -89,37 +89,34 @@ func TestSafeSearchCacheGoogle(t *testing.T) {
assert.False(t, res.IsFiltered) assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules) assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{} resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf) ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver ss.resolver = resolver
// Lookup for safesearch domain. // Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType) rewrite := ss.searchHost(domain, testQType)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, testQType) res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP)) assert.True(t, res.Rules[0].IP.Equal(wantIP))
// Check cache. // Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType) cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound) require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1) require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP)) assert.True(t, cachedValue.Rules[0].IP.Equal(wantIP))
} }
const googleHost = "www.google.com" const googleHost = "www.google.com"

View File

@ -92,8 +92,15 @@ func TestDefault_CheckHost_yandexAAAA(t *testing.T) {
} }
func TestDefault_CheckHost_google(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{} resolver := &aghtest.Resolver{
ip, _ := resolver.HostToIPs("forcesafesearch.google.com") OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4, ip6}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf conf := testConf
conf.CustomResolver = resolver conf.CustomResolver = resolver
@ -119,7 +126,7 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
assert.Equal(t, ip, res.Rules[0].IP) assert.Equal(t, wantIP, res.Rules[0].IP)
assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID) assert.EqualValues(t, filtering.SafeSearchListID, res.Rules[0].FilterListID)
}) })
} }

View File

@ -12,11 +12,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/fakefs"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -90,7 +90,7 @@ func newTestServer(
c := &websvc.Config{ c := &websvc.Config{
ConfigManager: confMgr, ConfigManager: confMgr,
Frontend: &aghtest.FS{ Frontend: &fakefs.FS{
OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist }, OnOpen: func(_ string) (_ fs.File, _ error) { return nil, fs.ErrNotExist },
}, },
TLS: nil, TLS: nil,

View File

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -49,7 +50,7 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool)
// Config is the configuration structure for Default. // Config is the configuration structure for Default.
type Config struct { type Config struct {
// DialContext is used to create TCP connections to WHOIS servers. // DialContext is used to create TCP connections to WHOIS servers.
DialContext DialContextFunc DialContext aghnet.DialContextFunc
// ServerAddr is the address of the WHOIS server. // ServerAddr is the address of the WHOIS server.
ServerAddr string ServerAddr string
@ -77,13 +78,6 @@ type Config struct {
Port uint16 Port uint16
} }
// DialContextFunc is the semantic alias for dialing functions, such as
// [http.Transport.DialContext].
//
// TODO(a.garipov): Move to aghnet once it stops importing aghtest, because
// otherwise there is an import cycle.
type DialContextFunc = func(ctx context.Context, network, addr string) (conn net.Conn, err error)
// Default is the default WHOIS information processor. // Default is the default WHOIS information processor.
type Default struct { type Default struct {
// cache is the cache containing IP addresses of clients. An active IP // cache is the cache containing IP addresses of clients. An active IP
@ -93,7 +87,7 @@ type Default struct {
cache gcache.Cache cache gcache.Cache
// dialContext is used to create TCP connections to WHOIS servers. // dialContext is used to create TCP connections to WHOIS servers.
dialContext DialContextFunc dialContext aghnet.DialContextFunc
// serverAddr is the address of the WHOIS server. // serverAddr is the address of the WHOIS server.
serverAddr string serverAddr string