Merge branch 'master' into AGDNS-2556-custom-update-url

This commit is contained in:
Eugene Burkov 2024-11-22 17:51:46 +03:00
commit c9efb412e7
25 changed files with 586 additions and 253 deletions

View File

@ -32,6 +32,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
- The release executables are now signed. - The release executables are now signed.
### Added
- The `--no-permcheck` command-line option to disable checking and migration of
permissions for the security-sensitive files and directories, which caused
issues on Windows ([#7400]).
[#7400]: https://github.com/AdguardTeam/AdGuardHome/issues/7400
[go-1.23.3]: https://groups.google.com/g/golang-announce/c/X5KodEJYuqI [go-1.23.3]: https://groups.google.com/g/golang-announce/c/X5KodEJYuqI
<!-- <!--

View File

@ -15,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
@ -506,7 +505,7 @@ func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
// RemoveByName removes persistent client information. ok is false if no such // RemoveByName removes persistent client information. ok is false if no such
// client exists by that name. // client exists by that name.
func (s *Storage) RemoveByName(name string) (ok bool) { func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -516,7 +515,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
} }
if err := p.CloseUpstreams(); err != nil { if err := p.CloseUpstreams(); err != nil {
log.Error("client storage: removing client %q: %s", p.Name, err) s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
} }
s.index.remove(p) s.index.remove(p)

View File

@ -735,7 +735,7 @@ func TestStorage_RemoveByName(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
tc.want(t, s.RemoveByName(tc.cliName)) tc.want(t, s.RemoveByName(ctx, tc.cliName))
}) })
} }
@ -744,8 +744,8 @@ func TestStorage_RemoveByName(t *testing.T) {
err = s.Add(ctx, existingClient) err = s.Add(ctx, existingClient)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, s.RemoveByName(existingName)) assert.True(t, s.RemoveByName(ctx, existingName))
assert.False(t, s.RemoveByName(existingName)) assert.False(t, s.RemoveByName(ctx, existingName))
}) })
} }

View File

