net/dnscache: use net/netip

Removes usage of net.IP and net.IPAddr where possible from net/dnscache.

Fixes #5282

Signed-off-by: Kris Brandow <kris.brandow@gmail.com>
This commit is contained in:
Kris Brandow 2022-08-16 14:45:46 -04:00
parent ba3cc08b62
commit 19008a3023
2 changed files with 44 additions and 67 deletions

View File

@ -25,6 +25,8 @@ import (
"tailscale.com/util/singleflight" "tailscale.com/util/singleflight"
) )
var zaddr netip.Addr
var single = &Resolver{ var single = &Resolver{
Forward: &net.Resolver{PreferGo: preferGoResolver()}, Forward: &net.Resolver{PreferGo: preferGoResolver()},
} }
@ -90,14 +92,14 @@ type Resolver struct {
// ipRes is the type used by the Resolver.sf singleflight group. // ipRes is the type used by the Resolver.sf singleflight group.
type ipRes struct { type ipRes struct {
ip, ip6 net.IP ip, ip6 netip.Addr
allIPs []net.IPAddr allIPs []netip.Addr
} }
type ipCacheEntry struct { type ipCacheEntry struct {
ip net.IP // either v4 or v6 ip netip.Addr // either v4 or v6
ip6 net.IP // nil if no v4 or no v6 ip6 netip.Addr // nil if no v4 or no v6
allIPs []net.IPAddr // 1+ v4 and/or v6 allIPs []netip.Addr // 1+ v4 and/or v6
expires time.Time expires time.Time
} }
@ -147,34 +149,28 @@ var debug = envknob.Bool("TS_DEBUG_DNS_CACHE")
// //
// If err is nil, ip will be non-nil. The v6 address may be nil even // If err is nil, ip will be non-nil. The v6 address may be nil even
// with a nil error. // with a nil error.
func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, allIPs []net.IPAddr, err error) { func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr, allIPs []netip.Addr, err error) {
if r.SingleHostStaticResult != nil { if r.SingleHostStaticResult != nil {
if r.SingleHost != host { if r.SingleHost != host {
return nil, nil, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost) return zaddr, zaddr, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost)
} }
for _, naIP := range r.SingleHostStaticResult { for _, naIP := range r.SingleHostStaticResult {
ipa := &net.IPAddr{ if !ip.IsValid() && naIP.Is4() {
IP: naIP.AsSlice(), ip = naIP
Zone: naIP.Zone(),
} }
if ip == nil && naIP.Is4() { if !v6.IsValid() && naIP.Is6() {
ip = ipa.IP v6 = naIP
} }
if v6 == nil && naIP.Is6() { allIPs = append(allIPs, naIP)
v6 = ipa.IP
}
allIPs = append(allIPs, *ipa)
} }
return return
} }
if ip := net.ParseIP(host); ip != nil { if ip, err := netip.ParseAddr(host); err == nil {
if ip4 := ip.To4(); ip4 != nil { ip = ip.Unmap()
return ip4, nil, []net.IPAddr{{IP: ip4}}, nil
}
if debug { if debug {
log.Printf("dnscache: %q is an IP", host) log.Printf("dnscache: %q is an IP", host)
} }
return ip, nil, []net.IPAddr{{IP: ip}}, nil return ip, zaddr, []netip.Addr{ip}, nil
} }
if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok { if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok {
@ -205,7 +201,7 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, al
if debug { if debug {
log.Printf("dnscache: error resolving %q: %v", host, res.Err) log.Printf("dnscache: error resolving %q: %v", host, res.Err)
} }
return nil, nil, nil, res.Err return zaddr, zaddr, nil, res.Err
} }
r := res.Val r := res.Val
return r.ip, r.ip6, r.allIPs, nil return r.ip, r.ip6, r.allIPs, nil
@ -213,26 +209,26 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, al
if debug { if debug {
log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err()) log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err())
} }
return nil, nil, nil, ctx.Err() return zaddr, zaddr, nil, ctx.Err()
} }
} }
func (r *Resolver) lookupIPCache(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, ok bool) { func (r *Resolver) lookupIPCache(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, ok bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) {
return ent.ip, ent.ip6, ent.allIPs, true return ent.ip, ent.ip6, ent.allIPs, true
} }
return nil, nil, nil, false return zaddr, zaddr, nil, false
} }
func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, ok bool) { func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, ok bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if ent, ok := r.ipCache[host]; ok { if ent, ok := r.ipCache[host]; ok {
return ent.ip, ent.ip6, ent.allIPs, true return ent.ip, ent.ip6, ent.allIPs, true
} }
return nil, nil, nil, false return zaddr, zaddr, nil, false
} }
func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
@ -252,7 +248,7 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration {
return 10 * time.Second return 10 * time.Second
} }
func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, err error) { func (r *Resolver) lookupIP(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, err error) {
if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok { if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok {
if debug { if debug {
log.Printf("dnscache: %q found in cache as %v", host, ip) log.Printf("dnscache: %q found in cache as %v", host, ip)
@ -262,47 +258,37 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, e
ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host))
defer cancel() defer cancel()
ips, err := r.fwd().LookupIPAddr(ctx, host) ips, err := r.fwd().LookupNetIP(ctx, "ip", host)
if err != nil || len(ips) == 0 { if err != nil || len(ips) == 0 {
if resolver, ok := r.cloudHostResolver(); ok { if resolver, ok := r.cloudHostResolver(); ok {
ips, err = resolver.LookupIPAddr(ctx, host) ips, err = resolver.LookupNetIP(ctx, "ip", host)
} }
} }
if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil { if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
var fips []netip.Addr ips, err = r.LookupIPFallback(ctx, host)
fips, err = r.LookupIPFallback(ctx, host)
if err == nil {
ips = nil
for _, fip := range fips {
ips = append(ips, net.IPAddr{
IP: fip.AsSlice(),
Zone: fip.Zone(),
})
}
}
} }
if err != nil { if err != nil {
return nil, nil, nil, err return netip.Addr{}, netip.Addr{}, nil, err
} }
if len(ips) == 0 { if len(ips) == 0 {
return nil, nil, nil, fmt.Errorf("no IPs for %q found", host) return netip.Addr{}, netip.Addr{}, nil, fmt.Errorf("no IPs for %q found", host)
} }
have4 := false have4 := false
for _, ipa := range ips { for _, ipa := range ips {
if ip4 := ipa.IP.To4(); ip4 != nil { if ipa.Is4() {
if !have4 { if !have4 {
ip6 = ip ip6 = ip
ip = ip4 ip = ipa
have4 = true have4 = true
} }
} else { } else {
if have4 { if have4 {
ip6 = ipa.IP ip6 = ipa
} else { } else {
ip = ipa.IP ip = ipa
} }
} }
} }
@ -310,7 +296,7 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, e
return ip, ip6, ips, nil return ip, ip6, ips, nil
} }
func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, allIPs []net.IPAddr, d time.Duration) { func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Addr, d time.Duration) {
if ip.IsPrivate() { if ip.IsPrivate() {
// Don't cache obviously wrong entries from captive portals. // Don't cache obviously wrong entries from captive portals.
// TODO: use DoH or DoT for the forwarding resolver? // TODO: use DoH or DoT for the forwarding resolver?
@ -399,20 +385,12 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC
if debug { if debug {
log.Printf("dnscache: dialing %s, %s for %s", network, ip, address) log.Printf("dnscache: dialing %s, %s for %s", network, ip, address)
} }
ipNA, ok := netip.AddrFromSlice(ip) c, err := dc.dialOne(ctx, ip.Unmap())
if !ok {
return nil, fmt.Errorf("invalid IP %q", ip)
}
c, err := dc.dialOne(ctx, ipNA.Unmap())
if err == nil || ctx.Err() != nil { if err == nil || ctx.Err() != nil {
return c, err return c, err
} }
// Fall back to trying IPv6, if any. // Fall back to trying IPv6, if any.
ip6NA, ok := netip.AddrFromSlice(ip6) return dc.dialOne(ctx, ip6)
if !ok {
return nil, err
}
return dc.dialOne(ctx, ip6NA)
} }
// Multiple IPv4 candidates, and 0+ IPv6. // Multiple IPv4 candidates, and 0+ IPv6.
@ -610,21 +588,20 @@ func interleaveSlices[T any](a, b []T) []T {
return ret return ret
} }
func v4addrs(aa []net.IPAddr) (ret []netip.Addr) { func v4addrs(aa []netip.Addr) (ret []netip.Addr) {
for _, a := range aa { for _, a := range aa {
ip, ok := netip.AddrFromSlice(a.IP) a = a.Unmap()
ip = ip.Unmap() if a.Is4() {
if ok && ip.Is4() { ret = append(ret, a)
ret = append(ret, ip)
} }
} }
return ret return ret
} }
func v6addrs(aa []net.IPAddr) (ret []netip.Addr) { func v6addrs(aa []netip.Addr) (ret []netip.Addr) {
for _, a := range aa { for _, a := range aa {
if ip, ok := netip.AddrFromSlice(a.IP); ok && ip.Is6() { if a.Is6() && !a.Is4In6() {
ret = append(ret, ip) ret = append(ret, a)
} }
} }
return ret return ret

View File

@ -131,7 +131,7 @@ func TestResolverAllHostStaticResult(t *testing.T) {
if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { if got, want := ip6.String(), "2001:4860:4860::8888"; got != want {
t.Errorf("ip4 got %q; want %q", got, want) t.Errorf("ip4 got %q; want %q", got, want)
} }
if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want { if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want {
t.Errorf("allIPs got %q; want %q", got, want) t.Errorf("allIPs got %q; want %q", got, want)
} }