Merge: * querylog: refactor: move HTTP handlers to querylog/

* commit '90db91b0fd347d168ef9589405ba812c0cfc0c2d':
  * querylog: refactor: move HTTP handlers to querylog/
This commit is contained in:
Simon Zolin 2019-10-09 19:43:00 +03:00
commit b43c076c4d
11 changed files with 276 additions and 246 deletions

View File

@ -550,7 +550,7 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error {
s.RLock() s.RLock()
// Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use.
// This can happen after proxy server has been stopped, but its workers haven't yet exited. // This can happen after proxy server has been stopped, but its workers haven't yet exited.
if s.conf.QueryLogEnabled && shouldLog && s.queryLog != nil { if shouldLog && s.queryLog != nil {
upstreamAddr := "" upstreamAddr := ""
if d.Upstream != nil { if d.Upstream != nil {
upstreamAddr = d.Upstream.Address() upstreamAddr = d.Upstream.Address()

View File

@ -380,7 +380,6 @@ func createTestServer(t *testing.T) *Server {
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.QueryLogEnabled = true
s.conf.FilteringConfig.FilteringEnabled = true s.conf.FilteringConfig.FilteringEnabled = true
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.FilteringConfig.ProtectionEnabled = true
s.conf.FilteringConfig.SafeBrowsingEnabled = true s.conf.FilteringConfig.SafeBrowsingEnabled = true

View File

@ -278,10 +278,6 @@ func parseConfig() error {
config.DNS.FiltersUpdateIntervalHours = 24 config.DNS.FiltersUpdateIntervalHours = 24
} }
if !checkQueryLogInterval(config.DNS.QueryLogInterval) {
config.DNS.QueryLogInterval = 1
}
for _, cy := range config.Clients { for _, cy := range config.Clients {
cli := Client{ cli := Client{
Name: cy.Name, Name: cy.Name,
@ -364,6 +360,13 @@ func (c *configuration) write() error {
config.DNS.StatsInterval = sdc.Interval config.DNS.StatsInterval = sdc.Interval
} }
if config.queryLog != nil {
dc := querylog.DiskConfig{}
config.queryLog.WriteDiskConfig(&dc)
config.DNS.QueryLogEnabled = dc.Enabled
config.DNS.QueryLogInterval = dc.Interval
}
configFile := config.getConfigFilename() configFile := config.getConfigFilename()
log.Debug("Writing YAML file: %s", configFile) log.Debug("Writing YAML file: %s", configFile)
yamlText, err := yaml.Marshal(&config) yamlText, err := yaml.Marshal(&config)

View File

@ -111,7 +111,6 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
"http_port": config.BindPort, "http_port": config.BindPort,
"dns_port": config.DNS.Port, "dns_port": config.DNS.Port,
"protection_enabled": config.DNS.ProtectionEnabled, "protection_enabled": config.DNS.ProtectionEnabled,
"querylog_enabled": config.DNS.QueryLogEnabled,
"running": isRunning(), "running": isRunning(),
"bootstrap_dns": config.DNS.BootstrapDNS, "bootstrap_dns": config.DNS.BootstrapDNS,
"upstream_dns": config.DNS.UpstreamDNS, "upstream_dns": config.DNS.UpstreamDNS,
@ -568,7 +567,6 @@ func registerControlHandlers() {
RegisterClientsHandlers() RegisterClientsHandlers()
registerRewritesHandlers() registerRewritesHandlers()
RegisterBlockedServicesHandlers() RegisterBlockedServicesHandlers()
RegisterQueryLogHandlers()
RegisterAuthHandlers() RegisterAuthHandlers()
http.HandleFunc("/dns-query", postInstall(handleDOH)) http.HandleFunc("/dns-query", postInstall(handleDOH))

View File

@ -1,160 +0,0 @@
package home
import (
"encoding/json"
"net/http"
"time"
"github.com/AdguardTeam/AdGuardHome/querylog"
"github.com/miekg/dns"
)
type qlogFilterJSON struct {
Domain string `json:"domain"`
Client string `json:"client"`
QuestionType string `json:"question_type"`
ResponseStatus string `json:"response_status"`
}
type queryLogRequest struct {
OlderThan string `json:"older_than"`
Filter qlogFilterJSON `json:"filter"`
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
func handleQueryLog(w http.ResponseWriter, r *http.Request) {
req := queryLogRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
params := querylog.GetDataParams{
Domain: req.Filter.Domain,
Client: req.Filter.Client,
}
if len(req.OlderThan) != 0 {
params.OlderThan, err = time.Parse(time.RFC3339Nano, req.OlderThan)
if err != nil {
httpError(w, http.StatusBadRequest, "invalid time stamp: %s", err)
return
}
}
if getDoubleQuotesEnclosedValue(&params.Domain) {
params.StrictMatchDomain = true
}
if getDoubleQuotesEnclosedValue(&params.Client) {
params.StrictMatchClient = true
}
if len(req.Filter.QuestionType) != 0 {
qtype, ok := dns.StringToType[req.Filter.QuestionType]
if !ok {
httpError(w, http.StatusBadRequest, "invalid question_type")
return
}
params.QuestionType = qtype
}
if len(req.Filter.ResponseStatus) != 0 {
switch req.Filter.ResponseStatus {
case "filtered":
params.ResponseStatus = querylog.ResponseStatusFiltered
default:
httpError(w, http.StatusBadRequest, "invalid response_status")
return
}
}
data := config.queryLog.GetData(params)
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}
func handleQueryLogClear(w http.ResponseWriter, r *http.Request) {
config.queryLog.Clear()
returnOK(w)
}
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
}
// Get configuration
func handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{}
resp.Enabled = config.DNS.QueryLogEnabled
resp.Interval = config.DNS.QueryLogInterval
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set configuration
func handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
reqData := qlogConfig{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkQueryLogInterval(reqData.Interval) {
httpError(w, http.StatusBadRequest, "Unsupported interval")
return
}
config.DNS.QueryLogEnabled = reqData.Enabled
config.DNS.QueryLogInterval = reqData.Interval
_ = config.write()
conf := querylog.Config{
Interval: config.DNS.QueryLogInterval * 24,
}
config.queryLog.Configure(conf)
returnOK(w)
}
func checkQueryLogInterval(i uint32) bool {
return i == 1 || i == 7 || i == 30 || i == 90
}
// RegisterQueryLogHandlers - register handlers
func RegisterQueryLogHandlers() {
httpRegister("POST", "/control/querylog", handleQueryLog)
httpRegister(http.MethodGet, "/control/querylog_info", handleQueryLogInfo)
httpRegister(http.MethodPost, "/control/querylog_clear", handleQueryLogClear)
httpRegister(http.MethodPost, "/control/querylog_config", handleQueryLogConfig)
}

View File

@ -48,8 +48,11 @@ func initDNSServer() {
log.Fatal("Couldn't initialize statistics module") log.Fatal("Couldn't initialize statistics module")
} }
conf := querylog.Config{ conf := querylog.Config{
BaseDir: baseDir, Enabled: config.DNS.QueryLogEnabled,
Interval: config.DNS.QueryLogInterval * 24, BaseDir: baseDir,
Interval: config.DNS.QueryLogInterval,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
} }
config.queryLog = querylog.New(conf) config.queryLog = querylog.New(conf)
config.dnsServer = dnsforward.NewServer(config.stats, config.queryLog) config.dnsServer = dnsforward.NewServer(config.stats, config.queryLog)

View File

@ -16,7 +16,7 @@ import (
) )
const ( const (
logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk logBufferCap = 5000 // maximum capacity of buffer before it's flushed to disk
queryLogFileName = "querylog.json" // .gz added during compression queryLogFileName = "querylog.json" // .gz added during compression
getDataLimit = 500 // GetData(): maximum log entries to return getDataLimit = 500 // GetData(): maximum log entries to return
@ -29,10 +29,11 @@ type queryLog struct {
conf Config conf Config
logFile string // path to the log file logFile string // path to the log file
logBufferLock sync.RWMutex bufferLock sync.RWMutex
logBuffer []*logEntry buffer []*logEntry
fileFlushLock sync.Mutex // synchronize a file-flushing goroutine and main thread fileFlushLock sync.Mutex // synchronize a file-flushing goroutine and main thread
flushPending bool // don't start another goroutine while the previous one is still running flushPending bool // don't start another goroutine while the previous one is still running
fileWriteLock sync.Mutex
} }
// create a new instance of the query log // create a new instance of the query log
@ -40,7 +41,13 @@ func newQueryLog(conf Config) *queryLog {
l := queryLog{} l := queryLog{}
l.logFile = filepath.Join(conf.BaseDir, queryLogFileName) l.logFile = filepath.Join(conf.BaseDir, queryLogFileName)
l.conf = conf l.conf = conf
go l.periodicQueryLogRotate() if !checkInterval(l.conf.Interval) {
l.conf.Interval = 1
}
if l.conf.HTTPRegister != nil {
l.initWeb()
}
go l.periodicRotate()
return &l return &l
} }
@ -48,18 +55,30 @@ func (l *queryLog) Close() {
_ = l.flushLogBuffer(true) _ = l.flushLogBuffer(true)
} }
func (l *queryLog) Configure(conf Config) { func checkInterval(days uint32) bool {
l.conf = conf return days == 1 || days == 7 || days == 30 || days == 90
} }
func (l *queryLog) Clear() { // Set new configuration at runtime
func (l *queryLog) configure(conf Config) {
l.conf.Enabled = conf.Enabled
l.conf.Interval = conf.Interval
}
func (l *queryLog) WriteDiskConfig(dc *DiskConfig) {
dc.Enabled = l.conf.Enabled
dc.Interval = l.conf.Interval
}
// Clear memory buffer and remove log files
func (l *queryLog) clear() {
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
l.logBufferLock.Lock() l.bufferLock.Lock()
l.logBuffer = nil l.buffer = nil
l.flushPending = false l.flushPending = false
l.logBufferLock.Unlock() l.bufferLock.Unlock()
err := os.Remove(l.logFile + ".1") err := os.Remove(l.logFile + ".1")
if err != nil && !os.IsNotExist(err) { if err != nil && !os.IsNotExist(err) {
@ -96,6 +115,10 @@ func getIPString(addr net.Addr) string {
} }
func (l *queryLog) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) { func (l *queryLog) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) {
if !l.conf.Enabled {
return
}
var q []byte var q []byte
var a []byte var a []byte
var err error var err error
@ -132,16 +155,16 @@ func (l *queryLog) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Res
Upstream: upstream, Upstream: upstream,
} }
l.logBufferLock.Lock() l.bufferLock.Lock()
l.logBuffer = append(l.logBuffer, &entry) l.buffer = append(l.buffer, &entry)
needFlush := false needFlush := false
if !l.flushPending { if !l.flushPending {
needFlush = len(l.logBuffer) >= logBufferCap needFlush = len(l.buffer) >= logBufferCap
if needFlush { if needFlush {
l.flushPending = true l.flushPending = true
} }
} }
l.logBufferLock.Unlock() l.bufferLock.Unlock()
// if buffer needs to be flushed to disk, do it now // if buffer needs to be flushed to disk, do it now
if needFlush { if needFlush {
@ -152,11 +175,9 @@ func (l *queryLog) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Res
} }
// Return TRUE if this entry is needed // Return TRUE if this entry is needed
func isNeeded(entry *logEntry, params GetDataParams) bool { func isNeeded(entry *logEntry, params getDataParams) bool {
if params.ResponseStatus != 0 { if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered {
if params.ResponseStatus == ResponseStatusFiltered && !entry.Result.IsFiltered { return false
return false
}
} }
if len(params.Domain) != 0 || params.QuestionType != 0 { if len(params.Domain) != 0 || params.QuestionType != 0 {
@ -193,7 +214,7 @@ func isNeeded(entry *logEntry, params GetDataParams) bool {
return true return true
} }
func (l *queryLog) readFromFile(params GetDataParams) ([]*logEntry, int) { func (l *queryLog) readFromFile(params getDataParams) ([]*logEntry, int) {
entries := []*logEntry{} entries := []*logEntry{}
olderThan := params.OlderThan olderThan := params.OlderThan
totalChunks := 0 totalChunks := 0
@ -247,7 +268,28 @@ func (l *queryLog) readFromFile(params GetDataParams) ([]*logEntry, int) {
return entries, total return entries, total
} }
func (l *queryLog) GetData(params GetDataParams) []map[string]interface{} { // Parameters for getData()
type getDataParams struct {
OlderThan time.Time // return entries that are older than this value
Domain string // filter by domain name in question
Client string // filter by client IP
QuestionType uint16 // filter by question type
ResponseStatus responseStatusType // filter by response status
StrictMatchDomain bool // if Domain value must be matched strictly
StrictMatchClient bool // if Client value must be matched strictly
}
// Response status
type responseStatusType int32
// Response status constants
const (
responseStatusAll responseStatusType = iota + 1
responseStatusFiltered
)
// Get log entries
func (l *queryLog) getData(params getDataParams) []map[string]interface{} {
var data = []map[string]interface{}{} var data = []map[string]interface{}{}
if len(params.Domain) != 0 && params.StrictMatchDomain { if len(params.Domain) != 0 && params.StrictMatchDomain {
@ -266,9 +308,9 @@ func (l *queryLog) GetData(params GetDataParams) []map[string]interface{} {
} }
// add from memory buffer // add from memory buffer
l.logBufferLock.Lock() l.bufferLock.Lock()
total += len(l.logBuffer) total += len(l.buffer)
for _, entry := range l.logBuffer { for _, entry := range l.buffer {
if !isNeeded(entry, params) { if !isNeeded(entry, params) {
continue continue
@ -283,7 +325,7 @@ func (l *queryLog) GetData(params GetDataParams) []map[string]interface{} {
} }
entries = append(entries, entry) entries = append(entries, entry)
} }
l.logBufferLock.Unlock() l.bufferLock.Unlock()
// process the elements from latest to oldest // process the elements from latest to oldest
for i := len(entries) - 1; i >= 0; i-- { for i := len(entries) - 1; i >= 0; i-- {

162
querylog/qlog_http.go Normal file
View File

@ -0,0 +1,162 @@
package querylog
import (
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
log.Info("QueryLog: %s %s: %s", r.Method, r.URL, text)
http.Error(w, text, code)
}
type filterJSON struct {
Domain string `json:"domain"`
Client string `json:"client"`
QuestionType string `json:"question_type"`
ResponseStatus string `json:"response_status"`
}
type request struct {
OlderThan string `json:"older_than"`
Filter filterJSON `json:"filter"`
}
// "value" -> value, return TRUE
func getDoubleQuotesEnclosedValue(s *string) bool {
t := *s
if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' {
*s = t[1 : len(t)-1]
return true
}
return false
}
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
req := request{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
params := getDataParams{
Domain: req.Filter.Domain,
Client: req.Filter.Client,
ResponseStatus: responseStatusAll,
}
if len(req.OlderThan) != 0 {
params.OlderThan, err = time.Parse(time.RFC3339Nano, req.OlderThan)
if err != nil {
httpError(r, w, http.StatusBadRequest, "invalid time stamp: %s", err)
return
}
}
if getDoubleQuotesEnclosedValue(&params.Domain) {
params.StrictMatchDomain = true
}
if getDoubleQuotesEnclosedValue(&params.Client) {
params.StrictMatchClient = true
}
if len(req.Filter.QuestionType) != 0 {
qtype, ok := dns.StringToType[req.Filter.QuestionType]
if !ok {
httpError(r, w, http.StatusBadRequest, "invalid question_type")
return
}
params.QuestionType = qtype
}
if len(req.Filter.ResponseStatus) != 0 {
switch req.Filter.ResponseStatus {
case "filtered":
params.ResponseStatus = responseStatusFiltered
default:
httpError(r, w, http.StatusBadRequest, "invalid response_status")
return
}
}
data := l.getData(params)
jsonVal, err := json.Marshal(data)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Couldn't marshal data into json: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
}
func (l *queryLog) handleQueryLogClear(w http.ResponseWriter, r *http.Request) {
l.clear()
}
type qlogConfig struct {
Enabled bool `json:"enabled"`
Interval uint32 `json:"interval"`
}
// Get configuration
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{}
resp.Enabled = l.conf.Enabled
resp.Interval = l.conf.Interval
jsonVal, err := json.Marshal(resp)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
httpError(r, w, http.StatusInternalServerError, "http write: %s", err)
}
}
// Set configuration
func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) {
reqData := qlogConfig{}
err := json.NewDecoder(r.Body).Decode(&reqData)
if err != nil {
httpError(r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
if !checkInterval(reqData.Interval) {
httpError(r, w, http.StatusBadRequest, "Unsupported interval")
return
}
conf := Config{
Enabled: reqData.Enabled,
Interval: reqData.Interval,
}
l.configure(conf)
l.conf.ConfigModified()
}
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister("POST", "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig)
}

View File

@ -2,58 +2,45 @@ package querylog
import ( import (
"net" "net"
"net/http"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// DiskConfig - configuration settings that are stored on disk
type DiskConfig struct {
Enabled bool
Interval uint32
}
// QueryLog - main interface // QueryLog - main interface
type QueryLog interface { type QueryLog interface {
// Close query log object // Close query log object
Close() Close()
// Set new configuration at runtime
// Currently only 'Interval' field is supported.
Configure(conf Config)
// Add a log entry // Add a log entry
Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, addr net.Addr, upstream string)
// Get log entries // WriteDiskConfig - write configuration
GetData(params GetDataParams) []map[string]interface{} WriteDiskConfig(dc *DiskConfig)
// Clear memory buffer and remove log files
Clear()
} }
// Config - configuration object // Config - configuration object
type Config struct { type Config struct {
Enabled bool
BaseDir string // directory where log file is stored BaseDir string // directory where log file is stored
Interval uint32 // interval to rotate logs (in hours) Interval uint32 // interval to rotate logs (in days)
// Called when the configuration is changed by HTTP request
ConfigModified func()
// Register an HTTP handler
HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request))
} }
// New - create a new instance of the query log // New - create a new instance of the query log
func New(conf Config) QueryLog { func New(conf Config) QueryLog {
return newQueryLog(conf) return newQueryLog(conf)
} }
// GetDataParams - parameters for GetData()
type GetDataParams struct {
OlderThan time.Time // return entries that are older than this value
Domain string // filter by domain name in question
Client string // filter by client IP
QuestionType uint16 // filter by question type
ResponseStatus ResponseStatusType // filter by response status
StrictMatchDomain bool // if Domain value must be matched strictly
StrictMatchClient bool // if Client value must be matched strictly
}
// ResponseStatusType - response status
type ResponseStatusType int32
// Response status constants
const (
ResponseStatusAll ResponseStatusType = iota + 1
ResponseStatusFiltered
)

View File

@ -7,17 +7,12 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"sync"
"time" "time"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/go-test/deep" "github.com/go-test/deep"
) )
var (
fileWriteLock sync.Mutex
)
const enableGzip = false const enableGzip = false
const maxEntrySize = 1000 const maxEntrySize = 1000
@ -27,16 +22,16 @@ func (l *queryLog) flushLogBuffer(fullFlush bool) error {
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
// flush remainder to file // flush remainder to file
l.logBufferLock.Lock() l.bufferLock.Lock()
needFlush := len(l.logBuffer) >= logBufferCap needFlush := len(l.buffer) >= logBufferCap
if !needFlush && !fullFlush { if !needFlush && !fullFlush {
l.logBufferLock.Unlock() l.bufferLock.Unlock()
return nil return nil
} }
flushBuffer := l.logBuffer flushBuffer := l.buffer
l.logBuffer = nil l.buffer = nil
l.flushPending = false l.flushPending = false
l.logBufferLock.Unlock() l.bufferLock.Unlock()
err := l.flushToFile(flushBuffer) err := l.flushToFile(flushBuffer)
if err != nil { if err != nil {
log.Error("Saving querylog to file failed: %s", err) log.Error("Saving querylog to file failed: %s", err)
@ -98,8 +93,8 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
zb = b zb = b
} }
fileWriteLock.Lock() l.fileWriteLock.Lock()
defer fileWriteLock.Unlock() defer l.fileWriteLock.Unlock()
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil { if err != nil {
log.Error("failed to create file \"%s\": %s", filename, err) log.Error("failed to create file \"%s\": %s", filename, err)
@ -146,7 +141,7 @@ func checkBuffer(buffer []*logEntry, b bytes.Buffer) error {
return nil return nil
} }
func (l *queryLog) rotateQueryLog() error { func (l *queryLog) rotate() error {
from := l.logFile from := l.logFile
to := l.logFile + ".1" to := l.logFile + ".1"
@ -171,9 +166,9 @@ func (l *queryLog) rotateQueryLog() error {
return nil return nil
} }
func (l *queryLog) periodicQueryLogRotate() { func (l *queryLog) periodicRotate() {
for range time.Tick(time.Duration(l.conf.Interval) * time.Hour) { for range time.Tick(time.Duration(l.conf.Interval) * 24 * time.Hour) {
err := l.rotateQueryLog() err := l.rotate()
if err != nil { if err != nil {
log.Error("Failed to rotate querylog: %s", err) log.Error("Failed to rotate querylog: %s", err)
// do nothing, continue rotating // do nothing, continue rotating
@ -219,7 +214,7 @@ func (l *queryLog) OpenReader() *Reader {
r := Reader{} r := Reader{}
r.ql = l r.ql = l
r.now = time.Now() r.now = time.Now()
r.validFrom = r.now.Unix() - int64(l.conf.Interval*60*60) r.validFrom = r.now.Unix() - int64(l.conf.Interval*24*60*60)
r.validFrom *= 1000000000 r.validFrom *= 1000000000
r.files = []string{ r.files = []string{
r.ql.logFile, r.ql.logFile,

View File

@ -12,9 +12,10 @@ import (
func TestQueryLog(t *testing.T) { func TestQueryLog(t *testing.T) {
conf := Config{ conf := Config{
Enabled: true,
Interval: 1, Interval: 1,
} }
l := New(conf) l := newQueryLog(conf)
q := dns.Msg{} q := dns.Msg{}
q.Question = append(q.Question, dns.Question{ q.Question = append(q.Question, dns.Question{
@ -37,10 +38,10 @@ func TestQueryLog(t *testing.T) {
res := dnsfilter.Result{} res := dnsfilter.Result{}
l.Add(&q, &a, &res, 0, nil, "upstream") l.Add(&q, &a, &res, 0, nil, "upstream")
params := GetDataParams{ params := getDataParams{
OlderThan: time.Now(), OlderThan: time.Now(),
} }
d := l.GetData(params) d := l.getData(params)
m := d[0] m := d[0]
mq := m["question"].(map[string]interface{}) mq := m["question"].(map[string]interface{})
assert.True(t, mq["host"].(string) == "example.org") assert.True(t, mq["host"].(string) == "example.org")