tailscale/derp/xdp/xdp_linux_test.go

976 lines
32 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package xdp
import (
"bytes"
"errors"
"fmt"
"net/netip"
"testing"
"github.com/cilium/ebpf"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"tailscale.com/net/stun"
)
type xdpAction uint32
func (x xdpAction) String() string {
switch x {
case xdpActionAborted:
return "XDP_ABORTED"
case xdpActionDrop:
return "XDP_DROP"
case xdpActionPass:
return "XDP_PASS"
case xdpActionTX:
return "XDP_TX"
case xdpActionRedirect:
return "XDP_REDIRECT"
default:
return fmt.Sprintf("unknown(%d)", x)
}
}
const (
xdpActionAborted xdpAction = iota
xdpActionDrop
xdpActionPass
xdpActionTX
xdpActionRedirect
)
const (
ethHLen = 14
udpHLen = 8
ipv4HLen = 20
ipv6HLen = 40
)
const (
defaultSTUNPort = 3478
defaultTTL = 64
reqSrcPort = uint16(1025)
)
var (
reqEthSrc = tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01})
reqEthDst = tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x02})
reqIPv4Src = netip.MustParseAddr("192.0.2.1")
reqIPv4Dst = netip.MustParseAddr("192.0.2.2")
reqIPv6Src = netip.MustParseAddr("2001:db8::1")
reqIPv6Dst = netip.MustParseAddr("2001:db8::2")
)
var testTXID = stun.TxID([12]byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b})
type ipv4Mutations struct {
ipHeaderFn func(header.IPv4)
udpHeaderFn func(header.UDP)
stunReqFn func([]byte)
}
func getIPv4STUNBindingReq(mutations *ipv4Mutations) []byte {
req := stun.Request(testTXID)
if mutations != nil && mutations.stunReqFn != nil {
mutations.stunReqFn(req)
}
payloadLen := len(req)
totalLen := ipv4HLen + udpHLen + payloadLen
b := make([]byte, ethHLen+totalLen)
ipv4H := header.IPv4(b[ethHLen:])
ethH := header.Ethernet(b)
ethFields := header.EthernetFields{
SrcAddr: reqEthSrc,
DstAddr: reqEthDst,
Type: unix.ETH_P_IP,
}
ethH.Encode(&ethFields)
ipFields := header.IPv4Fields{
SrcAddr: tcpip.AddrFrom4(reqIPv4Src.As4()),
DstAddr: tcpip.AddrFrom4(reqIPv4Dst.As4()),
Protocol: unix.IPPROTO_UDP,
TTL: defaultTTL,
TotalLength: uint16(totalLen),
}
ipv4H.Encode(&ipFields)
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
if mutations != nil && mutations.ipHeaderFn != nil {
mutations.ipHeaderFn(ipv4H)
}
udpH := header.UDP(b[ethHLen+ipv4HLen:])
udpFields := header.UDPFields{
SrcPort: reqSrcPort,
DstPort: defaultSTUNPort,
Length: uint16(udpHLen + payloadLen),
Checksum: 0,
}
udpH.Encode(&udpFields)
copy(b[ethHLen+ipv4HLen+udpHLen:], req)
cs := header.PseudoHeaderChecksum(
unix.IPPROTO_UDP,
ipv4H.SourceAddress(),
ipv4H.DestinationAddress(),
uint16(udpHLen+payloadLen),
)
cs = checksum.Checksum(req, cs)
udpH.SetChecksum(^udpH.CalculateChecksum(cs))
if mutations != nil && mutations.udpHeaderFn != nil {
mutations.udpHeaderFn(udpH)
}
return b
}
type ipv6Mutations struct {
ipHeaderFn func(header.IPv6)
udpHeaderFn func(header.UDP)
stunReqFn func([]byte)
}
func getIPv6STUNBindingReq(mutations *ipv6Mutations) []byte {
req := stun.Request(testTXID)
if mutations != nil && mutations.stunReqFn != nil {
mutations.stunReqFn(req)
}
payloadLen := len(req)
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
b := make([]byte, ethHLen+ipv6HLen+udpHLen+payloadLen)
ipv6H := header.IPv6(b[ethHLen:])
ethH := header.Ethernet(b)
ethFields := header.EthernetFields{
SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x01}),
DstAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x5e, 0x00, 0x53, 0x02}),
Type: unix.ETH_P_IPV6,
}
ethH.Encode(&ethFields)
ipFields := header.IPv6Fields{
SrcAddr: tcpip.AddrFrom16(src.As16()),
DstAddr: tcpip.AddrFrom16(dst.As16()),
TransportProtocol: unix.IPPROTO_UDP,
HopLimit: 64,
PayloadLength: uint16(udpHLen + payloadLen),
}
ipv6H.Encode(&ipFields)
if mutations != nil && mutations.ipHeaderFn != nil {
mutations.ipHeaderFn(ipv6H)
}
udpH := header.UDP(b[ethHLen+ipv6HLen:])
udpFields := header.UDPFields{
SrcPort: 1025,
DstPort: defaultSTUNPort,
Length: uint16(udpHLen + payloadLen),
Checksum: 0,
}
udpH.Encode(&udpFields)
copy(b[ethHLen+ipv6HLen+udpHLen:], req)
cs := header.PseudoHeaderChecksum(
unix.IPPROTO_UDP,
ipv6H.SourceAddress(),
ipv6H.DestinationAddress(),
uint16(udpHLen+payloadLen),
)
cs = checksum.Checksum(req, cs)
udpH.SetChecksum(^udpH.CalculateChecksum(cs))
if mutations != nil && mutations.udpHeaderFn != nil {
mutations.udpHeaderFn(udpH)
}
return b
}
func getIPv4STUNBindingResp() []byte {
addrPort := netip.AddrPortFrom(reqIPv4Src, reqSrcPort)
resp := stun.Response(testTXID, addrPort)
payloadLen := len(resp)
totalLen := ipv4HLen + udpHLen + payloadLen
b := make([]byte, ethHLen+totalLen)
ipv4H := header.IPv4(b[ethHLen:])
ethH := header.Ethernet(b)
ethFields := header.EthernetFields{
SrcAddr: reqEthDst,
DstAddr: reqEthSrc,
Type: unix.ETH_P_IP,
}
ethH.Encode(&ethFields)
ipFields := header.IPv4Fields{
SrcAddr: tcpip.AddrFrom4(reqIPv4Dst.As4()),
DstAddr: tcpip.AddrFrom4(reqIPv4Src.As4()),
Protocol: unix.IPPROTO_UDP,
TTL: defaultTTL,
TotalLength: uint16(totalLen),
}
ipv4H.Encode(&ipFields)
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
udpH := header.UDP(b[ethHLen+ipv4HLen:])
udpFields := header.UDPFields{
SrcPort: defaultSTUNPort,
DstPort: reqSrcPort,
Length: uint16(udpHLen + payloadLen),
Checksum: 0,
}
udpH.Encode(&udpFields)
copy(b[ethHLen+ipv4HLen+udpHLen:], resp)
cs := header.PseudoHeaderChecksum(
unix.IPPROTO_UDP,
ipv4H.SourceAddress(),
ipv4H.DestinationAddress(),
uint16(udpHLen+payloadLen),
)
cs = checksum.Checksum(resp, cs)
udpH.SetChecksum(^udpH.CalculateChecksum(cs))
return b
}
func getIPv6STUNBindingResp() []byte {
addrPort := netip.AddrPortFrom(reqIPv6Src, reqSrcPort)
resp := stun.Response(testTXID, addrPort)
payloadLen := len(resp)
totalLen := ipv6HLen + udpHLen + payloadLen
b := make([]byte, ethHLen+totalLen)
ipv6H := header.IPv6(b[ethHLen:])
ethH := header.Ethernet(b)
ethFields := header.EthernetFields{
SrcAddr: reqEthDst,
DstAddr: reqEthSrc,
Type: unix.ETH_P_IPV6,
}
ethH.Encode(&ethFields)
ipFields := header.IPv6Fields{
SrcAddr: tcpip.AddrFrom16(reqIPv6Dst.As16()),
DstAddr: tcpip.AddrFrom16(reqIPv6Src.As16()),
TransportProtocol: unix.IPPROTO_UDP,
HopLimit: defaultTTL,
PayloadLength: uint16(udpHLen + payloadLen),
}
ipv6H.Encode(&ipFields)
udpH := header.UDP(b[ethHLen+ipv6HLen:])
udpFields := header.UDPFields{
SrcPort: defaultSTUNPort,
DstPort: reqSrcPort,
Length: uint16(udpHLen + payloadLen),
Checksum: 0,
}
udpH.Encode(&udpFields)
copy(b[ethHLen+ipv6HLen+udpHLen:], resp)
cs := header.PseudoHeaderChecksum(
unix.IPPROTO_UDP,
ipv6H.SourceAddress(),
ipv6H.DestinationAddress(),
uint16(udpHLen+payloadLen),
)
cs = checksum.Checksum(resp, cs)
udpH.SetChecksum(^udpH.CalculateChecksum(cs))
return b
}
func TestXDP(t *testing.T) {
ipv4STUNBindingReqTX := getIPv4STUNBindingReq(nil)
ipv6STUNBindingReqTX := getIPv6STUNBindingReq(nil)
ipv4STUNBindingReqIPCsumPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
oldCS := ipv4H.Checksum()
newCS := oldCS
for newCS == 0 || newCS == oldCS {
newCS++
}
ipv4H.SetChecksum(newCS)
},
})
ipv4STUNBindingReqIHLPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H[0] &= 0xF0
},
})
ipv4STUNBindingReqIPVerPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H[0] &= 0x0F
},
})
ipv4STUNBindingReqIPProtoPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H[9] = unix.IPPROTO_TCP
},
})
ipv4STUNBindingReqFragOffsetPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H.SetFlagsFragmentOffset(ipv4H.Flags(), 8)
},
})
ipv4STUNBindingReqFlagsMFPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H.SetFlagsFragmentOffset(header.IPv4FlagMoreFragments, 0)
},
})
ipv4STUNBindingReqTotLenPass := getIPv4STUNBindingReq(&ipv4Mutations{
ipHeaderFn: func(ipv4H header.IPv4) {
ipv4H.SetTotalLength(ipv4H.TotalLength() + 1)
ipv4H.SetChecksum(0)
ipv4H.SetChecksum(^ipv4H.CalculateChecksum())
},
})
ipv6STUNBindingReqIPVerPass := getIPv6STUNBindingReq(&ipv6Mutations{
ipHeaderFn: func(ipv6H header.IPv6) {
ipv6H[0] &= 0x0F
},
udpHeaderFn: func(udp header.UDP) {},
})
ipv6STUNBindingReqNextHdrPass := getIPv6STUNBindingReq(&ipv6Mutations{
ipHeaderFn: func(ipv6H header.IPv6) {
ipv6H.SetNextHeader(unix.IPPROTO_TCP)
},
udpHeaderFn: func(udp header.UDP) {},
})
ipv6STUNBindingReqPayloadLenPass := getIPv6STUNBindingReq(&ipv6Mutations{
ipHeaderFn: func(ipv6H header.IPv6) {
ipv6H.SetPayloadLength(ipv6H.PayloadLength() + 1)
},
udpHeaderFn: func(udp header.UDP) {},
})
ipv4STUNBindingReqUDPCsumPass := getIPv4STUNBindingReq(&ipv4Mutations{
udpHeaderFn: func(udpH header.UDP) {
oldCS := udpH.Checksum()
newCS := oldCS
for newCS == 0 || newCS == oldCS {
newCS++
}
udpH.SetChecksum(newCS)
},
})
ipv6STUNBindingReqUDPCsumPass := getIPv6STUNBindingReq(&ipv6Mutations{
udpHeaderFn: func(udpH header.UDP) {
oldCS := udpH.Checksum()
newCS := oldCS
for newCS == 0 || newCS == oldCS {
newCS++
}
udpH.SetChecksum(newCS)
},
})
ipv4STUNBindingReqSTUNTypePass := getIPv4STUNBindingReq(&ipv4Mutations{
stunReqFn: func(req []byte) {
req[1] = ^req[1]
},
})
ipv6STUNBindingReqSTUNTypePass := getIPv6STUNBindingReq(&ipv6Mutations{
stunReqFn: func(req []byte) {
req[1] = ^req[1]
},
})
ipv4STUNBindingReqSTUNMagicPass := getIPv4STUNBindingReq(&ipv4Mutations{
stunReqFn: func(req []byte) {
req[4] = ^req[4]
},
})
ipv6STUNBindingReqSTUNMagicPass := getIPv6STUNBindingReq(&ipv6Mutations{
stunReqFn: func(req []byte) {
req[4] = ^req[4]
},
})
ipv4STUNBindingReqSTUNAttrsLenPass := getIPv4STUNBindingReq(&ipv4Mutations{
stunReqFn: func(req []byte) {
req[2] = ^req[2]
},
})
ipv6STUNBindingReqSTUNAttrsLenPass := getIPv6STUNBindingReq(&ipv6Mutations{
stunReqFn: func(req []byte) {
req[2] = ^req[2]
},
})
ipv4STUNBindingReqSTUNSWValPass := getIPv4STUNBindingReq(&ipv4Mutations{
stunReqFn: func(req []byte) {
req[24] = ^req[24]
},
})
ipv6STUNBindingReqSTUNSWValPass := getIPv6STUNBindingReq(&ipv6Mutations{
stunReqFn: func(req []byte) {
req[24] = ^req[24]
},
})
ipv4STUNBindingReqSTUNFirstAttrPass := getIPv4STUNBindingReq(&ipv4Mutations{
stunReqFn: func(req []byte) {
req[21] = ^req[21]
},
})
ipv6STUNBindingReqSTUNFirstAttrPass := getIPv6STUNBindingReq(&ipv6Mutations{
stunReqFn: func(req []byte) {
req[21] = ^req[21]
},
})
cases := []struct {
name string
packetIn []byte
wantCode xdpAction
wantPacketOut []byte
wantMetrics map[bpfCountersKey]uint64
}{
{
name: "ipv4 STUN Binding Request TX",
packetIn: ipv4STUNBindingReqTX,
wantCode: xdpActionTX,
wantPacketOut: getIPv4STUNBindingResp(),
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_TX_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_TX_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(getIPv4STUNBindingResp())),
},
},
{
name: "ipv6 STUN Binding Request TX",
packetIn: ipv6STUNBindingReqTX,
wantCode: xdpActionTX,
wantPacketOut: getIPv6STUNBindingResp(),
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_TX_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_TX_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(getIPv6STUNBindingResp())),
},
},
{
name: "ipv4 STUN Binding Request invalid ip csum PASS",
packetIn: ipv4STUNBindingReqIPCsumPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqIPCsumPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_IP_CSUM),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_IP_CSUM),
}: uint64(len(ipv4STUNBindingReqIPCsumPass)),
},
},
{
name: "ipv4 STUN Binding Request ihl PASS",
packetIn: ipv4STUNBindingReqIHLPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqIHLPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqIHLPass)),
},
},
{
name: "ipv4 STUN Binding Request ip version PASS",
packetIn: ipv4STUNBindingReqIPVerPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqIPVerPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqIPVerPass)),
},
},
{
name: "ipv4 STUN Binding Request ip proto PASS",
packetIn: ipv4STUNBindingReqIPProtoPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqIPProtoPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqIPProtoPass)),
},
},
{
name: "ipv4 STUN Binding Request frag offset PASS",
packetIn: ipv4STUNBindingReqFragOffsetPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqFragOffsetPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqFragOffsetPass)),
},
},
{
name: "ipv4 STUN Binding Request flags mf PASS",
packetIn: ipv4STUNBindingReqFlagsMFPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqFlagsMFPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqFlagsMFPass)),
},
},
{
name: "ipv4 STUN Binding Request tot len PASS",
packetIn: ipv4STUNBindingReqTotLenPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqTotLenPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqTotLenPass)),
},
},
{
name: "ipv6 STUN Binding Request ip version PASS",
packetIn: ipv6STUNBindingReqIPVerPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqIPVerPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqIPVerPass)),
},
},
{
name: "ipv6 STUN Binding Request next hdr PASS",
packetIn: ipv6STUNBindingReqNextHdrPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqNextHdrPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqNextHdrPass)),
},
},
{
name: "ipv6 STUN Binding Request payload len PASS",
packetIn: ipv6STUNBindingReqPayloadLenPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqPayloadLenPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqPayloadLenPass)),
},
},
{
name: "ipv4 STUN Binding Request UDP csum PASS",
packetIn: ipv4STUNBindingReqUDPCsumPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqUDPCsumPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM),
}: uint64(len(ipv4STUNBindingReqUDPCsumPass)),
},
},
{
name: "ipv6 STUN Binding Request UDP csum PASS",
packetIn: ipv6STUNBindingReqUDPCsumPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqUDPCsumPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_UDP_CSUM),
}: uint64(len(ipv6STUNBindingReqUDPCsumPass)),
},
},
{
name: "ipv4 STUN Binding Request STUN type PASS",
packetIn: ipv4STUNBindingReqSTUNTypePass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqSTUNTypePass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqSTUNTypePass)),
},
},
{
name: "ipv6 STUN Binding Request STUN type PASS",
packetIn: ipv6STUNBindingReqSTUNTypePass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqSTUNTypePass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqSTUNTypePass)),
},
},
{
name: "ipv4 STUN Binding Request STUN magic PASS",
packetIn: ipv4STUNBindingReqSTUNMagicPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqSTUNMagicPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqSTUNMagicPass)),
},
},
{
name: "ipv6 STUN Binding Request STUN magic PASS",
packetIn: ipv6STUNBindingReqSTUNMagicPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqSTUNMagicPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqSTUNMagicPass)),
},
},
{
name: "ipv4 STUN Binding Request STUN attrs len PASS",
packetIn: ipv4STUNBindingReqSTUNAttrsLenPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqSTUNAttrsLenPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv4STUNBindingReqSTUNAttrsLenPass)),
},
},
{
name: "ipv6 STUN Binding Request STUN attrs len PASS",
packetIn: ipv6STUNBindingReqSTUNAttrsLenPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqSTUNAttrsLenPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNSPECIFIED),
}: uint64(len(ipv6STUNBindingReqSTUNAttrsLenPass)),
},
},
{
name: "ipv4 STUN Binding Request STUN SW val PASS",
packetIn: ipv4STUNBindingReqSTUNSWValPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqSTUNSWValPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL),
}: uint64(len(ipv4STUNBindingReqSTUNSWValPass)),
},
},
{
name: "ipv6 STUN Binding Request STUN SW val PASS",
packetIn: ipv6STUNBindingReqSTUNSWValPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqSTUNSWValPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_INVALID_SW_ATTR_VAL),
}: uint64(len(ipv6STUNBindingReqSTUNSWValPass)),
},
},
{
name: "ipv4 STUN Binding Request STUN first attr PASS",
packetIn: ipv4STUNBindingReqSTUNFirstAttrPass,
wantCode: xdpActionPass,
wantPacketOut: ipv4STUNBindingReqSTUNFirstAttrPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV4),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR),
}: uint64(len(ipv4STUNBindingReqSTUNFirstAttrPass)),
},
},
{
name: "ipv6 STUN Binding Request STUN first attr PASS",
packetIn: ipv6STUNBindingReqSTUNFirstAttrPass,
wantCode: xdpActionPass,
wantPacketOut: ipv6STUNBindingReqSTUNFirstAttrPass,
wantMetrics: map[bpfCountersKey]uint64{
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR),
}: 1,
{
Af: uint8(bpfCounterKeyAfCOUNTER_KEY_AF_IPV6),
Pba: uint8(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_BYTES_PASS_TOTAL),
ProgEnd: uint8(bpfCounterKeyProgEndCOUNTER_KEY_END_UNEXPECTED_FIRST_STUN_ATTR),
}: uint64(len(ipv6STUNBindingReqSTUNFirstAttrPass)),
},
},
}
server, err := NewSTUNServer(&STUNServerConfig{DeviceName: "fake", DstPort: defaultSTUNPort},
&noAttachOption{})
if err != nil {
if errors.Is(err, unix.EPERM) {
// TODO(jwhited): get this running
t.Skip("skipping due to EPERM error; test requires elevated privileges")
}
t.Fatalf("error constructing STUN server: %v", err)
}
defer server.Close()
clearCounters := func() error {
server.metrics.last = make(map[bpfCountersKey]uint64)
var cur, next bpfCountersKey
keys := make([]bpfCountersKey, 0)
for err = server.objs.CountersMap.NextKey(nil, &next); ; err = server.objs.CountersMap.NextKey(cur, &next) {
if err != nil {
if errors.Is(err, ebpf.ErrKeyNotExist) {
break
}
return err
}
keys = append(keys, next)
cur = next
}
for _, key := range keys {
err = server.objs.CountersMap.Delete(&key)
if err != nil {
return err
}
}
err = server.objs.CountersMap.NextKey(nil, &next)
if !errors.Is(err, ebpf.ErrKeyNotExist) {
return errors.New("counters map is not empty")
}
return nil
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
err = clearCounters()
if err != nil {
t.Fatalf("error clearing counters: %v", err)
}
opts := ebpf.RunOptions{
Data: c.packetIn,
DataOut: make([]byte, 1514),
}
got, err := server.objs.XdpProgFunc.Run(&opts)
if err != nil {
t.Fatalf("error running program: %v", err)
}
if xdpAction(got) != c.wantCode {
t.Fatalf("got code: %s != %s", xdpAction(got), c.wantCode)
}
if !bytes.Equal(opts.DataOut, c.wantPacketOut) {
t.Fatal("packets not equal")
}
err = server.updateMetrics()
if err != nil {
t.Fatalf("error updating metrics: %v", err)
}
if c.wantMetrics != nil {
for k, v := range c.wantMetrics {
gotCounter, ok := server.metrics.last[k]
if !ok {
t.Errorf("expected counter at key %+v not found", k)
}
if gotCounter != v {
t.Errorf("key: %+v gotCounter: %d != %d", k, gotCounter, v)
}
}
for k := range server.metrics.last {
_, ok := c.wantMetrics[k]
if !ok {
t.Errorf("counter at key: %+v incremented unexpectedly", k)
}
}
}
})
}
}
func TestCountersMapKey(t *testing.T) {
if bpfCounterKeyAfCOUNTER_KEY_AF_LEN > 256 {
t.Error("COUNTER_KEY_AF_LEN no longer fits within uint8")
}
if bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN > 256 {
t.Error("COUNTER_KEY_PACKETS_BYTES_ACTION no longer fits within uint8")
}
if bpfCounterKeyProgEndCOUNTER_KEY_END_LEN > 256 {
t.Error("COUNTER_KEY_END_LEN no longer fits within uint8")
}
if len(pbaToOutcomeLV) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) {
t.Error("pbaToOutcomeLV is not in sync with xdp.c")
}
if len(progEndLV) != int(bpfCounterKeyProgEndCOUNTER_KEY_END_LEN) {
t.Error("progEndLV is not in sync with xdp.c")
}
if len(packetCounterKeys)+len(bytesCounterKeys) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) {
t.Error("packetCounterKeys and/or bytesCounterKeys is not in sync with xdp.c")
}
if len(pbaToOutcomeLV) != int(bpfCounterKeyPacketsBytesActionCOUNTER_KEY_PACKETS_BYTES_ACTION_LEN) {
t.Error("pbaToOutcomeLV is not in sync with xdp.c")
}
}