152 lines
3.4 KiB
Go
152 lines
3.4 KiB
Go
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// Package dnscache contains a minimal DNS cache that makes a bunch of
|
||
|
// assumptions that are only valid for us. Not recommended for general use.
|
||
|
package dnscache
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"golang.org/x/sync/singleflight"
|
||
|
)
|
||
|
|
||
|
var single = &Resolver{
|
||
|
Forward: &net.Resolver{PreferGo: true},
|
||
|
}
|
||
|
|
||
|
// Get returns a caching Resolver singleton.
|
||
|
func Get() *Resolver { return single }
|
||
|
|
||
|
const fixedTTL = 10 * time.Minute
|
||
|
|
||
|
// Resolver is a minimal DNS caching resolver.
|
||
|
//
|
||
|
// The TTL is always fixed for now. It's not intended for general use.
|
||
|
// Cache entries are never cleaned up so it's intended that this is
|
||
|
// only used with a fixed set of hostnames.
|
||
|
type Resolver struct {
|
||
|
// Forward is the resolver to use to populate the cache.
|
||
|
// If nil, net.DefaultResolver is used.
|
||
|
Forward *net.Resolver
|
||
|
|
||
|
sf singleflight.Group
|
||
|
|
||
|
mu sync.Mutex
|
||
|
ipCache map[string]ipCacheEntry
|
||
|
}
|
||
|
|
||
|
type ipCacheEntry struct {
|
||
|
ip net.IP
|
||
|
expires time.Time
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) fwd() *net.Resolver {
|
||
|
if r.Forward != nil {
|
||
|
return r.Forward
|
||
|
}
|
||
|
return net.DefaultResolver
|
||
|
}
|
||
|
|
||
|
// LookupIP returns the first IPv4 address found, otherwise the first IPv6 address.
|
||
|
func (r *Resolver) LookupIP(ctx context.Context, host string) (net.IP, error) {
|
||
|
if ip := net.ParseIP(host); ip != nil {
|
||
|
if ip4 := ip.To4(); ip4 != nil {
|
||
|
return ip4, nil
|
||
|
}
|
||
|
return ip, nil
|
||
|
}
|
||
|
|
||
|
if ip, ok := r.lookupIPCache(host); ok {
|
||
|
return ip, nil
|
||
|
}
|
||
|
|
||
|
ch := r.sf.DoChan(host, func() (interface{}, error) {
|
||
|
ip, err := r.lookupIP(host)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return ip, nil
|
||
|
})
|
||
|
select {
|
||
|
case res := <-ch:
|
||
|
if res.Err != nil {
|
||
|
return nil, res.Err
|
||
|
}
|
||
|
return res.Val.(net.IP), nil
|
||
|
case <-ctx.Done():
|
||
|
return nil, ctx.Err()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) lookupIPCache(host string) (ip net.IP, ok bool) {
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
|
||
|
return ent.ip, true
|
||
|
}
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) lookupIP(host string) (net.IP, error) {
|
||
|
if ip, ok := r.lookupIPCache(host); ok {
|
||
|
return ip, nil
|
||
|
}
|
||
|
|
||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
|
defer cancel()
|
||
|
ips, err := r.fwd().LookupIPAddr(ctx, host)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(ips) == 0 {
|
||
|
return nil, fmt.Errorf("no IPs for %q found", host)
|
||
|
}
|
||
|
|
||
|
for _, ipa := range ips {
|
||
|
if ip4 := ipa.IP.To4(); ip4 != nil {
|
||
|
return r.addIPCache(host, ip4, fixedTTL), nil
|
||
|
}
|
||
|
}
|
||
|
return r.addIPCache(host, ips[0].IP, fixedTTL), nil
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) addIPCache(host string, ip net.IP, d time.Duration) net.IP {
|
||
|
if isPrivateIP(ip) {
|
||
|
// Don't cache obviously wrong entries from captive portals.
|
||
|
// TODO: use DoH or DoT for the forwarding resolver?
|
||
|
return ip
|
||
|
}
|
||
|
|
||
|
r.mu.Lock()
|
||
|
defer r.mu.Unlock()
|
||
|
if r.ipCache == nil {
|
||
|
r.ipCache = make(map[string]ipCacheEntry)
|
||
|
}
|
||
|
r.ipCache[host] = ipCacheEntry{ip: ip, expires: time.Now().Add(d)}
|
||
|
return ip
|
||
|
}
|
||
|
|
||
|
func mustCIDR(s string) *net.IPNet {
|
||
|
_, ipNet, err := net.ParseCIDR("100.64.0.0/10")
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return ipNet
|
||
|
}
|
||
|
|
||
|
func isPrivateIP(ip net.IP) bool {
|
||
|
return private1.Contains(ip) || private2.Contains(ip) || private3.Contains(ip)
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
private1 = mustCIDR("10.0.0.0/8")
|
||
|
private2 = mustCIDR("172.16.0.0/12")
|
||
|
private3 = mustCIDR("192.168.0.0/16")
|
||
|
)
|