Merge pull request #93 in DNS/adguard-dns from fix/414 to master
* commit '914eb612cd0da015b98c151b7ac603fb4126a2c3': Add bootstrap DNS to readme Fix review comments Close test upstream Added bootstrap DNS to the config file DNS healthcheck now uses the upstream package methods goimports files Added CoreDNS plugin setup and replaced forward Added factory method for creating DNS upstreams Added health-check method Added persistent connections cache Upstream plugin prototype
This commit is contained in:
commit
4a357f1345
|
@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
|
|||
* `parental_enabled` — Parental control-based DNS requests filtering
|
||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
||||
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
|
||||
* `upstream_dns` — List of upstream DNS servers
|
||||
* `filters` — List of filters, each filter has the following values:
|
||||
* `ID` - filter ID (must be unique)
|
||||
* `url` — URL pointing to the filter contents (filtering rules)
|
||||
* `enabled` — Current filter's status (enabled/disabled)
|
||||
* `user_rules` — User-specified filtering rules
|
||||
|
|
|
@ -70,6 +70,7 @@ type coreDNSConfig struct {
|
|||
Pprof string `yaml:"-"`
|
||||
Cache string `yaml:"-"`
|
||||
Prometheus string `yaml:"-"`
|
||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
|
||||
|
@ -100,6 +101,7 @@ var config = configuration{
|
|||
SafeBrowsingEnabled: false,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
BootstrapDNS: "8.8.8.8:53",
|
||||
UpstreamDNS: defaultDNS,
|
||||
Cache: "cache",
|
||||
Prometheus: "prometheus :9153",
|
||||
|
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
|
|||
hosts {
|
||||
fallthrough
|
||||
}
|
||||
{{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}}
|
||||
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
|
||||
{{.Cache}}
|
||||
{{.Prometheus}}
|
||||
}
|
||||
|
|
147
control.go
147
control.go
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -15,8 +14,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
|
||||
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
|
||||
|
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
||||
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
||||
"running": isRunning(),
|
||||
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
|
||||
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
||||
"version": VersionString,
|
||||
}
|
||||
|
@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
|
|||
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// if empty body -- user is asking for default servers
|
||||
hosts, err := sanitiseDNSServers(string(body))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||
} else {
|
||||
|
@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
err = writeAllConfigs()
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
tellCoreDNSToReload()
|
||||
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, 400)
|
||||
errorText := fmt.Sprintf("Failed to read request body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, 400)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
errortext := fmt.Sprintf("No servers specified")
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusBadRequest)
|
||||
errorText := fmt.Sprintf("No servers specified")
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -198,120 +196,43 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
jsonVal, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonVal)
|
||||
if err != nil {
|
||||
errortext := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errortext)
|
||||
http.Error(w, errortext, http.StatusInternalServerError)
|
||||
errorText := fmt.Sprintf("Couldn't write body: %s", err)
|
||||
log.Println(errorText)
|
||||
http.Error(w, errorText, http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNS(input string) error {
|
||||
input, err := sanitizeDNSServer(input)
|
||||
|
||||
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer u.Close()
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
alive, err := upstream.IsAlive(u)
|
||||
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
|
||||
c := dns.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
c.Net = "tcp-tls"
|
||||
}
|
||||
|
||||
resp, rtt, err := c.Exchange(&req, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
||||
}
|
||||
trace("exchange with %s took %v", input, rtt)
|
||||
if len(resp.Answer) != 1 {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||
}
|
||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||
}
|
||||
|
||||
if !alive {
|
||||
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitiseDNSServers(input string) ([]string, error) {
|
||||
fields := strings.Fields(input)
|
||||
hosts := make([]string, 0)
|
||||
for _, field := range fields {
|
||||
sanitized, err := sanitizeDNSServer(field)
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
hosts = append(hosts, sanitized)
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func getDNSServerPrefix(input string) string {
|
||||
prefix := ""
|
||||
switch {
|
||||
case strings.HasPrefix(input, "dns://"):
|
||||
prefix = "dns://"
|
||||
case strings.HasPrefix(input, "tls://"):
|
||||
prefix = "tls://"
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func splitDNSServerPrefixServer(input string) (string, string) {
|
||||
prefix := getDNSServerPrefix(input)
|
||||
host := strings.TrimPrefix(input, prefix)
|
||||
return prefix, host
|
||||
}
|
||||
|
||||
func sanitizeDNSServer(input string) (string, error) {
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
host = appendPortIfMissing(prefix, host)
|
||||
{
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("invalid DNS server field: %s", h)
|
||||
}
|
||||
}
|
||||
return prefix + host, nil
|
||||
}
|
||||
|
||||
func appendPortIfMissing(prefix, input string) string {
|
||||
port := "53"
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
port = "853"
|
||||
}
|
||||
_, _, err := net.SplitHostPort(input)
|
||||
if err == nil {
|
||||
return input
|
||||
}
|
||||
return net.JoinHostPort(input, port)
|
||||
}
|
||||
|
||||
//noinspection GoUnusedParameter
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"sync" // Include all plugins.
|
||||
|
||||
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
_ "github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/coremain"
|
||||
_ "github.com/coredns/coredns/plugin/auto"
|
||||
|
@ -79,6 +80,7 @@ var directives = []string{
|
|||
"loop",
|
||||
"forward",
|
||||
"proxy",
|
||||
"upstream",
|
||||
"erratic",
|
||||
"whoami",
|
||||
"on",
|
||||
|
|
|
@ -41,6 +41,7 @@ paths:
|
|||
protection_enabled: true
|
||||
querylog_enabled: true
|
||||
running: true
|
||||
bootstrap_dns: 8.8.8.8:53
|
||||
upstream_dns:
|
||||
- 1.1.1.1
|
||||
- 1.0.0.1
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// DnsUpstream is a very simple upstream implementation for plain DNS
|
||||
type DnsUpstream struct {
|
||||
endpoint string // IP:port
|
||||
timeout time.Duration // Max read and write timeout
|
||||
proto string // Protocol (tcp, tcp-tls, or udp)
|
||||
transport *Transport // Persistent connections cache
|
||||
}
|
||||
|
||||
// NewDnsUpstream creates a new DNS upstream
|
||||
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
||||
|
||||
u := &DnsUpstream{
|
||||
endpoint: endpoint,
|
||||
timeout: defaultTimeout,
|
||||
proto: proto,
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if proto == "tcp-tls" {
|
||||
tlsConfig = new(tls.Config)
|
||||
tlsConfig.ServerName = tlsServerName
|
||||
}
|
||||
|
||||
// Initialize the connections cache
|
||||
u.transport = NewTransport(endpoint)
|
||||
u.transport.tlsConfig = tlsConfig
|
||||
u.transport.Start()
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
|
||||
resp, err := u.exchange(u.proto, query)
|
||||
|
||||
// Retry over TCP if response is truncated
|
||||
if err == dns.ErrTruncated && u.proto == "udp" {
|
||||
resp, err = u.exchange("tcp", query)
|
||||
} else if err == dns.ErrTruncated && resp != nil {
|
||||
// Reassemble something to be sent to client
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(query)
|
||||
m.Truncated = true
|
||||
m.Authoritative = true
|
||||
m.Rcode = dns.RcodeSuccess
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
resp = &dns.Msg{}
|
||||
resp.SetRcode(resp, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *DnsUpstream) Close() error {
|
||||
|
||||
// Close active connections
|
||||
u.transport.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Performs a synchronous query. It sends the message m via the conn
|
||||
// c and waits for a reply. The conn c is not closed.
|
||||
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
|
||||
|
||||
// Establish a connection if needed (or reuse cached)
|
||||
conn, err := u.transport.Dial(proto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write the request with a timeout
|
||||
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
||||
if err = conn.WriteMsg(query); err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write response with a timeout
|
||||
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
||||
r, err = conn.ReadMsg()
|
||||
if err != nil {
|
||||
conn.Close() // Not giving it back
|
||||
} else if err == nil && r.Id != query.Id {
|
||||
err = dns.ErrId
|
||||
conn.Close() // Not giving it back
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Return it back to the connections cache if there were no errors
|
||||
u.transport.Yield(conn)
|
||||
}
|
||||
return r, err
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// Detects the upstream type from the specified url and creates a proper Upstream object
|
||||
func NewUpstream(url string, bootstrap string) (Upstream, error) {
|
||||
|
||||
proto := "udp"
|
||||
prefix := ""
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(url, "tcp://"):
|
||||
proto = "tcp"
|
||||
prefix = "tcp://"
|
||||
case strings.HasPrefix(url, "tls://"):
|
||||
proto = "tcp-tls"
|
||||
prefix = "tls://"
|
||||
case strings.HasPrefix(url, "https://"):
|
||||
return NewHttpsUpstream(url, bootstrap)
|
||||
}
|
||||
|
||||
hostname := strings.TrimPrefix(url, prefix)
|
||||
|
||||
host, port, err := net.SplitHostPort(hostname)
|
||||
if err != nil {
|
||||
// Set port depending on the protocol
|
||||
switch proto {
|
||||
case "udp":
|
||||
port = "53"
|
||||
case "tcp":
|
||||
port = "53"
|
||||
case "tcp-tls":
|
||||
port = "853"
|
||||
}
|
||||
|
||||
// Set host = hostname
|
||||
host = hostname
|
||||
}
|
||||
|
||||
// Try to resolve the host address (or check if it's an IP address)
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
|
||||
|
||||
if err != nil || len(ips) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addr := ips[0].String()
|
||||
endpoint := net.JoinHostPort(addr, port)
|
||||
tlsServerName := ""
|
||||
|
||||
if proto == "tcp-tls" && host != addr {
|
||||
// Check if we need to specify TLS server name
|
||||
tlsServerName = host
|
||||
}
|
||||
|
||||
return NewDnsUpstream(endpoint, proto, tlsServerName)
|
||||
}
|
||||
|
||||
func CreateResolver(bootstrap string) *net.Resolver {
|
||||
|
||||
bootstrapResolver := net.DefaultResolver
|
||||
|
||||
if bootstrap != "" {
|
||||
bootstrapResolver = &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, network, bootstrap)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return bootstrapResolver
|
||||
}
|
||||
|
||||
// Performs a simple health-check of the specified upstream
|
||||
func IsAlive(u Upstream) (bool, error) {
|
||||
|
||||
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
|
||||
ping := new(dns.Msg)
|
||||
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
|
||||
|
||||
resp, err := u.Exchange(context.Background(), ping)
|
||||
|
||||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||
if err != nil && resp != nil {
|
||||
// Silly check, something sane came back.
|
||||
if resp.Rcode != dns.RcodeServerFailure {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return err == nil, err
|
||||
}
|
|
@ -0,0 +1,128 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
dnsMessageContentType = "application/dns-message"
|
||||
defaultKeepAlive = 30 * time.Second
|
||||
)
|
||||
|
||||
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
|
||||
type HttpsUpstream struct {
|
||||
client *http.Client
|
||||
endpoint *url.URL
|
||||
}
|
||||
|
||||
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
|
||||
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
|
||||
u, err := url.Parse(endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize bootstrap resolver
|
||||
bootstrapResolver := CreateResolver(bootstrap)
|
||||
dialer := &net.Dialer{
|
||||
Timeout: defaultTimeout,
|
||||
KeepAlive: defaultKeepAlive,
|
||||
DualStack: true,
|
||||
Resolver: bootstrapResolver,
|
||||
}
|
||||
|
||||
// Update TLS and HTTP client configuration
|
||||
tlsConfig := &tls.Config{ServerName: u.Hostname()}
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableCompression: true,
|
||||
MaxIdleConns: 1,
|
||||
DialContext: dialer.DialContext,
|
||||
}
|
||||
http2.ConfigureTransport(transport)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: defaultTimeout,
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
return &HttpsUpstream{client: client, endpoint: u}, nil
|
||||
}
|
||||
|
||||
// Exchange provides an implementation for the Upstream interface
|
||||
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
||||
queryBuf, err := query.Pack()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to pack DNS query")
|
||||
}
|
||||
|
||||
// No content negotiation for now, use DNS wire format
|
||||
buf, backendErr := u.exchangeWireformat(queryBuf)
|
||||
if backendErr == nil {
|
||||
response := &dns.Msg{}
|
||||
if err := response.Unpack(buf); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
|
||||
}
|
||||
|
||||
response.Id = query.Id
|
||||
return response, nil
|
||||
}
|
||||
|
||||
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
|
||||
return nil, backendErr
|
||||
}
|
||||
|
||||
// Perform message exchange with the default UDP wireformat defined in current draft
|
||||
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
|
||||
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
|
||||
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create an HTTPS request")
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", dnsMessageContentType)
|
||||
req.Header.Add("Accept", dnsMessageContentType)
|
||||
req.Host = u.endpoint.Hostname()
|
||||
|
||||
resp, err := u.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
|
||||
}
|
||||
|
||||
// Check response status code
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType != dnsMessageContentType {
|
||||
return nil, fmt.Errorf("return wrong content type %s", contentType)
|
||||
}
|
||||
|
||||
// Read application/dns-message response from the body
|
||||
buf, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read the response body")
|
||||
}
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// Clear resources
|
||||
func (u *HttpsUpstream) Close() error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sort"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
minDialTimeout = 100 * time.Millisecond
|
||||
maxDialTimeout = 30 * time.Second
|
||||
defaultDialTimeout = 30 * time.Second
|
||||
cumulativeAvgWeight = 4
|
||||
)
|
||||
|
||||
// a persistConn hold the dns.Conn and the last used time.
|
||||
type persistConn struct {
|
||||
c *dns.Conn
|
||||
used time.Time
|
||||
}
|
||||
|
||||
// Transport hold the persistent cache.
|
||||
type Transport struct {
|
||||
avgDialTime int64 // kind of average time of dial time
|
||||
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
|
||||
expire time.Duration // After this duration a connection is expired.
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
dial chan string
|
||||
yield chan *dns.Conn
|
||||
ret chan *dns.Conn
|
||||
stop chan bool
|
||||
}
|
||||
|
||||
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
|
||||
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
|
||||
// If tls has been configured; use it.
|
||||
if t.tlsConfig != nil {
|
||||
proto = "tcp-tls"
|
||||
}
|
||||
|
||||
t.dial <- proto
|
||||
c := <-t.ret
|
||||
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
reqTime := time.Now()
|
||||
timeout := t.dialTimeout()
|
||||
if proto == "tcp-tls" {
|
||||
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
conn, err := dns.DialTimeout(proto, t.addr, timeout)
|
||||
t.updateDialTimeout(time.Since(reqTime))
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Yield return the connection to transport for reuse.
|
||||
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
|
||||
|
||||
// Start starts the transport's connection manager.
|
||||
func (t *Transport) Start() { go t.connManager() }
|
||||
|
||||
// Stop stops the transport's connection manager.
|
||||
func (t *Transport) Stop() { close(t.stop) }
|
||||
|
||||
// SetExpire sets the connection expire time in transport.
|
||||
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
|
||||
|
||||
// SetTLSConfig sets the TLS config in transport.
|
||||
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
|
||||
|
||||
func NewTransport(addr string) *Transport {
|
||||
t := &Transport{
|
||||
avgDialTime: int64(defaultDialTimeout / 2),
|
||||
conns: make(map[string][]*persistConn),
|
||||
expire: defaultExpire,
|
||||
addr: addr,
|
||||
dial: make(chan string),
|
||||
yield: make(chan *dns.Conn),
|
||||
ret: make(chan *dns.Conn),
|
||||
stop: make(chan bool),
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
|
||||
dt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
|
||||
}
|
||||
|
||||
func (t *Transport) dialTimeout() time.Duration {
|
||||
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
|
||||
}
|
||||
|
||||
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
|
||||
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
|
||||
}
|
||||
|
||||
// limitTimeout is a utility function to auto-tune timeout values
|
||||
// average observed time is moved towards the last observed delay moderated by a weight
|
||||
// next timeout to use will be the double of the computed average, limited by min and max frame.
|
||||
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
|
||||
rt := time.Duration(atomic.LoadInt64(currentAvg))
|
||||
if rt < minValue {
|
||||
return minValue
|
||||
}
|
||||
if rt < maxValue/2 {
|
||||
return 2 * rt
|
||||
}
|
||||
return maxValue
|
||||
}
|
||||
|
||||
// connManagers manages the persistent connection cache for UDP and TCP.
|
||||
func (t *Transport) connManager() {
|
||||
ticker := time.NewTicker(t.expire)
|
||||
Wait:
|
||||
for {
|
||||
select {
|
||||
case proto := <-t.dial:
|
||||
// take the last used conn - complexity O(1)
|
||||
if stack := t.conns[proto]; len(stack) > 0 {
|
||||
pc := stack[len(stack)-1]
|
||||
if time.Since(pc.used) < t.expire {
|
||||
// Found one, remove from pool and return this conn.
|
||||
t.conns[proto] = stack[:len(stack)-1]
|
||||
t.ret <- pc.c
|
||||
continue Wait
|
||||
}
|
||||
// clear entire cache if the last conn is expired
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
}
|
||||
|
||||
t.ret <- nil
|
||||
|
||||
case conn := <-t.yield:
|
||||
|
||||
// no proto here, infer from config and conn
|
||||
if _, ok := conn.Conn.(*net.UDPConn); ok {
|
||||
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
if t.tlsConfig == nil {
|
||||
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
|
||||
continue Wait
|
||||
}
|
||||
|
||||
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
|
||||
|
||||
case <-ticker.C:
|
||||
t.cleanup(false)
|
||||
|
||||
case <-t.stop:
|
||||
t.cleanup(true)
|
||||
close(t.ret)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConns closes connections.
|
||||
func closeConns(conns []*persistConn) {
|
||||
for _, pc := range conns {
|
||||
pc.c.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes connections from cache.
|
||||
func (t *Transport) cleanup(all bool) {
|
||||
staleTime := time.Now().Add(-t.expire)
|
||||
for proto, stack := range t.conns {
|
||||
if len(stack) == 0 {
|
||||
continue
|
||||
}
|
||||
if all {
|
||||
t.conns[proto] = nil
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack)
|
||||
continue
|
||||
}
|
||||
if stack[0].used.After(staleTime) {
|
||||
continue
|
||||
}
|
||||
|
||||
// connections in stack are sorted by "used"
|
||||
good := sort.Search(len(stack), func(i int) bool {
|
||||
return stack[i].used.After(staleTime)
|
||||
})
|
||||
t.conns[proto] = stack[good:]
|
||||
// now, the connections being passed to closeConns() are not reachable from
|
||||
// transport methods anymore. So, it's safe to close them in a separate goroutine
|
||||
go closeConns(stack[:good])
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/coredns/coredns/core/dnsserver"
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("upstream", caddy.Plugin{
|
||||
ServerType: "dns",
|
||||
Action: setup,
|
||||
})
|
||||
}
|
||||
|
||||
// Read the configuration and initialize upstreams
|
||||
func setup(c *caddy.Controller) error {
|
||||
|
||||
p, err := setupPlugin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config := dnsserver.GetConfig(c)
|
||||
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
|
||||
p.Next = next
|
||||
return p
|
||||
})
|
||||
|
||||
c.OnShutdown(p.onShutdown)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read the configuration
|
||||
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
|
||||
|
||||
p := New()
|
||||
|
||||
log.Println("Initializing the Upstream plugin")
|
||||
|
||||
bootstrap := ""
|
||||
upstreamUrls := []string{}
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
if len(args) > 0 {
|
||||
upstreamUrls = append(upstreamUrls, args...)
|
||||
}
|
||||
for c.NextBlock() {
|
||||
switch c.Val() {
|
||||
case "bootstrap":
|
||||
if !c.NextArg() {
|
||||
return nil, c.ArgErr()
|
||||
}
|
||||
bootstrap = c.Val()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, url := range upstreamUrls {
|
||||
u, err := NewUpstream(url, bootstrap)
|
||||
if err != nil {
|
||||
log.Printf("Cannot initialize upstream %s", url)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p.Upstreams = append(p.Upstreams, u)
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (p *UpstreamPlugin) onShutdown() error {
|
||||
for i := range p.Upstreams {
|
||||
|
||||
u := p.Upstreams[i]
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestSetup(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
config string
|
||||
}{
|
||||
{`upstream 8.8.8.8`},
|
||||
{`upstream 8.8.8.8 {
|
||||
bootstrap 8.8.8.8:53
|
||||
}`},
|
||||
{`upstream tls://1.1.1.1 8.8.8.8 {
|
||||
bootstrap 1.1.1.1
|
||||
}`},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
c := caddy.NewTestController("dns", test.config)
|
||||
err := setup(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Test failed")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coredns/coredns/plugin"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// Upstream is a simplified interface for proxy destination
|
||||
type Upstream interface {
|
||||
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
|
||||
type UpstreamPlugin struct {
|
||||
Upstreams []Upstream
|
||||
Next plugin.Handler
|
||||
}
|
||||
|
||||
// Initialize the upstream plugin
|
||||
func New() *UpstreamPlugin {
|
||||
p := &UpstreamPlugin{
|
||||
Upstreams: []Upstream{},
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// ServeDNS implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
|
||||
var reply *dns.Msg
|
||||
var backendErr error
|
||||
|
||||
for i := range p.Upstreams {
|
||||
upstream := p.Upstreams[i]
|
||||
reply, backendErr = upstream.Exchange(ctx, r)
|
||||
if backendErr == nil {
|
||||
w.WriteMsg(reply)
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
|
||||
}
|
||||
|
||||
// Name implements interface for CoreDNS plugin
|
||||
func (p *UpstreamPlugin) Name() string {
|
||||
return "upstream"
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
package upstream
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDnsUpstreamIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstreamIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsIsAlive(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstreamIsAlive(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"8.8.8.8:53", "8.8.8.8:53"},
|
||||
{"1.1.1.1", ""},
|
||||
{"tcp://1.1.1.1:53", ""},
|
||||
{"176.103.130.130:5353", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHttpsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
|
||||
{"https://dns.google.com/experimental", "8.8.8.8:53"},
|
||||
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-HTTPS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDnsOverTlsUpstream(t *testing.T) {
|
||||
|
||||
var tests = []struct {
|
||||
url string
|
||||
bootstrap string
|
||||
}{
|
||||
{"tls://1.1.1.1", ""},
|
||||
{"tls://9.9.9.9:853", ""},
|
||||
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
|
||||
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
u, err := NewUpstream(test.url, test.bootstrap)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("cannot create a DNS-over-TLS upstream")
|
||||
}
|
||||
|
||||
testUpstream(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testUpstreamIsAlive(t *testing.T, u Upstream) {
|
||||
alive, err := IsAlive(u)
|
||||
if !alive || err != nil {
|
||||
t.Errorf("Upstream is not alive")
|
||||
}
|
||||
|
||||
u.Close()
|
||||
}
|
||||
|
||||
func testUpstream(t *testing.T, u Upstream) {
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
expected net.IP
|
||||
}{
|
||||
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
|
||||
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
|
||||
resp, err := u.Exchange(context.Background(), &req)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error while making an upstream request: %s", err)
|
||||
}
|
||||
|
||||
if len(resp.Answer) != 1 {
|
||||
t.Errorf("no answer section in the response")
|
||||
}
|
||||
if answer, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !test.expected.Equal(answer.A) {
|
||||
t.Errorf("wrong IP in the response: %v", answer.A)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := u.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Error while closing the upstream: %s", err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue