diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 3e92e8de..1580e416 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -2,7 +2,6 @@ package dnsforward import ( "crypto/tls" - "encoding/binary" "fmt" "path" "strings" @@ -172,19 +171,3 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string return clientID, nil } - -// processClientID puts the clientID into the DNS context, if there is one. -func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) { - pctx := dctx.proxyCtx - - var key [8]byte - binary.BigEndian.PutUint64(key[:], pctx.RequestID) - clientIDData := s.clientIDCache.Get(key[:]) - if clientIDData == nil { - return resultCodeSuccess - } - - dctx.clientID = string(clientIDData) - - return resultCodeSuccess -} diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index ec27dfd0..d498fcf8 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -1,6 +1,7 @@ package dnsforward import ( + "encoding/binary" "net" "strings" "time" @@ -86,7 +87,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { s.processInternalHosts, s.processRestrictLocal, s.processInternalIPAddrs, - s.processClientID, s.processFilteringBeforeRequest, s.processLocalPTR, s.processUpstream, @@ -131,7 +131,10 @@ func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } -// Perform initial checks; process WHOIS & rDNS +// processInitial terminates the following processing for some requests if +// needed and enriches the ctx with some client-specific information. +// +// TODO(e.burkov): Decompose into less general processors. func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { @@ -151,6 +154,13 @@ func (s *Server) processInitial(ctx *dnsContext) (rc resultCode) { return resultCodeFinish } + // Get the client's ID if any. It should be performed before getting + // client-specific filtering settings. + var key [8]byte + binary.BigEndian.PutUint64(key[:], d.RequestID) + ctx.clientID = string(s.clientIDCache.Get(key[:])) + + // Get the client-specific filtering settings. ctx.protectionEnabled = s.conf.ProtectionEnabled ctx.setts = s.getClientRequestFilteringSettings(ctx)