derp: prevent concurrent access to multiForwarder map
Instead of iterating over the map to determine the preferred forwarder on every packet (which could happen concurrently with map mutations), store it separately in an atomic variable. Fixes #6445 Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
parent
6e33d2da2b
commit
6cc6c70d70
|
@ -40,6 +40,7 @@ import (
|
|||
"tailscale.com/disco"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/metrics"
|
||||
"tailscale.com/syncs"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
|
@ -1560,22 +1561,20 @@ func (s *Server) AddPacketForwarder(dst key.NodePublic, fwd PacketForwarder) {
|
|||
// Duplicate registration of same forwarder. Ignore.
|
||||
return
|
||||
}
|
||||
if m, ok := prev.(multiForwarder); ok {
|
||||
if _, ok := m[fwd]; ok {
|
||||
if m, ok := prev.(*multiForwarder); ok {
|
||||
if _, ok := m.all[fwd]; ok {
|
||||
// Duplicate registration of same forwarder in set; ignore.
|
||||
return
|
||||
}
|
||||
m[fwd] = m.maxVal() + 1
|
||||
m.add(fwd)
|
||||
return
|
||||
}
|
||||
if prev != nil {
|
||||
// Otherwise, the existing value is not a set,
|
||||
// not a dup, and not local-only (nil) so make
|
||||
// it a set.
|
||||
fwd = multiForwarder{
|
||||
prev: 1, // existed 1st, higher priority
|
||||
fwd: 2, // the passed in fwd is in 2nd place
|
||||
}
|
||||
// it a set. `prev` existed first, so will have higher
|
||||
// priority.
|
||||
fwd = newMultiForwarder(prev, fwd)
|
||||
s.multiForwarderCreated.Add(1)
|
||||
}
|
||||
}
|
||||
|
@ -1591,19 +1590,14 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
if m, ok := v.(multiForwarder); ok {
|
||||
if len(m) < 2 {
|
||||
if m, ok := v.(*multiForwarder); ok {
|
||||
if len(m.all) < 2 {
|
||||
panic("unexpected")
|
||||
}
|
||||
delete(m, fwd)
|
||||
if remain, isLast := m.deleteLocked(fwd); isLast {
|
||||
// If fwd was in m and we no longer need to be a
|
||||
// multiForwarder, replace the entry with the
|
||||
// remaining PacketForwarder.
|
||||
if len(m) == 1 {
|
||||
var remain PacketForwarder
|
||||
for k := range m {
|
||||
remain = k
|
||||
}
|
||||
s.clientsMesh[dst] = remain
|
||||
s.multiForwarderDeleted.Add(1)
|
||||
}
|
||||
|
@ -1635,27 +1629,65 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder)
|
|||
// client is. The map value is unique connection number; the lowest
|
||||
// one has been seen the longest. It's used to make sure we forward
|
||||
// packets consistently to the same node and don't pick randomly.
|
||||
type multiForwarder map[PacketForwarder]uint8
|
||||
type multiForwarder struct {
|
||||
fwd syncs.AtomicValue[PacketForwarder] // preferred forwarder.
|
||||
all map[PacketForwarder]uint8 // all forwarders, protected by s.mu.
|
||||
}
|
||||
|
||||
func (m multiForwarder) maxVal() (max uint8) {
|
||||
for _, v := range m {
|
||||
// newMultiForwarder creates a new multiForwarder.
|
||||
// The first PacketForwarder passed to this function will be the preferred one.
|
||||
func newMultiForwarder(fwds ...PacketForwarder) *multiForwarder {
|
||||
f := &multiForwarder{all: make(map[PacketForwarder]uint8)}
|
||||
f.fwd.Store(fwds[0])
|
||||
for idx, fwd := range fwds {
|
||||
f.all[fwd] = uint8(idx)
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// add adds a new forwarder to the map with a connection number that
|
||||
// is higher than the existing ones.
|
||||
func (f *multiForwarder) add(fwd PacketForwarder) {
|
||||
var max uint8
|
||||
for _, v := range f.all {
|
||||
if v > max {
|
||||
max = v
|
||||
}
|
||||
}
|
||||
return
|
||||
f.all[fwd] = max + 1
|
||||
}
|
||||
|
||||
func (m multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
|
||||
var fwd PacketForwarder
|
||||
// deleteLocked removes a packet forwarder from the map. It expects Server.mu to be held.
|
||||
// If only one forwarder remains after the removal, it will be returned alongside a `true` boolean value.
|
||||
func (f *multiForwarder) deleteLocked(fwd PacketForwarder) (_ PacketForwarder, isLast bool) {
|
||||
delete(f.all, fwd)
|
||||
|
||||
if fwd == f.fwd.Load() {
|
||||
// The preferred forwarder has been removed, choose a new one
|
||||
// based on the lowest index.
|
||||
var lowestfwd PacketForwarder
|
||||
var lowest uint8
|
||||
for k, v := range m {
|
||||
if fwd == nil || v < lowest {
|
||||
fwd = k
|
||||
for k, v := range f.all {
|
||||
if lowestfwd == nil || v < lowest {
|
||||
lowestfwd = k
|
||||
lowest = v
|
||||
}
|
||||
}
|
||||
return fwd.ForwardPacket(src, dst, payload)
|
||||
if lowestfwd != nil {
|
||||
f.fwd.Store(lowestfwd)
|
||||
}
|
||||
}
|
||||
|
||||
if len(f.all) == 1 {
|
||||
for k := range f.all {
|
||||
return k, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (f *multiForwarder) ForwardPacket(src, dst key.NodePublic, payload []byte) error {
|
||||
return f.fwd.Load().ForwardPacket(src, dst, payload)
|
||||
}
|
||||
|
||||
func (s *Server) expVarFunc(f func() any) expvar.Func {
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) {
|
|||
s.AddPacketForwarder(u1, testFwd(100))
|
||||
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
|
||||
want(map[key.NodePublic]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
||||
})
|
||||
wantCounter(&s.multiForwarderCreated, 1)
|
||||
|
||||
// Removing a forwarder in a multi set that doesn't exist; does nothing.
|
||||
s.RemovePacketForwarder(u1, testFwd(55))
|
||||
want(map[key.NodePublic]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
||||
})
|
||||
|
||||
// Removing a forwarder in a multi set that does exist should collapse it away
|
||||
|
@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
type channelFwd struct {
|
||||
// id is to ensure that different instances that reference the
|
||||
// same channel are not equal, as they are used as keys in the
|
||||
// multiForwarder map.
|
||||
id int
|
||||
c chan []byte
|
||||
}
|
||||
|
||||
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
|
||||
f.c <- packet
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMultiForwarder(t *testing.T) {
|
||||
received := 0
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan []byte)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
s := &Server{
|
||||
clients: make(map[key.NodePublic]clientSet),
|
||||
clientsMesh: map[key.NodePublic]PacketForwarder{},
|
||||
}
|
||||
u := pubAll(1)
|
||||
s.AddPacketForwarder(u, channelFwd{1, ch})
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
received += 1
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
s.AddPacketForwarder(u, channelFwd{2, ch})
|
||||
s.AddPacketForwarder(u, channelFwd{3, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{2, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{1, ch})
|
||||
s.AddPacketForwarder(u, channelFwd{1, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{3, ch})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Number of messages is chosen arbitrarily, just for this loop to
|
||||
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
|
||||
numMsgs := 5000
|
||||
var fwd PacketForwarder
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
s.mu.Lock()
|
||||
fwd = s.clientsMesh[u]
|
||||
s.mu.Unlock()
|
||||
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if received != numMsgs {
|
||||
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received)
|
||||
}
|
||||
}
|
||||
func TestMetaCert(t *testing.T) {
|
||||
priv := key.NewNode()
|
||||
pub := priv.Public()
|
||||
|
|
Loading…
Reference in New Issue