From 19008a3023b3a19de225b5e57633891a00098e8c Mon Sep 17 00:00:00 2001 From: Kris Brandow Date: Tue, 16 Aug 2022 14:45:46 -0400 Subject: [PATCH] net/dnscache: use net/netip Removes usage of net.IP and net.IPAddr where possible from net/dnscache. Fixes #5282 Signed-off-by: Kris Brandow --- net/dnscache/dnscache.go | 109 ++++++++++++++-------------------- net/dnscache/dnscache_test.go | 2 +- 2 files changed, 44 insertions(+), 67 deletions(-) diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index b891d0f65..2ea05a6b7 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -25,6 +25,8 @@ import ( "tailscale.com/util/singleflight" ) +var zaddr netip.Addr + var single = &Resolver{ Forward: &net.Resolver{PreferGo: preferGoResolver()}, } @@ -90,14 +92,14 @@ type Resolver struct { // ipRes is the type used by the Resolver.sf singleflight group. type ipRes struct { - ip, ip6 net.IP - allIPs []net.IPAddr + ip, ip6 netip.Addr + allIPs []netip.Addr } type ipCacheEntry struct { - ip net.IP // either v4 or v6 - ip6 net.IP // nil if no v4 or no v6 - allIPs []net.IPAddr // 1+ v4 and/or v6 + ip netip.Addr // either v4 or v6 + ip6 netip.Addr // nil if no v4 or no v6 + allIPs []netip.Addr // 1+ v4 and/or v6 expires time.Time } @@ -147,34 +149,28 @@ var debug = envknob.Bool("TS_DEBUG_DNS_CACHE") // // If err is nil, ip will be non-nil. The v6 address may be nil even // with a nil error. -func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, allIPs []net.IPAddr, err error) { +func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr, allIPs []netip.Addr, err error) { if r.SingleHostStaticResult != nil { if r.SingleHost != host { - return nil, nil, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost) + return zaddr, zaddr, nil, fmt.Errorf("dnscache: unexpected hostname %q doesn't match expected %q", host, r.SingleHost) } for _, naIP := range r.SingleHostStaticResult { - ipa := &net.IPAddr{ - IP: naIP.AsSlice(), - Zone: naIP.Zone(), + if !ip.IsValid() && naIP.Is4() { + ip = naIP } - if ip == nil && naIP.Is4() { - ip = ipa.IP + if !v6.IsValid() && naIP.Is6() { + v6 = naIP } - if v6 == nil && naIP.Is6() { - v6 = ipa.IP - } - allIPs = append(allIPs, *ipa) + allIPs = append(allIPs, naIP) } return } - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - return ip4, nil, []net.IPAddr{{IP: ip4}}, nil - } + if ip, err := netip.ParseAddr(host); err == nil { + ip = ip.Unmap() if debug { log.Printf("dnscache: %q is an IP", host) } - return ip, nil, []net.IPAddr{{IP: ip}}, nil + return ip, zaddr, []netip.Addr{ip}, nil } if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok { @@ -205,7 +201,7 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, al if debug { log.Printf("dnscache: error resolving %q: %v", host, res.Err) } - return nil, nil, nil, res.Err + return zaddr, zaddr, nil, res.Err } r := res.Val return r.ip, r.ip6, r.allIPs, nil @@ -213,26 +209,26 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 net.IP, al if debug { log.Printf("dnscache: context done while resolving %q: %v", host, ctx.Err()) } - return nil, nil, nil, ctx.Err() + return zaddr, zaddr, nil, ctx.Err() } } -func (r *Resolver) lookupIPCache(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, ok bool) { +func (r *Resolver) lookupIPCache(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, ok bool) { r.mu.Lock() defer r.mu.Unlock() if ent, ok := r.ipCache[host]; ok && ent.expires.After(time.Now()) { return ent.ip, ent.ip6, ent.allIPs, true } - return nil, nil, nil, false + return zaddr, zaddr, nil, false } -func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, ok bool) { +func (r *Resolver) lookupIPCacheExpired(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, ok bool) { r.mu.Lock() defer r.mu.Unlock() if ent, ok := r.ipCache[host]; ok { return ent.ip, ent.ip6, ent.allIPs, true } - return nil, nil, nil, false + return zaddr, zaddr, nil, false } func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { @@ -252,7 +248,7 @@ func (r *Resolver) lookupTimeoutForHost(host string) time.Duration { return 10 * time.Second } -func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, err error) { +func (r *Resolver) lookupIP(host string) (ip, ip6 netip.Addr, allIPs []netip.Addr, err error) { if ip, ip6, allIPs, ok := r.lookupIPCache(host); ok { if debug { log.Printf("dnscache: %q found in cache as %v", host, ip) @@ -262,47 +258,37 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, e ctx, cancel := context.WithTimeout(context.Background(), r.lookupTimeoutForHost(host)) defer cancel() - ips, err := r.fwd().LookupIPAddr(ctx, host) + ips, err := r.fwd().LookupNetIP(ctx, "ip", host) if err != nil || len(ips) == 0 { if resolver, ok := r.cloudHostResolver(); ok { - ips, err = resolver.LookupIPAddr(ctx, host) + ips, err = resolver.LookupNetIP(ctx, "ip", host) } } if (err != nil || len(ips) == 0) && r.LookupIPFallback != nil { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - var fips []netip.Addr - fips, err = r.LookupIPFallback(ctx, host) - if err == nil { - ips = nil - for _, fip := range fips { - ips = append(ips, net.IPAddr{ - IP: fip.AsSlice(), - Zone: fip.Zone(), - }) - } - } + ips, err = r.LookupIPFallback(ctx, host) } if err != nil { - return nil, nil, nil, err + return netip.Addr{}, netip.Addr{}, nil, err } if len(ips) == 0 { - return nil, nil, nil, fmt.Errorf("no IPs for %q found", host) + return netip.Addr{}, netip.Addr{}, nil, fmt.Errorf("no IPs for %q found", host) } have4 := false for _, ipa := range ips { - if ip4 := ipa.IP.To4(); ip4 != nil { + if ipa.Is4() { if !have4 { ip6 = ip - ip = ip4 + ip = ipa have4 = true } } else { if have4 { - ip6 = ipa.IP + ip6 = ipa } else { - ip = ipa.IP + ip = ipa } } } @@ -310,7 +296,7 @@ func (r *Resolver) lookupIP(host string) (ip, ip6 net.IP, allIPs []net.IPAddr, e return ip, ip6, ips, nil } -func (r *Resolver) addIPCache(host string, ip, ip6 net.IP, allIPs []net.IPAddr, d time.Duration) { +func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Addr, d time.Duration) { if ip.IsPrivate() { // Don't cache obviously wrong entries from captive portals. // TODO: use DoH or DoT for the forwarding resolver? @@ -399,20 +385,12 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC if debug { log.Printf("dnscache: dialing %s, %s for %s", network, ip, address) } - ipNA, ok := netip.AddrFromSlice(ip) - if !ok { - return nil, fmt.Errorf("invalid IP %q", ip) - } - c, err := dc.dialOne(ctx, ipNA.Unmap()) + c, err := dc.dialOne(ctx, ip.Unmap()) if err == nil || ctx.Err() != nil { return c, err } // Fall back to trying IPv6, if any. - ip6NA, ok := netip.AddrFromSlice(ip6) - if !ok { - return nil, err - } - return dc.dialOne(ctx, ip6NA) + return dc.dialOne(ctx, ip6) } // Multiple IPv4 candidates, and 0+ IPv6. @@ -610,21 +588,20 @@ func interleaveSlices[T any](a, b []T) []T { return ret } -func v4addrs(aa []net.IPAddr) (ret []netip.Addr) { +func v4addrs(aa []netip.Addr) (ret []netip.Addr) { for _, a := range aa { - ip, ok := netip.AddrFromSlice(a.IP) - ip = ip.Unmap() - if ok && ip.Is4() { - ret = append(ret, ip) + a = a.Unmap() + if a.Is4() { + ret = append(ret, a) } } return ret } -func v6addrs(aa []net.IPAddr) (ret []netip.Addr) { +func v6addrs(aa []netip.Addr) (ret []netip.Addr) { for _, a := range aa { - if ip, ok := netip.AddrFromSlice(a.IP); ok && ip.Is6() { - ret = append(ret, ip) + if a.Is6() && !a.Is4In6() { + ret = append(ret, a) } } return ret diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index ef09c112f..b99992148 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -131,7 +131,7 @@ func TestResolverAllHostStaticResult(t *testing.T) { if got, want := ip6.String(), "2001:4860:4860::8888"; got != want { t.Errorf("ip4 got %q; want %q", got, want) } - if got, want := fmt.Sprintf("%q", allIPs), `[{"2001:4860:4860::8888" ""} {"2001:4860:4860::8844" ""} {"8.8.8.8" ""} {"8.8.4.4" ""}]`; got != want { + if got, want := fmt.Sprintf("%q", allIPs), `["2001:4860:4860::8888" "2001:4860:4860::8844" "8.8.8.8" "8.8.4.4"]`; got != want { t.Errorf("allIPs got %q; want %q", got, want) }