diff --git a/internal/aghnet/exchanger.go b/internal/aghnet/exchanger.go deleted file mode 100644 index c148e290..00000000 --- a/internal/aghnet/exchanger.go +++ /dev/null @@ -1,86 +0,0 @@ -package aghnet - -import ( - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/agherr" - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/miekg/dns" -) - -// This package is not the best place for this functionality, but we put it here -// since we need to use it in both rDNS (home) and dnsServer (dnsforward). - -// NoUpstreamsErr should be returned when there are no upstreams inside -// Exchanger implementation. -const NoUpstreamsErr agherr.Error = "no upstreams specified" - -// Exchanger represents an object able to resolve DNS messages. -// -// TODO(e.burkov): Maybe expand with method like ExchangeParallel to be able to -// use user's upstream mode settings. Also, think about Update method to -// refresh the internal state. -type Exchanger interface { - Exchange(req *dns.Msg) (resp *dns.Msg, err error) -} - -// multiAddrExchanger is the default implementation of Exchanger interface. -type multiAddrExchanger struct { - ups []upstream.Upstream -} - -// NewMultiAddrExchanger creates an Exchanger instance from passed addresses. -// It returns an error if any of addrs failed to become an upstream. -func NewMultiAddrExchanger( - addrs []string, - bootstraps []string, - timeout time.Duration, -) (e Exchanger, err error) { - defer agherr.Annotate("exchanger: %w", &err) - - if len(addrs) == 0 { - return &multiAddrExchanger{}, nil - } - - var ups []upstream.Upstream = make([]upstream.Upstream, 0, len(addrs)) - for _, addr := range addrs { - var u upstream.Upstream - u, err = upstream.AddressToUpstream(addr, upstream.Options{ - Bootstrap: bootstraps, - Timeout: timeout, - }) - if err != nil { - return nil, err - } - - ups = append(ups, u) - } - - return &multiAddrExchanger{ups: ups}, nil -} - -// Exсhange performs a query to each resolver until first response. -func (e *multiAddrExchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { - defer agherr.Annotate("exchanger: %w", &err) - - // TODO(e.burkov): Maybe prohibit the initialization without upstreams. - if len(e.ups) == 0 { - return nil, NoUpstreamsErr - } - - var errs []error - for _, u := range e.ups { - resp, err = u.Exchange(req) - if err != nil { - errs = append(errs, err) - - continue - } - - if resp != nil { - return resp, nil - } - } - - return nil, agherr.Many("can't exchange", errs...) -} diff --git a/internal/aghnet/exchanger_test.go b/internal/aghnet/exchanger_test.go deleted file mode 100644 index ace4b76b..00000000 --- a/internal/aghnet/exchanger_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package aghnet - -import ( - "testing" - - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestNewMultiAddrExchanger(t *testing.T) { - var e Exchanger - var err error - - t.Run("empty", func(t *testing.T) { - e, err = NewMultiAddrExchanger([]string{}, nil, 0) - require.NoError(t, err) - assert.NotNil(t, e) - }) - - t.Run("successful", func(t *testing.T) { - e, err = NewMultiAddrExchanger([]string{"www.example.com"}, nil, 0) - require.NoError(t, err) - assert.NotNil(t, e) - }) - - t.Run("unsuccessful", func(t *testing.T) { - e, err = NewMultiAddrExchanger([]string{"invalid-proto://www.example.com"}, nil, 0) - require.Error(t, err) - assert.Nil(t, e) - }) -} - -func TestMultiAddrExchanger_Exchange(t *testing.T) { - e := &multiAddrExchanger{} - - t.Run("error", func(t *testing.T) { - e.ups = []upstream.Upstream{&aghtest.TestErrUpstream{}} - - resp, err := e.Exchange(nil) - require.Error(t, err) - assert.Nil(t, resp) - }) - - t.Run("success", func(t *testing.T) { - e.ups = []upstream.Upstream{&aghtest.TestUpstream{ - Reverse: map[string][]string{ - "abc": {"cba"}, - }, - }} - - resp, err := e.Exchange(&dns.Msg{ - Question: []dns.Question{{ - Name: "abc", - Qtype: dns.TypePTR, - }}, - }) - require.NoError(t, err) - require.Len(t, resp.Answer, 1) - assert.Equal(t, "cba", resp.Answer[0].Header().Name) - }) -} diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 0e181e10..47aad064 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -1,7 +1,6 @@ package dnsforward import ( - "errors" "net" "strings" "time" @@ -403,21 +402,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - req := d.Req - resp, err := s.localResolvers.Exchange(req) + err := s.localResolvers.Resolve(d) if err != nil { - if errors.Is(err, aghnet.NoUpstreamsErr) { - d.Res = s.genNXDomain(req) - - return resultCodeFinish - } - ctx.err = err return resultCodeError } - d.Res = resp + if d.Res == nil { + d.Res = s.genNXDomain(d.Req) + + return resultCodeFinish + } return resultCodeSuccess } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 45362956..1b060cbb 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -259,17 +259,16 @@ func TestServer_ProcessInternalHosts(t *testing.T) { } func TestLocalRestriction(t *testing.T) { - s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ - UDPListenAddrs: []*net.UDPAddr{{}}, - TCPListenAddrs: []*net.TCPAddr{{}}, - }) ups := &aghtest.TestUpstream{ Reverse: map[string][]string{ "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."}, }, } - s.localResolvers = &aghtest.Exchanger{Ups: ups} + s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + }, ups) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups} startDeferStop(t, s) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index ab65a935..450897aa 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -66,7 +67,7 @@ type Server struct { ipset ipsetCtx subnetDetector *aghnet.SubnetDetector - localResolvers aghnet.Exchanger + localResolvers *proxy.Proxy tableHostToIP map[string]net.IP // "hostname -> IP" table for internal addresses (DHCP) tableHostToIPLock sync.Mutex @@ -243,24 +244,24 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { Qclass: dns.ClassINET, }}, } + ctx := &proxy.DNSContext{ + Proto: "udp", + Req: req, + StartTime: time.Now(), + } var resp *dns.Msg if s.subnetDetector.IsLocallyServedNetwork(ip) { - resp, err = s.localResolvers.Exchange(req) + err = s.localResolvers.Resolve(ctx) } else { - ctx := &proxy.DNSContext{ - Proto: "udp", - Req: req, - StartTime: time.Now(), - } err = s.internalProxy.Resolve(ctx) - - resp = ctx.Res } if err != nil { return "", err } + resp = ctx.Res + if len(resp.Answer) == 0 { return "", fmt.Errorf("lookup for %q: %w", arpa, rDNSEmptyAnswerErr) } @@ -376,18 +377,26 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) { return err } - // TODO(e.burkov): The approach of subtracting sets of strings - // is not really applicable here since in case of listening on - // all network interfaces we should check the whole interface's - // network to cut off all the loopback addresses as well. + // TODO(e.burkov): The approach of subtracting sets of strings is not + // really applicable here since in case of listening on all network + // interfaces we should check the whole interface's network to cut off + // all the loopback addresses as well. localAddrs = stringSetSubtract(localAddrs, ourAddrs) - if s.localResolvers, err = aghnet.NewMultiAddrExchanger( - localAddrs, - bootstraps, - defaultLocalTimeout, - ); err != nil { - return err + var upsConfig proxy.UpstreamConfig + upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{ + Bootstrap: bootstraps, + Timeout: defaultLocalTimeout, + // TODO(e.burkov): Should we verify server's ceritificates? + }) + if err != nil { + return fmt.Errorf("parsing upstreams: %w", err) + } + + s.localResolvers = &proxy.Proxy{ + Config: proxy.Config{ + UpstreamConfig: &upsConfig, + }, } return nil diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 6d16ac12..510fb87d 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -52,7 +52,12 @@ func startDeferStop(t *testing.T, s *Server) { }) } -func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf ServerConfig) *Server { +func createTestServer( + t *testing.T, + filterConf *dnsfilter.Config, + forwardConf ServerConfig, + localUps upstream.Upstream, +) (s *Server) { t.Helper() rules := `||nxdomain.example.org @@ -70,7 +75,6 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se require.NoError(t, err) require.NotNil(t, snd) - var s *Server s, err = NewServer(DNSCreateParams{ DNSFilter: f, SubnetDetector: snd, @@ -85,7 +89,9 @@ func createTestServer(t *testing.T, filterConf *dnsfilter.Config, forwardConf Se s.Lock() defer s.Unlock() - s.localResolvers = &aghtest.Exchanger{} + if localUps != nil { + s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} + } return s } @@ -143,7 +149,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) s = createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }) + }, nil) tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem s.conf.TLSConfig = tlsConf @@ -239,7 +245,7 @@ func TestServer(t *testing.T) { s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }) + }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ IPv4: map[string][]net.IP{ @@ -277,7 +283,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }) + }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ IPv4: map[string][]net.IP{ @@ -374,7 +380,7 @@ func TestServerRace(t *testing.T) { }, ConfigModified: func() {}, } - s := createTestServer(t, filterConf, forwardConf) + s := createTestServer(t, filterConf, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ IPv4: map[string][]net.IP{ @@ -407,7 +413,7 @@ func TestSafeSearch(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, filterConf, forwardConf) + s := createTestServer(t, filterConf, forwardConf, nil) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() @@ -460,7 +466,7 @@ func TestInvalidRequest(t *testing.T) { s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }) + }, nil) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() @@ -488,7 +494,7 @@ func TestBlockedRequest(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -513,7 +519,7 @@ func TestServerCustomClientUpstream(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) s.conf.GetCustomUpstreamByClient = func(_ string) *proxy.UpstreamConfig { return &proxy.UpstreamConfig{ Upstreams: []upstream.Upstream{ @@ -558,7 +564,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { s := createTestServer(t, &dnsfilter.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }) + }, nil) testUpstm := &aghtest.TestUpstream{ CName: testCNAMEs, IPv4: testIPv4, @@ -590,7 +596,7 @@ func TestBlockCNAME(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ CName: testCNAMEs, @@ -652,7 +658,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { }, }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ &aghtest.TestUpstream{ CName: testCNAMEs, @@ -693,7 +699,7 @@ func TestNullBlockedRequest(t *testing.T) { BlockingMode: "null_ip", }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -792,7 +798,7 @@ func TestBlockedByHosts(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, &dnsfilter.Config{}, forwardConf) + s := createTestServer(t, &dnsfilter.Config{}, forwardConf, nil) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -827,7 +833,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { ProtectionEnabled: true, }, } - s := createTestServer(t, filterConf, forwardConf) + s := createTestServer(t, filterConf, forwardConf, nil) s.dnsFilter.SetSafeBrowsingUpstream(sbUps) startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1240,8 +1246,13 @@ func TestServer_Exchange(t *testing.T) { }} for _, tc := range testCases { - dns.localResolvers = &aghtest.Exchanger{ - Ups: tc.locUpstream, + pcfg := proxy.Config{ + UpstreamConfig: &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{tc.locUpstream}, + }, + } + dns.localResolvers = &proxy.Proxy{ + Config: pcfg, } t.Run(tc.name, func(t *testing.T) { diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 273d9235..1724ae97 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -50,7 +50,7 @@ func TestDNSForwardHTTTP_handleGetConfig(t *testing.T) { }, ConfigModified: func() {}, } - s := createTestServer(t, filterConf, forwardConf) + s := createTestServer(t, filterConf, forwardConf, nil) require.Nil(t, s.Start()) t.Cleanup(func() { require.Nil(t, s.Stop()) @@ -123,7 +123,7 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) { }, ConfigModified: func() {}, } - s := createTestServer(t, filterConf, forwardConf) + s := createTestServer(t, filterConf, forwardConf, nil) defaultConf := s.conf