diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index b0537e4a3..1ac4cdc6b 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -127,7 +127,7 @@ func (c *Client) vlogf(format string, a ...interface{}) { // handleHairSTUN reports whether pkt (from src) was our magic hairpin // probe packet that we sent to ourselves. -func (c *Client) handleHairSTUNLocked(pkt []byte, src *net.UDPAddr) bool { +func (c *Client) handleHairSTUNLocked(pkt []byte, src netaddr.IPPort) bool { rs := c.curState if rs == nil { return false @@ -150,11 +150,7 @@ func (c *Client) MakeNextReportFull() { c.mu.Unlock() } -func (c *Client) ReceiveSTUNPacket(pkt []byte, src *net.UDPAddr) { - if src == nil || src.IP == nil { - panic("bogus src") - } - +func (c *Client) ReceiveSTUNPacket(pkt []byte, src netaddr.IPPort) { c.mu.Lock() if c.handleHairSTUNLocked(pkt, src) { c.mu.Unlock() @@ -421,7 +417,9 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) { if !stun.Is(pkt) { continue } - c.ReceiveSTUNPacket(pkt, ua) + if ipp, ok := netaddr.FromStdAddr(ua.IP, ua.Port, ua.Zone); ok { + c.ReceiveSTUNPacket(pkt, ipp) + } } } @@ -429,7 +427,7 @@ func (c *Client) readPackets(ctx context.Context, pc net.PacketConn) { type reportState struct { c *Client hairTX stun.TxID - gotHairSTUN chan *net.UDPAddr + gotHairSTUN chan netaddr.IPPort hairTimeout chan struct{} // closed on timeout pc4 STUNConn pc6 STUNConn @@ -638,7 +636,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (*Report, e report: newReport(), inFlight: map[stun.TxID]func(netaddr.IPPort){}, hairTX: stun.NewTxID(), // random payload - gotHairSTUN: make(chan *net.UDPAddr, 1), + gotHairSTUN: make(chan netaddr.IPPort, 1), hairTimeout: make(chan struct{}), stopProbeCh: make(chan struct{}, 1), } diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 7b2154b77..34a12c19a 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -16,6 +16,7 @@ import ( "testing" "time" + "inet.af/netaddr" "tailscale.com/net/interfaces" "tailscale.com/net/stun" "tailscale.com/net/stun/stuntest" @@ -27,14 +28,14 @@ func TestHairpinSTUN(t *testing.T) { c := &Client{ curState: &reportState{ hairTX: tx, - gotHairSTUN: make(chan *net.UDPAddr, 1), + gotHairSTUN: make(chan netaddr.IPPort, 1), }, } req := stun.Request(tx) if !stun.Is(req) { t.Fatal("expected STUN message") } - if !c.handleHairSTUNLocked(req, nil) { + if !c.handleHairSTUNLocked(req, netaddr.IPPort{}) { t.Fatal("expected true") } select { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index cd4635d7a..a8003b4ae 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -310,7 +310,7 @@ func (c *Conn) donec() <-chan struct{} { return c.connCtx.Done() } // ignoreSTUNPackets sets a STUN packet processing func that does nothing. func (c *Conn) ignoreSTUNPackets() { - c.stunReceiveFunc.Store(func([]byte, *net.UDPAddr) {}) + c.stunReceiveFunc.Store(func([]byte, netaddr.IPPort) {}) } // c.mu must NOT be held. @@ -1198,11 +1198,15 @@ func (c *Conn) awaitUDP4(b []byte) { return } addr := pAddr.(*net.UDPAddr) - if stun.Is(b[:n]) { - c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr) + ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) + if !ok { continue } - if c.handleDiscoMessage(b[:n], addr) { + if stun.Is(b[:n]) { + c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b[:n], ipp) + continue + } + if c.handleDiscoMessage(b[:n], ipp) { continue } @@ -1276,7 +1280,7 @@ Top: } addr := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} - if c.handleDiscoMessage(b[:n], addr.UDPAddr()) { + if c.handleDiscoMessage(b[:n], addr) { goto Top } @@ -1334,11 +1338,15 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { return 0, nil, nil, err } addr := pAddr.(*net.UDPAddr) - if stun.Is(b[:n]) { - c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b[:n], addr) + ipp, ok := netaddr.FromStdAddr(addr.IP, addr.Port, addr.Zone) + if !ok { continue } - if c.handleDiscoMessage(b[:n], addr) { + if stun.Is(b[:n]) { + c.stunReceiveFunc.Load().(func([]byte, netaddr.IPPort))(b[:n], ipp) + continue + } + if c.handleDiscoMessage(b[:n], ipp) { continue } @@ -1359,7 +1367,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { // // For messages received over DERP, the addr will be derpMagicIP (with // port being the region) -func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool { +func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { const magic = "TS💬" const nonceLen = 24 const headerLen = len(magic) + len(tailcfg.DiscoKey{}) + nonceLen @@ -1369,11 +1377,6 @@ func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool { var sender tailcfg.DiscoKey copy(sender[:], msg[len(magic):]) - srca, ok := netaddr.FromStdAddr(src.IP, src.Port, src.Zone) - if !ok { - return false - } - c.mu.Lock() defer c.mu.Unlock() @@ -1421,11 +1424,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src *net.UDPAddr) bool { switch dm := dm.(type) { case *disco.Ping: - c.handlePingLocked(dm, senderNode, sender, srca) + c.handlePingLocked(dm, senderNode, sender, src) case *disco.Pong: - c.handlePongLocked(dm, senderNode, sender, srca) + c.handlePongLocked(dm, senderNode, sender, src) case disco.CallMeMaybe: - if srca.IP != derpMagicIPAddr { + if src.IP != derpMagicIPAddr { // CallMeMaybe messages should only come via DERP. c.logf("[unexpected] CallMeMaybe packets should only come via DERP") return true diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 1b8693b75..0524d5d02 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -873,7 +873,7 @@ func TestDiscoMessage(t *testing.T) { pkt = append(pkt, nonce[:]...) pkt = box.Seal(pkt, []byte(payload), &nonce, c.discoPrivate.Public().B32(), peer1Priv.B32()) - got := c.handleDiscoMessage(pkt, &net.UDPAddr{IP: net.ParseIP("1.2.3.4")}) + got := c.handleDiscoMessage(pkt, netaddr.IPPort{}) if !got { t.Error("failed to open it") }