ipn/ipnlocal: improve testability of random node selection

In order to test the sticky last suggestion code, a test was written for
LocalBackend.SuggestExitNode but it contains a random number generator
which makes writing comprehensive tests very difficult. This doesn't
change how the last suggestion works, but it adds some infrastructure to
make that easier in a later PR.

This adds func parameters for the two randomized parts: breaking ties
between DERP regions and breaking ties between nodes. This way tests can
validate the entire list of tied options, rather than expecting a
particular outcome given a particular random seed.

As a result of this, the global random number generator can be used
rather than seeding a local one each time.

In order to see the tied nodes for the location based (i.e. Mullvad)
case, pickWeighted needed to return a slice instead of a single
arbitrary option, so there is a small change in how that works.

Updates tailscale/corp#19681

Change-Id: I83c48a752abdec0f59c58ccfd8bfb3f3f17d0ea8
Signed-off-by: Adrian Dewhurst <adrian@tailscale.com>
This commit is contained in:
Adrian Dewhurst 2024-06-03 16:12:12 -04:00 committed by Adrian Dewhurst
parent d21c00205d
commit 3bf2bddbb5
2 changed files with 143 additions and 54 deletions

View File

@ -6420,9 +6420,8 @@ func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionRes
} }
return last, err return last, err
} }
seed := time.Now().UnixNano()
r := rand.New(rand.NewSource(seed)) res, err := suggestExitNode(lastReport, netMap, randomRegion, randomNode)
res, err := suggestExitNode(lastReport, netMap, r)
if err != nil { if err != nil {
last, err := lastSuggestedExitNode.asAPIType() last, err := lastSuggestedExitNode.asAPIType()
if err != nil { if err != nil {
@ -6437,6 +6436,13 @@ func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionRes
return res, err return res, err
} }
// selectRegionFunc returns a DERP region from the slice of candidate regions.
// The value is returned, not the slice index.
type selectRegionFunc func(views.Slice[int]) int
// selectNodeFunc returns a node from the slice of candidate nodes.
type selectNodeFunc func(nodes views.Slice[tailcfg.NodeView]) tailcfg.NodeView
// asAPIType formats a response with the last suggested exit node's ID and name. // asAPIType formats a response with the last suggested exit node's ID and name.
// Returns error if there is no id or name. // Returns error if there is no id or name.
// Used as a fallback before returning a nil response and error. // Used as a fallback before returning a nil response and error.
@ -6449,8 +6455,8 @@ func (n lastSuggestedExitNode) asAPIType() (res apitype.ExitNodeSuggestionRespon
return res, ErrUnableToSuggestLastExitNode return res, ErrUnableToSuggestLastExitNode
} }
func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand.Rand) (res apitype.ExitNodeSuggestionResponse, err error) { func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, selectRegion selectRegionFunc, selectNode selectNodeFunc) (res apitype.ExitNodeSuggestionResponse, err error) {
if report.PreferredDERP == 0 { if report.PreferredDERP == 0 || netMap == nil || netMap.DERPMap == nil {
return res, ErrNoPreferredDERP return res, ErrNoPreferredDERP
} }
var allowedCandidates set.Set[string] var allowedCandidates set.Set[string]
@ -6461,6 +6467,9 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
} }
candidates := make([]tailcfg.NodeView, 0, len(netMap.Peers)) candidates := make([]tailcfg.NodeView, 0, len(netMap.Peers))
for _, peer := range netMap.Peers { for _, peer := range netMap.Peers {
if !peer.Valid() {
continue
}
if allowedCandidates != nil && !allowedCandidates.Contains(string(peer.StableID())) { if allowedCandidates != nil && !allowedCandidates.Contains(string(peer.StableID())) {
continue continue
} }
@ -6484,7 +6493,10 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
} }
candidatesByRegion := make(map[int][]tailcfg.NodeView, len(netMap.DERPMap.Regions)) candidatesByRegion := make(map[int][]tailcfg.NodeView, len(netMap.DERPMap.Regions))
var preferredDERP *tailcfg.DERPRegion = netMap.DERPMap.Regions[report.PreferredDERP] preferredDERP, ok := netMap.DERPMap.Regions[report.PreferredDERP]
if !ok {
return res, ErrNoPreferredDERP
}
var minDistance float64 = math.MaxFloat64 var minDistance float64 = math.MaxFloat64
type nodeDistance struct { type nodeDistance struct {
nv tailcfg.NodeView nv tailcfg.NodeView
@ -6492,9 +6504,6 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
} }
distances := make([]nodeDistance, 0, len(candidates)) distances := make([]nodeDistance, 0, len(candidates))
for _, c := range candidates { for _, c := range candidates {
if !c.Valid() {
continue
}
if c.DERP() != "" { if c.DERP() != "" {
ipp, err := netip.ParseAddrPort(c.DERP()) ipp, err := netip.ParseAddrPort(c.DERP())
if err != nil { if err != nil {
@ -6533,13 +6542,13 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
if len(candidatesByRegion) > 0 { if len(candidatesByRegion) > 0 {
minRegion := minLatencyDERPRegion(xmaps.Keys(candidatesByRegion), report) minRegion := minLatencyDERPRegion(xmaps.Keys(candidatesByRegion), report)
if minRegion == 0 { if minRegion == 0 {
minRegion = randomRegion(xmaps.Keys(candidatesByRegion), r) minRegion = selectRegion(views.SliceOf(xmaps.Keys(candidatesByRegion)))
} }
regionCandidates, ok := candidatesByRegion[minRegion] regionCandidates, ok := candidatesByRegion[minRegion]
if !ok { if !ok {
return res, errors.New("no candidates in expected region: this is a bug") return res, errors.New("no candidates in expected region: this is a bug")
} }
chosen := randomNode(regionCandidates, r) chosen := selectNode(views.SliceOf(regionCandidates))
res.ID = chosen.StableID() res.ID = chosen.StableID()
res.Name = chosen.Name() res.Name = chosen.Name()
if hi := chosen.Hostinfo(); hi.Valid() { if hi := chosen.Hostinfo(); hi.Valid() {
@ -6565,7 +6574,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
pickFrom = append(pickFrom, candidate.nv) pickFrom = append(pickFrom, candidate.nv)
} }
} }
chosen := pickWeighted(pickFrom) bestCandidates := pickWeighted(pickFrom)
chosen := selectNode(views.SliceOf(bestCandidates))
if !chosen.Valid() { if !chosen.Valid() {
return res, errors.New("chosen candidate invalid: this is a bug") return res, errors.New("chosen candidate invalid: this is a bug")
} }
@ -6580,36 +6590,35 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, r *rand
} }
// pickWeighted chooses the node with highest priority given a list of mullvad nodes. // pickWeighted chooses the node with highest priority given a list of mullvad nodes.
func pickWeighted(candidates []tailcfg.NodeView) tailcfg.NodeView { func pickWeighted(candidates []tailcfg.NodeView) []tailcfg.NodeView {
maxWeight := 0 maxWeight := 0
var best tailcfg.NodeView best := make([]tailcfg.NodeView, 0, 1)
for _, c := range candidates { for _, c := range candidates {
hi := c.Hostinfo() hi := c.Hostinfo()
if !hi.Valid() { if !hi.Valid() {
continue continue
} }
loc := hi.Location() loc := hi.Location()
if loc == nil || loc.Priority <= maxWeight { if loc == nil || loc.Priority < maxWeight {
continue continue
} }
if maxWeight != loc.Priority {
best = best[:0]
}
maxWeight = loc.Priority maxWeight = loc.Priority
best = c best = append(best, c)
} }
return best return best
} }
// randomNode chooses a node randomly given a list of nodes and a *rand.Rand. // randomRegion is a selectRegionFunc that selects a uniformly random region.
func randomNode(nodes []tailcfg.NodeView, r *rand.Rand) tailcfg.NodeView { func randomRegion(regions views.Slice[int]) int {
return nodes[r.Intn(len(nodes))] return regions.At(rand.Intn(regions.Len()))
} }
// randomRegion chooses a region randomly given a list of ints and a *rand.Rand // randomNode is a selectNodeFunc that returns a uniformly random node.
func randomRegion(regions []int, r *rand.Rand) int { func randomNode(nodes views.Slice[tailcfg.NodeView]) tailcfg.NodeView {
if testenv.InTest() { return nodes.At(rand.Intn(nodes.Len()))
regions = slices.Clone(regions)
slices.Sort(regions)
}
return regions[r.Intn(len(regions))]
} }
// minLatencyDERPRegion returns the region with the lowest latency value given the last netcheck report. // minLatencyDERPRegion returns the region with the lowest latency value given the last netcheck report.

View File

@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"math" "math"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@ -2775,6 +2774,54 @@ func withSuggest() peerOptFunc {
} }
} }
func deterministicRegionForTest(t testing.TB, want views.Slice[int], use int) selectRegionFunc {
t.Helper()
if !views.SliceContains(want, use) {
t.Errorf("invalid test: use %v is not in want %v", use, want)
}
return func(got views.Slice[int]) int {
if !views.SliceEqualAnyOrder(got, want) {
t.Errorf("candidate regions = %v, want %v", got, want)
}
return use
}
}
func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeID], use tailcfg.StableNodeID) selectNodeFunc {
t.Helper()
if !views.SliceContains(want, use) {
t.Errorf("invalid test: use %v is not in want %v", use, want)
}
return func(got views.Slice[tailcfg.NodeView]) tailcfg.NodeView {
var ret tailcfg.NodeView
gotIDs := make([]tailcfg.StableNodeID, got.Len())
for i := range got.Len() {
nv := got.At(i)
if !nv.Valid() {
t.Fatalf("invalid node at index %v", i)
}
gotIDs[i] = nv.StableID()
if nv.StableID() == use {
ret = nv
}
}
if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), want) {
t.Errorf("candidate nodes = %v, want %v", gotIDs, want)
}
if !ret.Valid() {
t.Fatalf("did not find matching node in %v, want %v", gotIDs, use)
}
return ret
}
}
func TestSuggestExitNode(t *testing.T) { func TestSuggestExitNode(t *testing.T) {
defaultDERPMap := &tailcfg.DERPMap{ defaultDERPMap := &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{ Regions: map[int]*tailcfg.DERPRegion{
@ -2911,7 +2958,6 @@ func TestSuggestExitNode(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
seed int64
lastReport *netcheck.Report lastReport *netcheck.Report
netMap *netmap.NetworkMap netMap *netmap.NetworkMap
@ -2919,6 +2965,11 @@ func TestSuggestExitNode(t *testing.T) {
allowPolicy []string allowPolicy []string
wantRegions []int
useRegion int
wantNodes []tailcfg.StableNodeID
wantID tailcfg.StableNodeID wantID tailcfg.StableNodeID
wantName string wantName string
wantLocation tailcfg.LocationView wantLocation tailcfg.LocationView
@ -2936,15 +2987,21 @@ func TestSuggestExitNode(t *testing.T) {
peer2DERP1, peer2DERP1,
}, },
}, },
wantNodes: []tailcfg.StableNodeID{
"stable1",
"stable2",
},
wantName: "peer1", wantName: "peer1",
wantID: "stable1", wantID: "stable1",
}, },
{ {
name: "2 exit nodes different regions unknown latency", name: "2 exit nodes different regions unknown latency",
lastReport: noLatency1Report, lastReport: noLatency1Report,
netMap: defaultNetmap, netMap: defaultNetmap,
wantName: "peer2", wantRegions: []int{1, 3}, // the only regions with peers
wantID: "stable2", useRegion: 1,
wantName: "peer2",
wantID: "stable2",
}, },
{ {
name: "2 derp based exit nodes, different regions, equal latency", name: "2 derp based exit nodes, different regions, equal latency",
@ -2964,8 +3021,10 @@ func TestSuggestExitNode(t *testing.T) {
peer3, peer3,
}, },
}, },
wantName: "peer1", wantRegions: []int{1, 2},
wantID: "stable1", useRegion: 1,
wantName: "peer1",
wantID: "stable1",
}, },
{ {
name: "mullvad nodes, no derp based exit nodes", name: "mullvad nodes, no derp based exit nodes",
@ -3003,6 +3062,7 @@ func TestSuggestExitNode(t *testing.T) {
fortWorthPeer8LowPriority, fortWorthPeer8LowPriority,
}, },
}, },
wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"},
wantID: "stable5", wantID: "stable5",
wantLocation: dallas.View(), wantLocation: dallas.View(),
wantName: "Dallas", wantName: "Dallas",
@ -3019,8 +3079,9 @@ func TestSuggestExitNode(t *testing.T) {
peer4DERP3, peer4DERP3,
}, },
}, },
wantID: "stable4", useRegion: 3,
wantName: "peer4", wantID: "stable4",
wantName: "peer4",
}, },
{ {
name: "no peers", name: "no peers",
@ -3053,7 +3114,6 @@ func TestSuggestExitNode(t *testing.T) {
}, },
{ {
name: "prefer last node", name: "prefer last node",
seed: 1,
lastReport: preferred1Report, lastReport: preferred1Report,
netMap: &netmap.NetworkMap{ netMap: &netmap.NetworkMap{
SelfNode: selfNode.View(), SelfNode: selfNode.View(),
@ -3064,8 +3124,12 @@ func TestSuggestExitNode(t *testing.T) {
}, },
}, },
lastSuggestion: "stable2", lastSuggestion: "stable2",
wantName: "peer2", wantNodes: []tailcfg.StableNodeID{
wantID: "stable2", "stable1",
"stable2",
},
wantName: "peer2",
wantID: "stable2",
}, },
{ {
name: "found better derp node", name: "found better derp node",
@ -3088,6 +3152,7 @@ func TestSuggestExitNode(t *testing.T) {
fortWorthPeer8LowPriority, fortWorthPeer8LowPriority,
}, },
}, },
wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"},
wantID: "stable5", wantID: "stable5",
wantName: "Dallas", wantName: "Dallas",
wantLocation: dallas.View(), wantLocation: dallas.View(),
@ -3105,15 +3170,16 @@ func TestSuggestExitNode(t *testing.T) {
fortWorthPeer7, fortWorthPeer7,
}, },
}, },
wantNodes: []tailcfg.StableNodeID{"stable7"},
wantID: "stable7", wantID: "stable7",
wantName: "Fort Worth", wantName: "Fort Worth",
wantLocation: fortWorth.View(), wantLocation: fortWorth.View(),
}, },
{ {
name: "large netmap", name: "large netmap",
seed: 1,
lastReport: preferred1Report, lastReport: preferred1Report,
netMap: largeNetmap, netMap: largeNetmap,
wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"},
wantID: "stable2", wantID: "stable2",
wantName: "peer2", wantName: "peer2",
}, },
@ -3125,10 +3191,10 @@ func TestSuggestExitNode(t *testing.T) {
}, },
{ {
name: "only derp suggestions", name: "only derp suggestions",
seed: 1,
lastReport: preferred1Report, lastReport: preferred1Report,
netMap: largeNetmap, netMap: largeNetmap,
allowPolicy: []string{"stable1", "stable2", "stable3"}, allowPolicy: []string{"stable1", "stable2", "stable3"},
wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"},
wantID: "stable2", wantID: "stable2",
wantName: "peer2", wantName: "peer2",
}, },
@ -3172,8 +3238,19 @@ func TestSuggestExitNode(t *testing.T) {
} }
syspolicy.SetHandlerForTest(t, &mh) syspolicy.SetHandlerForTest(t, &mh)
r := rand.New(rand.NewSource(tt.seed)) wantRegions := tt.wantRegions
got, err := suggestExitNode(tt.lastReport, tt.netMap, r) if wantRegions == nil {
wantRegions = []int{tt.useRegion}
}
selectRegion := deterministicRegionForTest(t, views.SliceOf(wantRegions), tt.useRegion)
wantNodes := tt.wantNodes
if wantNodes == nil {
wantNodes = []tailcfg.StableNodeID{tt.wantID}
}
selectNode := deterministicNodeForTest(t, views.SliceOf(wantNodes), tt.wantID)
got, err := suggestExitNode(tt.lastReport, tt.netMap, selectRegion, selectNode)
if got.Name != tt.wantName { if got.Name != tt.wantName {
t.Errorf("name=%v, want %v", got.Name, tt.wantName) t.Errorf("name=%v, want %v", got.Name, tt.wantName)
} }
@ -3204,7 +3281,7 @@ func TestSuggestExitNodePickWeighted(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
candidates []tailcfg.NodeView candidates []tailcfg.NodeView
wantID tailcfg.StableNodeID wantIDs []tailcfg.StableNodeID
}{ }{
{ {
name: "different priorities", name: "different priorities",
@ -3212,7 +3289,7 @@ func TestSuggestExitNodePickWeighted(t *testing.T) {
makePeer(2, withExitRoutes(), withLocation(location20.View())), makePeer(2, withExitRoutes(), withLocation(location20.View())),
makePeer(3, withExitRoutes(), withLocation(location10.View())), makePeer(3, withExitRoutes(), withLocation(location10.View())),
}, },
wantID: "stable2", wantIDs: []tailcfg.StableNodeID{"stable2"},
}, },
{ {
name: "same priorities", name: "same priorities",
@ -3220,31 +3297,34 @@ func TestSuggestExitNodePickWeighted(t *testing.T) {
makePeer(2, withExitRoutes(), withLocation(location10.View())), makePeer(2, withExitRoutes(), withLocation(location10.View())),
makePeer(3, withExitRoutes(), withLocation(location10.View())), makePeer(3, withExitRoutes(), withLocation(location10.View())),
}, },
wantID: "stable2", wantIDs: []tailcfg.StableNodeID{"stable2", "stable3"},
}, },
{ {
name: "<1 candidates", name: "<1 candidates",
candidates: []tailcfg.NodeView{}, candidates: []tailcfg.NodeView{},
wantID: "<invalid>",
}, },
{ {
name: "1 candidate", name: "1 candidate",
candidates: []tailcfg.NodeView{ candidates: []tailcfg.NodeView{
makePeer(2, withExitRoutes(), withLocation(location20.View())), makePeer(2, withExitRoutes(), withLocation(location20.View())),
}, },
wantID: "stable2", wantIDs: []tailcfg.StableNodeID{"stable2"},
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := pickWeighted(tt.candidates) got := pickWeighted(tt.candidates)
gotID := tailcfg.StableNodeID("<invalid>") gotIDs := make([]tailcfg.StableNodeID, 0, len(got))
if got.Valid() { for _, n := range got {
gotID = got.StableID() if !n.Valid() {
gotIDs = append(gotIDs, "<invalid>")
continue
}
gotIDs = append(gotIDs, n.StableID())
} }
if gotID != tt.wantID { if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), views.SliceOf(tt.wantIDs)) {
t.Errorf("node IDs = %v, want %v", gotID, tt.wantID) t.Errorf("node IDs = %v, want %v", gotIDs, tt.wantIDs)
} }
}) })
} }