From 15d07a40eb83c021032bfca5bbd532f3b032f2c5 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Thu, 25 Jul 2019 16:37:06 +0300 Subject: [PATCH] * refactor --- dnsfilter/dnsfilter.go | 11 +------- dnsfilter/dnsfilter_test.go | 52 +++++++++++++++++++++---------------- dnsforward/dnsforward.go | 12 ++++++++- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index cc724ff7..47ce1f0e 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -206,7 +206,7 @@ func (r Reason) Matched() bool { } // CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled -func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Result, error) { +func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) { // sometimes DNS clients will try to resolve ".", which is a request to get root servers if host == "" { return Result{Reason: NotFilteredNotFound}, nil @@ -217,15 +217,6 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, clientAddr string) (Res return Result{}, nil } - var setts RequestFilteringSettings - setts.FilteringEnabled = true - setts.SafeSearchEnabled = d.SafeSearchEnabled - setts.SafeBrowsingEnabled = d.SafeBrowsingEnabled - setts.ParentalEnabled = d.ParentalEnabled - if d.FilterHandler != nil { - d.FilterHandler(clientAddr, &setts) - } - var result Result var err error diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 7df5fe09..d5eabafe 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -16,6 +16,8 @@ import ( "github.com/stretchr/testify/assert" ) +var setts RequestFilteringSettings + // HELPERS // SAFE BROWSING // SAFE SEARCH @@ -46,10 +48,16 @@ func _Func() string { } func NewForTest(c *Config, filters map[int]string) *Dnsfilter { + setts = RequestFilteringSettings{} + setts.FilteringEnabled = true if c != nil { c.SafeBrowsingCacheSize = 1024 c.SafeSearchCacheSize = 1024 c.ParentalCacheSize = 1024 + + setts.SafeSearchEnabled = c.SafeSearchEnabled + setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled + setts.ParentalEnabled = c.ParentalEnabled } d := New(c, filters) purgeCaches() @@ -58,7 +66,7 @@ func NewForTest(c *Config, filters map[int]string) *Dnsfilter { func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname, dns.TypeA, "") + ret, err := d.CheckHost(hostname, dns.TypeA, &setts) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -69,7 +77,7 @@ func (d *Dnsfilter) checkMatch(t *testing.T, hostname string) { func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype uint16) { t.Helper() - ret, err := d.CheckHost(hostname, qtype, "") + ret, err := d.CheckHost(hostname, qtype, &setts) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -83,7 +91,7 @@ func (d *Dnsfilter) checkMatchIP(t *testing.T, hostname string, ip string, qtype func (d *Dnsfilter) checkMatchEmpty(t *testing.T, hostname string) { t.Helper() - ret, err := d.CheckHost(hostname, dns.TypeA, "") + ret, err := d.CheckHost(hostname, dns.TypeA, &setts) if err != nil { t.Errorf("Error while matching host %s: %s", hostname, err) } @@ -214,7 +222,7 @@ func TestCheckHostSafeSearchYandex(t *testing.T) { // Check host for each domain for _, host := range yandex { - result, err := d.CheckHost(host, dns.TypeA, "") + result, err := d.CheckHost(host, dns.TypeA, &setts) if err != nil { t.Errorf("SafeSearch doesn't work for yandex domain `%s` cause %s", host, err) } @@ -234,7 +242,7 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) { // Check host for each domain for _, host := range googleDomains { - result, err := d.CheckHost(host, dns.TypeA, "") + result, err := d.CheckHost(host, dns.TypeA, &setts) if err != nil { t.Errorf("SafeSearch doesn't work for %s cause %s", host, err) } @@ -254,7 +262,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { var err error // Check host with disabled safesearch - result, err = d.CheckHost(domain, dns.TypeA, "") + result, err = d.CheckHost(domain, dns.TypeA, &setts) if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -265,7 +273,7 @@ func TestSafeSearchCacheYandex(t *testing.T) { d = NewForTest(&Config{SafeSearchEnabled: true}, nil) defer d.Destroy() - result, err = d.CheckHost(domain, dns.TypeA, "") + result, err = d.CheckHost(domain, dns.TypeA, &setts) if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -291,7 +299,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { d := NewForTest(nil, nil) defer d.Destroy() domain := "www.google.ru" - result, err := d.CheckHost(domain, dns.TypeA, "") + result, err := d.CheckHost(domain, dns.TypeA, &setts) if err != nil { t.Fatalf("Cannot check host due to %s", err) } @@ -322,7 +330,7 @@ func TestSafeSearchCacheGoogle(t *testing.T) { } } - result, err = d.CheckHost(domain, dns.TypeA, "") + result, err = d.CheckHost(domain, dns.TypeA, &setts) if err != nil { t.Fatalf("CheckHost for safesearh domain %s failed cause %s", domain, err) } @@ -435,7 +443,7 @@ func TestMatching(t *testing.T) { d := NewForTest(nil, filters) defer d.Destroy() - ret, err := d.CheckHost(test.hostname, dns.TypeA, "") + ret, err := d.CheckHost(test.hostname, dns.TypeA, &setts) if err != nil { t.Errorf("Error while matching host %s: %s", test.hostname, err) } @@ -451,7 +459,7 @@ func TestMatching(t *testing.T) { // CLIENT SETTINGS -func applyClientSettings(clientAddr string, setts *RequestFilteringSettings) { +func applyClientSettings(setts *RequestFilteringSettings) { setts.FilteringEnabled = false setts.ParentalEnabled = false setts.SafeBrowsingEnabled = true @@ -476,50 +484,50 @@ func TestClientSettings(t *testing.T) { // no client settings: // blocked by filters - r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("example.org", dns.TypeA, &setts) if !r.IsFiltered || r.Reason != FilteredBlackList { t.Fatalf("CheckHost FilteredBlackList") } // blocked by parental - r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts) if !r.IsFiltered || r.Reason != FilteredParental { t.Fatalf("CheckHost FilteredParental") } // safesearch is disabled - r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts) if r.IsFiltered { t.Fatalf("CheckHost safesearch") } // not blocked - r, _ = d.CheckHost("facebook.com", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) assert.True(t, !r.IsFiltered) // override client settings: - d.FilterHandler = applyClientSettings + applyClientSettings(&setts) // override filtering settings - r, _ = d.CheckHost("example.org", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("example.org", dns.TypeA, &setts) if r.IsFiltered { t.Fatalf("CheckHost") } // override parental settings (force disable parental) - r, _ = d.CheckHost("pornhub.com", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("pornhub.com", dns.TypeA, &setts) if r.IsFiltered { t.Fatalf("CheckHost") } // override safesearch settings (force enable safesearch) - r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("wmconvirus.narod.ru", dns.TypeA, &setts) if !r.IsFiltered || r.Reason != FilteredSafeBrowsing { t.Fatalf("CheckHost FilteredSafeBrowsing") } // blocked by additional rules - r, _ = d.CheckHost("facebook.com", dns.TypeA, "1.1.1.1") + r, _ = d.CheckHost("facebook.com", dns.TypeA, &setts) assert.True(t, r.IsFiltered && r.Reason == FilteredBlockedService) } @@ -530,7 +538,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { defer d.Destroy() for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname, dns.TypeA, "") + ret, err := d.CheckHost(hostname, dns.TypeA, &setts) if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } @@ -546,7 +554,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" - ret, err := d.CheckHost(hostname, dns.TypeA, "") + ret, err := d.CheckHost(hostname, dns.TypeA, &setts) if err != nil { b.Errorf("Error while matching host %s: %s", hostname, err) } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index d09bf7be..a5aa81fd 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -533,7 +533,17 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error if d.Addr != nil { clientAddr, _, _ = net.SplitHostPort(d.Addr.String()) } - res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, clientAddr) + + var setts dnsfilter.RequestFilteringSettings + setts.FilteringEnabled = true + setts.SafeSearchEnabled = s.conf.SafeSearchEnabled + setts.SafeBrowsingEnabled = s.conf.SafeBrowsingEnabled + setts.ParentalEnabled = s.conf.ParentalEnabled + if s.conf.FilterHandler != nil { + s.conf.FilterHandler(clientAddr, &setts) + } + + res, err = dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)