Pull request 2159: Upd proxy

Squashed commit of the following:

commit 4468e826bcf8e856014059cac7cd83aa6d754ab0
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Mar 11 13:50:36 2024 +0300

    all: upd dnsproxy

commit 7887f521799b75ea43aac8eca0586d26a1287c94
Merge: 9120da6ab 36f9fecae
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Mar 11 13:50:09 2024 +0300

    Merge branch 'master' into upd-proxy

commit 9120da6ab841fc8f9090c16f0cfe6a103073c772
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Mar 5 11:50:16 2024 +0300

    all: upd proxy
This commit is contained in:
Eugene Burkov 2024-03-11 18:17:04 +03:00
parent 36f9fecaed
commit 28a6b9f303
12 changed files with 319 additions and 278 deletions

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.21.8 go 1.21.8
require ( require (
github.com/AdguardTeam/dnsproxy v0.65.2 github.com/AdguardTeam/dnsproxy v0.66.0
github.com/AdguardTeam/golibs v0.20.1 github.com/AdguardTeam/golibs v0.20.1
github.com/AdguardTeam/urlfilter v0.18.0 github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View File

@ -1,5 +1,5 @@
github.com/AdguardTeam/dnsproxy v0.65.2 h1:D+BMw0Vu2lbQrYpoPctG2Xr+24KdfhgkzZb6QgPZheM= github.com/AdguardTeam/dnsproxy v0.66.0 h1:RyUbyDxRSXBFjVG1l2/4HV3I98DtfIgpnZkgXkgHKnc=
github.com/AdguardTeam/dnsproxy v0.65.2/go.mod h1:8NQTTNZY+qR9O1Fzgz3WQv30knfSgms68SRlzSnX74A= github.com/AdguardTeam/dnsproxy v0.66.0/go.mod h1:ZThEXbMUlP1RxfwtNW30ItPAHE6OF4YFygK8qjU/cvY=
github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY= github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY=
github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU= github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=

View File

@ -9,8 +9,13 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -71,3 +76,49 @@ func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
return srv.Client(), u return srv.Client(), u
} }
// testTimeout is a timeout for tests.
//
// TODO(e.burkov): Move into agdctest.
const testTimeout = 1 * time.Second
// StartLocalhostUpstream is a test helper that starts a DNS server on
// localhost.
func StartLocalhostUpstream(t *testing.T, h dns.Handler) (addr *url.URL) {
t.Helper()
startCh := make(chan netip.AddrPort)
defer close(startCh)
errCh := make(chan error)
srv := &dns.Server{
Addr: "127.0.0.1:0",
Net: string(proxy.ProtoTCP),
Handler: h,
ReadTimeout: testTimeout,
WriteTimeout: testTimeout,
}
srv.NotifyStartedFunc = func() {
addrPort := srv.Listener.Addr()
startCh <- netutil.NetAddrToAddrPort(addrPort)
}
go func() { errCh <- srv.ListenAndServe() }()
select {
case addrPort := <-startCh:
addr = &url.URL{
Scheme: string(proxy.ProtoTCP),
Host: addrPort.String(),
}
testutil.CleanupAndRequireSuccess(t, func() (err error) { return <-errCh })
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
case err := <-errCh:
require.NoError(t, err)
case <-time.After(testTimeout):
require.FailNow(t, "timeout exceeded")
}
return addr
}

View File

