all: slog querylog

This commit is contained in:
Stanislav Chzhen 2024-11-15 15:57:07 +03:00
parent 1d6d85cff4
commit 8d88d79930
18 changed files with 462 additions and 224 deletions

View File

@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
} }
conf := querylog.Config{ conf := querylog.Config{
Logger: l.With(slogutil.KeyPrefix, "querylog"),
Anonymizer: anonymizer, Anonymizer: anonymizer,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
@ -457,7 +458,8 @@ func startDNSServer() error {
Context.filters.EnableFilters(false) Context.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context. // TODO(s.chzhen): Pass context.
err := Context.clients.Start(context.TODO()) ctx := context.TODO()
err := Context.clients.Start(ctx)
if err != nil { if err != nil {
return fmt.Errorf("starting clients container: %w", err) return fmt.Errorf("starting clients container: %w", err)
} }
@ -469,7 +471,7 @@ func startDNSServer() error {
Context.filters.Start() Context.filters.Start()
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start() Context.queryLog.Start(ctx)
return nil return nil
} }
@ -530,7 +532,8 @@ func closeDNSServer() {
} }
if Context.queryLog != nil { if Context.queryLog != nil {
Context.queryLog.Close() // TODO(s.chzhen): Pass context.
Context.queryLog.Close(context.TODO())
} }
log.Debug("all dns modules are closed") log.Debug("all dns modules are closed")

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,7 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{
} }
// decodeResultRuleKey decodes the token of "Rules" type to logEntry struct. // decodeResultRuleKey decodes the token of "Rules" type to logEntry struct.
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultRuleKey(
ctx context.Context,
key string,
i int,
dec *json.Decoder,
ent *logEntry,
) {
var vToken json.Token var vToken json.Token
switch key { switch key {
case "FilterListID": case "FilterListID":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if n, ok := vToken.(json.Number); ok { if n, ok := vToken.(json.Number); ok {
id, _ := n.Int64() id, _ := n.Int64()
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id) ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
} }
case "IP": case "IP":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if ipStr, ok := vToken.(string); ok { if ipStr, ok := vToken.(string); ok {
if ip, err := netip.ParseAddr(ipStr); err == nil { if ip, err := netip.ParseAddr(ipStr); err == nil {
ent.Result.Rules[i].IP = ip ent.Result.Rules[i].IP = ip
} else { } else {
log.Debug("querylog: decoding ipStr value: %s", err) l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err)
} }
} }
case "Text": case "Text":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if s, ok := vToken.(string); ok { if s, ok := vToken.(string); ok {
ent.Result.Rules[i].Text = s ent.Result.Rules[i].Text = s
} }
@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
// decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule] // decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule]
// and then adds the decoded object to the slice of result rules. // and then adds the decoded object to the slice of result rules.
func decodeVTokenAndAddRule( func (l *queryLog) decodeVTokenAndAddRule(
ctx context.Context,
key string, key string,
i int, i int,
dec *json.Decoder, dec *json.Decoder,
@ -215,7 +223,7 @@ func decodeVTokenAndAddRule(
vToken, err := dec.Token() vToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultRuleKey %s err: %s", key, err) l.logger.DebugContext(ctx, "decodeResultRuleKey", "key", key, slogutil.KeyError, err)
} }
return newRules, nil return newRules, nil
@ -230,12 +238,12 @@ func decodeVTokenAndAddRule(
// decodeResultRules parses the dec's tokens into logEntry ent interpreting it // decodeResultRules parses the dec's tokens into logEntry ent interpreting it
// as a slice of the result rules. // as a slice of the result rules.
func decodeResultRules(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) {
for { for {
delimToken, err := dec.Token() delimToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultRules err: %s", err) l.logger.DebugContext(ctx, "decodeResultRules", slogutil.KeyError, err)
} }
return return
@ -244,13 +252,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
if d, ok := delimToken.(json.Delim); !ok { if d, ok := delimToken.(json.Delim); !ok {
return return
} else if d != '[' { } else if d != '[' {
log.Debug("decodeResultRules: unexpected delim %q", d) l.logger.DebugContext(
ctx,
"decodeResultRules",
slogutil.KeyError, unexpectedDelimiterError(d),
)
} }
err = decodeResultRuleToken(dec, ent) err = l.decodeResultRuleToken(ctx, dec, ent)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultRules err: %s", err) l.logger.DebugContext(ctx, "decodeResultRules", slogutil.KeyError, err)
} }
return return
@ -259,7 +271,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
} }
// decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent. // decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent.
func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { func (l *queryLog) decodeResultRuleToken(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) (err error) {
i := 0 i := 0
for { for {
var keyToken json.Token var keyToken json.Token
@ -287,7 +303,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken) return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken)
} }
decodeResultRuleKey(key, i, dec, ent) l.decodeResultRuleKey(ctx, key, i, dec, ent)
} }
} }
@ -296,12 +312,12 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
// other occurrences of DNSRewriteResult in the entry since hosts container's // other occurrences of DNSRewriteResult in the entry since hosts container's
// rewrites currently has the highest priority along the entire filtering // rewrites currently has the highest priority along the entire filtering
// pipeline. // pipeline.
func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) {
for { for {
itemToken, err := dec.Token() itemToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultReverseHosts err: %s", err) l.logger.DebugContext(ctx, "decodeResultReverseHosts", slogutil.KeyError, err)
} }
return return
@ -315,7 +331,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
return return
} }
log.Debug("decodeResultReverseHosts: unexpected delim %q", v) l.logger.DebugContext(
ctx,
"decodeResultReverseHosts",
slogutil.KeyError, unexpectedDelimiterError(v),
)
return return
case string: case string:
@ -346,12 +366,12 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
// decodeResultIPList parses the dec's tokens into logEntry ent interpreting it // decodeResultIPList parses the dec's tokens into logEntry ent interpreting it
// as the result IP addresses list. // as the result IP addresses list.
func decodeResultIPList(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) {
for { for {
itemToken, err := dec.Token() itemToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultIPList err: %s", err) l.logger.DebugContext(ctx, "decodeResultIPList", slogutil.KeyError, err)
} }
return return
@ -365,7 +385,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
return return
} }
log.Debug("decodeResultIPList: unexpected delim %q", v) l.logger.DebugContext(
ctx,
"decodeResultIPList",
slogutil.KeyError, unexpectedDelimiterError(v),
)
return return
case string: case string:
@ -382,7 +406,12 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
// decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type // decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type
// to the logEntry struct. // to the logEntry struct.
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultDNSRewriteResultKey(
ctx context.Context,
key string,
dec *json.Decoder,
ent *logEntry,
) {
var err error var err error
switch key { switch key {
@ -391,7 +420,11 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
vToken, err = dec.Token() vToken, err = dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultDNSRewriteResultKey err: %s", err) l.logger.DebugContext(
ctx,
"decodeResultDNSRewriteResultKey",
slogutil.KeyError, err,
)
} }
return return
@ -419,7 +452,11 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decoding and correct the values. // decoding and correct the values.
err = dec.Decode(&ent.Result.DNSRewriteResult.Response) err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
if err != nil { if err != nil {
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err) l.logger.DebugContext(
ctx,
"decodeResultDNSRewriteResultKey response",
slogutil.KeyError, err,
)
} }
ent.parseDNSRewriteResultIPs() ent.parseDNSRewriteResultIPs()
@ -430,12 +467,16 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent // decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent
// interpreting it as the result DNSRewriteResult. // interpreting it as the result DNSRewriteResult.
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultDNSRewriteResult(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) {
for { for {
key, err := parseKeyToken(dec) key, err := parseKeyToken(dec)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultDNSRewriteResult: %s", err) l.logger.DebugContext(ctx, "decodeResultDNSRewriteResult", slogutil.KeyError, err)
} }
return return
@ -445,7 +486,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
continue continue
} }
decodeResultDNSRewriteResultKey(key, dec, ent) l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent)
} }
} }
@ -508,14 +549,14 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) {
} }
// decodeResult decodes a token of "Result" type to logEntry struct. // decodeResult decodes a token of "Result" type to logEntry struct.
func decodeResult(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) {
defer translateResult(ent) defer translateResult(ent)
for { for {
key, err := parseKeyToken(dec) key, err := parseKeyToken(dec)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResult: %s", err) l.logger.DebugContext(ctx, "decodeResult", slogutil.KeyError, err)
} }
return return
@ -525,10 +566,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
continue continue
} }
decHandler, ok := resultDecHandlers[key] ok := l.resultDecHandler(ctx, key, dec, ent)
if ok { if ok {
decHandler(dec, ent)
continue continue
} }
@ -543,7 +582,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
} }
if err = handler(val, ent); err != nil { if err = handler(val, ent); err != nil {
log.Debug("decodeResult handler err: %s", err) l.logger.DebugContext(ctx, "decodeResult handler", slogutil.KeyError, err)
return return
} }
@ -636,16 +675,32 @@ var resultHandlers = map[string]logEntryHandler{
}, },
} }
// resultDecHandlers is the map of decode handlers for various keys. // resultDecHandlers calls a decode handler for key if there is one.
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){ func (l *queryLog) resultDecHandler(
"ReverseHosts": decodeResultReverseHosts, ctx context.Context,
"IPList": decodeResultIPList, name string,
"Rules": decodeResultRules, dec *json.Decoder,
"DNSRewriteResult": decodeResultDNSRewriteResult, ent *logEntry,
) (found bool) {
notFound := false
switch name {
case "ReverseHosts":
l.decodeResultReverseHosts(ctx, dec, ent)
case "IPList":
l.decodeResultIPList(ctx, dec, ent)
case "Rules":
l.decodeResultRules(ctx, dec, ent)
case "DNSRewriteResult":
l.decodeResultDNSRewriteResult(ctx, dec, ent)
default:
notFound = true
}
return !notFound
} }
// decodeLogEntry decodes string str to logEntry ent. // decodeLogEntry decodes string str to logEntry ent.
func decodeLogEntry(ent *logEntry, str string) { func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) {
dec := json.NewDecoder(strings.NewReader(str)) dec := json.NewDecoder(strings.NewReader(str))
dec.UseNumber() dec.UseNumber()
@ -653,7 +708,7 @@ func decodeLogEntry(ent *logEntry, str string) {
keyToken, err := dec.Token() keyToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeLogEntry err: %s", err) l.logger.DebugContext(ctx, "decodeLogEntry err", slogutil.KeyError, err)
} }
return return
@ -665,13 +720,14 @@ func decodeLogEntry(ent *logEntry, str string) {
key, ok := keyToken.(string) key, ok := keyToken.(string)
if !ok { if !ok {
log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken) err = fmt.Errorf("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken)
l.logger.DebugContext(ctx, "decodeLogEntry", slogutil.KeyError, err)
return return
} }
if key == "Result" { if key == "Result" {
decodeResult(dec, ent) l.decodeResult(ctx, dec, ent)
continue continue
} }
@ -687,9 +743,14 @@ func decodeLogEntry(ent *logEntry, str string) {
} }
if err = handler(val, ent); err != nil { if err = handler(val, ent); err != nil {
log.Debug("decodeLogEntry handler err: %s", err) l.logger.DebugContext(ctx, "decodeLogEntry handler", slogutil.KeyError, err)
return return
} }
} }
} }
// unexpectedDelimiterError is a helper for creating informative errors.
func unexpectedDelimiterError(d json.Delim) (err error) {
return fmt.Errorf("unexpected delimiter: %q", d)
}

