diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 368a33283..ecb655fe9 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -40,6 +40,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store" "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" "tailscale.com/safesocket" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -1207,7 +1208,6 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { // TestNetstackTCPLoopback tests netstack loopback of a TCP stream, in both // directions. -// TODO(jwhited): do the same for UDP func TestNetstackTCPLoopback(t *testing.T) { tstest.Shard(t) if os.Getuid() != 0 { @@ -1216,9 +1216,9 @@ func TestNetstackTCPLoopback(t *testing.T) { env := newTestEnv(t) env.tunMode = true - loopbackPort := uint16(5201) + loopbackPort := 5201 env.loopbackPort = &loopbackPort - loopbackPortStr := strconv.Itoa(int(loopbackPort)) + loopbackPortStr := strconv.Itoa(loopbackPort) n1 := newTestNode(t, env) d1 := n1.StartDaemon() @@ -1348,6 +1348,153 @@ func TestNetstackTCPLoopback(t *testing.T) { d1.MustCleanShutdown(t) } +// TestNetstackUDPLoopback tests netstack loopback of UDP packets, in both +// directions. +func TestNetstackUDPLoopback(t *testing.T) { + tstest.Shard(t) + if os.Getuid() != 0 { + t.Skip("skipping when not root") + } + + env := newTestEnv(t) + env.tunMode = true + loopbackPort := 5201 + env.loopbackPort = &loopbackPort + n1 := newTestNode(t, env) + d1 := n1.StartDaemon() + + n1.AwaitResponding() + n1.MustUp() + + ip4 := n1.AwaitIP4() + ip6 := n1.AwaitIP6() + n1.AwaitRunning() + + cases := []struct { + pingerLAddr *net.UDPAddr + pongerLAddr *net.UDPAddr + network string + dialAddr *net.UDPAddr + }{ + { + pingerLAddr: &net.UDPAddr{IP: ip4.AsSlice(), Port: loopbackPort + 1}, + pongerLAddr: &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: loopbackPort}, + network: "udp4", + dialAddr: &net.UDPAddr{IP: tsaddr.TailscaleServiceIP().AsSlice(), Port: loopbackPort}, + }, + { + pingerLAddr: &net.UDPAddr{IP: ip6.AsSlice(), Port: loopbackPort + 1}, + pongerLAddr: &net.UDPAddr{IP: net.ParseIP("::1"), Port: loopbackPort}, + network: "udp6", + dialAddr: &net.UDPAddr{IP: tsaddr.TailscaleServiceIPv6().AsSlice(), Port: loopbackPort}, + }, + } + + writeBufSize := int(tstun.DefaultTUNMTU()) - 40 - 8 // mtu - ipv6 header - udp header + wantPongs := 100 + + for _, c := range cases { + pongerConn, err := net.ListenUDP(c.network, c.pongerLAddr) + if err != nil { + t.Fatal(err) + } + defer pongerConn.Close() + + var pingerConn *net.UDPConn + err = tstest.WaitFor(time.Second*5, func() error { + pingerConn, err = net.DialUDP(c.network, c.pingerLAddr, c.dialAddr) + return err + }) + if err != nil { + t.Fatal(err) + } + defer pingerConn.Close() + + pingerFn := func(conn *net.UDPConn) error { + b := make([]byte, writeBufSize) + n, err := conn.Write(b) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("bad write size: %d", n) + } + err = conn.SetReadDeadline(time.Now().Add(time.Millisecond * 500)) + if err != nil { + return err + } + n, err = conn.Read(b) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("bad read size: %d", n) + } + return nil + } + + pongerFn := func(conn *net.UDPConn) error { + for { + b := make([]byte, writeBufSize) + n, from, err := conn.ReadFromUDP(b) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("bad read size: %d", n) + } + n, err = conn.WriteToUDP(b, from) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("bad write size: %d", n) + } + } + } + + pongerErrCh := make(chan error, 1) + go func() { + pongerErrCh <- pongerFn(pongerConn) + }() + + err = tstest.WaitFor(time.Second*5, func() error { + err = pingerFn(pingerConn) + if err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatal(err) + } + + var pongsRX int + for { + pingerErrCh := make(chan error) + go func() { + pingerErrCh <- pingerFn(pingerConn) + }() + + select { + case err := <-pongerErrCh: + t.Fatal(err) + case err := <-pingerErrCh: + if err != nil { + t.Fatal(err) + } + } + + pongsRX++ + if pongsRX == wantPongs { + break + } + } + } + + d1.MustCleanShutdown(t) +} + // testEnv contains the test environment (set of servers) used by one // or more nodes. type testEnv struct { @@ -1355,7 +1502,7 @@ type testEnv struct { tunMode bool cli string daemon string - loopbackPort *uint16 + loopbackPort *int LogCatcher *LogCatcher LogCatcherServer *httptest.Server @@ -1657,7 +1804,7 @@ func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { "TS_DEBUG_LOG_RATE=all", ) if n.env.loopbackPort != nil { - cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(int(*n.env.loopbackPort))) + cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) } if version.IsRace() { cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1")