diff --git a/internal/client/addrproc.go b/internal/client/addrproc.go index 76ff1367..35293609 100644 --- a/internal/client/addrproc.go +++ b/internal/client/addrproc.go @@ -2,6 +2,7 @@ package client import ( "context" + "log/slog" "net/netip" "sync" "time" @@ -11,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" ) @@ -38,6 +40,10 @@ func (EmptyAddrProc) Close() (_ error) { return nil } // DefaultAddrProcConfig is the configuration structure for address processors. type DefaultAddrProcConfig struct { + // BaseLogger is used to create loggers with custom prefixes for sources of + // information about runtime clients. It must not be nil. + BaseLogger *slog.Logger + // DialContext is used to create TCP connections to WHOIS servers. // DialContext must not be nil if [DefaultAddrProcConfig.UseWHOIS] is true. DialContext aghnet.DialContextFunc @@ -147,6 +153,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { if c.UseRDNS { p.rdns = rdns.New(&rdns.Config{ + Logger: c.BaseLogger.With(slogutil.KeyPrefix, "rdns"), Exchanger: c.Exchanger, CacheSize: defaultCacheSize, CacheTTL: defaultIPTTL, @@ -154,7 +161,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { } if c.UseWHOIS { - p.whois = newWHOIS(c.DialContext) + p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext) } go p.process(c.CatchPanics) @@ -168,7 +175,7 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { // newWHOIS returns a whois.Interface instance using the given function for // dialing. -func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) { +func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Interface) { // TODO(s.chzhen): Consider making configurable. const ( // defaultTimeout is the timeout for WHOIS requests. @@ -186,6 +193,7 @@ func newWHOIS(dialFunc aghnet.DialContextFunc) (w whois.Interface) { ) return whois.New(&whois.Config{ + Logger: logger, DialContext: dialFunc, ServerAddr: whois.DefaultServer, Port: whois.DefaultPort, @@ -227,9 +235,11 @@ func (p *DefaultAddrProc) process(catchPanics bool) { log.Info("clients: processing addresses") + ctx := context.TODO() + for ip := range p.clientIPs { - host := p.processRDNS(ip) - info := p.processWHOIS(ip) + host := p.processRDNS(ctx, ip) + info := p.processWHOIS(ctx, ip) p.addrUpdater.UpdateAddress(ip, host, info) } @@ -239,7 +249,7 @@ func (p *DefaultAddrProc) process(catchPanics bool) { // processRDNS resolves the clients' IP addresses using reverse DNS. host is // empty if there were errors or if the information hasn't changed. -func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) { +func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) { start := time.Now() log.Debug("clients: processing %s with rdns", ip) defer func() { @@ -251,7 +261,7 @@ func (p *DefaultAddrProc) processRDNS(ip netip.Addr) (host string) { return } - host, changed := p.rdns.Process(ip) + host, changed := p.rdns.Process(ctx, ip) if !changed { host = "" } @@ -268,7 +278,7 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) { // processWHOIS looks up the information about clients' IP addresses in the // WHOIS databases. info is nil if there were errors or if the information // hasn't changed. -func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) { +func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) { start := time.Now() log.Debug("clients: processing %s with whois", ip) defer func() { @@ -277,7 +287,7 @@ func (p *DefaultAddrProc) processWHOIS(ip netip.Addr) (info *whois.Info) { // TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the // context. - info, changed := p.whois.Process(context.Background(), ip) + info, changed := p.whois.Process(ctx, ip) if !changed { info = nil } diff --git a/internal/client/addrproc_test.go b/internal/client/addrproc_test.go index f0d0a8f7..3df3a5c7 100644 --- a/internal/client/addrproc_test.go +++ b/internal/client/addrproc_test.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil/fakenet" @@ -99,6 +100,7 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) { updInfoCh := make(chan *whois.Info, 1) p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{ + BaseLogger: slogutil.NewDiscardLogger(), DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) { panic("not implemented") }, @@ -208,6 +210,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) { updInfoCh := make(chan *whois.Info, 1) p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{ + BaseLogger: slogutil.NewDiscardLogger(), DialContext: func(_ context.Context, _, _ string) (conn net.Conn, err error) { return whoisConn, nil }, diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index c0db6c9d..097930bc 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -218,8 +218,8 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { } srv := &Server{ - conf: ServerConfig{TLSConfig: tlsConf}, - logger: slogutil.NewDiscardLogger(), + conf: ServerConfig{TLSConfig: tlsConf}, + baseLogger: slogutil.NewDiscardLogger(), } var ( diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 7e12e8c3..c2054217 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -318,6 +318,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) conf = &proxy.Config{ + Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"), HTTP3: srvConf.ServeHTTP3, Ratelimit: int(srvConf.Ratelimit), RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, @@ -342,10 +343,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { MessageConstructor: s, } - if s.logger != nil { - conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") - } - if srvConf.EDNSClientSubnet.UseCustom { // TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy. conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice()) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 30d7e731..107eea39 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -123,11 +123,9 @@ type Server struct { // access drops disallowed clients. access *accessManager - // logger is used for logging during server routines. - // - // TODO(d.kolyshev): Make it never nil. - // TODO(d.kolyshev): Use this logger. - logger *slog.Logger + // baseLogger is used to create loggers for other entities. It should not + // have a prefix and must not be nil. + baseLogger *slog.Logger // localDomainSuffix is the suffix used to detect internal hosts. It // must be a valid domain name plus dots on each side. @@ -246,7 +244,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { stats: p.Stats, queryLog: p.QueryLog, privateNets: p.PrivateNets, - logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), + baseLogger: p.Logger, // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, @@ -615,7 +613,7 @@ func (s *Server) prepareInternalDNS() (err error) { return fmt.Errorf("preparing ipset settings: %w", err) } - ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset") + ipsetLogger := s.baseLogger.With(slogutil.KeyPrefix, "ipset") s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList) if err != nil { // Don't wrap the error, because it's informative enough as is. @@ -685,6 +683,7 @@ func (s *Server) setupAddrProc() { s.addrProc = client.EmptyAddrProc{} } else { c := s.conf.AddrProcConf + c.BaseLogger = s.baseLogger c.DialContext = s.DialContext c.PrivateSubnets = s.privateNets c.UsePrivateRDNS = s.conf.UsePrivateRDNS @@ -728,6 +727,7 @@ func validateBlockingMode( func (s *Server) prepareInternalProxy() (err error) { srvConf := s.conf conf := &proxy.Config{ + Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"), CacheEnabled: true, CacheSizeBytes: 4096, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig, @@ -740,10 +740,6 @@ func (s *Server) prepareInternalProxy() (err error) { MessageConstructor: s, } - if s.logger != nil { - conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy") - } - err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) if err != nil { return fmt.Errorf("invalid upstream mode: %w", err) diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index ba1133b9..9781b3b0 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -431,7 +431,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: dhcp, localDomainSuffix: localDomainSuffix, - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ @@ -567,7 +567,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { dnsFilter: createTestDNSFilter(t), dhcpServer: testDHCP, localDomainSuffix: tc.suffix, - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), } req := &dns.Msg{ diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 8626c180..6e4d5d86 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -203,7 +203,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) { ql := &testQueryLog{} st := &testStats{} srv := &Server{ - logger: slogutil.NewDiscardLogger(), + baseLogger: slogutil.NewDiscardLogger(), queryLog: ql, stats: st, anonymizer: aghnet.NewIPMut(nil), diff --git a/internal/rdns/rdns.go b/internal/rdns/rdns.go index 7130ccf3..de84243c 100644 --- a/internal/rdns/rdns.go +++ b/internal/rdns/rdns.go @@ -2,11 +2,13 @@ package rdns import ( + "context" + "log/slog" "net/netip" "time" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/bluele/gcache" ) @@ -14,7 +16,7 @@ import ( type Interface interface { // Process makes rDNS request and returns domain name. changed indicates // that domain name was updated since last request. - Process(ip netip.Addr) (host string, changed bool) + Process(ctx context.Context, ip netip.Addr) (host string, changed bool) } // Empty is an empty [Interface] implementation which does nothing. @@ -24,7 +26,7 @@ type Empty struct{} var _ Interface = (*Empty)(nil) // Process implements the [Interface] interface for Empty. -func (Empty) Process(_ netip.Addr) (host string, changed bool) { +func (Empty) Process(_ context.Context, _ netip.Addr) (host string, changed bool) { return "", false } @@ -37,6 +39,10 @@ type Exchanger interface { // Config is the configuration structure for Default. type Config struct { + // Logger is used for logging the operation of the reverse DNS lookup + // queries. It must not be nil. + Logger *slog.Logger + // Exchanger resolves IP addresses to domain names. Exchanger Exchanger @@ -50,6 +56,10 @@ type Config struct { // Default is the default rDNS query processor. type Default struct { + // logger is used for logging the operation of the reverse DNS lookup + // queries. It must not be nil. + logger *slog.Logger + // cache is the cache containing IP addresses of clients. An active IP // address is resolved once again after it expires. If IP address couldn't // be resolved, it stays here for some time to prevent further attempts to @@ -66,6 +76,7 @@ type Default struct { // New returns a new default rDNS query processor. conf must not be nil. func New(conf *Config) (r *Default) { return &Default{ + logger: conf.Logger, cache: gcache.New(conf.CacheSize).LRU().Build(), exchanger: conf.Exchanger, cacheTTL: conf.CacheTTL, @@ -76,15 +87,15 @@ func New(conf *Config) (r *Default) { var _ Interface = (*Default)(nil) // Process implements the [Interface] interface for Default. -func (r *Default) Process(ip netip.Addr) (host string, changed bool) { - fromCache, expired := r.findInCache(ip) +func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, changed bool) { + fromCache, expired := r.findInCache(ctx, ip) if !expired { return fromCache, false } host, ttl, err := r.exchanger.Exchange(ip) if err != nil { - log.Debug("rdns: resolving %q: %s", ip, err) + r.logger.DebugContext(ctx, "resolving", "ip", ip, slogutil.KeyError, err) } ttl = max(ttl, r.cacheTTL) @@ -96,7 +107,7 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) { err = r.cache.Set(ip, item) if err != nil { - log.Debug("rdns: cache: adding item %q: %s", ip, err) + r.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err) } // TODO(e.burkov): The name doesn't change if it's neither stored in cache @@ -106,22 +117,22 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) { // findInCache finds domain name in the cache. expired is true if host is not // valid anymore. -func (r *Default) findInCache(ip netip.Addr) (host string, expired bool) { +func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string, expired bool) { val, err := r.cache.Get(ip) if err != nil { if !errors.Is(err, gcache.KeyNotFoundError) { - log.Debug("rdns: cache: retrieving %q: %s", ip, err) + r.logger.DebugContext( + ctx, + "retrieving item from cache", + "key", ip, + slogutil.KeyError, err, + ) } return "", true } - item, ok := val.(*cacheItem) - if !ok { - log.Debug("rdns: cache: %q bad type %T", ip, val) - - return "", true - } + item := val.(*cacheItem) return item.host, time.Now().After(item.expiry) } diff --git a/internal/rdns/rdns_test.go b/internal/rdns/rdns_test.go index 0db13728..f0b27ed8 100644 --- a/internal/rdns/rdns_test.go +++ b/internal/rdns/rdns_test.go @@ -8,10 +8,14 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testTimeout is a common timeout for tests and contexts. +const testTimeout = 1 * time.Second + func TestDefault_Process(t *testing.T) { ip1 := netip.MustParseAddr("1.2.3.4") revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice()) @@ -71,14 +75,14 @@ func TestDefault_Process(t *testing.T) { Exchanger: &aghtest.Exchanger{OnExchange: onExchange}, }) - got, changed := r.Process(tc.addr) + got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr) require.True(t, changed) assert.Equal(t, tc.want, got) assert.Equal(t, 1, hit) // From cache. - got, changed = r.Process(tc.addr) + got, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), tc.addr) require.False(t, changed) assert.Equal(t, tc.want, got) @@ -101,7 +105,7 @@ func TestDefault_Process(t *testing.T) { Exchanger: zeroTTLExchanger, }) - got, changed := r.Process(ip1) + got, changed := r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1) require.True(t, changed) assert.Equal(t, revAddr1, got) @@ -109,14 +113,15 @@ func TestDefault_Process(t *testing.T) { return revAddr2, time.Hour, nil } + ctx := testutil.ContextWithTimeout(t, testTimeout) require.EventuallyWithT(t, func(t *assert.CollectT) { - got, changed = r.Process(ip1) + got, changed = r.Process(ctx, ip1) assert.True(t, changed) assert.Equal(t, revAddr2, got) }, 2*cacheTTL, time.Millisecond*100) assert.Never(t, func() (changed bool) { - _, changed = r.Process(ip1) + _, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1) return changed }, 2*cacheTTL, time.Millisecond*100) diff --git a/internal/whois/whois.go b/internal/whois/whois.go index 3f48894e..3dffd7a1 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "log/slog" "net" "net/netip" "strconv" @@ -16,9 +17,10 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/ioutil" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/bluele/gcache" + "github.com/c2h5oh/datasize" ) const ( @@ -49,6 +51,10 @@ func (Empty) Process(_ context.Context, _ netip.Addr) (info *Info, changed bool) // Config is the configuration structure for Default. type Config struct { + // Logger is used for logging the operation of the WHOIS lookup queries. It + // must not be nil. + Logger *slog.Logger + // DialContext is used to create TCP connections to WHOIS servers. DialContext aghnet.DialContextFunc @@ -80,6 +86,10 @@ type Config struct { // Default is the default WHOIS information processor. type Default struct { + // logger is used for logging the operation of the WHOIS lookup queries. It + // must not be nil. + logger *slog.Logger + // cache is the cache containing IP addresses of clients. An active IP // address is resolved once again after it expires. If IP address couldn't // be resolved, it stays here for some time to prevent further attempts to @@ -115,6 +125,7 @@ type Default struct { // nil. func New(conf *Config) (w *Default) { return &Default{ + logger: conf.Logger, serverAddr: conf.ServerAddr, dialContext: conf.DialContext, timeout: conf.Timeout, @@ -230,16 +241,22 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data [] // queryAll queries WHOIS server and handles redirects. func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) { server := net.JoinHostPort(w.serverAddr, w.portStr) - var data []byte for range w.maxRedirects { + var data []byte data, err = w.query(ctx, target, server) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err } - log.Debug("whois: received response (%d bytes) from %q about %q", len(data), server, target) + w.logger.DebugContext( + ctx, + "received response", + "size", datasize.ByteSize(len(data)), + "source", server, + "target", target, + ) info = whoisParse(data, w.maxInfoLen) redir, ok := info["whois"] @@ -256,7 +273,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string] server = redir } - log.Debug("whois: redirected to %q about %q", redir, target) + w.logger.DebugContext(ctx, "redirected", "destination", redir, "target", target) } return nil, fmt.Errorf("whois: redirect loop") @@ -272,7 +289,7 @@ func (w *Default) Process(ctx context.Context, ip netip.Addr) (wi *Info, changed return nil, false } - wi, expired := w.findInCache(ip) + wi, expired := w.findInCache(ctx, ip) if wi != nil && !expired { // Don't return an empty struct so that the frontend doesn't get // confused. @@ -299,13 +316,13 @@ func (w *Default) requestInfo( item := toCacheItem(info, w.cacheTTL) err := w.cache.Set(ip, item) if err != nil { - log.Debug("whois: cache: adding item %q: %s", ip, err) + w.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err) } }() kv, err := w.queryAll(ctx, ip.String()) if err != nil { - log.Debug("whois: querying %q: %s", ip, err) + w.logger.DebugContext(ctx, "querying", "target", ip, slogutil.KeyError, err) return nil, true } @@ -327,24 +344,22 @@ func (w *Default) requestInfo( } // findInCache finds Info in the cache. expired indicates that Info is valid. -func (w *Default) findInCache(ip netip.Addr) (wi *Info, expired bool) { +func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, expired bool) { val, err := w.cache.Get(ip) if err != nil { if !errors.Is(err, gcache.KeyNotFoundError) { - log.Debug("whois: cache: retrieving info about %q: %s", ip, err) + w.logger.DebugContext( + ctx, + "retrieving item from cache", + "key", ip, + slogutil.KeyError, err, + ) } return nil, false } - item, ok := val.(*cacheItem) - if !ok { - log.Debug("whois: cache: %q bad type %T", ip, val) - - return nil, false - } - - return fromCacheItem(item) + return fromCacheItem(val.(*cacheItem)) } // Info is the filtered WHOIS data for a runtime client. diff --git a/internal/whois/whois_test.go b/internal/whois/whois_test.go index 109ee1e9..4b425af2 100644 --- a/internal/whois/whois_test.go +++ b/internal/whois/whois_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil/fakenet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -119,6 +120,7 @@ func TestDefault_Process(t *testing.T) { } w := whois.New(&whois.Config{ + Logger: slogutil.NewDiscardLogger(), Timeout: 5 * time.Second, DialContext: func(_ context.Context, _, _ string) (_ net.Conn, _ error) { hit = 0