From 484c0ceaff2837db49f0ab679954511a9c916650 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Thu, 1 Nov 2018 14:45:32 +0300 Subject: [PATCH 01/10] Upstream plugin prototype --- control.go | 42 +++++++------- upstream/dns_upstream.go | 36 ++++++++++++ upstream/https_upstream.go | 109 +++++++++++++++++++++++++++++++++++++ upstream/tls_upstream.go | 47 ++++++++++++++++ upstream/upstream.go | 43 +++++++++++++++ upstream/upstream_test.go | 86 +++++++++++++++++++++++++++++ 6 files changed, 342 insertions(+), 21 deletions(-) create mode 100644 upstream/dns_upstream.go create mode 100644 upstream/https_upstream.go create mode 100644 upstream/tls_upstream.go create mode 100644 upstream/upstream.go create mode 100644 upstream/upstream_test.go diff --git a/control.go b/control.go index 1e1084e8..378a2aec 100644 --- a/control.go +++ b/control.go @@ -134,9 +134,9 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } // if empty body -- user is asking for default servers @@ -153,34 +153,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { err = writeAllConfigs() if err != nil { - errortext := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write config file: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } tellCoreDNSToReload() _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { - errortext := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) + errorText := fmt.Sprintf("Failed to read request body: %s", err) + log.Println(errorText) + http.Error(w, errorText, 400) return } hosts := strings.Fields(string(body)) if len(hosts) == 0 { - errortext := fmt.Sprintf("No servers specified") - log.Println(errortext) - http.Error(w, errortext, http.StatusBadRequest) + errorText := fmt.Sprintf("No servers specified") + log.Println(errorText) + http.Error(w, errorText, http.StatusBadRequest) return } @@ -198,18 +198,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { jsonVal, err := json.Marshal(result) if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Unable to marshal status json: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") _, err = w.Write(jsonVal) if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) + errorText := fmt.Sprintf("Couldn't write body: %s", err) + log.Println(errorText) + http.Error(w, errorText, http.StatusInternalServerError) } } diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go new file mode 100644 index 00000000..779c059e --- /dev/null +++ b/upstream/dns_upstream.go @@ -0,0 +1,36 @@ +package upstream + +import ( + "github.com/miekg/dns" + "golang.org/x/net/context" + "time" +) + +// DnsUpstream is a very simple upstream implementation for plain DNS +type DnsUpstream struct { + nameServer string // IP:port + timeout time.Duration // Max read and write timeout +} + +// NewDnsUpstream creates a new plain-DNS upstream +func NewDnsUpstream(nameServer string) (Upstream, error) { + return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + + dnsClient := &dns.Client{ + ReadTimeout: u.timeout, + WriteTimeout: u.timeout, + } + + resp, _, err := dnsClient.Exchange(query, u.nameServer) + + if err != nil { + resp = &dns.Msg{} + resp.SetRcode(resp, dns.RcodeServerFailure) + } + + return resp, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go new file mode 100644 index 00000000..61c7a397 --- /dev/null +++ b/upstream/https_upstream.go @@ -0,0 +1,109 @@ +package upstream + +import ( + "bytes" + "crypto/tls" + "fmt" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "golang.org/x/net/http2" + "io/ioutil" + "log" + "net/http" + "net/url" +) + +const ( + dnsMessageContentType = "application/dns-message" +) + +// HttpsUpstream is the upstream implementation for DNS-over-HTTPS +type HttpsUpstream struct { + client *http.Client + endpoint *url.URL +} + +// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname +func NewHttpsUpstream(endpoint string) (Upstream, error) { + u, err := url.Parse(endpoint) + if err != nil { + return nil, err + } + + // Update TLS and HTTP client configuration + tlsConfig := &tls.Config{ServerName: u.Hostname()} + transport := &http.Transport{ + TLSClientConfig: tlsConfig, + DisableCompression: true, + MaxIdleConns: 1, + } + http2.ConfigureTransport(transport) + + client := &http.Client{ + Timeout: defaultTimeout, + Transport: transport, + } + + return &HttpsUpstream{client: client, endpoint: u}, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + queryBuf, err := query.Pack() + if err != nil { + return nil, errors.Wrap(err, "failed to pack DNS query") + } + + // No content negotiation for now, use DNS wire format + buf, backendErr := u.exchangeWireformat(queryBuf) + if backendErr == nil { + response := &dns.Msg{} + if err := response.Unpack(buf); err != nil { + return nil, errors.Wrap(err, "failed to unpack DNS response from body") + } + + response.Id = query.Id + return response, nil + } + + log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) + return nil, backendErr +} + +// Perform message exchange with the default UDP wireformat defined in current draft +// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 +func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { + req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) + if err != nil { + return nil, errors.Wrap(err, "failed to create an HTTPS request") + } + + req.Header.Add("Content-Type", dnsMessageContentType) + req.Header.Add("Accept", dnsMessageContentType) + req.Host = u.endpoint.Hostname() + + resp, err := u.client.Do(req) + if err != nil { + return nil, errors.Wrap(err, "failed to perform an HTTPS request") + } + + // Check response status code + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("returned status code %d", resp.StatusCode) + } + + contentType := resp.Header.Get("Content-Type") + if contentType != dnsMessageContentType { + return nil, fmt.Errorf("return wrong content type %s", contentType) + } + + // Read application/dns-message response from the body + buf, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "failed to read the response body") + } + + return buf, nil +} diff --git a/upstream/tls_upstream.go b/upstream/tls_upstream.go new file mode 100644 index 00000000..aed55829 --- /dev/null +++ b/upstream/tls_upstream.go @@ -0,0 +1,47 @@ +package upstream + +import ( + "crypto/tls" + "github.com/miekg/dns" + "golang.org/x/net/context" + "time" +) + +// TODO: Use persistent connection here + +// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS +type DnsOverTlsUpstream struct { + endpoint string + tlsServerName string + timeout time.Duration +} + +// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name +func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) { + return &DnsOverTlsUpstream{ + endpoint: endpoint, + tlsServerName: tlsServerName, + timeout: defaultTimeout, + }, nil +} + +// Exchange provides an implementation for the Upstream interface +func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { + + dnsClient := &dns.Client{ + Net: "tcp-tls", + ReadTimeout: u.timeout, + WriteTimeout: u.timeout, + TLSConfig: new(tls.Config), + } + dnsClient.TLSConfig.ServerName = u.tlsServerName + + resp, _, err := dnsClient.Exchange(query, u.endpoint) + + if err != nil { + resp = &dns.Msg{} + resp.SetRcode(resp, dns.RcodeServerFailure) + } + + return resp, err +} diff --git a/upstream/upstream.go b/upstream/upstream.go new file mode 100644 index 00000000..6d2570c5 --- /dev/null +++ b/upstream/upstream.go @@ -0,0 +1,43 @@ +package upstream + +import ( + "github.com/coredns/coredns/plugin" + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "time" +) + +const ( + defaultTimeout = 5 * time.Second +) + +// Upstream is a simplified interface for proxy destination +type Upstream interface { + Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) +} + +// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface +type UpstreamPlugin struct { + Upstreams []Upstream + Next plugin.Handler +} + +// ServeDNS implements interface for CoreDNS plugin +func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + var reply *dns.Msg + var backendErr error + + for _, upstream := range p.Upstreams { + reply, backendErr = upstream.Exchange(ctx, r) + if backendErr == nil { + w.WriteMsg(reply) + return 0, nil + } + } + + return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") +} + +// Name implements interface for CoreDNS plugin +func (p UpstreamPlugin) Name() string { return "upstream" } diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go new file mode 100644 index 00000000..ca0df859 --- /dev/null +++ b/upstream/upstream_test.go @@ -0,0 +1,86 @@ +package upstream + +import ( + "github.com/miekg/dns" + "log" + "net" + "testing" +) + +func TestDnsUpstream(t *testing.T) { + + u, err := NewDnsUpstream("8.8.8.8:53") + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstream(t, u) +} + +func TestHttpsUpstream(t *testing.T) { + + testCases := []string{ + "https://cloudflare-dns.com/dns-query", + "https://dns.google.com/experimental", + "https://doh.cleanbrowsing.org/doh/security-filter/", + } + + for _, url := range testCases { + u, err := NewHttpsUpstream(url) + + if err != nil { + t.Errorf("cannot create a DNS-over-HTTPS upstream") + } + + testUpstream(t, u) + } +} + +func TestDnsOverTlsUpstream(t *testing.T) { + + var tests = []struct { + endpoint string + tlsServerName string + }{ + {"1.1.1.1:853", ""}, + {"8.8.8.8:853", ""}, + {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + } + + for _, test := range tests { + u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName) + + if err != nil { + t.Errorf("cannot create a DNS-over-TLS upstream") + } + + testUpstream(t, u) + } +} + +func testUpstream(t *testing.T, u Upstream) { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + resp, err := u.Exchange(nil, &req) + + if err != nil { + t.Errorf("error while making an upstream request: %s", err) + } + + if len(resp.Answer) != 1 { + t.Errorf("no answer section in the response") + } + if answer, ok := resp.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(answer.A) { + t.Errorf("wrong IP in the response: %v", answer.A) + } + } + + log.Printf("response: %v", resp) +} From d6f560ecafd18ea7a8b969abb801ccdf31e8ec73 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Mon, 5 Nov 2018 20:40:10 +0300 Subject: [PATCH 02/10] Added persistent connections cache --- upstream/dns_upstream.go | 78 ++++++++++++-- upstream/https_upstream.go | 7 ++ upstream/persistent.go | 208 +++++++++++++++++++++++++++++++++++++ upstream/tls_upstream.go | 47 --------- upstream/upstream.go | 31 +++++- upstream/upstream_test.go | 53 ++++++---- 6 files changed, 344 insertions(+), 80 deletions(-) create mode 100644 upstream/persistent.go delete mode 100644 upstream/tls_upstream.go diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go index 779c059e..a40aec5a 100644 --- a/upstream/dns_upstream.go +++ b/upstream/dns_upstream.go @@ -1,6 +1,7 @@ package upstream import ( + "crypto/tls" "github.com/miekg/dns" "golang.org/x/net/context" "time" @@ -8,24 +9,40 @@ import ( // DnsUpstream is a very simple upstream implementation for plain DNS type DnsUpstream struct { - nameServer string // IP:port - timeout time.Duration // Max read and write timeout + endpoint string // IP:port + timeout time.Duration // Max read and write timeout + proto string // Protocol (tcp, tcp-tls, or udp) + transport *Transport // Persistent connections cache } -// NewDnsUpstream creates a new plain-DNS upstream -func NewDnsUpstream(nameServer string) (Upstream, error) { - return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil +// NewDnsUpstream creates a new DNS upstream +func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { + + u := &DnsUpstream{ + endpoint: endpoint, + timeout: defaultTimeout, + proto: proto, + } + + var tlsConfig *tls.Config + + if tlsServerName != "" { + tlsConfig = new(tls.Config) + tlsConfig.ServerName = tlsServerName + } + + // Initialize the connections cache + u.transport = NewTransport(endpoint) + u.transport.tlsConfig = tlsConfig + u.transport.Start() + + return u, nil } // Exchange provides an implementation for the Upstream interface func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - dnsClient := &dns.Client{ - ReadTimeout: u.timeout, - WriteTimeout: u.timeout, - } - - resp, _, err := dnsClient.Exchange(query, u.nameServer) + resp, err := u.exchange(query) if err != nil { resp = &dns.Msg{} @@ -34,3 +51,42 @@ func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, e return resp, err } + +// Clear resources +func (u *DnsUpstream) Close() error { + + // Close active connections + u.transport.Stop() + return nil +} + +// Performs a synchronous query. It sends the message m via the conn +// c and waits for a reply. The conn c is not closed. +func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) { + + // Establish a connection if needed (or reuse cached) + conn, err := u.transport.Dial(u.proto) + if err != nil { + return nil, err + } + + // Write the request with a timeout + conn.SetWriteDeadline(time.Now().Add(u.timeout)) + if err = conn.WriteMsg(query); err != nil { + conn.Close() // Not giving it back + return nil, err + } + + // Write response with a timeout + conn.SetReadDeadline(time.Now().Add(u.timeout)) + r, err = conn.ReadMsg() + if err != nil { + conn.Close() // Not giving it back + } else if err == nil && r.Id != query.Id { + err = dns.ErrId + conn.Close() // Not giving it back + } + + u.transport.Yield(conn) + return r, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index 61c7a397..7daab106 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -18,6 +18,8 @@ const ( dnsMessageContentType = "application/dns-message" ) +// TODO: Add bootstrap DNS resolver field + // HttpsUpstream is the upstream implementation for DNS-over-HTTPS type HttpsUpstream struct { client *http.Client @@ -107,3 +109,8 @@ func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { return buf, nil } + +// Clear resources +func (u *HttpsUpstream) Close() error { + return nil +} diff --git a/upstream/persistent.go b/upstream/persistent.go new file mode 100644 index 00000000..5c28a10e --- /dev/null +++ b/upstream/persistent.go @@ -0,0 +1,208 @@ +package upstream + +import ( + "crypto/tls" + "net" + "sort" + "sync/atomic" + "time" + + "github.com/miekg/dns" +) + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 100 * time.Millisecond + maxDialTimeout = 30 * time.Second + defaultDialTimeout = 30 * time.Second + cumulativeAvgWeight = 4 +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + + dial chan string + yield chan *dns.Conn + ret chan *dns.Conn + stop chan bool +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*dns.Conn, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + c := <-t.ret + + if c != nil { + return c, nil + } + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return conn, err +} + +// Yield return the connection to transport for reuse. +func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +func NewTransport(addr string) *Transport { + t := &Transport{ + avgDialTime: int64(defaultDialTimeout / 2), + conns: make(map[string][]*persistConn), + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *dns.Conn), + ret: make(chan *dns.Conn), + stop: make(chan bool), + } + return t +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(t.expire) +Wait: + for { + select { + case proto := <-t.dial: + // take the last used conn - complexity O(1) + if stack := t.conns[proto]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[proto] = stack[:len(stack)-1] + t.ret <- pc.c + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + + t.ret <- nil + + case conn := <-t.yield: + + // no proto here, infer from config and conn + if _, ok := conn.Conn.(*net.UDPConn); ok { + t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) + continue Wait + } + + if t.tlsConfig == nil { + t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) + continue Wait + } + + t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for proto, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[proto] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[proto] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} diff --git a/upstream/tls_upstream.go b/upstream/tls_upstream.go deleted file mode 100644 index aed55829..00000000 --- a/upstream/tls_upstream.go +++ /dev/null @@ -1,47 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "github.com/miekg/dns" - "golang.org/x/net/context" - "time" -) - -// TODO: Use persistent connection here - -// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS -type DnsOverTlsUpstream struct { - endpoint string - tlsServerName string - timeout time.Duration -} - -// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name -func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) { - return &DnsOverTlsUpstream{ - endpoint: endpoint, - tlsServerName: tlsServerName, - timeout: defaultTimeout, - }, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - - dnsClient := &dns.Client{ - Net: "tcp-tls", - ReadTimeout: u.timeout, - WriteTimeout: u.timeout, - TLSConfig: new(tls.Config), - } - dnsClient.TLSConfig.ServerName = u.tlsServerName - - resp, _, err := dnsClient.Exchange(query, u.endpoint) - - if err != nil { - resp = &dns.Msg{} - resp.SetRcode(resp, dns.RcodeServerFailure) - } - - return resp, err -} diff --git a/upstream/upstream.go b/upstream/upstream.go index 6d2570c5..44d4e389 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -5,6 +5,8 @@ import ( "github.com/miekg/dns" "github.com/pkg/errors" "golang.org/x/net/context" + "log" + "runtime" "time" ) @@ -12,9 +14,12 @@ const ( defaultTimeout = 5 * time.Second ) +// TODO: Add a helper method for health-checking an upstream (see health.go in coredns) + // Upstream is a simplified interface for proxy destination type Upstream interface { Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) + Close() error } // UpstreamPlugin is a simplified DNS proxy using a generic upstream interface @@ -23,11 +28,21 @@ type UpstreamPlugin struct { Next plugin.Handler } +// Initialize the upstream plugin +func New() *UpstreamPlugin { + p := &UpstreamPlugin{} + + // Make sure all resources are cleaned up + runtime.SetFinalizer(p, (*UpstreamPlugin).finalizer) + return p +} + // ServeDNS implements interface for CoreDNS plugin -func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { +func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { var reply *dns.Msg var backendErr error + // TODO: Change the way we call upstreams for _, upstream := range p.Upstreams { reply, backendErr = upstream.Exchange(ctx, r) if backendErr == nil { @@ -40,4 +55,16 @@ func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *d } // Name implements interface for CoreDNS plugin -func (p UpstreamPlugin) Name() string { return "upstream" } +func (p *UpstreamPlugin) Name() string { return "upstream" } + +func (p *UpstreamPlugin) finalizer() { + + for i := range p.Upstreams { + + u := p.Upstreams[i] + err := u.Close() + if err != nil { + log.Printf("Error while closing the upstream: %s", err) + } + } +} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index ca0df859..5e60b63d 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -2,14 +2,13 @@ package upstream import ( "github.com/miekg/dns" - "log" "net" "testing" ) func TestDnsUpstream(t *testing.T) { - u, err := NewDnsUpstream("8.8.8.8:53") + u, err := NewDnsUpstream("8.8.8.8:53", "udp", "") if err != nil { t.Errorf("cannot create a DNS upstream") @@ -44,12 +43,12 @@ func TestDnsOverTlsUpstream(t *testing.T) { tlsServerName string }{ {"1.1.1.1:853", ""}, - {"8.8.8.8:853", ""}, + {"9.9.9.9:853", ""}, {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, } for _, test := range tests { - u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName) + u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream") @@ -60,27 +59,41 @@ func TestDnsOverTlsUpstream(t *testing.T) { } func testUpstream(t *testing.T, u Upstream) { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + + var tests = []struct { + name string + expected net.IP + }{ + {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, + {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, } - resp, err := u.Exchange(nil, &req) + for _, test := range tests { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } - if err != nil { - t.Errorf("error while making an upstream request: %s", err) - } + resp, err := u.Exchange(nil, &req) - if len(resp.Answer) != 1 { - t.Errorf("no answer section in the response") - } - if answer, ok := resp.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(answer.A) { - t.Errorf("wrong IP in the response: %v", answer.A) + if err != nil { + t.Errorf("error while making an upstream request: %s", err) + } + + if len(resp.Answer) != 1 { + t.Errorf("no answer section in the response") + } + if answer, ok := resp.Answer[0].(*dns.A); ok { + if !test.expected.Equal(answer.A) { + t.Errorf("wrong IP in the response: %v", answer.A) + } } } - log.Printf("response: %v", resp) + err := u.Close() + if err != nil { + t.Errorf("Error while closing the upstream: %s", err) + } } From a6022fc198bc22f0fc69f276e57b11f85a701457 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Mon, 5 Nov 2018 21:19:01 +0300 Subject: [PATCH 03/10] Added health-check method --- upstream/helpers.go | 23 ++++++++ upstream/https_upstream.go | 28 ++++++++- upstream/upstream.go | 6 +- upstream/upstream_test.go | 113 +++++++++++++++++++++++++++++++++---- 4 files changed, 152 insertions(+), 18 deletions(-) create mode 100644 upstream/helpers.go diff --git a/upstream/helpers.go b/upstream/helpers.go new file mode 100644 index 00000000..6e9cc30a --- /dev/null +++ b/upstream/helpers.go @@ -0,0 +1,23 @@ +package upstream + +import "github.com/miekg/dns" + +// Performs a simple health-check of the specified upstream +func IsAlive(u Upstream) (bool, error) { + + // Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere + ping := new(dns.Msg) + ping.SetQuestion("ipv4only.arpa.", dns.TypeA) + + resp, err := u.Exchange(nil, ping) + + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. + if err != nil && resp != nil { + // Silly check, something sane came back. + if resp.Response || resp.Opcode == dns.OpcodeQuery { + err = nil + } + } + + return err == nil, err +} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index 7daab106..906eadf2 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -10,16 +10,17 @@ import ( "golang.org/x/net/http2" "io/ioutil" "log" + "net" "net/http" "net/url" + "time" ) const ( dnsMessageContentType = "application/dns-message" + defaultKeepAlive = 30 * time.Second ) -// TODO: Add bootstrap DNS resolver field - // HttpsUpstream is the upstream implementation for DNS-over-HTTPS type HttpsUpstream struct { client *http.Client @@ -27,18 +28,39 @@ type HttpsUpstream struct { } // NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname -func NewHttpsUpstream(endpoint string) (Upstream, error) { +func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { u, err := url.Parse(endpoint) if err != nil { return nil, err } + // Initialize bootstrap resolver + bootstrapResolver := net.DefaultResolver + if bootstrap != "" { + bootstrapResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, bootstrap) + return conn, err + }, + } + } + + dialer := &net.Dialer{ + Timeout: defaultTimeout, + KeepAlive: defaultKeepAlive, + DualStack: true, + Resolver: bootstrapResolver, + } + // Update TLS and HTTP client configuration tlsConfig := &tls.Config{ServerName: u.Hostname()} transport := &http.Transport{ TLSClientConfig: tlsConfig, DisableCompression: true, MaxIdleConns: 1, + DialContext: dialer.DialContext, } http2.ConfigureTransport(transport) diff --git a/upstream/upstream.go b/upstream/upstream.go index 44d4e389..9d2222dc 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -42,8 +42,8 @@ func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r * var reply *dns.Msg var backendErr error - // TODO: Change the way we call upstreams - for _, upstream := range p.Upstreams { + for i := range p.Upstreams { + upstream := p.Upstreams[i] reply, backendErr = upstream.Exchange(ctx, r) if backendErr == nil { w.WriteMsg(reply) @@ -67,4 +67,4 @@ func (p *UpstreamPlugin) finalizer() { log.Printf("Error while closing the upstream: %s", err) } } -} +} \ No newline at end of file diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index 5e60b63d..f612fc6e 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -6,27 +6,107 @@ import ( "testing" ) -func TestDnsUpstream(t *testing.T) { +func TestDnsUpstreamIsAlive(t *testing.T) { - u, err := NewDnsUpstream("8.8.8.8:53", "udp", "") - - if err != nil { - t.Errorf("cannot create a DNS upstream") + var tests = []struct { + endpoint string + proto string + }{ + {"8.8.8.8:53", "udp"}, + {"8.8.8.8:53", "tcp"}, + {"1.1.1.1:53", "udp"}, } - testUpstream(t, u) + for _, test := range tests { + u, err := NewDnsUpstream(test.endpoint, test.proto, "") + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestHttpsUpstreamIsAlive(t *testing.T) { + + var tests = []struct { + url string + bootstrap string + }{ + {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, + {"https://dns.google.com/experimental", "8.8.8.8:53"}, + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201?? + } + + for _, test := range tests { + u, err := NewHttpsUpstream(test.url, test.bootstrap) + + if err != nil { + t.Errorf("cannot create a DNS-over-HTTPS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestDnsOverTlsIsAlive(t *testing.T) { + + var tests = []struct { + endpoint string + tlsServerName string + }{ + {"1.1.1.1:853", ""}, + {"9.9.9.9:853", ""}, + {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + } + + for _, test := range tests { + u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) + + if err != nil { + t.Errorf("cannot create a DNS-over-TLS upstream") + } + + testUpstreamIsAlive(t, u) + } +} + +func TestDnsUpstream(t *testing.T) { + + var tests = []struct { + endpoint string + proto string + }{ + {"8.8.8.8:53", "udp"}, + {"8.8.8.8:53", "tcp"}, + {"1.1.1.1:53", "udp"}, + } + + for _, test := range tests { + u, err := NewDnsUpstream(test.endpoint, test.proto, "") + + if err != nil { + t.Errorf("cannot create a DNS upstream") + } + + testUpstream(t, u) + } } func TestHttpsUpstream(t *testing.T) { - testCases := []string{ - "https://cloudflare-dns.com/dns-query", - "https://dns.google.com/experimental", - "https://doh.cleanbrowsing.org/doh/security-filter/", + var tests = []struct { + url string + bootstrap string + }{ + {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, + {"https://dns.google.com/experimental", "8.8.8.8:53"}, + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, } - for _, url := range testCases { - u, err := NewHttpsUpstream(url) + for _, test := range tests { + u, err := NewHttpsUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-HTTPS upstream") @@ -58,6 +138,15 @@ func TestDnsOverTlsUpstream(t *testing.T) { } } +func testUpstreamIsAlive(t *testing.T, u Upstream) { + alive, err := IsAlive(u) + if !alive || err != nil { + t.Errorf("Upstream is not alive") + } + + u.Close() +} + func testUpstream(t *testing.T, u Upstream) { var tests = []struct { From 9bc4bf66edb279f83aa7610894bd3cccaf12e477 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Mon, 5 Nov 2018 22:11:13 +0300 Subject: [PATCH 04/10] Added factory method for creating DNS upstreams --- upstream/helpers.go | 80 +++++++++++++++++++++++++++++++++++++- upstream/https_upstream.go | 15 +------ upstream/upstream_test.go | 58 ++++++++++++++------------- 3 files changed, 112 insertions(+), 41 deletions(-) diff --git a/upstream/helpers.go b/upstream/helpers.go index 6e9cc30a..0e698698 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -1,6 +1,84 @@ package upstream -import "github.com/miekg/dns" +import ( + "github.com/miekg/dns" + "golang.org/x/net/context" + "net" + "strings" +) + +// Detects the upstream type from the specified url and creates a proper Upstream object +func NewUpstream(url string, bootstrap string) (Upstream, error) { + + proto := "udp" + prefix := "" + + switch { + case strings.HasPrefix(url, "tcp://"): + proto = "tcp" + prefix = "tcp://" + case strings.HasPrefix(url, "tls://"): + proto = "tcp-tls" + prefix = "tls://" + case strings.HasPrefix(url, "https://"): + return NewHttpsUpstream(url, bootstrap) + } + + hostname := strings.TrimPrefix(url, prefix) + + host, port, err := net.SplitHostPort(hostname) + if err != nil { + // Set port depending on the protocol + switch proto { + case "udp": + port = "53" + case "tcp": + port = "53" + case "tcp-tls": + port = "853" + } + + // Set host = hostname + host = hostname + } + + // Try to resolve the host address (or check if it's an IP address) + bootstrapResolver := CreateResolver(bootstrap) + ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) + + if err != nil || len(ips) == 0 { + return nil, err + } + + addr := ips[0].String() + endpoint := net.JoinHostPort(addr, port) + tlsServerName := "" + + if proto == "tcp-tls" && host != addr { + // Check if we need to specify TLS server name + tlsServerName = host + } + + return NewDnsUpstream(endpoint, proto, tlsServerName) +} + +func CreateResolver(bootstrap string) *net.Resolver { + + bootstrapResolver := net.DefaultResolver + + if bootstrap != "" { + bootstrapResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, network, bootstrap) + return conn, err + }, + } + } + + return bootstrapResolver +} // Performs a simple health-check of the specified upstream func IsAlive(u Upstream) (bool, error) { diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index 906eadf2..ae705699 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -27,7 +27,7 @@ type HttpsUpstream struct { endpoint *url.URL } -// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname +// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { u, err := url.Parse(endpoint) if err != nil { @@ -35,18 +35,7 @@ func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { } // Initialize bootstrap resolver - bootstrapResolver := net.DefaultResolver - if bootstrap != "" { - bootstrapResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - conn, err := d.DialContext(ctx, network, bootstrap) - return conn, err - }, - } - } - + bootstrapResolver := CreateResolver(bootstrap) dialer := &net.Dialer{ Timeout: defaultTimeout, KeepAlive: defaultKeepAlive, diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index f612fc6e..1b3235fe 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -9,16 +9,17 @@ import ( func TestDnsUpstreamIsAlive(t *testing.T) { var tests = []struct { - endpoint string - proto string + url string + bootstrap string }{ - {"8.8.8.8:53", "udp"}, - {"8.8.8.8:53", "tcp"}, - {"1.1.1.1:53", "udp"}, + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, test.proto, "") + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS upstream") @@ -36,11 +37,11 @@ func TestHttpsUpstreamIsAlive(t *testing.T) { }{ {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, // TODO: status 201?? + {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, } for _, test := range tests { - u, err := NewHttpsUpstream(test.url, test.bootstrap) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-HTTPS upstream") @@ -53,16 +54,17 @@ func TestHttpsUpstreamIsAlive(t *testing.T) { func TestDnsOverTlsIsAlive(t *testing.T) { var tests = []struct { - endpoint string - tlsServerName string + url string + bootstrap string }{ - {"1.1.1.1:853", ""}, - {"9.9.9.9:853", ""}, - {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream") @@ -75,16 +77,17 @@ func TestDnsOverTlsIsAlive(t *testing.T) { func TestDnsUpstream(t *testing.T) { var tests = []struct { - endpoint string - proto string + url string + bootstrap string }{ - {"8.8.8.8:53", "udp"}, - {"8.8.8.8:53", "tcp"}, - {"1.1.1.1:53", "udp"}, + {"8.8.8.8:53", "8.8.8.8:53"}, + {"1.1.1.1", ""}, + {"tcp://1.1.1.1:53", ""}, + {"176.103.130.130:5353", ""}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, test.proto, "") + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS upstream") @@ -106,7 +109,7 @@ func TestHttpsUpstream(t *testing.T) { } for _, test := range tests { - u, err := NewHttpsUpstream(test.url, test.bootstrap) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-HTTPS upstream") @@ -119,16 +122,17 @@ func TestHttpsUpstream(t *testing.T) { func TestDnsOverTlsUpstream(t *testing.T) { var tests = []struct { - endpoint string - tlsServerName string + url string + bootstrap string }{ - {"1.1.1.1:853", ""}, - {"9.9.9.9:853", ""}, - {"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"}, + {"tls://1.1.1.1", ""}, + {"tls://9.9.9.9:853", ""}, + {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, + {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, } for _, test := range tests { - u, err := NewDnsUpstream(test.endpoint, "tcp-tls", test.tlsServerName) + u, err := NewUpstream(test.url, test.bootstrap) if err != nil { t.Errorf("cannot create a DNS-over-TLS upstream") From efdd1c1ff2aff4643f0982734462e11c0218b02f Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Mon, 5 Nov 2018 23:49:31 +0300 Subject: [PATCH 05/10] Added CoreDNS plugin setup and replaced forward --- config.go | 2 +- coredns.go | 2 + upstream/dns_upstream.go | 19 +++++++-- upstream/helpers.go | 2 +- upstream/setup.go | 83 +++++++++++++++++++++++++++++++++++++++ upstream/setup_test.go | 29 ++++++++++++++ upstream/upstream.go | 22 ++--------- upstream/upstream_test.go | 3 +- 8 files changed, 137 insertions(+), 25 deletions(-) create mode 100644 upstream/setup.go create mode 100644 upstream/setup_test.go diff --git a/config.go b/config.go index bf63c0ee..1a0d18a4 100644 --- a/config.go +++ b/config.go @@ -253,7 +253,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} { hosts { fallthrough } - {{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}} + {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap 8.8.8.8:53 }{{end}} {{.Cache}} {{.Prometheus}} } diff --git a/coredns.go b/coredns.go index 5dbe01b4..a21fb986 100644 --- a/coredns.go +++ b/coredns.go @@ -8,6 +8,7 @@ import ( "sync" // Include all plugins. _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin" + _ "github.com/AdguardTeam/AdGuardHome/upstream" "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/coremain" _ "github.com/coredns/coredns/plugin/auto" @@ -79,6 +80,7 @@ var directives = []string{ "loop", "forward", "proxy", + "upstream", "erratic", "whoami", "on", diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go index a40aec5a..e7c2e7bd 100644 --- a/upstream/dns_upstream.go +++ b/upstream/dns_upstream.go @@ -42,7 +42,20 @@ func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstre // Exchange provides an implementation for the Upstream interface func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - resp, err := u.exchange(query) + resp, err := u.exchange(u.proto, query) + + // Retry over TCP if response is truncated + if err == dns.ErrTruncated && u.proto == "udp" { + resp, err = u.exchange("tcp", query) + } else if err == dns.ErrTruncated && resp != nil { + // Reassemble something to be sent to client + m := new(dns.Msg) + m.SetReply(query) + m.Truncated = true + m.Authoritative = true + m.Rcode = dns.RcodeSuccess + return m, nil + } if err != nil { resp = &dns.Msg{} @@ -62,10 +75,10 @@ func (u *DnsUpstream) Close() error { // Performs a synchronous query. It sends the message m via the conn // c and waits for a reply. The conn c is not closed. -func (u *DnsUpstream) exchange(query *dns.Msg) (r *dns.Msg, err error) { +func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) { // Establish a connection if needed (or reuse cached) - conn, err := u.transport.Dial(u.proto) + conn, err := u.transport.Dial(proto) if err != nil { return nil, err } diff --git a/upstream/helpers.go b/upstream/helpers.go index 0e698698..e903f799 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -87,7 +87,7 @@ func IsAlive(u Upstream) (bool, error) { ping := new(dns.Msg) ping.SetQuestion("ipv4only.arpa.", dns.TypeA) - resp, err := u.Exchange(nil, ping) + resp, err := u.Exchange(context.Background(), ping) // If we got a header, we're alright, basically only care about I/O errors 'n stuff. if err != nil && resp != nil { diff --git a/upstream/setup.go b/upstream/setup.go new file mode 100644 index 00000000..e3420da4 --- /dev/null +++ b/upstream/setup.go @@ -0,0 +1,83 @@ +package upstream + +import ( + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/mholt/caddy" + "log" +) + +func init() { + caddy.RegisterPlugin("upstream", caddy.Plugin{ + ServerType: "dns", + Action: setup, + }) +} + +// Read the configuration and initialize upstreams +func setup(c *caddy.Controller) error { + + p, err := setupPlugin(c) + if err != nil { + return err + } + config := dnsserver.GetConfig(c) + config.AddPlugin(func(next plugin.Handler) plugin.Handler { + p.Next = next + return p + }) + + c.OnShutdown(p.onShutdown) + return nil +} + +// Read the configuration +func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) { + + p := New() + + log.Println("Initializing the Upstream plugin") + + bootstrap := "" + upstreamUrls := []string{} + for c.Next() { + args := c.RemainingArgs() + if len(args) > 0 { + upstreamUrls = append(upstreamUrls, args...) + } + for c.NextBlock() { + switch c.Val() { + case "bootstrap": + if !c.NextArg() { + return nil, c.ArgErr() + } + bootstrap = c.Val() + } + } + } + + for _, url := range upstreamUrls { + u, err := NewUpstream(url, bootstrap) + if err != nil { + log.Printf("Cannot initialize upstream %s", url) + return nil, err + } + + p.Upstreams = append(p.Upstreams, u) + } + + return p, nil +} + +func (p *UpstreamPlugin) onShutdown() error { + for i := range p.Upstreams { + + u := p.Upstreams[i] + err := u.Close() + if err != nil { + log.Printf("Error while closing the upstream: %s", err) + } + } + + return nil +} diff --git a/upstream/setup_test.go b/upstream/setup_test.go new file mode 100644 index 00000000..cff8abaf --- /dev/null +++ b/upstream/setup_test.go @@ -0,0 +1,29 @@ +package upstream + +import ( + "github.com/mholt/caddy" + "testing" +) + +func TestSetup(t *testing.T) { + + var tests = []struct { + config string + }{ + {`upstream 8.8.8.8`}, + {`upstream 8.8.8.8 { + bootstrap 8.8.8.8:53 +}`}, + {`upstream tls://1.1.1.1 8.8.8.8 { + bootstrap 1.1.1.1 +}`}, + } + + for _, test := range tests { + c := caddy.NewTestController("dns", test.config) + err := setup(c) + if err != nil { + t.Fatalf("Test failed") + } + } +} diff --git a/upstream/upstream.go b/upstream/upstream.go index 9d2222dc..6578c94e 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -5,8 +5,6 @@ import ( "github.com/miekg/dns" "github.com/pkg/errors" "golang.org/x/net/context" - "log" - "runtime" "time" ) @@ -14,8 +12,6 @@ const ( defaultTimeout = 5 * time.Second ) -// TODO: Add a helper method for health-checking an upstream (see health.go in coredns) - // Upstream is a simplified interface for proxy destination type Upstream interface { Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) @@ -30,10 +26,10 @@ type UpstreamPlugin struct { // Initialize the upstream plugin func New() *UpstreamPlugin { - p := &UpstreamPlugin{} + p := &UpstreamPlugin{ + Upstreams: []Upstream{}, + } - // Make sure all resources are cleaned up - runtime.SetFinalizer(p, (*UpstreamPlugin).finalizer) return p } @@ -56,15 +52,3 @@ func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r * // Name implements interface for CoreDNS plugin func (p *UpstreamPlugin) Name() string { return "upstream" } - -func (p *UpstreamPlugin) finalizer() { - - for i := range p.Upstreams { - - u := p.Upstreams[i] - err := u.Close() - if err != nil { - log.Printf("Error while closing the upstream: %s", err) - } - } -} \ No newline at end of file diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index 1b3235fe..171839a5 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -2,6 +2,7 @@ package upstream import ( "github.com/miekg/dns" + "golang.org/x/net/context" "net" "testing" ) @@ -169,7 +170,7 @@ func testUpstream(t *testing.T, u Upstream) { {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, } - resp, err := u.Exchange(nil, &req) + resp, err := u.Exchange(context.Background(), &req) if err != nil { t.Errorf("error while making an upstream request: %s", err) From 7f018234f6b1629f922901a7fd800612a7abfc69 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Mon, 5 Nov 2018 23:52:11 +0300 Subject: [PATCH 06/10] goimports files --- upstream/dns_upstream.go | 3 ++- upstream/helpers.go | 5 +++-- upstream/https_upstream.go | 9 +++++---- upstream/setup.go | 3 ++- upstream/setup_test.go | 3 ++- upstream/upstream.go | 3 ++- upstream/upstream_test.go | 5 +++-- 7 files changed, 19 insertions(+), 12 deletions(-) diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go index e7c2e7bd..89584e11 100644 --- a/upstream/dns_upstream.go +++ b/upstream/dns_upstream.go @@ -2,9 +2,10 @@ package upstream import ( "crypto/tls" + "time" + "github.com/miekg/dns" "golang.org/x/net/context" - "time" ) // DnsUpstream is a very simple upstream implementation for plain DNS diff --git a/upstream/helpers.go b/upstream/helpers.go index e903f799..832d58b4 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -1,10 +1,11 @@ package upstream import ( - "github.com/miekg/dns" - "golang.org/x/net/context" "net" "strings" + + "github.com/miekg/dns" + "golang.org/x/net/context" ) // Detects the upstream type from the specified url and creates a proper Upstream object diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go index ae705699..d7d7bdde 100644 --- a/upstream/https_upstream.go +++ b/upstream/https_upstream.go @@ -4,16 +4,17 @@ import ( "bytes" "crypto/tls" "fmt" - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" - "golang.org/x/net/http2" "io/ioutil" "log" "net" "net/http" "net/url" "time" + + "github.com/miekg/dns" + "github.com/pkg/errors" + "golang.org/x/net/context" + "golang.org/x/net/http2" ) const ( diff --git a/upstream/setup.go b/upstream/setup.go index e3420da4..56f5da27 100644 --- a/upstream/setup.go +++ b/upstream/setup.go @@ -1,10 +1,11 @@ package upstream import ( + "log" + "github.com/coredns/coredns/core/dnsserver" "github.com/coredns/coredns/plugin" "github.com/mholt/caddy" - "log" ) func init() { diff --git a/upstream/setup_test.go b/upstream/setup_test.go index cff8abaf..b3918932 100644 --- a/upstream/setup_test.go +++ b/upstream/setup_test.go @@ -1,8 +1,9 @@ package upstream import ( - "github.com/mholt/caddy" "testing" + + "github.com/mholt/caddy" ) func TestSetup(t *testing.T) { diff --git a/upstream/upstream.go b/upstream/upstream.go index 6578c94e..c2ab1826 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -1,11 +1,12 @@ package upstream import ( + "time" + "github.com/coredns/coredns/plugin" "github.com/miekg/dns" "github.com/pkg/errors" "golang.org/x/net/context" - "time" ) const ( diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go index 171839a5..7ce5690f 100644 --- a/upstream/upstream_test.go +++ b/upstream/upstream_test.go @@ -1,10 +1,11 @@ package upstream import ( - "github.com/miekg/dns" - "golang.org/x/net/context" "net" "testing" + + "github.com/miekg/dns" + "golang.org/x/net/context" ) func TestDnsUpstreamIsAlive(t *testing.T) { From 451922b858c77145342bc568f7773433d281fee6 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 6 Nov 2018 00:47:59 +0300 Subject: [PATCH 07/10] Added bootstrap DNS to the config file DNS healthcheck now uses the upstream package methods --- config.go | 4 +- control.go | 104 +++++------------------------------------ openapi.yaml | 1 + upstream/helpers.go | 2 +- upstream/persistent.go | 2 + 5 files changed, 19 insertions(+), 94 deletions(-) diff --git a/config.go b/config.go index 1a0d18a4..a8534bc5 100644 --- a/config.go +++ b/config.go @@ -70,6 +70,7 @@ type coreDNSConfig struct { Pprof string `yaml:"-"` Cache string `yaml:"-"` Prometheus string `yaml:"-"` + BootstrapDNS string `yaml:"bootstrap_dns"` UpstreamDNS []string `yaml:"upstream_dns"` } @@ -100,6 +101,7 @@ var config = configuration{ SafeBrowsingEnabled: false, BlockedResponseTTL: 10, // in seconds QueryLogEnabled: true, + BootstrapDNS: "8.8.8.8:53", UpstreamDNS: defaultDNS, Cache: "cache", Prometheus: "prometheus :9153", @@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} { hosts { fallthrough } - {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap 8.8.8.8:53 }{{end}} + {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}} {{.Cache}} {{.Prometheus}} } diff --git a/control.go b/control.go index 378a2aec..238bc131 100644 --- a/control.go +++ b/control.go @@ -6,7 +6,6 @@ import ( "fmt" "io/ioutil" "log" - "net" "net/http" "os" "path/filepath" @@ -15,8 +14,9 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/upstream" + corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) @@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { "protection_enabled": config.CoreDNS.ProtectionEnabled, "querylog_enabled": config.CoreDNS.QueryLogEnabled, "running": isRunning(), + "bootstrap_dns": config.CoreDNS.BootstrapDNS, "upstream_dns": config.CoreDNS.UpstreamDNS, "version": VersionString, } @@ -140,11 +141,8 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { return } // if empty body -- user is asking for default servers - hosts, err := sanitiseDNSServers(string(body)) - if err != nil { - httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err) - return - } + hosts := strings.Fields(string(body)) + if len(hosts) == 0 { config.CoreDNS.UpstreamDNS = defaultDNS } else { @@ -214,104 +212,26 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } func checkDNS(input string) error { - input, err := sanitizeDNSServer(input) + + u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS) + if err != nil { return err } - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } + alive, err := upstream.IsAlive(u) - prefix, host := splitDNSServerPrefixServer(input) - - c := dns.Client{ - Timeout: time.Minute, - } - switch prefix { - case "tls://": - c.Net = "tcp-tls" - } - - resp, rtt, err := c.Exchange(&req, host) if err != nil { return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } - trace("exchange with %s took %v", input, rtt) - if len(resp.Answer) != 1 { - return fmt.Errorf("DNS server %s returned wrong answer", input) - } - if t, ok := resp.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(t.A) { - return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) - } + + if !alive { + return fmt.Errorf("DNS server has not passed the healthcheck: %s", input) } return nil } -func sanitiseDNSServers(input string) ([]string, error) { - fields := strings.Fields(input) - hosts := make([]string, 0) - for _, field := range fields { - sanitized, err := sanitizeDNSServer(field) - if err != nil { - return hosts, err - } - hosts = append(hosts, sanitized) - } - return hosts, nil -} - -func getDNSServerPrefix(input string) string { - prefix := "" - switch { - case strings.HasPrefix(input, "dns://"): - prefix = "dns://" - case strings.HasPrefix(input, "tls://"): - prefix = "tls://" - } - return prefix -} - -func splitDNSServerPrefixServer(input string) (string, string) { - prefix := getDNSServerPrefix(input) - host := strings.TrimPrefix(input, prefix) - return prefix, host -} - -func sanitizeDNSServer(input string) (string, error) { - prefix, host := splitDNSServerPrefixServer(input) - host = appendPortIfMissing(prefix, host) - { - h, _, err := net.SplitHostPort(host) - if err != nil { - return "", err - } - ip := net.ParseIP(h) - if ip == nil { - return "", fmt.Errorf("invalid DNS server field: %s", h) - } - } - return prefix + host, nil -} - -func appendPortIfMissing(prefix, input string) string { - port := "53" - switch prefix { - case "tls://": - port = "853" - } - _, _, err := net.SplitHostPort(input) - if err == nil { - return input - } - return net.JoinHostPort(input, port) -} - //noinspection GoUnusedParameter func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() diff --git a/openapi.yaml b/openapi.yaml index 35e32a90..9fc585f2 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -41,6 +41,7 @@ paths: protection_enabled: true querylog_enabled: true running: true + bootstrap_dns: 8.8.8.8:53 upstream_dns: - 1.1.1.1 - 1.0.0.1 diff --git a/upstream/helpers.go b/upstream/helpers.go index 832d58b4..1313b8e0 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -93,7 +93,7 @@ func IsAlive(u Upstream) (bool, error) { // If we got a header, we're alright, basically only care about I/O errors 'n stuff. if err != nil && resp != nil { // Silly check, something sane came back. - if resp.Response || resp.Opcode == dns.OpcodeQuery { + if resp.Rcode != dns.RcodeServerFailure { err = nil } } diff --git a/upstream/persistent.go b/upstream/persistent.go index 5c28a10e..91cc9094 100644 --- a/upstream/persistent.go +++ b/upstream/persistent.go @@ -10,6 +10,8 @@ import ( "github.com/miekg/dns" ) +// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin + const ( defaultExpire = 10 * time.Second minDialTimeout = 100 * time.Millisecond From 2e879896ff08220407f0530eaa814db3d61f9552 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 6 Nov 2018 00:52:27 +0300 Subject: [PATCH 08/10] Close test upstream --- control.go | 1 + 1 file changed, 1 insertion(+) diff --git a/control.go b/control.go index 238bc131..63a387b0 100644 --- a/control.go +++ b/control.go @@ -218,6 +218,7 @@ func checkDNS(input string) error { if err != nil { return err } + defer u.Close() alive, err := upstream.IsAlive(u) From cc4082629986524be4e6cb988aa3d63cdf2d6a44 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 6 Nov 2018 01:14:28 +0300 Subject: [PATCH 09/10] Fix review comments --- upstream/dns_upstream.go | 7 +++++-- upstream/helpers.go | 3 +-- upstream/upstream.go | 4 +++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go index 89584e11..2902ca2e 100644 --- a/upstream/dns_upstream.go +++ b/upstream/dns_upstream.go @@ -27,7 +27,7 @@ func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstre var tlsConfig *tls.Config - if tlsServerName != "" { + if proto == "tcp-tls" { tlsConfig = new(tls.Config) tlsConfig.ServerName = tlsServerName } @@ -101,6 +101,9 @@ func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err er conn.Close() // Not giving it back } - u.transport.Yield(conn) + if err == nil { + // Return it back to the connections cache if there were no errors + u.transport.Yield(conn) + } return r, err } diff --git a/upstream/helpers.go b/upstream/helpers.go index 1313b8e0..209da533 100644 --- a/upstream/helpers.go +++ b/upstream/helpers.go @@ -72,8 +72,7 @@ func CreateResolver(bootstrap string) *net.Resolver { PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { var d net.Dialer - conn, err := d.DialContext(ctx, network, bootstrap) - return conn, err + return d.DialContext(ctx, network, bootstrap) }, } } diff --git a/upstream/upstream.go b/upstream/upstream.go index c2ab1826..faef224e 100644 --- a/upstream/upstream.go +++ b/upstream/upstream.go @@ -52,4 +52,6 @@ func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r * } // Name implements interface for CoreDNS plugin -func (p *UpstreamPlugin) Name() string { return "upstream" } +func (p *UpstreamPlugin) Name() string { + return "upstream" +} From 914eb612cd0da015b98c151b7ac603fb4126a2c3 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 6 Nov 2018 01:20:53 +0300 Subject: [PATCH 10/10] Add bootstrap DNS to readme --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 8464a7dd..b7b43446 100644 --- a/README.md +++ b/README.md @@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib * `parental_enabled` — Parental control-based DNS requests filtering * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) + * `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname * `upstream_dns` — List of upstream DNS servers * `filters` — List of filters, each filter has the following values: + * `ID` - filter ID (must be unique) * `url` — URL pointing to the filter contents (filtering rules) * `enabled` — Current filter's status (enabled/disabled) * `user_rules` — User-specified filtering rules