Merge branch 'master' into AGDNS-2556-custom-update-url
This commit is contained in:
commit
c9efb412e7
|
@ -32,6 +32,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||
|
||||
- The release executables are now signed.
|
||||
|
||||
### Added
|
||||
|
||||
- The `--no-permcheck` command-line option to disable checking and migration of
|
||||
permissions for the security-sensitive files and directories, which caused
|
||||
issues on Windows ([#7400]).
|
||||
|
||||
[#7400]: https://github.com/AdguardTeam/AdGuardHome/issues/7400
|
||||
|
||||
[go-1.23.3]: https://groups.google.com/g/golang-announce/c/X5KodEJYuqI
|
||||
|
||||
<!--
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
|
@ -506,7 +505,7 @@ func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
|||
|
||||
// RemoveByName removes persistent client information. ok is false if no such
|
||||
// client exists by that name.
|
||||
func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||
func (s *Storage) RemoveByName(ctx context.Context, name string) (ok bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
@ -516,7 +515,7 @@ func (s *Storage) RemoveByName(name string) (ok bool) {
|
|||
}
|
||||
|
||||
if err := p.CloseUpstreams(); err != nil {
|
||||
log.Error("client storage: removing client %q: %s", p.Name, err)
|
||||
s.logger.ErrorContext(ctx, "removing client", "name", p.Name, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
s.index.remove(p)
|
||||
|
|
|
@ -735,7 +735,7 @@ func TestStorage_RemoveByName(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.want(t, s.RemoveByName(tc.cliName))
|
||||
tc.want(t, s.RemoveByName(ctx, tc.cliName))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -744,8 +744,8 @@ func TestStorage_RemoveByName(t *testing.T) {
|
|||
err = s.Add(ctx, existingClient)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, s.RemoveByName(existingName))
|
||||
assert.False(t, s.RemoveByName(existingName))
|
||||
assert.True(t, s.RemoveByName(ctx, existingName))
|
||||
assert.False(t, s.RemoveByName(ctx, existingName))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -369,7 +369,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
|||
return
|
||||
}
|
||||
|
||||
if !clients.storage.RemoveByName(cj.Name) {
|
||||
if !clients.storage.RemoveByName(r.Context(), cj.Name) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
|
||||
|
||||
return
|
||||
|
|
|
@ -48,11 +48,11 @@ func onConfigModified() {
|
|||
// initDNS updates all the fields of the [Context] needed to initialize the DNS
|
||||
// server and initializes it at last. It also must not be called unless
|
||||
// [config] and [Context] are initialized. l must not be nil.
|
||||
func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
func initDNS(baseLogger *slog.Logger, statsDir, querylogDir string) (err error) {
|
||||
anonymizer := config.anonymizer()
|
||||
|
||||
statsConf := stats.Config{
|
||||
Logger: l.With(slogutil.KeyPrefix, "stats"),
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: config.Stats.Interval.Duration,
|
||||
ConfigModified: onConfigModified,
|
||||
|
@ -73,6 +73,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
|||
}
|
||||
|
||||
conf := querylog.Config{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
HTTPRegister: httpRegister,
|
||||
|
@ -113,7 +114,7 @@ func initDNS(l *slog.Logger, statsDir, querylogDir string) (err error) {
|
|||
anonymizer,
|
||||
httpRegister,
|
||||
tlsConf,
|
||||
l,
|
||||
baseLogger,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -457,7 +458,8 @@ func startDNSServer() error {
|
|||
Context.filters.EnableFilters(false)
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.clients.Start(context.TODO())
|
||||
ctx := context.TODO()
|
||||
err := Context.clients.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
|
@ -469,7 +471,11 @@ func startDNSServer() error {
|
|||
|
||||
Context.filters.Start()
|
||||
Context.stats.Start()
|
||||
Context.queryLog.Start()
|
||||
|
||||
err = Context.queryLog.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting query log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -525,12 +531,16 @@ func closeDNSServer() {
|
|||
if Context.stats != nil {
|
||||
err := Context.stats.Close()
|
||||
if err != nil {
|
||||
log.Debug("closing stats: %s", err)
|
||||
log.Error("closing stats: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if Context.queryLog != nil {
|
||||
Context.queryLog.Close()
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := Context.queryLog.Shutdown(context.TODO())
|
||||
if err != nil {
|
||||
log.Error("closing query log: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("all dns modules are closed")
|
||||
|
|
|
@ -159,7 +159,7 @@ func setupContext(opts options) (err error) {
|
|||
|
||||
if Context.firstRun {
|
||||
log.Info("This is the first time AdGuard Home is launched")
|
||||
checkPermissions()
|
||||
checkNetworkPermissions()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -666,12 +666,10 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
|||
}
|
||||
}
|
||||
|
||||
if permcheck.NeedsMigration(confPath) {
|
||||
permcheck.Migrate(Context.workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
if !opts.noPermCheck {
|
||||
checkPermissions(Context.workDir, confPath, dataDir, statsDir, querylogDir)
|
||||
}
|
||||
|
||||
permcheck.Check(Context.workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
|
||||
Context.web.start()
|
||||
|
||||
// Wait for other goroutines to complete their job.
|
||||
|
@ -740,6 +738,16 @@ func newUpdater(
|
|||
})
|
||||
}
|
||||
|
||||
// checkPermissions checks and migrates permissions of the files and directories
|
||||
// used by AdGuard Home, if needed.
|
||||
func checkPermissions(workDir, confPath, dataDir, statsDir, querylogDir string) {
|
||||
if permcheck.NeedsMigration(confPath) {
|
||||
permcheck.Migrate(workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
}
|
||||
|
||||
permcheck.Check(workDir, dataDir, statsDir, querylogDir, confPath)
|
||||
}
|
||||
|
||||
// initUsers initializes context auth module. Clears config users field.
|
||||
func initUsers() (auth *Auth, err error) {
|
||||
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
|
||||
|
@ -799,8 +807,9 @@ func startMods(l *slog.Logger) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Check if the current user permissions are enough to run AdGuard Home
|
||||
func checkPermissions() {
|
||||
// checkNetworkPermissions checks if the current user permissions are enough to
|
||||
// use the required networking functionality.
|
||||
func checkNetworkPermissions() {
|
||||
log.Info("Checking if AdGuard Home has necessary permissions")
|
||||
|
||||
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
|
||||
|
|
|
@ -78,6 +78,10 @@ type options struct {
|
|||
// localFrontend forces AdGuard Home to use the frontend files from disk
|
||||
// rather than the ones that have been compiled into the binary.
|
||||
localFrontend bool
|
||||
|
||||
// noPermCheck disables checking and migration of permissions for the
|
||||
// security-sensitive files.
|
||||
noPermCheck bool
|
||||
}
|
||||
|
||||
// initCmdLineOpts completes initialization of the global command-line option
|
||||
|
@ -305,6 +309,15 @@ var cmdLineOpts = []cmdLineOpt{{
|
|||
description: "Run in GL-Inet compatibility mode.",
|
||||
longName: "glinet",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: func(o options) (options, error) { o.noPermCheck = true; return o, nil },
|
||||
effect: nil,
|
||||
serialize: func(o options) (val string, ok bool) { return "", o.noPermCheck },
|
||||
description: "Skip checking and migration of permissions " +
|
||||
"of security-sensitive files.",
|
||||
longName: "no-permcheck",
|
||||
shortName: "",
|
||||
}, {
|
||||
updateWithValue: nil,
|
||||
updateNoValue: nil,
|
||||
|
|
|
@ -89,6 +89,12 @@ type options struct {
|
|||
// TODO(a.garipov): Use.
|
||||
performUpdate bool
|
||||
|
||||
// noPermCheck, if true, instructs AdGuard Home to skip checking and
|
||||
// migrating the permissions of its security-sensitive files.
|
||||
//
|
||||
// TODO(e.burkov): Use.
|
||||
noPermCheck bool
|
||||
|
||||
// verbose, if true, instructs AdGuard Home to enable verbose logging.
|
||||
verbose bool
|
||||
|
||||
|
@ -110,7 +116,8 @@ const (
|
|||
disableUpdateIdx
|
||||
glinetModeIdx
|
||||
helpIdx
|
||||
localFrontend
|
||||
localFrontendIdx
|
||||
noPermCheckIdx
|
||||
performUpdateIdx
|
||||
verboseIdx
|
||||
versionIdx
|
||||
|
@ -214,7 +221,7 @@ var commandLineOptions = []*commandLineOption{
|
|||
valueType: "",
|
||||
},
|
||||
|
||||
localFrontend: {
|
||||
localFrontendIdx: {
|
||||
defaultValue: false,
|
||||
description: "Use local frontend directories.",
|
||||
long: "local-frontend",
|
||||
|
@ -222,6 +229,14 @@ var commandLineOptions = []*commandLineOption{
|
|||
valueType: "",
|
||||
},
|
||||
|
||||
noPermCheckIdx: {
|
||||
defaultValue: false,
|
||||
description: "Skip checking the permissions of security-sensitive files.",
|
||||
long: "no-permcheck",
|
||||
short: "",
|
||||
valueType: "",
|
||||
},
|
||||
|
||||
performUpdateIdx: {
|
||||
defaultValue: false,
|
||||
description: "Update the current binary and restart the service in case it's installed.",
|
||||
|
@ -264,7 +279,8 @@ func parseOptions(cmdName string, args []string) (opts *options, err error) {
|
|||
disableUpdateIdx: &opts.disableUpdate,
|
||||
glinetModeIdx: &opts.glinetMode,
|
||||
helpIdx: &opts.help,
|
||||
localFrontend: &opts.localFrontend,
|
||||
localFrontendIdx: &opts.localFrontend,
|
||||
noPermCheckIdx: &opts.noPermCheck,
|
||||
performUpdateIdx: &opts.performUpdate,
|
||||
verboseIdx: &opts.verbose,
|
||||
versionIdx: &opts.version,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
@ -13,7 +14,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -174,26 +175,32 @@ var logEntryHandlers = map[string]logEntryHandler{
|
|||
}
|
||||
|
||||
// decodeResultRuleKey decodes the token of "Rules" type to logEntry struct.
|
||||
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultRuleKey(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
i int,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
var vToken json.Token
|
||||
switch key {
|
||||
case "FilterListID":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if n, ok := vToken.(json.Number); ok {
|
||||
id, _ := n.Int64()
|
||||
ent.Result.Rules[i].FilterListID = rulelist.URLFilterID(id)
|
||||
}
|
||||
case "IP":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if ipStr, ok := vToken.(string); ok {
|
||||
if ip, err := netip.ParseAddr(ipStr); err == nil {
|
||||
ent.Result.Rules[i].IP = ip
|
||||
} else {
|
||||
log.Debug("querylog: decoding ipStr value: %s", err)
|
||||
l.logger.DebugContext(ctx, "decoding ip", "value", ipStr, slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
case "Text":
|
||||
ent.Result.Rules, vToken = decodeVTokenAndAddRule(key, i, dec, ent.Result.Rules)
|
||||
ent.Result.Rules, vToken = l.decodeVTokenAndAddRule(ctx, key, i, dec, ent.Result.Rules)
|
||||
if s, ok := vToken.(string); ok {
|
||||
ent.Result.Rules[i].Text = s
|
||||
}
|
||||
|
@ -204,7 +211,8 @@ func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
|
|||
|
||||
// decodeVTokenAndAddRule decodes the "Rules" toke as [filtering.ResultRule]
|
||||
// and then adds the decoded object to the slice of result rules.
|
||||
func decodeVTokenAndAddRule(
|
||||
func (l *queryLog) decodeVTokenAndAddRule(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
i int,
|
||||
dec *json.Decoder,
|
||||
|
@ -215,7 +223,12 @@ func decodeVTokenAndAddRule(
|
|||
vToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultRuleKey %s err: %s", key, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"decoding result rule key",
|
||||
"key", key,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
return newRules, nil
|
||||
|
@ -230,12 +243,14 @@ func decodeVTokenAndAddRule(
|
|||
|
||||
// decodeResultRules parses the dec's tokens into logEntry ent interpreting it
|
||||
// as a slice of the result rules.
|
||||
func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultRules(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result rules"
|
||||
|
||||
for {
|
||||
delimToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultRules err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -244,13 +259,17 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
|||
if d, ok := delimToken.(json.Delim); !ok {
|
||||
return
|
||||
} else if d != '[' {
|
||||
log.Debug("decodeResultRules: unexpected delim %q", d)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(d),
|
||||
)
|
||||
}
|
||||
|
||||
err = decodeResultRuleToken(dec, ent)
|
||||
err = l.decodeResultRuleToken(ctx, dec, ent)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResultRules err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; rule token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -259,7 +278,11 @@ func decodeResultRules(dec *json.Decoder, ent *logEntry) {
|
|||
}
|
||||
|
||||
// decodeResultRuleToken decodes the tokens of "Rules" type to the logEntry ent.
|
||||
func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
||||
func (l *queryLog) decodeResultRuleToken(
|
||||
ctx context.Context,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) (err error) {
|
||||
i := 0
|
||||
for {
|
||||
var keyToken json.Token
|
||||
|
@ -287,7 +310,7 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
|||
return fmt.Errorf("keyToken is %T (%[1]v) and not string", keyToken)
|
||||
}
|
||||
|
||||
decodeResultRuleKey(key, i, dec, ent)
|
||||
l.decodeResultRuleKey(ctx, key, i, dec, ent)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -296,12 +319,14 @@ func decodeResultRuleToken(dec *json.Decoder, ent *logEntry) (err error) {
|
|||
// other occurrences of DNSRewriteResult in the entry since hosts container's
|
||||
// rewrites currently has the highest priority along the entire filtering
|
||||
// pipeline.
|
||||
func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultReverseHosts(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result reverse hosts"
|
||||
|
||||
for {
|
||||
itemToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultReverseHosts err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -315,7 +340,11 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
|||
return
|
||||
}
|
||||
|
||||
log.Debug("decodeResultReverseHosts: unexpected delim %q", v)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(v),
|
||||
)
|
||||
|
||||
return
|
||||
case string:
|
||||
|
@ -346,12 +375,14 @@ func decodeResultReverseHosts(dec *json.Decoder, ent *logEntry) {
|
|||
|
||||
// decodeResultIPList parses the dec's tokens into logEntry ent interpreting it
|
||||
// as the result IP addresses list.
|
||||
func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultIPList(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result ip list"
|
||||
|
||||
for {
|
||||
itemToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultIPList err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -365,7 +396,11 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
|||
return
|
||||
}
|
||||
|
||||
log.Debug("decodeResultIPList: unexpected delim %q", v)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
msgPrefix,
|
||||
slogutil.KeyError, newUnexpectedDelimiterError(v),
|
||||
)
|
||||
|
||||
return
|
||||
case string:
|
||||
|
@ -382,7 +417,14 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
|||
|
||||
// decodeResultDNSRewriteResultKey decodes the token of "DNSRewriteResult" type
|
||||
// to the logEntry struct.
|
||||
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultDNSRewriteResultKey(
|
||||
ctx context.Context,
|
||||
key string,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
const msgPrefix = "decoding result dns rewrite result key"
|
||||
|
||||
var err error
|
||||
|
||||
switch key {
|
||||
|
@ -391,7 +433,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
|||
vToken, err = dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeResultDNSRewriteResultKey err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -419,7 +461,7 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
|||
// decoding and correct the values.
|
||||
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
|
||||
if err != nil {
|
||||
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; response", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
ent.parseDNSRewriteResultIPs()
|
||||
|
@ -430,12 +472,18 @@ func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntr
|
|||
|
||||
// decodeResultDNSRewriteResult parses the dec's tokens into logEntry ent
|
||||
// interpreting it as the result DNSRewriteResult.
|
||||
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResultDNSRewriteResult(
|
||||
ctx context.Context,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) {
|
||||
const msgPrefix = "decoding result dns rewrite result"
|
||||
|
||||
for {
|
||||
key, err := parseKeyToken(dec)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResultDNSRewriteResult: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -445,7 +493,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
|||
continue
|
||||
}
|
||||
|
||||
decodeResultDNSRewriteResultKey(key, dec, ent)
|
||||
l.decodeResultDNSRewriteResultKey(ctx, key, dec, ent)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -508,14 +556,16 @@ func parseKeyToken(dec *json.Decoder) (key string, err error) {
|
|||
}
|
||||
|
||||
// decodeResult decodes a token of "Result" type to logEntry struct.
|
||||
func decodeResult(dec *json.Decoder, ent *logEntry) {
|
||||
func (l *queryLog) decodeResult(ctx context.Context, dec *json.Decoder, ent *logEntry) {
|
||||
const msgPrefix = "decoding result"
|
||||
|
||||
defer translateResult(ent)
|
||||
|
||||
for {
|
||||
key, err := parseKeyToken(dec)
|
||||
if err != nil {
|
||||
if err != io.EOF && !errors.Is(err, ErrEndOfToken) {
|
||||
log.Debug("decodeResult: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -525,10 +575,8 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
|||
continue
|
||||
}
|
||||
|
||||
decHandler, ok := resultDecHandlers[key]
|
||||
ok := l.resultDecHandler(ctx, key, dec, ent)
|
||||
if ok {
|
||||
decHandler(dec, ent)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -543,7 +591,7 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
|
|||
}
|
||||
|
||||
if err = handler(val, ent); err != nil {
|
||||
log.Debug("decodeResult handler err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -636,16 +684,34 @@ var resultHandlers = map[string]logEntryHandler{
|
|||
},
|
||||
}
|
||||
|
||||
// resultDecHandlers is the map of decode handlers for various keys.
|
||||
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
|
||||
"ReverseHosts": decodeResultReverseHosts,
|
||||
"IPList": decodeResultIPList,
|
||||
"Rules": decodeResultRules,
|
||||
"DNSRewriteResult": decodeResultDNSRewriteResult,
|
||||
// resultDecHandlers calls a decode handler for key if there is one.
|
||||
func (l *queryLog) resultDecHandler(
|
||||
ctx context.Context,
|
||||
name string,
|
||||
dec *json.Decoder,
|
||||
ent *logEntry,
|
||||
) (ok bool) {
|
||||
ok = true
|
||||
switch name {
|
||||
case "ReverseHosts":
|
||||
l.decodeResultReverseHosts(ctx, dec, ent)
|
||||
case "IPList":
|
||||
l.decodeResultIPList(ctx, dec, ent)
|
||||
case "Rules":
|
||||
l.decodeResultRules(ctx, dec, ent)
|
||||
case "DNSRewriteResult":
|
||||
l.decodeResultDNSRewriteResult(ctx, dec, ent)
|
||||
default:
|
||||
ok = false
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// decodeLogEntry decodes string str to logEntry ent.
|
||||
func decodeLogEntry(ent *logEntry, str string) {
|
||||
func (l *queryLog) decodeLogEntry(ctx context.Context, ent *logEntry, str string) {
|
||||
const msgPrefix = "decoding log entry"
|
||||
|
||||
dec := json.NewDecoder(strings.NewReader(str))
|
||||
dec.UseNumber()
|
||||
|
||||
|
@ -653,7 +719,7 @@ func decodeLogEntry(ent *logEntry, str string) {
|
|||
keyToken, err := dec.Token()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
log.Debug("decodeLogEntry err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; token", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return
|
||||
|
@ -665,13 +731,14 @@ func decodeLogEntry(ent *logEntry, str string) {
|
|||
|
||||
key, ok := keyToken.(string)
|
||||
if !ok {
|
||||
log.Debug("decodeLogEntry: keyToken is %T (%[1]v) and not string", keyToken)
|
||||
err = fmt.Errorf("%s: keyToken is %T (%[2]v) and not string", msgPrefix, keyToken)
|
||||
l.logger.DebugContext(ctx, msgPrefix, slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if key == "Result" {
|
||||
decodeResult(dec, ent)
|
||||
l.decodeResult(ctx, dec, ent)
|
||||
|
||||
continue
|
||||
}
|
||||
|
@ -687,9 +754,14 @@ func decodeLogEntry(ent *logEntry, str string) {
|
|||
}
|
||||
|
||||
if err = handler(val, ent); err != nil {
|
||||
log.Debug("decodeLogEntry handler err: %s", err)
|
||||
l.logger.DebugContext(ctx, msgPrefix+"; handler", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newUnexpectedDelimiterError is a helper for creating informative errors.
|
||||
func newUnexpectedDelimiterError(d json.Delim) (err error) {
|
||||
return fmt.Errorf("unexpected delimiter: %q", d)
|
||||
}
|
||||
|
|
|
@ -3,27 +3,35 @@ package querylog
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Common constants for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
func TestDecodeLogEntry(t *testing.T) {
|
||||
logOutput := &bytes.Buffer{}
|
||||
l := &queryLog{
|
||||
logger: slog.New(slog.NewTextHandler(logOutput, &slog.HandlerOptions{
|
||||
Level: slog.LevelDebug,
|
||||
ReplaceAttr: slogutil.RemoveTime,
|
||||
})),
|
||||
}
|
||||
|
||||
aghtest.ReplaceLogWriter(t, logOutput)
|
||||
aghtest.ReplaceLogLevel(t, log.DEBUG)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==`
|
||||
|
@ -92,7 +100,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}
|
||||
|
||||
got := &logEntry{}
|
||||
decodeLogEntry(got, data)
|
||||
l.decodeLogEntry(ctx, got, data)
|
||||
|
||||
s := logOutput.String()
|
||||
assert.Empty(t, s)
|
||||
|
@ -113,11 +121,11 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}, {
|
||||
name: "bad_filter_id_old_rule",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"FilterID":1.5},"Elapsed":837429}`,
|
||||
want: "decodeResult handler err: strconv.ParseInt: parsing \"1.5\": invalid syntax\n",
|
||||
want: `level=DEBUG msg="decoding result; handler" err="strconv.ParseInt: parsing \"1.5\": invalid syntax"`,
|
||||
}, {
|
||||
name: "bad_is_filtered",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":trooe,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry err: invalid character 'o' in literal true (expecting 'u')\n",
|
||||
want: `level=DEBUG msg="decoding log entry; token" err="invalid character 'o' in literal true (expecting 'u')"`,
|
||||
}, {
|
||||
name: "bad_elapsed",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":-1}`,
|
||||
|
@ -129,7 +137,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}, {
|
||||
name: "bad_time",
|
||||
log: `{"IP":"127.0.0.1","T":"12/09/1998T15:00:00.000000+05:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\"\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="parsing time \"12/09/1998T15:00:00.000000+05:00\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"12/09/1998T15:00:00.000000+05:00\" as \"2006\""`,
|
||||
}, {
|
||||
name: "bad_host",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":6,"QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
|
@ -149,7 +157,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}, {
|
||||
name: "very_bad_client_proto",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"dog","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: invalid client proto: \"dog\"\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="invalid client proto: \"dog\""`,
|
||||
}, {
|
||||
name: "bad_answer",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":0.9,"Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
|
@ -157,7 +165,7 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}, {
|
||||
name: "very_bad_answer",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3},"Elapsed":837429}`,
|
||||
want: "decodeLogEntry handler err: illegal base64 data at input byte 61\n",
|
||||
want: `level=DEBUG msg="decoding log entry; handler" err="illegal base64 data at input byte 61"`,
|
||||
}, {
|
||||
name: "bad_rule",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"Rule":false},"Elapsed":837429}`,
|
||||
|
@ -169,22 +177,25 @@ func TestDecodeLogEntry(t *testing.T) {
|
|||
}, {
|
||||
name: "bad_reverse_hosts",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":[{}]},"Elapsed":837429}`,
|
||||
want: "decodeResultReverseHosts: unexpected delim \"{\"\n",
|
||||
want: `level=DEBUG msg="decoding result reverse hosts" err="unexpected delimiter: \"{\""`,
|
||||
}, {
|
||||
name: "bad_ip_list",
|
||||
log: `{"IP":"127.0.0.1","T":"2020-11-25T18:55:56.519796+03:00","QH":"an.yandex.ru","QT":"A","QC":"IN","CP":"","Answer":"Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==","Result":{"IsFiltered":true,"Reason":3,"ReverseHosts":["example.net"],"IPList":[{}]},"Elapsed":837429}`,
|
||||
want: "decodeResultIPList: unexpected delim \"{\"\n",
|
||||
want: `level=DEBUG msg="decoding result ip list" err="unexpected delimiter: \"{\""`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
decodeLogEntry(new(logEntry), tc.log)
|
||||
|
||||
s := logOutput.String()
|
||||
l.decodeLogEntry(ctx, new(logEntry), tc.log)
|
||||
got := logOutput.String()
|
||||
if tc.want == "" {
|
||||
assert.Empty(t, s)
|
||||
assert.Empty(t, got)
|
||||
} else {
|
||||
assert.True(t, strings.HasSuffix(s, tc.want), "got %q", s)
|
||||
require.NotEmpty(t, got)
|
||||
|
||||
// Remove newline.
|
||||
got = got[:len(got)-1]
|
||||
assert.Equal(t, tc.want, got)
|
||||
}
|
||||
|
||||
logOutput.Reset()
|
||||
|
@ -200,6 +211,12 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
|
|||
aaaa2 = aaaa1.Next()
|
||||
)
|
||||
|
||||
l := &queryLog{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
want *logEntry
|
||||
entry string
|
||||
|
@ -249,7 +266,7 @@ func TestDecodeLogEntry_backwardCompatability(t *testing.T) {
|
|||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e := &logEntry{}
|
||||
decodeLogEntry(e, tc.entry)
|
||||
l.decodeLogEntry(ctx, e, tc.entry)
|
||||
|
||||
assert.Equal(t, tc.want, e)
|
||||
})
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -52,7 +54,7 @@ func (e *logEntry) shallowClone() (clone *logEntry) {
|
|||
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
|
||||
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
|
||||
// errors are logged.
|
||||
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
|
||||
func (e *logEntry) addResponse(ctx context.Context, l *slog.Logger, resp *dns.Msg, isOrig bool) {
|
||||
if resp == nil {
|
||||
return
|
||||
}
|
||||
|
@ -65,8 +67,9 @@ func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
|
|||
e.Answer, err = resp.Pack()
|
||||
err = errors.Annotate(err, "packing answer: %w")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Error("querylog: %s", err)
|
||||
l.ErrorContext(ctx, "adding data from response", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
|
@ -15,7 +16,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
@ -74,7 +75,8 @@ func (l *queryLog) initWeb() {
|
|||
|
||||
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
|
||||
func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := parseSearchParams(r)
|
||||
ctx := r.Context()
|
||||
params, err := l.parseSearchParams(ctx, r)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "parsing params: %s", err)
|
||||
|
||||
|
@ -87,18 +89,18 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
|||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
entries, oldest = l.search(params)
|
||||
entries, oldest = l.search(ctx, params)
|
||||
}()
|
||||
|
||||
resp := entriesToJSON(entries, oldest, l.anonymizer.Load())
|
||||
resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
|
||||
// API.
|
||||
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
|
||||
l.clear()
|
||||
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, r *http.Request) {
|
||||
l.clear(r.Context())
|
||||
}
|
||||
|
||||
// handleQueryLogInfo is the handler for the GET /control/querylog_info HTTP
|
||||
|
@ -280,11 +282,12 @@ func getDoubleQuotesEnclosedValue(s *string) bool {
|
|||
}
|
||||
|
||||
// parseSearchCriterion parses a search criterion from the query parameter.
|
||||
func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
||||
ok bool,
|
||||
sc searchCriterion,
|
||||
err error,
|
||||
) {
|
||||
func (l *queryLog) parseSearchCriterion(
|
||||
ctx context.Context,
|
||||
q url.Values,
|
||||
name string,
|
||||
ct criterionType,
|
||||
) (ok bool, sc searchCriterion, err error) {
|
||||
val := q.Get(name)
|
||||
if val == "" {
|
||||
return false, sc, nil
|
||||
|
@ -301,7 +304,7 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
|||
// TODO(e.burkov): Make it work with parts of IDNAs somehow.
|
||||
loweredVal := strings.ToLower(val)
|
||||
if asciiVal, err = idna.ToASCII(loweredVal); err != nil {
|
||||
log.Debug("can't convert %q to ascii: %s", val, err)
|
||||
l.logger.DebugContext(ctx, "converting to ascii", "value", val, slogutil.KeyError, err)
|
||||
} else if asciiVal == loweredVal {
|
||||
// Purge asciiVal to prevent checking the same value
|
||||
// twice.
|
||||
|
@ -331,7 +334,10 @@ func parseSearchCriterion(q url.Values, name string, ct criterionType) (
|
|||
|
||||
// parseSearchParams parses search parameters from the HTTP request's query
|
||||
// string.
|
||||
func parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
||||
func (l *queryLog) parseSearchParams(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
) (p *searchParams, err error) {
|
||||
p = newSearchParams()
|
||||
|
||||
q := r.URL.Query()
|
||||
|
@ -369,7 +375,7 @@ func parseSearchParams(r *http.Request) (p *searchParams, err error) {
|
|||
}} {
|
||||
var ok bool
|
||||
var c searchCriterion
|
||||
ok, c, err = parseSearchCriterion(q, v.urlField, v.ct)
|
||||
ok, c, err = l.parseSearchCriterion(ctx, q, v.urlField, v.ct)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -8,7 +9,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
@ -19,7 +20,8 @@ import (
|
|||
type jobject = map[string]any
|
||||
|
||||
// entriesToJSON converts query log entries to JSON.
|
||||
func entriesToJSON(
|
||||
func (l *queryLog) entriesToJSON(
|
||||
ctx context.Context,
|
||||
entries []*logEntry,
|
||||
oldest time.Time,
|
||||
anonFunc aghnet.IPMutFunc,
|
||||
|
@ -28,7 +30,7 @@ func entriesToJSON(
|
|||
|
||||
// The elements order is already reversed to be from newer to older.
|
||||
for _, entry := range entries {
|
||||
jsonEntry := entryToJSON(entry, anonFunc)
|
||||
jsonEntry := l.entryToJSON(ctx, entry, anonFunc)
|
||||
data = append(data, jsonEntry)
|
||||
}
|
||||
|
||||
|
@ -44,7 +46,11 @@ func entriesToJSON(
|
|||
}
|
||||
|
||||
// entryToJSON converts a log entry's data into an entry for the JSON API.
|
||||
func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject) {
|
||||
func (l *queryLog) entryToJSON(
|
||||
ctx context.Context,
|
||||
entry *logEntry,
|
||||
anonFunc aghnet.IPMutFunc,
|
||||
) (jsonEntry jobject) {
|
||||
hostname := entry.QHost
|
||||
question := jobject{
|
||||
"type": entry.QType,
|
||||
|
@ -53,7 +59,12 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
|
|||
}
|
||||
|
||||
if qhost, err := idna.ToUnicode(hostname); err != nil {
|
||||
log.Debug("querylog: translating %q into unicode: %s", hostname, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"translating into unicode",
|
||||
"hostname", hostname,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
} else if qhost != hostname && qhost != "" {
|
||||
question["unicode_name"] = qhost
|
||||
}
|
||||
|
@ -96,21 +107,26 @@ func entryToJSON(entry *logEntry, anonFunc aghnet.IPMutFunc) (jsonEntry jobject)
|
|||
jsonEntry["service_name"] = entry.Result.ServiceName
|
||||
}
|
||||
|
||||
setMsgData(entry, jsonEntry)
|
||||
setOrigAns(entry, jsonEntry)
|
||||
l.setMsgData(ctx, entry, jsonEntry)
|
||||
l.setOrigAns(ctx, entry, jsonEntry)
|
||||
|
||||
return jsonEntry
|
||||
}
|
||||
|
||||
// setMsgData sets the message data in jsonEntry.
|
||||
func setMsgData(entry *logEntry, jsonEntry jobject) {
|
||||
func (l *queryLog) setMsgData(ctx context.Context, entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.Answer) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
msg := &dns.Msg{}
|
||||
if err := msg.Unpack(entry.Answer); err != nil {
|
||||
log.Debug("querylog: failed to unpack dns msg answer: %v: %s", entry.Answer, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"unpacking dns message",
|
||||
"answer", entry.Answer,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -126,7 +142,7 @@ func setMsgData(entry *logEntry, jsonEntry jobject) {
|
|||
}
|
||||
|
||||
// setOrigAns sets the original answer data in jsonEntry.
|
||||
func setOrigAns(entry *logEntry, jsonEntry jobject) {
|
||||
func (l *queryLog) setOrigAns(ctx context.Context, entry *logEntry, jsonEntry jobject) {
|
||||
if len(entry.OrigAnswer) == 0 {
|
||||
return
|
||||
}
|
||||
|
@ -134,7 +150,12 @@ func setOrigAns(entry *logEntry, jsonEntry jobject) {
|
|||
orig := &dns.Msg{}
|
||||
err := orig.Unpack(entry.OrigAnswer)
|
||||
if err != nil {
|
||||
log.Debug("querylog: orig.Unpack(entry.OrigAnswer): %v: %s", entry.OrigAnswer, err)
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"setting original answer",
|
||||
"answer", entry.OrigAnswer,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -2,7 +2,9 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -11,7 +13,7 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -22,6 +24,10 @@ const queryLogFileName = "querylog.json"
|
|||
|
||||
// queryLog is a structure that writes and reads the DNS query log.
|
||||
type queryLog struct {
|
||||
// logger is used for logging the operation of the query log. It must not
|
||||
// be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// confMu protects conf.
|
||||
confMu *sync.RWMutex
|
||||
|
||||
|
@ -76,24 +82,34 @@ func NewClientProto(s string) (cp ClientProto, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (l *queryLog) Start() {
|
||||
// type check
|
||||
var _ QueryLog = (*queryLog)(nil)
|
||||
|
||||
// Start implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Start(ctx context.Context) (err error) {
|
||||
if l.conf.HTTPRegister != nil {
|
||||
l.initWeb()
|
||||
}
|
||||
|
||||
go l.periodicRotate()
|
||||
go l.periodicRotate(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) Close() {
|
||||
// Shutdown implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Shutdown(ctx context.Context) (err error) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
||||
if l.conf.FileEnabled {
|
||||
err := l.flushLogBuffer()
|
||||
err = l.flushLogBuffer(ctx)
|
||||
if err != nil {
|
||||
log.Error("querylog: closing: %s", err)
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkInterval(ivl time.Duration) (ok bool) {
|
||||
|
@ -123,6 +139,7 @@ func validateIvl(ivl time.Duration) (err error) {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WriteDiskConfig implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) WriteDiskConfig(c *Config) {
|
||||
l.confMu.RLock()
|
||||
defer l.confMu.RUnlock()
|
||||
|
@ -131,7 +148,7 @@ func (l *queryLog) WriteDiskConfig(c *Config) {
|
|||
}
|
||||
|
||||
// Clear memory buffer and remove log files
|
||||
func (l *queryLog) clear() {
|
||||
func (l *queryLog) clear(ctx context.Context) {
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
|
@ -146,19 +163,24 @@ func (l *queryLog) clear() {
|
|||
oldLogFile := l.logFile + ".1"
|
||||
err := os.Remove(oldLogFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing old log file %q: %s", oldLogFile, err)
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"removing old log file",
|
||||
"file", oldLogFile,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
err = os.Remove(l.logFile)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("removing log file %q: %s", l.logFile, err)
|
||||
l.logger.ErrorContext(ctx, "removing log file", "file", l.logFile, slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: cleared")
|
||||
l.logger.DebugContext(ctx, "cleared")
|
||||
}
|
||||
|
||||
// newLogEntry creates an instance of logEntry from parameters.
|
||||
func newLogEntry(params *AddParams) (entry *logEntry) {
|
||||
func newLogEntry(ctx context.Context, logger *slog.Logger, params *AddParams) (entry *logEntry) {
|
||||
q := params.Question.Question[0]
|
||||
qHost := aghnet.NormalizeDomain(q.Name)
|
||||
|
||||
|
@ -187,8 +209,8 @@ func newLogEntry(params *AddParams) (entry *logEntry) {
|
|||
entry.ReqECS = params.ReqECS.String()
|
||||
}
|
||||
|
||||
entry.addResponse(params.Answer, false)
|
||||
entry.addResponse(params.OrigAnswer, true)
|
||||
entry.addResponse(ctx, logger, params.Answer, false)
|
||||
entry.addResponse(ctx, logger, params.OrigAnswer, true)
|
||||
|
||||
return entry
|
||||
}
|
||||
|
@ -209,9 +231,12 @@ func (l *queryLog) Add(params *AddParams) {
|
|||
return
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
err := params.validate()
|
||||
if err != nil {
|
||||
log.Error("querylog: adding record: %s, skipping", err)
|
||||
l.logger.ErrorContext(ctx, "adding record", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -220,7 +245,7 @@ func (l *queryLog) Add(params *AddParams) {
|
|||
params.Result = &filtering.Result{}
|
||||
}
|
||||
|
||||
entry := newLogEntry(params)
|
||||
entry := newLogEntry(ctx, l.logger, params)
|
||||
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
@ -232,9 +257,9 @@ func (l *queryLog) Add(params *AddParams) {
|
|||
|
||||
// TODO(s.chzhen): Fix occasional rewrite of entires.
|
||||
go func() {
|
||||
flushErr := l.flushLogBuffer()
|
||||
flushErr := l.flushLogBuffer(ctx)
|
||||
if flushErr != nil {
|
||||
log.Error("querylog: flushing after adding: %s", flushErr)
|
||||
l.logger.ErrorContext(ctx, "flushing after adding", slogutil.KeyError, flushErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
@ -247,7 +272,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16, ids []string) bool {
|
|||
|
||||
c, err := l.findClient(ids)
|
||||
if err != nil {
|
||||
log.Error("querylog: finding client: %s", err)
|
||||
// TODO(s.chzhen): Pass context.
|
||||
l.logger.ErrorContext(context.TODO(), "finding client", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if c != nil && c.IgnoreQueryLog {
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
|
@ -14,14 +15,11 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.DiscardLogOutput(m)
|
||||
}
|
||||
|
||||
// TestQueryLog tests adding and loading (with filtering) entries from disk and
|
||||
// memory.
|
||||
func TestQueryLog(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
|
@ -30,16 +28,21 @@ func TestQueryLog(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
// Write to disk (first file).
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Start writing to the second file.
|
||||
require.NoError(t, l.rotate())
|
||||
require.NoError(t, l.rotate(ctx))
|
||||
|
||||
// Add disk entries.
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
|
||||
// Write to disk.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Add memory entries.
|
||||
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
|
||||
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
|
||||
|
@ -119,8 +122,9 @@ func TestQueryLog(t *testing.T) {
|
|||
params := newSearchParams()
|
||||
params.searchCriteria = tc.sCr
|
||||
|
||||
entries, _ := l.search(params)
|
||||
entries, _ := l.search(ctx, params)
|
||||
require.Len(t, entries, len(tc.want))
|
||||
|
||||
for _, want := range tc.want {
|
||||
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
|
||||
}
|
||||
|
@ -130,6 +134,7 @@ func TestQueryLog(t *testing.T) {
|
|||
|
||||
func TestQueryLogOffsetLimit(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
MemSize: 100,
|
||||
|
@ -142,12 +147,16 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
|||
firstPageDomain = "first.example.org"
|
||||
secondPageDomain = "second.example.org"
|
||||
)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Add entries to the log.
|
||||
for range entNum {
|
||||
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to the first file.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
// Add more to the in-memory part of log.
|
||||
for range entNum {
|
||||
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
|
@ -191,8 +200,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
|||
t.Run(tc.name, func(t *testing.T) {
|
||||
params.offset = tc.offset
|
||||
params.limit = tc.limit
|
||||
entries, _ := l.search(params)
|
||||
|
||||
entries, _ := l.search(ctx, params)
|
||||
require.Len(t, entries, tc.wantLen)
|
||||
|
||||
if tc.wantLen > 0 {
|
||||
|
@ -205,6 +213,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
|
|||
|
||||
func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: true,
|
||||
RotationIvl: timeutil.Day,
|
||||
|
@ -213,20 +222,21 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
|||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
const entNum = 10
|
||||
// Add entries to the log.
|
||||
for range entNum {
|
||||
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
}
|
||||
// Write them to disk.
|
||||
require.NoError(t, l.flushLogBuffer())
|
||||
require.NoError(t, l.flushLogBuffer(ctx))
|
||||
|
||||
params := newSearchParams()
|
||||
|
||||
for _, maxFileScanEntries := range []int{5, 0} {
|
||||
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
|
||||
params.maxFileScanEntries = maxFileScanEntries
|
||||
entries, _ := l.search(params)
|
||||
entries, _ := l.search(ctx, params)
|
||||
assert.Len(t, entries, entNum-maxFileScanEntries)
|
||||
})
|
||||
}
|
||||
|
@ -234,6 +244,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
|
|||
|
||||
func TestQueryLogFileDisabled(t *testing.T) {
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
Enabled: true,
|
||||
FileEnabled: false,
|
||||
RotationIvl: timeutil.Day,
|
||||
|
@ -248,8 +259,10 @@ func TestQueryLogFileDisabled(t *testing.T) {
|
|||
addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
|
||||
|
||||
params := newSearchParams()
|
||||
ll, _ := l.search(params)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
ll, _ := l.search(ctx, params)
|
||||
require.Len(t, ll, 2)
|
||||
|
||||
assert.Equal(t, "example3.org", ll[0].QHost)
|
||||
assert.Equal(t, "example2.org", ll[1].QHost)
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
@ -10,7 +12,7 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -102,7 +104,11 @@ func (q *qLogFile) validateQLogLineIdx(lineIdx, lastProbeLineIdx, ts, fSize int6
|
|||
// for so that when we call "ReadNext" this line was returned.
|
||||
// - Depth of the search (how many times we compared timestamps).
|
||||
// - If we could not find it, it returns one of the errors described above.
|
||||
func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
|
||||
func (q *qLogFile) seekTS(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
timestamp int64,
|
||||
) (pos int64, depth int, err error) {
|
||||
q.lock.Lock()
|
||||
defer q.lock.Unlock()
|
||||
|
||||
|
@ -151,7 +157,7 @@ func (q *qLogFile) seekTS(timestamp int64) (pos int64, depth int, err error) {
|
|||
lastProbeLineIdx = lineIdx
|
||||
|
||||
// Get the timestamp from the query log record.
|
||||
ts := readQLogTimestamp(line)
|
||||
ts := readQLogTimestamp(ctx, logger, line)
|
||||
if ts == 0 {
|
||||
return 0, depth, fmt.Errorf(
|
||||
"looking up timestamp %d in %q: record %q has empty timestamp",
|
||||
|
@ -385,20 +391,22 @@ func readJSONValue(s, prefix string) string {
|
|||
}
|
||||
|
||||
// readQLogTimestamp reads the timestamp field from the query log line.
|
||||
func readQLogTimestamp(str string) int64 {
|
||||
func readQLogTimestamp(ctx context.Context, logger *slog.Logger, str string) int64 {
|
||||
val := readJSONValue(str, `"T":"`)
|
||||
if len(val) == 0 {
|
||||
val = readJSONValue(str, `"Time":"`)
|
||||
}
|
||||
|
||||
if len(val) == 0 {
|
||||
log.Error("Couldn't find timestamp: %s", str)
|
||||
logger.ErrorContext(ctx, "couldn't find timestamp", "line", str)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
tm, err := time.Parse(time.RFC3339Nano, val)
|
||||
if err != nil {
|
||||
log.Error("Couldn't parse timestamp: %s", val)
|
||||
logger.ErrorContext(ctx, "couldn't parse timestamp", "value", val, slogutil.KeyError, err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -24,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
|
|||
|
||||
f, err := os.CreateTemp(dir, "*.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use defer and not t.Cleanup to make sure that the file is closed
|
||||
// after this function is done.
|
||||
defer func() {
|
||||
|
@ -108,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
|
|||
// Calculate the expected position.
|
||||
fileInfo, err := q.file.Stat()
|
||||
require.NoError(t, err)
|
||||
|
||||
var expPos int64
|
||||
if expPos = fileInfo.Size(); expPos > 0 {
|
||||
expPos--
|
||||
|
@ -129,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
|
|||
}
|
||||
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
assert.Equal(t, tc.linesNum, read)
|
||||
})
|
||||
}
|
||||
|
@ -146,6 +150,9 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
|
|||
num: 10,
|
||||
}}
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
@ -171,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
|
|||
t.Run(l.name+"_"+tc.name, func(t *testing.T) {
|
||||
line, err := getQLogFileLine(q, tc.line)
|
||||
require.NoError(t, err)
|
||||
ts := readQLogTimestamp(line)
|
||||
|
||||
ts := readQLogTimestamp(ctx, logger, line)
|
||||
assert.NotEqualValues(t, 0, ts)
|
||||
|
||||
// Try seeking to that line now.
|
||||
pos, _, err := q.seekTS(ts)
|
||||
pos, _, err := q.seekTS(ctx, logger, ts)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqualValues(t, 0, pos)
|
||||
|
||||
testLine, err := q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, line, testLine)
|
||||
})
|
||||
}
|
||||
|
@ -199,6 +209,9 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
|
|||
num: 10,
|
||||
}}
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
for _, l := range linesCases {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
@ -221,14 +234,14 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
|
|||
|
||||
line, err := getQLogFileLine(q, l.num/2)
|
||||
require.NoError(t, err)
|
||||
testCases[2].ts = readQLogTimestamp(line) - 1
|
||||
|
||||
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.NotEqualValues(t, 0, tc.ts)
|
||||
|
||||
var depth int
|
||||
_, depth, err = q.seekTS(tc.ts)
|
||||
_, depth, err = q.seekTS(ctx, logger, tc.ts)
|
||||
assert.NotEmpty(t, l.num)
|
||||
require.Error(t, err)
|
||||
|
||||
|
@ -262,11 +275,13 @@ func TestQLogFile(t *testing.T) {
|
|||
// Seek to the start.
|
||||
pos, err := q.SeekStart()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Greater(t, pos, int64(0))
|
||||
|
||||
// Read first line.
|
||||
line, err := q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, line, "0.0.0.2")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
assert.True(t, strings.HasSuffix(line, "}"), line)
|
||||
|
@ -274,6 +289,7 @@ func TestQLogFile(t *testing.T) {
|
|||
// Read second line.
|
||||
line, err = q.ReadNext()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.EqualValues(t, 0, q.position)
|
||||
assert.Contains(t, line, "0.0.0.1")
|
||||
assert.True(t, strings.HasPrefix(line, "{"), line)
|
||||
|
@ -282,12 +298,14 @@ func TestQLogFile(t *testing.T) {
|
|||
// Try reading again (there's nothing to read anymore).
|
||||
line, err = q.ReadNext()
|
||||
require.Equal(t, io.EOF, err)
|
||||
|
||||
assert.Empty(t, line)
|
||||
}
|
||||
|
||||
func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
|
||||
f, err := os.CreateTemp(t.TempDir(), "*.txt")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, f.Close)
|
||||
|
||||
_, err = f.WriteString(data)
|
||||
|
@ -295,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
|
|||
|
||||
file, err = newQLogFile(f.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.CleanupAndRequireSuccess(t, file.Close)
|
||||
|
||||
return file
|
||||
|
@ -308,6 +327,9 @@ func TestQLog_Seek(t *testing.T) {
|
|||
`{"T":"` + strV + `"}` + nl
|
||||
timestamp, _ := time.Parse(time.RFC3339Nano, "2020-08-31T18:44:25.376690873+03:00")
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
wantErr error
|
||||
name string
|
||||
|
@ -340,8 +362,10 @@ func TestQLog_Seek(t *testing.T) {
|
|||
|
||||
q := newTestQLogFileData(t, data)
|
||||
|
||||
_, depth, err := q.seekTS(timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano())
|
||||
ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
|
||||
_, depth, err := q.seekTS(ctx, logger, ts)
|
||||
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
|
||||
|
||||
assert.Equal(t, tc.wantDepth, depth)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// qLogReader allows reading from multiple query log files in the reverse
|
||||
|
@ -16,6 +18,10 @@ import (
|
|||
// pointer to a particular query log file, and to a specific position in this
|
||||
// file, and it reads lines in reverse order starting from that position.
|
||||
type qLogReader struct {
|
||||
// logger is used for logging the operation of the query log reader. It
|
||||
// must not be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// qFiles is an array with the query log files. The order is from oldest
|
||||
// to newest.
|
||||
qFiles []*qLogFile
|
||||
|
@ -25,7 +31,7 @@ type qLogReader struct {
|
|||
}
|
||||
|
||||
// newQLogReader initializes a qLogReader instance with the specified files.
|
||||
func newQLogReader(files []string) (*qLogReader, error) {
|
||||
func newQLogReader(ctx context.Context, logger *slog.Logger, files []string) (*qLogReader, error) {
|
||||
qFiles := make([]*qLogFile, 0)
|
||||
|
||||
for _, f := range files {
|
||||
|
@ -38,7 +44,7 @@ func newQLogReader(files []string) (*qLogReader, error) {
|
|||
// Close what we've already opened.
|
||||
cErr := closeQFiles(qFiles)
|
||||
if cErr != nil {
|
||||
log.Debug("querylog: closing files: %s", cErr)
|
||||
logger.DebugContext(ctx, "closing files", slogutil.KeyError, cErr)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
|
@ -47,16 +53,20 @@ func newQLogReader(files []string) (*qLogReader, error) {
|
|||
qFiles = append(qFiles, q)
|
||||
}
|
||||
|
||||
return &qLogReader{qFiles: qFiles, currentFile: len(qFiles) - 1}, nil
|
||||
return &qLogReader{
|
||||
logger: logger,
|
||||
qFiles: qFiles,
|
||||
currentFile: len(qFiles) - 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// seekTS performs binary search of a query log record with the specified
|
||||
// timestamp. If the record is found, it sets qLogReader's position to point
|
||||
// to that line, so that the next ReadNext call returned this line.
|
||||
func (r *qLogReader) seekTS(timestamp int64) (err error) {
|
||||
func (r *qLogReader) seekTS(ctx context.Context, timestamp int64) (err error) {
|
||||
for i := len(r.qFiles) - 1; i >= 0; i-- {
|
||||
q := r.qFiles[i]
|
||||
_, _, err = q.seekTS(timestamp)
|
||||
_, _, err = q.seekTS(ctx, r.logger, timestamp)
|
||||
if err != nil {
|
||||
if errors.Is(err, errTSTooEarly) {
|
||||
// Look at the next file, since we've reached the end of this
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -17,8 +18,11 @@ func newTestQLogReader(t *testing.T, filesNum, linesNum int) (reader *qLogReader
|
|||
|
||||
testFiles := prepareTestFiles(t, filesNum, linesNum)
|
||||
|
||||
logger := slogutil.NewDiscardLogger()
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
// Create the new qLogReader instance.
|
||||
reader, err := newQLogReader(testFiles)
|
||||
reader, err := newQLogReader(ctx, logger, testFiles)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotNil(t, reader)
|
||||
|
@ -73,6 +77,7 @@ func TestQLogReader(t *testing.T) {
|
|||
|
||||
func TestQLogReader_Seek(t *testing.T) {
|
||||
r := newTestQLogReader(t, 2, 10000)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
|
@ -113,7 +118,7 @@ func TestQLogReader_Seek(t *testing.T) {
|
|||
ts, err := time.Parse(time.RFC3339Nano, tc.time)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = r.seekTS(ts.UnixNano())
|
||||
err = r.seekTS(ctx, ts.UnixNano())
|
||||
assert.ErrorIs(t, err, tc.want)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package querylog
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
@ -12,20 +13,19 @@ import (
|
|||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// QueryLog - main interface
|
||||
// QueryLog is the query log interface for use by other packages.
|
||||
type QueryLog interface {
|
||||
Start()
|
||||
// Interface starts and stops the query log.
|
||||
service.Interface
|
||||
|
||||
// Close query log object
|
||||
Close()
|
||||
|
||||
// Add a log entry
|
||||
// Add adds a log entry.
|
||||
Add(params *AddParams)
|
||||
|
||||
// WriteDiskConfig - write configuration
|
||||
// WriteDiskConfig writes the query log configuration to c.
|
||||
WriteDiskConfig(c *Config)
|
||||
|
||||
// ShouldLog returns true if request for the host should be logged.
|
||||
|
@ -36,6 +36,10 @@ type QueryLog interface {
|
|||
//
|
||||
// Do not alter any fields of this structure after using it.
|
||||
type Config struct {
|
||||
// Logger is used for logging the operation of the query log. It must not
|
||||
// be nil.
|
||||
Logger *slog.Logger
|
||||
|
||||
// Ignored contains the list of host names, which should not be written to
|
||||
// log, and matches them.
|
||||
Ignored *aghnet.IgnoreEngine
|
||||
|
@ -151,6 +155,7 @@ func newQueryLog(conf Config) (l *queryLog, err error) {
|
|||
}
|
||||
|
||||
l = &queryLog{
|
||||
logger: conf.Logger,
|
||||
findClient: findClient,
|
||||
|
||||
buffer: container.NewRingBuffer[*logEntry](memSize),
|
||||
|
|
|
@ -2,6 +2,7 @@ package querylog
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
@ -9,28 +10,30 @@ import (
|
|||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/c2h5oh/datasize"
|
||||
)
|
||||
|
||||
// flushLogBuffer flushes the current buffer to file and resets the current
|
||||
// buffer.
|
||||
func (l *queryLog) flushLogBuffer() (err error) {
|
||||
func (l *queryLog) flushLogBuffer(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "flushing log buffer: %w") }()
|
||||
|
||||
l.fileFlushLock.Lock()
|
||||
defer l.fileFlushLock.Unlock()
|
||||
|
||||
b, err := l.encodeEntries()
|
||||
b, err := l.encodeEntries(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
return l.flushToFile(b)
|
||||
return l.flushToFile(ctx, b)
|
||||
}
|
||||
|
||||
// encodeEntries returns JSON encoded log entries, logs estimated time, clears
|
||||
// the log buffer.
|
||||
func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
||||
func (l *queryLog) encodeEntries(ctx context.Context) (b *bytes.Buffer, err error) {
|
||||
l.bufferLock.Lock()
|
||||
defer l.bufferLock.Unlock()
|
||||
|
||||
|
@ -55,8 +58,17 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
size := b.Len()
|
||||
elapsed := time.Since(start)
|
||||
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", bufLen, elapsed, b.Len()/1024, float64(b.Len())/float64(bufLen), elapsed/time.Duration(bufLen))
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"serialized elements via json",
|
||||
"count", bufLen,
|
||||
"elapsed", elapsed,
|
||||
"size", datasize.ByteSize(size),
|
||||
"size_per_entry", datasize.ByteSize(float64(size)/float64(bufLen)),
|
||||
"time_per_entry", elapsed/time.Duration(bufLen),
|
||||
)
|
||||
|
||||
l.buffer.Clear()
|
||||
l.flushPending = false
|
||||
|
@ -65,7 +77,7 @@ func (l *queryLog) encodeEntries() (b *bytes.Buffer, err error) {
|
|||
}
|
||||
|
||||
// flushToFile saves the encoded log entries to the query log file.
|
||||
func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
|
||||
func (l *queryLog) flushToFile(ctx context.Context, b *bytes.Buffer) (err error) {
|
||||
l.fileWriteLock.Lock()
|
||||
defer l.fileWriteLock.Unlock()
|
||||
|
||||
|
@ -83,19 +95,19 @@ func (l *queryLog) flushToFile(b *bytes.Buffer) (err error) {
|
|||
return fmt.Errorf("writing to file %q: %w", filename, err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: ok %q: %v bytes written", filename, n)
|
||||
l.logger.DebugContext(ctx, "flushed to file", "file", filename, "size", datasize.ByteSize(n))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) rotate() error {
|
||||
func (l *queryLog) rotate(ctx context.Context) error {
|
||||
from := l.logFile
|
||||
to := l.logFile + ".1"
|
||||
|
||||
err := os.Rename(from, to)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
log.Debug("querylog: no log to rotate")
|
||||
l.logger.DebugContext(ctx, "no log to rotate")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -103,12 +115,12 @@ func (l *queryLog) rotate() error {
|
|||
return fmt.Errorf("failed to rename old file: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("querylog: renamed %s into %s", from, to)
|
||||
l.logger.DebugContext(ctx, "renamed log file", "from", from, "to", to)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
|
||||
func (l *queryLog) readFileFirstTimeValue(ctx context.Context) (first time.Time, err error) {
|
||||
var f *os.File
|
||||
f, err = os.Open(l.logFile)
|
||||
if err != nil {
|
||||
|
@ -130,15 +142,15 @@ func (l *queryLog) readFileFirstTimeValue() (first time.Time, err error) {
|
|||
return time.Time{}, err
|
||||
}
|
||||
|
||||
log.Debug("querylog: the oldest log entry: %s", val)
|
||||
l.logger.DebugContext(ctx, "oldest log entry", "entry_time", val)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (l *queryLog) periodicRotate() {
|
||||
defer log.OnPanic("querylog: rotating")
|
||||
func (l *queryLog) periodicRotate(ctx context.Context) {
|
||||
defer slogutil.RecoverAndLog(ctx, l.logger)
|
||||
|
||||
l.checkAndRotate()
|
||||
l.checkAndRotate(ctx)
|
||||
|
||||
// rotationCheckIvl is the period of time between checking the need for
|
||||
// rotating log files. It's smaller of any available rotation interval to
|
||||
|
@ -151,13 +163,13 @@ func (l *queryLog) periodicRotate() {
|
|||
defer rotations.Stop()
|
||||
|
||||
for range rotations.C {
|
||||
l.checkAndRotate()
|
||||
l.checkAndRotate(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndRotate rotates log files if those are older than the specified
|
||||
// rotation interval.
|
||||
func (l *queryLog) checkAndRotate() {
|
||||
func (l *queryLog) checkAndRotate(ctx context.Context) {
|
||||
var rotationIvl time.Duration
|
||||
func() {
|
||||
l.confMu.RLock()
|
||||
|
@ -166,29 +178,30 @@ func (l *queryLog) checkAndRotate() {
|
|||
rotationIvl = l.conf.RotationIvl
|
||||
}()
|
||||
|
||||
oldest, err := l.readFileFirstTimeValue()
|
||||
oldest, err := l.readFileFirstTimeValue(ctx)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("querylog: reading oldest record for rotation: %s", err)
|
||||
l.logger.ErrorContext(ctx, "reading oldest record for rotation", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
|
||||
log.Debug(
|
||||
"querylog: %s <= %s, not rotating",
|
||||
now.Format(time.RFC3339),
|
||||
rotTime.Format(time.RFC3339),
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"not rotating",
|
||||
"now", now.Format(time.RFC3339),
|
||||
"rotate_time", rotTime.Format(time.RFC3339),
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = l.rotate()
|
||||
err = l.rotate(ctx)
|
||||
if err != nil {
|
||||
log.Error("querylog: rotating: %s", err)
|
||||
l.logger.ErrorContext(ctx, "rotating", slogutil.KeyError, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("querylog: rotated successfully")
|
||||
l.logger.DebugContext(ctx, "rotated successfully")
|
||||
}
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// client finds the client info, if any, by its ClientID and IP address,
|
||||
|
@ -48,7 +50,11 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
|
|||
// buffer. It optionally uses the client cache, if provided. It also returns
|
||||
// the total amount of records in the buffer at the moment of searching.
|
||||
// l.confMu is expected to be locked.
|
||||
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) {
|
||||
func (l *queryLog) searchMemory(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
) (entries []*logEntry, total int) {
|
||||
// Check memory size, as the buffer can contain a single log record. See
|
||||
// [newQueryLog].
|
||||
if l.conf.MemSize == 0 {
|
||||
|
@ -66,9 +72,14 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
|||
var err error
|
||||
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
|
||||
if err != nil {
|
||||
msg := "querylog: enriching memory record at time %s" +
|
||||
" for client %q (clientid %q): %s"
|
||||
log.Error(msg, e.Time, e.IP, e.ClientID, err)
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching memory record",
|
||||
"at", e.Time,
|
||||
"client_ip", e.IP,
|
||||
"client_id", e.ClientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
// Go on and try to match anyway.
|
||||
}
|
||||
|
@ -86,7 +97,10 @@ func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entrie
|
|||
// search searches log entries in memory buffer and log file using specified
|
||||
// parameters and returns the list of entries found and the time of the oldest
|
||||
// entry. l.confMu is expected to be locked.
|
||||
func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest time.Time) {
|
||||
func (l *queryLog) search(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
) (entries []*logEntry, oldest time.Time) {
|
||||
start := time.Now()
|
||||
|
||||
if params.limit == 0 {
|
||||
|
@ -95,11 +109,11 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
|||
|
||||
cache := clientCache{}
|
||||
|
||||
memoryEntries, bufLen := l.searchMemory(params, cache)
|
||||
log.Debug("querylog: got %d entries from memory", len(memoryEntries))
|
||||
memoryEntries, bufLen := l.searchMemory(ctx, params, cache)
|
||||
l.logger.DebugContext(ctx, "got entries from memory", "count", len(memoryEntries))
|
||||
|
||||
fileEntries, oldest, total := l.searchFiles(params, cache)
|
||||
log.Debug("querylog: got %d entries from files", len(fileEntries))
|
||||
fileEntries, oldest, total := l.searchFiles(ctx, params, cache)
|
||||
l.logger.DebugContext(ctx, "got entries from files", "count", len(fileEntries))
|
||||
|
||||
total += bufLen
|
||||
|
||||
|
@ -134,12 +148,13 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
|||
oldest = entries[len(entries)-1].Time
|
||||
}
|
||||
|
||||
log.Debug(
|
||||
"querylog: prepared data (%d/%d) older than %s in %s",
|
||||
len(entries),
|
||||
total,
|
||||
params.olderThan,
|
||||
time.Since(start),
|
||||
l.logger.DebugContext(
|
||||
ctx,
|
||||
"prepared data",
|
||||
"count", len(entries),
|
||||
"total", total,
|
||||
"older_than", params.olderThan,
|
||||
"elapsed", time.Since(start),
|
||||
)
|
||||
|
||||
return entries, oldest
|
||||
|
@ -147,12 +162,12 @@ func (l *queryLog) search(params *searchParams) (entries []*logEntry, oldest tim
|
|||
|
||||
// seekRecord changes the current position to the next record older than the
|
||||
// provided parameter.
|
||||
func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
|
||||
func (r *qLogReader) seekRecord(ctx context.Context, olderThan time.Time) (err error) {
|
||||
if olderThan.IsZero() {
|
||||
return r.SeekStart()
|
||||
}
|
||||
|
||||
err = r.seekTS(olderThan.UnixNano())
|
||||
err = r.seekTS(ctx, olderThan.UnixNano())
|
||||
if err == nil {
|
||||
// Read to the next record, because we only need the one that goes
|
||||
// after it.
|
||||
|
@ -164,21 +179,24 @@ func (r *qLogReader) seekRecord(olderThan time.Time) (err error) {
|
|||
|
||||
// setQLogReader creates a reader with the specified files and sets the
|
||||
// position to the next record older than the provided parameter.
|
||||
func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error) {
|
||||
func (l *queryLog) setQLogReader(
|
||||
ctx context.Context,
|
||||
olderThan time.Time,
|
||||
) (qr *qLogReader, err error) {
|
||||
files := []string{
|
||||
l.logFile + ".1",
|
||||
l.logFile,
|
||||
}
|
||||
|
||||
r, err := newQLogReader(files)
|
||||
r, err := newQLogReader(ctx, l.logger, files)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("opening qlog reader: %s", err)
|
||||
return nil, fmt.Errorf("opening qlog reader: %w", err)
|
||||
}
|
||||
|
||||
err = r.seekRecord(olderThan)
|
||||
err = r.seekRecord(ctx, olderThan)
|
||||
if err != nil {
|
||||
defer func() { err = errors.WithDeferred(err, r.Close()) }()
|
||||
log.Debug("querylog: cannot seek to %s: %s", olderThan, err)
|
||||
l.logger.DebugContext(ctx, "cannot seek", "older_than", olderThan, slogutil.KeyError, err)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -191,13 +209,14 @@ func (l *queryLog) setQLogReader(olderThan time.Time) (qr *qLogReader, err error
|
|||
// calls faster so that the UI could handle it and show something quicker.
|
||||
// This behavior can be overridden if maxFileScanEntries is set to 0.
|
||||
func (l *queryLog) readEntries(
|
||||
ctx context.Context,
|
||||
r *qLogReader,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
totalLimit int,
|
||||
) (entries []*logEntry, oldestNano int64, total int) {
|
||||
for total < params.maxFileScanEntries || params.maxFileScanEntries <= 0 {
|
||||
ent, ts, rErr := l.readNextEntry(r, params, cache)
|
||||
ent, ts, rErr := l.readNextEntry(ctx, r, params, cache)
|
||||
if rErr != nil {
|
||||
if rErr == io.EOF {
|
||||
oldestNano = 0
|
||||
|
@ -205,7 +224,7 @@ func (l *queryLog) readEntries(
|
|||
break
|
||||
}
|
||||
|
||||
log.Error("querylog: reading next entry: %s", rErr)
|
||||
l.logger.ErrorContext(ctx, "reading next entry", slogutil.KeyError, rErr)
|
||||
}
|
||||
|
||||
oldestNano = ts
|
||||
|
@ -231,12 +250,13 @@ func (l *queryLog) readEntries(
|
|||
// and the total number of processed entries, including discarded ones,
|
||||
// correspondingly.
|
||||
func (l *queryLog) searchFiles(
|
||||
ctx context.Context,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
) (entries []*logEntry, oldest time.Time, total int) {
|
||||
r, err := l.setQLogReader(params.olderThan)
|
||||
r, err := l.setQLogReader(ctx, params.olderThan)
|
||||
if err != nil {
|
||||
log.Error("querylog: %s", err)
|
||||
l.logger.ErrorContext(ctx, "searching files", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if r == nil {
|
||||
|
@ -245,12 +265,12 @@ func (l *queryLog) searchFiles(
|
|||
|
||||
defer func() {
|
||||
if closeErr := r.Close(); closeErr != nil {
|
||||
log.Error("querylog: closing file: %s", closeErr)
|
||||
l.logger.ErrorContext(ctx, "closing files", slogutil.KeyError, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
totalLimit := params.offset + params.limit
|
||||
entries, oldestNano, total := l.readEntries(r, params, cache, totalLimit)
|
||||
entries, oldestNano, total := l.readEntries(ctx, r, params, cache, totalLimit)
|
||||
if oldestNano != 0 {
|
||||
oldest = time.Unix(0, oldestNano)
|
||||
}
|
||||
|
@ -266,15 +286,21 @@ type quickMatchClientFinder struct {
|
|||
}
|
||||
|
||||
// findClient is a method that can be used as a quickMatchClientFinder.
|
||||
func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
|
||||
func (f quickMatchClientFinder) findClient(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
clientID string,
|
||||
ip string,
|
||||
) (c *Client) {
|
||||
var err error
|
||||
c, err = f.client(clientID, ip, f.cache)
|
||||
if err != nil {
|
||||
log.Error(
|
||||
"querylog: enriching file record for quick search: for client %q (clientid %q): %s",
|
||||
ip,
|
||||
clientID,
|
||||
err,
|
||||
logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching file record for quick search",
|
||||
"client_ip", ip,
|
||||
"client_id", clientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -286,6 +312,7 @@ func (f quickMatchClientFinder) findClient(clientID, ip string) (c *Client) {
|
|||
// the entry doesn't match the search criteria. ts is the timestamp of the
|
||||
// processed entry.
|
||||
func (l *queryLog) readNextEntry(
|
||||
ctx context.Context,
|
||||
r *qLogReader,
|
||||
params *searchParams,
|
||||
cache clientCache,
|
||||
|
@ -301,14 +328,14 @@ func (l *queryLog) readNextEntry(
|
|||
cache: cache,
|
||||
}
|
||||
|
||||
if !params.quickMatch(line, clientFinder.findClient) {
|
||||
ts = readQLogTimestamp(line)
|
||||
if !params.quickMatch(ctx, l.logger, line, clientFinder.findClient) {
|
||||
ts = readQLogTimestamp(ctx, l.logger, line)
|
||||
|
||||
return nil, ts, nil
|
||||
}
|
||||
|
||||
e = &logEntry{}
|
||||
decodeLogEntry(e, line)
|
||||
l.decodeLogEntry(ctx, e, line)
|
||||
|
||||
if l.isIgnored(e.QHost) {
|
||||
return nil, ts, nil
|
||||
|
@ -316,12 +343,13 @@ func (l *queryLog) readNextEntry(
|
|||
|
||||
e.client, err = l.client(e.ClientID, e.IP.String(), cache)
|
||||
if err != nil {
|
||||
log.Error(
|
||||
"querylog: enriching file record at time %s for client %q (clientid %q): %s",
|
||||
e.Time,
|
||||
e.IP,
|
||||
e.ClientID,
|
||||
err,
|
||||
l.logger.ErrorContext(
|
||||
ctx,
|
||||
"enriching file record",
|
||||
"at", e.Time,
|
||||
"client_ip", e.IP,
|
||||
"client_id", e.ClientID,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
// Go on and try to match anyway.
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -36,6 +38,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
|||
}
|
||||
|
||||
l, err := newQueryLog(Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
FindClient: findClient,
|
||||
BaseDir: t.TempDir(),
|
||||
RotationIvl: timeutil.Day,
|
||||
|
@ -45,7 +48,11 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
|||
AnonymizeClientIP: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(l.Close)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return l.Shutdown(ctx)
|
||||
})
|
||||
|
||||
q := &dns.Msg{
|
||||
Question: []dns.Question{{
|
||||
|
@ -81,7 +88,7 @@ func TestQueryLog_Search_findClient(t *testing.T) {
|
|||
olderThan: time.Now().Add(10 * time.Second),
|
||||
limit: 3,
|
||||
}
|
||||
entries, _ := l.search(sp)
|
||||
entries, _ := l.search(ctx, sp)
|
||||
assert.Equal(t, 2, findClientCalls)
|
||||
|
||||
require.Len(t, entries, 3)
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package querylog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
|
@ -87,7 +89,12 @@ func ctDomainOrClientCaseNonStrict(
|
|||
// quickMatch quickly checks if the line matches the given search criterion.
|
||||
// It returns false if the like doesn't match. This method is only here for
|
||||
// optimization purposes.
|
||||
func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
|
||||
func (c *searchCriterion) quickMatch(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
line string,
|
||||
findClient quickMatchClientFunc,
|
||||
) (ok bool) {
|
||||
switch c.criterionType {
|
||||
case ctTerm:
|
||||
host := readJSONValue(line, `"QH":"`)
|
||||
|
@ -95,7 +102,7 @@ func (c *searchCriterion) quickMatch(line string, findClient quickMatchClientFun
|
|||
clientID := readJSONValue(line, `"CID":"`)
|
||||
|
||||
var name string
|
||||
if cli := findClient(clientID, ip); cli != nil {
|
||||
if cli := findClient(ctx, logger, clientID, ip); cli != nil {
|
||||
name = cli.Name
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package querylog
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
// searchParams represent the search query sent by the client.
|
||||
type searchParams struct {
|
||||
|
@ -35,14 +39,23 @@ func newSearchParams() *searchParams {
|
|||
}
|
||||
|
||||
// quickMatchClientFunc is a simplified client finder for quick matches.
|
||||
type quickMatchClientFunc = func(clientID, ip string) (c *Client)
|
||||
type quickMatchClientFunc = func(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
clientID, ip string,
|
||||
) (c *Client)
|
||||
|
||||
// quickMatch quickly checks if the line matches the given search parameters.
|
||||
// It returns false if the line doesn't match. This method is only here for
|
||||
// optimization purposes.
|
||||
func (s *searchParams) quickMatch(line string, findClient quickMatchClientFunc) (ok bool) {
|
||||
func (s *searchParams) quickMatch(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
line string,
|
||||
findClient quickMatchClientFunc,
|
||||
) (ok bool) {
|
||||
for _, c := range s.searchCriteria {
|
||||
if !c.quickMatch(line, findClient) {
|
||||
if !c.quickMatch(ctx, logger, line, findClient) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue