all: imp code

This commit is contained in:
Stanislav Chzhen 2024-08-27 20:40:01 +03:00
parent 0b1f022094
commit 6b3fda0a37
8 changed files with 26 additions and 56 deletions

View File

@ -219,7 +219,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
srv := &Server{ srv := &Server{
conf: ServerConfig{TLSConfig: tlsConf}, conf: ServerConfig{TLSConfig: tlsConf},
logger: slogutil.NewDiscardLogger(), baseLogger: slogutil.NewDiscardLogger(),
} }
var ( var (

View File

@ -318,6 +318,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{ conf = &proxy.Config{
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
HTTP3: srvConf.ServeHTTP3, HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit), Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
@ -342,10 +343,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
MessageConstructor: s, MessageConstructor: s,
} }
if s.logger != nil {
conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy")
}
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy. // TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice()) conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())

View File

@ -123,11 +123,9 @@ type Server struct {
// access drops disallowed clients. // access drops disallowed clients.
access *accessManager access *accessManager
// logger is used for logging during server routines. // baseLogger is used to create loggers for other entities. It should not
// // have a prefix and must not be nil.
// TODO(d.kolyshev): Make it never nil. baseLogger *slog.Logger
// TODO(d.kolyshev): Use this logger.
logger *slog.Logger
// localDomainSuffix is the suffix used to detect internal hosts. It // localDomainSuffix is the suffix used to detect internal hosts. It
// must be a valid domain name plus dots on each side. // must be a valid domain name plus dots on each side.
@ -246,7 +244,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
stats: p.Stats, stats: p.Stats,
queryLog: p.QueryLog, queryLog: p.QueryLog,
privateNets: p.PrivateNets, privateNets: p.PrivateNets,
logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), baseLogger: p.Logger,
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,
@ -615,7 +613,7 @@ func (s *Server) prepareInternalDNS() (err error) {
return fmt.Errorf("preparing ipset settings: %w", err) return fmt.Errorf("preparing ipset settings: %w", err)
} }
ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset") ipsetLogger := s.baseLogger.With(slogutil.KeyPrefix, "ipset")
s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList) s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
@ -685,7 +683,7 @@ func (s *Server) setupAddrProc() {
s.addrProc = client.EmptyAddrProc{} s.addrProc = client.EmptyAddrProc{}
} else { } else {
c := s.conf.AddrProcConf c := s.conf.AddrProcConf
c.BaseLogger = s.logger c.BaseLogger = s.baseLogger
c.DialContext = s.DialContext c.DialContext = s.DialContext
c.PrivateSubnets = s.privateNets c.PrivateSubnets = s.privateNets
c.UsePrivateRDNS = s.conf.UsePrivateRDNS c.UsePrivateRDNS = s.conf.UsePrivateRDNS
@ -729,6 +727,7 @@ func validateBlockingMode(
func (s *Server) prepareInternalProxy() (err error) { func (s *Server) prepareInternalProxy() (err error) {
srvConf := s.conf srvConf := s.conf
conf := &proxy.Config{ conf := &proxy.Config{
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
@ -741,10 +740,6 @@ func (s *Server) prepareInternalProxy() (err error) {
MessageConstructor: s, MessageConstructor: s,
} }
if s.logger != nil {
conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy")
}
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
if err != nil { if err != nil {
return fmt.Errorf("invalid upstream mode: %w", err) return fmt.Errorf("invalid upstream mode: %w", err)

View File

@ -431,7 +431,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dnsFilter: createTestDNSFilter(t), dnsFilter: createTestDNSFilter(t),
dhcpServer: dhcp, dhcpServer: dhcp,
localDomainSuffix: localDomainSuffix, localDomainSuffix: localDomainSuffix,
logger: slogutil.NewDiscardLogger(), baseLogger: slogutil.NewDiscardLogger(),
} }
req := &dns.Msg{ req := &dns.Msg{
@ -567,7 +567,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dnsFilter: createTestDNSFilter(t), dnsFilter: createTestDNSFilter(t),
dhcpServer: testDHCP, dhcpServer: testDHCP,
localDomainSuffix: tc.suffix, localDomainSuffix: tc.suffix,
logger: slogutil.NewDiscardLogger(), baseLogger: slogutil.NewDiscardLogger(),
} }
req := &dns.Msg{ req := &dns.Msg{

View File

@ -203,7 +203,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
ql := &testQueryLog{} ql := &testQueryLog{}
st := &testStats{} st := &testStats{}
srv := &Server{ srv := &Server{
logger: slogutil.NewDiscardLogger(), baseLogger: slogutil.NewDiscardLogger(),
queryLog: ql, queryLog: ql,
stats: st, stats: st,
anonymizer: aghnet.NewIPMut(nil), anonymizer: aghnet.NewIPMut(nil),

View File

@ -3,7 +3,6 @@ package rdns
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"net/netip" "net/netip"
"time" "time"
@ -96,7 +95,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
host, ttl, err := r.exchanger.Exchange(ip) host, ttl, err := r.exchanger.Exchange(ip)
if err != nil { if err != nil {
r.logger.DebugContext(ctx, "resolving ip", "ip", ip, slogutil.KeyError, err) r.logger.DebugContext(ctx, "resolving", "ip", ip, slogutil.KeyError, err)
} }
ttl = max(ttl, r.cacheTTL) ttl = max(ttl, r.cacheTTL)
@ -108,7 +107,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
err = r.cache.Set(ip, item) err = r.cache.Set(ip, item)
if err != nil { if err != nil {
r.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err) r.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err)
} }
// TODO(e.burkov): The name doesn't change if it's neither stored in cache // TODO(e.burkov): The name doesn't change if it's neither stored in cache
@ -125,7 +124,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string,
r.logger.DebugContext( r.logger.DebugContext(
ctx, ctx,
"retrieving item from cache", "retrieving item from cache",
"item", ip, "key", ip,
slogutil.KeyError, err, slogutil.KeyError, err,
) )
} }
@ -133,17 +132,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string,
return "", true return "", true
} }
item, ok := val.(*cacheItem) item := val.(*cacheItem)
if !ok {
r.logger.DebugContext(
ctx,
"bad type of cache item",
"item", ip,
"type", fmt.Sprintf("%T", val),
)
return "", true
}
return item.host, time.Now().After(item.expiry) return item.host, time.Now().After(item.expiry)
} }

View File

@ -1,7 +1,6 @@
package rdns_test package rdns_test
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@ -114,14 +113,15 @@ func TestDefault_Process(t *testing.T) {
return revAddr2, time.Hour, nil return revAddr2, time.Hour, nil
} }
ctx := testutil.ContextWithTimeout(t, testTimeout)
require.EventuallyWithT(t, func(t *assert.CollectT) { require.EventuallyWithT(t, func(t *assert.CollectT) {
got, changed = r.Process(context.TODO(), ip1) got, changed = r.Process(ctx, ip1)
assert.True(t, changed) assert.True(t, changed)
assert.Equal(t, revAddr2, got) assert.Equal(t, revAddr2, got)
}, 2*cacheTTL, time.Millisecond*100) }, 2*cacheTTL, time.Millisecond*100)
assert.Never(t, func() (changed bool) { assert.Never(t, func() (changed bool) {
_, changed = r.Process(context.TODO(), ip1) _, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1)
return changed return changed
}, 2*cacheTTL, time.Millisecond*100) }, 2*cacheTTL, time.Millisecond*100)

