all: imp code
This commit is contained in:
parent
0b1f022094
commit
6b3fda0a37
|
@ -219,7 +219,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
conf: ServerConfig{TLSConfig: tlsConf},
|
conf: ServerConfig{TLSConfig: tlsConf},
|
||||||
logger: slogutil.NewDiscardLogger(),
|
baseLogger: slogutil.NewDiscardLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
|
@ -318,6 +318,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||||
|
|
||||||
conf = &proxy.Config{
|
conf = &proxy.Config{
|
||||||
|
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
|
||||||
HTTP3: srvConf.ServeHTTP3,
|
HTTP3: srvConf.ServeHTTP3,
|
||||||
Ratelimit: int(srvConf.Ratelimit),
|
Ratelimit: int(srvConf.Ratelimit),
|
||||||
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
|
||||||
|
@ -342,10 +343,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||||
MessageConstructor: s,
|
MessageConstructor: s,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.logger != nil {
|
|
||||||
conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy")
|
|
||||||
}
|
|
||||||
|
|
||||||
if srvConf.EDNSClientSubnet.UseCustom {
|
if srvConf.EDNSClientSubnet.UseCustom {
|
||||||
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
|
// TODO(s.chzhen): Use netip.Addr instead of net.IP inside dnsproxy.
|
||||||
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
|
conf.EDNSAddr = net.IP(srvConf.EDNSClientSubnet.CustomIP.AsSlice())
|
||||||
|
|
|
@ -123,11 +123,9 @@ type Server struct {
|
||||||
// access drops disallowed clients.
|
// access drops disallowed clients.
|
||||||
access *accessManager
|
access *accessManager
|
||||||
|
|
||||||
// logger is used for logging during server routines.
|
// baseLogger is used to create loggers for other entities. It should not
|
||||||
//
|
// have a prefix and must not be nil.
|
||||||
// TODO(d.kolyshev): Make it never nil.
|
baseLogger *slog.Logger
|
||||||
// TODO(d.kolyshev): Use this logger.
|
|
||||||
logger *slog.Logger
|
|
||||||
|
|
||||||
// localDomainSuffix is the suffix used to detect internal hosts. It
|
// localDomainSuffix is the suffix used to detect internal hosts. It
|
||||||
// must be a valid domain name plus dots on each side.
|
// must be a valid domain name plus dots on each side.
|
||||||
|
@ -246,7 +244,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
stats: p.Stats,
|
stats: p.Stats,
|
||||||
queryLog: p.QueryLog,
|
queryLog: p.QueryLog,
|
||||||
privateNets: p.PrivateNets,
|
privateNets: p.PrivateNets,
|
||||||
logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"),
|
baseLogger: p.Logger,
|
||||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||||
etcHosts: etcHosts,
|
etcHosts: etcHosts,
|
||||||
|
@ -615,7 +613,7 @@ func (s *Server) prepareInternalDNS() (err error) {
|
||||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ipsetLogger := s.logger.With(slogutil.KeyPrefix, "ipset")
|
ipsetLogger := s.baseLogger.With(slogutil.KeyPrefix, "ipset")
|
||||||
s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList)
|
s.ipset, err = newIpsetHandler(context.TODO(), ipsetLogger, ipsetList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
|
@ -685,7 +683,7 @@ func (s *Server) setupAddrProc() {
|
||||||
s.addrProc = client.EmptyAddrProc{}
|
s.addrProc = client.EmptyAddrProc{}
|
||||||
} else {
|
} else {
|
||||||
c := s.conf.AddrProcConf
|
c := s.conf.AddrProcConf
|
||||||
c.BaseLogger = s.logger
|
c.BaseLogger = s.baseLogger
|
||||||
c.DialContext = s.DialContext
|
c.DialContext = s.DialContext
|
||||||
c.PrivateSubnets = s.privateNets
|
c.PrivateSubnets = s.privateNets
|
||||||
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
c.UsePrivateRDNS = s.conf.UsePrivateRDNS
|
||||||
|
@ -729,6 +727,7 @@ func validateBlockingMode(
|
||||||
func (s *Server) prepareInternalProxy() (err error) {
|
func (s *Server) prepareInternalProxy() (err error) {
|
||||||
srvConf := s.conf
|
srvConf := s.conf
|
||||||
conf := &proxy.Config{
|
conf := &proxy.Config{
|
||||||
|
Logger: s.baseLogger.With(slogutil.KeyPrefix, "dnsproxy"),
|
||||||
CacheEnabled: true,
|
CacheEnabled: true,
|
||||||
CacheSizeBytes: 4096,
|
CacheSizeBytes: 4096,
|
||||||
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
|
||||||
|
@ -741,10 +740,6 @@ func (s *Server) prepareInternalProxy() (err error) {
|
||||||
MessageConstructor: s,
|
MessageConstructor: s,
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.logger != nil {
|
|
||||||
conf.Logger = s.logger.With(slogutil.KeyPrefix, "dnsproxy")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
|
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid upstream mode: %w", err)
|
return fmt.Errorf("invalid upstream mode: %w", err)
|
||||||
|
|
|
@ -431,7 +431,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
dnsFilter: createTestDNSFilter(t),
|
dnsFilter: createTestDNSFilter(t),
|
||||||
dhcpServer: dhcp,
|
dhcpServer: dhcp,
|
||||||
localDomainSuffix: localDomainSuffix,
|
localDomainSuffix: localDomainSuffix,
|
||||||
logger: slogutil.NewDiscardLogger(),
|
baseLogger: slogutil.NewDiscardLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &dns.Msg{
|
req := &dns.Msg{
|
||||||
|
@ -567,7 +567,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
dnsFilter: createTestDNSFilter(t),
|
dnsFilter: createTestDNSFilter(t),
|
||||||
dhcpServer: testDHCP,
|
dhcpServer: testDHCP,
|
||||||
localDomainSuffix: tc.suffix,
|
localDomainSuffix: tc.suffix,
|
||||||
logger: slogutil.NewDiscardLogger(),
|
baseLogger: slogutil.NewDiscardLogger(),
|
||||||
}
|
}
|
||||||
|
|
||||||
req := &dns.Msg{
|
req := &dns.Msg{
|
||||||
|
|
|
@ -203,7 +203,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||||
ql := &testQueryLog{}
|
ql := &testQueryLog{}
|
||||||
st := &testStats{}
|
st := &testStats{}
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
logger: slogutil.NewDiscardLogger(),
|
baseLogger: slogutil.NewDiscardLogger(),
|
||||||
queryLog: ql,
|
queryLog: ql,
|
||||||
stats: st,
|
stats: st,
|
||||||
anonymizer: aghnet.NewIPMut(nil),
|
anonymizer: aghnet.NewIPMut(nil),
|
||||||
|
|
|
@ -3,7 +3,6 @@ package rdns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
@ -96,7 +95,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
|
||||||
|
|
||||||
host, ttl, err := r.exchanger.Exchange(ip)
|
host, ttl, err := r.exchanger.Exchange(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.DebugContext(ctx, "resolving ip", "ip", ip, slogutil.KeyError, err)
|
r.logger.DebugContext(ctx, "resolving", "ip", ip, slogutil.KeyError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ttl = max(ttl, r.cacheTTL)
|
ttl = max(ttl, r.cacheTTL)
|
||||||
|
@ -108,7 +107,7 @@ func (r *Default) Process(ctx context.Context, ip netip.Addr) (host string, chan
|
||||||
|
|
||||||
err = r.cache.Set(ip, item)
|
err = r.cache.Set(ip, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err)
|
r.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(e.burkov): The name doesn't change if it's neither stored in cache
|
// TODO(e.burkov): The name doesn't change if it's neither stored in cache
|
||||||
|
@ -125,7 +124,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string,
|
||||||
r.logger.DebugContext(
|
r.logger.DebugContext(
|
||||||
ctx,
|
ctx,
|
||||||
"retrieving item from cache",
|
"retrieving item from cache",
|
||||||
"item", ip,
|
"key", ip,
|
||||||
slogutil.KeyError, err,
|
slogutil.KeyError, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -133,17 +132,7 @@ func (r *Default) findInCache(ctx context.Context, ip netip.Addr) (host string,
|
||||||
return "", true
|
return "", true
|
||||||
}
|
}
|
||||||
|
|
||||||
item, ok := val.(*cacheItem)
|
item := val.(*cacheItem)
|
||||||
if !ok {
|
|
||||||
r.logger.DebugContext(
|
|
||||||
ctx,
|
|
||||||
"bad type of cache item",
|
|
||||||
"item", ip,
|
|
||||||
"type", fmt.Sprintf("%T", val),
|
|
||||||
)
|
|
||||||
|
|
||||||
return "", true
|
|
||||||
}
|
|
||||||
|
|
||||||
return item.host, time.Now().After(item.expiry)
|
return item.host, time.Now().After(item.expiry)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package rdns_test
|
package rdns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -114,14 +113,15 @@ func TestDefault_Process(t *testing.T) {
|
||||||
return revAddr2, time.Hour, nil
|
return revAddr2, time.Hour, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||||
require.EventuallyWithT(t, func(t *assert.CollectT) {
|
require.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||||
got, changed = r.Process(context.TODO(), ip1)
|
got, changed = r.Process(ctx, ip1)
|
||||||
assert.True(t, changed)
|
assert.True(t, changed)
|
||||||
assert.Equal(t, revAddr2, got)
|
assert.Equal(t, revAddr2, got)
|
||||||
}, 2*cacheTTL, time.Millisecond*100)
|
}, 2*cacheTTL, time.Millisecond*100)
|
||||||
|
|
||||||
assert.Never(t, func() (changed bool) {
|
assert.Never(t, func() (changed bool) {
|
||||||
_, changed = r.Process(context.TODO(), ip1)
|
_, changed = r.Process(testutil.ContextWithTimeout(t, testTimeout), ip1)
|
||||||
|
|
||||||
return changed
|
return changed
|
||||||
}, 2*cacheTTL, time.Millisecond*100)
|
}, 2*cacheTTL, time.Millisecond*100)
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/bluele/gcache"
|
"github.com/bluele/gcache"
|
||||||
|
"github.com/c2h5oh/datasize"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -240,9 +241,9 @@ func (w *Default) query(ctx context.Context, target, serverAddr string) (data []
|
||||||
// queryAll queries WHOIS server and handles redirects.
|
// queryAll queries WHOIS server and handles redirects.
|
||||||
func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) {
|
func (w *Default) queryAll(ctx context.Context, target string) (info map[string]string, err error) {
|
||||||
server := net.JoinHostPort(w.serverAddr, w.portStr)
|
server := net.JoinHostPort(w.serverAddr, w.portStr)
|
||||||
var data []byte
|
|
||||||
|
|
||||||
for range w.maxRedirects {
|
for range w.maxRedirects {
|
||||||
|
var data []byte
|
||||||
data, err = w.query(ctx, target, server)
|
data, err = w.query(ctx, target, server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since it's informative enough as is.
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
@ -252,7 +253,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string]
|
||||||
w.logger.DebugContext(
|
w.logger.DebugContext(
|
||||||
ctx,
|
ctx,
|
||||||
"received response",
|
"received response",
|
||||||
"size", len(data),
|
"size", datasize.ByteSize(len(data)),
|
||||||
"source", server,
|
"source", server,
|
||||||
"target", target,
|
"target", target,
|
||||||
)
|
)
|
||||||
|
@ -315,7 +316,7 @@ func (w *Default) requestInfo(
|
||||||
item := toCacheItem(info, w.cacheTTL)
|
item := toCacheItem(info, w.cacheTTL)
|
||||||
err := w.cache.Set(ip, item)
|
err := w.cache.Set(ip, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.logger.DebugContext(ctx, "adding item to cache", "item", ip, slogutil.KeyError, err)
|
w.logger.DebugContext(ctx, "adding item to cache", "key", ip, slogutil.KeyError, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -350,7 +351,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp
|
||||||
w.logger.DebugContext(
|
w.logger.DebugContext(
|
||||||
ctx,
|
ctx,
|
||||||
"retrieving item from cache",
|
"retrieving item from cache",
|
||||||
"item", ip,
|
"key", ip,
|
||||||
slogutil.KeyError, err,
|
slogutil.KeyError, err,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -358,19 +359,7 @@ func (w *Default) findInCache(ctx context.Context, ip netip.Addr) (wi *Info, exp
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
item, ok := val.(*cacheItem)
|
return fromCacheItem(val.(*cacheItem))
|
||||||
if !ok {
|
|
||||||
w.logger.DebugContext(
|
|
||||||
ctx,
|
|
||||||
"bad type of cache item",
|
|
||||||
"item", ip,
|
|
||||||
"type", fmt.Sprintf("%T", val),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return fromCacheItem(item)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Info is the filtered WHOIS data for a runtime client.
|
// Info is the filtered WHOIS data for a runtime client.
|
||||||
|
|
Loading…
Reference in New Issue