dnsforward -- Make Upstream interface give access to Address field.

This commit is contained in:
Eugene Bujak 2018-12-05 12:57:14 +03:00
parent 8396dc2fdb
commit e5d2f883ac
2 changed files with 41 additions and 34 deletions

View File

@ -96,24 +96,24 @@ var defaultValues = ServerConfig{
FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600},
Upstreams: []Upstream{ Upstreams: []Upstream{
//// dns over HTTPS //// dns over HTTPS
// &dnsOverHTTPS{Address: "https://1.1.1.1/dns-query"}, // &dnsOverHTTPS{address: "https://1.1.1.1/dns-query"},
// &dnsOverHTTPS{Address: "https://dns.google.com/experimental"}, // &dnsOverHTTPS{address: "https://dns.google.com/experimental"},
// &dnsOverHTTPS{Address: "https://doh.cleanbrowsing.org/doh/security-filter/"}, // &dnsOverHTTPS{address: "https://doh.cleanbrowsing.org/doh/security-filter/"},
// &dnsOverHTTPS{Address: "https://dns10.quad9.net/dns-query"}, // &dnsOverHTTPS{address: "https://dns10.quad9.net/dns-query"},
// &dnsOverHTTPS{Address: "https://doh.powerdns.org"}, // &dnsOverHTTPS{address: "https://doh.powerdns.org"},
// &dnsOverHTTPS{Address: "https://doh.securedns.eu/dns-query"}, // &dnsOverHTTPS{address: "https://doh.securedns.eu/dns-query"},
//// dns over TLS //// dns over TLS
// &dnsOverTLS{Address: "tls://8.8.8.8:853"}, // &dnsOverTLS{address: "tls://8.8.8.8:853"},
// &dnsOverTLS{Address: "tls://8.8.4.4:853"}, // &dnsOverTLS{address: "tls://8.8.4.4:853"},
// &dnsOverTLS{Address: "tls://1.1.1.1:853"}, // &dnsOverTLS{address: "tls://1.1.1.1:853"},
// &dnsOverTLS{Address: "tls://1.0.0.1:853"}, // &dnsOverTLS{address: "tls://1.0.0.1:853"},
//// plainDNS //// plainDNS
&plainDNS{Address: "8.8.8.8:53"}, &plainDNS{address: "8.8.8.8:53"},
&plainDNS{Address: "8.8.4.4:53"}, &plainDNS{address: "8.8.4.4:53"},
&plainDNS{Address: "1.1.1.1:53"}, &plainDNS{address: "1.1.1.1:53"},
&plainDNS{Address: "1.0.0.1:53"}, &plainDNS{address: "1.0.0.1:53"},
}, },
} }

View File

