From 8d88d799303c7e3d15322fee87780fedb408ea13 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 15 Nov 2024 15:57:07 +0300 Subject: [PATCH] all: slog querylog --- internal/home/dns.go | 9 +- internal/querylog/decode.go | 149 +++++++++++++++++++-------- internal/querylog/decode_test.go | 55 ++++++---- internal/querylog/entry.go | 9 +- internal/querylog/http.go | 30 +++--- internal/querylog/json.go | 43 ++++++-- internal/querylog/qlog.go | 51 +++++---- internal/querylog/qlog_test.go | 35 ++++--- internal/querylog/qlogfile.go | 19 ++-- internal/querylog/qlogfile_test.go | 21 +++- internal/querylog/qlogreader.go | 22 ++-- internal/querylog/qlogreader_test.go | 9 +- internal/querylog/querylog.go | 11 +- internal/querylog/querylogfile.go | 66 +++++++----- internal/querylog/search.go | 114 ++++++++++++-------- internal/querylog/search_test.go | 11 +- internal/querylog/searchcriterion.go | 11 +- internal/querylog/searchparams.go | 21 +++- 18 files changed, 462 insertions(+), 224 deletions(-) diff --git a/internal/home/dns.go b/internal/home/dns.go index 6fcc09ca..601de99d 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) { } conf := querylog.Config{ + Logger: l.With(slogutil.KeyPrefix, "querylog"), Anonymizer: anonymizer, ConfigModified: onConfigModified, HTTPRegister: httpRegister, @@ -457,7 +458,8 @@ func startDNSServer() error { Context.filters.EnableFilters(false) // TODO(s.chzhen): Pass context. - err := Context.clients.Start(context.TODO()) + ctx := context.TODO() + err := Context.clients.Start(ctx) if err != nil { return fmt.Errorf("starting clients container: %w", err) } @@ -469,7 +471,7 @@ func startDNSServer() error { Context.filters.Start() Context.stats.Start() - Context.queryLog.Start() + Context.queryLog.Start(ctx) return nil } @@ -530,7 +532,8 @@ func closeDNSServer() { } if Context.queryLog != nil { - Context.queryLog.Close() + // TODO(s.chzhen): Pass context. + Context.queryLog.Close(context.TODO()) } log.Debug("all dns modules are closed") diff --git a/internal/querylog/decode.go b/internal/querylog/decode.go index d4dea04e..531bcf1c 100644 --- a/internal/querylog/decode.go +++ b/internal/querylog/decode.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -13,7 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{ } // decodeResultRuleKey decodes the token of "Rules" type to logEntry struct. -func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultRuleKey( + ctx context.Context, + key string, + i int, + dec *json.Decoder, + ent *logEntry, +) { var vToken json.Token switch key { case "FilterListID": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if n, ok := vToken.(json.Number); ok { id, _ := n.Int64() ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id) } case "IP": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if ipStr, ok := vToken.(string); ok { if ip, err := netip.ParseAddr(ipStr); err == nil { ent.Result.Rules[i].IP = ip } else { - log.Debug("querylog: decoding ipStr value: %s", err) + l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err) } } case "Text": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if s, ok := vToken.(string); ok { ent.Result.Rules[i].Text = s } @@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { // decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule] // and then adds the decoded object to the slice of result rules. -func decodeVTokenAndAddRule( +func (l *queryLog) decodeVTokenAndAddRule( + ctx context.Context, key string, i int, dec *json.Decoder, @@ -215,7 +223,7 @@ func decodeVTokenAndAddRule( vToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultRuleKey %s err: %s", key, err) + l.logger.DebugContext(ctx, "decodeResultRuleKey", "key", key, slogutil.KeyError, err) } return newRules, nil @@ -230,12 +238,12 @@ func decodeVTokenAndAddRule( // decodeResultRules parses the dec's tokens into logEntry ent interpreting it // as a slice of the result rules. -func decodeResultRules(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) { for { delimToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultRules err: %s", err) + l.logger.DebugContext(ctx, "decodeResultRules", slogutil.KeyError, err) } return @@ -244,13 +252,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) { if d, ok := delimToken.(json.Delim); !ok { return } else if d != '[' { - log.Debug("decodeResultRules: unexpected delim %q", d) + l.logger.DebugContext( + ctx, + "decodeResultRules", + slogutil.KeyError, unexpectedDelimiterError(d), + ) } - err = decodeResultRuleToken(dec, ent) + err = l.decodeResultRuleToken(ctx, dec, ent) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResultRules err: %s", err) + l.logger.DebugContext(ctx, "decodeResultRules", slogutil.KeyError, err) } return @@ -259,7 +271,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) { } // decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent. -func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { +func (l *queryLog) decodeResultRuleToken( + ctx context.Context, + dec *json.Decoder, + ent *logEntry, +) (err error) { i := 0 for { var keyToken json.Token @@ -287,7 +303,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken) } - decodeResultRuleKey(key, i, dec, ent) + l.decodeResultRuleKey(ctx, key, i, dec, ent) } } @@ -296,12 +312,12 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { // other occurrences of DNSRewriteResult in the entry since hosts container's // rewrites currently has the highest priority along the entire filtering // pipeline. -func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) { for { itemToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultReverseHosts err: %s", err) + l.logger.DebugContext(ctx, "decodeResultReverseHosts", slogutil.KeyError, err) } return @@ -315,7 +331,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { return } - log.Debug("decodeResultReverseHosts: unexpected delim %q", v) + l.logger.DebugContext( + ctx, + "decodeResultReverseHosts", + slogutil.KeyError, unexpectedDelimiterError(v), + ) return case string: @@ -346,12 +366,12 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { // decodeResultIPList parses the dec's tokens into logEntry ent interpreting it // as the result IP addresses list. -func decodeResultIPList(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) { for { itemToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultIPList err: %s", err) + l.logger.DebugContext(ctx, "decodeResultIPList", slogutil.KeyError, err) } return @@ -365,7 +385,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) { return } - log.Debug("decodeResultIPList: unexpected delim %q", v) + l.logger.DebugContext( + ctx, + "decodeResultIPList", + slogutil.KeyError, unexpectedDelimiterError(v), + ) return case string: @@ -382,7 +406,12 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) { // decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type // to the logEntry struct. -func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultDNSRewriteResultKey( + ctx context.Context, + key string, + dec *json.Decoder, + ent *logEntry, +) { var err error switch key { @@ -391,7 +420,11 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr vToken, err = dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultDNSRewriteResultKey err: %s", err) + l.logger.DebugContext( + ctx, + "decodeResultDNSRewriteResultKey", + slogutil.KeyError, err, + ) } return @@ -419,7 +452,11 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr // decoding and correct the values. err = dec.Decode(&ent.Result.DNSRewriteResult.Response) if err != nil { - log.Debug("decodeResultDNSRewriteResultKey response err: %s", err) + l.logger.DebugContext( + ctx, + "decodeResultDNSRewriteResultKey response", + slogutil.KeyError, err, + ) } ent.parseDNSRewriteResultIPs() @@ -430,12 +467,16 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr // decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent // interpreting it as the result DNSRewriteResult. -func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultDNSRewriteResult( + ctx context.Context, + dec *json.Decoder, + ent *logEntry, +) { for { key, err := parseKeyToken(dec) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResultDNSRewriteResult: %s", err) + l.logger.DebugContext(ctx, "decodeResultDNSRewriteResult", slogutil.KeyError, err) } return @@ -445,7 +486,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { continue } - decodeResultDNSRewriteResultKey(key, dec, ent) + l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent) } } @@ -508,14 +549,14 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) { } // decodeResult decodes a token of "Result" type to logEntry struct. -func decodeResult(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) { defer translateResult(ent) for { key, err := parseKeyToken(dec) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResult: %s", err) + l.logger.DebugContext(ctx, "decodeResult", slogutil.KeyError, err) } return @@ -525,10 +566,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) { continue } - decHandler, ok := resultDecHandlers[key] + ok := l.resultDecHandler(ctx, key, dec, ent) if ok { - decHandler(dec, ent) - continue } @@ -543,7 +582,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) { } if err = handler(val, ent); err != nil { - log.Debug("decodeResult handler err: %s", err) + l.logger.DebugContext(ctx, "decodeResult handler", slogutil.KeyError, err) return } @@ -636,16 +675,32 @@ var resultHandlers = map[string]logEntryHandler{ }, } -// resultDecHandlers is the map of decode handlers for various keys. -var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){ - "ReverseHosts": decodeResultReverseHosts, - "IPList": decodeResultIPList, - "Rules": decodeResultRules, - "DNSRewriteResult": decodeResultDNSRewriteResult, +// resultDecHandlers calls a decode handler for key if there is one. +func (l *queryLog) resultDecHandler( + ctx context.Context, + name string, + dec *json.Decoder, + ent *logEntry, +) (found bool) { + notFound := false + switch name { + case "ReverseHosts": + l.decodeResultReverseHosts(ctx, dec, ent) + case "IPList": + l.decodeResultIPList(ctx, dec, ent) + case "Rules": + l.decodeResultRules(ctx, dec, ent) + case "DNSRewriteResult": + l.decodeResultDNSRewriteResult(ctx, dec, ent) + default: + notFound = true + } + + return !notFound } // decodeLogEntry decodes string str to logEntry ent. -func decodeLogEntry(ent *logEntry, str string) { +func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) { dec := json.NewDecoder(strings.NewReader(str)) dec.UseNumber() @@ -653,7 +708,7 @@ func decodeLogEntry(ent *logEntry, str string) { keyToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeLogEntry err: %s", err) + l.logger.DebugContext(ctx, "decodeLogEntry err", slogutil.KeyError, err) } return @@ -665,13 +720,14 @@ func decodeLogEntry(ent *logEntry, str string) { key, ok := keyToken.(string) if !ok { - log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken) + err = fmt.Errorf("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken) + l.logger.DebugContext(ctx, "decodeLogEntry", slogutil.KeyError, err) return } if key == "Result" { - decodeResult(dec, ent) + l.decodeResult(ctx, dec, ent) continue } @@ -687,9 +743,14 @@ func decodeLogEntry(ent *logEntry, str string) { } if err = handler(val, ent); err != nil { - log.Debug("decodeLogEntry handler err: %s", err) + l.logger.DebugContext(ctx, "decodeLogEntry handler", slogutil.KeyError, err) return } } } + +// unexpectedDelimiterError is a helper for creating informative errors. +func unexpectedDelimiterError(d json.Delim) (err error) { + return fmt.Errorf("unexpected delimiter: %q", d) +} diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index 1f907e3d..f4a2e5cf 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -3,27 +3,35 @@ package querylog import ( "bytes" "encoding/base64" + "log/slog" "net" "net/netip" - "strings" "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// Common constants for tests. +const testTimeout = 1 * time.Second + func TestDecodeLogEntry(t *testing.T) { logOutput := &bytes.Buffer{} + l := &queryLog{ + logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: slogutil.RemoveTime, + })), + } - aghtest.ReplaceLogWriter(t, logOutput) - aghtest.ReplaceLogLevel(t, log.DEBUG) + ctx := testutil.ContextWithTimeout(t, testTimeout) t.Run("success", func(t *testing.T) { const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` @@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) { } got := &logEntry{} - decodeLogEntry(got, data) + l.decodeLogEntry(ctx, got, data) s := logOutput.String() assert.Empty(t, s) @@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_filter_id_old_rule", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`, - want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n", + want: `level=DEBUG msg="decodeResult handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`, }, { name: "bad_is_filtered", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n", + want: `level=DEBUG msg="decodeLogEntry err" err="invalid character 'o' in literal true (expecting 'u')"`, }, { name: "bad_elapsed", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`, @@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_time", log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n", + want: `level=DEBUG msg="decodeLogEntry handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`, }, { name: "bad_host", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, @@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "very_bad_client_proto", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n", + want: `level=DEBUG msg="decodeLogEntry handler" err="invalid client proto: \"dog\""`, }, { name: "bad_answer", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, @@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "very_bad_answer", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n", + want: `level=DEBUG msg="decodeLogEntry handler" err="illegal base64 data at input byte 61"`, }, { name: "bad_rule", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`, @@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_reverse_hosts", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`, - want: "decodeResultReverseHosts: unexpected delim \"{\"\n", + want: `level=DEBUG msg=decodeResultReverseHosts err="unexpected delimiter: \"{\""`, }, { name: "bad_ip_list", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`, - want: "decodeResultIPList: unexpected delim \"{\"\n", + want: `level=DEBUG msg=decodeResultIPList err="unexpected delimiter: \"{\""`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - decodeLogEntry(new(logEntry), tc.log) - - s := logOutput.String() + l.decodeLogEntry(ctx, new(logEntry), tc.log) + got := logOutput.String() if tc.want == "" { - assert.Empty(t, s) + assert.Empty(t, got) } else { - assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s) + require.NotEmpty(t, got) + + // Remove newline. + got = got[:len(got)-1] + assert.Equal(t, tc.want, got) } logOutput.Reset() @@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) { aaaa2 = aaaa1.Next() ) + l := &queryLog{ + logger: slogutil.NewDiscardLogger(), + } + + ctx := testutil.ContextWithTimeout(t, testTimeout) + testCases := []struct { want *logEntry entry string @@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := &logEntry{} - decodeLogEntry(e, tc.entry) + l.decodeLogEntry(ctx, e, tc.entry) assert.Equal(t, tc.want, e) }) diff --git a/internal/querylog/entry.go b/internal/querylog/entry.go index ed3319b0..67272bba 100644 --- a/internal/querylog/entry.go +++ b/internal/querylog/entry.go @@ -1,12 +1,14 @@ package querylog import ( + "context" + "log/slog" "net" "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" ) @@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) { // addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is // true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any // errors are logged. -func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { +func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) { if resp == nil { return } @@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { e.Answer, err = resp.Pack() err = errors.Annotate(err, "packing answer: %w") } + if err != nil { - log.Error("querylog: %s", err) + l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err) } } diff --git a/internal/querylog/http.go b/internal/querylog/http.go index 1fb7cce4..fe5050d8 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "encoding/json" "fmt" "math" @@ -15,7 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" "golang.org/x/net/idna" ) @@ -74,7 +75,7 @@ func (l *queryLog) initWeb() { // handleQueryLog is the handler for the GET /control/querylog HTTP API. func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { - params, err := parseSearchParams(r) + params, err := l.parseSearchParams(r) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err) @@ -87,18 +88,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { l.confMu.RLock() defer l.confMu.RUnlock() - entries, oldest = l.search(params) + entries, oldest = l.search(r.Context(), params) }() - resp := entriesToJSON(entries, oldest, l.anonymizer.Load()) + resp := l.entriesToJSON(r.Context(), entries, oldest, l.anonymizer.Load()) aghhttp.WriteJSONResponseOK(w, r, resp) } // handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP // API. -func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { - l.clear() +func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) { + l.clear(r.Context()) } // handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP @@ -280,11 +281,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool { } // parseSearchCriterion parses a search criterion from the query parameter. -func parseSearchCriterion(q url.Values, name string, ct criterionType) ( - ok bool, - sc searchCriterion, - err error, -) { +func (l *queryLog) parseSearchCriterion( + ctx context.Context, + q url.Values, + name string, + ct criterionType, +) (ok bool, sc searchCriterion, err error) { val := q.Get(name) if val == "" { return false, sc, nil @@ -301,7 +303,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) ( // TODO(e.burkov): Make it work with parts of IDNAs somehow. loweredVal := strings.ToLower(val) if asciiVal, err = idna.ToASCII(loweredVal); err != nil { - log.Debug("can't convert %q to ascii: %s", val, err) + l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err) } else if asciiVal == loweredVal { // Purge asciiVal to prevent checking the same value // twice. @@ -331,7 +333,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) ( // parseSearchParams parses search parameters from the HTTP request's query // string. -func parseSearchParams(r *http.Request) (p *searchParams, err error) { +func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err error) { p = newSearchParams() q := r.URL.Query() @@ -369,7 +371,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) { }} { var ok bool var c searchCriterion - ok, c, err = parseSearchCriterion(q, v.urlField, v.ct) + ok, c, err = l.parseSearchCriterion(r.Context(), q, v.urlField, v.ct) if err != nil { return nil, err } diff --git a/internal/querylog/json.go b/internal/querylog/json.go index 07a9d62b..ab372987 100644 --- a/internal/querylog/json.go +++ b/internal/querylog/json.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "slices" "strconv" "strings" @@ -8,7 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" "golang.org/x/net/idna" ) @@ -19,7 +20,8 @@ import ( type jobject = map[string]any // entriesToJSON converts query log entries to JSON. -func entriesToJSON( +func (l *queryLog) entriesToJSON( + ctx context.Context, entries []*logEntry, oldest time.Time, anonFunc aghnet.IPMutFunc, @@ -28,7 +30,7 @@ func entriesToJSON( // The elements order is already reversed to be from newer to older. for _, entry := range entries { - jsonEntry := entryToJSON(entry, anonFunc) + jsonEntry := l.entryToJSON(ctx, entry, anonFunc) data = append(data, jsonEntry) } @@ -44,7 +46,11 @@ func entriesToJSON( } // entryToJSON converts a log entry's data into an entry for the JSON API. -func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) { +func (l *queryLog) entryToJSON( + ctx context.Context, + entry *logEntry, + anonFunc aghnet.IPMutFunc, +) (jsonEntry jobject) { hostname := entry.QHost question := jobject{ "type": entry.QType, @@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) } if qhost, err := idna.ToUnicode(hostname); err != nil { - log.Debug("querylog: translating %q into unicode: %s", hostname, err) + l.logger.DebugContext( + ctx, + "translating into unicode", + "hostname", hostname, + slogutil.KeyError, err, + ) } else if qhost != hostname && qhost != "" { question["unicode_name"] = qhost } @@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) jsonEntry["service_name"] = entry.Result.ServiceName } - setMsgData(entry, jsonEntry) - setOrigAns(entry, jsonEntry) + l.setMsgData(ctx, entry, jsonEntry) + l.setOrigAns(ctx, entry, jsonEntry) return jsonEntry } // setMsgData sets the message data in jsonEntry. -func setMsgData(entry *logEntry, jsonEntry jobject) { +func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) { if len(entry.Answer) == 0 { return } msg := &dns.Msg{} if err := msg.Unpack(entry.Answer); err != nil { - log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err) + l.logger.DebugContext( + ctx, + "failed to unpack dns msg", + "answer", entry.Answer, + slogutil.KeyError, err, + ) return } @@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) { } // setOrigAns sets the original answer data in jsonEntry. -func setOrigAns(entry *logEntry, jsonEntry jobject) { +func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) { if len(entry.OrigAnswer) == 0 { return } @@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) { orig := &dns.Msg{} err := orig.Unpack(entry.OrigAnswer) if err != nil { - log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err) + l.logger.DebugContext( + ctx, + "setting original answer", + "answer", entry.OrigAnswer, + slogutil.KeyError, err, + ) return } diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index c0b2edd1..2cb4deca 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -2,7 +2,9 @@ package querylog import ( + "context" "fmt" + "log/slog" "os" "sync" "time" @@ -11,7 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" ) @@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json" // queryLog is a structure that writes and reads the DNS query log. type queryLog struct { + // logger is used for logging the operation of the query log. It must not + // be nil. + logger *slog.Logger + // confMu protects conf. confMu *sync.RWMutex @@ -76,22 +82,22 @@ func NewClientProto(s string) (cp ClientProto, err error) { } } -func (l *queryLog) Start() { +func (l *queryLog) Start(ctx context.Context) { if l.conf.HTTPRegister != nil { l.initWeb() } - go l.periodicRotate() + go l.periodicRotate(ctx) } -func (l *queryLog) Close() { +func (l *queryLog) Close(ctx context.Context) { l.confMu.RLock() defer l.confMu.RUnlock() if l.conf.FileEnabled { - err := l.flushLogBuffer() + err := l.flushLogBuffer(ctx) if err != nil { - log.Error("querylog: closing: %s", err) + l.logger.ErrorContext(ctx, "closing", slogutil.KeyError, err) } } } @@ -131,7 +137,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) { } // Clear memory buffer and remove log files -func (l *queryLog) clear() { +func (l *queryLog) clear(ctx context.Context) { l.fileFlushLock.Lock() defer l.fileFlushLock.Unlock() @@ -146,19 +152,24 @@ func (l *queryLog) clear() { oldLogFile := l.logFile + ".1" err := os.Remove(oldLogFile) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("removing old log file %q: %s", oldLogFile, err) + l.logger.ErrorContext( + ctx, + "removing old log file", + "file", oldLogFile, + slogutil.KeyError, err, + ) } err = os.Remove(l.logFile) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("removing log file %q: %s", l.logFile, err) + l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err) } - log.Debug("querylog: cleared") + l.logger.DebugContext(ctx, "cleared") } // newLogEntry creates an instance of logEntry from parameters. -func newLogEntry(params *AddParams) (entry *logEntry) { +func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) { q := params.Question.Question[0] qHost := aghnet.NormalizeDomain(q.Name) @@ -187,8 +198,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) { entry.ReqECS = params.ReqECS.String() } - entry.addResponse(params.Answer, false) - entry.addResponse(params.OrigAnswer, true) + entry.addResponse(ctx, logger, params.Answer, false) + entry.addResponse(ctx, logger, params.OrigAnswer, true) return entry } @@ -209,9 +220,12 @@ func (l *queryLog) Add(params *AddParams) { return } + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + err := params.validate() if err != nil { - log.Error("querylog: adding record: %s, skipping", err) + l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err) return } @@ -220,7 +234,7 @@ func (l *queryLog) Add(params *AddParams) { params.Result = &filtering.Result{} } - entry := newLogEntry(params) + entry := newLogEntry(ctx, l.logger, params) l.bufferLock.Lock() defer l.bufferLock.Unlock() @@ -232,9 +246,9 @@ func (l *queryLog) Add(params *AddParams) { // TODO(s.chzhen): Fix occasional rewrite of entires. go func() { - flushErr := l.flushLogBuffer() + flushErr := l.flushLogBuffer(ctx) if flushErr != nil { - log.Error("querylog: flushing after adding: %s", flushErr) + l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr) } }() } @@ -247,7 +261,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool { c, err := l.findClient(ids) if err != nil { - log.Error("querylog: finding client: %s", err) + // TODO(s.chzhen): Pass context. + l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err) } if c != nil && c.IgnoreQueryLog { diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 57d8b68d..0fe6e0bb 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" @@ -14,14 +15,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} - // TestQueryLog tests adding and loading (with filtering) entries from disk and // memory. func TestQueryLog(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, @@ -30,16 +28,18 @@ func TestQueryLog(t *testing.T) { }) require.NoError(t, err) + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // Write to disk (first file). - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) // Start writing to the second file. - require.NoError(t, l.rotate()) + require.NoError(t, l.rotate(ctx)) // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // Write to disk. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) // Add memory entries. addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) @@ -119,7 +119,7 @@ func TestQueryLog(t *testing.T) { params := newSearchParams() params.searchCriteria = tc.sCr - entries, _ := l.search(params) + entries, _ := l.search(ctx, params) require.Len(t, entries, len(tc.want)) for _, want := range tc.want { assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) @@ -130,6 +130,7 @@ func TestQueryLog(t *testing.T) { func TestQueryLogOffsetLimit(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, RotationIvl: timeutil.Day, MemSize: 100, @@ -142,12 +143,15 @@ func TestQueryLogOffsetLimit(t *testing.T) { firstPageDomain = "first.example.org" secondPageDomain = "second.example.org" ) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Add entries to the log. for range entNum { addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to the first file. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) // Add more to the in-memory part of log. for range entNum { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) @@ -191,7 +195,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { params.offset = tc.offset params.limit = tc.limit - entries, _ := l.search(params) + entries, _ := l.search(ctx, params) require.Len(t, entries, tc.wantLen) @@ -205,6 +209,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { func TestQueryLogMaxFileScanEntries(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, @@ -213,20 +218,22 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { }) require.NoError(t, err) + ctx := testutil.ContextWithTimeout(t, testTimeout) + const entNum = 10 // Add entries to the log. for range entNum { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to disk. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) params := newSearchParams() for _, maxFileScanEntries := range []int{5, 0} { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { params.maxFileScanEntries = maxFileScanEntries - entries, _ := l.search(params) + entries, _ := l.search(ctx, params) assert.Len(t, entries, entNum-maxFileScanEntries) }) } @@ -234,6 +241,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { func TestQueryLogFileDisabled(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: false, RotationIvl: timeutil.Day, @@ -248,7 +256,8 @@ func TestQueryLogFileDisabled(t *testing.T) { addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) params := newSearchParams() - ll, _ := l.search(params) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ll, _ := l.search(ctx, params) require.Len(t, ll, 2) assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) diff --git a/internal/querylog/qlogfile.go b/internal/querylog/qlogfile.go index c145f520..7b18f3d2 100644 --- a/internal/querylog/qlogfile.go +++ b/internal/querylog/qlogfile.go @@ -1,8 +1,10 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "os" "strings" "sync" @@ -10,7 +12,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" ) const ( @@ -102,7 +103,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6 // for so that when we call "ReadNext" this line was returned. // - Depth of the search (how many times we compared timestamps). // - If we could not find it, it returns one of the errors described above. -func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { +func (q *qLogFile) seekTS( + ctx context.Context, + logger *slog.Logger, + timestamp int64, +) (pos int64, depth int, err error) { q.lock.Lock() defer q.lock.Unlock() @@ -151,7 +156,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { lastProbeLineIdx = lineIdx // Get the timestamp from the query log record. - ts := readQLogTimestamp(line) + ts := readQLogTimestamp(ctx, logger, line) if ts == 0 { return 0, depth, fmt.Errorf( "looking up timestamp %d in %q: record %q has empty timestamp", @@ -385,20 +390,22 @@ func readJSONValue(s, prefix string) string { } // readQLogTimestamp reads the timestamp field from the query log line. -func readQLogTimestamp(str string) int64 { +func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 { val := readJSONValue(str, `"T":"`) if len(val) == 0 { val = readJSONValue(str, `"Time":"`) } if len(val) == 0 { - log.Error("Couldn't find timestamp: %s", str) + logger.ErrorContext(ctx, "couldn't find timestamp", "line", str) + return 0 } tm, err := time.Parse(time.RFC3339Nano, val) if err != nil { - log.Error("Couldn't parse timestamp: %s", val) + logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val) + return 0 } diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index 8462e950..c2ae19af 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -146,6 +147,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) { num: 10, }} + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + for _, l := range linesCases { testCases := []struct { name string @@ -171,11 +175,11 @@ func TestQLogFile_SeekTS_good(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) { line, err := getQLogFileLine(q, tc.line) require.NoError(t, err) - ts := readQLogTimestamp(line) + ts := readQLogTimestamp(ctx, logger, line) assert.NotEqualValues(t, 0, ts) // Try seeking to that line now. - pos, _, err := q.seekTS(ts) + pos, _, err := q.seekTS(ctx, logger, ts) require.NoError(t, err) assert.NotEqualValues(t, 0, pos) @@ -199,6 +203,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) { num: 10, }} + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + for _, l := range linesCases { testCases := []struct { name string @@ -221,14 +228,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) { line, err := getQLogFileLine(q, l.num/2) require.NoError(t, err) - testCases[2].ts = readQLogTimestamp(line) - 1 + testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.NotEqualValues(t, 0, tc.ts) var depth int - _, depth, err = q.seekTS(tc.ts) + _, depth, err = q.seekTS(ctx, logger, tc.ts) assert.NotEmpty(t, l.num) require.Error(t, err) @@ -308,6 +315,9 @@ func TestQLog_Seek(t *testing.T) { `{"T":"` + strV + `"}` + nl timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00") + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + testCases := []struct { wantErr error name string @@ -340,7 +350,8 @@ func TestQLog_Seek(t *testing.T) { q := newTestQLogFileData(t, data) - _, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()) + ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano() + _, depth, err := q.seekTS(ctx, logger, ts) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) assert.Equal(t, tc.wantDepth, depth) }) diff --git a/internal/querylog/qlogreader.go b/internal/querylog/qlogreader.go index a222323e..dad4a594 100644 --- a/internal/querylog/qlogreader.go +++ b/internal/querylog/qlogreader.go @@ -1,12 +1,14 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "os" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) // qLogReader allows reading from multiple query log files in the reverse @@ -16,6 +18,10 @@ import ( // pointer to a particular query log file, and to a specific position in this // file, and it reads lines in reverse order starting from that position. type qLogReader struct { + // logger is used for logging the operation of the query log reader. It + // must not be nil. + logger *slog.Logger + // qFiles is an array with the query log files. The order is from oldest // to newest. qFiles []*qLogFile @@ -25,7 +31,7 @@ type qLogReader struct { } // newQLogReader initializes a qLogReader instance with the specified files. -func newQLogReader(files []string) (*qLogReader, error) { +func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) { qFiles := make([]*qLogFile, 0) for _, f := range files { @@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) { // Close what we've already opened. cErr := closeQFiles(qFiles) if cErr != nil { - log.Debug("querylog: closing files: %s", cErr) + logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr) } return nil, err @@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) { qFiles = append(qFiles, q) } - return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil + return &qLogReader{ + logger: logger, + qFiles: qFiles, + currentFile: len(qFiles) - 1, + }, nil } // seekTS performs binary search of a query log record with the specified // timestamp. If the record is found, it sets qLogReader's position to point // to that line, so that the next ReadNext call returned this line. -func (r *qLogReader) seekTS(timestamp int64) (err error) { +func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) { for i := len(r.qFiles) - 1; i >= 0; i-- { q := r.qFiles[i] - _, _, err = q.seekTS(timestamp) + _, _, err = q.seekTS(ctx, r.logger, timestamp) if err != nil { if errors.Is(err, errTSTooEarly) { // Look at the next file, since we've reached the end of this diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index 43bb3d5c..bb3ce164 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader testFiles := prepareTestFiles(t, filesNum, linesNum) + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Create the new qLogReader instance. - reader, err := newQLogReader(testFiles) + reader, err := newQLogReader(ctx, logger, testFiles) require.NoError(t, err) assert.NotNil(t, reader) @@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) { func TestQLogReader_Seek(t *testing.T) { r := newTestQLogReader(t, 2, 10000) + ctx := testutil.ContextWithTimeout(t, testTimeout) testCases := []struct { want error @@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) { ts, err := time.Parse(time.RFC3339Nano, tc.time) require.NoError(t, err) - err = r.seekTS(ts.UnixNano()) + err = r.seekTS(ctx, ts.UnixNano()) assert.ErrorIs(t, err, tc.want) }) } diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index bccc264a..9c2c53cd 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -1,7 +1,9 @@ package querylog import ( + "context" "fmt" + "log/slog" "net" "path/filepath" "sync" @@ -17,10 +19,10 @@ import ( // QueryLog - main interface type QueryLog interface { - Start() + Start(ctx context.Context) // Close query log object - Close() + Close(ctx context.Context) // Add a log entry Add(params *AddParams) @@ -36,6 +38,10 @@ type QueryLog interface { // // Do not alter any fields of this structure after using it. type Config struct { + // Logger is used for logging the operation of the query log. It must not + // be nil. + Logger *slog.Logger + // Ignored contains the list of host names, which should not be written to // log, and matches them. Ignored *aghnet.IgnoreEngine @@ -151,6 +157,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) { } l = &queryLog{ + logger: conf.Logger, findClient: findClient, buffer: container.NewRingBuffer[*logEntry](memSize), diff --git a/internal/querylog/querylogfile.go b/internal/querylog/querylogfile.go index 6b4760c0..57539669 100644 --- a/internal/querylog/querylogfile.go +++ b/internal/querylog/querylogfile.go @@ -2,6 +2,7 @@ package querylog import ( "bytes" + "context" "encoding/json" "fmt" "os" @@ -9,28 +10,29 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/c2h5oh/datasize" ) // flushLogBuffer flushes the current buffer to file and resets the current // buffer. -func (l *queryLog) flushLogBuffer() (err error) { +func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) { defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }() l.fileFlushLock.Lock() defer l.fileFlushLock.Unlock() - b, err := l.encodeEntries() + b, err := l.encodeEntries(ctx) if err != nil { // Don't wrap the error since it's informative enough as is. return err } - return l.flushToFile(b) + return l.flushToFile(ctx, b) } // encodeEntries returns JSON encoded log entries, logs estimated time, clears // the log buffer. -func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { +func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) { l.bufferLock.Lock() defer l.bufferLock.Unlock() @@ -55,8 +57,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { return nil, err } + size := b.Len() elapsed := time.Since(start) - log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen)) + l.logger.DebugContext( + ctx, + "serialized elements via json", + "count", bufLen, + "elapsed", elapsed, + "size", datasize.ByteSize(size), + "size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)), + "time_per_entry", elapsed/time.Duration(bufLen), + ) l.buffer.Clear() l.flushPending = false @@ -65,7 +76,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { } // flushToFile saves the encoded log entries to the query log file. -func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { +func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) { l.fileWriteLock.Lock() defer l.fileWriteLock.Unlock() @@ -83,19 +94,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { return fmt.Errorf("writing to file %q: %w", filename, err) } - log.Debug("querylog: ok %q: %v bytes written", filename, n) + l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n)) return nil } -func (l *queryLog) rotate() error { +func (l *queryLog) rotate(ctx context.Context) error { from := l.logFile to := l.logFile + ".1" err := os.Rename(from, to) if err != nil { if errors.Is(err, os.ErrNotExist) { - log.Debug("querylog: no log to rotate") + l.logger.DebugContext(ctx, "no log to rotate") return nil } @@ -103,12 +114,12 @@ func (l *queryLog) rotate() error { return fmt.Errorf("failed to rename old file: %w", err) } - log.Debug("querylog: renamed %s into %s", from, to) + l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to) return nil } -func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { +func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) { var f *os.File f, err = os.Open(l.logFile) if err != nil { @@ -130,15 +141,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { return time.Time{}, err } - log.Debug("querylog: the oldest log entry: %s", val) + l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val) return t, nil } -func (l *queryLog) periodicRotate() { - defer log.OnPanic("querylog: rotating") +func (l *queryLog) periodicRotate(ctx context.Context) { + defer slogutil.RecoverAndLog(ctx, l.logger) - l.checkAndRotate() + l.checkAndRotate(ctx) // rotationCheckIvl is the period of time between checking the need for // rotating log files. It's smaller of any available rotation interval to @@ -151,13 +162,13 @@ func (l *queryLog) periodicRotate() { defer rotations.Stop() for range rotations.C { - l.checkAndRotate() + l.checkAndRotate(ctx) } } // checkAndRotate rotates log files if those are older than the specified // rotation interval. -func (l *queryLog) checkAndRotate() { +func (l *queryLog) checkAndRotate(ctx context.Context) { var rotationIvl time.Duration func() { l.confMu.RLock() @@ -166,29 +177,30 @@ func (l *queryLog) checkAndRotate() { rotationIvl = l.conf.RotationIvl }() - oldest, err := l.readFileFirstTimeValue() + oldest, err := l.readFileFirstTimeValue(ctx) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("querylog: reading oldest record for rotation: %s", err) + l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err) return } if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) { - log.Debug( - "querylog: %s <= %s, not rotating", - now.Format(time.RFC3339), - rotTime.Format(time.RFC3339), + l.logger.DebugContext( + ctx, + "not rotating", + "now", now.Format(time.RFC3339), + "rotate_time", rotTime.Format(time.RFC3339), ) return } - err = l.rotate() + err = l.rotate(ctx) if err != nil { - log.Error("querylog: rotating: %s", err) + l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err) return } - log.Debug("querylog: rotated successfully") + l.logger.DebugContext(ctx, "rotated successfully") } diff --git a/internal/querylog/search.go b/internal/querylog/search.go index f8c8f90e..adb9ec59 100644 --- a/internal/querylog/search.go +++ b/internal/querylog/search.go @@ -1,13 +1,15 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "slices" "time" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) // client finds the client info, if any, by its ClientID and IP address, @@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er // buffer. It optionally uses the client cache, if provided. It also returns // the total amount of records in the buffer at the moment of searching. // l.confMu is expected to be locked. -func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) { +func (l *queryLog) searchMemory( + ctx context.Context, + params *searchParams, + cache clientCache, +) (entries []*logEntry, total int) { // Check memory size, as the buffer can contain a single log record. See // [newQueryLog]. if l.conf.MemSize == 0 { @@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie var err error e.client, err = l.client(e.ClientID, e.IP.String(), cache) if err != nil { - msg := "querylog: enriching memory record at time %s" + - " for client %q (clientid %q): %s" - log.Error(msg, e.Time, e.IP, e.ClientID, err) + l.logger.ErrorContext( + ctx, + "enriching memory record", + "time", e.Time, + "client_ip", e.IP, + "client_id", e.ClientID, + slogutil.KeyError, err, + ) // Go on and try to match anyway. } @@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie // search searches log entries in memory buffer and log file using specified // parameters and returns the list of entries found and the time of the oldest // entry. l.confMu is expected to be locked. -func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) { +func (l *queryLog) search( + ctx context.Context, + params *searchParams, +) (entries []*logEntry, oldest time.Time) { start := time.Now() if params.limit == 0 { @@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim cache := clientCache{} - memoryEntries, bufLen := l.searchMemory(params, cache) - log.Debug("querylog: got %d entries from memory", len(memoryEntries)) + memoryEntries, bufLen := l.searchMemory(ctx, params, cache) + l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries)) - fileEntries, oldest, total := l.searchFiles(params, cache) - log.Debug("querylog: got %d entries from files", len(fileEntries)) + fileEntries, oldest, total := l.searchFiles(ctx, params, cache) + l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries)) total += bufLen @@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim oldest = entries[len(entries)-1].Time } - log.Debug( - "querylog: prepared data (%d/%d) older than %s in %s", - len(entries), - total, - params.olderThan, - time.Since(start), + l.logger.DebugContext( + ctx, + "prepared data", + "count", len(entries), + "total", total, + "older_than", params.olderThan, + "elapsed", time.Since(start), ) return entries, oldest @@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim // seekRecord changes the current position to the next record older than the // provided parameter. -func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { +func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) { if olderThan.IsZero() { return r.SeekStart() } - err = r.seekTS(olderThan.UnixNano()) + err = r.seekTS(ctx, olderThan.UnixNano()) if err == nil { // Read to the next record, because we only need the one that goes // after it. @@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { // setQLogReader creates a reader with the specified files and sets the // position to the next record older than the provided parameter. -func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) { +func (l *queryLog) setQLogReader( + ctx context.Context, + olderThan time.Time, +) (qr *qLogReader, err error) { files := []string{ l.logFile + ".1", l.logFile, } - r, err := newQLogReader(files) + r, err := newQLogReader(ctx, l.logger, files) if err != nil { return nil, fmt.Errorf("opening qlog reader: %s", err) } - err = r.seekRecord(olderThan) + err = r.seekRecord(ctx, olderThan) if err != nil { defer func() { err = errors.WithDeferred(err, r.Close()) }() - log.Debug("querylog: cannot seek to %s: %s", olderThan, err) + l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err) return nil, nil } @@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error // calls faster so that the UI could handle it and show something quicker. // This behavior can be overridden if maxFileScanEntries is set to 0. func (l *queryLog) readEntries( + ctx context.Context, r *qLogReader, params *searchParams, cache clientCache, totalLimit int, ) (entries []*logEntry, oldestNano int64, total int) { for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { - ent, ts, rErr := l.readNextEntry(r, params, cache) + ent, ts, rErr := l.readNextEntry(ctx, r, params, cache) if rErr != nil { if rErr == io.EOF { oldestNano = 0 @@ -205,7 +224,7 @@ func (l *queryLog) readEntries( break } - log.Error("querylog: reading next entry: %s", rErr) + l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr) } oldestNano = ts @@ -231,12 +250,13 @@ func (l *queryLog) readEntries( // and the total number of processed entries, including discarded ones, // correspondingly. func (l *queryLog) searchFiles( + ctx context.Context, params *searchParams, cache clientCache, ) (entries []*logEntry, oldest time.Time, total int) { - r, err := l.setQLogReader(params.olderThan) + r, err := l.setQLogReader(ctx, params.olderThan) if err != nil { - log.Error("querylog: %s", err) + l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err) } if r == nil { @@ -245,12 +265,12 @@ func (l *queryLog) searchFiles( defer func() { if closeErr := r.Close(); closeErr != nil { - log.Error("querylog: closing file: %s", closeErr) + l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr) } }() totalLimit := params.offset + params.limit - entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit) + entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit) if oldestNano != 0 { oldest = time.Unix(0, oldestNano) } @@ -266,15 +286,21 @@ type quickMatchClientFinder struct { } // findClient is a method that can be used as a quickMatchClientFinder. -func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { +func (f quickMatchClientFinder) findClient( + ctx context.Context, + logger *slog.Logger, + clientID string, + ip string, +) (c *Client) { var err error c, err = f.client(clientID, ip, f.cache) if err != nil { - log.Error( - "querylog: enriching file record for quick search: for client %q (clientid %q): %s", - ip, - clientID, - err, + logger.ErrorContext( + ctx, + "enriching file record for quick search", + "client_ip", ip, + "client_id", clientID, + slogutil.KeyError, err, ) } @@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { // the entry doesn't match the search criteria. ts is the timestamp of the // processed entry. func (l *queryLog) readNextEntry( + ctx context.Context, r *qLogReader, params *searchParams, cache clientCache, @@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry( cache: cache, } - if !params.quickMatch(line, clientFinder.findClient) { - ts = readQLogTimestamp(line) + if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) { + ts = readQLogTimestamp(ctx, l.logger, line) return nil, ts, nil } e = &logEntry{} - decodeLogEntry(e, line) + l.decodeLogEntry(ctx, e, line) if l.isIgnored(e.QHost) { return nil, ts, nil @@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry( e.client, err = l.client(e.ClientID, e.IP.String(), cache) if err != nil { - log.Error( - "querylog: enriching file record at time %s for client %q (clientid %q): %s", - e.Time, - e.IP, - e.ClientID, - err, + l.logger.ErrorContext( + ctx, + "enriching file record at time", + "at_time", e.Time, + "client_ip", e.IP, + "client_id", e.ClientID, + slogutil.KeyError, err, ) // Go on and try to match anyway. diff --git a/internal/querylog/search_test.go b/internal/querylog/search_test.go index 939942ad..fb6cf5ce 100644 --- a/internal/querylog/search_test.go +++ b/internal/querylog/search_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) { } l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), FindClient: findClient, BaseDir: t.TempDir(), RotationIvl: timeutil.Day, @@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) { AnonymizeClientIP: false, }) require.NoError(t, err) - t.Cleanup(l.Close) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + t.Cleanup(func() { + l.Close(ctx) + }) q := &dns.Msg{ Question: []dns.Question{{ @@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) { olderThan: time.Now().Add(10 * time.Second), limit: 3, } - entries, _ := l.search(sp) + entries, _ := l.search(ctx, sp) assert.Equal(t, 2, findClientCalls) require.Len(t, entries, 3) diff --git a/internal/querylog/searchcriterion.go b/internal/querylog/searchcriterion.go index a8942a83..7397fc55 100644 --- a/internal/querylog/searchcriterion.go +++ b/internal/querylog/searchcriterion.go @@ -1,7 +1,9 @@ package querylog import ( + "context" "fmt" + "log/slog" "strings" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict( // quickMatch quickly checks if the line matches the given search criterion. // It returns false if the like doesn't match. This method is only here for // optimization purposes. -func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { +func (c *searchCriterion) quickMatch( + ctx context.Context, + logger *slog.Logger, + line string, + findClient quickMatchClientFunc, +) (ok bool) { switch c.criterionType { case ctTerm: host := readJSONValue(line, `"QH":"`) @@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun clientID := readJSONValue(line, `"CID":"`) var name string - if cli := findClient(clientID, ip); cli != nil { + if cli := findClient(ctx, logger, clientID, ip); cli != nil { name = cli.Name } diff --git a/internal/querylog/searchparams.go b/internal/querylog/searchparams.go index a0a0ff6c..f0d6a0c6 100644 --- a/internal/querylog/searchparams.go +++ b/internal/querylog/searchparams.go @@ -1,6 +1,10 @@ package querylog -import "time" +import ( + "context" + "log/slog" + "time" +) // searchParams represent the search query sent by the client. type searchParams struct { @@ -35,14 +39,23 @@ func newSearchParams() *searchParams { } // quickMatchClientFunc is a simplified client finder for quick matches. -type quickMatchClientFunc = func(clientID, ip string) (c *Client) +type quickMatchClientFunc = func( + ctx context.Context, + logger *slog.Logger, + clientID, ip string, +) (c *Client) // quickMatch quickly checks if the line matches the given search parameters. // It returns false if the line doesn't match. This method is only here for // optimization purposes. -func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { +func (s *searchParams) quickMatch( + ctx context.Context, + logger *slog.Logger, + line string, + findClient quickMatchClientFunc, +) (ok bool) { for _, c := range s.searchCriteria { - if !c.quickMatch(line, findClient) { + if !c.quickMatch(ctx, logger, line, findClient) { return false } }