Merge branch 'master' into AGDNS-2556-custom-update-url

This commit is contained in:
Eugene Burkov 2024-11-22 17:51:46 +03:00
commit c9efb412e7
25 changed files with 586 additions and 253 deletions

View File

@ -32,6 +32,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
- The release executables are now signed.
### Added
- The `--no-permcheck` command-line option to disable checking and migration of
permissions for the security-sensitive files and directories, which caused
issues on Windows ([#7400]).
[#7400]: https://github.com/AdguardTeam/AdGuardHome/issues/7400
[go-1.23.3]: https://groups.google.com/g/golang-announce/c/X5KodEJYuqI
<!--

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
@ -506,7 +505,7 @@ func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
// RemoveByName removes persistent client information. ok is false if no such
// client exists by that name.
func (s *Storage) RemoveByName(name string) (ok bool) {
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
@ -516,7 +515,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
}
if err := p.CloseUpstreams(); err != nil {
log.Error("client storage: removing client %q: %s", p.Name, err)
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
}
s.index.remove(p)

View File

@ -735,7 +735,7 @@ func TestStorage_RemoveByName(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.want(t, s.RemoveByName(tc.cliName))
tc.want(t, s.RemoveByName(ctx, tc.cliName))
})
}
@ -744,8 +744,8 @@ func TestStorage_RemoveByName(t *testing.T) {
err = s.Add(ctx, existingClient)
require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName))
assert.False(t, s.RemoveByName(existingName))
assert.True(t, s.RemoveByName(ctx, existingName))
assert.False(t, s.RemoveByName(ctx, existingName))
})
}

View File

