diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 730f1b4e..f4275c6c 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -11,7 +11,6 @@ import ( "strings" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" @@ -20,7 +19,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } // testdata is the filesystem containing data for testing the package. @@ -196,10 +195,7 @@ func TestCheckPort(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, l.Close) - addr := l.Addr() - require.IsType(t, new(net.TCPAddr), addr) - - ipp := addr.(*net.TCPAddr).AddrPort() + ipp := testutil.RequireTypeAssert[*net.TCPAddr](t, l.Addr()).AddrPort() require.Equal(t, laddr.Addr(), ipp.Addr()) require.NotZero(t, ipp.Port()) @@ -215,10 +211,7 @@ func TestCheckPort(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, conn.Close) - addr := conn.LocalAddr() - require.IsType(t, new(net.UDPAddr), addr) - - ipp := addr.(*net.UDPAddr).AddrPort() + ipp := testutil.RequireTypeAssert[*net.UDPAddr](t, conn.LocalAddr()).AddrPort() require.Equal(t, laddr.Addr(), ipp.Addr()) require.NotZero(t, ipp.Port()) diff --git a/internal/aghnet/systemresolvers_others_test.go b/internal/aghnet/systemresolvers_others_test.go index a9974e0a..8bc506a8 100644 --- a/internal/aghnet/systemresolvers_others_test.go +++ b/internal/aghnet/systemresolvers_others_test.go @@ -6,6 +6,7 @@ import ( "context" "testing" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -17,9 +18,8 @@ func createTestSystemResolversImpl( t.Helper() sr := createTestSystemResolvers(t, hostGenFunc) - require.IsType(t, (*systemResolvers)(nil), sr) - return sr.(*systemResolvers) + return testutil.RequireTypeAssert[*systemResolvers](t, sr) } func TestSystemResolvers_Refresh(t *testing.T) { diff --git a/internal/aghos/aghos_test.go b/internal/aghos/aghos_test.go index 684f646e..3916d98e 100644 --- a/internal/aghos/aghos_test.go +++ b/internal/aghos/aghos_test.go @@ -3,9 +3,9 @@ package aghos_test import ( "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/testutil" ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 31274718..850446e0 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -3,23 +3,11 @@ package aghtest import ( "io" - "os" "testing" "github.com/AdguardTeam/golibs/log" ) -// DiscardLogOutput runs tests with discarded logger output. -// -// TODO(a.garipov): Replace with testutil. -func DiscardLogOutput(m *testing.M) { - // TODO(e.burkov): Refactor code and tests to not use the global mutable - // logger. - log.SetOutput(io.Discard) - - os.Exit(m.Run()) -} - // ReplaceLogWriter moves logger output to w and uses Cleanup method of t to // revert changes. func ReplaceLogWriter(t testing.TB, w io.Writer) { diff --git a/internal/aghtls/aghtls_test.go b/internal/aghtls/aghtls_test.go index 923ff063..7e5b99f9 100644 --- a/internal/aghtls/aghtls_test.go +++ b/internal/aghtls/aghtls_test.go @@ -4,14 +4,13 @@ import ( "crypto/tls" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } func TestParseCiphers(t *testing.T) { diff --git a/internal/dhcpd/dhcpd_unix_test.go b/internal/dhcpd/dhcpd_unix_test.go index 38305076..0bf516cb 100644 --- a/internal/dhcpd/dhcpd_unix_test.go +++ b/internal/dhcpd/dhcpd_unix_test.go @@ -9,7 +9,6 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,7 +16,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } func testNotify(flags uint32) { diff --git a/internal/dhcpd/v4_unix_test.go b/internal/dhcpd/v4_unix_test.go index c8e4dd1f..411e36d9 100644 --- a/internal/dhcpd/v4_unix_test.go +++ b/internal/dhcpd/v4_unix_test.go @@ -482,7 +482,6 @@ func TestV4Server_updateOptions(t *testing.T) { s, err := v4Create(conf) require.NoError(t, err) - require.IsType(t, (*v4Server)(nil), s) t.Run(tc.name, func(t *testing.T) { diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 915455d2..acaf5f55 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -5,6 +5,7 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" @@ -449,12 +450,27 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { } func TestServer_ProcessRestrictLocal(t *testing.T) { - ups := &aghtest.Upstream{ - Reverse: map[string][]string{ - "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, - "1.1.168.192.in-addr.arpa.": {"some.local-client."}, + const ( + extPTRQuestion = "251.252.253.254.in-addr.arpa." + extPTRAnswer = "host1.example.net." + intPTRQuestion = "1.1.168.192.in-addr.arpa." + intPTRAnswer = "some.local-client." + ) + + ups := &aghtest.UpstreamMock{ + OnAddress: func() (addr string) { return "upstream.example" }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = aghalg.Coalesce( + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, extPTRQuestion, extPTRAnswer), + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, intPTRQuestion, intPTRAnswer), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + return resp, nil }, + OnClose: func() (err error) { return nil }, } + s := createTestServer(t, &filtering.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, @@ -524,14 +540,26 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { const locDomain = "some.local." const reqAddr = "1.1.168.192.in-addr.arpa." - s := createTestServer(t, &filtering.Config{}, ServerConfig{ - UDPListenAddrs: []*net.UDPAddr{{}}, - TCPListenAddrs: []*net.TCPAddr{{}}, - }, &aghtest.Upstream{ - Reverse: map[string][]string{ - reqAddr: {locDomain}, + s := createTestServer( + t, + &filtering.Config{}, + ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, }, - }) + &aghtest.UpstreamMock{ + OnAddress: func() (addr string) { return "upstream.example" }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = aghalg.Coalesce( + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, reqAddr, locDomain), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + return resp, nil + }, + OnClose: func() (err error) { return nil }, + }, + ) var proxyCtx *proxy.DNSContext var dnsCtx *dnsContext diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 4fc9182d..05869765 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -18,7 +19,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } const ( diff --git a/internal/home/home_test.go b/internal/home/home_test.go index 1a611588..2ce1d76d 100644 --- a/internal/home/home_test.go +++ b/internal/home/home_test.go @@ -3,10 +3,10 @@ package home import ( "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/golibs/testutil" ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) initCmdLineOpts() } diff --git a/internal/next/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go index 814e3bad..5e53ba49 100644 --- a/internal/next/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -10,13 +10,14 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } // testTimeout is the common timeout for tests. diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index 39ab3038..56acb862 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc" @@ -21,7 +20,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } // testTimeout is the common timeout for tests. diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index f33401cc..cc470438 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxyutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -18,7 +18,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } // TestQueryLog tests adding and loading (with filtering) entries from disk and diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 5d86024b..bb2cc0d8 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -10,7 +10,6 @@ import ( "sync/atomic" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" @@ -18,7 +17,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } // constUnitID is the UnitIDGenFunc which always return 0. diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 6eed74fd..dbf0e069 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -11,7 +11,6 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" @@ -21,7 +20,7 @@ import ( // TODO(a.garipov): Rewrite these tests. func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } func startHTTPServer(data string) (l net.Listener, portStr string) {