467 lines
11 KiB
Go
467 lines
11 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package tsdns
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"hash/crc32"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"inet.af/netaddr"
|
|
"tailscale.com/logtail/backoff"
|
|
"tailscale.com/net/netns"
|
|
"tailscale.com/types/logger"
|
|
)
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
const headerBytes = 12
|
|
|
|
// connCount is the number of UDP connections to use for forwarding.
|
|
const connCount = 32
|
|
|
|
const (
|
|
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
|
cleanupInterval = 30 * time.Second
|
|
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
|
responseTimeout = 5 * time.Second
|
|
)
|
|
|
|
var errNoUpstreams = errors.New("upstream nameservers not set")
|
|
|
|
var aLongTimeAgo = time.Unix(0, 1)
|
|
|
|
type forwardingRecord struct {
|
|
src netaddr.IPPort
|
|
createdAt time.Time
|
|
}
|
|
|
|
// txid identifies a DNS transaction.
|
|
//
|
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
|
// the lower 32 bits are the zero-extended bits of the DNS Request ID;
|
|
// the upper 32 bits are the CRC32 checksum of the first question in the request.
|
|
// This makes probability of txid collision negligible.
|
|
type txid uint64
|
|
|
|
// getTxID computes the txid of the given DNS packet.
|
|
func getTxID(packet []byte) txid {
|
|
if len(packet) < headerBytes {
|
|
return 0
|
|
}
|
|
|
|
dnsid := binary.BigEndian.Uint16(packet[0:2])
|
|
qcount := binary.BigEndian.Uint16(packet[4:6])
|
|
if qcount == 0 {
|
|
return txid(dnsid)
|
|
}
|
|
|
|
offset := headerBytes
|
|
for i := uint16(0); i < qcount; i++ {
|
|
// Note: this relies on the fact that names are not compressed in questions,
|
|
// so they are guaranteed to end with a NUL byte.
|
|
//
|
|
// Justification:
|
|
// RFC 1035 doesn't seem to explicitly prohibit compressing names in questions,
|
|
// but this is exceedingly unlikely to be done in practice. A DNS request
|
|
// with multiple questions is ill-defined (which questions do the header flags apply to?)
|
|
// and a single question would have to contain a pointer to an *answer*,
|
|
// which would be excessively smart, pointless (an answer can just as well refer to the question)
|
|
// and perhaps even prohibited: a draft RFC (draft-ietf-dnsind-local-compression-05) states:
|
|
//
|
|
// > It is important that these pointers always point backwards.
|
|
//
|
|
// This is said in summarizing RFC 1035, although that phrase does not appear in the original RFC.
|
|
// Additionally, (https://cr.yp.to/djbdns/notes.html) states:
|
|
//
|
|
// > The precise rule is that a name can be compressed if it is a response owner name,
|
|
// > the name in NS data, the name in CNAME data, the name in PTR data, the name in MX data,
|
|
// > or one of the names in SOA data.
|
|
namebytes := bytes.IndexByte(packet[offset:], 0)
|
|
// ... | name | NUL | type | class
|
|
// ?? 1 2 2
|
|
offset = offset + namebytes + 5
|
|
if len(packet) < offset {
|
|
// Corrupt packet; don't crash.
|
|
return txid(dnsid)
|
|
}
|
|
}
|
|
|
|
hash := crc32.ChecksumIEEE(packet[headerBytes:offset])
|
|
return (txid(hash) << 32) | txid(dnsid)
|
|
}
|
|
|
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
|
type forwarder struct {
|
|
logf logger.Logf
|
|
|
|
// responses is a channel by which responses are returned.
|
|
responses chan Packet
|
|
// closed signals all goroutines to stop.
|
|
closed chan struct{}
|
|
// wg signals when all goroutines have stopped.
|
|
wg sync.WaitGroup
|
|
|
|
// conns are the UDP connections used for forwarding.
|
|
// A random one is selected for each request, regardless of the target upstream.
|
|
conns []*fwdConn
|
|
|
|
mu sync.Mutex
|
|
// upstreams are the nameserver addresses that should be used for forwarding.
|
|
upstreams []net.Addr
|
|
// txMap maps DNS txids to active forwarding records.
|
|
txMap map[txid]forwardingRecord
|
|
}
|
|
|
|
func init() {
|
|
rand.Seed(time.Now().UnixNano())
|
|
}
|
|
|
|
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
|
|
return &forwarder{
|
|
logf: logger.WithPrefix(logf, "forward: "),
|
|
responses: responses,
|
|
closed: make(chan struct{}),
|
|
conns: make([]*fwdConn, connCount),
|
|
txMap: make(map[txid]forwardingRecord),
|
|
}
|
|
}
|
|
|
|
func (f *forwarder) Start() error {
|
|
f.wg.Add(connCount + 1)
|
|
for idx := range f.conns {
|
|
f.conns[idx] = newFwdConn(f.logf, idx)
|
|
go f.recv(f.conns[idx])
|
|
}
|
|
go f.cleanMap()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (f *forwarder) Close() {
|
|
select {
|
|
case <-f.closed:
|
|
return
|
|
default:
|
|
// continue
|
|
}
|
|
close(f.closed)
|
|
|
|
for _, conn := range f.conns {
|
|
conn.close()
|
|
}
|
|
|
|
f.wg.Wait()
|
|
}
|
|
|
|
func (f *forwarder) setUpstreams(upstreams []net.Addr) {
|
|
f.mu.Lock()
|
|
f.upstreams = upstreams
|
|
f.mu.Unlock()
|
|
}
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
func (f *forwarder) send(packet []byte, dst net.Addr) {
|
|
connIdx := rand.Intn(connCount)
|
|
conn := f.conns[connIdx]
|
|
conn.send(packet, dst)
|
|
}
|
|
|
|
func (f *forwarder) recv(conn *fwdConn) {
|
|
defer f.wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-f.closed:
|
|
return
|
|
default:
|
|
}
|
|
out := make([]byte, maxResponseBytes)
|
|
n := conn.read(out)
|
|
if n == 0 {
|
|
continue
|
|
}
|
|
if n < headerBytes {
|
|
f.logf("recv: packet too small (%d bytes)", n)
|
|
}
|
|
|
|
out = out[:n]
|
|
txid := getTxID(out)
|
|
|
|
f.mu.Lock()
|
|
|
|
record, found := f.txMap[txid]
|
|
// At most one nameserver will return a response:
|
|
// the first one to do so will delete txid from the map.
|
|
if !found {
|
|
f.mu.Unlock()
|
|
continue
|
|
}
|
|
delete(f.txMap, txid)
|
|
|
|
f.mu.Unlock()
|
|
|
|
packet := Packet{
|
|
Payload: out,
|
|
Addr: record.src,
|
|
}
|
|
select {
|
|
case <-f.closed:
|
|
return
|
|
case f.responses <- packet:
|
|
// continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
|
func (f *forwarder) cleanMap() {
|
|
defer f.wg.Done()
|
|
|
|
t := time.NewTicker(cleanupInterval)
|
|
defer t.Stop()
|
|
|
|
var now time.Time
|
|
for {
|
|
select {
|
|
case <-f.closed:
|
|
return
|
|
case now = <-t.C:
|
|
// continue
|
|
}
|
|
|
|
f.mu.Lock()
|
|
for k, v := range f.txMap {
|
|
if now.Sub(v.createdAt) > responseTimeout {
|
|
delete(f.txMap, k)
|
|
}
|
|
}
|
|
f.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// forward forwards the query to all upstream nameservers and returns the first response.
|
|
func (f *forwarder) forward(query Packet) error {
|
|
txid := getTxID(query.Payload)
|
|
|
|
f.mu.Lock()
|
|
|
|
upstreams := f.upstreams
|
|
if len(upstreams) == 0 {
|
|
f.mu.Unlock()
|
|
return errNoUpstreams
|
|
}
|
|
f.txMap[txid] = forwardingRecord{
|
|
src: query.Addr,
|
|
createdAt: time.Now(),
|
|
}
|
|
|
|
f.mu.Unlock()
|
|
|
|
for _, upstream := range upstreams {
|
|
f.send(query.Payload, upstream)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// A fwdConn manages a single connection used to forward DNS requests.
|
|
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
|
|
// fwdConn detects such situations and transparently creates new connections.
|
|
type fwdConn struct {
|
|
// logf allows a fwdConn to log.
|
|
logf logger.Logf
|
|
|
|
// wg tracks the number of outstanding conn.Read and conn.Write calls.
|
|
wg sync.WaitGroup
|
|
// change allows calls to read to block until a the network connection has been replaced.
|
|
change *sync.Cond
|
|
|
|
// mu protects fields that follow it; it is also change's Locker.
|
|
mu sync.Mutex
|
|
// closed tracks whether fwdConn has been permanently closed.
|
|
closed bool
|
|
// conn is the current active connection.
|
|
conn net.PacketConn
|
|
}
|
|
|
|
func newFwdConn(logf logger.Logf, idx int) *fwdConn {
|
|
c := new(fwdConn)
|
|
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx))
|
|
c.change = sync.NewCond(&c.mu)
|
|
// c.conn is created lazily in send
|
|
return c
|
|
}
|
|
|
|
// send sends packet to dst using c's connection.
|
|
// It is best effort. It is UDP, after all. Failures are logged.
|
|
func (c *fwdConn) send(packet []byte, dst net.Addr) {
|
|
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
|
backOff := func(err error) {
|
|
if b == nil {
|
|
b = backoff.NewBackoff("tsdns-fwdConn-send", c.logf, 30*time.Second)
|
|
}
|
|
b.BackOff(context.Background(), err)
|
|
}
|
|
|
|
for {
|
|
// Gather the current connection.
|
|
// We can't hold the lock while we call WriteTo.
|
|
c.mu.Lock()
|
|
conn := c.conn
|
|
closed := c.closed
|
|
if closed {
|
|
c.mu.Unlock()
|
|
return
|
|
}
|
|
if conn == nil {
|
|
c.reconnectLocked()
|
|
c.mu.Unlock()
|
|
continue
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
c.wg.Add(1)
|
|
_, err := conn.WriteTo(packet, dst)
|
|
c.wg.Done()
|
|
if err == nil {
|
|
// Success
|
|
return
|
|
}
|
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
|
// We intentionally closed this connection.
|
|
// It has been replaced by a new connection. Try again.
|
|
continue
|
|
}
|
|
// Something else went wrong.
|
|
// We have three choices here: try again, give up, or create a new connection.
|
|
var opErr *net.OpError
|
|
if !errors.As(err, &opErr) {
|
|
// Weird. All errors from the net package should be *net.OpError. Bail.
|
|
c.logf("send: non-*net.OpErr %v (%T)", err, err)
|
|
return
|
|
}
|
|
if opErr.Temporary() || opErr.Timeout() {
|
|
// I doubt that either of these can happen (this is UDP),
|
|
// but go ahead and try again.
|
|
backOff(err)
|
|
continue
|
|
}
|
|
if networkIsDown(err) {
|
|
// Fail.
|
|
c.logf("send: network is down")
|
|
return
|
|
}
|
|
if networkIsUnreachable(err) {
|
|
// This can be caused by a link change.
|
|
// Replace the existing connection with a new one.
|
|
c.mu.Lock()
|
|
// It's possible that multiple senders discovered simultaneously
|
|
// that the network is unreachable. Avoid reconnecting multiple times:
|
|
// Only reconnect if the current connection is the one that we
|
|
// discovered to be problematic.
|
|
if c.conn == conn {
|
|
backOff(err)
|
|
c.reconnectLocked()
|
|
}
|
|
c.mu.Unlock()
|
|
// Try again with our new network connection.
|
|
continue
|
|
}
|
|
// Unrecognized error. Fail.
|
|
c.logf("send: unrecognized error: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// read waits for a response from c's connection.
|
|
// It returns the number of bytes read, which may be 0
|
|
// in case of an error or a closed connection.
|
|
func (c *fwdConn) read(out []byte) int {
|
|
for {
|
|
// Gather the current connection.
|
|
// We can't hold the lock while we call ReadFrom.
|
|
c.mu.Lock()
|
|
conn := c.conn
|
|
closed := c.closed
|
|
if closed {
|
|
c.mu.Unlock()
|
|
return 0
|
|
}
|
|
if conn == nil {
|
|
// There is no current connection.
|
|
// Wait for the connection to change, then try again.
|
|
c.change.Wait()
|
|
c.mu.Unlock()
|
|
continue
|
|
}
|
|
c.mu.Unlock()
|
|
|
|
c.wg.Add(1)
|
|
n, _, err := conn.ReadFrom(out)
|
|
c.wg.Done()
|
|
if err == nil {
|
|
// Success.
|
|
return n
|
|
}
|
|
if errors.Is(err, os.ErrDeadlineExceeded) {
|
|
// We intentionally closed this connection.
|
|
// It has been replaced by a new connection. Try again.
|
|
continue
|
|
}
|
|
|
|
c.logf("read: unrecognized error: %v", err)
|
|
return 0
|
|
}
|
|
}
|
|
|
|
// reconnectLocked replaces the current connection with a new one.
|
|
// c.mu must be locked.
|
|
func (c *fwdConn) reconnectLocked() {
|
|
c.closeConnLocked()
|
|
// Make a new connection.
|
|
conn, err := netns.Listener().ListenPacket(context.Background(), "udp", "")
|
|
if err != nil {
|
|
c.logf("ListenPacket failed: %v", err)
|
|
} else {
|
|
c.conn = conn
|
|
}
|
|
// Broadcast that a new connection is available.
|
|
c.change.Broadcast()
|
|
}
|
|
|
|
// closeCurrentConn closes the current connection.
|
|
// c.mu must be locked.
|
|
func (c *fwdConn) closeConnLocked() {
|
|
if c.conn == nil {
|
|
return
|
|
}
|
|
// Unblock all readers/writers, wait for them, close the connection.
|
|
c.conn.SetDeadline(aLongTimeAgo)
|
|
c.wg.Wait()
|
|
c.conn.Close()
|
|
c.conn = nil
|
|
}
|
|
|
|
// close permanently closes c.
|
|
func (c *fwdConn) close() {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
if c.closed {
|
|
return
|
|
}
|
|
c.closed = true
|
|
c.closeConnLocked()
|
|
// Unblock any remaining readers.
|
|
c.change.Broadcast()
|
|
}
|