diff --git a/internal/home/dns.go b/internal/home/dns.go index 6fcc09ca..5b07e795 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -48,11 +48,11 @@ func onConfigModified() { // initDNS updates all the fields of the [Context] needed to initialize the DNS // server and initializes it at last. It also must not be called unless // [config] and [Context] are initialized. l must not be nil. -func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) { +func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) { anonymizer := config.anonymizer() statsConf := stats.Config{ - Logger: l.With(slogutil.KeyPrefix, "stats"), + Logger: baseLogger.With(slogutil.KeyPrefix, "stats"), Filename: filepath.Join(statsDir, "stats.db"), Limit: config.Stats.Interval.Duration, ConfigModified: onConfigModified, @@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) { } conf := querylog.Config{ + Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"), Anonymizer: anonymizer, ConfigModified: onConfigModified, HTTPRegister: httpRegister, @@ -113,7 +114,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) { anonymizer, httpRegister, tlsConf, - l, + baseLogger, ) } @@ -457,7 +458,8 @@ func startDNSServer() error { Context.filters.EnableFilters(false) // TODO(s.chzhen): Pass context. - err := Context.clients.Start(context.TODO()) + ctx := context.TODO() + err := Context.clients.Start(ctx) if err != nil { return fmt.Errorf("starting clients container: %w", err) } @@ -469,7 +471,11 @@ func startDNSServer() error { Context.filters.Start() Context.stats.Start() - Context.queryLog.Start() + + err = Context.queryLog.Start(ctx) + if err != nil { + return fmt.Errorf("starting query log: %w", err) + } return nil } @@ -525,12 +531,16 @@ func closeDNSServer() { if Context.stats != nil { err := Context.stats.Close() if err != nil { - log.Debug("closing stats: %s", err) + log.Error("closing stats: %s", err) } } if Context.queryLog != nil { - Context.queryLog.Close() + // TODO(s.chzhen): Pass context. + err := Context.queryLog.Shutdown(context.TODO()) + if err != nil { + log.Error("closing query log: %s", err) + } } log.Debug("all dns modules are closed") diff --git a/internal/querylog/decode.go b/internal/querylog/decode.go index d4dea04e..47d87efc 100644 --- a/internal/querylog/decode.go +++ b/internal/querylog/decode.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -13,7 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{ } // decodeResultRuleKey decodes the token of "Rules" type to logEntry struct. -func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultRuleKey( + ctx context.Context, + key string, + i int, + dec *json.Decoder, + ent *logEntry, +) { var vToken json.Token switch key { case "FilterListID": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if n, ok := vToken.(json.Number); ok { id, _ := n.Int64() ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id) } case "IP": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if ipStr, ok := vToken.(string); ok { if ip, err := netip.ParseAddr(ipStr); err == nil { ent.Result.Rules[i].IP = ip } else { - log.Debug("querylog: decoding ipStr value: %s", err) + l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err) } } case "Text": - ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) + ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules) if s, ok := vToken.(string); ok { ent.Result.Rules[i].Text = s } @@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { // decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule] // and then adds the decoded object to the slice of result rules. -func decodeVTokenAndAddRule( +func (l *queryLog) decodeVTokenAndAddRule( + ctx context.Context, key string, i int, dec *json.Decoder, @@ -215,7 +223,12 @@ func decodeVTokenAndAddRule( vToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultRuleKey %s err: %s", key, err) + l.logger.DebugContext( + ctx, + "decoding result rule key", + "key", key, + slogutil.KeyError, err, + ) } return newRules, nil @@ -230,12 +243,14 @@ func decodeVTokenAndAddRule( // decodeResultRules parses the dec's tokens into logEntry ent interpreting it // as a slice of the result rules. -func decodeResultRules(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) { + const msgPrefix = "decoding result rules" + for { delimToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultRules err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -244,13 +259,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) { if d, ok := delimToken.(json.Delim); !ok { return } else if d != '[' { - log.Debug("decodeResultRules: unexpected delim %q", d) + l.logger.DebugContext( + ctx, + msgPrefix, + slogutil.KeyError, newUnexpectedDelimiterError(d), + ) } - err = decodeResultRuleToken(dec, ent) + err = l.decodeResultRuleToken(ctx, dec, ent) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResultRules err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; rule token", slogutil.KeyError, err) } return @@ -259,7 +278,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) { } // decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent. -func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { +func (l *queryLog) decodeResultRuleToken( + ctx context.Context, + dec *json.Decoder, + ent *logEntry, +) (err error) { i := 0 for { var keyToken json.Token @@ -287,7 +310,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken) } - decodeResultRuleKey(key, i, dec, ent) + l.decodeResultRuleKey(ctx, key, i, dec, ent) } } @@ -296,12 +319,14 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { // other occurrences of DNSRewriteResult in the entry since hosts container's // rewrites currently has the highest priority along the entire filtering // pipeline. -func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) { + const msgPrefix = "decoding result reverse hosts" + for { itemToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultReverseHosts err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -315,7 +340,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { return } - log.Debug("decodeResultReverseHosts: unexpected delim %q", v) + l.logger.DebugContext( + ctx, + msgPrefix, + slogutil.KeyError, newUnexpectedDelimiterError(v), + ) return case string: @@ -346,12 +375,14 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { // decodeResultIPList parses the dec's tokens into logEntry ent interpreting it // as the result IP addresses list. -func decodeResultIPList(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) { + const msgPrefix = "decoding result ip list" + for { itemToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultIPList err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -365,7 +396,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) { return } - log.Debug("decodeResultIPList: unexpected delim %q", v) + l.logger.DebugContext( + ctx, + msgPrefix, + slogutil.KeyError, newUnexpectedDelimiterError(v), + ) return case string: @@ -382,7 +417,14 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) { // decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type // to the logEntry struct. -func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultDNSRewriteResultKey( + ctx context.Context, + key string, + dec *json.Decoder, + ent *logEntry, +) { + const msgPrefix = "decoding result dns rewrite result key" + var err error switch key { @@ -391,7 +433,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr vToken, err = dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeResultDNSRewriteResultKey err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -419,7 +461,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr // decoding and correct the values. err = dec.Decode(&ent.Result.DNSRewriteResult.Response) if err != nil { - log.Debug("decodeResultDNSRewriteResultKey response err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; response", slogutil.KeyError, err) } ent.parseDNSRewriteResultIPs() @@ -430,12 +472,18 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr // decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent // interpreting it as the result DNSRewriteResult. -func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResultDNSRewriteResult( + ctx context.Context, + dec *json.Decoder, + ent *logEntry, +) { + const msgPrefix = "decoding result dns rewrite result" + for { key, err := parseKeyToken(dec) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResultDNSRewriteResult: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -445,7 +493,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { continue } - decodeResultDNSRewriteResultKey(key, dec, ent) + l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent) } } @@ -508,14 +556,16 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) { } // decodeResult decodes a token of "Result" type to logEntry struct. -func decodeResult(dec *json.Decoder, ent *logEntry) { +func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) { + const msgPrefix = "decoding result" + defer translateResult(ent) for { key, err := parseKeyToken(dec) if err != nil { if err != io.EOF && !errors.Is(err, ErrEndOfToken) { - log.Debug("decodeResult: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -525,10 +575,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) { continue } - decHandler, ok := resultDecHandlers[key] + ok := l.resultDecHandler(ctx, key, dec, ent) if ok { - decHandler(dec, ent) - continue } @@ -543,7 +591,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) { } if err = handler(val, ent); err != nil { - log.Debug("decodeResult handler err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err) return } @@ -636,16 +684,34 @@ var resultHandlers = map[string]logEntryHandler{ }, } -// resultDecHandlers is the map of decode handlers for various keys. -var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){ - "ReverseHosts": decodeResultReverseHosts, - "IPList": decodeResultIPList, - "Rules": decodeResultRules, - "DNSRewriteResult": decodeResultDNSRewriteResult, +// resultDecHandlers calls a decode handler for key if there is one. +func (l *queryLog) resultDecHandler( + ctx context.Context, + name string, + dec *json.Decoder, + ent *logEntry, +) (ok bool) { + ok = true + switch name { + case "ReverseHosts": + l.decodeResultReverseHosts(ctx, dec, ent) + case "IPList": + l.decodeResultIPList(ctx, dec, ent) + case "Rules": + l.decodeResultRules(ctx, dec, ent) + case "DNSRewriteResult": + l.decodeResultDNSRewriteResult(ctx, dec, ent) + default: + ok = false + } + + return ok } // decodeLogEntry decodes string str to logEntry ent. -func decodeLogEntry(ent *logEntry, str string) { +func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) { + const msgPrefix = "decoding log entry" + dec := json.NewDecoder(strings.NewReader(str)) dec.UseNumber() @@ -653,7 +719,7 @@ func decodeLogEntry(ent *logEntry, str string) { keyToken, err := dec.Token() if err != nil { if err != io.EOF { - log.Debug("decodeLogEntry err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err) } return @@ -665,13 +731,14 @@ func decodeLogEntry(ent *logEntry, str string) { key, ok := keyToken.(string) if !ok { - log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken) + err = fmt.Errorf("%s: keyToken is %T (%[2]v) and not string", msgPrefix, keyToken) + l.logger.DebugContext(ctx, msgPrefix, slogutil.KeyError, err) return } if key == "Result" { - decodeResult(dec, ent) + l.decodeResult(ctx, dec, ent) continue } @@ -687,9 +754,14 @@ func decodeLogEntry(ent *logEntry, str string) { } if err = handler(val, ent); err != nil { - log.Debug("decodeLogEntry handler err: %s", err) + l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err) return } } } + +// newUnexpectedDelimiterError is a helper for creating informative errors. +func newUnexpectedDelimiterError(d json.Delim) (err error) { + return fmt.Errorf("unexpected delimiter: %q", d) +} diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index 1f907e3d..8344479f 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -3,27 +3,35 @@ package querylog import ( "bytes" "encoding/base64" + "log/slog" "net" "net/netip" - "strings" "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// Common constants for tests. +const testTimeout = 1 * time.Second + func TestDecodeLogEntry(t *testing.T) { logOutput := &bytes.Buffer{} + l := &queryLog{ + logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{ + Level: slog.LevelDebug, + ReplaceAttr: slogutil.RemoveTime, + })), + } - aghtest.ReplaceLogWriter(t, logOutput) - aghtest.ReplaceLogLevel(t, log.DEBUG) + ctx := testutil.ContextWithTimeout(t, testTimeout) t.Run("success", func(t *testing.T) { const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` @@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) { } got := &logEntry{} - decodeLogEntry(got, data) + l.decodeLogEntry(ctx, got, data) s := logOutput.String() assert.Empty(t, s) @@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_filter_id_old_rule", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`, - want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n", + want: `level=DEBUG msg="decoding result; handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`, }, { name: "bad_is_filtered", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n", + want: `level=DEBUG msg="decoding log entry; token" err="invalid character 'o' in literal true (expecting 'u')"`, }, { name: "bad_elapsed", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`, @@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_time", log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n", + want: `level=DEBUG msg="decoding log entry; handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`, }, { name: "bad_host", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, @@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "very_bad_client_proto", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n", + want: `level=DEBUG msg="decoding log entry; handler" err="invalid client proto: \"dog\""`, }, { name: "bad_answer", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, @@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "very_bad_answer", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, - want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n", + want: `level=DEBUG msg="decoding log entry; handler" err="illegal base64 data at input byte 61"`, }, { name: "bad_rule", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`, @@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) { }, { name: "bad_reverse_hosts", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`, - want: "decodeResultReverseHosts: unexpected delim \"{\"\n", + want: `level=DEBUG msg="decoding result reverse hosts" err="unexpected delimiter: \"{\""`, }, { name: "bad_ip_list", log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`, - want: "decodeResultIPList: unexpected delim \"{\"\n", + want: `level=DEBUG msg="decoding result ip list" err="unexpected delimiter: \"{\""`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - decodeLogEntry(new(logEntry), tc.log) - - s := logOutput.String() + l.decodeLogEntry(ctx, new(logEntry), tc.log) + got := logOutput.String() if tc.want == "" { - assert.Empty(t, s) + assert.Empty(t, got) } else { - assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s) + require.NotEmpty(t, got) + + // Remove newline. + got = got[:len(got)-1] + assert.Equal(t, tc.want, got) } logOutput.Reset() @@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) { aaaa2 = aaaa1.Next() ) + l := &queryLog{ + logger: slogutil.NewDiscardLogger(), + } + + ctx := testutil.ContextWithTimeout(t, testTimeout) + testCases := []struct { want *logEntry entry string @@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := &logEntry{} - decodeLogEntry(e, tc.entry) + l.decodeLogEntry(ctx, e, tc.entry) assert.Equal(t, tc.want, e) }) diff --git a/internal/querylog/entry.go b/internal/querylog/entry.go index ed3319b0..67272bba 100644 --- a/internal/querylog/entry.go +++ b/internal/querylog/entry.go @@ -1,12 +1,14 @@ package querylog import ( + "context" + "log/slog" "net" "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" ) @@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) { // addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is // true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any // errors are logged. -func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { +func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) { if resp == nil { return } @@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { e.Answer, err = resp.Pack() err = errors.Annotate(err, "packing answer: %w") } + if err != nil { - log.Error("querylog: %s", err) + l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err) } } diff --git a/internal/querylog/http.go b/internal/querylog/http.go index 1fb7cce4..fb878e04 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "encoding/json" "fmt" "math" @@ -15,7 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" "golang.org/x/net/idna" ) @@ -74,7 +75,8 @@ func (l *queryLog) initWeb() { // handleQueryLog is the handler for the GET /control/querylog HTTP API. func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { - params, err := parseSearchParams(r) + ctx := r.Context() + params, err := l.parseSearchParams(ctx, r) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err) @@ -87,18 +89,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { l.confMu.RLock() defer l.confMu.RUnlock() - entries, oldest = l.search(params) + entries, oldest = l.search(ctx, params) }() - resp := entriesToJSON(entries, oldest, l.anonymizer.Load()) + resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load()) aghhttp.WriteJSONResponseOK(w, r, resp) } // handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP // API. -func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { - l.clear() +func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) { + l.clear(r.Context()) } // handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP @@ -280,11 +282,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool { } // parseSearchCriterion parses a search criterion from the query parameter. -func parseSearchCriterion(q url.Values, name string, ct criterionType) ( - ok bool, - sc searchCriterion, - err error, -) { +func (l *queryLog) parseSearchCriterion( + ctx context.Context, + q url.Values, + name string, + ct criterionType, +) (ok bool, sc searchCriterion, err error) { val := q.Get(name) if val == "" { return false, sc, nil @@ -301,7 +304,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) ( // TODO(e.burkov): Make it work with parts of IDNAs somehow. loweredVal := strings.ToLower(val) if asciiVal, err = idna.ToASCII(loweredVal); err != nil { - log.Debug("can't convert %q to ascii: %s", val, err) + l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err) } else if asciiVal == loweredVal { // Purge asciiVal to prevent checking the same value // twice. @@ -331,7 +334,10 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) ( // parseSearchParams parses search parameters from the HTTP request's query // string. -func parseSearchParams(r *http.Request) (p *searchParams, err error) { +func (l *queryLog) parseSearchParams( + ctx context.Context, + r *http.Request, +) (p *searchParams, err error) { p = newSearchParams() q := r.URL.Query() @@ -369,7 +375,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) { }} { var ok bool var c searchCriterion - ok, c, err = parseSearchCriterion(q, v.urlField, v.ct) + ok, c, err = l.parseSearchCriterion(ctx, q, v.urlField, v.ct) if err != nil { return nil, err } diff --git a/internal/querylog/json.go b/internal/querylog/json.go index 07a9d62b..295ce7da 100644 --- a/internal/querylog/json.go +++ b/internal/querylog/json.go @@ -1,6 +1,7 @@ package querylog import ( + "context" "slices" "strconv" "strings" @@ -8,7 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/miekg/dns" "golang.org/x/net/idna" ) @@ -19,7 +20,8 @@ import ( type jobject = map[string]any // entriesToJSON converts query log entries to JSON. -func entriesToJSON( +func (l *queryLog) entriesToJSON( + ctx context.Context, entries []*logEntry, oldest time.Time, anonFunc aghnet.IPMutFunc, @@ -28,7 +30,7 @@ func entriesToJSON( // The elements order is already reversed to be from newer to older. for _, entry := range entries { - jsonEntry := entryToJSON(entry, anonFunc) + jsonEntry := l.entryToJSON(ctx, entry, anonFunc) data = append(data, jsonEntry) } @@ -44,7 +46,11 @@ func entriesToJSON( } // entryToJSON converts a log entry's data into an entry for the JSON API. -func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) { +func (l *queryLog) entryToJSON( + ctx context.Context, + entry *logEntry, + anonFunc aghnet.IPMutFunc, +) (jsonEntry jobject) { hostname := entry.QHost question := jobject{ "type": entry.QType, @@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) } if qhost, err := idna.ToUnicode(hostname); err != nil { - log.Debug("querylog: translating %q into unicode: %s", hostname, err) + l.logger.DebugContext( + ctx, + "translating into unicode", + "hostname", hostname, + slogutil.KeyError, err, + ) } else if qhost != hostname && qhost != "" { question["unicode_name"] = qhost } @@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) jsonEntry["service_name"] = entry.Result.ServiceName } - setMsgData(entry, jsonEntry) - setOrigAns(entry, jsonEntry) + l.setMsgData(ctx, entry, jsonEntry) + l.setOrigAns(ctx, entry, jsonEntry) return jsonEntry } // setMsgData sets the message data in jsonEntry. -func setMsgData(entry *logEntry, jsonEntry jobject) { +func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) { if len(entry.Answer) == 0 { return } msg := &dns.Msg{} if err := msg.Unpack(entry.Answer); err != nil { - log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err) + l.logger.DebugContext( + ctx, + "unpacking dns message", + "answer", entry.Answer, + slogutil.KeyError, err, + ) return } @@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) { } // setOrigAns sets the original answer data in jsonEntry. -func setOrigAns(entry *logEntry, jsonEntry jobject) { +func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) { if len(entry.OrigAnswer) == 0 { return } @@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) { orig := &dns.Msg{} err := orig.Unpack(entry.OrigAnswer) if err != nil { - log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err) + l.logger.DebugContext( + ctx, + "setting original answer", + "answer", entry.OrigAnswer, + slogutil.KeyError, err, + ) return } diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index c0b2edd1..0f89854f 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -2,7 +2,9 @@ package querylog import ( + "context" "fmt" + "log/slog" "os" "sync" "time" @@ -11,7 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" ) @@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json" // queryLog is a structure that writes and reads the DNS query log. type queryLog struct { + // logger is used for logging the operation of the query log. It must not + // be nil. + logger *slog.Logger + // confMu protects conf. confMu *sync.RWMutex @@ -76,24 +82,34 @@ func NewClientProto(s string) (cp ClientProto, err error) { } } -func (l *queryLog) Start() { +// type check +var _ QueryLog = (*queryLog)(nil) + +// Start implements the [QueryLog] interface for *queryLog. +func (l *queryLog) Start(ctx context.Context) (err error) { if l.conf.HTTPRegister != nil { l.initWeb() } - go l.periodicRotate() + go l.periodicRotate(ctx) + + return nil } -func (l *queryLog) Close() { +// Shutdown implements the [QueryLog] interface for *queryLog. +func (l *queryLog) Shutdown(ctx context.Context) (err error) { l.confMu.RLock() defer l.confMu.RUnlock() if l.conf.FileEnabled { - err := l.flushLogBuffer() + err = l.flushLogBuffer(ctx) if err != nil { - log.Error("querylog: closing: %s", err) + // Don't wrap the error because it's informative enough as is. + return err } } + + return nil } func checkInterval(ivl time.Duration) (ok bool) { @@ -123,6 +139,7 @@ func validateIvl(ivl time.Duration) (err error) { return nil } +// WriteDiskConfig implements the [QueryLog] interface for *queryLog. func (l *queryLog) WriteDiskConfig(c *Config) { l.confMu.RLock() defer l.confMu.RUnlock() @@ -131,7 +148,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) { } // Clear memory buffer and remove log files -func (l *queryLog) clear() { +func (l *queryLog) clear(ctx context.Context) { l.fileFlushLock.Lock() defer l.fileFlushLock.Unlock() @@ -146,19 +163,24 @@ func (l *queryLog) clear() { oldLogFile := l.logFile + ".1" err := os.Remove(oldLogFile) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("removing old log file %q: %s", oldLogFile, err) + l.logger.ErrorContext( + ctx, + "removing old log file", + "file", oldLogFile, + slogutil.KeyError, err, + ) } err = os.Remove(l.logFile) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("removing log file %q: %s", l.logFile, err) + l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err) } - log.Debug("querylog: cleared") + l.logger.DebugContext(ctx, "cleared") } // newLogEntry creates an instance of logEntry from parameters. -func newLogEntry(params *AddParams) (entry *logEntry) { +func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) { q := params.Question.Question[0] qHost := aghnet.NormalizeDomain(q.Name) @@ -187,8 +209,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) { entry.ReqECS = params.ReqECS.String() } - entry.addResponse(params.Answer, false) - entry.addResponse(params.OrigAnswer, true) + entry.addResponse(ctx, logger, params.Answer, false) + entry.addResponse(ctx, logger, params.OrigAnswer, true) return entry } @@ -209,9 +231,12 @@ func (l *queryLog) Add(params *AddParams) { return } + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + err := params.validate() if err != nil { - log.Error("querylog: adding record: %s, skipping", err) + l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err) return } @@ -220,7 +245,7 @@ func (l *queryLog) Add(params *AddParams) { params.Result = &filtering.Result{} } - entry := newLogEntry(params) + entry := newLogEntry(ctx, l.logger, params) l.bufferLock.Lock() defer l.bufferLock.Unlock() @@ -232,9 +257,9 @@ func (l *queryLog) Add(params *AddParams) { // TODO(s.chzhen): Fix occasional rewrite of entires. go func() { - flushErr := l.flushLogBuffer() + flushErr := l.flushLogBuffer(ctx) if flushErr != nil { - log.Error("querylog: flushing after adding: %s", flushErr) + l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr) } }() } @@ -247,7 +272,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool { c, err := l.findClient(ids) if err != nil { - log.Error("querylog: finding client: %s", err) + // TODO(s.chzhen): Pass context. + l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err) } if c != nil && c.IgnoreQueryLog { diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 57d8b68d..2a688552 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" @@ -14,14 +15,11 @@ import ( "github.com/stretchr/testify/require" ) -func TestMain(m *testing.M) { - testutil.DiscardLogOutput(m) -} - // TestQueryLog tests adding and loading (with filtering) entries from disk and // memory. func TestQueryLog(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, @@ -30,16 +28,21 @@ func TestQueryLog(t *testing.T) { }) require.NoError(t, err) + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // Write to disk (first file). - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) + // Start writing to the second file. - require.NoError(t, l.rotate()) + require.NoError(t, l.rotate(ctx)) + // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // Write to disk. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) + // Add memory entries. addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) @@ -119,8 +122,9 @@ func TestQueryLog(t *testing.T) { params := newSearchParams() params.searchCriteria = tc.sCr - entries, _ := l.search(params) + entries, _ := l.search(ctx, params) require.Len(t, entries, len(tc.want)) + for _, want := range tc.want { assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) } @@ -130,6 +134,7 @@ func TestQueryLog(t *testing.T) { func TestQueryLogOffsetLimit(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, RotationIvl: timeutil.Day, MemSize: 100, @@ -142,12 +147,16 @@ func TestQueryLogOffsetLimit(t *testing.T) { firstPageDomain = "first.example.org" secondPageDomain = "second.example.org" ) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Add entries to the log. for range entNum { addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to the first file. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) + // Add more to the in-memory part of log. for range entNum { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) @@ -191,8 +200,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { t.Run(tc.name, func(t *testing.T) { params.offset = tc.offset params.limit = tc.limit - entries, _ := l.search(params) - + entries, _ := l.search(ctx, params) require.Len(t, entries, tc.wantLen) if tc.wantLen > 0 { @@ -205,6 +213,7 @@ func TestQueryLogOffsetLimit(t *testing.T) { func TestQueryLogMaxFileScanEntries(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: true, RotationIvl: timeutil.Day, @@ -213,20 +222,21 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { }) require.NoError(t, err) + ctx := testutil.ContextWithTimeout(t, testTimeout) + const entNum = 10 // Add entries to the log. for range entNum { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to disk. - require.NoError(t, l.flushLogBuffer()) + require.NoError(t, l.flushLogBuffer(ctx)) params := newSearchParams() - for _, maxFileScanEntries := range []int{5, 0} { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { params.maxFileScanEntries = maxFileScanEntries - entries, _ := l.search(params) + entries, _ := l.search(ctx, params) assert.Len(t, entries, entNum-maxFileScanEntries) }) } @@ -234,6 +244,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { func TestQueryLogFileDisabled(t *testing.T) { l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), Enabled: true, FileEnabled: false, RotationIvl: timeutil.Day, @@ -248,8 +259,10 @@ func TestQueryLogFileDisabled(t *testing.T) { addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) params := newSearchParams() - ll, _ := l.search(params) + ctx := testutil.ContextWithTimeout(t, testTimeout) + ll, _ := l.search(ctx, params) require.Len(t, ll, 2) + assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) } diff --git a/internal/querylog/qlogfile.go b/internal/querylog/qlogfile.go index c145f520..67b1eeb3 100644 --- a/internal/querylog/qlogfile.go +++ b/internal/querylog/qlogfile.go @@ -1,8 +1,10 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "os" "strings" "sync" @@ -10,7 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) const ( @@ -102,7 +104,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6 // for so that when we call "ReadNext" this line was returned. // - Depth of the search (how many times we compared timestamps). // - If we could not find it, it returns one of the errors described above. -func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { +func (q *qLogFile) seekTS( + ctx context.Context, + logger *slog.Logger, + timestamp int64, +) (pos int64, depth int, err error) { q.lock.Lock() defer q.lock.Unlock() @@ -151,7 +157,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { lastProbeLineIdx = lineIdx // Get the timestamp from the query log record. - ts := readQLogTimestamp(line) + ts := readQLogTimestamp(ctx, logger, line) if ts == 0 { return 0, depth, fmt.Errorf( "looking up timestamp %d in %q: record %q has empty timestamp", @@ -385,20 +391,22 @@ func readJSONValue(s, prefix string) string { } // readQLogTimestamp reads the timestamp field from the query log line. -func readQLogTimestamp(str string) int64 { +func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 { val := readJSONValue(str, `"T":"`) if len(val) == 0 { val = readJSONValue(str, `"Time":"`) } if len(val) == 0 { - log.Error("Couldn't find timestamp: %s", str) + logger.ErrorContext(ctx, "couldn't find timestamp", "line", str) + return 0 } tm, err := time.Parse(time.RFC3339Nano, val) if err != nil { - log.Error("Couldn't parse timestamp: %s", val) + logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val, slogutil.KeyError, err) + return 0 } diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index 8462e950..087d43aa 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) { f, err := os.CreateTemp(dir, "*.txt") require.NoError(t, err) + // Use defer and not t.Cleanup to make sure that the file is closed // after this function is done. defer func() { @@ -108,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) { // Calculate the expected position. fileInfo, err := q.file.Stat() require.NoError(t, err) + var expPos int64 if expPos = fileInfo.Size(); expPos > 0 { expPos-- @@ -129,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) { } require.Equal(t, io.EOF, err) + assert.Equal(t, tc.linesNum, read) }) } @@ -146,6 +150,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) { num: 10, }} + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + for _, l := range linesCases { testCases := []struct { name string @@ -171,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) { line, err := getQLogFileLine(q, tc.line) require.NoError(t, err) - ts := readQLogTimestamp(line) + + ts := readQLogTimestamp(ctx, logger, line) assert.NotEqualValues(t, 0, ts) // Try seeking to that line now. - pos, _, err := q.seekTS(ts) + pos, _, err := q.seekTS(ctx, logger, ts) require.NoError(t, err) + assert.NotEqualValues(t, 0, pos) testLine, err := q.ReadNext() require.NoError(t, err) + assert.Equal(t, line, testLine) }) } @@ -199,6 +209,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) { num: 10, }} + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + for _, l := range linesCases { testCases := []struct { name string @@ -221,14 +234,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) { line, err := getQLogFileLine(q, l.num/2) require.NoError(t, err) - testCases[2].ts = readQLogTimestamp(line) - 1 + testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.NotEqualValues(t, 0, tc.ts) var depth int - _, depth, err = q.seekTS(tc.ts) + _, depth, err = q.seekTS(ctx, logger, tc.ts) assert.NotEmpty(t, l.num) require.Error(t, err) @@ -262,11 +275,13 @@ func TestQLogFile(t *testing.T) { // Seek to the start. pos, err := q.SeekStart() require.NoError(t, err) + assert.Greater(t, pos, int64(0)) // Read first line. line, err := q.ReadNext() require.NoError(t, err) + assert.Contains(t, line, "0.0.0.2") assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasSuffix(line, "}"), line) @@ -274,6 +289,7 @@ func TestQLogFile(t *testing.T) { // Read second line. line, err = q.ReadNext() require.NoError(t, err) + assert.EqualValues(t, 0, q.position) assert.Contains(t, line, "0.0.0.1") assert.True(t, strings.HasPrefix(line, "{"), line) @@ -282,12 +298,14 @@ func TestQLogFile(t *testing.T) { // Try reading again (there's nothing to read anymore). line, err = q.ReadNext() require.Equal(t, io.EOF, err) + assert.Empty(t, line) } func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { f, err := os.CreateTemp(t.TempDir(), "*.txt") require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, f.Close) _, err = f.WriteString(data) @@ -295,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { file, err = newQLogFile(f.Name()) require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, file.Close) return file @@ -308,6 +327,9 @@ func TestQLog_Seek(t *testing.T) { `{"T":"` + strV + `"}` + nl timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00") + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + testCases := []struct { wantErr error name string @@ -340,8 +362,10 @@ func TestQLog_Seek(t *testing.T) { q := newTestQLogFileData(t, data) - _, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()) + ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano() + _, depth, err := q.seekTS(ctx, logger, ts) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) + assert.Equal(t, tc.wantDepth, depth) }) } diff --git a/internal/querylog/qlogreader.go b/internal/querylog/qlogreader.go index a222323e..dad4a594 100644 --- a/internal/querylog/qlogreader.go +++ b/internal/querylog/qlogreader.go @@ -1,12 +1,14 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "os" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) // qLogReader allows reading from multiple query log files in the reverse @@ -16,6 +18,10 @@ import ( // pointer to a particular query log file, and to a specific position in this // file, and it reads lines in reverse order starting from that position. type qLogReader struct { + // logger is used for logging the operation of the query log reader. It + // must not be nil. + logger *slog.Logger + // qFiles is an array with the query log files. The order is from oldest // to newest. qFiles []*qLogFile @@ -25,7 +31,7 @@ type qLogReader struct { } // newQLogReader initializes a qLogReader instance with the specified files. -func newQLogReader(files []string) (*qLogReader, error) { +func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) { qFiles := make([]*qLogFile, 0) for _, f := range files { @@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) { // Close what we've already opened. cErr := closeQFiles(qFiles) if cErr != nil { - log.Debug("querylog: closing files: %s", cErr) + logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr) } return nil, err @@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) { qFiles = append(qFiles, q) } - return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil + return &qLogReader{ + logger: logger, + qFiles: qFiles, + currentFile: len(qFiles) - 1, + }, nil } // seekTS performs binary search of a query log record with the specified // timestamp. If the record is found, it sets qLogReader's position to point // to that line, so that the next ReadNext call returned this line. -func (r *qLogReader) seekTS(timestamp int64) (err error) { +func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) { for i := len(r.qFiles) - 1; i >= 0; i-- { q := r.qFiles[i] - _, _, err = q.seekTS(timestamp) + _, _, err = q.seekTS(ctx, r.logger, timestamp) if err != nil { if errors.Is(err, errTSTooEarly) { // Look at the next file, since we've reached the end of this diff --git a/internal/querylog/qlogreader_test.go b/internal/querylog/qlogreader_test.go index 43bb3d5c..bb3ce164 100644 --- a/internal/querylog/qlogreader_test.go +++ b/internal/querylog/qlogreader_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader testFiles := prepareTestFiles(t, filesNum, linesNum) + logger := slogutil.NewDiscardLogger() + ctx := testutil.ContextWithTimeout(t, testTimeout) + // Create the new qLogReader instance. - reader, err := newQLogReader(testFiles) + reader, err := newQLogReader(ctx, logger, testFiles) require.NoError(t, err) assert.NotNil(t, reader) @@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) { func TestQLogReader_Seek(t *testing.T) { r := newTestQLogReader(t, 2, 10000) + ctx := testutil.ContextWithTimeout(t, testTimeout) testCases := []struct { want error @@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) { ts, err := time.Parse(time.RFC3339Nano, tc.time) require.NoError(t, err) - err = r.seekTS(ts.UnixNano()) + err = r.seekTS(ctx, ts.UnixNano()) assert.ErrorIs(t, err, tc.want) }) } diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index bccc264a..c7350f70 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -2,6 +2,7 @@ package querylog import ( "fmt" + "log/slog" "net" "path/filepath" "sync" @@ -12,20 +13,19 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/service" "github.com/miekg/dns" ) -// QueryLog - main interface +// QueryLog is the query log interface for use by other packages. type QueryLog interface { - Start() + // Interface starts and stops the query log. + service.Interface - // Close query log object - Close() - - // Add a log entry + // Add adds a log entry. Add(params *AddParams) - // WriteDiskConfig - write configuration + // WriteDiskConfig writes the query log configuration to c. WriteDiskConfig(c *Config) // ShouldLog returns true if request for the host should be logged. @@ -36,6 +36,10 @@ type QueryLog interface { // // Do not alter any fields of this structure after using it. type Config struct { + // Logger is used for logging the operation of the query log. It must not + // be nil. + Logger *slog.Logger + // Ignored contains the list of host names, which should not be written to // log, and matches them. Ignored *aghnet.IgnoreEngine @@ -151,6 +155,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) { } l = &queryLog{ + logger: conf.Logger, findClient: findClient, buffer: container.NewRingBuffer[*logEntry](memSize), diff --git a/internal/querylog/querylogfile.go b/internal/querylog/querylogfile.go index 6b4760c0..84da97cf 100644 --- a/internal/querylog/querylogfile.go +++ b/internal/querylog/querylogfile.go @@ -2,6 +2,7 @@ package querylog import ( "bytes" + "context" "encoding/json" "fmt" "os" @@ -9,28 +10,30 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/c2h5oh/datasize" ) // flushLogBuffer flushes the current buffer to file and resets the current // buffer. -func (l *queryLog) flushLogBuffer() (err error) { +func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) { defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }() + l.fileFlushLock.Lock() defer l.fileFlushLock.Unlock() - b, err := l.encodeEntries() + b, err := l.encodeEntries(ctx) if err != nil { // Don't wrap the error since it's informative enough as is. return err } - return l.flushToFile(b) + return l.flushToFile(ctx, b) } // encodeEntries returns JSON encoded log entries, logs estimated time, clears // the log buffer. -func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { +func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) { l.bufferLock.Lock() defer l.bufferLock.Unlock() @@ -55,8 +58,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { return nil, err } + size := b.Len() elapsed := time.Since(start) - log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen)) + l.logger.DebugContext( + ctx, + "serialized elements via json", + "count", bufLen, + "elapsed", elapsed, + "size", datasize.ByteSize(size), + "size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)), + "time_per_entry", elapsed/time.Duration(bufLen), + ) l.buffer.Clear() l.flushPending = false @@ -65,7 +77,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { } // flushToFile saves the encoded log entries to the query log file. -func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { +func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) { l.fileWriteLock.Lock() defer l.fileWriteLock.Unlock() @@ -83,19 +95,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { return fmt.Errorf("writing to file %q: %w", filename, err) } - log.Debug("querylog: ok %q: %v bytes written", filename, n) + l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n)) return nil } -func (l *queryLog) rotate() error { +func (l *queryLog) rotate(ctx context.Context) error { from := l.logFile to := l.logFile + ".1" err := os.Rename(from, to) if err != nil { if errors.Is(err, os.ErrNotExist) { - log.Debug("querylog: no log to rotate") + l.logger.DebugContext(ctx, "no log to rotate") return nil } @@ -103,12 +115,12 @@ func (l *queryLog) rotate() error { return fmt.Errorf("failed to rename old file: %w", err) } - log.Debug("querylog: renamed %s into %s", from, to) + l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to) return nil } -func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { +func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) { var f *os.File f, err = os.Open(l.logFile) if err != nil { @@ -130,15 +142,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { return time.Time{}, err } - log.Debug("querylog: the oldest log entry: %s", val) + l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val) return t, nil } -func (l *queryLog) periodicRotate() { - defer log.OnPanic("querylog: rotating") +func (l *queryLog) periodicRotate(ctx context.Context) { + defer slogutil.RecoverAndLog(ctx, l.logger) - l.checkAndRotate() + l.checkAndRotate(ctx) // rotationCheckIvl is the period of time between checking the need for // rotating log files. It's smaller of any available rotation interval to @@ -151,13 +163,13 @@ func (l *queryLog) periodicRotate() { defer rotations.Stop() for range rotations.C { - l.checkAndRotate() + l.checkAndRotate(ctx) } } // checkAndRotate rotates log files if those are older than the specified // rotation interval. -func (l *queryLog) checkAndRotate() { +func (l *queryLog) checkAndRotate(ctx context.Context) { var rotationIvl time.Duration func() { l.confMu.RLock() @@ -166,29 +178,30 @@ func (l *queryLog) checkAndRotate() { rotationIvl = l.conf.RotationIvl }() - oldest, err := l.readFileFirstTimeValue() + oldest, err := l.readFileFirstTimeValue(ctx) if err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error("querylog: reading oldest record for rotation: %s", err) + l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err) return } if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) { - log.Debug( - "querylog: %s <= %s, not rotating", - now.Format(time.RFC3339), - rotTime.Format(time.RFC3339), + l.logger.DebugContext( + ctx, + "not rotating", + "now", now.Format(time.RFC3339), + "rotate_time", rotTime.Format(time.RFC3339), ) return } - err = l.rotate() + err = l.rotate(ctx) if err != nil { - log.Error("querylog: rotating: %s", err) + l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err) return } - log.Debug("querylog: rotated successfully") + l.logger.DebugContext(ctx, "rotated successfully") } diff --git a/internal/querylog/search.go b/internal/querylog/search.go index f8c8f90e..153d81a4 100644 --- a/internal/querylog/search.go +++ b/internal/querylog/search.go @@ -1,13 +1,15 @@ package querylog import ( + "context" "fmt" "io" + "log/slog" "slices" "time" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) // client finds the client info, if any, by its ClientID and IP address, @@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er // buffer. It optionally uses the client cache, if provided. It also returns // the total amount of records in the buffer at the moment of searching. // l.confMu is expected to be locked. -func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) { +func (l *queryLog) searchMemory( + ctx context.Context, + params *searchParams, + cache clientCache, +) (entries []*logEntry, total int) { // Check memory size, as the buffer can contain a single log record. See // [newQueryLog]. if l.conf.MemSize == 0 { @@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie var err error e.client, err = l.client(e.ClientID, e.IP.String(), cache) if err != nil { - msg := "querylog: enriching memory record at time %s" + - " for client %q (clientid %q): %s" - log.Error(msg, e.Time, e.IP, e.ClientID, err) + l.logger.ErrorContext( + ctx, + "enriching memory record", + "at", e.Time, + "client_ip", e.IP, + "client_id", e.ClientID, + slogutil.KeyError, err, + ) // Go on and try to match anyway. } @@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie // search searches log entries in memory buffer and log file using specified // parameters and returns the list of entries found and the time of the oldest // entry. l.confMu is expected to be locked. -func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) { +func (l *queryLog) search( + ctx context.Context, + params *searchParams, +) (entries []*logEntry, oldest time.Time) { start := time.Now() if params.limit == 0 { @@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim cache := clientCache{} - memoryEntries, bufLen := l.searchMemory(params, cache) - log.Debug("querylog: got %d entries from memory", len(memoryEntries)) + memoryEntries, bufLen := l.searchMemory(ctx, params, cache) + l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries)) - fileEntries, oldest, total := l.searchFiles(params, cache) - log.Debug("querylog: got %d entries from files", len(fileEntries)) + fileEntries, oldest, total := l.searchFiles(ctx, params, cache) + l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries)) total += bufLen @@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim oldest = entries[len(entries)-1].Time } - log.Debug( - "querylog: prepared data (%d/%d) older than %s in %s", - len(entries), - total, - params.olderThan, - time.Since(start), + l.logger.DebugContext( + ctx, + "prepared data", + "count", len(entries), + "total", total, + "older_than", params.olderThan, + "elapsed", time.Since(start), ) return entries, oldest @@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim // seekRecord changes the current position to the next record older than the // provided parameter. -func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { +func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) { if olderThan.IsZero() { return r.SeekStart() } - err = r.seekTS(olderThan.UnixNano()) + err = r.seekTS(ctx, olderThan.UnixNano()) if err == nil { // Read to the next record, because we only need the one that goes // after it. @@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { // setQLogReader creates a reader with the specified files and sets the // position to the next record older than the provided parameter. -func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) { +func (l *queryLog) setQLogReader( + ctx context.Context, + olderThan time.Time, +) (qr *qLogReader, err error) { files := []string{ l.logFile + ".1", l.logFile, } - r, err := newQLogReader(files) + r, err := newQLogReader(ctx, l.logger, files) if err != nil { - return nil, fmt.Errorf("opening qlog reader: %s", err) + return nil, fmt.Errorf("opening qlog reader: %w", err) } - err = r.seekRecord(olderThan) + err = r.seekRecord(ctx, olderThan) if err != nil { defer func() { err = errors.WithDeferred(err, r.Close()) }() - log.Debug("querylog: cannot seek to %s: %s", olderThan, err) + l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err) return nil, nil } @@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error // calls faster so that the UI could handle it and show something quicker. // This behavior can be overridden if maxFileScanEntries is set to 0. func (l *queryLog) readEntries( + ctx context.Context, r *qLogReader, params *searchParams, cache clientCache, totalLimit int, ) (entries []*logEntry, oldestNano int64, total int) { for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { - ent, ts, rErr := l.readNextEntry(r, params, cache) + ent, ts, rErr := l.readNextEntry(ctx, r, params, cache) if rErr != nil { if rErr == io.EOF { oldestNano = 0 @@ -205,7 +224,7 @@ func (l *queryLog) readEntries( break } - log.Error("querylog: reading next entry: %s", rErr) + l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr) } oldestNano = ts @@ -231,12 +250,13 @@ func (l *queryLog) readEntries( // and the total number of processed entries, including discarded ones, // correspondingly. func (l *queryLog) searchFiles( + ctx context.Context, params *searchParams, cache clientCache, ) (entries []*logEntry, oldest time.Time, total int) { - r, err := l.setQLogReader(params.olderThan) + r, err := l.setQLogReader(ctx, params.olderThan) if err != nil { - log.Error("querylog: %s", err) + l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err) } if r == nil { @@ -245,12 +265,12 @@ func (l *queryLog) searchFiles( defer func() { if closeErr := r.Close(); closeErr != nil { - log.Error("querylog: closing file: %s", closeErr) + l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr) } }() totalLimit := params.offset + params.limit - entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit) + entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit) if oldestNano != 0 { oldest = time.Unix(0, oldestNano) } @@ -266,15 +286,21 @@ type quickMatchClientFinder struct { } // findClient is a method that can be used as a quickMatchClientFinder. -func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { +func (f quickMatchClientFinder) findClient( + ctx context.Context, + logger *slog.Logger, + clientID string, + ip string, +) (c *Client) { var err error c, err = f.client(clientID, ip, f.cache) if err != nil { - log.Error( - "querylog: enriching file record for quick search: for client %q (clientid %q): %s", - ip, - clientID, - err, + logger.ErrorContext( + ctx, + "enriching file record for quick search", + "client_ip", ip, + "client_id", clientID, + slogutil.KeyError, err, ) } @@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { // the entry doesn't match the search criteria. ts is the timestamp of the // processed entry. func (l *queryLog) readNextEntry( + ctx context.Context, r *qLogReader, params *searchParams, cache clientCache, @@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry( cache: cache, } - if !params.quickMatch(line, clientFinder.findClient) { - ts = readQLogTimestamp(line) + if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) { + ts = readQLogTimestamp(ctx, l.logger, line) return nil, ts, nil } e = &logEntry{} - decodeLogEntry(e, line) + l.decodeLogEntry(ctx, e, line) if l.isIgnored(e.QHost) { return nil, ts, nil @@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry( e.client, err = l.client(e.ClientID, e.IP.String(), cache) if err != nil { - log.Error( - "querylog: enriching file record at time %s for client %q (clientid %q): %s", - e.Time, - e.IP, - e.ClientID, - err, + l.logger.ErrorContext( + ctx, + "enriching file record", + "at", e.Time, + "client_ip", e.IP, + "client_id", e.ClientID, + slogutil.KeyError, err, ) // Go on and try to match anyway. diff --git a/internal/querylog/search_test.go b/internal/querylog/search_test.go index 939942ad..7bc97f70 100644 --- a/internal/querylog/search_test.go +++ b/internal/querylog/search_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) { } l, err := newQueryLog(Config{ + Logger: slogutil.NewDiscardLogger(), FindClient: findClient, BaseDir: t.TempDir(), RotationIvl: timeutil.Day, @@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) { AnonymizeClientIP: false, }) require.NoError(t, err) - t.Cleanup(l.Close) + + ctx := testutil.ContextWithTimeout(t, testTimeout) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return l.Shutdown(ctx) + }) q := &dns.Msg{ Question: []dns.Question{{ @@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) { olderThan: time.Now().Add(10 * time.Second), limit: 3, } - entries, _ := l.search(sp) + entries, _ := l.search(ctx, sp) assert.Equal(t, 2, findClientCalls) require.Len(t, entries, 3) diff --git a/internal/querylog/searchcriterion.go b/internal/querylog/searchcriterion.go index a8942a83..7397fc55 100644 --- a/internal/querylog/searchcriterion.go +++ b/internal/querylog/searchcriterion.go @@ -1,7 +1,9 @@ package querylog import ( + "context" "fmt" + "log/slog" "strings" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict( // quickMatch quickly checks if the line matches the given search criterion. // It returns false if the like doesn't match. This method is only here for // optimization purposes. -func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { +func (c *searchCriterion) quickMatch( + ctx context.Context, + logger *slog.Logger, + line string, + findClient quickMatchClientFunc, +) (ok bool) { switch c.criterionType { case ctTerm: host := readJSONValue(line, `"QH":"`) @@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun clientID := readJSONValue(line, `"CID":"`) var name string - if cli := findClient(clientID, ip); cli != nil { + if cli := findClient(ctx, logger, clientID, ip); cli != nil { name = cli.Name } diff --git a/internal/querylog/searchparams.go b/internal/querylog/searchparams.go index a0a0ff6c..f0d6a0c6 100644 --- a/internal/querylog/searchparams.go +++ b/internal/querylog/searchparams.go @@ -1,6 +1,10 @@ package querylog -import "time" +import ( + "context" + "log/slog" + "time" +) // searchParams represent the search query sent by the client. type searchParams struct { @@ -35,14 +39,23 @@ func newSearchParams() *searchParams { } // quickMatchClientFunc is a simplified client finder for quick matches. -type quickMatchClientFunc = func(clientID, ip string) (c *Client) +type quickMatchClientFunc = func( + ctx context.Context, + logger *slog.Logger, + clientID, ip string, +) (c *Client) // quickMatch quickly checks if the line matches the given search parameters. // It returns false if the line doesn't match. This method is only here for // optimization purposes. -func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { +func (s *searchParams) quickMatch( + ctx context.Context, + logger *slog.Logger, + line string, + findClient quickMatchClientFunc, +) (ok bool) { for _, c := range s.searchCriteria { - if !c.quickMatch(line, findClient) { + if !c.quickMatch(ctx, logger, line, findClient) { return false } }