Merge pull request #93 in DNS/adguard-dns from fix/414 to master

* commit '914eb612cd0da015b98c151b7ac603fb4126a2c3':
  Add bootstrap DNS to readme
  Fix review comments
  Close test upstream
  Added bootstrap DNS to the config file DNS healthcheck now uses the upstream package methods
  goimports files
  Added CoreDNS plugin setup and replaced forward
  Added factory method for creating DNS upstreams
  Added health-check method
  Added persistent connections cache
  Upstream plugin prototype
This commit is contained in:
Andrey Meshkov 2018-11-06 12:27:08 +03:00
commit 4a357f1345
13 changed files with 955 additions and 114 deletions

View File

@ -106,8 +106,10 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
* `parental_enabled` — Parental control-based DNS requests filtering
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
* `upstream_dns` — List of upstream DNS servers
* `filters` — List of filters, each filter has the following values:
* `ID` - filter ID (must be unique)
* `url` — URL pointing to the filter contents (filtering rules)
* `enabled` — Current filter's status (enabled/disabled)
* `user_rules` — User-specified filtering rules

View File

@ -70,6 +70,7 @@ type coreDNSConfig struct {
Pprof string `yaml:"-"`
Cache string `yaml:"-"`
Prometheus string `yaml:"-"`
BootstrapDNS string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"`
}
@ -100,6 +101,7 @@ var config = configuration{
SafeBrowsingEnabled: false,
BlockedResponseTTL: 10, // in seconds
QueryLogEnabled: true,
BootstrapDNS: "8.8.8.8:53",
UpstreamDNS: defaultDNS,
Cache: "cache",
Prometheus: "prometheus :9153",
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
hosts {
fallthrough
}
{{if .UpstreamDNS}}forward . {{range .UpstreamDNS}}{{.}} {{end}}{{end}}
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
{{.Cache}}
{{.Prometheus}}
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
@ -15,8 +14,9 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/upstream"
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
"github.com/miekg/dns"
"gopkg.in/asaskevich/govalidator.v4"
)
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
"protection_enabled": config.CoreDNS.ProtectionEnabled,
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
"running": isRunning(),
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
"upstream_dns": config.CoreDNS.UpstreamDNS,
"version": VersionString,
}
@ -134,17 +135,14 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusBadRequest)
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
// if empty body -- user is asking for default servers
hosts, err := sanitiseDNSServers(string(body))
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
return
}
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS
} else {
@ -153,34 +151,34 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
err = writeAllConfigs()
if err != nil {
errortext := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
tellCoreDNSToReload()
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, 400)
return
}
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
errortext := fmt.Sprintf("No servers specified")
log.Println(errortext)
http.Error(w, errortext, http.StatusBadRequest)
errorText := fmt.Sprintf("No servers specified")
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
@ -198,120 +196,43 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
jsonVal, err := json.Marshal(result)
if err != nil {
errortext := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Unable to marshal status json: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
errortext := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errortext)
http.Error(w, errortext, http.StatusInternalServerError)
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func checkDNS(input string) error {
input, err := sanitizeDNSServer(input)
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
if err != nil {
return err
}
defer u.Close()
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
alive, err := upstream.IsAlive(u)
prefix, host := splitDNSServerPrefixServer(input)
c := dns.Client{
Timeout: time.Minute,
}
switch prefix {
case "tls://":
c.Net = "tcp-tls"
}
resp, rtt, err := c.Exchange(&req, host)
if err != nil {
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
}
trace("exchange with %s took %v", input, rtt)
if len(resp.Answer) != 1 {
return fmt.Errorf("DNS server %s returned wrong answer", input)
}
if t, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
}
if !alive {
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
}
return nil
}
func sanitiseDNSServers(input string) ([]string, error) {
fields := strings.Fields(input)
hosts := make([]string, 0)
for _, field := range fields {
sanitized, err := sanitizeDNSServer(field)
if err != nil {
return hosts, err
}
hosts = append(hosts, sanitized)
}
return hosts, nil
}
func getDNSServerPrefix(input string) string {
prefix := ""
switch {
case strings.HasPrefix(input, "dns://"):
prefix = "dns://"
case strings.HasPrefix(input, "tls://"):
prefix = "tls://"
}
return prefix
}
func splitDNSServerPrefixServer(input string) (string, string) {
prefix := getDNSServerPrefix(input)
host := strings.TrimPrefix(input, prefix)
return prefix, host
}
func sanitizeDNSServer(input string) (string, error) {
prefix, host := splitDNSServerPrefixServer(input)
host = appendPortIfMissing(prefix, host)
{
h, _, err := net.SplitHostPort(host)
if err != nil {
return "", err
}
ip := net.ParseIP(h)
if ip == nil {
return "", fmt.Errorf("invalid DNS server field: %s", h)
}
}
return prefix + host, nil
}
func appendPortIfMissing(prefix, input string) string {
port := "53"
switch prefix {
case "tls://":
port = "853"
}
_, _, err := net.SplitHostPort(input)
if err == nil {
return input
}
return net.JoinHostPort(input, port)
}
//noinspection GoUnusedParameter
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now()

