diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 204995fb..ec27dfd0 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -16,9 +16,6 @@ import ( // To transfer information between modules type dnsContext struct { - // TODO(a.garipov): Remove this and rewrite processors to be methods of - // *Server instead. - srv *Server proxyCtx *proxy.DNSContext // setts are the filtering settings for the client. setts *filtering.Settings @@ -28,7 +25,8 @@ type dnsContext struct { // response is modified by filters. origResp *dns.Msg // unreversedReqIP stores an IP address obtained from PTR request if it - // was successfully parsed. + // parsed successfully and belongs to one of locally-served IP ranges as per + // RFC 6303. unreversedReqIP net.IP // err is the error returned from a processing function. err error @@ -69,7 +67,6 @@ const ( // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { ctx := &dnsContext{ - srv: s, proxyCtx: d, result: &filtering.Result{}, startTime: time.Now(), @@ -84,7 +81,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { // appropriate handler. mods := []modProcessFunc{ s.processRecursion, - processInitial, + s.processInitial, s.processDetermineLocal, s.processInternalHosts, s.processRestrictLocal, @@ -93,10 +90,10 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { s.processFilteringBeforeRequest, s.processLocalPTR, s.processUpstream, - processDNSSECAfterResponse, - processFilteringAfterResponse, + s.processDNSSECAfterResponse, + s.processFilteringAfterResponse, s.ipset.process, - processQueryLogsAndStats, + s.processQueryLogsAndStats, } for _, process := range mods { r := process(ctx) @@ -135,8 +132,7 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { } // Perform initial checks; process WHOIS & rDNS -func processInitial(ctx *dnsContext) (rc resultCode) { - s := ctx.srv +func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { _ = proxy.CheckDisabledAAAARequest(d, true) @@ -155,6 +151,9 @@ func processInitial(ctx *dnsContext) (rc resultCode) { return resultCodeFinish } + ctx.protectionEnabled = s.conf.ProtectionEnabled + ctx.setts = s.getClientRequestFilteringSettings(ctx) + return resultCodeSuccess } @@ -339,10 +338,16 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { } // Restrict an access to local addresses for external clients. We also - // assume that all the DHCP leases we give are locally-served or at - // least don't need to be inaccessible externally. - if s.subnetDetector.IsLocallyServedNetwork(ip) && !ctx.isLocalClient { - log.Debug("dns: %q requests for internal ip", d.Addr) + // assume that all the DHCP leases we give are locally-served or at least + // don't need to be inaccessible externally. + if !s.subnetDetector.IsLocallyServedNetwork(ip) { + log.Debug("dns: addr %s is not from locally-served network", ip) + + return resultCodeSuccess + } + + if !ctx.isLocalClient { + log.Debug("dns: %q requests an internal ip", d.Addr) d.Res = s.genNXDomain(req) // Do not even put into query log. @@ -352,13 +357,13 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { // Do not perform unreversing ever again. ctx.unreversedReqIP = ip - // Disable redundant filtering. - filterSetts := s.getClientRequestFilteringSettings(ctx) - filterSetts.ParentalEnabled = false - filterSetts.SafeBrowsingEnabled = false - filterSetts.SafeSearchEnabled = false - filterSetts.ServicesRules = nil - ctx.setts = filterSetts + // There is no need to filter request from external addresses since this + // code is only executed when the request is for locally-served ARPA + // hostname so disable redundant filters. + ctx.setts.ParentalEnabled = false + ctx.setts.SafeBrowsingEnabled = false + ctx.setts.SafeSearchEnabled = false + ctx.setts.ServicesRules = nil // Nothing to restrict. return resultCodeSuccess @@ -475,16 +480,10 @@ func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) s.serverLock.RLock() defer s.serverLock.RUnlock() - ctx.protectionEnabled = s.conf.ProtectionEnabled - if s.dnsFilter == nil { return resultCodeSuccess } - if ctx.setts == nil { - ctx.setts = s.getClientRequestFilteringSettings(ctx) - } - var err error if ctx.result, err = s.filterDNSRequest(ctx); err != nil { ctx.err = err @@ -555,11 +554,11 @@ func (s *Server) processUpstream(ctx *dnsContext) (rc resultCode) { } // Process DNSSEC after response from upstream server -func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { +func (s *Server) processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx - if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers - !ctx.srv.conf.EnableDNSSEC { + // Don't process response if it's not from upstream servers. + if !ctx.responseFromUpstream || !s.conf.EnableDNSSEC { return resultCodeSuccess } @@ -601,8 +600,7 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { } // Apply filtering logic after we have received response from upstream servers -func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { - s := ctx.srv +func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx switch res := ctx.result; res.Reason { diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 1714563f..b7760c68 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -13,9 +13,8 @@ import ( ) // Write Stats data and logs -func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { +func (s *Server) processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { elapsed := time.Since(ctx.startTime) - s := ctx.srv pctx := ctx.proxyCtx shouldLog := true diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 92985fd4..22780ef2 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -160,6 +160,12 @@ func TestProcessQueryLogsAndStats(t *testing.T) { require.NoError(t, err) for _, tc := range testCases { + ql := &testQueryLog{} + st := &testStats{} + srv := &Server{ + queryLog: ql, + stats: st, + } t.Run(tc.name, func(t *testing.T) { req := &dns.Msg{ Question: []dns.Question{{ @@ -173,14 +179,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) { Addr: tc.addr, Upstream: ups, } - - ql := &testQueryLog{} - st := &testStats{} dctx := &dnsContext{ - srv: &Server{ - queryLog: ql, - stats: st, - }, proxyCtx: pctx, startTime: time.Now(), result: &filtering.Result{ @@ -189,7 +188,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) { clientID: tc.clientID, } - code := processQueryLogsAndStats(dctx) + code := srv.processQueryLogsAndStats(dctx) assert.Equal(t, tc.wantCode, code) assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto) assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)