ssh/tailssh: allow multiple sessions on the same conn

Fixes #4920
Fixes tailscale/corp#5633
Updates #4479

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-06-27 11:50:11 -07:00 committed by Maisem Ali
parent 1d04e01d1e
commit a7d2024e35
2 changed files with 145 additions and 106 deletions

View File

@ -62,11 +62,10 @@ type server struct {
sessionWaitGroup sync.WaitGroup
// mu protects the following
mu sync.Mutex
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
mu sync.Mutex
activeConns map[*conn]bool // set; value is always true
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
}
func (srv *server) now() time.Time {
@ -91,14 +90,28 @@ func init() {
})
}
func (srv *server) trackActiveConn(c *conn, add bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
if add {
mak.Set(&srv.activeConns, c, true)
return
}
delete(srv.activeConns, c)
}
// HandleSSHConn handles a Tailscale SSH connection from c.
func (srv *server) HandleSSHConn(c net.Conn) error {
// This is the entry point for all SSH connections.
// When this returns, the connection is closed.
func (srv *server) HandleSSHConn(nc net.Conn) error {
metricIncomingConnections.Add(1)
ss, err := srv.newConn()
c, err := srv.newConn()
if err != nil {
return err
}
ss.HandleConn(c)
srv.trackActiveConn(c, true) // add
defer srv.trackActiveConn(c, false) // remove
c.HandleConn(nc)
// Return nil to signal to netstack's interception that it doesn't need to
// log. If ss.HandleConn had problems, it can log itself (ideally on an
@ -110,11 +123,13 @@ func (srv *server) HandleSSHConn(c net.Conn) error {
func (srv *server) Shutdown() {
srv.mu.Lock()
srv.shutdownCalled = true
for _, s := range srv.activeSessionByH {
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
context.Canceled,
})
for c := range srv.activeConns {
for _, s := range c.sessions {
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
context.Canceled,
})
}
}
srv.mu.Unlock()
srv.sessionWaitGroup.Wait()
@ -125,8 +140,8 @@ func (srv *server) Shutdown() {
func (srv *server) OnPolicyChange() {
srv.mu.Lock()
defer srv.mu.Unlock()
for _, s := range srv.activeSessionByH {
go s.checkStillValid()
for c := range srv.activeConns {
go c.checkStillValid()
}
}
@ -135,25 +150,33 @@ func (srv *server) OnPolicyChange() {
type conn struct {
*ssh.Server
insecureSkipTailscaleAuth bool // used by tests.
// now is the time to consider the present moment for the
// purposes of rule evaluation.
now time.Time
connID string // ID that's shared with control
action0 *tailcfg.SSHAction // first matching action
srv *server
info *sshConnInfo // set by setInfo
localUser *user.User // set by checkAuth
userGroupIDs []string // set by checkAuth
insecureSkipTailscaleAuth bool // used by tests.
mu sync.Mutex // protects the following
// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
// and is shared among all sessions. It should not be shared outside
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
// library.
idH string
pubKey gossh.PublicKey // set by authorizeSession
finalAction *tailcfg.SSHAction // set by authorizeSession
finalActionErr error // set by authorizeSession
sessions []*sshSession
}
func (c *conn) logf(format string, args ...any) {
if c.info == nil {
c.srv.logf(format, args...)
return
}
format = fmt.Sprintf("%v: %v", c.info.String(), format)
format = fmt.Sprintf("%v: %v", c.connID, format)
c.srv.logf(format, args...)
}
@ -247,21 +270,22 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
func (srv *server) newConn() (*conn, error) {
srv.mu.Lock()
shutdownCalled := srv.shutdownCalled
srv.mu.Unlock()
if shutdownCalled {
if srv.shutdownCalled {
srv.mu.Unlock()
// Stop accepting new connections.
// Connections in the auth phase are handled in handleConnPostSSHAuth.
// Existing sessions are terminated by Shutdown.
return nil, gossh.ErrDenied
}
srv.mu.Unlock()
c := &conn{srv: srv, now: srv.now()}
c.connID = fmt.Sprintf("conn-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
c.Server = &ssh.Server{
Version: "Tailscale",
Handler: c.handleConnPostSSHAuth,
Handler: c.handleSessionPostSSHAuth,
RequestHandlers: map[string]ssh.RequestHandler{},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": c.handleConnPostSSHAuth,
"sftp": c.handleSessionPostSSHAuth,
},
// Note: the direct-tcpip channel handler and LocalPortForwardingCallback
@ -270,7 +294,7 @@ func (srv *server) newConn() (*conn, error) {
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": ssh.DirectTCPIPHandler,
},
LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
LocalPortForwardingCallback: c.mayForwardLocalPortTo,
PublicKeyHandler: c.PublicKeyHandler,
ServerConfigCallback: c.ServerConfig,
@ -298,16 +322,12 @@ func (srv *server) newConn() (*conn, error) {
// mayForwardLocalPortTo reports whether the ctx should be allowed to port forward
// to the specified host and port.
// TODO(bradfitz/maisem): should we have more checks on host/port?
func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
ss, ok := srv.getSessionForContext(ctx)
if !ok {
return false
func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding {
metricLocalPortForward.Add(1)
return true
}
if !ss.action.AllowLocalPortForwarding {
return false
}
metricLocalPortForward.Add(1)
return true
return false
}
// havePubKeyPolicy reports whether any policy rule may provide access by means
@ -401,6 +421,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
ci.uprof = &uprof
c.info = ci
c.logf("handling conn: %v", ci.String())
return nil
}
@ -516,32 +537,47 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
return lines, err
}
// handleConnPostSSHAuth runs an SSH session after the SSH-level authentication,
// but not necessarily before all the Tailscale-level extra verification has
// completed. It also handles SFTP requests.
func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
if s.PublicKey() != nil {
metricPublicKeyConnections.Add(1)
func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
idH := s.Context().(ssh.Context).SessionID()
if c.idH == "" {
c.idH = idH
} else if c.idH != idH {
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
s.Exit(1)
return nil, false
}
sshUser := s.User()
cr := &contextReader{r: s}
action, err := c.resolveTerminalAction(s, cr)
action, err := c.resolveTerminalActionLocked(s, cr)
if err != nil {
c.logf("resolveTerminalAction: %v", err)
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
s.Exit(1)
return
return nil, false
}
if action.Reject || !action.Accept {
c.logf("access denied for %v", c.info.uprof.LoginName)
s.Exit(1)
return nil, false
}
return cr, true
}
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
// but not necessarily before all the Tailscale-level extra verification has
// completed. It also handles SFTP requests.
func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
// Now that we have passed the SSH-level authentication, we can start the
// Tailscale-level extra verification. This means that we are going to
// evaluate the policy provided by control against the incoming SSH session.
cr, ok := c.authorizeSession(s)
if !ok {
return
}
if s.PublicKey() != nil {
metricPublicKeyAccepts.Add(1)
}
if cr.HasOutstandingRead() {
// There was some buffered input while we were waiting for the policy
// decision.
s = contextReaderSesssion{s, cr}
}
@ -555,20 +591,37 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
return
}
ss := c.newSSHSession(s, action)
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), sshUser)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, sshUser)
ss := c.newSSHSession(s)
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Name)
ss.run()
}
// resolveTerminalAction either returns action0 (if it's Accept or Reject) or
// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or
// else loops, fetching new SSHActions from the control plane.
//
// Any action with a Message in the chain will be printed to s.
//
// The returned SSHAction will be either Reject or Accept.
func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) {
action := c.action0
//
// c.mu must be held.
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
if c.finalAction != nil || c.finalActionErr != nil {
return c.finalAction, c.finalActionErr
}
if s.PublicKey() != nil {
metricPublicKeyConnections.Add(1)
}
defer func() {
c.finalAction = action
c.finalActionErr = err
c.pubKey = s.PublicKey()
if c.pubKey != nil && action.Accept {
metricPublicKeyAccepts.Add(1)
}
}()
action = c.action0
var awaitReadOnce sync.Once // to start Reads on cr
var sawInterrupt syncs.AtomicBool
@ -672,13 +725,11 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
// sshSession is an accepted Tailscale SSH session.
type sshSession struct {
ssh.Session
idH string // the RFC4253 sec8 hash H; don't share outside process
sharedID string // ID that's shared with control
logf logger.Logf
ctx *sshContext // implements context.Context
conn *conn
action *tailcfg.SSHAction
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
// initialized by launchProcess:
@ -699,22 +750,21 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
}
}
func (c *conn) newSSHSession(s ssh.Session, action *tailcfg.SSHAction) *sshSession {
sharedID := fmt.Sprintf("%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
sharedID := fmt.Sprintf("sess-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
c.logf("starting session: %v", sharedID)
return &sshSession{
Session: s,
idH: s.Context().(ssh.Context).SessionID(),
sharedID: sharedID,
ctx: newSSHContext(),
conn: c,
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
action: action,
}
}
func (c *conn) isStillValid(pubKey ssh.PublicKey) bool {
a, localUser, err := c.evaluatePolicy(pubKey)
// isStillValid reports whether the conn is still valid.
func (c *conn) isStillValid() bool {
a, localUser, err := c.evaluatePolicy(c.pubKey)
if err != nil {
return false
}
@ -724,18 +774,20 @@ func (c *conn) isStillValid(pubKey ssh.PublicKey) bool {
return c.localUser.Username == localUser
}
// checkStillValid checks that the session is still valid per the latest SSHPolicy.
// If not, it terminates the session.
func (ss *sshSession) checkStillValid() {
if ss.conn.isStillValid(ss.PublicKey()) {
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
// If not, it terminates all sessions associated with the conn.
func (c *conn) checkStillValid() {
if c.isStillValid() {
return
}
metricPolicyChangeKick.Add(1)
ss.logf("session no longer valid per new SSH policy; closing")
ss.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Access revoked.\r\n"),
context.Canceled,
})
c.logf("session no longer valid per new SSH policy; closing")
for _, s := range c.sessions {
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Access revoked.\r\n"),
context.Canceled,
})
}
}
func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
@ -798,41 +850,27 @@ func (ss *sshSession) killProcessOnContextDone() {
})
}
// sessionAction returns the SSHAction associated with the session.
func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bool) {
srv.mu.Lock()
defer srv.mu.Unlock()
ss, ok = srv.activeSessionByH[sctx.SessionID()]
return
}
// startSessionLocked registers ss as an active session.
// It must be called with srv.mu held.
func (srv *server) startSessionLocked(ss *sshSession) {
srv.sessionWaitGroup.Add(1)
if ss.idH == "" {
panic("empty idH")
}
func (c *conn) startSessionLocked(ss *sshSession) {
c.srv.sessionWaitGroup.Add(1)
if ss.sharedID == "" {
panic("empty sharedID")
}
if _, dup := srv.activeSessionByH[ss.idH]; dup {
panic("dup idH")
}
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
panic("dup sharedID")
}
mak.Set(&srv.activeSessionByH, ss.idH, ss)
mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss)
c.sessions = append(c.sessions, ss)
}
// endSession unregisters s from the list of active sessions.
func (srv *server) endSession(ss *sshSession) {
defer srv.sessionWaitGroup.Done()
srv.mu.Lock()
defer srv.mu.Unlock()
delete(srv.activeSessionByH, ss.idH)
delete(srv.activeSessionBySharedID, ss.sharedID)
func (c *conn) endSession(ss *sshSession) {
defer c.srv.sessionWaitGroup.Done()
c.srv.mu.Lock()
defer c.srv.mu.Unlock()
for i, s := range c.sessions {
if s == ss {
c.sessions = append(c.sessions[:i], c.sessions[i+1:]...)
break
}
}
}
var errSessionDone = errors.New("session is done")
@ -841,7 +879,7 @@ var errSessionDone = errors.New("session is done")
// forwards agent connections between the listener and the ssh.Session.
// On success, it assigns ss.agentListener.
func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error {
if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding {
if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
return nil
}
ss.logf("ssh: agent forwarding requested")
@ -906,15 +944,15 @@ func (ss *sshSession) run() {
ss.Exit(1)
return
}
srv.startSessionLocked(ss)
ss.conn.startSessionLocked(ss)
srv.mu.Unlock()
defer srv.endSession(ss)
defer ss.conn.endSession(ss)
if ss.action.SessionDuration != 0 {
t := time.AfterFunc(ss.action.SessionDuration, func() {
if ss.conn.finalAction.SessionDuration != 0 {
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
ss.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SessionDuration),
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
context.DeadlineExceeded,
})
})

View File

@ -238,9 +238,10 @@ func TestSSH(t *testing.T) {
node: &tailcfg.Node{},
uprof: &tailcfg.UserProfile{},
}
sc.finalAction = &tailcfg.SSHAction{Accept: true}
sc.Handler = func(s ssh.Session) {
sc.newSSHSession(s, &tailcfg.SSHAction{Accept: true}).run()
sc.newSSHSession(s).run()
}
ln, err := net.Listen("tcp4", "127.0.0.1:0")