diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index d8bca1130..b016d5727 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -310,6 +310,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index f6edbe7d7..42a1c4579 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -221,6 +221,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ 💣 gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ diff --git a/go.mod b/go.mod index 4ec098a71..ff8adf5b4 100644 --- a/go.mod +++ b/go.mod @@ -80,7 +80,7 @@ require ( github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b + github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 @@ -104,7 +104,7 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 - gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 + gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 honnef.co/go/tools v0.4.6 k8s.io/api v0.30.3 k8s.io/apimachinery v0.30.3 diff --git a/go.sum b/go.sum index 9758ff885..e5f8a4673 100644 --- a/go.sum +++ b/go.sum @@ -934,8 +934,8 @@ github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:t github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b h1:8U9NaPB32iFoNjJ+H/yPkAVqXw/dudtj+fLTE4edF+Q= -github.com/tailscale/wireguard-go v0.0.0-20240724015428-60eeedfd624b/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= +github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= @@ -1491,8 +1491,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= -gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= +gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= +gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 44b07b1ff..24defba27 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -162,6 +162,10 @@ type Wrapper struct { PreFilterPacketInboundFromWireGuard FilterFunc // PostFilterPacketInboundFromWireGuard is the inbound filter function that runs after the main filter. PostFilterPacketInboundFromWireGuard FilterFunc + // EndPacketVectorInboundFromWireGuardFlush is a function that runs after all packets in a given vector + // have been handled by all filters. Filters may queue packets for the purposes of GRO, requiring an + // explicit flush. + EndPacketVectorInboundFromWireGuardFlush func() // PreFilterPacketOutboundToWireGuardNetstackIntercept is a filter function that runs before the main filter // for packets from the local system. This filter is populated by netstack to hook // packets that should be handled by netstack. If set, this filter runs before @@ -1179,6 +1183,9 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { } } } + if t.EndPacketVectorInboundFromWireGuardFlush != nil { + t.EndPacketVectorInboundFromWireGuardFlush() + } if t.disableFilter { i = len(buffs) } diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go index 77d6075ca..22fdcf8b3 100644 --- a/wgengine/netstack/link_endpoint.go +++ b/wgengine/netstack/link_endpoint.go @@ -4,12 +4,19 @@ package netstack import ( + "bytes" "context" "sync" + "github.com/tailscale/wireguard-go/tun" + "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/stack/gro" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" ) type queue struct { @@ -79,34 +86,53 @@ var _ stack.GSOEndpoint = (*linkEndpoint)(nil) // linkEndpoint implements stack.LinkEndpoint and stack.GSOEndpoint. Outbound // packets written by gVisor towards Tailscale are stored in a channel. -// Inbound is fed to gVisor via InjectInbound. This is loosely modeled after -// gvisor.dev/pkg/tcpip/link/channel.Endpoint. +// Inbound is fed to gVisor via injectInbound or enqueueGRO. This is loosely +// modeled after gvisor.dev/pkg/tcpip/link/channel.Endpoint. type linkEndpoint struct { - LinkEPCapabilities stack.LinkEndpointCapabilities - SupportedGSOKind stack.SupportedGSO + SupportedGSOKind stack.SupportedGSO + initGRO initGRO mu sync.RWMutex // mu guards the following fields dispatcher stack.NetworkDispatcher linkAddr tcpip.LinkAddress mtu uint32 + gro gro.GRO // mu only guards access to gro.Dispatcher q *queue // outbound } -func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress) *linkEndpoint { - return &linkEndpoint{ +// TODO(jwhited): move to linkEndpointOpts struct or similar. +type initGRO bool + +const ( + disableGRO initGRO = false + enableGRO initGRO = true +) + +func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, gro initGRO) *linkEndpoint { + le := &linkEndpoint{ q: &queue{ c: make(chan *stack.PacketBuffer, size), }, mtu: mtu, linkAddr: linkAddr, } + le.initGRO = gro + le.gro.Init(bool(gro)) + return le } // Close closes l. Further packet injections will return an error, and all // pending packets are discarded. Close may be called concurrently with // WritePackets. func (l *linkEndpoint) Close() { + l.mu.Lock() + if l.gro.Dispatcher != nil { + l.gro.Flush() + } + l.dispatcher = nil + l.gro.Dispatcher = nil + l.mu.Unlock() l.q.Close() l.Drain() } @@ -132,19 +158,149 @@ func (l *linkEndpoint) Drain() int { return c } -// NumQueued returns the number of packet queued for outbound. +// NumQueued returns the number of packets queued for outbound. func (l *linkEndpoint) NumQueued() int { return l.q.Num() } -// InjectInbound injects an inbound packet. If the endpoint is not attached, the -// packet is not delivered. -func (l *linkEndpoint) InjectInbound(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { +// rxChecksumOffload validates IPv4, TCP, and UDP header checksums in p, +// returning an equivalent *stack.PacketBuffer if they are valid, otherwise nil. +// The set of headers validated covers where gVisor would perform validation if +// !stack.PacketBuffer.RXChecksumValidated, i.e. it satisfies +// stack.CapabilityRXChecksumOffload. Other protocols with checksum fields, +// e.g. ICMP{v6}, are still validated by gVisor regardless of rx checksum +// offloading capabilities. +func rxChecksumOffload(p *packet.Parsed) *stack.PacketBuffer { + var ( + pn tcpip.NetworkProtocolNumber + csumStart int + ) + buf := p.Buffer() + + switch p.IPVersion { + case 4: + if len(buf) < header.IPv4MinimumSize { + return nil + } + csumStart = int((buf[0] & 0x0F) * 4) + if csumStart < header.IPv4MinimumSize || csumStart > header.IPv4MaximumHeaderSize || len(buf) < csumStart { + return nil + } + if ^tun.Checksum(buf[:csumStart], 0) != 0 { + return nil + } + pn = header.IPv4ProtocolNumber + case 6: + if len(buf) < header.IPv6FixedHeaderSize { + return nil + } + csumStart = header.IPv6FixedHeaderSize + pn = header.IPv6ProtocolNumber + if p.IPProto != ipproto.ICMPv6 && p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + // buf could have extension headers before a UDP or TCP header, but + // packet.Parsed.IPProto will be set to the ext header type, so we + // have to look deeper. We are still responsible for validating the + // L4 checksum in this case. So, make use of gVisor's existing + // extension header parsing via parse.IPv6() in order to unpack the + // L4 csumStart index. This is not particularly efficient as we have + // to allocate a short-lived stack.PacketBuffer that cannot be + // re-used. parse.IPv6() "consumes" the IPv6 headers, so we can't + // inject this stack.PacketBuffer into the stack at a later point. + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(bytes.Clone(buf)), + }) + defer packetBuf.DecRef() + // The rightmost bool returns false only if packetBuf is too short, + // which we've already accounted for above. + transportProto, _, _, _, _ := parse.IPv6(packetBuf) + if transportProto == header.TCPProtocolNumber || transportProto == header.UDPProtocolNumber { + csumLen := packetBuf.Data().Size() + if len(buf) < csumLen { + return nil + } + csumStart = len(buf) - csumLen + p.IPProto = ipproto.Proto(transportProto) + } + } + } + + if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP { + lenForPseudo := len(buf) - csumStart + csum := tun.PseudoHeaderChecksum( + uint8(p.IPProto), + p.Src.Addr().AsSlice(), + p.Dst.Addr().AsSlice(), + uint16(lenForPseudo)) + csum = tun.Checksum(buf[csumStart:], csum) + if ^csum != 0 { + return nil + } + } + + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(bytes.Clone(buf)), + }) + packetBuf.NetworkProtocolNumber = pn + // Setting this is not technically required. gVisor overrides where + // stack.CapabilityRXChecksumOffload is advertised from Capabilities(). + // https://github.com/google/gvisor/blob/64c016c92987cc04dfd4c7b091ddd21bdad875f8/pkg/tcpip/stack/nic.go#L763 + // This is also why we offload for all packets since we cannot signal this + // per-packet. + packetBuf.RXChecksumValidated = true + return packetBuf +} + +func (l *linkEndpoint) injectInbound(p *packet.Parsed) { l.mu.RLock() d := l.dispatcher l.mu.RUnlock() - if d != nil { - d.DeliverNetworkPacket(protocol, pkt) + if d == nil { + return + } + pkt := rxChecksumOffload(p) + if pkt == nil { + return + } + d.DeliverNetworkPacket(pkt.NetworkProtocolNumber, pkt) + pkt.DecRef() +} + +// enqueueGRO enqueues the provided packet for GRO. It may immediately deliver +// it to the underlying stack.NetworkDispatcher depending on its contents and if +// GRO was initialized via newLinkEndpoint. To explicitly flush previously +// enqueued packets see flushGRO. enqueueGRO is not thread-safe and must not +// be called concurrently with flushGRO. +func (l *linkEndpoint) enqueueGRO(p *packet.Parsed) { + l.mu.RLock() + defer l.mu.RUnlock() + if l.gro.Dispatcher == nil { + return + } + pkt := rxChecksumOffload(p) + if pkt == nil { + return + } + // TODO(jwhited): gro.Enqueue() duplicates a lot of p.Decode(). + // We may want to push stack.PacketBuffer further up as a + // replacement for packet.Parsed, or inversely push packet.Parsed + // down into refactored GRO logic. + l.gro.Enqueue(pkt) + pkt.DecRef() +} + +// flushGRO flushes previously enqueueGRO'd packets to the underlying +// stack.NetworkDispatcher. flushGRO is not thread-safe, and must not be +// called concurrently with enqueueGRO. +func (l *linkEndpoint) flushGRO() { + if !l.initGRO { + // If GRO was not initialized fast path return to avoid scanning GRO + // buckets (see l.gro.Flush()) that will always be empty. + return + } + l.mu.RLock() + defer l.mu.RUnlock() + if l.gro.Dispatcher != nil { + l.gro.Flush() } } @@ -154,6 +310,7 @@ func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { l.mu.Lock() defer l.mu.Unlock() l.dispatcher = dispatcher + l.gro.Dispatcher = dispatcher } // IsAttached implements stack.LinkEndpoint.IsAttached. @@ -179,7 +336,9 @@ func (l *linkEndpoint) SetMTU(mtu uint32) { // Capabilities implements stack.LinkEndpoint.Capabilities. func (l *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return l.LinkEPCapabilities + // We are required to offload RX checksum validation for the purposes of + // GRO. + return stack.CapabilityRXChecksumOffload } // GSOMaxSize implements stack.GSOEndpoint. diff --git a/wgengine/netstack/link_endpoint_test.go b/wgengine/netstack/link_endpoint_test.go new file mode 100644 index 000000000..97bc9e70a --- /dev/null +++ b/wgengine/netstack/link_endpoint_test.go @@ -0,0 +1,112 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netstack + +import ( + "bytes" + "net/netip" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "tailscale.com/net/packet" +) + +func Test_rxChecksumOffload(t *testing.T) { + payloadLen := 100 + + tcpFields := &header.TCPFields{ + SrcPort: 1, + DstPort: 1, + SeqNum: 1, + AckNum: 1, + DataOffset: 20, + Flags: header.TCPFlagAck | header.TCPFlagPsh, + WindowSize: 3000, + } + tcp4 := make([]byte, 20+20+payloadLen) + ipv4H := header.IPv4(tcp4) + ipv4H.Encode(&header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("192.0.2.2").AsSlice()), + Protocol: uint8(header.TCPProtocolNumber), + TTL: 64, + TotalLength: uint16(len(tcp4)), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + tcpH := header.TCP(tcp4[20:]) + tcpH.Encode(tcpFields) + pseudoCsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+payloadLen)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + + tcp6ExtHeader := make([]byte, 40+8+20+payloadLen) + ipv6H := header.IPv6(tcp6ExtHeader) + ipv6H.Encode(&header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::1").AsSlice()), + DstAddr: tcpip.AddrFromSlice(netip.MustParseAddr("2001:db8::2").AsSlice()), + TransportProtocol: 60, // really next header; destination options ext header + HopLimit: 64, + PayloadLength: uint16(8 + 20 + payloadLen), + }) + tcp6ExtHeader[40] = uint8(header.TCPProtocolNumber) // next header + tcp6ExtHeader[41] = 0 // length of ext header in 8-octet units, exclusive of first 8 octets. + // 42-47 options and padding + tcpH = header.TCP(tcp6ExtHeader[48:]) + tcpH.Encode(tcpFields) + pseudoCsum = header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+payloadLen)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + + tcp4InvalidCsum := make([]byte, len(tcp4)) + copy(tcp4InvalidCsum, tcp4) + at := 20 + 16 + tcp4InvalidCsum[at] = ^tcp4InvalidCsum[at] + + tcp6ExtHeaderInvalidCsum := make([]byte, len(tcp6ExtHeader)) + copy(tcp6ExtHeaderInvalidCsum, tcp6ExtHeader) + at = 40 + 8 + 16 + tcp6ExtHeaderInvalidCsum[at] = ^tcp6ExtHeaderInvalidCsum[at] + + tests := []struct { + name string + input []byte + wantPB bool + }{ + { + "tcp4 packet valid csum", + tcp4, + true, + }, + { + "tcp6 with ext header valid csum", + tcp6ExtHeader, + true, + }, + { + "tcp4 packet invalid csum", + tcp4InvalidCsum, + false, + }, + { + "tcp6 with ext header invalid csum", + tcp6ExtHeaderInvalidCsum, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &packet.Parsed{} + p.Decode(tt.input) + got := rxChecksumOffload(p) + if tt.wantPB != (got != nil) { + t.Fatalf("wantPB = %v != (got != nil): %v", tt.wantPB, got != nil) + } + if tt.wantPB { + gotBuf := got.ToBuffer() + if !bytes.Equal(tt.input, gotBuf.Flatten()) { + t.Fatal("output packet unequal to input") + } + } + }) + } +} diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index c575607cb..c7ec29437 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -5,7 +5,6 @@ package netstack import ( - "bytes" "context" "errors" "expvar" @@ -21,7 +20,6 @@ import ( "sync/atomic" "time" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" @@ -284,10 +282,13 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) } } - linkEP := newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "") + var linkEP *linkEndpoint if runtime.GOOS == "linux" { // TODO(jwhited): add Windows support https://github.com/tailscale/corp/issues/21874 + linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", enableGRO) linkEP.SupportedGSOKind = stack.HostGSOSupported + } else { + linkEP = newLinkEndpoint(512, uint32(tstun.DefaultTUNMTU()), "", disableGRO) } if tcpipProblem := ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) @@ -336,6 +337,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound + ns.tundev.EndPacketVectorInboundFromWireGuardFlush = linkEP.flushGRO ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets stacksForMetrics.Store(ns, struct{}{}) return ns, nil @@ -737,23 +739,11 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) filter.Re // care about the packet; resume processing. return filter.Accept } - - var pn tcpip.NetworkProtocolNumber - switch p.IPVersion { - case 4: - pn = header.IPv4ProtocolNumber - case 6: - pn = header.IPv6ProtocolNumber - } if debugPackets { ns.logf("[v2] service packet in (from %v): % x", p.Src, p.Buffer()) } - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(bytes.Clone(p.Buffer())), - }) - ns.linkEP.InjectInbound(pn, packetBuf) - packetBuf.DecRef() + ns.linkEP.injectInbound(p) return filter.DropSilently } @@ -794,7 +784,7 @@ func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet. func (ns *Impl) inject() { for { pkt := ns.linkEP.ReadContext(ns.ctx) - if pkt.IsNil() { + if pkt == nil { if ns.ctx.Err() != nil { // Return without logging. return @@ -1038,21 +1028,10 @@ func (ns *Impl) injectInbound(p *packet.Parsed, t *tstun.Wrapper) filter.Respons return filter.DropSilently } - var pn tcpip.NetworkProtocolNumber - switch p.IPVersion { - case 4: - pn = header.IPv4ProtocolNumber - case 6: - pn = header.IPv6ProtocolNumber - } if debugPackets { ns.logf("[v2] packet in (from %v): % x", p.Src, p.Buffer()) } - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(bytes.Clone(p.Buffer())), - }) - ns.linkEP.InjectInbound(pn, packetBuf) - packetBuf.DecRef() + ns.linkEP.enqueueGRO(p) // We've now delivered this to netstack, so we're done. // Instead of returning a filter.Accept here (which would also