diff --git a/.gometalinter.json b/.gometalinter.json index 2c3f557e..10de7f42 100644 --- a/.gometalinter.json +++ b/.gometalinter.json @@ -20,7 +20,8 @@ "DisableAll": false, "Disable": [ "maligned", - "goconst" + "goconst", + "vetshadow" ], "Cyclo": 20, diff --git a/README.md b/README.md index f1b80df1..52b678db 100644 --- a/README.md +++ b/README.md @@ -133,13 +133,14 @@ Usage: ./AdGuardHome [options] Options: - -c, --config path to config file + -c, --config path to the config file + -w, --work-dir path to the working directory -h, --host host address to bind HTTP server on -p, --port port to serve HTTP pages on -s, --service service control action: status, install, uninstall, start, stop, restart -l, --logfile path to the log file. If empty, writes to stdout, if 'syslog' -- system log -v, --verbose enable verbose output - --help print this help + --help print this help ``` Please note, that the command-line arguments override settings from the configuration file. diff --git a/app.go b/app.go index e7ca88c8..61870582 100644 --- a/app.go +++ b/app.go @@ -66,7 +66,7 @@ func run(args options) { // print the first message after logger is configured log.Printf("AdGuard Home, version %s\n", VersionString) - log.Tracef("Current working directory is %s", config.ourBinaryDir) + log.Tracef("Current working directory is %s", config.ourWorkingDir) if args.runningAsService { log.Printf("AdGuard Home is running as a service") } @@ -106,6 +106,7 @@ func run(args options) { log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err) // clear LastUpdated so it gets fetched right away } + if len(filter.Rules) == 0 { filter.LastUpdated = time.Time{} } @@ -117,6 +118,10 @@ func run(args options) { log.Fatal(err) } + // Init the DNS server instance before registering HTTP handlers + dnsBaseDir := filepath.Join(config.ourWorkingDir, dataDir) + initDNSServer(dnsBaseDir) + if !config.firstRun { err = startDNSServer() if err != nil { @@ -172,18 +177,19 @@ func run(args options) { } } -// initWorkingDir initializes the ourBinaryDir (basically, we use it as a working dir) +// initWorkingDir initializes the ourWorkingDir +// if no command-line arguments specified, we use the directory where our binary file is located func initWorkingDir(args options) { exec, err := os.Executable() if err != nil { panic(err) } - if args.configFilename != "" { + if args.workDir != "" { // If there is a custom config file, use it's directory as our working dir - config.ourBinaryDir = filepath.Dir(args.configFilename) + config.ourWorkingDir = args.workDir } else { - config.ourBinaryDir = filepath.Dir(exec) + config.ourWorkingDir = filepath.Dir(exec) } } @@ -218,7 +224,7 @@ func configureLogger(args options) { log.Fatalf("cannot initialize syslog: %s", err) } } else { - logFilePath := filepath.Join(config.ourBinaryDir, ls.LogFile) + logFilePath := filepath.Join(config.ourWorkingDir, ls.LogFile) file, err := os.OpenFile(logFilePath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0755) if err != nil { log.Fatalf("cannot create a log file: %s", err) @@ -244,6 +250,7 @@ func cleanup() { type options struct { verbose bool // is verbose logging enabled configFilename string // path to the config file + workDir string // path to the working directory where we will store the filters data and the querylog bindHost string // host address to bind HTTP server on bindPort int // port to serve HTTP pages on logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog @@ -267,7 +274,8 @@ func loadOptions() options { callbackWithValue func(value string) callbackNoValue func() }{ - {"config", "c", "path to config file", func(value string) { o.configFilename = value }, nil}, + {"config", "c", "path to the config file", func(value string) { o.configFilename = value }, nil}, + {"work-dir", "w", "path to the working directory", func(value string) { o.workDir = value }, nil}, {"host", "h", "host address to bind HTTP server on", func(value string) { o.bindHost = value }, nil}, {"port", "p", "port to serve HTTP pages on", func(value string) { v, err := strconv.Atoi(value) diff --git a/config.go b/config.go index 751e3c41..0e680a1b 100644 --- a/config.go +++ b/config.go @@ -28,7 +28,7 @@ type logSettings struct { // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { ourConfigFilename string // Config filename (can be overridden via the command line arguments) - ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else + ourWorkingDir string // Location of our directory, used to protect against CWD being somewhere else firstRun bool // if set to true, don't run any services except HTTP web inteface, and serve only first-run html BindHost string `yaml:"bind_host"` @@ -92,7 +92,7 @@ var config = configuration{ func (c *configuration) getConfigFilename() string { configFile := config.ourConfigFilename if !filepath.IsAbs(configFile) { - configFile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + configFile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename) } return configFile } @@ -115,7 +115,7 @@ func getLogSettings() logSettings { // parseConfig loads configuration from the YAML file func parseConfig() error { configFile := config.getConfigFilename() - log.Tracef("Reading YAML file: %s", configFile) + log.Printf("Reading config file: %s", configFile) yamlFile, err := readConfigFile() if err != nil { log.Printf("Couldn't read config file: %s", err) diff --git a/control.go b/control.go index 2077549d..35373c2c 100644 --- a/control.go +++ b/control.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" @@ -8,6 +9,7 @@ import ( "net" "net/http" "os" + "sort" "strconv" "strings" "time" @@ -32,9 +34,28 @@ var client = &http.Client{ Timeout: time.Second * 30, } -// ------------------- +// ---------------- +// helper functions +// ---------------- + +func returnOK(w http.ResponseWriter) { + _, err := fmt.Fprintf(w, "OK\n") + if err != nil { + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } +} + +func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { + text := fmt.Sprintf(format, args...) + log.Println(text) + http.Error(w, text, code) +} + +// --------------- // dns run control -// ------------------- +// --------------- func writeAllConfigsAndReloadDNS() error { err := writeAllConfigs() if err != nil { @@ -55,15 +76,6 @@ func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { returnOK(w) } -func returnOK(w http.ResponseWriter) { - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - errorText := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - func handleStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "dns_address": config.BindHost, @@ -117,12 +129,190 @@ func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { httpUpdateConfigReloadDNSReturnOK(w, r) } -func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) { - text := fmt.Sprintf(format, args...) - log.Println(text) - http.Error(w, text, code) +func handleQueryLog(w http.ResponseWriter, r *http.Request) { + data := dnsServer.GetQueryLog() + + jsonVal, err := json.Marshal(data) + if err != nil { + errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonVal) + if err != nil { + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } } +func handleStatsTop(w http.ResponseWriter, r *http.Request) { + s := dnsServer.GetStatsTop() + + // use manual json marshalling because we want maps to be sorted by value + statsJSON := bytes.Buffer{} + statsJSON.WriteString("{\n") + + gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", name)) + json.WriteString(": {\n") + sorted := sortByValue(top) + // no more than 50 entries + if len(sorted) > 50 { + sorted = sorted[:50] + } + for i, key := range sorted { + json.WriteString(" ") + json.WriteString(fmt.Sprintf("%q", key)) + json.WriteString(": ") + json.WriteString(strconv.Itoa(top[key])) + if i+1 != len(sorted) { + json.WriteByte(',') + } + json.WriteByte('\n') + } + json.WriteString(" }") + if addComma { + json.WriteByte(',') + } + json.WriteByte('\n') + } + gen(&statsJSON, "top_queried_domains", s.Domains, true) + gen(&statsJSON, "top_blocked_domains", s.Blocked, true) + gen(&statsJSON, "top_clients", s.Clients, true) + statsJSON.WriteString(" \"stats_period\": \"24 hours\"\n") + statsJSON.WriteString("}\n") + + w.Header().Set("Content-Type", "application/json") + _, err := w.Write(statsJSON.Bytes()) + if err != nil { + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } +} + +// handleStatsReset resets the stats caches +func handleStatsReset(w http.ResponseWriter, r *http.Request) { + dnsServer.PurgeStats() + _, err := fmt.Fprintf(w, "OK\n") + if err != nil { + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + } +} + +// handleStats returns aggregated stats data for the 24 hours +func handleStats(w http.ResponseWriter, r *http.Request) { + summed := dnsServer.GetAggregatedStats() + + statsJSON, err := json.Marshal(summed) + if err != nil { + errorText := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errorText) + http.Error(w, errorText, 500) + return + } + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(statsJSON) + if err != nil { + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, 500) + return + } +} + +// HandleStatsHistory returns historical stats data for the 24 hours +func handleStatsHistory(w http.ResponseWriter, r *http.Request) { + // handle time unit and prepare our time window size + timeUnitString := r.URL.Query().Get("time_unit") + var timeUnit time.Duration + switch timeUnitString { + case "seconds": + timeUnit = time.Second + case "minutes": + timeUnit = time.Minute + case "hours": + timeUnit = time.Hour + case "days": + timeUnit = time.Hour * 24 + default: + http.Error(w, "Must specify valid time_unit parameter", http.StatusBadRequest) + return + } + + // parse start and end time + startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) + if err != nil { + errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) + return + } + endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) + if err != nil { + errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) + return + } + + data, err := dnsServer.GetStatsHistory(timeUnit, startTime, endTime) + if err != nil { + errorText := fmt.Sprintf("Cannot get stats history: %s", err) + http.Error(w, errorText, http.StatusBadRequest) + return + } + + statsJSON, err := json.Marshal(data) + if err != nil { + errorText := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(statsJSON) + if err != nil { + errorText := fmt.Sprintf("Unable to write response json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) + return + } +} + +// sortByValue is a helper function for querylog API +func sortByValue(m map[string]int) []string { + type kv struct { + k string + v int + } + var ss []kv + for k, v := range m { + ss = append(ss, kv{k, v}) + } + sort.Slice(ss, func(l, r int) bool { + return ss[l].v > ss[r].v + }) + + sorted := []string{} + for _, v := range ss { + sorted = append(sorted, v.k) + } + return sorted +} + +// ----------------------- +// upstreams configuration +// ----------------------- + func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -737,8 +927,8 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { data.Interfaces = make(map[string]interface{}) for _, iface := range ifaces { - addrs, err := iface.Addrs() - if err != nil { + addrs, e := iface.Addrs() + if e != nil { httpError(w, http.StatusInternalServerError, "Failed to get addresses for interface %s: %s", iface.Name, err) return } @@ -844,17 +1034,17 @@ func registerControlHandlers() { http.HandleFunc("/control/status", postInstall(optionalAuth(ensureGET(handleStatus)))) http.HandleFunc("/control/enable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionEnable)))) http.HandleFunc("/control/disable_protection", postInstall(optionalAuth(ensurePOST(handleProtectionDisable)))) - http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(dnsforward.HandleQueryLog)))) + http.HandleFunc("/control/querylog", postInstall(optionalAuth(ensureGET(handleQueryLog)))) http.HandleFunc("/control/querylog_enable", postInstall(optionalAuth(ensurePOST(handleQueryLogEnable)))) http.HandleFunc("/control/querylog_disable", postInstall(optionalAuth(ensurePOST(handleQueryLogDisable)))) http.HandleFunc("/control/set_upstream_dns", postInstall(optionalAuth(ensurePOST(handleSetUpstreamDNS)))) http.HandleFunc("/control/test_upstream_dns", postInstall(optionalAuth(ensurePOST(handleTestUpstreamDNS)))) http.HandleFunc("/control/i18n/change_language", postInstall(optionalAuth(ensurePOST(handleI18nChangeLanguage)))) http.HandleFunc("/control/i18n/current_language", postInstall(optionalAuth(ensureGET(handleI18nCurrentLanguage)))) - http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsTop)))) - http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(dnsforward.HandleStats)))) - http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(dnsforward.HandleStatsHistory)))) - http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(dnsforward.HandleStatsReset)))) + http.HandleFunc("/control/stats_top", postInstall(optionalAuth(ensureGET(handleStatsTop)))) + http.HandleFunc("/control/stats", postInstall(optionalAuth(ensureGET(handleStats)))) + http.HandleFunc("/control/stats_history", postInstall(optionalAuth(ensureGET(handleStatsHistory)))) + http.HandleFunc("/control/stats_reset", postInstall(optionalAuth(ensurePOST(handleStatsReset)))) http.HandleFunc("/control/version.json", postInstall(optionalAuth(handleGetVersionJSON))) http.HandleFunc("/control/filtering/enable", postInstall(optionalAuth(ensurePOST(handleFilteringEnable)))) http.HandleFunc("/control/filtering/disable", postInstall(optionalAuth(ensurePOST(handleFilteringDisable)))) diff --git a/dns.go b/dns.go index 12c71def..3e800892 100644 --- a/dns.go +++ b/dns.go @@ -3,6 +3,7 @@ package main import ( "fmt" "net" + "os" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" @@ -11,10 +12,22 @@ import ( "github.com/joomcode/errorx" ) -var dnsServer = dnsforward.Server{} +var dnsServer *dnsforward.Server + +// initDNSServer creates an instance of the dnsforward.Server +// Please note that we must do it even if we don't start it +// so that we had access to the query log and the stats +func initDNSServer(baseDir string) { + err := os.MkdirAll(baseDir, 0755) + if err != nil { + log.Fatalf("Cannot create DNS data dir at %s: %s", baseDir, err) + } + + dnsServer = dnsforward.NewServer(baseDir) +} func isRunning() bool { - return dnsServer.IsRunning() + return dnsServer != nil && dnsServer.IsRunning() } func generateServerConfig() dnsforward.ServerConfig { diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index b2cf0556..e1006f83 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -35,14 +35,25 @@ const ( // // The zero Server is empty and ready for use. type Server struct { - dnsProxy *proxy.Proxy // DNS proxy instance - + dnsProxy *proxy.Proxy // DNS proxy instance dnsFilter *dnsfilter.Dnsfilter // DNS filter instance + queryLog *queryLog // Query log instance + stats *stats // General server statistics + once sync.Once sync.RWMutex ServerConfig } +// NewServer creates a new instance of the dnsforward.Server +// baseDir is the base directory for query logs +func NewServer(baseDir string) *Server { + return &Server{ + queryLog: newQueryLog(baseDir), + stats: newStats(), + } +} + // FilteringConfig represents the DNS filtering configuration of AdGuard Home type FilteringConfig struct { ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features @@ -105,21 +116,31 @@ func (s *Server) startInternal(config *ServerConfig) error { return errors.New("DNS server is already started") } + if s.queryLog == nil { + s.queryLog = newQueryLog(".") + } + + if s.stats == nil { + s.stats = newStats() + } + err := s.initDNSFilter() if err != nil { return err } log.Tracef("Loading stats from querylog") - err = fillStatsFromQueryLog() + err = s.queryLog.fillStatsFromQueryLog(s.stats) if err != nil { return errorx.Decorate(err, "failed to load stats from querylog") } - once.Do(func() { - go periodicQueryLogRotate() - go periodicHourlyTopRotate() - go statsRotator() + // TODO: Think about reworking this, the current approach won't work properly if AG Home is restarted periodically + s.once.Do(func() { + log.Printf("Start DNS server periodic jobs") + go s.queryLog.periodicQueryLogRotate() + go s.queryLog.runningTop.periodicHourlyTopRotate() + go s.stats.statsRotator() }) proxyConfig := proxy.Config{ @@ -187,17 +208,7 @@ func (s *Server) stopInternal() error { } // flush remainder to file - logBufferLock.Lock() - flushBuffer := logBuffer - logBuffer = nil - logBufferLock.Unlock() - err := flushToFile(flushBuffer) - if err != nil { - log.Printf("Saving querylog to file failed: %s", err) - return err - } - - return nil + return s.queryLog.flushLogBuffer() } // IsRunning returns true if the DNS server is running @@ -229,6 +240,36 @@ func (s *Server) Reconfigure(config *ServerConfig) error { return nil } +// GetQueryLog returns a map with the current query log ready to be converted to a JSON +func (s *Server) GetQueryLog() []map[string]interface{} { + return s.queryLog.getQueryLog() +} + +// GetStatsTop returns the current stop stats +func (s *Server) GetStatsTop() *StatsTop { + return s.queryLog.runningTop.getStatsTop() +} + +// PurgeStats purges current server stats +func (s *Server) PurgeStats() { + // TODO: Locks? + s.stats.purgeStats() +} + +// GetAggregatedStats returns aggregated stats data for the 24 hours +func (s *Server) GetAggregatedStats() map[string]interface{} { + return s.stats.getAggregatedStats() +} + +// GetStatsHistory gets stats history aggregated by the specified time unit +// timeUnit is either time.Second, time.Minute, time.Hour, or 24*time.Hour +// start is start of the time range +// end is end of the time range +// returns nil if time unit is not supported +func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) { + return s.stats.getStatsHistory(timeUnit, startTime, endTime) +} + // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { start := time.Now() @@ -261,7 +302,10 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { if d.Upstream != nil { upstreamAddr = d.Upstream.Address() } - logRequest(msg, d.Res, res, elapsed, d.Addr, upstreamAddr) + entry := s.queryLog.logRequest(msg, d.Res, res, elapsed, d.Addr, upstreamAddr) + if entry != nil { + s.stats.incrementCounters(entry) + } } return nil @@ -402,5 +446,3 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { } return []dns.RR{&soa} } - -var once sync.Once diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 0edde88b..9553b9ed 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -2,18 +2,19 @@ package dnsforward import ( "net" + "os" "testing" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/stretchr/testify/assert" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/miekg/dns" ) func TestServer(t *testing.T) { - s := Server{} - s.UDPListenAddr = &net.UDPAddr{Port: 0} - s.TCPListenAddr = &net.TCPAddr{Port: 0} + s := createTestServer(t) + defer removeDataDir(t) err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -29,6 +30,14 @@ func TestServer(t *testing.T) { } assertResponse(t, reply) + // check query log and stats + log := s.GetQueryLog() + assert.Equal(t, 1, len(log), "Log size") + stats := s.GetStatsTop() + assert.Equal(t, 1, len(stats.Domains), "Top domains length") + assert.Equal(t, 0, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 1, len(stats.Clients), "Top clients length") + // message over TCP req = createTestMessage() addr = s.dnsProxy.Addr("tcp") @@ -39,6 +48,15 @@ func TestServer(t *testing.T) { } assertResponse(t, reply) + // check query log and stats again + log = s.GetQueryLog() + assert.Equal(t, 2, len(log), "Log size") + stats = s.GetStatsTop() + // Length did not change as we queried the same domain + assert.Equal(t, 1, len(stats.Domains), "Top domains length") + assert.Equal(t, 0, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 1, len(stats.Clients), "Top clients length") + err = s.Stop() if err != nil { t.Fatalf("DNS server failed to stop: %s", err) @@ -46,9 +64,8 @@ func TestServer(t *testing.T) { } func TestInvalidRequest(t *testing.T) { - s := Server{} - s.UDPListenAddr = &net.UDPAddr{Port: 0} - s.TCPListenAddr = &net.TCPAddr{Port: 0} + s := createTestServer(t) + defer removeDataDir(t) err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -67,6 +84,15 @@ func TestInvalidRequest(t *testing.T) { t.Fatalf("got a response to an invalid query") } + // check query log and stats + // invalid requests aren't written to the query log + log := s.GetQueryLog() + assert.Equal(t, 0, len(log), "Log size") + stats := s.GetStatsTop() + assert.Equal(t, 0, len(stats.Domains), "Top domains length") + assert.Equal(t, 0, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 0, len(stats.Clients), "Top clients length") + err = s.Stop() if err != nil { t.Fatalf("DNS server failed to stop: %s", err) @@ -74,7 +100,8 @@ func TestInvalidRequest(t *testing.T) { } func TestBlockedRequest(t *testing.T) { - s := createTestServer() + s := createTestServer(t) + defer removeDataDir(t) err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -99,6 +126,14 @@ func TestBlockedRequest(t *testing.T) { t.Fatalf("Wrong response: %s", reply.String()) } + // check query log and stats + log := s.GetQueryLog() + assert.Equal(t, 1, len(log), "Log size") + stats := s.GetStatsTop() + assert.Equal(t, 1, len(stats.Domains), "Top domains length") + assert.Equal(t, 1, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 1, len(stats.Clients), "Top clients length") + err = s.Stop() if err != nil { t.Fatalf("DNS server failed to stop: %s", err) @@ -106,7 +141,8 @@ func TestBlockedRequest(t *testing.T) { } func TestBlockedByHosts(t *testing.T) { - s := createTestServer() + s := createTestServer(t) + defer removeDataDir(t) err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -138,6 +174,14 @@ func TestBlockedByHosts(t *testing.T) { t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) } + // check query log and stats + log := s.GetQueryLog() + assert.Equal(t, 1, len(log), "Log size") + stats := s.GetStatsTop() + assert.Equal(t, 1, len(stats.Domains), "Top domains length") + assert.Equal(t, 1, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 1, len(stats.Clients), "Top clients length") + err = s.Stop() if err != nil { t.Fatalf("DNS server failed to stop: %s", err) @@ -145,7 +189,8 @@ func TestBlockedByHosts(t *testing.T) { } func TestBlockedBySafeBrowsing(t *testing.T) { - s := createTestServer() + s := createTestServer(t) + defer removeDataDir(t) err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -188,16 +233,25 @@ func TestBlockedBySafeBrowsing(t *testing.T) { t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) } + // check query log and stats + log := s.GetQueryLog() + assert.Equal(t, 1, len(log), "Log size") + stats := s.GetStatsTop() + assert.Equal(t, 1, len(stats.Domains), "Top domains length") + assert.Equal(t, 1, len(stats.Blocked), "Top blocked length") + assert.Equal(t, 1, len(stats.Clients), "Top clients length") + err = s.Stop() if err != nil { t.Fatalf("DNS server failed to stop: %s", err) } } -func createTestServer() *Server { - s := Server{} +func createTestServer(t *testing.T) *Server { + s := NewServer(createDataDir(t)) s.UDPListenAddr = &net.UDPAddr{Port: 0} s.TCPListenAddr = &net.TCPAddr{Port: 0} + s.QueryLogEnabled = true s.FilteringConfig.FilteringEnabled = true s.FilteringConfig.ProtectionEnabled = true s.FilteringConfig.SafeBrowsingEnabled = true @@ -209,7 +263,24 @@ func createTestServer() *Server { } filter := dnsfilter.Filter{ID: 1, Rules: rules} s.Filters = append(s.Filters, filter) - return &s + return s +} + +func createDataDir(t *testing.T) string { + dir := "testData" + err := os.MkdirAll(dir, 0755) + if err != nil { + t.Fatalf("Cannot create %s: %s", dir, err) + } + return dir +} + +func removeDataDir(t *testing.T) { + dir := "testData" + err := os.RemoveAll(dir) + if err != nil { + t.Fatalf("Cannot remove %s: %s", dir, err) + } } func createTestMessage() *dns.Msg { diff --git a/dnsforward/querylog.go b/dnsforward/querylog.go index fc51d165..52fa115c 100644 --- a/dnsforward/querylog.go +++ b/dnsforward/querylog.go @@ -1,10 +1,9 @@ package dnsforward import ( - "encoding/json" "fmt" "net" - "net/http" + "path/filepath" "strconv" "strings" "sync" @@ -24,13 +23,27 @@ const ( queryLogTopSize = 500 // Keep in memory only top N values ) -var ( +// queryLog is a structure that writes and reads the DNS query log +type queryLog struct { + logFile string // path to the log file + runningTop *dayTop // current top charts + logBufferLock sync.RWMutex logBuffer []*logEntry queryLogCache []*logEntry queryLogLock sync.RWMutex -) +} + +// newQueryLog creates a new instance of the query log +func newQueryLog(baseDir string) *queryLog { + l := &queryLog{ + logFile: filepath.Join(baseDir, queryLogFileName), + runningTop: &dayTop{}, + } + l.runningTop.init() + return l +} type logEntry struct { Question []byte @@ -42,7 +55,7 @@ type logEntry struct { Upstream string `json:",omitempty"` // if empty, means it was cached } -func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) { +func (l *queryLog) logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) *logEntry { var q []byte var a []byte var err error @@ -52,7 +65,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el q, err = question.Pack() if err != nil { log.Printf("failed to pack question for querylog: %s", err) - return + return nil } } @@ -60,7 +73,7 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el a, err = answer.Pack() if err != nil { log.Printf("failed to pack answer for querylog: %s", err) - return + return nil } } @@ -80,49 +93,49 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, el } var flushBuffer []*logEntry - logBufferLock.Lock() - logBuffer = append(logBuffer, &entry) - if len(logBuffer) >= logBufferCap { - flushBuffer = logBuffer - logBuffer = nil + l.logBufferLock.Lock() + l.logBuffer = append(l.logBuffer, &entry) + if len(l.logBuffer) >= logBufferCap { + flushBuffer = l.logBuffer + l.logBuffer = nil } - logBufferLock.Unlock() - queryLogLock.Lock() - queryLogCache = append(queryLogCache, &entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] + l.logBufferLock.Unlock() + l.queryLogLock.Lock() + l.queryLogCache = append(l.queryLogCache, &entry) + if len(l.queryLogCache) > queryLogSize { + toremove := len(l.queryLogCache) - queryLogSize + l.queryLogCache = l.queryLogCache[toremove:] } - queryLogLock.Unlock() + l.queryLogLock.Unlock() // add it to running top - err = runningTop.addEntry(&entry, question, now) + err = l.runningTop.addEntry(&entry, question, now) if err != nil { log.Printf("Failed to add entry to running top: %s", err) // don't do failure, just log } - incrementCounters(&entry) - // if buffer needs to be flushed to disk, do it now if len(flushBuffer) > 0 { // write to file // do it in separate goroutine -- we are stalling DNS response this whole time go func() { - err := flushToFile(flushBuffer) + err := l.flushToFile(flushBuffer) if err != nil { log.Printf("Failed to flush the query log: %s", err) } }() } + + return &entry } -// HandleQueryLog handles query log web request -func HandleQueryLog(w http.ResponseWriter, r *http.Request) { - queryLogLock.RLock() - values := make([]*logEntry, len(queryLogCache)) - copy(values, queryLogCache) - queryLogLock.RUnlock() +// getQueryLogJson returns a map with the current query log ready to be converted to a JSON +func (l *queryLog) getQueryLog() []map[string]interface{} { + l.queryLogLock.RLock() + values := make([]*logEntry, len(l.queryLogCache)) + copy(values, l.queryLogCache) + l.queryLogLock.RUnlock() // reverse it so that newest is first for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { @@ -182,21 +195,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { data = append(data, jsonEntry) } - jsonVal, err := json.Marshal(data) - if err != nil { - errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } + return data } func answerToMap(a *dns.Msg) []map[string]interface{} { diff --git a/dnsforward/querylog_file.go b/dnsforward/querylog_file.go index 9ab048db..8aadd5ae 100644 --- a/dnsforward/querylog_file.go +++ b/dnsforward/querylog_file.go @@ -19,7 +19,23 @@ var ( const enableGzip = false -func flushToFile(buffer []*logEntry) error { +// flushLogBuffer flushes the current buffer to file and resets the current buffer +func (l *queryLog) flushLogBuffer() error { + // flush remainder to file + l.logBufferLock.Lock() + flushBuffer := l.logBuffer + l.logBuffer = nil + l.logBufferLock.Unlock() + err := l.flushToFile(flushBuffer) + if err != nil { + log.Printf("Saving querylog to file failed: %s", err) + return err + } + return nil +} + +// flushToFile saves the specified log entries to the query log file +func (l *queryLog) flushToFile(buffer []*logEntry) error { if len(buffer) == 0 { return nil } @@ -45,14 +61,14 @@ func flushToFile(buffer []*logEntry) error { } var zb bytes.Buffer - filename := queryLogFileName + filename := l.logFile // gzip enabled? if enableGzip { filename += ".gz" zw := gzip.NewWriter(&zb) - zw.Name = queryLogFileName + zw.Name = l.logFile zw.ModTime = time.Now() _, err = zw.Write(b.Bytes()) @@ -118,13 +134,13 @@ func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { return nil } -func rotateQueryLog() error { - from := queryLogFileName - to := queryLogFileName + ".1" +func (l *queryLog) rotateQueryLog() error { + from := l.logFile + to := l.logFile + ".1" if enableGzip { - from = queryLogFileName + ".gz" - to = queryLogFileName + ".gz.1" + from = l.logFile + ".gz" + to = l.logFile + ".gz.1" } if _, err := os.Stat(from); os.IsNotExist(err) { @@ -143,9 +159,9 @@ func rotateQueryLog() error { return nil } -func periodicQueryLogRotate() { +func (l *queryLog) periodicQueryLogRotate() { for range time.Tick(queryLogRotationPeriod) { - err := rotateQueryLog() + err := l.rotateQueryLog() if err != nil { log.Printf("Failed to rotate querylog: %s", err) // do nothing, continue rotating @@ -153,20 +169,20 @@ func periodicQueryLogRotate() { } } -func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { +func (l *queryLog) genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { now := time.Now() // read from querylog files, try newest file first var files []string if enableGzip { files = []string{ - queryLogFileName + ".gz", - queryLogFileName + ".gz.1", + l.logFile + ".gz", + l.logFile + ".gz.1", } } else { files = []string{ - queryLogFileName, - queryLogFileName + ".1", + l.logFile, + l.logFile + ".1", } } diff --git a/dnsforward/querylog_top.go b/dnsforward/querylog_top.go index 5c08a223..25ad9791 100644 --- a/dnsforward/querylog_top.go +++ b/dnsforward/querylog_top.go @@ -1,14 +1,10 @@ package dnsforward import ( - "bytes" "fmt" - "net/http" "os" "path" "runtime" - "sort" - "strconv" "strings" "sync" "time" @@ -40,32 +36,30 @@ type dayTop struct { loadedLock sync.Mutex } -var runningTop dayTop - -func init() { - runningTop.hoursWriteLock() +func (d *dayTop) init() { + d.hoursWriteLock() for i := 0; i < 24; i++ { hour := hourTop{} hour.init() - runningTop.hours = append(runningTop.hours, &hour) + d.hours = append(d.hours, &hour) } - runningTop.hoursWriteUnlock() + d.hoursWriteUnlock() } -func rotateHourlyTop() { +func (d *dayTop) rotateHourlyTop() { log.Printf("Rotating hourly top") hour := &hourTop{} hour.init() - runningTop.hoursWriteLock() - runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) - runningTop.hours = runningTop.hours[:24] - runningTop.hoursWriteUnlock() + d.hoursWriteLock() + d.hours = append([]*hourTop{hour}, d.hours...) + d.hours = d.hours[:24] + d.hoursWriteUnlock() } -func periodicHourlyTopRotate() { +func (d *dayTop) periodicHourlyTopRotate() { t := time.Hour for range time.Tick(t) { - rotateHourlyTop() + d.rotateHourlyTop() } } @@ -165,16 +159,16 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) // get value, if not set, crate one - runningTop.hoursReadLock() - defer runningTop.hoursReadUnlock() - err := runningTop.hours[hour].incrementDomains(hostname) + d.hoursReadLock() + defer d.hoursReadUnlock() + err := d.hours[hour].incrementDomains(hostname) if err != nil { log.Printf("Failed to increment value: %s", err) return err } if entry.Result.IsFiltered { - err := runningTop.hours[hour].incrementBlocked(hostname) + err := d.hours[hour].incrementBlocked(hostname) if err != nil { log.Printf("Failed to increment value: %s", err) return err @@ -182,7 +176,7 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { } if len(entry.IP) > 0 { - err := runningTop.hours[hour].incrementClients(entry.IP) + err := d.hours[hour].incrementClients(entry.IP) if err != nil { log.Printf("Failed to increment value: %s", err) return err @@ -192,11 +186,11 @@ func (d *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { return nil } -func fillStatsFromQueryLog() error { +func (l *queryLog) fillStatsFromQueryLog(s *stats) error { now := time.Now() - runningTop.loadedWriteLock() - defer runningTop.loadedWriteUnlock() - if runningTop.loaded { + l.runningTop.loadedWriteLock() + defer l.runningTop.loadedWriteUnlock() + if l.runningTop.loaded { return nil } onEntry := func(entry *logEntry) error { @@ -221,42 +215,49 @@ func fillStatsFromQueryLog() error { return nil } - err := runningTop.addEntry(entry, q, now) + err := l.runningTop.addEntry(entry, q, now) if err != nil { log.Printf("Failed to add entry to running top: %s", err) return err } - queryLogLock.Lock() - queryLogCache = append(queryLogCache, entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] + l.queryLogLock.Lock() + l.queryLogCache = append(l.queryLogCache, entry) + if len(l.queryLogCache) > queryLogSize { + toremove := len(l.queryLogCache) - queryLogSize + l.queryLogCache = l.queryLogCache[toremove:] } - queryLogLock.Unlock() - - incrementCounters(entry) + l.queryLogLock.Unlock() + s.incrementCounters(entry) return nil } needMore := func() bool { return true } - err := genericLoader(onEntry, needMore, queryLogTimeLimit) + err := l.genericLoader(onEntry, needMore, queryLogTimeLimit) if err != nil { log.Printf("Failed to load entries from querylog: %s", err) return err } - runningTop.loaded = true - + l.runningTop.loaded = true return nil } -// HandleStatsTop returns the current top stats -func HandleStatsTop(w http.ResponseWriter, r *http.Request) { - domains := map[string]int{} - blocked := map[string]int{} - clients := map[string]int{} +// StatsTop represents top stat charts +type StatsTop struct { + Domains map[string]int // Domains - top requested domains + Blocked map[string]int // Blocked - top blocked domains + Clients map[string]int // Clients - top DNS clients +} + +// getStatsTop returns the current top stats +func (d *dayTop) getStatsTop() *StatsTop { + s := &StatsTop{ + Domains: map[string]int{}, + Blocked: map[string]int{}, + Clients: map[string]int{}, + } do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { for _, ikey := range keys { @@ -273,79 +274,17 @@ func HandleStatsTop(w http.ResponseWriter, r *http.Request) { } } - runningTop.hoursReadLock() + d.hoursReadLock() for hour := 0; hour < 24; hour++ { - runningTop.hours[hour].RLock() - do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) - do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) - do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) - runningTop.hours[hour].RUnlock() + d.hours[hour].RLock() + do(d.hours[hour].domains.Keys(), d.hours[hour].lockedGetDomains, s.Domains) + do(d.hours[hour].blocked.Keys(), d.hours[hour].lockedGetBlocked, s.Blocked) + do(d.hours[hour].clients.Keys(), d.hours[hour].lockedGetClients, s.Clients) + d.hours[hour].RUnlock() } - runningTop.hoursReadUnlock() + d.hoursReadUnlock() - // use manual json marshalling because we want maps to be sorted by value - json := bytes.Buffer{} - json.WriteString("{\n") - - gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", name)) - json.WriteString(": {\n") - sorted := sortByValue(top) - // no more than 50 entries - if len(sorted) > 50 { - sorted = sorted[:50] - } - for i, key := range sorted { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", key)) - json.WriteString(": ") - json.WriteString(strconv.Itoa(top[key])) - if i+1 != len(sorted) { - json.WriteByte(',') - } - json.WriteByte('\n') - } - json.WriteString(" }") - if addComma { - json.WriteByte(',') - } - json.WriteByte('\n') - } - gen(&json, "top_queried_domains", domains, true) - gen(&json, "top_blocked_domains", blocked, true) - gen(&json, "top_clients", clients, true) - json.WriteString(" \"stats_period\": \"24 hours\"\n") - json.WriteString("}\n") - - w.Header().Set("Content-Type", "application/json") - _, err := w.Write(json.Bytes()) - if err != nil { - errorText := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - -// helper function for querylog API -func sortByValue(m map[string]int) []string { - type kv struct { - k string - v int - } - var ss []kv - for k, v := range m { - ss = append(ss, kv{k, v}) - } - sort.Slice(ss, func(l, r int) bool { - return ss[l].v > ss[r].v - }) - - sorted := []string{} - for _, v := range ss { - sorted = append(sorted, v.k) - } - return sorted + return s } func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } diff --git a/dnsforward/stats.go b/dnsforward/stats.go index cbc25af7..705a250f 100644 --- a/dnsforward/stats.go +++ b/dnsforward/stats.go @@ -1,68 +1,76 @@ package dnsforward import ( - "encoding/json" "fmt" - "net/http" "sync" "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/hmage/golibs/log" ) -var ( - requests = newDNSCounter("requests_total") - filtered = newDNSCounter("filtered_total") - filteredLists = newDNSCounter("filtered_lists_total") - filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") - filteredParental = newDNSCounter("filtered_parental_total") - whitelisted = newDNSCounter("whitelisted_total") - safesearch = newDNSCounter("safesearch_total") - errorsTotal = newDNSCounter("errors_total") - elapsedTime = newDNSHistogram("request_duration") -) - -// entries for single time period (for example all per-second entries) -type statsEntries map[string][statsHistoryElements]float64 - // how far back to keep the stats const statsHistoryElements = 60 + 1 // +1 for calculating delta +// entries for single time period (for example all per-second entries) +type statsEntries map[string][statsHistoryElements]float64 + // each periodic stat is a map of arrays type periodicStats struct { - Entries statsEntries + entries statsEntries period time.Duration // how long one entry lasts - LastRotate time.Time // last time this data was rotated + lastRotate time.Time // last time this data was rotated sync.RWMutex } +// stats is the DNS server historical statistics type stats struct { - PerSecond periodicStats - PerMinute periodicStats - PerHour periodicStats - PerDay periodicStats + perSecond periodicStats + perMinute periodicStats + perHour periodicStats + perDay periodicStats + + requests *counter // total number of requests + filtered *counter // total number of filtered requests + filteredLists *counter // total number of requests blocked by filter lists + filteredSafebrowsing *counter // total number of requests blocked by safebrowsing + filteredParental *counter // total number of requests blocked by the parental control + whitelisted *counter // total number of requests whitelisted by filter lists + safesearch *counter // total number of requests for which safe search rules were applied + errorsTotal *counter // total number of errors + elapsedTime *histogram // requests duration histogram } -// per-second/per-minute/per-hour/per-day stats -var statistics stats +// initializes an empty stats structure +func newStats() *stats { + s := &stats{ + requests: newDNSCounter("requests_total"), + filtered: newDNSCounter("filtered_total"), + filteredLists: newDNSCounter("filtered_lists_total"), + filteredSafebrowsing: newDNSCounter("filtered_safebrowsing_total"), + filteredParental: newDNSCounter("filtered_parental_total"), + whitelisted: newDNSCounter("whitelisted_total"), + safesearch: newDNSCounter("safesearch_total"), + errorsTotal: newDNSCounter("errors_total"), + elapsedTime: newDNSHistogram("request_duration"), + } + + // Initializes empty per-sec/minute/hour/day stats + s.purgeStats() + return s +} func initPeriodicStats(periodic *periodicStats, period time.Duration) { - periodic.Entries = statsEntries{} - periodic.LastRotate = time.Now() + periodic.entries = statsEntries{} + periodic.lastRotate = time.Now() periodic.period = period } -func init() { - purgeStats() -} - -func purgeStats() { - initPeriodicStats(&statistics.PerSecond, time.Second) - initPeriodicStats(&statistics.PerMinute, time.Minute) - initPeriodicStats(&statistics.PerHour, time.Hour) - initPeriodicStats(&statistics.PerDay, time.Hour*24) +func (s *stats) purgeStats() { + initPeriodicStats(&s.perSecond, time.Second) + initPeriodicStats(&s.perMinute, time.Minute) + initPeriodicStats(&s.perHour, time.Hour) + initPeriodicStats(&s.perDay, time.Hour*24) } func (p *periodicStats) Inc(name string, when time.Time) { @@ -73,9 +81,9 @@ func (p *periodicStats) Inc(name string, when time.Time) { return // outside of our timeframe } p.Lock() - currentValues := p.Entries[name] + currentValues := p.entries[name] currentValues[elapsed]++ - p.Entries[name] = currentValues + p.entries[name] = currentValues p.Unlock() } @@ -89,51 +97,51 @@ func (p *periodicStats) Observe(name string, when time.Time, value float64) { p.Lock() { countname := name + "_count" - currentValues := p.Entries[countname] + currentValues := p.entries[countname] v := currentValues[elapsed] - // log.Tracef("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) + // log.Tracef("Will change p.entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) v++ currentValues[elapsed] = v - p.Entries[countname] = currentValues + p.entries[countname] = currentValues } { totalname := name + "_sum" - currentValues := p.Entries[totalname] + currentValues := p.entries[totalname] currentValues[elapsed] += value - p.Entries[totalname] = currentValues + p.entries[totalname] = currentValues } p.Unlock() } func (p *periodicStats) statsRotate(now time.Time) { p.Lock() - rotations := int64(now.Sub(p.LastRotate) / p.period) + rotations := int64(now.Sub(p.lastRotate) / p.period) if rotations > statsHistoryElements { rotations = statsHistoryElements } // calculate how many times we should rotate for r := int64(0); r < rotations; r++ { - for key, values := range p.Entries { + for key, values := range p.entries { newValues := [statsHistoryElements]float64{} for i := 1; i < len(values); i++ { newValues[i] = values[i-1] } - p.Entries[key] = newValues + p.entries[key] = newValues } } if rotations > 0 { - p.LastRotate = now + p.lastRotate = now } p.Unlock() } -func statsRotator() { +func (s *stats) statsRotator() { for range time.Tick(time.Second) { now := time.Now() - statistics.PerSecond.statsRotate(now) - statistics.PerMinute.statsRotate(now) - statistics.PerHour.statsRotate(now) - statistics.PerDay.statsRotate(now) + s.perSecond.statsRotate(now) + s.perMinute.statsRotate(now) + s.perHour.statsRotate(now) + s.perDay.statsRotate(now) } } @@ -152,20 +160,16 @@ func newDNSCounter(name string) *counter { } } -func (c *counter) IncWithTime(when time.Time) { - statistics.PerSecond.Inc(c.name, when) - statistics.PerMinute.Inc(c.name, when) - statistics.PerHour.Inc(c.name, when) - statistics.PerDay.Inc(c.name, when) +func (s *stats) incWithTime(c *counter, when time.Time) { + s.perSecond.Inc(c.name, when) + s.perMinute.Inc(c.name, when) + s.perHour.Inc(c.name, when) + s.perDay.Inc(c.name, when) c.Lock() c.value++ c.Unlock() } -func (c *counter) Inc() { - c.IncWithTime(time.Now()) -} - type histogram struct { name string // used as key in periodic stats count int64 @@ -180,56 +184,52 @@ func newDNSHistogram(name string) *histogram { } } -func (h *histogram) ObserveWithTime(value float64, when time.Time) { - statistics.PerSecond.Observe(h.name, when, value) - statistics.PerMinute.Observe(h.name, when, value) - statistics.PerHour.Observe(h.name, when, value) - statistics.PerDay.Observe(h.name, when, value) +func (s *stats) observeWithTime(h *histogram, value float64, when time.Time) { + s.perSecond.Observe(h.name, when, value) + s.perMinute.Observe(h.name, when, value) + s.perHour.Observe(h.name, when, value) + s.perDay.Observe(h.name, when, value) h.Lock() h.count++ h.total += value h.Unlock() } -func (h *histogram) Observe(value float64) { - h.ObserveWithTime(value, time.Now()) -} - // ----- // stats // ----- -func incrementCounters(entry *logEntry) { - requests.IncWithTime(entry.Time) +func (s *stats) incrementCounters(entry *logEntry) { + s.incWithTime(s.requests, entry.Time) if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) + s.incWithTime(s.filtered, entry.Time) } switch entry.Result.Reason { case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) + s.incWithTime(s.whitelisted, entry.Time) case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) + s.incWithTime(s.errorsTotal, entry.Time) case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) + s.incWithTime(s.filteredLists, entry.Time) case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) + s.incWithTime(s.filteredSafebrowsing, entry.Time) case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) + s.incWithTime(s.filteredParental, entry.Time) case dnsfilter.FilteredInvalid: // do nothing case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) + s.incWithTime(s.safesearch, entry.Time) } - elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) + s.observeWithTime(s.elapsedTime, entry.Elapsed.Seconds(), entry.Time) } -// HandleStats returns aggregated stats data for the 24 hours -func HandleStats(w http.ResponseWriter, r *http.Request) { +// getAggregatedStats returns aggregated stats data for the 24 hours +func (s *stats) getAggregatedStats() map[string]interface{} { const numHours = 24 - histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) + historical := s.generateMapFromStats(&s.perHour, 0, numHours) // sum them up summed := map[string]interface{}{} - for key, values := range histrical { + for key, values := range historical { summedValue := 0.0 floats, ok := values.([]float64) if !ok { @@ -249,33 +249,18 @@ func HandleStats(w http.ResponseWriter, r *http.Request) { } summed["stats_period"] = "24 hours" - - json, err := json.Marshal(summed) - if err != nil { - errorText := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errorText) - http.Error(w, errorText, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, 500) - return - } + return summed } -func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { +func (s *stats) generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { // clamp start = clamp(start, 0, statsHistoryElements) end = clamp(end, 0, statsHistoryElements) avgProcessingTime := make([]float64, 0) - count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) - sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) + count := getReversedSlice(stats.entries[s.elapsedTime.name+"_count"], start, end) + sum := getReversedSlice(stats.entries[s.elapsedTime.name+"_sum"], start, end) for i := 0; i < len(count); i++ { var avg float64 if count[i] != 0 { @@ -286,66 +271,48 @@ func generateMapFromStats(stats *periodicStats, start int, end int) map[string]i } result := map[string]interface{}{ - "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), - "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), - "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), - "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), - "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), + "dns_queries": getReversedSlice(stats.entries[s.requests.name], start, end), + "blocked_filtering": getReversedSlice(stats.entries[s.filtered.name], start, end), + "replaced_safebrowsing": getReversedSlice(stats.entries[s.filteredSafebrowsing.name], start, end), + "replaced_safesearch": getReversedSlice(stats.entries[s.safesearch.name], start, end), + "replaced_parental": getReversedSlice(stats.entries[s.filteredParental.name], start, end), "avg_processing_time": avgProcessingTime, } return result } -// HandleStatsHistory returns historical stats data for the 24 hours -func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { - // handle time unit and prepare our time window size - now := time.Now() - timeUnitString := r.URL.Query().Get("time_unit") +// getStatsHistory gets stats history aggregated by the specified time unit +// timeUnit is either time.Second, time.Minute, time.Hour, or 24*time.Hour +// start is start of the time range +// end is end of the time range +// returns nil if time unit is not supported +func (s *stats) getStatsHistory(timeUnit time.Duration, startTime time.Time, endTime time.Time) (map[string]interface{}, error) { var stats *periodicStats - var timeUnit time.Duration - switch timeUnitString { - case "seconds": - timeUnit = time.Second - stats = &statistics.PerSecond - case "minutes": - timeUnit = time.Minute - stats = &statistics.PerMinute - case "hours": - timeUnit = time.Hour - stats = &statistics.PerHour - case "days": - timeUnit = time.Hour * 24 - stats = &statistics.PerDay - default: - http.Error(w, "Must specify valid time_unit parameter", 400) - return + + switch timeUnit { + case time.Second: + stats = &s.perSecond + case time.Minute: + stats = &s.perMinute + case time.Hour: + stats = &s.perHour + case 24 * time.Hour: + stats = &s.perDay } - // parse start and end time - startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) - if err != nil { - errorText := fmt.Sprintf("Must specify valid start_time parameter: %s", err) - log.Println(errorText) - http.Error(w, errorText, 400) - return - } - endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) - if err != nil { - errorText := fmt.Sprintf("Must specify valid end_time parameter: %s", err) - log.Println(errorText) - http.Error(w, errorText, 400) - return + if stats == nil { + return nil, fmt.Errorf("unsupported time unit: %v", timeUnit) } + now := time.Now() + // check if start and time times are within supported time range timeRange := timeUnit * statsHistoryElements if startTime.Add(timeRange).Before(now) { - http.Error(w, "start_time parameter is outside of supported range", http.StatusBadRequest) - return + return nil, fmt.Errorf("start_time parameter is outside of supported range: %s", startTime.String()) } if endTime.Add(timeRange).Before(now) { - http.Error(w, "end_time parameter is outside of supported range", http.StatusBadRequest) - return + return nil, fmt.Errorf("end_time parameter is outside of supported range: %s", startTime.String()) } // calculate start and end of our array @@ -358,33 +325,7 @@ func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { start, end = end, start } - data := generateMapFromStats(stats, start, end) - json, err := json.Marshal(data) - if err != nil { - errorText := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errorText) - http.Error(w, errorText, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, 500) - return - } -} - -// HandleStatsReset resets the stats caches -func HandleStatsReset(w http.ResponseWriter, r *http.Request) { - purgeStats() - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - errorText := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } + return s.generateMapFromStats(stats, start, end), nil } func clamp(value, low, high int) int { diff --git a/filter.go b/filter.go index 49b54f0e..dcdd40be 100644 --- a/filter.go +++ b/filter.go @@ -26,7 +26,7 @@ type filter struct { URL string `json:"url"` Name string `json:"name" yaml:"name"` RulesCount int `json:"rulesCount" yaml:"-"` - LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` + LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"-"` dnsfilter.Filter `yaml:",inline"` } @@ -95,6 +95,12 @@ func refreshFiltersIfNecessary(force bool) int { filter.ID = assignUniqueFilterID() } + if len(filter.Rules) == 0 { + // Try reloading filter from the disk before updating + // This is useful for the case when we simply enable a previously downloaded filter + _ = filter.load() + } + updated, err := filter.update(force) if err != nil { log.Printf("Failed to update filter %s: %s\n", filter.URL, err) @@ -162,9 +168,6 @@ func (filter *filter) update(force bool) (bool, error) { log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) - // use the same update period for failed filter downloads to avoid flooding with requests - filter.LastUpdated = time.Now() - resp, err := client.Get(filter.URL) if resp != nil && resp.Body != nil { defer resp.Body.Close() @@ -217,7 +220,11 @@ func (filter *filter) save() error { log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) body := []byte(strings.Join(filter.Rules, "\n")) - return safeWriteFile(filterFilePath, body) + err := safeWriteFile(filterFilePath, body) + + // update LastUpdated field after saving the file + filter.LastUpdated = filter.LastTimeUpdated() + return err } // loads filter contents from the file in dataDir @@ -245,11 +252,30 @@ func (filter *filter) load() error { filter.RulesCount = rulesCount filter.Rules = rules + filter.LastUpdated = filter.LastTimeUpdated() return nil } // Path to the filter contents func (filter *filter) Path() string { - return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") + return filepath.Join(config.ourWorkingDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") +} + +// LastUpdated returns the time when the filter was last time updated +func (filter *filter) LastTimeUpdated() time.Time { + filterFilePath := filter.Path() + if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + // if the filter file does not exist, return 0001-01-01 + return time.Time{} + } + + s, err := os.Stat(filterFilePath) + if err != nil { + // if the filter file does not exist, return 0001-01-01 + return time.Time{} + } + + // filter file modified time + return s.ModTime() } diff --git a/go.mod b/go.mod index b2b0a049..68a55326 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/miekg/dns v1.1.1 github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect + github.com/stretchr/testify v1.2.2 go.uber.org/goleak v0.10.0 golang.org/x/net v0.0.0-20181220203305-927f97764cc3 golang.org/x/sys v0.0.0-20181228144115-9a3f9b0469bb diff --git a/helpers.go b/helpers.go index 1bea694e..a0cf1fd7 100644 --- a/helpers.go +++ b/helpers.go @@ -100,7 +100,7 @@ func optionalAuthHandler(handler http.Handler) http.Handler { func detectFirstRun() bool { configfile := config.ourConfigFilename if !filepath.IsAbs(configfile) { - configfile = filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + configfile = filepath.Join(config.ourWorkingDir, config.ourConfigFilename) } _, err := os.Stat(configfile) if !os.IsNotExist(err) { diff --git a/upgrade.go b/upgrade.go index 02629797..0b3ddc5c 100644 --- a/upgrade.go +++ b/upgrade.go @@ -95,7 +95,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err func upgradeSchema0to1(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") + dnsFilterPath := filepath.Join(config.ourWorkingDir, "dnsfilter.txt") if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) err = os.Remove(dnsFilterPath) @@ -116,7 +116,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { func upgradeSchema1to2(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile") + coreFilePath := filepath.Join(config.ourWorkingDir, "Corefile") if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", coreFilePath) err = os.Remove(coreFilePath)