@ -357,10 +357,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
conf.DNSCryptResolverCert = c.ResolverCert conf.DNSCryptResolverCert = c.ResolverCert
} }
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
return nil, errors.Error("no default upstream servers configured")
}
conf, err = prepareCacheConfig(conf, conf, err = prepareCacheConfig(conf,
srvConf.CacheSize, srvConf.CacheSize,
srvConf.CacheMinTTL, srvConf.CacheMinTTL,

View File

@ -8,7 +8,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -101,21 +100,6 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
type answerMap = map[uint16][sectionsNum][]dns.RR type answerMap = map[uint16][sectionsNum][]dns.RR
pt := testutil.PanicT{} pt := testutil.PanicT{}
newUps := func(answers answerMap) (u upstream.Upstream) {
return aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
q := req.Question[0]
require.Contains(pt, answers, q.Qtype)
answer := answers[q.Qtype]
resp = (&dns.Msg{}).SetReply(req)
resp.Answer = answer[sectionAnswer]
resp.Ns = answer[sectionAuthority]
resp.Extra = answer[sectionAdditional]
return resp, nil
})
}
testCases := []struct { testCases := []struct {
name string name string
@ -265,13 +249,16 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
}} }}
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain) localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
localUps := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Equal(pt, req.Question[0].Name, ptr64Domain) require.Len(pt, m.Question, 1)
resp = (&dns.Msg{}).SetReply(req) require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp.Answer = []dns.RR{localRR} resp := (&dns.Msg{
Answer: []dns.RR{localRR},
}).SetReply(m)
return resp, nil require.NoError(t, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{ client := &dns.Client{
Net: "tcp", Net: "tcp",
@ -279,25 +266,44 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
} }
for _, tc := range testCases { for _, tc := range testCases {
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be reused tc := tc
// right after stop, due to a data race in [proxy.Proxy.Init] method
// when setting an OOB size. As a temporary workaround, recreate the
// whole server for each test case.
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
}, localUps)
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)} upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{
Answer: answer[sectionAnswer],
Ns: answer[sectionAuthority],
Extra: answer[sectionAdditional],
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
upsAddr := aghtest.StartLocalhostUpstream(t, upsHdlr).String()
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be
// reused right after stop, due to a data race in [proxy.Proxy.Init]
// method when setting an OOB size. As a temporary workaround,
// recreate the whole server for each test case.
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s) startDeferStop(t, s)
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)

View File

