+ dnsfilter: cache IP addresses of safebrowsing and parental control servers

This commit is contained in:
Simon Zolin 2019-05-13 14:16:07 +03:00
parent a45f0c519e
commit 24ae61de3e
1 changed files with 42 additions and 2 deletions

View File

@ -157,6 +157,7 @@ const (
// these variables need to survive coredns reload
var (
stats Stats
securityCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache gcache.Cache
parentalCache gcache.Cache
safeSearchCache gcache.Cache
@ -972,10 +973,34 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) {
// lifecycle helper functions
//
// Return TRUE if this host's IP should be cached
func (d *Dnsfilter) shouldCache(host string) bool {
return host == d.safeBrowsingServer ||
host == d.parentalServer
}
// Search for an IP address by host name
func searchInCache(host string) string {
rawValue, err := securityCache.Get(host)
if err != nil {
return ""
}
ip, _ := rawValue.(string)
log.Debug("Found in cache: %s -> %s", host, ip)
return ip
}
// Add "hostname" -> "IP address" entry to cache
func addToCache(host, ip string) {
securityCache.Set(host, ip)
log.Debug("Added to cache: %s -> %s", host, ip)
}
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
// Connect to a remote server resolving hostname using our own DNS server
func createCustomDialContext(resolverAddr string) dialFunctionType {
func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
log.Tracef("network:%v addr:%v", network, addr)
@ -993,6 +1018,15 @@ func createCustomDialContext(resolverAddr string) dialFunctionType {
return con, err
}
cache := d.shouldCache(host)
if cache {
ip := searchInCache(host)
if len(ip) != 0 {
addr = fmt.Sprintf("%s:%s", ip, port)
return dialer.DialContext(ctx, network, addr)
}
}
r := upstream.NewResolver(resolverAddr, 30*time.Second)
addrs, e := r.LookupIPAddr(ctx, host)
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
@ -1011,6 +1045,11 @@ func createCustomDialContext(resolverAddr string) dialFunctionType {
}
continue
}
if cache {
addToCache(host, a.String())
}
return con, err
}
return nil, firstErr
@ -1037,7 +1076,8 @@ func New(c *Config) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second,
}
if c != nil && len(c.ResolverAddress) != 0 {
d.transport.DialContext = createCustomDialContext(c.ResolverAddress)
securityCache = gcache.New(2).LRU().Expiration(defaultCacheTime).Build()
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
}
d.client = http.Client{
Transport: d.transport,