wgengine/packet: add some tests, more docs, minor Go style, performance changes

This commit is contained in:
Brad Fitzpatrick 2020-05-25 08:57:04 -07:00
parent 3f4a567032
commit 43ded2b581
3 changed files with 115 additions and 50 deletions

View File

@ -10,6 +10,8 @@ import (
"log"
"net"
"strings"
"tailscale.com/types/strbuilder"
)
type IPProto int
@ -23,7 +25,7 @@ const (
)
// RFC1858: prevent overlapping fragment attacks.
const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header
const minFrag = 60 + 20 // max IPv4 header + basic TCP header
func (p IPProto) String() string {
switch p {
@ -40,8 +42,11 @@ func (p IPProto) String() string {
}
}
// IP is an IPv4 address.
type IP uint32
// NewIP converts a standard library IP address into an IP.
// It panics if b is not an IPv4 address.
func NewIP(b net.IP) IP {
b4 := b.To4()
if b4 == nil {
@ -51,22 +56,21 @@ func NewIP(b net.IP) IP {
}
func (ip IP) String() string {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(ip))
return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3])
return fmt.Sprintf("%d.%d.%d.%d", byte(ip>>24), byte(ip>>16), byte(ip>>8), byte(ip))
}
// ICMP types.
const (
EchoReply uint8 = 0x00
EchoRequest uint8 = 0x08
Unreachable uint8 = 0x03
TimeExceeded uint8 = 0x0B
ICMPEchoReply = 0x00
ICMPEchoRequest = 0x08
ICMPUnreachable = 0x03
ICMPTimeExceeded = 0x0b
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPSynAck uint8 = TCPSyn | TCPAck
TCPSyn = 0x02
TCPAck = 0x10
TCPSynAck = TCPSyn | TCPAck
)
type QDecode struct {
@ -81,18 +85,30 @@ type QDecode struct {
TCPFlags uint8 // TCP flags (SYN, ACK, etc)
}
func (q QDecode) String() string {
func (q *QDecode) String() string {
if q.IPProto == Junk {
return "Junk{}"
}
srcip := make([]byte, 4)
dstip := make([]byte, 4)
binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP))
binary.BigEndian.PutUint32(dstip, uint32(q.DstIP))
return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}",
q.IPProto,
srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort,
dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort)
sb := strbuilder.Get()
sb.WriteString(q.IPProto.String())
sb.WriteByte('{')
writeIPPort(sb, q.SrcIP, q.SrcPort)
sb.WriteString(" > ")
writeIPPort(sb, q.DstIP, q.DstPort)
sb.WriteByte('}')
return sb.String()
}
func writeIPPort(sb *strbuilder.Builder, ip IP, port uint16) {
sb.WriteUint(uint64(byte(ip >> 24)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip >> 16)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip >> 8)))
sb.WriteByte('.')
sb.WriteUint(uint64(byte(ip)))
sb.WriteByte(':')
sb.WriteUint(uint64(port))
}
// based on https://tools.ietf.org/html/rfc1071
@ -114,7 +130,12 @@ func ipChecksum(b []byte) uint16 {
return uint16(^ac)
}
func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte {
var put16 = binary.BigEndian.PutUint16
var put32 = binary.BigEndian.PutUint32
// GenICMP returns the bytes of an ICMP packet.
// If payload is too short or too long, it returns nil.
func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType, icmpCode uint8, payload []byte) []byte {
if len(payload) < 4 {
return nil
}
@ -126,22 +147,22 @@ func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, paylo
out := make([]byte, 24+len(payload))
out[0] = 0x45 // IPv4, 20-byte header
out[1] = 0x00 // DHCP, ECN
binary.BigEndian.PutUint16(out[2:4], uint16(sz))
binary.BigEndian.PutUint16(out[4:6], ipid)
binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset
out[8] = 64 // TTL
out[9] = 0x01 // ICMPv4
put16(out[2:4], uint16(sz))
put16(out[4:6], ipid)
put16(out[6:8], 0) // flags, offset
out[8] = 64 // TTL
out[9] = 0x01 // ICMPv4
// out[10:12] = 0x00 // blank IP header checksum
binary.BigEndian.PutUint32(out[12:16], uint32(srcIP))
binary.BigEndian.PutUint32(out[16:20], uint32(dstIP))
put32(out[12:16], uint32(srcIP))
put32(out[16:20], uint32(dstIP))
out[20] = icmpType
out[21] = icmpCode
//out[22:24] = 0x00 // blank ICMP checksum
copy(out[24:], payload)
binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20]))
binary.BigEndian.PutUint16(out[22:24], ipChecksum(out))
put16(out[10:12], ipChecksum(out[0:20]))
put16(out[22:24], ipChecksum(out))
return out
}
@ -193,7 +214,7 @@ func (q *QDecode) Decode(b []byte) {
fragOfs := fragFlags & 0x1FFF
if fragOfs == 0 {
// This is the first fragment
if moreFrags && len(sub) < MIN_FRAG {
if moreFrags && len(sub) < minFrag {
// Suspiciously short first fragment, dump it.
log.Printf("junk1!\n")
q.IPProto = Junk
@ -241,7 +262,7 @@ func (q *QDecode) Decode(b []byte) {
}
} else {
// This is a fragment other than the first one.
if fragOfs < MIN_FRAG {
if fragOfs < minFrag {
// First frag was suspiciously short, so we can't
// trust the followup either.
q.IPProto = Junk
@ -263,57 +284,52 @@ func (q *QDecode) Sub(begin, n int) []byte {
return q.b[q.subofs+begin : q.subofs+begin+n]
}
// For a packet that is known to be IPv4, trim the buffer to its IPv4 length.
// Trim trims the buffer to its IPv4 length.
// Sometimes packets arrive from an interface with extra bytes on the end.
// This removes them.
func (q *QDecode) Trim() []byte {
n := binary.BigEndian.Uint16(q.b[2:4])
return q.b[0:n]
return q.b[:n]
}
// For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the
// IsTCPSyn reports whether q is a TCP SYN packet (i.e. the
// first packet in a new connection).
func (q *QDecode) IsTCPSyn() bool {
const Syn = 0x02
const Ack = 0x10
const SynAck = Syn | Ack
return (q.TCPFlags & SynAck) == Syn
return (q.TCPFlags & TCPSynAck) == TCPSyn
}
// For a packet that has already been decoded, check if it's an IPv4 ICMP
// "Error" packet.
// IsError reports whether q is an IPv4 ICMP "Error" packet.
func (q *QDecode) IsError() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
switch q.b[q.subofs] {
case Unreachable, TimeExceeded:
case ICMPUnreachable, ICMPTimeExceeded:
return true
}
}
return false
}
// For a packet that has already been decoded, check if it's an IPv4 ICMP
// Echo Request.
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Request.
func (q *QDecode) IsEchoRequest() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0
return q.b[q.subofs] == ICMPEchoRequest && q.b[q.subofs+1] == 0
}
return false
}
// For a packet that has already been decoded, check if it's an IPv4 ICMP
// Echo Response.
// IsEchoRequest reports whether q is an IPv4 ICMP Echo Response.
func (q *QDecode) IsEchoResponse() bool {
if q.IPProto == ICMP && len(q.b) >= q.subofs+8 {
return q.b[q.subofs] == EchoReply && q.b[q.subofs+1] == 0
return q.b[q.subofs] == ICMPEchoReply && q.b[q.subofs+1] == 0
}
return false
}
// EchoResponse returns an IPv4 ICMP echo reply to the request in q.
func (q *QDecode) EchoRespond() []byte {
icmpid := binary.BigEndian.Uint16(q.Sub(4, 2))
b := q.Trim()
return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:])
return GenICMP(q.DstIP, q.SrcIP, icmpid, ICMPEchoReply, 0, b[q.subofs+4:])
}
func Hexdump(b []byte) string {

View File

@ -0,0 +1,49 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package packet
import (
"net"
"testing"
)
func TestIPString(t *testing.T) {
const str = "1.2.3.4"
ip := NewIP(net.ParseIP(str))
var got string
allocs := testing.AllocsPerRun(1000, func() {
got = ip.String()
})
if got != str {
t.Errorf("got %q; want %q", got, str)
}
if allocs != 1 {
t.Errorf("allocs = %v; want 1", allocs)
}
}
func TestQDecodeString(t *testing.T) {
q := QDecode{
IPProto: TCP,
SrcIP: NewIP(net.ParseIP("1.2.3.4")),
SrcPort: 123,
DstIP: NewIP(net.ParseIP("5.6.7.8")),
DstPort: 567,
}
got := q.String()
want := "TCP{1.2.3.4:123 > 5.6.7.8:567}"
if got != want {
t.Errorf("got %q; want %q", got, want)
}
allocs := testing.AllocsPerRun(1000, func() {
got = q.String()
})
if allocs != 1 {
t.Errorf("allocs = %v; want 1", allocs)
}
}

View File

@ -322,7 +322,7 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) {
return
}
for _, dstIP := range dstIPs {
b := packet.GenICMP(srcIP, dstIP, ipid, packet.EchoRequest, 0, payload)
b := packet.GenICMP(srcIP, dstIP, ipid, packet.ICMPEchoRequest, 0, payload)
e.tundev.InjectOutbound(b)
}
ipid++