887 lines
21 KiB
Go
887 lines
21 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package tstun
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/netip"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
"unicode"
|
|
"unsafe"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
"github.com/tailscale/wireguard-go/tun/tuntest"
|
|
"go4.org/mem"
|
|
"go4.org/netipx"
|
|
"gvisor.dev/gvisor/pkg/buffer"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"tailscale.com/disco"
|
|
"tailscale.com/net/connstats"
|
|
"tailscale.com/net/netaddr"
|
|
"tailscale.com/net/packet"
|
|
"tailscale.com/tstest"
|
|
"tailscale.com/tstime/mono"
|
|
"tailscale.com/types/ipproto"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/types/logger"
|
|
"tailscale.com/types/netlogtype"
|
|
"tailscale.com/types/ptr"
|
|
"tailscale.com/util/must"
|
|
"tailscale.com/wgengine/capture"
|
|
"tailscale.com/wgengine/filter"
|
|
"tailscale.com/wgengine/wgcfg"
|
|
)
|
|
|
|
func udp4(src, dst string, sport, dport uint16) []byte {
|
|
sip, err := netip.ParseAddr(src)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
dip, err := netip.ParseAddr(dst)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
header := &packet.UDP4Header{
|
|
IP4Header: packet.IP4Header{
|
|
Src: sip,
|
|
Dst: dip,
|
|
IPID: 0,
|
|
},
|
|
SrcPort: sport,
|
|
DstPort: dport,
|
|
}
|
|
return packet.Generate(header, []byte("udp_payload"))
|
|
}
|
|
|
|
func tcp4syn(src, dst string, sport, dport uint16) []byte {
|
|
sip, err := netip.ParseAddr(src)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
dip, err := netip.ParseAddr(dst)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
ipHeader := packet.IP4Header{
|
|
IPProto: ipproto.TCP,
|
|
Src: sip,
|
|
Dst: dip,
|
|
IPID: 0,
|
|
}
|
|
tcpHeader := make([]byte, 20)
|
|
binary.BigEndian.PutUint16(tcpHeader[0:], sport)
|
|
binary.BigEndian.PutUint16(tcpHeader[2:], dport)
|
|
tcpHeader[13] |= 2 // SYN
|
|
|
|
both := packet.Generate(ipHeader, tcpHeader)
|
|
|
|
// 20 byte IP4 + 20 byte TCP
|
|
binary.BigEndian.PutUint16(both[2:4], 40)
|
|
|
|
return both
|
|
}
|
|
|
|
func nets(nets ...string) (ret []netip.Prefix) {
|
|
for _, s := range nets {
|
|
if i := strings.IndexByte(s, '/'); i == -1 {
|
|
ip, err := netip.ParseAddr(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
bits := uint8(32)
|
|
if ip.Is6() {
|
|
bits = 128
|
|
}
|
|
ret = append(ret, netip.PrefixFrom(ip, int(bits)))
|
|
} else {
|
|
pfx, err := netip.ParsePrefix(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
ret = append(ret, pfx)
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func ports(s string) filter.PortRange {
|
|
if s == "*" {
|
|
return filter.PortRange{First: 0, Last: 65535}
|
|
}
|
|
|
|
var fs, ls string
|
|
i := strings.IndexByte(s, '-')
|
|
if i == -1 {
|
|
fs = s
|
|
ls = fs
|
|
} else {
|
|
fs = s[:i]
|
|
ls = s[i+1:]
|
|
}
|
|
first, err := strconv.ParseInt(fs, 10, 16)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
|
}
|
|
last, err := strconv.ParseInt(ls, 10, 16)
|
|
if err != nil {
|
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
|
}
|
|
return filter.PortRange{First: uint16(first), Last: uint16(last)}
|
|
}
|
|
|
|
func netports(netPorts ...string) (ret []filter.NetPortRange) {
|
|
for _, s := range netPorts {
|
|
i := strings.LastIndexByte(s, ':')
|
|
if i == -1 {
|
|
panic(fmt.Sprintf("invalid NetPortRange %q", s))
|
|
}
|
|
|
|
npr := filter.NetPortRange{
|
|
Net: nets(s[:i])[0],
|
|
Ports: ports(s[i+1:]),
|
|
}
|
|
ret = append(ret, npr)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func setfilter(logf logger.Logf, tun *Wrapper) {
|
|
protos := []ipproto.Proto{
|
|
ipproto.TCP,
|
|
ipproto.UDP,
|
|
}
|
|
matches := []filter.Match{
|
|
{IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")},
|
|
{IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")},
|
|
}
|
|
var sb netipx.IPSetBuilder
|
|
sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16"))
|
|
ipSet, _ := sb.IPSet()
|
|
tun.SetFilter(filter.New(matches, ipSet, ipSet, nil, logf))
|
|
}
|
|
|
|
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
|
chtun := tuntest.NewChannelTUN()
|
|
tun := Wrap(logf, chtun.TUN())
|
|
if secure {
|
|
setfilter(logf, tun)
|
|
} else {
|
|
tun.disableFilter = true
|
|
}
|
|
return chtun, tun
|
|
}
|
|
|
|
func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
|
|
ftun := NewFake()
|
|
tun := Wrap(logf, ftun)
|
|
if secure {
|
|
setfilter(logf, tun)
|
|
} else {
|
|
tun.disableFilter = true
|
|
}
|
|
return ftun.(*fakeTUN), tun
|
|
}
|
|
|
|
func TestReadAndInject(t *testing.T) {
|
|
chtun, tun := newChannelTUN(t.Logf, false)
|
|
defer tun.Close()
|
|
|
|
const size = 2 // all payloads have this size
|
|
written := []string{"w0", "w1"}
|
|
injected := []string{"i0", "i1"}
|
|
|
|
go func() {
|
|
for _, packet := range written {
|
|
payload := []byte(packet)
|
|
chtun.Outbound <- payload
|
|
}
|
|
}()
|
|
|
|
for _, packet := range injected {
|
|
go func(packet string) {
|
|
payload := []byte(packet)
|
|
err := tun.InjectOutbound(payload)
|
|
if err != nil {
|
|
t.Errorf("%s: error: %v", packet, err)
|
|
}
|
|
}(packet)
|
|
}
|
|
|
|
var buf [MaxPacketSize]byte
|
|
var seen = make(map[string]bool)
|
|
sizes := make([]int, 1)
|
|
// We expect the same packets back, in no particular order.
|
|
for i := 0; i < len(written)+len(injected); i++ {
|
|
packet := buf[:]
|
|
buffs := [][]byte{packet}
|
|
numPackets, err := tun.Read(buffs, sizes, 0)
|
|
if err != nil {
|
|
t.Errorf("read %d: error: %v", i, err)
|
|
}
|
|
if numPackets != 1 {
|
|
t.Fatalf("read %d packets, expected %d", numPackets, 1)
|
|
}
|
|
packet = packet[:sizes[0]]
|
|
packetLen := len(packet)
|
|
if packetLen != size {
|
|
t.Errorf("read %d: got size %d; want %d", i, packetLen, size)
|
|
}
|
|
got := string(packet)
|
|
t.Logf("read %d: got %s", i, got)
|
|
seen[got] = true
|
|
}
|
|
|
|
for _, packet := range written {
|
|
if !seen[packet] {
|
|
t.Errorf("%s not received", packet)
|
|
}
|
|
}
|
|
for _, packet := range injected {
|
|
if !seen[packet] {
|
|
t.Errorf("%s not received", packet)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestWriteAndInject(t *testing.T) {
|
|
chtun, tun := newChannelTUN(t.Logf, false)
|
|
defer tun.Close()
|
|
|
|
const size = 2 // all payloads have this size
|
|
written := []string{"w0", "w1"}
|
|
injected := []string{"i0", "i1"}
|
|
|
|
go func() {
|
|
for _, packet := range written {
|
|
payload := []byte(packet)
|
|
_, err := tun.Write([][]byte{payload}, 0)
|
|
if err != nil {
|
|
t.Errorf("%s: error: %v", packet, err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
for _, packet := range injected {
|
|
go func(packet string) {
|
|
payload := []byte(packet)
|
|
err := tun.InjectInboundCopy(payload)
|
|
if err != nil {
|
|
t.Errorf("%s: error: %v", packet, err)
|
|
}
|
|
}(packet)
|
|
}
|
|
|
|
seen := make(map[string]bool)
|
|
// We expect the same packets back, in no particular order.
|
|
for i := 0; i < len(written)+len(injected); i++ {
|
|
packet := <-chtun.Inbound
|
|
got := string(packet)
|
|
t.Logf("read %d: got %s", i, got)
|
|
seen[got] = true
|
|
}
|
|
|
|
for _, packet := range written {
|
|
if !seen[packet] {
|
|
t.Errorf("%s not received", packet)
|
|
}
|
|
}
|
|
for _, packet := range injected {
|
|
if !seen[packet] {
|
|
t.Errorf("%s not received", packet)
|
|
}
|
|
}
|
|
}
|
|
|
|
// mustHexDecode is like hex.DecodeString, but panics on error
|
|
// and ignores whitespace in s.
|
|
func mustHexDecode(s string) []byte {
|
|
return must.Get(hex.DecodeString(strings.Map(func(r rune) rune {
|
|
if unicode.IsSpace(r) {
|
|
return -1
|
|
}
|
|
return r
|
|
}, s)))
|
|
}
|
|
|
|
func TestFilter(t *testing.T) {
|
|
chtun, tun := newChannelTUN(t.Logf, true)
|
|
defer tun.Close()
|
|
|
|
type direction int
|
|
|
|
const (
|
|
in direction = iota
|
|
out
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
dir direction
|
|
drop bool
|
|
data []byte
|
|
}{
|
|
{"short_in", in, true, []byte("\x45xxx")},
|
|
{"short_out", out, true, []byte("\x45xxx")},
|
|
{"ip97_out", out, false, mustHexDecode("4500 0019 d186 4000 4061 751d 644a 4603 6449 e549 6865 6c6c 6f")},
|
|
{"bad_port_in", in, true, udp4("5.6.7.8", "1.2.3.4", 22, 22)},
|
|
{"bad_port_out", out, false, udp4("1.2.3.4", "5.6.7.8", 22, 22)},
|
|
{"bad_ip_in", in, true, udp4("8.1.1.1", "1.2.3.4", 89, 89)},
|
|
{"bad_ip_out", out, false, udp4("1.2.3.4", "8.1.1.1", 98, 98)},
|
|
{"good_packet_in", in, false, udp4("5.6.7.8", "1.2.3.4", 89, 89)},
|
|
{"good_packet_out", out, false, udp4("1.2.3.4", "5.6.7.8", 98, 98)},
|
|
}
|
|
|
|
// A reader on the other end of the tun.
|
|
go func() {
|
|
var recvbuf []byte
|
|
for {
|
|
select {
|
|
case <-tun.closed:
|
|
return
|
|
case recvbuf = <-chtun.Inbound:
|
|
// continue
|
|
}
|
|
for _, tt := range tests {
|
|
if tt.drop && bytes.Equal(recvbuf, tt.data) {
|
|
t.Errorf("did not drop %s", tt.name)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
var buf [MaxPacketSize]byte
|
|
stats := connstats.NewStatistics(0, 0, nil)
|
|
defer stats.Shutdown(context.Background())
|
|
tun.SetStatistics(stats)
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var n int
|
|
var err error
|
|
var filtered bool
|
|
sizes := make([]int, 1)
|
|
|
|
tunStats, _ := stats.TestExtract()
|
|
if len(tunStats) > 0 {
|
|
t.Errorf("connstats.Statistics.Extract = %v, want {}", stats)
|
|
}
|
|
|
|
if tt.dir == in {
|
|
// Use the side effect of updating the last
|
|
// activity atomic to determine whether the
|
|
// data was actually filtered.
|
|
// If it stays zero, nothing made it through
|
|
// to the wrapped TUN.
|
|
tun.lastActivityAtomic.StoreAtomic(0)
|
|
_, err = tun.Write([][]byte{tt.data}, 0)
|
|
filtered = tun.lastActivityAtomic.LoadAtomic() == 0
|
|
} else {
|
|
chtun.Outbound <- tt.data
|
|
n, err = tun.Read([][]byte{buf[:]}, sizes, 0)
|
|
// In the read direction, errors are fatal, so we return n = 0 instead.
|
|
filtered = (n == 0)
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("got err %v; want nil", err)
|
|
}
|
|
|
|
if filtered {
|
|
if !tt.drop {
|
|
t.Errorf("got drop; want accept")
|
|
}
|
|
} else {
|
|
if tt.drop {
|
|
t.Errorf("got accept; want drop")
|
|
}
|
|
}
|
|
|
|
got, _ := stats.TestExtract()
|
|
want := map[netlogtype.Connection]netlogtype.Counts{}
|
|
var wasUDP bool
|
|
if !tt.drop {
|
|
var p packet.Parsed
|
|
p.Decode(tt.data)
|
|
wasUDP = p.IPProto == ipproto.UDP
|
|
switch tt.dir {
|
|
case in:
|
|
conn := netlogtype.Connection{Proto: ipproto.UDP, Src: p.Dst, Dst: p.Src}
|
|
want[conn] = netlogtype.Counts{RxPackets: 1, RxBytes: uint64(len(tt.data))}
|
|
case out:
|
|
conn := netlogtype.Connection{Proto: ipproto.UDP, Src: p.Src, Dst: p.Dst}
|
|
want[conn] = netlogtype.Counts{TxPackets: 1, TxBytes: uint64(len(tt.data))}
|
|
}
|
|
}
|
|
if wasUDP {
|
|
if diff := cmp.Diff(got, want, cmpopts.EquateEmpty()); diff != "" {
|
|
t.Errorf("stats.TestExtract (-got +want):\n%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAllocs(t *testing.T) {
|
|
ftun, tun := newFakeTUN(t.Logf, false)
|
|
defer tun.Close()
|
|
|
|
buf := [][]byte{{0x00}}
|
|
err := tstest.MinAllocsPerRun(t, 0, func() {
|
|
_, err := ftun.Write(buf, 0)
|
|
if err != nil {
|
|
t.Errorf("write: error: %v", err)
|
|
return
|
|
}
|
|
})
|
|
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
ftun, tun := newFakeTUN(t.Logf, false)
|
|
|
|
data := [][]byte{udp4("1.2.3.4", "5.6.7.8", 98, 98)}
|
|
_, err := ftun.Write(data, 0)
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
|
|
tun.Close()
|
|
_, err = ftun.Write(data, 0)
|
|
if err == nil {
|
|
t.Error("Expected error from ftun.Write() after Close()")
|
|
}
|
|
}
|
|
|
|
func BenchmarkWrite(b *testing.B) {
|
|
b.ReportAllocs()
|
|
ftun, tun := newFakeTUN(b.Logf, true)
|
|
defer tun.Close()
|
|
|
|
packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)}
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := ftun.Write(packet, 0)
|
|
if err != nil {
|
|
b.Errorf("err = %v; want nil", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestAtomic64Alignment(t *testing.T) {
|
|
off := unsafe.Offsetof(Wrapper{}.lastActivityAtomic)
|
|
if off%8 != 0 {
|
|
t.Errorf("offset %v not 8-byte aligned", off)
|
|
}
|
|
|
|
c := new(Wrapper)
|
|
c.lastActivityAtomic.StoreAtomic(mono.Now())
|
|
}
|
|
|
|
func TestPeerAPIBypass(t *testing.T) {
|
|
wrapperWithPeerAPI := &Wrapper{
|
|
PeerAPIPort: func(ip netip.Addr) (port uint16, ok bool) {
|
|
if ip == netip.MustParseAddr("100.64.1.2") {
|
|
return 60000, true
|
|
}
|
|
return
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
w *Wrapper
|
|
filter *filter.Filter
|
|
pkt []byte
|
|
want filter.Response
|
|
}{
|
|
{
|
|
name: "reject_nil_filter",
|
|
w: &Wrapper{
|
|
PeerAPIPort: func(netip.Addr) (port uint16, ok bool) {
|
|
return 60000, true
|
|
},
|
|
},
|
|
pkt: tcp4syn("1.2.3.4", "100.64.1.2", 1234, 60000),
|
|
want: filter.Drop,
|
|
},
|
|
{
|
|
name: "reject_with_filter",
|
|
w: &Wrapper{},
|
|
filter: filter.NewAllowNone(logger.Discard, new(netipx.IPSet)),
|
|
pkt: tcp4syn("1.2.3.4", "100.64.1.2", 1234, 60000),
|
|
want: filter.Drop,
|
|
},
|
|
{
|
|
name: "peerapi_bypass_filter",
|
|
w: wrapperWithPeerAPI,
|
|
filter: filter.NewAllowNone(logger.Discard, new(netipx.IPSet)),
|
|
pkt: tcp4syn("1.2.3.4", "100.64.1.2", 1234, 60000),
|
|
want: filter.Accept,
|
|
},
|
|
{
|
|
name: "peerapi_dont_bypass_filter_wrong_port",
|
|
w: wrapperWithPeerAPI,
|
|
filter: filter.NewAllowNone(logger.Discard, new(netipx.IPSet)),
|
|
pkt: tcp4syn("1.2.3.4", "100.64.1.2", 1234, 60001),
|
|
want: filter.Drop,
|
|
},
|
|
{
|
|
name: "peerapi_dont_bypass_filter_wrong_dst_ip",
|
|
w: wrapperWithPeerAPI,
|
|
filter: filter.NewAllowNone(logger.Discard, new(netipx.IPSet)),
|
|
pkt: tcp4syn("1.2.3.4", "100.64.1.3", 1234, 60000),
|
|
want: filter.Drop,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
p := new(packet.Parsed)
|
|
p.Decode(tt.pkt)
|
|
tt.w.SetFilter(tt.filter)
|
|
tt.w.disableTSMPRejected = true
|
|
tt.w.logf = t.Logf
|
|
if got := tt.w.filterPacketInboundFromWireGuard(p, nil); got != tt.want {
|
|
t.Errorf("got = %v; want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Issue 1526: drop disco frames from ourselves.
|
|
func TestFilterDiscoLoop(t *testing.T) {
|
|
var memLog tstest.MemLogger
|
|
discoPub := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 31: 0}))
|
|
tw := &Wrapper{logf: memLog.Logf, limitedLogf: memLog.Logf}
|
|
tw.SetDiscoKey(discoPub)
|
|
uh := packet.UDP4Header{
|
|
IP4Header: packet.IP4Header{
|
|
IPProto: ipproto.UDP,
|
|
Src: netaddr.IPv4(1, 2, 3, 4),
|
|
Dst: netaddr.IPv4(5, 6, 7, 8),
|
|
},
|
|
SrcPort: 9,
|
|
DstPort: 10,
|
|
}
|
|
discobs := discoPub.Raw32()
|
|
discoPayload := fmt.Sprintf("%s%s%s", disco.Magic, discobs[:], [disco.NonceLen]byte{})
|
|
pkt := make([]byte, uh.Len()+len(discoPayload))
|
|
uh.Marshal(pkt)
|
|
copy(pkt[uh.Len():], discoPayload)
|
|
|
|
p := new(packet.Parsed)
|
|
p.Decode(pkt)
|
|
got := tw.filterPacketInboundFromWireGuard(p, nil)
|
|
if got != filter.DropSilently {
|
|
t.Errorf("got %v; want DropSilently", got)
|
|
}
|
|
if got, want := memLog.String(), "[unexpected] received self disco in packet over tstun; dropping\n"; got != want {
|
|
t.Errorf("log output mismatch\n got: %q\nwant: %q\n", got, want)
|
|
}
|
|
|
|
memLog.Reset()
|
|
pp := new(packet.Parsed)
|
|
pp.Decode(pkt)
|
|
got = tw.filterPacketOutboundToWireGuard(pp)
|
|
if got != filter.DropSilently {
|
|
t.Errorf("got %v; want DropSilently", got)
|
|
}
|
|
if got, want := memLog.String(), "[unexpected] received self disco out packet over tstun; dropping\n"; got != want {
|
|
t.Errorf("log output mismatch\n got: %q\nwant: %q\n", got, want)
|
|
}
|
|
}
|
|
|
|
func TestNATCfg(t *testing.T) {
|
|
node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer {
|
|
p := wgcfg.Peer{
|
|
PublicKey: key.NewNode().Public(),
|
|
AllowedIPs: []netip.Prefix{
|
|
netip.PrefixFrom(ip, ip.BitLen()),
|
|
},
|
|
V4MasqAddr: ptr.To(masqIP),
|
|
}
|
|
p.AllowedIPs = append(p.AllowedIPs, otherAllowedIPs...)
|
|
return p
|
|
}
|
|
var (
|
|
noIP netip.Addr
|
|
|
|
selfNativeIP = netip.MustParseAddr("100.64.0.1")
|
|
selfEIP1 = netip.MustParseAddr("100.64.1.1")
|
|
selfEIP2 = netip.MustParseAddr("100.64.1.2")
|
|
selfAddrs = []netip.Prefix{netip.PrefixFrom(selfNativeIP, selfNativeIP.BitLen())}
|
|
|
|
peer1IP = netip.MustParseAddr("100.64.0.2")
|
|
peer2IP = netip.MustParseAddr("100.64.0.3")
|
|
|
|
subnet = netip.MustParsePrefix("192.168.0.0/24")
|
|
subnetIP = netip.MustParseAddr("192.168.0.1")
|
|
|
|
exitRoute = netip.MustParsePrefix("0.0.0.0/0")
|
|
publicIP = netip.MustParseAddr("8.8.8.8")
|
|
)
|
|
|
|
tests := []struct {
|
|
name string
|
|
wcfg *wgcfg.Config
|
|
snatMap map[netip.Addr]netip.Addr // dst -> src
|
|
dnatMap map[netip.Addr]netip.Addr
|
|
}{
|
|
{
|
|
name: "no-cfg",
|
|
wcfg: nil,
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfNativeIP,
|
|
peer2IP: selfNativeIP,
|
|
subnetIP: selfNativeIP,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfEIP1,
|
|
selfEIP2: selfEIP2,
|
|
},
|
|
},
|
|
{
|
|
name: "single-peer-requires-nat",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, noIP),
|
|
node(peer2IP, selfEIP1),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfNativeIP,
|
|
peer2IP: selfEIP1,
|
|
subnetIP: selfNativeIP,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfNativeIP,
|
|
selfEIP2: selfEIP2,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
{
|
|
name: "multiple-peers-require-nat",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, selfEIP1),
|
|
node(peer2IP, selfEIP2),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfEIP1,
|
|
peer2IP: selfEIP2,
|
|
subnetIP: selfNativeIP,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfNativeIP,
|
|
selfEIP2: selfNativeIP,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
{
|
|
name: "multiple-peers-require-nat-with-subnet",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, selfEIP1),
|
|
node(peer2IP, selfEIP2, subnet),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfEIP1,
|
|
peer2IP: selfEIP2,
|
|
subnetIP: selfEIP2,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfNativeIP,
|
|
selfEIP2: selfNativeIP,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
{
|
|
name: "multiple-peers-require-nat-with-default-route",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, selfEIP1),
|
|
node(peer2IP, selfEIP2, exitRoute),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfEIP1,
|
|
peer2IP: selfEIP2,
|
|
publicIP: selfEIP2,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfNativeIP,
|
|
selfEIP2: selfNativeIP,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
{
|
|
name: "no-nat",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, noIP),
|
|
node(peer2IP, noIP),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfNativeIP,
|
|
peer2IP: selfNativeIP,
|
|
subnetIP: selfNativeIP,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP1: selfEIP1,
|
|
selfEIP2: selfEIP2,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
{
|
|
name: "exit-node-require-nat-peer-doesnt",
|
|
wcfg: &wgcfg.Config{
|
|
Addresses: selfAddrs,
|
|
Peers: []wgcfg.Peer{
|
|
node(peer1IP, noIP),
|
|
node(peer2IP, selfEIP2, exitRoute),
|
|
},
|
|
},
|
|
snatMap: map[netip.Addr]netip.Addr{
|
|
peer1IP: selfNativeIP,
|
|
peer2IP: selfEIP2,
|
|
publicIP: selfEIP2,
|
|
},
|
|
dnatMap: map[netip.Addr]netip.Addr{
|
|
selfNativeIP: selfNativeIP,
|
|
selfEIP2: selfNativeIP,
|
|
subnetIP: subnetIP,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ncfg := natV4ConfigFromWGConfig(tc.wcfg)
|
|
for peer, want := range tc.snatMap {
|
|
if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want {
|
|
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)
|
|
}
|
|
}
|
|
for dstIP, want := range tc.dnatMap {
|
|
if got := ncfg.mapDstIP(dstIP); got != want {
|
|
t.Errorf("mapDstIP[%v]: got %v; want %v", dstIP, got, want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestCaptureHook verifies that the Wrapper.captureHook callback is called
|
|
// with the correct parameters when various packet operations are performed.
|
|
func TestCaptureHook(t *testing.T) {
|
|
type captureRecord struct {
|
|
path capture.Path
|
|
now time.Time
|
|
pkt []byte
|
|
meta packet.CaptureMeta
|
|
}
|
|
|
|
var captured []captureRecord
|
|
hook := func(path capture.Path, now time.Time, pkt []byte, meta packet.CaptureMeta) {
|
|
captured = append(captured, captureRecord{
|
|
path: path,
|
|
now: now,
|
|
pkt: pkt,
|
|
meta: meta,
|
|
})
|
|
}
|
|
|
|
now := time.Unix(1682085856, 0)
|
|
|
|
_, w := newFakeTUN(t.Logf, true)
|
|
w.timeNow = func() time.Time {
|
|
return now
|
|
}
|
|
w.InstallCaptureHook(hook)
|
|
defer w.Close()
|
|
|
|
// Loop reading and discarding packets; this ensures that we don't have
|
|
// packets stuck in vectorOutbound
|
|
go func() {
|
|
var (
|
|
buf [MaxPacketSize]byte
|
|
sizes = make([]int, 1)
|
|
)
|
|
for {
|
|
_, err := w.Read([][]byte{buf[:]}, sizes, 0)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Do operations that should result in a packet being captured.
|
|
w.Write([][]byte{
|
|
[]byte("Write1"),
|
|
[]byte("Write2"),
|
|
}, 0)
|
|
packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData([]byte("InjectInboundPacketBuffer")),
|
|
})
|
|
w.InjectInboundPacketBuffer(packetBuf)
|
|
|
|
packetBuf = stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
Payload: buffer.MakeWithData([]byte("InjectOutboundPacketBuffer")),
|
|
})
|
|
w.InjectOutboundPacketBuffer(packetBuf)
|
|
|
|
// TODO: test Read
|
|
// TODO: determine if we want InjectOutbound to log
|
|
|
|
// Assert that the right packets are captured.
|
|
want := []captureRecord{
|
|
{
|
|
path: capture.FromPeer,
|
|
pkt: []byte("Write1"),
|
|
},
|
|
{
|
|
path: capture.FromPeer,
|
|
pkt: []byte("Write2"),
|
|
},
|
|
{
|
|
path: capture.SynthesizedToLocal,
|
|
pkt: []byte("InjectInboundPacketBuffer"),
|
|
},
|
|
{
|
|
path: capture.SynthesizedToPeer,
|
|
pkt: []byte("InjectOutboundPacketBuffer"),
|
|
},
|
|
}
|
|
for i := 0; i < len(want); i++ {
|
|
want[i].now = now
|
|
}
|
|
if !reflect.DeepEqual(captured, want) {
|
|
t.Errorf("mismatch between captured and expected packets\ngot: %+v\nwant: %+v",
|
|
captured, want)
|
|
}
|
|
}
|