From 6866aaeab39fee83a4f3292986e7022109d85685 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 14 Apr 2023 08:09:09 -0700 Subject: [PATCH] wgengine/magicsock: factor out receiveIPv4 & receiveIPv6 common code Updates #2331 Change-Id: I801df38b217f5d17203e8dc3b8654f44747e0f4b Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/magicsock.go | 123 +++++++++++---------------- wgengine/magicsock/magicsock_test.go | 9 +- 2 files changed, 57 insertions(+), 75 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 6fd7f4b1b..6fe0e8881 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -322,11 +322,6 @@ type Conn struct { // bind is the wireguard-go conn.Bind for Conn. bind *connBind - // ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and - // receiveIPv6, respectively, to cache an IPPort->endpoint for - // hot flows. - ippEndpoint4, ippEndpoint6 ippEndpointCache - // ============================================================ // Fields that must be accessed via atomic load/stores. @@ -1851,79 +1846,63 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) { c.receiveBatchPool.Put(batch) } -func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { - health.ReceiveIPv6.Enter() - defer health.ReceiveIPv6.Exit() - - batch := c.getReceiveBatchForBuffs(buffs) - defer c.putReceiveBatch(batch) - for { - numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0) - if err != nil { - if neterror.PacketWasTruncated(err) { - // TODO(raggi): discuss whether to log? - continue - } - return 0, err - } - - reportToCaller := false - for i, msg := range batch.msgs[:numMsgs] { - if msg.N == 0 { - sizes[i] = 0 - continue - } - ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok { - metricRecvDataIPv6.Add(1) - eps[i] = ep - sizes[i] = msg.N - reportToCaller = true - } else { - sizes[i] = 0 - } - } - - if reportToCaller { - return numMsgs, nil - } - } +// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4. +func (c *Conn) receiveIPv4() conn.ReceiveFunc { + return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4) } -func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { - health.ReceiveIPv4.Enter() - defer health.ReceiveIPv4.Exit() +// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6. +func (c *Conn) receiveIPv6() conn.ReceiveFunc { + return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6) +} - batch := c.getReceiveBatchForBuffs(buffs) - defer c.putReceiveBatch(batch) - for { - numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0) - if err != nil { - if neterror.PacketWasTruncated(err) { - // TODO(raggi): discuss whether to log? - continue - } - return 0, err +// mkReceiveFunc creates a ReceiveFunc reading from ruc. +// The provided healthItem and metric are updated if non-nil. +func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc { + // epCache caches an IPPort->endpoint for hot flows. + var epCache ippEndpointCache + + return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { + if healthItem != nil { + healthItem.Enter() + defer healthItem.Exit() + } + if ruc == nil { + panic("nil RebindingUDPConn") } - reportToCaller := false - for i, msg := range batch.msgs[:numMsgs] { - if msg.N == 0 { - sizes[i] = 0 - continue + batch := c.getReceiveBatchForBuffs(buffs) + defer c.putReceiveBatch(batch) + for { + numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0) + if err != nil { + if neterror.PacketWasTruncated(err) { + continue + } + return 0, err } - ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok { - metricRecvDataIPv4.Add(1) - eps[i] = ep - sizes[i] = msg.N - reportToCaller = true - } else { - sizes[i] = 0 + + reportToCaller := false + for i, msg := range batch.msgs[:numMsgs] { + if msg.N == 0 { + sizes[i] = 0 + continue + } + ipp := msg.Addr.(*net.UDPAddr).AddrPort() + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { + if metric != nil { + metric.Add(1) + } + eps[i] = ep + sizes[i] = msg.N + reportToCaller = true + } else { + sizes[i] = 0 + } + } + if reportToCaller { + return numMsgs, nil } - } - if reportToCaller { - return numMsgs, nil } } } @@ -3044,7 +3023,7 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error) return nil, 0, errors.New("magicsock: connBind already open") } c.closed = false - fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveIPv6, c.receiveDERP} + fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP} if runtime.GOOS == "js" { fns = []conn.ReceiveFunc{c.receiveDERP} } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 92b1f13b7..8c6e64fdc 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -374,8 +374,9 @@ func TestNewConn(t *testing.T) { sizes := make([]int, 1) eps := make([]wgconn.Endpoint, 1) pkts[0] = make([]byte, 64<<10) + receiveIPv4 := conn.receiveIPv4() for { - _, err := conn.receiveIPv4(pkts, sizes, eps) + _, err := receiveIPv4(pkts, sizes, eps) if err != nil { return } @@ -1284,11 +1285,12 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) { buffs[0] = make([]byte, 2<<10) sizes := make([]int, 1) eps := make([]wgconn.Endpoint, 1) + receiveIPv4 := conn.receiveIPv4() return func() { if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil { tb.Fatalf("WriteTo: %v", err) } - n, err := conn.receiveIPv4(buffs, sizes, eps) + n, err := receiveIPv4(buffs, sizes, eps) if err != nil { tb.Fatal(err) } @@ -1513,8 +1515,9 @@ func TestRebindStress(t *testing.T) { sizes := make([]int, 1) eps := make([]wgconn.Endpoint, 1) buffs[0] = make([]byte, 1500) + receiveIPv4 := conn.receiveIPv4() for { - _, err := conn.receiveIPv4(buffs, sizes, eps) + _, err := receiveIPv4(buffs, sizes, eps) if ctx.Err() != nil { errc <- nil return