diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index c3b2e54b8..e01cb1f9a 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -46,6 +46,7 @@ import ( "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/util/multierr" @@ -497,11 +498,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, tr.DisableCompression = true // (mis)use httptrace to extract the underlying net.Conn from the - // transport. We make exactly 1 request using this transport, so - // there will be exactly 1 GotConn call. Additionally, the - // transport handles 101 Switching Protocols correctly, such that - // the Conn will not be reused or kept alive by the transport once - // the response has been handed back from RoundTrip. + // transport. The transport handles 101 Switching Protocols correctly, + // such that the Conn will not be reused or kept alive by the transport + // once the response has been handed back from RoundTrip. // // In theory, the machinery of net/http should make it such that // the trace callback happens-before we get the response, but @@ -517,10 +516,16 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, // unexpected EOFs...), and we're bound to forget someday and // introduce a protocol optimization at a higher level that starts // eagerly transmitting from the server. - connCh := make(chan net.Conn, 1) + var lastConn syncs.AtomicValue[net.Conn] trace := httptrace.ClientTrace{ + // Even though we only make a single HTTP request which should + // require a single connection, the context (with the attached + // trace configuration) might be used by our custom dialer to + // make other HTTP requests (e.g. BootstrapDNS). We only care + // about the last connection made, which should be the one to + // the control server. GotConn: func(info httptrace.GotConnInfo) { - connCh <- info.Conn + lastConn.Store(info.Conn) }, } ctx = httptrace.WithClientTrace(ctx, &trace) @@ -548,11 +553,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, addr netip.Addr, // is still a read buffer attached to it within resp.Body. So, we // must direct I/O through resp.Body, but we can still use the // underlying net.Conn for stuff like deadlines. - var switchedConn net.Conn - select { - case switchedConn = <-connCh: - default: - } + switchedConn := lastConn.Load() if switchedConn == nil { resp.Body.Close() return nil, fmt.Errorf("httptrace didn't provide a connection") diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index ba2767289..8c8ed7f57 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -11,10 +11,12 @@ import ( "log" "net" "net/http" + "net/http/httptest" "net/http/httputil" "net/netip" "net/url" "runtime" + "slices" "strconv" "sync" "testing" @@ -41,6 +43,8 @@ type httpTestParam struct { makeHTTPHangAfterUpgrade bool doEarlyWrite bool + + httpInDial bool } func TestControlHTTP(t *testing.T) { @@ -120,6 +124,12 @@ func TestControlHTTP(t *testing.T) { name: "early_write", doEarlyWrite: true, }, + // Dialer needed to make another HTTP request along the way (e.g. to + // resolve the hostname via BootstrapDNS). + { + name: "http_request_in_dial", + httpInDial: true, + }, } for _, test := range tests { @@ -217,6 +227,29 @@ func testControlHTTP(t *testing.T, param httpTestParam) { Clock: clock, } + if param.httpInDial { + // Spin up a separate server to get a different port on localhost. + secondServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return })) + defer secondServer.Close() + + prev := a.Dialer + a.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, "GET", secondServer.URL, nil) + if err != nil { + t.Errorf("http.NewRequest: %v", err) + } + r, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Get: %v", err) + } + r.Body.Close() + + return prev(ctx, network, addr) + } + } + if proxy != nil { proxyEnv := proxy.Start(t) defer proxy.Close() @@ -238,6 +271,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { t.Fatalf("dialing controlhttp: %v", err) } defer conn.Close() + si := <-sch if si.conn != nil { defer si.conn.Close() @@ -266,6 +300,19 @@ func testControlHTTP(t *testing.T, param httpTestParam) { t.Errorf("early write = %q; want %q", buf, earlyWriteMsg) } } + + // When no proxy is used, the RemoteAddr of the returned connection should match + // one of the listeners of the test server. + if proxy == nil { + var expectedAddrs []string + for _, ln := range []net.Listener{httpLn, httpsLn} { + expectedAddrs = append(expectedAddrs, fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port)) + expectedAddrs = append(expectedAddrs, fmt.Sprintf("[::1]:%d", ln.Addr().(*net.TCPAddr).Port)) + } + if !slices.Contains(expectedAddrs, conn.RemoteAddr().String()) { + t.Errorf("unexpected remote addr: %s, want %s", conn.RemoteAddr(), expectedAddrs) + } + } } type serverResult struct {