net/dnscache: add overly simplistic DNS cache package for selective use
I started to write a full DNS caching resolver and I realized it was overkill and wouldn't work on Windows even in Go 1.14 yet, so I'm doing this tiny one instead for now, just for all our netcheck STUN derp lookups, and connections to DERP servers. (This will be caching a exactly 8 DNS entries, all ours.) Fixes #145 (can be better later, of course)
This commit is contained in:
parent
a36ccb8525
commit
2cff9016e4
|
@ -26,6 +26,7 @@ import (
|
|||
"time"
|
||||
|
||||
"tailscale.com/derp"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
@ -37,7 +38,8 @@ import (
|
|||
// Send/Recv will completely re-establish the connection (unless Close
|
||||
// has been called).
|
||||
type Client struct {
|
||||
TLSConfig *tls.Config // for sever connection, optional, nil means default
|
||||
TLSConfig *tls.Config // for sever connection, optional, nil means default
|
||||
DNSCache *dnscache.Resolver // optional; if nil, no caching
|
||||
|
||||
privateKey key.Private
|
||||
logf logger.Logf
|
||||
|
@ -137,11 +139,23 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
|||
}
|
||||
}()
|
||||
|
||||
host := c.url.Hostname()
|
||||
hostOrIP := host
|
||||
|
||||
var d net.Dialer
|
||||
log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
|
||||
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
|
||||
log.Printf("Dialing: %q", net.JoinHostPort(host, urlPort(c.url)))
|
||||
|
||||
if c.DNSCache != nil {
|
||||
ip, err := c.DNSCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hostOrIP = ip.String()
|
||||
}
|
||||
|
||||
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(hostOrIP, urlPort(c.url)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Dial of %q: %v", host, err)
|
||||
}
|
||||
|
||||
// Now that we have a TCP connection, force close it.
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package dnscache contains a minimal DNS cache that makes a bunch of
|
||||
// assumptions that are only valid for us. Not recommended for general use.
|
||||
package dnscache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var single = &Resolver{
|
||||
Forward: &net.Resolver{PreferGo: true},
|
||||
}
|
||||
|
||||
// Get returns a caching Resolver singleton.
|
||||
func Get() *Resolver { return single }
|
||||
|
||||
const fixedTTL = 10 * time.Minute
|
||||
|
||||
// Resolver is a minimal DNS caching resolver.
|
||||
//
|
||||
// The TTL is always fixed for now. It's not intended for general use.
|
||||
// Cache entries are never cleaned up so it's intended that this is
|
||||
// only used with a fixed set of hostnames.
|
||||
type Resolver struct {
|
||||
// Forward is the resolver to use to populate the cache.
|
||||
// If nil, net.DefaultResolver is used.
|
||||
Forward *net.Resolver
|
||||
|
||||
sf singleflight.Group
|
||||
|
||||
mu sync.Mutex
|
||||
ipCache map[string]ipCacheEntry
|
||||
}
|
||||
|
||||
type ipCacheEntry struct {
|
||||
ip net.IP
|
||||
expires time.Time
|
||||
}
|
||||
|
||||
func (r *Resolver) fwd() *net.Resolver {
|
||||
if r.Forward != nil {
|
||||
return r.Forward
|
||||
}
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
|
||||
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return ip4, nil
|
||||
}
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
if ip, ok := r.lookupIPCache(host); ok {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ch := r.sf.DoChan(host, func() (interface{}, error) {
|
||||
ip, err := r.lookupIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ip, nil
|
||||
})
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.Err != nil {
|
||||
return nil, res.Err
|
||||
}
|
||||
return res.Val.(net.IP), nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
|
||||
return ent.ip, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *Resolver) lookupIP(host string) (net.IP, error) {
|
||||
if ip, ok := r.lookupIPCache(host); ok {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
ips, err := r.fwd().LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IPs for %q found", host)
|
||||
}
|
||||
|
||||
for _, ipa := range ips {
|
||||
if ip4 := ipa.IP.To4(); ip4 != nil {
|
||||
return r.addIPCache(host, ip4, fixedTTL), nil
|
||||
}
|
||||
}
|
||||
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
|
||||
}
|
||||
|
||||
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
|
||||
if isPrivateIP(ip) {
|
||||
// Don't cache obviously wrong entries from captive portals.
|
||||
// TODO: use DoH or DoT for the forwarding resolver?
|
||||
return ip
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.ipCache == nil {
|
||||
r.ipCache = make(map[string]ipCacheEntry)
|
||||
}
|
||||
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)}
|
||||
return ip
|
||||
}
|
||||
|
||||
func mustCIDR(s string) *net.IPNet {
|
||||
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ipNet
|
||||
}
|
||||
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
|
||||
}
|
||||
|
||||
var (
|
||||
private1 = mustCIDR("10.0.0.0/8")
|
||||
private2 = mustCIDR("172.16.0.0/12")
|
||||
private3 = mustCIDR("192.168.0.0/16")
|
||||
)
|
|
@ -17,6 +17,7 @@ import (
|
|||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"tailscale.com/interfaces"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/stun"
|
||||
"tailscale.com/stunner"
|
||||
"tailscale.com/types/logger"
|
||||
|
@ -181,6 +182,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
|
|||
Endpoint: add,
|
||||
Servers: stunServers,
|
||||
Logf: logf,
|
||||
DNSCache: dnscache.Get(),
|
||||
}
|
||||
grp.Go(func() error { return s4.Run(ctx) })
|
||||
go reader(s4, pc4, unlimited)
|
||||
|
@ -190,6 +192,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
|
|||
Endpoint: addHair,
|
||||
Servers: stunServers,
|
||||
Logf: logf,
|
||||
DNSCache: dnscache.Get(),
|
||||
}
|
||||
grp.Go(func() error { return s4Hair.Run(ctx) })
|
||||
go reader(s4Hair, pc4Hair, 2)
|
||||
|
@ -201,6 +204,7 @@ func GetReport(ctx context.Context, logf logger.Logf) (*Report, error) {
|
|||
Servers: stunServers6,
|
||||
Logf: logf,
|
||||
OnlyIPv6: true,
|
||||
DNSCache: dnscache.Get(),
|
||||
}
|
||||
grp.Go(func() error { return s6.Run(ctx) })
|
||||
go reader(s6, pc6, unlimited)
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/stun"
|
||||
)
|
||||
|
||||
|
@ -38,9 +39,9 @@ type Stunner struct {
|
|||
|
||||
Servers []string // STUN servers to contact
|
||||
|
||||
// Resolver optionally specifies a resolver to use for DNS lookups.
|
||||
// If nil, net.DefaultResolver is used.
|
||||
Resolver *net.Resolver
|
||||
// DNSCache optionally specifies a DNSCache to use.
|
||||
// If nil, a DNS cache is not used.
|
||||
DNSCache *dnscache.Resolver
|
||||
|
||||
// Logf optionally specifies a log function. If nil, logging is disabled.
|
||||
Logf func(format string, args ...interface{})
|
||||
|
@ -118,9 +119,6 @@ func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
|
|||
}
|
||||
|
||||
func (s *Stunner) resolver() *net.Resolver {
|
||||
if s.Resolver != nil {
|
||||
return s.Resolver
|
||||
}
|
||||
return net.DefaultResolver
|
||||
}
|
||||
|
||||
|
@ -192,9 +190,18 @@ func (s *Stunner) sendSTUN(ctx context.Context, server string) error {
|
|||
}
|
||||
addr := &net.UDPAddr{Port: addrPort}
|
||||
|
||||
ipAddrs, err := s.resolver().LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup ip addr: %v", err)
|
||||
var ipAddrs []net.IPAddr
|
||||
if s.DNSCache != nil {
|
||||
ip, err := s.DNSCache.LookupIP(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup ip addr: %v", err)
|
||||
}
|
||||
ipAddrs = []net.IPAddr{{IP: ip}}
|
||||
} else {
|
||||
ipAddrs, err = s.resolver().LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("lookup ip addr: %v", err)
|
||||
}
|
||||
}
|
||||
for _, ipAddr := range ipAddrs {
|
||||
ip4 := ipAddr.IP.To4()
|
||||
|
|
|
@ -31,6 +31,7 @@ import (
|
|||
"tailscale.com/derp"
|
||||
"tailscale.com/derp/derphttp"
|
||||
"tailscale.com/interfaces"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/netcheck"
|
||||
"tailscale.com/stun"
|
||||
"tailscale.com/stunner"
|
||||
|
@ -638,6 +639,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest {
|
|||
c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err)
|
||||
return nil
|
||||
}
|
||||
dc.DNSCache = dnscache.Get()
|
||||
dc.TLSConfig = c.derpTLSConfig
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
|
Loading…
Reference in New Issue