wgengine/magicsock: make peerMap also keyed by NodeID

In prep for incremental netmap update plumbing (#1909), make peerMap
also keyed by NodeID, as all the netmap node mutations passed around
later will be keyed by NodeID.

In the process, also:

* add envknob.InDevMode, as a signal that we can panic more aggressively
  in unexpected cases.
* pull two moderately large blocks of code in Conn.SetNetworkMap out
  into their own methods
* convert a few more sets from maps to set.Set

Updates #1909

Change-Id: I7acdd64452ba58e9d554140ee7a8760f9043f961
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2023-09-11 10:13:00 -07:00 committed by Brad Fitzpatrick
parent 683ba62f3e
commit d050700a3b
8 changed files with 188 additions and 84 deletions

View File

@ -389,12 +389,24 @@ func CanTaildrop() bool { return !Bool("TS_DISABLE_TAILDROP") }
// SSHPolicyFile returns the path, if any, to the SSHPolicy JSON file for development.
func SSHPolicyFile() string { return String("TS_DEBUG_SSH_POLICY_FILE") }
// SSHIgnoreTailnetPolicy is whether to ignore the Tailnet SSH policy for development.
// SSHIgnoreTailnetPolicy reports whether to ignore the Tailnet SSH policy for development.
func SSHIgnoreTailnetPolicy() bool { return Bool("TS_DEBUG_SSH_IGNORE_TAILNET_POLICY") }
// TKASkipSignatureCheck is whether to skip node-key signature checking for development.
// TKASkipSignatureCheck reports whether to skip node-key signature checking for development.
func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION") }
// CrashOnUnexpected reports whether the Tailscale client should panic
// on unexpected conditions. If TS_DEBUG_CRASH_ON_UNEXPECTED is set, that's
// used. Otherwise the default value is true for unstable builds.
func CrashOnUnexpected() bool {
if v, ok := crashOnUnexpected().Get(); ok {
return v
}
return version.IsUnstableBuild()
}
var crashOnUnexpected = RegisterOptBool("TS_DEBUG_CRASH_ON_UNEXPECTED")
// NoLogsNoSupport reports whether the client's opted out of log uploads and
// technical support.
func NoLogsNoSupport() bool {

View File

@ -15,6 +15,8 @@ var (
// debugDisco prints verbose logs of active discovery events as
// they happen.
debugDisco = envknob.RegisterBool("TS_DEBUG_DISCO")
// debugPeerMap prints verbose logs of changes to the peermap.
debugPeerMap = envknob.RegisterBool("TS_DEBUG_MAGICSOCK_PEERMAP")
// debugOmitLocalAddresses removes all local interface addresses
// from magicsock's discovered local endpoints. Used in some tests.
debugOmitLocalAddresses = envknob.RegisterBool("TS_DEBUG_OMIT_LOCAL_ADDRS")

View File

@ -26,3 +26,4 @@ func debugUseDerpRouteEnv() string { return "" }
func debugUseDerpRoute() opt.Bool { return "" }
func debugRingBufferMaxSizeBytes() int { return 0 }
func inTest() bool { return false }
func debugPeerMap() bool { return false }

View File

@ -49,6 +49,7 @@ type endpoint struct {
// These fields are initialized once and never modified.
c *Conn
nodeID tailcfg.NodeID
publicKey key.NodePublic // peer public key (for WireGuard + DERP)
publicKeyHex string // cached output of publicKey.UntypedHexString
fakeWGAddr netip.AddrPort // the UDP address we tell wireguard-go we're using

View File

@ -55,6 +55,7 @@ import (
"tailscale.com/util/mak"
"tailscale.com/util/ringbuffer"
"tailscale.com/util/set"
"tailscale.com/util/testenv"
"tailscale.com/util/uniq"
"tailscale.com/wgengine/capture"
)
@ -232,8 +233,8 @@ type Conn struct {
// in other maps below that are keyed by peer public key.
peerSet set.Set[key.NodePublic]
// nodeOfDisco tracks the networkmap Node entity for each peer
// discovery key.
// peerMap tracks the networkmap Node entity for each peer
// by node key, node ID, and discovery key.
peerMap peerMap
// discoInfo is the state for an active DiscoKey.
@ -1742,6 +1743,30 @@ func nodesEqual(x, y []tailcfg.NodeView) bool {
return true
}
// debugRingBufferSize returns a maximum size for our set of endpoint ring
// buffers by assuming that a single large update is ~500 bytes, and that we
// want to not use more than 1MiB of memory on phones / 4MiB on other devices.
// Calculate the per-endpoint ring buffer size by dividing that out, but always
// storing at least two entries.
func debugRingBufferSize(numPeers int) int {
const defaultVal = 2
if numPeers == 0 {
return defaultVal
}
var maxRingBufferSize int
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
maxRingBufferSize = 1 * 1024 * 1024
} else {
maxRingBufferSize = 4 * 1024 * 1024
}
if v := debugRingBufferMaxSizeBytes(); v > 0 {
maxRingBufferSize = v
}
const averageRingBufferElemSize = 512
return max(defaultVal, maxRingBufferSize/(averageRingBufferElemSize*numPeers))
}
// SetNetworkMap is called when the control client gets a new network
// map from the control server. It must always be non-nil.
//
@ -1771,29 +1796,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
c.logf("[v1] magicsock: got updated network map; %d peers", len(nm.Peers))
heartbeatDisabled := debugEnableSilentDisco()
// Set a maximum size for our set of endpoint ring buffers by assuming
// that a single large update is ~500 bytes, and that we want to not
// use more than 1MiB of memory on phones / 4MiB on other devices.
// Calculate the per-endpoint ring buffer size by dividing that out,
// but always storing at least two entries.
var entriesPerBuffer int = 2
if len(nm.Peers) > 0 {
var maxRingBufferSize int
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
maxRingBufferSize = 1 * 1024 * 1024
} else {
maxRingBufferSize = 4 * 1024 * 1024
}
if v := debugRingBufferMaxSizeBytes(); v > 0 {
maxRingBufferSize = v
}
const averageRingBufferElemSize = 512
entriesPerBuffer = maxRingBufferSize / (averageRingBufferElemSize * len(nm.Peers))
if entriesPerBuffer < 2 {
entriesPerBuffer = 2
}
}
entriesPerBuffer := debugRingBufferSize(len(nm.Peers))
// Try a pass of just upserting nodes and creating missing
// endpoints. If the set of nodes is the same, this is an
@ -1801,7 +1804,26 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// we'll fall through to the next pass, which allocates but can
// handle full set updates.
for _, n := range nm.Peers {
if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok {
if n.ID() == 0 {
devPanicf("node with zero ID")
continue
}
if n.Key().IsZero() {
devPanicf("node with zero key")
continue
}
ep, ok := c.peerMap.endpointForNodeID(n.ID())
if ok && ep.publicKey != n.Key() {
// The node rotated public keys. Delete the old endpoint and create
// it anew.
c.peerMap.deleteEndpoint(ep)
ok = false
}
if ok {
// At this point we're modifying an existing endpoint (ep) whose
// public key and nodeID match n. Its other fields (such as disco
// key or endpoints) might've changed.
if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() {
// Discokey transitioned from non-zero to zero? This should not
// happen in the wild, however it could mean:
@ -1821,14 +1843,31 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap
continue
}
if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok {
// At this point n.Key() should be for a key we've never seen before. If
// ok was true above, it was an update to an existing matching key and
// we don't get this far. If ok was false above, that means it's a key
// that differs from the one the NodeID had. But double check.
if ep.nodeID != n.ID() {
// Server error.
devPanicf("public key moved between nodeIDs")
} else {
// Internal data structures out of sync.
devPanicf("public key found in peerMap but not by nodeID")
}
continue
}
if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() {
// Ancient pre-0.100 node, which does not have a disco key, and will only be reachable via DERP.
// Ancient pre-0.100 node, which does not have a disco key.
// No longer supported.
continue
}
ep := &endpoint{
ep = &endpoint{
c: c,
debugUpdates: ringbuffer.New[EndpointChange](entriesPerBuffer),
nodeID: n.ID(),
publicKey: n.Key(),
publicKeyHex: n.Key().UntypedHexString(),
sentPing: map[stun.TxID]sentPing{},
@ -1847,35 +1886,12 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
key: n.DiscoKey(),
short: n.DiscoKey().ShortString(),
})
if debugDisco() { // rather than making a new knob
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP(), derpPrefix) {
ipp, _ := netip.ParseAddrPort(n.DERP())
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
}
fmt.Fprintf(w, "derp=%v%s ", regionID, code)
}
for i := range n.AllowedIPs().LenIter() {
a := n.AllowedIPs().At(i)
if a.IsSingleIP() {
fmt.Fprintf(w, "aip=%v ", a.Addr())
} else {
fmt.Fprintf(w, "aip=%v ", a)
}
}
for i := range n.Endpoints().LenIter() {
ep := n.Endpoints().At(i)
fmt.Fprintf(w, "ep=%v ", ep)
}
}))
}
}
if debugPeerMap() {
c.logEndpointCreated(n)
}
ep.updateFromNode(n, heartbeatDisabled)
c.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
}
@ -1886,12 +1902,12 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// current netmap. If that happens, go through the allocful
// deletion path to clean up moribund nodes.
if c.peerMap.nodeCount() != len(nm.Peers) {
keep := make(map[key.NodePublic]bool, len(nm.Peers))
keep := set.Set[key.NodePublic]{}
for _, n := range nm.Peers {
keep[n.Key()] = true
keep.Add(n.Key())
}
c.peerMap.forEachEndpoint(func(ep *endpoint) {
if !keep[ep.publicKey] {
if !keep.Contains(ep.publicKey) {
c.peerMap.deleteEndpoint(ep)
}
})
@ -1905,6 +1921,40 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
}
}
func devPanicf(format string, a ...any) {
if testenv.InTest() || envknob.CrashOnUnexpected() {
panic(fmt.Sprintf(format, a...))
}
}
func (c *Conn) logEndpointCreated(n tailcfg.NodeView) {
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP(), derpPrefix) {
ipp, _ := netip.ParseAddrPort(n.DERP())
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
}
fmt.Fprintf(w, "derp=%v%s ", regionID, code)
}
for i := range n.AllowedIPs().LenIter() {
a := n.AllowedIPs().At(i)
if a.IsSingleIP() {
fmt.Fprintf(w, "aip=%v ", a.Addr())
} else {
fmt.Fprintf(w, "aip=%v ", a)
}
}
for i := range n.Endpoints().LenIter() {
ep := n.Endpoints().At(i)
fmt.Fprintf(w, "ep=%v ", ep)
}
}))
}
func (c *Conn) logEndpointChange(endpoints []tailcfg.Endpoint) {
c.logf("magicsock: endpoints changed: %s", logger.ArgWriter(func(buf *bufio.Writer) {
for i, ep := range endpoints {

View File

@ -1161,6 +1161,7 @@ func TestDiscoMessage(t *testing.T) {
DiscoKey: peer1Pub,
}
ep := &endpoint{
nodeID: 1,
publicKey: n.Key,
}
ep.disco.Store(&endpointDisco{
@ -1257,6 +1258,7 @@ func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (key.No
conn.SetNetworkMap(&netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
{
ID: 1,
Key: nodeKey,
DiscoKey: discoKey,
Endpoints: []string{sendConn.LocalAddr().String()},
@ -1461,6 +1463,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
conn.SetNetworkMap(&netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
{
ID: 1,
Key: nodeKey1,
DiscoKey: discoKey,
Endpoints: []string{"192.168.1.2:345"},
@ -1476,6 +1479,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
conn.SetNetworkMap(&netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
{
ID: 2,
Key: nodeKey2,
DiscoKey: discoKey,
Endpoints: []string{"192.168.1.2:345"},
@ -1767,6 +1771,7 @@ func TestStressSetNetworkMap(t *testing.T) {
for i := range allPeers {
present[i] = true
allPeers[i] = &tailcfg.Node{
ID: tailcfg.NodeID(i) + 1,
DiscoKey: randDiscoKey(),
Key: randNodeKey(),
Endpoints: []string{fmt.Sprintf("192.168.1.2:%d", i)},
@ -1831,18 +1836,26 @@ func (m *peerMap) validate() error {
return fmt.Errorf("duplicate endpoint present: %v", pi.ep.publicKey)
}
seenEps[pi.ep] = true
for ipp, v := range pi.ipPorts {
if !v {
return fmt.Errorf("m.byIPPort[%v] is false, expected map to be set-like", ipp)
}
for ipp := range pi.ipPorts {
if got := m.byIPPort[ipp]; got != pi {
return fmt.Errorf("m.byIPPort[%v] = %v, want %v", ipp, got, pi)
}
}
}
if len(m.byNodeKey) != len(m.byNodeID) {
return fmt.Errorf("len(m.byNodeKey)=%d != len(m.byNodeID)=%d", len(m.byNodeKey), len(m.byNodeID))
}
for nodeID, pi := range m.byNodeID {
ep := pi.ep
if pi2, ok := m.byNodeKey[ep.publicKey]; !ok {
return fmt.Errorf("nodeID %d in map with publicKey %v that's missing from map", nodeID, ep.publicKey)
} else if pi2 != pi {
return fmt.Errorf("nodeID %d in map with publicKey %v that points to different endpoint", nodeID, ep.publicKey)
}
}
for ipp, pi := range m.byIPPort {
if !pi.ipPorts[ipp] {
if !pi.ipPorts.Contains(ipp) {
return fmt.Errorf("ipPorts[%v] for %v is false", ipp, pi.ep.publicKey)
}
pi2 := m.byNodeKey[pi.ep.publicKey]
@ -1853,10 +1866,7 @@ func (m *peerMap) validate() error {
publicToDisco := make(map[key.NodePublic]key.DiscoPublic)
for disco, nodes := range m.nodesOfDisco {
for pub, v := range nodes {
if !v {
return fmt.Errorf("m.nodeOfDisco[%v][%v] is false, expected map to be set-like", disco, pub)
}
for pub := range nodes {
if _, ok := m.byNodeKey[pub]; !ok {
return fmt.Errorf("nodesOfDisco refers to public key %v, which is not present in byNodeKey", pub)
}
@ -2254,6 +2264,7 @@ func TestIsWireGuardOnlyPeer(t *testing.T) {
Addresses: []netip.Prefix{tsaip},
Peers: nodeViews([]*tailcfg.Node{
{
ID: 1,
Key: wgkey.Public(),
Endpoints: []string{wgEp.String()},
IsWireGuardOnly: true,
@ -2312,6 +2323,7 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
Addresses: []netip.Prefix{tsaip},
Peers: nodeViews([]*tailcfg.Node{
{
ID: 1,
Key: wgkey.Public(),
Endpoints: []string{wgEp.String()},
IsWireGuardOnly: true,

View File

@ -6,7 +6,9 @@ package magicsock
import (
"net/netip"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/util/set"
)
// peerInfo is all the information magicsock tracks about a particular
@ -17,39 +19,44 @@ type peerInfo struct {
// that when we're deleting this node, we can rapidly find out the
// keys that need deleting from peerMap.byIPPort without having to
// iterate over every IPPort known for any peer.
ipPorts map[netip.AddrPort]bool
ipPorts set.Set[netip.AddrPort]
}
func newPeerInfo(ep *endpoint) *peerInfo {
return &peerInfo{
ep: ep,
ipPorts: map[netip.AddrPort]bool{},
ipPorts: set.Set[netip.AddrPort]{},
}
}
// peerMap is an index of peerInfos by node (WireGuard) key, disco
// key, and discovered ip:port endpoints.
//
// Doesn't do any locking, all access must be done with Conn.mu held.
// It doesn't do any locking; all access must be done with Conn.mu held.
type peerMap struct {
byNodeKey map[key.NodePublic]*peerInfo
byIPPort map[netip.AddrPort]*peerInfo
byNodeID map[tailcfg.NodeID]*peerInfo
// nodesOfDisco contains the set of nodes that are using a
// DiscoKey. Usually those sets will be just one node.
nodesOfDisco map[key.DiscoPublic]map[key.NodePublic]bool
nodesOfDisco map[key.DiscoPublic]set.Set[key.NodePublic]
}
func newPeerMap() peerMap {
return peerMap{
byNodeKey: map[key.NodePublic]*peerInfo{},
byIPPort: map[netip.AddrPort]*peerInfo{},
nodesOfDisco: map[key.DiscoPublic]map[key.NodePublic]bool{},
byNodeID: map[tailcfg.NodeID]*peerInfo{},
nodesOfDisco: map[key.DiscoPublic]set.Set[key.NodePublic]{},
}
}
// nodeCount returns the number of nodes currently in m.
func (m *peerMap) nodeCount() int {
if len(m.byNodeKey) != len(m.byNodeID) {
devPanicf("internal error: peerMap.byNodeKey and byNodeID out of sync")
}
return len(m.byNodeKey)
}
@ -71,6 +78,15 @@ func (m *peerMap) endpointForNodeKey(nk key.NodePublic) (ep *endpoint, ok bool)
return nil, false
}
// endpointForNodeID returns the endpoint for nodeID, or nil if
// nodeID is not known to us.
func (m *peerMap) endpointForNodeID(nodeID tailcfg.NodeID) (ep *endpoint, ok bool) {
if info, ok := m.byNodeID[nodeID]; ok {
return info.ep, true
}
return nil, false
}
// endpointForIPPort returns the endpoint for the peer we
// believe to be at ipp, or nil if we don't know of any such peer.
func (m *peerMap) endpointForIPPort(ipp netip.AddrPort) (ep *endpoint, ok bool) {
@ -111,9 +127,16 @@ func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoi
// ep.publicKey, and updates indexes. m must already have a
// tailcfg.Node for ep.publicKey.
func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) {
if m.byNodeKey[ep.publicKey] == nil {
m.byNodeKey[ep.publicKey] = newPeerInfo(ep)
if ep.nodeID == 0 {
panic("internal error: upsertEndpoint called with zero NodeID")
}
pi, ok := m.byNodeKey[ep.publicKey]
if !ok {
pi = newPeerInfo(ep)
m.byNodeKey[ep.publicKey] = pi
}
m.byNodeID[ep.nodeID] = pi
epDisco := ep.disco.Load()
if epDisco == nil || oldDiscoKey != epDisco.key {
delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey)
@ -129,15 +152,14 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) {
for ipp := range ep.endpointState {
m.setNodeKeyForIPPort(ipp, ep.publicKey)
}
return
}
set := m.nodesOfDisco[epDisco.key]
if set == nil {
set = map[key.NodePublic]bool{}
m.nodesOfDisco[epDisco.key] = set
discoSet := m.nodesOfDisco[epDisco.key]
if discoSet == nil {
discoSet = set.Set[key.NodePublic]{}
m.nodesOfDisco[epDisco.key] = discoSet
}
set[ep.publicKey] = true
discoSet.Add(ep.publicKey)
}
// setNodeKeyForIPPort makes future peer lookups by ipp return the
@ -152,7 +174,7 @@ func (m *peerMap) setNodeKeyForIPPort(ipp netip.AddrPort, nk key.NodePublic) {
delete(m.byIPPort, ipp)
}
if pi, ok := m.byNodeKey[nk]; ok {
pi.ipPorts[ipp] = true
pi.ipPorts.Add(ipp)
m.byIPPort[ipp] = pi
}
}
@ -172,6 +194,9 @@ func (m *peerMap) deleteEndpoint(ep *endpoint) {
delete(m.nodesOfDisco[epDisco.key], ep.publicKey)
}
delete(m.byNodeKey, ep.publicKey)
if was, ok := m.byNodeID[ep.nodeID]; ok && was.ep == ep {
delete(m.byNodeID, ep.nodeID)
}
if pi == nil {
// Kneejerk paranoia from earlier issue 2801.
// Unexpected. But no logger plumbed here to log so.

View File

@ -109,6 +109,7 @@ func TestUserspaceEngineReconfig(t *testing.T) {
nm := &netmap.NetworkMap{
Peers: nodeViews([]*tailcfg.Node{
{
ID: 1,
Key: nkFromHex(nodeHex),
},
}),