View File

@ -3,27 +3,35 @@ package querylog
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Common constants for tests.
const testTimeout = 1 * time.Second
func TestDecodeLogEntry(t *testing.T) { func TestDecodeLogEntry(t *testing.T) {
logOutput := &bytes.Buffer{} logOutput := &bytes.Buffer{}
l := &queryLog{
logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
Level: slog.LevelDebug,
ReplaceAttr: slogutil.RemoveTime,
})),
}
aghtest.ReplaceLogWriter(t, logOutput) ctx := testutil.ContextWithTimeout(t, testTimeout)
aghtest.ReplaceLogLevel(t, log.DEBUG)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) {
} }
got := &logEntry{} got := &logEntry{}
decodeLogEntry(got, data) l.decodeLogEntry(ctx, got, data)
s := logOutput.String() s := logOutput.String()
assert.Empty(t, s) assert.Empty(t, s)
@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_filter_id_old_rule", name: "bad_filter_id_old_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`,
want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n", want: `level=DEBUG msg="decodeResult handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`,
}, { }, {
name: "bad_is_filtered", name: "bad_is_filtered",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n", want: `level=DEBUG msg="decodeLogEntry err" err="invalid character 'o' in literal true (expecting 'u')"`,
}, { }, {
name: "bad_elapsed", name: "bad_elapsed",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`,
@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_time", name: "bad_time",
log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n", want: `level=DEBUG msg="decodeLogEntry handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`,
}, { }, {
name: "bad_host", name: "bad_host",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "very_bad_client_proto", name: "very_bad_client_proto",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n", want: `level=DEBUG msg="decodeLogEntry handler" err="invalid client proto: \"dog\""`,
}, { }, {
name: "bad_answer", name: "bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "very_bad_answer", name: "very_bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n", want: `level=DEBUG msg="decodeLogEntry handler" err="illegal base64 data at input byte 61"`,
}, { }, {
name: "bad_rule", name: "bad_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`,
@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_reverse_hosts", name: "bad_reverse_hosts",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`,
want: "decodeResultReverseHosts: unexpected delim \"{\"\n", want: `level=DEBUG msg=decodeResultReverseHosts err="unexpected delimiter: \"{\""`,
}, { }, {
name: "bad_ip_list", name: "bad_ip_list",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`,
want: "decodeResultIPList: unexpected delim \"{\"\n", want: `level=DEBUG msg=decodeResultIPList err="unexpected delimiter: \"{\""`,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
decodeLogEntry(new(logEntry), tc.log) l.decodeLogEntry(ctx, new(logEntry), tc.log)
got := logOutput.String()
s := logOutput.String()
if tc.want == "" { if tc.want == "" {
assert.Empty(t, s) assert.Empty(t, got)
} else { } else {
assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s) require.NotEmpty(t, got)
// Remove newline.
got = got[:len(got)-1]
assert.Equal(t, tc.want, got)
} }
logOutput.Reset() logOutput.Reset()
@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
aaaa2 = aaaa1.Next() aaaa2 = aaaa1.Next()
) )
l := &queryLog{
logger: slogutil.NewDiscardLogger(),
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
want *logEntry want *logEntry
entry string entry string
@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := &logEntry{} e := &logEntry{}
decodeLogEntry(e, tc.entry) l.decodeLogEntry(ctx, e, tc.entry)
assert.Equal(t, tc.want, e) assert.Equal(t, tc.want, e)
}) })

View File

@ -1,12 +1,14 @@
package querylog package querylog
import ( import (
"context"
"log/slog"
"net" "net"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) {
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is // addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any // true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
// errors are logged. // errors are logged.
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) {
if resp == nil { if resp == nil {
return return
} }
@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
e.Answer, err = resp.Pack() e.Answer, err = resp.Pack()
err = errors.Annotate(err, "packing answer: %w") err = errors.Annotate(err, "packing answer: %w")
} }
if err != nil { if err != nil {
log.Error("querylog: %s", err) l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err)
} }
} }

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
@ -15,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -74,7 +75,7 @@ func (l *queryLog) initWeb() {
// handleQueryLog is the handler for the GET /control/querylog HTTP API. // handleQueryLog is the handler for the GET /control/querylog HTTP API.
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
params, err := parseSearchParams(r) params, err := l.parseSearchParams(r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
@ -87,18 +88,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
entries, oldest = l.search(params) entries, oldest = l.search(r.Context(), params)
}() }()
resp := entriesToJSON(entries, oldest, l.anonymizer.Load()) resp := l.entriesToJSON(r.Context(), entries, oldest, l.anonymizer.Load())
aghhttp.WriteJSONResponseOK(w, r, resp) aghhttp.WriteJSONResponseOK(w, r, resp)
} }
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP // handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
// API. // API.
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) {
l.clear() l.clear(r.Context())
} }
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP // handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
@ -280,11 +281,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
} }
// parseSearchCriterion parses a search criterion from the query parameter. // parseSearchCriterion parses a search criterion from the query parameter.
func parseSearchCriterion(q url.Values, name string, ct criterionType) ( func (l *queryLog) parseSearchCriterion(
ok bool, ctx context.Context,
sc searchCriterion, q url.Values,
err error, name string,
) { ct criterionType,
) (ok bool, sc searchCriterion, err error) {
val := q.Get(name) val := q.Get(name)
if val == "" { if val == "" {
return false, sc, nil return false, sc, nil
@ -301,7 +303,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// TODO(e.burkov): Make it work with parts of IDNAs somehow. // TODO(e.burkov): Make it work with parts of IDNAs somehow.
loweredVal := strings.ToLower(val) loweredVal := strings.ToLower(val)
if asciiVal, err = idna.ToASCII(loweredVal); err != nil { if asciiVal, err = idna.ToASCII(loweredVal); err != nil {
log.Debug("can't convert %q to ascii: %s", val, err) l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err)
} else if asciiVal == loweredVal { } else if asciiVal == loweredVal {
// Purge asciiVal to prevent checking the same value // Purge asciiVal to prevent checking the same value
// twice. // twice.
@ -331,7 +333,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// parseSearchParams parses search parameters from the HTTP request's query // parseSearchParams parses search parameters from the HTTP request's query
// string. // string.
func parseSearchParams(r *http.Request) (p *searchParams, err error) { func (l *queryLog) parseSearchParams(r *http.Request) (p *searchParams, err error) {
p = newSearchParams() p = newSearchParams()
q := r.URL.Query() q := r.URL.Query()
@ -369,7 +371,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) {
}} { }} {
var ok bool var ok bool
var c searchCriterion var c searchCriterion
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct) ok, c, err = l.parseSearchCriterion(r.Context(), q, v.urlField, v.ct)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -8,7 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -19,7 +20,8 @@ import (
type jobject = map[string]any type jobject = map[string]any
// entriesToJSON converts query log entries to JSON. // entriesToJSON converts query log entries to JSON.
func entriesToJSON( func (l *queryLog) entriesToJSON(
ctx context.Context,
entries []*logEntry, entries []*logEntry,
oldest time.Time, oldest time.Time,
anonFunc aghnet.IPMutFunc, anonFunc aghnet.IPMutFunc,
@ -28,7 +30,7 @@ func entriesToJSON(
// The elements order is already reversed to be from newer to older. // The elements order is already reversed to be from newer to older.
for _, entry := range entries { for _, entry := range entries {
jsonEntry := entryToJSON(entry, anonFunc) jsonEntry := l.entryToJSON(ctx, entry, anonFunc)
data = append(data, jsonEntry) data = append(data, jsonEntry)
} }
@ -44,7 +46,11 @@ func entriesToJSON(
} }
// entryToJSON converts a log entry's data into an entry for the JSON API. // entryToJSON converts a log entry's data into an entry for the JSON API.
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) { func (l *queryLog) entryToJSON(
ctx context.Context,
entry *logEntry,
anonFunc aghnet.IPMutFunc,
) (jsonEntry jobject) {
hostname := entry.QHost hostname := entry.QHost
question := jobject{ question := jobject{
"type": entry.QType, "type": entry.QType,
@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
} }
if qhost, err := idna.ToUnicode(hostname); err != nil { if qhost, err := idna.ToUnicode(hostname); err != nil {
log.Debug("querylog: translating %q into unicode: %s", hostname, err) l.logger.DebugContext(
ctx,
"translating into unicode",
"hostname", hostname,
slogutil.KeyError, err,
)
} else if qhost != hostname && qhost != "" { } else if qhost != hostname && qhost != "" {
question["unicode_name"] = qhost question["unicode_name"] = qhost
} }
@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
jsonEntry["service_name"] = entry.Result.ServiceName jsonEntry["service_name"] = entry.Result.ServiceName
} }
setMsgData(entry, jsonEntry) l.setMsgData(ctx, entry, jsonEntry)
setOrigAns(entry, jsonEntry) l.setOrigAns(ctx, entry, jsonEntry)
return jsonEntry return jsonEntry
} }
// setMsgData sets the message data in jsonEntry. // setMsgData sets the message data in jsonEntry.
func setMsgData(entry *logEntry, jsonEntry jobject) { func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.Answer) == 0 { if len(entry.Answer) == 0 {
return return
} }
msg := &dns.Msg{} msg := &dns.Msg{}
if err := msg.Unpack(entry.Answer); err != nil { if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err) l.logger.DebugContext(
ctx,
"failed to unpack dns msg",
"answer", entry.Answer,
slogutil.KeyError, err,
)
return return
} }
@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) {
} }
// setOrigAns sets the original answer data in jsonEntry. // setOrigAns sets the original answer data in jsonEntry.
func setOrigAns(entry *logEntry, jsonEntry jobject) { func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.OrigAnswer) == 0 { if len(entry.OrigAnswer) == 0 {
return return
} }
@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) {
orig := &dns.Msg{} orig := &dns.Msg{}
err := orig.Unpack(entry.OrigAnswer) err := orig.Unpack(entry.OrigAnswer)
if err != nil { if err != nil {
log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err) l.logger.DebugContext(
ctx,
"setting original answer",
"answer", entry.OrigAnswer,
slogutil.KeyError, err,
)
return return
} }

View File

@ -2,7 +2,9 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog"
"os" "os"
"sync" "sync"
"time" "time"
@ -11,7 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json"
// queryLog is a structure that writes and reads the DNS query log. // queryLog is a structure that writes and reads the DNS query log.
type queryLog struct { type queryLog struct {
// logger is used for logging the operation of the query log. It must not
// be nil.
logger *slog.Logger
// confMu protects conf. // confMu protects conf.
confMu *sync.RWMutex confMu *sync.RWMutex
@ -76,22 +82,22 @@ func NewClientProto(s string) (cp ClientProto, err error) {
} }
} }
func (l *queryLog) Start() { func (l *queryLog) Start(ctx context.Context) {
if l.conf.HTTPRegister != nil { if l.conf.HTTPRegister != nil {
l.initWeb() l.initWeb()
} }
go l.periodicRotate() go l.periodicRotate(ctx)
} }
func (l *queryLog) Close() { func (l *queryLog) Close(ctx context.Context) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
if l.conf.FileEnabled { if l.conf.FileEnabled {
err := l.flushLogBuffer() err := l.flushLogBuffer(ctx)
if err != nil { if err != nil {
log.Error("querylog: closing: %s", err) l.logger.ErrorContext(ctx, "closing", slogutil.KeyError, err)
} }
} }
} }
@ -131,7 +137,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) {
} }
// Clear memory buffer and remove log files // Clear memory buffer and remove log files
func (l *queryLog) clear() { func (l *queryLog) clear(ctx context.Context) {
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
@ -146,19 +152,24 @@ func (l *queryLog) clear() {
oldLogFile := l.logFile + ".1" oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile) err := os.Remove(oldLogFile)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing old log file %q: %s", oldLogFile, err) l.logger.ErrorContext(
ctx,
"removing old log file",
"file", oldLogFile,
slogutil.KeyError, err,
)
} }
err = os.Remove(l.logFile) err = os.Remove(l.logFile)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing log file %q: %s", l.logFile, err) l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err)
} }
log.Debug("querylog: cleared") l.logger.DebugContext(ctx, "cleared")
} }
// newLogEntry creates an instance of logEntry from parameters. // newLogEntry creates an instance of logEntry from parameters.
func newLogEntry(params *AddParams) (entry *logEntry) { func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) {
q := params.Question.Question[0] q := params.Question.Question[0]
qHost := aghnet.NormalizeDomain(q.Name) qHost := aghnet.NormalizeDomain(q.Name)
@ -187,8 +198,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
entry.ReqECS = params.ReqECS.String() entry.ReqECS = params.ReqECS.String()
} }
entry.addResponse(params.Answer, false) entry.addResponse(ctx, logger, params.Answer, false)
entry.addResponse(params.OrigAnswer, true) entry.addResponse(ctx, logger, params.OrigAnswer, true)
return entry return entry
} }
@ -209,9 +220,12 @@ func (l *queryLog) Add(params *AddParams) {
return return
} }
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := params.validate() err := params.validate()
if err != nil { if err != nil {
log.Error("querylog: adding record: %s, skipping", err) l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err)
return return
} }
@ -220,7 +234,7 @@ func (l *queryLog) Add(params *AddParams) {
params.Result = &filtering.Result{} params.Result = &filtering.Result{}
} }
entry := newLogEntry(params) entry := newLogEntry(ctx, l.logger, params)
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
@ -232,9 +246,9 @@ func (l *queryLog) Add(params *AddParams) {
// TODO(s.chzhen): Fix occasional rewrite of entires. // TODO(s.chzhen): Fix occasional rewrite of entires.
go func() { go func() {
flushErr := l.flushLogBuffer() flushErr := l.flushLogBuffer(ctx)
if flushErr != nil { if flushErr != nil {
log.Error("querylog: flushing after adding: %s", flushErr) l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr)
} }
}() }()
} }
@ -247,7 +261,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
c, err := l.findClient(ids) c, err := l.findClient(ids)
if err != nil { if err != nil {
log.Error("querylog: finding client: %s", err) // TODO(s.chzhen): Pass context.
l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err)
} }
if c != nil && c.IgnoreQueryLog { if c != nil && c.IgnoreQueryLog {

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -14,14 +15,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// TestQueryLog tests adding and loading (with filtering) entries from disk and // TestQueryLog tests adding and loading (with filtering) entries from disk and
// memory. // memory.
func TestQueryLog(t *testing.T) { func TestQueryLog(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: true, FileEnabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -30,16 +28,18 @@ func TestQueryLog(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file). // Write to disk (first file).
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Start writing to the second file. // Start writing to the second file.
require.NoError(t, l.rotate()) require.NoError(t, l.rotate(ctx))
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk. // Write to disk.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Add memory entries. // Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -119,7 +119,7 @@ func TestQueryLog(t *testing.T) {
params := newSearchParams() params := newSearchParams()
params.searchCriteria = tc.sCr params.searchCriteria = tc.sCr
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
require.Len(t, entries, len(tc.want)) require.Len(t, entries, len(tc.want))
for _, want := range tc.want { for _, want := range tc.want {
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
@ -130,6 +130,7 @@ func TestQueryLog(t *testing.T) {
func TestQueryLogOffsetLimit(t *testing.T) { func TestQueryLogOffsetLimit(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
MemSize: 100, MemSize: 100,
@ -142,12 +143,15 @@ func TestQueryLogOffsetLimit(t *testing.T) {
firstPageDomain = "first.example.org" firstPageDomain = "first.example.org"
secondPageDomain = "second.example.org" secondPageDomain = "second.example.org"
) )
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add entries to the log. // Add entries to the log.
for range entNum { for range entNum {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to the first file. // Write them to the first file.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for range entNum { for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -191,7 +195,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
params.offset = tc.offset params.offset = tc.offset
params.limit = tc.limit params.limit = tc.limit
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
require.Len(t, entries, tc.wantLen) require.Len(t, entries, tc.wantLen)
@ -205,6 +209,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
func TestQueryLogMaxFileScanEntries(t *testing.T) { func TestQueryLogMaxFileScanEntries(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: true, FileEnabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -213,20 +218,22 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
const entNum = 10 const entNum = 10
// Add entries to the log. // Add entries to the log.
for range entNum { for range entNum {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to disk. // Write them to disk.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
params := newSearchParams() params := newSearchParams()
for _, maxFileScanEntries := range []int{5, 0} { for _, maxFileScanEntries := range []int{5, 0} {
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
params.maxFileScanEntries = maxFileScanEntries params.maxFileScanEntries = maxFileScanEntries
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
assert.Len(t, entries, entNum-maxFileScanEntries) assert.Len(t, entries, entNum-maxFileScanEntries)
}) })
} }
@ -234,6 +241,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
func TestQueryLogFileDisabled(t *testing.T) { func TestQueryLogFileDisabled(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: false, FileEnabled: false,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -248,7 +256,8 @@ func TestQueryLogFileDisabled(t *testing.T) {
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
params := newSearchParams() params := newSearchParams()
ll, _ := l.search(params) ctx := testutil.ContextWithTimeout(t, testTimeout)
ll, _ := l.search(ctx, params)
require.Len(t, ll, 2) require.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost) assert.Equal(t, "example2.org", ll[1].QHost)

View File

@ -1,8 +1,10 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -10,7 +12,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
) )
const ( const (
@ -102,7 +103,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6
// for so that when we call "ReadNext" this line was returned. // for so that when we call "ReadNext" this line was returned.
// - Depth of the search (how many times we compared timestamps). // - Depth of the search (how many times we compared timestamps).
// - If we could not find it, it returns one of the errors described above. // - If we could not find it, it returns one of the errors described above.
func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { func (q *qLogFile) seekTS(
ctx context.Context,
logger *slog.Logger,
timestamp int64,
) (pos int64, depth int, err error) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
@ -151,7 +156,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
lastProbeLineIdx = lineIdx lastProbeLineIdx = lineIdx
// Get the timestamp from the query log record. // Get the timestamp from the query log record.
ts := readQLogTimestamp(line) ts := readQLogTimestamp(ctx, logger, line)
if ts == 0 { if ts == 0 {
return 0, depth, fmt.Errorf( return 0, depth, fmt.Errorf(
"looking up timestamp %d in %q: record %q has empty timestamp", "looking up timestamp %d in %q: record %q has empty timestamp",
@ -385,20 +390,22 @@ func readJSONValue(s, prefix string) string {
} }
// readQLogTimestamp reads the timestamp field from the query log line. // readQLogTimestamp reads the timestamp field from the query log line.
func readQLogTimestamp(str string) int64 { func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 {
val := readJSONValue(str, `"T":"`) val := readJSONValue(str, `"T":"`)
if len(val) == 0 { if len(val) == 0 {
val = readJSONValue(str, `"Time":"`) val = readJSONValue(str, `"Time":"`)
} }
if len(val) == 0 { if len(val) == 0 {
log.Error("Couldn't find timestamp: %s", str) logger.ErrorContext(ctx, "couldn't find timestamp", "line", str)
return 0 return 0
} }
tm, err := time.Parse(time.RFC3339Nano, val) tm, err := time.Parse(time.RFC3339Nano, val)
if err != nil { if err != nil {
log.Error("Couldn't parse timestamp: %s", val) logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val)
return 0 return 0
} }

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -146,6 +147,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
num: 10, num: 10,
}} }}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases { for _, l := range linesCases {
testCases := []struct { testCases := []struct {
name string name string
@ -171,11 +175,11 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
t.Run(l.name+"_"+tc.name, func(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line) line, err := getQLogFileLine(q, tc.line)
require.NoError(t, err) require.NoError(t, err)
ts := readQLogTimestamp(line) ts := readQLogTimestamp(ctx, logger, line)
assert.NotEqualValues(t, 0, ts) assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now. // Try seeking to that line now.
pos, _, err := q.seekTS(ts) pos, _, err := q.seekTS(ctx, logger, ts)
require.NoError(t, err) require.NoError(t, err)
assert.NotEqualValues(t, 0, pos) assert.NotEqualValues(t, 0, pos)
@ -199,6 +203,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
num: 10, num: 10,
}} }}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases { for _, l := range linesCases {
testCases := []struct { testCases := []struct {
name string name string
@ -221,14 +228,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
line, err := getQLogFileLine(q, l.num/2) line, err := getQLogFileLine(q, l.num/2)
require.NoError(t, err) require.NoError(t, err)
testCases[2].ts = readQLogTimestamp(line) - 1 testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
assert.NotEqualValues(t, 0, tc.ts) assert.NotEqualValues(t, 0, tc.ts)
var depth int var depth int
_, depth, err = q.seekTS(tc.ts) _, depth, err = q.seekTS(ctx, logger, tc.ts)
assert.NotEmpty(t, l.num) assert.NotEmpty(t, l.num)
require.Error(t, err) require.Error(t, err)
@ -308,6 +315,9 @@ func TestQLog_Seek(t *testing.T) {
`{"T":"` + strV + `"}` + nl `{"T":"` + strV + `"}` + nl
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00") timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
wantErr error wantErr error
name string name string
@ -340,7 +350,8 @@ func TestQLog_Seek(t *testing.T) {
q := newTestQLogFileData(t, data) q := newTestQLogFileData(t, data)
_, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()) ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
_, depth, err := q.seekTS(ctx, logger, ts)
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
assert.Equal(t, tc.wantDepth, depth) assert.Equal(t, tc.wantDepth, depth)
}) })

View File

@ -1,12 +1,14 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// qLogReader allows reading from multiple query log files in the reverse // qLogReader allows reading from multiple query log files in the reverse
@ -16,6 +18,10 @@ import (
// pointer to a particular query log file, and to a specific position in this // pointer to a particular query log file, and to a specific position in this
// file, and it reads lines in reverse order starting from that position. // file, and it reads lines in reverse order starting from that position.
type qLogReader struct { type qLogReader struct {
// logger is used for logging the operation of the query log reader. It
// must not be nil.
logger *slog.Logger
// qFiles is an array with the query log files. The order is from oldest // qFiles is an array with the query log files. The order is from oldest
// to newest. // to newest.
qFiles []*qLogFile qFiles []*qLogFile
@ -25,7 +31,7 @@ type qLogReader struct {
} }
// newQLogReader initializes a qLogReader instance with the specified files. // newQLogReader initializes a qLogReader instance with the specified files.
func newQLogReader(files []string) (*qLogReader, error) { func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) {
qFiles := make([]*qLogFile, 0) qFiles := make([]*qLogFile, 0)
for _, f := range files { for _, f := range files {
@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) {
// Close what we've already opened. // Close what we've already opened.
cErr := closeQFiles(qFiles) cErr := closeQFiles(qFiles)
if cErr != nil { if cErr != nil {
log.Debug("querylog: closing files: %s", cErr) logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr)
} }
return nil, err return nil, err
@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) {
qFiles = append(qFiles, q) qFiles = append(qFiles, q)
} }
return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil return &qLogReader{
logger: logger,
qFiles: qFiles,
currentFile: len(qFiles) - 1,
}, nil
} }
// seekTS performs binary search of a query log record with the specified // seekTS performs binary search of a query log record with the specified
// timestamp. If the record is found, it sets qLogReader's position to point // timestamp. If the record is found, it sets qLogReader's position to point
// to that line, so that the next ReadNext call returned this line. // to that line, so that the next ReadNext call returned this line.
func (r *qLogReader) seekTS(timestamp int64) (err error) { func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) {
for i := len(r.qFiles) - 1; i >= 0; i-- { for i := len(r.qFiles) - 1; i >= 0; i-- {
q := r.qFiles[i] q := r.qFiles[i]
_, _, err = q.seekTS(timestamp) _, _, err = q.seekTS(ctx, r.logger, timestamp)
if err != nil { if err != nil {
if errors.Is(err, errTSTooEarly) { if errors.Is(err, errTSTooEarly) {
// Look at the next file, since we've reached the end of this // Look at the next file, since we've reached the end of this

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader
testFiles := prepareTestFiles(t, filesNum, linesNum) testFiles := prepareTestFiles(t, filesNum, linesNum)
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Create the new qLogReader instance. // Create the new qLogReader instance.
reader, err := newQLogReader(testFiles) reader, err := newQLogReader(ctx, logger, testFiles)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, reader) assert.NotNil(t, reader)
@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) {
func TestQLogReader_Seek(t *testing.T) { func TestQLogReader_Seek(t *testing.T) {
r := newTestQLogReader(t, 2, 10000) r := newTestQLogReader(t, 2, 10000)
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
want error want error
@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) {
ts, err := time.Parse(time.RFC3339Nano, tc.time) ts, err := time.Parse(time.RFC3339Nano, tc.time)
require.NoError(t, err) require.NoError(t, err)
err = r.seekTS(ts.UnixNano()) err = r.seekTS(ctx, ts.UnixNano())
assert.ErrorIs(t, err, tc.want) assert.ErrorIs(t, err, tc.want)
}) })
} }

View File

@ -1,7 +1,9 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog"
"net" "net"
"path/filepath" "path/filepath"
"sync" "sync"
@ -17,10 +19,10 @@ import (
// QueryLog - main interface // QueryLog - main interface
type QueryLog interface { type QueryLog interface {
Start() Start(ctx context.Context)
// Close query log object // Close query log object
Close() Close(ctx context.Context)
// Add a log entry // Add a log entry
Add(params *AddParams) Add(params *AddParams)
@ -36,6 +38,10 @@ type QueryLog interface {
// //
// Do not alter any fields of this structure after using it. // Do not alter any fields of this structure after using it.
type Config struct { type Config struct {
// Logger is used for logging the operation of the query log. It must not
// be nil.
Logger *slog.Logger
// Ignored contains the list of host names, which should not be written to // Ignored contains the list of host names, which should not be written to
// log, and matches them. // log, and matches them.
Ignored *aghnet.IgnoreEngine Ignored *aghnet.IgnoreEngine
@ -151,6 +157,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
} }
l = &queryLog{ l = &queryLog{
logger: conf.Logger,
findClient: findClient, findClient: findClient,
buffer: container.NewRingBuffer[*logEntry](memSize), buffer: container.NewRingBuffer[*logEntry](memSize),

View File

@ -2,6 +2,7 @@ package querylog
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -9,28 +10,29 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/c2h5oh/datasize"
) )
// flushLogBuffer flushes the current buffer to file and resets the current // flushLogBuffer flushes the current buffer to file and resets the current
// buffer. // buffer.
func (l *queryLog) flushLogBuffer() (err error) { func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }() defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
b, err := l.encodeEntries() b, err := l.encodeEntries(ctx)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
} }
return l.flushToFile(b) return l.flushToFile(ctx, b)
} }
// encodeEntries returns JSON encoded log entries, logs estimated time, clears // encodeEntries returns JSON encoded log entries, logs estimated time, clears
// the log buffer. // the log buffer.
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) {
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
@ -55,8 +57,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
return nil, err return nil, err
} }
size := b.Len()
elapsed := time.Since(start) elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen)) l.logger.DebugContext(
ctx,
"serialized elements via json",
"count", bufLen,
"elapsed", elapsed,
"size", datasize.ByteSize(size),
"size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)),
"time_per_entry", elapsed/time.Duration(bufLen),
)
l.buffer.Clear() l.buffer.Clear()
l.flushPending = false l.flushPending = false
@ -65,7 +76,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
} }
// flushToFile saves the encoded log entries to the query log file. // flushToFile saves the encoded log entries to the query log file.
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) {
l.fileWriteLock.Lock() l.fileWriteLock.Lock()
defer l.fileWriteLock.Unlock() defer l.fileWriteLock.Unlock()
@ -83,19 +94,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
return fmt.Errorf("writing to file %q: %w", filename, err) return fmt.Errorf("writing to file %q: %w", filename, err)
} }
log.Debug("querylog: ok %q: %v bytes written", filename, n) l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n))
return nil return nil
} }
func (l *queryLog) rotate() error { func (l *queryLog) rotate(ctx context.Context) error {
from := l.logFile from := l.logFile
to := l.logFile + ".1" to := l.logFile + ".1"
err := os.Rename(from, to) err := os.Rename(from, to)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
log.Debug("querylog: no log to rotate") l.logger.DebugContext(ctx, "no log to rotate")
return nil return nil
} }
@ -103,12 +114,12 @@ func (l *queryLog) rotate() error {
return fmt.Errorf("failed to rename old file: %w", err) return fmt.Errorf("failed to rename old file: %w", err)
} }
log.Debug("querylog: renamed %s into %s", from, to) l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to)
return nil return nil
} }
func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) {
var f *os.File var f *os.File
f, err = os.Open(l.logFile) f, err = os.Open(l.logFile)
if err != nil { if err != nil {
@ -130,15 +141,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
return time.Time{}, err return time.Time{}, err
} }
log.Debug("querylog: the oldest log entry: %s", val) l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val)
return t, nil return t, nil
} }
func (l *queryLog) periodicRotate() { func (l *queryLog) periodicRotate(ctx context.Context) {
defer log.OnPanic("querylog: rotating") defer slogutil.RecoverAndLog(ctx, l.logger)
l.checkAndRotate() l.checkAndRotate(ctx)
// rotationCheckIvl is the period of time between checking the need for // rotationCheckIvl is the period of time between checking the need for
// rotating log files. It's smaller of any available rotation interval to // rotating log files. It's smaller of any available rotation interval to
@ -151,13 +162,13 @@ func (l *queryLog) periodicRotate() {
defer rotations.Stop() defer rotations.Stop()
for range rotations.C { for range rotations.C {
l.checkAndRotate() l.checkAndRotate(ctx)
} }
} }
// checkAndRotate rotates log files if those are older than the specified // checkAndRotate rotates log files if those are older than the specified
// rotation interval. // rotation interval.
func (l *queryLog) checkAndRotate() { func (l *queryLog) checkAndRotate(ctx context.Context) {
var rotationIvl time.Duration var rotationIvl time.Duration
func() { func() {
l.confMu.RLock() l.confMu.RLock()
@ -166,29 +177,30 @@ func (l *queryLog) checkAndRotate() {
rotationIvl = l.conf.RotationIvl rotationIvl = l.conf.RotationIvl
}() }()
oldest, err := l.readFileFirstTimeValue() oldest, err := l.readFileFirstTimeValue(ctx)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("querylog: reading oldest record for rotation: %s", err) l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err)
return return
} }
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) { if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
log.Debug( l.logger.DebugContext(
"querylog: %s <= %s, not rotating", ctx,
now.Format(time.RFC3339), "not rotating",
rotTime.Format(time.RFC3339), "now", now.Format(time.RFC3339),
"rotate_time", rotTime.Format(time.RFC3339),
) )
return return
} }
err = l.rotate() err = l.rotate(ctx)
if err != nil { if err != nil {
log.Error("querylog: rotating: %s", err) l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err)
return return
} }
log.Debug("querylog: rotated successfully") l.logger.DebugContext(ctx, "rotated successfully")
} }

View File

@ -1,13 +1,15 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"slices" "slices"
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// client finds the client info, if any, by its ClientID and IP address, // client finds the client info, if any, by its ClientID and IP address,
@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
// buffer. It optionally uses the client cache, if provided. It also returns // buffer. It optionally uses the client cache, if provided. It also returns
// the total amount of records in the buffer at the moment of searching. // the total amount of records in the buffer at the moment of searching.
// l.confMu is expected to be locked. // l.confMu is expected to be locked.
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) { func (l *queryLog) searchMemory(
ctx context.Context,
params *searchParams,
cache clientCache,
) (entries []*logEntry, total int) {
// Check memory size, as the buffer can contain a single log record. See // Check memory size, as the buffer can contain a single log record. See
// [newQueryLog]. // [newQueryLog].
if l.conf.MemSize == 0 { if l.conf.MemSize == 0 {
@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
var err error var err error
e.client, err = l.client(e.ClientID, e.IP.String(), cache) e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil { if err != nil {
msg := "querylog: enriching memory record at time %s" + l.logger.ErrorContext(
" for client %q (clientid %q): %s" ctx,
log.Error(msg, e.Time, e.IP, e.ClientID, err) "enriching memory record",
"time", e.Time,
"client_ip", e.IP,
"client_id", e.ClientID,
slogutil.KeyError, err,
)
// Go on and try to match anyway. // Go on and try to match anyway.
} }
@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
// search searches log entries in memory buffer and log file using specified // search searches log entries in memory buffer and log file using specified
// parameters and returns the list of entries found and the time of the oldest // parameters and returns the list of entries found and the time of the oldest
// entry. l.confMu is expected to be locked. // entry. l.confMu is expected to be locked.
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) { func (l *queryLog) search(
ctx context.Context,
params *searchParams,
) (entries []*logEntry, oldest time.Time) {
start := time.Now() start := time.Now()
if params.limit == 0 { if params.limit == 0 {
@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
cache := clientCache{} cache := clientCache{}
memoryEntries, bufLen := l.searchMemory(params, cache) memoryEntries, bufLen := l.searchMemory(ctx, params, cache)
log.Debug("querylog: got %d entries from memory", len(memoryEntries)) l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries))
fileEntries, oldest, total := l.searchFiles(params, cache) fileEntries, oldest, total := l.searchFiles(ctx, params, cache)
log.Debug("querylog: got %d entries from files", len(fileEntries)) l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries))
total += bufLen total += bufLen
@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
oldest = entries[len(entries)-1].Time oldest = entries[len(entries)-1].Time
} }
log.Debug( l.logger.DebugContext(
"querylog: prepared data (%d/%d) older than %s in %s", ctx,
len(entries), "prepared data",
total, "count", len(entries),
params.olderThan, "total", total,
time.Since(start), "older_than", params.olderThan,
"elapsed", time.Since(start),
) )
return entries, oldest return entries, oldest
@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
// seekRecord changes the current position to the next record older than the // seekRecord changes the current position to the next record older than the
// provided parameter. // provided parameter.
func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) {
if olderThan.IsZero() { if olderThan.IsZero() {
return r.SeekStart() return r.SeekStart()
} }
err = r.seekTS(olderThan.UnixNano()) err = r.seekTS(ctx, olderThan.UnixNano())
if err == nil { if err == nil {
// Read to the next record, because we only need the one that goes // Read to the next record, because we only need the one that goes
// after it. // after it.
@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
// setQLogReader creates a reader with the specified files and sets the // setQLogReader creates a reader with the specified files and sets the
// position to the next record older than the provided parameter. // position to the next record older than the provided parameter.
func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) { func (l *queryLog) setQLogReader(
ctx context.Context,
olderThan time.Time,
) (qr *qLogReader, err error) {
files := []string{ files := []string{
l.logFile + ".1", l.logFile + ".1",
l.logFile, l.logFile,
} }
r, err := newQLogReader(files) r, err := newQLogReader(ctx, l.logger, files)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening qlog reader: %s", err) return nil, fmt.Errorf("opening qlog reader: %s", err)
} }
err = r.seekRecord(olderThan) err = r.seekRecord(ctx, olderThan)
if err != nil { if err != nil {
defer func() { err = errors.WithDeferred(err, r.Close()) }() defer func() { err = errors.WithDeferred(err, r.Close()) }()
log.Debug("querylog: cannot seek to %s: %s", olderThan, err) l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err)
return nil, nil return nil, nil
} }
@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error
// calls faster so that the UI could handle it and show something quicker. // calls faster so that the UI could handle it and show something quicker.
// This behavior can be overridden if maxFileScanEntries is set to 0. // This behavior can be overridden if maxFileScanEntries is set to 0.
func (l *queryLog) readEntries( func (l *queryLog) readEntries(
ctx context.Context,
r *qLogReader, r *qLogReader,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
totalLimit int, totalLimit int,
) (entries []*logEntry, oldestNano int64, total int) { ) (entries []*logEntry, oldestNano int64, total int) {
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
ent, ts, rErr := l.readNextEntry(r, params, cache) ent, ts, rErr := l.readNextEntry(ctx, r, params, cache)
if rErr != nil { if rErr != nil {
if rErr == io.EOF { if rErr == io.EOF {
oldestNano = 0 oldestNano = 0
@ -205,7 +224,7 @@ func (l *queryLog) readEntries(
break break
} }
log.Error("querylog: reading next entry: %s", rErr) l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr)
} }
oldestNano = ts oldestNano = ts
@ -231,12 +250,13 @@ func (l *queryLog) readEntries(
// and the total number of processed entries, including discarded ones, // and the total number of processed entries, including discarded ones,
// correspondingly. // correspondingly.
func (l *queryLog) searchFiles( func (l *queryLog) searchFiles(
ctx context.Context,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
) (entries []*logEntry, oldest time.Time, total int) { ) (entries []*logEntry, oldest time.Time, total int) {
r, err := l.setQLogReader(params.olderThan) r, err := l.setQLogReader(ctx, params.olderThan)
if err != nil { if err != nil {
log.Error("querylog: %s", err) l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err)
} }
if r == nil { if r == nil {
@ -245,12 +265,12 @@ func (l *queryLog) searchFiles(
defer func() { defer func() {
if closeErr := r.Close(); closeErr != nil { if closeErr := r.Close(); closeErr != nil {
log.Error("querylog: closing file: %s", closeErr) l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr)
} }
}() }()
totalLimit := params.offset + params.limit totalLimit := params.offset + params.limit
entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit) entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit)
if oldestNano != 0 { if oldestNano != 0 {
oldest = time.Unix(0, oldestNano) oldest = time.Unix(0, oldestNano)
} }
@ -266,15 +286,21 @@ type quickMatchClientFinder struct {
} }
// findClient is a method that can be used as a quickMatchClientFinder. // findClient is a method that can be used as a quickMatchClientFinder.
func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { func (f quickMatchClientFinder) findClient(
ctx context.Context,
logger *slog.Logger,
clientID string,
ip string,
) (c *Client) {
var err error var err error
c, err = f.client(clientID, ip, f.cache) c, err = f.client(clientID, ip, f.cache)
if err != nil { if err != nil {
log.Error( logger.ErrorContext(
"querylog: enriching file record for quick search: for client %q (clientid %q): %s", ctx,
ip, "enriching file record for quick search",
clientID, "client_ip", ip,
err, "client_id", clientID,
slogutil.KeyError, err,
) )
} }
@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
// the entry doesn't match the search criteria. ts is the timestamp of the // the entry doesn't match the search criteria. ts is the timestamp of the
// processed entry. // processed entry.
func (l *queryLog) readNextEntry( func (l *queryLog) readNextEntry(
ctx context.Context,
r *qLogReader, r *qLogReader,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry(
cache: cache, cache: cache,
} }
if !params.quickMatch(line, clientFinder.findClient) { if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) {
ts = readQLogTimestamp(line) ts = readQLogTimestamp(ctx, l.logger, line)
return nil, ts, nil return nil, ts, nil
} }
e = &logEntry{} e = &logEntry{}
decodeLogEntry(e, line) l.decodeLogEntry(ctx, e, line)
if l.isIgnored(e.QHost) { if l.isIgnored(e.QHost) {
return nil, ts, nil return nil, ts, nil
@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry(
e.client, err = l.client(e.ClientID, e.IP.String(), cache) e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil { if err != nil {
log.Error( l.logger.ErrorContext(
"querylog: enriching file record at time %s for client %q (clientid %q): %s", ctx,
e.Time, "enriching file record at time",
e.IP, "at_time", e.Time,
e.ClientID, "client_ip", e.IP,
err, "client_id", e.ClientID,
slogutil.KeyError, err,
) )
// Go on and try to match anyway. // Go on and try to match anyway.

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
} }
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
FindClient: findClient, FindClient: findClient,
BaseDir: t.TempDir(), BaseDir: t.TempDir(),
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) {
AnonymizeClientIP: false, AnonymizeClientIP: false,
}) })
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(l.Close)
ctx := testutil.ContextWithTimeout(t, testTimeout)
t.Cleanup(func() {
l.Close(ctx)
})
q := &dns.Msg{ q := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{
@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
olderThan: time.Now().Add(10 * time.Second), olderThan: time.Now().Add(10 * time.Second),
limit: 3, limit: 3,
} }
entries, _ := l.search(sp) entries, _ := l.search(ctx, sp)
assert.Equal(t, 2, findClientCalls) assert.Equal(t, 2, findClientCalls)
require.Len(t, entries, 3) require.Len(t, entries, 3)

View File

@ -1,7 +1,9 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict(
// quickMatch quickly checks if the line matches the given search criterion. // quickMatch quickly checks if the line matches the given search criterion.
// It returns false if the like doesn't match. This method is only here for // It returns false if the like doesn't match. This method is only here for
// optimization purposes. // optimization purposes.
func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { func (c *searchCriterion) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
switch c.criterionType { switch c.criterionType {
case ctTerm: case ctTerm:
host := readJSONValue(line, `"QH":"`) host := readJSONValue(line, `"QH":"`)
@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun
clientID := readJSONValue(line, `"CID":"`) clientID := readJSONValue(line, `"CID":"`)
var name string var name string
if cli := findClient(clientID, ip); cli != nil { if cli := findClient(ctx, logger, clientID, ip); cli != nil {
name = cli.Name name = cli.Name
} }

View File

@ -1,6 +1,10 @@
package querylog package querylog
import "time" import (
"context"
"log/slog"
"time"
)
// searchParams represent the search query sent by the client. // searchParams represent the search query sent by the client.
type searchParams struct { type searchParams struct {
@ -35,14 +39,23 @@ func newSearchParams() *searchParams {
} }
// quickMatchClientFunc is a simplified client finder for quick matches. // quickMatchClientFunc is a simplified client finder for quick matches.
type quickMatchClientFunc = func(clientID, ip string) (c *Client) type quickMatchClientFunc = func(
ctx context.Context,
logger *slog.Logger,
clientID, ip string,
) (c *Client)
// quickMatch quickly checks if the line matches the given search parameters. // quickMatch quickly checks if the line matches the given search parameters.
// It returns false if the line doesn't match. This method is only here for // It returns false if the line doesn't match. This method is only here for
// optimization purposes. // optimization purposes.
func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { func (s *searchParams) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
for _, c := range s.searchCriteria { for _, c := range s.searchCriteria {
if !c.quickMatch(line, findClient) { if !c.quickMatch(ctx, logger, line, findClient) {
return false return false
} }
} }