View File

@ -8,6 +8,7 @@ import (
"sync" // Include all plugins.
_ "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
_ "github.com/AdguardTeam/AdGuardHome/upstream"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/coremain"
_ "github.com/coredns/coredns/plugin/auto"
@ -79,6 +80,7 @@ var directives = []string{
"loop",
"forward",
"proxy",
"upstream",
"erratic",
"whoami",
"on",

View File

@ -41,6 +41,7 @@ paths:
protection_enabled: true
querylog_enabled: true
running: true
bootstrap_dns: 8.8.8.8:53
upstream_dns:
- 1.1.1.1
- 1.0.0.1

109
upstream/dns_upstream.go Normal file
View File

@ -0,0 +1,109 @@
package upstream
import (
"crypto/tls"
"time"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// DnsUpstream is a very simple upstream implementation for plain DNS
type DnsUpstream struct {
endpoint string // IP:port
timeout time.Duration // Max read and write timeout
proto string // Protocol (tcp, tcp-tls, or udp)
transport *Transport // Persistent connections cache
}
// NewDnsUpstream creates a new DNS upstream
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
u := &DnsUpstream{
endpoint: endpoint,
timeout: defaultTimeout,
proto: proto,
}
var tlsConfig *tls.Config
if proto == "tcp-tls" {
tlsConfig = new(tls.Config)
tlsConfig.ServerName = tlsServerName
}
// Initialize the connections cache
u.transport = NewTransport(endpoint)
u.transport.tlsConfig = tlsConfig
u.transport.Start()
return u, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
resp, err := u.exchange(u.proto, query)
// Retry over TCP if response is truncated
if err == dns.ErrTruncated && u.proto == "udp" {
resp, err = u.exchange("tcp", query)
} else if err == dns.ErrTruncated && resp != nil {
// Reassemble something to be sent to client
m := new(dns.Msg)
m.SetReply(query)
m.Truncated = true
m.Authoritative = true
m.Rcode = dns.RcodeSuccess
return m, nil
}
if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}
return resp, err
}
// Clear resources
func (u *DnsUpstream) Close() error {
// Close active connections
u.transport.Stop()
return nil
}
// Performs a synchronous query. It sends the message m via the conn
// c and waits for a reply. The conn c is not closed.
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
// Establish a connection if needed (or reuse cached)
conn, err := u.transport.Dial(proto)
if err != nil {
return nil, err
}
// Write the request with a timeout
conn.SetWriteDeadline(time.Now().Add(u.timeout))
if err = conn.WriteMsg(query); err != nil {
conn.Close() // Not giving it back
return nil, err
}
// Write response with a timeout
conn.SetReadDeadline(time.Now().Add(u.timeout))
r, err = conn.ReadMsg()
if err != nil {
conn.Close() // Not giving it back
} else if err == nil && r.Id != query.Id {
err = dns.ErrId
conn.Close() // Not giving it back
}
if err == nil {
// Return it back to the connections cache if there were no errors
u.transport.Yield(conn)
}
return r, err
}

