derp: prevent concurrent access to multiForwarder map

Instead of iterating over the map to determine the preferred forwarder
on every packet (which could happen concurrently with map mutations),
store it separately in an atomic variable.

Fixes #6445

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
Anton Tolchanov 2022-11-22 16:13:53 +00:00 committed by Anton Tolchanov
parent 6e33d2da2b
commit 6cc6c70d70
2 changed files with 136 additions and 39 deletions

View File

@ -40,6 +40,7 @@ import (
"tailscale.com/disco" "tailscale.com/disco"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/metrics" "tailscale.com/metrics"
"tailscale.com/syncs"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/version" "tailscale.com/version"
@ -1560,22 +1561,20 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
// Duplicate registration of same forwarder. Ignore. // Duplicate registration of same forwarder. Ignore.
return return
} }
if m, ok := prev.(multiForwarder); ok { if m, ok := prev.(*multiForwarder); ok {
if _, ok := m[fwd]; ok { if _, ok := m.all[fwd]; ok {
// Duplicate registration of same forwarder in set; ignore. // Duplicate registration of same forwarder in set; ignore.
return return
} }
m[fwd] = m.maxVal() + 1 m.add(fwd)
return return
} }
if prev != nil { if prev != nil {
// Otherwise, the existing value is not a set, // Otherwise, the existing value is not a set,
// not a dup, and not local-only (nil) so make // not a dup, and not local-only (nil) so make
// it a set. // it a set. `prev` existed first, so will have higher
fwd = multiForwarder{ // priority.
prev: 1, // existed 1st, higher priority fwd = newMultiForwarder(prev, fwd)
fwd: 2, // the passed in fwd is in 2nd place
}
s.multiForwarderCreated.Add(1) s.multiForwarderCreated.Add(1)
} }
} }
@ -1591,19 +1590,14 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
if !ok { if !ok {
return return
} }
if m, ok := v.(multiForwarder); ok { if m, ok := v.(*multiForwarder); ok {
if len(m) < 2 { if len(m.all) < 2 {
panic("unexpected") panic("unexpected")
} }
delete(m, fwd) if remain, isLast := m.deleteLocked(fwd); isLast {
// If fwd was in m and we no longer need to be a // If fwd was in m and we no longer need to be a
// multiForwarder, replace the entry with the // multiForwarder, replace the entry with the
// remaining PacketForwarder. // remaining PacketForwarder.
if len(m) == 1 {
var remain PacketForwarder
for k := range m {
remain = k
}
s.clientsMesh[dst] = remain s.clientsMesh[dst] = remain
s.multiForwarderDeleted.Add(1) s.multiForwarderDeleted.Add(1)
} }
@ -1635,27 +1629,65 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
// client is. The map value is unique connection number; the lowest // client is. The map value is unique connection number; the lowest
// one has been seen the longest. It's used to make sure we forward // one has been seen the longest. It's used to make sure we forward
// packets consistently to the same node and don't pick randomly. // packets consistently to the same node and don't pick randomly.
type multiForwarder map[PacketForwarder]uint8 type multiForwarder struct {
fwd syncs.AtomicValue[PacketForwarder] // preferred forwarder.
all map[PacketForwarder]uint8 // all forwarders, protected by s.mu.
}
func (m multiForwarder) maxVal() (max uint8) { // newMultiForwarder creates a new multiForwarder.
for _, v := range m { // The first PacketForwarder passed to this function will be the preferred one.
func newMultiForwarder(fwds ...PacketForwarder) *multiForwarder {
f := &multiForwarder{all: make(map[PacketForwarder]uint8)}
f.fwd.Store(fwds[0])
for idx, fwd := range fwds {
f.all[fwd] = uint8(idx)
}
return f
}
// add adds a new forwarder to the map with a connection number that
// is higher than the existing ones.
func (f *multiForwarder) add(fwd PacketForwarder) {
var max uint8
for _, v := range f.all {
if v > max { if v > max {
max = v max = v
} }
} }
return f.all[fwd] = max + 1
} }
func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error { // deleteLocked removes a packet forwarder from the map. It expects Server.mu to be held.
var fwd PacketForwarder // If only one forwarder remains after the removal, it will be returned alongside a `true` boolean value.
var lowest uint8 func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, isLast bool) {
for k, v := range m { delete(f.all, fwd)
if fwd == nil || v < lowest {
fwd = k if fwd == f.fwd.Load() {
lowest = v // The preferred forwarder has been removed, choose a new one
// based on the lowest index.
var lowestfwd PacketForwarder
var lowest uint8
for k, v := range f.all {
if lowestfwd == nil || v < lowest {
lowestfwd = k
lowest = v
}
}
if lowestfwd != nil {
f.fwd.Store(lowestfwd)
} }
} }
return fwd.ForwardPacket(src, dst, payload)
if len(f.all) == 1 {
for k := range f.all {
return k, true
}
}
return nil, false
}
func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
return f.fwd.Load().ForwardPacket(src, dst, payload)
} }
func (s *Server) expVarFunc(f func() any) expvar.Func { func (s *Server) expVarFunc(f func() any) expvar.Func {

View File

@ -19,6 +19,7 @@ import (
"net" "net"
"os" "os"
"reflect" "reflect"
"strconv"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) {
s.AddPacketForwarder(u1, testFwd(100)) s.AddPacketForwarder(u1, testFwd(100))
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
want(map[key.NodePublic]PacketForwarder{ want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)),
testFwd(1): 1,
testFwd(100): 2,
},
}) })
wantCounter(&s.multiForwarderCreated, 1) wantCounter(&s.multiForwarderCreated, 1)
// Removing a forwarder in a multi set that doesn't exist; does nothing. // Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55)) s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.NodePublic]PacketForwarder{ want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{ u1: newMultiForwarder(testFwd(1), testFwd(100)),
testFwd(1): 1,
testFwd(100): 2,
},
}) })
// Removing a forwarder in a multi set that does exist should collapse it away // Removing a forwarder in a multi set that does exist should collapse it away
@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) {
}) })
} }
type channelFwd struct {
// id is to ensure that different instances that reference the
// same channel are not equal, as they are used as keys in the
// multiForwarder map.
id int
c chan []byte
}
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
f.c <- packet
return nil
}
func TestMultiForwarder(t *testing.T) {
received := 0
var wg sync.WaitGroup
ch := make(chan []byte)
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
clients: make(map[key.NodePublic]clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{},
}
u := pubAll(1)
s.AddPacketForwarder(u, channelFwd{1, ch})
wg.Add(2)
go func() {
defer wg.Done()
for {
select {
case <-ch:
received += 1
case <-ctx.Done():
return
}
}
}()
go func() {
defer wg.Done()
for {
s.AddPacketForwarder(u, channelFwd{2, ch})
s.AddPacketForwarder(u, channelFwd{3, ch})
s.RemovePacketForwarder(u, channelFwd{2, ch})
s.RemovePacketForwarder(u, channelFwd{1, ch})
s.AddPacketForwarder(u, channelFwd{1, ch})
s.RemovePacketForwarder(u, channelFwd{3, ch})
if ctx.Err() != nil {
return
}
}
}()
// Number of messages is chosen arbitrarily, just for this loop to
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
numMsgs := 5000
var fwd PacketForwarder
for i := 0; i < numMsgs; i++ {
s.mu.Lock()
fwd = s.clientsMesh[u]
s.mu.Unlock()
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i)))
}
cancel()
wg.Wait()
if received != numMsgs {
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received)
}
}
func TestMetaCert(t *testing.T) { func TestMetaCert(t *testing.T) {
priv := key.NewNode() priv := key.NewNode()
pub := priv.Public() pub := priv.Public()