wgengine/magicsock: avoid RebindingUDPConn mutex in common read/write case

Change-Id: I209fac567326f2e926bace2582dbc67a8bc94c78
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-08-02 09:29:16 -07:00 committed by Brad Fitzpatrick
parent 116f55ff66
commit fb82299f5a
1 changed files with 26 additions and 23 deletions

View File

@ -2826,7 +2826,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
if debugAlwaysDERP {
c.logf("disabled %v per TS_DEBUG_ALWAYS_USE_DERP", network)
ruc.pconn = newBlockForeverConn()
ruc.setConnLocked(newBlockForeverConn())
return nil
}
@ -2860,7 +2860,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
continue
}
// Success.
ruc.pconn = pconn
ruc.setConnLocked(pconn)
if network == "udp4" {
health.SetUDP4Unbound(false)
}
@ -2871,7 +2871,7 @@ func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate
// Set pconn to a dummy conn whose reads block until closed.
// This keeps the receive funcs alive for a future in which
// we get a link change and we can try binding again.
ruc.pconn = newBlockForeverConn()
ruc.setConnLocked(newBlockForeverConn())
if network == "udp4" {
health.SetUDP4Unbound(true)
}
@ -2974,11 +2974,26 @@ func (c *Conn) ParseEndpoint(nodeKeyStr string) (conn.Endpoint, error) {
// RebindingUDPConn is a UDP socket that can be re-bound.
// Unix has no notion of re-binding a socket, so we swap it out for a new one.
type RebindingUDPConn struct {
mu sync.Mutex
// pconnAtomic is the same as pconn, but doesn't require acquiring mu. It's
// used for reads/writes and only upon failure do the reads/writes then
// check pconn (after acquiring mu) to see if there's been a rebind
// meanwhile.
// pconn isn't really needed, but makes some of the code simpler
// to keep it in a type safe form. TODO(bradfitz): really we should make a generic
// atomic.Value. Unfortunately Go 1.19's atomic.Pointer[T] is only for pointers,
// not interfaces.
pconnAtomic atomic.Value // of nettype.PacketConn
mu sync.Mutex // held while changing pconn (and pconnAtomic)
pconn nettype.PacketConn
}
// currentConn returns c's current pconn.
func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn) {
c.pconn = p
c.pconnAtomic.Store(p)
}
// currentConn returns c's current pconn, acquiring c.mu in the process.
func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
c.mu.Lock()
defer c.mu.Unlock()
@ -2989,7 +3004,7 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn {
// It returns the number of bytes copied and the source address.
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
for {
pconn := c.currentConn()
pconn := c.pconnAtomic.Load().(nettype.PacketConn)
n, addr, err := pconn.ReadFrom(b)
if err != nil && pconn != c.currentConn() {
continue
@ -3007,7 +3022,7 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
// when c's underlying connection is a net.UDPConn.
func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netip.AddrPort, err error) {
for {
pconn := c.currentConn()
pconn := c.pconnAtomic.Load().(nettype.PacketConn)
// Optimization: Treat *net.UDPConn specially.
// This lets us avoid allocations by calling ReadFromUDPAddrPort.
@ -3066,17 +3081,11 @@ func (c *RebindingUDPConn) closeLocked() error {
func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
for {
c.mu.Lock()
pconn := c.pconn
c.mu.Unlock()
pconn := c.pconnAtomic.Load().(nettype.PacketConn)
n, err := pconn.WriteTo(b, addr)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn
c.mu.Unlock()
if pconn != pconn2 {
if pconn != c.currentConn() {
continue
}
}
@ -3086,17 +3095,11 @@ func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
for {
c.mu.Lock()
pconn := c.pconn
c.mu.Unlock()
pconn := c.pconnAtomic.Load().(nettype.PacketConn)
n, err := pconn.WriteToUDPAddrPort(b, addr)
if err != nil {
c.mu.Lock()
pconn2 := c.pconn
c.mu.Unlock()
if pconn != pconn2 {
if pconn != c.currentConn() {
continue
}
}