View File

@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
"github.com/c2h5oh/datasize"
) )
const ( const (
@ -240,9 +241,9 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
// queryAll queries WHOIS server and handles redirects. // queryAll queries WHOIS server and handles redirects.
func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) { func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) {
server := net.JoinHostPort(w.serverAddr, w.portStr) server := net.JoinHostPort(w.serverAddr, w.portStr)
var data []byte
for range w.maxRedirects { for range w.maxRedirects {
var data []byte
data, err = w.query(ctx, target, server) data, err = w.query(ctx, target, server)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
@ -252,7 +253,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string]
w.logger.DebugContext( w.logger.DebugContext(
ctx, ctx,
"received response", "received response",
"size", len(data), "size", datasize.ByteSize(len(data)),
"source", server, "source", server,
"target", target, "target", target,
) )
@ -315,7 +316,7 @@ func (w *Default) requestInfo(
item := toCacheItem(info, w.cacheTTL) item := toCacheItem(info, w.cacheTTL)
err := w.cache.Set(ip, item) err := w.cache.Set(ip, item)
if err != nil { if err != nil {
w.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err) w.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err)
} }
}() }()
@ -350,7 +351,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp
w.logger.DebugContext( w.logger.DebugContext(
ctx, ctx,
"retrieving item from cache", "retrieving item from cache",
"item", ip, "key", ip,
slogutil.KeyError, err, slogutil.KeyError, err,
) )
} }
@ -358,19 +359,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp
return nil, false return nil, false
} }
item, ok := val.(*cacheItem) return fromCacheItem(val.(*cacheItem))
if !ok {
w.logger.DebugContext(
ctx,
"bad type of cache item",
"item", ip,
"type", fmt.Sprintf("%T", val),
)
return nil, false
}
return fromCacheItem(item)
} }
// Info is the filtered WHOIS data for a runtime client. // Info is the filtered WHOIS data for a runtime client.