From 7e0d12e7ccb8d3062bf16ab03edce7d01944367a Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 22 Mar 2021 10:23:26 -0700 Subject: [PATCH] wgengine/magicsock: don't update control if only endpoint order changes Updates #1559 Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/magicsock.go | 32 ++++++++++++--- wgengine/magicsock/magicsock_test.go | 58 ++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 884056263..68c3f60db 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -630,7 +630,7 @@ func (c *Conn) setEndpoints(endpoints []string, reasons map[string]string) (chan delete(c.onEndpointRefreshed, de) } - if stringsEqual(endpoints, c.lastEndpoints) { + if stringSetsEqual(endpoints, c.lastEndpoints) { return false } c.lastEndpoints = endpoints @@ -1111,12 +1111,32 @@ func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, reason return eps, already, nil } -func stringsEqual(x, y []string) bool { - if len(x) != len(y) { - return false +// stringSetsEqual reports whether x and y represent the same set of +// strings. The order doesn't matter. +// +// It does not mutate the slices. +func stringSetsEqual(x, y []string) bool { + if len(x) == len(y) { + orderMatches := true + for i := range x { + if x[i] != y[i] { + orderMatches = false + break + } + } + if orderMatches { + return true + } } - for i := range x { - if x[i] != y[i] { + m := map[string]int{} + for _, v := range x { + m[v] |= 1 + } + for _, v := range y { + m[v] |= 2 + } + for _, n := range m { + if n != 3 { return false } } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 8e64a2696..1c9dfd1d7 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1793,3 +1793,61 @@ func TestRebindStress(t *testing.T) { t.Fatalf("Got ReceiveIPv4 error: %v (is closed = %v). Log:\n%s", err, errors.Is(err, net.ErrClosed), logBuf.Bytes()) } } + +func TestStringSetsEqual(t *testing.T) { + s := func(nn ...int) (ret []string) { + for _, n := range nn { + ret = append(ret, strconv.Itoa(n)) + } + return + } + tests := []struct { + a, b []string + want bool + }{ + { + want: true, + }, + { + a: s(1, 2, 3), + b: s(1, 2, 3), + want: true, + }, + { + a: s(1, 2), + b: s(2, 1), + want: true, + }, + { + a: s(1, 2), + b: s(2, 1, 1), + want: true, + }, + { + a: s(1, 2, 2), + b: s(2, 1), + want: true, + }, + { + a: s(1, 2, 2), + b: s(2, 1, 1), + want: true, + }, + { + a: s(1, 2, 2, 3), + b: s(2, 1, 1), + want: false, + }, + { + a: s(1, 2, 2), + b: s(2, 1, 1, 3), + want: false, + }, + } + for _, tt := range tests { + if got := stringSetsEqual(tt.a, tt.b); got != tt.want { + t.Errorf("%q vs %q = %v; want %v", tt.a, tt.b, got, tt.want) + } + } + +}