diff --git a/Makefile b/Makefile index 649a1262..864e6361 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,7 @@ AdguardDNS: $(STATIC) *.go coredns: coredns_plugin/*.go dnsfilter/*.go echo mkfile_dir = $(mkfile_dir) go get -v -d github.com/coredns/coredns + cd $(GOPATH)/src/github.com/coredns/coredns && perl -p -i.bak -e 's/^(trace|route53|federation|kubernetes|etcd):.*//' plugin.cfg cd $(GOPATH)/src/github.com/coredns/coredns && grep -q '^dnsfilter:' plugin.cfg || perl -p -i.bak -e 's|^log:log|log:log\ndnsfilter:github.com/AdguardTeam/AdguardDNS/coredns_plugin|' plugin.cfg grep '^dnsfilter:' $(GOPATH)/src/github.com/coredns/coredns/plugin.cfg ## used to check that plugin.cfg was successfully edited by sed cd $(GOPATH)/src/github.com/coredns/coredns && GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) go generate diff --git a/app.go b/app.go index 3ca54175..57e9ea56 100644 --- a/app.go +++ b/app.go @@ -122,6 +122,11 @@ func main() { http.Handle("/", http.FileServer(box)) registerControlHandlers() + err = startDNSServer() + if err != nil { + log.Fatal(err) + } + URL := fmt.Sprintf("http://%s", address) log.Println("Go to " + URL) log.Fatal(http.ListenAndServe(address, nil)) diff --git a/config.go b/config.go index b87429b1..f0923601 100644 --- a/config.go +++ b/config.go @@ -108,11 +108,16 @@ func writeConfig() error { log.Printf("Couldn't generate YAML file: %s", err) return err } - err = ioutil.WriteFile(configfile, yamlText, 0644) + err = ioutil.WriteFile(configfile+".tmp", yamlText, 0644) if err != nil { log.Printf("Couldn't write YAML config: %s", err) return err } + err = os.Rename(configfile+".tmp", configfile) + if err != nil { + log.Printf("Couldn't rename YAML config: %s", err) + return err + } return nil } @@ -127,10 +132,14 @@ func writeCoreDNSConfig() error { log.Printf("Couldn't generate DNS config: %s", err) return err } - err = ioutil.WriteFile(corefile, []byte(configtext), 0644) + err = ioutil.WriteFile(corefile+".tmp", []byte(configtext), 0644) if err != nil { log.Printf("Couldn't write DNS config: %s", err) } + err = os.Rename(corefile+".tmp", corefile) + if err != nil { + log.Printf("Couldn't rename DNS config: %s", err) + } return err } diff --git a/control.go b/control.go index 48e71bf9..70cb9349 100644 --- a/control.go +++ b/control.go @@ -44,7 +44,6 @@ func tellCoreDNSToReload() { log.Printf("os.FindProcess(%d) returned err: %v\n", pid, err) return } - log.Printf("os.FindProcess(%d) returned: %v, %v\n", pid, process, err) err = process.Signal(syscall.SIGUSR1) if err != nil { log.Printf("process.Signal on pid %d returned: %v\n", pid, err) @@ -69,9 +68,10 @@ func isRunning() bool { if err != nil { log.Printf("os.FindProcess(%d) returned err: %v\n", pid, err) } else { - log.Printf("os.FindProcess(%d) returned: %v, %v\n", pid, process, err) err := process.Signal(syscall.Signal(0)) - log.Printf("process.Signal on pid %d returned: %v\n", pid, err) + if err != nil { + log.Printf("process.Signal on pid %d returned: %v\n", pid, err) + } if err == nil { return true } @@ -79,24 +79,22 @@ func isRunning() bool { } return false } -func handleStart(w http.ResponseWriter, r *http.Request) { + +func startDNSServer() error { if isRunning() { - http.Error(w, fmt.Sprintf("Unable to start coreDNS: Already running"), 400) - return + return fmt.Errorf("Unable to start coreDNS: Already running") } err := writeCoreDNSConfig() if err != nil { - errortext := fmt.Sprintf("Unable to write coredns config: %s", err) + errortext := fmt.Errorf("Unable to write coredns config: %s", err) log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return + return errortext } err = writeFilterFile() if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) + errortext := fmt.Errorf("Couldn't write filter file: %s", err) log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return + return errortext } binarypath := filepath.Join(config.ourBinaryDir, config.CoreDNS.binaryFile) configpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) @@ -105,14 +103,27 @@ func handleStart(w http.ResponseWriter, r *http.Request) { coreDNSCommand.Stderr = os.Stderr err = coreDNSCommand.Start() if err != nil { - errortext := fmt.Sprintf("Unable to start coreDNS: %s", err) + errortext := fmt.Errorf("Unable to start coreDNS: %s", err) log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return + return errortext } log.Printf("coredns PID: %v\n", coreDNSCommand.Process.Pid) - fmt.Fprintf(w, "OK, PID %d\n", coreDNSCommand.Process.Pid) go childwaiter() + return nil +} + +func handleStart(w http.ResponseWriter, r *http.Request) { + if isRunning() { + http.Error(w, fmt.Sprintf("Unable to start coreDNS: Already running"), 400) + return + } + err := startDNSServer() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "OK, PID %d\n", coreDNSCommand.Process.Pid) } func childwaiter() { @@ -188,21 +199,23 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleStats(w http.ResponseWriter, r *http.Request) { - snap := &statistics.lastsnap + histrical := generateMapFromStats(&statistics.perMinute, 0, 2) + // sum them up + summed := map[string]interface{}{} + for key, values := range histrical { + summedValue := 0.0 + floats, ok := values.([]float64) + if !ok { + continue + } + for _, v := range floats { + summedValue += v + } + summed[key] = summedValue + } + summed["stats_period"] = "3 minutes" - // generate from last 3 minutes - var last3mins statsSnapshot - last3mins.filteredTotal = snap.filteredTotal - statistics.perMinute.filteredTotal[2] - last3mins.filteredLists = snap.filteredLists - statistics.perMinute.filteredLists[2] - last3mins.filteredSafebrowsing = snap.filteredSafebrowsing - statistics.perMinute.filteredSafebrowsing[2] - last3mins.filteredParental = snap.filteredParental - statistics.perMinute.filteredParental[2] - last3mins.totalRequests = snap.totalRequests - statistics.perMinute.totalRequests[2] - last3mins.processingTimeSum = snap.processingTimeSum - statistics.perMinute.processingTimeSum[2] - last3mins.processingTimeCount = snap.processingTimeCount - statistics.perMinute.processingTimeCount[2] - // rate := computeRate(append([]float64(snap.totalRequests}, statistics.perMinute.totalRequests[0:2]) - - data := generateMapFromSnap(last3mins) - json, err := json.Marshal(data) + json, err := json.Marshal(summed) if err != nil { errortext := fmt.Sprintf("Unable to marshal status json: %s", err) log.Println(errortext) @@ -221,28 +234,29 @@ func handleStats(w http.ResponseWriter, r *http.Request) { func handleStatsHistory(w http.ResponseWriter, r *http.Request) { // handle time unit and prepare our time window size - limitTime := time.Now() - timeUnit := r.URL.Query().Get("time_unit") + now := time.Now() + timeUnitString := r.URL.Query().Get("time_unit") var stats *periodicStats - switch timeUnit { + var timeUnit time.Duration + switch timeUnitString { case "seconds": - limitTime = limitTime.Add(statsHistoryElements * -1 * time.Second) + timeUnit = time.Second stats = &statistics.perSecond case "minutes": - limitTime = limitTime.Add(statsHistoryElements * -1 * time.Minute) + timeUnit = time.Minute stats = &statistics.perMinute case "hours": - limitTime = limitTime.Add(statsHistoryElements * -1 * time.Hour) + timeUnit = time.Hour stats = &statistics.perHour case "days": - limitTime = limitTime.Add(statsHistoryElements * -1 * time.Hour * 24) + timeUnit = time.Hour * 24 stats = &statistics.perDay default: http.Error(w, "Must specify valid time_unit parameter", 400) return } - // check if start time is within supported time range + // parse start and end time startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) if err != nil { errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) @@ -250,12 +264,6 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, 400) return } - if startTime.Before(limitTime) { - http.Error(w, "start_time parameter is outside of supported range", 501) - return - } - - // check if end time is within supported time range endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) if err != nil { errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) @@ -263,28 +271,22 @@ func handleStatsHistory(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, 400) return } - if endTime.Before(limitTime) { + + // check if start and time times are within supported time range + timeRange := timeUnit * statsHistoryElements + if startTime.Add(timeRange).Before(now) { + http.Error(w, "start_time parameter is outside of supported range", 501) + return + } + if endTime.Add(timeRange).Before(now) { http.Error(w, "end_time parameter is outside of supported range", 501) return } - // calculate how which slice range we need to provide - var start int - var end int - switch timeUnit { - case "seconds": - start = int(startTime.Sub(limitTime).Seconds()) - end = int(endTime.Sub(limitTime).Seconds()) - case "minutes": - start = int(startTime.Sub(limitTime).Minutes()) - end = int(endTime.Sub(limitTime).Minutes()) - case "hours": - start = int(startTime.Sub(limitTime).Hours()) - end = int(endTime.Sub(limitTime).Hours()) - case "days": - start = int(startTime.Sub(limitTime).Hours() / 24.0) - end = int(endTime.Sub(limitTime).Hours() / 24.0) - } + // calculate start and end of our array + // basically it's how many hours/minutes/etc have passed since now + start := int(now.Sub(endTime) / timeUnit) + end := int(now.Sub(startTime) / timeUnit) // swap them around if they're inverted if start > end { @@ -840,6 +842,7 @@ func refreshFiltersIfNeccessary() int { errortext := fmt.Sprintf("Couldn't write filter file: %s", err) log.Println(errortext) } + tellCoreDNSToReload() } return updateCount } @@ -918,11 +921,17 @@ func writeFilterFile() error { data = append(data, []byte(rule)...) data = append(data, '\n') } - err := ioutil.WriteFile(filterpath, data, 0644) + err := ioutil.WriteFile(filterpath+".tmp", data, 0644) if err != nil { log.Printf("Couldn't write filter file: %s", err) return err } + + err = os.Rename(filterpath+".tmp", filterpath) + if err != nil { + log.Printf("Couldn't rename filter file: %s", err) + return err + } return nil } diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index a2347cb7..7f922d0e 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -132,7 +132,7 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { } case "querylog": d.QueryLogEnabled = true - once.Do(func() { + onceQueryLog.Do(func() { go startQueryLogServer() // TODO: how to handle errors? }) } @@ -145,6 +145,7 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { } defer file.Close() + count := 0 scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() @@ -158,7 +159,9 @@ func setupPlugin(c *caddy.Controller) (*Plugin, error) { if err != nil { return nil, err } + count++ } + log.Printf("Added %d rules from %s", count, filterFileName) if err = scanner.Err(); err != nil { return nil, err @@ -184,23 +187,21 @@ func setup(c *caddy.Controller) error { }) c.OnStartup(func() error { - once.Do(func() { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(requests) - x.MustRegister(filtered) - x.MustRegister(filteredLists) - x.MustRegister(filteredSafebrowsing) - x.MustRegister(filteredParental) - x.MustRegister(whitelisted) - x.MustRegister(safesearch) - x.MustRegister(errorsTotal) - x.MustRegister(d) - } - }) + m := dnsserver.GetConfig(c).Handler("prometheus") + if m == nil { + return nil + } + if x, ok := m.(*metrics.Metrics); ok { + x.MustRegister(requests) + x.MustRegister(filtered) + x.MustRegister(filteredLists) + x.MustRegister(filteredSafebrowsing) + x.MustRegister(filteredParental) + x.MustRegister(whitelisted) + x.MustRegister(safesearch) + x.MustRegister(errorsTotal) + x.MustRegister(d) + } return nil }) c.OnShutdown(d.OnShutdown) @@ -410,39 +411,44 @@ func (d *Plugin) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r * return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err), dnsfilter.Result{} } - // safebrowsing - if result.IsFiltered == true && result.Reason == dnsfilter.FilteredSafeBrowsing { - // return cname safebrowsing.block.dns.adguard.com - val := d.SafeBrowsingBlockHost - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, err, dnsfilter.Result{} + if result.IsFiltered { + switch result.Reason { + case dnsfilter.FilteredSafeBrowsing: + // return cname safebrowsing.block.dns.adguard.com + val := d.SafeBrowsingBlockHost + rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if err != nil { + return rcode, err, dnsfilter.Result{} + } + return rcode, err, result + case dnsfilter.FilteredParental: + // return cname family.block.dns.adguard.com + val := d.ParentalBlockHost + rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) + if err != nil { + return rcode, err, dnsfilter.Result{} + } + return rcode, err, result + case dnsfilter.FilteredBlackList: + // return NXdomain + rcode, err := writeNXdomain(ctx, w, r) + if err != nil { + return rcode, err, dnsfilter.Result{} + } + return rcode, err, result + default: + log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering: %T %v %s", result.Reason, result.Reason, result.Reason.String()) } - return rcode, err, result - } - - // parental - if result.IsFiltered == true && result.Reason == dnsfilter.FilteredParental { - // return cname - val := d.ParentalBlockHost - rcode, err := d.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, err, dnsfilter.Result{} + } else { + switch result.Reason { + case dnsfilter.NotFilteredWhiteList: + rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) + return rcode, err, result + case dnsfilter.NotFilteredNotFound: + // do nothing, pass through to lower code + default: + log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering: %T %v %s", result.Reason, result.Reason, result.Reason.String()) } - return rcode, err, result - } - - // blacklist - if result.IsFiltered == true && result.Reason == dnsfilter.FilteredBlackList { - rcode, err := writeNXdomain(ctx, w, r) - if err != nil { - return rcode, err, dnsfilter.Result{} - } - return rcode, err, result - } - if result.IsFiltered == false && result.Reason == dnsfilter.NotFilteredWhiteList { - rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) - return rcode, err, result } } rcode, err := plugin.NextOrFailure(d.Name(), d.Next, ctx, w, r) @@ -498,11 +504,11 @@ func (d *Plugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) // log if d.QueryLogEnabled { - logRequest(rrw.Msg, result, time.Since(start), ip) + logRequest(r, rrw.Msg, result, time.Since(start), ip) } return rcode, err } func (d *Plugin) Name() string { return "dnsfilter" } -var once sync.Once +var onceQueryLog sync.Once diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go index 4fdcd345..a29388fb 100644 --- a/coredns_plugin/querylog.go +++ b/coredns_plugin/querylog.go @@ -19,24 +19,26 @@ import ( var logBuffer = ring.Ring{} type logEntry struct { - R *dns.Msg - Result dnsfilter.Result - Time time.Time - Elapsed time.Duration - IP string + Question *dns.Msg + Answer *dns.Msg + Result dnsfilter.Result + Time time.Time + Elapsed time.Duration + IP string } func init() { logBuffer.SetCapacity(1000) } -func logRequest(r *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { +func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { entry := logEntry{ - R: r, - Result: result, - Time: time.Now(), - Elapsed: elapsed, - IP: ip, + Question: question, + Answer: answer, + Result: result, + Time: time.Now(), + Elapsed: elapsed, + IP: ip, } logBuffer.Enqueue(entry) } @@ -57,21 +59,21 @@ func handler(w http.ResponseWriter, r *http.Request) { "client": entry.IP, } question := map[string]interface{}{ - "host": strings.ToLower(strings.TrimSuffix(entry.R.Question[0].Name, ".")), - "type": dns.Type(entry.R.Question[0].Qtype).String(), - "class": dns.Class(entry.R.Question[0].Qclass).String(), + "host": strings.ToLower(strings.TrimSuffix(entry.Question.Question[0].Name, ".")), + "type": dns.Type(entry.Question.Question[0].Qtype).String(), + "class": dns.Class(entry.Question.Question[0].Qclass).String(), } jsonentry["question"] = question - status, _ := response.Typify(entry.R, time.Now().UTC()) + status, _ := response.Typify(entry.Answer, time.Now().UTC()) jsonentry["status"] = status.String() if len(entry.Result.Rule) > 0 { jsonentry["rule"] = entry.Result.Rule } - if len(entry.R.Answer) > 0 { + if entry.Answer != nil && len(entry.Answer.Answer) > 0 { var answers = []map[string]interface{}{} - for _, k := range entry.R.Answer { + for _, k := range entry.Answer.Answer { header := k.Header() answer := map[string]interface{}{ "type": dns.TypeToString[header.Rrtype], diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 8f8dda7e..387a58a1 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -385,43 +385,44 @@ var regexRules = []string{"/example\\.org/", "@@||test.example.org^"} var maskRules = []string{"test*.example.org^", "exam*.com"} var tests = []struct { - testname string - rules []string - hostname string - result bool + testname string + rules []string + hostname string + isFiltered bool + reason Reason }{ - {"sanity", []string{"||doubleclick.net^"}, "www.doubleclick.net", true}, - {"sanity", []string{"||doubleclick.net^"}, "nodoubleclick.net", false}, - {"sanity", []string{"||doubleclick.net^"}, "doubleclick.net.ru", false}, - {"sanity", []string{"||doubleclick.net^"}, "wmconvirus.narod.ru", false}, - {"blocking", blockingRules, "example.org", true}, - {"blocking", blockingRules, "test.example.org", true}, - {"blocking", blockingRules, "test.test.example.org", true}, - {"blocking", blockingRules, "testexample.org", false}, - {"blocking", blockingRules, "onemoreexample.org", false}, - {"whitelist", whitelistRules, "example.org", true}, - {"whitelist", whitelistRules, "test.example.org", false}, - {"whitelist", whitelistRules, "test.test.example.org", false}, - {"whitelist", whitelistRules, "testexample.org", false}, - {"whitelist", whitelistRules, "onemoreexample.org", false}, - {"important", importantRules, "example.org", false}, - {"important", importantRules, "test.example.org", true}, - {"important", importantRules, "test.test.example.org", true}, - {"important", importantRules, "testexample.org", false}, - {"important", importantRules, "onemoreexample.org", false}, - {"regex", regexRules, "example.org", true}, - {"regex", regexRules, "test.example.org", false}, - {"regex", regexRules, "test.test.example.org", false}, - {"regex", regexRules, "testexample.org", true}, - {"regex", regexRules, "onemoreexample.org", true}, - {"mask", maskRules, "test.example.org", true}, - {"mask", maskRules, "test2.example.org", true}, - {"mask", maskRules, "example.com", true}, - {"mask", maskRules, "exampleeee.com", true}, - {"mask", maskRules, "onemoreexamsite.com", true}, - {"mask", maskRules, "example.org", false}, - {"mask", maskRules, "testexample.org", false}, - {"mask", maskRules, "example.co.uk", false}, + {"sanity", []string{"||doubleclick.net^"}, "www.doubleclick.net", true, FilteredBlackList}, + {"sanity", []string{"||doubleclick.net^"}, "nodoubleclick.net", false, NotFilteredNotFound}, + {"sanity", []string{"||doubleclick.net^"}, "doubleclick.net.ru", false, NotFilteredNotFound}, + {"sanity", []string{"||doubleclick.net^"}, "wmconvirus.narod.ru", false, NotFilteredNotFound}, + {"blocking", blockingRules, "example.org", true, FilteredBlackList}, + {"blocking", blockingRules, "test.example.org", true, FilteredBlackList}, + {"blocking", blockingRules, "test.test.example.org", true, FilteredBlackList}, + {"blocking", blockingRules, "testexample.org", false, NotFilteredNotFound}, + {"blocking", blockingRules, "onemoreexample.org", false, NotFilteredNotFound}, + {"whitelist", whitelistRules, "example.org", true, FilteredBlackList}, + {"whitelist", whitelistRules, "test.example.org", false, NotFilteredWhiteList}, + {"whitelist", whitelistRules, "test.test.example.org", false, NotFilteredWhiteList}, + {"whitelist", whitelistRules, "testexample.org", false, NotFilteredNotFound}, + {"whitelist", whitelistRules, "onemoreexample.org", false, NotFilteredNotFound}, + {"important", importantRules, "example.org", false, NotFilteredWhiteList}, + {"important", importantRules, "test.example.org", true, FilteredBlackList}, + {"important", importantRules, "test.test.example.org", true, FilteredBlackList}, + {"important", importantRules, "testexample.org", false, NotFilteredNotFound}, + {"important", importantRules, "onemoreexample.org", false, NotFilteredNotFound}, + {"regex", regexRules, "example.org", true, FilteredBlackList}, + {"regex", regexRules, "test.example.org", false, NotFilteredWhiteList}, + {"regex", regexRules, "test.test.example.org", false, NotFilteredWhiteList}, + {"regex", regexRules, "testexample.org", true, FilteredBlackList}, + {"regex", regexRules, "onemoreexample.org", true, FilteredBlackList}, + {"mask", maskRules, "test.example.org", true, FilteredBlackList}, + {"mask", maskRules, "test2.example.org", true, FilteredBlackList}, + {"mask", maskRules, "example.com", true, FilteredBlackList}, + {"mask", maskRules, "exampleeee.com", true, FilteredBlackList}, + {"mask", maskRules, "onemoreexamsite.com", true, FilteredBlackList}, + {"mask", maskRules, "example.org", false, NotFilteredNotFound}, + {"mask", maskRules, "testexample.org", false, NotFilteredNotFound}, + {"mask", maskRules, "example.co.uk", false, NotFilteredNotFound}, } func TestMatching(t *testing.T) { @@ -439,8 +440,11 @@ func TestMatching(t *testing.T) { if err != nil { t.Errorf("Error while matching host %s: %s", test.hostname, err) } - if ret.IsFiltered != test.result { - t.Errorf("Hostname %s has wrong result (%v must be %v)", test.hostname, ret, test.result) + if ret.IsFiltered != test.isFiltered { + t.Errorf("Hostname %s has wrong result (%v must be %v)", test.hostname, ret.IsFiltered, test.isFiltered) + } + if ret.Reason != test.reason { + t.Errorf("Hostname %s has wrong reason (%v must be %v)", test.hostname, ret.Reason.String(), test.reason.String()) } }) } diff --git a/helpers.go b/helpers.go index c3ede134..a460f98f 100644 --- a/helpers.go +++ b/helpers.go @@ -51,33 +51,14 @@ func ensureDELETE(handler func(http.ResponseWriter, *http.Request)) func(http.Re // -------------------------- // helper functions for stats // -------------------------- -func computeRate(input []float64) []float64 { +func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { output := make([]float64, 0) - for i := len(input) - 2; i >= 0; i-- { - value := input[i] - diff := value - input[i+1] - output = append([]float64{diff}, output...) + for i := start; i <= end; i++ { + output = append([]float64{input[i]}, output...) } return output } -func generateMapFromSnap(snap statsSnapshot) map[string]interface{} { - var avgProcessingTime float64 - if snap.processingTimeCount > 0 { - avgProcessingTime = snap.processingTimeSum / snap.processingTimeCount - } - - result := map[string]interface{}{ - "dns_queries": snap.totalRequests, - "blocked_filtering": snap.filteredLists, - "replaced_safebrowsing": snap.filteredSafebrowsing, - "replaced_safesearch": snap.filteredSafesearch, - "replaced_parental": snap.filteredParental, - "avg_processing_time": avgProcessingTime, - } - return result -} - func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { // clamp start = clamp(start, 0, statsHistoryElements) @@ -85,8 +66,8 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i avgProcessingTime := make([]float64, 0) - count := computeRate(stats.processingTimeCount[start:end]) - sum := computeRate(stats.processingTimeSum[start:end]) + count := getReversedSlice(stats.entries[processingTimeCount], start, end) + sum := getReversedSlice(stats.entries[processingTimeSum], start, end) for i := 0; i < len(count); i++ { var avg float64 if count[i] != 0 { @@ -97,11 +78,11 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i } result := map[string]interface{}{ - "dns_queries": computeRate(stats.totalRequests[start:end]), - "blocked_filtering": computeRate(stats.filteredLists[start:end]), - "replaced_safebrowsing": computeRate(stats.filteredSafebrowsing[start:end]), - "replaced_safesearch": computeRate(stats.filteredSafesearch[start:end]), - "replaced_parental": computeRate(stats.filteredParental[start:end]), + "dns_queries": getReversedSlice(stats.entries[totalRequests], start, end), + "blocked_filtering": getReversedSlice(stats.entries[filteredLists], start, end), + "replaced_safebrowsing": getReversedSlice(stats.entries[filteredSafebrowsing], start, end), + "replaced_safesearch": getReversedSlice(stats.entries[filteredSafesearch], start, end), + "replaced_parental": getReversedSlice(stats.entries[filteredParental], start, end), "avg_processing_time": avgProcessingTime, } return result diff --git a/stats.go b/stats.go index eef38b72..4829cd0c 100644 --- a/stats.go +++ b/stats.go @@ -8,70 +8,50 @@ import ( "net/http" "net/url" "os" - "regexp" "strconv" "strings" "syscall" "time" ) -type periodicStats struct { - totalRequests []float64 - - filteredTotal []float64 - filteredLists []float64 - filteredSafebrowsing []float64 - filteredSafesearch []float64 - filteredParental []float64 - - processingTimeSum []float64 - processingTimeCount []float64 - - lastRotate time.Time // last time this data was rotated -} - -type statsSnapshot struct { - totalRequests float64 - - filteredTotal float64 - filteredLists float64 - filteredSafebrowsing float64 - filteredSafesearch float64 - filteredParental float64 - - processingTimeSum float64 - processingTimeCount float64 -} - -type statsCollection struct { - perSecond periodicStats - perMinute periodicStats - perHour periodicStats - perDay periodicStats - lastsnap statsSnapshot -} - -var statistics statsCollection - var client = &http.Client{ Timeout: time.Second * 30, } -const statsHistoryElements = 60 + 1 // +1 for calculating delta +// as seen over HTTP +type statsEntry map[string]float64 +type statsEntries map[string][statsHistoryElements]float64 -var requestCountTotalRegex = regexp.MustCompile(`^coredns_dns_request_count_total`) -var requestDurationSecondsSum = regexp.MustCompile(`^coredns_dns_request_duration_seconds_sum`) -var requestDurationSecondsCount = regexp.MustCompile(`^coredns_dns_request_duration_seconds_count`) +const ( + statsHistoryElements = 60 + 1 // +1 for calculating delta + totalRequests = `coredns_dns_request_count_total` + filteredTotal = `coredns_dnsfilter_filtered_total` + filteredLists = `coredns_dnsfilter_filtered_lists_total` + filteredSafebrowsing = `coredns_dnsfilter_filtered_safebrowsing_total` + filteredSafesearch = `coredns_dnsfilter_safesearch_total` + filteredParental = `coredns_dnsfilter_filtered_parental_total` + processingTimeSum = `coredns_dns_request_duration_seconds_sum` + processingTimeCount = `coredns_dns_request_duration_seconds_count` +) -func initPeriodicStats(stats *periodicStats) { - stats.totalRequests = make([]float64, statsHistoryElements) - stats.filteredTotal = make([]float64, statsHistoryElements) - stats.filteredLists = make([]float64, statsHistoryElements) - stats.filteredSafebrowsing = make([]float64, statsHistoryElements) - stats.filteredSafesearch = make([]float64, statsHistoryElements) - stats.filteredParental = make([]float64, statsHistoryElements) - stats.processingTimeSum = make([]float64, statsHistoryElements) - stats.processingTimeCount = make([]float64, statsHistoryElements) +type periodicStats struct { + entries statsEntries + lastRotate time.Time // last time this data was rotated +} + +type stats struct { + perSecond periodicStats + perMinute periodicStats + perHour periodicStats + perDay periodicStats + + lastSeen statsEntry +} + +var statistics stats + +func initPeriodicStats(periodic *periodicStats) { + periodic.entries = statsEntries{} } func init() { @@ -106,37 +86,22 @@ func isConnRefused(err error) bool { return false } -func sliceRotate(slice *[]float64) { - a := (*slice)[:len(*slice)-1] - *slice = append([]float64{0}, a...) -} - -func statsRotate(stats *periodicStats, now time.Time) { - sliceRotate(&stats.totalRequests) - sliceRotate(&stats.filteredTotal) - sliceRotate(&stats.filteredLists) - sliceRotate(&stats.filteredSafebrowsing) - sliceRotate(&stats.filteredSafesearch) - sliceRotate(&stats.filteredParental) - sliceRotate(&stats.processingTimeSum) - sliceRotate(&stats.processingTimeCount) - stats.lastRotate = now -} - -func handleValue(input string, target *float64) { - value, err := strconv.ParseFloat(input, 64) - if err != nil { - log.Println("Failed to parse number input:", err) - return +func statsRotate(periodic *periodicStats, now time.Time) { + for key, values := range periodic.entries { + newValues := [statsHistoryElements]float64{} + for i := 1; i < len(values); i++ { + newValues[i] = values[i-1] + } + periodic.entries[key] = newValues } - *target = value + periodic.lastRotate = now } // called every second, accumulates stats for each second, minute, hour and day func collectStats() { now := time.Now() // rotate each second - // NOTE: since we are called every second, always rotate, otherwise aliasing problems cause the rotation to skip + // NOTE: since we are called every second, always rotate perSecond, otherwise aliasing problems cause the rotation to skip if true { statsRotate(&statistics.perSecond, now) } @@ -172,6 +137,8 @@ func collectStats() { return } + entry := statsEntry{} + // handle body scanner := bufio.NewScanner(strings.NewReader(string(body))) for scanner.Scan() { @@ -181,38 +148,61 @@ func collectStats() { continue } splitted := strings.Split(line, " ") - switch { - case splitted[0] == "coredns_dnsfilter_filtered_total": - handleValue(splitted[1], &statistics.lastsnap.filteredTotal) - case splitted[0] == "coredns_dnsfilter_filtered_lists_total": - handleValue(splitted[1], &statistics.lastsnap.filteredLists) - case splitted[0] == "coredns_dnsfilter_filtered_safebrowsing_total": - handleValue(splitted[1], &statistics.lastsnap.filteredSafebrowsing) - case splitted[0] == "coredns_dnsfilter_filtered_parental_total": - handleValue(splitted[1], &statistics.lastsnap.filteredParental) - case requestCountTotalRegex.MatchString(splitted[0]): - handleValue(splitted[1], &statistics.lastsnap.totalRequests) - case requestDurationSecondsSum.MatchString(splitted[0]): - handleValue(splitted[1], &statistics.lastsnap.processingTimeSum) - case requestDurationSecondsCount.MatchString(splitted[0]): - handleValue(splitted[1], &statistics.lastsnap.processingTimeCount) + if len(splitted) < 2 { + continue } + + value, err := strconv.ParseFloat(splitted[1], 64) + if err != nil { + log.Printf("Failed to parse number input %s: %s", splitted[1], err) + continue + } + + key := splitted[0] + index := strings.IndexByte(key, '{') + if index >= 0 { + key = key[:index] + } + + // empty keys are not ok + if key == "" { + continue + } + + got, ok := entry[key] + if ok { + value += got + } + entry[key] = value } - // put the snap into per-second, per-minute, per-hour and per-day - assignSnapToStats(&statistics.perSecond) - assignSnapToStats(&statistics.perMinute) - assignSnapToStats(&statistics.perHour) - assignSnapToStats(&statistics.perDay) + // calculate delta + delta := calcDelta(entry, statistics.lastSeen) + + // apply delta to second/minute/hour/day + applyDelta(&statistics.perSecond, delta) + applyDelta(&statistics.perMinute, delta) + applyDelta(&statistics.perHour, delta) + applyDelta(&statistics.perDay, delta) + + // save last seen + statistics.lastSeen = entry } -func assignSnapToStats(stats *periodicStats) { - stats.totalRequests[0] = statistics.lastsnap.totalRequests - stats.filteredTotal[0] = statistics.lastsnap.filteredTotal - stats.filteredLists[0] = statistics.lastsnap.filteredLists - stats.filteredSafebrowsing[0] = statistics.lastsnap.filteredSafebrowsing - stats.filteredSafesearch[0] = statistics.lastsnap.filteredSafesearch - stats.filteredParental[0] = statistics.lastsnap.filteredParental - stats.processingTimeSum[0] = statistics.lastsnap.processingTimeSum - stats.processingTimeCount[0] = statistics.lastsnap.processingTimeCount +func calcDelta(current, seen statsEntry) statsEntry { + delta := statsEntry{} + for key, currentValue := range current { + seenValue := seen[key] + deltaValue := currentValue - seenValue + delta[key] = deltaValue + } + return delta +} + +func applyDelta(current *periodicStats, delta statsEntry) { + for key, deltaValue := range delta { + currentValues := current.entries[key] + currentValues[0] += deltaValue + current.entries[key] = currentValues + } }