diff --git a/go.mod b/go.mod index 28b2d1a4..18e8de6e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/kardianos/service v0.0.0-20181115005516-4c239ee84e7b github.com/krolaw/dhcp4 v0.0.0-20180925202202-7cead472c414 github.com/miekg/dns v1.1.26 + github.com/pkg/errors v0.8.1 github.com/sparrc/go-ping v0.0.0-20181106165434-ef3ab45e41b0 github.com/stretchr/testify v1.4.0 go.etcd.io/bbolt v1.3.3 // indirect diff --git a/querylog/qlog.go b/querylog/qlog.go index 8a34130e..a0570011 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -171,82 +171,6 @@ func (l *queryLog) Add(params AddParams) { } } -// Return TRUE if this entry is needed -func isNeeded(entry *logEntry, params getDataParams) bool { - if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { - return false - } - - if len(params.QuestionType) != 0 { - if entry.QType != params.QuestionType { - return false - } - } - - if len(params.Domain) != 0 { - if (params.StrictMatchDomain && entry.QHost != params.Domain) || - (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { - return false - } - } - - if len(params.Client) != 0 { - if (params.StrictMatchClient && entry.IP != params.Client) || - (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { - return false - } - } - - return true -} - -func (l *queryLog) readFromFile(params getDataParams) ([]*logEntry, time.Time, int) { - entries := []*logEntry{} - oldest := time.Time{} - - r := l.OpenReader() - if r == nil { - return entries, time.Time{}, 0 - } - r.BeginRead(params.OlderThan, getDataLimit, ¶ms) - total := uint64(0) - for total <= maxSearchEntries { - newEntries := []*logEntry{} - for { - entry := r.Next() - if entry == nil { - break - } - - if !isNeeded(entry, params) { - continue - } - if len(newEntries) == getDataLimit { - newEntries = newEntries[1:] - } - newEntries = append(newEntries, entry) - } - - log.Debug("entries: +%d (%d) [%d]", len(newEntries), len(entries), r.Total()) - - entries = append(newEntries, entries...) - if len(entries) > getDataLimit { - toremove := len(entries) - getDataLimit - entries = entries[toremove:] - break - } - if r.Total() == 0 || len(entries) == getDataLimit { - break - } - total += r.Total() - oldest = r.Oldest() - r.BeginReadPrev(getDataLimit) - } - - r.Close() - return entries, oldest, int(total) -} - // Parameters for getData() type getDataParams struct { OlderThan time.Time // return entries that are older than this value @@ -267,17 +191,12 @@ const ( responseStatusFiltered ) -// Get log entries +// Gets log entries func (l *queryLog) getData(params getDataParams) map[string]interface{} { - var data = []map[string]interface{}{} - - var oldest time.Time now := time.Now() - entries := []*logEntry{} - total := 0 // add from file - entries, oldest, total = l.readFromFile(params) + fileEntries, oldest, total := l.searchFiles(params) if params.OlderThan.IsZero() { params.OlderThan = now @@ -286,78 +205,40 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { // add from memory buffer l.bufferLock.Lock() total += len(l.buffer) - for _, entry := range l.buffer { + memoryEntries := make([]*logEntry, 0) - if !isNeeded(entry, params) { + // go through the buffer in the reverse order + // from NEWER to OLDER + for i := len(l.buffer) - 1; i >= 0; i-- { + entry := l.buffer[i] + + if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { + // Ignore entries newer than what was requested continue } - if entry.Time.UnixNano() >= params.OlderThan.UnixNano() { - break + if !matchesGetDataParams(entry, params) { + continue } - if len(entries) == getDataLimit { - entries = entries[1:] - } - entries = append(entries, entry) + memoryEntries = append(memoryEntries, entry) } l.bufferLock.Unlock() - // process the elements from latest to oldest - for i := len(entries) - 1; i >= 0; i-- { + // now let's get a unified collection + entries := append(memoryEntries, fileEntries...) + if len(entries) > getDataLimit { + // remove extra records + entries = entries[(len(entries) - getDataLimit):] + } + + // init the response object + var data = []map[string]interface{}{} + + // the elements order is already reversed (from newer to older) + for i := 0; i < len(entries); i++ { entry := entries[i] - var a *dns.Msg - - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339Nano), - "client": entry.IP, - } - jsonEntry["question"] = map[string]interface{}{ - "host": entry.QHost, - "type": entry.QType, - "class": entry.QClass, - } - - if a != nil { - jsonEntry["status"] = dns.RcodeToString[a.Rcode] - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if len(entry.Result.ServiceName) != 0 { - jsonEntry["service_name"] = entry.Result.ServiceName - } - - answers := answerToMap(a) - if answers != nil { - jsonEntry["answer"] = answers - } - - if len(entry.OrigAnswer) != 0 { - a := new(dns.Msg) - err := a.Unpack(entry.OrigAnswer) - if err == nil { - answers = answerToMap(a) - if answers != nil { - jsonEntry["original_answer"] = answers - } - } else { - log.Debug("Querylog: a.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) - } - } - + jsonEntry := logEntryToJSONEntry(entry) data = append(data, jsonEntry) } @@ -376,6 +257,62 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { return result } +func logEntryToJSONEntry(entry *logEntry) map[string]interface{} { + var msg *dns.Msg + + if len(entry.Answer) > 0 { + msg = new(dns.Msg) + if err := msg.Unpack(entry.Answer); err != nil { + log.Debug("Failed to unpack dns message answer: %s: %s", err, string(entry.Answer)) + msg = nil + } + } + + jsonEntry := map[string]interface{}{ + "reason": entry.Result.Reason.String(), + "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), + "time": entry.Time.Format(time.RFC3339Nano), + "client": entry.IP, + } + jsonEntry["question"] = map[string]interface{}{ + "host": entry.QHost, + "type": entry.QType, + "class": entry.QClass, + } + + if msg != nil { + jsonEntry["status"] = dns.RcodeToString[msg.Rcode] + } + if len(entry.Result.Rule) > 0 { + jsonEntry["rule"] = entry.Result.Rule + jsonEntry["filterId"] = entry.Result.FilterID + } + + if len(entry.Result.ServiceName) != 0 { + jsonEntry["service_name"] = entry.Result.ServiceName + } + + answers := answerToMap(msg) + if answers != nil { + jsonEntry["answer"] = answers + } + + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: msg.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + + return jsonEntry +} + func answerToMap(a *dns.Msg) []map[string]interface{} { if a == nil || len(a.Answer) == 0 { return nil diff --git a/querylog/qlog_file.go b/querylog/qlog_file.go new file mode 100644 index 00000000..c1eeefa2 --- /dev/null +++ b/querylog/qlog_file.go @@ -0,0 +1,333 @@ +package querylog + +import ( + "io" + "os" + "sync" + "time" + + "github.com/AdguardTeam/golibs/log" + + "github.com/pkg/errors" +) + +// ErrSeekNotFound is returned from the Seek method +// if we failed to find the desired record +var ErrSeekNotFound = errors.New("Seek not found the record") + +// TODO: Find a way to grow buffer instead of relying on this value when reading strings +const maxEntrySize = 16 * 1024 + +// buffer should be enough for at least this number of entries +const bufferSize = 100 * maxEntrySize + +// QLogFile represents a single query log file +// It allows reading from the file in the reverse order +// +// Please note that this is a stateful object. +// Internally, it contains a pointer to a specific position in the file, +// and it reads lines in reverse order starting from that position. +type QLogFile struct { + file *os.File // the query log file + position int64 // current position in the file + + buffer []byte // buffer that we've read from the file + bufferStart int64 // start of the buffer (in the file) + bufferLen int // buffer len + + lock sync.Mutex // We use mutex to make it thread-safe +} + +// NewQLogFile initializes a new instance of the QLogFile +func NewQLogFile(path string) (*QLogFile, error) { + f, err := os.OpenFile(path, os.O_RDONLY, 0644) + + if err != nil { + return nil, err + } + + return &QLogFile{ + file: f, + }, nil +} + +// Seek performs binary search in the query log file looking for a record +// with the specified timestamp. Once the record is found, it sets +// "position" so that the next ReadNext call returned that record. +// +// The algorithm is rather simple: +// 1. It starts with the position in the middle of a file +// 2. Shifts back to the beginning of the line +// 3. Checks the log record timestamp +// 4. If it is lower than the timestamp we are looking for, +// it shifts seek position to 3/4 of the file. Otherwise, to 1/4 of the file. +// 5. It performs the search again, every time the search scope is narrowed twice. +// +// Returns: +// * It returns the position of the the line with the timestamp we were looking for +// so that when we call "ReadNext" this line was returned. +// * Depth of the search (how many times we compared timestamps). +// * If we could not find it, it returns ErrSeekNotFound +func (q *QLogFile) Seek(timestamp int64) (int64, int, error) { + q.lock.Lock() + defer q.lock.Unlock() + + // Empty the buffer + q.buffer = nil + + // First of all, check the file size + fileInfo, err := q.file.Stat() + if err != nil { + return 0, 0, err + } + + // Define the search scope + start := int64(0) // start of the search interval (position in the file) + end := fileInfo.Size() // end of the search interval (position in the file) + probe := (end - start) / 2 // probe -- approximate index of the line we'll try to check + var line string + var lineIdx int64 // index of the probe line in the file + var lastProbeLineIdx int64 // index of the last probe line + + // Count seek depth in order to detect mistakes + // If depth is too large, we should stop the search + depth := 0 + + for { + // Get the line at the specified position + line, lineIdx, err = q.readProbeLine(probe) + if err != nil { + return 0, depth, err + } + + // Get the timestamp from the query log record + ts := readQLogTimestamp(line) + + if ts == 0 { + return 0, depth, ErrSeekNotFound + } + + if ts == timestamp { + // Hurray, returning the result + break + } + + if lastProbeLineIdx == lineIdx { + // If we're testing the same line twice then most likely + // the scope is too narrow and we won't find anything anymore + return 0, depth, ErrSeekNotFound + } + + // Save the last found idx + lastProbeLineIdx = lineIdx + + // Narrow the scope and repeat the search + if ts > timestamp { + // If the timestamp we're looking for is OLDER than what we found + // Then the line is somewhere on the LEFT side from the current probe position + end = probe + probe = start + (end-start)/2 + } else { + // If the timestamp we're looking for is NEWER than what we found + // Then the line is somewhere on the RIGHT side from the current probe position + start = probe + probe = start + (end-start)/2 + } + + depth++ + if depth >= 100 { + log.Error("Seek depth is too high, aborting. File %s, ts %v", q.file.Name(), timestamp) + return 0, depth, ErrSeekNotFound + } + } + + q.position = lineIdx + int64(len(line)) + return q.position, depth, nil +} + +// SeekStart changes the current position to the end of the file +// Please note that we're reading query log in the reverse order +// and that's why log start is actually the end of file +// +// Returns nil if we were able to change the current position. +// Returns error in any other case. +func (q *QLogFile) SeekStart() (int64, error) { + q.lock.Lock() + defer q.lock.Unlock() + + // Empty the buffer + q.buffer = nil + + // First of all, check the file size + fileInfo, err := q.file.Stat() + if err != nil { + return 0, err + } + + // Place the position to the very end of file + q.position = fileInfo.Size() - 1 + if q.position < 0 { + q.position = 0 + } + return q.position, nil +} + +// ReadNext reads the next line (in the reverse order) from the file +// and shifts the current position left to the next (actually prev) line. +// returns io.EOF if there's nothing to read more +func (q *QLogFile) ReadNext() (string, error) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.position == 0 { + return "", io.EOF + } + + line, lineIdx, err := q.readNextLine(q.position) + if err != nil { + return "", err + } + + // Shift position + if lineIdx == 0 { + q.position = 0 + } else { + // there's usually a line break before the line + // so we should shift one more char left from the line + // line\nline + q.position = lineIdx - 1 + } + return line, err +} + +// Close frees the underlying resources +func (q *QLogFile) Close() error { + return q.file.Close() +} + +// readNextLine reads the next line from the specified position +// this line actually have to END on that position. +// +// the algorithm is: +// 1. check if we have the buffer initialized +// 2. if it is, scan it and look for the line there +// 3. if we cannot find the line there, read the prev chunk into the buffer +// 4. read the line from the buffer +func (q *QLogFile) readNextLine(position int64) (string, int64, error) { + relativePos := position - q.bufferStart + if q.buffer == nil || (relativePos < maxEntrySize && q.bufferStart != 0) { + // Time to re-init the buffer + err := q.initBuffer(position) + if err != nil { + return "", 0, err + } + relativePos = position - q.bufferStart + } + + // Look for the end of the prev line + // This is where we'll read from + var startLine = int64(0) + for i := relativePos - 1; i >= 0; i-- { + if q.buffer[i] == '\n' { + startLine = i + 1 + break + } + } + + line := string(q.buffer[startLine:relativePos]) + lineIdx := q.bufferStart + startLine + return line, lineIdx, nil +} + +// initBuffer initializes the QLogFile buffer. +// the goal is to read a chunk of file that includes the line with the specified position. +func (q *QLogFile) initBuffer(position int64) error { + q.bufferStart = int64(0) + if (position - bufferSize) > 0 { + q.bufferStart = position - bufferSize + } + + // Seek to this position + _, err := q.file.Seek(q.bufferStart, io.SeekStart) + if err != nil { + return err + } + + if q.buffer == nil { + q.buffer = make([]byte, bufferSize) + } + q.bufferLen, err = q.file.Read(q.buffer) + if err != nil { + return err + } + + return nil +} + +// readProbeLine reads a line that includes the specified position +// this method is supposed to be used when we use binary search in the Seek method +// in the case of consecutive reads, use readNext (it uses a better buffer) +func (q *QLogFile) readProbeLine(position int64) (string, int64, error) { + // First of all, we should read a buffer that will include the query log line + // In order to do this, we'll define the boundaries + seekPosition := int64(0) + relativePos := position // position relative to the buffer we're going to read + if (position - maxEntrySize) > 0 { + seekPosition = position - maxEntrySize + relativePos = maxEntrySize + } + + // Seek to this position + _, err := q.file.Seek(seekPosition, io.SeekStart) + if err != nil { + return "", 0, err + } + + // The buffer size is 2*maxEntrySize + buffer := make([]byte, maxEntrySize*2) + bufferLen, err := q.file.Read(buffer) + if err != nil { + return "", 0, err + } + + // Now start looking for the new line character starting + // from the relativePos and going left + var startLine = int64(0) + for i := relativePos - 1; i >= 0; i-- { + if buffer[i] == '\n' { + startLine = i + 1 + break + } + } + // Looking for the end of line now + var endLine = int64(bufferLen) + for i := relativePos; i < int64(bufferLen); i++ { + if buffer[i] == '\n' { + endLine = i + break + } + } + + // Finally we can return the string we were looking for + lineIdx := startLine + seekPosition + return string(buffer[startLine:endLine]), lineIdx, nil +} + +// readQLogTimestamp reads the timestamp field from the query log line +func readQLogTimestamp(str string) int64 { + val := readJSONValue(str, "T") + if len(val) == 0 { + val = readJSONValue(str, "Time") + } + + if len(val) == 0 { + log.Error("Couldn't find timestamp: %s", str) + return 0 + } + tm, err := time.Parse(time.RFC3339Nano, val) + if err != nil { + log.Error("Couldn't parse timestamp: %s", val) + return 0 + } + return tm.UnixNano() +} diff --git a/querylog/qlog_file_test.go b/querylog/qlog_file_test.go new file mode 100644 index 00000000..a0fa07f1 --- /dev/null +++ b/querylog/qlog_file_test.go @@ -0,0 +1,268 @@ +package querylog + +import ( + "encoding/binary" + "io" + "io/ioutil" + "math" + "net" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQLogFileEmpty(t *testing.T) { + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, 0) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + defer q.Close() + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.Equal(t, int64(0), pos) + + // try reading anyway + line, err := q.ReadNext() + assert.Equal(t, io.EOF, err) + assert.Equal(t, "", line) +} + +func TestQLogFileLarge(t *testing.T) { + // should be large enough + count := 50000 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + defer q.Close() + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.NotEqual(t, int64(0), pos) + + read := 0 + var line string + for err == nil { + line, err = q.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count, read) + assert.Equal(t, io.EOF, err) +} + +func TestQLogFileSeekLargeFile(t *testing.T) { + // more or less big file + count := 10000 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + defer q.Close() + + // CASE 1: NOT TOO OLD LINE + testSeekLineQLogFile(t, q, 300) + + // CASE 2: OLD LINE + testSeekLineQLogFile(t, q, count-300) + + // CASE 3: FIRST LINE + testSeekLineQLogFile(t, q, 0) + + // CASE 4: LAST LINE + testSeekLineQLogFile(t, q, count) + + // CASE 5: Seek non-existent (too low) + _, _, err = q.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + _, _, err = q.Seek(ts.UnixNano()) + assert.NotNil(t, err) + + // CASE 7: "Almost" found + line, err := getQLogFileLine(q, count/2) + assert.Nil(t, err) + // ALMOST the record we need + timestamp := readQLogTimestamp(line) - 1 + assert.NotEqual(t, uint64(0), timestamp) + _, depth, err := q.Seek(timestamp) + assert.NotNil(t, err) + assert.True(t, depth <= int(math.Log2(float64(count))+3)) +} + +func TestQLogFileSeekSmallFile(t *testing.T) { + // more or less big file + count := 10 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, count) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + defer q.Close() + + // CASE 1: NOT TOO OLD LINE + testSeekLineQLogFile(t, q, 2) + + // CASE 2: OLD LINE + testSeekLineQLogFile(t, q, count-2) + + // CASE 3: FIRST LINE + testSeekLineQLogFile(t, q, 0) + + // CASE 4: LAST LINE + testSeekLineQLogFile(t, q, count) + + // CASE 5: Seek non-existent (too low) + _, _, err = q.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + _, _, err = q.Seek(ts.UnixNano()) + assert.NotNil(t, err) + + // CASE 7: "Almost" found + line, err := getQLogFileLine(q, count/2) + assert.Nil(t, err) + // ALMOST the record we need + timestamp := readQLogTimestamp(line) - 1 + assert.NotEqual(t, uint64(0), timestamp) + _, depth, err := q.Seek(timestamp) + assert.NotNil(t, err) + assert.True(t, depth <= int(math.Log2(float64(count))+3)) +} + +func testSeekLineQLogFile(t *testing.T, q *QLogFile, lineNumber int) { + line, err := getQLogFileLine(q, lineNumber) + assert.Nil(t, err) + ts := readQLogTimestamp(line) + assert.NotEqual(t, uint64(0), ts) + + // try seeking to that line now + pos, _, err := q.Seek(ts) + assert.Nil(t, err) + assert.NotEqual(t, int64(0), pos) + + testLine, err := q.ReadNext() + assert.Nil(t, err) + assert.Equal(t, line, testLine) +} + +func getQLogFileLine(q *QLogFile, lineNumber int) (string, error) { + _, err := q.SeekStart() + if err != nil { + return "", err + } + + for i := 1; i < lineNumber; i++ { + _, err := q.ReadNext() + if err != nil { + return "", err + } + } + return q.ReadNext() +} + +// Check adding and loading (with filtering) entries from disk and memory +func TestQLogFile(t *testing.T) { + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFile := prepareTestFile(testDir, 2) + + // create the new QLogFile instance + q, err := NewQLogFile(testFile) + assert.Nil(t, err) + assert.NotNil(t, q) + defer q.Close() + + // seek to the start + pos, err := q.SeekStart() + assert.Nil(t, err) + assert.True(t, pos > 0) + + // read first line + line, err := q.ReadNext() + assert.Nil(t, err) + assert.True(t, strings.Contains(line, "0.0.0.2"), line) + assert.True(t, strings.HasPrefix(line, "{"), line) + assert.True(t, strings.HasSuffix(line, "}"), line) + + // read second line + line, err = q.ReadNext() + assert.Nil(t, err) + assert.Equal(t, int64(0), q.position) + assert.True(t, strings.Contains(line, "0.0.0.1"), line) + assert.True(t, strings.HasPrefix(line, "{"), line) + assert.True(t, strings.HasSuffix(line, "}"), line) + + // try reading again (there's nothing to read anymore) + line, err = q.ReadNext() + assert.Equal(t, io.EOF, err) + assert.Equal(t, "", line) +} + +// prepareTestFile - prepares a test query log file with the specified number of lines +func prepareTestFile(dir string, linesCount int) string { + return prepareTestFiles(dir, 1, linesCount)[0] +} + +// prepareTestFiles - prepares several test query log files +// each of them -- with the specified linesCount +func prepareTestFiles(dir string, filesCount, linesCount int) []string { + format := `{"IP":"${IP}","T":"${TIMESTAMP}","QH":"example.org","QT":"A","QC":"IN","Answer":"AAAAAAABAAEAAAAAB2V4YW1wbGUDb3JnAAABAAEHZXhhbXBsZQNvcmcAAAEAAQAAAAAABAECAwQ=","Result":{},"Elapsed":0,"Upstream":"upstream"}` + + lineTime, _ := time.Parse(time.RFC3339Nano, "2020-02-18T22:36:35.920973+03:00") + lineIP := uint32(0) + + files := make([]string, 0) + for j := 0; j < filesCount; j++ { + f, _ := ioutil.TempFile(dir, "*.txt") + files = append(files, f.Name()) + + for i := 0; i < linesCount; i++ { + lineIP += 1 + lineTime = lineTime.Add(time.Second) + + ip := make(net.IP, 4) + binary.BigEndian.PutUint32(ip, lineIP) + + line := format + line = strings.ReplaceAll(line, "${IP}", ip.String()) + line = strings.ReplaceAll(line, "${TIMESTAMP}", lineTime.Format(time.RFC3339Nano)) + + _, _ = f.WriteString(line) + _, _ = f.WriteString("\n") + } + } + + return files +} diff --git a/querylog/qlog_reader.go b/querylog/qlog_reader.go new file mode 100644 index 00000000..ee2f617d --- /dev/null +++ b/querylog/qlog_reader.go @@ -0,0 +1,139 @@ +package querylog + +import ( + "io" + + "github.com/joomcode/errorx" +) + +// QLogReader allows reading from multiple query log files in the reverse order. +// +// Please note that this is a stateful object. +// Internally, it contains a pointer to a particular query log file, and +// to a specific position in this file, and it reads lines in reverse order +// starting from that position. +type QLogReader struct { + // qFiles - array with the query log files + // The order is - from oldest to newest + qFiles []*QLogFile + + currentFile int // Index of the current file +} + +// NewQLogReader initializes a QLogReader instance +// with the specified files +func NewQLogReader(files []string) (*QLogReader, error) { + qFiles := make([]*QLogFile, 0) + + for _, f := range files { + q, err := NewQLogFile(f) + if err != nil { + // Close what we've already opened + _ = closeQFiles(qFiles) + return nil, err + } + + qFiles = append(qFiles, q) + } + + return &QLogReader{ + qFiles: qFiles, + currentFile: (len(qFiles) - 1), + }, nil +} + +// Seek performs binary search of a query log record with the specified timestamp. +// If the record is found, it sets QLogReader's position to point to that line, +// so that the next ReadNext call returned this line. +// +// Returns nil if the record is successfully found. +// Returns an error if for some reason we could not find a record with the specified timestamp. +func (r *QLogReader) Seek(timestamp int64) error { + for i := len(r.qFiles) - 1; i >= 0; i-- { + q := r.qFiles[i] + _, _, err := q.Seek(timestamp) + if err == nil { + // Our search is finished, we found the element we were looking for + // Update currentFile only, position is already set properly in the QLogFile + r.currentFile = i + return nil + } + } + + return ErrSeekNotFound +} + +// SeekStart changes the current position to the end of the newest file +// Please note that we're reading query log in the reverse order +// and that's why log start is actually the end of file +// +// Returns nil if we were able to change the current position. +// Returns error in any other case. +func (r *QLogReader) SeekStart() error { + if len(r.qFiles) == 0 { + return nil + } + + r.currentFile = len(r.qFiles) - 1 + _, err := r.qFiles[r.currentFile].SeekStart() + return err +} + +// ReadNext reads the next line (in the reverse order) from the query log files. +// and shifts the current position left to the next (actually prev) line (or the next file). +// returns io.EOF if there's nothing to read more. +func (r *QLogReader) ReadNext() (string, error) { + if len(r.qFiles) == 0 { + return "", io.EOF + } + + for r.currentFile >= 0 { + q := r.qFiles[r.currentFile] + line, err := q.ReadNext() + if err != nil { + // Shift to the older file + r.currentFile-- + if r.currentFile < 0 { + break + } + + q = r.qFiles[r.currentFile] + + // Set it's position to the start right away + _, err = q.SeekStart() + + // This is unexpected, return an error right away + if err != nil { + return "", err + } + } else { + return line, nil + } + } + + // Nothing to read anymore + return "", io.EOF +} + +// Close closes the QLogReader +func (r *QLogReader) Close() error { + return closeQFiles(r.qFiles) +} + +// closeQFiles - helper method to close multiple QLogFile instances +func closeQFiles(qFiles []*QLogFile) error { + var errs []error + + for _, q := range qFiles { + err := q.Close() + if err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return errorx.DecorateMany("Error while closing QLogReader", errs...) + } + + return nil +} diff --git a/querylog/qlog_reader_test.go b/querylog/qlog_reader_test.go new file mode 100644 index 00000000..357b4f9d --- /dev/null +++ b/querylog/qlog_reader_test.go @@ -0,0 +1,157 @@ +package querylog + +import ( + "io" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestQLogReaderEmpty(t *testing.T) { + r, err := NewQLogReader([]string{}) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + line, err := r.ReadNext() + assert.Equal(t, "", line) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderOneFile(t *testing.T) { + // let's do one small file + count := 10 + filesCount := 1 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + // read everything + read := 0 + var line string + for err == nil { + line, err = r.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count*filesCount, read) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderMultipleFiles(t *testing.T) { + // should be large enough + count := 10000 + filesCount := 5 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // seek to the start + err = r.SeekStart() + assert.Nil(t, err) + + // read everything + read := 0 + var line string + for err == nil { + line, err = r.ReadNext() + if err == nil { + assert.True(t, len(line) > 0) + read += 1 + } + } + + assert.Equal(t, count*filesCount, read) + assert.Equal(t, io.EOF, err) +} + +func TestQLogReaderSeek(t *testing.T) { + // more or less big file + count := 10000 + filesCount := 2 + + testDir := prepareTestDir() + defer func() { _ = os.RemoveAll(testDir) }() + testFiles := prepareTestFiles(testDir, filesCount, count) + + r, err := NewQLogReader(testFiles) + assert.Nil(t, err) + assert.NotNil(t, r) + defer r.Close() + + // CASE 1: NOT TOO OLD LINE + testSeekLineQLogReader(t, r, 300) + + // CASE 2: OLD LINE + testSeekLineQLogReader(t, r, count-300) + + // CASE 3: FIRST LINE + testSeekLineQLogReader(t, r, 0) + + // CASE 4: LAST LINE + testSeekLineQLogReader(t, r, count) + + // CASE 5: Seek non-existent (too low) + err = r.Seek(123) + assert.NotNil(t, err) + + // CASE 6: Seek non-existent (too high) + ts, _ := time.Parse(time.RFC3339, "2100-01-02T15:04:05Z07:00") + err = r.Seek(ts.UnixNano()) + assert.NotNil(t, err) +} + +func testSeekLineQLogReader(t *testing.T, r *QLogReader, lineNumber int) { + line, err := getQLogReaderLine(r, lineNumber) + assert.Nil(t, err) + ts := readQLogTimestamp(line) + assert.NotEqual(t, uint64(0), ts) + + // try seeking to that line now + err = r.Seek(ts) + assert.Nil(t, err) + + testLine, err := r.ReadNext() + assert.Nil(t, err) + assert.Equal(t, line, testLine) +} + +func getQLogReaderLine(r *QLogReader, lineNumber int) (string, error) { + err := r.SeekStart() + if err != nil { + return "", err + } + + for i := 1; i < lineNumber; i++ { + _, err := r.ReadNext() + if err != nil { + return "", err + } + } + return r.ReadNext() +} diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index 02296a98..1eeeea7c 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -1,25 +1,14 @@ package querylog import ( - "bufio" "bytes" - "compress/gzip" - "encoding/base64" "encoding/json" - "io" "os" - "strconv" - "strings" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/golibs/log" - "github.com/miekg/dns" ) -const enableGzip = false -const maxEntrySize = 1000 - // flushLogBuffer flushes the current buffer to file and resets the current buffer func (l *queryLog) flushLogBuffer(fullFlush bool) error { l.fileFlushLock.Lock() @@ -68,29 +57,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error { var err error var zb bytes.Buffer filename := l.logFile - - // gzip enabled? - if enableGzip { - filename += ".gz" - - zw := gzip.NewWriter(&zb) - zw.Name = l.logFile - zw.ModTime = time.Now() - - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Error("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } - - if err = zw.Close(); err != nil { - log.Error("Couldn't close gzip writer: %s", err) - return err - } - } else { - zb = b - } + zb = b l.fileWriteLock.Lock() defer l.fileWriteLock.Unlock() @@ -116,11 +83,6 @@ func (l *queryLog) rotate() error { from := l.logFile to := l.logFile + ".1" - if enableGzip { - from = l.logFile + ".gz" - to = l.logFile + ".gz.1" - } - if _, err := os.Stat(from); os.IsNotExist(err) { // do nothing, file doesn't exist return nil @@ -133,7 +95,6 @@ func (l *queryLog) rotate() error { } log.Debug("Rotated from %s to %s successfully", from, to) - return nil } @@ -146,591 +107,3 @@ func (l *queryLog) periodicRotate() { } } } - -// Reader is the DB reader context -type Reader struct { - ql *queryLog - search *getDataParams - - f *os.File - reader *bufio.Reader // reads file line by line - now time.Time - validFrom int64 // UNIX time (ns) - olderThan int64 // UNIX time (ns) - oldest time.Time - - files []string - ifile int - - limit uint64 - count uint64 // counter for returned elements - latest bool // return the latest entries - filePrepared bool - - seeking bool // we're seaching for an entry with exact time stamp - fseeker fileSeeker // file seeker object - fpos uint64 // current file offset - nSeekRequests uint32 // number of Seek() requests made (finding a new line doesn't count) -} - -type fileSeeker struct { - target uint64 // target value - - pos uint64 // current offset, may be adjusted by user for increased accuracy - lastpos uint64 // the last offset returned - lo uint64 // low boundary offset - hi uint64 // high boundary offset -} - -// OpenReader - return reader object -func (l *queryLog) OpenReader() *Reader { - r := Reader{} - r.ql = l - r.now = time.Now() - r.validFrom = r.now.Unix() - int64(l.conf.Interval*24*60*60) - r.validFrom *= 1000000000 - r.files = []string{ - r.ql.logFile, - r.ql.logFile + ".1", - } - return &r -} - -// Close - close the reader -func (r *Reader) Close() { - elapsed := time.Since(r.now) - var perunit time.Duration - if r.count > 0 { - perunit = elapsed / time.Duration(r.count) - } - log.Debug("querylog: read %d entries in %v, %v/entry, seek-reqs:%d", - r.count, elapsed, perunit, r.nSeekRequests) - - if r.f != nil { - r.f.Close() - } -} - -// BeginRead - start reading -// olderThan: stop returning entries when an entry with this time is reached -// count: minimum number of entries to return -func (r *Reader) BeginRead(olderThan time.Time, count uint64, search *getDataParams) { - r.olderThan = olderThan.UnixNano() - r.latest = olderThan.IsZero() - r.oldest = time.Time{} - r.search = search - r.limit = count - if r.latest { - r.olderThan = r.now.UnixNano() - } - r.filePrepared = false - r.seeking = false -} - -// BeginReadPrev - start reading the previous data chunk -func (r *Reader) BeginReadPrev(count uint64) { - r.olderThan = r.oldest.UnixNano() - r.oldest = time.Time{} - r.latest = false - r.limit = count - r.count = 0 - - off := r.fpos - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err := r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("file.Seek: %s: %s", r.files[r.ifile], err) - return - } - r.nSeekRequests++ - - r.seekToNewLine() - r.fseeker.pos = r.fpos - - r.filePrepared = true - r.seeking = false -} - -// Perform binary seek -// Return 0: success; 1: seek reqiured; -1: error -func (fs *fileSeeker) seekBinary(cur uint64) int32 { - log.Debug("QueryLog: seek: tgt=%x cur=%x, %x: [%x..%x]", fs.target, cur, fs.pos, fs.lo, fs.hi) - - off := uint64(0) - if fs.pos >= fs.lo && fs.pos < fs.hi { - if cur == fs.target { - return 0 - } else if cur < fs.target { - fs.lo = fs.pos + 1 - } else { - fs.hi = fs.pos - } - off = fs.lo + (fs.hi-fs.lo)/2 - } else { - // we didn't find another entry from the last file offset: now return the boundary beginning - off = fs.lo - } - - if off == fs.lastpos { - return -1 - } - - fs.lastpos = off - fs.pos = off - return 1 -} - -// Seek to a new line -func (r *Reader) seekToNewLine() bool { - r.reader = bufio.NewReader(r.f) - b, err := r.reader.ReadBytes('\n') - if err != nil { - r.reader = nil - log.Error("QueryLog: file.Read: %s: %s", r.files[r.ifile], err) - return false - } - - off := len(b) - r.fpos += uint64(off) - log.Debug("QueryLog: seek: %x (+%d)", r.fpos, off) - return true -} - -// Open a file -func (r *Reader) openFile() bool { - var err error - fn := r.files[r.ifile] - - r.f, err = os.Open(fn) - if err != nil { - if !os.IsNotExist(err) { - log.Error("QueryLog: Failed to open file \"%s\": %s", fn, err) - } - return false - } - return true -} - -// Seek to the needed position -func (r *Reader) prepareRead() bool { - fn := r.files[r.ifile] - - fi, err := r.f.Stat() - if err != nil { - log.Error("QueryLog: file.Stat: %s: %s", fn, err) - return false - } - fsize := uint64(fi.Size()) - - off := uint64(0) - if r.latest { - // read data from the end of file - off = fsize - maxEntrySize*(r.limit+1) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - log.Debug("QueryLog: seek: %x", off) - _, err = r.f.Seek(int64(off), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } else { - // start searching in file: we'll read the first chunk of data from the middle of file - r.seeking = true - r.fseeker = fileSeeker{} - r.fseeker.target = uint64(r.olderThan) - r.fseeker.hi = fsize - rc := r.fseeker.seekBinary(0) - r.fpos = r.fseeker.pos - if rc == 1 { - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", fn, err) - return false - } - } - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return false - } - r.fseeker.pos = r.fpos - return true -} - -// Get bool value from "key":bool -func readJSONBool(s, name string) (bool, bool) { - i := strings.Index(s, "\""+name+"\":") - if i == -1 { - return false, false - } - start := i + 1 + len(name) + 2 - b := false - if strings.HasPrefix(s[start:], "true") { - b = true - } else if !strings.HasPrefix(s[start:], "false") { - return false, false - } - return b, true -} - -// Get value from "key":"value" -func readJSONValue(s, name string) string { - i := strings.Index(s, "\""+name+"\":\"") - if i == -1 { - return "" - } - start := i + 1 + len(name) + 3 - i = strings.IndexByte(s[start:], '"') - if i == -1 { - return "" - } - end := start + i - return s[start:end] -} - -// nolint (gocyclo) -func (r *Reader) applySearch(str string) bool { - if r.search.ResponseStatus == responseStatusFiltered { - boolVal, ok := readJSONBool(str, "IsFiltered") - if !ok || !boolVal { - return false - } - } - - mq := dns.Msg{} - - if len(r.search.Domain) != 0 { - val := readJSONValue(str, "QH") - if len(val) == 0 { - // pre-v0.99.3 compatibility - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - val = strings.TrimSuffix(mq.Question[0].Name, ".") - } - if len(val) == 0 { - return false - } - - if (r.search.StrictMatchDomain && val != r.search.Domain) || - (!r.search.StrictMatchDomain && strings.Index(val, r.search.Domain) == -1) { - return false - } - } - - if len(r.search.QuestionType) != 0 { - val := readJSONValue(str, "QT") - if len(val) == 0 { - // pre-v0.99.3 compatibility - if len(mq.Question) == 0 { - val = readJSONValue(str, "Question") - if len(val) == 0 { - return false - } - bval, err := base64.StdEncoding.DecodeString(val) - if err != nil { - return false - } - err = mq.Unpack(bval) - if err != nil { - return false - } - } - ok := false - val, ok = dns.TypeToString[mq.Question[0].Qtype] - if !ok { - return false - } - } - if val != r.search.QuestionType { - return false - } - } - - if len(r.search.Client) != 0 { - val := readJSONValue(str, "IP") - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - return false - } - - if (r.search.StrictMatchClient && val != r.search.Client) || - (!r.search.StrictMatchClient && strings.Index(val, r.search.Client) == -1) { - return false - } - } - - return true -} - -const ( - jsonTErr = iota - jsonTObj - jsonTStr - jsonTNum - jsonTBool -) - -// Parse JSON key-value pair -// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) -// Note the limitations: -// . doesn't support whitespace -// . doesn't support "null" -// . doesn't validate boolean or number -// . no proper handling of {} braces -// . no handling of [] brackets -// Return (key, value, type) -func readJSON(ps *string) (string, string, int32) { - s := *ps - k := "" - v := "" - t := int32(jsonTErr) - - q1 := strings.IndexByte(s, '"') - if q1 == -1 { - return k, v, t - } - q2 := strings.IndexByte(s[q1+1:], '"') - if q2 == -1 { - return k, v, t - } - k = s[q1+1 : q1+1+q2] - s = s[q1+1+q2+1:] - - if len(s) < 2 || s[0] != ':' { - return k, v, t - } - - if s[1] == '"' { - q2 = strings.IndexByte(s[2:], '"') - if q2 == -1 { - return k, v, t - } - v = s[2 : 2+q2] - t = jsonTStr - s = s[2+q2+1:] - - } else if s[1] == '{' { - t = jsonTObj - s = s[1+1:] - - } else { - sep := strings.IndexAny(s[1:], ",}") - if sep == -1 { - return k, v, t - } - v = s[1 : 1+sep] - if s[1] == 't' || s[1] == 'f' { - t = jsonTBool - } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { - t = jsonTNum - } - s = s[1+sep+1:] - } - - *ps = s - return k, v, t -} - -// nolint (gocyclo) -func decode(ent *logEntry, str string) { - var b bool - var i int - var err error - for { - k, v, t := readJSON(&str) - if t == jsonTErr { - break - } - switch k { - case "IP": - if len(ent.IP) == 0 { - ent.IP = v - } - case "T": - ent.Time, err = time.Parse(time.RFC3339, v) - - case "QH": - ent.QHost = v - case "QT": - ent.QType = v - case "QC": - ent.QClass = v - - case "Answer": - ent.Answer, err = base64.StdEncoding.DecodeString(v) - case "OrigAnswer": - ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) - - case "IsFiltered": - b, err = strconv.ParseBool(v) - ent.Result.IsFiltered = b - case "Rule": - ent.Result.Rule = v - case "FilterID": - i, err = strconv.Atoi(v) - ent.Result.FilterID = int64(i) - case "Reason": - i, err = strconv.Atoi(v) - ent.Result.Reason = dnsfilter.Reason(i) - - case "Upstream": - ent.Upstream = v - case "Elapsed": - i, err = strconv.Atoi(v) - ent.Elapsed = time.Duration(i) - - // pre-v0.99.3 compatibility: - case "Question": - var qstr []byte - qstr, err = base64.StdEncoding.DecodeString(v) - if err != nil { - break - } - q := new(dns.Msg) - err = q.Unpack(qstr) - if err != nil { - break - } - ent.QHost = q.Question[0].Name - if len(ent.QHost) == 0 { - break - } - ent.QHost = ent.QHost[:len(ent.QHost)-1] - ent.QType = dns.TypeToString[q.Question[0].Qtype] - ent.QClass = dns.ClassToString[q.Question[0].Qclass] - case "Time": - ent.Time, err = time.Parse(time.RFC3339, v) - } - - if err != nil { - log.Debug("decode err: %s", err) - break - } - } -} - -// Next - return the next entry or nil if reading is finished -func (r *Reader) Next() *logEntry { // nolint - for { - // open file if needed - if r.f == nil { - if r.ifile == len(r.files) { - return nil - } - if !r.openFile() { - r.ifile++ - continue - } - } - - if !r.filePrepared { - if !r.prepareRead() { - return nil - } - r.filePrepared = true - } - - b, err := r.reader.ReadBytes('\n') - if err != nil { - return nil - } - str := string(b) - - val := readJSONValue(str, "T") - if len(val) == 0 { - val = readJSONValue(str, "Time") - } - if len(val) == 0 { - log.Debug("QueryLog: failed to decode") - continue - } - tm, err := time.Parse(time.RFC3339, val) - if err != nil { - log.Debug("QueryLog: failed to decode") - continue - } - t := tm.UnixNano() - - if r.seeking { - - r.reader = nil - rr := r.fseeker.seekBinary(uint64(t)) - r.fpos = r.fseeker.pos - if rr < 0 { - log.Error("QueryLog: File seek error: can't find the target entry: %s", r.files[r.ifile]) - return nil - } else if rr == 0 { - // We found the target entry. - // We'll start reading the previous chunk of data. - r.seeking = false - - off := r.fpos - (maxEntrySize * (r.limit + 1)) - if int64(off) < maxEntrySize { - off = 0 - } - r.fpos = off - } - - _, err = r.f.Seek(int64(r.fpos), io.SeekStart) - if err != nil { - log.Error("QueryLog: file.Seek: %s: %s", r.files[r.ifile], err) - return nil - } - r.nSeekRequests++ - - if !r.seekToNewLine() { - return nil - } - r.fseeker.pos = r.fpos - continue - } - - if r.oldest.IsZero() { - r.oldest = tm - } - - if t < r.validFrom { - continue - } - if t >= r.olderThan { - return nil - } - r.count++ - - if !r.applySearch(str) { - continue - } - - var ent logEntry - decode(&ent, str) - return &ent - } -} - -// Total returns the total number of processed items -func (r *Reader) Total() uint64 { - return r.count -} - -// Oldest returns the time of the oldest processed entry -func (r *Reader) Oldest() time.Time { - return r.oldest -} diff --git a/querylog/querylog_search.go b/querylog/querylog_search.go new file mode 100644 index 00000000..236c1940 --- /dev/null +++ b/querylog/querylog_search.go @@ -0,0 +1,366 @@ +package querylog + +import ( + "encoding/base64" + "io" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// searchFiles reads log entries from all log files and applies the specified search criteria. +// IMPORTANT: this method does not scan more than "maxSearchEntries" so you +// may need to call it many times. +// +// it returns: +// * an array of log entries that we have read +// * time of the oldest processed entry (even if it was discarded) +// * total number of processed entries (including discarded). +func (l *queryLog) searchFiles(params getDataParams) ([]*logEntry, time.Time, int) { + entries := make([]*logEntry, 0) + oldest := time.Time{} + + r, err := l.openReader() + if err != nil { + log.Error("Failed to open qlog reader: %v", err) + return entries, oldest, 0 + } + defer r.Close() + + if params.OlderThan.IsZero() { + err = r.SeekStart() + } else { + err = r.Seek(params.OlderThan.UnixNano()) + } + + if err != nil { + log.Error("Failed to Seek(): %v", err) + return entries, oldest, 0 + } + + total := 0 + oldestNano := int64(0) + // Do not scan more than 50k at once + for total <= maxSearchEntries { + entry, ts, err := l.readNextEntry(r, params) + + if err == io.EOF { + // there's nothing to read anymore + break + } + + if entry != nil { + entries = append(entries, entry) + } + + oldestNano = ts + total++ + } + + oldest = time.Unix(0, oldestNano) + return entries, oldest, total +} + +// readNextEntry - reads the next log entry and checks if it matches the search criteria (getDataParams) +// +// returns: +// * log entry that matches search criteria or null if it was discarded (or if there's nothing to read) +// * timestamp of the processed log entry +// * error if we can't read anymore +func (l *queryLog) readNextEntry(r *QLogReader, params getDataParams) (*logEntry, int64, error) { + line, err := r.ReadNext() + if err != nil { + return nil, 0, err + } + + // Read the log record timestamp right away + timestamp := readQLogTimestamp(line) + + // Quick check without deserializing log entry + if !quickMatchesGetDataParams(line, params) { + return nil, timestamp, nil + } + + entry := logEntry{} + decodeLogEntry(&entry, line) + + // Full check of the deserialized log entry + if !matchesGetDataParams(&entry, params) { + return nil, timestamp, nil + } + + return &entry, timestamp, nil +} + +// openReader - opens QLogReader instance +func (l *queryLog) openReader() (*QLogReader, error) { + files := make([]string, 0) + + if util.FileExists(l.logFile + ".1") { + files = append(files, l.logFile+".1") + } + if util.FileExists(l.logFile) { + files = append(files, l.logFile) + } + + return NewQLogReader(files) +} + +// quickMatchesGetDataParams - quickly checks if the line matches getDataParams +// this method does not guarantee anything and the reason is to do a quick check +// without deserializing anything +func quickMatchesGetDataParams(line string, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered { + boolVal, ok := readJSONBool(line, "IsFiltered") + if !ok || !boolVal { + return false + } + } + + if len(params.Domain) != 0 { + val := readJSONValue(line, "QH") + if len(val) == 0 { + return false + } + + if (params.StrictMatchDomain && val != params.Domain) || + (!params.StrictMatchDomain && strings.Index(val, params.Domain) == -1) { + return false + } + } + + if len(params.QuestionType) != 0 { + val := readJSONValue(line, "QT") + if val != params.QuestionType { + return false + } + } + + if len(params.Client) != 0 { + val := readJSONValue(line, "IP") + if len(val) == 0 { + log.Debug("QueryLog: failed to decodeLogEntry") + return false + } + + if (params.StrictMatchClient && val != params.Client) || + (!params.StrictMatchClient && strings.Index(val, params.Client) == -1) { + return false + } + } + + return true +} + +// matchesGetDataParams - returns true if the entry matches the search parameters +func matchesGetDataParams(entry *logEntry, params getDataParams) bool { + if params.ResponseStatus == responseStatusFiltered && !entry.Result.IsFiltered { + return false + } + + if len(params.QuestionType) != 0 { + if entry.QType != params.QuestionType { + return false + } + } + + if len(params.Domain) != 0 { + if (params.StrictMatchDomain && entry.QHost != params.Domain) || + (!params.StrictMatchDomain && strings.Index(entry.QHost, params.Domain) == -1) { + return false + } + } + + if len(params.Client) != 0 { + if (params.StrictMatchClient && entry.IP != params.Client) || + (!params.StrictMatchClient && strings.Index(entry.IP, params.Client) == -1) { + return false + } + } + + return true +} + +// decodeLogEntry - decodes query log entry from a line +// nolint (gocyclo) +func decodeLogEntry(ent *logEntry, str string) { + var b bool + var i int + var err error + for { + k, v, t := readJSON(&str) + if t == jsonTErr { + break + } + switch k { + case "IP": + if len(ent.IP) == 0 { + ent.IP = v + } + case "T": + ent.Time, err = time.Parse(time.RFC3339, v) + + case "QH": + ent.QHost = v + case "QT": + ent.QType = v + case "QC": + ent.QClass = v + + case "Answer": + ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) + + case "IsFiltered": + b, err = strconv.ParseBool(v) + ent.Result.IsFiltered = b + case "Rule": + ent.Result.Rule = v + case "FilterID": + i, err = strconv.Atoi(v) + ent.Result.FilterID = int64(i) + case "Reason": + i, err = strconv.Atoi(v) + ent.Result.Reason = dnsfilter.Reason(i) + + case "Upstream": + ent.Upstream = v + case "Elapsed": + i, err = strconv.Atoi(v) + ent.Elapsed = time.Duration(i) + + // pre-v0.99.3 compatibility: + case "Question": + var qstr []byte + qstr, err = base64.StdEncoding.DecodeString(v) + if err != nil { + break + } + q := new(dns.Msg) + err = q.Unpack(qstr) + if err != nil { + break + } + ent.QHost = q.Question[0].Name + if len(ent.QHost) == 0 { + break + } + ent.QHost = ent.QHost[:len(ent.QHost)-1] + ent.QType = dns.TypeToString[q.Question[0].Qtype] + ent.QClass = dns.ClassToString[q.Question[0].Qclass] + case "Time": + ent.Time, err = time.Parse(time.RFC3339, v) + } + + if err != nil { + log.Debug("decodeLogEntry err: %s", err) + break + } + } +} + +// Get bool value from "key":bool +func readJSONBool(s, name string) (bool, bool) { + i := strings.Index(s, "\""+name+"\":") + if i == -1 { + return false, false + } + start := i + 1 + len(name) + 2 + b := false + if strings.HasPrefix(s[start:], "true") { + b = true + } else if !strings.HasPrefix(s[start:], "false") { + return false, false + } + return b, true +} + +// Get value from "key":"value" +func readJSONValue(s, name string) string { + i := strings.Index(s, "\""+name+"\":\"") + if i == -1 { + return "" + } + start := i + 1 + len(name) + 3 + i = strings.IndexByte(s[start:], '"') + if i == -1 { + return "" + } + end := start + i + return s[start:end] +} + +const ( + jsonTErr = iota + jsonTObj + jsonTStr + jsonTNum + jsonTBool +) + +// Parse JSON key-value pair +// e.g.: "key":VALUE where VALUE is "string", true|false (boolean), or 123.456 (number) +// Note the limitations: +// . doesn't support whitespace +// . doesn't support "null" +// . doesn't validate boolean or number +// . no proper handling of {} braces +// . no handling of [] brackets +// Return (key, value, type) +func readJSON(ps *string) (string, string, int32) { + s := *ps + k := "" + v := "" + t := int32(jsonTErr) + + q1 := strings.IndexByte(s, '"') + if q1 == -1 { + return k, v, t + } + q2 := strings.IndexByte(s[q1+1:], '"') + if q2 == -1 { + return k, v, t + } + k = s[q1+1 : q1+1+q2] + s = s[q1+1+q2+1:] + + if len(s) < 2 || s[0] != ':' { + return k, v, t + } + + if s[1] == '"' { + q2 = strings.IndexByte(s[2:], '"') + if q2 == -1 { + return k, v, t + } + v = s[2 : 2+q2] + t = jsonTStr + s = s[2+q2+1:] + + } else if s[1] == '{' { + t = jsonTObj + s = s[1+1:] + + } else { + sep := strings.IndexAny(s[1:], ",}") + if sep == -1 { + return k, v, t + } + v = s[1 : 1+sep] + if s[1] == 't' || s[1] == 'f' { + t = jsonTBool + } else if s[1] == '.' || (s[1] >= '0' && s[1] <= '9') { + t = jsonTNum + } + s = s[1+sep+1:] + } + + *ps = s + return k, v, t +} diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go index 8c8b9bb4..06de4101 100644 --- a/querylog/querylog_test.go +++ b/querylog/querylog_test.go @@ -30,14 +30,18 @@ func TestQueryLog(t *testing.T) { l := newQueryLog(conf) // add disk entries - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - addEntry(l, "example.org", "1.2.3.4", "0.1.2.3") - + addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") + // write to disk (first file) + _ = l.flushLogBuffer(true) + // start writing to the second file + _ = l.rotate() + // add disk entries + addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") // write to disk - l.flushLogBuffer(true) - + _ = l.flushLogBuffer(true) // add memory entries - addEntry(l, "test.example.org", "2.2.3.4", "0.1.2.4") + addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") + addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") // get all entries params := getDataParams{ @@ -45,9 +49,11 @@ func TestQueryLog(t *testing.T) { } d := l.getData(params) mdata := d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 4, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) + assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) // search by domain (strict) params = getDataParams{ @@ -58,9 +64,9 @@ func TestQueryLog(t *testing.T) { d = l.getData(params) mdata = d["data"].([]map[string]interface{}) assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) - // search by domain + // search by domain (not strict) params = getDataParams{ OlderThan: time.Time{}, Domain: "example.org", @@ -68,32 +74,35 @@ func TestQueryLog(t *testing.T) { } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 3, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[1], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.1", "2.2.2.1")) // search by client IP (strict) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2.3", + Client: "2.2.2.2", StrictMatchClient: true, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 1) - assert.True(t, checkEntry(t, mdata[0], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 1, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.org", "1.1.1.2", "2.2.2.2")) - // search by client IP + // search by client IP (part of) params = getDataParams{ OlderThan: time.Time{}, - Client: "0.1.2", + Client: "2.2.2", StrictMatchClient: false, } d = l.getData(params) mdata = d["data"].([]map[string]interface{}) - assert.True(t, len(mdata) == 2) - assert.True(t, checkEntry(t, mdata[0], "test.example.org", "2.2.3.4", "0.1.2.4")) - assert.True(t, checkEntry(t, mdata[1], "example.org", "1.2.3.4", "0.1.2.3")) + assert.Equal(t, 4, len(mdata)) + assert.True(t, checkEntry(t, mdata[0], "example.com", "1.1.1.4", "2.2.2.4")) + assert.True(t, checkEntry(t, mdata[1], "test.example.org", "1.1.1.3", "2.2.2.3")) + assert.True(t, checkEntry(t, mdata[2], "example.org", "1.1.1.2", "2.2.2.2")) + assert.True(t, checkEntry(t, mdata[3], "example.org", "1.1.1.1", "2.2.2.1")) } func addEntry(l *queryLog, host, answerStr, client string) { @@ -129,11 +138,11 @@ func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client str mq := m["question"].(map[string]interface{}) ma := m["answer"].([]map[string]interface{}) ma0 := ma[0] - if !assert.True(t, mq["host"].(string) == host) || - !assert.True(t, mq["class"].(string) == "IN") || - !assert.True(t, mq["type"].(string) == "A") || - !assert.True(t, ma0["value"].(string) == answer) || - !assert.True(t, m["client"].(string) == client) { + if !assert.Equal(t, host, mq["host"].(string)) || + !assert.Equal(t, "IN", mq["class"].(string)) || + !assert.Equal(t, "A", mq["type"].(string)) || + !assert.Equal(t, answer, ma0["value"].(string)) || + !assert.Equal(t, client, m["client"].(string)) { return false } return true