101
upstream/helpers.go Normal file
View File

@ -0,0 +1,101 @@
package upstream
import (
"net"
"strings"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// Detects the upstream type from the specified url and creates a proper Upstream object
func NewUpstream(url string, bootstrap string) (Upstream, error) {
proto := "udp"
prefix := ""
switch {
case strings.HasPrefix(url, "tcp://"):
proto = "tcp"
prefix = "tcp://"
case strings.HasPrefix(url, "tls://"):
proto = "tcp-tls"
prefix = "tls://"
case strings.HasPrefix(url, "https://"):
return NewHttpsUpstream(url, bootstrap)
}
hostname := strings.TrimPrefix(url, prefix)
host, port, err := net.SplitHostPort(hostname)
if err != nil {
// Set port depending on the protocol
switch proto {
case "udp":
port = "53"
case "tcp":
port = "53"
case "tcp-tls":
port = "853"
}
// Set host = hostname
host = hostname
}
// Try to resolve the host address (or check if it's an IP address)
bootstrapResolver := CreateResolver(bootstrap)
ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host)
if err != nil || len(ips) == 0 {
return nil, err
}
addr := ips[0].String()
endpoint := net.JoinHostPort(addr, port)
tlsServerName := ""
if proto == "tcp-tls" && host != addr {
// Check if we need to specify TLS server name
tlsServerName = host
}
return NewDnsUpstream(endpoint, proto, tlsServerName)
}
func CreateResolver(bootstrap string) *net.Resolver {
bootstrapResolver := net.DefaultResolver
if bootstrap != "" {
bootstrapResolver = &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, network, bootstrap)
},
}
}
return bootstrapResolver
}
// Performs a simple health-check of the specified upstream
func IsAlive(u Upstream) (bool, error) {
// Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere
ping := new(dns.Msg)
ping.SetQuestion("ipv4only.arpa.", dns.TypeA)
resp, err := u.Exchange(context.Background(), ping)
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
if err != nil && resp != nil {
// Silly check, something sane came back.
if resp.Rcode != dns.RcodeServerFailure {
err = nil
}
}
return err == nil, err
}

128
upstream/https_upstream.go Normal file
View File

@ -0,0 +1,128 @@
package upstream
import (
"bytes"
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"net/url"
"time"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
"golang.org/x/net/http2"
)
const (
dnsMessageContentType = "application/dns-message"
defaultKeepAlive = 30 * time.Second
)
// HttpsUpstream is the upstream implementation for DNS-over-HTTPS
type HttpsUpstream struct {
client *http.Client
endpoint *url.URL
}
// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url
func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
// Initialize bootstrap resolver
bootstrapResolver := CreateResolver(bootstrap)
dialer := &net.Dialer{
Timeout: defaultTimeout,
KeepAlive: defaultKeepAlive,
DualStack: true,
Resolver: bootstrapResolver,
}
// Update TLS and HTTP client configuration
tlsConfig := &tls.Config{ServerName: u.Hostname()}
transport := &http.Transport{
TLSClientConfig: tlsConfig,
DisableCompression: true,
MaxIdleConns: 1,
DialContext: dialer.DialContext,
}
http2.ConfigureTransport(transport)
client := &http.Client{
Timeout: defaultTimeout,
Transport: transport,
}
return &HttpsUpstream{client: client, endpoint: u}, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
queryBuf, err := query.Pack()
if err != nil {
return nil, errors.Wrap(err, "failed to pack DNS query")
}
// No content negotiation for now, use DNS wire format
buf, backendErr := u.exchangeWireformat(queryBuf)
if backendErr == nil {
response := &dns.Msg{}
if err := response.Unpack(buf); err != nil {
return nil, errors.Wrap(err, "failed to unpack DNS response from body")
}
response.Id = query.Id
return response, nil
}
log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr)
return nil, backendErr
}
// Perform message exchange with the default UDP wireformat defined in current draft
// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10
func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) {
req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg))
if err != nil {
return nil, errors.Wrap(err, "failed to create an HTTPS request")
}
req.Header.Add("Content-Type", dnsMessageContentType)
req.Header.Add("Accept", dnsMessageContentType)
req.Host = u.endpoint.Hostname()
resp, err := u.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "failed to perform an HTTPS request")
}
// Check response status code
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("returned status code %d", resp.StatusCode)
}
contentType := resp.Header.Get("Content-Type")
if contentType != dnsMessageContentType {
return nil, fmt.Errorf("return wrong content type %s", contentType)
}
// Read application/dns-message response from the body
buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read the response body")
}
return buf, nil
}
// Clear resources
func (u *HttpsUpstream) Close() error {
return nil
}