@ -464,7 +464,8 @@ func (s *Server) Start() error {
// startLocked starts the DNS server without locking. s.serverLock is expected // startLocked starts the DNS server without locking. s.serverLock is expected
// to be locked. // to be locked.
func (s *Server) startLocked() error { func (s *Server) startLocked() error {
err := s.dnsProxy.Start() // TODO(e.burkov): Use context properly.
err := s.dnsProxy.Start(context.Background())
if err == nil { if err == nil {
s.isRunning = true s.isRunning = true
} }
@ -518,34 +519,30 @@ func (s *Server) prepareLocalResolvers(
} }
// setupLocalResolvers initializes and sets the resolvers for local addresses. // setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. // It assumes s.serverLock is locked or s not running. It returns the upstream
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) { // configuration used for private PTR resolving, or nil if it's disabled. Note,
uc, err := s.prepareLocalResolvers(boot) // that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil { if err != nil {
// Don't wrap the error because it's informative enough as is. // Don't wrap the error because it's informative enough as is.
return err return nil, err
} }
s.localResolvers = &proxy.Proxy{ s.localResolvers, err = proxy.New(&proxy.Config{
Config: proxy.Config{ UpstreamConfig: uc,
UpstreamConfig: uc, })
},
}
err = s.localResolvers.Init()
if err != nil { if err != nil {
return fmt.Errorf("initializing proxy: %w", err) return nil, fmt.Errorf("creating local resolvers: %w", err)
} }
// TODO(e.burkov): Should we also consider the DNS64 usage? // TODO(e.burkov): Should we also consider the DNS64 usage?
if s.conf.UsePrivateRDNS && return uc, nil
// Only set the upstream config if there are any upstreams. It's safe
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
}
return nil
} }
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
@ -586,21 +583,22 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err) return fmt.Errorf("preparing access: %w", err)
} }
// Set the proxy here because [setupLocalResolvers] sets its values.
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig} proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
err = s.setupLocalResolvers(boot)
if err != nil { if err != nil {
return fmt.Errorf("setting up resolvers: %w", err) return fmt.Errorf("setting up resolvers: %w", err)
} }
err = s.setupFallbackDNS() proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil { if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err) return fmt.Errorf("setting up fallback dns servers: %w", err)
} }
s.dnsProxy, err = proxy.New(proxyConfig)
if err != nil {
return fmt.Errorf("creating proxy: %w", err)
}
s.recDetector.clear() s.recDetector.clear()
s.setupAddrProc() s.setupAddrProc()
@ -643,26 +641,25 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
} }
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
func (s *Server) setupFallbackDNS() (err error) { func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
fallbacks := s.conf.FallbackDNS fallbacks := s.conf.FallbackDNS
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty) fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
if len(fallbacks) == 0 { if len(fallbacks) == 0 {
return nil return nil, nil
} }
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{ uc, err = proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
// TODO(s.chzhen): Investigate if other options are needed. // TODO(s.chzhen): Investigate if other options are needed.
Timeout: s.conf.UpstreamTimeout, Timeout: s.conf.UpstreamTimeout,
PreferIPv6: s.conf.BootstrapPreferIPv6, PreferIPv6: s.conf.BootstrapPreferIPv6,
// TODO(e.burkov): Use bootstrap.
}) })
if err != nil { if err != nil {
// Do not wrap the error because it's informative enough as is. // Do not wrap the error because it's informative enough as is.
return err return nil, err
} }
s.dnsProxy.Fallbacks = uc return uc, nil
return nil
} }
// setupAddrProc initializes the address processor. It assumes s.serverLock is // setupAddrProc initializes the address processor. It assumes s.serverLock is
@ -730,19 +727,9 @@ func (s *Server) prepareInternalProxy() (err error) {
return fmt.Errorf("invalid upstream mode: %w", err) return fmt.Errorf("invalid upstream mode: %w", err)
} }
// TODO(a.garipov): Make a proper constructor for proxy.Proxy. s.internalProxy, err = proxy.New(conf)
p := &proxy.Proxy{
Config: *conf,
}
err = p.Init() return err
if err != nil {
return err
}
s.internalProxy = p
return nil
} }
// Stop stops the DNS server. // Stop stops the DNS server.
@ -761,14 +748,17 @@ func (s *Server) stopLocked() (err error) {
// [upstream.Upstream] implementations. // [upstream.Upstream] implementations.
if s.dnsProxy != nil { if s.dnsProxy != nil {
err = s.dnsProxy.Stop() // TODO(e.burkov): Use context properly.
err = s.dnsProxy.Shutdown(context.Background())
if err != nil { if err != nil {
log.Error("dnsforward: closing primary resolvers: %s", err) log.Error("dnsforward: closing primary resolvers: %s", err)
} }
} }
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s") logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s") if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers { for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())

View File

