all: imp code

This commit is contained in:
Stanislav Chzhen 2024-11-18 16:39:35 +03:00
parent cbb993f7ae
commit ef715c58cb
8 changed files with 49 additions and 23 deletions

View File

@ -471,7 +471,11 @@ func startDNSServer() error {
Context.filters.Start() Context.filters.Start()
Context.stats.Start() Context.stats.Start()
Context.queryLog.Start(ctx)
err = Context.queryLog.Start(ctx)
if err != nil {
return fmt.Errorf("starting query log: %w", err)
}
return nil return nil
} }
@ -533,7 +537,10 @@ func closeDNSServer() {
if Context.queryLog != nil { if Context.queryLog != nil {
// TODO(s.chzhen): Pass context. // TODO(s.chzhen): Pass context.
Context.queryLog.Shutdown(context.TODO()) err := Context.queryLog.Shutdown(context.TODO())
if err != nil {
log.Debug("closing query log: %s", err)
}
} }
log.Debug("all dns modules are closed") log.Debug("all dns modules are closed")

View File

@ -681,8 +681,8 @@ func (l *queryLog) resultDecHandler(
name string, name string,
dec *json.Decoder, dec *json.Decoder,
ent *logEntry, ent *logEntry,
) (found bool) { ) (ok bool) {
found = true ok = true
switch name { switch name {
case "ReverseHosts": case "ReverseHosts":
l.decodeResultReverseHosts(ctx, dec, ent) l.decodeResultReverseHosts(ctx, dec, ent)
@ -693,10 +693,10 @@ func (l *queryLog) resultDecHandler(
case "DNSRewriteResult": case "DNSRewriteResult":
l.decodeResultDNSRewriteResult(ctx, dec, ent) l.decodeResultDNSRewriteResult(ctx, dec, ent)
default: default:
found = false ok = false
} }
return found return ok
} }
// decodeLogEntry decodes string str to logEntry ent. // decodeLogEntry decodes string str to logEntry ent.

View File

@ -86,25 +86,30 @@ func NewClientProto(s string) (cp ClientProto, err error) {
var _ QueryLog = (*queryLog)(nil) var _ QueryLog = (*queryLog)(nil)
// Start implements the [QueryLog] interface for *queryLog. // Start implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Start(ctx context.Context) { func (l *queryLog) Start(ctx context.Context) (err error) {
if l.conf.HTTPRegister != nil { if l.conf.HTTPRegister != nil {
l.initWeb() l.initWeb()
} }
go l.periodicRotate(ctx) go l.periodicRotate(ctx)
return nil
} }
// Shutdown implements the [QueryLog] interface for *queryLog. // Shutdown implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Shutdown(ctx context.Context) { func (l *queryLog) Shutdown(ctx context.Context) (err error) {
l.confMu.RLock() l.confMu.RLock()
defer l.confMu.RUnlock() defer l.confMu.RUnlock()
if l.conf.FileEnabled { if l.conf.FileEnabled {
err := l.flushLogBuffer(ctx) err = l.flushLogBuffer(ctx)
if err != nil { if err != nil {
l.logger.ErrorContext(ctx, "closing", slogutil.KeyError, err) // Don't wrap the error because it's informative enough as is.
return err
} }
} }
return nil
} }
func checkInterval(ivl time.Duration) (ok bool) { func checkInterval(ivl time.Duration) (ok bool) {

View File

@ -34,12 +34,15 @@ func TestQueryLog(t *testing.T) {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file). // Write to disk (first file).
require.NoError(t, l.flushLogBuffer(ctx)) require.NoError(t, l.flushLogBuffer(ctx))
// Start writing to the second file. // Start writing to the second file.
require.NoError(t, l.rotate(ctx)) require.NoError(t, l.rotate(ctx))
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk. // Write to disk.
require.NoError(t, l.flushLogBuffer(ctx)) require.NoError(t, l.flushLogBuffer(ctx))
// Add memory entries. // Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -121,6 +124,7 @@ func TestQueryLog(t *testing.T) {
entries, _ := l.search(ctx, params) entries, _ := l.search(ctx, params)
require.Len(t, entries, len(tc.want)) require.Len(t, entries, len(tc.want))
for _, want := range tc.want { for _, want := range tc.want {
assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) assertLogEntry(t, entries[want.num], want.host, want.answer, want.client)
} }
@ -152,6 +156,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
} }
// Write them to the first file. // Write them to the first file.
require.NoError(t, l.flushLogBuffer(ctx)) require.NoError(t, l.flushLogBuffer(ctx))
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for range entNum { for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -196,7 +201,6 @@ func TestQueryLogOffsetLimit(t *testing.T) {
params.offset = tc.offset params.offset = tc.offset
params.limit = tc.limit params.limit = tc.limit
entries, _ := l.search(ctx, params) entries, _ := l.search(ctx, params)
require.Len(t, entries, tc.wantLen) require.Len(t, entries, tc.wantLen)
if tc.wantLen > 0 { if tc.wantLen > 0 {
@ -229,7 +233,6 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
require.NoError(t, l.flushLogBuffer(ctx)) require.NoError(t, l.flushLogBuffer(ctx))
params := newSearchParams() params := newSearchParams()
for _, maxFileScanEntries := range []int{5, 0} { for _, maxFileScanEntries := range []int{5, 0} {
t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) {
params.maxFileScanEntries = maxFileScanEntries params.maxFileScanEntries = maxFileScanEntries
@ -259,6 +262,7 @@ func TestQueryLogFileDisabled(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout) ctx := testutil.ContextWithTimeout(t, testTimeout)
ll, _ := l.search(ctx, params) ll, _ := l.search(ctx, params)
require.Len(t, ll, 2) require.Len(t, ll, 2)
assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example3.org", ll[0].QHost)
assert.Equal(t, "example2.org", ll[1].QHost) assert.Equal(t, "example2.org", ll[1].QHost)
} }

View File

@ -25,6 +25,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
f, err := os.CreateTemp(dir, "*.txt") f, err := os.CreateTemp(dir, "*.txt")
require.NoError(t, err) require.NoError(t, err)
// Use defer and not t.Cleanup to make sure that the file is closed // Use defer and not t.Cleanup to make sure that the file is closed
// after this function is done. // after this function is done.
defer func() { defer func() {
@ -109,6 +110,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
// Calculate the expected position. // Calculate the expected position.
fileInfo, err := q.file.Stat() fileInfo, err := q.file.Stat()
require.NoError(t, err) require.NoError(t, err)
var expPos int64 var expPos int64
if expPos = fileInfo.Size(); expPos > 0 { if expPos = fileInfo.Size(); expPos > 0 {
expPos-- expPos--
@ -130,6 +132,7 @@ func TestQLogFile_ReadNext(t *testing.T) {
} }
require.Equal(t, io.EOF, err) require.Equal(t, io.EOF, err)
assert.Equal(t, tc.linesNum, read) assert.Equal(t, tc.linesNum, read)
}) })
} }
@ -175,16 +178,19 @@ func TestQLogFile_SeekTS_good(t *testing.T) {
t.Run(l.name+"_"+tc.name, func(t *testing.T) { t.Run(l.name+"_"+tc.name, func(t *testing.T) {
line, err := getQLogFileLine(q, tc.line) line, err := getQLogFileLine(q, tc.line)
require.NoError(t, err) require.NoError(t, err)
ts := readQLogTimestamp(ctx, logger, line) ts := readQLogTimestamp(ctx, logger, line)
assert.NotEqualValues(t, 0, ts) assert.NotEqualValues(t, 0, ts)
// Try seeking to that line now. // Try seeking to that line now.
pos, _, err := q.seekTS(ctx, logger, ts) pos, _, err := q.seekTS(ctx, logger, ts)
require.NoError(t, err) require.NoError(t, err)
assert.NotEqualValues(t, 0, pos) assert.NotEqualValues(t, 0, pos)
testLine, err := q.ReadNext() testLine, err := q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, line, testLine) assert.Equal(t, line, testLine)
}) })
} }
@ -228,8 +234,8 @@ func TestQLogFile_SeekTS_bad(t *testing.T) {
line, err := getQLogFileLine(q, l.num/2) line, err := getQLogFileLine(q, l.num/2)
require.NoError(t, err) require.NoError(t, err)
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
testCases[2].ts = readQLogTimestamp(ctx, logger, line) - 1
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
assert.NotEqualValues(t, 0, tc.ts) assert.NotEqualValues(t, 0, tc.ts)
@ -269,11 +275,13 @@ func TestQLogFile(t *testing.T) {
// Seek to the start. // Seek to the start.
pos, err := q.SeekStart() pos, err := q.SeekStart()
require.NoError(t, err) require.NoError(t, err)
assert.Greater(t, pos, int64(0)) assert.Greater(t, pos, int64(0))
// Read first line. // Read first line.
line, err := q.ReadNext() line, err := q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.Contains(t, line, "0.0.0.2") assert.Contains(t, line, "0.0.0.2")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
assert.True(t, strings.HasSuffix(line, "}"), line) assert.True(t, strings.HasSuffix(line, "}"), line)
@ -281,6 +289,7 @@ func TestQLogFile(t *testing.T) {
// Read second line. // Read second line.
line, err = q.ReadNext() line, err = q.ReadNext()
require.NoError(t, err) require.NoError(t, err)
assert.EqualValues(t, 0, q.position) assert.EqualValues(t, 0, q.position)
assert.Contains(t, line, "0.0.0.1") assert.Contains(t, line, "0.0.0.1")
assert.True(t, strings.HasPrefix(line, "{"), line) assert.True(t, strings.HasPrefix(line, "{"), line)
@ -289,12 +298,14 @@ func TestQLogFile(t *testing.T) {
// Try reading again (there's nothing to read anymore). // Try reading again (there's nothing to read anymore).
line, err = q.ReadNext() line, err = q.ReadNext()
require.Equal(t, io.EOF, err) require.Equal(t, io.EOF, err)
assert.Empty(t, line) assert.Empty(t, line)
} }
func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) { func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
f, err := os.CreateTemp(t.TempDir(), "*.txt") f, err := os.CreateTemp(t.TempDir(), "*.txt")
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, f.Close) testutil.CleanupAndRequireSuccess(t, f.Close)
_, err = f.WriteString(data) _, err = f.WriteString(data)
@ -302,6 +313,7 @@ func newTestQLogFileData(t *testing.T, data string) (file *qLogFile) {
file, err = newQLogFile(f.Name()) file, err = newQLogFile(f.Name())
require.NoError(t, err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, file.Close) testutil.CleanupAndRequireSuccess(t, file.Close)
return file return file
@ -353,6 +365,7 @@ func TestQLog_Seek(t *testing.T) {
ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano() ts := timestamp.Add(time.Second * time.Duration(tc.delta)).UnixNano()
_, depth, err := q.seekTS(ctx, logger, ts) _, depth, err := q.seekTS(ctx, logger, ts)
require.Truef(t, errors.Is(err, tc.wantErr), "%v", err) require.Truef(t, errors.Is(err, tc.wantErr), "%v", err)
assert.Equal(t, tc.wantDepth, depth) assert.Equal(t, tc.wantDepth, depth)
}) })
} }

