net/dnscache: add overly simplistic DNS cache package for selective use

I started to write a full DNS caching resolver and I realized it was
overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm
doing this tiny one instead for now, just for all our netcheck STUN
derp lookups, and connections to DERP servers. (This will be caching a
exactly 8 DNS entries, all ours.)

Fixes #145 (can be better later, of course)
This commit is contained in:
Brad Fitzpatrick 2020-03-05 10:29:19 -08:00
parent a36ccb8525
commit 2cff9016e4
5 changed files with 191 additions and 13 deletions

View File

@ -26,6 +26,7 @@ import (
"time"
"tailscale.com/derp"
"tailscale.com/net/dnscache"
"tailscale.com/types/key"
"tailscale.com/types/logger"
)
@ -37,7 +38,8 @@ import (
// Send/Recv will completely re-establish the connection (unless Close
// has been called).
type Client struct {
TLSConfig *tls.Config // for sever connection, optional, nil means default
TLSConfig *tls.Config // for sever connection, optional, nil means default
DNSCache *dnscache.Resolver // optional; if nil, no caching
privateKey key.Private
logf logger.Logf
@ -137,11 +139,23 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
}
}()
host := c.url.Hostname()
hostOrIP := host
var d net.Dialer
log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
log.Printf("Dialing: %q", net.JoinHostPort(host, urlPort(c.url)))
if c.DNSCache != nil {
ip, err := c.DNSCache.LookupIP(ctx, host)
if err != nil {
return nil, err
}
hostOrIP = ip.String()
}
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url)))
if err != nil {
return nil, err
return nil, fmt.Errorf("Dial of %q: %v", host, err)
}
// Now that we have a TCP connection, force close it.

151
net/dnscache/dnscache.go Normal file
View File

@ -0,0 +1,151 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package dnscache contains a minimal DNS cache that makes a bunch of
// assumptions that are only valid for us. Not recommended for general use.
package dnscache
import (
"context"
"fmt"
"net"
"sync"
"time"
"golang.org/x/sync/singleflight"
)
var single = &Resolver{
Forward: &net.Resolver{PreferGo: true},
}
// Get returns a caching Resolver singleton.
func Get() *Resolver { return single }
const fixedTTL = 10 * time.Minute
// Resolver is a minimal DNS caching resolver.
//
// The TTL is always fixed for now. It's not intended for general use.
// Cache entries are never cleaned up so it's intended that this is
// only used with a fixed set of hostnames.
type Resolver struct {
// Forward is the resolver to use to populate the cache.
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
sf singleflight.Group
mu sync.Mutex
ipCache map[string]ipCacheEntry
}
type ipCacheEntry struct {
ip net.IP
expires time.Time
}
func (r *Resolver) fwd() *net.Resolver {
if r.Forward != nil {
return r.Forward
}
return net.DefaultResolver
}
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
return ip4, nil
}
return ip, nil
}
if ip, ok := r.lookupIPCache(host); ok {
return ip, nil
}
ch := r.sf.DoChan(host, func() (interface{}, error) {
ip, err := r.lookupIP(host)
if err != nil {
return nil, err
}
return ip, nil
})
select {
case res := <-ch:
if res.Err != nil {
return nil, res.Err
}
return res.Val.(net.IP), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
r.mu.Lock()
defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
return ent.ip, true
}
return nil, false
}
func (r *Resolver) lookupIP(host string) (net.IP, error) {
if ip, ok := r.lookupIPCache(host); ok {
return ip, nil
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ips, err := r.fwd().LookupIPAddr(ctx, host)
if err != nil {
return nil, err
}
if len(ips) == 0 {
return nil, fmt.Errorf("no IPs for %q found", host)
}
for _, ipa := range ips {
if ip4 := ipa.IP.To4(); ip4 != nil {
return r.addIPCache(host, ip4, fixedTTL), nil
}
}
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
}
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
if isPrivateIP(ip) {
// Don't cache obviously wrong entries from captive portals.
// TODO: use DoH or DoT for the forwarding resolver?
return ip
}
r.mu.Lock()
defer r.mu.Unlock()
if r.ipCache == nil {
r.ipCache = make(map[string]ipCacheEntry)
}
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)}
return ip
}
func mustCIDR(s string) *net.IPNet {
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
if err != nil {
panic(err)
}
return ipNet
}
func isPrivateIP(ip net.IP) bool {
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
}
var (
private1 = mustCIDR("10.0.0.0/8")
private2 = mustCIDR("172.16.0.0/12")
private3 = mustCIDR("192.168.0.0/16")
)

View File

@ -17,6 +17,7 @@ import (
"golang.org/x/sync/errgroup"
"tailscale.com/interfaces"
"tailscale.com/net/dnscache"
"tailscale.com/stun"
"tailscale.com/stunner"
"tailscale.com/types/logger"
@ -181,6 +182,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Endpoint: add,
Servers: stunServers,
Logf: logf,
DNSCache: dnscache.Get(),
}
grp.Go(func() error { return s4.Run(ctx) })
go reader(s4, pc4, unlimited)
@ -190,6 +192,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Endpoint: addHair,
Servers: stunServers,
Logf: logf,
DNSCache: dnscache.Get(),
}
grp.Go(func() error { return s4Hair.Run(ctx) })
go reader(s4Hair, pc4Hair, 2)
@ -201,6 +204,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
Servers: stunServers6,
Logf: logf,
OnlyIPv6: true,
DNSCache: dnscache.Get(),
}
grp.Go(func() error { return s6.Run(ctx) })
go reader(s6, pc6, unlimited)

View File

@ -12,6 +12,7 @@ import (
"sync"
"time"
"tailscale.com/net/dnscache"
"tailscale.com/stun"
)
@ -38,9 +39,9 @@ type Stunner struct {
Servers []string // STUN servers to contact
// Resolver optionally specifies a resolver to use for DNS lookups.
// If nil, net.DefaultResolver is used.
Resolver *net.Resolver
// DNSCache optionally specifies a DNSCache to use.
// If nil, a DNS cache is not used.
DNSCache *dnscache.Resolver
// Logf optionally specifies a log function. If nil, logging is disabled.
Logf func(format string, args ...interface{})
@ -118,9 +119,6 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
}
func (s *Stunner) resolver() *net.Resolver {
if s.Resolver != nil {
return s.Resolver
}
return net.DefaultResolver
}
@ -192,9 +190,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
}
addr := &net.UDPAddr{Port: addrPort}
ipAddrs, err := s.resolver().LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
var ipAddrs []net.IPAddr
if s.DNSCache != nil {
ip, err := s.DNSCache.LookupIP(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
ipAddrs = []net.IPAddr{{IP: ip}}
} else {
ipAddrs, err = s.resolver().LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("lookup ip addr: %v", err)
}
}
for _, ipAddr := range ipAddrs {
ip4 := ipAddr.IP.To4()

View File

@ -31,6 +31,7 @@ import (
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/interfaces"
"tailscale.com/net/dnscache"
"tailscale.com/netcheck"
"tailscale.com/stun"
"tailscale.com/stunner"
@ -638,6 +639,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest {
c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err)
return nil
}
dc.DNSCache = dnscache.Get()
dc.TLSConfig = c.derpTLSConfig
ctx, cancel := context.WithCancel(context.Background())