@ -369,7 +369,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return
}
if !clients.storage.RemoveByName(cj.Name) {
if !clients.storage.RemoveByName(r.Context(), cj.Name) {
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return

View File

@ -48,11 +48,11 @@ func onConfigModified() {
// initDNS updates all the fields of the [Context] needed to initialize the DNS
// server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized. l must not be nil.
func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
anonymizer := config.anonymizer()
statsConf := stats.Config{
Logger: l.With(slogutil.KeyPrefix, "stats"),
Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
Filename: filepath.Join(statsDir, "stats.db"),
Limit: config.Stats.Interval.Duration,
ConfigModified: onConfigModified,
@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
}
conf := querylog.Config{
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
Anonymizer: anonymizer,
ConfigModified: onConfigModified,
HTTPRegister: httpRegister,
@ -113,7 +114,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
anonymizer,
httpRegister,
tlsConf,
l,
baseLogger,
)
}
@ -457,7 +458,8 @@ func startDNSServer() error {
Context.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context.
err := Context.clients.Start(context.TODO())
ctx := context.TODO()
err := Context.clients.Start(ctx)
if err != nil {
return fmt.Errorf("starting clients container: %w", err)
}
@ -469,7 +471,11 @@ func startDNSServer() error {
Context.filters.Start()
Context.stats.Start()
Context.queryLog.Start()
err = Context.queryLog.Start(ctx)
if err != nil {
return fmt.Errorf("starting query log: %w", err)
}
return nil
}
@ -525,12 +531,16 @@ func closeDNSServer() {
if Context.stats != nil {
err := Context.stats.Close()
if err != nil {
log.Debug("closing stats: %s", err)
log.Error("closing stats: %s", err)
}
}
if Context.queryLog != nil {
Context.queryLog.Close()
// TODO(s.chzhen): Pass context.
err := Context.queryLog.Shutdown(context.TODO())
if err != nil {
log.Error("closing query log: %s", err)
}
}
log.Debug("all dns modules are closed")

View File

@ -159,7 +159,7 @@ func setupContext(opts options) (err error) {
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
checkNetworkPermissions()
return nil
}
@ -666,12 +666,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
}
}
if permcheck.NeedsMigration(confPath) {
permcheck.Migrate(Context.workDir, dataDir, statsDir, querylogDir, confPath)
if !opts.noPermCheck {
checkPermissions(Context.workDir, confPath, dataDir, statsDir, querylogDir)
}
permcheck.Check(Context.workDir, dataDir, statsDir, querylogDir, confPath)
Context.web.start()
// Wait for other goroutines to complete their job.
@ -740,6 +738,16 @@ func newUpdater(
})
}
// checkPermissions checks and migrates permissions of the files and directories
// used by AdGuard Home, if needed.
func checkPermissions(workDir, confPath, dataDir, statsDir, querylogDir string) {
if permcheck.NeedsMigration(confPath) {
permcheck.Migrate(workDir, dataDir, statsDir, querylogDir, confPath)
}
permcheck.Check(workDir, dataDir, statsDir, querylogDir, confPath)
}
// initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
@ -799,8 +807,9 @@ func startMods(l *slog.Logger) (err error) {
return nil
}
// Check if the current user permissions are enough to run AdGuard Home
func checkPermissions() {
// checkNetworkPermissions checks if the current user permissions are enough to
// use the required networking functionality.
func checkNetworkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions")
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {

View File

@ -78,6 +78,10 @@ type options struct {
// localFrontend forces AdGuard Home to use the frontend files from disk
// rather than the ones that have been compiled into the binary.
localFrontend bool
// noPermCheck disables checking and migration of permissions for the
// security-sensitive files.
noPermCheck bool
}
// initCmdLineOpts completes initialization of the global command-line option
@ -305,6 +309,15 @@ var cmdLineOpts = []cmdLineOpt{{
description: "Run in GL-Inet compatibility mode.",
longName: "glinet",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noPermCheck = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.noPermCheck },
description: "Skip checking and migration of permissions " +
"of security-sensitive files.",
longName: "no-permcheck",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: nil,

View File

@ -89,6 +89,12 @@ type options struct {
// TODO(a.garipov): Use.
performUpdate bool
// noPermCheck, if true, instructs AdGuard Home to skip checking and
// migrating the permissions of its security-sensitive files.
//
// TODO(e.burkov): Use.
noPermCheck bool
// verbose, if true, instructs AdGuard Home to enable verbose logging.
verbose bool
@ -110,7 +116,8 @@ const (
disableUpdateIdx
glinetModeIdx
helpIdx
localFrontend
localFrontendIdx
noPermCheckIdx
performUpdateIdx
verboseIdx
versionIdx
@ -214,7 +221,7 @@ var commandLineOptions = []*commandLineOption{
valueType: "",
},
localFrontend: {
localFrontendIdx: {
defaultValue: false,
description: "Use local frontend directories.",
long: "local-frontend",
@ -222,6 +229,14 @@ var commandLineOptions = []*commandLineOption{
valueType: "",
},
noPermCheckIdx: {
defaultValue: false,
description: "Skip checking the permissions of security-sensitive files.",
long: "no-permcheck",
short: "",
valueType: "",
},
performUpdateIdx: {
defaultValue: false,
description: "Update the current binary and restart the service in case it's installed.",
@ -264,7 +279,8 @@ func parseOptions(cmdName string, args []string) (opts *options, err error) {
disableUpdateIdx: &opts.disableUpdate,
glinetModeIdx: &opts.glinetMode,
helpIdx: &opts.help,
localFrontend: &opts.localFrontend,
localFrontendIdx: &opts.localFrontend,
noPermCheckIdx: &opts.noPermCheck,
performUpdateIdx: &opts.performUpdate,
verboseIdx: &opts.verbose,
versionIdx: &opts.version,

View File

@ -1,6 +1,7 @@
package querylog
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
@ -13,7 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{
}
// decodeResultRuleKey decodes the token of "Rules" type to logEntry struct.
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultRuleKey(
ctx context.Context,
key string,
i int,
dec *json.Decoder,
ent *logEntry,
) {
var vToken json.Token
switch key {
case "FilterListID":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if n, ok := vToken.(json.Number); ok {
id, _ := n.Int64()
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
}
case "IP":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if ipStr, ok := vToken.(string); ok {
if ip, err := netip.ParseAddr(ipStr); err == nil {
ent.Result.Rules[i].IP = ip
} else {
log.Debug("querylog: decoding ipStr value: %s", err)
l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err)
}
}
case "Text":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if s, ok := vToken.(string); ok {
ent.Result.Rules[i].Text = s
}
@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
// decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule]
// and then adds the decoded object to the slice of result rules.
func decodeVTokenAndAddRule(
func (l *queryLog) decodeVTokenAndAddRule(
ctx context.Context,
key string,
i int,
dec *json.Decoder,
@ -215,7 +223,12 @@ func decodeVTokenAndAddRule(
vToken, err := dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultRuleKey %s err: %s", key, err)
l.logger.DebugContext(
ctx,
"decoding result rule key",
"key", key,
slogutil.KeyError, err,
)
}
return newRules, nil
@ -230,12 +243,14 @@ func decodeVTokenAndAddRule(
// decodeResultRules parses the dec's tokens into logEntry ent interpreting it
// as a slice of the result rules.
func decodeResultRules(dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result rules"
for {
delimToken, err := dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultRules err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -244,13 +259,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
if d, ok := delimToken.(json.Delim); !ok {
return
} else if d != '[' {
log.Debug("decodeResultRules: unexpected delim %q", d)
l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(d),
)
}
err = decodeResultRuleToken(dec, ent)
err = l.decodeResultRuleToken(ctx, dec, ent)
if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultRules err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; rule token", slogutil.KeyError, err)
}
return
@ -259,7 +278,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
}
// decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent.
func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
func (l *queryLog) decodeResultRuleToken(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) (err error) {
i := 0
for {
var keyToken json.Token
@ -287,7 +310,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken)
}
decodeResultRuleKey(key, i, dec, ent)
l.decodeResultRuleKey(ctx, key, i, dec, ent)
}
}
@ -296,12 +319,14 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
// other occurrences of DNSRewriteResult in the entry since hosts container's
// rewrites currently has the highest priority along the entire filtering
// pipeline.
func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result reverse hosts"
for {
itemToken, err := dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultReverseHosts err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -315,7 +340,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
return
}
log.Debug("decodeResultReverseHosts: unexpected delim %q", v)
l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(v),
)
return
case string:
@ -346,12 +375,14 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
// decodeResultIPList parses the dec's tokens into logEntry ent interpreting it
// as the result IP addresses list.
func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result ip list"
for {
itemToken, err := dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultIPList err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -365,7 +396,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
return
}
log.Debug("decodeResultIPList: unexpected delim %q", v)
l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(v),
)
return
case string:
@ -382,7 +417,14 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
// decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type
// to the logEntry struct.
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultDNSRewriteResultKey(
ctx context.Context,
key string,
dec *json.Decoder,
ent *logEntry,
) {
const msgPrefix = "decoding result dns rewrite result key"
var err error
switch key {
@ -391,7 +433,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
vToken, err = dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultDNSRewriteResultKey err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -419,7 +461,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decoding and correct the values.
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
if err != nil {
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; response", slogutil.KeyError, err)
}
ent.parseDNSRewriteResultIPs()
@ -430,12 +472,18 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent
// interpreting it as the result DNSRewriteResult.
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResultDNSRewriteResult(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) {
const msgPrefix = "decoding result dns rewrite result"
for {
key, err := parseKeyToken(dec)
if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultDNSRewriteResult: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -445,7 +493,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
continue
}
decodeResultDNSRewriteResultKey(key, dec, ent)
l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent)
}
}
@ -508,14 +556,16 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) {
}
// decodeResult decodes a token of "Result" type to logEntry struct.
func decodeResult(dec *json.Decoder, ent *logEntry) {
func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result"
defer translateResult(ent)
for {
key, err := parseKeyToken(dec)
if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResult: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -525,10 +575,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
continue
}
decHandler, ok := resultDecHandlers[key]
ok := l.resultDecHandler(ctx, key, dec, ent)
if ok {
decHandler(dec, ent)
continue
}
@ -543,7 +591,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
}
if err = handler(val, ent); err != nil {
log.Debug("decodeResult handler err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
return
}
@ -636,16 +684,34 @@ var resultHandlers = map[string]logEntryHandler{
},
}
// resultDecHandlers is the map of decode handlers for various keys.
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
"ReverseHosts": decodeResultReverseHosts,
"IPList": decodeResultIPList,
"Rules": decodeResultRules,
"DNSRewriteResult": decodeResultDNSRewriteResult,
// resultDecHandlers calls a decode handler for key if there is one.
func (l *queryLog) resultDecHandler(
ctx context.Context,
name string,
dec *json.Decoder,
ent *logEntry,
) (ok bool) {
ok = true
switch name {
case "ReverseHosts":
l.decodeResultReverseHosts(ctx, dec, ent)
case "IPList":
l.decodeResultIPList(ctx, dec, ent)
case "Rules":
l.decodeResultRules(ctx, dec, ent)
case "DNSRewriteResult":
l.decodeResultDNSRewriteResult(ctx, dec, ent)
default:
ok = false
}
return ok
}
// decodeLogEntry decodes string str to logEntry ent.
func decodeLogEntry(ent *logEntry, str string) {
func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) {
const msgPrefix = "decoding log entry"
dec := json.NewDecoder(strings.NewReader(str))
dec.UseNumber()
@ -653,7 +719,7 @@ func decodeLogEntry(ent *logEntry, str string) {
keyToken, err := dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeLogEntry err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
}
return
@ -665,13 +731,14 @@ func decodeLogEntry(ent *logEntry, str string) {
key, ok := keyToken.(string)
if !ok {
log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken)
err = fmt.Errorf("%s: keyToken is %T (%[2]v) and not string", msgPrefix, keyToken)
l.logger.DebugContext(ctx, msgPrefix, slogutil.KeyError, err)
return
}
if key == "Result" {
decodeResult(dec, ent)
l.decodeResult(ctx, dec, ent)
continue
}
@ -687,9 +754,14 @@ func decodeLogEntry(ent *logEntry, str string) {
}
if err = handler(val, ent); err != nil {
log.Debug("decodeLogEntry handler err: %s", err)
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
return
}
}
}
// newUnexpectedDelimiterError is a helper for creating informative errors.
func newUnexpectedDelimiterError(d json.Delim) (err error) {
return fmt.Errorf("unexpected delimiter: %q", d)
}

View File

@ -3,27 +3,35 @@ package querylog
import (
"bytes"
"encoding/base64"
"log/slog"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Common constants for tests.
const testTimeout = 1 * time.Second
func TestDecodeLogEntry(t *testing.T) {
logOutput := &bytes.Buffer{}
l := &queryLog{
logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
Level: slog.LevelDebug,
ReplaceAttr: slogutil.RemoveTime,
})),
}
aghtest.ReplaceLogWriter(t, logOutput)
aghtest.ReplaceLogLevel(t, log.DEBUG)
ctx := testutil.ContextWithTimeout(t, testTimeout)
t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) {
}
got := &logEntry{}
decodeLogEntry(got, data)
l.decodeLogEntry(ctx, got, data)
s := logOutput.String()
assert.Empty(t, s)
@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) {
}, {
name: "bad_filter_id_old_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`,
want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n",
want: `level=DEBUG msg="decoding result; handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`,
}, {
name: "bad_is_filtered",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n",
want: `level=DEBUG msg="decoding log entry; token" err="invalid character 'o' in literal true (expecting 'u')"`,
}, {
name: "bad_elapsed",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`,
@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, {
name: "bad_time",
log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n",
want: `level=DEBUG msg="decoding log entry; handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`,
}, {
name: "bad_host",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, {
name: "very_bad_client_proto",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n",
want: `level=DEBUG msg="decoding log entry; handler" err="invalid client proto: \"dog\""`,
}, {
name: "bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, {
name: "very_bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n",
want: `level=DEBUG msg="decoding log entry; handler" err="illegal base64 data at input byte 61"`,
}, {
name: "bad_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`,
@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) {
}, {
name: "bad_reverse_hosts",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`,
want: "decodeResultReverseHosts: unexpected delim \"{\"\n",
want: `level=DEBUG msg="decoding result reverse hosts" err="unexpected delimiter: \"{\""`,
}, {
name: "bad_ip_list",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`,
want: "decodeResultIPList: unexpected delim \"{\"\n",
want: `level=DEBUG msg="decoding result ip list" err="unexpected delimiter: \"{\""`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
decodeLogEntry(new(logEntry), tc.log)
s := logOutput.String()
l.decodeLogEntry(ctx, new(logEntry), tc.log)
got := logOutput.String()
if tc.want == "" {
assert.Empty(t, s)
assert.Empty(t, got)
} else {
assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s)
require.NotEmpty(t, got)
// Remove newline.
got = got[:len(got)-1]
assert.Equal(t, tc.want, got)
}
logOutput.Reset()
@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
aaaa2 = aaaa1.Next()
)
l := &queryLog{
logger: slogutil.NewDiscardLogger(),
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct {
want *logEntry
entry string
@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := &logEntry{}
decodeLogEntry(e, tc.entry)
l.decodeLogEntry(ctx, e, tc.entry)
assert.Equal(t, tc.want, e)
})

View File

@ -1,12 +1,14 @@
package querylog
import (
"context"
"log/slog"
"net"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
)
@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) {
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
// errors are logged.
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) {
if resp == nil {
return
}
@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
e.Answer, err = resp.Pack()
err = errors.Annotate(err, "packing answer: %w")
}
if err != nil {
log.Error("querylog: %s", err)
l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err)
}
}

View File

@ -1,6 +1,7 @@
package querylog
import (
"context"
"encoding/json"
"fmt"
"math"
@ -15,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/net/idna"
)
@ -74,7 +75,8 @@ func (l *queryLog) initWeb() {
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
params, err := parseSearchParams(r)
ctx := r.Context()
params, err := l.parseSearchParams(ctx, r)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
@ -87,18 +89,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
l.confMu.RLock()
defer l.confMu.RUnlock()
entries, oldest = l.search(params)
entries, oldest = l.search(ctx, params)
}()
resp := entriesToJSON(entries, oldest, l.anonymizer.Load())
resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
aghhttp.WriteJSONResponseOK(w, r, resp)
}
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
// API.
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
l.clear()
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) {
l.clear(r.Context())
}
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
@ -280,11 +282,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
}
// parseSearchCriterion parses a search criterion from the query parameter.
func parseSearchCriterion(q url.Values, name string, ct criterionType) (
ok bool,
sc searchCriterion,
err error,
) {
func (l *queryLog) parseSearchCriterion(
ctx context.Context,
q url.Values,
name string,
ct criterionType,
) (ok bool, sc searchCriterion, err error) {
val := q.Get(name)
if val == "" {
return false, sc, nil
@ -301,7 +304,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// TODO(e.burkov): Make it work with parts of IDNAs somehow.
loweredVal := strings.ToLower(val)
if asciiVal, err = idna.ToASCII(loweredVal); err != nil {
log.Debug("can't convert %q to ascii: %s", val, err)
l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err)
} else if asciiVal == loweredVal {
// Purge asciiVal to prevent checking the same value
// twice.
@ -331,7 +334,10 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// parseSearchParams parses search parameters from the HTTP request's query
// string.
func parseSearchParams(r *http.Request) (p *searchParams, err error) {
func (l *queryLog) parseSearchParams(
ctx context.Context,
r *http.Request,
) (p *searchParams, err error) {
p = newSearchParams()
q := r.URL.Query()
@ -369,7 +375,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) {
}} {
var ok bool
var c searchCriterion
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct)
ok, c, err = l.parseSearchCriterion(ctx, q, v.urlField, v.ct)
if err != nil {
return nil, err
}

View File

@ -1,6 +1,7 @@
package querylog
import (
"context"
"slices"
"strconv"
"strings"
@ -8,7 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns"
"golang.org/x/net/idna"
)
@ -19,7 +20,8 @@ import (
type jobject = map[string]any
// entriesToJSON converts query log entries to JSON.
func entriesToJSON(
func (l *queryLog) entriesToJSON(
ctx context.Context,
entries []*logEntry,
oldest time.Time,
anonFunc aghnet.IPMutFunc,
@ -28,7 +30,7 @@ func entriesToJSON(
// The elements order is already reversed to be from newer to older.
for _, entry := range entries {
jsonEntry := entryToJSON(entry, anonFunc)
jsonEntry := l.entryToJSON(ctx, entry, anonFunc)
data = append(data, jsonEntry)
}
@ -44,7 +46,11 @@ func entriesToJSON(
}
// entryToJSON converts a log entry's data into an entry for the JSON API.
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
func (l *queryLog) entryToJSON(
ctx context.Context,
entry *logEntry,
anonFunc aghnet.IPMutFunc,
) (jsonEntry jobject) {
hostname := entry.QHost
question := jobject{
"type": entry.QType,
@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
}
if qhost, err := idna.ToUnicode(hostname); err != nil {
log.Debug("querylog: translating %q into unicode: %s", hostname, err)
l.logger.DebugContext(
ctx,
"translating into unicode",
"hostname", hostname,
slogutil.KeyError, err,
)
} else if qhost != hostname && qhost != "" {
question["unicode_name"] = qhost
}
@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
jsonEntry["service_name"] = entry.Result.ServiceName
}
setMsgData(entry, jsonEntry)
setOrigAns(entry, jsonEntry)
l.setMsgData(ctx, entry, jsonEntry)
l.setOrigAns(ctx, entry, jsonEntry)
return jsonEntry
}
// setMsgData sets the message data in jsonEntry.
func setMsgData(entry *logEntry, jsonEntry jobject) {
func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.Answer) == 0 {
return
}
msg := &dns.Msg{}
if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err)
l.logger.DebugContext(
ctx,
"unpacking dns message",
"answer", entry.Answer,
slogutil.KeyError, err,
)
return
}
@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) {
}
// setOrigAns sets the original answer data in jsonEntry.
func setOrigAns(entry *logEntry, jsonEntry jobject) {
func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.OrigAnswer) == 0 {
return
}
@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) {
orig := &dns.Msg{}
err := orig.Unpack(entry.OrigAnswer)
if err != nil {
log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err)
l.logger.DebugContext(
ctx,
"setting original answer",
"answer", entry.OrigAnswer,
slogutil.KeyError, err,
)
return
}

View File

@ -2,7 +2,9 @@
package querylog
import (
"context"
"fmt"
"log/slog"
"os"
"sync"
"time"
@ -11,7 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
)
@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json"
// queryLog is a structure that writes and reads the DNS query log.
type queryLog struct {
// logger is used for logging the operation of the query log. It must not
// be nil.
logger *slog.Logger
// confMu protects conf.
confMu *sync.RWMutex
@ -76,24 +82,34 @@ func NewClientProto(s string) (cp ClientProto, err error) {
}
}
func (l *queryLog) Start() {
// type check
var _ QueryLog = (*queryLog)(nil)
// Start implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Start(ctx context.Context) (err error) {
if l.conf.HTTPRegister != nil {
l.initWeb()
}
go l.periodicRotate()
go l.periodicRotate(ctx)
return nil
}
func (l *queryLog) Close() {
// Shutdown implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Shutdown(ctx context.Context) (err error) {
l.confMu.RLock()
defer l.confMu.RUnlock()
if l.conf.FileEnabled {
err := l.flushLogBuffer()
err = l.flushLogBuffer(ctx)
if err != nil {
log.Error("querylog: closing: %s", err)
// Don't wrap the error because it's informative enough as is.
return err
}
}
return nil
}
func checkInterval(ivl time.Duration) (ok bool) {
@ -123,6 +139,7 @@ func validateIvl(ivl time.Duration) (err error) {
return nil
}
// WriteDiskConfig implements the [QueryLog] interface for *queryLog.
func (l *queryLog) WriteDiskConfig(c *Config) {
l.confMu.RLock()
defer l.confMu.RUnlock()
@ -131,7 +148,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) {
}
// Clear memory buffer and remove log files
func (l *queryLog) clear() {
func (l *queryLog) clear(ctx context.Context) {
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
@ -146,19 +163,24 @@ func (l *queryLog) clear() {
oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing old log file %q: %s", oldLogFile, err)
l.logger.ErrorContext(
ctx,
"removing old log file",
"file", oldLogFile,
slogutil.KeyError, err,
)
}
err = os.Remove(l.logFile)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing log file %q: %s", l.logFile, err)
l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err)
}
log.Debug("querylog: cleared")
l.logger.DebugContext(ctx, "cleared")
}
// newLogEntry creates an instance of logEntry from parameters.
func newLogEntry(params *AddParams) (entry *logEntry) {
func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) {
q := params.Question.Question[0]
qHost := aghnet.NormalizeDomain(q.Name)
@ -187,8 +209,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
entry.ReqECS = params.ReqECS.String()
}
entry.addResponse(params.Answer, false)
entry.addResponse(params.OrigAnswer, true)
entry.addResponse(ctx, logger, params.Answer, false)
entry.addResponse(ctx, logger, params.OrigAnswer, true)
return entry
}
@ -209,9 +231,12 @@ func (l *queryLog) Add(params *AddParams) {
return
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := params.validate()
if err != nil {
log.Error("querylog: adding record: %s, skipping", err)
l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err)
return
}
@ -220,7 +245,7 @@ func (l *queryLog) Add(params *AddParams) {
params.Result = &filtering.Result{}
}
entry := newLogEntry(params)
entry := newLogEntry(ctx, l.logger, params)
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
@ -232,9 +257,9 @@ func (l *queryLog) Add(params *AddParams) {
// TODO(s.chzhen): Fix occasional rewrite of entires.
go func() {
flushErr := l.flushLogBuffer()
flushErr := l.flushLogBuffer(ctx)
if flushErr != nil {
log.Error("querylog: flushing after adding: %s", flushErr)
l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr)
}
}()
}
@ -247,7 +272,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
c, err := l.findClient(ids)
if err != nil {
log.Error("querylog: finding client: %s", err)
// TODO(s.chzhen): Pass context.
l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err)
}
if c != nil && c.IgnoreQueryLog {

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
@ -14,14 +15,11 @@ import (
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// TestQueryLog tests adding and loading (with filtering) entries from disk and
// memory.
func TestQueryLog(t *testing.T) {
l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true,
FileEnabled: true,
RotationIvl: timeutil.Day,
@ -30,16 +28,21 @@ func TestQueryLog(t *testing.T) {
})
require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file).
require.NoError(t, l.flushLogBuffer())
require.NoError(t, l.flushLogBuffer(ctx))
// Start writing to the second file.
require.NoError(t, l.rotate())
require.NoError(t, l.rotate(ctx))
// Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk.
require.NoError(t, l.flushLogBuffer())
require.NoError(t, l.flushLogBuffer(ctx))
// Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -119,8 +122,9 @@ func TestQueryLog(t *testing.T) {
params := newSearchParams()
params.searchCriteria = tc.sCr
entries, _ := l.search(params)
entries, _ := l.search(ctx, params)
require.Len(t, entries, len(tc.want))
for _, want := range tc.want {
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
}
@ -130,6 +134,7 @@ func TestQueryLog(t *testing.T) {
func TestQueryLogOffsetLimit(t *testing.T) {
l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true,
RotationIvl: timeutil.Day,
MemSize: 100,
@ -142,12 +147,16 @@ func TestQueryLogOffsetLimit(t *testing.T) {
firstPageDomain = "first.example.org"
secondPageDomain = "second.example.org"
)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add entries to the log.
for range entNum {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to the first file.
require.NoError(t, l.flushLogBuffer())
require.NoError(t, l.flushLogBuffer(ctx))
// Add more to the in-memory part of log.
for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -191,8 +200,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
params.offset = tc.offset
params.limit = tc.limit
entries, _ := l.search(params)
entries, _ := l.search(ctx, params)
require.Len(t, entries, tc.wantLen)
if tc.wantLen > 0 {
@ -205,6 +213,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
func TestQueryLogMaxFileScanEntries(t *testing.T) {
l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true,
FileEnabled: true,
RotationIvl: timeutil.Day,
@ -213,20 +222,21 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
})
require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
const entNum = 10
// Add entries to the log.
for range entNum {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
}
// Write them to disk.
require.NoError(t, l.flushLogBuffer())
require.NoError(t, l.flushLogBuffer(ctx))
params := newSearchParams()
for _, maxFileScanEntries := range []int{5, 0} {
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
params.maxFileScanEntries = maxFileScanEntries
entries, _ := l.search(params)
entries, _ := l.search(ctx, params)
assert.Len(t, entries, entNum-maxFileScanEntries)
})
}
@ -234,6 +244,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
func TestQueryLogFileDisabled(t *testing.T) {
l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true,
FileEnabled: false,
RotationIvl: timeutil.Day,
@ -248,8 +259,10 @@ func TestQueryLogFileDisabled(t *testing.T) {
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
params := newSearchParams()
ll, _ := l.search(params)
ctx := testutil.ContextWithTimeout(t, testTimeout)
ll, _ := l.search(ctx, params)
require.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost)
}

View File

@ -1,8 +1,10 @@
package querylog
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"strings"
"sync"
@ -10,7 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
const (
@ -102,7 +104,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6
// for so that when we call "ReadNext" this line was returned.
// - Depth of the search (how many times we compared timestamps).
// - If we could not find it, it returns one of the errors described above.
func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
func (q *qLogFile) seekTS(
ctx context.Context,
logger *slog.Logger,
timestamp int64,
) (pos int64, depth int, err error) {
q.lock.Lock()
defer q.lock.Unlock()
@ -151,7 +157,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
lastProbeLineIdx = lineIdx
// Get the timestamp from the query log record.
ts := readQLogTimestamp(line)
ts := readQLogTimestamp(ctx, logger, line)
if ts == 0 {
return 0, depth, fmt.Errorf(
"looking up timestamp %d in %q: record %q has empty timestamp",
@ -385,20 +391,22 @@ func readJSONValue(s, prefix string) string {
}
// readQLogTimestamp reads the timestamp field from the query log line.
func readQLogTimestamp(str string) int64 {
func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 {
val := readJSONValue(str, `"T":"`)
if len(val) == 0 {
val = readJSONValue(str, `"Time":"`)
}
if len(val) == 0 {
log.Error("Couldn't find timestamp: %s", str)
logger.ErrorContext(ctx, "couldn't find timestamp", "line", str)
return 0
}
tm, err := time.Parse(time.RFC3339Nano, val)
if err != nil {
log.Error("Couldn't parse timestamp: %s", val)
logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val, slogutil.KeyError, err)
return 0
}

View File

@ -12,6 +12,7 @@ import (
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
f, err := os.CreateTemp(dir, "*.txt")
require.NoError(t, err)
// Use defer and not t.Cleanup to make sure that the file is closed
// after this function is done.
defer func() {
@ -108,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
// Calculate the expected position.
fileInfo, err := q.file.Stat()
require.NoError(t, err)
var expPos int64
if expPos = fileInfo.Size(); expPos > 0 {
expPos--
@ -129,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
}
require.Equal(t, io.EOF, err)
assert.Equal(t, tc.linesNum, read)
})
}
@ -146,6 +150,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
num: 10,
}}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases {
testCases := []struct {
name string
@ -171,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line)
require.NoError(t, err)
ts := readQLogTimestamp(line)
ts := readQLogTimestamp(ctx, logger, line)
assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now.
pos, _, err := q.seekTS(ts)
pos, _, err := q.seekTS(ctx, logger, ts)
require.NoError(t, err)
assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext()
require.NoError(t, err)
assert.Equal(t, line, testLine)
})
}
@ -199,6 +209,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
num: 10,
}}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases {
testCases := []struct {
name string
@ -221,14 +234,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
line, err := getQLogFileLine(q, l.num/2)
require.NoError(t, err)
testCases[2].ts = readQLogTimestamp(line) - 1
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.NotEqualValues(t, 0, tc.ts)
var depth int
_, depth, err = q.seekTS(tc.ts)
_, depth, err = q.seekTS(ctx, logger, tc.ts)
assert.NotEmpty(t, l.num)
require.Error(t, err)
@ -262,11 +275,13 @@ func TestQLogFile(t *testing.T) {
// Seek to the start.
pos, err := q.SeekStart()
require.NoError(t, err)
assert.Greater(t, pos, int64(0))
// Read first line.
line, err := q.ReadNext()
require.NoError(t, err)
assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line)
@ -274,6 +289,7 @@ func TestQLogFile(t *testing.T) {
// Read second line.
line, err = q.ReadNext()
require.NoError(t, err)
assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line)
@ -282,12 +298,14 @@ func TestQLogFile(t *testing.T) {
// Try reading again (there's nothing to read anymore).
line, err = q.ReadNext()
require.Equal(t, io.EOF, err)
assert.Empty(t, line)
}
func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
f, err := os.CreateTemp(t.TempDir(), "*.txt")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, f.Close)
_, err = f.WriteString(data)
@ -295,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
file, err = newQLogFile(f.Name())
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, file.Close)
return file
@ -308,6 +327,9 @@ func TestQLog_Seek(t *testing.T) {
`{"T":"` + strV + `"}` + nl
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct {
wantErr error
name string
@ -340,8 +362,10 @@ func TestQLog_Seek(t *testing.T) {
q := newTestQLogFileData(t, data)
_, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano())
ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
_, depth, err := q.seekTS(ctx, logger, ts)
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
assert.Equal(t, tc.wantDepth, depth)
})
}

