Added bootstrap DNS to the config file
DNS healthcheck now uses the upstream package methods
This commit is contained in:
parent
7f018234f6
commit
451922b858
|
@ -70,6 +70,7 @@ type coreDNSConfig struct {
|
|||
Pprof string `yaml:"-"`
|
||||
Cache string `yaml:"-"`
|
||||
Prometheus string `yaml:"-"`
|
||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||
}
|
||||
|
||||
|
@ -100,6 +101,7 @@ var config = configuration{
|
|||
SafeBrowsingEnabled: false,
|
||||
BlockedResponseTTL: 10, // in seconds
|
||||
QueryLogEnabled: true,
|
||||
BootstrapDNS: "8.8.8.8:53",
|
||||
UpstreamDNS: defaultDNS,
|
||||
Cache: "cache",
|
||||
Prometheus: "prometheus :9153",
|
||||
|
@ -253,7 +255,7 @@ const coreDNSConfigTemplate = `.:{{.Port}} {
|
|||
hosts {
|
||||
fallthrough
|
||||
}
|
||||
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap 8.8.8.8:53 }{{end}}
|
||||
{{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}}
|
||||
{{.Cache}}
|
||||
{{.Prometheus}}
|
||||
}
|
||||
|
|
104
control.go
104
control.go
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -15,8 +14,9 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/upstream"
|
||||
|
||||
corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin"
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/asaskevich/govalidator.v4"
|
||||
)
|
||||
|
||||
|
@ -81,6 +81,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
|||
"protection_enabled": config.CoreDNS.ProtectionEnabled,
|
||||
"querylog_enabled": config.CoreDNS.QueryLogEnabled,
|
||||
"running": isRunning(),
|
||||
"bootstrap_dns": config.CoreDNS.BootstrapDNS,
|
||||
"upstream_dns": config.CoreDNS.UpstreamDNS,
|
||||
"version": VersionString,
|
||||
}
|
||||
|
@ -140,11 +141,8 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
// if empty body -- user is asking for default servers
|
||||
hosts, err := sanitiseDNSServers(string(body))
|
||||
if err != nil {
|
||||
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
|
||||
return
|
||||
}
|
||||
hosts := strings.Fields(string(body))
|
||||
|
||||
if len(hosts) == 0 {
|
||||
config.CoreDNS.UpstreamDNS = defaultDNS
|
||||
} else {
|
||||
|
@ -214,104 +212,26 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func checkDNS(input string) error {
|
||||
input, err := sanitizeDNSServer(input)
|
||||
|
||||
u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := dns.Msg{}
|
||||
req.Id = dns.Id()
|
||||
req.RecursionDesired = true
|
||||
req.Question = []dns.Question{
|
||||
{Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
||||
}
|
||||
alive, err := upstream.IsAlive(u)
|
||||
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
|
||||
c := dns.Client{
|
||||
Timeout: time.Minute,
|
||||
}
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
c.Net = "tcp-tls"
|
||||
}
|
||||
|
||||
resp, rtt, err := c.Exchange(&req, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err)
|
||||
}
|
||||
trace("exchange with %s took %v", input, rtt)
|
||||
if len(resp.Answer) != 1 {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer", input)
|
||||
}
|
||||
if t, ok := resp.Answer[0].(*dns.A); ok {
|
||||
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
|
||||
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
|
||||
}
|
||||
|
||||
if !alive {
|
||||
return fmt.Errorf("DNS server has not passed the healthcheck: %s", input)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitiseDNSServers(input string) ([]string, error) {
|
||||
fields := strings.Fields(input)
|
||||
hosts := make([]string, 0)
|
||||
for _, field := range fields {
|
||||
sanitized, err := sanitizeDNSServer(field)
|
||||
if err != nil {
|
||||
return hosts, err
|
||||
}
|
||||
hosts = append(hosts, sanitized)
|
||||
}
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func getDNSServerPrefix(input string) string {
|
||||
prefix := ""
|
||||
switch {
|
||||
case strings.HasPrefix(input, "dns://"):
|
||||
prefix = "dns://"
|
||||
case strings.HasPrefix(input, "tls://"):
|
||||
prefix = "tls://"
|
||||
}
|
||||
return prefix
|
||||
}
|
||||
|
||||
func splitDNSServerPrefixServer(input string) (string, string) {
|
||||
prefix := getDNSServerPrefix(input)
|
||||
host := strings.TrimPrefix(input, prefix)
|
||||
return prefix, host
|
||||
}
|
||||
|
||||
func sanitizeDNSServer(input string) (string, error) {
|
||||
prefix, host := splitDNSServerPrefixServer(input)
|
||||
host = appendPortIfMissing(prefix, host)
|
||||
{
|
||||
h, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
ip := net.ParseIP(h)
|
||||
if ip == nil {
|
||||
return "", fmt.Errorf("invalid DNS server field: %s", h)
|
||||
}
|
||||
}
|
||||
return prefix + host, nil
|
||||
}
|
||||
|
||||
func appendPortIfMissing(prefix, input string) string {
|
||||
port := "53"
|
||||
switch prefix {
|
||||
case "tls://":
|
||||
port = "853"
|
||||
}
|
||||
_, _, err := net.SplitHostPort(input)
|
||||
if err == nil {
|
||||
return input
|
||||
}
|
||||
return net.JoinHostPort(input, port)
|
||||
}
|
||||
|
||||
//noinspection GoUnusedParameter
|
||||
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now()
|
||||
|
|
|
@ -41,6 +41,7 @@ paths:
|
|||
protection_enabled: true
|
||||
querylog_enabled: true
|
||||
running: true
|
||||
bootstrap_dns: 8.8.8.8:53
|
||||
upstream_dns:
|
||||
- 1.1.1.1
|
||||
- 1.0.0.1
|
||||
|
|
|
@ -93,7 +93,7 @@ func IsAlive(u Upstream) (bool, error) {
|
|||
// If we got a header, we're alright, basically only care about I/O errors 'n stuff.
|
||||
if err != nil && resp != nil {
|
||||
// Silly check, something sane came back.
|
||||
if resp.Response || resp.Opcode == dns.OpcodeQuery {
|
||||
if resp.Rcode != dns.RcodeServerFailure {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin
|
||||
|
||||
const (
|
||||
defaultExpire = 10 * time.Second
|
||||
minDialTimeout = 100 * time.Millisecond
|
||||
|
|
Loading…
Reference in New Issue