@ -22,13 +22,14 @@ const defaultTimeout = time.Second * 10
type Upstream interface { type Upstream interface {
Exchange(m *dns.Msg) (*dns.Msg, error) Exchange(m *dns.Msg) (*dns.Msg, error)
Address() string
} }
// //
// plain DNS // plain DNS
// //
type plainDNS struct { type plainDNS struct {
Address string address string
} }
var defaultUDPClient = dns.Client{ var defaultUDPClient = dns.Client{
@ -42,11 +43,13 @@ var defaultTCPClient = dns.Client{
Timeout: defaultTimeout, Timeout: defaultTimeout,
} }
func (p *plainDNS) Address() string { return p.address }
func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
reply, _, err := defaultUDPClient.Exchange(m, p.Address) reply, _, err := defaultUDPClient.Exchange(m, p.address)
if err != nil && reply != nil && reply.Truncated { if err != nil && reply != nil && reply.Truncated {
log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String())
reply, _, err = defaultTCPClient.Exchange(m, p.Address) reply, _, err = defaultTCPClient.Exchange(m, p.address)
} }
return reply, err return reply, err
} }
@ -55,7 +58,7 @@ func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) {
// DNS-over-TLS // DNS-over-TLS
// //
type dnsOverTLS struct { type dnsOverTLS struct {
Address string address string
pool *TLSPool pool *TLSPool
sync.RWMutex // protects pool sync.RWMutex // protects pool
@ -68,6 +71,8 @@ var defaultTLSClient = dns.Client{
TLSConfig: &tls.Config{}, TLSConfig: &tls.Config{},
} }
func (p *dnsOverTLS) Address() string { return p.address }
func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
var pool *TLSPool var pool *TLSPool
p.RLock() p.RLock()
@ -76,7 +81,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
if pool == nil { if pool == nil {
p.Lock() p.Lock()
// lazy initialize it // lazy initialize it
p.pool = &TLSPool{Address: p.Address} p.pool = &TLSPool{Address: p.address}
p.Unlock() p.Unlock()
} }
@ -84,19 +89,19 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
poolConn, err := p.pool.Get() poolConn, err := p.pool.Get()
p.RUnlock() p.RUnlock()
if err != nil { if err != nil {
return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address) return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.address)
} }
c := dns.Conn{Conn: poolConn} c := dns.Conn{Conn: poolConn}
err = c.WriteMsg(m) err = c.WriteMsg(m)
if err != nil { if err != nil {
poolConn.Close() poolConn.Close()
return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address) return nil, errorx.Decorate(err, "Failed to send a request to %s", p.address)
} }
reply, err := c.ReadMsg() reply, err := c.ReadMsg()
if err != nil { if err != nil {
poolConn.Close() poolConn.Close()
return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address) return nil, errorx.Decorate(err, "Failed to read a request from %s", p.address)
} }
p.RLock() p.RLock()
p.pool.Put(poolConn) p.pool.Put(poolConn)
@ -108,7 +113,7 @@ func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) {
// DNS-over-https // DNS-over-https
// //
type dnsOverHTTPS struct { type dnsOverHTTPS struct {
Address string address string
} }
var defaultHTTPSTransport = http.Transport{} var defaultHTTPSTransport = http.Transport{}
@ -118,33 +123,35 @@ var defaultHTTPSClient = http.Client{
Timeout: defaultTimeout, Timeout: defaultTimeout,
} }
func (p *dnsOverHTTPS) Address() string { return p.address }
func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) {
buf, err := m.Pack() buf, err := m.Pack()
if err != nil { if err != nil {
return nil, errorx.Decorate(err, "Couldn't pack request msg") return nil, errorx.Decorate(err, "Couldn't pack request msg")
} }
bb := bytes.NewBuffer(buf) bb := bytes.NewBuffer(buf)
resp, err := http.Post(p.Address, "application/dns-message", bb) resp, err := http.Post(p.address, "application/dns-message", bb)
if resp != nil && resp.Body != nil { if resp != nil && resp.Body != nil {
defer resp.Body.Close() defer resp.Body.Close()
} }
if err != nil { if err != nil {
return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.Address) return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", p.address)
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.Address) return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", p.address)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.Address) return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, p.address)
} }
if len(body) == 0 { if len(body) == 0 {
return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.Address) return nil, fmt.Errorf("Got an unexpected empty body from '%s'", p.address)
} }
response := dns.Msg{} response := dns.Msg{}
err = response.Unpack(body) err = response.Unpack(body)
if err != nil { if err != nil {
return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.Address, string(body)) return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", p.address, string(body))
} }
return &response, nil return &response, nil
} }
@ -176,20 +183,20 @@ func GetUpstream(address string) (Upstream, error) {
if url.Port() == "" { if url.Port() == "" {
url.Host += ":53" url.Host += ":53"
} }
return &plainDNS{Address: url.String()}, nil return &plainDNS{address: url.String()}, nil
case "tls": case "tls":
if url.Port() == "" { if url.Port() == "" {
url.Host += ":853" url.Host += ":853"
} }
return &dnsOverTLS{Address: url.String()}, nil return &dnsOverTLS{address: url.String()}, nil
case "https": case "https":
return &dnsOverHTTPS{Address: url.String()}, nil return &dnsOverHTTPS{address: url.String()}, nil
default: default:
// assume it's plain DNS // assume it's plain DNS
if url.Port() == "" { if url.Port() == "" {
url.Host += ":53" url.Host += ":53"
} }
return &plainDNS{Address: url.String()}, nil return &plainDNS{address: url.String()}, nil
} }
} }
@ -199,5 +206,5 @@ func GetUpstream(address string) (Upstream, error) {
// doesn't have port, default to 53 // doesn't have port, default to 53
address = net.JoinHostPort(address, "53") address = net.JoinHostPort(address, "53")
} }
return &plainDNS{Address: address}, nil return &plainDNS{address: address}, nil
} }