@ -5,9 +5,11 @@ import (
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/sha256"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/hex"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big" "math/big"
@ -63,8 +65,7 @@ func startDeferStop(t *testing.T, s *Server) {
t.Helper() t.Helper()
err := s.Start() err := s.Start()
require.NoErrorf(t, err, "failed to start server: %s", err) require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop) testutil.CleanupAndRequireSuccess(t, s.Stop)
} }
@ -72,7 +73,6 @@ func createTestServer(
t *testing.T, t *testing.T,
filterConf *filtering.Config, filterConf *filtering.Config,
forwardConf ServerConfig, forwardConf ServerConfig,
localUps upstream.Upstream,
) (s *Server) { ) (s *Server) {
t.Helper() t.Helper()
@ -82,7 +82,8 @@ func createTestServer(
@@||whitelist.example.org^ @@||whitelist.example.org^
||127.0.0.255` ||127.0.0.255`
filters := []filtering.Filter{{ filters := []filtering.Filter{{
ID: 0, Data: []byte(rules), ID: 0,
Data: []byte(rules),
}} }}
f, err := filtering.New(filterConf, filters) f, err := filtering.New(filterConf, filters)
@ -105,19 +106,6 @@ func createTestServer(
err = s.Prepare(&forwardConf) err = s.Prepare(&forwardConf)
require.NoError(t, err) require.NoError(t, err)
s.serverLock.Lock()
defer s.serverLock.Unlock()
// TODO(e.burkov): Try to move it higher.
if localUps != nil {
ups := []upstream.Upstream{localUps}
s.localResolvers.UpstreamConfig.Upstreams = ups
s.conf.UsePrivateRDNS = true
s.dnsProxy.PrivateRDNSUpstreamConfig = &proxy.UpstreamConfig{
Upstreams: ups,
}
}
return s return s
} }
@ -181,7 +169,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
s.conf.TLSConfig = tlsConf s.conf.TLSConfig = tlsConf
@ -310,7 +298,7 @@ func TestServer(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s) startDeferStop(t, s)
@ -410,7 +398,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s) startDeferStop(t, s)
@ -490,7 +478,7 @@ func TestServerRace(t *testing.T) {
ConfigModified: func() {}, ConfigModified: func() {},
ServePlainDNS: true, ServePlainDNS: true,
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
startDeferStop(t, s) startDeferStop(t, s)
@ -545,7 +533,7 @@ func TestSafeSearch(t *testing.T) {
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@ -628,7 +616,7 @@ func TestInvalidRequest(t *testing.T) {
}, },
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
@ -662,7 +650,7 @@ func TestBlockedRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true, ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil) }, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -698,7 +686,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
} }
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil) }, forwardConf)
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1) atomic.AddUint32(&upsCalledCounter, 1)
@ -773,7 +761,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
}, },
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
testUpstm := &aghtest.Upstream{ testUpstm := &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
IPv4: testIPv4, IPv4: testIPv4,
@ -811,7 +799,7 @@ func TestBlockCNAME(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true, ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil) }, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{ &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
@ -886,7 +874,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
} }
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil) }, forwardConf)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
&aghtest.Upstream{ &aghtest.Upstream{
CName: testCNAMEs, CName: testCNAMEs,
@ -933,7 +921,7 @@ func TestNullBlockedRequest(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true, ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeNullIP, BlockingMode: filtering.BlockingModeNullIP,
}, forwardConf, nil) }, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1054,7 +1042,7 @@ func TestBlockedByHosts(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
ProtectionEnabled: true, ProtectionEnabled: true,
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, forwardConf, nil) }, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1102,7 +1090,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
}, },
ServePlainDNS: true, ServePlainDNS: true,
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf)
startDeferStop(t, s) startDeferStop(t, s)
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1482,6 +1470,8 @@ func TestServer_Exchange(t *testing.T) {
onesIP = netip.MustParseAddr("1.1.1.1") onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2") twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1") localIP = netip.MustParseAddr("192.168.1.1")
pt = testutil.PanicT{}
) )
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
@ -1490,72 +1480,73 @@ func TestServer_Exchange(t *testing.T) {
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice()) twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err) require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{ extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
OnAddress: func() (addr string) { return "external.upstream.example" }, resp := aghalg.Coalesce(
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
return aghalg.Coalesce( doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)), )
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil require.NoError(pt, w.WriteMsg(resp))
}, })
} upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String()
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice()) revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
require.NoError(t, err) require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{ locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
OnAddress: func() (addr string) { return "local.upstream.example" }, resp := aghalg.Coalesce(
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
return aghalg.Coalesce( new(dns.Msg).SetRcode(req, dns.RcodeNameError),
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost), )
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
}
errUpstream := aghtest.NewErrorUpstream() require.NoError(pt, w.WriteMsg(resp))
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil
}) })
zeroTTLUps := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "zero.ttl.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = new(dns.Msg).SetReply(req)
hdr := dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
}
resp.Answer = []dns.RR{&dns.PTR{
Hdr: hdr,
Ptr: localDomainHost,
}}
return resp, nil errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
}, require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure)))
} })
srv := &Server{ nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
recDetector: newRecursionDetector(0, 1), hash := sha256.Sum256([]byte("some-host"))
internalProxy: &proxy.Proxy{ resp := (&dns.Msg{
Config: proxy.Config{ Answer: []dns.RR{&dns.TXT{
UpstreamConfig: &proxy.UpstreamConfig{ Hdr: dns.RR_Header{
Upstreams: []upstream.Upstream{extUpstream}, Name: req.Question[0].Name,
Rrtype: dns.TypeTXT,
Class: dns.ClassINET,
Ttl: 60,
}, },
}, Txt: []string{hex.EncodeToString(hash[:])},
}, }},
} }).SetReply(req)
srv.conf.UsePrivateRDNS = true
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) require.NoError(pt, w.WriteMsg(resp))
require.NoError(t, srv.internalProxy.Init()) })
refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)))
})
zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{
Answer: []dns.RR{&dns.PTR{
Hdr: dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 0,
},
Ptr: dns.Fqdn(localDomainHost),
}},
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp))
})
testCases := []struct { testCases := []struct {
req netip.Addr req netip.Addr
wantErr error wantErr error
locUpstream upstream.Upstream locUpstream dns.Handler
name string name string
want string want string
wantTTL time.Duration wantTTL time.Duration
@ -1570,35 +1561,35 @@ func TestServer_Exchange(t *testing.T) {
name: "local_good", name: "local_good",
want: localDomainHost, want: localDomainHost,
wantErr: nil, wantErr: nil,
locUpstream: locUpstream, locUpstream: locUpsHdlr,
req: localIP, req: localIP,
wantTTL: defaultTTL, wantTTL: defaultTTL,
}, { }, {
name: "upstream_error", name: "upstream_error",
want: "", want: "",
wantErr: aghtest.ErrUpstream, wantErr: ErrRDNSFailed,
locUpstream: errUpstream, locUpstream: errUpsHdlr,
req: localIP, req: localIP,
wantTTL: 0, wantTTL: 0,
}, { }, {
name: "empty_answer_error", name: "empty_answer_error",
want: "", want: "",
wantErr: ErrRDNSNoData, wantErr: ErrRDNSNoData,
locUpstream: locUpstream, locUpstream: locUpsHdlr,
req: netip.MustParseAddr("192.168.1.2"), req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0, wantTTL: 0,
}, { }, {
name: "invalid_answer", name: "invalid_answer",
want: "", want: "",
wantErr: ErrRDNSNoData, wantErr: ErrRDNSNoData,
locUpstream: nonPtrUpstream, locUpstream: nonPtrHdlr,
req: localIP, req: localIP,
wantTTL: 0, wantTTL: 0,
}, { }, {
name: "refused", name: "refused",
want: "", want: "",
wantErr: ErrRDNSFailed, wantErr: ErrRDNSFailed,
locUpstream: refusingUpstream, locUpstream: refusingHdlr,
req: localIP, req: localIP,
wantTTL: 0, wantTTL: 0,
}, { }, {
@ -1612,23 +1603,28 @@ func TestServer_Exchange(t *testing.T) {
name: "zero_ttl", name: "zero_ttl",
want: localDomainHost, want: localDomainHost,
wantErr: nil, wantErr: nil,
locUpstream: zeroTTLUps, locUpstream: zeroTTLHdlr,
req: localIP, req: localIP,
wantTTL: 0, wantTTL: 0,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
pcfg := proxy.Config{ localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String()
UpstreamConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{tc.locUpstream},
},
}
srv.localResolvers = &proxy.Proxy{
Config: pcfg,
}
require.NoError(t, srv.localResolvers.Init())
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
srv := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{localUpsAddr},
UsePrivateRDNS: true,
ServePlainDNS: true,
})
host, ttl, eerr := srv.Exchange(tc.req) host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr) require.ErrorIs(t, eerr, tc.wantErr)
@ -1638,8 +1634,17 @@ func TestServer_Exchange(t *testing.T) {
} }
t.Run("resolving_disabled", func(t *testing.T) { t.Run("resolving_disabled", func(t *testing.T) {
srv.conf.UsePrivateRDNS = false srv := createTestServer(t, &filtering.Config{
t.Cleanup(func() { srv.conf.UsePrivateRDNS = true }) BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
UpstreamDNS: []string{upsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
LocalPTRResolvers: []string{},
ServePlainDNS: true,
})
host, _, eerr := srv.Exchange(localIP) host, _, eerr := srv.Exchange(localIP)

View File

@ -42,7 +42,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
makeQ := func(qtype rules.RRType) (req *dns.Msg) { makeQ := func(qtype rules.RRType) (req *dns.Msg) {
return &dns.Msg{ return &dns.Msg{

View File

@ -83,7 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
ConfigModified: func() {}, ConfigModified: func() {},
ServePlainDNS: true, ServePlainDNS: true,
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf)
s.sysResolvers = &emptySysResolvers{} s.sysResolvers = &emptySysResolvers{}
require.NoError(t, s.Start()) require.NoError(t, s.Start())
@ -164,7 +164,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
ConfigModified: func() {}, ConfigModified: func() {},
ServePlainDNS: true, ServePlainDNS: true,
} }
s := createTestServer(t, filterConf, forwardConf, nil) s := createTestServer(t, filterConf, forwardConf)
s.sysResolvers = &emptySysResolvers{} s.sysResolvers = &emptySysResolvers{}
defaultConf := s.conf defaultConf := s.conf
@ -439,7 +439,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
srv.etcHosts = upstream.NewHostsResolver(hc) srv.etcHosts = upstream.NewHostsResolver(hc)
startDeferStop(t, srv) startDeferStop(t, srv)

View File

@ -9,7 +9,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@ -87,7 +86,7 @@ func TestServer_ProcessInitial(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, c, nil) }, c)
var gotAddr netip.Addr var gotAddr netip.Addr
s.addrProc = &aghtest.AddressProcessor{ s.addrProc = &aghtest.AddressProcessor{
@ -188,7 +187,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
}, c, nil) }, c)
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns) resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
dctx := &dnsContext{ dctx := &dnsContext{
@ -248,9 +247,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host string host string
want []*dns.SVCB want []*dns.SVCB
wantRes resultCode wantRes resultCode
portDoH int addrsDoH []*net.TCPAddr
portDoT int addrsDoT []*net.TCPAddr
portDoQ int addrsDoQ []*net.UDPAddr
qtype uint16 qtype uint16
ddrEnabled bool ddrEnabled bool
}{{ }{{
@ -259,14 +258,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: testQuestionTarget, host: testQuestionTarget,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoH: 8043, addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, { }, {
name: "pass_qtype", name: "pass_qtype",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeA, qtype: dns.TypeA,
ddrEnabled: true, ddrEnabled: true,
portDoH: 8043, addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, { }, {
name: "pass_disabled_tls", name: "pass_disabled_tls",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
@ -279,7 +278,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: false, ddrEnabled: false,
portDoH: 8043, addrsDoH: []*net.TCPAddr{{Port: 8043}},
}, { }, {
name: "dot", name: "dot",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
@ -287,7 +286,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoT: 8043, addrsDoT: []*net.TCPAddr{{Port: 8043}},
}, { }, {
name: "doh", name: "doh",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
@ -295,7 +294,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoH: 8044, addrsDoH: []*net.TCPAddr{{Port: 8044}},
}, { }, {
name: "doq", name: "doq",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
@ -303,7 +302,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoQ: 8042, addrsDoQ: []*net.UDPAddr{{Port: 8042}},
}, { }, {
name: "dot_doh", name: "dot_doh",
wantRes: resultCodeFinish, wantRes: resultCodeFinish,
@ -311,13 +310,35 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
host: ddrHostFQDN, host: ddrHostFQDN,
qtype: dns.TypeSVCB, qtype: dns.TypeSVCB,
ddrEnabled: true, ddrEnabled: true,
portDoT: 8043, addrsDoT: []*net.TCPAddr{{Port: 8043}},
portDoH: 8044, addrsDoH: []*net.TCPAddr{{Port: 8044}},
}} }}
_, certPem, keyPem := createServerTLSConfig(t)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled) s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
Config: Config{
HandleDDR: tc.ddrEnabled,
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
CertificateChainData: certPem,
PrivateKeyData: keyPem,
TLSListenAddrs: tc.addrsDoT,
HTTPSListenAddrs: tc.addrsDoH,
QUICListenAddrs: tc.addrsDoQ,
},
ServePlainDNS: true,
})
// TODO(e.burkov): Generate a certificate actually containing the
// IP addresses.
s.conf.hasIPAddrs = true
req := createTestMessageWithType(tc.host, tc.qtype) req := createTestMessageWithType(tc.host, tc.qtype)
@ -358,41 +379,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f return f
} }
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
s = &Server{
dnsFilter: createTestDNSFilter(t),
dnsProxy: &proxy.Proxy{
Config: proxy.Config{},
},
conf: ServerConfig{
Config: Config{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
ServePlainDNS: true,
},
}
if portDoT > 0 {
s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
s.conf.hasIPAddrs = true
}
if portDoQ > 0 {
s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
}
if portDoH > 0 {
s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
}
return s
}
func TestServer_ProcessDetermineLocal(t *testing.T) { func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{ s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
@ -680,13 +666,16 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
intPTRAnswer = "some.local-client." intPTRAnswer = "some.local-client."
) )
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
return aghalg.Coalesce( resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer(t, &filtering.Config{ s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault, BlockingMode: filtering.BlockingModeDefault,
@ -696,12 +685,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true. // TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
// Improve Config declaration for tests. // Improve Config declaration for tests.
Config: Config{ Config: Config{
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, UsePrivateRDNS: true,
}, ups) LocalPTRResolvers: []string{localUpsAddr},
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups} ServePlainDNS: true,
})
startDeferStop(t, s) startDeferStop(t, s)
testCases := []struct { testCases := []struct {
@ -764,6 +755,16 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
const locDomain = "some.local." const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa." const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer( s := createTestServer(
t, t,
&filtering.Config{ &filtering.Config{
@ -776,14 +777,10 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
UpstreamMode: UpstreamModeLoadBalance, UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
}, },
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
}),
) )
var proxyCtx *proxy.DNSContext var proxyCtx *proxy.DNSContext

