diff --git a/AGHTechDoc.md b/AGHTechDoc.md index e70d1609..8eba5ed6 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -64,6 +64,7 @@ Contents: * API: Log in * API: Log out * API: Get current user info +* Safe services ## Relations between subsystems @@ -1747,3 +1748,40 @@ Response: } If no client is configured then authentication is disabled and server sends an empty response. + + +### Safe services + +Check if host name is blocked by SB/PC service: + +* For each host name component, search for the result in cache by the first 2 bytes of SHA-256 hashes of host name components (max. is 4, i.e. sub2.sub1.host.com), excluding TLD: + + hashes[] = cache_search(sha256(host.com)[0..1]) + ... + + If hash prefix is found, search for a full hash sum in the cached data. + If found, the host is blocked. + If not found, the host is not blocked - don't request data for this prefix from the Family server again. + If hash prefix is not found, request data for this prefix from the Family server. + +* Prepare query string which is generated from the first 2 bytes (converted to a 4-character string) of SHA-256 hashes of host name components (max. is 4, i.e. sub2.sub1.host.com), excluding TLD: + + qs = ... + string(sha256(sub.host.com)[0..1]) + "." + string(sha256(host.com)[0..1]) + ".sb.dns.adguard.com." + + For PC `.pc.dns.adguard.com` suffix is used. + +* Send TXT query to Family server, receive response which contains the array of complete hash sums of the blocked hosts + +* Check if one of received hash sums (`hashes[]`) matches hash sums for our host name + + hashes[0] <> sha256(host.com) + hashes[0] <> sha256(sub.host.com) + hashes[1] <> sha256(host.com) + hashes[1] <> sha256(sub.host.com) + ... + +* Store all received hash sums in cache: + + sha256(host.com)[0..1] -> hashes[0],hashes[1],... + sha256(sub.host.com)[0..1] -> hashes[2],... + ... diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index f0dff6ec..0b669367 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -47,9 +47,9 @@ func NewForTest(c *Config, filters []Filter) *Dnsfilter { setts = RequestFilteringSettings{} setts.FilteringEnabled = true if c != nil { - c.SafeBrowsingCacheSize = 1000 + c.SafeBrowsingCacheSize = 10000 + c.ParentalCacheSize = 10000 c.SafeSearchCacheSize = 1000 - c.ParentalCacheSize = 1000 c.CacheTime = 30 setts.SafeSearchEnabled = c.SafeSearchEnabled setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled @@ -146,12 +146,6 @@ func TestEtcHostsMatching(t *testing.T) { // SAFE BROWSING -func TestSafeBrowsingHash(t *testing.T) { - q, hashes := hostnameToHashParam("1.2.3.4.5.6") - assert.Equal(t, "0132d0fa.b5413b4e.5fa067c1.e7f6c011.", q) - assert.Equal(t, 4, len(hashes)) -} - func TestSafeBrowsing(t *testing.T) { d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) defer d.Close() diff --git a/dnsfilter/safe_search.go b/dnsfilter/safe_search.go new file mode 100644 index 00000000..fb1358d5 --- /dev/null +++ b/dnsfilter/safe_search.go @@ -0,0 +1,149 @@ +package dnsfilter + +import ( + "bytes" + "encoding/binary" + "encoding/gob" + "encoding/json" + "fmt" + "net" + "net/http" + "time" + + "github.com/AdguardTeam/golibs/cache" + "github.com/AdguardTeam/golibs/log" +) + +/* +expire byte[4] +res Result +*/ +func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int { + var buf bytes.Buffer + + expire := uint(time.Now().Unix()) + d.Config.CacheTime*60 + var exp []byte + exp = make([]byte, 4) + binary.BigEndian.PutUint32(exp, uint32(expire)) + _, _ = buf.Write(exp) + + enc := gob.NewEncoder(&buf) + err := enc.Encode(res) + if err != nil { + log.Error("gob.Encode(): %s", err) + return 0 + } + val := buf.Bytes() + _ = cache.Set([]byte(host), val) + return len(val) +} + +func getCachedResult(cache cache.Cache, host string) (Result, bool) { + data := cache.Get([]byte(host)) + if data == nil { + return Result{}, false + } + + exp := int(binary.BigEndian.Uint32(data[:4])) + if exp <= int(time.Now().Unix()) { + cache.Del([]byte(host)) + return Result{}, false + } + + var buf bytes.Buffer + buf.Write(data[4:]) + dec := gob.NewDecoder(&buf) + r := Result{} + err := dec.Decode(&r) + if err != nil { + log.Debug("gob.Decode(): %s", err) + return Result{}, false + } + + return r, true +} + +// SafeSearchDomain returns replacement address for search engine +func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { + val, ok := safeSearchDomains[host] + return val, ok +} + +func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { + if log.GetLevel() >= log.DEBUG { + timer := log.StartTimer() + defer timer.LogElapsed("SafeSearch: lookup for %s", host) + } + + // Check cache. Return cached result if it was found + cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host) + if isFound { + // atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1) + log.Tracef("SafeSearch: found in cache: %s", host) + return cachedValue, nil + } + + safeHost, ok := d.SafeSearchDomain(host) + if !ok { + return Result{}, nil + } + + res := Result{IsFiltered: true, Reason: FilteredSafeSearch} + if ip := net.ParseIP(safeHost); ip != nil { + res.IP = ip + valLen := d.setCacheResult(gctx.safeSearchCache, host, res) + log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen) + return res, nil + } + + // TODO this address should be resolved with upstream that was configured in dnsforward + addrs, err := net.LookupIP(safeHost) + if err != nil { + log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) + return Result{}, err + } + + for _, i := range addrs { + if ipv4 := i.To4(); ipv4 != nil { + res.IP = ipv4 + break + } + } + + if len(res.IP) == 0 { + return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost) + } + + // Cache result + valLen := d.setCacheResult(gctx.safeSearchCache, host, res) + log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen) + return res, nil +} + +func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeSearchEnabled = true + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { + d.Config.SafeSearchEnabled = false + d.Config.ConfigModified() +} + +func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { + data := map[string]interface{}{ + "enabled": d.Config.SafeSearchEnabled, + } + jsonVal, err := json.Marshal(data) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) + return + } +} diff --git a/dnsfilter/security.go b/dnsfilter/sb_pc.go similarity index 52% rename from dnsfilter/security.go rename to dnsfilter/sb_pc.go index 9de23add..257513ae 100644 --- a/dnsfilter/security.go +++ b/dnsfilter/sb_pc.go @@ -1,4 +1,4 @@ -// Parental Control, Safe Browsing, Safe Search +// Safe Browsing, Parental Control package dnsfilter @@ -6,12 +6,12 @@ import ( "bytes" "crypto/sha256" "encoding/binary" - "encoding/gob" "encoding/hex" "encoding/json" "fmt" "net" "net/http" + "sort" "strings" "time" @@ -22,9 +22,6 @@ import ( "golang.org/x/net/publicsuffix" ) -// Servers to use for resolution of SB/PC server name -var bootstrapServers = []string{"176.103.130.130", "176.103.130.131"} - const dnsTimeout = 3 * time.Second const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query" const defaultParentalServer = "https://dns-family.adguard.com/dns-query" @@ -35,7 +32,15 @@ func (d *Dnsfilter) initSecurityServices() error { var err error d.safeBrowsingServer = defaultSafebrowsingServer d.parentalServer = defaultParentalServer - opts := upstream.Options{Timeout: dnsTimeout, Bootstrap: bootstrapServers} + opts := upstream.Options{ + Timeout: dnsTimeout, + ServerIPAddrs: []net.IP{ + net.ParseIP("176.103.130.132"), + net.ParseIP("176.103.130.134"), + net.ParseIP("2a00:5a60::bad1:ff"), + net.ParseIP("2a00:5a60::bad2:ff"), + }, + } d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, opts) if err != nil { @@ -52,115 +57,65 @@ func (d *Dnsfilter) initSecurityServices() error { /* expire byte[4] -res Result +hash byte[32] +... */ -func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int { - var buf bytes.Buffer - - expire := uint(time.Now().Unix()) + d.Config.CacheTime*60 - var exp []byte - exp = make([]byte, 4) - binary.BigEndian.PutUint32(exp, uint32(expire)) - _, _ = buf.Write(exp) - - enc := gob.NewEncoder(&buf) - err := enc.Encode(res) - if err != nil { - log.Error("gob.Encode(): %s", err) - return 0 - } - val := buf.Bytes() - _ = cache.Set([]byte(host), val) - return len(val) +func (c *sbCtx) setCache(prefix []byte, hashes []byte) { + d := make([]byte, 4+len(hashes)) + expire := uint(time.Now().Unix()) + c.cacheTime*60 + binary.BigEndian.PutUint32(d[:4], uint32(expire)) + copy(d[4:], hashes) + c.cache.Set(prefix, d) + log.Debug("%s: stored in cache: %v", c.svc, prefix) } -func getCachedResult(cache cache.Cache, host string) (Result, bool) { - data := cache.Get([]byte(host)) - if data == nil { - return Result{}, false - } - - exp := int(binary.BigEndian.Uint32(data[:4])) - if exp <= int(time.Now().Unix()) { - cache.Del([]byte(host)) - return Result{}, false - } - - var buf bytes.Buffer - buf.Write(data[4:]) - dec := gob.NewDecoder(&buf) - r := Result{} - err := dec.Decode(&r) - if err != nil { - log.Debug("gob.Decode(): %s", err) - return Result{}, false - } - - return r, true -} - -// SafeSearchDomain returns replacement address for search engine -func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { - val, ok := safeSearchDomains[host] - return val, ok -} - -func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) { - if log.GetLevel() >= log.DEBUG { - timer := log.StartTimer() - defer timer.LogElapsed("SafeSearch: lookup for %s", host) - } - - // Check cache. Return cached result if it was found - cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host) - if isFound { - // atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1) - log.Tracef("SafeSearch: found in cache: %s", host) - return cachedValue, nil - } - - safeHost, ok := d.SafeSearchDomain(host) - if !ok { - return Result{}, nil - } - - res := Result{IsFiltered: true, Reason: FilteredSafeSearch} - if ip := net.ParseIP(safeHost); ip != nil { - res.IP = ip - valLen := d.setCacheResult(gctx.safeSearchCache, host, res) - log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen) - return res, nil - } - - // TODO this address should be resolved with upstream that was configured in dnsforward - addrs, err := net.LookupIP(safeHost) - if err != nil { - log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err) - return Result{}, err - } - - for _, i := range addrs { - if ipv4 := i.To4(); ipv4 != nil { - res.IP = ipv4 - break +func (c *sbCtx) getCached() int { + now := time.Now().Unix() + hashesToRequest := map[[32]byte]string{} + for k, v := range c.hashToHost { + key := k[0:2] + val := c.cache.Get(key) + if val != nil { + expire := binary.BigEndian.Uint32(val) + if now >= int64(expire) { + val = nil + } else { + for i := 4; i < len(val); i += 32 { + hash := val[i : i+32] + var hash32 [32]byte + copy(hash32[:], hash[0:32]) + _, found := c.hashToHost[hash32] + if found { + log.Debug("%s: found in cache: %s: blocked by %v", c.svc, c.host, hash32) + return 1 + } + } + } + } + if val == nil { + hashesToRequest[k] = v } } - if len(res.IP) == 0 { - return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost) + if len(hashesToRequest) == 0 { + log.Debug("%s: found in cache: %s: not blocked", c.svc, c.host) + return -1 } - // Cache result - valLen := d.setCacheResult(gctx.safeSearchCache, host, res) - log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, valLen) - return res, nil + c.hashToHost = hashesToRequest + return 0 } -// for each dot, hash it and add it to string -// The maximum is 4 components: "a.b.c.d" -func hostnameToHashParam(host string) (string, map[string]bool) { - var hashparam bytes.Buffer - hashes := map[string]bool{} +type sbCtx struct { + host string + svc string + hashToHost map[[32]byte]string + cache cache.Cache + cacheTime uint +} + +func hostnameToHashes(host string) map[[32]byte]string { + hashes := map[[32]byte]string{} tld, icann := publicsuffix.PublicSuffix(host) if !icann { // private suffixes like cloudfront.net @@ -190,8 +145,7 @@ func hostnameToHashParam(host string) (string, map[string]bool) { } sum := sha256.Sum256([]byte(curhost)) - hashes[hex.EncodeToString(sum[:])] = true - hashparam.WriteString(fmt.Sprintf("%s.", hex.EncodeToString(sum[0:4]))) + hashes[sum] = curhost pos := strings.IndexByte(curhost, byte('.')) if pos < 0 { @@ -199,26 +153,91 @@ func hostnameToHashParam(host string) (string, map[string]bool) { } curhost = curhost[pos+1:] } - return hashparam.String(), hashes + return hashes +} + +// convert hash array to string +func (c *sbCtx) getQuestion() string { + q := "" + for hash := range c.hashToHost { + q += fmt.Sprintf("%s.", hex.EncodeToString(hash[0:2])) + } + if c.svc == "SafeBrowsing" { + q += sbTXTSuffix + } else { + q += pcTXTSuffix + } + return q } // Find the target hash in TXT response -func (d *Dnsfilter) processTXT(svc, host string, resp *dns.Msg, hashes map[string]bool) bool { +func (c *sbCtx) processTXT(resp *dns.Msg) (bool, [][]byte) { + matched := false + hashes := [][]byte{} for _, a := range resp.Answer { txt, ok := a.(*dns.TXT) if !ok { continue } - log.Tracef("%s: hashes for %s: %v", svc, host, txt.Txt) + log.Debug("%s: received hashes for %s: %v", c.svc, c.host, txt.Txt) + for _, t := range txt.Txt { - _, ok := hashes[t] - if ok { - log.Tracef("%s: matched %s by %s", svc, host, t) - return true + + if len(t) != 32*2 { + continue + } + hash, err := hex.DecodeString(t) + if err != nil { + continue + } + + hashes = append(hashes, hash) + + if !matched { + var hash32 [32]byte + copy(hash32[:], hash) + hashHost, ok := c.hashToHost[hash32] + if ok { + log.Debug("%s: matched %s by %s/%s", c.svc, c.host, hashHost, t) + matched = true + } } } } - return false + + return matched, hashes +} + +func (c *sbCtx) storeCache(hashes [][]byte) { + sort.Slice(hashes, func(a, b int) bool { + return bytes.Compare(hashes[a], hashes[b]) < 0 + }) + + var curData []byte + var prevPrefix []byte + for i, hash := range hashes { + prefix := hash[0:2] + if !bytes.Equal(prefix, prevPrefix) { + if i != 0 { + c.setCache(prevPrefix, curData) + curData = nil + } + prevPrefix = hashes[i][0:2] + } + curData = append(curData, hash...) + } + + if len(prevPrefix) != 0 { + c.setCache(prevPrefix, curData) + } + + for hash := range c.hashToHost { + prefix := hash[0:2] + val := c.cache.Get(prefix) + if val == nil { + c.setCache(prefix, nil) + } + } } // Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data @@ -229,18 +248,29 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { defer timer.LogElapsed("SafeBrowsing lookup for %s", host) } - // check cache - cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host) - if isFound { - // atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1) - log.Tracef("SafeBrowsing: found in cache: %s", host) - return cachedValue, nil + result := Result{} + hashes := hostnameToHashes(host) + + c := &sbCtx{ + host: host, + svc: "SafeBrowsing", + hashToHost: hashes, + cache: gctx.safebrowsingCache, + cacheTime: d.Config.CacheTime, } - result := Result{} - question, hashes := hostnameToHashParam(host) - question = question + sbTXTSuffix + // check cache + match := c.getCached() + if match < 0 { + return result, nil + } else if match > 0 { + result.IsFiltered = true + result.Reason = FilteredSafeBrowsing + result.Rule = "adguard-malware-shavar" + return result, nil + } + question := c.getQuestion() log.Tracef("SafeBrowsing: checking %s: %s", host, question) req := dns.Msg{} @@ -250,14 +280,14 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { return result, err } - if d.processTXT("SafeBrowsing", host, resp, hashes) { + matched, receivedHashes := c.processTXT(resp) + if matched { result.IsFiltered = true result.Reason = FilteredSafeBrowsing result.Rule = "adguard-malware-shavar" } + c.storeCache(receivedHashes) - valLen := d.setCacheResult(gctx.safebrowsingCache, host, result) - log.Debug("SafeBrowsing: stored in cache: %s (%d bytes)", host, valLen) return result, nil } @@ -269,18 +299,29 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) { defer timer.LogElapsed("Parental lookup for %s", host) } - // check cache - cachedValue, isFound := getCachedResult(gctx.parentalCache, host) - if isFound { - // atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1) - log.Tracef("Parental: found in cache: %s", host) - return cachedValue, nil + result := Result{} + hashes := hostnameToHashes(host) + + c := &sbCtx{ + host: host, + svc: "Parental", + hashToHost: hashes, + cache: gctx.parentalCache, + cacheTime: d.Config.CacheTime, } - result := Result{} - question, hashes := hostnameToHashParam(host) - question = question + pcTXTSuffix + // check cache + match := c.getCached() + if match < 0 { + return result, nil + } else if match > 0 { + result.IsFiltered = true + result.Reason = FilteredParental + result.Rule = "parental CATEGORY_BLACKLISTED" + return result, nil + } + question := c.getQuestion() log.Tracef("Parental: checking %s: %s", host, question) req := dns.Msg{} @@ -290,14 +331,14 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) { return result, err } - if d.processTXT("Parental", host, resp, hashes) { + matched, receivedHashes := c.processTXT(resp) + if matched { result.IsFiltered = true result.Reason = FilteredParental result.Rule = "parental CATEGORY_BLACKLISTED" } + c.storeCache(receivedHashes) - valLen := d.setCacheResult(gctx.parentalCache, host, result) - log.Debug("Parental: stored in cache: %s (%d bytes)", host, valLen) return result, err } @@ -362,34 +403,6 @@ func (d *Dnsfilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) } } -func (d *Dnsfilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeSearchEnabled = true - d.Config.ConfigModified() -} - -func (d *Dnsfilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - d.Config.SafeSearchEnabled = false - d.Config.ConfigModified() -} - -func (d *Dnsfilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { - data := map[string]interface{}{ - "enabled": d.Config.SafeSearchEnabled, - } - jsonVal, err := json.Marshal(data) - if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to marshal status json: %s", err) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err) - return - } -} - func (d *Dnsfilter) registerSecurityHandlers() { d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) diff --git a/dnsfilter/sb_pc_test.go b/dnsfilter/sb_pc_test.go new file mode 100644 index 00000000..4d4d4d04 --- /dev/null +++ b/dnsfilter/sb_pc_test.go @@ -0,0 +1,91 @@ +package dnsfilter + +import ( + "crypto/sha256" + "strings" + "testing" + + "github.com/AdguardTeam/golibs/cache" + "github.com/stretchr/testify/assert" +) + +func TestSafeBrowsingHash(t *testing.T) { + // test hostnameToHashes() + hashes := hostnameToHashes("1.2.3.sub.host.com") + assert.Equal(t, 3, len(hashes)) + _, ok := hashes[sha256.Sum256([]byte("3.sub.host.com"))] + assert.True(t, ok) + _, ok = hashes[sha256.Sum256([]byte("sub.host.com"))] + assert.True(t, ok) + _, ok = hashes[sha256.Sum256([]byte("host.com"))] + assert.True(t, ok) + _, ok = hashes[sha256.Sum256([]byte("com"))] + assert.False(t, ok) + + c := &sbCtx{ + svc: "SafeBrowsing", + } + + // test getQuestion() + c.hashToHost = hashes + q := c.getQuestion() + assert.True(t, strings.Index(q, "7a1b.") >= 0) + assert.True(t, strings.Index(q, "af5a.") >= 0) + assert.True(t, strings.Index(q, "eb11.") >= 0) + assert.True(t, strings.Index(q, "sb.dns.adguard.com.") > 0) +} + +func TestSafeBrowsingCache(t *testing.T) { + c := &sbCtx{ + svc: "SafeBrowsing", + cacheTime: 100, + } + conf := cache.Config{} + c.cache = cache.New(conf) + + // store in cache hashes for "3.sub.host.com" and "host.com" + // and empty data for hash-prefix for "sub.host.com" + hash := sha256.Sum256([]byte("sub.host.com")) + c.hashToHost = make(map[[32]byte]string) + c.hashToHost[hash] = "sub.host.com" + var hashesArray [][]byte + hash4 := sha256.Sum256([]byte("3.sub.host.com")) + hashesArray = append(hashesArray, hash4[:]) + hash2 := sha256.Sum256([]byte("host.com")) + hashesArray = append(hashesArray, hash2[:]) + c.storeCache(hashesArray) + + // match "3.sub.host.com" or "host.com" from cache + c.hashToHost = make(map[[32]byte]string) + hash = sha256.Sum256([]byte("3.sub.host.com")) + c.hashToHost[hash] = "3.sub.host.com" + hash = sha256.Sum256([]byte("sub.host.com")) + c.hashToHost[hash] = "sub.host.com" + hash = sha256.Sum256([]byte("host.com")) + c.hashToHost[hash] = "host.com" + assert.Equal(t, 1, c.getCached()) + + // match "sub.host.com" from cache + c.hashToHost = make(map[[32]byte]string) + hash = sha256.Sum256([]byte("sub.host.com")) + c.hashToHost[hash] = "sub.host.com" + assert.Equal(t, -1, c.getCached()) + + // match "sub.host.com" from cache, + // but another hash for "nonexisting.com" is not in cache + // which means that we must get data from server for it + c.hashToHost = make(map[[32]byte]string) + hash = sha256.Sum256([]byte("sub.host.com")) + c.hashToHost[hash] = "sub.host.com" + hash = sha256.Sum256([]byte("nonexisting.com")) + c.hashToHost[hash] = "nonexisting.com" + assert.Equal(t, 0, c.getCached()) + + hash = sha256.Sum256([]byte("sub.host.com")) + _, ok := c.hashToHost[hash] + assert.False(t, ok) + + hash = sha256.Sum256([]byte("nonexisting.com")) + _, ok = c.hashToHost[hash] + assert.True(t, ok) +} diff --git a/dnsforward/config.go b/dnsforward/config.go index db4a043f..24ed3f7f 100644 --- a/dnsforward/config.go +++ b/dnsforward/config.go @@ -133,8 +133,8 @@ var defaultValues = ServerConfig{ // createProxyConfig creates and validates configuration for the main proxy func (s *Server) createProxyConfig() (proxy.Config, error) { proxyConfig := proxy.Config{ - UDPListenAddr: s.conf.UDPListenAddr, - TCPListenAddr: s.conf.TCPListenAddr, + UDPListenAddr: []*net.UDPAddr{s.conf.UDPListenAddr}, + TCPListenAddr: []*net.TCPAddr{s.conf.TCPListenAddr}, Ratelimit: int(s.conf.Ratelimit), RatelimitWhitelist: s.conf.RatelimitWhitelist, RefuseAny: s.conf.RefuseAny, @@ -229,7 +229,7 @@ func (s *Server) prepareIntlProxy() { // prepareTLS - prepares TLS configuration for the DNS proxy func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { if s.conf.TLSListenAddr != nil && len(s.conf.CertificateChainData) != 0 && len(s.conf.PrivateKeyData) != 0 { - proxyConfig.TLSListenAddr = s.conf.TLSListenAddr + proxyConfig.TLSListenAddr = []*net.TCPAddr{s.conf.TLSListenAddr} var err error s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) if err != nil { diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 0ca5549e..051178cd 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -252,10 +252,6 @@ func TestBlockedRequest(t *testing.T) { func TestServerCustomClientUpstream(t *testing.T) { s := createTestServer(t) - err := s.Start() - if err != nil { - t.Fatalf("Failed to start server: %s", err) - } s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig { uc := &proxy.UpstreamConfig{} u := &testUpstream{} @@ -264,6 +260,9 @@ func TestServerCustomClientUpstream(t *testing.T) { uc.Upstreams = append(uc.Upstreams, u) return uc } + + assert.Nil(t, s.Start()) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) // Send test request diff --git a/go.mod b/go.mod index 49902aaf..a9f5a605 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.14 require ( - github.com/AdguardTeam/dnsproxy v0.29.1 + github.com/AdguardTeam/dnsproxy v0.30.1 github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/urlfilter v0.11.2 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index 18a415ee..4271f61e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.29.1 h1:Stc+JLh67C9K38vbrH2920+3FnbXKkFzYQqRiu5auUo= -github.com/AdguardTeam/dnsproxy v0.29.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= +github.com/AdguardTeam/dnsproxy v0.30.1 h1:SnsL5kM/eFTrtLLdww1EePOhVDZTWzMkse+5tadGhvc= +github.com/AdguardTeam/dnsproxy v0.30.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=