210
upstream/persistent.go Normal file
View File

@ -0,0 +1,210 @@
package upstream
import (
"crypto/tls"
"net"
"sort"
"sync/atomic"
"time"
"github.com/miekg/dns"
)
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
const (
defaultExpire = 10 * time.Second
minDialTimeout = 100 * time.Millisecond
maxDialTimeout = 30 * time.Second
defaultDialTimeout = 30 * time.Second
cumulativeAvgWeight = 4
)
// a persistConn hold the dns.Conn and the last used time.
type persistConn struct {
c *dns.Conn
used time.Time
}
// Transport hold the persistent cache.
type Transport struct {
avgDialTime int64 // kind of average time of dial time
conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls.
expire time.Duration // After this duration a connection is expired.
addr string
tlsConfig *tls.Config
dial chan string
yield chan *dns.Conn
ret chan *dns.Conn
stop chan bool
}
// Dial dials the address configured in transport, potentially reusing a connection or creating a new one.
func (t *Transport) Dial(proto string) (*dns.Conn, error) {
// If tls has been configured; use it.
if t.tlsConfig != nil {
proto = "tcp-tls"
}
t.dial <- proto
c := <-t.ret
if c != nil {
return c, nil
}
reqTime := time.Now()
timeout := t.dialTimeout()
if proto == "tcp-tls" {
conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout)
t.updateDialTimeout(time.Since(reqTime))
return conn, err
}
conn, err := dns.DialTimeout(proto, t.addr, timeout)
t.updateDialTimeout(time.Since(reqTime))
return conn, err
}
// Yield return the connection to transport for reuse.
func (t *Transport) Yield(c *dns.Conn) { t.yield <- c }
// Start starts the transport's connection manager.
func (t *Transport) Start() { go t.connManager() }
// Stop stops the transport's connection manager.
func (t *Transport) Stop() { close(t.stop) }
// SetExpire sets the connection expire time in transport.
func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire }
// SetTLSConfig sets the TLS config in transport.
func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg }
func NewTransport(addr string) *Transport {
t := &Transport{
avgDialTime: int64(defaultDialTimeout / 2),
conns: make(map[string][]*persistConn),
expire: defaultExpire,
addr: addr,
dial: make(chan string),
yield: make(chan *dns.Conn),
ret: make(chan *dns.Conn),
stop: make(chan bool),
}
return t
}
func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) {
dt := time.Duration(atomic.LoadInt64(currentAvg))
atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight)
}
func (t *Transport) dialTimeout() time.Duration {
return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout)
}
func (t *Transport) updateDialTimeout(newDialTime time.Duration) {
averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight)
}
// limitTimeout is a utility function to auto-tune timeout values
// average observed time is moved towards the last observed delay moderated by a weight
// next timeout to use will be the double of the computed average, limited by min and max frame.
func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration {
rt := time.Duration(atomic.LoadInt64(currentAvg))
if rt < minValue {
return minValue
}
if rt < maxValue/2 {
return 2 * rt
}
return maxValue
}
// connManagers manages the persistent connection cache for UDP and TCP.
func (t *Transport) connManager() {
ticker := time.NewTicker(t.expire)
Wait:
for {
select {
case proto := <-t.dial:
// take the last used conn - complexity O(1)
if stack := t.conns[proto]; len(stack) > 0 {
pc := stack[len(stack)-1]
if time.Since(pc.used) < t.expire {
// Found one, remove from pool and return this conn.
t.conns[proto] = stack[:len(stack)-1]
t.ret <- pc.c
continue Wait
}
// clear entire cache if the last conn is expired
t.conns[proto] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
}
t.ret <- nil
case conn := <-t.yield:
// no proto here, infer from config and conn
if _, ok := conn.Conn.(*net.UDPConn); ok {
t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()})
continue Wait
}
if t.tlsConfig == nil {
t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()})
continue Wait
}
t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()})
case <-ticker.C:
t.cleanup(false)
case <-t.stop:
t.cleanup(true)
close(t.ret)
return
}
}
}
// closeConns closes connections.
func closeConns(conns []*persistConn) {
for _, pc := range conns {
pc.c.Close()
}
}
// cleanup removes connections from cache.
func (t *Transport) cleanup(all bool) {
staleTime := time.Now().Add(-t.expire)
for proto, stack := range t.conns {
if len(stack) == 0 {
continue
}
if all {
t.conns[proto] = nil
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack)
continue
}
if stack[0].used.After(staleTime) {
continue
}
// connections in stack are sorted by "used"
good := sort.Search(len(stack), func(i int) bool {
return stack[i].used.After(staleTime)
})
t.conns[proto] = stack[good:]
// now, the connections being passed to closeConns() are not reachable from
// transport methods anymore. So, it's safe to close them in a separate goroutine
go closeConns(stack[:good])
}
}

