diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index f14845a8..e94b47cc 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { // A better approach is for proxy.Stop() to wait until all its workers exit, // but this would require the Upstream interface to have Close() function // (to prevent from hanging while waiting for unresponsive DNS server to respond). - res, err := s.filterDNSRequest(d) + + var setts *dnsfilter.RequestFilteringSettings + var err error + res := &dnsfilter.Result{} + protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil + if protectionEnabled { + setts = s.getClientRequestFilteringSettings(d) + res, err = s.filterDNSRequest(d, setts) + } s.RUnlock() if err != nil { return err @@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { d.Res.Answer = answer } - } else if res.Reason != dnsfilter.NotFilteredWhiteList { + } else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled { origResp2 := d.Res - res, err = s.filterDNSResponse(d) + res, err = s.filterDNSResponse(d, setts) if err != nil { return err } @@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt } // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered -func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) { - if !s.conf.ProtectionEnabled || s.dnsFilter == nil { - return &dnsfilter.Result{}, nil - } - - setts := s.getClientRequestFilteringSettings(d) +func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { req := d.Req host := strings.TrimSuffix(req.Question[0].Name, ".") res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) @@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error // If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. // If this is a match, we set a new response in d.Res and return. -func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) { +func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { for _, a := range d.Res.Answer { host := "" @@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro s.RUnlock() continue } - setts := s.getClientRequestFilteringSettings(d) res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts) s.RUnlock() diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 76f8f028..5a9d8434 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{ "example.org.": {{127, 0, 0, 255}}, } +func TestBlockCNAMEProtectionEnabled(t *testing.T) { + s := createTestServer(t) + testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} + s.conf.ProtectionEnabled = false + err := s.startWithUpstream(testUpstm) + assert.True(t, err == nil) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: + // but protection is disabled - response is NOT blocked + req := createTestMessage("badhost.") + reply, err := dns.Exchange(req, addr.String()) + assert.True(t, err == nil) + assert.True(t, reply.Rcode == dns.RcodeSuccess) +} + func TestBlockCNAME(t *testing.T) { s := createTestServer(t) testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}