diff --git a/changelog.config.js b/changelog.config.js index 95079477..427807cd 100644 --- a/changelog.config.js +++ b/changelog.config.js @@ -16,12 +16,14 @@ module.exports = { ], "scopes": [ "", + "ui", "global", "dnsfilter", "home", "dnsforward", "dhcpd", - "documentation" + "querylog", + "documentation", ], "types": { "+": { diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 193b69fc..0638dddf 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -1,5 +1,12 @@ # AdGuard Home API Change Log +## v0.103: API changes + +### API: Get querylog: GET /control/querylog + +* Added optional "offset" and "limit" parameters + +We are still using "older_than" approach in AdGuard Home UI, but we realize that it's easier to use offset/limit so here is this option now. ## v0.102: API changes diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index a72bd816..304c8dc5 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -143,13 +143,26 @@ paths: tags: - log operationId: queryLog - summary: Get DNS server query log + summary: Get DNS server query log. parameters: - name: older_than in: query description: Filter by older than schema: type: string + - name: offset + in: query + description: + Specify the ranking number of the first item on the page. + Even though it is possible to use "offset" and "older_than", + we recommend choosing one of them and sticking to it. + schema: + type: integer + - name: limit + in: query + description: Limit the number of records to be returned + schema: + type: integer - name: filter_domain in: query description: Filter by domain name diff --git a/querylog/decode.go b/querylog/decode.go new file mode 100644 index 00000000..3f381140 --- /dev/null +++ b/querylog/decode.go @@ -0,0 +1,175 @@ +package querylog + +import ( + "encoding/base64" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// decodeLogEntry - decodes query log entry from a line +// nolint (gocyclo) +func decodeLogEntry(ent *logEntry, str string) { + var b bool + var i int + var err error + for { + k, v, t := readJSON(&str) + if t == jsonTErr { + break + } + switch k { + case "IP": + if len(ent.IP) == 0 { + ent.IP = v + } + case "T": + ent.Time, err = time.Parse(time.RFC3339, v) + + case "QH": + ent.QHost = v + case "QT": + ent.QType = v + case "QC": + ent.QClass = v + + case "Answer": + ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) + + case "IsFiltered": + b, err = strconv.ParseBool(v) + ent.Result.IsFiltered = b + case "Rule": + ent.Result.Rule = v + case "FilterID": + i, err = strconv.Atoi(v) + ent.Result.FilterID = int64(i) + case "Reason": + i, err = strconv.Atoi(v) + ent.Result.Reason = dnsfilter.Reason(i) + + case "Upstream": + ent.Upstream = v + case "Elapsed": + i, err = strconv.Atoi(v) + ent.Elapsed = time.Duration(i) + + // pre-v0.99.3 compatibility: + case "Question": + var qstr []byte + qstr, err = base64.StdEncoding.DecodeString(v) + if err != nil { + break + } + q := new(dns.Msg) + err = q.Unpack(qstr) + if err != nil { + break + } + ent.QHost = q.Question[0].Name + if len(ent.QHost) == 0 { + break + } + ent.QHost = ent.QHost[:len(ent.QHost)-1] + ent.QType = dns.TypeToString[q.Question[0].Qtype] + ent.QClass = dns.ClassToString[q.Question[0].Qclass] + case "Time": + ent.Time, err = time.Parse(time.RFC3339, v) + } + + if err != nil { + log.Debug("decodeLogEntry err: %s", err) + break + } + } +} + +// Get value from "key":"value" +func readJSONValue(s, name string) string { + i := strings.Index(s, "\""+name+"\":\"") + if i == -1 { + return "" + } + start := i + 1 + len(name) + 3 + i = strings.IndexByte(s[start:], '"') + if i == -1 { + return "" + } + end := start + i + return s[start:end] +} + +const ( + jsonTErr = iota + jsonTObj + jsonTStr + jsonTNum + jsonTBool +) + +// Parse JSON key-value pair +// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) +// Note the limitations: +// . doesn't support whitespace +// . doesn't support "null" +// . doesn't validate boolean or number +// . no proper handling of {} braces +// . no handling of [] brackets +// Return (key, value, type) +func readJSON(ps *string) (string, string, int32) { + s := *ps + k := "" + v := "" + t := int32(jsonTErr) + + q1 := strings.IndexByte(s, '"') + if q1 == -1 { + return k, v, t + } + q2 := strings.IndexByte(s[q1+1:], '"') + if q2 == -1 { + return k, v, t + } + k = s[q1+1 : q1+1+q2] + s = s[q1+1+q2+1:] + + if len(s) < 2 || s[0] != ':' { + return k, v, t + } + + if s[1] == '"' { + q2 = strings.IndexByte(s[2:], '"') + if q2 == -1 { + return k, v, t + } + v = s[2 : 2+q2] + t = jsonTStr + s = s[2+q2+1:] + + } else if s[1] == '{' { + t = jsonTObj + s = s[1+1:] + + } else { + sep := strings.IndexAny(s[1:], ",}") + if sep == -1 { + return k, v, t + } + v = s[1 : 1+sep] + if s[1] == 't' || s[1] == 'f' { + t = jsonTBool + } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { + t = jsonTNum + } + s = s[1+sep+1:] + } + + *ps = s + return k, v, t +} diff --git a/querylog/decode_test.go b/querylog/decode_test.go new file mode 100644 index 00000000..b5b3b7d9 --- /dev/null +++ b/querylog/decode_test.go @@ -0,0 +1,34 @@ +package querylog + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJSON(t *testing.T) { + s := ` + {"keystr":"val","obj":{"keybool":true,"keyint":123456}} + ` + k, v, jtype := readJSON(&s) + assert.Equal(t, jtype, int32(jsonTStr)) + assert.Equal(t, "keystr", k) + assert.Equal(t, "val", v) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTObj)) + assert.Equal(t, "obj", k) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTBool)) + assert.Equal(t, "keybool", k) + assert.Equal(t, "true", v) + + k, v, jtype = readJSON(&s) + assert.Equal(t, jtype, int32(jsonTNum)) + assert.Equal(t, "keyint", k) + assert.Equal(t, "123456", v) + + k, v, jtype = readJSON(&s) + assert.True(t, jtype == jsonTErr) +} diff --git a/querylog/json.go b/querylog/json.go new file mode 100644 index 00000000..86f35e11 --- /dev/null +++ b/querylog/json.go @@ -0,0 +1,164 @@ +package querylog + +import ( + "fmt" + "net" + "strconv" + "time" + + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// Get Client IP address +func (l *queryLog) getClientIP(clientIP string) string { + if l.conf.AnonymizeClientIP { + ip := net.ParseIP(clientIP) + if ip != nil { + ip4 := ip.To4() + const AnonymizeClientIP4Mask = 24 + const AnonymizeClientIP6Mask = 112 + if ip4 != nil { + clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() + } else { + clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() + } + } + } + + return clientIP +} + +// entriesToJSON - converts log entries to JSON +func (l *queryLog) entriesToJSON(entries []*logEntry, oldest time.Time) map[string]interface{} { + // init the response object + var data = []map[string]interface{}{} + + // the elements order is already reversed (from newer to older) + for i := 0; i < len(entries); i++ { + entry := entries[i] + jsonEntry := l.logEntryToJSONEntry(entry) + data = append(data, jsonEntry) + } + + var result = map[string]interface{}{} + result["oldest"] = "" + if !oldest.IsZero() { + result["oldest"] = oldest.Format(time.RFC3339Nano) + } + result["data"] = data + + return result +} + +func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} { + var msg *dns.Msg + + if len(entry.Answer) > 0 { + msg = new(dns.Msg) + if err := msg.Unpack(entry.Answer); err != nil { + log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) + msg = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339Nano), + "client": l.getClientIP(entry.IP), + } + jsonEntry["question"] = map[string]interface{}{ + "host": entry.QHost, + "type": entry.QType, + "class": entry.QClass, + } + + if msg != nil { + jsonEntry["status"] = dns.RcodeToString[msg.Rcode] + + opt := msg.IsEdns0() + dnssecOk := false + if opt != nil { + dnssecOk = opt.Do() + } + jsonEntry["answer_dnssec"] = dnssecOk + } + + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if len(entry.Result.ServiceName) != 0 { + jsonEntry["service_name"] = entry.Result.ServiceName + } + + answers := answerToMap(msg) + if answers != nil { + jsonEntry["answer"] = answers + } + + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + + return jsonEntry +} + +func answerToMap(a *dns.Msg) []map[string]interface{} { + if a == nil || len(a.Answer) == 0 { + return nil + } + + var answers = []map[string]interface{}{} + for _, k := range a.Answer { + header := k.Header() + answer := map[string]interface{}{ + "type": dns.TypeToString[header.Rrtype], + "ttl": header.Ttl, + } + // try most common record types + switch v := k.(type) { + case *dns.A: + answer["value"] = v.A.String() + case *dns.AAAA: + answer["value"] = v.AAAA.String() + case *dns.MX: + answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) + case *dns.CNAME: + answer["value"] = v.Target + case *dns.NS: + answer["value"] = v.Ns + case *dns.SPF: + answer["value"] = v.Txt + case *dns.TXT: + answer["value"] = v.Txt + case *dns.PTR: + answer["value"] = v.Ptr + case *dns.SOA: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) + case *dns.CAA: + answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) + case *dns.HINFO: + answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) + case *dns.RRSIG: + answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) + default: + // type unknown, marshall it as-is + answer["value"] = v + } + answers = append(answers, answer) + } + + return answers +} diff --git a/querylog/qlog.go b/querylog/qlog.go index 247bb519..e85da3b9 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -1,11 +1,8 @@ package querylog import ( - "fmt" - "net" "os" "path/filepath" - "strconv" "strings" "sync" "time" @@ -17,10 +14,6 @@ import ( const ( queryLogFileName = "querylog.json" // .gz added during compression - getDataLimit = 500 // GetData(): maximum log entries to return - - // maximum entries to parse when searching - maxSearchEntries = 50000 ) // queryLog is a structure that writes and reads the DNS query log @@ -36,6 +29,23 @@ type queryLog struct { fileWriteLock sync.Mutex } +// logEntry - represents a single log entry +type logEntry struct { + IP string `json:"IP"` // Client IP + Time time.Time `json:"T"` + + QHost string `json:"QH"` + QType string `json:"QT"` + QClass string `json:"QC"` + + Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + OrigAnswer []byte `json:",omitempty"` + + Result dnsfilter.Result + Elapsed time.Duration + Upstream string `json:",omitempty"` // if empty, means it was cached +} + // create a new instance of the query log func newQueryLog(conf Config) *queryLog { l := queryLog{} @@ -93,22 +103,6 @@ func (l *queryLog) clear() { log.Debug("Query log: cleared") } -type logEntry struct { - IP string `json:"IP"` - Time time.Time `json:"T"` - - QHost string `json:"QH"` - QType string `json:"QT"` - QClass string `json:"QC"` - - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - OrigAnswer []byte `json:",omitempty"` - - Result dnsfilter.Result - Elapsed time.Duration - Upstream string `json:",omitempty"` // if empty, means it was cached -} - func (l *queryLog) Add(params AddParams) { if !l.conf.Enabled { return @@ -173,230 +167,3 @@ func (l *queryLog) Add(params AddParams) { go l.flushLogBuffer(false) // nolint } } - -// Parameters for getData() -type getDataParams struct { - OlderThan time.Time // return entries that are older than this value - Domain string // filter by domain name in question - Client string // filter by client IP - QuestionType string // filter by question type - ResponseStatus responseStatusType // filter by response status - StrictMatchDomain bool // if Domain value must be matched strictly - StrictMatchClient bool // if Client value must be matched strictly -} - -// Response status -type responseStatusType int32 - -// Response status constants -const ( - responseStatusAll responseStatusType = iota + 1 - responseStatusFiltered -) - -// Gets log entries -func (l *queryLog) getData(params getDataParams) map[string]interface{} { - now := time.Now() - - if len(params.Client) != 0 && l.conf.AnonymizeClientIP { - params.Client = l.getClientIP(params.Client) - } - - // add from file - fileEntries, oldest, total := l.searchFiles(params) - - if params.OlderThan.IsZero() { - // In case if the timer is not precise (for instance, on Windows) - // We really want to get all records including those added just before the call - params.OlderThan = now.Add(time.Millisecond) - } - - // add from memory buffer - l.bufferLock.Lock() - total += len(l.buffer) - memoryEntries := make([]*logEntry, 0) - - // go through the buffer in the reverse order - // from NEWER to OLDER - for i := len(l.buffer) - 1; i >= 0; i-- { - entry := l.buffer[i] - - if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { - // Ignore entries newer than what was requested - continue - } - - if !matchesGetDataParams(entry, params) { - continue - } - - memoryEntries = append(memoryEntries, entry) - } - l.bufferLock.Unlock() - - // now let's get a unified collection - entries := append(memoryEntries, fileEntries...) - if len(entries) > getDataLimit { - // remove extra records - entries = entries[:getDataLimit] - } - if len(entries) == getDataLimit { - // change the "oldest" value here. - // we cannot use the "oldest" we got from "searchFiles" anymore - // because after adding in-memory records and removing extra records - // the situation has changed - oldest = entries[len(entries)-1].Time - } - - // init the response object - var data = []map[string]interface{}{} - - // the elements order is already reversed (from newer to older) - for i := 0; i < len(entries); i++ { - entry := entries[i] - jsonEntry := l.logEntryToJSONEntry(entry) - data = append(data, jsonEntry) - } - - log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s", - len(entries), total, params.OlderThan, time.Since(now)) - - var result = map[string]interface{}{} - result["oldest"] = "" - if !oldest.IsZero() { - result["oldest"] = oldest.Format(time.RFC3339Nano) - } - result["data"] = data - return result -} - -// Get Client IP address -func (l *queryLog) getClientIP(clientIP string) string { - if l.conf.AnonymizeClientIP { - ip := net.ParseIP(clientIP) - if ip != nil { - ip4 := ip.To4() - const AnonymizeClientIP4Mask = 24 - const AnonymizeClientIP6Mask = 112 - if ip4 != nil { - clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String() - } else { - clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String() - } - } - } - - return clientIP -} - -func (l *queryLog) logEntryToJSONEntry(entry *logEntry) map[string]interface{} { - var msg *dns.Msg - - if len(entry.Answer) > 0 { - msg = new(dns.Msg) - if err := msg.Unpack(entry.Answer); err != nil { - log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) - msg = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339Nano), - "client": l.getClientIP(entry.IP), - } - jsonEntry["question"] = map[string]interface{}{ - "host": entry.QHost, - "type": entry.QType, - "class": entry.QClass, - } - - if msg != nil { - jsonEntry["status"] = dns.RcodeToString[msg.Rcode] - - opt := msg.IsEdns0() - dnssecOk := false - if opt != nil { - dnssecOk = opt.Do() - } - jsonEntry["answer_dnssec"] = dnssecOk - } - - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if len(entry.Result.ServiceName) != 0 { - jsonEntry["service_name"] = entry.Result.ServiceName - } - - answers := answerToMap(msg) - if answers != nil { - jsonEntry["answer"] = answers - } - - if len(entry.OrigAnswer) != 0 { - a := new(dns.Msg) - err := a.Unpack(entry.OrigAnswer) - if err == nil { - answers = answerToMap(a) - if answers != nil { - jsonEntry["original_answer"] = answers - } - } else { - log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) - } - } - - return jsonEntry -} - -func answerToMap(a *dns.Msg) []map[string]interface{} { - if a == nil || len(a.Answer) == 0 { - return nil - } - - var answers = []map[string]interface{}{} - for _, k := range a.Answer { - header := k.Header() - answer := map[string]interface{}{ - "type": dns.TypeToString[header.Rrtype], - "ttl": header.Ttl, - } - // try most common record types - switch v := k.(type) { - case *dns.A: - answer["value"] = v.A.String() - case *dns.AAAA: - answer["value"] = v.AAAA.String() - case *dns.MX: - answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) - case *dns.CNAME: - answer["value"] = v.Target - case *dns.NS: - answer["value"] = v.Ns - case *dns.SPF: - answer["value"] = v.Txt - case *dns.TXT: - answer["value"] = v.Txt - case *dns.PTR: - answer["value"] = v.Ptr - case *dns.SOA: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) - case *dns.CAA: - answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) - case *dns.HINFO: - answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) - case *dns.RRSIG: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) - default: - // type unknown, marshall it as-is - answer["value"] = v - } - answers = append(answers, answer) - } - - return answers -} diff --git a/querylog/qlog_http.go b/querylog/qlog_http.go index fae8dba6..19caa35c 100644 --- a/querylog/qlog_http.go +++ b/querylog/qlog_http.go @@ -4,13 +4,30 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "strconv" "time" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/jsonutil" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) +type qlogConfig struct { + Enabled bool `json:"enabled"` + Interval uint32 `json:"interval"` + AnonymizeClientIP bool `json:"anonymize_client_ip"` +} + +// Register web handlers +func (l *queryLog) initWeb() { + l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog) + l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo) + l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear) + l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig) +} + func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) { text := fmt.Sprintf(format, args...) @@ -19,74 +36,18 @@ func httpError(r *http.Request, w http.ResponseWriter, code int, format string, http.Error(w, text, code) } -type request struct { - olderThan string - filterDomain string - filterClient string - filterQuestionType string - filterResponseStatus string -} - -// "value" -> value, return TRUE -func getDoubleQuotesEnclosedValue(s *string) bool { - t := *s - if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' { - *s = t[1 : len(t)-1] - return true - } - return false -} - func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { - var err error - req := request{} - q := r.URL.Query() - req.olderThan = q.Get("older_than") - req.filterDomain = q.Get("filter_domain") - req.filterClient = q.Get("filter_client") - req.filterQuestionType = q.Get("filter_question_type") - req.filterResponseStatus = q.Get("filter_response_status") - - params := getDataParams{ - Domain: req.filterDomain, - Client: req.filterClient, - ResponseStatus: responseStatusAll, - } - if len(req.olderThan) != 0 { - params.OlderThan, err = time.Parse(time.RFC3339Nano, req.olderThan) - if err != nil { - httpError(r, w, http.StatusBadRequest, "invalid time stamp: %s", err) - return - } + params, err := l.parseSearchParams(r) + if err != nil { + httpError(r, w, http.StatusBadRequest, "failed to parse params: %s", err) + return } - if getDoubleQuotesEnclosedValue(¶ms.Domain) { - params.StrictMatchDomain = true - } - if getDoubleQuotesEnclosedValue(¶ms.Client) { - params.StrictMatchClient = true - } + // search for the log entries + entries, oldest := l.search(params) - if len(req.filterQuestionType) != 0 { - _, ok := dns.StringToType[req.filterQuestionType] - if !ok { - httpError(r, w, http.StatusBadRequest, "invalid question_type") - return - } - params.QuestionType = req.filterQuestionType - } - - if len(req.filterResponseStatus) != 0 { - switch req.filterResponseStatus { - case "filtered": - params.ResponseStatus = responseStatusFiltered - default: - httpError(r, w, http.StatusBadRequest, "invalid response_status") - return - } - } - - data := l.getData(params) + // convert log entries to JSON + var data = l.entriesToJSON(entries, oldest) jsonVal, err := json.Marshal(data) if err != nil { @@ -101,16 +62,10 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { } } -func (l *queryLog) handleQueryLogClear(w http.ResponseWriter, r *http.Request) { +func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { l.clear() } -type qlogConfig struct { - Enabled bool `json:"enabled"` - Interval uint32 `json:"interval"` - AnonymizeClientIP bool `json:"anonymize_client_ip"` -} - // Get configuration func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { resp := qlogConfig{} @@ -162,10 +117,85 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) l.conf.ConfigModified() } -// Register web handlers -func (l *queryLog) initWeb() { - l.conf.HTTPRegister("GET", "/control/querylog", l.handleQueryLog) - l.conf.HTTPRegister("GET", "/control/querylog_info", l.handleQueryLogInfo) - l.conf.HTTPRegister("POST", "/control/querylog_clear", l.handleQueryLogClear) - l.conf.HTTPRegister("POST", "/control/querylog_config", l.handleQueryLogConfig) +// "value" -> value, return TRUE +func getDoubleQuotesEnclosedValue(s *string) bool { + t := *s + if len(t) >= 2 && t[0] == '"' && t[len(t)-1] == '"' { + *s = t[1 : len(t)-1] + return true + } + return false +} + +// parseSearchCriteria - parses "searchCriteria" from the specified query parameter +func (l *queryLog) parseSearchCriteria(q url.Values, name string, ct criteriaType) (bool, searchCriteria, error) { + val := q.Get(name) + if len(val) == 0 { + return false, searchCriteria{}, nil + } + + c := searchCriteria{ + criteriaType: ct, + value: val, + } + if getDoubleQuotesEnclosedValue(&c.value) { + c.strict = true + } + + if ct == ctClient && l.conf.AnonymizeClientIP { + c.value = l.getClientIP(c.value) + } + + if ct == ctFilteringStatus && !util.ContainsString(filteringStatusValues, c.value) { + return false, c, fmt.Errorf("invalid value %s", c.value) + } + + return true, c, nil +} + +// parseSearchParams - parses "searchParams" from the HTTP request's query string +func (l *queryLog) parseSearchParams(r *http.Request) (*searchParams, error) { + p := newSearchParams() + + var err error + q := r.URL.Query() + olderThan := q.Get("older_than") + if len(olderThan) != 0 { + p.olderThan, err = time.Parse(time.RFC3339Nano, olderThan) + if err != nil { + return nil, err + } + } + + if limit, err := strconv.ParseInt(q.Get("limit"), 10, 64); err == nil { + p.limit = int(limit) + + // If limit or offset are specified explicitly, we should change the default behavior + // and scan all log records until we found enough log entries + p.maxFileScanEntries = 0 + } + if offset, err := strconv.ParseInt(q.Get("offset"), 10, 64); err == nil { + p.offset = int(offset) + p.maxFileScanEntries = 0 + } + + paramNames := map[string]criteriaType{ + "filter_domain": ctDomain, + "filter_client": ctClient, + "filter_question_type": ctQuestionType, + "filter_response_status": ctFilteringStatus, + } + + for k, v := range paramNames { + ok, c, err := l.parseSearchCriteria(q, k, v) + if err != nil { + return nil, err + } + + if ok { + p.searchCriteria = append(p.searchCriteria, c) + } + } + + return p, nil } diff --git a/querylog/qlog_test.go b/querylog/qlog_test.go new file mode 100644 index 00000000..a5392fc0 --- /dev/null +++ b/querylog/qlog_test.go @@ -0,0 +1,228 @@ +package querylog + +import ( + "net" + "os" + "testing" + + "github.com/AdguardTeam/dnsproxy/proxyutil" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +func prepareTestDir() string { + const dir = "./agh-test" + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, 0755) + return dir +} + +// Check adding and loading (with filtering) entries from disk and memory +func TestQueryLog(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add disk entries + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + // write to disk (first file) + _ = l.flushLogBuffer(true) + // start writing to the second file + _ = l.rotate() + // add disk entries + addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") + // write to disk + _ = l.flushLogBuffer(true) + // add memory entries + addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") + addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") + + // get all entries + params := newSearchParams() + entries, _ := l.search(params) + assert.Equal(t, 4, len(entries)) + assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") + assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") + + // search by domain (strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctDomain, + strict: true, + value: "test.example.org", + }) + entries, _ = l.search(params) + assert.Equal(t, 1, len(entries)) + assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") + + // search by domain (not strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctDomain, + strict: false, + value: "example.org", + }) + entries, _ = l.search(params) + assert.Equal(t, 3, len(entries)) + assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") + + // search by client IP (strict) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctClient, + strict: true, + value: "2.2.2.2", + }) + entries, _ = l.search(params) + assert.Equal(t, 1, len(entries)) + assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") + + // search by client IP (part of) + params = newSearchParams() + params.searchCriteria = append(params.searchCriteria, searchCriteria{ + criteriaType: ctClient, + strict: false, + value: "2.2.2", + }) + entries, _ = l.search(params) + assert.Equal(t, 4, len(entries)) + assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") + assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") + assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") + assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") +} + +func TestQueryLogOffsetLimit(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add 10 entries to the log + for i := 0; i < 10; i++ { + addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1") + } + // write them to disk (first file) + _ = l.flushLogBuffer(true) + // add 10 more entries to the log (memory) + for i := 0; i < 10; i++ { + addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1") + } + + // First page + params := newSearchParams() + params.offset = 0 + params.limit = 10 + entries, _ := l.search(params) + assert.Equal(t, 10, len(entries)) + assert.Equal(t, entries[0].QHost, "first.example.org") + assert.Equal(t, entries[9].QHost, "first.example.org") + + // Second page + params.offset = 10 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 10, len(entries)) + assert.Equal(t, entries[0].QHost, "second.example.org") + assert.Equal(t, entries[9].QHost, "second.example.org") + + // Second and a half page + params.offset = 15 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 5, len(entries)) + assert.Equal(t, entries[0].QHost, "second.example.org") + assert.Equal(t, entries[4].QHost, "second.example.org") + + // Third page + params.offset = 20 + params.limit = 10 + entries, _ = l.search(params) + assert.Equal(t, 0, len(entries)) +} + +func TestQueryLogMaxFileScanEntries(t *testing.T) { + conf := Config{ + Enabled: true, + Interval: 1, + MemSize: 100, + } + conf.BaseDir = prepareTestDir() + defer func() { _ = os.RemoveAll(conf.BaseDir) }() + l := newQueryLog(conf) + + // add 10 entries to the log + for i := 0; i < 10; i++ { + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + } + // write them to disk (first file) + _ = l.flushLogBuffer(true) + + params := newSearchParams() + params.maxFileScanEntries = 5 // do not scan more than 5 records + entries, _ := l.search(params) + assert.Equal(t, 5, len(entries)) + + params.maxFileScanEntries = 0 // disable the limit + entries, _ = l.search(params) + assert.Equal(t, 10, len(entries)) +} + +func addEntry(l *queryLog, host, answerStr, client string) { + q := dns.Msg{} + q.Question = append(q.Question, dns.Question{ + Name: host + ".", + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + }) + + a := dns.Msg{} + a.Question = append(a.Question, q.Question[0]) + answer := new(dns.A) + answer.Hdr = dns.RR_Header{ + Name: q.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + } + answer.A = net.ParseIP(answerStr) + a.Answer = append(a.Answer, answer) + res := dnsfilter.Result{} + params := AddParams{ + Question: &q, + Answer: &a, + Result: &res, + ClientIP: net.ParseIP(client), + Upstream: "upstream", + } + l.Add(params) +} + +func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool { + assert.Equal(t, host, entry.QHost) + assert.Equal(t, client, entry.IP) + assert.Equal(t, "A", entry.QType) + assert.Equal(t, "IN", entry.QClass) + + msg := new(dns.Msg) + assert.Nil(t, msg.Unpack(entry.Answer)) + assert.Equal(t, 1, len(msg.Answer)) + ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) + assert.NotNil(t, ip) + assert.Equal(t, answer, ip.String()) + return true +} diff --git a/querylog/querylog_search.go b/querylog/querylog_search.go index f9493af9..eda1f92d 100644 --- a/querylog/querylog_search.go +++ b/querylog/querylog_search.go @@ -1,18 +1,72 @@ package querylog import ( - "encoding/base64" "io" - "strconv" - "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) +// search - searches log entries in the query log using specified parameters +// returns the list of entries found + time of the oldest entry +func (l *queryLog) search(params *searchParams) ([]*logEntry, time.Time) { + now := time.Now() + + if params.limit == 0 { + return []*logEntry{}, time.Time{} + } + + // add from file + fileEntries, oldest, total := l.searchFiles(params) + + // add from memory buffer + l.bufferLock.Lock() + total += len(l.buffer) + memoryEntries := make([]*logEntry, 0) + + // go through the buffer in the reverse order + // from NEWER to OLDER + for i := len(l.buffer) - 1; i >= 0; i-- { + entry := l.buffer[i] + if !params.match(entry) { + continue + } + memoryEntries = append(memoryEntries, entry) + } + l.bufferLock.Unlock() + + // limits + totalLimit := params.offset + params.limit + + // now let's get a unified collection + entries := append(memoryEntries, fileEntries...) + if len(entries) > totalLimit { + // remove extra records + entries = entries[:totalLimit] + } + if params.offset > 0 { + if len(entries) > params.offset { + entries = entries[params.offset:] + } else { + entries = make([]*logEntry, 0) + oldest = time.Time{} + } + } + if len(entries) == totalLimit { + // change the "oldest" value here. + // we cannot use the "oldest" we got from "searchFiles" anymore + // because after adding in-memory records and removing extra records + // the situation has changed + oldest = entries[len(entries)-1].Time + } + + log.Debug("QueryLog: prepared data (%d/%d) older than %s in %s", + len(entries), total, params.olderThan, time.Since(now)) + + return entries, oldest +} + // searchFiles reads log entries from all log files and applies the specified search criteria. // IMPORTANT: this method does not scan more than "maxSearchEntries" so you // may need to call it many times. @@ -21,7 +75,7 @@ import ( // * an array of log entries that we have read // * time of the oldest processed entry (even if it was discarded) // * total number of processed entries (including discarded). -func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { +func (l *queryLog) searchFiles(params *searchParams) ([]*logEntry, time.Time, int) { entries := make([]*logEntry, 0) oldest := time.Time{} @@ -32,10 +86,10 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in } defer r.Close() - if params.OlderThan.IsZero() { + if params.olderThan.IsZero() { err = r.SeekStart() } else { - err = r.Seek(params.OlderThan.UnixNano()) + err = r.Seek(params.olderThan.UnixNano()) if err == nil { // Read to the next record right away // The one that was specified in the "oldest" param is not needed, @@ -45,14 +99,17 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in } if err != nil { - log.Debug("Cannot Seek() to %v: %v", params.OlderThan, err) + log.Debug("Cannot Seek() to %v: %v", params.olderThan, err) return entries, oldest, 0 } + totalLimit := params.offset + params.limit total := 0 oldestNano := int64(0) - // Do not scan more than 50k at once - for total <= maxSearchEntries { + // By default, we do not scan more than "maxFileScanEntries" at once + // The idea is to make search calls faster so that the UI could handle it and show something + // This behavior can be overridden if "maxFileScanEntries" is set to 0 + for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { entry, ts, err := l.readNextEntry(r, params) if err == io.EOF { @@ -65,8 +122,8 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in if entry != nil { entries = append(entries, entry) - if len(entries) == getDataLimit { - // Do not read more than "getDataLimit" records at once + if len(entries) == totalLimit { + // Do not read more than "totalLimit" records at once break } } @@ -82,7 +139,7 @@ func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, in // * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) // * timestamp of the processed log entry // * error if we can't read anymore -func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { +func (l *queryLog) readNextEntry(r *QLogReader, params *searchParams) (*logEntry, int64, error) { line, err := r.ReadNext() if err != nil { return nil, 0, err @@ -92,7 +149,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry timestamp := readQLogTimestamp(line) // Quick check without deserializing log entry - if !quickMatchesGetDataParams(line, params) { + if !params.quickMatch(line) { return nil, timestamp, nil } @@ -100,7 +157,7 @@ func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry decodeLogEntry(&entry, line) // Full check of the deserialized log entry - if !matchesGetDataParams(&entry, params) { + if !params.match(&entry) { return nil, timestamp, nil } @@ -120,257 +177,3 @@ func (l *queryLog) openReader() (*QLogReader, error) { return NewQLogReader(files) } - -// quickMatchesGetDataParams - quickly checks if the line matches getDataParams -// this method does not guarantee anything and the reason is to do a quick check -// without deserializing anything -func quickMatchesGetDataParams(line string, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered { - boolVal, ok := readJSONBool(line, "IsFiltered") - if !ok || !boolVal { - return false - } - } - - if len(params.Domain) != 0 { - val := readJSONValue(line, "QH") - if len(val) == 0 { - return false - } - - if (params.StrictMatchDomain && val != params.Domain) || - (!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) { - return false - } - } - - if len(params.QuestionType) != 0 { - val := readJSONValue(line, "QT") - if val != params.QuestionType { - return false - } - } - - if len(params.Client) != 0 { - val := readJSONValue(line, "IP") - if len(val) == 0 { - log.Debug("QueryLog: failed to decodeLogEntry") - return false - } - - if (params.StrictMatchClient && val != params.Client) || - (!params.StrictMatchClient && strings.Index(val, params.Client) == -1) { - return false - } - } - - return true -} - -// matchesGetDataParams - returns true if the entry matches the search parameters -func matchesGetDataParams(entry *logEntry, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { - return false - } - - if len(params.QuestionType) != 0 { - if entry.QType != params.QuestionType { - return false - } - } - - if len(params.Domain) != 0 { - if (params.StrictMatchDomain && entry.QHost != params.Domain) || - (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { - return false - } - } - - if len(params.Client) != 0 { - if (params.StrictMatchClient && entry.IP != params.Client) || - (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { - return false - } - } - - return true -} - -// decodeLogEntry - decodes query log entry from a line -// nolint (gocyclo) -func decodeLogEntry(ent *logEntry, str string) { - var b bool - var i int - var err error - for { - k, v, t := readJSON(&str) - if t == jsonTErr { - break - } - switch k { - case "IP": - if len(ent.IP) == 0 { - ent.IP = v - } - case "T": - ent.Time, err = time.Parse(time.RFC3339, v) - - case "QH": - ent.QHost = v - case "QT": - ent.QType = v - case "QC": - ent.QClass = v - - case "Answer": - ent.Answer, err = base64.StdEncoding.DecodeString(v) - case "OrigAnswer": - ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) - - case "IsFiltered": - b, err = strconv.ParseBool(v) - ent.Result.IsFiltered = b - case "Rule": - ent.Result.Rule = v - case "FilterID": - i, err = strconv.Atoi(v) - ent.Result.FilterID = int64(i) - case "Reason": - i, err = strconv.Atoi(v) - ent.Result.Reason = dnsfilter.Reason(i) - - case "Upstream": - ent.Upstream = v - case "Elapsed": - i, err = strconv.Atoi(v) - ent.Elapsed = time.Duration(i) - - // pre-v0.99.3 compatibility: - case "Question": - var qstr []byte - qstr, err = base64.StdEncoding.DecodeString(v) - if err != nil { - break - } - q := new(dns.Msg) - err = q.Unpack(qstr) - if err != nil { - break - } - ent.QHost = q.Question[0].Name - if len(ent.QHost) == 0 { - break - } - ent.QHost = ent.QHost[:len(ent.QHost)-1] - ent.QType = dns.TypeToString[q.Question[0].Qtype] - ent.QClass = dns.ClassToString[q.Question[0].Qclass] - case "Time": - ent.Time, err = time.Parse(time.RFC3339, v) - } - - if err != nil { - log.Debug("decodeLogEntry err: %s", err) - break - } - } -} - -// Get bool value from "key":bool -func readJSONBool(s, name string) (bool, bool) { - i := strings.Index(s, "\""+name+"\":") - if i == -1 { - return false, false - } - start := i + 1 + len(name) + 2 - b := false - if strings.HasPrefix(s[start:], "true") { - b = true - } else if !strings.HasPrefix(s[start:], "false") { - return false, false - } - return b, true -} - -// Get value from "key":"value" -func readJSONValue(s, name string) string { - i := strings.Index(s, "\""+name+"\":\"") - if i == -1 { - return "" - } - start := i + 1 + len(name) + 3 - i = strings.IndexByte(s[start:], '"') - if i == -1 { - return "" - } - end := start + i - return s[start:end] -} - -const ( - jsonTErr = iota - jsonTObj - jsonTStr - jsonTNum - jsonTBool -) - -// Parse JSON key-value pair -// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) -// Note the limitations: -// . doesn't support whitespace -// . doesn't support "null" -// . doesn't validate boolean or number -// . no proper handling of {} braces -// . no handling of [] brackets -// Return (key, value, type) -func readJSON(ps *string) (string, string, int32) { - s := *ps - k := "" - v := "" - t := int32(jsonTErr) - - q1 := strings.IndexByte(s, '"') - if q1 == -1 { - return k, v, t - } - q2 := strings.IndexByte(s[q1+1:], '"') - if q2 == -1 { - return k, v, t - } - k = s[q1+1 : q1+1+q2] - s = s[q1+1+q2+1:] - - if len(s) < 2 || s[0] != ':' { - return k, v, t - } - - if s[1] == '"' { - q2 = strings.IndexByte(s[2:], '"') - if q2 == -1 { - return k, v, t - } - v = s[2 : 2+q2] - t = jsonTStr - s = s[2+q2+1:] - - } else if s[1] == '{' { - t = jsonTObj - s = s[1+1:] - - } else { - sep := strings.IndexAny(s[1:], ",}") - if sep == -1 { - return k, v, t - } - v = s[1 : 1+sep] - if s[1] == 't' || s[1] == 'f' { - t = jsonTBool - } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { - t = jsonTNum - } - s = s[1+sep+1:] - } - - *ps = s - return k, v, t -} diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go deleted file mode 100644 index 06de4101..00000000 --- a/querylog/querylog_test.go +++ /dev/null @@ -1,176 +0,0 @@ -package querylog - -import ( - "net" - "os" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" -) - -func prepareTestDir() string { - const dir = "./agh-test" - _ = os.RemoveAll(dir) - _ = os.MkdirAll(dir, 0755) - return dir -} - -// Check adding and loading (with filtering) entries from disk and memory -func TestQueryLog(t *testing.T) { - conf := Config{ - Enabled: true, - Interval: 1, - MemSize: 100, - } - conf.BaseDir = prepareTestDir() - defer func() { _ = os.RemoveAll(conf.BaseDir) }() - l := newQueryLog(conf) - - // add disk entries - addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") - // write to disk (first file) - _ = l.flushLogBuffer(true) - // start writing to the second file - _ = l.rotate() - // add disk entries - addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") - // write to disk - _ = l.flushLogBuffer(true) - // add memory entries - addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") - addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") - - // get all entries - params := getDataParams{ - OlderThan: time.Time{}, - } - d := l.getData(params) - mdata := d["data"].([]map[string]interface{}) - assert.Equal(t, 4, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) - assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) - - // search by domain (strict) - params = getDataParams{ - OlderThan: time.Time{}, - Domain: "test.example.org", - StrictMatchDomain: true, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - - // search by domain (not strict) - params = getDataParams{ - OlderThan: time.Time{}, - Domain: "example.org", - StrictMatchDomain: false, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 3, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) - - // search by client IP (strict) - params = getDataParams{ - OlderThan: time.Time{}, - Client: "2.2.2.2", - StrictMatchClient: true, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 1, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2")) - - // search by client IP (part of) - params = getDataParams{ - OlderThan: time.Time{}, - Client: "2.2.2", - StrictMatchClient: false, - } - d = l.getData(params) - mdata = d["data"].([]map[string]interface{}) - assert.Equal(t, 4, len(mdata)) - assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) - assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) - assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) - assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) -} - -func addEntry(l *queryLog, host, answerStr, client string) { - q := dns.Msg{} - q.Question = append(q.Question, dns.Question{ - Name: host + ".", - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - }) - - a := dns.Msg{} - a.Question = append(a.Question, q.Question[0]) - answer := new(dns.A) - answer.Hdr = dns.RR_Header{ - Name: q.Question[0].Name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - } - answer.A = net.ParseIP(answerStr) - a.Answer = append(a.Answer, answer) - res := dnsfilter.Result{} - params := AddParams{ - Question: &q, - Answer: &a, - Result: &res, - ClientIP: net.ParseIP(client), - Upstream: "upstream", - } - l.Add(params) -} - -func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool { - mq := m["question"].(map[string]interface{}) - ma := m["answer"].([]map[string]interface{}) - ma0 := ma[0] - if !assert.Equal(t, host, mq["host"].(string)) || - !assert.Equal(t, "IN", mq["class"].(string)) || - !assert.Equal(t, "A", mq["type"].(string)) || - !assert.Equal(t, answer, ma0["value"].(string)) || - !assert.Equal(t, client, m["client"].(string)) { - return false - } - return true -} - -func TestJSON(t *testing.T) { - s := ` - {"keystr":"val","obj":{"keybool":true,"keyint":123456}} - ` - k, v, jtype := readJSON(&s) - assert.Equal(t, jtype, int32(jsonTStr)) - assert.Equal(t, "keystr", k) - assert.Equal(t, "val", v) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTObj)) - assert.Equal(t, "obj", k) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTBool)) - assert.Equal(t, "keybool", k) - assert.Equal(t, "true", v) - - k, v, jtype = readJSON(&s) - assert.Equal(t, jtype, int32(jsonTNum)) - assert.Equal(t, "keyint", k) - assert.Equal(t, "123456", v) - - k, v, jtype = readJSON(&s) - assert.True(t, jtype == jsonTErr) -} diff --git a/querylog/search_criteria.go b/querylog/search_criteria.go new file mode 100644 index 00000000..b2ba63f6 --- /dev/null +++ b/querylog/search_criteria.go @@ -0,0 +1,139 @@ +package querylog + +import ( + "strings" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" +) + +type criteriaType int + +const ( + ctDomain criteriaType = iota // domain name + ctClient // client IP address + ctQuestionType // question type + ctFilteringStatus // filtering status +) + +const ( + filteringStatusAll = "all" + filteringStatusFiltered = "filtered" // all kinds of filtering + + filteringStatusBlocked = "blocked" // blocked or blocked service + filteringStatusBlockedSafebrowsing = "blocked_safebrowsing" // blocked by safebrowsing + filteringStatusBlockedParental = "blocked_parental" // blocked by parental control + filteringStatusWhitelisted = "whitelisted" // whitelisted + filteringStatusRewritten = "rewritten" // all kinds of rewrites + filteringStatusSafeSearch = "safe_search" // enforced safe search +) + +// filteringStatusValues -- array with all possible filteringStatus values +var filteringStatusValues = []string{ + filteringStatusAll, filteringStatusFiltered, filteringStatusBlocked, + filteringStatusBlockedSafebrowsing, filteringStatusBlockedParental, + filteringStatusWhitelisted, filteringStatusRewritten, filteringStatusSafeSearch, +} + +// searchCriteria - every search request may contain a list of different search criteria +// we use each of them to match the query +type searchCriteria struct { + criteriaType criteriaType // type of the criteria + strict bool // should we strictly match (equality) or not (indexOf) + value string // search criteria value +} + +// quickMatch - quickly checks if the log entry matches this search criteria +// the reason is to do it as quickly as possible without de-serializing the entry +func (c *searchCriteria) quickMatch(line string) bool { + // note that we do this only for a limited set of criteria + + switch c.criteriaType { + case ctDomain: + return c.quickMatchJSONValue(line, "QH") + case ctClient: + return c.quickMatchJSONValue(line, "IP") + case ctQuestionType: + return c.quickMatchJSONValue(line, "QT") + default: + return true + } +} + +// quickMatchJSONValue - helper used by quickMatch +func (c *searchCriteria) quickMatchJSONValue(line string, propertyName string) bool { + val := readJSONValue(line, propertyName) + if len(val) == 0 { + return false + } + + if c.strict && c.value == val { + return true + } + if !c.strict && strings.Contains(val, c.value) { + return true + } + + return false +} + +// match - checks if the log entry matches this search criteria +// nolint (gocyclo) +func (c *searchCriteria) match(entry *logEntry) bool { + switch c.criteriaType { + case ctDomain: + if c.strict && entry.QHost == c.value { + return true + } + if !c.strict && strings.Contains(entry.QHost, c.value) { + return true + } + return false + case ctClient: + if c.strict && entry.IP == c.value { + return true + } + if !c.strict && strings.Contains(entry.IP, c.value) { + return true + } + return false + case ctQuestionType: + if c.strict && entry.QType == c.value { + return true + } + if !c.strict && strings.Contains(entry.QType, c.value) { + return true + } + case ctFilteringStatus: + res := entry.Result + + switch c.value { + case filteringStatusAll: + return true + case filteringStatusFiltered: + return res.IsFiltered + case filteringStatusBlocked: + return res.IsFiltered && + (res.Reason == dnsfilter.FilteredBlackList || + res.Reason == dnsfilter.FilteredBlockedService) + case filteringStatusBlockedParental: + return res.IsFiltered && res.Reason == dnsfilter.FilteredParental + case filteringStatusBlockedSafebrowsing: + return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeBrowsing + case filteringStatusWhitelisted: + return res.IsFiltered && res.Reason == dnsfilter.NotFilteredWhiteList + case filteringStatusRewritten: + return res.IsFiltered && + (res.Reason == dnsfilter.ReasonRewrite || + res.Reason == dnsfilter.RewriteEtcHosts) + case filteringStatusSafeSearch: + return res.IsFiltered && res.Reason == dnsfilter.FilteredSafeSearch + default: + return false + } + + default: + return false + } + + return false +} diff --git a/querylog/search_params.go b/querylog/search_params.go new file mode 100644 index 00000000..da083e45 --- /dev/null +++ b/querylog/search_params.go @@ -0,0 +1,57 @@ +package querylog + +import "time" + +// searchParams represent the search query sent by the client +type searchParams struct { + // searchCriteria - list of search criteria that we use to get filter results + searchCriteria []searchCriteria + + // olderThen - return entries that are older than this value + // if not set - disregard it and return any value + olderThan time.Time + + offset int // offset for the search + limit int // limit the number of records returned + maxFileScanEntries int // maximum log entries to scan in query log files. if 0 - no limit +} + +// newSearchParams - creates an empty instance of searchParams +func newSearchParams() *searchParams { + return &searchParams{ + // default max log entries to return + limit: 500, + + // by default, we scan up to 50k entries at once + maxFileScanEntries: 50000, + } +} + +// quickMatchesGetDataParams - quickly checks if the line matches the searchParams +// this method does not guarantee anything and the reason is to do a quick check +// without deserializing anything +func (s *searchParams) quickMatch(line string) bool { + for _, c := range s.searchCriteria { + if !c.quickMatch(line) { + return false + } + } + + return true +} + +// match - checks if the logEntry matches the searchParams +func (s *searchParams) match(entry *logEntry) bool { + if !s.olderThan.IsZero() && entry.Time.UnixNano() >= s.olderThan.UnixNano() { + // Ignore entries newer than what was requested + return false + } + + for _, c := range s.searchCriteria { + if !c.match(entry) { + return false + } + } + + return true +} diff --git a/util/helpers.go b/util/helpers.go index 27ac4d71..b32ac456 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -10,6 +10,16 @@ import ( "strings" ) +// ContainsString checks if "v" is in the array "arr" +func ContainsString(arr []string, v string) bool { + for _, i := range arr { + if i == v { + return true + } + } + return false +} + // fileExists returns TRUE if file exists func FileExists(fn string) bool { _, err := os.Stat(fn)