AdGuardHome/upstream/dns_upstream.go

106 lines
2.6 KiB
Go

package upstream
import (
"crypto/tls"
"time"
"github.com/miekg/dns"
"golang.org/x/net/context"
)
// DnsUpstream is a very simple upstream implementation for plain DNS
type DnsUpstream struct {
endpoint string // IP:port
timeout time.Duration // Max read and write timeout
proto string // Protocol (tcp, tcp-tls, or udp)
transport *Transport // Persistent connections cache
}
// NewDnsUpstream creates a new DNS upstream
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
u := &DnsUpstream{
endpoint: endpoint,
timeout: defaultTimeout,
proto: proto,
}
var tlsConfig *tls.Config
if proto == "tcp-tls" {
tlsConfig = new(tls.Config)
tlsConfig.ServerName = tlsServerName
}
// Initialize the connections cache
u.transport = NewTransport(endpoint)
u.transport.tlsConfig = tlsConfig
u.transport.Start()
return u, nil
}
// Exchange provides an implementation for the Upstream interface
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
resp, err := u.exchange(u.proto, query)
// Retry over TCP if response is truncated
if err == dns.ErrTruncated && u.proto == "udp" {
resp, err = u.exchange("tcp", query)
} else if err == dns.ErrTruncated && resp != nil {
// Reassemble something to be sent to client
m := new(dns.Msg)
m.SetReply(query)
m.Truncated = true
m.Authoritative = true
m.Rcode = dns.RcodeSuccess
return m, nil
}
if err != nil {
resp = &dns.Msg{}
resp.SetRcode(resp, dns.RcodeServerFailure)
}
return resp, err
}
// Clear resources
func (u *DnsUpstream) Close() error {
// Close active connections
u.transport.Stop()
return nil
}
// Performs a synchronous query. It sends the message m via the conn
// c and waits for a reply. The conn c is not closed.
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
// Establish a connection if needed (or reuse cached)
conn, err := u.transport.Dial(proto)
if err != nil {
return nil, err
}
// Write the request with a timeout
conn.SetWriteDeadline(time.Now().Add(u.timeout))
if err = conn.WriteMsg(query); err != nil {
conn.Close() // Not giving it back
return nil, err
}
// Write response with a timeout
conn.SetReadDeadline(time.Now().Add(u.timeout))
r, err = conn.ReadMsg()
if err != nil {
conn.Close() // Not giving it back
} else if err == nil && r.Id != query.Id {
err = dns.ErrId
conn.Close() // Not giving it back
}
if err == nil {
// Return it back to the connections cache if there were no errors
u.transport.Yield(conn)
}
return r, err
}