* dnsfilter: use a single global context object

This commit is contained in:
Simon Zolin 2019-06-24 19:00:03 +03:00
parent f1e6a30931
commit 2307f55715
1 changed files with 30 additions and 23 deletions

View File

@ -128,14 +128,15 @@ const (
FilteredSafeSearch
)
// these variables need to survive coredns reload
var (
type dnsfContext struct {
stats Stats
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
safebrowsingCache gcache.Cache
parentalCache gcache.Cache
safeSearchCache gcache.Cache
)
}
var gctx dnsfContext // global dnsfilter context
// Result holds state of hostname check
type Result struct {
@ -298,14 +299,10 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
}
if safeSearchCache == nil {
safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
// Check cache. Return cached result if it was found
cachedValue, isFound, err := getCachedReason(safeSearchCache, host)
cachedValue, isFound, err := getCachedReason(gctx.safeSearchCache, host)
if isFound {
atomic.AddUint64(&stats.Safesearch.CacheHits, 1)
atomic.AddUint64(&gctx.stats.Safesearch.CacheHits, 1)
log.Tracef("%s: found in SafeSearch cache", host)
return cachedValue, nil
}
@ -322,7 +319,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
if ip := net.ParseIP(safeHost); ip != nil {
res.IP = ip
err = safeSearchCache.Set(host, res)
err = gctx.safeSearchCache.Set(host, res)
if err != nil {
return Result{}, nil
}
@ -349,7 +346,7 @@ func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
}
// Cache result
err = safeSearchCache.Set(host, res)
err = gctx.safeSearchCache.Set(host, res)
if err != nil {
return Result{}, nil
}
@ -395,10 +392,7 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
}
return result, nil
}
if safebrowsingCache == nil {
safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
result, err := d.lookupCommon(host, &gctx.stats.Safebrowsing, gctx.safebrowsingCache, true, format, handleBody)
return result, err
}
@ -450,10 +444,7 @@ func (d *Dnsfilter) checkParental(host string) (Result, error) {
}
return result, nil
}
if parentalCache == nil {
parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
result, err := d.lookupCommon(host, &gctx.stats.Parental, gctx.parentalCache, false, format, handleBody)
return result, err
}
@ -620,7 +611,7 @@ func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
// Search for an IP address by host name
func searchInDialCache(host string) string {
rawValue, err := dialCache.Get(host)
rawValue, err := gctx.dialCache.Get(host)
if err != nil {
return ""
}
@ -632,7 +623,7 @@ func searchInDialCache(host string) string {
// Add "hostname" -> "IP address" entry to cache
func addToDialCache(host, ip string) {
err := dialCache.Set(host, ip)
err := gctx.dialCache.Set(host, ip)
if err != nil {
log.Debug("dialCache.Set: %s", err)
}
@ -701,6 +692,23 @@ func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionTyp
// New creates properly initialized DNS Filter that is ready to be used
func New(c *Config, filters map[int]string) *Dnsfilter {
if c != nil {
// initialize objects only once
if c.SafeBrowsingEnabled && gctx.safebrowsingCache == nil {
gctx.safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if c.SafeSearchEnabled && gctx.safeSearchCache == nil {
gctx.safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if c.ParentalEnabled && gctx.parentalCache == nil {
gctx.parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
if len(c.ResolverAddress) != 0 && gctx.dialCache == nil {
gctx.dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
}
}
d := new(Dnsfilter)
// Customize the Transport to have larger connection pool,
@ -714,7 +722,6 @@ func New(c *Config, filters map[int]string) *Dnsfilter {
ExpectContinueTimeout: 1 * time.Second,
}
if c != nil && len(c.ResolverAddress) != 0 {
dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
}
d.client = http.Client{
@ -790,5 +797,5 @@ func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
// GetStats return dns filtering stats since startup
func (d *Dnsfilter) GetStats() Stats {
return stats
return gctx.stats
}