From 7313c3bc5387cdc77f9f688e846a4a4aca5b1ec8 Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Wed, 6 Nov 2019 17:24:15 +0300 Subject: [PATCH] + use per-client DNS servers --- AGHTechDoc.md | 3 +++ dnsforward/dnsforward.go | 16 ++++++++++++++++ dnsforward/dnsforward_http.go | 6 +++--- dnsforward/dnsforward_test.go | 6 +++--- home/clients.go | 17 +++++++++++++++++ home/clients_http.go | 6 ++++++ home/dns.go | 10 ++++++++++ 7 files changed, 58 insertions(+), 6 deletions(-) diff --git a/AGHTechDoc.md b/AGHTechDoc.md index d843195d..58a39dc6 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -669,6 +669,7 @@ Response: key: "value" ... } + upstreams: ["upstream1", ...] } ] auto_clients: [ @@ -703,6 +704,7 @@ Request: safesearch_enabled: false use_global_blocked_services: true blocked_services: [ "name1", ... ] + upstreams: ["upstream1", ...] } Response: @@ -732,6 +734,7 @@ Request: safesearch_enabled: false use_global_blocked_services: true blocked_services: [ "name1", ... ] + upstreams: ["upstream1", ...] } } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index a9fc04c9..a3dabc01 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -94,6 +94,9 @@ type FilteringConfig struct { // Filtering callback function FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` + // This callback function returns the list of upstream servers for a client specified by IP address + GetUpstreamsByClient func(clientAddr string) []string `yaml:"-"` + ProtectionEnabled bool `yaml:"protection_enabled"` // whether or not use any of dnsfilter features BlockingMode string `yaml:"blocking_mode"` // mode how to answer filtered requests @@ -393,6 +396,19 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { d.Req.Question[0].Name = dns.Fqdn(res.CanonName) } + if d.Addr != nil && s.conf.GetUpstreamsByClient != nil { + clientIP, _, _ := net.SplitHostPort(d.Addr.String()) + upstreams := s.conf.GetUpstreamsByClient(clientIP) + for _, us := range upstreams { + u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: 30 * time.Second}) + if err != nil { + log.Error("upstream.AddressToUpstream: %s: %s", us, err) + continue + } + d.Upstreams = append(d.Upstreams, u) + } + } + // request was not filtered so let it be processed further err = p.Resolve(d) if err != nil { diff --git a/dnsforward/dnsforward_http.go b/dnsforward/dnsforward_http.go index b8b51d76..3941aba1 100644 --- a/dnsforward/dnsforward_http.go +++ b/dnsforward/dnsforward_http.go @@ -44,7 +44,7 @@ func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) return } - err = validateUpstreams(req.Upstreams) + err = ValidateUpstreams(req.Upstreams) if err != nil { httpError(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) return @@ -78,8 +78,8 @@ func (s *Server) handleSetUpstreamConfig(w http.ResponseWriter, r *http.Request) } } -// validateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified -func validateUpstreams(upstreams []string) error { +// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified +func ValidateUpstreams(upstreams []string) error { var defaultUpstreamFound bool for _, u := range upstreams { d, err := validateUpstream(u) diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 7fcb5fb4..43970a7e 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -762,21 +762,21 @@ func TestValidateUpstreamsSet(t *testing.T) { "[/host.com/google.com/]8.8.8.8", "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20", } - err := validateUpstreams(upstreamsSet) + err := ValidateUpstreams(upstreamsSet) if err == nil { t.Fatalf("there is no default upstream") } // Let's add default upstream upstreamsSet = append(upstreamsSet, "8.8.8.8") - err = validateUpstreams(upstreamsSet) + err = ValidateUpstreams(upstreamsSet) if err != nil { t.Fatalf("upstreams set is valid, but doesn't pass through validation cause: %s", err) } // Let's add invalid upstream upstreamsSet = append(upstreamsSet, "dhcp://fake.dns") - err = validateUpstreams(upstreamsSet) + err = ValidateUpstreams(upstreamsSet) if err == nil { t.Fatalf("there is an invalid upstream in set, but it pass through validation") } diff --git a/home/clients.go b/home/clients.go index 0cc5c65a..b8e43d6f 100644 --- a/home/clients.go +++ b/home/clients.go @@ -13,6 +13,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" + "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" ) @@ -34,6 +35,8 @@ type Client struct { UseOwnBlockedServices bool // false: use global settings BlockedServices []string + + Upstreams []string // list of upstream servers to be used for the client's requests } type clientSource uint @@ -96,6 +99,8 @@ type clientObject struct { UseGlobalBlockedServices bool `yaml:"use_global_blocked_services"` BlockedServices []string `yaml:"blocked_services"` + + Upstreams []string `yaml:"upstreams"` } func (clients *clientsContainer) addFromConfig(objects []clientObject) { @@ -111,6 +116,8 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { UseOwnBlockedServices: !cy.UseGlobalBlockedServices, BlockedServices: cy.BlockedServices, + + Upstreams: cy.Upstreams, } _, err := clients.Add(cli) if err != nil { @@ -134,6 +141,8 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) { UseGlobalBlockedServices: !cli.UseOwnBlockedServices, BlockedServices: cli.BlockedServices, + + Upstreams: cli.Upstreams, } *objects = append(*objects, cy) } @@ -268,6 +277,14 @@ func (c *Client) check() error { return fmt.Errorf("Invalid ID: %s", id) } + + if len(c.Upstreams) != 0 { + err := dnsforward.ValidateUpstreams(c.Upstreams) + if err != nil { + return fmt.Errorf("Invalid upstream servers: %s", err) + } + } + return nil } diff --git a/home/clients_http.go b/home/clients_http.go index 5a6bd332..e1cea2cf 100644 --- a/home/clients_http.go +++ b/home/clients_http.go @@ -20,6 +20,8 @@ type clientJSON struct { UseGlobalBlockedServices bool `json:"use_global_blocked_services"` BlockedServices []string `json:"blocked_services"` + + Upstreams []string `json:"upstreams"` } type clientHostJSON struct { @@ -92,6 +94,8 @@ func jsonToClient(cj clientJSON) (*Client, error) { UseOwnBlockedServices: !cj.UseGlobalBlockedServices, BlockedServices: cj.BlockedServices, + + Upstreams: cj.Upstreams, } return &c, nil } @@ -109,6 +113,8 @@ func clientToJSON(c *Client) clientJSON { UseGlobalBlockedServices: !c.UseOwnBlockedServices, BlockedServices: c.BlockedServices, + + Upstreams: c.Upstreams, } cj.WhoisInfo = make(map[string]interface{}) diff --git a/home/dns.go b/home/dns.go index 0ec5a042..dbc48a73 100644 --- a/home/dns.go +++ b/home/dns.go @@ -170,9 +170,19 @@ func generateServerConfig() (dnsforward.ServerConfig, error) { } newconfig.FilterHandler = applyAdditionalFiltering + newconfig.GetUpstreamsByClient = getUpstreamsByClient return newconfig, nil } +func getUpstreamsByClient(clientAddr string) []string { + c, ok := config.clients.Find(clientAddr) + if !ok { + return []string{} + } + log.Debug("Using upstreams %v for client %s (IP: %s)", c.Upstreams, c.Name, clientAddr) + return c.Upstreams +} + // If a client has his own settings, apply them func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) {