diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 47ce1f0e..4acd4ee7 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -510,7 +510,21 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { } return result, nil } - result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody) + + // check cache + cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host) + if isFound { + atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1) + log.Tracef("%s: found in the lookup cache %p", host, gctx.safebrowsingCache) + return cachedValue, nil + } + + result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, true, format, handleBody) + + if err == nil { + setCacheResult(gctx.safebrowsingCache, host, result) + } + return result, err } @@ -562,7 +576,21 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) { } return result, nil } - result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody) + + // check cache + cachedValue, isFound := getCachedResult(gctx.parentalCache, host) + if isFound { + atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1) + log.Tracef("%s: found in the lookup cache %p", host, gctx.parentalCache) + return cachedValue, nil + } + + result, err := d.lookupCommon(host, &gctx.stats.Parental, false, format, handleBody) + + if err == nil { + setCacheResult(gctx.parentalCache, host, result) + } + return result, err } @@ -570,18 +598,7 @@ type formatHandler func(hashparam string) string type bodyHandler func(body []byte, hashes map[string]bool) (Result, error) // real implementation of lookup/check -func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache *fastcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) { - // if host ends with a dot, trim it - host = strings.ToLower(strings.Trim(host, ".")) - - // check cache - cachedValue, isFound := getCachedResult(cache, host) - if isFound { - atomic.AddUint64(&lookupstats.CacheHits, 1) - log.Tracef("%s: found in the lookup cache %p", host, cache) - return cachedValue, nil - } - +func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) { // convert hostname to hash parameters hashparam, hashes := hostnameToHashParam(host, hashparamNeedSlash) @@ -613,20 +630,16 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache *f switch { case resp.StatusCode == 204: // empty result, save cache - setCacheResult(cache, host, Result{}) return Result{}, nil case resp.StatusCode != 200: - // error, don't save cache - return Result{}, nil + return Result{}, fmt.Errorf("HTTP status code: %d", resp.StatusCode) } result, err := handleBody(body, hashes) if err != nil { - // error, don't save cache return Result{}, err } - setCacheResult(cache, host, result) return result, nil } diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index d5eabafe..15ecec80 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -144,9 +144,7 @@ func TestSafeBrowsing(t *testing.T) { if gctx.stats.Safebrowsing.Requests != 1 { t.Errorf("Safebrowsing lookup positive cache is not working: %v", gctx.stats.Safebrowsing.Requests) } - d.checkMatch(t, "wmconvirus.narod.ru.") d.checkMatch(t, "test.wmconvirus.narod.ru") - d.checkMatch(t, "test.wmconvirus.narod.ru.") d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "pornhub.com") l := gctx.stats.Safebrowsing.Requests @@ -166,9 +164,7 @@ func TestParallelSB(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Parallel() d.checkMatch(t, "wmconvirus.narod.ru") - d.checkMatch(t, "wmconvirus.narod.ru.") d.checkMatch(t, "test.wmconvirus.narod.ru") - d.checkMatch(t, "test.wmconvirus.narod.ru.") d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "pornhub.com") }) @@ -368,8 +364,6 @@ func TestParentalControl(t *testing.T) { t.Errorf("Parental lookup positive cache is not working") } d.checkMatch(t, "www.pornhub.com") - d.checkMatch(t, "pornhub.com.") - d.checkMatch(t, "www.pornhub.com.") d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") l := gctx.stats.Parental.Requests diff --git a/home/dns.go b/home/dns.go index b474bc5d..ee25ebc1 100644 --- a/home/dns.go +++ b/home/dns.go @@ -4,9 +4,7 @@ import ( "fmt" "net" "os" - "strings" "sync" - "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" @@ -17,10 +15,6 @@ import ( "github.com/miekg/dns" ) -const ( - rdnsTimeout = 3 * time.Second // max time to wait for rDNS response -) - type dnsContext struct { rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread // contains IP addresses of clients to be resolved by rDNS @@ -41,115 +35,13 @@ func initDNSServer(baseDir string) { config.dnsServer = dnsforward.NewServer(baseDir) - bindhost := config.DNS.BindHost - if config.DNS.BindHost == "0.0.0.0" { - bindhost = "127.0.0.1" - } - resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) - opts := upstream.Options{ - Timeout: rdnsTimeout, - } - config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) - if err != nil { - log.Error("upstream.AddressToUpstream: %s", err) - return - } - - config.dnsctx.rdnsIP = make(map[string]bool) - config.dnsctx.rdnsChannel = make(chan string, 256) - go asyncRDNSLoop() + initRDNS() } func isRunning() bool { return config.dnsServer != nil && config.dnsServer.IsRunning() } -func beginAsyncRDNS(ip string) { - if config.clients.Exists(ip) { - return - } - - // add IP to rdnsIP, if not exists - config.dnsctx.rdnsLock.Lock() - defer config.dnsctx.rdnsLock.Unlock() - _, ok := config.dnsctx.rdnsIP[ip] - if ok { - return - } - config.dnsctx.rdnsIP[ip] = true - - log.Tracef("Adding %s for rDNS resolve", ip) - select { - case config.dnsctx.rdnsChannel <- ip: - // - default: - log.Tracef("rDNS queue is full") - } -} - -// Use rDNS to get hostname by IP address -func resolveRDNS(ip string) string { - log.Tracef("Resolving host for %s", ip) - - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - { - Qtype: dns.TypePTR, - Qclass: dns.ClassINET, - }, - } - var err error - req.Question[0].Name, err = dns.ReverseAddr(ip) - if err != nil { - log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) - return "" - } - - resp, err := config.dnsctx.upstream.Exchange(&req) - if err != nil { - log.Error("Error while making an rDNS lookup for %s: %s", ip, err) - return "" - } - if len(resp.Answer) != 1 { - log.Debug("No answer for rDNS lookup of %s", ip) - return "" - } - ptr, ok := resp.Answer[0].(*dns.PTR) - if !ok { - log.Error("not a PTR response for %s", ip) - return "" - } - - log.Tracef("PTR response for %s: %s", ip, ptr.String()) - if strings.HasSuffix(ptr.Ptr, ".") { - ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1] - } - - return ptr.Ptr -} - -// Wait for a signal and then synchronously resolve hostname by IP address -// Add the hostname:IP pair to "Clients" array -func asyncRDNSLoop() { - for { - var ip string - ip = <-config.dnsctx.rdnsChannel - - host := resolveRDNS(ip) - if len(host) == 0 { - continue - } - - config.dnsctx.rdnsLock.Lock() - delete(config.dnsctx.rdnsIP, ip) - config.dnsctx.rdnsLock.Unlock() - - _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) - } -} - func onDNSRequest(d *proxy.DNSContext) { qType := d.Req.Question[0].Qtype if qType != dns.TypeA && qType != dns.TypeAAAA { diff --git a/home/rdns.go b/home/rdns.go new file mode 100644 index 00000000..9ea11a26 --- /dev/null +++ b/home/rdns.go @@ -0,0 +1,125 @@ +package home + +import ( + "fmt" + "strings" + "time" + + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +const ( + rdnsTimeout = 3 * time.Second // max time to wait for rDNS response +) + +func initRDNS() { + var err error + + bindhost := config.DNS.BindHost + if config.DNS.BindHost == "0.0.0.0" { + bindhost = "127.0.0.1" + } + resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) + + opts := upstream.Options{ + Timeout: rdnsTimeout, + } + config.dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) + if err != nil { + log.Error("upstream.AddressToUpstream: %s", err) + return + } + + config.dnsctx.rdnsIP = make(map[string]bool) + config.dnsctx.rdnsChannel = make(chan string, 256) + go asyncRDNSLoop() +} + +// Add IP address to the rDNS queue +func beginAsyncRDNS(ip string) { + if config.clients.Exists(ip) { + return + } + + // add IP to rdnsIP, if not exists + config.dnsctx.rdnsLock.Lock() + defer config.dnsctx.rdnsLock.Unlock() + _, ok := config.dnsctx.rdnsIP[ip] + if ok { + return + } + config.dnsctx.rdnsIP[ip] = true + + log.Tracef("Adding %s for rDNS resolve", ip) + select { + case config.dnsctx.rdnsChannel <- ip: + // + default: + log.Tracef("rDNS queue is full") + } +} + +// Use rDNS to get hostname by IP address +func resolveRDNS(ip string) string { + log.Tracef("Resolving host for %s", ip) + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + { + Qtype: dns.TypePTR, + Qclass: dns.ClassINET, + }, + } + var err error + req.Question[0].Name, err = dns.ReverseAddr(ip) + if err != nil { + log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) + return "" + } + + resp, err := config.dnsctx.upstream.Exchange(&req) + if err != nil { + log.Debug("Error while making an rDNS lookup for %s: %s", ip, err) + return "" + } + if len(resp.Answer) != 1 { + log.Debug("No answer for rDNS lookup of %s", ip) + return "" + } + ptr, ok := resp.Answer[0].(*dns.PTR) + if !ok { + log.Debug("not a PTR response for %s", ip) + return "" + } + + log.Tracef("PTR response for %s: %s", ip, ptr.String()) + if strings.HasSuffix(ptr.Ptr, ".") { + ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1] + } + + return ptr.Ptr +} + +// Wait for a signal and then synchronously resolve hostname by IP address +// Add the hostname:IP pair to "Clients" array +func asyncRDNSLoop() { + for { + var ip string + ip = <-config.dnsctx.rdnsChannel + + host := resolveRDNS(ip) + if len(host) == 0 { + continue + } + + config.dnsctx.rdnsLock.Lock() + delete(config.dnsctx.rdnsIP, ip) + config.dnsctx.rdnsLock.Unlock() + + _, _ = config.clients.AddHost(ip, host, ClientSourceRDNS) + } +}