diff --git a/internal/home/dns.go b/internal/home/dns.go index 25d869e8..23130cce 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -471,7 +471,11 @@ func startDNSServer() error { Context.filters.Start() Context.stats.Start() - Context.queryLog.Start(ctx) + + err = Context.queryLog.Start(ctx) + if err != nil { + return fmt.Errorf("starting query log: %w", err) + } return nil } @@ -533,7 +537,10 @@ func closeDNSServer() { if Context.queryLog != nil { // TODO(s.chzhen): Pass context. - Context.queryLog.Shutdown(context.TODO()) + err := Context.queryLog.Shutdown(context.TODO()) + if err != nil { + log.Debug("closing query log: %s", err) + } } log.Debug("all dns modules are closed") diff --git a/internal/querylog/decode.go b/internal/querylog/decode.go index e9531ed3..b1993941 100644 --- a/internal/querylog/decode.go +++ b/internal/querylog/decode.go @@ -681,8 +681,8 @@ func (l *queryLog) resultDecHandler( name string, dec *json.Decoder, ent *logEntry, -) (found bool) { - found = true +) (ok bool) { + ok = true switch name { case "ReverseHosts": l.decodeResultReverseHosts(ctx, dec, ent) @@ -693,10 +693,10 @@ func (l *queryLog) resultDecHandler( case "DNSRewriteResult": l.decodeResultDNSRewriteResult(ctx, dec, ent) default: - found = false + ok = false } - return found + return ok } // decodeLogEntry decodes string str to logEntry ent. diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index ca54bd3f..0f89854f 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -86,25 +86,30 @@ func NewClientProto(s string) (cp ClientProto, err error) { var _ QueryLog = (*queryLog)(nil) // Start implements the [QueryLog] interface for *queryLog. -func (l *queryLog) Start(ctx context.Context) { +func (l *queryLog) Start(ctx context.Context) (err error) { if l.conf.HTTPRegister != nil { l.initWeb() } go l.periodicRotate(ctx) + + return nil } // Shutdown implements the [QueryLog] interface for *queryLog. -func (l *queryLog) Shutdown(ctx context.Context) { +func (l *queryLog) Shutdown(ctx context.Context) (err error) { l.confMu.RLock() defer l.confMu.RUnlock() if l.conf.FileEnabled { - err := l.flushLogBuffer(ctx) + err = l.flushLogBuffer(ctx) if err != nil { - l.logger.ErrorContext(ctx, "closing", slogutil.KeyError, err) + // Don't wrap the error because it's informative enough as is. + return err } } + + return nil } func checkInterval(ivl time.Duration) (ok bool) { diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 0fe6e0bb..2a688552 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -34,12 +34,15 @@ func TestQueryLog(t *testing.T) { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // Write to disk (first file). require.NoError(t, l.flushLogBuffer(ctx)) + // Start writing to the second file. require.NoError(t, l.rotate(ctx)) + // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // Write to disk. require.NoError(t, l.flushLogBuffer(ctx)) + // Add memory entries. addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) @@ -121,6 +124,7 @@ func TestQueryLog(t *testing.T) { entries, _ := l.search(ctx, params) require.Len(t, entries, len(tc.want)) + for _, want := range tc.want { assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) } @@ -152,6 +156,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { } // Write them to the first file. require.NoError(t, l.flushLogBuffer(ctx)) + // Add more to the in-memory part of log. for range entNum { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) @@ -196,7 +201,6 @@ func TestQueryLogOffsetLimit(t *testing.T) { params.offset = tc.offset params.limit = tc.limit entries, _ := l.search(ctx, params) - require.Len(t, entries, tc.wantLen) if tc.wantLen > 0 { @@ -229,7 +233,6 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { require.NoError(t, l.flushLogBuffer(ctx)) params := newSearchParams() - for _, maxFileScanEntries := range []int{5, 0} { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { params.maxFileScanEntries = maxFileScanEntries @@ -259,6 +262,7 @@ func TestQueryLogFileDisabled(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) ll, _ := l.search(ctx, params) require.Len(t, ll, 2) + assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) } diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index c2ae19af..087d43aa 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -25,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) { f, err := os.CreateTemp(dir, "*.txt") require.NoError(t, err) + // Use defer and not t.Cleanup to make sure that the file is closed // after this function is done. defer func() { @@ -109,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) { // Calculate the expected position. fileInfo, err := q.file.Stat() require.NoError(t, err) + var expPos int64 if expPos = fileInfo.Size(); expPos > 0 { expPos-- @@ -130,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) { } require.Equal(t, io.EOF, err) + assert.Equal(t, tc.linesNum, read) }) } @@ -175,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) { line, err := getQLogFileLine(q, tc.line) require.NoError(t, err) + ts := readQLogTimestamp(ctx, logger, line) assert.NotEqualValues(t, 0, ts) // Try seeking to that line now. pos, _, err := q.seekTS(ctx, logger, ts) require.NoError(t, err) + assert.NotEqualValues(t, 0, pos) testLine, err := q.ReadNext() require.NoError(t, err) + assert.Equal(t, line, testLine) }) } @@ -228,8 +234,8 @@ func TestQLogFile_SeekTS_bad(t *testing.T) { line, err := getQLogFileLine(q, l.num/2) require.NoError(t, err) - testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1 + testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.NotEqualValues(t, 0, tc.ts) @@ -269,11 +275,13 @@ func TestQLogFile(t *testing.T) { // Seek to the start. pos, err := q.SeekStart() require.NoError(t, err) + assert.Greater(t, pos, int64(0)) // Read first line. line, err := q.ReadNext() require.NoError(t, err) + assert.Contains(t, line, "0.0.0.2") assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) @@ -281,6 +289,7 @@ func TestQLogFile(t *testing.T) { // Read second line. line, err = q.ReadNext() require.NoError(t, err) + assert.EqualValues(t, 0, q.position) assert.Contains(t, line, "0.0.0.1") assert.True(t, strings.HasPrefix(line, "{"), line) @@ -289,12 +298,14 @@ func TestQLogFile(t *testing.T) { // Try reading again (there's nothing to read anymore). line, err = q.ReadNext() require.Equal(t, io.EOF, err) + assert.Empty(t, line) } func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { f, err := os.CreateTemp(t.TempDir(), "*.txt") require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, f.Close) _, err = f.WriteString(data) @@ -302,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { file, err = newQLogFile(f.Name()) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, file.Close) return file @@ -353,6 +365,7 @@ func TestQLog_Seek(t *testing.T) { ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano() _, depth, err := q.seekTS(ctx, logger, ts) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) + assert.Equal(t, tc.wantDepth, depth) }) } diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index 3968db70..c7350f70 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -1,7 +1,6 @@ package querylog import ( - "context" "fmt" "log/slog" "net" @@ -14,16 +13,14 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/service" "github.com/miekg/dns" ) // QueryLog is the query log interface for use by other packages. type QueryLog interface { - // Start starts the query log. - Start(ctx context.Context) - - // Shutdown stops the query log. - Shutdown(ctx context.Context) + // Interface starts and stops the query log. + service.Interface // Add adds a log entry. Add(params *AddParams) diff --git a/internal/querylog/search.go b/internal/querylog/search.go index 7e5dc3f6..4a066887 100644 --- a/internal/querylog/search.go +++ b/internal/querylog/search.go @@ -190,7 +190,7 @@ func (l *queryLog) setQLogReader( r, err := newQLogReader(ctx, l.logger, files) if err != nil { - return nil, fmt.Errorf("opening qlog reader: %s", err) + return nil, fmt.Errorf("opening qlog reader: %w", err) } err = r.seekRecord(ctx, olderThan) @@ -345,7 +345,7 @@ func (l *queryLog) readNextEntry( if err != nil { l.logger.ErrorContext( ctx, - "enriching file record at time", + "enriching file record", "at", e.Time, "client_ip", e.IP, "client_id", e.ClientID, diff --git a/internal/querylog/search_test.go b/internal/querylog/search_test.go index 304a4566..7bc97f70 100644 --- a/internal/querylog/search_test.go +++ b/internal/querylog/search_test.go @@ -50,8 +50,8 @@ func TestQueryLog_Search_findClient(t *testing.T) { require.NoError(t, err) ctx := testutil.ContextWithTimeout(t, testTimeout) - t.Cleanup(func() { - l.Shutdown(ctx) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return l.Shutdown(ctx) }) q := &dns.Msg{