Upstream plugin prototype
This commit is contained in:
parent
19e30dbccc
commit
484c0ceaff
42
control.go
42
control.go
|
@ -134,9 +134,9 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
||||||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusBadRequest)
|
http.Error(w, errorText, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// if empty body -- user is asking for default servers
|
// if empty body -- user is asking for default servers
|
||||||
|
@ -153,34 +153,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
err = writeAllConfigs()
|
err = writeAllConfigs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tellCoreDNSToReload()
|
tellCoreDNSToReload()
|
||||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
body, err := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, 400)
|
http.Error(w, errorText, 400)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
hosts := strings.Fields(string(body))
|
hosts := strings.Fields(string(body))
|
||||||
|
|
||||||
if len(hosts) == 0 {
|
if len(hosts) == 0 {
|
||||||
errortext := fmt.Sprintf("No servers specified")
|
errorText := fmt.Sprintf("No servers specified")
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusBadRequest)
|
http.Error(w, errorText, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,18 +198,18 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
jsonVal, err := json.Marshal(result)
|
jsonVal, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_, err = w.Write(jsonVal)
|
_, err = w.Write(jsonVal)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||||
log.Println(errortext)
|
log.Println(errorText)
|
||||||
http.Error(w, errortext, http.StatusInternalServerError)
|
http.Error(w, errorText, http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||||
|
type DnsUpstream struct {
|
||||||
|
nameServer string // IP:port
|
||||||
|
timeout time.Duration // Max read and write timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDnsUpstream creates a new plain-DNS upstream
|
||||||
|
func NewDnsUpstream(nameServer string) (Upstream, error) {
|
||||||
|
return &DnsUpstream{nameServer: nameServer, timeout: defaultTimeout}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
|
||||||
|
dnsClient := &dns.Client{
|
||||||
|
ReadTimeout: u.timeout,
|
||||||
|
WriteTimeout: u.timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, _, err := dnsClient.Exchange(query, u.nameServer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resp = &dns.Msg{}
|
||||||
|
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dnsMessageContentType = "application/dns-message"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||||
|
type HttpsUpstream struct {
|
||||||
|
client *http.Client
|
||||||
|
endpoint *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from hostname
|
||||||
|
func NewHttpsUpstream(endpoint string) (Upstream, error) {
|
||||||
|
u, err := url.Parse(endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update TLS and HTTP client configuration
|
||||||
|
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||||
|
transport := &http.Transport{
|
||||||
|
TLSClientConfig: tlsConfig,
|
||||||
|
DisableCompression: true,
|
||||||
|
MaxIdleConns: 1,
|
||||||
|
}
|
||||||
|
http2.ConfigureTransport(transport)
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: defaultTimeout,
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
queryBuf, err := query.Pack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||||
|
}
|
||||||
|
|
||||||
|
// No content negotiation for now, use DNS wire format
|
||||||
|
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||||
|
if backendErr == nil {
|
||||||
|
response := &dns.Msg{}
|
||||||
|
if err := response.Unpack(buf); err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Id = query.Id
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||||
|
return nil, backendErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||||
|
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||||
|
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||||
|
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||||
|
req.Header.Add("Accept", dnsMessageContentType)
|
||||||
|
req.Host = u.endpoint.Hostname()
|
||||||
|
|
||||||
|
resp, err := u.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check response status code
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentType := resp.Header.Get("Content-Type")
|
||||||
|
if contentType != dnsMessageContentType {
|
||||||
|
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read application/dns-message response from the body
|
||||||
|
buf, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "failed to read the response body")
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf, nil
|
||||||
|
}
|
|
@ -0,0 +1,47 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: Use persistent connection here
|
||||||
|
|
||||||
|
// DnsOverTlsUpstream is the upstream implementation for plain DNS-over-TLS
|
||||||
|
type DnsOverTlsUpstream struct {
|
||||||
|
endpoint string
|
||||||
|
tlsServerName string
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHttpsUpstream creates a new DNS-over-TLS upstream from the endpoint address and TLS server name
|
||||||
|
func NewDnsOverTlsUpstream(endpoint string, tlsServerName string) (Upstream, error) {
|
||||||
|
return &DnsOverTlsUpstream{
|
||||||
|
endpoint: endpoint,
|
||||||
|
tlsServerName: tlsServerName,
|
||||||
|
timeout: defaultTimeout,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange provides an implementation for the Upstream interface
|
||||||
|
func (u *DnsOverTlsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||||
|
|
||||||
|
dnsClient := &dns.Client{
|
||||||
|
Net: "tcp-tls",
|
||||||
|
ReadTimeout: u.timeout,
|
||||||
|
WriteTimeout: u.timeout,
|
||||||
|
TLSConfig: new(tls.Config),
|
||||||
|
}
|
||||||
|
dnsClient.TLSConfig.ServerName = u.tlsServerName
|
||||||
|
|
||||||
|
resp, _, err := dnsClient.Exchange(query, u.endpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
resp = &dns.Msg{}
|
||||||
|
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, err
|
||||||
|
}
|
|
@ -0,0 +1,43 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/coredns/coredns/plugin"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Upstream is a simplified interface for proxy destination
|
||||||
|
type Upstream interface {
|
||||||
|
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||||
|
type UpstreamPlugin struct {
|
||||||
|
Upstreams []Upstream
|
||||||
|
Next plugin.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeDNS implements interface for CoreDNS plugin
|
||||||
|
func (p UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||||
|
var reply *dns.Msg
|
||||||
|
var backendErr error
|
||||||
|
|
||||||
|
for _, upstream := range p.Upstreams {
|
||||||
|
reply, backendErr = upstream.Exchange(ctx, r)
|
||||||
|
if backendErr == nil {
|
||||||
|
w.WriteMsg(reply)
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name implements interface for CoreDNS plugin
|
||||||
|
func (p UpstreamPlugin) Name() string { return "upstream" }
|
|
@ -0,0 +1,86 @@
|
||||||
|
package upstream
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDnsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
u, err := NewDnsUpstream("8.8.8.8:53")
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHttpsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
testCases := []string{
|
||||||
|
"https://cloudflare-dns.com/dns-query",
|
||||||
|
"https://dns.google.com/experimental",
|
||||||
|
"https://doh.cleanbrowsing.org/doh/security-filter/",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, url := range testCases {
|
||||||
|
u, err := NewHttpsUpstream(url)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||||
|
|
||||||
|
var tests = []struct {
|
||||||
|
endpoint string
|
||||||
|
tlsServerName string
|
||||||
|
}{
|
||||||
|
{"1.1.1.1:853", ""},
|
||||||
|
{"8.8.8.8:853", ""},
|
||||||
|
{"185.228.168.10:853", "security-filter-dns.cleanbrowsing.org"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
u, err := NewDnsOverTlsUpstream(test.endpoint, test.tlsServerName)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||||
|
}
|
||||||
|
|
||||||
|
testUpstream(t, u)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testUpstream(t *testing.T, u Upstream) {
|
||||||
|
req := dns.Msg{}
|
||||||
|
req.Id = dns.Id()
|
||||||
|
req.RecursionDesired = true
|
||||||
|
req.Question = []dns.Question{
|
||||||
|
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := u.Exchange(nil, &req)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error while making an upstream request: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Answer) != 1 {
|
||||||
|
t.Errorf("no answer section in the response")
|
||||||
|
}
|
||||||
|
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||||
|
if !net.IPv4(8, 8, 8, 8).Equal(answer.A) {
|
||||||
|
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("response: %v", resp)
|
||||||
|
}
|
Loading…
Reference in New Issue