Merge: * dns: refactor

Squashed commit of the following:

commit e9469266cafa3df537b5a4d5e28ca51db8289a34
Merge: 17cf6d60 e7e946fa
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Jan 21 13:04:30 2020 +0300

    Merge remote-tracking branch 'origin/master' into refactor

commit 17cf6d60d11602df3837316119ba8828f41a95df
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jan 20 15:25:43 2020 +0300

    minor

commit 7b79462ebbeb743a10417bd28ceb70262ff9fa5c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 17:50:09 2020 +0300

    minor

commit d8b175c7eda36005c0277e7876f0f0a55a661b05
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 15:30:37 2020 +0300

    minor

commit 93370aa32aa560d42fc67c95fd13f027ddc01b94
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 14:28:14 2020 +0300

    * dns: refactor

    . introduce a local context object
    . move filtering, upstream logic, stats, querylog code to separate functions
This commit is contained in:
Simon Zolin 2020-01-21 13:49:34 +03:00
parent e7e946faa6
commit 3f7e2f7241
1 changed files with 158 additions and 72 deletions

View File

@ -425,14 +425,33 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool
return true, nil return true, nil
} }
// handleDNSRequest filters the incoming DNS requests and writes them to the query log // To transfer information between modules
// nolint (gocyclo) type dnsContext struct {
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { srv *Server
start := time.Now() proxyCtx *proxy.DNSContext
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
startTime time.Time
result *dnsfilter.Result
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
origQuestion dns.Question // question received from client. Set when Rewrites are used.
err error // error returned from the module
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
responseFromUpstream bool // response is received from upstream servers
}
const (
resultDone = iota // module has completed its job, continue
resultFinish // module has completed its job, exit normally
resultError // an error occurred, exit with an error
)
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true) _ = proxy.CheckDisabledAAAARequest(d, true)
return nil return resultFinish
} }
if s.conf.OnDNSRequest != nil { if s.conf.OnDNSRequest != nil {
@ -443,10 +462,17 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) && if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
d.Req.Question[0].Name == "use-application-dns.net." { d.Req.Question[0].Name == "use-application-dns.net." {
d.Res = s.genNXDomain(d.Req) d.Res = s.genNXDomain(d.Req)
return nil return resultFinish
} }
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise return resultDone
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
s.RLock() s.RLock()
// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
// This could happen after proxy server has been stopped, but its workers are not yet exited. // This could happen after proxy server has been stopped, but its workers are not yet exited.
@ -455,28 +481,27 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
// but this would require the Upstream interface to have Close() function // but this would require the Upstream interface to have Close() function
// (to prevent from hanging while waiting for unresponsive DNS server to respond). // (to prevent from hanging while waiting for unresponsive DNS server to respond).
var setts *dnsfilter.RequestFilteringSettings
var err error var err error
res := &dnsfilter.Result{} ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil if ctx.protectionEnabled {
if protectionEnabled { ctx.setts = s.getClientRequestFilteringSettings(d)
setts = s.getClientRequestFilteringSettings(d) ctx.result, err = s.filterDNSRequest(ctx)
res, err = s.filterDNSRequest(d, setts)
} }
s.RUnlock() s.RUnlock()
if err != nil { if err != nil {
return err ctx.err = err
return resultError
} }
return resultDone
}
var origResp *dns.Msg // Pass request to upstream servers; process the response
if d.Res == nil { func processUpstream(ctx *dnsContext) int {
answer := []dns.RR{} s := ctx.srv
originalQuestion := d.Req.Question[0] d := ctx.proxyCtx
if d.Res != nil {
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { return resultDone // response is already set - nothing to do
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
// resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
} }
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
@ -489,37 +514,60 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
} }
// request was not filtered so let it be processed further // request was not filtered so let it be processed further
err = p.Resolve(d) err := s.dnsProxy.Resolve(d)
if err != nil { if err != nil {
return err ctx.err = err
return resultError
}
ctx.responseFromUpstream = true
return resultDone
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
var err error
if !ctx.responseFromUpstream {
return resultDone // don't process response if it's not from upstream servers
} }
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
d.Req.Question[0] = originalQuestion d.Req.Question[0] = ctx.origQuestion
d.Res.Question[0] = originalQuestion d.Res.Question[0] = ctx.origQuestion
if len(d.Res.Answer) != 0 { if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...) // host -> IP answer = append(answer, d.Res.Answer...) // host -> IP
d.Res.Answer = answer d.Res.Answer = answer
} }
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled { } else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled {
origResp2 := d.Res origResp2 := d.Res
res, err = s.filterDNSResponse(d, setts) ctx.result, err = s.filterDNSResponse(ctx)
if err != nil { if err != nil {
return err ctx.err = err
return resultError
} }
if res != nil { if ctx.result != nil {
origResp = origResp2 // matched by response ctx.origResp = origResp2 // matched by response
} else { } else {
res = &dnsfilter.Result{} ctx.result = &dnsfilter.Result{}
}
} }
} }
if d.Res != nil { return resultDone
d.Res.Compress = true // some devices require DNS message compression }
}
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) int {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
shouldLog := true shouldLog := true
msg := d.Req msg := d.Req
@ -529,7 +577,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
shouldLog = false shouldLog = false
} }
elapsed := time.Since(start)
s.RLock() s.RLock()
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
// This can happen after proxy server has been stopped, but its workers haven't yet exited. // This can happen after proxy server has been stopped, but its workers haven't yet exited.
@ -537,8 +584,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
p := querylog.AddParams{ p := querylog.AddParams{
Question: msg, Question: msg,
Answer: d.Res, Answer: d.Res,
OrigAnswer: origResp, OrigAnswer: ctx.origResp,
Result: res, Result: ctx.result,
Elapsed: elapsed, Elapsed: elapsed,
ClientIP: getIP(d.Addr), ClientIP: getIP(d.Addr),
} }
@ -548,9 +595,41 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
s.queryLog.Add(p) s.queryLog.Add(p)
} }
s.updateStats(d, elapsed, *res) s.updateStats(d, elapsed, *ctx.result)
s.RUnlock() s.RUnlock()
return resultDone
}
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
// nolint (gocyclo)
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now()
type modProcessFunc func(ctx *dnsContext) int
mods := []modProcessFunc{
processInitial,
processFilteringBeforeRequest,
processUpstream,
processFilteringAfterResponse,
}
for _, process := range mods {
r := process(ctx)
switch r {
case resultFinish:
return nil
case resultError:
return ctx.err
}
}
if d.Res != nil {
d.Res.Compress = true // some devices require DNS message compression
}
_ = processQueryLogsAndStats(ctx)
return nil return nil
} }
@ -619,10 +698,11 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
} }
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered // filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
d := ctx.proxyCtx
req := d.Req req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".") host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts) res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil { if err != nil {
// Return immediately if there's an error // Return immediately if there's an error
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host) return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
@ -653,6 +733,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
} }
d.Res = resp d.Res = resp
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
ctx.origQuestion = d.Req.Question[0]
// resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
} }
return &res, err return &res, err
@ -660,7 +745,8 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. // If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
// If this is a match, we set a new response in d.Res and return. // If this is a match, we set a new response in d.Res and return.
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) { func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) {
d := ctx.proxyCtx
for _, a := range d.Res.Answer { for _, a := range d.Res.Answer {
host := "" host := ""
@ -688,7 +774,7 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.Request
s.RUnlock() s.RUnlock()
continue continue
} }
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts) res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
s.RUnlock() s.RUnlock()
if err != nil { if err != nil {