84
upstream/setup.go Normal file
View File

@ -0,0 +1,84 @@
package upstream
import (
"log"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/mholt/caddy"
)
func init() {
caddy.RegisterPlugin("upstream", caddy.Plugin{
ServerType: "dns",
Action: setup,
})
}
// Read the configuration and initialize upstreams
func setup(c *caddy.Controller) error {
p, err := setupPlugin(c)
if err != nil {
return err
}
config := dnsserver.GetConfig(c)
config.AddPlugin(func(next plugin.Handler) plugin.Handler {
p.Next = next
return p
})
c.OnShutdown(p.onShutdown)
return nil
}
// Read the configuration
func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) {
p := New()
log.Println("Initializing the Upstream plugin")
bootstrap := ""
upstreamUrls := []string{}
for c.Next() {
args := c.RemainingArgs()
if len(args) > 0 {
upstreamUrls = append(upstreamUrls, args...)
}
for c.NextBlock() {
switch c.Val() {
case "bootstrap":
if !c.NextArg() {
return nil, c.ArgErr()
}
bootstrap = c.Val()
}
}
}
for _, url := range upstreamUrls {
u, err := NewUpstream(url, bootstrap)
if err != nil {
log.Printf("Cannot initialize upstream %s", url)
return nil, err
}
p.Upstreams = append(p.Upstreams, u)
}
return p, nil
}
func (p *UpstreamPlugin) onShutdown() error {
for i := range p.Upstreams {
u := p.Upstreams[i]
err := u.Close()
if err != nil {
log.Printf("Error while closing the upstream: %s", err)
}
}
return nil
}

30
upstream/setup_test.go Normal file
View File

