From 7aec8d4e6b4e72d5053d9ff7d819c28f9e035c2c Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 12 Aug 2024 14:09:45 -0700 Subject: [PATCH] cmd/stunstamp: refactor connection construction (#13110) getConns() is now responsible for returning both stable and unstable conns. conn and measureFn are now passed together via connAndMeasureFn. newConnAndMeasureFn() is responsible for constructing them. TCP measurement timeouts are adjusted to more closely match netcheck. Updates tailscale/corp#22114 Signed-off-by: Jordan Whited --- cmd/stunstamp/stunstamp.go | 297 +++++++++++++++-------------- cmd/stunstamp/stunstamp_default.go | 24 ++- cmd/stunstamp/stunstamp_linux.go | 27 ++- 3 files changed, 199 insertions(+), 149 deletions(-) diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index e01f3ac92..950fdc2cd 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -23,6 +23,7 @@ import ( "net/url" "os" "os/signal" + "runtime" "slices" "strconv" "strings" @@ -190,11 +191,10 @@ func addrInUse(err error, lport *lportForTCPConn) bool { return false } -func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { +func tcpDial(ctx context.Context, lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { for { var opErr error dialer := &net.Dialer{ - Timeout: time.Second * 2, LocalAddr: &net.TCPAddr{ Port: int(*lport), }, @@ -208,7 +208,7 @@ func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { if opErr != nil { panic(opErr) } - tcpConn, err := dialer.Dial("tcp", dst.String()) + tcpConn, err := dialer.DialContext(ctx, "tcp", dst.String()) if err != nil { if addrInUse(err, lport) { continue @@ -232,11 +232,23 @@ func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt t if !ok { return 0, fmt.Errorf("unexpected conn type: %T", conn) } - tcpConn, err := tcpDial(lport, dst) + // Set a dial timeout < 1s (TCP_TIMEOUT_INIT on Linux) as a means to avoid + // SYN retries, which can contribute to tcpi->rtt below. This simply limits + // retries from the initiator, but SYN+ACK on the reverse path can also + // time out and be retransmitted. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*750) + defer cancel() + tcpConn, err := tcpDial(ctx, lport, dst) if err != nil { return 0, tempError{err} } defer tcpConn.Close() + // This is an unreliable method to measure TCP RTT. The Linux kernel + // describes it as such in tcp_rtt_estimator(). We take some care in how we + // hold tcp_info->rtt here, e.g. clamping dial timeout, but if we are to + // actually use this elsewhere as an input to some decision it warrants a + // deeper study and consideration for alternative methods. Its usefulness + // here is as a point of comparison against the other methods. rtt, err = tcpinfo.RTT(tcpConn) if err != nil { return 0, tempError{err} @@ -250,15 +262,19 @@ func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPor return 0, fmt.Errorf("unexpected conn type: %T", conn) } var httpResult httpstat.Result - ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3) + // 5s mirrors net/netcheck.overallProbeTimeout used in net/netcheck.Client.measureHTTPSLatency. + reqCtx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*5) defer cancel() reqURL := "https://" + dst.String() + "/derp/latency-check" - req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + req, err := http.NewRequestWithContext(reqCtx, "GET", reqURL, nil) if err != nil { return 0, err } client := &http.Client{} - tcpConn, err := tcpDial(lport, dst) + // 1.5s mirrors derp/derphttp.dialnodeTimeout used in derp/derphttp.DialNode(). + dialCtx, dialCancel := context.WithTimeout(reqCtx, time.Millisecond*1500) + defer dialCancel() + tcpConn, err := tcpDial(dialCtx, lport, dst) if err != nil { return 0, tempError{err} } @@ -355,18 +371,17 @@ type nodeMeta struct { type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) -// probe measures round trip time for the node described by meta over -// conn against dstPort using fn. It may return a nil duration and nil error in -// the event of a timeout. A non-nil error indicates an unrecoverable or -// non-temporary error. -func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) { +// probe measures round trip time for the node described by meta over cf against +// dstPort. It may return a nil duration and nil error in the event of a +// timeout. A non-nil error indicates an unrecoverable or non-temporary error. +func probe(meta nodeMeta, cf *connAndMeasureFn, dstPort int) (*time.Duration, error) { ua := &net.UDPAddr{ IP: net.IP(meta.addr.AsSlice()), Port: dstPort, } time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx - rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) + rtt, err := cf.fn(cf.conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) if err != nil { if isTemporaryOrTimeoutErr(err) { log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) @@ -437,31 +452,69 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node return stale, nil } -func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) { +type connAndMeasureFn struct { + conn io.ReadWriteCloser + fn measureFn +} + +// newConnAndMeasureFn returns a connAndMeasureFn or an error. It may return +// nil for both if some combination of the supplied timestampSource, protocol, +// or connStability is unsupported. +func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) { + info := getProtocolSupportInfo(protocol) + if !info.stableConn && bool(stable) { + return nil, nil + } + if !info.userspaceTS && source == timestampSourceUserspace { + return nil, nil + } + if !info.kernelTS && source == timestampSourceKernel { + return nil, nil + } switch protocol { case protocolSTUN: if source == timestampSourceKernel { - return getUDPConnKernelTimestamp() + conn, err := getUDPConnKernelTimestamp() + if err != nil { + return nil, err + } + return &connAndMeasureFn{ + conn: conn, + fn: measureSTUNRTTKernel, + }, nil } else { - return net.ListenUDP("udp", &net.UDPAddr{}) + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return nil, err + } + return &connAndMeasureFn{ + conn: conn, + fn: measureSTUNRTT, + }, nil } case protocolICMP: // TODO(jwhited): implement - return nil, errors.New("unimplemented protocol") + return nil, nil case protocolHTTPS: localPort := 0 if stable { localPort = lports.get() } - ret := lportForTCPConn(localPort) - return &ret, nil + conn := lportForTCPConn(localPort) + return &connAndMeasureFn{ + conn: &conn, + fn: measureHTTPSRTT, + }, nil case protocolTCP: localPort := 0 if stable { localPort = lports.get() } - ret := lportForTCPConn(localPort) - return &ret, nil + conn := lportForTCPConn(localPort) + return &connAndMeasureFn{ + conn: &conn, + fn: measureTCPRTT, + }, nil } return nil, errors.New("unknown protocol") } @@ -472,41 +525,57 @@ type stableConnKey struct { port int } -func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) { - if !protocolSupportsStableConn(protocol) { - return [2]io.ReadWriteCloser{}, nil - } - key := stableConnKey{addr, protocol, dstPort} - conns, ok := stableConns[key] - if ok { - return conns, nil - } - - if protocolSupportsKernelTS(protocol) { - kconn, err := newConn(timestampSourceKernel, protocol, stableConn) - if err != nil { - return conns, err - } - conns[timestampSourceKernel] = kconn - } - uconn, err := newConn(timestampSourceUserspace, protocol, stableConn) - if err != nil { - if protocolSupportsKernelTS(protocol) { - conns[timestampSourceKernel].Close() - } - return conns, err - } - conns[timestampSourceUserspace] = uconn - stableConns[key] = conns - return conns, nil +type protocolSupportInfo struct { + kernelTS bool + userspaceTS bool + stableConn bool } -func protocolSupportsStableConn(p protocol) bool { - if p == protocolICMP { - // no value for ICMP - return false +func getConns( + stableConns map[stableConnKey][2]*connAndMeasureFn, + addr netip.Addr, + protocol protocol, + dstPort int, +) (stable, unstable [2]*connAndMeasureFn, err error) { + key := stableConnKey{addr, protocol, dstPort} + defer func() { + if err != nil { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + c := stable[source] + if c != nil { + c.conn.Close() + } + c = unstable[source] + if c != nil { + c.conn.Close() + } + } + } + }() + + var ok bool + stable, ok = stableConns[key] + if !ok { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + var cf *connAndMeasureFn + cf, err = newConnAndMeasureFn(source, protocol, stableConn) + if err != nil { + return + } + stable[source] = cf + } + stableConns[key] = stable } - return true + + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + var cf *connAndMeasureFn + cf, err = newConnAndMeasureFn(source, protocol, unstableConn) + if err != nil { + return + } + unstable[source] = cf + } + return stable, unstable, nil } // probeNodes measures the round-trip time for the protocols and ports described @@ -514,7 +583,7 @@ func protocolSupportsStableConn(p protocol) bool { // stableConns are used to recycle connections across calls to probeNodes. // probeNodes is also responsible for trimming stableConns based on node // lifetime in nodeMetaByAddr. It returns the results or an error if one occurs. -func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) { +func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]*connAndMeasureFn, portsByProtocol map[protocol][]int) ([]result, error) { wg := sync.WaitGroup{} results := make([]result, 0) resultsCh := make(chan result) @@ -524,47 +593,19 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo at := time.Now() addrsToProbe := make(map[netip.Addr]bool) - doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) { + doProbe := func(cf *connAndMeasureFn, meta nodeMeta, source timestampSource, stable connStability, protocol protocol, dstPort int) { defer wg.Done() r := result{ key: resultKey{ meta: meta, timestampSource: source, + connStability: stable, dstPort: dstPort, protocol: protocol, }, at: at, } - if conn == nil { - var err error - conn, err = newConn(source, protocol, unstableConn) - if err != nil { - select { - case <-doneCh: - return - case errCh <- err: - return - } - } - defer conn.Close() - } else { - r.key.connStability = stableConn - } - var fn measureFn - switch protocol { - case protocolSTUN: - fn = measureSTUNRTT - if source == timestampSourceKernel { - fn = measureSTUNRTTKernel - } - case protocolICMP: - // TODO(jwhited): implement - case protocolHTTPS: - fn = measureHTTPSRTT - case protocolTCP: - fn = measureTCPRTT - } - rtt, err := probe(meta, conn, fn, dstPort) + rtt, err := probe(meta, cf, dstPort) if err != nil { select { case <-doneCh: @@ -584,44 +625,39 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo addrsToProbe[meta.addr] = true for p, ports := range portsByProtocol { for _, port := range ports { - stable, err := getStableConns(stableConns, meta.addr, p, port) + stable, unstable, err := getConns(stableConns, meta.addr, p, port) if err != nil { close(doneCh) wg.Wait() return nil, err } - if protocolSupportsStableConn(p) { - wg.Add(1) - numProbes++ - go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port) - } - wg.Add(1) - numProbes++ - go doProbe(nil, meta, timestampSourceUserspace, p, port) - - if protocolSupportsKernelTS(p) { - if protocolSupportsStableConn(p) { + for i, cf := range stable { + if cf != nil { wg.Add(1) numProbes++ - go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port) + go doProbe(cf, meta, timestampSource(i), stableConn, p, port) } + } - wg.Add(1) - numProbes++ - go doProbe(nil, meta, timestampSourceKernel, p, port) + for i, cf := range unstable { + if cf != nil { + wg.Add(1) + numProbes++ + go doProbe(cf, meta, timestampSource(i), unstableConn, p, port) + } } } } } // cleanup conns we no longer need - for k, conns := range stableConns { + for k, cf := range stableConns { if !addrsToProbe[k.node] { - if conns[timestampSourceKernel] != nil { - conns[timestampSourceKernel].Close() + if cf[timestampSourceKernel] != nil { + cf[timestampSourceKernel].conn.Close() } - conns[timestampSourceUserspace].Close() + cf[timestampSourceUserspace].conn.Close() delete(stableConns, k) } } @@ -728,42 +764,16 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol Value: math.Float64frombits(staleNaN), }, } - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), - Samples: samples, - }) - if protocolSupportsStableConn(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), - Samples: samples, - }) - } - if protocolSupportsKernelTS(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), - Samples: samples, - }) - if protocolSupportsStableConn(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port), - Samples: samples, - }) + // We send stale markers for all combinations in the interest + // of simplicity. + for _, name := range []string{rttMetricName, timeoutsMetricName} { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + for _, stable := range []connStability{unstableConn, stableConn} { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(name, s, instance, source, stable, p, port), + Samples: samples, + }) + } } } } @@ -909,6 +919,9 @@ func getPortsFromFlag(f string) ([]int, error) { } func main() { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + log.Fatal("unsupported platform") + } flag.Parse() portsByProtocol := make(map[protocol][]int) @@ -1035,7 +1048,7 @@ func main() { // Comparison of stable and unstable 5-tuple results can shed light on // differences between paths where hashing (multipathing/load balancing) // comes into play. The inner 2 element array index is timestampSource. - stableConns := make(map[stableConnKey][2]io.ReadWriteCloser) + stableConns := make(map[stableConnKey][2]*connAndMeasureFn) // timeouts holds counts of timeout events. Values are persisted for the // lifetime of the related node in the DERP map. diff --git a/cmd/stunstamp/stunstamp_default.go b/cmd/stunstamp/stunstamp_default.go index 017af1251..36afdbb8f 100644 --- a/cmd/stunstamp/stunstamp_default.go +++ b/cmd/stunstamp/stunstamp_default.go @@ -20,8 +20,28 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad return 0, errors.New("unimplemented") } -func protocolSupportsKernelTS(_ protocol) bool { - return false +func getProtocolSupportInfo(p protocol) protocolSupportInfo { + switch p { + case protocolSTUN: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolHTTPS: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolTCP: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: false, + stableConn: true, + } + } + return protocolSupportInfo{} } func setSOReuseAddr(fd uintptr) error { diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index 148e4b0ef..e73d1ee3c 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -138,12 +138,29 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad } -func protocolSupportsKernelTS(p protocol) bool { - if p == protocolSTUN { - return true +func getProtocolSupportInfo(p protocol) protocolSupportInfo { + switch p { + case protocolSTUN: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: true, + stableConn: true, + } + case protocolHTTPS: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolTCP: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: false, + stableConn: true, + } + // TODO(jwhited): add ICMP } - // TODO: jwhited support ICMP - return false + return protocolSupportInfo{} } func setSOReuseAddr(fd uintptr) error {