cmd/stunstamp: refactor connection construction (#13110)

getConns() is now responsible for returning both stable and unstable
conns. conn and measureFn are now passed together via connAndMeasureFn.
newConnAndMeasureFn() is responsible for constructing them.

TCP measurement timeouts are adjusted to more closely match netcheck.

Updates tailscale/corp#22114

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited 2024-08-12 14:09:45 -07:00 committed by GitHub
parent 218110963d
commit 7aec8d4e6b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 149 deletions

View File

@ -23,6 +23,7 @@ import (
"net/url"
"os"
"os/signal"
"runtime"
"slices"
"strconv"
"strings"
@ -190,11 +191,10 @@ func addrInUse(err error, lport *lportForTCPConn) bool {
return false
}
func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
func tcpDial(ctx context.Context, lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
for {
var opErr error
dialer := &net.Dialer{
Timeout: time.Second * 2,
LocalAddr: &net.TCPAddr{
Port: int(*lport),
},
@ -208,7 +208,7 @@ func tcpDial(lport *lportForTCPConn, dst netip.AddrPort) (net.Conn, error) {
if opErr != nil {
panic(opErr)
}
tcpConn, err := dialer.Dial("tcp", dst.String())
tcpConn, err := dialer.DialContext(ctx, "tcp", dst.String())
if err != nil {
if addrInUse(err, lport) {
continue
@ -232,11 +232,23 @@ func measureTCPRTT(conn io.ReadWriteCloser, _ string, dst netip.AddrPort) (rtt t
if !ok {
return 0, fmt.Errorf("unexpected conn type: %T", conn)
}
tcpConn, err := tcpDial(lport, dst)
// Set a dial timeout < 1s (TCP_TIMEOUT_INIT on Linux) as a means to avoid
// SYN retries, which can contribute to tcpi->rtt below. This simply limits
// retries from the initiator, but SYN+ACK on the reverse path can also
// time out and be retransmitted.
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*750)
defer cancel()
tcpConn, err := tcpDial(ctx, lport, dst)
if err != nil {
return 0, tempError{err}
}
defer tcpConn.Close()
// This is an unreliable method to measure TCP RTT. The Linux kernel
// describes it as such in tcp_rtt_estimator(). We take some care in how we
// hold tcp_info->rtt here, e.g. clamping dial timeout, but if we are to
// actually use this elsewhere as an input to some decision it warrants a
// deeper study and consideration for alternative methods. Its usefulness
// here is as a point of comparison against the other methods.
rtt, err = tcpinfo.RTT(tcpConn)
if err != nil {
return 0, tempError{err}
@ -250,15 +262,19 @@ func measureHTTPSRTT(conn io.ReadWriteCloser, hostname string, dst netip.AddrPor
return 0, fmt.Errorf("unexpected conn type: %T", conn)
}
var httpResult httpstat.Result
ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*3)
// 5s mirrors net/netcheck.overallProbeTimeout used in net/netcheck.Client.measureHTTPSLatency.
reqCtx, cancel := context.WithTimeout(httpstat.WithHTTPStat(context.Background(), &httpResult), time.Second*5)
defer cancel()
reqURL := "https://" + dst.String() + "/derp/latency-check"
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
req, err := http.NewRequestWithContext(reqCtx, "GET", reqURL, nil)
if err != nil {
return 0, err
}
client := &http.Client{}
tcpConn, err := tcpDial(lport, dst)
// 1.5s mirrors derp/derphttp.dialnodeTimeout used in derp/derphttp.DialNode().
dialCtx, dialCancel := context.WithTimeout(reqCtx, time.Millisecond*1500)
defer dialCancel()
tcpConn, err := tcpDial(dialCtx, lport, dst)
if err != nil {
return 0, tempError{err}
}
@ -355,18 +371,17 @@ type nodeMeta struct {
type measureFn func(conn io.ReadWriteCloser, hostname string, dst netip.AddrPort) (rtt time.Duration, err error)
// probe measures round trip time for the node described by meta over
// conn against dstPort using fn. It may return a nil duration and nil error in
// the event of a timeout. A non-nil error indicates an unrecoverable or
// non-temporary error.
func probe(meta nodeMeta, conn io.ReadWriteCloser, fn measureFn, dstPort int) (*time.Duration, error) {
// probe measures round trip time for the node described by meta over cf against
// dstPort. It may return a nil duration and nil error in the event of a
// timeout. A non-nil error indicates an unrecoverable or non-temporary error.
func probe(meta nodeMeta, cf *connAndMeasureFn, dstPort int) (*time.Duration, error) {
ua := &net.UDPAddr{
IP: net.IP(meta.addr.AsSlice()),
Port: dstPort,
}
time.Sleep(rand.N(200 * time.Millisecond)) // jitter across tx
rtt, err := fn(conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort)))
rtt, err := cf.fn(cf.conn, meta.hostname, netip.AddrPortFrom(meta.addr, uint16(dstPort)))
if err != nil {
if isTemporaryOrTimeoutErr(err) {
log.Printf("temp error measuring RTT to %s(%s): %v", meta.hostname, ua.String(), err)
@ -437,31 +452,69 @@ func nodeMetaFromDERPMap(dm *tailcfg.DERPMap, nodeMetaByAddr map[netip.Addr]node
return stale, nil
}
func newConn(source timestampSource, protocol protocol, stable connStability) (io.ReadWriteCloser, error) {
type connAndMeasureFn struct {
conn io.ReadWriteCloser
fn measureFn
}
// newConnAndMeasureFn returns a connAndMeasureFn or an error. It may return
// nil for both if some combination of the supplied timestampSource, protocol,
// or connStability is unsupported.
func newConnAndMeasureFn(source timestampSource, protocol protocol, stable connStability) (*connAndMeasureFn, error) {
info := getProtocolSupportInfo(protocol)
if !info.stableConn && bool(stable) {
return nil, nil
}
if !info.userspaceTS && source == timestampSourceUserspace {
return nil, nil
}
if !info.kernelTS && source == timestampSourceKernel {
return nil, nil
}
switch protocol {
case protocolSTUN:
if source == timestampSourceKernel {
return getUDPConnKernelTimestamp()
conn, err := getUDPConnKernelTimestamp()
if err != nil {
return nil, err
}
return &connAndMeasureFn{
conn: conn,
fn: measureSTUNRTTKernel,
}, nil
} else {
return net.ListenUDP("udp", &net.UDPAddr{})
conn, err := net.ListenUDP("udp", &net.UDPAddr{})
if err != nil {
return nil, err
}
return &connAndMeasureFn{
conn: conn,
fn: measureSTUNRTT,
}, nil
}
case protocolICMP:
// TODO(jwhited): implement
return nil, errors.New("unimplemented protocol")
return nil, nil
case protocolHTTPS:
localPort := 0
if stable {
localPort = lports.get()
}
ret := lportForTCPConn(localPort)
return &ret, nil
conn := lportForTCPConn(localPort)
return &connAndMeasureFn{
conn: &conn,
fn: measureHTTPSRTT,
}, nil
case protocolTCP:
localPort := 0
if stable {
localPort = lports.get()
}
ret := lportForTCPConn(localPort)
return &ret, nil
conn := lportForTCPConn(localPort)
return &connAndMeasureFn{
conn: &conn,
fn: measureTCPRTT,
}, nil
}
return nil, errors.New("unknown protocol")
}
@ -472,41 +525,57 @@ type stableConnKey struct {
port int
}
func getStableConns(stableConns map[stableConnKey][2]io.ReadWriteCloser, addr netip.Addr, protocol protocol, dstPort int) ([2]io.ReadWriteCloser, error) {
if !protocolSupportsStableConn(protocol) {
return [2]io.ReadWriteCloser{}, nil
}
key := stableConnKey{addr, protocol, dstPort}
conns, ok := stableConns[key]
if ok {
return conns, nil
}
if protocolSupportsKernelTS(protocol) {
kconn, err := newConn(timestampSourceKernel, protocol, stableConn)
if err != nil {
return conns, err
}
conns[timestampSourceKernel] = kconn
}
uconn, err := newConn(timestampSourceUserspace, protocol, stableConn)
if err != nil {
if protocolSupportsKernelTS(protocol) {
conns[timestampSourceKernel].Close()
}
return conns, err
}
conns[timestampSourceUserspace] = uconn
stableConns[key] = conns
return conns, nil
type protocolSupportInfo struct {
kernelTS bool
userspaceTS bool
stableConn bool
}
func protocolSupportsStableConn(p protocol) bool {
if p == protocolICMP {
// no value for ICMP
return false
func getConns(
stableConns map[stableConnKey][2]*connAndMeasureFn,
addr netip.Addr,
protocol protocol,
dstPort int,
) (stable, unstable [2]*connAndMeasureFn, err error) {
key := stableConnKey{addr, protocol, dstPort}
defer func() {
if err != nil {
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
c := stable[source]
if c != nil {
c.conn.Close()
}
c = unstable[source]
if c != nil {
c.conn.Close()
}
}
}
}()
var ok bool
stable, ok = stableConns[key]
if !ok {
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
var cf *connAndMeasureFn
cf, err = newConnAndMeasureFn(source, protocol, stableConn)
if err != nil {
return
}
stable[source] = cf
}
stableConns[key] = stable
}
return true
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
var cf *connAndMeasureFn
cf, err = newConnAndMeasureFn(source, protocol, unstableConn)
if err != nil {
return
}
unstable[source] = cf
}
return stable, unstable, nil
}
// probeNodes measures the round-trip time for the protocols and ports described
@ -514,7 +583,7 @@ func protocolSupportsStableConn(p protocol) bool {
// stableConns are used to recycle connections across calls to probeNodes.
// probeNodes is also responsible for trimming stableConns based on node
// lifetime in nodeMetaByAddr. It returns the results or an error if one occurs.
func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]io.ReadWriteCloser, portsByProtocol map[protocol][]int) ([]result, error) {
func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableConnKey][2]*connAndMeasureFn, portsByProtocol map[protocol][]int) ([]result, error) {
wg := sync.WaitGroup{}
results := make([]result, 0)
resultsCh := make(chan result)
@ -524,47 +593,19 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
at := time.Now()
addrsToProbe := make(map[netip.Addr]bool)
doProbe := func(conn io.ReadWriteCloser, meta nodeMeta, source timestampSource, protocol protocol, dstPort int) {
doProbe := func(cf *connAndMeasureFn, meta nodeMeta, source timestampSource, stable connStability, protocol protocol, dstPort int) {
defer wg.Done()
r := result{
key: resultKey{
meta: meta,
timestampSource: source,
connStability: stable,
dstPort: dstPort,
protocol: protocol,
},
at: at,
}
if conn == nil {
var err error
conn, err = newConn(source, protocol, unstableConn)
if err != nil {
select {
case <-doneCh:
return
case errCh <- err:
return
}
}
defer conn.Close()
} else {
r.key.connStability = stableConn
}
var fn measureFn
switch protocol {
case protocolSTUN:
fn = measureSTUNRTT
if source == timestampSourceKernel {
fn = measureSTUNRTTKernel
}
case protocolICMP:
// TODO(jwhited): implement
case protocolHTTPS:
fn = measureHTTPSRTT
case protocolTCP:
fn = measureTCPRTT
}
rtt, err := probe(meta, conn, fn, dstPort)
rtt, err := probe(meta, cf, dstPort)
if err != nil {
select {
case <-doneCh:
@ -584,44 +625,39 @@ func probeNodes(nodeMetaByAddr map[netip.Addr]nodeMeta, stableConns map[stableCo
addrsToProbe[meta.addr] = true
for p, ports := range portsByProtocol {
for _, port := range ports {
stable, err := getStableConns(stableConns, meta.addr, p, port)
stable, unstable, err := getConns(stableConns, meta.addr, p, port)
if err != nil {
close(doneCh)
wg.Wait()
return nil, err
}
if protocolSupportsStableConn(p) {
wg.Add(1)
numProbes++
go doProbe(stable[timestampSourceUserspace], meta, timestampSourceUserspace, p, port)
}
wg.Add(1)
numProbes++
go doProbe(nil, meta, timestampSourceUserspace, p, port)
if protocolSupportsKernelTS(p) {
if protocolSupportsStableConn(p) {
for i, cf := range stable {
if cf != nil {
wg.Add(1)
numProbes++
go doProbe(stable[timestampSourceKernel], meta, timestampSourceKernel, p, port)
go doProbe(cf, meta, timestampSource(i), stableConn, p, port)
}
}
wg.Add(1)
numProbes++
go doProbe(nil, meta, timestampSourceKernel, p, port)
for i, cf := range unstable {
if cf != nil {
wg.Add(1)
numProbes++
go doProbe(cf, meta, timestampSource(i), unstableConn, p, port)
}
}
}
}
}
// cleanup conns we no longer need
for k, conns := range stableConns {
for k, cf := range stableConns {
if !addrsToProbe[k.node] {
if conns[timestampSourceKernel] != nil {
conns[timestampSourceKernel].Close()
if cf[timestampSourceKernel] != nil {
cf[timestampSourceKernel].conn.Close()
}
conns[timestampSourceUserspace].Close()
cf[timestampSourceUserspace].conn.Close()
delete(stableConns, k)
}
}
@ -728,42 +764,16 @@ func staleMarkersFromNodeMeta(stale []nodeMeta, instance string, portsByProtocol
Value: math.Float64frombits(staleNaN),
},
}
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, unstableConn, p, port),
Samples: samples,
})
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceUserspace, stableConn, p, port),
Samples: samples,
})
}
if protocolSupportsKernelTS(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, unstableConn, p, port),
Samples: samples,
})
if protocolSupportsStableConn(p) {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(rttMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(timeoutsMetricName, s, instance, timestampSourceKernel, stableConn, p, port),
Samples: samples,
})
// We send stale markers for all combinations in the interest
// of simplicity.
for _, name := range []string{rttMetricName, timeoutsMetricName} {
for _, source := range []timestampSource{timestampSourceUserspace, timestampSourceKernel} {
for _, stable := range []connStability{unstableConn, stableConn} {
staleMarkers = append(staleMarkers, prompb.TimeSeries{
Labels: timeSeriesLabels(name, s, instance, source, stable, p, port),
Samples: samples,
})
}
}
}
}
@ -909,6 +919,9 @@ func getPortsFromFlag(f string) ([]int, error) {
}
func main() {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
log.Fatal("unsupported platform")
}
flag.Parse()
portsByProtocol := make(map[protocol][]int)
@ -1035,7 +1048,7 @@ func main() {
// Comparison of stable and unstable 5-tuple results can shed light on
// differences between paths where hashing (multipathing/load balancing)
// comes into play. The inner 2 element array index is timestampSource.
stableConns := make(map[stableConnKey][2]io.ReadWriteCloser)
stableConns := make(map[stableConnKey][2]*connAndMeasureFn)
// timeouts holds counts of timeout events. Values are persisted for the
// lifetime of the related node in the DERP map.

View File

@ -20,8 +20,28 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad
return 0, errors.New("unimplemented")
}
func protocolSupportsKernelTS(_ protocol) bool {
return false
func getProtocolSupportInfo(p protocol) protocolSupportInfo {
switch p {
case protocolSTUN:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolHTTPS:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolTCP:
return protocolSupportInfo{
kernelTS: true,
userspaceTS: false,
stableConn: true,
}
}
return protocolSupportInfo{}
}
func setSOReuseAddr(fd uintptr) error {

View File

@ -138,12 +138,29 @@ func measureSTUNRTTKernel(conn io.ReadWriteCloser, hostname string, dst netip.Ad
}
func protocolSupportsKernelTS(p protocol) bool {
if p == protocolSTUN {
return true
func getProtocolSupportInfo(p protocol) protocolSupportInfo {
switch p {
case protocolSTUN:
return protocolSupportInfo{
kernelTS: true,
userspaceTS: true,
stableConn: true,
}
case protocolHTTPS:
return protocolSupportInfo{
kernelTS: false,
userspaceTS: true,
stableConn: true,
}
case protocolTCP:
return protocolSupportInfo{
kernelTS: true,
userspaceTS: false,
stableConn: true,
}
// TODO(jwhited): add ICMP
}
// TODO: jwhited support ICMP
return false
return protocolSupportInfo{}
}
func setSOReuseAddr(fd uintptr) error {