magicsock: simple ping test via magicsock

Passes `go test -count=20 -race ./wgengine/magicsock`

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw 2020-03-03 10:39:40 -05:00 committed by David Crawshaw
parent 34859f8e7d
commit 9f584414d9
1 changed files with 175 additions and 0 deletions

View File

@ -5,6 +5,7 @@
package magicsock package magicsock
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
"net" "net"
@ -13,6 +14,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun/tuntest"
"github.com/tailscale/wireguard-go/wgcfg"
"tailscale.com/stun" "tailscale.com/stun"
) )
@ -179,3 +183,174 @@ func runSTUN(pc net.PacketConn, stats *stunStats) {
} }
} }
} }
func makeConfigs(t *testing.T, ports []uint16) []wgcfg.Config {
t.Helper()
var privKeys []wgcfg.PrivateKey
var addresses [][]wgcfg.CIDR
for i := range ports {
privKey, err := wgcfg.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
privKeys = append(privKeys, privKey)
addresses = append(addresses, []wgcfg.CIDR{
parseCIDR(t, fmt.Sprintf("1.0.0.%d/32", i+1)),
})
}
var cfgs []wgcfg.Config
for i, port := range ports {
cfg := wgcfg.Config{
Name: fmt.Sprintf("peer%d", i+1),
PrivateKey: privKeys[i],
Addresses: addresses[i],
ListenPort: port,
}
for peerNum, port := range ports {
if peerNum == i {
continue
}
peer := wgcfg.Peer{
PublicKey: privKeys[peerNum].Public(),
AllowedIPs: addresses[peerNum],
Endpoints: []wgcfg.Endpoint{{
Host: "127.0.0.1",
Port: port,
}},
}
cfg.Peers = append(cfg.Peers, peer)
}
cfgs = append(cfgs, cfg)
}
return cfgs
}
func parseCIDR(t *testing.T, addr string) wgcfg.CIDR {
t.Helper()
cidr, err := wgcfg.ParseCIDR(addr)
if err != nil {
t.Fatal(err)
}
return *cidr
}
func TestTwoDevicePing(t *testing.T) {
stunAddr, stunCleanupFn := serveSTUN(t)
defer stunCleanupFn()
epCh1 := make(chan []string, 16)
conn1, err := Listen(Options{
STUN: []string{stunAddr.String()},
EndpointsFunc: func(eps []string) {
epCh1 <- eps
},
})
if err != nil {
t.Fatal(err)
}
defer conn1.Close()
epCh2 := make(chan []string, 16)
conn2, err := Listen(Options{
STUN: []string{stunAddr.String()},
EndpointsFunc: func(eps []string) {
epCh2 <- eps
},
})
if err != nil {
t.Fatal(err)
}
defer conn2.Close()
ports := []uint16{conn1.LocalPort(), conn2.LocalPort()}
cfgs := makeConfigs(t, ports)
uapi1, _ := cfgs[0].ToUAPI()
t.Logf("cfg0: %v", uapi1)
uapi2, _ := cfgs[1].ToUAPI()
t.Logf("cfg1: %v", uapi2)
tun1 := tuntest.NewChannelTUN()
dev1 := device.NewDevice(tun1.TUN(), &device.DeviceOptions{
Logger: device.NewLogger(device.LogLevelDebug, "dev1: "),
CreateEndpoint: conn1.CreateEndpoint,
CreateBind: conn1.CreateBind,
SkipBindUpdate: true,
})
dev1.Up()
//defer dev1.Close() TODO(crawshaw): this hangs
if err := dev1.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err)
}
tun2 := tuntest.NewChannelTUN()
dev2 := device.NewDevice(tun2.TUN(), &device.DeviceOptions{
Logger: device.NewLogger(device.LogLevelDebug, "dev2: "),
CreateEndpoint: conn2.CreateEndpoint,
CreateBind: conn2.CreateBind,
SkipBindUpdate: true,
})
dev2.Up()
//defer dev2.Close() TODO(crawshaw): this hangs
if err := dev2.Reconfig(&cfgs[1]); err != nil {
t.Fatal(err)
}
ping1 := func(t *testing.T) {
t.Helper()
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(1 * time.Second):
t.Error("ping did not transit")
}
}
ping2 := func(t *testing.T) {
t.Helper()
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(1 * time.Second):
t.Error("return ping did not transit")
}
}
t.Run("ping 1.0.0.1", func(t *testing.T) { ping1(t) })
t.Run("ping 1.0.0.2", func(t *testing.T) { ping2(t) })
t.Run("ping 1.0.0.2 via SendPacket", func(t *testing.T) {
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
if err := dev1.SendPacket(msg1to2); err != nil {
t.Fatal(err)
}
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(1 * time.Second):
t.Error("return ping did not transit")
}
})
t.Run("no-op dev1 reconfig", func(t *testing.T) {
if err := dev1.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err)
}
ping1(t)
ping2(t)
})
}