@ -0,0 +1,30 @@
package upstream
import (
"testing"
"github.com/mholt/caddy"
)
func TestSetup(t *testing.T) {
var tests = []struct {
config string
}{
{`upstream 8.8.8.8`},
{`upstream 8.8.8.8 {
bootstrap 8.8.8.8:53
}`},
{`upstream tls://1.1.1.1 8.8.8.8 {
bootstrap 1.1.1.1
}`},
}
for _, test := range tests {
c := caddy.NewTestController("dns", test.config)
err := setup(c)
if err != nil {
t.Fatalf("Test failed")
}
}
}

57
upstream/upstream.go Normal file
View File

@ -0,0 +1,57 @@
package upstream
import (
"time"
"github.com/coredns/coredns/plugin"
"github.com/miekg/dns"
"github.com/pkg/errors"
"golang.org/x/net/context"
)
const (
defaultTimeout = 5 * time.Second
)
// Upstream is a simplified interface for proxy destination
type Upstream interface {
Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error)
Close() error
}
// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface
type UpstreamPlugin struct {
Upstreams []Upstream
Next plugin.Handler
}
// Initialize the upstream plugin
func New() *UpstreamPlugin {
p := &UpstreamPlugin{
Upstreams: []Upstream{},
}
return p
}
// ServeDNS implements interface for CoreDNS plugin
func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) {
var reply *dns.Msg
var backendErr error
for i := range p.Upstreams {
upstream := p.Upstreams[i]
reply, backendErr = upstream.Exchange(ctx, r)
if backendErr == nil {
w.WriteMsg(reply)
return 0, nil
}
}
return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams")
}
// Name implements interface for CoreDNS plugin
func (p *UpstreamPlugin) Name() string {
return "upstream"
}

194
upstream/upstream_test.go Normal file
View File

@ -0,0 +1,194 @@
package upstream
import (
"net"
"testing"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
func TestDnsUpstreamIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestHttpsUpstreamIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
{"https://dns.google.com/experimental", "8.8.8.8:53"},
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestDnsOverTlsIsAlive(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
}
testUpstreamIsAlive(t, u)
}
}
func TestDnsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"8.8.8.8:53", "8.8.8.8:53"},
{"1.1.1.1", ""},
{"tcp://1.1.1.1:53", ""},
{"176.103.130.130:5353", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS upstream")
}
testUpstream(t, u)
}
}
func TestHttpsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"},
{"https://dns.google.com/experimental", "8.8.8.8:53"},
{"https://doh.cleanbrowsing.org/doh/security-filter/", ""},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-HTTPS upstream")
}
testUpstream(t, u)
}
}
func TestDnsOverTlsUpstream(t *testing.T) {
var tests = []struct {
url string
bootstrap string
}{
{"tls://1.1.1.1", ""},
{"tls://9.9.9.9:853", ""},
{"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"},
{"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"},
}
for _, test := range tests {
u, err := NewUpstream(test.url, test.bootstrap)
if err != nil {
t.Errorf("cannot create a DNS-over-TLS upstream")
}
testUpstream(t, u)
}
}
func testUpstreamIsAlive(t *testing.T, u Upstream) {
alive, err := IsAlive(u)
if !alive || err != nil {
t.Errorf("Upstream is not alive")
}
u.Close()
}
func testUpstream(t *testing.T, u Upstream) {
var tests = []struct {
name string
expected net.IP
}{
{"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)},
{"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)},
}
for _, test := range tests {
req := dns.Msg{}
req.Id = dns.Id()
req.RecursionDesired = true
req.Question = []dns.Question{
{Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET},
}
resp, err := u.Exchange(context.Background(), &req)
if err != nil {
t.Errorf("error while making an upstream request: %s", err)
}
if len(resp.Answer) != 1 {
t.Errorf("no answer section in the response")
}
if answer, ok := resp.Answer[0].(*dns.A); ok {
if !test.expected.Equal(answer.A) {
t.Errorf("wrong IP in the response: %v", answer.A)
}
}
}
err := u.Close()
if err != nil {
t.Errorf("Error while closing the upstream: %s", err)
}
}