View File

@ -21,7 +21,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
}, },
ServePlainDNS: true, ServePlainDNS: true,
}, nil) })
req := &dns.Msg{ req := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{

View File

@ -67,19 +67,15 @@ func New(c *Config) (svc *Service, err error) {
} }
svc.bootstrapResolvers = resolvers svc.bootstrapResolvers = resolvers
svc.proxy = &proxy.Proxy{ svc.proxy, err = proxy.New(&proxy.Config{
Config: proxy.Config{ UDPListenAddr: udpAddrs(c.Addresses),
UDPListenAddr: udpAddrs(c.Addresses), TCPListenAddr: tcpAddrs(c.Addresses),
TCPListenAddr: tcpAddrs(c.Addresses), UpstreamConfig: &proxy.UpstreamConfig{
UpstreamConfig: &proxy.UpstreamConfig{ Upstreams: upstreams,
Upstreams: upstreams,
},
UseDNS64: c.UseDNS64,
DNS64Prefs: c.DNS64Prefixes,
}, },
} UseDNS64: c.UseDNS64,
DNS64Prefs: c.DNS64Prefixes,
err = svc.proxy.Init() })
if err != nil { if err != nil {
return nil, fmt.Errorf("proxy: %w", err) return nil, fmt.Errorf("proxy: %w", err)
} }
@ -174,7 +170,7 @@ func (svc *Service) Start() (err error) {
svc.running.Store(err == nil) svc.running.Store(err == nil)
}() }()
return svc.proxy.Start() return svc.proxy.Start(context.Background())
} }
// Shutdown implements the [agh.Service] interface for *Service. svc may be // Shutdown implements the [agh.Service] interface for *Service. svc may be
@ -185,7 +181,7 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
} }
errs := []error{ errs := []error{
svc.proxy.Stop(), svc.proxy.Shutdown(ctx),
} }
for _, b := range svc.bootstrapResolvers { for _, b := range svc.bootstrapResolvers {