diff --git a/config.go b/config.go index 99a6fe9d..6130dff6 100644 --- a/config.go +++ b/config.go @@ -50,7 +50,7 @@ type coreDNSConfig struct { type filter struct { URL string `json:"url"` - Name string `json:"name" yaml:"-"` + Name string `json:"name" yaml:"name"` Enabled bool `json:"enabled"` RulesCount int `json:"rules_count" yaml:"-"` contents []byte diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 195a9218..4b00eca6 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -45,6 +45,16 @@ func init() { }) } +type cacheEntry struct { + answer []dns.RR + lastUpdated time.Time +} + +var ( + lookupCacheTime = time.Minute * 30 + lookupCache = map[string]cacheEntry{} +) + type plugSettings struct { SafeBrowsingBlockHost string ParentalBlockHost string @@ -324,20 +334,29 @@ func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWri records = append(records, result) } else { // this is a domain name, need to look it up - req := new(dns.Msg) - req.SetQuestion(dns.Fqdn(val), question.Qtype) - req.RecursionDesired = true - reqstate := request.Request{W: w, Req: req, Context: ctx} - result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - if result != nil { - for _, answer := range result.Answer { - answer.Header().Name = question.Name + cacheentry := lookupCache[val] + if time.Since(cacheentry.lastUpdated) > lookupCacheTime { + req := new(dns.Msg) + req.SetQuestion(dns.Fqdn(val), question.Qtype) + req.RecursionDesired = true + reqstate := request.Request{W: w, Req: req, Context: ctx} + result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) + if err != nil { + log.Printf("Got error %s\n", err) + return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) } - records = result.Answer + if result != nil { + for _, answer := range result.Answer { + answer.Header().Name = question.Name + } + records = result.Answer + cacheentry.answer = result.Answer + cacheentry.lastUpdated = time.Now() + lookupCache[val] = cacheentry + } + } else { + // get from cache + records = cacheentry.answer } } m := new(dns.Msg) diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index 2ca0eb2d..b72d1719 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -25,16 +25,15 @@ const ( queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours queryLogFileName = "querylog.json" // .gz added during compression queryLogSize = 5000 // maximum API response for /querylog - queryLogCacheTime = time.Minute // if requested more often than this, give out cached response queryLogTopSize = 500 // Keep in memory only top N values queryLogAPIPort = "8618" // 8618 is sha512sum of "querylog" then each byte summed ) var ( logBufferLock sync.RWMutex - logBuffer []logEntry + logBuffer []*logEntry - queryLogCache []logEntry + queryLogCache []*logEntry queryLogLock sync.RWMutex queryLogTime time.Time ) @@ -77,15 +76,22 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela Elapsed: elapsed, IP: ip, } - var flushBuffer []logEntry + var flushBuffer []*logEntry logBufferLock.Lock() - logBuffer = append(logBuffer, entry) + logBuffer = append(logBuffer, &entry) if len(logBuffer) >= logBufferCap { flushBuffer = logBuffer logBuffer = nil } logBufferLock.Unlock() + queryLogLock.Lock() + queryLogCache = append(queryLogCache, &entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() // add it to running top err = runningTop.addEntry(&entry, question, now) @@ -103,26 +109,14 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela } func handleQueryLog(w http.ResponseWriter, r *http.Request) { - now := time.Now() - queryLogLock.RLock() - values := queryLogCache - needRefresh := now.Sub(queryLogTime) >= queryLogCacheTime + values := make([]*logEntry, len(queryLogCache)) + copy(values, queryLogCache) queryLogLock.RUnlock() - if needRefresh { - // need to get fresh data - logBufferLock.RLock() - values = logBuffer - logBufferLock.RUnlock() - - if len(values) < queryLogSize { - values = appendFromLogFile(values, queryLogSize, queryLogTimeLimit) - } - queryLogLock.Lock() - queryLogCache = values - queryLogTime = now - queryLogLock.Unlock() + // reverse it so that newest is first + for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { + values[left], values[right] = values[right], values[left] } var data = []map[string]interface{}{} diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go index 932dc105..aebbee44 100644 --- a/coredns_plugin/querylog_file.go +++ b/coredns_plugin/querylog_file.go @@ -19,7 +19,7 @@ var ( const enableGzip = false -func flushToFile(buffer []logEntry) error { +func flushToFile(buffer []*logEntry) error { if len(buffer) == 0 { return nil } @@ -90,7 +90,7 @@ func flushToFile(buffer []logEntry) error { return nil } -func checkBuffer(buffer []logEntry, b bytes.Buffer) error { +func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { l := len(buffer) d := json.NewDecoder(&b) @@ -237,11 +237,11 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti return nil } -func appendFromLogFile(values []logEntry, maxLen int, timeWindow time.Duration) []logEntry { - a := []logEntry{} +func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { + a := []*logEntry{} onEntry := func(entry *logEntry) error { - a = append(a, *entry) + a = append(a, entry) if len(a) > maxLen { toskip := len(a) - maxLen a = a[toskip:] diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go index d6bb7a6b..92f4ce9f 100644 --- a/coredns_plugin/querylog_top.go +++ b/coredns_plugin/querylog_top.go @@ -223,6 +223,14 @@ func fillStatsFromQueryLog() error { return err } + queryLogLock.Lock() + queryLogCache = append(queryLogCache, entry) + if len(queryLogCache) > queryLogSize { + toremove := len(queryLogCache) - queryLogSize + queryLogCache = queryLogCache[toremove:] + } + queryLogLock.Unlock() + requests.IncWithTime(entry.Time) if entry.Result.IsFiltered { filtered.IncWithTime(entry.Time) diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 1344ea9f..ec2ad5f7 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "time" + _ "github.com/benburkert/dns/init" "github.com/bluele/gcache" "golang.org/x/net/publicsuffix" )