diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 0e40377c7..c7cb9b4da 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -17,7 +17,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/golang/protobuf/proto from github.com/matttproud/golang_protobuf_extensions/pbutil - github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header L github.com/google/nftables from tailscale.com/util/linuxfw L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ @@ -79,22 +78,6 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa google.golang.org/protobuf/runtime/protoimpl from github.com/golang/protobuf/proto+ google.golang.org/protobuf/types/descriptorpb from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ - gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ - gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer - 💣 gvisor.dev/gvisor/pkg/buffer from gvisor.dev/gvisor/pkg/tcpip+ - gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs - 💣 gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ - gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log - gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ - gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/buffer - 💣 gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/atomicbitops+ - gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state - 💣 gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/atomicbitops+ - gvisor.dev/gvisor/pkg/tcpip from gvisor.dev/gvisor/pkg/tcpip/header+ - gvisor.dev/gvisor/pkg/tcpip/checksum from gvisor.dev/gvisor/pkg/buffer+ - gvisor.dev/gvisor/pkg/tcpip/header from tailscale.com/net/packet - gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header - gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ nhooyr.io/websocket from tailscale.com/cmd/derper+ nhooyr.io/websocket/internal/errd from nhooyr.io/websocket nhooyr.io/websocket/internal/xsync from nhooyr.io/websocket diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index b70d18d1b..b1112adea 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -12,6 +12,7 @@ import ( "testing" "tailscale.com/net/stun" + "tailscale.com/tstest/deptest" ) func TestProdAutocertHostPolicy(t *testing.T) { @@ -128,3 +129,14 @@ func TestNoContent(t *testing.T) { }) } } + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "gvisor.dev/gvisor/pkg/buffer": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/cpuid": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", + }, + }.Check(t) +} diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 715596c43..6a4415067 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -17,7 +17,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep github.com/fxamacker/cbor/v2 from tailscale.com/tka L 💣 github.com/godbus/dbus/v5 from github.com/coreos/go-systemd/v22/dbus github.com/golang/groupcache/lru from tailscale.com/net/dnscache - github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header L github.com/google/nftables from tailscale.com/util/linuxfw L 💣 github.com/google/nftables/alignedbuff from github.com/google/nftables/xt L 💣 github.com/google/nftables/binaryutil from github.com/google/nftables+ @@ -65,22 +64,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep go4.org/netipx from tailscale.com/wgengine/filter+ W 💣 golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/interfaces+ gopkg.in/yaml.v2 from sigs.k8s.io/yaml - gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ - gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer - 💣 gvisor.dev/gvisor/pkg/buffer from gvisor.dev/gvisor/pkg/tcpip+ - gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs - 💣 gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ - gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log - gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ - gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/buffer - 💣 gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/atomicbitops+ - gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state - 💣 gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/atomicbitops+ - gvisor.dev/gvisor/pkg/tcpip from gvisor.dev/gvisor/pkg/tcpip/header+ - gvisor.dev/gvisor/pkg/tcpip/checksum from gvisor.dev/gvisor/pkg/buffer+ - gvisor.dev/gvisor/pkg/tcpip/header from tailscale.com/net/packet - gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header - gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ k8s.io/client-go/util/homedir from tailscale.com/cmd/tailscale/cli nhooyr.io/websocket from tailscale.com/derp/derphttp+ nhooyr.io/websocket/internal/errd from nhooyr.io/websocket diff --git a/cmd/tailscale/tailscale_test.go b/cmd/tailscale/tailscale_test.go new file mode 100644 index 000000000..0a554aa82 --- /dev/null +++ b/cmd/tailscale/tailscale_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "gvisor.dev/gvisor/pkg/buffer": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/cpuid": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", + }, + }.Check(t) +} diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index b1cb0d295..c99601c0e 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -276,6 +276,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de 💣 tailscale.com/net/netstat from tailscale.com/ipn/ipnauth+ tailscale.com/net/netutil from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet from tailscale.com/net/tstun+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/net/netcheck+ tailscale.com/net/proxymux from tailscale.com/cmd/tailscaled diff --git a/ipn/ipn_test.go b/ipn/ipn_test.go new file mode 100644 index 000000000..bc9632fbb --- /dev/null +++ b/ipn/ipn_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "gvisor.dev/gvisor/pkg/buffer": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/cpuid": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip": "https://github.com/tailscale/tailscale/issues/9756", + "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", + }, + }.Check(t) +} diff --git a/net/packet/checksum/checksum.go b/net/packet/checksum/checksum.go new file mode 100644 index 000000000..c49ae3626 --- /dev/null +++ b/net/packet/checksum/checksum.go @@ -0,0 +1,197 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package checksum provides functions for updating checksums in parsed packets. +package checksum + +import ( + "encoding/binary" + "net/netip" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +// UpdateSrcAddr updates the source address in the packet buffer (e.g. during +// SNAT). It also updates the checksum. Currently (2023-09-22) only TCP/UDP/ICMP +// is supported. It panics if provided with an address in a different +// family to the parsed packet. +func UpdateSrcAddr(q *packet.Parsed, src netip.Addr) { + if src.Is6() && q.IPVersion != 6 { + panic("UpdateSrcAddr: cannot write IPv6 address to v4 packet") + } else if src.Is4() && q.IPVersion != 4 { + panic("UpdateSrcAddr: cannot write IPv4 address to v6 packet") + } + q.CaptureMeta.DidSNAT = true + q.CaptureMeta.OriginalSrc = q.Src + + old := q.Src.Addr() + q.Src = netip.AddrPortFrom(src, q.Src.Port()) + + b := q.Buffer() + if src.Is6() { + v6 := src.As16() + copy(b[8:24], v6[:]) + updateV6PacketChecksums(q, old, src) + } else { + v4 := src.As4() + copy(b[12:16], v4[:]) + updateV4PacketChecksums(q, old, src) + } +} + +// UpdateDstAddr updates the destination address in the packet buffer (e.g. during +// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP +// is supported. It panics if provided with an address in a different +// family to the parsed packet. +func UpdateDstAddr(q *packet.Parsed, dst netip.Addr) { + if dst.Is6() && q.IPVersion != 6 { + panic("UpdateDstAddr: cannot write IPv6 address to v4 packet") + } else if dst.Is4() && q.IPVersion != 4 { + panic("UpdateDstAddr: cannot write IPv4 address to v6 packet") + } + q.CaptureMeta.DidDNAT = true + q.CaptureMeta.OriginalDst = q.Dst + + old := q.Dst.Addr() + q.Dst = netip.AddrPortFrom(dst, q.Dst.Port()) + + b := q.Buffer() + if dst.Is6() { + v6 := dst.As16() + copy(b[24:36], v6[:]) + updateV6PacketChecksums(q, old, dst) + } else { + v4 := dst.As4() + copy(b[16:20], v4[:]) + updateV4PacketChecksums(q, old, dst) + } +} + +// updateV4PacketChecksums updates the checksums in the packet buffer. +// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported. +// p is modified in place. +// If p.IPProto is unknown, only the IP header checksum is updated. +func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) { + if len(p.Buffer()) < 12 { + // Not enough space for an IPv4 header. + return + } + o4, n4 := old.As4(), new.As4() + + // First update the checksum in the IP header. + updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:]) + + // Now update the transport layer checksums, where applicable. + tr := p.Transport() + switch p.IPProto { + case ipproto.UDP, ipproto.DCCP: + if len(tr) < header.UDPMinimumSize { + // Not enough space for a UDP header. + return + } + updateV4Checksum(tr[6:8], o4[:], n4[:]) + case ipproto.TCP: + if len(tr) < header.TCPMinimumSize { + // Not enough space for a TCP header. + return + } + updateV4Checksum(tr[16:18], o4[:], n4[:]) + case ipproto.GRE: + if len(tr) < 6 { + // Not enough space for a GRE header. + return + } + if tr[0] == 1 { // checksum present + updateV4Checksum(tr[4:6], o4[:], n4[:]) + } + case ipproto.SCTP, ipproto.ICMPv4: + // No transport layer update required. + } +} + +// updateV6PacketChecksums updates the checksums in the packet buffer. +// p is modified in place. +// If p.IPProto is unknown, no checksums are updated. +func updateV6PacketChecksums(p *packet.Parsed, old, new netip.Addr) { + if len(p.Buffer()) < 40 { + // Not enough space for an IPv6 header. + return + } + o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice()) + + // Now update the transport layer checksums, where applicable. + tr := p.Transport() + switch p.IPProto { + case ipproto.ICMPv6: + if len(tr) < header.ICMPv6MinimumSize { + return + } + header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6) + case ipproto.UDP, ipproto.DCCP: + if len(tr) < header.UDPMinimumSize { + return + } + header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) + case ipproto.TCP: + if len(tr) < header.TCPMinimumSize { + return + } + header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) + case ipproto.SCTP: + // No transport layer update required. + } +} + +// updateV4Checksum calculates and updates the checksum in the packet buffer for +// a change between old and new. The oldSum must point to the 16-bit checksum +// field in the packet buffer that holds the old checksum value, it will be +// updated in place. +// +// The old and new must be the same length, and must be an even number of bytes. +func updateV4Checksum(oldSum, old, new []byte) { + if len(old) != len(new) { + panic("old and new must be the same length") + } + if len(old)%2 != 0 { + panic("old and new must be of even length") + } + /* + RFC 1624 + Given the following notation: + + HC - old checksum in header + C - one's complement sum of old header + HC' - new checksum in header + C' - one's complement sum of new header + m - old value of a 16-bit field + m' - new value of a 16-bit field + + HC' = ~(C + (-m) + m') -- [Eqn. 3] + HC' = ~(~HC + ~m + m') + + This can be simplified to: + HC' = ~(C + ~m + m') -- [Eqn. 3] + HC' = ~C' + C' = C + ~m + m' + */ + + c := uint32(^binary.BigEndian.Uint16(oldSum)) + + cPrime := c + for len(new) > 0 { + mNot := uint32(^binary.BigEndian.Uint16(old[:2])) + mPrime := uint32(binary.BigEndian.Uint16(new[:2])) + cPrime += mPrime + mNot + new, old = new[2:], old[2:] + } + + // Account for overflows by adding the carry bits back into the sum. + for (cPrime >> 16) > 0 { + cPrime = cPrime&0xFFFF + cPrime>>16 + } + hcPrime := ^uint16(cPrime) + binary.BigEndian.PutUint16(oldSum, hcPrime) +} diff --git a/net/packet/checksum/checksum_test.go b/net/packet/checksum/checksum_test.go new file mode 100644 index 000000000..aeb030c1c --- /dev/null +++ b/net/packet/checksum/checksum_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package checksum + +import ( + "encoding/binary" + "net/netip" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" + "tailscale.com/net/packet" +) + +func fullHeaderChecksumV4(b []byte) uint16 { + s := uint32(0) + for i := 0; i < len(b); i += 2 { + if i == 10 { + // Skip checksum field. + continue + } + s += uint32(binary.BigEndian.Uint16(b[i : i+2])) + } + for s>>16 > 0 { + s = s&0xFFFF + s>>16 + } + return ^uint16(s) +} + +func TestHeaderChecksumsV4(t *testing.T) { + // This is not a good enough test, because it doesn't + // check the various packet types or the many edge cases + // of the checksum algorithm. But it's a start. + + tests := []struct { + name string + packet []byte + }{ + { + name: "ICMPv4", + packet: []byte{ + 0x45, 0x00, 0x00, 0x54, 0xb7, 0x96, 0x40, 0x00, 0x40, 0x01, 0x7a, 0x06, 0x64, 0x7f, 0x3f, 0x4c, 0x64, 0x40, 0x01, 0x01, 0x08, 0x00, 0x47, 0x1a, 0x00, 0x11, 0x01, 0xac, 0xcc, 0xf5, 0x95, 0x63, 0x00, 0x00, 0x00, 0x00, 0x8d, 0xfc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + }, + }, + { + name: "TLS", + packet: []byte{ + 0x45, 0x00, 0x00, 0x3c, 0x54, 0x29, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xac, 0x64, 0x42, 0xd4, 0x33, 0x64, 0x61, 0x98, 0x0f, 0xb1, 0x94, 0x01, 0xbb, 0x0a, 0x51, 0xce, 0x7c, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfb, 0xe0, 0x38, 0xf6, 0x00, 0x00, 0x02, 0x04, 0x04, 0xd8, 0x04, 0x02, 0x08, 0x0a, 0x86, 0x2b, 0xcc, 0xd5, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07, + }, + }, + { + name: "DNS", + packet: []byte{ + 0x45, 0x00, 0x00, 0x74, 0xe2, 0x85, 0x00, 0x00, 0x40, 0x11, 0x96, 0xb5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x42, 0xd4, 0x33, 0x00, 0x35, 0xec, 0x55, 0x00, 0x60, 0xd9, 0x19, 0xed, 0xfd, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x34, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x0c, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x01, 0x6c, 0xc0, 0x15, 0xc0, 0x31, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x04, 0x8e, 0xfa, 0xbd, 0xce, 0x00, 0x00, 0x29, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "DCCP", + packet: []byte{ + 0x45, 0x00, 0x00, 0x28, 0x15, 0x06, 0x40, 0x00, 0x40, 0x21, 0x5f, 0x2f, 0xc0, 0xa8, 0x01, 0x1f, 0xc9, 0x0b, 0x3b, 0xad, 0x80, 0x04, 0x13, 0x89, 0x05, 0x00, 0x08, 0xdb, 0x01, 0x00, 0x00, 0x04, 0x29, 0x01, 0x6d, 0xdc, 0x00, 0x00, 0x00, 0x00, + }, + }, + { + name: "SCTP", + packet: []byte{ + 0x45, 0x00, 0x00, 0x30, 0x09, 0xd9, 0x40, 0x00, 0xff, 0x84, 0x50, 0xe2, 0x0a, 0x1c, 0x06, 0x2c, 0x0a, 0x1c, 0x06, 0x2b, 0x0b, 0x80, 0x40, 0x00, 0x21, 0x44, 0x15, 0x23, 0x2b, 0xf2, 0x02, 0x4e, 0x03, 0x00, 0x00, 0x10, 0x28, 0x02, 0x43, 0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + }, + // TODO(maisem): add test for GRE. + } + var p packet.Parsed + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p.Decode(tt.packet) + t.Log(p.String()) + UpdateSrcAddr(&p, netip.MustParseAddr("100.64.0.1")) + + got := binary.BigEndian.Uint16(tt.packet[10:12]) + want := fullHeaderChecksumV4(tt.packet[:20]) + if got != want { + t.Fatalf("got %x want %x", got, want) + } + + UpdateDstAddr(&p, netip.MustParseAddr("100.64.0.2")) + got = binary.BigEndian.Uint16(tt.packet[10:12]) + want = fullHeaderChecksumV4(tt.packet[:20]) + if got != want { + t.Fatalf("got %x want %x", got, want) + } + }) + } +} + +func TestNatChecksumsV6UDP(t *testing.T) { + a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1") + + // Make a fake UDP packet with 32 bytes of zeros as the datagram payload. + b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.UDPMinimumSize+32)) + b.Encode(&header.IPv6Fields{ + PayloadLength: header.UDPMinimumSize + 32, + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 16, + SrcAddr: tcpip.AddrFrom16Slice(a1.AsSlice()), + DstAddr: tcpip.AddrFrom16Slice(a2.AsSlice()), + }) + udp := header.UDP(b[header.IPv6MinimumSize:]) + udp.Encode(&header.UDPFields{ + SrcPort: 42, + DstPort: 43, + Length: header.UDPMinimumSize + 32, + }) + xsum := header.PseudoHeaderChecksum( + header.UDPProtocolNumber, + tcpip.AddrFrom16Slice(a1.AsSlice()), + tcpip.AddrFrom16Slice(a2.AsSlice()), + uint16(header.UDPMinimumSize+32), + ) + xsum = checksum.Checksum(b.Payload()[header.UDPMinimumSize:], xsum) + udp.SetChecksum(^udp.CalculateChecksum(xsum)) + if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { + t.Fatal("test broken; initial packet has incorrect checksum") + } + + // Parse the packet. + var p packet.Parsed + p.Decode(b) + t.Log(p.String()) + + // Update the source address of the packet to be the same as the dest. + UpdateSrcAddr(&p, a2) + if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { + t.Fatal("incorrect checksum after updating source address") + } + + // Update the dest address of the packet to be the original source address. + UpdateDstAddr(&p, a1) + if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { + t.Fatal("incorrect checksum after updating destination address") + } +} + +func TestNatChecksumsV6TCP(t *testing.T) { + a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1") + + // Make a fake TCP packet with no payload. + b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize)) + b.Encode(&header.IPv6Fields{ + PayloadLength: header.TCPMinimumSize, + TransportProtocol: header.TCPProtocolNumber, + HopLimit: 16, + SrcAddr: tcpip.AddrFrom16Slice(a1.AsSlice()), + DstAddr: tcpip.AddrFrom16Slice(a2.AsSlice()), + }) + tcp := header.TCP(b[header.IPv6MinimumSize:]) + tcp.Encode(&header.TCPFields{ + SrcPort: 42, + DstPort: 43, + SeqNum: 1, + AckNum: 2, + DataOffset: header.TCPMinimumSize, + Flags: 3, + WindowSize: 4, + Checksum: 0, + UrgentPointer: 5, + }) + xsum := header.PseudoHeaderChecksum( + header.TCPProtocolNumber, + tcpip.AddrFrom16Slice(a1.AsSlice()), + tcpip.AddrFrom16Slice(a2.AsSlice()), + uint16(header.TCPMinimumSize), + ) + tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) + + if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) { + t.Fatal("test broken; initial packet has incorrect checksum") + } + + // Parse the packet. + var p packet.Parsed + p.Decode(b) + t.Log(p.String()) + + // Update the source address of the packet to be the same as the dest. + UpdateSrcAddr(&p, a2) + if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) { + t.Fatal("incorrect checksum after updating source address") + } + + // Update the dest address of the packet to be the original source address. + UpdateDstAddr(&p, a1) + if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), 0, 0) { + t.Fatal("incorrect checksum after updating destination address") + } +} diff --git a/net/packet/packet.go b/net/packet/packet.go index 9760005ec..9e837e575 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -10,8 +10,6 @@ import ( "net/netip" "strings" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/net/netaddr" "tailscale.com/types/ipproto" ) @@ -454,62 +452,6 @@ func (q *Parsed) IsEchoResponse() bool { } } -// UpdateSrcAddr updates the source address in the packet buffer (e.g. during -// SNAT). It also updates the checksum. Currently (2023-09-22) only TCP/UDP/ICMP -// is supported. It panics if provided with an address in a different -// family to the parsed packet. -func (q *Parsed) UpdateSrcAddr(src netip.Addr) { - if src.Is6() && q.IPVersion != 6 { - panic("UpdateSrcAddr: cannot write IPv6 address to v4 packet") - } else if src.Is4() && q.IPVersion != 4 { - panic("UpdateSrcAddr: cannot write IPv4 address to v6 packet") - } - q.CaptureMeta.DidSNAT = true - q.CaptureMeta.OriginalSrc = q.Src - - old := q.Src.Addr() - q.Src = netip.AddrPortFrom(src, q.Src.Port()) - - b := q.Buffer() - if src.Is6() { - v6 := src.As16() - copy(b[8:24], v6[:]) - updateV6PacketChecksums(q, old, src) - } else { - v4 := src.As4() - copy(b[12:16], v4[:]) - updateV4PacketChecksums(q, old, src) - } -} - -// UpdateDstAddr updates the destination address in the packet buffer (e.g. during -// DNAT). It also updates the checksum. Currently (2022-12-10) only TCP/UDP/ICMP -// is supported. It panics if provided with an address in a different -// family to the parsed packet. -func (q *Parsed) UpdateDstAddr(dst netip.Addr) { - if dst.Is6() && q.IPVersion != 6 { - panic("UpdateDstAddr: cannot write IPv6 address to v4 packet") - } else if dst.Is4() && q.IPVersion != 4 { - panic("UpdateDstAddr: cannot write IPv4 address to v6 packet") - } - q.CaptureMeta.DidDNAT = true - q.CaptureMeta.OriginalDst = q.Dst - - old := q.Dst.Addr() - q.Dst = netip.AddrPortFrom(dst, q.Dst.Port()) - - b := q.Buffer() - if dst.Is6() { - v6 := dst.As16() - copy(b[24:36], v6[:]) - updateV6PacketChecksums(q, old, dst) - } else { - v4 := dst.As4() - copy(b[16:20], v4[:]) - updateV4PacketChecksums(q, old, dst) - } -} - // EchoIDSeq extracts the identifier/sequence bytes from an ICMP Echo response, // and returns them as a uint32, used to lookup internally routed ICMP echo // responses. This function is intentionally lightweight as it is called on @@ -572,129 +514,3 @@ func withIP(ap netip.AddrPort, ip netip.Addr) netip.AddrPort { func withPort(ap netip.AddrPort, port uint16) netip.AddrPort { return netip.AddrPortFrom(ap.Addr(), port) } - -// updateV4PacketChecksums updates the checksums in the packet buffer. -// Currently (2023-03-01) only TCP/UDP/ICMP over IPv4 is supported. -// p is modified in place. -// If p.IPProto is unknown, only the IP header checksum is updated. -func updateV4PacketChecksums(p *Parsed, old, new netip.Addr) { - if len(p.Buffer()) < 12 { - // Not enough space for an IPv4 header. - return - } - o4, n4 := old.As4(), new.As4() - - // First update the checksum in the IP header. - updateV4Checksum(p.Buffer()[10:12], o4[:], n4[:]) - - // Now update the transport layer checksums, where applicable. - tr := p.Transport() - switch p.IPProto { - case ipproto.UDP, ipproto.DCCP: - if len(tr) < header.UDPMinimumSize { - // Not enough space for a UDP header. - return - } - updateV4Checksum(tr[6:8], o4[:], n4[:]) - case ipproto.TCP: - if len(tr) < header.TCPMinimumSize { - // Not enough space for a TCP header. - return - } - updateV4Checksum(tr[16:18], o4[:], n4[:]) - case ipproto.GRE: - if len(tr) < 6 { - // Not enough space for a GRE header. - return - } - if tr[0] == 1 { // checksum present - updateV4Checksum(tr[4:6], o4[:], n4[:]) - } - case ipproto.SCTP, ipproto.ICMPv4: - // No transport layer update required. - } -} - -// updateV6PacketChecksums updates the checksums in the packet buffer. -// p is modified in place. -// If p.IPProto is unknown, no checksums are updated. -func updateV6PacketChecksums(p *Parsed, old, new netip.Addr) { - if len(p.Buffer()) < 40 { - // Not enough space for an IPv6 header. - return - } - o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice()) - - // Now update the transport layer checksums, where applicable. - tr := p.Transport() - switch p.IPProto { - case ipproto.ICMPv6: - if len(tr) < header.ICMPv6MinimumSize { - return - } - header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6) - case ipproto.UDP, ipproto.DCCP: - if len(tr) < header.UDPMinimumSize { - return - } - header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) - case ipproto.TCP: - if len(tr) < header.TCPMinimumSize { - return - } - header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) - case ipproto.SCTP: - // No transport layer update required. - } -} - -// updateV4Checksum calculates and updates the checksum in the packet buffer for -// a change between old and new. The oldSum must point to the 16-bit checksum -// field in the packet buffer that holds the old checksum value, it will be -// updated in place. -// -// The old and new must be the same length, and must be an even number of bytes. -func updateV4Checksum(oldSum, old, new []byte) { - if len(old) != len(new) { - panic("old and new must be the same length") - } - if len(old)%2 != 0 { - panic("old and new must be of even length") - } - /* - RFC 1624 - Given the following notation: - - HC - old checksum in header - C - one's complement sum of old header - HC' - new checksum in header - C' - one's complement sum of new header - m - old value of a 16-bit field - m' - new value of a 16-bit field - - HC' = ~(C + (-m) + m') -- [Eqn. 3] - HC' = ~(~HC + ~m + m') - - This can be simplified to: - HC' = ~(C + ~m + m') -- [Eqn. 3] - HC' = ~C' - C' = C + ~m + m' - */ - - c := uint32(^binary.BigEndian.Uint16(oldSum)) - - cPrime := c - for len(new) > 0 { - mNot := uint32(^binary.BigEndian.Uint16(old[:2])) - mPrime := uint32(binary.BigEndian.Uint16(new[:2])) - cPrime += mPrime + mNot - new, old = new[2:], old[2:] - } - - // Account for overflows by adding the carry bits back into the sum. - for (cPrime >> 16) > 0 { - cPrime = cPrime&0xFFFF + cPrime>>16 - } - hcPrime := ^uint16(cPrime) - binary.BigEndian.PutUint16(oldSum, hcPrime) -} diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index 553fd11f4..9d6254f09 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -5,7 +5,6 @@ package packet import ( "bytes" - "encoding/binary" "encoding/hex" "net/netip" "reflect" @@ -13,9 +12,6 @@ import ( "testing" "unicode" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/checksum" - "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/tstest" "tailscale.com/types/ipproto" "tailscale.com/util/must" @@ -33,187 +29,6 @@ const ( Fragment = ipproto.Fragment ) -func fullHeaderChecksumV4(b []byte) uint16 { - s := uint32(0) - for i := 0; i < len(b); i += 2 { - if i == 10 { - // Skip checksum field. - continue - } - s += uint32(binary.BigEndian.Uint16(b[i : i+2])) - } - for s>>16 > 0 { - s = s&0xFFFF + s>>16 - } - return ^uint16(s) -} - -func TestHeaderChecksumsV4(t *testing.T) { - // This is not a good enough test, because it doesn't - // check the various packet types or the many edge cases - // of the checksum algorithm. But it's a start. - - tests := []struct { - name string - packet []byte - }{ - { - name: "ICMPv4", - packet: []byte{ - 0x45, 0x00, 0x00, 0x54, 0xb7, 0x96, 0x40, 0x00, 0x40, 0x01, 0x7a, 0x06, 0x64, 0x7f, 0x3f, 0x4c, 0x64, 0x40, 0x01, 0x01, 0x08, 0x00, 0x47, 0x1a, 0x00, 0x11, 0x01, 0xac, 0xcc, 0xf5, 0x95, 0x63, 0x00, 0x00, 0x00, 0x00, 0x8d, 0xfc, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, - }, - }, - { - name: "TLS", - packet: []byte{ - 0x45, 0x00, 0x00, 0x3c, 0x54, 0x29, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xac, 0x64, 0x42, 0xd4, 0x33, 0x64, 0x61, 0x98, 0x0f, 0xb1, 0x94, 0x01, 0xbb, 0x0a, 0x51, 0xce, 0x7c, 0x00, 0x00, 0x00, 0x00, 0xa0, 0x02, 0xfb, 0xe0, 0x38, 0xf6, 0x00, 0x00, 0x02, 0x04, 0x04, 0xd8, 0x04, 0x02, 0x08, 0x0a, 0x86, 0x2b, 0xcc, 0xd5, 0x00, 0x00, 0x00, 0x00, 0x01, 0x03, 0x03, 0x07, - }, - }, - { - name: "DNS", - packet: []byte{ - 0x45, 0x00, 0x00, 0x74, 0xe2, 0x85, 0x00, 0x00, 0x40, 0x11, 0x96, 0xb5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x42, 0xd4, 0x33, 0x00, 0x35, 0xec, 0x55, 0x00, 0x60, 0xd9, 0x19, 0xed, 0xfd, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x08, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x34, 0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x0c, 0x07, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x01, 0x6c, 0xc0, 0x15, 0xc0, 0x31, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x1e, 0x00, 0x04, 0x8e, 0xfa, 0xbd, 0xce, 0x00, 0x00, 0x29, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - }, - { - name: "DCCP", - packet: []byte{ - 0x45, 0x00, 0x00, 0x28, 0x15, 0x06, 0x40, 0x00, 0x40, 0x21, 0x5f, 0x2f, 0xc0, 0xa8, 0x01, 0x1f, 0xc9, 0x0b, 0x3b, 0xad, 0x80, 0x04, 0x13, 0x89, 0x05, 0x00, 0x08, 0xdb, 0x01, 0x00, 0x00, 0x04, 0x29, 0x01, 0x6d, 0xdc, 0x00, 0x00, 0x00, 0x00, - }, - }, - { - name: "SCTP", - packet: []byte{ - 0x45, 0x00, 0x00, 0x30, 0x09, 0xd9, 0x40, 0x00, 0xff, 0x84, 0x50, 0xe2, 0x0a, 0x1c, 0x06, 0x2c, 0x0a, 0x1c, 0x06, 0x2b, 0x0b, 0x80, 0x40, 0x00, 0x21, 0x44, 0x15, 0x23, 0x2b, 0xf2, 0x02, 0x4e, 0x03, 0x00, 0x00, 0x10, 0x28, 0x02, 0x43, 0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, - }, - }, - // TODO(maisem): add test for GRE. - } - var p Parsed - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p.Decode(tt.packet) - t.Log(p.String()) - p.UpdateSrcAddr(netip.MustParseAddr("100.64.0.1")) - - got := binary.BigEndian.Uint16(tt.packet[10:12]) - want := fullHeaderChecksumV4(tt.packet[:20]) - if got != want { - t.Fatalf("got %x want %x", got, want) - } - - p.UpdateDstAddr(netip.MustParseAddr("100.64.0.2")) - got = binary.BigEndian.Uint16(tt.packet[10:12]) - want = fullHeaderChecksumV4(tt.packet[:20]) - if got != want { - t.Fatalf("got %x want %x", got, want) - } - }) - } -} - -func TestNatChecksumsV6UDP(t *testing.T) { - a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1") - - // Make a fake UDP packet with 32 bytes of zeros as the datagram payload. - b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.UDPMinimumSize+32)) - b.Encode(&header.IPv6Fields{ - PayloadLength: header.UDPMinimumSize + 32, - TransportProtocol: header.UDPProtocolNumber, - HopLimit: 16, - SrcAddr: tcpip.AddrFrom16Slice(a1.AsSlice()), - DstAddr: tcpip.AddrFrom16Slice(a2.AsSlice()), - }) - udp := header.UDP(b[header.IPv6MinimumSize:]) - udp.Encode(&header.UDPFields{ - SrcPort: 42, - DstPort: 43, - Length: header.UDPMinimumSize + 32, - }) - xsum := header.PseudoHeaderChecksum( - header.UDPProtocolNumber, - tcpip.AddrFrom16Slice(a1.AsSlice()), - tcpip.AddrFrom16Slice(a2.AsSlice()), - uint16(header.UDPMinimumSize+32), - ) - xsum = checksum.Checksum(b.Payload()[header.UDPMinimumSize:], xsum) - udp.SetChecksum(^udp.CalculateChecksum(xsum)) - if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { - t.Fatal("test broken; initial packet has incorrect checksum") - } - - // Parse the packet. - var p Parsed - p.Decode(b) - t.Log(p.String()) - - // Update the source address of the packet to be the same as the dest. - p.UpdateSrcAddr(a2) - if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { - t.Fatal("incorrect checksum after updating source address") - } - - // Update the dest address of the packet to be the original source address. - p.UpdateDstAddr(a1) - if !udp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), checksum.Checksum(b.Payload()[header.UDPMinimumSize:], 0)) { - t.Fatal("incorrect checksum after updating destination address") - } -} - -func TestNatChecksumsV6TCP(t *testing.T) { - a1, a2 := netip.MustParseAddr("a::1"), netip.MustParseAddr("b::1") - - // Make a fake TCP packet with no payload. - b := header.IPv6(make([]byte, header.IPv6MinimumSize+header.TCPMinimumSize)) - b.Encode(&header.IPv6Fields{ - PayloadLength: header.TCPMinimumSize, - TransportProtocol: header.TCPProtocolNumber, - HopLimit: 16, - SrcAddr: tcpip.AddrFrom16Slice(a1.AsSlice()), - DstAddr: tcpip.AddrFrom16Slice(a2.AsSlice()), - }) - tcp := header.TCP(b[header.IPv6MinimumSize:]) - tcp.Encode(&header.TCPFields{ - SrcPort: 42, - DstPort: 43, - SeqNum: 1, - AckNum: 2, - DataOffset: header.TCPMinimumSize, - Flags: 3, - WindowSize: 4, - Checksum: 0, - UrgentPointer: 5, - }) - xsum := header.PseudoHeaderChecksum( - header.TCPProtocolNumber, - tcpip.AddrFrom16Slice(a1.AsSlice()), - tcpip.AddrFrom16Slice(a2.AsSlice()), - uint16(header.TCPMinimumSize), - ) - tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) - - if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a1.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) { - t.Fatal("test broken; initial packet has incorrect checksum") - } - - // Parse the packet. - var p Parsed - p.Decode(b) - t.Log(p.String()) - - // Update the source address of the packet to be the same as the dest. - p.UpdateSrcAddr(a2) - if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a2.AsSlice()), 0, 0) { - t.Fatal("incorrect checksum after updating source address") - } - - // Update the dest address of the packet to be the original source address. - p.UpdateDstAddr(a1) - if !tcp.IsChecksumValid(tcpip.AddrFrom16Slice(a2.AsSlice()), tcpip.AddrFrom16Slice(a1.AsSlice()), 0, 0) { - t.Fatal("incorrect checksum after updating destination address") - } -} - func mustIPPort(s string) netip.AddrPort { ipp, err := netip.ParseAddrPort(s) if err != nil { diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index bad262b6c..f64a133aa 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -25,6 +25,7 @@ import ( "tailscale.com/disco" "tailscale.com/net/connstats" "tailscale.com/net/packet" + "tailscale.com/net/packet/checksum" "tailscale.com/net/tsaddr" "tailscale.com/net/tstun/table" "tailscale.com/syncs" @@ -487,7 +488,7 @@ func (t *Wrapper) snat(p *packet.Parsed) { oldSrc := p.Src.Addr() newSrc := nc.selectSrcIP(oldSrc, p.Dst.Addr()) if oldSrc != newSrc { - p.UpdateSrcAddr(newSrc) + checksum.UpdateSrcAddr(p, newSrc) } } @@ -497,7 +498,7 @@ func (t *Wrapper) dnat(p *packet.Parsed) { oldDst := p.Dst.Addr() newDst := nc.mapDstIP(oldDst) if newDst != oldDst { - p.UpdateDstAddr(newDst) + checksum.UpdateDstAddr(p, newDst) } }