ipn/ipnserver: remove protoSwitchConn shenanigans; just use http.Server early

Now that everything's just HTTP, there's no longer a need to have a
header-sniffing net.Conn wraper that dispatches which route to
take. Refactor to just use an http.Server earlier instead.

Updates #6417

Change-Id: I12a2054db4e56f48660c46f81233db224fdc77cb
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-25 20:54:37 -08:00 committed by Brad Fitzpatrick
parent e567902aa9
commit b0545873e5
4 changed files with 142 additions and 164 deletions

View File

@ -57,7 +57,6 @@ func (ci *ConnIdentity) UserID() string { return ci.userID }
func (ci *ConnIdentity) User() *user.User { return ci.user }
func (ci *ConnIdentity) Pid() int { return ci.pid }
func (ci *ConnIdentity) IsUnixSock() bool { return ci.isUnixSock }
func (ci *ConnIdentity) NotWindows() bool { return ci.notWindows }
func (ci *ConnIdentity) Creds() *peercred.Creds { return ci.creds }
// GetConnIdentity returns the localhost TCP connection's identity information

View File

@ -7,15 +7,11 @@
package ipnserver
import (
"bufio"
"context"
"io"
"net"
"net/http"
"time"
"tailscale.com/logpolicy"
"tailscale.com/types/logger"
)
// handleProxyConnectConn handles a CONNECT request to
@ -27,40 +23,42 @@ import (
// "Internet Kill Switch" installed by tailscaled for exit nodes
// precludes that from working and instead the GUI fails to dial out.
// So, go through tailscaled (with a CONNECT request) instead.
func (s *Server) handleProxyConnectConn(ctx context.Context, br *bufio.Reader, c net.Conn, logf logger.Logf) {
defer c.Close()
c.SetReadDeadline(time.Now().Add(5 * time.Second)) // should be long enough to send the HTTP headers
req, err := http.ReadRequest(br)
if err != nil {
logf("ReadRequest: %v", err)
return
}
c.SetReadDeadline(time.Time{})
if req.Method != "CONNECT" {
logf("ReadRequest: unexpected method %q, not CONNECT", req.Method)
return
func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method != "CONNECT" {
panic("[unexpected] miswired")
}
hostPort := req.RequestURI
hostPort := r.RequestURI
logHost := logpolicy.LogHost()
allowed := net.JoinHostPort(logHost, "443")
if hostPort != allowed {
logf("invalid CONNECT target %q; want %q", hostPort, allowed)
io.WriteString(c, "HTTP/1.1 403 Forbidden\r\n\r\nBad CONNECT target.\n")
s.logf("invalid CONNECT target %q; want %q", hostPort, allowed)
http.Error(w, "Bad CONNECT target.", http.StatusForbidden)
return
}
tr := logpolicy.NewLogtailTransport(logHost)
back, err := tr.DialContext(ctx, "tcp", hostPort)
if err != nil {
logf("error CONNECT dialing %v: %v", hostPort, err)
io.WriteString(c, "HTTP/1.1 502 Fail\r\n\r\nConnect failure.\n")
s.logf("error CONNECT dialing %v: %v", hostPort, err)
http.Error(w, "Connect failure", http.StatusBadGateway)
return
}
defer back.Close()
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
c, br, err := hj.Hijack()
if err != nil {
s.logf("CONNECT hijack: %v", err)
return
}
defer c.Close()
io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
errc := make(chan error, 2)

View File

@ -4,14 +4,8 @@
package ipnserver
import (
"bufio"
"context"
"net"
import "net/http"
"tailscale.com/types/logger"
)
func (s *Server) handleProxyConnectConn(ctx context.Context, br *bufio.Reader, c net.Conn, logf logger.Logf) {
func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) {
panic("unreachable")
}

View File

@ -5,8 +5,8 @@
package ipnserver
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
@ -20,7 +20,6 @@ import (
"time"
"unicode"
"go4.org/mem"
"tailscale.com/control/controlclient"
"tailscale.com/envknob"
"tailscale.com/ipn"
@ -29,10 +28,10 @@ import (
"tailscale.com/ipn/localapi"
"tailscale.com/logtail/backoff"
"tailscale.com/net/dnsfallback"
"tailscale.com/net/netutil"
"tailscale.com/net/tsdial"
"tailscale.com/smallzstd"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/util/systemd"
"tailscale.com/version/distro"
"tailscale.com/wgengine"
@ -82,9 +81,9 @@ type Server struct {
logf logger.Logf
backendLogID string
// resetOnZero is whether to call bs.Reset on transition from
// 1->0 connections. That is, this is whether the backend is
// 1->0 active HTTP requests. That is, this is whether the backend is
// being run in "client mode" that requires an active GUI
// connection (such as on Windows by default). Even if this
// connection (such as on Windows by default). Even if this
// is true, the ForceDaemon pref can override this.
resetOnZero bool
@ -92,56 +91,64 @@ type Server struct {
// lock order: mu, then LocalBackend.mu
mu sync.Mutex
lastUserID string // tracks last userid; on change, Reset state for paranoia
allClients map[net.Conn]*ipnauth.ConnIdentity
activeReqs map[*http.Request]*ipnauth.ConnIdentity
}
// LocalBackend returns the server's LocalBackend.
func (s *Server) LocalBackend() *ipnlocal.LocalBackend { return s.b }
// bufferIsConnect reports whether br looks like it's likely an HTTP
// CONNECT request.
//
// Invariant: br has already had at least 4 bytes Peek'ed.
func bufferIsConnect(br *bufio.Reader) bool {
peek, _ := br.Peek(br.Buffered())
return mem.HasPrefix(mem.B(peek), mem.S("CONN"))
}
func (s *Server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) {
// First sniff a few bytes to check its HTTP method.
br := bufio.NewReader(c)
c.SetReadDeadline(time.Now().Add(30 * time.Second))
br.Peek(len("GET / HTTP/1.1\r\n")) // reasonable sniff size to get HTTP method
c.SetReadDeadline(time.Time{})
// Handle logtail CONNECT requests early. (See docs on handleProxyConnectConn)
if bufferIsConnect(br) {
s.handleProxyConnectConn(ctx, br, c, logf)
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "CONNECT" {
if envknob.GOOS() == "windows" {
// For the GUI client when using an exit node. See docs on handleProxyConnectConn.
s.handleProxyConnectConn(w, r)
} else {
http.Error(w, "bad method for platform", http.StatusMethodNotAllowed)
}
return
}
ci, err := s.addConn(c)
var ci *ipnauth.ConnIdentity
switch v := r.Context().Value(connIdentityContextKey{}).(type) {
case *ipnauth.ConnIdentity:
ci = v
case error:
http.Error(w, v.Error(), http.StatusUnauthorized)
return
case nil:
http.Error(w, "internal error: no connIdentityContextKey", http.StatusInternalServerError)
return
}
onDone, err := s.addActiveHTTPRequest(r, ci)
if err != nil {
fmt.Fprintf(c, "HTTP/1.0 500 Nope\r\nContent-Type: text/plain\r\nX-Content-Type-Options: nosniff\r\n\r\n%s\n", err.Error())
c.Close()
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
defer onDone()
if strings.HasPrefix(r.URL.Path, "/localapi/") {
lah := localapi.NewHandler(s.b, s.logf, s.backendLogID)
lah.PermitRead, lah.PermitWrite = s.localAPIPermissions(ci)
lah.PermitCert = s.connCanFetchCerts(ci)
lah.ServeHTTP(w, r)
return
}
// Tell the LocalBackend about the identity we're now running as.
s.b.SetCurrentUserID(ci.UserID())
httpServer := &http.Server{
// Localhost connections are cheap; so only do
// keep-alives for a short period of time, as these
// active connections lock the server into only serving
// that user. If the user has this page open, we don't
// want another switching user to be locked out for
// minutes. 5 seconds is enough to let browser hit
// favicon.ico and such.
IdleTimeout: 5 * time.Second,
ErrorLog: logger.StdLogger(logf),
Handler: s.localhostHandler(ci),
if r.URL.Path != "/" {
http.NotFound(w, r)
return
}
httpServer.Serve(netutil.NewOneConnListener(&protoSwitchConn{s: s, br: br, Conn: c}, nil))
if envknob.GOOS() == "windows" {
// TODO(bradfitz): remove this once we moved to named pipes for LocalAPI
// on Windows. This could then move to all platforms instead at
// 100.100.100.100 or something (quad100 handler in LocalAPI)
s.ServeHTMLStatus(w, r)
return
}
io.WriteString(w, "<html><title>Tailscale</title><body><h1>Tailscale</h1>This is the local Tailscale daemon.\n")
}
// inUseOtherUserError is the error type for when the server is in use
@ -159,9 +166,9 @@ func (e inUseOtherUserError) Unwrap() error { return e.error }
func (s *Server) checkConnIdentityLocked(ci *ipnauth.ConnIdentity) error {
// If clients are already connected, verify they're the same user.
// This mostly matters on Windows at the moment.
if len(s.allClients) > 0 {
if len(s.activeReqs) > 0 {
var active *ipnauth.ConnIdentity
for _, active = range s.allClients {
for _, active = range s.activeReqs {
break
}
if active != nil && ci.UserID() != active.UserID() {
@ -239,14 +246,14 @@ func (s *Server) connCanFetchCerts(ci *ipnauth.ConnIdentity) bool {
return false
}
// addConn adds c to the server's list of clients.
// addActiveHTTPRequest adds c to the server's list of active HTTP requests.
//
// If the returned error is of type inUseOtherUserError then the
// returned connIdentity is also valid.
func (s *Server) addConn(c net.Conn) (ci *ipnauth.ConnIdentity, err error) {
ci, err = ipnauth.GetConnIdentity(s.logf, c)
if err != nil {
return
// If the returned error may be of type inUseOtherUserError.
//
// onDone must be called when the HTTP request is done.
func (s *Server) addActiveHTTPRequest(req *http.Request, ci *ipnauth.ConnIdentity) (onDone func(), err error) {
if ci == nil {
return nil, errors.New("internal error: nil connIdentity")
}
// If the connected user changes, reset the backend server state to make
@ -262,40 +269,41 @@ func (s *Server) addConn(c net.Conn) (ci *ipnauth.ConnIdentity, err error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.allClients == nil {
s.allClients = map[net.Conn]*ipnauth.ConnIdentity{}
}
if err := s.checkConnIdentityLocked(ci); err != nil {
return ci, err
return nil, err
}
s.allClients[c] = ci
mak.Set(&s.activeReqs, req, ci)
if s.lastUserID != ci.UserID() {
if s.lastUserID != "" {
doReset = true
}
s.lastUserID = ci.UserID()
}
return ci, nil
}
func (s *Server) removeAndCloseConn(c net.Conn) {
s.mu.Lock()
delete(s.allClients, c)
remain := len(s.allClients)
s.mu.Unlock()
if remain == 0 && s.resetOnZero {
if s.b.InServerMode() {
s.logf("client disconnected; staying alive in server mode")
} else {
s.logf("client disconnected; stopping server")
s.b.ResetForClientDisconnect()
if envknob.GOOS() == "windows" && len(s.activeReqs) == 1 {
uid := ci.UserID()
// Tell the LocalBackend about the identity we're now running as.
s.b.SetCurrentUserID(uid)
if s.lastUserID != uid {
if s.lastUserID != "" {
doReset = true
}
s.lastUserID = uid
}
}
c.Close()
onDone = func() {
s.mu.Lock()
delete(s.activeReqs, req)
remain := len(s.activeReqs)
s.mu.Unlock()
if remain == 0 && s.resetOnZero {
if s.b.InServerMode() {
s.logf("client disconnected; staying alive in server mode")
} else {
s.logf("client disconnected; stopping server")
s.b.ResetForClientDisconnect()
}
}
}
return onDone, nil
}
// Run runs a Tailscale backend service.
@ -407,9 +415,14 @@ func New(logf logger.Logf, logid string, store ipn.StateStore, eng wgengine.Engi
return server, nil
}
// connIdentityContextKey is the http.Request.Context's context.Value key for either an
// *ipnauth.ConnIdentity or an error.
type connIdentityContextKey struct{}
// Run runs the server, accepting connections from ln forever.
//
// If the context is done, the listener is closed.
// If the context is done, the listener is closed. It is also the base context
// of all HTTP requests.
func (s *Server) Run(ctx context.Context, ln net.Listener) error {
defer s.b.Shutdown()
@ -429,26 +442,34 @@ func (s *Server) Run(ctx context.Context, ln net.Listener) error {
if s.b.Prefs().Valid() {
s.b.Start(ipn.Options{})
}
systemd.Ready()
bo := backoff.NewBackoff("ipnserver", s.logf, 30*time.Second)
var connNum int
for {
if ctx.Err() != nil {
return ctx.Err()
}
c, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
hs := &http.Server{
Handler: http.HandlerFunc(s.serveHTTP),
BaseContext: func(_ net.Listener) context.Context { return ctx },
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
ci, err := ipnauth.GetConnIdentity(s.logf, c)
if err != nil {
return context.WithValue(ctx, connIdentityContextKey{}, err)
}
s.logf("ipnserver: Accept: %v", err)
bo.BackOff(ctx, err)
continue
}
connNum++
go s.serveConn(ctx, c, logger.WithPrefix(s.logf, fmt.Sprintf("ipnserver: conn%d: ", connNum)))
return context.WithValue(ctx, connIdentityContextKey{}, ci)
},
// Localhost connections are cheap; so only do
// keep-alives for a short period of time, as these
// active connections lock the server into only serving
// that user. If the user has this page open, we don't
// want another switching user to be locked out for
// minutes. 5 seconds is enough to let browser hit
// favicon.ico and such.
IdleTimeout: 5 * time.Second,
ErrorLog: logger.StdLogger(logger.WithPrefix(s.logf, "ipnserver: ")),
}
if err := hs.Serve(ln); err != nil {
if err := ctx.Err(); err != nil {
return err
}
}
return nil
}
// getEngineUntilItWorksWrapper returns a getEngine wrapper that does
@ -474,40 +495,6 @@ func getEngineUntilItWorksWrapper(getEngine func() (wgengine.Engine, *netstack.I
}
}
// protoSwitchConn is a net.Conn with which we want to speak HTTP to but
// it's already had a few bytes read from it to determine its HTTP method.
// So we Read from its bufio.Reader. On Close, we we tell the
type protoSwitchConn struct {
s *Server
net.Conn
br *bufio.Reader
closeOnce sync.Once
}
func (psc *protoSwitchConn) Read(p []byte) (int, error) { return psc.br.Read(p) }
func (psc *protoSwitchConn) Close() error {
psc.closeOnce.Do(func() { psc.s.removeAndCloseConn(psc.Conn) })
return nil
}
func (s *Server) localhostHandler(ci *ipnauth.ConnIdentity) http.Handler {
lah := localapi.NewHandler(s.b, s.logf, s.backendLogID)
lah.PermitRead, lah.PermitWrite = s.localAPIPermissions(ci)
lah.PermitCert = s.connCanFetchCerts(ci)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/localapi/") {
lah.ServeHTTP(w, r)
return
}
if ci.NotWindows() {
io.WriteString(w, "<html><title>Tailscale</title><body><h1>Tailscale</h1>This is the local Tailscale daemon.")
return
}
s.ServeHTMLStatus(w, r)
})
}
// ServeHTMLStatus serves an HTML status page at http://localhost:41112/ for
// Windows and via $DEBUG_LISTENER/debug/ipn when tailscaled's --debug flag
// is used to run a debug server.