Upstream plugin prototype

This commit is contained in:
Andrey Meshkov 2018-11-01 14:45:32 +03:00
parent 19e30dbccc
commit 484c0ceaff
6 changed files with 342 additions and 21 deletions

View File

@ -134,9 +134,9 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err) errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
// if empty body -- user is asking for default servers // if empty body -- user is asking for default servers
@ -153,34 +153,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
err = writeAllConfigs() err = writeAllConfigs()
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err) errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
tellCoreDNSToReload() tellCoreDNSToReload()
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body) body, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err) errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, 400) http.Error(w, errorText, 400)
return return
} }
hosts := strings.Fields(string(body)) hosts := strings.Fields(string(body))
if len(hosts) == 0 { if len(hosts) == 0 {
errortext := fmt.Sprintf("No servers specified") errorText := fmt.Sprintf("No servers specified")
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusBadRequest) http.Error(w, errorText, http.StatusBadRequest)
return return
} }
@ -198,18 +198,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(result) jsonVal, err := json.Marshal(result)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err) errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
return return
} }
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal) _, err = w.Write(jsonVal)
if err != nil { if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err) errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext) log.Println(errorText)
http.Error(w, errortext, http.StatusInternalServerError) http.Error(w, errorText, http.StatusInternalServerError)
} }
} }

36
upstream/dns_upstream.go Normal file
View File

@ -0,0 +1,36 @@
package upstream
import (
"github.com/miekg/dns"
"golang.org/x/net/context"
"time"
)
// DnsUpstream is a very simple upstream implementation for plain DNS
type DnsUpstream struct {
nameServer string // IP:port
timeout time.Duration // Max read and write timeout
}
// NewDnsUpstream creates a new plain-DNS upstream
func NewDnsUpstream(nameServer string) (Upstream, error) {
return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
dnsClient := &dns.Client{
ReadTimeout: u.timeout,
WriteTimeout: u.timeout,
}
resp, _, err := dnsClient.Exchange(query, u.nameServer)
if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}
return resp, err
}

109
upstream/https_upstream.go Normal file
View File

@ -0,0 +1,109 @@
package upstream
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"io/ioutil"
"log"
"net/http"
"net/url"
)
const (
dnsMessageContentType = "application/dns-message"
)
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
type HttpsUpstream struct {
client *http.Client
endpoint *url.URL
}
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
func NewHttpsUpstream(endpoint string) (Upstream, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
// Update TLS and HTTP client configuration
tlsConfig := &tls.Config{ServerName: u.Hostname()}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true,
MaxIdleConns: 1,
}
http2.ConfigureTransport(transport)
client := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
return &HttpsUpstream{client: client, endpoint: u}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
queryBuf, err := query.Pack()
if err != nil {
return nil, errors.Wrap(err, "failed to pack DNS query")
}
// No content negotiation for now, use DNS wire format
buf, backendErr := u.exchangeWireformat(queryBuf)
if backendErr == nil {
response := &dns.Msg{}
if err := response.Unpack(buf); err != nil {
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
}
response.Id = query.Id
return response, nil
}
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
return nil, backendErr
}
// Perform message exchange with the default UDP wireformat defined in current draft
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
if err != nil {
return nil, errors.Wrap(err, "failed to create an HTTPS request")
}
req.Header.Add("Content-Type", dnsMessageContentType)
req.Header.Add("Accept", dnsMessageContentType)
req.Host = u.endpoint.Hostname()
resp, err := u.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != dnsMessageContentType {
return nil, fmt.Errorf("return wrong content type %s", contentType)
}
// Read application/dns-message response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read the response body")
}
return buf, nil
}

47
upstream/tls_upstream.go Normal file
View File

@ -0,0 +1,47 @@
package upstream
import (
"crypto/tls"
"github.com/miekg/dns"
"golang.org/x/net/context"
"time"
)
// TODO: Use persistent connection here
// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS
type DnsOverTlsUpstream struct {
endpoint string
tlsServerName string
timeout time.Duration
}
// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name
func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) {
return &DnsOverTlsUpstream{
endpoint: endpoint,
tlsServerName: tlsServerName,
timeout: defaultTimeout,
}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
dnsClient := &dns.Client{
Net: "tcp-tls",
ReadTimeout: u.timeout,
WriteTimeout: u.timeout,
TLSConfig: new(tls.Config),
}
dnsClient.TLSConfig.ServerName = u.tlsServerName
resp, _, err := dnsClient.Exchange(query, u.endpoint)
if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}
return resp, err
}

43
upstream/upstream.go Normal file
View File

@ -0,0 +1,43 @@
package upstream
import (
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
"time"
)
const (
defaultTimeout = 5 * time.Second
)
// Upstream is a simplified interface for proxy destination
type Upstream interface {
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
}
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
type UpstreamPlugin struct {
Upstreams []Upstream
Next plugin.Handler
}
// ServeDNS implements interface for CoreDNS plugin
func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
var reply *dns.Msg
var backendErr error
for _, upstream := range p.Upstreams {
reply, backendErr = upstream.Exchange(ctx, r)
if backendErr == nil {
w.WriteMsg(reply)
return 0, nil
}
}
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
}
// Name implements interface for CoreDNS plugin
func (p UpstreamPlugin) Name() string { return "upstream" }

86
upstream/upstream_test.go Normal file
View File

@ -0,0 +1,86 @@
package upstream
import (
"github.com/miekg/dns"
"log"
"net"
"testing"
)
func TestDnsUpstream(t *testing.T) {
u, err := NewDnsUpstream("8.8.8.8:53")
if err != nil {
t.Errorf("cannot create a DNS upstream")
}
testUpstream(t, u)
}
func TestHttpsUpstream(t *testing.T) {
testCases := []string{
"https://cloudflare-dns.com/dns-query",
"https://dns.google.com/experimental",
"https://doh.cleanbrowsing.org/doh/security-filter/",
}
for _, url := range testCases {
u, err := NewHttpsUpstream(url)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
}
testUpstream(t, u)
}
}
func TestDnsOverTlsUpstream(t *testing.T) {
var tests = []struct {
endpoint string
tlsServerName string
}{
{"1.1.1.1:853", ""},
{"8.8.8.8:853", ""},
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
}
for _, test := range tests {
u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
}
testUpstream(t, u)
}
}
func testUpstream(t *testing.T, u Upstream) {
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
resp, err := u.Exchange(nil, &req)
if err != nil {
t.Errorf("error while making an upstream request: %s", err)
}
if len(resp.Answer) != 1 {
t.Errorf("no answer section in the response")
}
if answer, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(answer.A) {
t.Errorf("wrong IP in the response: %v", answer.A)
}
}
log.Printf("response: %v", resp)
}