+ dnsfilter: use AG DNS server for SB/PC services

* move SS/SB/PC services to security.go
* remove old useless code (HTTP client)
This commit is contained in:
Simon Zolin 2019-10-16 12:57:49 +03:00
parent 2f1e631c66
commit 7f69848084
6 changed files with 308 additions and 602 deletions

View File

@ -1,13 +1,7 @@
package dnsfilter package dnsfilter
import ( import (
"bufio"
"bytes" "bytes"
"context"
"crypto/sha256"
"encoding/binary"
"encoding/gob"
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"net" "net"
@ -16,30 +10,14 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time"
"github.com/joomcode/errorx"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter"
"github.com/bluele/gcache"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/publicsuffix"
) )
const defaultHTTPTimeout = 5 * time.Minute
const defaultHTTPMaxIdleConnections = 100
const defaultSafebrowsingServer = "sb.adtidy.org"
const defaultSafebrowsingURL = "%s://%s/safebrowsing-lookup-hash.html?prefixes=%s"
const defaultParentalServer = "pctrl.adguard.com"
const defaultParentalURL = "%s://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d"
const defaultParentalSensitivity = 13 // use "TEEN" by default
const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control
// ServiceEntry - blocked service array element // ServiceEntry - blocked service array element
type ServiceEntry struct { type ServiceEntry struct {
Name string Name string
@ -65,7 +43,6 @@ type RewriteEntry struct {
type Config struct { type Config struct {
ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
ParentalEnabled bool `yaml:"parental_enabled"` ParentalEnabled bool `yaml:"parental_enabled"`
UsePlainHTTP bool `yaml:"-"` // use plain HTTP for requests to parental and safe browsing servers
SafeSearchEnabled bool `yaml:"safesearch_enabled"` SafeSearchEnabled bool `yaml:"safesearch_enabled"`
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
ResolverAddress string `yaml:"-"` // DNS server address ResolverAddress string `yaml:"-"` // DNS server address
@ -110,12 +87,10 @@ type Dnsfilter struct {
filteringEngine *urlfilter.DNSEngine filteringEngine *urlfilter.DNSEngine
engineLock sync.RWMutex engineLock sync.RWMutex
// HTTP lookups for safebrowsing and parental parentalServer string // access via methods
client http.Client // handle for http client -- single instance as recommended by docs safeBrowsingServer string // access via methods
transport *http.Transport // handle for http transport used by http client parentalUpstream upstream.Upstream
safeBrowsingUpstream upstream.Upstream
parentalServer string // access via methods
safeBrowsingServer string // access via methods
Config // for direct access by library users, even a = assignment Config // for direct access by library users, even a = assignment
confLock sync.RWMutex confLock sync.RWMutex
@ -251,9 +226,6 @@ func (d *Dnsfilter) filtersInitializer() {
// Close - close the object // Close - close the object
func (d *Dnsfilter) Close() { func (d *Dnsfilter) Close() {
if d != nil && d.transport != nil {
d.transport.CloseIdleConnections()
}
if d.rulesStorage != nil { if d.rulesStorage != nil {
d.rulesStorage.Close() d.rulesStorage.Close()
} }
@ -261,7 +233,6 @@ func (d *Dnsfilter) Close() {
type dnsFilterContext struct { type dnsFilterContext struct {
stats Stats stats Stats
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache cache.Cache safebrowsingCache cache.Cache
parentalCache cache.Cache parentalCache cache.Cache
safeSearchCache cache.Cache safeSearchCache cache.Cache
@ -328,11 +299,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
} }
} }
// check safeSearch if no match
if setts.SafeSearchEnabled { if setts.SafeSearchEnabled {
result, err = d.checkSafeSearch(host) result, err = d.checkSafeSearch(host)
if err != nil { if err != nil {
log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err) log.Info("SafeSearch: failed: %v", err)
return Result{}, nil return Result{}, nil
} }
@ -341,12 +311,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
} }
} }
// check safebrowsing if no match
if setts.SafeBrowsingEnabled { if setts.SafeBrowsingEnabled {
result, err = d.checkSafeBrowsing(host) result, err = d.checkSafeBrowsing(host)
if err != nil { if err != nil {
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache log.Info("SafeBrowsing: failed: %v", err)
log.Printf("Failed to do safebrowsing HTTP lookup, ignoring check: %v", err)
return Result{}, nil return Result{}, nil
} }
if result.Reason.Matched() { if result.Reason.Matched() {
@ -354,12 +322,10 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
} }
} }
// check parental if no match
if setts.ParentalEnabled { if setts.ParentalEnabled {
result, err = d.checkParental(host) result, err = d.checkParental(host)
if err != nil { if err != nil {
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache log.Printf("Parental: failed: %v", err)
log.Printf("Failed to do parental HTTP lookup, ignoring check: %v", err)
return Result{}, nil return Result{}, nil
} }
if result.Reason.Matched() { if result.Reason.Matched() {
@ -367,7 +333,6 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering
} }
} }
// nothing matched, return nothing
return Result{}, nil return Result{}, nil
} }
@ -445,311 +410,6 @@ func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result {
return res return res
} }
/*
expire byte[4]
res Result
*/
func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) {
var buf bytes.Buffer
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60
var exp []byte
exp = make([]byte, 4)
binary.BigEndian.PutUint32(exp, uint32(expire))
_, _ = buf.Write(exp)
enc := gob.NewEncoder(&buf)
err := enc.Encode(res)
if err != nil {
log.Error("gob.Encode(): %s", err)
return
}
_ = cache.Set([]byte(host), buf.Bytes())
log.Debug("Stored in cache %p: %s", cache, host)
}
func getCachedResult(cache cache.Cache, host string) (Result, bool) {
data := cache.Get([]byte(host))
if data == nil {
return Result{}, false
}
exp := int(binary.BigEndian.Uint32(data[:4]))
if exp <= int(time.Now().Unix()) {
cache.Del([]byte(host))
return Result{}, false
}
var buf bytes.Buffer
buf.Write(data[4:])
dec := gob.NewDecoder(&buf)
r := Result{}
err := dec.Decode(&r)
if err != nil {
log.Debug("gob.Decode(): %s", err)
return Result{}, false
}
return r, true
}
// for each dot, hash it and add it to string
func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) {
var hashparam bytes.Buffer
hashes := map[string]bool{}
tld, icann := publicsuffix.PublicSuffix(host)
if !icann {
// private suffixes like cloudfront.net
tld = ""
}
curhost := host
for {
if curhost == "" {
// we've reached end of string
break
}
if tld != "" && curhost == tld {
// we've reached the TLD, don't hash it
break
}
tohash := []byte(curhost)
if addslash {
tohash = append(tohash, '/')
}
sum := sha256.Sum256(tohash)
hexhash := fmt.Sprintf("%X", sum)
hashes[hexhash] = true
hashparam.WriteString(fmt.Sprintf("%02X%02X%02X%02X/", sum[0], sum[1], sum[2], sum[3]))
pos := strings.IndexByte(curhost, byte('.'))
if pos < 0 {
break
}
curhost = curhost[pos+1:]
}
return hashparam.String(), hashes
}
func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
if isFound {
atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("%s: found in SafeSearch cache", host)
return cachedValue, nil
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
return Result{}, nil
}
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
d.setCacheResult(gctx.safeSearchCache, host, res)
return res, nil
}
// TODO this address should be resolved with upstream that was configured in dnsforward
addrs, err := net.LookupIP(safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, i := range addrs {
if ipv4 := i.To4(); ipv4 != nil {
res.IP = ipv4
break
}
}
if len(res.IP) == 0 {
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
}
// Cache result
d.setCacheResult(gctx.safeSearchCache, host, res)
return res, nil
}
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host)
}
format := func(hashparam string) string {
schema := "https"
if d.UsePlainHTTP {
schema = "http"
}
url := fmt.Sprintf(defaultSafebrowsingURL, schema, d.safeBrowsingServer, hashparam)
return url
}
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
result := Result{}
scanner := bufio.NewScanner(strings.NewReader(string(body)))
for scanner.Scan() {
line := scanner.Text()
splitted := strings.Split(line, ":")
if len(splitted) < 3 {
continue
}
hash := splitted[2]
if _, ok := hashes[hash]; ok {
// it's in the hash
result.IsFiltered = true
result.Reason = FilteredSafeBrowsing
result.Rule = splitted[0]
break
}
}
if err := scanner.Err(); err != nil {
// error, don't save cache
return Result{}, err
}
return result, nil
}
// check cache
cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host)
if isFound {
atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1)
log.Tracef("%s: found in the lookup cache %p", host, gctx.safebrowsingCache)
return cachedValue, nil
}
result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, true, format, handleBody)
if err == nil {
d.setCacheResult(gctx.safebrowsingCache, host, result)
}
return result, err
}
func (d *Dnsfilter) checkParental(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("Parental HTTP lookup for %s", host)
}
format := func(hashparam string) string {
schema := "https"
if d.UsePlainHTTP {
schema = "http"
}
sensitivity := d.ParentalSensitivity
if sensitivity == 0 {
sensitivity = defaultParentalSensitivity
}
url := fmt.Sprintf(defaultParentalURL, schema, d.parentalServer, hashparam, sensitivity)
return url
}
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
// parse json
var m []struct {
Blocked bool `json:"blocked"`
ClientTTL int `json:"clientTtl"`
Reason string `json:"reason"`
Hash string `json:"hash"`
}
err := json.Unmarshal(body, &m)
if err != nil {
// error, don't save cache
log.Printf("Couldn't parse json '%s': %s", body, err)
return Result{}, err
}
result := Result{}
for i := range m {
if !hashes[m[i].Hash] {
continue
}
if m[i].Blocked {
result.IsFiltered = true
result.Reason = FilteredParental
result.Rule = fmt.Sprintf("parental %s", m[i].Reason)
break
}
}
return result, nil
}
// check cache
cachedValue, isFound := getCachedResult(gctx.parentalCache, host)
if isFound {
atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1)
log.Tracef("%s: found in the lookup cache %p", host, gctx.parentalCache)
return cachedValue, nil
}
result, err := d.lookupCommon(host, &gctx.stats.Parental, false, format, handleBody)
if err == nil {
d.setCacheResult(gctx.parentalCache, host, result)
}
return result, err
}
type formatHandler func(hashparam string) string
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
// real implementation of lookup/check
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
// convert hostname to hash parameters
hashparam, hashes := hostnameToHashParam(host, hashparamNeedSlash)
// format URL with our hashes
url := format(hashparam)
// do HTTP request
atomic.AddUint64(&lookupstats.Requests, 1)
atomic.AddInt64(&lookupstats.Pending, 1)
updateMax(&lookupstats.Pending, &lookupstats.PendingMax)
resp, err := d.client.Get(url)
atomic.AddInt64(&lookupstats.Pending, -1)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
if err != nil {
// error, don't save cache
return Result{}, err
}
// get body text
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
// error, don't save cache
return Result{}, err
}
// handle status code
switch {
case resp.StatusCode == 204:
// empty result, save cache
return Result{}, nil
case resp.StatusCode != 200:
return Result{}, fmt.Errorf("HTTP status code: %d", resp.StatusCode)
}
result, err := handleBody(body, hashes)
if err != nil {
return Result{}, err
}
return result, nil
}
// //
// Adding rule and matching against the rules // Adding rule and matching against the rules
// //
@ -887,97 +547,6 @@ func (d *Dnsfilter) matchHost(host string, qtype uint16) (Result, error) {
return Result{}, nil return Result{}, nil
} }
//
// lifecycle helper functions
//
// Return TRUE if this host's IP should be cached
func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
return host == d.safeBrowsingServer ||
host == d.parentalServer
}
// Search for an IP address by host name
func searchInDialCache(host string) string {
rawValue, err := gctx.dialCache.Get(host)
if err != nil {
return ""
}
ip, _ := rawValue.(string)
log.Debug("Found in cache: %s -> %s", host, ip)
return ip
}
// Add "hostname" -> "IP address" entry to cache
func addToDialCache(host, ip string) {
err := gctx.dialCache.Set(host, ip)
if err != nil {
log.Debug("dialCache.Set: %s", err)
}
log.Debug("Added to cache: %s -> %s", host, ip)
}
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
// Connect to a remote server resolving hostname using our own DNS server
func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
dialer := &net.Dialer{
Timeout: time.Minute * 5,
}
if net.ParseIP(host) != nil {
con, err := dialer.DialContext(ctx, network, addr)
return con, err
}
cache := d.shouldBeInDialCache(host)
if cache {
ip := searchInDialCache(host)
if len(ip) != 0 {
addr = fmt.Sprintf("%s:%s", ip, port)
return dialer.DialContext(ctx, network, addr)
}
}
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
if e != nil {
return nil, e
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
var dialErrs []error
for _, a := range addrs {
addr = fmt.Sprintf("%s:%s", a.String(), port)
con, err := dialer.DialContext(ctx, network, addr)
if err != nil {
dialErrs = append(dialErrs, err)
continue
}
if cache {
addToDialCache(host, a.String())
}
return con, err
}
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
}
// New creates properly initialized DNS Filter that is ready to be used // New creates properly initialized DNS Filter that is ready to be used
func New(c *Config, filters map[int]string) *Dnsfilter { func New(c *Config, filters map[int]string) *Dnsfilter {
@ -1002,34 +571,16 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
cacheConf.MaxSize = c.ParentalCacheSize cacheConf.MaxSize = c.ParentalCacheSize
gctx.parentalCache = cache.New(cacheConf) gctx.parentalCache = cache.New(cacheConf)
} }
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
dur := time.Duration(c.CacheTime) * time.Minute
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(dur).Build()
}
} }
d := new(Dnsfilter) d := new(Dnsfilter)
// Customize the Transport to have larger connection pool, err := d.initSecurityServices()
// We are not (re)using http.DefaultTransport because of race conditions found by tests if err != nil {
d.transport = &http.Transport{ log.Error("dnsfilter: initialize services: %s", err)
Proxy: http.ProxyFromEnvironment, return nil
MaxIdleConns: defaultHTTPMaxIdleConnections, // default 100
MaxIdleConnsPerHost: defaultHTTPMaxIdleConnections, // default 2
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
} }
if c != nil && len(c.ResolverAddress) != 0 {
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
}
d.client = http.Client{
Transport: d.transport,
Timeout: defaultHTTPTimeout,
}
d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer
if c != nil { if c != nil {
d.Config = *c d.Config = *c
} }
@ -1053,38 +604,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
return d return d
} }
//
// config manipulation helpers
//
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
if len(host) == 0 {
d.safeBrowsingServer = defaultSafebrowsingServer
} else {
d.safeBrowsingServer = host
}
}
// SetHTTPTimeout lets you optionally change timeout during lookups
func (d *Dnsfilter) SetHTTPTimeout(t time.Duration) {
d.client.Timeout = t
}
// ResetHTTPTimeout resets lookup timeouts
func (d *Dnsfilter) ResetHTTPTimeout() {
d.client.Timeout = defaultHTTPTimeout
}
// SafeSearchDomain returns replacement address for search engine
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
if d.SafeSearchEnabled {
val, ok := safeSearchDomains[host]
return val, ok
}
return "", false
}
// //
// stats // stats
// //

