From ff1d0aa027f9e8de36d8f4a4aba67f575534cd06 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 26 Aug 2024 22:21:14 -0700 Subject: [PATCH] tstest/natlab/vnet: start adding tests And refactor some of vnet.go for testability. The only behavioral change (with a new test) is that ethernet broadcasts no longer get sent back to the sender. Updates #13038 Change-Id: Ic2e7e7d6d8805b7b7f2b5c52c2c5ba97101cef14 Signed-off-by: Brad Fitzpatrick --- tstest/natlab/vnet/conf.go | 20 ++- tstest/natlab/vnet/vnet.go | 251 ++++++++++++++++++++------------ tstest/natlab/vnet/vnet_test.go | 232 +++++++++++++++++++++++++++++ 3 files changed, 403 insertions(+), 100 deletions(-) create mode 100644 tstest/natlab/vnet/vnet_test.go diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index 70c666f7b..42629c34e 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -6,7 +6,6 @@ package vnet import ( "cmp" "fmt" - "log" "net/netip" "os" "slices" @@ -61,6 +60,11 @@ func (c *Config) FirstNetwork() *Network { return c.networks[0] } +func nodeMac(n int) MAC { + // 52=TS then 0xcc for cccclient + return MAC{0x52, 0xcc, 0xcc, 0xcc, 0xcc, byte(n)} +} + // AddNode creates a new node in the world. // // The opts may be of the following types: @@ -70,10 +74,10 @@ func (c *Config) FirstNetwork() *Network { // On an error or unknown opt type, AddNode returns a // node with a carried error that gets returned later. func (c *Config) AddNode(opts ...any) *Node { - num := len(c.nodes) + num := len(c.nodes) + 1 n := &Node{ - num: num + 1, - mac: MAC{0x52, 0xcc, 0xcc, 0xcc, 0xcc, byte(num) + 1}, // 52=TS then 0xcc for ccclient + num: num, + mac: nodeMac(num), } c.nodes = append(c.nodes, n) for _, o := range opts { @@ -130,10 +134,10 @@ type TailscaledEnv struct { // On an error or unknown opt type, AddNetwork returns a // network with a carried error that gets returned later. func (c *Config) AddNetwork(opts ...any) *Network { - num := len(c.networks) + num := len(c.networks) + 1 n := &Network{ - num: num + 1, - mac: MAC{0x52, 0xee, 0xee, 0xee, 0xee, byte(num) + 1}, // 52=TS then 0xee for 'etwork + num: num, + mac: MAC{0x52, 0xee, 0xee, 0xee, 0xee, byte(num)}, // 52=TS then 0xee for 'etwork } c.networks = append(c.networks, n) for _, o := range opts { @@ -330,7 +334,7 @@ func (s *Server) initFromConfig(c *Config) error { lanIP4: conf.lanIP4, nodesByIP4: map[netip.Addr]*node{}, nodesByMAC: map[MAC]*node{}, - logf: logger.WithPrefix(log.Printf, fmt.Sprintf("[net-%v] ", conf.mac)), + logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)), } netOfConf[conf] = n s.networks.Add(n) diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index a8e5b6ae5..cdc7adb30 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -9,12 +9,9 @@ package vnet // TODO: -// - [ ] port mapping actually working -// - [ ] conf to let you firewall things // - [ ] tests for NAT tables import ( - "bufio" "bytes" "context" "crypto/tls" @@ -23,7 +20,9 @@ import ( "errors" "fmt" "io" + "iter" "log" + "maps" "math/rand/v2" "net" "net/http" @@ -493,18 +492,22 @@ type portMapping struct { expiry time.Time } -type writerFunc func([]byte, *net.UnixAddr, int) +// writerFunc is a function that writes an Ethernet frame to a connected client. +// +// ethFrame is the Ethernet frame to write. +// +// interfaceIndexID is the interface ID for the pcap file. +type writerFunc func(dst vmClient, ethFrame []byte, interfaceIndexID int) -// Encapsulates both a write function, an optional outbound socket address -// for dgram mode and an interfaceID for packet captures. +// networkWriter are the arguments to a writerFunc and the writerFunc. type networkWriter struct { - writer writerFunc // Function to write packets to the network - addr *net.UnixAddr // Outbound socket address for dgram mode - interfaceID int // The interface ID of the src node (for writing pcaps) + writer writerFunc // Function to write packets to the network + c vmClient + interfaceID int // The interface ID of the src node (for writing pcaps) } -func (nw *networkWriter) write(b []byte) { - nw.writer(b, nw.addr, nw.interfaceID) +func (nw networkWriter) write(b []byte) { + nw.writer(nw.c, b, nw.interfaceID) } type network struct { @@ -540,19 +543,20 @@ type network struct { writers syncs.Map[MAC, networkWriter] // MAC -> to networkWriter for that MAC } -// Regsiters a writerFunc for a MAC address. -// raddr is and optional outbound socket address of the client interface for dgram mode. -// Pass nil for the writerFunc to deregister the writer. -func (n *network) registerWriter(mac MAC, raddr *net.UnixAddr, interfaceID int, wf writerFunc) { - if wf != nil { - n.writers.Store(mac, networkWriter{ - writer: wf, - addr: raddr, - interfaceID: interfaceID, - }) - } else { - n.writers.Delete(mac) +// registerWriter registers a client address with a MAC address. +func (n *network) registerWriter(mac MAC, c vmClient) { + nw := networkWriter{ + writer: n.s.writeEthernetFrameToVM, + c: c, } + if node, ok := n.s.nodeByMAC[mac]; ok { + nw.interfaceID = node.interfaceID + } + n.writers.Store(mac, nw) +} + +func (n *network) unregisterWriter(mac MAC) { + n.writers.Delete(mac) } func (n *network) MACOfIP(ip netip.Addr) (_ MAC, ok bool) { @@ -616,6 +620,8 @@ type Server struct { wg sync.WaitGroup blendReality bool + optLogf func(format string, args ...any) // or nil to use log.Printf + derpIPs set.Set[netip.Addr] nodes []*node @@ -627,12 +633,28 @@ type Server struct { derps []*derpServer pcapWriter *pcapWriter + // writeMu serializes all writes to VM clients. + writeMu sync.Mutex + scratch []byte + mu sync.Mutex agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all agentDialer map[*node]DialFunc } +func (s *Server) logf(format string, args ...any) { + if s.optLogf != nil { + s.optLogf(format, args...) + } else { + log.Printf(format, args...) + } +} + +func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { + s.optLogf = logf +} + type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) var derpMap = &tailcfg.DERPMap{ @@ -713,6 +735,23 @@ func (s *Server) Close() { s.wg.Wait() } +// MACs returns the MAC addresses of the configured nodes. +func (s *Server) MACs() iter.Seq[MAC] { + return maps.Keys(s.nodeByMAC) +} + +func (s *Server) RegisterSinkForTest(mac MAC, fn func(eth []byte)) { + n, ok := s.nodeByMAC[mac] + if !ok { + log.Fatalf("RegisterSinkForTest: unknown MAC %v", mac) + } + n.net.writers.Store(mac, networkWriter{ + writer: func(_ vmClient, eth []byte, _ int) { + fn(eth) + }, + }) +} + func (s *Server) HWAddr(mac MAC) net.HardwareAddr { // TODO: cache return net.HardwareAddr(mac[:]) @@ -725,6 +764,53 @@ const ( ProtocolUnixDGRAM // for macOS Virtualization.Framework and VZFileHandleNetworkDeviceAttachment ) +func (s *Server) writeEthernetFrameToVM(c vmClient, ethPkt []byte, interfaceID int) { + s.writeMu.Lock() + defer s.writeMu.Unlock() + + if ethPkt == nil { + return + } + switch c.proto() { + case ProtocolQEMU: + s.scratch = binary.BigEndian.AppendUint32(s.scratch[:0], uint32(len(ethPkt))) + s.scratch = append(s.scratch, ethPkt...) + if _, err := c.uc.Write(s.scratch); err != nil { + log.Printf("Write pkt: %v", err) + } + + case ProtocolUnixDGRAM: + if _, err := c.uc.WriteToUnix(ethPkt, c.raddr); err != nil { + log.Printf("Write pkt : %v", err) + return + } + } + + must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(ethPkt), + Length: len(ethPkt), + InterfaceIndex: interfaceID, + }, ethPkt)) +} + +// vmClient is a comparable value representing a connection from a VM, either a +// QEMU-style client (with streams over a Unix socket) or a datagram based +// client (such as macOS Virtualization.framework clients). +type vmClient struct { + uc *net.UnixConn + raddr *net.UnixAddr // nil for QEMU-style clients using streams; else datagram source +} + +func (c vmClient) proto() Protocol { + if c.raddr == nil { + return ProtocolQEMU + } + return ProtocolUnixDGRAM +} + +const ethernetHeaderLen = 14 + // Handles a single connection from a QEMU-style client or muxd connections for dgram mode func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { if s.shuttingDown.Load() { @@ -738,51 +824,8 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { log.Printf("Got conn %T %p", uc, uc) defer uc.Close() - bw := bufio.NewWriterSize(uc, 2<<10) - var writeMu sync.Mutex - writePkt := func(pkt []byte, raddr *net.UnixAddr, interfaceID int) { - if pkt == nil { - return - } - writeMu.Lock() - defer writeMu.Unlock() - switch proto { - case ProtocolQEMU: - hdr := binary.BigEndian.AppendUint32(bw.AvailableBuffer()[:0], uint32(len(pkt))) - if _, err := bw.Write(hdr); err != nil { - log.Printf("Write hdr: %v", err) - return - } - - if _, err := bw.Write(pkt); err != nil { - log.Printf("Write pkt: %v", err) - return - - } - case ProtocolUnixDGRAM: - if raddr == nil { - log.Printf("Write pkt: dgram mode write failure, no outbound socket address") - return - } - - if _, err := uc.WriteToUnix(pkt, raddr); err != nil { - log.Printf("Write pkt : %v", err) - return - } - } - - if err := bw.Flush(); err != nil { - log.Printf("Flush: %v", err) - } - must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ - Timestamp: time.Now(), - CaptureLength: len(pkt), - Length: len(pkt), - InterfaceIndex: interfaceID, - }, pkt)) - } - buf := make([]byte, 16<<10) + didReg := map[MAC]bool{} for { var packetRaw []byte var raddr *net.UnixAddr @@ -817,40 +860,55 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { } packetRaw = buf[4 : 4+n] // raw ethernet frame } + c := vmClient{uc, raddr} - packet := gopacket.NewPacket(packetRaw, layers.LayerTypeEthernet, gopacket.Lazy) - le, ok := packet.LinkLayer().(*layers.Ethernet) - if !ok || len(le.SrcMAC) != 6 || len(le.DstMAC) != 6 { - log.Printf("ignoring non-Ethernet packet: % 02x", packetRaw) + // For the first packet from a MAC, register a writerFunc to write to the VM. + if len(packetRaw) < ethernetHeaderLen { continue } - ep := EthernetPacket{le, packet} - - srcMAC := ep.SrcMAC() + srcMAC := MAC(packetRaw[6:12]) srcNode, ok := s.nodeByMAC[srcMAC] if !ok { - log.Printf("[conn %p] got frame from unknown MAC %v", uc, srcMAC) + log.Printf("[conn %p] got frame from unknown MAC %v", c.uc, srcMAC) continue } - - // Register a writer for the source MAC address if one doesn't exist. - if _, ok := srcNode.net.writers.Load(srcMAC); !ok { - log.Printf("[conn %p] Registering writer for MAC %v is node %v", uc, srcMAC, srcNode.lanIP) - srcNode.net.registerWriter(srcMAC, raddr, srcNode.interfaceID, writePkt) - defer srcNode.net.registerWriter(srcMAC, nil, 0, nil) - continue + if !didReg[srcMAC] { + didReg[srcMAC] = true + log.Printf("[conn %p] Registering writer for MAC %v, node %v", c.uc, srcMAC, srcNode.lanIP) + srcNode.net.registerWriter(srcMAC, c) + defer srcNode.net.unregisterWriter(srcMAC) } - must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ - Timestamp: time.Now(), - CaptureLength: len(packetRaw), - Length: len(packetRaw), - InterfaceIndex: srcNode.interfaceID, - }, packetRaw)) - srcNode.net.HandleEthernetPacket(ep) + if err := s.handleEthernetFrameFromVM(packetRaw); err != nil { + srcNode.net.logf("handleEthernetFrameFromVM: [conn %p], %v", c.uc, err) + } } } +func (s *Server) handleEthernetFrameFromVM(packetRaw []byte) error { + packet := gopacket.NewPacket(packetRaw, layers.LayerTypeEthernet, gopacket.Lazy) + le, ok := packet.LinkLayer().(*layers.Ethernet) + if !ok || len(le.SrcMAC) != 6 || len(le.DstMAC) != 6 { + return fmt.Errorf("ignoring non-Ethernet packet: % 02x", packetRaw) + } + ep := EthernetPacket{le, packet} + + srcMAC := ep.SrcMAC() + srcNode, ok := s.nodeByMAC[srcMAC] + if !ok { + return fmt.Errorf("got frame from unknown MAC %v", srcMAC) + } + + must.Do(s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(packetRaw), + Length: len(packetRaw), + InterfaceIndex: srcNode.interfaceID, + }, packetRaw)) + srcNode.net.HandleEthernetPacket(ep) + return nil +} + func (s *Server) routeUDPPacket(up UDPPacket) { // Find which network owns this based on the destination IP // and all the known networks' wan IPs. @@ -896,8 +954,10 @@ func (n *network) writeEth(res []byte) bool { if dstMAC.IsBroadcast() { num := 0 n.writers.Range(func(mac MAC, nw networkWriter) bool { - num++ - nw.write(res) + if mac != srcMAC { + num++ + nw.write(res) + } return true }) return num > 0 @@ -922,6 +982,11 @@ func (n *network) writeEth(res []byte) bool { var ( macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02} + macBroadcast = MAC{0xff, 0xff, 0xff, 0xff, 0xff, 0xff} +) + +const ( + testingEthertype layers.EthernetType = 0x1234 ) func (n *network) HandleEthernetPacket(ep EthernetPacket) { @@ -943,6 +1008,8 @@ func (n *network) HandleEthernetPacket(ep EthernetPacket) { default: n.logf("Dropping non-IP packet: %v", ep.le.EthernetType) return + case 0x1234: + // Permitted for testing. Not a real ethertype. case layers.EthernetTypeARP: res, err := n.createARPResponse(packet) if err != nil { @@ -1309,7 +1376,7 @@ func (n *network) handleIPv6RouterSolicitation(ep EthernetPacket, rs *layers.ICM DstMAC: ep.SrcMAC().HWAddr(), EthernetType: layers.EthernetTypeIPv6, } - n.logf("sending IPv6 router advertisement to %v from %v", eth.SrcMAC, eth.DstMAC) + n.logf("sending IPv6 router advertisement to %v from %v", eth.DstMAC, eth.SrcMAC) ip := &layers.IPv6{ Version: 6, HopLimit: 255, diff --git a/tstest/natlab/vnet/vnet_test.go b/tstest/natlab/vnet/vnet_test.go new file mode 100644 index 000000000..d76718d31 --- /dev/null +++ b/tstest/natlab/vnet/vnet_test.go @@ -0,0 +1,232 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package vnet + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// TestPacketSideEffects tests that upon receiving certain +// packets, other packets and/or log statements are generated. +func TestPacketSideEffects(t *testing.T) { + type netTest struct { + name string + pkt []byte // to send + check func(*sideEffects) error + } + tests := []struct { + netName string // name of the Server returned by setup + setup func() (*Server, error) + tests []netTest // to run against setup's Server + }{ + { + netName: "basic", + setup: func() (*Server, error) { + var c Config + nw := c.AddNetwork("192.168.0.1/24") + c.AddNode(nw) + c.AddNode(nw) + return New(&c) + }, + tests: []netTest{ + { + name: "drop-rando-ethertype", + pkt: mkEth(nodeMac(2), nodeMac(1), 0x4321, []byte("hello")), + check: all( + logSubstr("Dropping non-IP packet"), + ), + }, + { + name: "dst-mac-between-nodes", + pkt: mkEth(nodeMac(2), nodeMac(1), testingEthertype, []byte("hello")), + check: all( + numPkts(1), + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=52:cc:cc:cc:cc:02 EthernetType=UnknownEthernetType"), + pktSubstr("Unable to decode EthernetType 4660"), + ), + }, + { + name: "broadcast-mac", + pkt: mkEth(macBroadcast, nodeMac(1), testingEthertype, []byte("hello")), + check: all( + numPkts(1), + pktSubstr("SrcMAC=52:cc:cc:cc:cc:01 DstMAC=ff:ff:ff:ff:ff:ff EthernetType=UnknownEthernetType"), + pktSubstr("Unable to decode EthernetType 4660"), + ), + }, + }, + }, + { + netName: "v6", + setup: func() (*Server, error) { + var c Config + c.AddNode(c.AddNetwork("2000:52::1/64")) + return New(&c) + }, + tests: []netTest{ + { + name: "router-solicit", + pkt: mkIPv6RouterSolicit(nodeMac(1), netip.MustParseAddr("fe80::50cc:ccff:fecc:cc01")), + check: all( + logSubstr("sending IPv6 router advertisement to 52:cc:cc:cc:cc:01 from 52:ee:ee:ee:ee:01"), + numPkts(1), + pktSubstr("TypeCode=RouterAdvertisement"), + pktSubstr("= ICMPv6RouterAdvertisement"), + pktSubstr("SrcMAC=52:ee:ee:ee:ee:01 DstMAC=52:cc:cc:cc:cc:01 EthernetType=IPv6"), + ), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.netName, func(t *testing.T) { + s, err := tt.setup() + if err != nil { + t.Fatal(err) + } + defer s.Close() + + for _, tt := range tt.tests { + t.Run(tt.name, func(t *testing.T) { + se := &sideEffects{} + s.SetLoggerForTest(se.logf) + for mac := range s.MACs() { + s.RegisterSinkForTest(mac, func(eth []byte) { + se.got = append(se.got, eth) + }) + } + + s.handleEthernetFrameFromVM(tt.pkt) + if tt.check != nil { + if err := tt.check(se); err != nil { + t.Fatal(err) + } + } + }) + } + }) + } + +} + +// mkEth encodes an ethernet frame with the given payload. +func mkEth(dst, src MAC, ethType layers.EthernetType, payload []byte) []byte { + ret := make([]byte, 0, 14+len(payload)) + ret = append(ret, dst.HWAddr()...) + ret = append(ret, src.HWAddr()...) + ret = binary.BigEndian.AppendUint16(ret, uint16(ethType)) + return append(ret, payload...) +} + +// mkIPv6RouterSolicit makes a IPv6 router solicitation packet +// ethernet frame. +func mkIPv6RouterSolicit(srcMAC MAC, srcIP netip.Addr) []byte { + ip := &layers.IPv6{ + Version: 6, + HopLimit: 255, + NextHeader: layers.IPProtocolICMPv6, + SrcIP: srcIP.AsSlice(), + DstIP: net.ParseIP("ff02::2"), // all routers + } + icmp := &layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeRouterSolicitation, 0), + } + + ra := &layers.ICMPv6RouterSolicitation{ + Options: []layers.ICMPv6Option{{ + Type: layers.ICMPv6OptSourceAddress, + Data: srcMAC.HWAddr(), + }}, + } + icmp.SetNetworkLayerForChecksum(ip) + buf := gopacket.NewSerializeBuffer() + options := gopacket.SerializeOptions{FixLengths: true, ComputeChecksums: true} + if err := gopacket.SerializeLayers(buf, options, ip, icmp, ra); err != nil { + panic(fmt.Sprintf("serializing ICMPv6 RA: %v", err)) + } + + return mkEth(macAllRouters, srcMAC, layers.EthernetTypeIPv6, buf.Bytes()) +} + +// sideEffects gathers side effects as a result of sending a packet and tests +// whether those effects were as desired. +type sideEffects struct { + logs []string + got [][]byte // ethernet packets received +} + +func (se *sideEffects) logf(format string, args ...any) { + se.logs = append(se.logs, fmt.Sprintf(format, args...)) +} + +// all aggregates several side effects checkers into one. +func all(checks ...func(*sideEffects) error) func(*sideEffects) error { + return func(se *sideEffects) error { + var errs []error + for _, check := range checks { + if err := check(se); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + } +} + +// logSubstr returns a side effect checker func that checks +// whether a log statement was output containing substring sub. +func logSubstr(sub string) func(*sideEffects) error { + return func(se *sideEffects) error { + for _, log := range se.logs { + if strings.Contains(log, sub) { + return nil + } + } + return fmt.Errorf("expected log substring %q not found; log statements were %q", sub, se.logs) + } +} + +// pkgSubstr returns a side effect checker func that checks whether an ethernet +// packet was received that, once decoded and stringified by gopacket, contains +// substring sub. +func pktSubstr(sub string) func(*sideEffects) error { + return func(se *sideEffects) error { + var pkts bytes.Buffer + for i, pkt := range se.got { + pkt := gopacket.NewPacket(pkt, layers.LayerTypeEthernet, gopacket.Lazy) + got := pkt.String() + fmt.Fprintf(&pkts, "[pkt%d]:\n%s\n", i, got) + if strings.Contains(got, sub) { + return nil + } + } + return fmt.Errorf("packet summary with substring %q not found; packets were:\n%s", sub, pkts.Bytes()) + } +} + +// numPkts returns a side effect checker func that checks whether +// the received number of ethernet packets was the given number. +func numPkts(want int) func(*sideEffects) error { + return func(se *sideEffects) error { + if len(se.got) == want { + return nil + } + var pkts bytes.Buffer + for i, pkt := range se.got { + pkt := gopacket.NewPacket(pkt, layers.LayerTypeEthernet, gopacket.Lazy) + got := pkt.String() + fmt.Fprintf(&pkts, "[pkt%d]:\n%s\n", i, got) + } + return fmt.Errorf("got %d packets, want %d. packets were:\n%s", len(se.got), want, pkts.Bytes()) + } +}