Merge: * dns: refactor

Squashed commit of the following:

commit e9469266cafa3df537b5a4d5e28ca51db8289a34
Merge: 17cf6d60 e7e946fa
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Tue Jan 21 13:04:30 2020 +0300

    Merge remote-tracking branch 'origin/master' into refactor

commit 17cf6d60d11602df3837316119ba8828f41a95df
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Mon Jan 20 15:25:43 2020 +0300

    minor

commit 7b79462ebbeb743a10417bd28ceb70262ff9fa5c
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 17:50:09 2020 +0300

    minor

commit d8b175c7eda36005c0277e7876f0f0a55a661b05
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 15:30:37 2020 +0300

    minor

commit 93370aa32aa560d42fc67c95fd13f027ddc01b94
Author: Simon Zolin <s.zolin@adguard.com>
Date:   Fri Jan 17 14:28:14 2020 +0300

    * dns: refactor

    . introduce a local context object
    . move filtering, upstream logic, stats, querylog code to separate functions
This commit is contained in:
Simon Zolin 2020-01-21 13:49:34 +03:00
parent e7e946faa6
commit 3f7e2f7241
1 changed files with 158 additions and 72 deletions

View File

@ -425,14 +425,33 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool
return true, nil
}
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
// nolint (gocyclo)
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
start := time.Now()
// To transfer information between modules
type dnsContext struct {
srv *Server
proxyCtx *proxy.DNSContext
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
startTime time.Time
result *dnsfilter.Result
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
origQuestion dns.Question // question received from client. Set when Rewrites are used.
err error // error returned from the module
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
responseFromUpstream bool // response is received from upstream servers
}
const (
resultDone = iota // module has completed its job, continue
resultFinish // module has completed its job, exit normally
resultError // an error occurred, exit with an error
)
// Perform initial checks; process WHOIS & rDNS
func processInitial(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
_ = proxy.CheckDisabledAAAARequest(d, true)
return nil
return resultFinish
}
if s.conf.OnDNSRequest != nil {
@ -443,10 +462,17 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
d.Req.Question[0].Name == "use-application-dns.net." {
d.Res = s.genNXDomain(d.Req)
return nil
return resultFinish
}
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
return resultDone
}
// Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
s.RLock()
// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
// This could happen after proxy server has been stopped, but its workers are not yet exited.
@ -455,28 +481,27 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
// but this would require the Upstream interface to have Close() function
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
var setts *dnsfilter.RequestFilteringSettings
var err error
res := &dnsfilter.Result{}
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
if protectionEnabled {
setts = s.getClientRequestFilteringSettings(d)
res, err = s.filterDNSRequest(d, setts)
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
if ctx.protectionEnabled {
ctx.setts = s.getClientRequestFilteringSettings(d)
ctx.result, err = s.filterDNSRequest(ctx)
}
s.RUnlock()
if err != nil {
return err
ctx.err = err
return resultError
}
return resultDone
}
var origResp *dns.Msg
if d.Res == nil {
answer := []dns.RR{}
originalQuestion := d.Req.Question[0]
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
// resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
// Pass request to upstream servers; process the response
func processUpstream(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
if d.Res != nil {
return resultDone // response is already set - nothing to do
}
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
@ -489,38 +514,61 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
}
// request was not filtered so let it be processed further
err = p.Resolve(d)
err := s.dnsProxy.Resolve(d)
if err != nil {
return err
ctx.err = err
return resultError
}
ctx.responseFromUpstream = true
return resultDone
}
// Apply filtering logic after we have received response from upstream servers
func processFilteringAfterResponse(ctx *dnsContext) int {
s := ctx.srv
d := ctx.proxyCtx
res := ctx.result
var err error
if !ctx.responseFromUpstream {
return resultDone // don't process response if it's not from upstream servers
}
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
d.Req.Question[0] = originalQuestion
d.Res.Question[0] = originalQuestion
d.Req.Question[0] = ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...) // host -> IP
d.Res.Answer = answer
}
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
} else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled {
origResp2 := d.Res
res, err = s.filterDNSResponse(d, setts)
ctx.result, err = s.filterDNSResponse(ctx)
if err != nil {
return err
ctx.err = err
return resultError
}
if res != nil {
origResp = origResp2 // matched by response
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response
} else {
res = &dnsfilter.Result{}
}
ctx.result = &dnsfilter.Result{}
}
}
if d.Res != nil {
d.Res.Compress = true // some devices require DNS message compression
return resultDone
}
// Write Stats data and logs
func processQueryLogsAndStats(ctx *dnsContext) int {
elapsed := time.Since(ctx.startTime)
s := ctx.srv
d := ctx.proxyCtx
shouldLog := true
msg := d.Req
@ -529,7 +577,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
shouldLog = false
}
elapsed := time.Since(start)
s.RLock()
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
@ -537,8 +584,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
p := querylog.AddParams{
Question: msg,
Answer: d.Res,
OrigAnswer: origResp,
Result: res,
OrigAnswer: ctx.origResp,
Result: ctx.result,
Elapsed: elapsed,
ClientIP: getIP(d.Addr),
}
@ -548,9 +595,41 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
s.queryLog.Add(p)
}
s.updateStats(d, elapsed, *res)
s.updateStats(d, elapsed, *ctx.result)
s.RUnlock()
return resultDone
}
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
// nolint (gocyclo)
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{srv: s, proxyCtx: d}
ctx.result = &dnsfilter.Result{}
ctx.startTime = time.Now()
type modProcessFunc func(ctx *dnsContext) int
mods := []modProcessFunc{
processInitial,
processFilteringBeforeRequest,
processUpstream,
processFilteringAfterResponse,
}
for _, process := range mods {
r := process(ctx)
switch r {
case resultFinish:
return nil
case resultError:
return ctx.err
}
}
if d.Res != nil {
d.Res.Compress = true // some devices require DNS message compression
}
_ = processQueryLogsAndStats(ctx)
return nil
}
@ -619,10 +698,11 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
}
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
d := ctx.proxyCtx
req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".")
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil {
// Return immediately if there's an error
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
@ -653,6 +733,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
}
d.Res = resp
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
ctx.origQuestion = d.Req.Question[0]
// resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
}
return &res, err
@ -660,7 +745,8 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
// If this is a match, we set a new response in d.Res and return.
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) {
d := ctx.proxyCtx
for _, a := range d.Res.Answer {
host := ""
@ -688,7 +774,7 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.Request
s.RUnlock()
continue
}
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
s.RUnlock()
if err != nil {