diff --git a/internal/agherr/agherr_test.go b/internal/agherr/agherr_test.go index 123c45ef..3ac5aeab 100644 --- a/internal/agherr/agherr_test.go +++ b/internal/agherr/agherr_test.go @@ -5,14 +5,9 @@ import ( "fmt" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/stretchr/testify/assert" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} - func TestError_Error(t *testing.T) { testCases := []struct { name string diff --git a/internal/testutil/testutil.go b/internal/aghtest/aghtest.go similarity index 93% rename from internal/testutil/testutil.go rename to internal/aghtest/aghtest.go index 69187969..4c453055 100644 --- a/internal/testutil/testutil.go +++ b/internal/aghtest/aghtest.go @@ -1,5 +1,5 @@ -// Package testutil contains utilities for testing. -package testutil +// Package aghtest contains utilities for testing. +package aghtest import ( "io" diff --git a/internal/aghtest/resolver.go b/internal/aghtest/resolver.go new file mode 100644 index 00000000..75fb6ce0 --- /dev/null +++ b/internal/aghtest/resolver.go @@ -0,0 +1,63 @@ +package aghtest + +import ( + "context" + "crypto/sha256" + "net" + "sync" +) + +// TestResolver is a Resolver for tests. +type TestResolver struct { + counter int + counterLock sync.Mutex +} + +// HostToIPs generates IPv4 and IPv6 from host. +// +// TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15. +func (r *TestResolver) HostToIPs(host string) (ipv4, ipv6 net.IP) { + hash := sha256.Sum256([]byte(host)) + + return net.IP(hash[:4]), net.IP(hash[4:20]) +} + +// LookupIPAddr implements Resolver interface for *testResolver. It returns the +// slice of net.IPAddr with IPv4 and IPv6 instances. +func (r *TestResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) { + ipv4, ipv6 := r.HostToIPs(host) + addrs := []net.IPAddr{{ + IP: ipv4, + }, { + IP: ipv6, + }} + + r.counterLock.Lock() + defer r.counterLock.Unlock() + r.counter++ + + return addrs, nil +} + +// LookupHost implements Resolver interface for *testResolver. It returns the +// slice of IPv4 and IPv6 instances converted to strings. +func (r *TestResolver) LookupHost(host string) (addrs []string, err error) { + ipv4, ipv6 := r.HostToIPs(host) + + r.counterLock.Lock() + defer r.counterLock.Unlock() + r.counter++ + + return []string{ + ipv4.String(), + ipv6.String(), + }, nil +} + +// Counter returns the number of requests handled. +func (r *TestResolver) Counter() int { + r.counterLock.Lock() + defer r.counterLock.Unlock() + + return r.counter +} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go new file mode 100644 index 00000000..78622771 --- /dev/null +++ b/internal/aghtest/upstream.go @@ -0,0 +1,175 @@ +package aghtest + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net" + "strings" + "sync" + + "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/miekg/dns" +) + +// TestUpstream is a mock of real upstream. +type TestUpstream struct { + // Addr is the address for Address method. + Addr string + // CName is a map of hostname to canonical name. + CName map[string]string + // IPv4 is a map of hostname to IPv4. + IPv4 map[string][]net.IP + // IPv6 is a map of hostname to IPv6. + IPv6 map[string][]net.IP + // Reverse is a map of address to domain name. + Reverse map[string][]string +} + +// Exchange implements upstream.Upstream interface for *TestUpstream. +func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { + resp = &dns.Msg{} + resp.SetReply(m) + + if len(m.Question) == 0 { + return nil, fmt.Errorf("question should not be empty") + } + name := m.Question[0].Name + + if cname, ok := u.CName[name]; ok { + resp.Answer = append(resp.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: dns.TypeCNAME, + }, + Target: cname, + }) + } + + var hasRec bool + var rrType uint16 + var ips []net.IP + switch m.Question[0].Qtype { + case dns.TypeA: + rrType = dns.TypeA + if ipv4addr, ok := u.IPv4[name]; ok { + hasRec = true + ips = ipv4addr + } + case dns.TypeAAAA: + rrType = dns.TypeAAAA + if ipv6addr, ok := u.IPv6[name]; ok { + hasRec = true + ips = ipv6addr + } + case dns.TypePTR: + names, ok := u.Reverse[name] + if !ok { + break + } + + for _, n := range names { + resp.Answer = append(resp.Answer, &dns.PTR{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: rrType, + }, + Ptr: n, + }) + } + } + + for _, ip := range ips { + resp.Answer = append(resp.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: name, + Rrtype: rrType, + }, + A: ip, + }) + } + + if len(resp.Answer) == 0 { + if hasRec { + // Set no error RCode if there are some records for + // given Qname but we didn't apply them. + resp.SetRcode(m, dns.RcodeSuccess) + + return resp, nil + } + // Set NXDomain RCode otherwise. + resp.SetRcode(m, dns.RcodeNameError) + } + + return resp, nil +} + +// Address implements upstream.Upstream interface for *TestUpstream. +func (u *TestUpstream) Address() string { + return u.Addr +} + +// TestBlockUpstream implements upstream.Upstream interface for replacing real +// upstream in tests. +type TestBlockUpstream struct { + Hostname string + Block bool + requestsCount int + lock sync.RWMutex +} + +// Exchange returns a message unique for TestBlockUpstream's Hostname-Block +// pair. +func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { + u.lock.Lock() + defer u.lock.Unlock() + u.requestsCount++ + + hash := sha256.Sum256([]byte(u.Hostname)) + hashToReturn := hex.EncodeToString(hash[:]) + if !u.Block { + hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28) + } + + m := &dns.Msg{} + m.Answer = []dns.RR{ + &dns.TXT{ + Hdr: dns.RR_Header{ + Name: r.Question[0].Name, + }, + Txt: []string{ + hashToReturn, + }, + }, + } + + return m, nil +} + +// Address always returns an empty string. +func (u *TestBlockUpstream) Address() string { + return "" +} + +// RequestsCount returns the number of handled requests. It's safe for +// concurrent use. +func (u *TestBlockUpstream) RequestsCount() int { + u.lock.Lock() + defer u.lock.Unlock() + + return u.requestsCount +} + +// TestErrUpstream implements upstream.Upstream interface for replacing real +// upstream in tests. +type TestErrUpstream struct{} + +// Exchange always returns nil Msg and non-nil error. +func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { + return nil, agherr.Error("bad") +} + +// Address always returns an empty string. +func (u *TestErrUpstream) Address() string { + return "" +} diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index f5c34154..1aa1b9a6 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -9,12 +9,12 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func testNotify(flags uint32) { diff --git a/internal/dhcpd/nclient4/client_test.go b/internal/dhcpd/nclient4/client_test.go index 99f99640..9ad376fe 100644 --- a/internal/dhcpd/nclient4/client_test.go +++ b/internal/dhcpd/nclient4/client_test.go @@ -17,14 +17,14 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/hugelgupf/socketpair" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4/server4" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } type handler struct { diff --git a/internal/dnsfilter/dnsfilter.go b/internal/dnsfilter/dnsfilter.go index 6ade9701..34ab408d 100644 --- a/internal/dnsfilter/dnsfilter.go +++ b/internal/dnsfilter/dnsfilter.go @@ -43,6 +43,12 @@ type RequestFilteringSettings struct { ServicesRules []ServiceEntry } +// Resolver is the interface for net.Resolver to simplify testing. +type Resolver interface { + // TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15. + LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error) +} + // Config allows you to configure DNS filtering with New() or just change variables directly. type Config struct { ParentalEnabled bool `yaml:"parental_enabled"` @@ -69,6 +75,9 @@ type Config struct { // Register an HTTP handler HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` + + // CustomResolver is the resolver used by DNSFilter. + CustomResolver Resolver } // LookupStats store stats collected during safebrowsing or parental checks @@ -92,12 +101,6 @@ type filtersInitializerParams struct { blockFilters []Filter } -// Resolver is the interface for net.Resolver to simplify testing. -type Resolver interface { - // TODO(e.burkov): Replace with LookupIP after upgrading go to v1.15. - LookupIPAddr(ctx context.Context, host string) (ips []net.IPAddr, err error) -} - // DNSFilter matches hostnames and DNS requests against filtering rules. type DNSFilter struct { rulesStorage *filterlist.RuleStorage @@ -796,6 +799,7 @@ func InitModule() { // New creates properly initialized DNS Filter that is ready to be used. func New(c *Config, blockFilters []Filter) *DNSFilter { + var resolver Resolver = net.DefaultResolver if c != nil { cacheConf := cache.Config{ EnableLRU: true, @@ -815,10 +819,14 @@ func New(c *Config, blockFilters []Filter) *DNSFilter { cacheConf.MaxSize = c.ParentalCacheSize gctx.parentalCache = cache.New(cacheConf) } + + if c.CustomResolver != nil { + resolver = c.CustomResolver + } } d := &DNSFilter{ - resolver: net.DefaultResolver, + resolver: resolver, } err := d.initSecurityServices() diff --git a/internal/dnsfilter/dnsfilter_test.go b/internal/dnsfilter/dnsfilter_test.go index 6ab4fbc4..25985dae 100644 --- a/internal/dnsfilter/dnsfilter_test.go +++ b/internal/dnsfilter/dnsfilter_test.go @@ -3,12 +3,11 @@ package dnsfilter import ( "bytes" "context" - "crypto/sha256" "fmt" "net" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" @@ -17,7 +16,7 @@ import ( ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } var setts RequestFilteringSettings @@ -37,7 +36,9 @@ func purgeCaches() { } func newForTest(c *Config, filters []Filter) *DNSFilter { - setts = RequestFilteringSettings{} + setts = RequestFilteringSettings{ + FilteringEnabled: true, + } setts.FilteringEnabled = true if c != nil { c.SafeBrowsingCacheSize = 10000 @@ -149,16 +150,16 @@ func TestEtcHostsMatching(t *testing.T) { func TestSafeBrowsing(t *testing.T) { logOutput := &bytes.Buffer{} - testutil.ReplaceLogWriter(t, logOutput) - testutil.ReplaceLogLevel(t, log.DEBUG) + aghtest.ReplaceLogWriter(t, logOutput) + aghtest.ReplaceLogLevel(t, log.DEBUG) d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) matching := "wmconvirus.narod.ru" - d.safeBrowsingUpstream = &testSbUpstream{ - hostname: matching, - block: true, - } + d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ + Hostname: matching, + Block: true, + }) d.checkMatch(t, matching) assert.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) @@ -178,10 +179,10 @@ func TestParallelSB(t *testing.T) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) matching := "wmconvirus.narod.ru" - d.safeBrowsingUpstream = &testSbUpstream{ - hostname: matching, - block: true, - } + d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ + Hostname: matching, + Block: true, + }) t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { @@ -228,26 +229,12 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { } } -// testResolver is a Resolver for tests. -type testResolver struct{} - -// LookupIP implements Resolver interface for *testResolver. -func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) { - hash := sha256.Sum256([]byte(host)) - addrs := []net.IPAddr{{ - IP: net.IP(hash[:4]), - Zone: "somezone", - }, { - IP: net.IP(hash[4:20]), - Zone: "somezone", - }} - return addrs, nil -} - func TestCheckHostSafeSearchGoogle(t *testing.T) { - d := newForTest(&Config{SafeSearchEnabled: true}, nil) + d := newForTest(&Config{ + SafeSearchEnabled: true, + CustomResolver: &aghtest.TestResolver{}, + }, nil) t.Cleanup(d.Close) - d.resolver = &testResolver{} // Check host for each domain. for _, host := range []string{ @@ -299,12 +286,12 @@ func TestSafeSearchCacheYandex(t *testing.T) { } func TestSafeSearchCacheGoogle(t *testing.T) { - d := newForTest(nil, nil) + resolver := &aghtest.TestResolver{} + d := newForTest(&Config{ + CustomResolver: resolver, + }, nil) t.Cleanup(d.Close) - resolver := &testResolver{} - d.resolver = resolver - domain := "www.google.ru" res, err := d.CheckHost(domain, dns.TypeA, &setts) assert.Nil(t, err) @@ -350,16 +337,16 @@ func TestSafeSearchCacheGoogle(t *testing.T) { func TestParentalControl(t *testing.T) { logOutput := &bytes.Buffer{} - testutil.ReplaceLogWriter(t, logOutput) - testutil.ReplaceLogLevel(t, log.DEBUG) + aghtest.ReplaceLogWriter(t, logOutput) + aghtest.ReplaceLogLevel(t, log.DEBUG) d := newForTest(&Config{ParentalEnabled: true}, nil) t.Cleanup(d.Close) matching := "pornhub.com" - d.parentalUpstream = &testSbUpstream{ - hostname: matching, - block: true, - } + d.SetParentalUpstream(&aghtest.TestBlockUpstream{ + Hostname: matching, + Block: true, + }) d.checkMatch(t, matching) assert.Contains(t, logOutput.String(), "Parental lookup for "+matching) @@ -733,14 +720,14 @@ func TestClientSettings(t *testing.T) { }}, ) t.Cleanup(d.Close) - d.parentalUpstream = &testSbUpstream{ - hostname: "pornhub.com", - block: true, - } - d.safeBrowsingUpstream = &testSbUpstream{ - hostname: "wmconvirus.narod.ru", - block: true, - } + d.SetParentalUpstream(&aghtest.TestBlockUpstream{ + Hostname: "pornhub.com", + Block: true, + }) + d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ + Hostname: "wmconvirus.narod.ru", + Block: true, + }) type testCase struct { name string @@ -801,10 +788,10 @@ func BenchmarkSafeBrowsing(b *testing.B) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) blocked := "wmconvirus.narod.ru" - d.safeBrowsingUpstream = &testSbUpstream{ - hostname: blocked, - block: true, - } + d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ + Hostname: blocked, + Block: true, + }) for n := 0; n < b.N; n++ { res, err := d.CheckHost(blocked, dns.TypeA, &setts) assert.Nilf(b, err, "Error while matching host %s: %s", blocked, err) @@ -816,10 +803,10 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) blocked := "wmconvirus.narod.ru" - d.safeBrowsingUpstream = &testSbUpstream{ - hostname: blocked, - block: true, - } + d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ + Hostname: blocked, + Block: true, + }) b.RunParallel(func(pb *testing.PB) { for pb.Next() { res, err := d.CheckHost(blocked, dns.TypeA, &setts) diff --git a/internal/dnsfilter/safebrowsing.go b/internal/dnsfilter/safebrowsing.go index 9a14b584..9142c7c4 100644 --- a/internal/dnsfilter/safebrowsing.go +++ b/internal/dnsfilter/safebrowsing.go @@ -30,6 +30,20 @@ const ( pcTXTSuffix = `pc.dns.adguard.com.` ) +// SetParentalUpstream sets the parental upstream for *DNSFilter. +// +// TODO(e.burkov): Remove this in v1 API to forbid the direct access. +func (d *DNSFilter) SetParentalUpstream(u upstream.Upstream) { + d.parentalUpstream = u +} + +// SetSafeBrowsingUpstream sets the safe browsing upstream for *DNSFilter. +// +// TODO(e.burkov): Remove this in v1 API to forbid the direct access. +func (d *DNSFilter) SetSafeBrowsingUpstream(u upstream.Upstream) { + d.safeBrowsingUpstream = u +} + func (d *DNSFilter) initSecurityServices() error { var err error d.safeBrowsingServer = defaultSafebrowsingServer @@ -44,15 +58,17 @@ func (d *DNSFilter) initSecurityServices() error { }, } - d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts) + parUps, err := upstream.AddressToUpstream(d.parentalServer, opts) if err != nil { return fmt.Errorf("converting parental server: %w", err) } + d.SetParentalUpstream(parUps) - d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, opts) + sbUps, err := upstream.AddressToUpstream(d.safeBrowsingServer, opts) if err != nil { return fmt.Errorf("converting safe browsing server: %w", err) } + d.SetSafeBrowsingUpstream(sbUps) return nil } @@ -227,7 +243,7 @@ func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) { func (c *sbCtx) storeCache(hashes [][]byte) { sort.Slice(hashes, func(a, b int) bool { - return bytes.Compare(hashes[a], hashes[b]) < 0 + return bytes.Compare(hashes[a], hashes[b]) == -1 }) var curData []byte diff --git a/internal/dnsfilter/safebrowsing_test.go b/internal/dnsfilter/safebrowsing_test.go index a1f627c4..f94a3c99 100644 --- a/internal/dnsfilter/safebrowsing_test.go +++ b/internal/dnsfilter/safebrowsing_test.go @@ -2,14 +2,11 @@ package dnsfilter import ( "crypto/sha256" - "encoding/hex" "strings" - "sync" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/agherr" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/cache" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -108,27 +105,14 @@ func TestSafeBrowsingCache(t *testing.T) { assert.Empty(t, c.getCached()) } -// testErrUpstream implements upstream.Upstream interface for replacing real -// upstream in tests. -type testErrUpstream struct{} - -// Exchange always returns nil Msg and non-nil error. -func (teu *testErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { - return nil, agherr.Error("bad") -} - -func (teu *testErrUpstream) Address() string { - return "" -} - func TestSBPC_checkErrorUpstream(t *testing.T) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - ups := &testErrUpstream{} + ups := &aghtest.TestErrUpstream{} - d.safeBrowsingUpstream = ups - d.parentalUpstream = ups + d.SetSafeBrowsingUpstream(ups) + d.SetParentalUpstream(ups) _, err := d.checkSafeBrowsing("smthng.com") assert.NotNil(t, err) @@ -137,122 +121,86 @@ func TestSBPC_checkErrorUpstream(t *testing.T) { assert.NotNil(t, err) } -// testSbUpstream implements upstream.Upstream interface for replacing real -// upstream in tests. -type testSbUpstream struct { - hostname string - block bool - requestsCount int - counterLock sync.RWMutex -} - -// Exchange returns a message depending on the upstream settings (hostname, block) -func (u *testSbUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { - u.counterLock.Lock() - u.requestsCount++ - u.counterLock.Unlock() - - hash := sha256.Sum256([]byte(u.hostname)) - prefix := hash[0:2] - hashToReturn := hex.EncodeToString(prefix) + strings.Repeat("ab", 28) - if u.block { - hashToReturn = hex.EncodeToString(hash[:]) - } - - m := &dns.Msg{} - m.Answer = []dns.RR{ - &dns.TXT{ - Hdr: dns.RR_Header{ - Name: r.Question[0].Name, - }, - Txt: []string{ - hashToReturn, - }, - }, - } - - return m, nil -} - -func (u *testSbUpstream) Address() string { - return "" -} - -func TestSBPC_sbValidResponse(t *testing.T) { +func TestSBPC(t *testing.T) { d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - ups := &testSbUpstream{} - d.safeBrowsingUpstream = ups - d.parentalUpstream = ups + const hostname = "example.org" - // Prepare the upstream - ups.hostname = "example.org" - ups.block = false - ups.requestsCount = 0 + testCases := []struct { + name string + block bool + testFunc func(string) (Result, error) + testCache cache.Cache + }{{ + name: "sb_no_block", + block: false, + testFunc: d.checkSafeBrowsing, + testCache: gctx.safebrowsingCache, + }, { + name: "sb_block", + block: true, + testFunc: d.checkSafeBrowsing, + testCache: gctx.safebrowsingCache, + }, { + name: "pc_no_block", + block: false, + testFunc: d.checkParental, + testCache: gctx.parentalCache, + }, { + name: "pc_block", + block: true, + testFunc: d.checkParental, + testCache: gctx.parentalCache, + }} - // First - check that the request is not blocked - res, err := d.checkSafeBrowsing("example.org") - assert.Nil(t, err) - assert.False(t, res.IsFiltered) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Prepare the upstream. + ups := &aghtest.TestBlockUpstream{ + Hostname: hostname, + Block: tc.block, + } + d.SetSafeBrowsingUpstream(ups) + d.SetParentalUpstream(ups) - // Check the cache state, check that the response is now cached - assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Count) - assert.Equal(t, 0, gctx.safebrowsingCache.Stats().Hit) + // Firstly, check the request blocking. + hits := 0 + res, err := tc.testFunc(hostname) + assert.Nil(t, err) + if tc.block { + assert.True(t, res.IsFiltered) + assert.Len(t, res.Rules, 1) + hits++ + } else { + assert.False(t, res.IsFiltered) + } - // There was one request to an upstream - assert.Equal(t, 1, ups.requestsCount) + // Check the cache state, check the response is now cached. + assert.Equal(t, 1, tc.testCache.Stats().Count) + assert.Equal(t, hits, tc.testCache.Stats().Hit) - // Now make the same request to check that the cache was used - res, err = d.checkSafeBrowsing("example.org") - assert.Nil(t, err) - assert.False(t, res.IsFiltered) + // There was one request to an upstream. + assert.Equal(t, 1, ups.RequestsCount()) - // Check the cache state, it should've been used - assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Count) - assert.Equal(t, 1, gctx.safebrowsingCache.Stats().Hit) + // Now make the same request to check the cache was used. + res, err = tc.testFunc(hostname) + assert.Nil(t, err) + if tc.block { + assert.True(t, res.IsFiltered) + assert.Len(t, res.Rules, 1) + } else { + assert.False(t, res.IsFiltered) + } - // Check that there were no additional requests - assert.Equal(t, 1, ups.requestsCount) -} - -func TestSBPC_pcBlockedResponse(t *testing.T) { - d := newForTest(&Config{SafeBrowsingEnabled: true}, nil) - t.Cleanup(d.Close) - - ups := &testSbUpstream{} - d.safeBrowsingUpstream = ups - d.parentalUpstream = ups - - // Prepare the upstream - // Make sure that the upstream will return a response that matches the queried domain - ups.hostname = "example.com" - ups.block = true - ups.requestsCount = 0 - - // Make a lookup - res, err := d.checkParental("example.com") - assert.Nil(t, err) - assert.True(t, res.IsFiltered) - assert.Len(t, res.Rules, 1) - - // Check the cache state, check that the response is now cached - assert.Equal(t, 1, gctx.parentalCache.Stats().Count) - assert.Equal(t, 1, gctx.parentalCache.Stats().Hit) - - // There was one request to an upstream - assert.Equal(t, 1, ups.requestsCount) - - // Make a second lookup for the same domain - res, err = d.checkParental("example.com") - assert.Nil(t, err) - assert.True(t, res.IsFiltered) - assert.Len(t, res.Rules, 1) - - // Check the cache state, it should've been used - assert.Equal(t, 1, gctx.parentalCache.Stats().Count) - assert.Equal(t, 2, gctx.parentalCache.Stats().Hit) - - // Check that there were no additional requests - assert.Equal(t, 1, ups.requestsCount) + // Check the cache state, it should've been used. + assert.Equal(t, 1, tc.testCache.Stats().Count) + assert.Equal(t, hits+1, tc.testCache.Stats().Hit) + + // Check that there were no additional requests. + assert.Equal(t, 1, ups.RequestsCount()) + + purgeCaches() + }) + } } diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 748b14b2..a80e7ad5 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -296,12 +296,13 @@ func (s *Server) prepareUpstreamSettings() error { // prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries func (s *Server) prepareIntlProxy() { - intlProxyConfig := proxy.Config{ - CacheEnabled: true, - CacheSizeBytes: 4096, - UpstreamConfig: s.conf.UpstreamConfig, + s.internalProxy = &proxy.Proxy{ + Config: proxy.Config{ + CacheEnabled: true, + CacheSizeBytes: 4096, + UpstreamConfig: s.conf.UpstreamConfig, + }, } - s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} } // prepareTLS - prepares TLS configuration for the DNS proxy diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index c24ba62b..68a0b6ae 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -85,10 +85,11 @@ type DNSCreateParams struct { // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once func NewServer(p DNSCreateParams) *Server { - s := &Server{} - s.dnsFilter = p.DNSFilter - s.stats = p.Stats - s.queryLog = p.QueryLog + s := &Server{ + dnsFilter: p.DNSFilter, + stats: p.Stats, + queryLog: p.QueryLog, + } if p.DHCPServer != nil { s.dhcpServer = p.DHCPServer @@ -103,6 +104,16 @@ func NewServer(p DNSCreateParams) *Server { return s } +// NewCustomServer creates a new instance of *Server with custom internal proxy. +func NewCustomServer(internalProxy *proxy.Proxy) *Server { + s := &Server{} + if internalProxy != nil { + s.internalProxy = internalProxy + } + + return s +} + // Close - close object func (s *Server) Close() { s.Lock() diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 9b72dd52..93b6bcb7 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1,11 +1,9 @@ package dnsforward import ( - "context" "crypto/ecdsa" "crypto/rand" "crypto/rsa" - "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" @@ -20,7 +18,7 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" @@ -32,7 +30,7 @@ import ( ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } const ( @@ -53,10 +51,13 @@ func startDeferStop(t *testing.T, s *Server) { } func TestServer(t *testing.T) { - s := createTestServer(t) + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, }, @@ -88,11 +89,13 @@ func TestServer(t *testing.T) { } func TestServerWithProtectionDisabled(t *testing.T) { - s := createTestServer(t) - s.conf.ProtectionEnabled = false + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, }, @@ -113,7 +116,11 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) var keyPem []byte _, certPem, keyPem = createServerTLSConfig(t) - s = createTestServer(t) + + s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem s.conf.TLSConfig = tlsConf @@ -126,11 +133,11 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) func TestDoTServer(t *testing.T) { s, certPem := createTestTLS(t, TLSConfig{ - TLSListenAddr: &net.TCPAddr{Port: 0}, + TLSListenAddr: &net.TCPAddr{}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, }, @@ -156,11 +163,11 @@ func TestDoTServer(t *testing.T) { func TestDoQServer(t *testing.T) { s, _ := createTestTLS(t, TLSConfig{ - QUICListenAddr: &net.UDPAddr{Port: 0}, + QUICListenAddr: &net.UDPAddr{}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, }, @@ -184,10 +191,27 @@ func TestDoQServer(t *testing.T) { func TestServerRace(t *testing.T) { t.Skip("TODO(e.burkov): inspect the golibs/cache package for locks") - s := createTestServer(t) + filterConf := &dnsfilter.Config{ + SafeBrowsingEnabled: true, + SafeBrowsingCacheSize: 1000, + SafeSearchEnabled: true, + SafeSearchCacheSize: 1000, + ParentalCacheSize: 1000, + CacheTime: 30, + } + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, + }, + ConfigModified: func() {}, + } + s := createTestServer(t, filterConf, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, }, @@ -202,68 +226,74 @@ func TestServerRace(t *testing.T) { sendTestMessagesAsync(t, conn) } -// testResolver is a Resolver for tests. -// -//lint:ignore U1000 TODO(e.burkov): move into aghtest package. -type testResolver struct{} - -// LookupIPAddr implements Resolver interface for *testResolver. -// -//lint:ignore U1000 TODO(e.burkov): move into aghtest package. -func (r *testResolver) LookupIPAddr(_ context.Context, host string) (ips []net.IPAddr, err error) { - hash := sha256.Sum256([]byte(host)) - addrs := []net.IPAddr{{ - IP: net.IP(hash[:4]), - Zone: "somezone", - }, { - IP: net.IP(hash[4:20]), - Zone: "somezone", - }} - return addrs, nil -} - -// LookupHost implements Resolver interface for *testResolver. -// -//lint:ignore U1000 TODO(e.burkov): move into aghtest package. -func (r *testResolver) LookupHost(host string) (addrs []string, err error) { - hash := sha256.Sum256([]byte(host)) - addrs = []string{ - net.IP(hash[:4]).String(), - net.IP(hash[4:20]).String(), - } - return addrs, nil -} - func TestSafeSearch(t *testing.T) { - t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom resolver from aghtest") - - testUpstreamIP := net.IP{213, 180, 193, 56} - testCases := []string{ - "yandex.com.", - "yandex.by.", - "yandex.kz.", - "yandex.ru.", - "www.google.com.", - "www.google.com.af.", - "www.google.be.", - "www.google.by.", + resolver := &aghtest.TestResolver{} + filterConf := &dnsfilter.Config{ + SafeSearchEnabled: true, + SafeSearchCacheSize: 1000, + CacheTime: 30, + CustomResolver: resolver, } + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + }, + } + s := createTestServer(t, filterConf, forwardConf) + startDeferStop(t, s) + + addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() + client := dns.Client{Net: proxy.ProtoUDP} + + yandexIP := net.IP{213, 180, 193, 56} + googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com") + + testCases := []struct { + host string + want net.IP + }{{ + host: "yandex.com.", + want: yandexIP, + }, { + host: "yandex.by.", + want: yandexIP, + }, { + host: "yandex.kz.", + want: yandexIP, + }, { + host: "yandex.ru.", + want: yandexIP, + }, { + host: "www.google.com.", + want: googleIP, + }, { + host: "www.google.com.af.", + want: googleIP, + }, { + host: "www.google.be.", + want: googleIP, + }, { + host: "www.google.by.", + want: googleIP, + }} for _, tc := range testCases { - t.Run("safe_search_"+tc, func(t *testing.T) { - s := createTestServer(t) - startDeferStop(t, s) - - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - client := dns.Client{Net: proxy.ProtoUDP} - - exchangeAndAssertResponse(t, &client, addr, tc, testUpstreamIP) + t.Run(tc.host, func(t *testing.T) { + req := createTestMessage(tc.host) + reply, _, err := client.Exchange(req, addr) + assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) + assertResponse(t, reply, tc.want) }) } } func TestInvalidRequest(t *testing.T) { - s := createTestServer(t) + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() @@ -284,7 +314,14 @@ func TestInvalidRequest(t *testing.T) { } func TestBlockedRequest(t *testing.T) { - s := createTestServer(t) + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + }, + } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -299,12 +336,19 @@ func TestBlockedRequest(t *testing.T) { } func TestServerCustomClientUpstream(t *testing.T) { - s := createTestServer(t) + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + }, + } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { return &proxy.UpstreamConfig{ Upstreams: []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ + &aghtest.TestUpstream{ + IPv4: map[string][]net.IP{ "host.": {{192, 168, 0, 1}}, }, }, @@ -327,82 +371,6 @@ func TestServerCustomClientUpstream(t *testing.T) { assert.Equal(t, net.IP{192, 168, 0, 1}, reply.Answer[0].(*dns.A).A) } -// testUpstream is a mock of real upstream. specify fields with necessary values -// to simulate real upstream behaviour. -// -// TODO(e.burkov): move into aghtest package. -type testUpstream struct { - // cn is a map of hostname to canonical name. - cn map[string]string - // ipv4 is a map of hostname to IPv4. - ipv4 map[string][]net.IP - // ipv6 is a map of hostname to IPv6. - ipv6 map[string][]net.IP -} - -// Exchange implements upstream.Upstream interface for *testUpstream. -func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { - resp := &dns.Msg{} - resp.SetReply(m) - hasRec := false - - name := m.Question[0].Name - - if cname, ok := u.cn[name]; ok { - resp.Answer = append(resp.Answer, &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeCNAME, - }, - Target: cname, - }) - } - - var rrtype uint16 - var a []net.IP - switch m.Question[0].Qtype { - case dns.TypeA: - rrtype = dns.TypeA - if ipv4addr, ok := u.ipv4[name]; ok { - hasRec = true - a = ipv4addr - } - case dns.TypeAAAA: - rrtype = dns.TypeAAAA - if ipv6addr, ok := u.ipv6[name]; ok { - hasRec = true - a = ipv6addr - } - } - for _, ip := range a { - resp.Answer = append(resp.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: rrtype, - }, - A: ip, - }) - } - - if len(resp.Answer) == 0 { - if hasRec { - // Set no error RCode if there are some records for - // given Qname but we didn't apply them. - resp.SetRcode(m, dns.RcodeSuccess) - return resp, nil - } - // Set NXDomain RCode otherwise. - resp.SetRcode(m, dns.RcodeNameError) - } - - return resp, nil -} - -// Address implements upstream.Upstream interface for *testUpstream. -func (u *testUpstream) Address() string { - return "test" -} - func (s *Server) startWithUpstream(u upstream.Upstream) error { s.Lock() defer s.Unlock() @@ -410,30 +378,35 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error { if err != nil { return err } + s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{ Upstreams: []upstream.Upstream{u}, } + return s.dnsProxy.Start() } -// testCNAMEs is a simple map of names and CNAMEs necessary for the testUpstream work +// testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. var testCNAMEs = map[string]string{ "badhost.": "null.example.org.", "whitelist.example.org.": "null.example.org.", } -// testIPv4 is a simple map of names and IPv4s necessary for the testUpstream work +// testIPv4 is a map of names and IPv4s necessary for the TestUpstream work. var testIPv4 = map[string][]net.IP{ "null.example.org.": {{1, 2, 3, 4}}, "example.org.": {{127, 0, 0, 255}}, } func TestBlockCNAMEProtectionEnabled(t *testing.T) { - s := createTestServer(t) - testUpstm := &testUpstream{ - cn: testCNAMEs, - ipv4: testIPv4, - ipv6: nil, + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + }) + testUpstm := &aghtest.TestUpstream{ + CName: testCNAMEs, + IPv4: testIPv4, + IPv6: nil, } s.conf.ProtectionEnabled = false err := s.startWithUpstream(testUpstm) @@ -449,11 +422,18 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { } func TestBlockCNAME(t *testing.T) { - s := createTestServer(t) + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + }, + } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - cn: testCNAMEs, - ipv4: testIPv4, + &aghtest.TestUpstream{ + CName: testCNAMEs, + IPv4: testIPv4, }, } startDeferStop(t, s) @@ -496,14 +476,21 @@ func TestBlockCNAME(t *testing.T) { } func TestClientRulesForCNAMEMatching(t *testing.T) { - s := createTestServer(t) - s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) { - settings.FilteringEnabled = false + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) { + settings.FilteringEnabled = false + }, + }, } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - cn: testCNAMEs, - ipv4: testIPv4, + &aghtest.TestUpstream{ + CName: testCNAMEs, + IPv4: testIPv4, }, } startDeferStop(t, s) @@ -531,8 +518,15 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { } func TestNullBlockedRequest(t *testing.T) { - s := createTestServer(t) - s.conf.FilteringConfig.BlockingMode = "null_ip" + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + BlockingMode: "null_ip", + }, + } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -568,8 +562,8 @@ func TestBlockedCustomIP(t *testing.T) { DNSFilter: dnsfilter.New(&dnsfilter.Config{}, filters), }) conf := ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 0}, - TCPListenAddr: &net.TCPAddr{Port: 0}, + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, BlockingMode: "custom_ip", @@ -606,7 +600,14 @@ func TestBlockedCustomIP(t *testing.T) { } func TestBlockedByHosts(t *testing.T) { - s := createTestServer(t) + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + }, + } + s := createTestServer(t, &dnsfilter.Config{}, forwardConf) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -623,24 +624,32 @@ func TestBlockedByHosts(t *testing.T) { } func TestBlockedBySafeBrowsing(t *testing.T) { - t.Skip("TODO(e.burkov): substitute the dnsfilter by one with custom safeBrowsingUpstream") - resolver := &testResolver{} - ips, _ := resolver.LookupIPAddr(context.Background(), safeBrowsingBlockHost) - addrs, _ := resolver.LookupHost(safeBrowsingBlockHost) + const hostname = "wmconvirus.narod.ru" - s := createTestServer(t) - s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - ipv4: map[string][]net.IP{ - "wmconvirus.narod.ru.": {ips[0].IP}, - }, + sbUps := &aghtest.TestBlockUpstream{ + Hostname: hostname, + Block: true, + } + ans, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) + + filterConf := &dnsfilter.Config{ + SafeBrowsingEnabled: true, + } + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + SafeBrowsingBlockHost: ans.String(), + ProtectionEnabled: true, }, } + s := createTestServer(t, filterConf, forwardConf) + s.dnsFilter.SetSafeBrowsingUpstream(sbUps) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) // SafeBrowsing blocking. - req := createTestMessage("wmconvirus.narod.ru.") + req := createTestMessage(hostname + ".") reply, err := dns.Exchange(req, addr.String()) assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) @@ -648,14 +657,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { a, ok := reply.Answer[0].(*dns.A) if assert.Truef(t, ok, "dns server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) { - found := false - for _, blockAddr := range addrs { - if blockAddr == a.A.String() { - found = true - break - } - } - assert.Truef(t, found, "dns server %s returned wrong answer: %v", addr, a.A) + assert.Equal(t, ans, a.A, "dns server %s returned wrong answer: %v", addr, a.A) } } @@ -679,19 +681,19 @@ func TestRewrite(t *testing.T) { s := NewServer(DNSCreateParams{DNSFilter: f}) err := s.Prepare(&ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 0}, - TCPListenAddr: &net.TCPAddr{Port: 0}, + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, UpstreamDNS: []string{"8.8.8.8:53"}, }, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &testUpstream{ - cn: map[string]string{ + &aghtest.TestUpstream{ + CName: map[string]string{ "example.org": "somename", }, - ipv4: map[string][]net.IP{ + IPv4: map[string][]net.IP{ "example.org.": {{4, 3, 2, 1}}, }, }, @@ -724,13 +726,14 @@ func TestRewrite(t *testing.T) { req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) reply, err = dns.Exchange(req, addr.String()) assert.Nil(t, err) - assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) // the original question is restored + // The original question is restored. + assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) assert.Len(t, reply.Answer, 2) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) } -func createTestServer(t *testing.T) *Server { +func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server { rules := `||nxdomain.example.org ||null.example.org^ 127.0.0.1 host.example.org @@ -739,30 +742,13 @@ func createTestServer(t *testing.T) *Server { filters := []dnsfilter.Filter{{ ID: 0, Data: []byte(rules), }} - c := dnsfilter.Config{ - SafeBrowsingEnabled: true, - SafeBrowsingCacheSize: 1000, - SafeSearchEnabled: true, - SafeSearchCacheSize: 1000, - ParentalCacheSize: 1000, - CacheTime: 30, - } - f := dnsfilter.New(&c, filters) + f := dnsfilter.New(filterConf, filters) s := NewServer(DNSCreateParams{DNSFilter: f}) - s.conf = ServerConfig{ - UDPListenAddr: &net.UDPAddr{Port: 0}, - TCPListenAddr: &net.TCPAddr{Port: 0}, - FilteringConfig: FilteringConfig{ - ProtectionEnabled: true, - UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, - }, - ConfigModified: func() {}, - } + s.conf = forwardConf + assert.Nil(t, s.Prepare(nil)) - err := s.Prepare(nil) - assert.Nil(t, err) return s } @@ -849,15 +835,6 @@ func sendTestMessages(t *testing.T, conn *dns.Conn) { } } -func exchangeAndAssertResponse(t *testing.T, client *dns.Client, addr net.Addr, host string, ip net.IP) { - t.Helper() - - req := createTestMessage(host) - reply, _, err := client.Exchange(req, addr.String()) - assert.Nilf(t, err, "couldn't talk to server %s: %s", addr, err) - assertResponse(t, reply, ip) -} - func createGoogleATestMessage() *dns.Msg { return createTestMessage("google-public-dns-a.google.com.") } @@ -879,6 +856,7 @@ func createTestMessage(host string) *dns.Msg { func createTestMessageWithType(host string, qtype uint16) *dns.Msg { req := createTestMessage(host) req.Question[0].Qtype = qtype + return req } @@ -889,7 +867,10 @@ func assertGoogleAResponse(t *testing.T, reply *dns.Msg) { func assertResponse(t *testing.T, reply *dns.Msg, ip net.IP) { t.Helper() - assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) + if !assert.Lenf(t, reply.Answer, 1, "dns server returned reply with wrong number of answers - %d", len(reply.Answer)) { + return + } + a, ok := reply.Answer[0].(*dns.A) if assert.Truef(t, ok, "dns server returned wrong answer type instead of A: %v", reply.Answer[0]) { assert.Truef(t, a.A.Equal(ip), "dns server returned wrong answer instead of %s: %s", ip, a.A) @@ -900,8 +881,10 @@ func publicKey(priv interface{}) interface{} { switch k := priv.(type) { case *rsa.PrivateKey: return &k.PublicKey + case *ecdsa.PrivateKey: return &k.PublicKey + default: return nil } @@ -1082,6 +1065,7 @@ func (d *testDHCP) Leases(flags int) []dhcpd.Lease { HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, Hostname: "localhost", } + return []dhcpd.Lease{l} } func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} @@ -1094,8 +1078,8 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { DHCPServer: dhcp, }) - s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} - s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} + s.conf.UDPListenAddr = &net.UDPAddr{} + s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true err := s.Prepare(nil) @@ -1143,8 +1127,8 @@ func TestPTRResponseFromHosts(t *testing.T) { t.Cleanup(c.AutoHosts.Close) s := NewServer(DNSCreateParams{DNSFilter: dnsfilter.New(&c, nil)}) - s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} - s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} + s.conf.UDPListenAddr = &net.UDPAddr{} + s.conf.TCPListenAddr = &net.TCPAddr{} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.FilteringConfig.ProtectionEnabled = true assert.Nil(t, s.Prepare(nil)) diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index c8e2f9c5..f1ba7031 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -2,16 +2,35 @@ package dnsforward import ( "io/ioutil" + "net" "net/http" "net/http/httptest" "strings" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/stretchr/testify/assert" ) func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { - s := createTestServer(t) + filterConf := &dnsfilter.Config{ + SafeBrowsingEnabled: true, + SafeBrowsingCacheSize: 1000, + SafeSearchEnabled: true, + SafeSearchCacheSize: 1000, + ParentalCacheSize: 1000, + CacheTime: 30, + } + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, + }, + ConfigModified: func() {}, + } + s := createTestServer(t, filterConf, forwardConf) err := s.Start() assert.Nil(t, err) defer assert.Nil(t, s.Stop()) @@ -35,6 +54,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { conf: func() ServerConfig { conf := defaultConf conf.FastestAddr = true + return conf }, want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"fastest_addr\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n", @@ -43,6 +63,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { conf: func() ServerConfig { conf := defaultConf conf.AllServers = true + return conf }, want: "{\"upstream_dns\":[\"8.8.8.8:53\",\"8.8.4.4:53\"],\"upstream_dns_file\":\"\",\"bootstrap_dns\":[\"9.9.9.10\",\"149.112.112.10\",\"2620:fe::10\",\"2620:fe::fe:10\"],\"protection_enabled\":true,\"ratelimit\":0,\"blocking_mode\":\"\",\"blocking_ipv4\":\"\",\"blocking_ipv6\":\"\",\"edns_cs_enabled\":false,\"dnssec_enabled\":false,\"disable_ipv6\":false,\"upstream_mode\":\"parallel\",\"cache_size\":0,\"cache_ttl_min\":0,\"cache_ttl_max\":0}\n", @@ -61,7 +82,24 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { } func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { - s := createTestServer(t) + filterConf := &dnsfilter.Config{ + SafeBrowsingEnabled: true, + SafeBrowsingCacheSize: 1000, + SafeSearchEnabled: true, + SafeSearchCacheSize: 1000, + ParentalCacheSize: 1000, + CacheTime: 30, + } + forwardConf := ServerConfig{ + UDPListenAddr: &net.UDPAddr{}, + TCPListenAddr: &net.TCPAddr{}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, + }, + ConfigModified: func() {}, + } + s := createTestServer(t, filterConf, forwardConf) defaultConf := s.conf diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index e44b7e83..4dbd07b6 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -9,12 +9,12 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func prepareTestDir() string { diff --git a/internal/home/rdns.go b/internal/home/rdns.go index dad75e44..c21a6f6e 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -27,18 +27,18 @@ type RDNS struct { // InitRDNS - create module context func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS { - r := RDNS{} - r.dnsServer = dnsServer - r.clients = clients + r := &RDNS{ + dnsServer: dnsServer, + clients: clients, + ipAddrs: cache.New(cache.Config{ + EnableLRU: true, + MaxCount: 10000, + }), + ipChannel: make(chan net.IP, 256), + } - cconf := cache.Config{} - cconf.EnableLRU = true - cconf.MaxCount = 10000 - r.ipAddrs = cache.New(cconf) - - r.ipChannel = make(chan net.IP, 256) go r.workerLoop() - return &r + return r } // Begin - add IP address to rDNS queue @@ -75,23 +75,23 @@ func (r *RDNS) Begin(ip net.IP) { func (r *RDNS) resolve(ip net.IP) string { log.Tracef("Resolving host for %s", ip) - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - { - Qtype: dns.TypePTR, - Qclass: dns.ClassINET, - }, - } - var err error - req.Question[0].Name, err = dns.ReverseAddr(ip.String()) + name, err := dns.ReverseAddr(ip.String()) if err != nil { log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) return "" } - resp, err := r.dnsServer.Exchange(&req) + resp, err := r.dnsServer.Exchange(&dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: dns.Id(), + RecursionDesired: true, + }, + Question: []dns.Question{{ + Name: name, + Qtype: dns.TypePTR, + Qclass: dns.ClassINET, + }}, + }) if err != nil { log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) return "" diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 53dd093d..b17efdd8 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -4,16 +4,26 @@ import ( "net" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/stretchr/testify/assert" ) func TestResolveRDNS(t *testing.T) { - dns := &dnsforward.Server{} - conf := &dnsforward.ServerConfig{} - conf.UpstreamDNS = []string{"8.8.8.8"} - err := dns.Prepare(conf) - assert.Nil(t, err) + ups := &aghtest.TestUpstream{ + Reverse: map[string][]string{ + "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, + }, + } + dns := dnsforward.NewCustomServer(&proxy.Proxy{ + Config: proxy.Config{ + UpstreamConfig: &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{ups}, + }, + }, + }) clients := &clientsContainer{} rdns := InitRDNS(dns, clients) diff --git a/internal/home/whois_test.go b/internal/home/whois_test.go index a160cdda..ed72740d 100644 --- a/internal/home/whois_test.go +++ b/internal/home/whois_test.go @@ -13,9 +13,14 @@ func prepareTestDNSServer() error { Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{}) conf := &dnsforward.ServerConfig{} conf.UpstreamDNS = []string{"8.8.8.8"} + return Context.dnsServer.Prepare(conf) } +// TODO(e.burkov): It's kind of complicated to get rid of network access in this +// test. The thing is that *Whois creates new *net.Dialer each time it requests +// the server, so it becomes hard to simulate handling of request from test even +// with substituted upstream. However, it must be done. func TestWhois(t *testing.T) { assert.Nil(t, prepareTestDNSServer()) diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index 58ca5851..fe64a624 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -19,8 +19,8 @@ import ( func TestDecodeLogEntry(t *testing.T) { logOutput := &bytes.Buffer{} - testutil.ReplaceLogWriter(t, logOutput) - testutil.ReplaceLogLevel(t, log.DEBUG) + aghtest.ReplaceLogWriter(t, logOutput) + aghtest.ReplaceLogLevel(t, log.DEBUG) t.Run("success", func(t *testing.T) { const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 24d9064e..07fac8ff 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -10,14 +10,14 @@ import ( "github.com/AdguardTeam/dnsproxy/proxyutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func prepareTestDir() string { diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index 06163c6c..7643b31a 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -7,12 +7,12 @@ import ( "sync/atomic" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func UIntArrayEquals(a, b []uint64) bool { diff --git a/internal/sysutil/sysutil_test.go b/internal/sysutil/sysutil_test.go index 0cddbf42..8ca0ff66 100644 --- a/internal/sysutil/sysutil_test.go +++ b/internal/sysutil/sysutil_test.go @@ -3,9 +3,9 @@ package sysutil import ( "testing" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index a8923ee2..197e2142 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -11,7 +11,7 @@ import ( "strconv" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/stretchr/testify/assert" ) @@ -19,7 +19,7 @@ import ( // TODO(a.garipov): Rewrite these tests. func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func startHTTPServer(data string) (l net.Listener, portStr string) { diff --git a/internal/util/autohosts_test.go b/internal/util/autohosts_test.go index b9632855..c5f26934 100644 --- a/internal/util/autohosts_test.go +++ b/internal/util/autohosts_test.go @@ -8,13 +8,13 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/testutil" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) + aghtest.DiscardLogOutput(m) } func prepareTestDir() string {