diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index eed6bd33..2aabb371 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -1,13 +1,7 @@ package dnsfilter import ( - "bufio" "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "encoding/gob" - "encoding/json" "fmt" "io/ioutil" "net" @@ -16,30 +10,14 @@ import ( "runtime" "strings" "sync" - "sync/atomic" - "time" - - "github.com/joomcode/errorx" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter" - "github.com/bluele/gcache" "github.com/miekg/dns" - "golang.org/x/net/publicsuffix" ) -const defaultHTTPTimeout = 5 * time.Minute -const defaultHTTPMaxIdleConnections = 100 - -const defaultSafebrowsingServer = "sb.adtidy.org" -const defaultSafebrowsingURL = "%s://%s/safebrowsing-lookup-hash.html?prefixes=%s" -const defaultParentalServer = "pctrl.adguard.com" -const defaultParentalURL = "%s://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d" -const defaultParentalSensitivity = 13 // use "TEEN" by default -const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control - // ServiceEntry - blocked service array element type ServiceEntry struct { Name string @@ -65,7 +43,6 @@ type RewriteEntry struct { type Config struct { ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 ParentalEnabled bool `yaml:"parental_enabled"` - UsePlainHTTP bool `yaml:"-"` // use plain HTTP for requests to parental and safe browsing servers SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` ResolverAddress string `yaml:"-"` // DNS server address @@ -110,12 +87,10 @@ type Dnsfilter struct { filteringEngine *urlfilter.DNSEngine engineLock sync.RWMutex - // HTTP lookups for safebrowsing and parental - client http.Client // handle for http client -- single instance as recommended by docs - transport *http.Transport // handle for http transport used by http client - - parentalServer string // access via methods - safeBrowsingServer string // access via methods + parentalServer string // access via methods + safeBrowsingServer string // access via methods + parentalUpstream upstream.Upstream + safeBrowsingUpstream upstream.Upstream Config // for direct access by library users, even a = assignment confLock sync.RWMutex @@ -251,9 +226,6 @@ func (d *Dnsfilter) filtersInitializer() { // Close - close the object func (d *Dnsfilter) Close() { - if d != nil && d.transport != nil { - d.transport.CloseIdleConnections() - } if d.rulesStorage != nil { d.rulesStorage.Close() } @@ -261,7 +233,6 @@ func (d *Dnsfilter) Close() { type dnsFilterContext struct { stats Stats - dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers safebrowsingCache cache.Cache parentalCache cache.Cache safeSearchCache cache.Cache @@ -328,11 +299,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering } } - // check safeSearch if no match if setts.SafeSearchEnabled { result, err = d.checkSafeSearch(host) if err != nil { - log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err) + log.Info("SafeSearch: failed: %v", err) return Result{}, nil } @@ -341,12 +311,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering } } - // check safebrowsing if no match if setts.SafeBrowsingEnabled { result, err = d.checkSafeBrowsing(host) if err != nil { - // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache - log.Printf("Failed to do safebrowsing HTTP lookup, ignoring check: %v", err) + log.Info("SafeBrowsing: failed: %v", err) return Result{}, nil } if result.Reason.Matched() { @@ -354,12 +322,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering } } - // check parental if no match if setts.ParentalEnabled { result, err = d.checkParental(host) if err != nil { - // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache - log.Printf("Failed to do parental HTTP lookup, ignoring check: %v", err) + log.Printf("Parental: failed: %v", err) return Result{}, nil } if result.Reason.Matched() { @@ -367,7 +333,6 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering } } - // nothing matched, return nothing return Result{}, nil } @@ -445,311 +410,6 @@ func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result { return res } -/* -expire byte[4] -res Result -*/ -func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) { - var buf bytes.Buffer - - expire := uint(time.Now().Unix()) + d.Config.CacheTime*60 - var exp []byte - exp = make([]byte, 4) - binary.BigEndian.PutUint32(exp, uint32(expire)) - _, _ = buf.Write(exp) - - enc := gob.NewEncoder(&buf) - err := enc.Encode(res) - if err != nil { - log.Error("gob.Encode(): %s", err) - return - } - _ = cache.Set([]byte(host), buf.Bytes()) - log.Debug("Stored in cache %p: %s", cache, host) -} - -func getCachedResult(cache cache.Cache, host string) (Result, bool) { - data := cache.Get([]byte(host)) - if data == nil { - return Result{}, false - } - - exp := int(binary.BigEndian.Uint32(data[:4])) - if exp <= int(time.Now().Unix()) { - cache.Del([]byte(host)) - return Result{}, false - } - - var buf bytes.Buffer - buf.Write(data[4:]) - dec := gob.NewDecoder(&buf) - r := Result{} - err := dec.Decode(&r) - if err != nil { - log.Debug("gob.Decode(): %s", err) - return Result{}, false - } - - return r, true -} - -// for each dot, hash it and add it to string -func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) { - var hashparam bytes.Buffer - hashes := map[string]bool{} - tld, icann := publicsuffix.PublicSuffix(host) - if !icann { - // private suffixes like cloudfront.net - tld = "" - } - curhost := host - for { - if curhost == "" { - // we've reached end of string - break - } - if tld != "" && curhost == tld { - // we've reached the TLD, don't hash it - break - } - tohash := []byte(curhost) - if addslash { - tohash = append(tohash, '/') - } - sum := sha256.Sum256(tohash) - hexhash := fmt.Sprintf("%X", sum) - hashes[hexhash] = true - hashparam.WriteString(fmt.Sprintf("%02X%02X%02X%02X/", sum[0], sum[1], sum[2], sum[3])) - pos := strings.IndexByte(curhost, byte('.')) - if pos < 0 { - break - } - curhost = curhost[pos+1:] - } - return hashparam.String(), hashes -} - -func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { - if log.GetLevel() >= log.DEBUG { - timer := log.StartTimer() - defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host) - } - - // Check cache. Return cached result if it was found - cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host) - if isFound { - atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1) - log.Tracef("%s: found in SafeSearch cache", host) - return cachedValue, nil - } - - safeHost, ok := d.SafeSearchDomain(host) - if !ok { - return Result{}, nil - } - - res := Result{IsFiltered: true, Reason: FilteredSafeSearch} - if ip := net.ParseIP(safeHost); ip != nil { - res.IP = ip - d.setCacheResult(gctx.safeSearchCache, host, res) - return res, nil - } - - // TODO this address should be resolved with upstream that was configured in dnsforward - addrs, err := net.LookupIP(safeHost) - if err != nil { - log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) - return Result{}, err - } - - for _, i := range addrs { - if ipv4 := i.To4(); ipv4 != nil { - res.IP = ipv4 - break - } - } - - if len(res.IP) == 0 { - return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost) - } - - // Cache result - d.setCacheResult(gctx.safeSearchCache, host, res) - return res, nil -} - -func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { - if log.GetLevel() >= log.DEBUG { - timer := log.StartTimer() - defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host) - } - - format := func(hashparam string) string { - schema := "https" - if d.UsePlainHTTP { - schema = "http" - } - url := fmt.Sprintf(defaultSafebrowsingURL, schema, d.safeBrowsingServer, hashparam) - return url - } - handleBody := func(body []byte, hashes map[string]bool) (Result, error) { - result := Result{} - scanner := bufio.NewScanner(strings.NewReader(string(body))) - for scanner.Scan() { - line := scanner.Text() - splitted := strings.Split(line, ":") - if len(splitted) < 3 { - continue - } - hash := splitted[2] - if _, ok := hashes[hash]; ok { - // it's in the hash - result.IsFiltered = true - result.Reason = FilteredSafeBrowsing - result.Rule = splitted[0] - break - } - } - - if err := scanner.Err(); err != nil { - // error, don't save cache - return Result{}, err - } - return result, nil - } - - // check cache - cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host) - if isFound { - atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1) - log.Tracef("%s: found in the lookup cache %p", host, gctx.safebrowsingCache) - return cachedValue, nil - } - - result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, true, format, handleBody) - - if err == nil { - d.setCacheResult(gctx.safebrowsingCache, host, result) - } - - return result, err -} - -func (d *Dnsfilter) checkParental(host string) (Result, error) { - if log.GetLevel() >= log.DEBUG { - timer := log.StartTimer() - defer timer.LogElapsed("Parental HTTP lookup for %s", host) - } - - format := func(hashparam string) string { - schema := "https" - if d.UsePlainHTTP { - schema = "http" - } - sensitivity := d.ParentalSensitivity - if sensitivity == 0 { - sensitivity = defaultParentalSensitivity - } - url := fmt.Sprintf(defaultParentalURL, schema, d.parentalServer, hashparam, sensitivity) - return url - } - handleBody := func(body []byte, hashes map[string]bool) (Result, error) { - // parse json - var m []struct { - Blocked bool `json:"blocked"` - ClientTTL int `json:"clientTtl"` - Reason string `json:"reason"` - Hash string `json:"hash"` - } - err := json.Unmarshal(body, &m) - if err != nil { - // error, don't save cache - log.Printf("Couldn't parse json '%s': %s", body, err) - return Result{}, err - } - - result := Result{} - - for i := range m { - if !hashes[m[i].Hash] { - continue - } - if m[i].Blocked { - result.IsFiltered = true - result.Reason = FilteredParental - result.Rule = fmt.Sprintf("parental %s", m[i].Reason) - break - } - } - return result, nil - } - - // check cache - cachedValue, isFound := getCachedResult(gctx.parentalCache, host) - if isFound { - atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1) - log.Tracef("%s: found in the lookup cache %p", host, gctx.parentalCache) - return cachedValue, nil - } - - result, err := d.lookupCommon(host, &gctx.stats.Parental, false, format, handleBody) - - if err == nil { - d.setCacheResult(gctx.parentalCache, host, result) - } - - return result, err -} - -type formatHandler func(hashparam string) string -type bodyHandler func(body []byte, hashes map[string]bool) (Result, error) - -// real implementation of lookup/check -func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) { - // convert hostname to hash parameters - hashparam, hashes := hostnameToHashParam(host, hashparamNeedSlash) - - // format URL with our hashes - url := format(hashparam) - - // do HTTP request - atomic.AddUint64(&lookupstats.Requests, 1) - atomic.AddInt64(&lookupstats.Pending, 1) - updateMax(&lookupstats.Pending, &lookupstats.PendingMax) - resp, err := d.client.Get(url) - atomic.AddInt64(&lookupstats.Pending, -1) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - if err != nil { - // error, don't save cache - return Result{}, err - } - - // get body text - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - // error, don't save cache - return Result{}, err - } - - // handle status code - switch { - case resp.StatusCode == 204: - // empty result, save cache - return Result{}, nil - case resp.StatusCode != 200: - return Result{}, fmt.Errorf("HTTP status code: %d", resp.StatusCode) - } - - result, err := handleBody(body, hashes) - if err != nil { - return Result{}, err - } - - return result, nil -} - // // Adding rule and matching against the rules // @@ -887,97 +547,6 @@ func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) { return Result{}, nil } -// -// lifecycle helper functions -// - -// Return TRUE if this host's IP should be cached -func (d *Dnsfilter) shouldBeInDialCache(host string) bool { - return host == d.safeBrowsingServer || - host == d.parentalServer -} - -// Search for an IP address by host name -func searchInDialCache(host string) string { - rawValue, err := gctx.dialCache.Get(host) - if err != nil { - return "" - } - - ip, _ := rawValue.(string) - log.Debug("Found in cache: %s -> %s", host, ip) - return ip -} - -// Add "hostname" -> "IP address" entry to cache -func addToDialCache(host, ip string) { - err := gctx.dialCache.Set(host, ip) - if err != nil { - log.Debug("dialCache.Set: %s", err) - } - log.Debug("Added to cache: %s -> %s", host, ip) -} - -type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error) - -// Connect to a remote server resolving hostname using our own DNS server -func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType { - return func(ctx context.Context, network, addr string) (net.Conn, error) { - log.Tracef("network:%v addr:%v", network, addr) - - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - - dialer := &net.Dialer{ - Timeout: time.Minute * 5, - } - - if net.ParseIP(host) != nil { - con, err := dialer.DialContext(ctx, network, addr) - return con, err - } - - cache := d.shouldBeInDialCache(host) - if cache { - ip := searchInDialCache(host) - if len(ip) != 0 { - addr = fmt.Sprintf("%s:%s", ip, port) - return dialer.DialContext(ctx, network, addr) - } - } - - r := upstream.NewResolver(resolverAddr, 30*time.Second) - addrs, e := r.LookupIPAddr(ctx, host) - log.Tracef("LookupIPAddr: %s: %v", host, addrs) - if e != nil { - return nil, e - } - - if len(addrs) == 0 { - return nil, fmt.Errorf("couldn't lookup host: %s", host) - } - - var dialErrs []error - for _, a := range addrs { - addr = fmt.Sprintf("%s:%s", a.String(), port) - con, err := dialer.DialContext(ctx, network, addr) - if err != nil { - dialErrs = append(dialErrs, err) - continue - } - - if cache { - addToDialCache(host, a.String()) - } - - return con, err - } - return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...) - } -} - // New creates properly initialized DNS Filter that is ready to be used func New(c *Config, filters map[int]string) *Dnsfilter { @@ -1002,34 +571,16 @@ func New(c *Config, filters map[int]string) *Dnsfilter { cacheConf.MaxSize = c.ParentalCacheSize gctx.parentalCache = cache.New(cacheConf) } - - if len(c.ResolverAddress) != 0 && gctx.dialCache == nil { - dur := time.Duration(c.CacheTime) * time.Minute - gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(dur).Build() - } } d := new(Dnsfilter) - // Customize the Transport to have larger connection pool, - // We are not (re)using http.DefaultTransport because of race conditions found by tests - d.transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - MaxIdleConns: defaultHTTPMaxIdleConnections, // default 100 - MaxIdleConnsPerHost: defaultHTTPMaxIdleConnections, // default 2 - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + err := d.initSecurityServices() + if err != nil { + log.Error("dnsfilter: initialize services: %s", err) + return nil } - if c != nil && len(c.ResolverAddress) != 0 { - d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress) - } - d.client = http.Client{ - Transport: d.transport, - Timeout: defaultHTTPTimeout, - } - d.safeBrowsingServer = defaultSafebrowsingServer - d.parentalServer = defaultParentalServer + if c != nil { d.Config = *c } @@ -1053,38 +604,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter { return d } -// -// config manipulation helpers -// - -// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup -func (d *Dnsfilter) SetSafeBrowsingServer(host string) { - if len(host) == 0 { - d.safeBrowsingServer = defaultSafebrowsingServer - } else { - d.safeBrowsingServer = host - } -} - -// SetHTTPTimeout lets you optionally change timeout during lookups -func (d *Dnsfilter) SetHTTPTimeout(t time.Duration) { - d.client.Timeout = t -} - -// ResetHTTPTimeout resets lookup timeouts -func (d *Dnsfilter) ResetHTTPTimeout() { - d.client.Timeout = defaultHTTPTimeout -} - -// SafeSearchDomain returns replacement address for search engine -func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { - if d.SafeSearchEnabled { - val, ok := safeSearchDomains[host] - return val, ok - } - return "", false -} - // // stats // diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 37255b78..231340bc 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -3,15 +3,11 @@ package dnsfilter import ( "fmt" "net" - "net/http" - "net/http/httptest" "path" "runtime" "testing" - "time" "github.com/AdguardTeam/urlfilter" - "github.com/bluele/gcache" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -23,7 +19,6 @@ var setts RequestFilteringSettings // SAFE SEARCH // PARENTAL // FILTERING -// CLIENTS SETTINGS // BENCHMARKS // HELPERS @@ -126,34 +121,19 @@ func TestEtcHostsMatching(t *testing.T) { // SAFE BROWSING func TestSafeBrowsing(t *testing.T) { - testCases := []string{ - "", - "sb.adtidy.org", - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() - gctx.stats.Safebrowsing.Requests = 0 - d.checkMatch(t, "wmconvirus.narod.ru") - d.checkMatch(t, "wmconvirus.narod.ru") - if gctx.stats.Safebrowsing.Requests != 1 { - t.Errorf("Safebrowsing lookup positive cache is not working: %v", gctx.stats.Safebrowsing.Requests) - } - d.checkMatch(t, "WMconvirus.narod.ru") - if gctx.stats.Safebrowsing.Requests != 1 { - t.Errorf("Safebrowsing lookup positive cache is not working: %v", gctx.stats.Safebrowsing.Requests) - } - d.checkMatch(t, "test.wmconvirus.narod.ru") - d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, "pornhub.com") - l := gctx.stats.Safebrowsing.Requests - d.checkMatchEmpty(t, "pornhub.com") - if gctx.stats.Safebrowsing.Requests != l { - t.Errorf("Safebrowsing lookup negative cache is not working: %v", gctx.stats.Safebrowsing.Requests) - } - }) - } + d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) + defer d.Close() + gctx.stats.Safebrowsing.Requests = 0 + d.checkMatch(t, "wmconvirus.narod.ru") + d.checkMatch(t, "test.wmconvirus.narod.ru") + d.checkMatchEmpty(t, "yandex.ru") + d.checkMatchEmpty(t, "pornhub.com") + + // test cached result + d.safeBrowsingServer = "127.0.0.1" + d.checkMatch(t, "wmconvirus.narod.ru") + d.checkMatchEmpty(t, "pornhub.com") + d.safeBrowsingServer = defaultSafebrowsingServer } func TestParallelSB(t *testing.T) { @@ -172,33 +152,10 @@ func TestParallelSB(t *testing.T) { }) } -// the only way to verify that custom server option is working is to point it at a server that does serve safebrowsing -func TestSafeBrowsingCustomServerFail(t *testing.T) { - d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) - defer d.Close() - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // w.Write("Hello, client") - fmt.Fprintln(w, "Hello, client") - })) - defer ts.Close() - address := ts.Listener.Addr().String() - - d.SetHTTPTimeout(time.Second * 5) - d.SetSafeBrowsingServer(address) // this will ensure that test fails - d.checkMatchEmpty(t, "wmconvirus.narod.ru") -} - // SAFE SEARCH func TestSafeSearch(t *testing.T) { - d := NewForTest(nil, nil) - defer d.Close() - _, ok := d.SafeSearchDomain("www.google.com") - if ok { - t.Errorf("Expected safesearch to error when disabled") - } - - d = NewForTest(&Config{SafeSearchEnabled: true}, nil) + d := NewForTest(&Config{SafeSearchEnabled: true}, nil) defer d.Close() val, ok := d.SafeSearchDomain("www.google.com") if !ok { @@ -355,24 +312,16 @@ func TestParentalControl(t *testing.T) { defer d.Close() d.ParentalSensitivity = 3 d.checkMatch(t, "pornhub.com") - d.checkMatch(t, "pornhub.com") - if gctx.stats.Parental.Requests != 1 { - t.Errorf("Parental lookup positive cache is not working") - } - d.checkMatch(t, "PORNhub.com") - if gctx.stats.Parental.Requests != 1 { - t.Errorf("Parental lookup positive cache is not working") - } d.checkMatch(t, "www.pornhub.com") d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") - l := gctx.stats.Parental.Requests - d.checkMatchEmpty(t, "yandex.ru") - if gctx.stats.Parental.Requests != l { - t.Errorf("Parental lookup negative cache is not working") - } - d.checkMatchEmpty(t, "api.jquery.com") + + // test cached result + d.parentalServer = "127.0.0.1" + d.checkMatch(t, "pornhub.com") + d.checkMatchEmpty(t, "yandex.ru") + d.parentalServer = defaultParentalServer } // FILTERING @@ -588,17 +537,3 @@ func BenchmarkSafeSearchParallel(b *testing.B) { } }) } - -func TestDnsfilterDialCache(t *testing.T) { - d := Dnsfilter{} - gctx.dialCache = gcache.New(1).LRU().Expiration(30 * time.Minute).Build() - - d.shouldBeInDialCache("hostname") - if searchInDialCache("hostname") != "" { - t.Errorf("searchInDialCache") - } - addToDialCache("hostname", "1.1.1.1") - if searchInDialCache("hostname") != "1.1.1.1" { - t.Errorf("searchInDialCache") - } -} diff --git a/dnsfilter/helpers.go b/dnsfilter/helpers.go deleted file mode 100644 index 2d60c47c..00000000 --- a/dnsfilter/helpers.go +++ /dev/null @@ -1,20 +0,0 @@ -package dnsfilter - -import ( - "sync/atomic" -) - -func updateMax(valuePtr *int64, maxPtr *int64) { - for { - current := atomic.LoadInt64(valuePtr) - max := atomic.LoadInt64(maxPtr) - if current <= max { - break - } - swapped := atomic.CompareAndSwapInt64(maxPtr, max, current) - if swapped { - break - } - // swapping failed because value has changed after reading, try again - } -} diff --git a/dnsfilter/security.go b/dnsfilter/security.go index c4ce32de..1de5cc8a 100644 --- a/dnsfilter/security.go +++ b/dnsfilter/security.go @@ -4,17 +4,290 @@ package dnsfilter import ( "bufio" + "bytes" + "crypto/sha256" + "encoding/binary" + "encoding/gob" + "encoding/hex" "encoding/json" "errors" "fmt" "io" + "net" "net/http" "strconv" "strings" + "time" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" + "golang.org/x/net/publicsuffix" ) +const dnsTimeout = 3 * time.Second +const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query" +const defaultParentalServer = "https://dns-family.adguard.com/dns-query" +const sbTXTSuffix = "sb.dns.adguard.com." +const pcTXTSuffix = "pc.dns.adguard.com." + +func (d *Dnsfilter) initSecurityServices() error { + var err error + d.safeBrowsingServer = defaultSafebrowsingServer + d.parentalServer = defaultParentalServer + + d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, upstream.Options{Timeout: dnsTimeout}) + if err != nil { + return err + } + + d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, upstream.Options{Timeout: dnsTimeout}) + if err != nil { + return err + } + + return nil +} + +/* +expire byte[4] +res Result +*/ +func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int { + var buf bytes.Buffer + + expire := uint(time.Now().Unix()) + d.Config.CacheTime*60 + var exp []byte + exp = make([]byte, 4) + binary.BigEndian.PutUint32(exp, uint32(expire)) + _, _ = buf.Write(exp) + + enc := gob.NewEncoder(&buf) + err := enc.Encode(res) + if err != nil { + log.Error("gob.Encode(): %s", err) + return 0 + } + val := buf.Bytes() + _ = cache.Set([]byte(host), val) + return len(val) +} + +func getCachedResult(cache cache.Cache, host string) (Result, bool) { + data := cache.Get([]byte(host)) + if data == nil { + return Result{}, false + } + + exp := int(binary.BigEndian.Uint32(data[:4])) + if exp <= int(time.Now().Unix()) { + cache.Del([]byte(host)) + return Result{}, false + } + + var buf bytes.Buffer + buf.Write(data[4:]) + dec := gob.NewDecoder(&buf) + r := Result{} + err := dec.Decode(&r) + if err != nil { + log.Debug("gob.Decode(): %s", err) + return Result{}, false + } + + return r, true +} + +// SafeSearchDomain returns replacement address for search engine +func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { + val, ok := safeSearchDomains[host] + return val, ok +} + +func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { + if log.GetLevel() >= log.DEBUG { + timer := log.StartTimer() + defer timer.LogElapsed("SafeSearch: lookup for %s", host) + } + + // Check cache. Return cached result if it was found + cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host) + if isFound { + // atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1) + log.Tracef("SafeSearch: found in cache: %s", host) + return cachedValue, nil + } + + safeHost, ok := d.SafeSearchDomain(host) + if !ok { + return Result{}, nil + } + + res := Result{IsFiltered: true, Reason: FilteredSafeSearch} + if ip := net.ParseIP(safeHost); ip != nil { + res.IP = ip + len := d.setCacheResult(gctx.safeSearchCache, host, res) + log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, len) + return res, nil + } + + // TODO this address should be resolved with upstream that was configured in dnsforward + addrs, err := net.LookupIP(safeHost) + if err != nil { + log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) + return Result{}, err + } + + for _, i := range addrs { + if ipv4 := i.To4(); ipv4 != nil { + res.IP = ipv4 + break + } + } + + if len(res.IP) == 0 { + return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost) + } + + // Cache result + len := d.setCacheResult(gctx.safeSearchCache, host, res) + log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, len) + return res, nil +} + +// for each dot, hash it and add it to string +func hostnameToHashParam(host string) (string, map[string]bool) { + var hashparam bytes.Buffer + hashes := map[string]bool{} + tld, icann := publicsuffix.PublicSuffix(host) + if !icann { + // private suffixes like cloudfront.net + tld = "" + } + curhost := host + for { + if curhost == "" { + // we've reached end of string + break + } + if tld != "" && curhost == tld { + // we've reached the TLD, don't hash it + break + } + + sum := sha256.Sum256([]byte(curhost)) + hashes[hex.EncodeToString(sum[:])] = true + hashparam.WriteString(fmt.Sprintf("%s.", hex.EncodeToString(sum[0:4]))) + + pos := strings.IndexByte(curhost, byte('.')) + if pos < 0 { + break + } + curhost = curhost[pos+1:] + } + return hashparam.String(), hashes +} + +// Find the target hash in TXT response +func (d *Dnsfilter) processTXT(svc, host string, resp *dns.Msg, hashes map[string]bool) bool { + for _, a := range resp.Answer { + txt, ok := a.(*dns.TXT) + if !ok { + continue + } + log.Tracef("%s: hashes for %s: %v", svc, host, txt.Txt) + for _, t := range txt.Txt { + _, ok := hashes[t] + if ok { + log.Tracef("%s: matched %s by %s", svc, host, t) + return true + } + } + } + return false +} + +// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data +// nolint:dupl +func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { + if log.GetLevel() >= log.DEBUG { + timer := log.StartTimer() + defer timer.LogElapsed("SafeBrowsing lookup for %s", host) + } + + // check cache + cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host) + if isFound { + // atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1) + log.Tracef("SafeBrowsing: found in cache: %s", host) + return cachedValue, nil + } + + result := Result{} + question, hashes := hostnameToHashParam(host) + question = question + sbTXTSuffix + + log.Tracef("SafeBrowsing: checking %s: %s", host, question) + + req := dns.Msg{} + req.SetQuestion(question, dns.TypeTXT) + resp, err := d.safeBrowsingUpstream.Exchange(&req) + if err != nil { + return result, err + } + + if d.processTXT("SafeBrowsing", host, resp, hashes) { + result.IsFiltered = true + result.Reason = FilteredSafeBrowsing + result.Rule = "adguard-malware-shavar" + } + + len := d.setCacheResult(gctx.safebrowsingCache, host, result) + log.Debug("SafeBrowsing: stored in cache: %s (%d bytes)", host, len) + return result, nil +} + +// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data +// nolint:dupl +func (d *Dnsfilter) checkParental(host string) (Result, error) { + if log.GetLevel() >= log.DEBUG { + timer := log.StartTimer() + defer timer.LogElapsed("Parental lookup for %s", host) + } + + // check cache + cachedValue, isFound := getCachedResult(gctx.parentalCache, host) + if isFound { + // atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1) + log.Tracef("Parental: found in cache: %s", host) + return cachedValue, nil + } + + result := Result{} + question, hashes := hostnameToHashParam(host) + question = question + pcTXTSuffix + + log.Tracef("Parental: checking %s: %s", host, question) + + req := dns.Msg{} + req.SetQuestion(question, dns.TypeTXT) + resp, err := d.parentalUpstream.Exchange(&req) + if err != nil { + return result, err + } + + if d.processTXT("Parental", host, resp, hashes) { + result.IsFiltered = true + result.Reason = FilteredParental + result.Rule = "parental CATEGORY_BLACKLISTED" + } + + len := d.setCacheResult(gctx.parentalCache, host, result) + log.Debug("Parental: stored in cache: %s (%d bytes)", host, len) + return result, err +} + func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text) @@ -170,9 +443,11 @@ func (d *Dnsfilter) registerSecurityHandlers() { d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus) + d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable) d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable) d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus) + d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable) d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable) d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus) diff --git a/go.mod b/go.mod index fd4fa566..51a1c692 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,8 @@ require ( github.com/AdguardTeam/golibs v0.2.4 github.com/AdguardTeam/urlfilter v0.6.1 github.com/NYTimes/gziphandler v1.1.1 - github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833 github.com/etcd-io/bbolt v1.3.3 - github.com/go-test/deep v1.0.4 + github.com/go-test/deep v1.0.4 // indirect github.com/gobuffalo/packr v1.19.0 github.com/joomcode/errorx v1.0.0 github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1 // indirect diff --git a/go.sum b/go.sum index 6be363ef..a999a567 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,6 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= -github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833 h1:yCfXxYaelOyqnia8F/Yng47qhmfC9nKTRIbYRrRueq4= -github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=