From 67a39045fccfed978120e17a9527f8e9a9a7a668 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Wed, 13 May 2020 20:31:43 +0300 Subject: [PATCH] -(dnsforward): custom client per-domain upstreams Closes: https://github.com/AdguardTeam/AdGuardHome/issues/1539 --- .golangci.yml | 4 +++ dnsforward/config.go | 57 +++++++++++++++++------------------ dnsforward/dnsforward_test.go | 40 ++++++++++++------------ dnsforward/filter.go | 2 +- dnsforward/handle_dns.go | 10 +++--- go.mod | 2 +- go.sum | 4 +-- home/clients.go | 40 ++++++++++-------------- home/clients_test.go | 28 +++++++++++++++++ home/dns.go | 7 +---- 10 files changed, 106 insertions(+), 88 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 60b7ee1c..47425e3f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -70,3 +70,7 @@ issues: - G108 # gosec: Subprocess launched with function call as argument or cmd arguments - G204 + # gosec: Potential DoS vulnerability via decompression bomb + - G110 + # gosec: Expect WriteFile permissions to be 0600 or less + - G306 diff --git a/dnsforward/config.go b/dnsforward/config.go index b7523ab3..ed5ec3e9 100644 --- a/dnsforward/config.go +++ b/dnsforward/config.go @@ -26,8 +26,9 @@ type FilteringConfig struct { // Filtering callback function FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` - // This callback function returns the list of upstream servers for a client specified by IP address - GetUpstreamsByClient func(clientAddr string) []upstream.Upstream `yaml:"-"` + // GetCustomUpstreamByClient - a callback function that returns upstreams configuration + // based on the client IP address. Returns nil if there are no custom upstreams for the client + GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` // Protection configuration // -- @@ -102,11 +103,10 @@ type TLSConfig struct { // ServerConfig represents server configuration. // The zero ServerConfig is empty and ready for use. type ServerConfig struct { - UDPListenAddr *net.UDPAddr // UDP listen address - TCPListenAddr *net.TCPAddr // TCP listen address - Upstreams []upstream.Upstream // Configured upstreams - DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams - OnDNSRequest func(d *proxy.DNSContext) + UDPListenAddr *net.UDPAddr // UDP listen address + TCPListenAddr *net.TCPAddr // TCP listen address + UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config + OnDNSRequest func(d *proxy.DNSContext) FilteringConfig TLSConfig @@ -132,22 +132,21 @@ var defaultValues = ServerConfig{ // createProxyConfig creates and validates configuration for the main proxy func (s *Server) createProxyConfig() (proxy.Config, error) { proxyConfig := proxy.Config{ - UDPListenAddr: s.conf.UDPListenAddr, - TCPListenAddr: s.conf.TCPListenAddr, - Ratelimit: int(s.conf.Ratelimit), - RatelimitWhitelist: s.conf.RatelimitWhitelist, - RefuseAny: s.conf.RefuseAny, - CacheEnabled: true, - CacheSizeBytes: int(s.conf.CacheSize), - CacheMinTTL: s.conf.CacheMinTTL, - CacheMaxTTL: s.conf.CacheMaxTTL, - Upstreams: s.conf.Upstreams, - DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, - BeforeRequestHandler: s.beforeRequestHandler, - RequestHandler: s.handleDNSRequest, - AllServers: s.conf.AllServers, - EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, - FindFastestAddr: s.conf.FastestAddr, + UDPListenAddr: s.conf.UDPListenAddr, + TCPListenAddr: s.conf.TCPListenAddr, + Ratelimit: int(s.conf.Ratelimit), + RatelimitWhitelist: s.conf.RatelimitWhitelist, + RefuseAny: s.conf.RefuseAny, + CacheEnabled: true, + CacheSizeBytes: int(s.conf.CacheSize), + CacheMinTTL: s.conf.CacheMinTTL, + CacheMaxTTL: s.conf.CacheMaxTTL, + UpstreamConfig: s.conf.UpstreamConfig, + BeforeRequestHandler: s.beforeRequestHandler, + RequestHandler: s.handleDNSRequest, + AllServers: s.conf.AllServers, + EnableEDNSClientSubnet: s.conf.EnableEDNSClientSubnet, + FindFastestAddr: s.conf.FastestAddr, } if len(s.conf.BogusNXDomain) > 0 { @@ -168,7 +167,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) { } // Validate proxy config - if len(proxyConfig.Upstreams) == 0 { + if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 { return proxyConfig, errors.New("no upstream servers configured") } @@ -204,18 +203,16 @@ func (s *Server) prepareUpstreamSettings() error { if err != nil { return fmt.Errorf("DNS: proxy.ParseUpstreamsConfig: %s", err) } - s.conf.Upstreams = upstreamConfig.Upstreams - s.conf.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams + s.conf.UpstreamConfig = &upstreamConfig return nil } // prepareIntlProxy - initializes DNS proxy that we use for internal DNS queries func (s *Server) prepareIntlProxy() { intlProxyConfig := proxy.Config{ - CacheEnabled: true, - CacheSizeBytes: 4096, - Upstreams: s.conf.Upstreams, - DomainsReservedUpstreams: s.conf.DomainsReservedUpstreams, + CacheEnabled: true, + CacheSizeBytes: 4096, + UpstreamConfig: s.conf.UpstreamConfig, } s.internalProxy = &proxy.Proxy{Config: intlProxyConfig} } diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 60b606bf..430d2f35 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -325,7 +325,9 @@ func (s *Server) startWithUpstream(u upstream.Upstream) error { if err != nil { return err } - s.dnsProxy.Upstreams = []upstream.Upstream{u} + s.dnsProxy.UpstreamConfig = &proxy.UpstreamConfig{ + Upstreams: []upstream.Upstream{u}, + } return s.dnsProxy.Start() } @@ -353,8 +355,8 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { // but protection is disabled - response is NOT blocked req := createTestMessage("badhost.") reply, err := dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, reply.Rcode == dns.RcodeSuccess) + assert.Nil(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) } func TestBlockCNAME(t *testing.T) { @@ -368,23 +370,23 @@ func TestBlockCNAME(t *testing.T) { // response is blocked req := createTestMessage("badhost.") reply, err := dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, reply.Rcode == dns.RcodeNameError) + assert.Nil(t, err, nil) + assert.Equal(t, dns.RcodeNameError, reply.Rcode) // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters // but 'whitelist.example.org' is in a whitelist: // response isn't blocked req = createTestMessage("whitelist.example.org.") reply, err = dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, reply.Rcode == dns.RcodeSuccess) + assert.Nil(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters: // response is blocked req = createTestMessage("example.org.") reply, err = dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, reply.Rcode == dns.RcodeNameError) + assert.Nil(t, err) + assert.Equal(t, dns.RcodeNameError, reply.Rcode) _ = s.Stop() } @@ -455,7 +457,7 @@ func TestNullBlockedRequest(t *testing.T) { func TestBlockedCustomIP(t *testing.T) { rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" - filters := []dnsfilter.Filter{dnsfilter.Filter{ + filters := []dnsfilter.Filter{{ ID: 0, Data: []byte(rules), }} c := dnsfilter.Config{} @@ -475,27 +477,27 @@ func TestBlockedCustomIP(t *testing.T) { conf.BlockingIPv4 = "0.0.0.1" conf.BlockingIPv6 = "::1" err = s.Prepare(&conf) - assert.True(t, err == nil) + assert.Nil(t, err) err = s.Start() - assert.True(t, err == nil, "%s", err) + assert.Nil(t, err) addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := createTestMessageWithType("null.example.org.", dns.TypeA) reply, err := dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, len(reply.Answer) == 1) + assert.Nil(t, err) + assert.Equal(t, 1, len(reply.Answer)) a, ok := reply.Answer[0].(*dns.A) assert.True(t, ok) - assert.True(t, a.A.String() == "0.0.0.1") + assert.Equal(t, "0.0.0.1", a.A.String()) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) reply, err = dns.Exchange(req, addr.String()) - assert.True(t, err == nil) - assert.True(t, len(reply.Answer) == 1) + assert.Nil(t, err) + assert.Equal(t, 1, len(reply.Answer)) a6, ok := reply.Answer[0].(*dns.AAAA) assert.True(t, ok) - assert.True(t, a6.AAAA.String() == "::1") + assert.Equal(t, "::1", a6.AAAA.String()) err = s.Stop() if err != nil { @@ -598,7 +600,7 @@ func createTestServer(t *testing.T) *Server { 127.0.0.1 host.example.org @@||whitelist.example.org^ ||127.0.0.255` - filters := []dnsfilter.Filter{dnsfilter.Filter{ + filters := []dnsfilter.Filter{{ ID: 0, Data: []byte(rules), }} c := dnsfilter.Config{} diff --git a/dnsforward/filter.go b/dnsforward/filter.go index f6b9da63..08c81c40 100644 --- a/dnsforward/filter.go +++ b/dnsforward/filter.go @@ -10,7 +10,7 @@ import ( "github.com/miekg/dns" ) -func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool, error) { +func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { ip := ipFromAddr(d.Addr) if s.access.IsBlockedIP(ip) { log.Tracef("Client IP %s is blocked by settings", ip) diff --git a/dnsforward/handle_dns.go b/dnsforward/handle_dns.go index 5bd663cf..945413d1 100644 --- a/dnsforward/handle_dns.go +++ b/dnsforward/handle_dns.go @@ -31,7 +31,7 @@ const ( ) // handleDNSRequest filters the incoming DNS requests and writes them to the query log -func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { +func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { ctx := &dnsContext{srv: s, proxyCtx: d} ctx.result = &dnsfilter.Result{} ctx.startTime = time.Now() @@ -124,12 +124,12 @@ func processUpstream(ctx *dnsContext) int { return resultDone // response is already set - nothing to do } - if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { + if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { clientIP := ipFromAddr(d.Addr) - upstreams := s.conf.GetUpstreamsByClient(clientIP) - if len(upstreams) > 0 { + upstreamsConf := s.conf.GetCustomUpstreamByClient(clientIP) + if upstreamsConf != nil { log.Debug("Using custom upstreams for %s", clientIP) - d.Upstreams = upstreams + d.CustomUpstreamConfig = upstreamsConf } } diff --git a/go.mod b/go.mod index b33129c5..24e8a312 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.14 require ( - github.com/AdguardTeam/dnsproxy v0.28.0 + github.com/AdguardTeam/dnsproxy v0.28.1 github.com/AdguardTeam/golibs v0.4.2 github.com/AdguardTeam/urlfilter v0.10.0 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index 14b7faba..0ac64199 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.28.0 h1:w6ITGjSMLztUOTVNVVcE0JU1bV2U0bOPyDHGwyZgTc4= -github.com/AdguardTeam/dnsproxy v0.28.0/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= +github.com/AdguardTeam/dnsproxy v0.28.1 h1:WkLjrUcVf/njbTLyL7bNt6e18zQjF2ZYv/HWwL9cMmU= +github.com/AdguardTeam/dnsproxy v0.28.1/go.mod h1:hOYFV9TW+pd5XKYz7KZf2FFD8SvSPqjyGTxUae86s58= github.com/AdguardTeam/golibs v0.4.0 h1:4VX6LoOqFe9p9Gf55BeD8BvJD6M6RDYmgEiHrENE9KU= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= diff --git a/home/clients.go b/home/clients.go index 90750776..7457696b 100644 --- a/home/clients.go +++ b/home/clients.go @@ -11,11 +11,12 @@ import ( "sync" "time" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/AdGuardHome/util" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" ) @@ -41,11 +42,12 @@ type Client struct { BlockedServices []string Upstreams []string // list of upstream servers to be used for the client's requests - // Upstream objects: + + // Custom upstream config for this client // nil: not yet initialized // not nil, but empty: initialized, no good upstreams // not nil, not empty: Upstreams ready to be used - upstreamObjects []upstream.Upstream + upstreamConfig *proxy.UpstreamConfig } type clientSource uint @@ -273,16 +275,10 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) { return c, true } -func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream { - a2 := make([]upstream.Upstream, len(a)) - copy(a2, a) - return a2 -} - // FindUpstreams looks for upstreams configured for the client // If no client found for this IP, or if no custom upstreams are configured, // this method returns nil -func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { +func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig { clients.lock.Lock() defer clients.lock.Unlock() @@ -291,22 +287,18 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { return nil } - if c.upstreamObjects == nil { - c.upstreamObjects = make([]upstream.Upstream, 0) - for _, us := range c.Upstreams { - u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) - if err != nil { - log.Error("upstream.AddressToUpstream: %s: %s", us, err) - continue - } - c.upstreamObjects = append(c.upstreamObjects, u) + if len(c.Upstreams) == 0 { + return nil + } + + if c.upstreamConfig == nil { + config, err := proxy.ParseUpstreamsConfig(c.Upstreams, config.DNS.BootstrapDNS, dnsforward.DefaultTimeout) + if err == nil { + c.upstreamConfig = &config } } - if len(c.upstreamObjects) == 0 { - return nil - } - return upstreamArrayCopy(c.upstreamObjects) + return c.upstreamConfig } // Find searches for a client by IP (and does not lock anything) @@ -537,7 +529,7 @@ func (clients *clientsContainer) Update(name string, c Client) error { } // update upstreams cache - c.upstreamObjects = nil + c.upstreamConfig = nil *old = c return nil diff --git a/home/clients_test.go b/home/clients_test.go index 50b96121..9f6131ca 100644 --- a/home/clients_test.go +++ b/home/clients_test.go @@ -236,3 +236,31 @@ func TestClientsAddExisting(t *testing.T) { assert.True(t, ok) assert.Nil(t, err) } + +func TestClientsCustomUpstream(t *testing.T) { + clients := clientsContainer{} + clients.testing = true + + clients.Init(nil, nil, nil) + + // add client with upstreams + client := Client{ + IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, + Name: "client1", + Upstreams: []string{ + "1.1.1.1", + "[/example.org/]8.8.8.8", + }, + } + ok, err := clients.Add(client) + assert.Nil(t, err) + assert.True(t, ok) + + config := clients.FindUpstreams("1.2.3.4") + assert.Nil(t, config) + + config = clients.FindUpstreams("1.1.1.1") + assert.NotNil(t, config) + assert.Equal(t, 1, len(config.Upstreams)) + assert.Equal(t, 1, len(config.DomainReservedUpstreams)) +} diff --git a/home/dns.go b/home/dns.go index b03d0de3..36ee9ec1 100644 --- a/home/dns.go +++ b/home/dns.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/stats" "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" ) @@ -176,7 +175,7 @@ func generateServerConfig() dnsforward.ServerConfig { newconfig.TLSAllowUnencryptedDOH = tlsConf.AllowUnencryptedDOH newconfig.FilterHandler = applyAdditionalFiltering - newconfig.GetUpstreamsByClient = getUpstreamsByClient + newconfig.GetCustomUpstreamByClient = Context.clients.FindUpstreams return newconfig } @@ -222,10 +221,6 @@ func getDNSAddresses() []string { return dnsAddresses } -func getUpstreamsByClient(clientAddr string) []upstream.Upstream { - return Context.clients.FindUpstreams(clientAddr) -} - // If a client has his own settings, apply them func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { Context.dnsFilter.ApplyBlockedServices(setts, nil, true)