@ -369,7 +369,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return return
} }
if !clients.storage.RemoveByName(cj.Name) { if !clients.storage.RemoveByName(r.Context(), cj.Name) {
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
return return

View File

@ -48,11 +48,11 @@ func onConfigModified() {
// initDNS updates all the fields of the [Context] needed to initialize the DNS // initDNS updates all the fields of the [Context] needed to initialize the DNS
// server and initializes it at last. It also must not be called unless // server and initializes it at last. It also must not be called unless
// [config] and [Context] are initialized. l must not be nil. // [config] and [Context] are initialized. l must not be nil.
func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) { func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
anonymizer := config.anonymizer() anonymizer := config.anonymizer()
statsConf := stats.Config{ statsConf := stats.Config{
Logger: l.With(slogutil.KeyPrefix, "stats"), Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
Filename: filepath.Join(statsDir, "stats.db"), Filename: filepath.Join(statsDir, "stats.db"),
Limit: config.Stats.Interval.Duration, Limit: config.Stats.Interval.Duration,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
} }
conf := querylog.Config{ conf := querylog.Config{
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
Anonymizer: anonymizer, Anonymizer: anonymizer,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
@ -113,7 +114,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
anonymizer, anonymizer,
httpRegister, httpRegister,
tlsConf, tlsConf,
l, baseLogger,
) )
} }
@ -457,7 +458,8 @@ func startDNSServer() error {
Context.filters.EnableFilters(false) Context.filters.EnableFilters(false)
// TODO(s.chzhen): Pass context. // TODO(s.chzhen): Pass context.
err := Context.clients.Start(context.TODO()) ctx := context.TODO()
err := Context.clients.Start(ctx)
if err != nil { if err != nil {
return fmt.Errorf("starting clients container: %w", err) return fmt.Errorf("starting clients container: %w", err)
} }
@ -469,7 +471,11 @@ func startDNSServer() error {
Context.filters.Start() Context.filters.Start()
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start()
err = Context.queryLog.Start(ctx)
if err != nil {
return fmt.Errorf("starting query log: %w", err)
}
return nil return nil
} }
@ -525,12 +531,16 @@ func closeDNSServer() {
if Context.stats != nil { if Context.stats != nil {
err := Context.stats.Close() err := Context.stats.Close()
if err != nil { if err != nil {
log.Debug("closing stats: %s", err) log.Error("closing stats: %s", err)
} }
} }
if Context.queryLog != nil { if Context.queryLog != nil {
Context.queryLog.Close() // TODO(s.chzhen): Pass context.
err := Context.queryLog.Shutdown(context.TODO())
if err != nil {
log.Error("closing query log: %s", err)
}
} }
log.Debug("all dns modules are closed") log.Debug("all dns modules are closed")

View File

@ -159,7 +159,7 @@ func setupContext(opts options) (err error) {
if Context.firstRun { if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched") log.Info("This is the first time AdGuard Home is launched")
checkPermissions() checkNetworkPermissions()
return nil return nil
} }
@ -666,12 +666,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
} }
} }
if permcheck.NeedsMigration(confPath) { if !opts.noPermCheck {
permcheck.Migrate(Context.workDir, dataDir, statsDir, querylogDir, confPath) checkPermissions(Context.workDir, confPath, dataDir, statsDir, querylogDir)
} }
permcheck.Check(Context.workDir, dataDir, statsDir, querylogDir, confPath)
Context.web.start() Context.web.start()
// Wait for other goroutines to complete their job. // Wait for other goroutines to complete their job.
@ -740,6 +738,16 @@ func newUpdater(
}) })
} }
// checkPermissions checks and migrates permissions of the files and directories
// used by AdGuard Home, if needed.
func checkPermissions(workDir, confPath, dataDir, statsDir, querylogDir string) {
if permcheck.NeedsMigration(confPath) {
permcheck.Migrate(workDir, dataDir, statsDir, querylogDir, confPath)
}
permcheck.Check(workDir, dataDir, statsDir, querylogDir, confPath)
}
// initUsers initializes context auth module. Clears config users field. // initUsers initializes context auth module. Clears config users field.
func initUsers() (auth *Auth, err error) { func initUsers() (auth *Auth, err error) {
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
@ -799,8 +807,9 @@ func startMods(l *slog.Logger) (err error) {
return nil return nil
} }
// Check if the current user permissions are enough to run AdGuard Home // checkNetworkPermissions checks if the current user permissions are enough to
func checkPermissions() { // use the required networking functionality.
func checkNetworkPermissions() {
log.Info("Checking if AdGuard Home has necessary permissions") log.Info("Checking if AdGuard Home has necessary permissions")
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil { if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {

View File

@ -78,6 +78,10 @@ type options struct {
// localFrontend forces AdGuard Home to use the frontend files from disk // localFrontend forces AdGuard Home to use the frontend files from disk
// rather than the ones that have been compiled into the binary. // rather than the ones that have been compiled into the binary.
localFrontend bool localFrontend bool
// noPermCheck disables checking and migration of permissions for the
// security-sensitive files.
noPermCheck bool
} }
// initCmdLineOpts completes initialization of the global command-line option // initCmdLineOpts completes initialization of the global command-line option
@ -305,6 +309,15 @@ var cmdLineOpts = []cmdLineOpt{{
description: "Run in GL-Inet compatibility mode.", description: "Run in GL-Inet compatibility mode.",
longName: "glinet", longName: "glinet",
shortName: "", shortName: "",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noPermCheck = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.noPermCheck },
description: "Skip checking and migration of permissions " +
"of security-sensitive files.",
longName: "no-permcheck",
shortName: "",
}, { }, {
updateWithValue: nil, updateWithValue: nil,
updateNoValue: nil, updateNoValue: nil,

View File

@ -89,6 +89,12 @@ type options struct {
// TODO(a.garipov): Use. // TODO(a.garipov): Use.
performUpdate bool performUpdate bool
// noPermCheck, if true, instructs AdGuard Home to skip checking and
// migrating the permissions of its security-sensitive files.
//
// TODO(e.burkov): Use.
noPermCheck bool
// verbose, if true, instructs AdGuard Home to enable verbose logging. // verbose, if true, instructs AdGuard Home to enable verbose logging.
verbose bool verbose bool
@ -110,7 +116,8 @@ const (
disableUpdateIdx disableUpdateIdx
glinetModeIdx glinetModeIdx
helpIdx helpIdx
localFrontend localFrontendIdx
noPermCheckIdx
performUpdateIdx performUpdateIdx
verboseIdx verboseIdx
versionIdx versionIdx
@ -214,7 +221,7 @@ var commandLineOptions = []*commandLineOption{
valueType: "", valueType: "",
}, },
localFrontend: { localFrontendIdx: {
defaultValue: false, defaultValue: false,
description: "Use local frontend directories.", description: "Use local frontend directories.",
long: "local-frontend", long: "local-frontend",
@ -222,6 +229,14 @@ var commandLineOptions = []*commandLineOption{
valueType: "", valueType: "",
}, },
noPermCheckIdx: {
defaultValue: false,
description: "Skip checking the permissions of security-sensitive files.",
long: "no-permcheck",
short: "",
valueType: "",
},
performUpdateIdx: { performUpdateIdx: {
defaultValue: false, defaultValue: false,
description: "Update the current binary and restart the service in case it's installed.", description: "Update the current binary and restart the service in case it's installed.",
@ -264,7 +279,8 @@ func parseOptions(cmdName string, args []string) (opts *options, err error) {
disableUpdateIdx: &opts.disableUpdate, disableUpdateIdx: &opts.disableUpdate,
glinetModeIdx: &opts.glinetMode, glinetModeIdx: &opts.glinetMode,
helpIdx: &opts.help, helpIdx: &opts.help,
localFrontend: &opts.localFrontend, localFrontendIdx: &opts.localFrontend,
noPermCheckIdx: &opts.noPermCheck,
performUpdateIdx: &opts.performUpdate, performUpdateIdx: &opts.performUpdate,
verboseIdx: &opts.verbose, verboseIdx: &opts.verbose,
versionIdx: &opts.version, versionIdx: &opts.version,

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -13,7 +14,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{
} }
// decodeResultRuleKey decodes the token of "Rules" type to logEntry struct. // decodeResultRuleKey decodes the token of "Rules" type to logEntry struct.
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultRuleKey(
ctx context.Context,
key string,
i int,
dec *json.Decoder,
ent *logEntry,
) {
var vToken json.Token var vToken json.Token
switch key { switch key {
case "FilterListID": case "FilterListID":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if n, ok := vToken.(json.Number); ok { if n, ok := vToken.(json.Number); ok {
id, _ := n.Int64() id, _ := n.Int64()
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id) ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
} }
case "IP": case "IP":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if ipStr, ok := vToken.(string); ok { if ipStr, ok := vToken.(string); ok {
if ip, err := netip.ParseAddr(ipStr); err == nil { if ip, err := netip.ParseAddr(ipStr); err == nil {
ent.Result.Rules[i].IP = ip ent.Result.Rules[i].IP = ip
} else { } else {
log.Debug("querylog: decoding ipStr value: %s", err) l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err)
} }
} }
case "Text": case "Text":
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules) ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
if s, ok := vToken.(string); ok { if s, ok := vToken.(string); ok {
ent.Result.Rules[i].Text = s ent.Result.Rules[i].Text = s
} }
@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
// decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule] // decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule]
// and then adds the decoded object to the slice of result rules. // and then adds the decoded object to the slice of result rules.
func decodeVTokenAndAddRule( func (l *queryLog) decodeVTokenAndAddRule(
ctx context.Context,
key string, key string,
i int, i int,
dec *json.Decoder, dec *json.Decoder,
@ -215,7 +223,12 @@ func decodeVTokenAndAddRule(
vToken, err := dec.Token() vToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultRuleKey %s err: %s", key, err) l.logger.DebugContext(
ctx,
"decoding result rule key",
"key", key,
slogutil.KeyError, err,
)
} }
return newRules, nil return newRules, nil
@ -230,12 +243,14 @@ func decodeVTokenAndAddRule(
// decodeResultRules parses the dec's tokens into logEntry ent interpreting it // decodeResultRules parses the dec's tokens into logEntry ent interpreting it
// as a slice of the result rules. // as a slice of the result rules.
func decodeResultRules(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result rules"
for { for {
delimToken, err := dec.Token() delimToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultRules err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -244,13 +259,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
if d, ok := delimToken.(json.Delim); !ok { if d, ok := delimToken.(json.Delim); !ok {
return return
} else if d != '[' { } else if d != '[' {
log.Debug("decodeResultRules: unexpected delim %q", d) l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(d),
)
} }
err = decodeResultRuleToken(dec, ent) err = l.decodeResultRuleToken(ctx, dec, ent)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultRules err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; rule token", slogutil.KeyError, err)
} }
return return
@ -259,7 +278,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
} }
// decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent. // decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent.
func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) { func (l *queryLog) decodeResultRuleToken(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) (err error) {
i := 0 i := 0
for { for {
var keyToken json.Token var keyToken json.Token
@ -287,7 +310,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken) return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken)
} }
decodeResultRuleKey(key, i, dec, ent) l.decodeResultRuleKey(ctx, key, i, dec, ent)
} }
} }
@ -296,12 +319,14 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
// other occurrences of DNSRewriteResult in the entry since hosts container's // other occurrences of DNSRewriteResult in the entry since hosts container's
// rewrites currently has the highest priority along the entire filtering // rewrites currently has the highest priority along the entire filtering
// pipeline. // pipeline.
func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result reverse hosts"
for { for {
itemToken, err := dec.Token() itemToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultReverseHosts err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -315,7 +340,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
return return
} }
log.Debug("decodeResultReverseHosts: unexpected delim %q", v) l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(v),
)
return return
case string: case string:
@ -346,12 +375,14 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
// decodeResultIPList parses the dec's tokens into logEntry ent interpreting it // decodeResultIPList parses the dec's tokens into logEntry ent interpreting it
// as the result IP addresses list. // as the result IP addresses list.
func decodeResultIPList(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result ip list"
for { for {
itemToken, err := dec.Token() itemToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultIPList err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -365,7 +396,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
return return
} }
log.Debug("decodeResultIPList: unexpected delim %q", v) l.logger.DebugContext(
ctx,
msgPrefix,
slogutil.KeyError, newUnexpectedDelimiterError(v),
)
return return
case string: case string:
@ -382,7 +417,14 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
// decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type // decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type
// to the logEntry struct. // to the logEntry struct.
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultDNSRewriteResultKey(
ctx context.Context,
key string,
dec *json.Decoder,
ent *logEntry,
) {
const msgPrefix = "decoding result dns rewrite result key"
var err error var err error
switch key { switch key {
@ -391,7 +433,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
vToken, err = dec.Token() vToken, err = dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeResultDNSRewriteResultKey err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -419,7 +461,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decoding and correct the values. // decoding and correct the values.
err = dec.Decode(&ent.Result.DNSRewriteResult.Response) err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
if err != nil { if err != nil {
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; response", slogutil.KeyError, err)
} }
ent.parseDNSRewriteResultIPs() ent.parseDNSRewriteResultIPs()
@ -430,12 +472,18 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
// decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent // decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent
// interpreting it as the result DNSRewriteResult. // interpreting it as the result DNSRewriteResult.
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResultDNSRewriteResult(
ctx context.Context,
dec *json.Decoder,
ent *logEntry,
) {
const msgPrefix = "decoding result dns rewrite result"
for { for {
key, err := parseKeyToken(dec) key, err := parseKeyToken(dec)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResultDNSRewriteResult: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -445,7 +493,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
continue continue
} }
decodeResultDNSRewriteResultKey(key, dec, ent) l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent)
} }
} }
@ -508,14 +556,16 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) {
} }
// decodeResult decodes a token of "Result" type to logEntry struct. // decodeResult decodes a token of "Result" type to logEntry struct.
func decodeResult(dec *json.Decoder, ent *logEntry) { func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) {
const msgPrefix = "decoding result"
defer translateResult(ent) defer translateResult(ent)
for { for {
key, err := parseKeyToken(dec) key, err := parseKeyToken(dec)
if err != nil { if err != nil {
if err != io.EOF && !errors.Is(err, ErrEndOfToken) { if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
log.Debug("decodeResult: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -525,10 +575,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
continue continue
} }
decHandler, ok := resultDecHandlers[key] ok := l.resultDecHandler(ctx, key, dec, ent)
if ok { if ok {
decHandler(dec, ent)
continue continue
} }
@ -543,7 +591,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
} }
if err = handler(val, ent); err != nil { if err = handler(val, ent); err != nil {
log.Debug("decodeResult handler err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
return return
} }
@ -636,16 +684,34 @@ var resultHandlers = map[string]logEntryHandler{
}, },
} }
// resultDecHandlers is the map of decode handlers for various keys. // resultDecHandlers calls a decode handler for key if there is one.
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){ func (l *queryLog) resultDecHandler(
"ReverseHosts": decodeResultReverseHosts, ctx context.Context,
"IPList": decodeResultIPList, name string,
"Rules": decodeResultRules, dec *json.Decoder,
"DNSRewriteResult": decodeResultDNSRewriteResult, ent *logEntry,
) (ok bool) {
ok = true
switch name {
case "ReverseHosts":
l.decodeResultReverseHosts(ctx, dec, ent)
case "IPList":
l.decodeResultIPList(ctx, dec, ent)
case "Rules":
l.decodeResultRules(ctx, dec, ent)
case "DNSRewriteResult":
l.decodeResultDNSRewriteResult(ctx, dec, ent)
default:
ok = false
}
return ok
} }
// decodeLogEntry decodes string str to logEntry ent. // decodeLogEntry decodes string str to logEntry ent.
func decodeLogEntry(ent *logEntry, str string) { func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) {
const msgPrefix = "decoding log entry"
dec := json.NewDecoder(strings.NewReader(str)) dec := json.NewDecoder(strings.NewReader(str))
dec.UseNumber() dec.UseNumber()
@ -653,7 +719,7 @@ func decodeLogEntry(ent *logEntry, str string) {
keyToken, err := dec.Token() keyToken, err := dec.Token()
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
log.Debug("decodeLogEntry err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
} }
return return
@ -665,13 +731,14 @@ func decodeLogEntry(ent *logEntry, str string) {
key, ok := keyToken.(string) key, ok := keyToken.(string)
if !ok { if !ok {
log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken) err = fmt.Errorf("%s: keyToken is %T (%[2]v) and not string", msgPrefix, keyToken)
l.logger.DebugContext(ctx, msgPrefix, slogutil.KeyError, err)
return return
} }
if key == "Result" { if key == "Result" {
decodeResult(dec, ent) l.decodeResult(ctx, dec, ent)
continue continue
} }
@ -687,9 +754,14 @@ func decodeLogEntry(ent *logEntry, str string) {
} }
if err = handler(val, ent); err != nil { if err = handler(val, ent); err != nil {
log.Debug("decodeLogEntry handler err: %s", err) l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
return return
} }
} }
} }
// newUnexpectedDelimiterError is a helper for creating informative errors.
func newUnexpectedDelimiterError(d json.Delim) (err error) {
return fmt.Errorf("unexpected delimiter: %q", d)
}

