diff --git a/README.md b/README.md index 4096ad93..b9916574 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib * `parental_enabled` — Parental control-based DNS requests filtering * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) - * `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname + * `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname * `upstream_dns` — List of upstream DNS servers * `filters` — List of filters, each filter has the following values: * `ID` - filter ID (must be unique) diff --git a/config.go b/config.go index 89822d3f..8facc405 100644 --- a/config.go +++ b/config.go @@ -43,8 +43,7 @@ type dnsConfig struct { dnsforward.FilteringConfig `yaml:",inline"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` + UpstreamDNS []string `yaml:"upstream_dns"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -63,9 +62,9 @@ var config = configuration{ QueryLogEnabled: true, Ratelimit: 20, RefuseAny: true, + BootstrapDNS: "8.8.8.8:53", }, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, + UpstreamDNS: defaultDNS, }, Filters: []filter{ {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, diff --git a/control.go b/control.go index 49869ddc..2674585c 100644 --- a/control.go +++ b/control.go @@ -204,7 +204,7 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { func checkDNS(input string) error { log.Printf("Checking if DNS %s works...", input) - u, err := dnsforward.GetUpstream(input) + u, err := dnsforward.AddressToUpstream(input, "") if err != nil { return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) } diff --git a/coredns.go b/coredns.go index 250c9e37..42894336 100644 --- a/coredns.go +++ b/coredns.go @@ -37,7 +37,7 @@ func generateServerConfig() dnsforward.ServerConfig { } for _, u := range config.DNS.UpstreamDNS { - upstream, err := dnsforward.GetUpstream(u) + upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS) if err != nil { log.Printf("Couldn't get upstream: %s", err) // continue, just ignore the upstream diff --git a/dnsforward/bootstrap.go b/dnsforward/bootstrap.go new file mode 100644 index 00000000..2d263871 --- /dev/null +++ b/dnsforward/bootstrap.go @@ -0,0 +1,107 @@ +package dnsforward + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + "strings" + "sync" + + "github.com/joomcode/errorx" +) + +type bootstrapper struct { + address string // in form of "tls://one.one.one.one:853" + resolver *net.Resolver // resolver to use to resolve hostname, if neccessary + resolved string // in form "IP:port" + resolvedConfig *tls.Config + sync.Mutex +} + +func toBoot(address, bootstrapAddr string) bootstrapper { + var resolver *net.Resolver + if bootstrapAddr != "" { + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, network, bootstrapAddr) + }, + } + } + return bootstrapper{ + address: address, + resolver: resolver, + } +} + +// will get usable IP address from Address field, and caches the result +func (n *bootstrapper) get() (string, *tls.Config, error) { + // TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work + n.Lock() + if n.resolved != "" { // fast path + retval, tlsconfig := n.resolved, n.resolvedConfig + n.Unlock() + return retval, tlsconfig, nil + } + + // + // slow path + // + + defer n.Unlock() + + justHostPort := n.address + if strings.Contains(n.address, "://") { + url, err := url.Parse(n.address) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address) + } + + justHostPort = url.Host + } + + // convert host to IP if neccessary, we know that it's scheme://hostname:port/ + + // get a host without port + host, port, err := net.SplitHostPort(justHostPort) + if err != nil { + return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address) + } + + // if it's an IP + ip := net.ParseIP(host) + if ip != nil { + n.resolved = justHostPort + return n.resolved, nil, nil + } + + // + // if it's a hostname + // + + resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver + addrs, err := resolver.LookupIPAddr(context.TODO(), host) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to lookup %s", host) + } + for _, addr := range addrs { + // TODO: support ipv6, support multiple ipv4 + if addr.IP.To4() == nil { + continue + } + ip = addr.IP + break + } + + if ip == nil { + // couldn't find any suitable IP address + return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host) + } + + n.resolved = net.JoinHostPort(ip.String(), port) + n.resolvedConfig = &tls.Config{ServerName: host} + return n.resolved, n.resolvedConfig, nil +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index bee85d3a..404bbfb3 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -86,6 +86,7 @@ type FilteringConfig struct { Ratelimit int `yaml:"ratelimit"` RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` RefuseAny bool `yaml:"refuse_any"` + BootstrapDNS string `yaml:"bootstrap_dns"` dnsfilter.Config `yaml:",inline"` } @@ -105,24 +106,24 @@ var defaultValues = ServerConfig{ FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, Upstreams: []Upstream{ //// dns over HTTPS - // &dnsOverHTTPS{address: "https://1.1.1.1/dns-query"}, - // &dnsOverHTTPS{address: "https://dns.google.com/experimental"}, - // &dnsOverHTTPS{address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, - // &dnsOverHTTPS{address: "https://dns10.quad9.net/dns-query"}, - // &dnsOverHTTPS{address: "https://doh.powerdns.org"}, - // &dnsOverHTTPS{address: "https://doh.securedns.eu/dns-query"}, + // &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")}, //// dns over TLS - // &dnsOverTLS{address: "tls://8.8.8.8:853"}, - // &dnsOverTLS{address: "tls://8.8.4.4:853"}, - // &dnsOverTLS{address: "tls://1.1.1.1:853"}, - // &dnsOverTLS{address: "tls://1.0.0.1:853"}, + // &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")}, //// plainDNS - &plainDNS{address: "8.8.8.8:53"}, - &plainDNS{address: "8.8.4.4:53"}, - &plainDNS{address: "1.1.1.1:53"}, - &plainDNS{address: "1.0.0.1:53"}, + &plainDNS{boot: toBoot("8.8.8.8:53", "")}, + &plainDNS{boot: toBoot("8.8.4.4:53", "")}, + &plainDNS{boot: toBoot("1.1.1.1:53", "")}, + &plainDNS{boot: toBoot("1.0.0.1:53", "")}, }, } diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go index 5d672f5a..99142929 100644 --- a/dnsforward/upstream.go +++ b/dnsforward/upstream.go @@ -29,7 +29,7 @@ type Upstream interface { // plain DNS // type plainDNS struct { - address string + boot bootstrapper preferTCP bool } @@ -44,19 +44,25 @@ var defaultTCPClient = dns.Client{ Timeout: defaultTimeout, } -func (p *plainDNS) Address() string { return p.address } +// Address returns the original address that we've put in initially, not resolved one +func (p *plainDNS) Address() string { return p.boot.address } func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, _, err := p.boot.get() + if err != nil { + return nil, err + } if p.preferTCP { - reply, _, err := defaultTCPClient.Exchange(m, p.address) + reply, _, err := defaultTCPClient.Exchange(m, addr) return reply, err } - reply, _, err := defaultUDPClient.Exchange(m, p.address) + reply, _, err := defaultUDPClient.Exchange(m, addr) if err != nil && reply != nil && reply.Truncated { log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) - reply, _, err = defaultTCPClient.Exchange(m, p.address) + reply, _, err = defaultTCPClient.Exchange(m, addr) } + return reply, err } @@ -64,8 +70,8 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-TLS // type dnsOverTLS struct { - address string - pool *TLSPool + boot bootstrapper + pool *TLSPool sync.RWMutex // protects pool } @@ -77,7 +83,7 @@ var defaultTLSClient = dns.Client{ TLSConfig: &tls.Config{}, } -func (p *dnsOverTLS) Address() string { return p.address } +func (p *dnsOverTLS) Address() string { return p.boot.address } func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { var pool *TLSPool @@ -87,7 +93,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { if pool == nil { p.Lock() // lazy initialize it - p.pool = &TLSPool{Address: p.address} + p.pool = &TLSPool{boot: &p.boot} p.Unlock() } @@ -95,19 +101,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { poolConn, err := p.pool.Get() p.RUnlock() if err != nil { - return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address) + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address()) } c := dns.Conn{Conn: poolConn} err = c.WriteMsg(m) if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address) + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address()) } reply, err := c.ReadMsg() if err != nil { poolConn.Close() - return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address) + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address()) } p.RLock() p.pool.Put(poolConn) @@ -119,7 +125,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { // DNS-over-https // type dnsOverHTTPS struct { - address string + boot bootstrapper } var defaultHTTPSTransport = http.Transport{} @@ -129,35 +135,59 @@ var defaultHTTPSClient = http.Client{ Timeout: defaultTimeout, } -func (p *dnsOverHTTPS) Address() string { return p.address } +func (p *dnsOverHTTPS) Address() string { return p.boot.address } func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, tlsConfig, err := p.boot.get() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address) + } + buf, err := m.Pack() if err != nil { return nil, errorx.Decorate(err, "Couldn't pack request msg") } bb := bytes.NewBuffer(buf) - resp, err := http.Post(p.address, "application/dns-message", bb) + + // set up a custom request with custom URL + url, err := url.Parse(p.boot.address) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address) + } + req := http.Request{ + Method: "POST", + URL: url, + Body: ioutil.NopCloser(bb), + Header: make(http.Header), + Host: url.Host, + } + url.Host = addr + req.Header.Set("Content-Type", "application/dns-message") + client := http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsConfig}, + } + resp, err := client.Do(&req) if resp != nil && resp.Body != nil { defer resp.Body.Close() } if err != nil { - return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address) + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr) } + body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address) + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address) + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr) } if len(body) == 0 { - return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address) + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr) } response := dns.Msg{} err = response.Unpack(body) if err != nil { - return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body)) + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body)) } return &response, nil } @@ -178,7 +208,7 @@ func (s *Server) chooseUpstream() Upstream { return upstream } -func GetUpstream(address string) (Upstream, error) { +func AddressToUpstream(address string, bootstrap string) (Upstream, error) { if strings.Contains(address, "://") { url, err := url.Parse(address) if err != nil { @@ -189,25 +219,28 @@ func GetUpstream(address string) (Upstream, error) { if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.Host}, nil + return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil case "tcp": if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.Host, preferTCP: true}, nil + return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil case "tls": if url.Port() == "" { url.Host += ":853" } - return &dnsOverTLS{address: url.String()}, nil + return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil case "https": - return &dnsOverHTTPS{address: url.String()}, nil + if url.Port() == "" { + url.Host += ":443" + } + return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil default: // assume it's plain DNS if url.Port() == "" { url.Host += ":53" } - return &plainDNS{address: url.String()}, nil + return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil } } @@ -217,5 +250,5 @@ func GetUpstream(address string) (Upstream, error) { // doesn't have port, default to 53 address = net.JoinHostPort(address, "53") } - return &plainDNS{address: address}, nil + return &plainDNS{boot: toBoot(address, bootstrap)}, nil } diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go index f944e695..ca597808 100644 --- a/dnsforward/upstream_pool.go +++ b/dnsforward/upstream_pool.go @@ -2,9 +2,7 @@ package dnsforward import ( "crypto/tls" - "fmt" "net" - "net/url" "sync" "github.com/joomcode/errorx" @@ -27,51 +25,29 @@ import ( // log.Println(r) // pool.Put(c.Conn) type TLSPool struct { - Address string - parsedAddress *url.URL - parsedAddressMutex sync.RWMutex + boot *bootstrapper + // connections conns []net.Conn - sync.Mutex // protects conns -} - -func (n *TLSPool) getHost() (string, error) { - n.parsedAddressMutex.RLock() - if n.parsedAddress != nil { - n.parsedAddressMutex.RUnlock() - return n.parsedAddress.Host, nil - } - n.parsedAddressMutex.RUnlock() - - n.parsedAddressMutex.Lock() - defer n.parsedAddressMutex.Unlock() - url, err := url.Parse(n.Address) - if err != nil { - return "", errorx.Decorate(err, "Failed to parse %s", n.Address) - } - if url.Scheme != "tls" { - return "", fmt.Errorf("TLSPool only supports TLS") - } - n.parsedAddress = url - return n.parsedAddress.Host, nil + connsMutex sync.Mutex // protects conns } func (n *TLSPool) Get() (net.Conn, error) { - host, err := n.getHost() + address, tlsConfig, err := n.boot.get() if err != nil { return nil, err } // get the connection from the slice inside the lock var c net.Conn - n.Lock() + n.connsMutex.Lock() num := len(n.conns) if num > 0 { last := num - 1 c = n.conns[last] n.conns = n.conns[:last] } - n.Unlock() + n.connsMutex.Unlock() // if we got connection from the slice, return it if c != nil { @@ -80,10 +56,10 @@ func (n *TLSPool) Get() (net.Conn, error) { } // we'll need a new connection, dial now - // log.Printf("Dialing to %s", host) - conn, err := tls.Dial("tcp", host, nil) + // log.Printf("Dialing to %s", address) + conn, err := tls.Dial("tcp", address, tlsConfig) if err != nil { - return nil, errorx.Decorate(err, "Failed to connect to %s", host) + return nil, errorx.Decorate(err, "Failed to connect to %s", address) } return conn, nil } @@ -92,7 +68,7 @@ func (n *TLSPool) Put(c net.Conn) { if c == nil { return } - n.Lock() + n.connsMutex.Lock() n.conns = append(n.conns, c) - n.Unlock() + n.connsMutex.Unlock() } diff --git a/dnsforward/upstream_test.go b/dnsforward/upstream_test.go index 975c5035..0b83670f 100644 --- a/dnsforward/upstream_test.go +++ b/dnsforward/upstream_test.go @@ -7,53 +7,65 @@ import ( "github.com/miekg/dns" ) -func TestUpstreamDNS(t *testing.T) { - upstreams := []string{ - "8.8.8.8:53", - "1.1.1.1", - "tcp://1.1.1.1:53", - "176.103.130.130:5353", +func TestUpstreams(t *testing.T) { + upstreams := []struct { + address string + bootstrap string + }{ + { + address: "8.8.8.8:53", + bootstrap: "8.8.8.8:53", + }, + { + address: "1.1.1.1", + bootstrap: "", + }, + { + address: "tcp://1.1.1.1:53", + bootstrap: "", + }, + { + address: "176.103.130.130:5353", + bootstrap: "", + }, + { + address: "tls://1.1.1.1", + bootstrap: "", + }, + { + address: "tls://9.9.9.9:853", + bootstrap: "", + }, + { + address: "tls://security-filter-dns.cleanbrowsing.org", + bootstrap: "8.8.8.8:53", + }, + { + address: "tls://adult-filter-dns.cleanbrowsing.org:853", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://cloudflare-dns.com/dns-query", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://dns.google.com/experimental", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://doh.cleanbrowsing.org/doh/security-filter/", + bootstrap: "", + }, } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } + for _, test := range upstreams { + t.Run(test.address, func(t *testing.T) { + u, err := AddressToUpstream(test.address, test.bootstrap) + if err != nil { + t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err) + } - checkUpstream(t, u, input) - } -} - -func TestUpstreamTLS(t *testing.T) { - upstreams := []string{ - "tls://1.1.1.1", - "tls://9.9.9.9:853", - "tls://security-filter-dns.cleanbrowsing.org", - "tls://adult-filter-dns.cleanbrowsing.org:853", - } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } - - checkUpstream(t, u, input) - } -} - -func TestUpstreamHTTPS(t *testing.T) { - upstreams := []string{ - "https://cloudflare-dns.com/dns-query", - "https://dns.google.com/experimental", - "https://doh.cleanbrowsing.org/doh/security-filter/", - } - for _, input := range upstreams { - u, err := GetUpstream(input) - if err != nil { - t.Fatalf("Failed to choose upstream for %s: %s", input, err) - } - - checkUpstream(t, u, input) + checkUpstream(t, u, test.address) + }) } }