View File

@ -3,15 +3,11 @@ package dnsfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/http"
"net/http/httptest"
"path" "path"
"runtime" "runtime"
"testing" "testing"
"time"
"github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter"
"github.com/bluele/gcache"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -23,7 +19,6 @@ var setts RequestFilteringSettings
// SAFE SEARCH // SAFE SEARCH
// PARENTAL // PARENTAL
// FILTERING // FILTERING
// CLIENTS SETTINGS
// BENCHMARKS // BENCHMARKS
// HELPERS // HELPERS
@ -126,34 +121,19 @@ func TestEtcHostsMatching(t *testing.T) {
// SAFE BROWSING // SAFE BROWSING
func TestSafeBrowsing(t *testing.T) { func TestSafeBrowsing(t *testing.T) {
testCases := []string{ d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
"", defer d.Close()
"sb.adtidy.org", gctx.stats.Safebrowsing.Requests = 0
} d.checkMatch(t, "wmconvirus.narod.ru")
for _, tc := range testCases { d.checkMatch(t, "test.wmconvirus.narod.ru")
t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { d.checkMatchEmpty(t, "yandex.ru")
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil) d.checkMatchEmpty(t, "pornhub.com")
defer d.Close()
gctx.stats.Safebrowsing.Requests = 0 // test cached result
d.checkMatch(t, "wmconvirus.narod.ru") d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru")
if gctx.stats.Safebrowsing.Requests != 1 { d.checkMatchEmpty(t, "pornhub.com")
t.Errorf("Safebrowsing lookup positive cache is not working: %v", gctx.stats.Safebrowsing.Requests) d.safeBrowsingServer = defaultSafebrowsingServer
}
d.checkMatch(t, "WMconvirus.narod.ru")
if gctx.stats.Safebrowsing.Requests != 1 {
t.Errorf("Safebrowsing lookup positive cache is not working: %v", gctx.stats.Safebrowsing.Requests)
}
d.checkMatch(t, "test.wmconvirus.narod.ru")
d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com")
l := gctx.stats.Safebrowsing.Requests
d.checkMatchEmpty(t, "pornhub.com")
if gctx.stats.Safebrowsing.Requests != l {
t.Errorf("Safebrowsing lookup negative cache is not working: %v", gctx.stats.Safebrowsing.Requests)
}
})
}
} }
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
@ -172,33 +152,10 @@ func TestParallelSB(t *testing.T) {
}) })
} }
// the only way to verify that custom server option is working is to point it at a server that does serve safebrowsing
func TestSafeBrowsingCustomServerFail(t *testing.T) {
d := NewForTest(&Config{SafeBrowsingEnabled: true}, nil)
defer d.Close()
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.Write("Hello, client")
fmt.Fprintln(w, "Hello, client")
}))
defer ts.Close()
address := ts.Listener.Addr().String()
d.SetHTTPTimeout(time.Second * 5)
d.SetSafeBrowsingServer(address) // this will ensure that test fails
d.checkMatchEmpty(t, "wmconvirus.narod.ru")
}
// SAFE SEARCH // SAFE SEARCH
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
d := NewForTest(nil, nil) d := NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close()
_, ok := d.SafeSearchDomain("www.google.com")
if ok {
t.Errorf("Expected safesearch to error when disabled")
}
d = NewForTest(&Config{SafeSearchEnabled: true}, nil)
defer d.Close() defer d.Close()
val, ok := d.SafeSearchDomain("www.google.com") val, ok := d.SafeSearchDomain("www.google.com")
if !ok { if !ok {
@ -355,24 +312,16 @@ func TestParentalControl(t *testing.T) {
defer d.Close() defer d.Close()
d.ParentalSensitivity = 3 d.ParentalSensitivity = 3
d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com")
d.checkMatch(t, "pornhub.com")
if gctx.stats.Parental.Requests != 1 {
t.Errorf("Parental lookup positive cache is not working")
}
d.checkMatch(t, "PORNhub.com")
if gctx.stats.Parental.Requests != 1 {
t.Errorf("Parental lookup positive cache is not working")
}
d.checkMatch(t, "www.pornhub.com") d.checkMatch(t, "www.pornhub.com")
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
l := gctx.stats.Parental.Requests
d.checkMatchEmpty(t, "yandex.ru")
if gctx.stats.Parental.Requests != l {
t.Errorf("Parental lookup negative cache is not working")
}
d.checkMatchEmpty(t, "api.jquery.com") d.checkMatchEmpty(t, "api.jquery.com")
// test cached result
d.parentalServer = "127.0.0.1"
d.checkMatch(t, "pornhub.com")
d.checkMatchEmpty(t, "yandex.ru")
d.parentalServer = defaultParentalServer
} }
// FILTERING // FILTERING
@ -588,17 +537,3 @@ func BenchmarkSafeSearchParallel(b *testing.B) {
} }
}) })
} }
func TestDnsfilterDialCache(t *testing.T) {
d := Dnsfilter{}
gctx.dialCache = gcache.New(1).LRU().Expiration(30 * time.Minute).Build()
d.shouldBeInDialCache("hostname")
if searchInDialCache("hostname") != "" {
t.Errorf("searchInDialCache")
}
addToDialCache("hostname", "1.1.1.1")
if searchInDialCache("hostname") != "1.1.1.1" {
t.Errorf("searchInDialCache")
}
}

