Pull request 1797: AG-21072-querylog-conf-race

Merge in DNS/adguard-home from AG-21072-querylog-conf-race to master

Squashed commit of the following:

commit fcb14353ee63f582986e18affebf5ed965c3bfc7
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Apr 3 15:04:03 2023 +0300

    querylog: imp code, docs

commit 9070bc1d4eee5efc5795466b2c2a40c6ab495e68
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon Apr 3 13:58:56 2023 +0300

    querylog: fix races
This commit is contained in:
Ainar Garipov 2023-04-03 16:29:07 +03:00
parent 3575aa0570
commit 2a0d062947
6 changed files with 220 additions and 199 deletions

View File

@ -166,86 +166,6 @@ var logEntryHandlers = map[string]logEntryHandler{
}, },
} }
var resultHandlers = map[string]logEntryHandler{
"IsFiltered": func(t json.Token, ent *logEntry) error {
v, ok := t.(bool)
if !ok {
return nil
}
ent.Result.IsFiltered = v
return nil
},
"Rule": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].Text = s
return nil
},
"FilterID": func(t json.Token, ent *logEntry) error {
n, ok := t.(json.Number)
if !ok {
return nil
}
i, err := n.Int64()
if err != nil {
return err
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].FilterListID = i
return nil
},
"Reason": func(t json.Token, ent *logEntry) error {
v, ok := t.(json.Number)
if !ok {
return nil
}
i, err := v.Int64()
if err != nil {
return err
}
ent.Result.Reason = filtering.Reason(i)
return nil
},
"ServiceName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.ServiceName = s
return nil
},
"CanonName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.CanonName = s
return nil
},
}
func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) { func decodeResultRuleKey(key string, i int, dec *json.Decoder, ent *logEntry) {
var vToken json.Token var vToken json.Token
switch key { switch key {
@ -582,25 +502,11 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
return return
} }
switch key { decHandler, ok := resultDecHandlers[key]
case "ReverseHosts": if ok {
decodeResultReverseHosts(dec, ent) decHandler(dec, ent)
continue continue
case "IPList":
decodeResultIPList(dec, ent)
continue
case "Rules":
decodeResultRules(dec, ent)
continue
case "DNSRewriteResult":
decodeResultDNSRewriteResult(dec, ent)
continue
default:
// Go on.
} }
handler, ok := resultHandlers[key] handler, ok := resultHandlers[key]
@ -621,6 +527,93 @@ func decodeResult(dec *json.Decoder, ent *logEntry) {
} }
} }
var resultHandlers = map[string]logEntryHandler{
"IsFiltered": func(t json.Token, ent *logEntry) error {
v, ok := t.(bool)
if !ok {
return nil
}
ent.Result.IsFiltered = v
return nil
},
"Rule": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].Text = s
return nil
},
"FilterID": func(t json.Token, ent *logEntry) error {
n, ok := t.(json.Number)
if !ok {
return nil
}
i, err := n.Int64()
if err != nil {
return err
}
l := len(ent.Result.Rules)
if l == 0 {
ent.Result.Rules = []*filtering.ResultRule{{}}
l++
}
ent.Result.Rules[l-1].FilterListID = i
return nil
},
"Reason": func(t json.Token, ent *logEntry) error {
v, ok := t.(json.Number)
if !ok {
return nil
}
i, err := v.Int64()
if err != nil {
return err
}
ent.Result.Reason = filtering.Reason(i)
return nil
},
"ServiceName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.ServiceName = s
return nil
},
"CanonName": func(t json.Token, ent *logEntry) error {
s, ok := t.(string)
if !ok {
return nil
}
ent.Result.CanonName = s
return nil
},
}
var resultDecHandlers = map[string]func(dec *json.Decoder, ent *logEntry){
"ReverseHosts": decodeResultReverseHosts,
"IPList": decodeResultIPList,
"Rules": decodeResultRules,
"DNSRewriteResult": decodeResultDNSRewriteResult,
}
func decodeLogEntry(ent *logEntry, str string) { func decodeLogEntry(ent *logEntry, str string) {
dec := json.NewDecoder(strings.NewReader(str)) dec := json.NewDecoder(strings.NewReader(str))
dec.UseNumber() dec.UseNumber()

View File

@ -0,0 +1,70 @@
package querylog
import (
"net"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// logEntry represents a single entry in the file.
type logEntry struct {
// client is the found client information, if any.
client *Client
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ReqECS string `json:"ECS,omitempty"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Upstream string `json:",omitempty"`
Answer []byte `json:",omitempty"`
OrigAnswer []byte `json:",omitempty"`
IP net.IP `json:"IP"`
Result filtering.Result
Elapsed time.Duration
Cached bool `json:",omitempty"`
AuthenticatedData bool `json:"AD,omitempty"`
}
// shallowClone returns a shallow clone of e.
func (e *logEntry) shallowClone() (clone *logEntry) {
cloneVal := *e
return &cloneVal
}
// addResponse adds data from resp to e.Answer if resp is not nil. If isOrig is
// true, addResponse sets the e.OrigAnswer field instead of e.Answer. Any
// errors are logged.
func (e *logEntry) addResponse(resp *dns.Msg, isOrig bool) {
if resp == nil {
return
}
var err error
if isOrig {
e.Answer, err = resp.Pack()
err = errors.Annotate(err, "packing answer: %w")
} else {
e.OrigAnswer, err = resp.Pack()
err = errors.Annotate(err, "packing orig answer: %w")
}
if err != nil {
log.Error("querylog: %s", err)
}
}

View File

@ -3,7 +3,6 @@ package querylog
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -13,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -74,52 +74,24 @@ func NewClientProto(s string) (cp ClientProto, err error) {
} }
} }
// logEntry - represents a single log entry
type logEntry struct {
// client is the found client information, if any.
client *Client
Time time.Time `json:"T"`
QHost string `json:"QH"`
QType string `json:"QT"`
QClass string `json:"QC"`
ReqECS string `json:"ECS,omitempty"`
ClientID string `json:"CID,omitempty"`
ClientProto ClientProto `json:"CP"`
Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net
OrigAnswer []byte `json:",omitempty"`
Result filtering.Result
Upstream string `json:",omitempty"`
IP net.IP `json:"IP"`
Elapsed time.Duration
Cached bool `json:",omitempty"`
AuthenticatedData bool `json:"AD,omitempty"`
}
// shallowClone returns a shallow clone of e.
func (e *logEntry) shallowClone() (clone *logEntry) {
cloneVal := *e
return &cloneVal
}
func (l *queryLog) Start() { func (l *queryLog) Start() {
if l.conf.HTTPRegister != nil { if l.conf.HTTPRegister != nil {
l.initWeb() l.initWeb()
} }
go l.periodicRotate() go l.periodicRotate()
} }
func (l *queryLog) Close() { func (l *queryLog) Close() {
_ = l.flushLogBuffer(true) l.confMu.RLock()
defer l.confMu.RUnlock()
if l.conf.FileEnabled {
err := l.flushLogBuffer()
if err != nil {
log.Error("querylog: closing: %s", err)
}
}
} }
func checkInterval(ivl time.Duration) (ok bool) { func checkInterval(ivl time.Duration) (ok bool) {
@ -150,7 +122,13 @@ func validateIvl(ivl time.Duration) (err error) {
} }
func (l *queryLog) WriteDiskConfig(c *Config) { func (l *queryLog) WriteDiskConfig(c *Config) {
l.confMu.RLock()
defer l.confMu.RUnlock()
*c = *l.conf *c = *l.conf
// TODO(a.garipov): Add stringutil.Set.Clone.
c.Ignored = stringutil.NewSet(l.conf.Ignored.Values()...)
} }
// Clear memory buffer and remove log files // Clear memory buffer and remove log files
@ -181,7 +159,17 @@ func (l *queryLog) clear() {
} }
func (l *queryLog) Add(params *AddParams) { func (l *queryLog) Add(params *AddParams) {
if !l.conf.Enabled { var isEnabled, fileIsEnabled bool
var memSize uint32
func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
isEnabled, fileIsEnabled = l.conf.Enabled, l.conf.FileEnabled
memSize = l.conf.MemSize
}()
if !isEnabled {
return return
} }
@ -198,7 +186,7 @@ func (l *queryLog) Add(params *AddParams) {
now := time.Now() now := time.Now()
q := params.Question.Question[0] q := params.Question.Question[0]
entry := logEntry{ entry := &logEntry{
Time: now, Time: now,
QHost: strings.ToLower(q.Name[:len(q.Name)-1]), QHost: strings.ToLower(q.Name[:len(q.Name)-1]),
@ -223,39 +211,18 @@ func (l *queryLog) Add(params *AddParams) {
entry.ReqECS = params.ReqECS.String() entry.ReqECS = params.ReqECS.String()
} }
if params.Answer != nil { entry.addResponse(params.Answer, false)
var a []byte entry.addResponse(params.OrigAnswer, true)
a, err = params.Answer.Pack()
if err != nil {
log.Error("querylog: Answer.Pack(): %s", err)
return
}
entry.Answer = a
}
if params.OrigAnswer != nil {
var a []byte
a, err = params.OrigAnswer.Pack()
if err != nil {
log.Error("querylog: OrigAnswer.Pack(): %s", err)
return
}
entry.OrigAnswer = a
}
needFlush := false needFlush := false
func() { func() {
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
l.buffer = append(l.buffer, &entry) l.buffer = append(l.buffer, entry)
if !l.conf.FileEnabled { if !fileIsEnabled {
if len(l.buffer) > int(l.conf.MemSize) { if len(l.buffer) > int(memSize) {
// Writing to file is disabled, so just remove the oldest entry // Writing to file is disabled, so just remove the oldest entry
// from the slices. // from the slices.
// //
@ -265,17 +232,19 @@ func (l *queryLog) Add(params *AddParams) {
l.buffer = l.buffer[1:] l.buffer = l.buffer[1:]
} }
} else if !l.flushPending { } else if !l.flushPending {
needFlush = len(l.buffer) >= int(l.conf.MemSize) needFlush = len(l.buffer) >= int(memSize)
if needFlush { if needFlush {
l.flushPending = true l.flushPending = true
} }
} }
}() }()
// if buffer needs to be flushed to disk, do it now
if needFlush { if needFlush {
go func() { go func() {
_ = l.flushLogBuffer(false) flushErr := l.flushLogBuffer()
if flushErr != nil {
log.Error("querylog: flushing after adding: %s", err)
}
}() }()
} }
} }
@ -288,7 +257,8 @@ func (l *queryLog) ShouldLog(host string, _, _ uint16) bool {
return !l.isIgnored(host) return !l.isIgnored(host)
} }
// isIgnored returns true if the host is in the Ignored list. // isIgnored returns true if the host is in the ignored domains list. It
// assumes that l.confMu is locked for reading.
func (l *queryLog) isIgnored(host string) bool { func (l *queryLog) isIgnored(host string) bool {
return l.conf.Ignored.Has(host) return l.conf.Ignored.Has(host)
} }

View File

@ -34,13 +34,13 @@ func TestQueryLog(t *testing.T) {
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// Write to disk (first file). // Write to disk (first file).
require.NoError(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer())
// Start writing to the second file. // Start writing to the second file.
require.NoError(t, l.rotate()) require.NoError(t, l.rotate())
// Add disk entries. // Add disk entries.
addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// Write to disk. // Write to disk.
require.NoError(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer())
// Add memory entries. // Add memory entries.
addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
@ -144,7 +144,7 @@ func TestQueryLogOffsetLimit(t *testing.T) {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to the first file. // Write them to the first file.
require.NoError(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer())
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for i := 0; i < entNum; i++ { for i := 0; i < entNum; i++ {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
@ -216,7 +216,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to disk. // Write them to disk.
require.NoError(t, l.flushLogBuffer(true)) require.NoError(t, l.flushLogBuffer())
params := newSearchParams() params := newSearchParams()

View File

@ -13,40 +13,23 @@ import (
// flushLogBuffer flushes the current buffer to file and resets the current // flushLogBuffer flushes the current buffer to file and resets the current
// buffer. // buffer.
func (l *queryLog) flushLogBuffer(fullFlush bool) (err error) { func (l *queryLog) flushLogBuffer() (err error) {
if !l.conf.FileEnabled {
return nil
}
l.fileFlushLock.Lock() l.fileFlushLock.Lock()
defer l.fileFlushLock.Unlock() defer l.fileFlushLock.Unlock()
// Flush the remainder to file.
var flushBuffer []*logEntry var flushBuffer []*logEntry
needFlush := fullFlush
func() { func() {
l.bufferLock.Lock() l.bufferLock.Lock()
defer l.bufferLock.Unlock() defer l.bufferLock.Unlock()
needFlush = needFlush || len(l.buffer) >= int(l.conf.MemSize) flushBuffer = l.buffer
if needFlush { l.buffer = nil
flushBuffer = l.buffer l.flushPending = false
l.buffer = nil
l.flushPending = false
}
}() }()
if !needFlush {
return nil
}
err = l.flushToFile(flushBuffer) err = l.flushToFile(flushBuffer)
if err != nil {
log.Error("querylog: writing to file: %s", err)
return err return errors.Annotate(err, "writing to file: %w")
}
return nil
} }
// flushToFile saves the specified log entries to the query log file // flushToFile saves the specified log entries to the query log file
@ -167,8 +150,13 @@ func (l *queryLog) periodicRotate() {
// checkAndRotate rotates log files if those are older than the specified // checkAndRotate rotates log files if those are older than the specified
// rotation interval. // rotation interval.
func (l *queryLog) checkAndRotate() { func (l *queryLog) checkAndRotate() {
l.confMu.RLock() var rotationIvl time.Duration
defer l.confMu.RUnlock() func() {
l.confMu.RLock()
defer l.confMu.RUnlock()
rotationIvl = l.conf.RotationIvl
}()
oldest, err := l.readFileFirstTimeValue() oldest, err := l.readFileFirstTimeValue()
if err != nil && !errors.Is(err, os.ErrNotExist) { if err != nil && !errors.Is(err, os.ErrNotExist) {
@ -177,11 +165,11 @@ func (l *queryLog) checkAndRotate() {
return return
} }
if rot, now := oldest.Add(l.conf.RotationIvl), time.Now(); rot.After(now) { if rotTime, now := oldest.Add(rotationIvl), time.Now(); rotTime.After(now) {
log.Debug( log.Debug(
"querylog: %s <= %s, not rotating", "querylog: %s <= %s, not rotating",
now.Format(time.RFC3339), now.Format(time.RFC3339),
rot.Format(time.RFC3339), rotTime.Format(time.RFC3339),
) )
return return

View File

@ -161,10 +161,10 @@ run_linter "$GO" vet ./...
run_linter govulncheck ./... run_linter govulncheck ./...
# Apply more lax standards to the code we haven't properly refactored yet. # Apply more lax standards to the code we haven't properly refactored yet.
run_linter gocyclo --over 14 ./internal/querylog/
run_linter gocyclo --over 13\ run_linter gocyclo --over 13\
./internal/dhcpd\ ./internal/dhcpd\
./internal/home/\ ./internal/home/\
./internal/querylog/\
; ;
# Apply the normal standards to new or somewhat refactored code. # Apply the normal standards to new or somewhat refactored code.