View File

@ -1,7 +1,6 @@
package querylog package querylog
import ( import (
"context"
"fmt" "fmt"
"log/slog" "log/slog"
"net" "net"
@ -14,16 +13,14 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/service"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// QueryLog is the query log interface for use by other packages. // QueryLog is the query log interface for use by other packages.
type QueryLog interface { type QueryLog interface {
// Start starts the query log. // Interface starts and stops the query log.
Start(ctx context.Context) service.Interface
// Shutdown stops the query log.
Shutdown(ctx context.Context)
// Add adds a log entry. // Add adds a log entry.
Add(params *AddParams) Add(params *AddParams)

View File

@ -190,7 +190,7 @@ func (l *queryLog) setQLogReader(
r, err := newQLogReader(ctx, l.logger, files) r, err := newQLogReader(ctx, l.logger, files)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening qlog reader: %s", err) return nil, fmt.Errorf("opening qlog reader: %w", err)
} }
err = r.seekRecord(ctx, olderThan) err = r.seekRecord(ctx, olderThan)
@ -345,7 +345,7 @@ func (l *queryLog) readNextEntry(
if err != nil { if err != nil {
l.logger.ErrorContext( l.logger.ErrorContext(
ctx, ctx,
"enriching file record at time", "enriching file record",
"at", e.Time, "at", e.Time,
"client_ip", e.IP, "client_ip", e.IP,
"client_id", e.ClientID, "client_id", e.ClientID,

View File

@ -50,8 +50,8 @@ func TestQueryLog_Search_findClient(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ctx := testutil.ContextWithTimeout(t, testTimeout) ctx := testutil.ContextWithTimeout(t, testTimeout)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) {
l.Shutdown(ctx) return l.Shutdown(ctx)
}) })
q := &dns.Msg{ q := &dns.Msg{