diff --git a/net/portmapper/igd_test.go b/net/portmapper/igd_test.go index 2defb21f6..3268aee6d 100644 --- a/net/portmapper/igd_test.go +++ b/net/portmapper/igd_test.go @@ -49,10 +49,11 @@ func NewTestIGD() (*TestIGD, error) { doUPnP: true, } var err error - if d.upnpConn, err = net.ListenPacket("udp", "127.0.0.1:1900"); err != nil { + if d.upnpConn, err = testListenUDP(); err != nil { return nil, err } - if d.pxpConn, err = net.ListenPacket("udp", "127.0.0.1:5351"); err != nil { + if d.pxpConn, err = testListenUDP(); err != nil { + d.upnpConn.Close() return nil, err } d.ts = httptest.NewServer(http.HandlerFunc(d.serveUPnPHTTP)) @@ -61,6 +62,18 @@ func NewTestIGD() (*TestIGD, error) { return d, nil } +func testListenUDP() (net.PacketConn, error) { + return net.ListenPacket("udp4", "127.0.0.1:0") +} + +func (d *TestIGD) TestPxPPort() uint16 { + return uint16(d.pxpConn.LocalAddr().(*net.UDPAddr).Port) +} + +func (d *TestIGD) TestUPnPPort() uint16 { + return uint16(d.upnpConn.LocalAddr().(*net.UDPAddr).Port) +} + func (d *TestIGD) Close() error { d.ts.Close() d.upnpConn.Close() diff --git a/net/portmapper/pcp.go b/net/portmapper/pcp.go index 8fb8e1384..9549d5e0e 100644 --- a/net/portmapper/pcp.go +++ b/net/portmapper/pcp.go @@ -12,7 +12,6 @@ import ( "time" "inet.af/netaddr" - "tailscale.com/net/netns" ) // References: @@ -22,8 +21,8 @@ import ( // PCP constants const ( - pcpVersion = 2 - pcpPort = 5351 + pcpVersion = 2 + pcpDefaultPort = 5351 pcpMapLifetimeSec = 7200 // TODO does the RFC recommend anything? This is taken from PMP. @@ -39,7 +38,8 @@ const ( ) type pcpMapping struct { - gw netaddr.IP + c *Client + gw netaddr.IPPort internal netaddr.IPPort external netaddr.IPPort @@ -54,13 +54,13 @@ func (p *pcpMapping) GoodUntil() time.Time { return p.goodUntil } func (p *pcpMapping) RenewAfter() time.Time { return p.renewAfter } func (p *pcpMapping) External() netaddr.IPPort { return p.external } func (p *pcpMapping) Release(ctx context.Context) { - uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0") + uc, err := p.c.listenPacket(ctx, "udp4", ":0") if err != nil { return } defer uc.Close() pkt := buildPCPRequestMappingPacket(p.internal.IP(), p.internal.Port(), p.external.Port(), 0, p.external.IP()) - uc.WriteTo(pkt, netaddr.IPPortFrom(p.gw, pcpPort).UDPAddr()) + uc.WriteTo(pkt, p.gw.UDPAddr()) } // buildPCPRequestMappingPacket generates a PCP packet with a MAP opcode. @@ -95,6 +95,8 @@ func buildPCPRequestMappingPacket( return pkt } +// parsePCPMapResponse parses resp into a partially populated pcpMapping. +// In particular, its Client is not populated. func parsePCPMapResponse(resp []byte) (*pcpMapping, error) { if len(resp) < 60 { return nil, fmt.Errorf("Does not appear to be PCP MAP response") diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 33fb805ac..e4ca98d83 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -14,6 +14,7 @@ import ( "io" "net" "net/http" + "os" "sync" "time" @@ -55,6 +56,8 @@ type Client struct { logf logger.Logf ipAndGateway func() (gw, ip netaddr.IP, ok bool) onChange func() // or nil + testPxPPort uint16 // if non-zero, pxpPort to use for tests + testUPnPPort uint16 // if non-zero, uPnPPort to use for tests mu sync.Mutex // guards following, and all fields thereof @@ -113,7 +116,8 @@ func (c *Client) HaveMapping() bool { // // All fields are immutable once created. type pmpMapping struct { - gw netaddr.IP + c *Client + gw netaddr.IPPort external netaddr.IPPort internal netaddr.IPPort renewAfter time.Time // the time at which we want to renew the mapping @@ -132,13 +136,13 @@ func (p *pmpMapping) External() netaddr.IPPort { return p.external } // Release does a best effort fire-and-forget release of the PMP mapping m. func (m *pmpMapping) Release(ctx context.Context) { - uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0") + uc, err := m.c.listenPacket(ctx, "udp4", ":0") if err != nil { return } defer uc.Close() pkt := buildPMPRequestMappingPacket(m.internal.Port(), m.external.Port(), pmpMapLifetimeDelete) - uc.WriteTo(pkt, netaddr.IPPortFrom(m.gw, pmpPort).UDPAddr()) + uc.WriteTo(pkt, m.gw.UDPAddr()) } // NewClient returns a new portmapping client. @@ -213,6 +217,32 @@ func (c *Client) gatewayAndSelfIP() (gw, myIP netaddr.IP, ok bool) { return } +// pxpPort returns the NAT-PMP and PCP port number. +// It returns 5351, except for in tests where it varies by run. +func (c *Client) pxpPort() uint16 { + if c.testPxPPort != 0 { + return c.testPxPPort + } + return pmpDefaultPort +} + +// upnpPort returns the UPnP discovery port number. +// It returns 1900, except for in tests where it varies by run. +func (c *Client) upnpPort() uint16 { + if c.testUPnPPort != 0 { + return c.testUPnPPort + } + return upnpDefaultPort +} + +func (c *Client) listenPacket(ctx context.Context, network, addr string) (net.PacketConn, error) { + if (c.testPxPPort != 0 || c.testUPnPPort != 0) && os.Getenv("GITHUB_ACTIONS") == "true" { + var lc net.ListenConfig + return lc.ListenPacket(ctx, network, addr) + } + return netns.Listener().ListenPacket(ctx, network, addr) +} + func (c *Client) invalidateMappingsLocked(releaseOld bool) { if c.mapping != nil { if releaseOld { @@ -399,7 +429,8 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor // PCP returns all the information necessary for a mapping in a single packet, so we can // construct it upon receiving that packet. m := &pmpMapping{ - gw: gw, + c: c, + gw: netaddr.IPPortFrom(gw, c.pxpPort()), internal: internalAddr, } if haveRecentPMP { @@ -415,7 +446,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor } c.mu.Unlock() - uc, err := netns.Listener().ListenPacket(ctx, "udp4", ":0") + uc, err := c.listenPacket(ctx, "udp4", ":0") if err != nil { return netaddr.IPPort{}, err } @@ -424,7 +455,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor uc.SetReadDeadline(time.Now().Add(portMapServiceTimeout)) defer closeCloserOnContextDone(ctx, uc)() - pxpAddr := netaddr.IPPortFrom(gw, pmpPort) + pxpAddr := netaddr.IPPortFrom(gw, c.pxpPort()) pxpAddru := pxpAddr.UDPAddr() preferPCP := !DisablePCP && (DisablePMP || (!haveRecentPMP && haveRecentPCP)) @@ -499,8 +530,9 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netaddr.IPPor // PCP should only have a single packet response return netaddr.IPPort{}, NoMappingError{ErrNoPortMappingServices} } + pcpMapping.c = c pcpMapping.internal = m.internal - pcpMapping.gw = gw + pcpMapping.gw = netaddr.IPPortFrom(gw, c.pxpPort()) c.mu.Lock() defer c.mu.Unlock() c.mapping = pcpMapping @@ -524,7 +556,7 @@ type pmpResultCode uint16 // NAT-PMP constants. const ( - pmpPort = 5351 + pmpDefaultPort = 5351 pmpMapLifetimeSec = 7200 // RFC recommended 2 hour map duration pmpMapLifetimeDelete = 0 // 0 second lifetime deletes @@ -622,7 +654,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { } }() - uc, err := netns.Listener().ListenPacket(context.Background(), "udp4", ":0") + uc, err := c.listenPacket(context.Background(), "udp4", ":0") if err != nil { c.logf("ProbePCP: %v", err) return res, err @@ -632,9 +664,8 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { defer cancel() defer closeCloserOnContextDone(ctx, uc)() - pcpAddr := netaddr.IPPortFrom(gw, pcpPort).UDPAddr() - pmpAddr := netaddr.IPPortFrom(gw, pmpPort).UDPAddr() - upnpAddr := netaddr.IPPortFrom(gw, upnpPort).UDPAddr() + pxpAddr := netaddr.IPPortFrom(gw, c.pxpPort()).UDPAddr() + upnpAddr := netaddr.IPPortFrom(gw, c.upnpPort()).UDPAddr() // Don't send probes to services that we recently learned (for // the same gw/myIP) are available. See @@ -642,12 +673,12 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { if c.sawPMPRecently() { res.PMP = true } else if !DisablePMP { - uc.WriteTo(pmpReqExternalAddrPacket, pmpAddr) + uc.WriteTo(pmpReqExternalAddrPacket, pxpAddr) } if c.sawPCPRecently() { res.PCP = true } else if !DisablePCP { - uc.WriteTo(pcpAnnounceRequest(myIP), pcpAddr) + uc.WriteTo(pcpAnnounceRequest(myIP), pxpAddr) } if c.sawUPnPRecently() { res.UPnP = true @@ -669,9 +700,9 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { } return res, err } - port := addr.(*net.UDPAddr).Port + port := uint16(addr.(*net.UDPAddr).Port) switch port { - case upnpPort: + case c.upnpPort(): if mem.Contains(mem.B(buf[:n]), mem.S(":InternetGatewayDevice:")) { meta, err := parseUPnPDiscoResponse(buf[:n]) if err != nil { @@ -686,7 +717,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { c.uPnPMeta = meta c.mu.Unlock() } - case pcpPort: // same as pmpPort + case c.pxpPort(): // same value for PMP and PCP if pres, ok := parsePCPResponse(buf[:n]); ok { if pres.OpCode == pcpOpReply|pcpOpAnnounce { pcpHeard = true @@ -729,7 +760,7 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { var pmpReqExternalAddrPacket = []byte{pmpVersion, pmpOpMapPublicAddr} // 0, 0 const ( - upnpPort = 1900 // for UDP discovery only; TCP port discovered later + upnpDefaultPort = 1900 // for UDP discovery only; TCP port discovered later ) // uPnPPacket is the UPnP UDP discovery packet's request body. diff --git a/net/portmapper/portmapper_test.go b/net/portmapper/portmapper_test.go index 7d5d77f1f..4c3026e81 100644 --- a/net/portmapper/portmapper_test.go +++ b/net/portmapper/portmapper_test.go @@ -7,6 +7,7 @@ package portmapper import ( "context" "os" + "reflect" "strconv" "testing" "time" @@ -72,7 +73,9 @@ func TestProbeIntegration(t *testing.T) { logf("portmapping changed.") logf("have mapping: %v", c.HaveMapping()) }) - + c.testPxPPort = igd.TestPxPPort() + c.testUPnPPort = igd.TestUPnPPort() + t.Logf("Listening on pxp=%v, upnp=%v", c.testPxPPort, c.testUPnPPort) c.SetGatewayLookupFunc(func() (gw, self netaddr.IP, ok bool) { return netaddr.IPv4(127, 0, 0, 1), netaddr.IPv4(1, 2, 3, 4), true }) @@ -81,7 +84,21 @@ func TestProbeIntegration(t *testing.T) { if err != nil { t.Fatalf("Probe: %v", err) } + if !res.UPnP { + t.Errorf("didn't detect UPnP") + } + st := igd.stats() + want := igdCounters{ + numUPnPDiscoRecv: 1, + numPMPRecv: 1, + numPCPRecv: 1, + numPMPPublicAddrRecv: 1, + } + if !reflect.DeepEqual(st, want) { + t.Errorf("unexpected stats:\n got: %+v\nwant: %+v", st, want) + } + t.Logf("Probe: %+v", res) - t.Logf("IGD stats: %+v", igd.stats()) + t.Logf("IGD stats: %+v", st) // TODO(bradfitz): finish }