View File

@ -3,27 +3,35 @@ package querylog
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/base64"
"log/slog"
"net" "net"
"net/netip" "net/netip"
"strings"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// Common constants for tests.
const testTimeout = 1 * time.Second
func TestDecodeLogEntry(t *testing.T) { func TestDecodeLogEntry(t *testing.T) {
logOutput := &bytes.Buffer{} logOutput := &bytes.Buffer{}
l := &queryLog{
logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
Level: slog.LevelDebug,
ReplaceAttr: slogutil.RemoveTime,
})),
}
aghtest.ReplaceLogWriter(t, logOutput) ctx := testutil.ContextWithTimeout(t, testTimeout)
aghtest.ReplaceLogLevel(t, log.DEBUG)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) {
} }
got := &logEntry{} got := &logEntry{}
decodeLogEntry(got, data) l.decodeLogEntry(ctx, got, data)
s := logOutput.String() s := logOutput.String()
assert.Empty(t, s) assert.Empty(t, s)
@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_filter_id_old_rule", name: "bad_filter_id_old_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`,
want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n", want: `level=DEBUG msg="decoding result; handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`,
}, { }, {
name: "bad_is_filtered", name: "bad_is_filtered",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n", want: `level=DEBUG msg="decoding log entry; token" err="invalid character 'o' in literal true (expecting 'u')"`,
}, { }, {
name: "bad_elapsed", name: "bad_elapsed",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`,
@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_time", name: "bad_time",
log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n", want: `level=DEBUG msg="decoding log entry; handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`,
}, { }, {
name: "bad_host", name: "bad_host",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "very_bad_client_proto", name: "very_bad_client_proto",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n", want: `level=DEBUG msg="decoding log entry; handler" err="invalid client proto: \"dog\""`,
}, { }, {
name: "bad_answer", name: "bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "very_bad_answer", name: "very_bad_answer",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n", want: `level=DEBUG msg="decoding log entry; handler" err="illegal base64 data at input byte 61"`,
}, { }, {
name: "bad_rule", name: "bad_rule",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`,
@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) {
}, { }, {
name: "bad_reverse_hosts", name: "bad_reverse_hosts",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`,
want: "decodeResultReverseHosts: unexpected delim \"{\"\n", want: `level=DEBUG msg="decoding result reverse hosts" err="unexpected delimiter: \"{\""`,
}, { }, {
name: "bad_ip_list", name: "bad_ip_list",
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`, log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`,
want: "decodeResultIPList: unexpected delim \"{\"\n", want: `level=DEBUG msg="decoding result ip list" err="unexpected delimiter: \"{\""`,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
decodeLogEntry(new(logEntry), tc.log) l.decodeLogEntry(ctx, new(logEntry), tc.log)
got := logOutput.String()
s := logOutput.String()
if tc.want == "" { if tc.want == "" {
assert.Empty(t, s) assert.Empty(t, got)
} else { } else {
assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s) require.NotEmpty(t, got)
// Remove newline.
got = got[:len(got)-1]
assert.Equal(t, tc.want, got)
} }
logOutput.Reset() logOutput.Reset()
@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
aaaa2 = aaaa1.Next() aaaa2 = aaaa1.Next()
) )
l := &queryLog{
logger: slogutil.NewDiscardLogger(),
}
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
want *logEntry want *logEntry
entry string entry string
@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
e := &logEntry{} e := &logEntry{}
decodeLogEntry(e, tc.entry) l.decodeLogEntry(ctx, e, tc.entry)
assert.Equal(t, tc.want, e) assert.Equal(t, tc.want, e)
}) })

View File

@ -1,12 +1,14 @@
package querylog package querylog
import ( import (
"context"
"log/slog"
"net" "net"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) {
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is // addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any // true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
// errors are logged. // errors are logged.
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) { func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) {
if resp == nil { if resp == nil {
return return
} }
@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
e.Answer, err = resp.Pack() e.Answer, err = resp.Pack()
err = errors.Annotate(err, "packing answer: %w") err = errors.Annotate(err, "packing answer: %w")
} }
if err != nil { if err != nil {
log.Error("querylog: %s", err) l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err)
} }
} }

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"math" "math"
@ -15,7 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -74,7 +75,8 @@ func (l *queryLog) initWeb() {
// handleQueryLog is the handler for the GET /control/querylog HTTP API. // handleQueryLog is the handler for the GET /control/querylog HTTP API.
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
params, err := parseSearchParams(r) ctx := r.Context()
params, err := l.parseSearchParams(ctx, r)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
@ -87,18 +89,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
entries, oldest = l.search(params) entries, oldest = l.search(ctx, params)
}() }()
resp := entriesToJSON(entries, oldest, l.anonymizer.Load()) resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
aghhttp.WriteJSONResponseOK(w, r, resp) aghhttp.WriteJSONResponseOK(w, r, resp)
} }
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP // handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
// API. // API.
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) {
l.clear() l.clear(r.Context())
} }
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP // handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
@ -280,11 +282,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
} }
// parseSearchCriterion parses a search criterion from the query parameter. // parseSearchCriterion parses a search criterion from the query parameter.
func parseSearchCriterion(q url.Values, name string, ct criterionType) ( func (l *queryLog) parseSearchCriterion(
ok bool, ctx context.Context,
sc searchCriterion, q url.Values,
err error, name string,
) { ct criterionType,
) (ok bool, sc searchCriterion, err error) {
val := q.Get(name) val := q.Get(name)
if val == "" { if val == "" {
return false, sc, nil return false, sc, nil
@ -301,7 +304,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// TODO(e.burkov): Make it work with parts of IDNAs somehow. // TODO(e.burkov): Make it work with parts of IDNAs somehow.
loweredVal := strings.ToLower(val) loweredVal := strings.ToLower(val)
if asciiVal, err = idna.ToASCII(loweredVal); err != nil { if asciiVal, err = idna.ToASCII(loweredVal); err != nil {
log.Debug("can't convert %q to ascii: %s", val, err) l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err)
} else if asciiVal == loweredVal { } else if asciiVal == loweredVal {
// Purge asciiVal to prevent checking the same value // Purge asciiVal to prevent checking the same value
// twice. // twice.
@ -331,7 +334,10 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
// parseSearchParams parses search parameters from the HTTP request's query // parseSearchParams parses search parameters from the HTTP request's query
// string. // string.
func parseSearchParams(r *http.Request) (p *searchParams, err error) { func (l *queryLog) parseSearchParams(
ctx context.Context,
r *http.Request,
) (p *searchParams, err error) {
p = newSearchParams() p = newSearchParams()
q := r.URL.Query() q := r.URL.Query()
@ -369,7 +375,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) {
}} { }} {
var ok bool var ok bool
var c searchCriterion var c searchCriterion
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct) ok, c, err = l.parseSearchCriterion(ctx, q, v.urlField, v.ct)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,6 +1,7 @@
package querylog package querylog
import ( import (
"context"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@ -8,7 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.org/x/net/idna" "golang.org/x/net/idna"
) )
@ -19,7 +20,8 @@ import (
type jobject = map[string]any type jobject = map[string]any
// entriesToJSON converts query log entries to JSON. // entriesToJSON converts query log entries to JSON.
func entriesToJSON( func (l *queryLog) entriesToJSON(
ctx context.Context,
entries []*logEntry, entries []*logEntry,
oldest time.Time, oldest time.Time,
anonFunc aghnet.IPMutFunc, anonFunc aghnet.IPMutFunc,
@ -28,7 +30,7 @@ func entriesToJSON(
// The elements order is already reversed to be from newer to older. // The elements order is already reversed to be from newer to older.
for _, entry := range entries { for _, entry := range entries {
jsonEntry := entryToJSON(entry, anonFunc) jsonEntry := l.entryToJSON(ctx, entry, anonFunc)
data = append(data, jsonEntry) data = append(data, jsonEntry)
} }
@ -44,7 +46,11 @@ func entriesToJSON(
} }
// entryToJSON converts a log entry's data into an entry for the JSON API. // entryToJSON converts a log entry's data into an entry for the JSON API.
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) { func (l *queryLog) entryToJSON(
ctx context.Context,
entry *logEntry,
anonFunc aghnet.IPMutFunc,
) (jsonEntry jobject) {
hostname := entry.QHost hostname := entry.QHost
question := jobject{ question := jobject{
"type": entry.QType, "type": entry.QType,
@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
} }
if qhost, err := idna.ToUnicode(hostname); err != nil { if qhost, err := idna.ToUnicode(hostname); err != nil {
log.Debug("querylog: translating %q into unicode: %s", hostname, err) l.logger.DebugContext(
ctx,
"translating into unicode",
"hostname", hostname,
slogutil.KeyError, err,
)
} else if qhost != hostname && qhost != "" { } else if qhost != hostname && qhost != "" {
question["unicode_name"] = qhost question["unicode_name"] = qhost
} }
@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
jsonEntry["service_name"] = entry.Result.ServiceName jsonEntry["service_name"] = entry.Result.ServiceName
} }
setMsgData(entry, jsonEntry) l.setMsgData(ctx, entry, jsonEntry)
setOrigAns(entry, jsonEntry) l.setOrigAns(ctx, entry, jsonEntry)
return jsonEntry return jsonEntry
} }
// setMsgData sets the message data in jsonEntry. // setMsgData sets the message data in jsonEntry.
func setMsgData(entry *logEntry, jsonEntry jobject) { func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.Answer) == 0 { if len(entry.Answer) == 0 {
return return
} }
msg := &dns.Msg{} msg := &dns.Msg{}
if err := msg.Unpack(entry.Answer); err != nil { if err := msg.Unpack(entry.Answer); err != nil {
log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err) l.logger.DebugContext(
ctx,
"unpacking dns message",
"answer", entry.Answer,
slogutil.KeyError, err,
)
return return
} }
@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) {
} }
// setOrigAns sets the original answer data in jsonEntry. // setOrigAns sets the original answer data in jsonEntry.
func setOrigAns(entry *logEntry, jsonEntry jobject) { func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) {
if len(entry.OrigAnswer) == 0 { if len(entry.OrigAnswer) == 0 {
return return
} }
@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) {
orig := &dns.Msg{} orig := &dns.Msg{}
err := orig.Unpack(entry.OrigAnswer) err := orig.Unpack(entry.OrigAnswer)
if err != nil { if err != nil {
log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err) l.logger.DebugContext(
ctx,
"setting original answer",
"answer", entry.OrigAnswer,
slogutil.KeyError, err,
)
return return
} }

View File

@ -2,7 +2,9 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog"
"os" "os"
"sync" "sync"
"time" "time"
@ -11,7 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json"
// queryLog is a structure that writes and reads the DNS query log. // queryLog is a structure that writes and reads the DNS query log.
type queryLog struct { type queryLog struct {
// logger is used for logging the operation of the query log. It must not
// be nil.
logger *slog.Logger
// confMu protects conf. // confMu protects conf.
confMu *sync.RWMutex confMu *sync.RWMutex
@ -76,24 +82,34 @@ func NewClientProto(s string) (cp ClientProto, err error) {
} }
} }
func (l *queryLog) Start() { // type check
var _ QueryLog = (*queryLog)(nil)
// Start implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Start(ctx context.Context) (err error) {
if l.conf.HTTPRegister != nil { if l.conf.HTTPRegister != nil {
l.initWeb() l.initWeb()
} }
go l.periodicRotate() go l.periodicRotate(ctx)
return nil
} }
func (l *queryLog) Close() { // Shutdown implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Shutdown(ctx context.Context) (err error) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
if l.conf.FileEnabled { if l.conf.FileEnabled {
err := l.flushLogBuffer() err = l.flushLogBuffer(ctx)
if err != nil { if err != nil {
log.Error("querylog: closing: %s", err) // Don't wrap the error because it's informative enough as is.
return err
} }
} }
return nil
} }
func checkInterval(ivl time.Duration) (ok bool) { func checkInterval(ivl time.Duration) (ok bool) {
@ -123,6 +139,7 @@ func validateIvl(ivl time.Duration) (err error) {
return nil return nil
} }
// WriteDiskConfig implements the [QueryLog] interface for *queryLog.
func (l *queryLog) WriteDiskConfig(c *Config) { func (l *queryLog) WriteDiskConfig(c *Config) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
@ -131,7 +148,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) {
} }
// Clear memory buffer and remove log files // Clear memory buffer and remove log files
func (l *queryLog) clear() { func (l *queryLog) clear(ctx context.Context) {
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
@ -146,19 +163,24 @@ func (l *queryLog) clear() {
oldLogFile := l.logFile + ".1" oldLogFile := l.logFile + ".1"
err := os.Remove(oldLogFile) err := os.Remove(oldLogFile)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing old log file %q: %s", oldLogFile, err) l.logger.ErrorContext(
ctx,
"removing old log file",
"file", oldLogFile,
slogutil.KeyError, err,
)
} }
err = os.Remove(l.logFile) err = os.Remove(l.logFile)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("removing log file %q: %s", l.logFile, err) l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err)
} }
log.Debug("querylog: cleared") l.logger.DebugContext(ctx, "cleared")
} }
// newLogEntry creates an instance of logEntry from parameters. // newLogEntry creates an instance of logEntry from parameters.
func newLogEntry(params *AddParams) (entry *logEntry) { func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) {
q := params.Question.Question[0] q := params.Question.Question[0]
qHost := aghnet.NormalizeDomain(q.Name) qHost := aghnet.NormalizeDomain(q.Name)
@ -187,8 +209,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
entry.ReqECS = params.ReqECS.String() entry.ReqECS = params.ReqECS.String()
} }
entry.addResponse(params.Answer, false) entry.addResponse(ctx, logger, params.Answer, false)
entry.addResponse(params.OrigAnswer, true) entry.addResponse(ctx, logger, params.OrigAnswer, true)
return entry return entry
} }
@ -209,9 +231,12 @@ func (l *queryLog) Add(params *AddParams) {
return return
} }
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err := params.validate() err := params.validate()
if err != nil { if err != nil {
log.Error("querylog: adding record: %s, skipping", err) l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err)
return return
} }
@ -220,7 +245,7 @@ func (l *queryLog) Add(params *AddParams) {
params.Result = &filtering.Result{} params.Result = &filtering.Result{}
} }
entry := newLogEntry(params) entry := newLogEntry(ctx, l.logger, params)
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
@ -232,9 +257,9 @@ func (l *queryLog) Add(params *AddParams) {
// TODO(s.chzhen): Fix occasional rewrite of entires. // TODO(s.chzhen): Fix occasional rewrite of entires.
go func() { go func() {
flushErr := l.flushLogBuffer() flushErr := l.flushLogBuffer(ctx)
if flushErr != nil { if flushErr != nil {
log.Error("querylog: flushing after adding: %s", flushErr) l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr)
} }
}() }()
} }
@ -247,7 +272,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
c, err := l.findClient(ids) c, err := l.findClient(ids)
if err != nil { if err != nil {
log.Error("querylog: finding client: %s", err) // TODO(s.chzhen): Pass context.
l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err)
} }
if c != nil && c.IgnoreQueryLog { if c != nil && c.IgnoreQueryLog {

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -14,14 +15,11 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// TestQueryLog tests adding and loading (with filtering) entries from disk and // TestQueryLog tests adding and loading (with filtering) entries from disk and
// memory. // memory.
func TestQueryLog(t *testing.T) { func TestQueryLog(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: true, FileEnabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -30,16 +28,21 @@ func TestQueryLog(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file). // Write to disk (first file).
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Start writing to the second file. // Start writing to the second file.
require.NoError(t, l.rotate()) require.NoError(t, l.rotate(ctx))
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk. // Write to disk.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Add memory entries. // Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -119,8 +122,9 @@ func TestQueryLog(t *testing.T) {
params := newSearchParams() params := newSearchParams()
params.searchCriteria = tc.sCr params.searchCriteria = tc.sCr
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
require.Len(t, entries, len(tc.want)) require.Len(t, entries, len(tc.want))
for _, want := range tc.want { for _, want := range tc.want {
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
} }
@ -130,6 +134,7 @@ func TestQueryLog(t *testing.T) {
func TestQueryLogOffsetLimit(t *testing.T) { func TestQueryLogOffsetLimit(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
MemSize: 100, MemSize: 100,
@ -142,12 +147,16 @@ func TestQueryLogOffsetLimit(t *testing.T) {
firstPageDomain = "first.example.org" firstPageDomain = "first.example.org"
secondPageDomain = "second.example.org" secondPageDomain = "second.example.org"
) )
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Add entries to the log. // Add entries to the log.
for range entNum { for range entNum {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to the first file. // Write them to the first file.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for range entNum { for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -191,8 +200,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
params.offset = tc.offset params.offset = tc.offset
params.limit = tc.limit params.limit = tc.limit
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
require.Len(t, entries, tc.wantLen) require.Len(t, entries, tc.wantLen)
if tc.wantLen > 0 { if tc.wantLen > 0 {
@ -205,6 +213,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
func TestQueryLogMaxFileScanEntries(t *testing.T) { func TestQueryLogMaxFileScanEntries(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: true, FileEnabled: true,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -213,20 +222,21 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout)
const entNum = 10 const entNum = 10
// Add entries to the log. // Add entries to the log.
for range entNum { for range entNum {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to disk. // Write them to disk.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer(ctx))
params := newSearchParams() params := newSearchParams()
for _, maxFileScanEntries := range []int{5, 0} { for _, maxFileScanEntries := range []int{5, 0} {
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
params.maxFileScanEntries = maxFileScanEntries params.maxFileScanEntries = maxFileScanEntries
entries, _ := l.search(params) entries, _ := l.search(ctx, params)
assert.Len(t, entries, entNum-maxFileScanEntries) assert.Len(t, entries, entNum-maxFileScanEntries)
}) })
} }
@ -234,6 +244,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
func TestQueryLogFileDisabled(t *testing.T) { func TestQueryLogFileDisabled(t *testing.T) {
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
Enabled: true, Enabled: true,
FileEnabled: false, FileEnabled: false,
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -248,8 +259,10 @@ func TestQueryLogFileDisabled(t *testing.T) {
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
params := newSearchParams() params := newSearchParams()
ll, _ := l.search(params) ctx := testutil.ContextWithTimeout(t, testTimeout)
ll, _ := l.search(ctx, params)
require.Len(t, ll, 2) require.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost) assert.Equal(t, "example2.org", ll[1].QHost)
} }

View File

@ -1,8 +1,10 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -10,7 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
const ( const (
@ -102,7 +104,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6
// for so that when we call "ReadNext" this line was returned. // for so that when we call "ReadNext" this line was returned.
// - Depth of the search (how many times we compared timestamps). // - Depth of the search (how many times we compared timestamps).
// - If we could not find it, it returns one of the errors described above. // - If we could not find it, it returns one of the errors described above.
func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) { func (q *qLogFile) seekTS(
ctx context.Context,
logger *slog.Logger,
timestamp int64,
) (pos int64, depth int, err error) {
q.lock.Lock() q.lock.Lock()
defer q.lock.Unlock() defer q.lock.Unlock()
@ -151,7 +157,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
lastProbeLineIdx = lineIdx lastProbeLineIdx = lineIdx
// Get the timestamp from the query log record. // Get the timestamp from the query log record.
ts := readQLogTimestamp(line) ts := readQLogTimestamp(ctx, logger, line)
if ts == 0 { if ts == 0 {
return 0, depth, fmt.Errorf( return 0, depth, fmt.Errorf(
"looking up timestamp %d in %q: record %q has empty timestamp", "looking up timestamp %d in %q: record %q has empty timestamp",
@ -385,20 +391,22 @@ func readJSONValue(s, prefix string) string {
} }
// readQLogTimestamp reads the timestamp field from the query log line. // readQLogTimestamp reads the timestamp field from the query log line.
func readQLogTimestamp(str string) int64 { func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 {
val := readJSONValue(str, `"T":"`) val := readJSONValue(str, `"T":"`)
if len(val) == 0 { if len(val) == 0 {
val = readJSONValue(str, `"Time":"`) val = readJSONValue(str, `"Time":"`)
} }
if len(val) == 0 { if len(val) == 0 {
log.Error("Couldn't find timestamp: %s", str) logger.ErrorContext(ctx, "couldn't find timestamp", "line", str)
return 0 return 0
} }
tm, err := time.Parse(time.RFC3339Nano, val) tm, err := time.Parse(time.RFC3339Nano, val)
if err != nil { if err != nil {
log.Error("Couldn't parse timestamp: %s", val) logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val, slogutil.KeyError, err)
return 0 return 0
} }

View File

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -24,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
f, err := os.CreateTemp(dir, "*.txt") f, err := os.CreateTemp(dir, "*.txt")
require.NoError(t, err) require.NoError(t, err)
// Use defer and not t.Cleanup to make sure that the file is closed // Use defer and not t.Cleanup to make sure that the file is closed
// after this function is done. // after this function is done.
defer func() { defer func() {
@ -108,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
// Calculate the expected position. // Calculate the expected position.
fileInfo, err := q.file.Stat() fileInfo, err := q.file.Stat()
require.NoError(t, err) require.NoError(t, err)
var expPos int64 var expPos int64
if expPos = fileInfo.Size(); expPos > 0 { if expPos = fileInfo.Size(); expPos > 0 {
expPos-- expPos--
@ -129,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
} }
require.Equal(t, io.EOF, err) require.Equal(t, io.EOF, err)
assert.Equal(t, tc.linesNum, read) assert.Equal(t, tc.linesNum, read)
}) })
} }
@ -146,6 +150,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
num: 10, num: 10,
}} }}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases { for _, l := range linesCases {
testCases := []struct { testCases := []struct {
name string name string
@ -171,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
t.Run(l.name+"_"+tc.name, func(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line) line, err := getQLogFileLine(q, tc.line)
require.NoError(t, err) require.NoError(t, err)
ts := readQLogTimestamp(line)
ts := readQLogTimestamp(ctx, logger, line)
assert.NotEqualValues(t, 0, ts) assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now. // Try seeking to that line now.
pos, _, err := q.seekTS(ts) pos, _, err := q.seekTS(ctx, logger, ts)
require.NoError(t, err) require.NoError(t, err)
assert.NotEqualValues(t, 0, pos) assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext() testLine, err := q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, line, testLine) assert.Equal(t, line, testLine)
}) })
} }
@ -199,6 +209,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
num: 10, num: 10,
}} }}
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
for _, l := range linesCases { for _, l := range linesCases {
testCases := []struct { testCases := []struct {
name string name string
@ -221,14 +234,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
line, err := getQLogFileLine(q, l.num/2) line, err := getQLogFileLine(q, l.num/2)
require.NoError(t, err) require.NoError(t, err)
testCases[2].ts = readQLogTimestamp(line) - 1
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
assert.NotEqualValues(t, 0, tc.ts) assert.NotEqualValues(t, 0, tc.ts)
var depth int var depth int
_, depth, err = q.seekTS(tc.ts) _, depth, err = q.seekTS(ctx, logger, tc.ts)
assert.NotEmpty(t, l.num) assert.NotEmpty(t, l.num)
require.Error(t, err) require.Error(t, err)
@ -262,11 +275,13 @@ func TestQLogFile(t *testing.T) {
// Seek to the start. // Seek to the start.
pos, err := q.SeekStart() pos, err := q.SeekStart()
require.NoError(t, err) require.NoError(t, err)
assert.Greater(t, pos, int64(0)) assert.Greater(t, pos, int64(0))
// Read first line. // Read first line.
line, err := q.ReadNext() line, err := q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, line, "0.0.0.2") assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line) assert.True(t, strings.HasSuffix(line, "}"), line)
@ -274,6 +289,7 @@ func TestQLogFile(t *testing.T) {
// Read second line. // Read second line.
line, err = q.ReadNext() line, err = q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 0, q.position) assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1") assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
@ -282,12 +298,14 @@ func TestQLogFile(t *testing.T) {
// Try reading again (there's nothing to read anymore). // Try reading again (there's nothing to read anymore).
line, err = q.ReadNext() line, err = q.ReadNext()
require.Equal(t, io.EOF, err) require.Equal(t, io.EOF, err)
assert.Empty(t, line) assert.Empty(t, line)
} }
func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
f, err := os.CreateTemp(t.TempDir(), "*.txt") f, err := os.CreateTemp(t.TempDir(), "*.txt")
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, f.Close) testutil.CleanupAndRequireSuccess(t, f.Close)
_, err = f.WriteString(data) _, err = f.WriteString(data)
@ -295,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
file, err = newQLogFile(f.Name()) file, err = newQLogFile(f.Name())
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, file.Close) testutil.CleanupAndRequireSuccess(t, file.Close)
return file return file
@ -308,6 +327,9 @@ func TestQLog_Seek(t *testing.T) {
`{"T":"` + strV + `"}` + nl `{"T":"` + strV + `"}` + nl
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00") timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
wantErr error wantErr error
name string name string
@ -340,8 +362,10 @@ func TestQLog_Seek(t *testing.T) {
q := newTestQLogFileData(t, data) q := newTestQLogFileData(t, data)
_, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()) ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
_, depth, err := q.seekTS(ctx, logger, ts)
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
assert.Equal(t, tc.wantDepth, depth) assert.Equal(t, tc.wantDepth, depth)
}) })
} }

View File

@ -1,12 +1,14 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// qLogReader allows reading from multiple query log files in the reverse // qLogReader allows reading from multiple query log files in the reverse
@ -16,6 +18,10 @@ import (
// pointer to a particular query log file, and to a specific position in this // pointer to a particular query log file, and to a specific position in this
// file, and it reads lines in reverse order starting from that position. // file, and it reads lines in reverse order starting from that position.
type qLogReader struct { type qLogReader struct {
// logger is used for logging the operation of the query log reader. It
// must not be nil.
logger *slog.Logger
// qFiles is an array with the query log files. The order is from oldest // qFiles is an array with the query log files. The order is from oldest
// to newest. // to newest.
qFiles []*qLogFile qFiles []*qLogFile
@ -25,7 +31,7 @@ type qLogReader struct {
} }
// newQLogReader initializes a qLogReader instance with the specified files. // newQLogReader initializes a qLogReader instance with the specified files.
func newQLogReader(files []string) (*qLogReader, error) { func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) {
qFiles := make([]*qLogFile, 0) qFiles := make([]*qLogFile, 0)
for _, f := range files { for _, f := range files {
@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) {
// Close what we've already opened. // Close what we've already opened.
cErr := closeQFiles(qFiles) cErr := closeQFiles(qFiles)
if cErr != nil { if cErr != nil {
log.Debug("querylog: closing files: %s", cErr) logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr)
} }
return nil, err return nil, err
@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) {
qFiles = append(qFiles, q) qFiles = append(qFiles, q)
} }
return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil return &qLogReader{
logger: logger,
qFiles: qFiles,
currentFile: len(qFiles) - 1,
}, nil
} }
// seekTS performs binary search of a query log record with the specified // seekTS performs binary search of a query log record with the specified
// timestamp. If the record is found, it sets qLogReader's position to point // timestamp. If the record is found, it sets qLogReader's position to point
// to that line, so that the next ReadNext call returned this line. // to that line, so that the next ReadNext call returned this line.
func (r *qLogReader) seekTS(timestamp int64) (err error) { func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) {
for i := len(r.qFiles) - 1; i >= 0; i-- { for i := len(r.qFiles) - 1; i >= 0; i-- {
q := r.qFiles[i] q := r.qFiles[i]
_, _, err = q.seekTS(timestamp) _, _, err = q.seekTS(ctx, r.logger, timestamp)
if err != nil { if err != nil {
if errors.Is(err, errTSTooEarly) { if errors.Is(err, errTSTooEarly) {
// Look at the next file, since we've reached the end of this // Look at the next file, since we've reached the end of this

View File

@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader
testFiles := prepareTestFiles(t, filesNum, linesNum) testFiles := prepareTestFiles(t, filesNum, linesNum)
logger := slogutil.NewDiscardLogger()
ctx := testutil.ContextWithTimeout(t, testTimeout)
// Create the new qLogReader instance. // Create the new qLogReader instance.
reader, err := newQLogReader(testFiles) reader, err := newQLogReader(ctx, logger, testFiles)
require.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, reader) assert.NotNil(t, reader)
@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) {
func TestQLogReader_Seek(t *testing.T) { func TestQLogReader_Seek(t *testing.T) {
r := newTestQLogReader(t, 2, 10000) r := newTestQLogReader(t, 2, 10000)
ctx := testutil.ContextWithTimeout(t, testTimeout)
testCases := []struct { testCases := []struct {
want error want error
@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) {
ts, err := time.Parse(time.RFC3339Nano, tc.time) ts, err := time.Parse(time.RFC3339Nano, tc.time)
require.NoError(t, err) require.NoError(t, err)
err = r.seekTS(ts.UnixNano()) err = r.seekTS(ctx, ts.UnixNano())
assert.ErrorIs(t, err, tc.want) assert.ErrorIs(t, err, tc.want)
}) })
} }

View File

@ -2,6 +2,7 @@ package querylog
import ( import (
"fmt" "fmt"
"log/slog"
"net" "net"
"path/filepath" "path/filepath"
"sync" "sync"
@ -12,20 +13,19 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/service"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// QueryLog - main interface // QueryLog is the query log interface for use by other packages.
type QueryLog interface { type QueryLog interface {
Start() // Interface starts and stops the query log.
service.Interface
// Close query log object // Add adds a log entry.
Close()
// Add a log entry
Add(params *AddParams) Add(params *AddParams)
// WriteDiskConfig - write configuration // WriteDiskConfig writes the query log configuration to c.
WriteDiskConfig(c *Config) WriteDiskConfig(c *Config)
// ShouldLog returns true if request for the host should be logged. // ShouldLog returns true if request for the host should be logged.
@ -36,6 +36,10 @@ type QueryLog interface {
// //
// Do not alter any fields of this structure after using it. // Do not alter any fields of this structure after using it.
type Config struct { type Config struct {
// Logger is used for logging the operation of the query log. It must not
// be nil.
Logger *slog.Logger
// Ignored contains the list of host names, which should not be written to // Ignored contains the list of host names, which should not be written to
// log, and matches them. // log, and matches them.
Ignored *aghnet.IgnoreEngine Ignored *aghnet.IgnoreEngine
@ -151,6 +155,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
} }
l = &queryLog{ l = &queryLog{
logger: conf.Logger,
findClient: findClient, findClient: findClient,
buffer: container.NewRingBuffer[*logEntry](memSize), buffer: container.NewRingBuffer[*logEntry](memSize),

View File

@ -2,6 +2,7 @@ package querylog
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
@ -9,28 +10,30 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/c2h5oh/datasize"
) )
// flushLogBuffer flushes the current buffer to file and resets the current // flushLogBuffer flushes the current buffer to file and resets the current
// buffer. // buffer.
func (l *queryLog) flushLogBuffer() (err error) { func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }() defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
b, err := l.encodeEntries() b, err := l.encodeEntries(ctx)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
} }
return l.flushToFile(b) return l.flushToFile(ctx, b)
} }
// encodeEntries returns JSON encoded log entries, logs estimated time, clears // encodeEntries returns JSON encoded log entries, logs estimated time, clears
// the log buffer. // the log buffer.
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) { func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) {
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
@ -55,8 +58,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
return nil, err return nil, err
} }
size := b.Len()
elapsed := time.Since(start) elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen)) l.logger.DebugContext(
ctx,
"serialized elements via json",
"count", bufLen,
"elapsed", elapsed,
"size", datasize.ByteSize(size),
"size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)),
"time_per_entry", elapsed/time.Duration(bufLen),
)
l.buffer.Clear() l.buffer.Clear()
l.flushPending = false l.flushPending = false
@ -65,7 +77,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
} }
// flushToFile saves the encoded log entries to the query log file. // flushToFile saves the encoded log entries to the query log file.
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) { func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) {
l.fileWriteLock.Lock() l.fileWriteLock.Lock()
defer l.fileWriteLock.Unlock() defer l.fileWriteLock.Unlock()
@ -83,19 +95,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
return fmt.Errorf("writing to file %q: %w", filename, err) return fmt.Errorf("writing to file %q: %w", filename, err)
} }
log.Debug("querylog: ok %q: %v bytes written", filename, n) l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n))
return nil return nil
} }
func (l *queryLog) rotate() error { func (l *queryLog) rotate(ctx context.Context) error {
from := l.logFile from := l.logFile
to := l.logFile + ".1" to := l.logFile + ".1"
err := os.Rename(from, to) err := os.Rename(from, to)
if err != nil { if err != nil {
if errors.Is(err, os.ErrNotExist) { if errors.Is(err, os.ErrNotExist) {
log.Debug("querylog: no log to rotate") l.logger.DebugContext(ctx, "no log to rotate")
return nil return nil
} }
@ -103,12 +115,12 @@ func (l *queryLog) rotate() error {
return fmt.Errorf("failed to rename old file: %w", err) return fmt.Errorf("failed to rename old file: %w", err)
} }
log.Debug("querylog: renamed %s into %s", from, to) l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to)
return nil return nil
} }
func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) { func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) {
var f *os.File var f *os.File
f, err = os.Open(l.logFile) f, err = os.Open(l.logFile)
if err != nil { if err != nil {
@ -130,15 +142,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
return time.Time{}, err return time.Time{}, err
} }
log.Debug("querylog: the oldest log entry: %s", val) l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val)
return t, nil return t, nil
} }
func (l *queryLog) periodicRotate() { func (l *queryLog) periodicRotate(ctx context.Context) {
defer log.OnPanic("querylog: rotating") defer slogutil.RecoverAndLog(ctx, l.logger)
l.checkAndRotate() l.checkAndRotate(ctx)
// rotationCheckIvl is the period of time between checking the need for // rotationCheckIvl is the period of time between checking the need for
// rotating log files. It's smaller of any available rotation interval to // rotating log files. It's smaller of any available rotation interval to
@ -151,13 +163,13 @@ func (l *queryLog) periodicRotate() {
defer rotations.Stop() defer rotations.Stop()
for range rotations.C { for range rotations.C {
l.checkAndRotate() l.checkAndRotate(ctx)
} }
} }
// checkAndRotate rotates log files if those are older than the specified // checkAndRotate rotates log files if those are older than the specified
// rotation interval. // rotation interval.
func (l *queryLog) checkAndRotate() { func (l *queryLog) checkAndRotate(ctx context.Context) {
var rotationIvl time.Duration var rotationIvl time.Duration
func() { func() {
l.confMu.RLock() l.confMu.RLock()
@ -166,29 +178,30 @@ func (l *queryLog) checkAndRotate() {
rotationIvl = l.conf.RotationIvl rotationIvl = l.conf.RotationIvl
}() }()
oldest, err := l.readFileFirstTimeValue() oldest, err := l.readFileFirstTimeValue(ctx)
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
log.Error("querylog: reading oldest record for rotation: %s", err) l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err)
return return
} }
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) { if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
log.Debug( l.logger.DebugContext(
"querylog: %s <= %s, not rotating", ctx,
now.Format(time.RFC3339), "not rotating",
rotTime.Format(time.RFC3339), "now", now.Format(time.RFC3339),
"rotate_time", rotTime.Format(time.RFC3339),
) )
return return
} }
err = l.rotate() err = l.rotate(ctx)
if err != nil { if err != nil {
log.Error("querylog: rotating: %s", err) l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err)
return return
} }
log.Debug("querylog: rotated successfully") l.logger.DebugContext(ctx, "rotated successfully")
} }

View File

@ -1,13 +1,15 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log/slog"
"slices" "slices"
"time" "time"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil"
) )
// client finds the client info, if any, by its ClientID and IP address, // client finds the client info, if any, by its ClientID and IP address,
@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
// buffer. It optionally uses the client cache, if provided. It also returns // buffer. It optionally uses the client cache, if provided. It also returns
// the total amount of records in the buffer at the moment of searching. // the total amount of records in the buffer at the moment of searching.
// l.confMu is expected to be locked. // l.confMu is expected to be locked.
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) { func (l *queryLog) searchMemory(
ctx context.Context,
params *searchParams,
cache clientCache,
) (entries []*logEntry, total int) {
// Check memory size, as the buffer can contain a single log record. See // Check memory size, as the buffer can contain a single log record. See
// [newQueryLog]. // [newQueryLog].
if l.conf.MemSize == 0 { if l.conf.MemSize == 0 {
@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
var err error var err error
e.client, err = l.client(e.ClientID, e.IP.String(), cache) e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil { if err != nil {
msg := "querylog: enriching memory record at time %s" + l.logger.ErrorContext(
" for client %q (clientid %q): %s" ctx,
log.Error(msg, e.Time, e.IP, e.ClientID, err) "enriching memory record",
"at", e.Time,
"client_ip", e.IP,
"client_id", e.ClientID,
slogutil.KeyError, err,
)
// Go on and try to match anyway. // Go on and try to match anyway.
} }
@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
// search searches log entries in memory buffer and log file using specified // search searches log entries in memory buffer and log file using specified
// parameters and returns the list of entries found and the time of the oldest // parameters and returns the list of entries found and the time of the oldest
// entry. l.confMu is expected to be locked. // entry. l.confMu is expected to be locked.
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) { func (l *queryLog) search(
ctx context.Context,
params *searchParams,
) (entries []*logEntry, oldest time.Time) {
start := time.Now() start := time.Now()
if params.limit == 0 { if params.limit == 0 {
@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
cache := clientCache{} cache := clientCache{}
memoryEntries, bufLen := l.searchMemory(params, cache) memoryEntries, bufLen := l.searchMemory(ctx, params, cache)
log.Debug("querylog: got %d entries from memory", len(memoryEntries)) l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries))
fileEntries, oldest, total := l.searchFiles(params, cache) fileEntries, oldest, total := l.searchFiles(ctx, params, cache)
log.Debug("querylog: got %d entries from files", len(fileEntries)) l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries))
total += bufLen total += bufLen
@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
oldest = entries[len(entries)-1].Time oldest = entries[len(entries)-1].Time
} }
log.Debug( l.logger.DebugContext(
"querylog: prepared data (%d/%d) older than %s in %s", ctx,
len(entries), "prepared data",
total, "count", len(entries),
params.olderThan, "total", total,
time.Since(start), "older_than", params.olderThan,
"elapsed", time.Since(start),
) )
return entries, oldest return entries, oldest
@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
// seekRecord changes the current position to the next record older than the // seekRecord changes the current position to the next record older than the
// provided parameter. // provided parameter.
func (r *qLogReader) seekRecord(olderThan time.Time) (err error) { func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) {
if olderThan.IsZero() { if olderThan.IsZero() {
return r.SeekStart() return r.SeekStart()
} }
err = r.seekTS(olderThan.UnixNano()) err = r.seekTS(ctx, olderThan.UnixNano())
if err == nil { if err == nil {
// Read to the next record, because we only need the one that goes // Read to the next record, because we only need the one that goes
// after it. // after it.
@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
// setQLogReader creates a reader with the specified files and sets the // setQLogReader creates a reader with the specified files and sets the
// position to the next record older than the provided parameter. // position to the next record older than the provided parameter.
func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) { func (l *queryLog) setQLogReader(
ctx context.Context,
olderThan time.Time,
) (qr *qLogReader, err error) {
files := []string{ files := []string{
l.logFile + ".1", l.logFile + ".1",
l.logFile, l.logFile,
} }
r, err := newQLogReader(files) r, err := newQLogReader(ctx, l.logger, files)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening qlog reader: %s", err) return nil, fmt.Errorf("opening qlog reader: %w", err)
} }
err = r.seekRecord(olderThan) err = r.seekRecord(ctx, olderThan)
if err != nil { if err != nil {
defer func() { err = errors.WithDeferred(err, r.Close()) }() defer func() { err = errors.WithDeferred(err, r.Close()) }()
log.Debug("querylog: cannot seek to %s: %s", olderThan, err) l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err)
return nil, nil return nil, nil
} }
@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error
// calls faster so that the UI could handle it and show something quicker. // calls faster so that the UI could handle it and show something quicker.
// This behavior can be overridden if maxFileScanEntries is set to 0. // This behavior can be overridden if maxFileScanEntries is set to 0.
func (l *queryLog) readEntries( func (l *queryLog) readEntries(
ctx context.Context,
r *qLogReader, r *qLogReader,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
totalLimit int, totalLimit int,
) (entries []*logEntry, oldestNano int64, total int) { ) (entries []*logEntry, oldestNano int64, total int) {
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 { for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
ent, ts, rErr := l.readNextEntry(r, params, cache) ent, ts, rErr := l.readNextEntry(ctx, r, params, cache)
if rErr != nil { if rErr != nil {
if rErr == io.EOF { if rErr == io.EOF {
oldestNano = 0 oldestNano = 0
@ -205,7 +224,7 @@ func (l *queryLog) readEntries(
break break
} }
log.Error("querylog: reading next entry: %s", rErr) l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr)
} }
oldestNano = ts oldestNano = ts
@ -231,12 +250,13 @@ func (l *queryLog) readEntries(
// and the total number of processed entries, including discarded ones, // and the total number of processed entries, including discarded ones,
// correspondingly. // correspondingly.
func (l *queryLog) searchFiles( func (l *queryLog) searchFiles(
ctx context.Context,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
) (entries []*logEntry, oldest time.Time, total int) { ) (entries []*logEntry, oldest time.Time, total int) {
r, err := l.setQLogReader(params.olderThan) r, err := l.setQLogReader(ctx, params.olderThan)
if err != nil { if err != nil {
log.Error("querylog: %s", err) l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err)
} }
if r == nil { if r == nil {
@ -245,12 +265,12 @@ func (l *queryLog) searchFiles(
defer func() { defer func() {
if closeErr := r.Close(); closeErr != nil { if closeErr := r.Close(); closeErr != nil {
log.Error("querylog: closing file: %s", closeErr) l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr)
} }
}() }()
totalLimit := params.offset + params.limit totalLimit := params.offset + params.limit
entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit) entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit)
if oldestNano != 0 { if oldestNano != 0 {
oldest = time.Unix(0, oldestNano) oldest = time.Unix(0, oldestNano)
} }
@ -266,15 +286,21 @@ type quickMatchClientFinder struct {
} }
// findClient is a method that can be used as a quickMatchClientFinder. // findClient is a method that can be used as a quickMatchClientFinder.
func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) { func (f quickMatchClientFinder) findClient(
ctx context.Context,
logger *slog.Logger,
clientID string,
ip string,
) (c *Client) {
var err error var err error
c, err = f.client(clientID, ip, f.cache) c, err = f.client(clientID, ip, f.cache)
if err != nil { if err != nil {
log.Error( logger.ErrorContext(
"querylog: enriching file record for quick search: for client %q (clientid %q): %s", ctx,
ip, "enriching file record for quick search",
clientID, "client_ip", ip,
err, "client_id", clientID,
slogutil.KeyError, err,
) )
} }
@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
// the entry doesn't match the search criteria. ts is the timestamp of the // the entry doesn't match the search criteria. ts is the timestamp of the
// processed entry. // processed entry.
func (l *queryLog) readNextEntry( func (l *queryLog) readNextEntry(
ctx context.Context,
r *qLogReader, r *qLogReader,
params *searchParams, params *searchParams,
cache clientCache, cache clientCache,
@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry(
cache: cache, cache: cache,
} }
if !params.quickMatch(line, clientFinder.findClient) { if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) {
ts = readQLogTimestamp(line) ts = readQLogTimestamp(ctx, l.logger, line)
return nil, ts, nil return nil, ts, nil
} }
e = &logEntry{} e = &logEntry{}
decodeLogEntry(e, line) l.decodeLogEntry(ctx, e, line)
if l.isIgnored(e.QHost) { if l.isIgnored(e.QHost) {
return nil, ts, nil return nil, ts, nil
@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry(
e.client, err = l.client(e.ClientID, e.IP.String(), cache) e.client, err = l.client(e.ClientID, e.IP.String(), cache)
if err != nil { if err != nil {
log.Error( l.logger.ErrorContext(
"querylog: enriching file record at time %s for client %q (clientid %q): %s", ctx,
e.Time, "enriching file record",
e.IP, "at", e.Time,
e.ClientID, "client_ip", e.IP,
err, "client_id", e.ClientID,
slogutil.KeyError, err,
) )
// Go on and try to match anyway. // Go on and try to match anyway.

View File

@ -5,6 +5,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
} }
l, err := newQueryLog(Config{ l, err := newQueryLog(Config{
Logger: slogutil.NewDiscardLogger(),
FindClient: findClient, FindClient: findClient,
BaseDir: t.TempDir(), BaseDir: t.TempDir(),
RotationIvl: timeutil.Day, RotationIvl: timeutil.Day,
@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) {
AnonymizeClientIP: false, AnonymizeClientIP: false,
}) })
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(l.Close)
ctx := testutil.ContextWithTimeout(t, testTimeout)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return l.Shutdown(ctx)
})
q := &dns.Msg{ q := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{
@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
olderThan: time.Now().Add(10 * time.Second), olderThan: time.Now().Add(10 * time.Second),
limit: 3, limit: 3,
} }
entries, _ := l.search(sp) entries, _ := l.search(ctx, sp)
assert.Equal(t, 2, findClientCalls) assert.Equal(t, 2, findClientCalls)
require.Len(t, entries, 3) require.Len(t, entries, 3)

View File

@ -1,7 +1,9 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict(
// quickMatch quickly checks if the line matches the given search criterion. // quickMatch quickly checks if the line matches the given search criterion.
// It returns false if the like doesn't match. This method is only here for // It returns false if the like doesn't match. This method is only here for
// optimization purposes. // optimization purposes.
func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { func (c *searchCriterion) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
switch c.criterionType { switch c.criterionType {
case ctTerm: case ctTerm:
host := readJSONValue(line, `"QH":"`) host := readJSONValue(line, `"QH":"`)
@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun
clientID := readJSONValue(line, `"CID":"`) clientID := readJSONValue(line, `"CID":"`)
var name string var name string
if cli := findClient(clientID, ip); cli != nil { if cli := findClient(ctx, logger, clientID, ip); cli != nil {
name = cli.Name name = cli.Name
} }

View File

@ -1,6 +1,10 @@
package querylog package querylog
import "time" import (
"context"
"log/slog"
"time"
)
// searchParams represent the search query sent by the client. // searchParams represent the search query sent by the client.
type searchParams struct { type searchParams struct {
@ -35,14 +39,23 @@ func newSearchParams() *searchParams {
} }
// quickMatchClientFunc is a simplified client finder for quick matches. // quickMatchClientFunc is a simplified client finder for quick matches.
type quickMatchClientFunc = func(clientID, ip string) (c *Client) type quickMatchClientFunc = func(
ctx context.Context,
logger *slog.Logger,
clientID, ip string,
) (c *Client)
// quickMatch quickly checks if the line matches the given search parameters. // quickMatch quickly checks if the line matches the given search parameters.
// It returns false if the line doesn't match. This method is only here for // It returns false if the line doesn't match. This method is only here for
// optimization purposes. // optimization purposes.
func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) { func (s *searchParams) quickMatch(
ctx context.Context,
logger *slog.Logger,
line string,
findClient quickMatchClientFunc,
) (ok bool) {
for _, c := range s.searchCriteria { for _, c := range s.searchCriteria {
if !c.quickMatch(line, findClient) { if !c.quickMatch(ctx, logger, line, findClient) {
return false return false
} }
} }