diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 45a7db444..73ef1bd29 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -78,6 +78,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm + github.com/bits-and-blooms/bitset from github.com/gaissmai/bart L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw L github.com/coreos/go-systemd/v22/dbus from tailscale.com/clientupdate LD 💣 github.com/creack/pty from tailscale.com/ssh/tailssh @@ -89,6 +90,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de LW 💣 github.com/digitalocean/go-smbios/smbios from tailscale.com/posture 💣 github.com/djherbis/times from tailscale.com/tailfs/tailfsimpl github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/gaissmai/bart from tailscale.com/net/tstun W 💣 github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+ W 💣 github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet L 💣 github.com/godbus/dbus/v5 from github.com/coreos/go-systemd/v22/dbus+ @@ -308,7 +310,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ 💣 tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ - tailscale.com/net/tstun/table from tailscale.com/net/tstun tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ tailscale.com/paths from tailscale.com/client/tailscale+ 💣 tailscale.com/portlist from tailscale.com/ipn/ipnlocal @@ -324,7 +325,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/tailfs/tailfsimpl/compositedav from tailscale.com/tailfs/tailfsimpl tailscale.com/tailfs/tailfsimpl/dirfs from tailscale.com/tailfs/tailfsimpl+ tailscale.com/tailfs/tailfsimpl/shared from tailscale.com/tailfs/tailfsimpl+ - 💣 tailscale.com/tempfork/device from tailscale.com/net/tstun/table LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock tailscale.com/tka from tailscale.com/client/tailscale+ diff --git a/go.mod b/go.mod index 21816a21b..f255b5821 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/evanw/esbuild v0.19.11 github.com/frankban/quicktest v1.14.6 github.com/fxamacker/cbor/v2 v2.5.0 + github.com/gaissmai/bart v0.4.0 github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 github.com/go-logr/zapr v1.3.0 github.com/go-ole/go-ole v1.3.0 @@ -115,6 +116,7 @@ require ( require ( github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/bits-and-blooms/bitset v1.13.0 // indirect github.com/cyphar/filepath-securejoin v0.2.4 // indirect github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect github.com/dave/brenda v1.1.0 // indirect diff --git a/go.sum b/go.sum index f91d77833..30888dbcf 100644 --- a/go.sum +++ b/go.sum @@ -167,6 +167,8 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bkielbasa/cyclop v1.2.0 h1:7Jmnh0yL2DjKfw28p86YTd/B4lRGcNuu12sKE35sM7A= github.com/bkielbasa/cyclop v1.2.0/go.mod h1:qOI0yy6A7dYC4Zgsa72Ppm9kONl0RoIlPbzot9mhmeI= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4= @@ -294,6 +296,8 @@ github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADi github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= +github.com/gaissmai/bart v0.4.0 h1:ImIFoETsNMBzUr21tMGD82GQIwAb555fI6uxEyCHBTI= +github.com/gaissmai/bart v0.4.0/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= diff --git a/licenses/tailscale.md b/licenses/tailscale.md index 72799d53e..02d41a512 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -111,5 +111,5 @@ Some packages may only be included on certain architectures or operating systems - [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE)) - [software.sslmate.com/src/go-pkcs12](https://pkg.go.dev/software.sslmate.com/src/go-pkcs12) ([BSD-3-Clause](https://github.com/SSLMate/go-pkcs12/blob/v0.4.0/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) - - [tailscale.com/tempfork/device](https://pkg.go.dev/tailscale.com/tempfork/device) ([MIT](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/device/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart?tab=MIT-1-ov-file#readme)) - [tailscale.com/tempfork/gliderlabs/ssh](https://pkg.go.dev/tailscale.com/tempfork/gliderlabs/ssh) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/gliderlabs/ssh/LICENSE)) diff --git a/net/tstun/table/table.go b/net/tstun/table/table.go deleted file mode 100644 index ad359587a..000000000 --- a/net/tstun/table/table.go +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package table provides a Routing Table implementation which allows -// looking up the peer that should be used to route a given IP address. -package table - -import ( - "net/netip" - - "tailscale.com/tempfork/device" - "tailscale.com/types/key" - "tailscale.com/util/mak" -) - -// RoutingTableBuilder is a builder for a RoutingTable. -// It is not safe for concurrent use. -type RoutingTableBuilder struct { - // peers is a map from node public key to the peer that owns that key. - // It is only used to handle insertions and deletions. - peers map[key.NodePublic]*device.Peer - - // prefixTrie is a trie of prefixes which facilitates looking up the - // peer that owns a given IP address. - prefixTrie *device.AllowedIPs -} - -// Remove removes the given peer from the routing table. -func (t *RoutingTableBuilder) Remove(peer key.NodePublic) { - p, ok := t.peers[peer] - if !ok { - return - } - t.prefixTrie.RemoveByPeer(p) - delete(t.peers, peer) -} - -// InsertOrReplace inserts the given peer and prefixes into the routing table. -func (t *RoutingTableBuilder) InsertOrReplace(peer key.NodePublic, pfxs ...netip.Prefix) { - p, ok := t.peers[peer] - if !ok { - p = device.NewPeer(peer) - mak.Set(&t.peers, peer, p) - } else { - t.prefixTrie.RemoveByPeer(p) - } - if len(pfxs) == 0 { - return - } - if t.prefixTrie == nil { - t.prefixTrie = new(device.AllowedIPs) - } - for _, pfx := range pfxs { - t.prefixTrie.Insert(pfx, p) - } -} - -// Build returns a RoutingTable that can be used to look up peers. -// Build resets the RoutingTableBuilder to its zero value. -func (t *RoutingTableBuilder) Build() *RoutingTable { - pt := t.prefixTrie - t.prefixTrie = nil - t.peers = nil - return &RoutingTable{ - prefixTrie: pt, - } -} - -// RoutingTable provides a mapping from IP addresses to peers identified by -// their public node key. It is used to find the peer that should be used to -// route a given IP address. -// It is immutable after creation. -// -// It is safe for concurrent use. -type RoutingTable struct { - prefixTrie *device.AllowedIPs -} - -// Lookup returns the peer that would be used to route the given IP address. -// If no peer is found, Lookup returns the zero value. -func (t *RoutingTable) Lookup(ip netip.Addr) (_ key.NodePublic, ok bool) { - if t == nil { - return key.NodePublic{}, false - } - p := t.prefixTrie.Lookup(ip.AsSlice()) - if p == nil { - return key.NodePublic{}, false - } - return p.Key(), true -} diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 84d0c0821..0ef8d1baa 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "go4.org/mem" @@ -27,7 +28,6 @@ import ( "tailscale.com/net/packet" "tailscale.com/net/packet/checksum" "tailscale.com/net/tsaddr" - "tailscale.com/net/tstun/table" "tailscale.com/syncs" "tailscale.com/tstime/mono" "tailscale.com/types/ipproto" @@ -611,15 +611,13 @@ type natFamilyConfig struct { // peers will use to connect to this node. listenAddrs views.Map[netip.Addr, struct{}] // masqAddr -> struct{} - // dstMasqAddrs is map of dst addresses to their respective MasqueradeAsIP - // addresses. The MasqueradeAsIP address is the address that should be used - // as the source address for packets to dst. - dstMasqAddrs views.Map[key.NodePublic, netip.Addr] // dst -> masqAddr + // dstMasqAddrs is the routing table used to map a given dst IP to the + // respective MasqueradeAsIP address. The MasqueradeAsIP address is the + // address that should be used as the source address for packets to dst. + dstMasqAddrs *bart.Table[netip.Addr] - // dstAddrToPeerKeyMapper is the routing table used to map a given dst IP to - // the peer key responsible for that IP. - // It only contains peers that require a MasqueradeAsIP address. - dstAddrToPeerKeyMapper *table.RoutingTable + // masqAddrCounts is a count of peers by MasqueradeAsIP. + masqAddrCounts map[netip.Addr]int } func (c *natFamilyConfig) String() string { @@ -640,15 +638,10 @@ func (c *natFamilyConfig) String() string { i++ return true }) - count := map[netip.Addr]int{} - c.dstMasqAddrs.Range(func(_ key.NodePublic, v netip.Addr) bool { - count[v]++ - return true - }) i = 0 b.WriteString("], dstMasqAddrs: [") - for k, v := range count { + for k, v := range c.masqAddrCounts { if i > 0 { b.WriteString(", ") } @@ -682,14 +675,11 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr { if oldSrc != c.nativeAddr { return oldSrc } - p, ok := c.dstAddrToPeerKeyMapper.Lookup(dst) + eip, ok := c.dstMasqAddrs.Get(dst) if !ok { return oldSrc } - if eip, ok := c.dstMasqAddrs.GetOk(p); ok { - return eip - } - return oldSrc + return eip } // natConfigFromWGConfig generates a natFamilyConfig from nm, @@ -712,9 +702,9 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami } var ( - rt table.RoutingTableBuilder - dstMasqAddrs map[key.NodePublic]netip.Addr - listenAddrs set.Set[netip.Addr] + rt bart.Table[netip.Addr] + masqAddrCounts = map[netip.Addr]int{} + listenAddrs set.Set[netip.Addr] ) // When using an exit node that requires masquerading, we need to @@ -747,17 +737,20 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami } else { continue } - rt.InsertOrReplace(p.PublicKey, p.AllowedIPs...) - mak.Set(&dstMasqAddrs, p.PublicKey, addrToUse) + + masqAddrCounts[addrToUse]++ + for _, ip := range p.AllowedIPs { + rt.Insert(ip, addrToUse) + } } - if len(listenAddrs) == 0 && len(dstMasqAddrs) == 0 { + if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 { return nil } return &natFamilyConfig{ - nativeAddr: nativeAddr, - listenAddrs: views.MapOf(listenAddrs), - dstMasqAddrs: views.MapOf(dstMasqAddrs), - dstAddrToPeerKeyMapper: rt.Build(), + nativeAddr: nativeAddr, + listenAddrs: views.MapOf(listenAddrs), + dstMasqAddrs: &rt, + masqAddrCounts: masqAddrCounts, } } diff --git a/tempfork/device/LICENSE b/tempfork/device/LICENSE deleted file mode 100644 index a3fdf73d5..000000000 --- a/tempfork/device/LICENSE +++ /dev/null @@ -1,17 +0,0 @@ -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies -of the Software, and to permit persons to whom the Software is furnished to do -so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/tempfork/device/README.md b/tempfork/device/README.md deleted file mode 100644 index 945e25527..000000000 --- a/tempfork/device/README.md +++ /dev/null @@ -1,3 +0,0 @@ -This is a fork of golang.zx2c4.com/wireguard/device that only keeps the bare -minimum data structures required for AllowedIPs. It is meant to be short lived -until we replace it with our version of a routing table. diff --git a/tempfork/device/allowedips.go b/tempfork/device/allowedips.go deleted file mode 100644 index fa46f97c1..000000000 --- a/tempfork/device/allowedips.go +++ /dev/null @@ -1,294 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "container/list" - "encoding/binary" - "errors" - "math/bits" - "net" - "net/netip" - "sync" - "unsafe" -) - -type parentIndirection struct { - parentBit **trieEntry - parentBitType uint8 -} - -type trieEntry struct { - peer *Peer - child [2]*trieEntry - parent parentIndirection - cidr uint8 - bitAtByte uint8 - bitAtShift uint8 - bits []byte - perPeerElem *list.Element -} - -func commonBits(ip1, ip2 []byte) uint8 { - size := len(ip1) - if size == net.IPv4len { - a := binary.BigEndian.Uint32(ip1) - b := binary.BigEndian.Uint32(ip2) - x := a ^ b - return uint8(bits.LeadingZeros32(x)) - } else if size == net.IPv6len { - a := binary.BigEndian.Uint64(ip1) - b := binary.BigEndian.Uint64(ip2) - x := a ^ b - if x != 0 { - return uint8(bits.LeadingZeros64(x)) - } - a = binary.BigEndian.Uint64(ip1[8:]) - b = binary.BigEndian.Uint64(ip2[8:]) - x = a ^ b - return 64 + uint8(bits.LeadingZeros64(x)) - } else { - panic("Wrong size bit string") - } -} - -func (node *trieEntry) addToPeerEntries() { - node.perPeerElem = node.peer.trieEntries.PushBack(node) -} - -func (node *trieEntry) removeFromPeerEntries() { - if node.perPeerElem != nil { - node.peer.trieEntries.Remove(node.perPeerElem) - node.perPeerElem = nil - } -} - -func (node *trieEntry) choose(ip []byte) byte { - return (ip[node.bitAtByte] >> node.bitAtShift) & 1 -} - -func (node *trieEntry) maskSelf() { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - for i := 0; i < len(mask); i++ { - node.bits[i] &= mask[i] - } -} - -func (node *trieEntry) zeroizePointers() { - // Make the garbage collector's life slightly easier - node.peer = nil - node.child[0] = nil - node.child[1] = nil - node.parent.parentBit = nil -} - -func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { - for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { - parent = node - if parent.cidr == cidr { - exact = true - return - } - bit := node.choose(ip) - node = node.child[bit] - } - return -} - -func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { - if *trie.parentBit == nil { - node := &trieEntry{ - peer: peer, - parent: trie, - bits: ip, - cidr: cidr, - bitAtByte: cidr / 8, - bitAtShift: 7 - (cidr % 8), - } - node.maskSelf() - node.addToPeerEntries() - *trie.parentBit = node - return - } - node, exact := (*trie.parentBit).nodePlacement(ip, cidr) - if exact { - node.removeFromPeerEntries() - node.peer = peer - node.addToPeerEntries() - return - } - - newNode := &trieEntry{ - peer: peer, - bits: ip, - cidr: cidr, - bitAtByte: cidr / 8, - bitAtShift: 7 - (cidr % 8), - } - newNode.maskSelf() - newNode.addToPeerEntries() - - var down *trieEntry - if node == nil { - down = *trie.parentBit - } else { - bit := node.choose(ip) - down = node.child[bit] - if down == nil { - newNode.parent = parentIndirection{&node.child[bit], bit} - node.child[bit] = newNode - return - } - } - common := commonBits(down.bits, ip) - if common < cidr { - cidr = common - } - parent := node - - if newNode.cidr == cidr { - bit := newNode.choose(down.bits) - down.parent = parentIndirection{&newNode.child[bit], bit} - newNode.child[bit] = down - if parent == nil { - newNode.parent = trie - *trie.parentBit = newNode - } else { - bit := parent.choose(newNode.bits) - newNode.parent = parentIndirection{&parent.child[bit], bit} - parent.child[bit] = newNode - } - return - } - - node = &trieEntry{ - bits: append([]byte{}, newNode.bits...), - cidr: cidr, - bitAtByte: cidr / 8, - bitAtShift: 7 - (cidr % 8), - } - node.maskSelf() - - bit := node.choose(down.bits) - down.parent = parentIndirection{&node.child[bit], bit} - node.child[bit] = down - bit = node.choose(newNode.bits) - newNode.parent = parentIndirection{&node.child[bit], bit} - node.child[bit] = newNode - if parent == nil { - node.parent = trie - *trie.parentBit = node - } else { - bit := parent.choose(node.bits) - node.parent = parentIndirection{&parent.child[bit], bit} - parent.child[bit] = node - } -} - -func (node *trieEntry) lookup(ip []byte) *Peer { - var found *Peer - size := uint8(len(ip)) - for node != nil && commonBits(node.bits, ip) >= node.cidr { - if node.peer != nil { - found = node.peer - } - if node.bitAtByte == size { - break - } - bit := node.choose(ip) - node = node.child[bit] - } - return found -} - -type AllowedIPs struct { - IPv4 *trieEntry - IPv6 *trieEntry - mutex sync.RWMutex -} - -func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { - table.mutex.RLock() - defer table.mutex.RUnlock() - - for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { - node := elem.Value.(*trieEntry) - a, _ := netip.AddrFromSlice(node.bits) - if !cb(netip.PrefixFrom(a, int(node.cidr))) { - return - } - } -} - -func (table *AllowedIPs) RemoveByPeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - var next *list.Element - for elem := peer.trieEntries.Front(); elem != nil; elem = next { - next = elem.Next() - node := elem.Value.(*trieEntry) - - node.removeFromPeerEntries() - node.peer = nil - if node.child[0] != nil && node.child[1] != nil { - continue - } - bit := 0 - if node.child[0] == nil { - bit = 1 - } - child := node.child[bit] - if child != nil { - child.parent = node.parent - } - *node.parent.parentBit = child - if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { - node.zeroizePointers() - continue - } - parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) - if parent.peer != nil { - node.zeroizePointers() - continue - } - child = parent.child[node.parent.parentBitType^1] - if child != nil { - child.parent = parent.parent - } - *parent.parent.parentBit = child - node.zeroizePointers() - parent.zeroizePointers() - } -} - -func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - if prefix.Addr().Is6() { - ip := prefix.Addr().As16() - parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) - } else if prefix.Addr().Is4() { - ip := prefix.Addr().As4() - parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) - } else { - panic(errors.New("inserting unknown address type")) - } -} - -func (table *AllowedIPs) Lookup(ip []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - switch len(ip) { - case net.IPv6len: - return table.IPv6.lookup(ip) - case net.IPv4len: - return table.IPv4.lookup(ip) - default: - panic(errors.New("looking up unknown address type")) - } -} diff --git a/tempfork/device/allowedips_rand_test.go b/tempfork/device/allowedips_rand_test.go deleted file mode 100644 index 07065c30a..000000000 --- a/tempfork/device/allowedips_rand_test.go +++ /dev/null @@ -1,141 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "math/rand" - "net" - "net/netip" - "sort" - "testing" -) - -const ( - NumberOfPeers = 100 - NumberOfPeerRemovals = 4 - NumberOfAddresses = 250 - NumberOfTests = 10000 -) - -type SlowNode struct { - peer *Peer - cidr uint8 - bits []byte -} - -type SlowRouter []*SlowNode - -func (r SlowRouter) Len() int { - return len(r) -} - -func (r SlowRouter) Less(i, j int) bool { - return r[i].cidr > r[j].cidr -} - -func (r SlowRouter) Swap(i, j int) { - r[i], r[j] = r[j], r[i] -} - -func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { - for _, t := range r { - if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { - t.peer = peer - t.bits = addr - return r - } - } - r = append(r, &SlowNode{ - cidr: cidr, - bits: addr, - peer: peer, - }) - sort.Sort(r) - return r -} - -func (r SlowRouter) Lookup(addr []byte) *Peer { - for _, t := range r { - common := commonBits(t.bits, addr) - if common >= t.cidr { - return t.peer - } - } - return nil -} - -func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { - n := 0 - for _, x := range r { - if x.peer != peer { - r[n] = x - n++ - } - } - return r[:n] -} - -func TestTrieRandom(t *testing.T) { - var slow4, slow6 SlowRouter - var peers []*Peer - var allowedIPs AllowedIPs - - rand.Seed(1) - - for n := 0; n < NumberOfPeers; n++ { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n++ { - var addr4 [4]byte - rand.Read(addr4[:]) - cidr := uint8(rand.Intn(32) + 1) - index := rand.Intn(NumberOfPeers) - allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) - slow4 = slow4.Insert(addr4[:], cidr, peers[index]) - - var addr6 [16]byte - rand.Read(addr6[:]) - cidr = uint8(rand.Intn(128) + 1) - index = rand.Intn(NumberOfPeers) - allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) - slow6 = slow6.Insert(addr6[:], cidr, peers[index]) - } - - var p int - for p = 0; ; p++ { - for n := 0; n < NumberOfTests; n++ { - var addr4 [4]byte - rand.Read(addr4[:]) - peer1 := slow4.Lookup(addr4[:]) - peer2 := allowedIPs.Lookup(addr4[:]) - if peer1 != peer2 { - t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) - } - - var addr6 [16]byte - rand.Read(addr6[:]) - peer1 = slow6.Lookup(addr6[:]) - peer2 = allowedIPs.Lookup(addr6[:]) - if peer1 != peer2 { - t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) - } - } - if p >= len(peers) || p >= NumberOfPeerRemovals { - break - } - allowedIPs.RemoveByPeer(peers[p]) - slow4 = slow4.RemoveByPeer(peers[p]) - slow6 = slow6.RemoveByPeer(peers[p]) - } - for ; p < len(peers); p++ { - allowedIPs.RemoveByPeer(peers[p]) - } - - if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { - t.Error("Failed to remove all nodes from trie by peer") - } -} diff --git a/tempfork/device/allowedips_test.go b/tempfork/device/allowedips_test.go deleted file mode 100644 index cde068ec3..000000000 --- a/tempfork/device/allowedips_test.go +++ /dev/null @@ -1,247 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "math/rand" - "net" - "net/netip" - "testing" -) - -type testPairCommonBits struct { - s1 []byte - s2 []byte - match uint8 -} - -func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ - {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, - {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, - {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, - {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, - {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, - } - - for _, p := range tests { - v := commonBits(p.s1, p.s2) - if v != p.match { - t.Error( - "For slice", p.s1, p.s2, - "expected match", p.match, - ",but got", v, - ) - } - } -} - -func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { - var trie *trieEntry - var peers []*Peer - root := parentIndirection{&trie, 2} - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < peerNumber; n++ { - peers = append(peers, &Peer{}) - } - - for n := 0; n < addressNumber; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint8(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % peerNumber - root.insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < b.N; n++ { - var addr [AddressLength]byte - rand.Read(addr[:]) - trie.lookup(addr[:]) - } -} - -func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv4len, b) -} - -func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv4len, b) -} - -func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { - benchmarkTrie(100, 1000, net.IPv6len, b) -} - -func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { - benchmarkTrie(10, 10, net.IPv6len, b) -} - -/* Test ported from kernel implementation: - * selftest/allowedips.h - */ -func TestTrieIPv4(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - g := &Peer{} - h := &Peer{} - - var allowedIPs AllowedIPs - - insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { - allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) - } - - assertEQ := func(peer *Peer, a, b, c, d byte) { - p := allowedIPs.Lookup([]byte{a, b, c, d}) - if p != peer { - t.Error("Assert EQ failed") - } - } - - assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := allowedIPs.Lookup([]byte{a, b, c, d}) - if p == peer { - t.Error("Assert NEQ failed") - } - } - - insert(a, 192, 168, 4, 0, 24) - insert(b, 192, 168, 4, 4, 32) - insert(c, 192, 168, 0, 0, 16) - insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) - insert(e, 0, 0, 0, 0, 0) - insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) - insert(a, 10, 0, 0, 0, 25) - insert(b, 10, 0, 0, 128, 25) - insert(a, 10, 1, 0, 0, 30) - insert(b, 10, 1, 0, 4, 30) - insert(c, 10, 1, 0, 8, 29) - insert(d, 10, 1, 0, 16, 29) - - assertEQ(a, 192, 168, 4, 20) - assertEQ(a, 192, 168, 4, 0) - assertEQ(b, 192, 168, 4, 4) - assertEQ(c, 192, 168, 200, 182) - assertEQ(c, 192, 95, 5, 68) - assertEQ(e, 192, 95, 5, 96) - assertEQ(g, 64, 15, 116, 26) - assertEQ(g, 64, 15, 127, 3) - - insert(a, 1, 0, 0, 0, 32) - insert(a, 64, 0, 0, 0, 32) - insert(a, 128, 0, 0, 0, 32) - insert(a, 192, 0, 0, 0, 32) - insert(a, 255, 0, 0, 0, 32) - - assertEQ(a, 1, 0, 0, 0) - assertEQ(a, 64, 0, 0, 0) - assertEQ(a, 128, 0, 0, 0) - assertEQ(a, 192, 0, 0, 0) - assertEQ(a, 255, 0, 0, 0) - - allowedIPs.RemoveByPeer(a) - - assertNEQ(a, 1, 0, 0, 0) - assertNEQ(a, 64, 0, 0, 0) - assertNEQ(a, 128, 0, 0, 0) - assertNEQ(a, 192, 0, 0, 0) - assertNEQ(a, 255, 0, 0, 0) - - allowedIPs.RemoveByPeer(a) - allowedIPs.RemoveByPeer(b) - allowedIPs.RemoveByPeer(c) - allowedIPs.RemoveByPeer(d) - allowedIPs.RemoveByPeer(e) - allowedIPs.RemoveByPeer(g) - allowedIPs.RemoveByPeer(h) - if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { - t.Error("Expected removing all the peers to empty trie, but it did not") - } - - insert(a, 192, 168, 0, 0, 16) - insert(a, 192, 168, 0, 0, 24) - - allowedIPs.RemoveByPeer(a) - - assertNEQ(a, 192, 168, 0, 1) -} - -/* Test ported from kernel implementation: - * selftest/allowedips.h - */ -func TestTrieIPv6(t *testing.T) { - a := &Peer{} - b := &Peer{} - c := &Peer{} - d := &Peer{} - e := &Peer{} - f := &Peer{} - g := &Peer{} - h := &Peer{} - - var allowedIPs AllowedIPs - - expand := func(a uint32) []byte { - var out [4]byte - out[0] = byte(a >> 24 & 0xff) - out[1] = byte(a >> 16 & 0xff) - out[2] = byte(a >> 8 & 0xff) - out[3] = byte(a & 0xff) - return out[:] - } - - insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) - } - - assertEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := allowedIPs.Lookup(addr) - if p != peer { - t.Error("Assert EQ failed") - } - } - - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) - insert(c, 0x26075300, 0x60006b00, 0, 0, 64) - insert(e, 0, 0, 0, 0, 0) - insert(f, 0, 0, 0, 0, 0) - insert(g, 0x24046800, 0, 0, 0, 32) - insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) - insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) - insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) - insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) - - assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) - assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) - assertEQ(f, 0x26075300, 0x60006b01, 0, 0) - assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) - assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) - assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) - assertEQ(h, 0x24046800, 0x40040800, 0, 0) - assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) - assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) -} diff --git a/tempfork/device/peer.go b/tempfork/device/peer.go deleted file mode 100644 index c4f4d42ff..000000000 --- a/tempfork/device/peer.go +++ /dev/null @@ -1,28 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "container/list" - - "tailscale.com/types/key" -) - -type Peer struct { - trieEntries list.List - - key key.NodePublic -} - -func NewPeer(k key.NodePublic) *Peer { - return &Peer{ - key: k, - } -} - -func (p *Peer) Key() key.NodePublic { - return p.key -}