142 lines
2.9 KiB
Go
142 lines
2.9 KiB
Go
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
package device
|
|
|
|
import (
|
|
"math/rand"
|
|
"net"
|
|
"net/netip"
|
|
"sort"
|
|
"testing"
|
|
)
|
|
|
|
const (
|
|
NumberOfPeers = 100
|
|
NumberOfPeerRemovals = 4
|
|
NumberOfAddresses = 250
|
|
NumberOfTests = 10000
|
|
)
|
|
|
|
type SlowNode struct {
|
|
peer *Peer
|
|
cidr uint8
|
|
bits []byte
|
|
}
|
|
|
|
type SlowRouter []*SlowNode
|
|
|
|
func (r SlowRouter) Len() int {
|
|
return len(r)
|
|
}
|
|
|
|
func (r SlowRouter) Less(i, j int) bool {
|
|
return r[i].cidr > r[j].cidr
|
|
}
|
|
|
|
func (r SlowRouter) Swap(i, j int) {
|
|
r[i], r[j] = r[j], r[i]
|
|
}
|
|
|
|
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
|
|
for _, t := range r {
|
|
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
|
t.peer = peer
|
|
t.bits = addr
|
|
return r
|
|
}
|
|
}
|
|
r = append(r, &SlowNode{
|
|
cidr: cidr,
|
|
bits: addr,
|
|
peer: peer,
|
|
})
|
|
sort.Sort(r)
|
|
return r
|
|
}
|
|
|
|
func (r SlowRouter) Lookup(addr []byte) *Peer {
|
|
for _, t := range r {
|
|
common := commonBits(t.bits, addr)
|
|
if common >= t.cidr {
|
|
return t.peer
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
|
|
n := 0
|
|
for _, x := range r {
|
|
if x.peer != peer {
|
|
r[n] = x
|
|
n++
|
|
}
|
|
}
|
|
return r[:n]
|
|
}
|
|
|
|
func TestTrieRandom(t *testing.T) {
|
|
var slow4, slow6 SlowRouter
|
|
var peers []*Peer
|
|
var allowedIPs AllowedIPs
|
|
|
|
rand.Seed(1)
|
|
|
|
for n := 0; n < NumberOfPeers; n++ {
|
|
peers = append(peers, &Peer{})
|
|
}
|
|
|
|
for n := 0; n < NumberOfAddresses; n++ {
|
|
var addr4 [4]byte
|
|
rand.Read(addr4[:])
|
|
cidr := uint8(rand.Intn(32) + 1)
|
|
index := rand.Intn(NumberOfPeers)
|
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
|
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
|
|
|
var addr6 [16]byte
|
|
rand.Read(addr6[:])
|
|
cidr = uint8(rand.Intn(128) + 1)
|
|
index = rand.Intn(NumberOfPeers)
|
|
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
|
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
|
}
|
|
|
|
var p int
|
|
for p = 0; ; p++ {
|
|
for n := 0; n < NumberOfTests; n++ {
|
|
var addr4 [4]byte
|
|
rand.Read(addr4[:])
|
|
peer1 := slow4.Lookup(addr4[:])
|
|
peer2 := allowedIPs.Lookup(addr4[:])
|
|
if peer1 != peer2 {
|
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
|
|
}
|
|
|
|
var addr6 [16]byte
|
|
rand.Read(addr6[:])
|
|
peer1 = slow6.Lookup(addr6[:])
|
|
peer2 = allowedIPs.Lookup(addr6[:])
|
|
if peer1 != peer2 {
|
|
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
|
|
}
|
|
}
|
|
if p >= len(peers) || p >= NumberOfPeerRemovals {
|
|
break
|
|
}
|
|
allowedIPs.RemoveByPeer(peers[p])
|
|
slow4 = slow4.RemoveByPeer(peers[p])
|
|
slow6 = slow6.RemoveByPeer(peers[p])
|
|
}
|
|
for ; p < len(peers); p++ {
|
|
allowedIPs.RemoveByPeer(peers[p])
|
|
}
|
|
|
|
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
|
t.Error("Failed to remove all nodes from trie by peer")
|
|
}
|
|
}
|