View File

@ -1,20 +0,0 @@
package dnsfilter
import (
"sync/atomic"
)
func updateMax(valuePtr *int64, maxPtr *int64) {
for {
current := atomic.LoadInt64(valuePtr)
max := atomic.LoadInt64(maxPtr)
if current <= max {
break
}
swapped := atomic.CompareAndSwapInt64(maxPtr, max, current)
if swapped {
break
}
// swapping failed because value has changed after reading, try again
}
}

View File

@ -4,17 +4,290 @@ package dnsfilter
import ( import (
"bufio" "bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/gob"
"encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"golang.org/x/net/publicsuffix"
) )
const dnsTimeout = 3 * time.Second
const defaultSafebrowsingServer = "https://dns-family.adguard.com/dns-query"
const defaultParentalServer = "https://dns-family.adguard.com/dns-query"
const sbTXTSuffix = "sb.dns.adguard.com."
const pcTXTSuffix = "pc.dns.adguard.com."
func (d *Dnsfilter) initSecurityServices() error {
var err error
d.safeBrowsingServer = defaultSafebrowsingServer
d.parentalServer = defaultParentalServer
d.parentalUpstream, err = upstream.AddressToUpstream(d.parentalServer, upstream.Options{Timeout: dnsTimeout})
if err != nil {
return err
}
d.safeBrowsingUpstream, err = upstream.AddressToUpstream(d.safeBrowsingServer, upstream.Options{Timeout: dnsTimeout})
if err != nil {
return err
}
return nil
}
/*
expire byte[4]
res Result
*/
func (d *Dnsfilter) setCacheResult(cache cache.Cache, host string, res Result) int {
var buf bytes.Buffer
expire := uint(time.Now().Unix()) + d.Config.CacheTime*60
var exp []byte
exp = make([]byte, 4)
binary.BigEndian.PutUint32(exp, uint32(expire))
_, _ = buf.Write(exp)
enc := gob.NewEncoder(&buf)
err := enc.Encode(res)
if err != nil {
log.Error("gob.Encode(): %s", err)
return 0
}
val := buf.Bytes()
_ = cache.Set([]byte(host), val)
return len(val)
}
func getCachedResult(cache cache.Cache, host string) (Result, bool) {
data := cache.Get([]byte(host))
if data == nil {
return Result{}, false
}
exp := int(binary.BigEndian.Uint32(data[:4]))
if exp <= int(time.Now().Unix()) {
cache.Del([]byte(host))
return Result{}, false
}
var buf bytes.Buffer
buf.Write(data[4:])
dec := gob.NewDecoder(&buf)
r := Result{}
err := dec.Decode(&r)
if err != nil {
log.Debug("gob.Decode(): %s", err)
return Result{}, false
}
return r, true
}
// SafeSearchDomain returns replacement address for search engine
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
val, ok := safeSearchDomains[host]
return val, ok
}
func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := getCachedResult(gctx.safeSearchCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("SafeSearch: found in cache: %s", host)
return cachedValue, nil
}
safeHost, ok := d.SafeSearchDomain(host)
if !ok {
return Result{}, nil
}
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
len := d.setCacheResult(gctx.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, len)
return res, nil
}
// TODO this address should be resolved with upstream that was configured in dnsforward
addrs, err := net.LookupIP(safeHost)
if err != nil {
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
return Result{}, err
}
for _, i := range addrs {
if ipv4 := i.To4(); ipv4 != nil {
res.IP = ipv4
break
}
}
if len(res.IP) == 0 {
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
}
// Cache result
len := d.setCacheResult(gctx.safeSearchCache, host, res)
log.Debug("SafeSearch: stored in cache: %s (%d bytes)", host, len)
return res, nil
}
// for each dot, hash it and add it to string
func hostnameToHashParam(host string) (string, map[string]bool) {
var hashparam bytes.Buffer
hashes := map[string]bool{}
tld, icann := publicsuffix.PublicSuffix(host)
if !icann {
// private suffixes like cloudfront.net
tld = ""
}
curhost := host
for {
if curhost == "" {
// we've reached end of string
break
}
if tld != "" && curhost == tld {
// we've reached the TLD, don't hash it
break
}
sum := sha256.Sum256([]byte(curhost))
hashes[hex.EncodeToString(sum[:])] = true
hashparam.WriteString(fmt.Sprintf("%s.", hex.EncodeToString(sum[0:4])))
pos := strings.IndexByte(curhost, byte('.'))
if pos < 0 {
break
}
curhost = curhost[pos+1:]
}
return hashparam.String(), hashes
}
// Find the target hash in TXT response
func (d *Dnsfilter) processTXT(svc, host string, resp *dns.Msg, hashes map[string]bool) bool {
for _, a := range resp.Answer {
txt, ok := a.(*dns.TXT)
if !ok {
continue
}
log.Tracef("%s: hashes for %s: %v", svc, host, txt.Txt)
for _, t := range txt.Txt {
_, ok := hashes[t]
if ok {
log.Tracef("%s: matched %s by %s", svc, host, t)
return true
}
}
}
return false
}
// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data
// nolint:dupl
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
}
// check cache
cachedValue, isFound := getCachedResult(gctx.safebrowsingCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Safebrowsing.CacheHits, 1)
log.Tracef("SafeBrowsing: found in cache: %s", host)
return cachedValue, nil
}
result := Result{}
question, hashes := hostnameToHashParam(host)
question = question + sbTXTSuffix
log.Tracef("SafeBrowsing: checking %s: %s", host, question)
req := dns.Msg{}
req.SetQuestion(question, dns.TypeTXT)
resp, err := d.safeBrowsingUpstream.Exchange(&req)
if err != nil {
return result, err
}
if d.processTXT("SafeBrowsing", host, resp, hashes) {
result.IsFiltered = true
result.Reason = FilteredSafeBrowsing
result.Rule = "adguard-malware-shavar"
}
len := d.setCacheResult(gctx.safebrowsingCache, host, result)
log.Debug("SafeBrowsing: stored in cache: %s (%d bytes)", host, len)
return result, nil
}
// Disabling "dupl": the algorithm of SB/PC is similar, but it uses different data
// nolint:dupl
func (d *Dnsfilter) checkParental(host string) (Result, error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host)
}
// check cache
cachedValue, isFound := getCachedResult(gctx.parentalCache, host)
if isFound {
// atomic.AddUint64(&gctx.stats.Parental.CacheHits, 1)
log.Tracef("Parental: found in cache: %s", host)
return cachedValue, nil
}
result := Result{}
question, hashes := hostnameToHashParam(host)
question = question + pcTXTSuffix
log.Tracef("Parental: checking %s: %s", host, question)
req := dns.Msg{}
req.SetQuestion(question, dns.TypeTXT)
resp, err := d.parentalUpstream.Exchange(&req)
if err != nil {
return result, err
}
if d.processTXT("Parental", host, resp, hashes) {
result.IsFiltered = true
result.Reason = FilteredParental
result.Rule = "parental CATEGORY_BLACKLISTED"
}
len := d.setCacheResult(gctx.parentalCache, host, result)
log.Debug("Parental: stored in cache: %s (%d bytes)", host, len)
return result, err
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...) text := fmt.Sprintf(format, args...)
log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text) log.Info("DNSFilter: %s %s: %s", r.Method, r.URL, text)
@ -170,9 +443,11 @@ func (d *Dnsfilter) registerSecurityHandlers() {
d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable) d.Config.HTTPRegister("POST", "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable) d.Config.HTTPRegister("POST", "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus) d.Config.HTTPRegister("GET", "/control/safebrowsing/status", d.handleSafeBrowsingStatus)
d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable) d.Config.HTTPRegister("POST", "/control/parental/enable", d.handleParentalEnable)
d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable) d.Config.HTTPRegister("POST", "/control/parental/disable", d.handleParentalDisable)
d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus) d.Config.HTTPRegister("GET", "/control/parental/status", d.handleParentalStatus)
d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable) d.Config.HTTPRegister("POST", "/control/safesearch/enable", d.handleSafeSearchEnable)
d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable) d.Config.HTTPRegister("POST", "/control/safesearch/disable", d.handleSafeSearchDisable)
d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus) d.Config.HTTPRegister("GET", "/control/safesearch/status", d.handleSafeSearchStatus)

3
go.mod
View File

@ -7,9 +7,8 @@ require (
github.com/AdguardTeam/golibs v0.2.4 github.com/AdguardTeam/golibs v0.2.4
github.com/AdguardTeam/urlfilter v0.6.1 github.com/AdguardTeam/urlfilter v0.6.1
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833
github.com/etcd-io/bbolt v1.3.3 github.com/etcd-io/bbolt v1.3.3
github.com/go-test/deep v1.0.4 github.com/go-test/deep v1.0.4 // indirect
github.com/gobuffalo/packr v1.19.0 github.com/gobuffalo/packr v1.19.0
github.com/joomcode/errorx v1.0.0 github.com/joomcode/errorx v1.0.0
github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1 // indirect github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1 // indirect

2
go.sum
View File

@ -27,8 +27,6 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4
github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I=
github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA=
github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833 h1:yCfXxYaelOyqnia8F/Yng47qhmfC9nKTRIbYRrRueq4=
github.com/bluele/gcache v0.0.0-20190518031135-bc40bd653833/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=