Set default servers to tls://1.1.1.1 and tls://1.0.0.1

Also add support for tls:// in webUI API
This commit is contained in:
Eugene Bujak 2018-09-26 17:47:23 +03:00
parent 3afd8fccc7
commit ff86d6b7dc
3 changed files with 88 additions and 39 deletions

View File

@ -56,7 +56,7 @@ type filter struct {
LastUpdated time.Time `json:"last_updated" yaml:"-"`
}
var defaultDNS = []string{"1.1.1.1", "1.0.0.1"}
var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"}
// initialize to default values, will be changed later when reading config or parsing command line
var config = configuration{

View File

@ -506,16 +506,25 @@ func handleStatsTop(w http.ResponseWriter, r *http.Request) {
}
}
func httpError(w http.ResponseWriter, code int, format string, args ...interface{}) {
text := fmt.Sprintf(format, args...)
http.Error(w, text, code)
}
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errortext := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errortext)
http.Error(w, errortext, 400)
http.Error(w, errortext, http.StatusBadRequest)
return
}
// if empty body -- user is asking for default servers
hosts := parseIPsOptionalPort(string(body))
hosts, err := sanitiseDNSServers(string(body))
if err != nil {
httpError(w, http.StatusBadRequest, "Invalid DNS servers were given: %s", err)
return
}
if len(hosts) == 0 {
config.CoreDNS.UpstreamDNS = defaultDNS
} else {
@ -584,18 +593,11 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
}
}
func checkDNS(host string) error {
host = appendPortIfMissing(host)
{
h, _, err := net.SplitHostPort(host)
func checkDNS(input string) error {
input, err := sanitizeDNSServer(input)
if err != nil {
return err
}
ip := net.ParseIP(h)
if ip == nil {
return fmt.Errorf("Invalid DNS server field: %s", h)
}
}
req := dns.Msg{}
req.Id = dns.Id()
@ -603,45 +605,91 @@ func checkDNS(host string) error {
req.Question = []dns.Question{
{"google-public-dns-a.google.com.", dns.TypeA, dns.ClassINET},
}
resp, err := dns.Exchange(&req, host)
if err != nil {
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", host, err)
prefix, host := splitDNSServerPrefixServer(input)
c := dns.Client{
Timeout: time.Minute,
}
switch prefix {
case "tls://":
c.Net = "tcp-tls"
}
resp, rtt, err := c.Exchange(&req, host)
if err != nil {
return fmt.Errorf("Couldn't communicate with DNS server %s: %s", input, err)
}
trace("exchange with %s took %v", input, rtt)
if len(resp.Answer) != 1 {
return fmt.Errorf("DNS server %s returned wrong answer", host)
return fmt.Errorf("DNS server %s returned wrong answer", input)
}
if t, ok := resp.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("DNS server %s returned wrong answer: %v", host, t.A)
return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A)
}
}
return nil
}
func appendPortIfMissing(input string) string {
func sanitiseDNSServers(input string) ([]string, error) {
fields := strings.Fields(input)
hosts := []string{}
for _, field := range fields {
sanitized, err := sanitizeDNSServer(field)
if err != nil {
return hosts, err
}
hosts = append(hosts, sanitized)
}
return hosts, nil
}
func getDNSServerPrefix(input string) string {
prefix := ""
switch {
case strings.HasPrefix(input, "dns://"):
prefix = "dns://"
case strings.HasPrefix(input, "tls://"):
prefix = "tls://"
}
return prefix
}
func splitDNSServerPrefixServer(input string) (string, string) {
prefix := getDNSServerPrefix(input)
host := strings.TrimPrefix(input, prefix)
return prefix, host
}
func sanitizeDNSServer(input string) (string, error) {
prefix, host := splitDNSServerPrefixServer(input)
host = appendPortIfMissing(prefix, host)
{
h, _, err := net.SplitHostPort(host)
if err != nil {
return "", err
}
ip := net.ParseIP(h)
if ip == nil {
return "", fmt.Errorf("Invalid DNS server field: %s", h)
}
}
return prefix + host, nil
}
func appendPortIfMissing(prefix, input string) string {
port := "53"
switch prefix {
case "tls://":
port = "853"
}
_, _, err := net.SplitHostPort(input)
if err == nil {
return input
}
return net.JoinHostPort(input, "53")
}
func parseIPsOptionalPort(input string) []string {
fields := strings.Fields(input)
hosts := []string{}
for _, field := range fields {
_, _, err := net.SplitHostPort(field)
if err != nil {
ip := net.ParseIP(field)
if ip == nil {
log.Printf("Invalid DNS server field: %s\n", field)
continue
}
}
hosts = append(hosts, field)
}
return hosts
return net.JoinHostPort(input, port)
}
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"os"
"path"
"runtime"
"sort"
@ -259,5 +260,5 @@ func trace(format string, args ...interface{}) {
if len(text) == 0 || text[len(text)-1] != '\n' {
buf.WriteRune('\n')
}
fmt.Print(buf.String())
fmt.Fprint(os.Stderr, buf.String())
}