diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 88f6467a6..45ac45c9d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -505,9 +505,9 @@ var logPacketDests, _ = strconv.ParseBool(os.Getenv("DEBUG_LOG_PACKET_DESTS")) // be fake addrs representing DERP servers. // // It also returns as's current roamAddr, if any. -func appendDests(dsts []*net.UDPAddr, as *AddrSet, b []byte) (_ []*net.UDPAddr, roamAddr *net.UDPAddr) { +func (as *AddrSet) appendDests(dsts []*net.UDPAddr, b []byte) (_ []*net.UDPAddr, roamAddr *net.UDPAddr) { spray := shouldSprayPacket(b) // true for handshakes - now := time.Now() + now := as.timeNow() as.mu.Lock() defer as.mu.Unlock() @@ -542,7 +542,6 @@ func appendDests(dsts []*net.UDPAddr, as *AddrSet, b []byte) (_ []*net.UDPAddr, } // Pick our destination address(es). - roamAddr = as.roamAddr switch { case spray: // This packet is being sprayed to all addresses. @@ -575,9 +574,9 @@ func appendDests(dsts []*net.UDPAddr, as *AddrSet, b []byte) (_ []*net.UDPAddr, } if logPacketDests { - log.Printf("spray=%v; roam=%v; dests=%v", spray, roamAddr, dsts) + log.Printf("spray=%v; roam=%v; dests=%v", spray, as.roamAddr, dsts) } - return dsts, roamAddr + return dsts, as.roamAddr } var errNoDestinations = errors.New("magicsock: no destinations") @@ -600,7 +599,7 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error { } var addrBuf [8]*net.UDPAddr - dsts, roamAddr := appendDests(addrBuf[:0], as, b) + dsts, roamAddr := as.appendDests(addrBuf[:0], b) if len(dsts) == 0 { return errNoDestinations @@ -1104,6 +1103,10 @@ type AddrSet struct { // But there could be multiple or none of each. addrs []net.UDPAddr + // clock, if non-nil, is used in tests instead of time.Now. + clock func() time.Time + Logf logger.Logf // Logf, if non-nil, is used instead of log.Printf + mu sync.Mutex // guards following fields // roamAddr is non-nil if/when we receive a correctly signed @@ -1126,6 +1129,21 @@ type AddrSet struct { lastSpray time.Time } +func (as *AddrSet) timeNow() time.Time { + if as.clock != nil { + return as.clock() + } + return time.Now() +} + +func (as *AddrSet) logf(format string, args ...interface{}) { + if as.Logf != nil { + as.Logf(format, args...) + } else { + log.Printf(format, args...) + } +} + var noAddr = &net.UDPAddr{ IP: net.ParseIP("127.127.127.127"), Port: 127, @@ -1192,7 +1210,7 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error { // This is a hot path for established connections. return nil } - if a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) { + if a.roamAddr == nil && a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) { // Packet from current-priority address, no logging. // This is a hot path for established connections. return nil @@ -1216,26 +1234,26 @@ func (a *AddrSet) UpdateDst(new *net.UDPAddr) error { switch { case index == -1: if a.roamAddr == nil { - log.Printf("magicsock: rx %s from roaming address %s, set as new priority", pk, new) + a.logf("magicsock: rx %s from roaming address %s, set as new priority", pk, new) } else { - log.Printf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) + a.logf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) } a.roamAddr = new case a.roamAddr != nil: - log.Printf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) + a.logf("magicsock: rx %s from known %s (%d), replaces roaming address %s", pk, new, index, a.roamAddr) a.roamAddr = nil a.curAddr = index case a.curAddr == -1: - log.Printf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs)) + a.logf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs)) a.curAddr = index case index < a.curAddr: - log.Printf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr) + a.logf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr) default: // index > a.curAddr - log.Printf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old) + a.logf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old) a.curAddr = index } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index bb26ee589..001c2fd1b 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -8,6 +8,7 @@ import ( "bytes" crand "crypto/rand" "crypto/tls" + "encoding/binary" "fmt" "net" "net/http" @@ -596,3 +597,149 @@ func TestTwoDevicePing(t *testing.T) { } }) } + +// TestAddrSet tests AddrSet appendDests and UpdateDst. +func TestAddrSet(t *testing.T) { + mustUDPAddr := func(s string) *net.UDPAddr { + t.Helper() + ua, err := net.ResolveUDPAddr("udp", s) + if err != nil { + t.Fatal(err) + } + return ua + } + udpAddrs := func(ss ...string) (ret []net.UDPAddr) { + t.Helper() + for _, s := range ss { + ret = append(ret, *mustUDPAddr(s)) + } + return ret + } + joinUDPs := func(in []*net.UDPAddr) string { + var sb strings.Builder + for i, ua := range in { + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(ua.String()) + } + return sb.String() + } + var ( + regPacket = []byte("some regular packet") + sprayPacket = []byte("0000") + ) + binary.LittleEndian.PutUint32(sprayPacket[:4], device.MessageInitiationType) + if !shouldSprayPacket(sprayPacket) { + t.Fatal("sprayPacket should be classified as a spray packet for testing") + } + + // A step is either a b+want appendDests tests, or an + // UpdateDst call, depending on which fields are set. + type step struct { + // advance is the time to advance the fake clock + // before the step. + advance time.Duration + + // updateDst, if set, does an UpdateDst call and + // b+want are ignored. + updateDst *net.UDPAddr + + b []byte + want string // comma-separated + } + tests := []struct { + name string + as *AddrSet + steps []step + }{ + { + name: "reg_packet_no_curaddr", + as: &AddrSet{ + addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + curAddr: -1, // unknown + roamAddr: nil, + }, + steps: []step{ + {b: regPacket, want: "127.3.3.40:1"}, + }, + }, + { + name: "reg_packet_have_curaddr", + as: &AddrSet{ + addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + curAddr: 1, // global IP + roamAddr: nil, + }, + steps: []step{ + {b: regPacket, want: "123.45.67.89:123"}, + }, + }, + { + name: "reg_packet_have_roamaddr", + as: &AddrSet{ + addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + curAddr: 2, // should be ignored + roamAddr: mustUDPAddr("5.6.7.8:123"), + }, + steps: []step{ + {b: regPacket, want: "5.6.7.8:123"}, + {updateDst: mustUDPAddr("10.0.0.1:123")}, // no more roaming + {b: regPacket, want: "10.0.0.1:123"}, + }, + }, + { + name: "start_roaming", + as: &AddrSet{ + addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + curAddr: 2, + }, + steps: []step{ + {b: regPacket, want: "10.0.0.1:123"}, + {updateDst: mustUDPAddr("4.5.6.7:123")}, + {b: regPacket, want: "4.5.6.7:123"}, + {updateDst: mustUDPAddr("5.6.7.8:123")}, + {b: regPacket, want: "5.6.7.8:123"}, + {updateDst: mustUDPAddr("123.45.67.89:123")}, // end roaming + {b: regPacket, want: "123.45.67.89:123"}, + }, + }, + { + name: "spray_packet", + as: &AddrSet{ + addrs: udpAddrs("127.3.3.40:1", "123.45.67.89:123", "10.0.0.1:123"), + curAddr: 2, // should be ignored + roamAddr: mustUDPAddr("5.6.7.8:123"), + }, + steps: []step{ + {b: sprayPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, + {advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, + {advance: 300 * time.Millisecond, b: regPacket, want: "127.3.3.40:1,123.45.67.89:123,10.0.0.1:123,5.6.7.8:123"}, + {advance: 3, b: regPacket, want: "5.6.7.8:123"}, + {advance: 2 * time.Millisecond, updateDst: mustUDPAddr("10.0.0.1:123")}, + {advance: 3, b: regPacket, want: "10.0.0.1:123"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + faket := time.Unix(0, 0) + tt.as.Logf = t.Logf + tt.as.clock = func() time.Time { return faket } + for i, st := range tt.steps { + faket = faket.Add(st.advance) + + if st.updateDst != nil { + if err := tt.as.UpdateDst(st.updateDst); err != nil { + t.Fatal(err) + } + continue + } + got, _ := tt.as.appendDests(nil, st.b) + if gotStr := joinUDPs(got); gotStr != st.want { + t.Errorf("step %d: got %v; want %v", i, gotStr, st.want) + } + } + }) + } +}