View File

@ -1,12 +1,14 @@
package querylog
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// qLogReader allows reading from multiple query log files in the reverse
@ -16,6 +18,10 @@ import (
// pointer to a particular query log file, and to a specific position in this
// file, and it reads lines in reverse order starting from that position.
type qLogReader struct {
// logger is used for logging the operation of the query log reader. It
// must not be nil.
logger *slog.Logger
// qFiles is an array with the query log files. The order is from oldest
// to newest.
qFiles []*qLogFile
@ -25,7 +31,7 @@ type qLogReader struct {
}
// newQLogReader initializes a qLogReader instance with the specified files.
func newQLogReader(files []string) (*qLogReader, error) {
func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) {
qFiles := make([]*qLogFile, 0)
for _, f := range files {
@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) {
// Close what we've already opened.
cErr := closeQFiles(qFiles)
if cErr != nil {
log.Debug("querylog: closing files: %s", cErr)
logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr)
}
return nil, err
@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) {
qFiles = append(qFiles, q)
}
return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil
return &qLogReader{
logger: logger,
qFiles: qFiles,
currentFile: len(qFiles) - 1,
}, nil
}
// seekTS performs binary search of a query log record with the specified
// timestamp. If the record is found, it sets qLogReader's position to point
// to that line, so that the next ReadNext call returned this line.
func (r *qLogReader) seekTS(timestamp int64) (err error) {
func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) {
for i := len(r.qFiles) - 1; i >= 0; i-- {
q := r.qFiles[i]
_, _, err = q.seekTS(timestamp)
_, _, err = q.seekTS(ctx, r.logger, timestamp)
if err != nil {
if errors.Is(err, errTSTooEarly) {
// Look at the next file, since we've reached the end of this

View File

@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader
testFiles := prepareTestFiles(t, filesNum, linesNum)
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Create the new qLogReader instance.
reader, err := newQLogReader(testFiles)
reader, err := newQLogReader(ctx, logger, testFiles)
require.NoError(t, err)
assert.NotNil(t, reader)
@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) {
func TestQLogReader_Seek(t *testing.T) {
r := newTestQLogReader(t, 2, 10000)
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct {
want error
@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) {
ts, err := time.Parse(time.RFC3339Nano, tc.time)
require.NoError(t, err)
err = r.seekTS(ts.UnixNano())
err = r.seekTS(ctx, ts.UnixNano())
assert.ErrorIs(t, err, tc.want)
})
}

View File

@ -2,6 +2,7 @@ package querylog
import (
"fmt"
"log/slog"
"net"
"path/filepath"
"sync"
@ -12,20 +13,19 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/service"
"github.com/miekg/dns"
)
// QueryLog - main interface
// QueryLog is the query log interface for use by other packages.
type QueryLog interface {
Start()
// Interface starts and stops the query log.
service.Interface
// Close query log object
Close()
// Add a log entry
// Add adds a log entry.
Add(params *AddParams)
// WriteDiskConfig - write configuration
// WriteDiskConfig writes the query log configuration to c.
WriteDiskConfig(c *Config)
// ShouldLog returns true if request for the host should be logged.
@ -36,6 +36,10 @@ type QueryLog interface {
//
// Do not alter any fields of this structure after using it.
type Config struct {
// Logger is used for logging the operation of the query log. It must not
// be nil.
Logger *slog.Logger
// Ignored contains the list of host names, which should not be written to
// log, and matches them.
Ignored *aghnet.IgnoreEngine
@ -151,6 +155,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
}
l = &queryLog{
logger: conf.Logger,
findClient: findClient,
buffer: container.NewRingBuffer[*logEntry](memSize),

View File

@ -2,6 +2,7 @@ package querylog
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
@ -9,28 +10,30 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/c2h5oh/datasize"
)
// flushLogBuffer flushes the current buffer to file and resets the current
// buffer.
func (l *queryLog) flushLogBuffer() (err error) {
func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock()
b, err := l.encodeEntries()
b, err := l.encodeEntries(ctx)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
}
return l.flushToFile(b)
return l.flushToFile(ctx, b)
}
// encodeEntries returns JSON encoded log entries, logs estimated time, clears
// the log buffer.
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) {
l.bufferLock.Lock()
defer l.bufferLock.Unlock()
@ -55,8 +58,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
return nil, err
}
size := b.Len()
elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen))
l.logger.DebugContext(
ctx,
"serialized elements via json",
"count", bufLen,
"elapsed", elapsed,
"size", datasize.ByteSize(size),
"size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)),
"time_per_entry", elapsed/time.Duration(bufLen),
)
l.buffer.Clear()
l.flushPending = false
@ -65,7 +77,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
}
// flushToFile saves the encoded log entries to the query log file.
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) {
l.fileWriteLock.Lock()
defer l.fileWriteLock.Unlock()
@ -83,19 +95,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
return fmt.Errorf("writing to file %q: %w", filename, err)
}
log.Debug("querylog: ok %q: %v bytes written", filename, n)
l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n))
return nil
}
func (l *queryLog) rotate() error {
func (l *queryLog) rotate(ctx context.Context) error {
from := l.logFile
to := l.logFile + ".1"
err := os.Rename(from, to)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
log.Debug("querylog: no log to rotate")
l.logger.DebugContext(ctx, "no log to rotate")
return nil
}
@ -103,12 +115,12 @@ func (l *queryLog) rotate() error {
return fmt.Errorf("failed to rename old file: %w", err)
}
log.Debug("querylog: renamed %s into %s", from, to)
l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to)
return nil
}
func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) {
var f *os.File
f, err = os.Open(l.logFile)
if err != nil {
@ -130,15 +142,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
return time.Time{}, err
}
log.Debug("querylog: the oldest log entry: %s", val)
l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val)
return t, nil
}
func (l *queryLog) periodicRotate() {
defer log.OnPanic("querylog: rotating")
func (l *queryLog) periodicRotate(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, l.logger)
l.checkAndRotate()
l.checkAndRotate(ctx)
// rotationCheckIvl is the period of time between checking the need for
// rotating log files. It's smaller of any available rotation interval to
@ -151,13 +163,13 @@ func (l *queryLog) periodicRotate() {
defer rotations.Stop()
for range rotations.C {
l.checkAndRotate()
l.checkAndRotate(ctx)
}
}
// checkAndRotate rotates log files if those are older than the specified
// rotation interval.
func (l *queryLog) checkAndRotate() {
func (l *queryLog) checkAndRotate(ctx context.Context) {
var rotationIvl time.Duration
func() {
l.confMu.RLock()
@ -166,29 +178,30 @@ func (l *queryLog) checkAndRotate() {
rotationIvl = l.conf.RotationIvl
}()
oldest, err := l.readFileFirstTimeValue()
oldest, err := l.readFileFirstTimeValue(ctx)
if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("querylog: reading oldest record for rotation: %s", err)
l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err)
return
}
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
log.Debug(
"querylog: %s <= %s, not rotating",
now.Format(time.RFC3339),
rotTime.Format(time.RFC3339),
l.logger.DebugContext(
ctx,
"not rotating",
"now", now.Format(time.RFC3339),
"rotate_time", rotTime.Format(time.RFC3339),
)
return
}
err = l.rotate()
err = l.rotate(ctx)
if err != nil {
log.Error("querylog: rotating: %s", err)
l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err)
return
}
log.Debug("querylog: rotated successfully")
l.logger.DebugContext(ctx, "rotated successfully")
}

