From 3f7e2f7241caa0cf0847579efefaa276e273f89e Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Tue, 21 Jan 2020 13:49:34 +0300 Subject: [PATCH] Merge: * dns: refactor Squashed commit of the following: commit e9469266cafa3df537b5a4d5e28ca51db8289a34 Merge: 17cf6d60 e7e946fa Author: Simon Zolin Date: Tue Jan 21 13:04:30 2020 +0300 Merge remote-tracking branch 'origin/master' into refactor commit 17cf6d60d11602df3837316119ba8828f41a95df Author: Simon Zolin Date: Mon Jan 20 15:25:43 2020 +0300 minor commit 7b79462ebbeb743a10417bd28ceb70262ff9fa5c Author: Simon Zolin Date: Fri Jan 17 17:50:09 2020 +0300 minor commit d8b175c7eda36005c0277e7876f0f0a55a661b05 Author: Simon Zolin Date: Fri Jan 17 15:30:37 2020 +0300 minor commit 93370aa32aa560d42fc67c95fd13f027ddc01b94 Author: Simon Zolin Date: Fri Jan 17 14:28:14 2020 +0300 * dns: refactor . introduce a local context object . move filtering, upstream logic, stats, querylog code to separate functions --- dnsforward/dnsforward.go | 230 +++++++++++++++++++++++++++------------ 1 file changed, 158 insertions(+), 72 deletions(-) diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 462cf4d3..85be3ac9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -425,14 +425,33 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool return true, nil } -// handleDNSRequest filters the incoming DNS requests and writes them to the query log -// nolint (gocyclo) -func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { - start := time.Now() +// To transfer information between modules +type dnsContext struct { + srv *Server + proxyCtx *proxy.DNSContext + setts *dnsfilter.RequestFilteringSettings // filtering settings for this client + startTime time.Time + result *dnsfilter.Result + origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering + origQuestion dns.Question // question received from client. Set when Rewrites are used. + err error // error returned from the module + protectionEnabled bool // filtering is enabled, dnsfilter object is ready + responseFromUpstream bool // response is received from upstream servers +} +const ( + resultDone = iota // module has completed its job, continue + resultFinish // module has completed its job, exit normally + resultError // an error occurred, exit with an error +) + +// Perform initial checks; process WHOIS & rDNS +func processInitial(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { _ = proxy.CheckDisabledAAAARequest(d, true) - return nil + return resultFinish } if s.conf.OnDNSRequest != nil { @@ -443,10 +462,17 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) && d.Req.Question[0].Name == "use-application-dns.net." { d.Res = s.genNXDomain(d.Req) - return nil + return resultFinish } - // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise + return resultDone +} + +// Apply filtering logic +func processFilteringBeforeRequest(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + s.RLock() // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. // This could happen after proxy server has been stopped, but its workers are not yet exited. @@ -455,72 +481,94 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { // but this would require the Upstream interface to have Close() function // (to prevent from hanging while waiting for unresponsive DNS server to respond). - var setts *dnsfilter.RequestFilteringSettings var err error - res := &dnsfilter.Result{} - protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil - if protectionEnabled { - setts = s.getClientRequestFilteringSettings(d) - res, err = s.filterDNSRequest(d, setts) + ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil + if ctx.protectionEnabled { + ctx.setts = s.getClientRequestFilteringSettings(d) + ctx.result, err = s.filterDNSRequest(ctx) } s.RUnlock() + if err != nil { - return err - } - - var origResp *dns.Msg - if d.Res == nil { - answer := []dns.RR{} - originalQuestion := d.Req.Question[0] - - if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { - answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) - // resolve canonical name, not the original host name - d.Req.Question[0].Name = dns.Fqdn(res.CanonName) - } - - if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { - clientIP := ipFromAddr(d.Addr) - upstreams := s.conf.GetUpstreamsByClient(clientIP) - if len(upstreams) > 0 { - log.Debug("Using custom upstreams for %s", clientIP) - d.Upstreams = upstreams - } - } - - // request was not filtered so let it be processed further - err = p.Resolve(d) - if err != nil { - return err - } - - if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { - d.Req.Question[0] = originalQuestion - d.Res.Question[0] = originalQuestion - - if len(d.Res.Answer) != 0 { - answer = append(answer, d.Res.Answer...) // host -> IP - d.Res.Answer = answer - } - - } else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled { - origResp2 := d.Res - res, err = s.filterDNSResponse(d, setts) - if err != nil { - return err - } - if res != nil { - origResp = origResp2 // matched by response - } else { - res = &dnsfilter.Result{} - } - } + ctx.err = err + return resultError } + return resultDone +} +// Pass request to upstream servers; process the response +func processUpstream(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx if d.Res != nil { - d.Res.Compress = true // some devices require DNS message compression + return resultDone // response is already set - nothing to do } + if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { + clientIP := ipFromAddr(d.Addr) + upstreams := s.conf.GetUpstreamsByClient(clientIP) + if len(upstreams) > 0 { + log.Debug("Using custom upstreams for %s", clientIP) + d.Upstreams = upstreams + } + } + + // request was not filtered so let it be processed further + err := s.dnsProxy.Resolve(d) + if err != nil { + ctx.err = err + return resultError + } + + ctx.responseFromUpstream = true + return resultDone +} + +// Apply filtering logic after we have received response from upstream servers +func processFilteringAfterResponse(ctx *dnsContext) int { + s := ctx.srv + d := ctx.proxyCtx + res := ctx.result + var err error + + if !ctx.responseFromUpstream { + return resultDone // don't process response if it's not from upstream servers + } + + if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + d.Req.Question[0] = ctx.origQuestion + d.Res.Question[0] = ctx.origQuestion + + if len(d.Res.Answer) != 0 { + answer := []dns.RR{} + answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName)) + answer = append(answer, d.Res.Answer...) // host -> IP + d.Res.Answer = answer + } + + } else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled { + origResp2 := d.Res + ctx.result, err = s.filterDNSResponse(ctx) + if err != nil { + ctx.err = err + return resultError + } + if ctx.result != nil { + ctx.origResp = origResp2 // matched by response + } else { + ctx.result = &dnsfilter.Result{} + } + } + + return resultDone +} + +// Write Stats data and logs +func processQueryLogsAndStats(ctx *dnsContext) int { + elapsed := time.Since(ctx.startTime) + s := ctx.srv + d := ctx.proxyCtx + shouldLog := true msg := d.Req @@ -529,7 +577,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { shouldLog = false } - elapsed := time.Since(start) s.RLock() // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. // This can happen after proxy server has been stopped, but its workers haven't yet exited. @@ -537,8 +584,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { p := querylog.AddParams{ Question: msg, Answer: d.Res, - OrigAnswer: origResp, - Result: res, + OrigAnswer: ctx.origResp, + Result: ctx.result, Elapsed: elapsed, ClientIP: getIP(d.Addr), } @@ -548,9 +595,41 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { s.queryLog.Add(p) } - s.updateStats(d, elapsed, *res) + s.updateStats(d, elapsed, *ctx.result) s.RUnlock() + return resultDone +} + +// handleDNSRequest filters the incoming DNS requests and writes them to the query log +// nolint (gocyclo) +func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { + ctx := &dnsContext{srv: s, proxyCtx: d} + ctx.result = &dnsfilter.Result{} + ctx.startTime = time.Now() + + type modProcessFunc func(ctx *dnsContext) int + mods := []modProcessFunc{ + processInitial, + processFilteringBeforeRequest, + processUpstream, + processFilteringAfterResponse, + } + for _, process := range mods { + r := process(ctx) + switch r { + case resultFinish: + return nil + case resultError: + return ctx.err + } + } + + if d.Res != nil { + d.Res.Compress = true // some devices require DNS message compression + } + + _ = processQueryLogsAndStats(ctx) return nil } @@ -619,10 +698,11 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt } // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered -func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { +func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { + d := ctx.proxyCtx req := d.Req host := strings.TrimSuffix(req.Question[0].Name, ".") - res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts) if err != nil { // Return immediately if there's an error return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) @@ -653,6 +733,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF } d.Res = resp + + } else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + ctx.origQuestion = d.Req.Question[0] + // resolve canonical name, not the original host name + d.Req.Question[0].Name = dns.Fqdn(res.CanonName) } return &res, err @@ -660,7 +745,8 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF // If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. // If this is a match, we set a new response in d.Res and return. -func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { +func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) { + d := ctx.proxyCtx for _, a := range d.Res.Answer { host := "" @@ -688,7 +774,7 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.Request s.RUnlock() continue } - res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts) + res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts) s.RUnlock() if err != nil {