diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index e01f3ac92..950fdc2cd 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -23,6 +23,7 @@ import ( "net/url" "os" "os/signal" + "runtime" "slices" "strconv" "strings" @@ -190,11 +191,10 @@ func addrInUse(err error, lport *lportForTCPConn) bool { return false } -func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { +func tcpDial(ctx context.Context, lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { for { var opErr error dialer := &net.Dialer{ - Timeout: time.Second * 2, LocalAddr: &net.TCPAddr{ Port: int(*lport), }, @@ -208,7 +208,7 @@ func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) { if opErr != nil { panic(opErr) } - tcpConn, err := dialer.Dial("tcp", dst.String()) + tcpConn, err := dialer.DialContext(ctx, "tcp", dst.String()) if err != nil { if addrInUse(err, lport) { continue @@ -232,11 +232,23 @@ func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt t if !ok { return 0, fmt.Errorf("unexpected conn type: %T", conn) } - tcpConn, err := tcpDial(lport, dst) + // Set a dial timeout < 1s (TCP_TIMEOUT_INIT on Linux) as a means to avoid + // SYN retries, which can contribute to tcpi->rtt below. This simply limits + // retries from the initiator, but SYN+ACK on the reverse path can also + // time out and be retransmitted. + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*750) + defer cancel() + tcpConn, err := tcpDial(ctx, lport, dst) if err != nil { return 0, tempError{err} } defer tcpConn.Close() + // This is an unreliable method to measure TCP RTT. The Linux kernel + // describes it as such in tcp_rtt_estimator(). We take some care in how we + // hold tcp_info->rtt here, e.g. clamping dial timeout, but if we are to + // actually use this elsewhere as an input to some decision it warrants a + // deeper study and consideration for alternative methods. Its usefulness + // here is as a point of comparison against the other methods. rtt, err = tcpinfo.RTT(tcpConn) if err != nil { return 0, tempError{err} @@ -250,15 +262,19 @@ func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPor return 0, fmt.Errorf("unexpected conn type: %T", conn) } var httpResult httpstat.Result - ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3) + // 5s mirrors net/netcheck.overallProbeTimeout used in net/netcheck.Client.measureHTTPSLatency. + reqCtx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*5) defer cancel() reqURL := "https://" + dst.String() + "/derp/latency-check" - req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + req, err := http.NewRequestWithContext(reqCtx, "GET", reqURL, nil) if err != nil { return 0, err } client := &http.Client{} - tcpConn, err := tcpDial(lport, dst) + // 1.5s mirrors derp/derphttp.dialnodeTimeout used in derp/derphttp.DialNode(). + dialCtx, dialCancel := context.WithTimeout(reqCtx, time.Millisecond*1500) + defer dialCancel() + tcpConn, err := tcpDial(dialCtx, lport, dst) if err != nil { return 0, tempError{err} } @@ -355,18 +371,17 @@ type nodeMeta struct { type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error) -// probe measures round trip time for the node described by meta over -// conn against dstPort using fn. It may return a nil duration and nil error in -// the event of a timeout. A non-nil error indicates an unrecoverable or -// non-temporary error. -func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) { +// probe measures round trip time for the node described by meta over cf against +// dstPort. It may return a nil duration and nil error in the event of a +// timeout. A non-nil error indicates an unrecoverable or non-temporary error. +func probe(meta nodeMeta, cf *connAndMeasureFn, dstPort int) (*time.Duration, error) { ua := &net.UDPAddr{ IP: net.IP(meta.addr.AsSlice()), Port: dstPort, } time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx - rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) + rtt, err := cf.fn(cf.conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort))) if err != nil { if isTemporaryOrTimeoutErr(err) { log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err) @@ -437,31 +452,69 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node return stale, nil } -func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) { +type connAndMeasureFn struct { + conn io.ReadWriteCloser + fn measureFn +} + +// newConnAndMeasureFn returns a connAndMeasureFn or an error. It may return +// nil for both if some combination of the supplied timestampSource, protocol, +// or connStability is unsupported. +func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) { + info := getProtocolSupportInfo(protocol) + if !info.stableConn && bool(stable) { + return nil, nil + } + if !info.userspaceTS && source == timestampSourceUserspace { + return nil, nil + } + if !info.kernelTS && source == timestampSourceKernel { + return nil, nil + } switch protocol { case protocolSTUN: if source == timestampSourceKernel { - return getUDPConnKernelTimestamp() + conn, err := getUDPConnKernelTimestamp() + if err != nil { + return nil, err + } + return &connAndMeasureFn{ + conn: conn, + fn: measureSTUNRTTKernel, + }, nil } else { - return net.ListenUDP("udp", &net.UDPAddr{}) + conn, err := net.ListenUDP("udp", &net.UDPAddr{}) + if err != nil { + return nil, err + } + return &connAndMeasureFn{ + conn: conn, + fn: measureSTUNRTT, + }, nil } case protocolICMP: // TODO(jwhited): implement - return nil, errors.New("unimplemented protocol") + return nil, nil case protocolHTTPS: localPort := 0 if stable { localPort = lports.get() } - ret := lportForTCPConn(localPort) - return &ret, nil + conn := lportForTCPConn(localPort) + return &connAndMeasureFn{ + conn: &conn, + fn: measureHTTPSRTT, + }, nil case protocolTCP: localPort := 0 if stable { localPort = lports.get() } - ret := lportForTCPConn(localPort) - return &ret, nil + conn := lportForTCPConn(localPort) + return &connAndMeasureFn{ + conn: &conn, + fn: measureTCPRTT, + }, nil } return nil, errors.New("unknown protocol") } @@ -472,41 +525,57 @@ type stableConnKey struct { port int } -func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) { - if !protocolSupportsStableConn(protocol) { - return [2]io.ReadWriteCloser{}, nil - } - key := stableConnKey{addr, protocol, dstPort} - conns, ok := stableConns[key] - if ok { - return conns, nil - } - - if protocolSupportsKernelTS(protocol) { - kconn, err := newConn(timestampSourceKernel, protocol, stableConn) - if err != nil { - return conns, err - } - conns[timestampSourceKernel] = kconn - } - uconn, err := newConn(timestampSourceUserspace, protocol, stableConn) - if err != nil { - if protocolSupportsKernelTS(protocol) { - conns[timestampSourceKernel].Close() - } - return conns, err - } - conns[timestampSourceUserspace] = uconn - stableConns[key] = conns - return conns, nil +type protocolSupportInfo struct { + kernelTS bool + userspaceTS bool + stableConn bool } -func protocolSupportsStableConn(p protocol) bool { - if p == protocolICMP { - // no value for ICMP - return false +func getConns( + stableConns map[stableConnKey][2]*connAndMeasureFn, + addr netip.Addr, + protocol protocol, + dstPort int, +) (stable, unstable [2]*connAndMeasureFn, err error) { + key := stableConnKey{addr, protocol, dstPort} + defer func() { + if err != nil { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + c := stable[source] + if c != nil { + c.conn.Close() + } + c = unstable[source] + if c != nil { + c.conn.Close() + } + } + } + }() + + var ok bool + stable, ok = stableConns[key] + if !ok { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + var cf *connAndMeasureFn + cf, err = newConnAndMeasureFn(source, protocol, stableConn) + if err != nil { + return + } + stable[source] = cf + } + stableConns[key] = stable } - return true + + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + var cf *connAndMeasureFn + cf, err = newConnAndMeasureFn(source, protocol, unstableConn) + if err != nil { + return + } + unstable[source] = cf + } + return stable, unstable, nil } // probeNodes measures the round-trip time for the protocols and ports described @@ -514,7 +583,7 @@ func protocolSupportsStableConn(p protocol) bool { // stableConns are used to recycle connections across calls to probeNodes. // probeNodes is also responsible for trimming stableConns based on node // lifetime in nodeMetaByAddr. It returns the results or an error if one occurs. -func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) { +func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]*connAndMeasureFn, portsByProtocol map[protocol][]int) ([]result, error) { wg := sync.WaitGroup{} results := make([]result, 0) resultsCh := make(chan result) @@ -524,47 +593,19 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo at := time.Now() addrsToProbe := make(map[netip.Addr]bool) - doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) { + doProbe := func(cf *connAndMeasureFn, meta nodeMeta, source timestampSource, stable connStability, protocol protocol, dstPort int) { defer wg.Done() r := result{ key: resultKey{ meta: meta, timestampSource: source, + connStability: stable, dstPort: dstPort, protocol: protocol, }, at: at, } - if conn == nil { - var err error - conn, err = newConn(source, protocol, unstableConn) - if err != nil { - select { - case <-doneCh: - return - case errCh <- err: - return - } - } - defer conn.Close() - } else { - r.key.connStability = stableConn - } - var fn measureFn - switch protocol { - case protocolSTUN: - fn = measureSTUNRTT - if source == timestampSourceKernel { - fn = measureSTUNRTTKernel - } - case protocolICMP: - // TODO(jwhited): implement - case protocolHTTPS: - fn = measureHTTPSRTT - case protocolTCP: - fn = measureTCPRTT - } - rtt, err := probe(meta, conn, fn, dstPort) + rtt, err := probe(meta, cf, dstPort) if err != nil { select { case <-doneCh: @@ -584,44 +625,39 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo addrsToProbe[meta.addr] = true for p, ports := range portsByProtocol { for _, port := range ports { - stable, err := getStableConns(stableConns, meta.addr, p, port) + stable, unstable, err := getConns(stableConns, meta.addr, p, port) if err != nil { close(doneCh) wg.Wait() return nil, err } - if protocolSupportsStableConn(p) { - wg.Add(1) - numProbes++ - go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port) - } - wg.Add(1) - numProbes++ - go doProbe(nil, meta, timestampSourceUserspace, p, port) - - if protocolSupportsKernelTS(p) { - if protocolSupportsStableConn(p) { + for i, cf := range stable { + if cf != nil { wg.Add(1) numProbes++ - go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port) + go doProbe(cf, meta, timestampSource(i), stableConn, p, port) } + } - wg.Add(1) - numProbes++ - go doProbe(nil, meta, timestampSourceKernel, p, port) + for i, cf := range unstable { + if cf != nil { + wg.Add(1) + numProbes++ + go doProbe(cf, meta, timestampSource(i), unstableConn, p, port) + } } } } } // cleanup conns we no longer need - for k, conns := range stableConns { + for k, cf := range stableConns { if !addrsToProbe[k.node] { - if conns[timestampSourceKernel] != nil { - conns[timestampSourceKernel].Close() + if cf[timestampSourceKernel] != nil { + cf[timestampSourceKernel].conn.Close() } - conns[timestampSourceUserspace].Close() + cf[timestampSourceUserspace].conn.Close() delete(stableConns, k) } } @@ -728,42 +764,16 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol Value: math.Float64frombits(staleNaN), }, } - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port), - Samples: samples, - }) - if protocolSupportsStableConn(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port), - Samples: samples, - }) - } - if protocolSupportsKernelTS(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port), - Samples: samples, - }) - if protocolSupportsStableConn(p) { - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port), - Samples: samples, - }) - staleMarkers = append(staleMarkers, prompb.TimeSeries{ - Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port), - Samples: samples, - }) + // We send stale markers for all combinations in the interest + // of simplicity. + for _, name := range []string{rttMetricName, timeoutsMetricName} { + for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} { + for _, stable := range []connStability{unstableConn, stableConn} { + staleMarkers = append(staleMarkers, prompb.TimeSeries{ + Labels: timeSeriesLabels(name, s, instance, source, stable, p, port), + Samples: samples, + }) + } } } } @@ -909,6 +919,9 @@ func getPortsFromFlag(f string) ([]int, error) { } func main() { + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { + log.Fatal("unsupported platform") + } flag.Parse() portsByProtocol := make(map[protocol][]int) @@ -1035,7 +1048,7 @@ func main() { // Comparison of stable and unstable 5-tuple results can shed light on // differences between paths where hashing (multipathing/load balancing) // comes into play. The inner 2 element array index is timestampSource. - stableConns := make(map[stableConnKey][2]io.ReadWriteCloser) + stableConns := make(map[stableConnKey][2]*connAndMeasureFn) // timeouts holds counts of timeout events. Values are persisted for the // lifetime of the related node in the DERP map. diff --git a/cmd/stunstamp/stunstamp_default.go b/cmd/stunstamp/stunstamp_default.go index 017af1251..36afdbb8f 100644 --- a/cmd/stunstamp/stunstamp_default.go +++ b/cmd/stunstamp/stunstamp_default.go @@ -20,8 +20,28 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad return 0, errors.New("unimplemented") } -func protocolSupportsKernelTS(_ protocol) bool { - return false +func getProtocolSupportInfo(p protocol) protocolSupportInfo { + switch p { + case protocolSTUN: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolHTTPS: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolTCP: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: false, + stableConn: true, + } + } + return protocolSupportInfo{} } func setSOReuseAddr(fd uintptr) error { diff --git a/cmd/stunstamp/stunstamp_linux.go b/cmd/stunstamp/stunstamp_linux.go index 148e4b0ef..e73d1ee3c 100644 --- a/cmd/stunstamp/stunstamp_linux.go +++ b/cmd/stunstamp/stunstamp_linux.go @@ -138,12 +138,29 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad } -func protocolSupportsKernelTS(p protocol) bool { - if p == protocolSTUN { - return true +func getProtocolSupportInfo(p protocol) protocolSupportInfo { + switch p { + case protocolSTUN: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: true, + stableConn: true, + } + case protocolHTTPS: + return protocolSupportInfo{ + kernelTS: false, + userspaceTS: true, + stableConn: true, + } + case protocolTCP: + return protocolSupportInfo{ + kernelTS: true, + userspaceTS: false, + stableConn: true, + } + // TODO(jwhited): add ICMP } - // TODO: jwhited support ICMP - return false + return protocolSupportInfo{} } func setSOReuseAddr(fd uintptr) error {