View File

@ -1,13 +1,15 @@
package querylog
import (
"context"
"fmt"
"io"
"log/slog"
"slices"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// client finds the client info, if any, by its ClientID and IP address,
@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
// buffer. It optionally uses the client cache, if provided. It also returns
// the total amount of records in the buffer at the moment of searching.
// l.confMu is expected to be locked.
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) {
func (l *queryLog) searchMemory(
ctx context.Context,
params *searchParams,
cache clientCache,
) (entries []*logEntry, total int) {
// Check memory size, as the buffer can contain a single log record. See
// [newQueryLog].
if l.conf.MemSize == 0 {
@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
var err error
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil {
msg := "querylog: enriching memory record at time %s" +
" for client %q (clientid %q): %s"
log.Error(msg, e.Time, e.IP, e.ClientID, err)
l.logger.ErrorContext(
ctx,
"enriching memory record",
"at", e.Time,
"client_ip", e.IP,
"client_id", e.ClientID,
slogutil.KeyError, err,
)
// Go on and try to match anyway.
}
@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
// search searches log entries in memory buffer and log file using specified
// parameters and returns the list of entries found and the time of the oldest
// entry. l.confMu is expected to be locked.
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
func (l *queryLog) search(
ctx context.Context,
params *searchParams,
) (entries []*logEntry, oldest time.Time) {
start := time.Now()
if params.limit == 0 {
@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
cache := clientCache{}
memoryEntries, bufLen := l.searchMemory(params, cache)
log.Debug("querylog: got %d entries from memory", len(memoryEntries))
memoryEntries, bufLen := l.searchMemory(ctx, params, cache)
l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries))
fileEntries, oldest, total := l.searchFiles(params, cache)
log.Debug("querylog: got %d entries from files", len(fileEntries))
fileEntries, oldest, total := l.searchFiles(ctx, params, cache)
l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries))
total += bufLen
@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
oldest = entries[len(entries)-1].Time
}
log.Debug(
"querylog: prepared data (%d/%d) older than %s in %s",
len(entries),
total,
params.olderThan,
time.Since(start),
l.logger.DebugContext(
ctx,
"prepared data",
"count", len(entries),
"total", total,
"older_than", params.olderThan,
"elapsed", time.Since(start),
)
return entries, oldest
@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
// seekRecord changes the current position to the next record older than the
// provided parameter.
func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) {
if olderThan.IsZero() {
return r.SeekStart()
}
err = r.seekTS(olderThan.UnixNano())
err = r.seekTS(ctx, olderThan.UnixNano())
if err == nil {
// Read to the next record, because we only need the one that goes
// after it.
@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
// setQLogReader creates a reader with the specified files and sets the
// position to the next record older than the provided parameter.
func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) {
func (l *queryLog) setQLogReader(
ctx context.Context,
olderThan time.Time,
) (qr *qLogReader, err error) {
files := []string{
l.logFile + ".1",
l.logFile,
}
r, err := newQLogReader(files)
r, err := newQLogReader(ctx, l.logger, files)
if err != nil {
return nil, fmt.Errorf("opening qlog reader: %s", err)
return nil, fmt.Errorf("opening qlog reader: %w", err)
}
err = r.seekRecord(olderThan)
err = r.seekRecord(ctx, olderThan)
if err != nil {
defer func() { err = errors.WithDeferred(err, r.Close()) }()
log.Debug("querylog: cannot seek to %s: %s", olderThan, err)
l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err)
return nil, nil
}
@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error
// calls faster so that the UI could handle it and show something quicker.
// This behavior can be overridden if maxFileScanEntries is set to 0.
func (l *queryLog) readEntries(
ctx context.Context,
r *qLogReader,
params *searchParams,
cache clientCache,
totalLimit int,
) (entries []*logEntry, oldestNano int64, total int) {
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
ent, ts, rErr := l.readNextEntry(r, params, cache)
ent, ts, rErr := l.readNextEntry(ctx, r, params, cache)
if rErr != nil {
if rErr == io.EOF {
oldestNano = 0
@ -205,7 +224,7 @@ func (l *queryLog) readEntries(
break
}
log.Error("querylog: reading next entry: %s", rErr)
l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr)
}
oldestNano = ts
@ -231,12 +250,13 @@ func (l *queryLog) readEntries(
// and the total number of processed entries, including discarded ones,
// correspondingly.
func (l *queryLog) searchFiles(
ctx context.Context,
params *searchParams,
cache clientCache,
) (entries []*logEntry, oldest time.Time, total int) {
r, err := l.setQLogReader(params.olderThan)
r, err := l.setQLogReader(ctx, params.olderThan)
if err != nil {
log.Error("querylog: %s", err)
l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err)
}
if r == nil {
@ -245,12 +265,12 @@ func (l *queryLog) searchFiles(
defer func() {
if closeErr := r.Close(); closeErr != nil {
log.Error("querylog: closing file: %s", closeErr)
l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr)
}
}()
totalLimit := params.offset + params.limit
entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit)
entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit)
if oldestNano != 0 {
oldest = time.Unix(0, oldestNano)
}
@ -266,15 +286,21 @@ type quickMatchClientFinder struct {
}
// findClient is a method that can be used as a quickMatchClientFinder.
func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
func (f quickMatchClientFinder) findClient(
ctx context.Context,
logger *slog.Logger,
clientID string,
ip string,
) (c *Client) {
var err error
c, err = f.client(clientID, ip, f.cache)
if err != nil {
log.Error(
"querylog: enriching file record for quick search: for client %q (clientid %q): %s",
ip,
clientID,
err,
logger.ErrorContext(
ctx,
"enriching file record for quick search",
"client_ip", ip,
"client_id", clientID,
slogutil.KeyError, err,
)
}
@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
// the entry doesn't match the search criteria. ts is the timestamp of the
// processed entry.
func (l *queryLog) readNextEntry(
ctx context.Context,
r *qLogReader,
params *searchParams,
cache clientCache,
@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry(
cache: cache,
}
if !params.quickMatch(line, clientFinder.findClient) {
ts = readQLogTimestamp(line)
if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) {
ts = readQLogTimestamp(ctx, l.logger, line)
return nil, ts, nil
}
e = &logEntry{}
decodeLogEntry(e, line)
l.decodeLogEntry(ctx, e, line)
if l.isIgnored(e.QHost) {
return nil, ts, nil
@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry(
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil {
log.Error(
"querylog: enriching file record at time %s for client %q (clientid %q): %s",
e.Time,
e.IP,
e.ClientID,
err,
l.logger.ErrorContext(
ctx,
"enriching file record",
"at", e.Time,
"client_ip", e.IP,
"client_id", e.ClientID,
slogutil.KeyError, err,
)
// Go on and try to match anyway.

View File

@ -5,6 +5,8 @@ import (
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
}
l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
FindClient: findClient,
BaseDir: t.TempDir(),
RotationIvl: timeutil.Day,
@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) {
AnonymizeClientIP: false,
})
require.NoError(t, err)
t.Cleanup(l.Close)
ctx := testutil.ContextWithTimeout(t, testTimeout)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return l.Shutdown(ctx)
})
q := &dns.Msg{
Question: []dns.Question{{
@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
olderThan: time.Now().Add(10 * time.Second),
limit: 3,
}
entries, _ := l.search(sp)
entries, _ := l.search(ctx, sp)
assert.Equal(t, 2, findClientCalls)
require.Len(t, entries, 3)

View File

@ -1,7 +1,9 @@
package querylog
import (
"context"
"fmt"
"log/slog"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict(
// quickMatch quickly checks if the line matches the given search criterion.
// It returns false if the like doesn't match. This method is only here for
// optimization purposes.
func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
func (c *searchCriterion) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
switch c.criterionType {
case ctTerm:
host := readJSONValue(line, `"QH":"`)
@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun
clientID := readJSONValue(line, `"CID":"`)
var name string
if cli := findClient(clientID, ip); cli != nil {
if cli := findClient(ctx, logger, clientID, ip); cli != nil {
name = cli.Name
}

View File

@ -1,6 +1,10 @@
package querylog
import "time"
import (
"context"
"log/slog"
"time"
)
// searchParams represent the search query sent by the client.
type searchParams struct {
@ -35,14 +39,23 @@ func newSearchParams() *searchParams {
}
// quickMatchClientFunc is a simplified client finder for quick matches.
type quickMatchClientFunc = func(clientID, ip string) (c *Client)
type quickMatchClientFunc = func(
ctx context.Context,
logger *slog.Logger,
clientID, ip string,
) (c *Client)
// quickMatch quickly checks if the line matches the given search parameters.
// It returns false if the line doesn't match. This method is only here for
// optimization purposes.
func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
func (s *searchParams) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
for _, c := range s.searchCriteria {
if !c.quickMatch(line, findClient) {
if !c.quickMatch(ctx, logger, line, findClient) {
return false
}
}