* dnsforward: get per-client settings only once

+ dnsforward: add 'ProtectionEnabled = false' test
This commit is contained in:
Simon Zolin 2020-01-09 19:52:06 +03:00
parent b3ddae7f85
commit 0ef8e5cdae
2 changed files with 29 additions and 11 deletions

View File

@ -445,7 +445,15 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
// A better approach is for proxy.Stop() to wait until all its workers exit,
// but this would require the Upstream interface to have Close() function
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
res, err := s.filterDNSRequest(d)
var setts *dnsfilter.RequestFilteringSettings
var err error
res := &dnsfilter.Result{}
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
if protectionEnabled {
setts = s.getClientRequestFilteringSettings(d)
res, err = s.filterDNSRequest(d, setts)
}
s.RUnlock()
if err != nil {
return err
@ -486,9 +494,9 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
d.Res.Answer = answer
}
} else if res.Reason != dnsfilter.NotFilteredWhiteList {
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
origResp2 := d.Res
res, err = s.filterDNSResponse(d)
res, err = s.filterDNSResponse(d, setts)
if err != nil {
return err
}
@ -602,12 +610,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
}
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error) {
if !s.conf.ProtectionEnabled || s.dnsFilter == nil {
return &dnsfilter.Result{}, nil
}
setts := s.getClientRequestFilteringSettings(d)
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
@ -648,7 +651,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
// If this is a match, we set a new response in d.Res and return.
func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) {
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
for _, a := range d.Res.Answer {
host := ""
@ -676,7 +679,6 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext) (*dnsfilter.Result, erro
s.RUnlock()
continue
}
setts := s.getClientRequestFilteringSettings(d)
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
s.RUnlock()

View File

@ -340,6 +340,22 @@ var testIPv4 = map[string][]net.IP{
"example.org.": {{127, 0, 0, 255}},
}
func TestBlockCNAMEProtectionEnabled(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.ProtectionEnabled = false
err := s.startWithUpstream(testUpstm)
assert.True(t, err == nil)
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
// 'badhost' has a canonical name 'null.example.org' which is blocked by filters:
// but protection is disabled - response is NOT blocked
req := createTestMessage("badhost.")
reply, err := dns.Exchange(req, addr.String())
assert.True(t, err == nil)
assert.True(t, reply.Rcode == dns.RcodeSuccess)
}
func TestBlockCNAME(t *testing.T) {
s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}