Add support for bootstrapping upstream DNS servers by hostname.
This commit is contained in:
parent
ff1c19cac5
commit
0f5dd661f5
|
@ -116,7 +116,7 @@ Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possib
|
||||||
* `parental_enabled` — Parental control-based DNS requests filtering
|
* `parental_enabled` — Parental control-based DNS requests filtering
|
||||||
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
* `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17
|
||||||
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
* `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes)
|
||||||
* `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname
|
* `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname
|
||||||
* `upstream_dns` — List of upstream DNS servers
|
* `upstream_dns` — List of upstream DNS servers
|
||||||
* `filters` — List of filters, each filter has the following values:
|
* `filters` — List of filters, each filter has the following values:
|
||||||
* `ID` - filter ID (must be unique)
|
* `ID` - filter ID (must be unique)
|
||||||
|
|
|
@ -43,7 +43,6 @@ type dnsConfig struct {
|
||||||
|
|
||||||
dnsforward.FilteringConfig `yaml:",inline"`
|
dnsforward.FilteringConfig `yaml:",inline"`
|
||||||
|
|
||||||
BootstrapDNS string `yaml:"bootstrap_dns"`
|
|
||||||
UpstreamDNS []string `yaml:"upstream_dns"`
|
UpstreamDNS []string `yaml:"upstream_dns"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,8 +62,8 @@ var config = configuration{
|
||||||
QueryLogEnabled: true,
|
QueryLogEnabled: true,
|
||||||
Ratelimit: 20,
|
Ratelimit: 20,
|
||||||
RefuseAny: true,
|
RefuseAny: true,
|
||||||
},
|
|
||||||
BootstrapDNS: "8.8.8.8:53",
|
BootstrapDNS: "8.8.8.8:53",
|
||||||
|
},
|
||||||
UpstreamDNS: defaultDNS,
|
UpstreamDNS: defaultDNS,
|
||||||
},
|
},
|
||||||
Filters: []filter{
|
Filters: []filter{
|
||||||
|
|
|
@ -204,7 +204,7 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
func checkDNS(input string) error {
|
func checkDNS(input string) error {
|
||||||
log.Printf("Checking if DNS %s works...", input)
|
log.Printf("Checking if DNS %s works...", input)
|
||||||
u, err := dnsforward.GetUpstream(input)
|
u, err := dnsforward.AddressToUpstream(input, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err)
|
return fmt.Errorf("Failed to choose upstream for %s: %s", input, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ func generateServerConfig() dnsforward.ServerConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, u := range config.DNS.UpstreamDNS {
|
for _, u := range config.DNS.UpstreamDNS {
|
||||||
upstream, err := dnsforward.GetUpstream(u)
|
upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Couldn't get upstream: %s", err)
|
log.Printf("Couldn't get upstream: %s", err)
|
||||||
// continue, just ignore the upstream
|
// continue, just ignore the upstream
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/joomcode/errorx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bootstrapper struct {
|
||||||
|
address string // in form of "tls://one.one.one.one:853"
|
||||||
|
resolver *net.Resolver // resolver to use to resolve hostname, if neccessary
|
||||||
|
resolved string // in form "IP:port"
|
||||||
|
resolvedConfig *tls.Config
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func toBoot(address, bootstrapAddr string) bootstrapper {
|
||||||
|
var resolver *net.Resolver
|
||||||
|
if bootstrapAddr != "" {
|
||||||
|
resolver = &net.Resolver{
|
||||||
|
PreferGo: true,
|
||||||
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
d := net.Dialer{}
|
||||||
|
return d.DialContext(ctx, network, bootstrapAddr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bootstrapper{
|
||||||
|
address: address,
|
||||||
|
resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// will get usable IP address from Address field, and caches the result
|
||||||
|
func (n *bootstrapper) get() (string, *tls.Config, error) {
|
||||||
|
// TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work
|
||||||
|
n.Lock()
|
||||||
|
if n.resolved != "" { // fast path
|
||||||
|
retval, tlsconfig := n.resolved, n.resolvedConfig
|
||||||
|
n.Unlock()
|
||||||
|
return retval, tlsconfig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// slow path
|
||||||
|
//
|
||||||
|
|
||||||
|
defer n.Unlock()
|
||||||
|
|
||||||
|
justHostPort := n.address
|
||||||
|
if strings.Contains(n.address, "://") {
|
||||||
|
url, err := url.Parse(n.address)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
justHostPort = url.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert host to IP if neccessary, we know that it's scheme://hostname:port/
|
||||||
|
|
||||||
|
// get a host without port
|
||||||
|
host, port, err := net.SplitHostPort(justHostPort)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if it's an IP
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
n.resolved = justHostPort
|
||||||
|
return n.resolved, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// if it's a hostname
|
||||||
|
//
|
||||||
|
|
||||||
|
resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver
|
||||||
|
addrs, err := resolver.LookupIPAddr(context.TODO(), host)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, errorx.Decorate(err, "Failed to lookup %s", host)
|
||||||
|
}
|
||||||
|
for _, addr := range addrs {
|
||||||
|
// TODO: support ipv6, support multiple ipv4
|
||||||
|
if addr.IP.To4() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ip = addr.IP
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip == nil {
|
||||||
|
// couldn't find any suitable IP address
|
||||||
|
return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
n.resolved = net.JoinHostPort(ip.String(), port)
|
||||||
|
n.resolvedConfig = &tls.Config{ServerName: host}
|
||||||
|
return n.resolved, n.resolvedConfig, nil
|
||||||
|
}
|
|
@ -86,6 +86,7 @@ type FilteringConfig struct {
|
||||||
Ratelimit int `yaml:"ratelimit"`
|
Ratelimit int `yaml:"ratelimit"`
|
||||||
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
RatelimitWhitelist []string `yaml:"ratelimit_whitelist"`
|
||||||
RefuseAny bool `yaml:"refuse_any"`
|
RefuseAny bool `yaml:"refuse_any"`
|
||||||
|
BootstrapDNS string `yaml:"bootstrap_dns"`
|
||||||
|
|
||||||
dnsfilter.Config `yaml:",inline"`
|
dnsfilter.Config `yaml:",inline"`
|
||||||
}
|
}
|
||||||
|
@ -105,24 +106,24 @@ var defaultValues = ServerConfig{
|
||||||
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
|
||||||
Upstreams: []Upstream{
|
Upstreams: []Upstream{
|
||||||
//// dns over HTTPS
|
//// dns over HTTPS
|
||||||
// &dnsOverHTTPS{address: "https://1.1.1.1/dns-query"},
|
// &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")},
|
||||||
// &dnsOverHTTPS{address: "https://dns.google.com/experimental"},
|
// &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")},
|
||||||
// &dnsOverHTTPS{address: "https://doh.cleanbrowsing.org/doh/security-filter/"},
|
// &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")},
|
||||||
// &dnsOverHTTPS{address: "https://dns10.quad9.net/dns-query"},
|
// &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")},
|
||||||
// &dnsOverHTTPS{address: "https://doh.powerdns.org"},
|
// &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")},
|
||||||
// &dnsOverHTTPS{address: "https://doh.securedns.eu/dns-query"},
|
// &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")},
|
||||||
|
|
||||||
//// dns over TLS
|
//// dns over TLS
|
||||||
// &dnsOverTLS{address: "tls://8.8.8.8:853"},
|
// &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")},
|
||||||
// &dnsOverTLS{address: "tls://8.8.4.4:853"},
|
// &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")},
|
||||||
// &dnsOverTLS{address: "tls://1.1.1.1:853"},
|
// &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")},
|
||||||
// &dnsOverTLS{address: "tls://1.0.0.1:853"},
|
// &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")},
|
||||||
|
|
||||||
//// plainDNS
|
//// plainDNS
|
||||||
&plainDNS{address: "8.8.8.8:53"},
|
&plainDNS{boot: toBoot("8.8.8.8:53", "")},
|
||||||
&plainDNS{address: "8.8.4.4:53"},
|
&plainDNS{boot: toBoot("8.8.4.4:53", "")},
|
||||||
&plainDNS{address: "1.1.1.1:53"},
|
&plainDNS{boot: toBoot("1.1.1.1:53", "")},
|
||||||
&plainDNS{address: "1.0.0.1:53"},
|
&plainDNS{boot: toBoot("1.0.0.1:53", "")},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ type Upstream interface {
|
||||||
// plain DNS
|
// plain DNS
|
||||||
//
|
//
|
||||||
type plainDNS struct {
|
type plainDNS struct {
|
||||||
address string
|
boot bootstrapper
|
||||||
preferTCP bool
|
preferTCP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,19 +44,25 @@ var defaultTCPClient = dns.Client{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *plainDNS) Address() string { return p.address }
|
// Address returns the original address that we've put in initially, not resolved one
|
||||||
|
func (p *plainDNS) Address() string { return p.boot.address }
|
||||||
|
|
||||||
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
|
addr, _, err := p.boot.get()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if p.preferTCP {
|
if p.preferTCP {
|
||||||
reply, _, err := defaultTCPClient.Exchange(m, p.address)
|
reply, _, err := defaultTCPClient.Exchange(m, addr)
|
||||||
return reply, err
|
return reply, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, _, err := defaultUDPClient.Exchange(m, p.address)
|
reply, _, err := defaultUDPClient.Exchange(m, addr)
|
||||||
if err != nil && reply != nil && reply.Truncated {
|
if err != nil && reply != nil && reply.Truncated {
|
||||||
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
|
||||||
reply, _, err = defaultTCPClient.Exchange(m, p.address)
|
reply, _, err = defaultTCPClient.Exchange(m, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return reply, err
|
return reply, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +70,7 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
// DNS-over-TLS
|
// DNS-over-TLS
|
||||||
//
|
//
|
||||||
type dnsOverTLS struct {
|
type dnsOverTLS struct {
|
||||||
address string
|
boot bootstrapper
|
||||||
pool *TLSPool
|
pool *TLSPool
|
||||||
|
|
||||||
sync.RWMutex // protects pool
|
sync.RWMutex // protects pool
|
||||||
|
@ -77,7 +83,7 @@ var defaultTLSClient = dns.Client{
|
||||||
TLSConfig: &tls.Config{},
|
TLSConfig: &tls.Config{},
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *dnsOverTLS) Address() string { return p.address }
|
func (p *dnsOverTLS) Address() string { return p.boot.address }
|
||||||
|
|
||||||
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
var pool *TLSPool
|
var pool *TLSPool
|
||||||
|
@ -87,7 +93,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
if pool == nil {
|
if pool == nil {
|
||||||
p.Lock()
|
p.Lock()
|
||||||
// lazy initialize it
|
// lazy initialize it
|
||||||
p.pool = &TLSPool{Address: p.address}
|
p.pool = &TLSPool{boot: &p.boot}
|
||||||
p.Unlock()
|
p.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,19 +101,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
poolConn, err := p.pool.Get()
|
poolConn, err := p.pool.Get()
|
||||||
p.RUnlock()
|
p.RUnlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address)
|
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address())
|
||||||
}
|
}
|
||||||
c := dns.Conn{Conn: poolConn}
|
c := dns.Conn{Conn: poolConn}
|
||||||
err = c.WriteMsg(m)
|
err = c.WriteMsg(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
poolConn.Close()
|
poolConn.Close()
|
||||||
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address)
|
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address())
|
||||||
}
|
}
|
||||||
|
|
||||||
reply, err := c.ReadMsg()
|
reply, err := c.ReadMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
poolConn.Close()
|
poolConn.Close()
|
||||||
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address)
|
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address())
|
||||||
}
|
}
|
||||||
p.RLock()
|
p.RLock()
|
||||||
p.pool.Put(poolConn)
|
p.pool.Put(poolConn)
|
||||||
|
@ -119,7 +125,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
// DNS-over-https
|
// DNS-over-https
|
||||||
//
|
//
|
||||||
type dnsOverHTTPS struct {
|
type dnsOverHTTPS struct {
|
||||||
address string
|
boot bootstrapper
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultHTTPSTransport = http.Transport{}
|
var defaultHTTPSTransport = http.Transport{}
|
||||||
|
@ -129,35 +135,59 @@ var defaultHTTPSClient = http.Client{
|
||||||
Timeout: defaultTimeout,
|
Timeout: defaultTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *dnsOverHTTPS) Address() string { return p.address }
|
func (p *dnsOverHTTPS) Address() string { return p.boot.address }
|
||||||
|
|
||||||
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
|
||||||
|
addr, tlsConfig, err := p.boot.get()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address)
|
||||||
|
}
|
||||||
|
|
||||||
buf, err := m.Pack()
|
buf, err := m.Pack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
return nil, errorx.Decorate(err, "Couldn't pack request msg")
|
||||||
}
|
}
|
||||||
bb := bytes.NewBuffer(buf)
|
bb := bytes.NewBuffer(buf)
|
||||||
resp, err := http.Post(p.address, "application/dns-message", bb)
|
|
||||||
|
// set up a custom request with custom URL
|
||||||
|
url, err := url.Parse(p.boot.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address)
|
||||||
|
}
|
||||||
|
req := http.Request{
|
||||||
|
Method: "POST",
|
||||||
|
URL: url,
|
||||||
|
Body: ioutil.NopCloser(bb),
|
||||||
|
Header: make(http.Header),
|
||||||
|
Host: url.Host,
|
||||||
|
}
|
||||||
|
url.Host = addr
|
||||||
|
req.Header.Set("Content-Type", "application/dns-message")
|
||||||
|
client := http.Client{
|
||||||
|
Transport: &http.Transport{TLSClientConfig: tlsConfig},
|
||||||
|
}
|
||||||
|
resp, err := client.Do(&req)
|
||||||
if resp != nil && resp.Body != nil {
|
if resp != nil && resp.Body != nil {
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address)
|
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address)
|
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr)
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address)
|
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr)
|
||||||
}
|
}
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address)
|
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr)
|
||||||
}
|
}
|
||||||
response := dns.Msg{}
|
response := dns.Msg{}
|
||||||
err = response.Unpack(body)
|
err = response.Unpack(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body))
|
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body))
|
||||||
}
|
}
|
||||||
return &response, nil
|
return &response, nil
|
||||||
}
|
}
|
||||||
|
@ -178,7 +208,7 @@ func (s *Server) chooseUpstream() Upstream {
|
||||||
return upstream
|
return upstream
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUpstream(address string) (Upstream, error) {
|
func AddressToUpstream(address string, bootstrap string) (Upstream, error) {
|
||||||
if strings.Contains(address, "://") {
|
if strings.Contains(address, "://") {
|
||||||
url, err := url.Parse(address)
|
url, err := url.Parse(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -189,25 +219,28 @@ func GetUpstream(address string) (Upstream, error) {
|
||||||
if url.Port() == "" {
|
if url.Port() == "" {
|
||||||
url.Host += ":53"
|
url.Host += ":53"
|
||||||
}
|
}
|
||||||
return &plainDNS{address: url.Host}, nil
|
return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil
|
||||||
case "tcp":
|
case "tcp":
|
||||||
if url.Port() == "" {
|
if url.Port() == "" {
|
||||||
url.Host += ":53"
|
url.Host += ":53"
|
||||||
}
|
}
|
||||||
return &plainDNS{address: url.Host, preferTCP: true}, nil
|
return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil
|
||||||
case "tls":
|
case "tls":
|
||||||
if url.Port() == "" {
|
if url.Port() == "" {
|
||||||
url.Host += ":853"
|
url.Host += ":853"
|
||||||
}
|
}
|
||||||
return &dnsOverTLS{address: url.String()}, nil
|
return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||||
case "https":
|
case "https":
|
||||||
return &dnsOverHTTPS{address: url.String()}, nil
|
if url.Port() == "" {
|
||||||
|
url.Host += ":443"
|
||||||
|
}
|
||||||
|
return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||||
default:
|
default:
|
||||||
// assume it's plain DNS
|
// assume it's plain DNS
|
||||||
if url.Port() == "" {
|
if url.Port() == "" {
|
||||||
url.Host += ":53"
|
url.Host += ":53"
|
||||||
}
|
}
|
||||||
return &plainDNS{address: url.String()}, nil
|
return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,5 +250,5 @@ func GetUpstream(address string) (Upstream, error) {
|
||||||
// doesn't have port, default to 53
|
// doesn't have port, default to 53
|
||||||
address = net.JoinHostPort(address, "53")
|
address = net.JoinHostPort(address, "53")
|
||||||
}
|
}
|
||||||
return &plainDNS{address: address}, nil
|
return &plainDNS{boot: toBoot(address, bootstrap)}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,7 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
|
@ -27,51 +25,29 @@ import (
|
||||||
// log.Println(r)
|
// log.Println(r)
|
||||||
// pool.Put(c.Conn)
|
// pool.Put(c.Conn)
|
||||||
type TLSPool struct {
|
type TLSPool struct {
|
||||||
Address string
|
boot *bootstrapper
|
||||||
parsedAddress *url.URL
|
|
||||||
parsedAddressMutex sync.RWMutex
|
|
||||||
|
|
||||||
|
// connections
|
||||||
conns []net.Conn
|
conns []net.Conn
|
||||||
sync.Mutex // protects conns
|
connsMutex sync.Mutex // protects conns
|
||||||
}
|
|
||||||
|
|
||||||
func (n *TLSPool) getHost() (string, error) {
|
|
||||||
n.parsedAddressMutex.RLock()
|
|
||||||
if n.parsedAddress != nil {
|
|
||||||
n.parsedAddressMutex.RUnlock()
|
|
||||||
return n.parsedAddress.Host, nil
|
|
||||||
}
|
|
||||||
n.parsedAddressMutex.RUnlock()
|
|
||||||
|
|
||||||
n.parsedAddressMutex.Lock()
|
|
||||||
defer n.parsedAddressMutex.Unlock()
|
|
||||||
url, err := url.Parse(n.Address)
|
|
||||||
if err != nil {
|
|
||||||
return "", errorx.Decorate(err, "Failed to parse %s", n.Address)
|
|
||||||
}
|
|
||||||
if url.Scheme != "tls" {
|
|
||||||
return "", fmt.Errorf("TLSPool only supports TLS")
|
|
||||||
}
|
|
||||||
n.parsedAddress = url
|
|
||||||
return n.parsedAddress.Host, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *TLSPool) Get() (net.Conn, error) {
|
func (n *TLSPool) Get() (net.Conn, error) {
|
||||||
host, err := n.getHost()
|
address, tlsConfig, err := n.boot.get()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the connection from the slice inside the lock
|
// get the connection from the slice inside the lock
|
||||||
var c net.Conn
|
var c net.Conn
|
||||||
n.Lock()
|
n.connsMutex.Lock()
|
||||||
num := len(n.conns)
|
num := len(n.conns)
|
||||||
if num > 0 {
|
if num > 0 {
|
||||||
last := num - 1
|
last := num - 1
|
||||||
c = n.conns[last]
|
c = n.conns[last]
|
||||||
n.conns = n.conns[:last]
|
n.conns = n.conns[:last]
|
||||||
}
|
}
|
||||||
n.Unlock()
|
n.connsMutex.Unlock()
|
||||||
|
|
||||||
// if we got connection from the slice, return it
|
// if we got connection from the slice, return it
|
||||||
if c != nil {
|
if c != nil {
|
||||||
|
@ -80,10 +56,10 @@ func (n *TLSPool) Get() (net.Conn, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// we'll need a new connection, dial now
|
// we'll need a new connection, dial now
|
||||||
// log.Printf("Dialing to %s", host)
|
// log.Printf("Dialing to %s", address)
|
||||||
conn, err := tls.Dial("tcp", host, nil)
|
conn, err := tls.Dial("tcp", address, tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Failed to connect to %s", host)
|
return nil, errorx.Decorate(err, "Failed to connect to %s", address)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
@ -92,7 +68,7 @@ func (n *TLSPool) Put(c net.Conn) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
n.Lock()
|
n.connsMutex.Lock()
|
||||||
n.conns = append(n.conns, c)
|
n.conns = append(n.conns, c)
|
||||||
n.Unlock()
|
n.connsMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,53 +7,65 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUpstreamDNS(t *testing.T) {
|
func TestUpstreams(t *testing.T) {
|
||||||
upstreams := []string{
|
upstreams := []struct {
|
||||||
"8.8.8.8:53",
|
address string
|
||||||
"1.1.1.1",
|
bootstrap string
|
||||||
"tcp://1.1.1.1:53",
|
}{
|
||||||
"176.103.130.130:5353",
|
{
|
||||||
|
address: "8.8.8.8:53",
|
||||||
|
bootstrap: "8.8.8.8:53",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "1.1.1.1",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "tcp://1.1.1.1:53",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "176.103.130.130:5353",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "tls://1.1.1.1",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "tls://9.9.9.9:853",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "tls://security-filter-dns.cleanbrowsing.org",
|
||||||
|
bootstrap: "8.8.8.8:53",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "tls://adult-filter-dns.cleanbrowsing.org:853",
|
||||||
|
bootstrap: "8.8.8.8:53",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "https://cloudflare-dns.com/dns-query",
|
||||||
|
bootstrap: "8.8.8.8:53",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "https://dns.google.com/experimental",
|
||||||
|
bootstrap: "8.8.8.8:53",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
address: "https://doh.cleanbrowsing.org/doh/security-filter/",
|
||||||
|
bootstrap: "",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, input := range upstreams {
|
for _, test := range upstreams {
|
||||||
u, err := GetUpstream(input)
|
t.Run(test.address, func(t *testing.T) {
|
||||||
|
u, err := AddressToUpstream(test.address, test.bootstrap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to choose upstream for %s: %s", input, err)
|
t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
checkUpstream(t, u, input)
|
checkUpstream(t, u, test.address)
|
||||||
}
|
})
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamTLS(t *testing.T) {
|
|
||||||
upstreams := []string{
|
|
||||||
"tls://1.1.1.1",
|
|
||||||
"tls://9.9.9.9:853",
|
|
||||||
"tls://security-filter-dns.cleanbrowsing.org",
|
|
||||||
"tls://adult-filter-dns.cleanbrowsing.org:853",
|
|
||||||
}
|
|
||||||
for _, input := range upstreams {
|
|
||||||
u, err := GetUpstream(input)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to choose upstream for %s: %s", input, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
checkUpstream(t, u, input)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpstreamHTTPS(t *testing.T) {
|
|
||||||
upstreams := []string{
|
|
||||||
"https://cloudflare-dns.com/dns-query",
|
|
||||||
"https://dns.google.com/experimental",
|
|
||||||
"https://doh.cleanbrowsing.org/doh/security-filter/",
|
|
||||||
}
|
|
||||||
for _, input := range upstreams {
|
|
||||||
u, err := GetUpstream(input)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to choose upstream for %s: %s", input, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
checkUpstream(t, u, input)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue