From 638127530bd1ab2e67542d7d59aa798bf5cd3565 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 9 Oct 2020 12:15:57 -0700 Subject: [PATCH] ipn/ipnserver: prevent use by multiple Windows users, add HTML status page It was previously possible for two different Windows users to connect to the IPN server at once, but it didn't really work. They mostly stepped on each other's toes and caused chaos. Now only one can control it, but it can be active for everybody else. Necessary dependency step for Windows server/headless mode (#275) While here, finish wiring up the HTTP status page on Windows, now that all the dependent pieces are available. --- ipn/ipnserver/server.go | 204 +++++++++++++++++++++++++++++----------- 1 file changed, 150 insertions(+), 54 deletions(-) diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index fb63de133..0fa8a5f0e 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -7,8 +7,8 @@ package ipnserver import ( "bufio" "context" + "errors" "fmt" - "html" "io" "log" "net" @@ -82,29 +82,112 @@ type Options struct { // talking to an IPN backend. type server struct { resetOnZero bool // call bs.Reset on transition from 1->0 connections + b *ipn.LocalBackend bsMu sync.Mutex // lock order: bsMu, then mu bs *ipn.BackendServer - mu sync.Mutex - clients map[net.Conn]bool + mu sync.Mutex + allClients map[net.Conn]connIdentity // HTTP or IPN + clients map[net.Conn]bool // subset of allClients; only IPN protocol +} + +// connIdentity represents the owner of a localhost TCP connection. +type connIdentity struct { + Unknown bool + Pid int + UserID string + User *user.User +} + +// getConnIdentity returns the localhost TCP connection's identity information +// (pid, userid, user). If it's not Windows (for now), it returns a nil error +// and a ConnIdentity with Unknown set true. It's only an error if we expected +// to be able to map it and couldn't. +func getConnIdentity(c net.Conn) (ci connIdentity, err error) { + if runtime.GOOS != "windows" { // for now; TODO: expand to other OSes + return connIdentity{Unknown: true}, nil + } + la, err := netaddr.ParseIPPort(c.LocalAddr().String()) + if err != nil { + return ci, fmt.Errorf("parsing local address: %w", err) + } + ra, err := netaddr.ParseIPPort(c.RemoteAddr().String()) + if err != nil { + return ci, fmt.Errorf("parsing local remote: %w", err) + } + if !la.IP.IsLoopback() || !ra.IP.IsLoopback() { + return ci, errors.New("non-loopback connection") + } + tab, err := netstat.Get() + if err != nil { + return ci, fmt.Errorf("failed to get local connection table: %w", err) + } + pid := peerPid(tab.Entries, la, ra) + if pid == 0 { + return ci, errors.New("no local process found matching localhost connection") + } + ci.Pid = pid + uid, err := pidowner.OwnerOfPID(pid) + if err != nil { + var hint string + if runtime.GOOS == "windows" { + hint = " (WSL?)" + } + return ci, fmt.Errorf("failed to map connection's pid to a user%s: %w", hint, err) + } + ci.UserID = uid + u, err := user.LookupId(uid) + if err != nil { + return ci, fmt.Errorf("failed to look up user from userid: %w", err) + } + ci.User = u + return ci, nil } func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { - br := bufio.NewReader(c) - // First see if it's an HTTP request. + br := bufio.NewReader(c) c.SetReadDeadline(time.Now().Add(time.Second)) peek, _ := br.Peek(4) c.SetReadDeadline(time.Time{}) - if string(peek) == "GET " { - http.Serve(&oneConnListener{altReaderNetConn{br, c}}, localhostHandler(c)) + isHTTPReq := string(peek) == "GET " + + ci, err := s.addConn(c, isHTTPReq) + if err != nil { + if isHTTPReq { + fmt.Fprintf(c, "HTTP/1.0 500 Nope\r\nContent-Type: text/plain\r\nX-Content-Type-Options: nosniff\r\n\r\n%s\n", err.Error()) + c.Close() + return + } + defer c.Close() + serverToClient := func(b []byte) { ipn.WriteMsg(c, b) } + bs := ipn.NewBackendServer(logf, nil, serverToClient) + bs.SendErrorMessage(err.Error()) + time.Sleep(time.Second) + return + } + + if isHTTPReq { + httpServer := http.Server{ + // Localhost connections are cheap; so only do + // keep-alives for a short period of time, as these + // active connections lock the server into only serving + // that user. If the user has this page open, we don't + // want another switching user to be locked out for + // minutes. 5 seconds is enough to let browser hit + // favicon.ico and such. + IdleTimeout: 5 * time.Second, + ErrorLog: logger.StdLogger(logf), + Handler: s.localhostHandler(ci), + } + httpServer.Serve(&oneConnListener{&protoSwitchConn{s: s, br: br, Conn: c}}) return } - s.addConn(c) - logf("incoming control connection") defer s.removeAndCloseConn(c) + logf("incoming control connection") + for ctx.Err() == nil { msg, err := ipn.ReadMsg(br) if err != nil { @@ -125,19 +208,48 @@ func (s *server) serveConn(ctx context.Context, c net.Conn, logf logger.Logf) { } } -func (s *server) addConn(c net.Conn) { +func (s *server) addConn(c net.Conn, isHTTP bool) (ci connIdentity, err error) { + ci, err = getConnIdentity(c) + if err != nil { + return + } + s.mu.Lock() defer s.mu.Unlock() + if s.clients == nil { s.clients = map[net.Conn]bool{} } - s.clients[c] = true + if s.allClients == nil { + s.allClients = map[net.Conn]connIdentity{} + } + + // If clients are already connected, verify they're the same user. + // This mostly matters on Windows at the moment. + if len(s.allClients) > 0 { + var active connIdentity + for _, active = range s.allClients { + break + } + if ci.UserID != active.UserID { + //lint:ignore ST1005 we want to capitalize Tailscale here + return ci, fmt.Errorf("Tailscale already in use by %s, pid %d", active.User.Username, active.Pid) + } + } + + if !isHTTP { + s.clients[c] = true + } + s.allClients[c] = ci + + return ci, nil } func (s *server) removeAndCloseConn(c net.Conn) { s.mu.Lock() delete(s.clients, c) - remain := len(s.clients) + delete(s.allClients, c) + remain := len(s.allClients) s.mu.Unlock() if remain == 0 && s.resetOnZero { @@ -250,13 +362,11 @@ func Run(ctx context.Context, logf logger.Logf, logid string, getEngine func() ( if opts.DebugMux != nil { opts.DebugMux.HandleFunc("/debug/ipn", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - st := b.Status() - // TODO(bradfitz): add LogID and opts to st? - st.WriteHTML(w) + serveHTMLStatus(w, b) }) } + server.b = b server.bs = ipn.NewBackendServer(logf, b, server.writeToClients) if opts.AutostartStateKey != "" { @@ -436,54 +546,40 @@ func (l *oneConnListener) Addr() net.Addr { return dummyAddr("unused-address") } func (a dummyAddr) Network() string { return string(a) } func (a dummyAddr) String() string { return string(a) } -type altReaderNetConn struct { - r io.Reader +// protoSwitchConn is a net.Conn that's we want to speak HTTP to but +// it's already had a few bytes read from it to determine that it's +// HTTP. So we Read from its bufio.Reader. On Close, we we tell the +// server it's closed, so the server can account the who's connected. +type protoSwitchConn struct { + s *server net.Conn + br *bufio.Reader + closeOnce sync.Once } -func (a altReaderNetConn) Read(p []byte) (int, error) { return a.r.Read(p) } +func (psc *protoSwitchConn) Read(p []byte) (int, error) { return psc.br.Read(p) } +func (psc *protoSwitchConn) Close() error { + psc.closeOnce.Do(func() { psc.s.removeAndCloseConn(psc.Conn) }) + return nil +} -func localhostHandler(c net.Conn) http.Handler { - la, lerr := netaddr.ParseIPPort(c.LocalAddr().String()) - ra, rerr := netaddr.ParseIPPort(c.RemoteAddr().String()) +func (s *server) localhostHandler(ci connIdentity) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "

tailscale

\n") - if lerr != nil || rerr != nil { - io.WriteString(w, "failed to parse remote address") + if ci.Unknown { + io.WriteString(w, "Tailscale

Tailscale

This is the local Tailscale daemon.") return } - if !la.IP.IsLoopback() || !ra.IP.IsLoopback() { - io.WriteString(w, "not loopback") - return - } - tab, err := netstat.Get() - if err == netstat.ErrNotImplemented { - io.WriteString(w, "status page not available on "+runtime.GOOS) - return - } - if err != nil { - io.WriteString(w, "failed to get netstat table") - return - } - pid := peerPid(tab.Entries, la, ra) - if pid == 0 { - io.WriteString(w, "peer pid not found") - return - } - uid, err := pidowner.OwnerOfPID(pid) - if err != nil { - io.WriteString(w, "owner of peer pid not found") - return - } - u, err := user.LookupId(uid) - if err != nil { - io.WriteString(w, "User lookup failed") - return - } - fmt.Fprintf(w, "Hello, %s", html.EscapeString(u.Username)) + serveHTMLStatus(w, s.b) }) } +func serveHTMLStatus(w http.ResponseWriter, b *ipn.LocalBackend) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + st := b.Status() + // TODO(bradfitz): add LogID and opts to st? + st.WriteHTML(w) +} + func peerPid(entries []netstat.Entry, la, ra netaddr.IPPort) int { for _, e := range entries { if e.Local == ra && e.Remote == la {