Merge: * dns: refactor
Squashed commit of the following:
commit e9469266cafa3df537b5a4d5e28ca51db8289a34
Merge: 17cf6d60 e7e946fa
Author: Simon Zolin <s.zolin@adguard.com>
Date: Tue Jan 21 13:04:30 2020 +0300
Merge remote-tracking branch 'origin/master' into refactor
commit 17cf6d60d11602df3837316119ba8828f41a95df
Author: Simon Zolin <s.zolin@adguard.com>
Date: Mon Jan 20 15:25:43 2020 +0300
minor
commit 7b79462ebbeb743a10417bd28ceb70262ff9fa5c
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 17:50:09 2020 +0300
minor
commit d8b175c7eda36005c0277e7876f0f0a55a661b05
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 15:30:37 2020 +0300
minor
commit 93370aa32aa560d42fc67c95fd13f027ddc01b94
Author: Simon Zolin <s.zolin@adguard.com>
Date: Fri Jan 17 14:28:14 2020 +0300
* dns: refactor
. introduce a local context object
. move filtering, upstream logic, stats, querylog code to separate functions
This commit is contained in:
parent
e7e946faa6
commit
3f7e2f7241
|
@ -425,14 +425,33 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
// To transfer information between modules
|
||||||
// nolint (gocyclo)
|
type dnsContext struct {
|
||||||
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
srv *Server
|
||||||
start := time.Now()
|
proxyCtx *proxy.DNSContext
|
||||||
|
setts *dnsfilter.RequestFilteringSettings // filtering settings for this client
|
||||||
|
startTime time.Time
|
||||||
|
result *dnsfilter.Result
|
||||||
|
origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering
|
||||||
|
origQuestion dns.Question // question received from client. Set when Rewrites are used.
|
||||||
|
err error // error returned from the module
|
||||||
|
protectionEnabled bool // filtering is enabled, dnsfilter object is ready
|
||||||
|
responseFromUpstream bool // response is received from upstream servers
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
resultDone = iota // module has completed its job, continue
|
||||||
|
resultFinish // module has completed its job, exit normally
|
||||||
|
resultError // an error occurred, exit with an error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Perform initial checks; process WHOIS & rDNS
|
||||||
|
func processInitial(ctx *dnsContext) int {
|
||||||
|
s := ctx.srv
|
||||||
|
d := ctx.proxyCtx
|
||||||
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
|
if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA {
|
||||||
_ = proxy.CheckDisabledAAAARequest(d, true)
|
_ = proxy.CheckDisabledAAAARequest(d, true)
|
||||||
return nil
|
return resultFinish
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.conf.OnDNSRequest != nil {
|
if s.conf.OnDNSRequest != nil {
|
||||||
|
@ -443,10 +462,17 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
|
if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) &&
|
||||||
d.Req.Question[0].Name == "use-application-dns.net." {
|
d.Req.Question[0].Name == "use-application-dns.net." {
|
||||||
d.Res = s.genNXDomain(d.Req)
|
d.Res = s.genNXDomain(d.Req)
|
||||||
return nil
|
return resultFinish
|
||||||
}
|
}
|
||||||
|
|
||||||
// use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise
|
return resultDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply filtering logic
|
||||||
|
func processFilteringBeforeRequest(ctx *dnsContext) int {
|
||||||
|
s := ctx.srv
|
||||||
|
d := ctx.proxyCtx
|
||||||
|
|
||||||
s.RLock()
|
s.RLock()
|
||||||
// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
|
// Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use.
|
||||||
// This could happen after proxy server has been stopped, but its workers are not yet exited.
|
// This could happen after proxy server has been stopped, but its workers are not yet exited.
|
||||||
|
@ -455,28 +481,27 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
// but this would require the Upstream interface to have Close() function
|
// but this would require the Upstream interface to have Close() function
|
||||||
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
// (to prevent from hanging while waiting for unresponsive DNS server to respond).
|
||||||
|
|
||||||
var setts *dnsfilter.RequestFilteringSettings
|
|
||||||
var err error
|
var err error
|
||||||
res := &dnsfilter.Result{}
|
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil
|
||||||
protectionEnabled := s.conf.ProtectionEnabled && s.dnsFilter != nil
|
if ctx.protectionEnabled {
|
||||||
if protectionEnabled {
|
ctx.setts = s.getClientRequestFilteringSettings(d)
|
||||||
setts = s.getClientRequestFilteringSettings(d)
|
ctx.result, err = s.filterDNSRequest(ctx)
|
||||||
res, err = s.filterDNSRequest(d, setts)
|
|
||||||
}
|
}
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
ctx.err = err
|
||||||
|
return resultError
|
||||||
|
}
|
||||||
|
return resultDone
|
||||||
}
|
}
|
||||||
|
|
||||||
var origResp *dns.Msg
|
// Pass request to upstream servers; process the response
|
||||||
if d.Res == nil {
|
func processUpstream(ctx *dnsContext) int {
|
||||||
answer := []dns.RR{}
|
s := ctx.srv
|
||||||
originalQuestion := d.Req.Question[0]
|
d := ctx.proxyCtx
|
||||||
|
if d.Res != nil {
|
||||||
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
return resultDone // response is already set - nothing to do
|
||||||
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
|
|
||||||
// resolve canonical name, not the original host name
|
|
||||||
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
if d.Addr != nil && s.conf.GetUpstreamsByClient != nil {
|
||||||
|
@ -489,38 +514,61 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// request was not filtered so let it be processed further
|
// request was not filtered so let it be processed further
|
||||||
err = p.Resolve(d)
|
err := s.dnsProxy.Resolve(d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
ctx.err = err
|
||||||
|
return resultError
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.responseFromUpstream = true
|
||||||
|
return resultDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply filtering logic after we have received response from upstream servers
|
||||||
|
func processFilteringAfterResponse(ctx *dnsContext) int {
|
||||||
|
s := ctx.srv
|
||||||
|
d := ctx.proxyCtx
|
||||||
|
res := ctx.result
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if !ctx.responseFromUpstream {
|
||||||
|
return resultDone // don't process response if it's not from upstream servers
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||||
d.Req.Question[0] = originalQuestion
|
d.Req.Question[0] = ctx.origQuestion
|
||||||
d.Res.Question[0] = originalQuestion
|
d.Res.Question[0] = ctx.origQuestion
|
||||||
|
|
||||||
if len(d.Res.Answer) != 0 {
|
if len(d.Res.Answer) != 0 {
|
||||||
|
answer := []dns.RR{}
|
||||||
|
answer = append(answer, s.genCNAMEAnswer(d.Req, res.CanonName))
|
||||||
answer = append(answer, d.Res.Answer...) // host -> IP
|
answer = append(answer, d.Res.Answer...) // host -> IP
|
||||||
d.Res.Answer = answer
|
d.Res.Answer = answer
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if res.Reason != dnsfilter.NotFilteredWhiteList && protectionEnabled {
|
} else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled {
|
||||||
origResp2 := d.Res
|
origResp2 := d.Res
|
||||||
res, err = s.filterDNSResponse(d, setts)
|
ctx.result, err = s.filterDNSResponse(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
ctx.err = err
|
||||||
|
return resultError
|
||||||
}
|
}
|
||||||
if res != nil {
|
if ctx.result != nil {
|
||||||
origResp = origResp2 // matched by response
|
ctx.origResp = origResp2 // matched by response
|
||||||
} else {
|
} else {
|
||||||
res = &dnsfilter.Result{}
|
ctx.result = &dnsfilter.Result{}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.Res != nil {
|
return resultDone
|
||||||
d.Res.Compress = true // some devices require DNS message compression
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write Stats data and logs
|
||||||
|
func processQueryLogsAndStats(ctx *dnsContext) int {
|
||||||
|
elapsed := time.Since(ctx.startTime)
|
||||||
|
s := ctx.srv
|
||||||
|
d := ctx.proxyCtx
|
||||||
|
|
||||||
shouldLog := true
|
shouldLog := true
|
||||||
msg := d.Req
|
msg := d.Req
|
||||||
|
|
||||||
|
@ -529,7 +577,6 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
shouldLog = false
|
shouldLog = false
|
||||||
}
|
}
|
||||||
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
s.RLock()
|
s.RLock()
|
||||||
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
|
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
|
||||||
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
|
// This can happen after proxy server has been stopped, but its workers haven't yet exited.
|
||||||
|
@ -537,8 +584,8 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
p := querylog.AddParams{
|
p := querylog.AddParams{
|
||||||
Question: msg,
|
Question: msg,
|
||||||
Answer: d.Res,
|
Answer: d.Res,
|
||||||
OrigAnswer: origResp,
|
OrigAnswer: ctx.origResp,
|
||||||
Result: res,
|
Result: ctx.result,
|
||||||
Elapsed: elapsed,
|
Elapsed: elapsed,
|
||||||
ClientIP: getIP(d.Addr),
|
ClientIP: getIP(d.Addr),
|
||||||
}
|
}
|
||||||
|
@ -548,9 +595,41 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
s.queryLog.Add(p)
|
s.queryLog.Add(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.updateStats(d, elapsed, *res)
|
s.updateStats(d, elapsed, *ctx.result)
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
|
|
||||||
|
return resultDone
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||||
|
// nolint (gocyclo)
|
||||||
|
func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
|
ctx := &dnsContext{srv: s, proxyCtx: d}
|
||||||
|
ctx.result = &dnsfilter.Result{}
|
||||||
|
ctx.startTime = time.Now()
|
||||||
|
|
||||||
|
type modProcessFunc func(ctx *dnsContext) int
|
||||||
|
mods := []modProcessFunc{
|
||||||
|
processInitial,
|
||||||
|
processFilteringBeforeRequest,
|
||||||
|
processUpstream,
|
||||||
|
processFilteringAfterResponse,
|
||||||
|
}
|
||||||
|
for _, process := range mods {
|
||||||
|
r := process(ctx)
|
||||||
|
switch r {
|
||||||
|
case resultFinish:
|
||||||
|
return nil
|
||||||
|
case resultError:
|
||||||
|
return ctx.err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if d.Res != nil {
|
||||||
|
d.Res.Compress = true // some devices require DNS message compression
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = processQueryLogsAndStats(ctx)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -619,10 +698,11 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
// filterDNSRequest applies the dnsFilter and sets d.Res if the request was filtered
|
||||||
func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||||
|
d := ctx.proxyCtx
|
||||||
req := d.Req
|
req := d.Req
|
||||||
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
host := strings.TrimSuffix(req.Question[0].Name, ".")
|
||||||
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, setts)
|
res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return immediately if there's an error
|
// Return immediately if there's an error
|
||||||
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
|
||||||
|
@ -653,6 +733,11 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
|
||||||
}
|
}
|
||||||
|
|
||||||
d.Res = resp
|
d.Res = resp
|
||||||
|
|
||||||
|
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 {
|
||||||
|
ctx.origQuestion = d.Req.Question[0]
|
||||||
|
// resolve canonical name, not the original host name
|
||||||
|
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &res, err
|
return &res, err
|
||||||
|
@ -660,7 +745,8 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext, setts *dnsfilter.RequestF
|
||||||
|
|
||||||
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address.
|
||||||
// If this is a match, we set a new response in d.Res and return.
|
// If this is a match, we set a new response in d.Res and return.
|
||||||
func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.RequestFilteringSettings) (*dnsfilter.Result, error) {
|
func (s *Server) filterDNSResponse(ctx *dnsContext) (*dnsfilter.Result, error) {
|
||||||
|
d := ctx.proxyCtx
|
||||||
for _, a := range d.Res.Answer {
|
for _, a := range d.Res.Answer {
|
||||||
host := ""
|
host := ""
|
||||||
|
|
||||||
|
@ -688,7 +774,7 @@ func (s *Server) filterDNSResponse(d *proxy.DNSContext, setts *dnsfilter.Request
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, setts)
|
res, err := s.dnsFilter.CheckHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
|
||||||
s.RUnlock()
|
s.RUnlock()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue