wgengine/magicksock: remove nullability of RebindingUDPConns

Both RebindingUDPConns now always exist. the initial bind (which now
just calls rebind) now ensures that bind is called for both, such that
they both at least contain a blockForeverConn. Calling code no longer
needs to assert their state.

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker 2022-09-02 14:24:51 -07:00 committed by James Tucker
parent 56f6fe204b
commit 1f959edeb0
1 changed files with 13 additions and 47 deletions

View File

@ -262,8 +262,8 @@ type Conn struct {
// pconn4 and pconn6 are the underlying UDP sockets used to
// send/receive packets for wireguard and other magicsock
// protocols.
pconn4 *RebindingUDPConn
pconn6 *RebindingUDPConn
pconn4 RebindingUDPConn
pconn6 RebindingUDPConn
// closeDisco4 and closeDisco6 are io.Closers to shut down the raw
// disco packet receivers. If nil, no raw disco receiver is
@ -570,7 +570,7 @@ func NewConn(opts Options) (*Conn, error) {
}
c.linkMon = opts.LinkMonitor
if err := c.initialBind(); err != nil {
if err := c.rebind(keepCurrentPort); err != nil {
return nil, err
}
@ -578,15 +578,12 @@ func NewConn(opts Options) (*Conn, error) {
c.donec = c.connCtx.Done()
c.netChecker = &netcheck.Client{
Logf: logger.WithPrefix(c.logf, "netcheck: "),
GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 },
GetSTUNConn4: func() netcheck.STUNConn { return &c.pconn4 },
GetSTUNConn6: func() netcheck.STUNConn { return &c.pconn6 },
SkipExternalNetwork: inTest(),
PortMapper: c.portMapper,
}
if c.pconn6 != nil {
c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 }
}
c.ignoreSTUNPackets()
if d4, err := c.listenRawDisco("ip4"); err == nil {
@ -1240,10 +1237,6 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error)
return false, nil
}
case addr.Addr().Is6():
if c.pconn6 == nil {
// ignore IPv6 dest if we don't have an IPv6 address.
return false, nil
}
_, err = c.pconn6.WriteToUDPAddrPort(b, addr)
if err != nil && (c.noV6.Load() || neterror.TreatAsLostUDP(err)) {
return false, nil
@ -2660,12 +2653,8 @@ func (c *connBind) Close() error {
}
c.closed = true
// Unblock all outstanding receives.
if c.pconn4 != nil {
c.pconn4.Close()
}
if c.pconn6 != nil {
c.pconn6.Close()
}
c.pconn4.Close()
c.pconn6.Close()
if c.closeDisco4 != nil {
c.closeDisco4.Close()
}
@ -2710,12 +2699,8 @@ func (c *Conn) Close() error {
c.closeAllDerpLocked("conn-close")
// Ignore errors from c.pconnN.Close.
// They will frequently have been closed already by a call to connBind.Close.
if c.pconn6 != nil {
c.pconn6.Close()
}
if c.pconn4 != nil {
c.pconn4.Close()
}
c.pconn6.Close()
c.pconn4.Close()
// Wait on goroutines updating right at the end, once everything is
// already closed. We want everything else in the Conn to be
@ -2821,20 +2806,6 @@ func (c *Conn) ReSTUN(why string) {
}
}
func (c *Conn) initialBind() error {
if runtime.GOOS == "js" {
return nil
}
if err := c.bindSocket(&c.pconn4, "udp4", keepCurrentPort); err != nil {
return fmt.Errorf("magicsock: initialBind IPv4 failed: %w", err)
}
c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bindSocket(&c.pconn6, "udp6", keepCurrentPort); err != nil {
c.logf("magicsock: ignoring IPv6 bind failure: %v", err)
}
return nil
}
// listenPacket opens a packet listener.
// The network must be "udp4" or "udp6".
func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, error) {
@ -2852,12 +2823,7 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er
// The caller is responsible for informing the portMapper of any changes.
// If curPortFate is set to dropCurrentPort, no attempt is made to reuse
// the current port.
func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string, curPortFate currentPortFate) error {
if *rucPtr == nil {
*rucPtr = new(RebindingUDPConn)
}
ruc := *rucPtr
func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate currentPortFate) error {
// Hold the ruc lock the entire time, so that the close+bind is atomic
// from the perspective of ruc receive functions.
ruc.mu.Lock()
@ -2930,13 +2896,13 @@ func (c *Conn) rebind(curPortFate currentPortFate) error {
if runtime.GOOS == "js" {
return nil
}
if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil {
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
}
if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil {
return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err)
}
c.portMapper.SetLocalPort(c.LocalPort())
if err := c.bindSocket(&c.pconn6, "udp6", curPortFate); err != nil {
c.logf("magicsock: Rebind ignoring IPv6 bind failure: %v", err)
}
return nil
}