ssh/tailssh: allow multiple sessions on the same conn
Fixes #4920 Fixes tailscale/corp#5633 Updates #4479 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
1d04e01d1e
commit
a7d2024e35
|
@ -62,11 +62,10 @@ type server struct {
|
|||
sessionWaitGroup sync.WaitGroup
|
||||
|
||||
// mu protects the following
|
||||
mu sync.Mutex
|
||||
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
|
||||
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
|
||||
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
|
||||
shutdownCalled bool
|
||||
mu sync.Mutex
|
||||
activeConns map[*conn]bool // set; value is always true
|
||||
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
|
||||
shutdownCalled bool
|
||||
}
|
||||
|
||||
func (srv *server) now() time.Time {
|
||||
|
@ -91,14 +90,28 @@ func init() {
|
|||
})
|
||||
}
|
||||
|
||||
func (srv *server) trackActiveConn(c *conn, add bool) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
if add {
|
||||
mak.Set(&srv.activeConns, c, true)
|
||||
return
|
||||
}
|
||||
delete(srv.activeConns, c)
|
||||
}
|
||||
|
||||
// HandleSSHConn handles a Tailscale SSH connection from c.
|
||||
func (srv *server) HandleSSHConn(c net.Conn) error {
|
||||
// This is the entry point for all SSH connections.
|
||||
// When this returns, the connection is closed.
|
||||
func (srv *server) HandleSSHConn(nc net.Conn) error {
|
||||
metricIncomingConnections.Add(1)
|
||||
ss, err := srv.newConn()
|
||||
c, err := srv.newConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ss.HandleConn(c)
|
||||
srv.trackActiveConn(c, true) // add
|
||||
defer srv.trackActiveConn(c, false) // remove
|
||||
c.HandleConn(nc)
|
||||
|
||||
// Return nil to signal to netstack's interception that it doesn't need to
|
||||
// log. If ss.HandleConn had problems, it can log itself (ideally on an
|
||||
|
@ -110,11 +123,13 @@ func (srv *server) HandleSSHConn(c net.Conn) error {
|
|||
func (srv *server) Shutdown() {
|
||||
srv.mu.Lock()
|
||||
srv.shutdownCalled = true
|
||||
for _, s := range srv.activeSessionByH {
|
||||
s.ctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
for c := range srv.activeConns {
|
||||
for _, s := range c.sessions {
|
||||
s.ctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Tailscale SSH is shutting down.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
}
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
srv.sessionWaitGroup.Wait()
|
||||
|
@ -125,8 +140,8 @@ func (srv *server) Shutdown() {
|
|||
func (srv *server) OnPolicyChange() {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
for _, s := range srv.activeSessionByH {
|
||||
go s.checkStillValid()
|
||||
for c := range srv.activeConns {
|
||||
go c.checkStillValid()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -135,25 +150,33 @@ func (srv *server) OnPolicyChange() {
|
|||
type conn struct {
|
||||
*ssh.Server
|
||||
|
||||
insecureSkipTailscaleAuth bool // used by tests.
|
||||
|
||||
// now is the time to consider the present moment for the
|
||||
// purposes of rule evaluation.
|
||||
now time.Time
|
||||
|
||||
connID string // ID that's shared with control
|
||||
action0 *tailcfg.SSHAction // first matching action
|
||||
srv *server
|
||||
info *sshConnInfo // set by setInfo
|
||||
localUser *user.User // set by checkAuth
|
||||
userGroupIDs []string // set by checkAuth
|
||||
|
||||
insecureSkipTailscaleAuth bool // used by tests.
|
||||
mu sync.Mutex // protects the following
|
||||
// idH is the RFC4253 sec8 hash H. It is used to identify the connection,
|
||||
// and is shared among all sessions. It should not be shared outside
|
||||
// process. It is confusingly referred to as SessionID by the gliderlabs/ssh
|
||||
// library.
|
||||
idH string
|
||||
pubKey gossh.PublicKey // set by authorizeSession
|
||||
finalAction *tailcfg.SSHAction // set by authorizeSession
|
||||
finalActionErr error // set by authorizeSession
|
||||
sessions []*sshSession
|
||||
}
|
||||
|
||||
func (c *conn) logf(format string, args ...any) {
|
||||
if c.info == nil {
|
||||
c.srv.logf(format, args...)
|
||||
return
|
||||
}
|
||||
format = fmt.Sprintf("%v: %v", c.info.String(), format)
|
||||
format = fmt.Sprintf("%v: %v", c.connID, format)
|
||||
c.srv.logf(format, args...)
|
||||
}
|
||||
|
||||
|
@ -247,21 +270,22 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
|
|||
|
||||
func (srv *server) newConn() (*conn, error) {
|
||||
srv.mu.Lock()
|
||||
shutdownCalled := srv.shutdownCalled
|
||||
srv.mu.Unlock()
|
||||
if shutdownCalled {
|
||||
if srv.shutdownCalled {
|
||||
srv.mu.Unlock()
|
||||
// Stop accepting new connections.
|
||||
// Connections in the auth phase are handled in handleConnPostSSHAuth.
|
||||
// Existing sessions are terminated by Shutdown.
|
||||
return nil, gossh.ErrDenied
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
c := &conn{srv: srv, now: srv.now()}
|
||||
c.connID = fmt.Sprintf("conn-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
|
||||
c.Server = &ssh.Server{
|
||||
Version: "Tailscale",
|
||||
Handler: c.handleConnPostSSHAuth,
|
||||
Handler: c.handleSessionPostSSHAuth,
|
||||
RequestHandlers: map[string]ssh.RequestHandler{},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": c.handleConnPostSSHAuth,
|
||||
"sftp": c.handleSessionPostSSHAuth,
|
||||
},
|
||||
|
||||
// Note: the direct-tcpip channel handler and LocalPortForwardingCallback
|
||||
|
@ -270,7 +294,7 @@ func (srv *server) newConn() (*conn, error) {
|
|||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": ssh.DirectTCPIPHandler,
|
||||
},
|
||||
LocalPortForwardingCallback: srv.mayForwardLocalPortTo,
|
||||
LocalPortForwardingCallback: c.mayForwardLocalPortTo,
|
||||
|
||||
PublicKeyHandler: c.PublicKeyHandler,
|
||||
ServerConfigCallback: c.ServerConfig,
|
||||
|
@ -298,16 +322,12 @@ func (srv *server) newConn() (*conn, error) {
|
|||
// mayForwardLocalPortTo reports whether the ctx should be allowed to port forward
|
||||
// to the specified host and port.
|
||||
// TODO(bradfitz/maisem): should we have more checks on host/port?
|
||||
func (srv *server) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
ss, ok := srv.getSessionForContext(ctx)
|
||||
if !ok {
|
||||
return false
|
||||
func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
|
||||
if c.finalAction != nil && c.finalAction.AllowLocalPortForwarding {
|
||||
metricLocalPortForward.Add(1)
|
||||
return true
|
||||
}
|
||||
if !ss.action.AllowLocalPortForwarding {
|
||||
return false
|
||||
}
|
||||
metricLocalPortForward.Add(1)
|
||||
return true
|
||||
return false
|
||||
}
|
||||
|
||||
// havePubKeyPolicy reports whether any policy rule may provide access by means
|
||||
|
@ -401,6 +421,7 @@ func (c *conn) setInfo(cm gossh.ConnMetadata) error {
|
|||
ci.uprof = &uprof
|
||||
|
||||
c.info = ci
|
||||
c.logf("handling conn: %v", ci.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -516,32 +537,47 @@ func (srv *server) fetchPublicKeysURL(url string) ([]string, error) {
|
|||
return lines, err
|
||||
}
|
||||
|
||||
// handleConnPostSSHAuth runs an SSH session after the SSH-level authentication,
|
||||
// but not necessarily before all the Tailscale-level extra verification has
|
||||
// completed. It also handles SFTP requests.
|
||||
func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
|
||||
if s.PublicKey() != nil {
|
||||
metricPublicKeyConnections.Add(1)
|
||||
func (c *conn) authorizeSession(s ssh.Session) (_ *contextReader, ok bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
idH := s.Context().(ssh.Context).SessionID()
|
||||
if c.idH == "" {
|
||||
c.idH = idH
|
||||
} else if c.idH != idH {
|
||||
c.logf("ssh: session ID mismatch: %q != %q", c.idH, idH)
|
||||
s.Exit(1)
|
||||
return nil, false
|
||||
}
|
||||
sshUser := s.User()
|
||||
cr := &contextReader{r: s}
|
||||
action, err := c.resolveTerminalAction(s, cr)
|
||||
action, err := c.resolveTerminalActionLocked(s, cr)
|
||||
if err != nil {
|
||||
c.logf("resolveTerminalAction: %v", err)
|
||||
io.WriteString(s.Stderr(), "Access Denied: failed during authorization check.\r\n")
|
||||
s.Exit(1)
|
||||
return
|
||||
return nil, false
|
||||
}
|
||||
if action.Reject || !action.Accept {
|
||||
c.logf("access denied for %v", c.info.uprof.LoginName)
|
||||
s.Exit(1)
|
||||
return nil, false
|
||||
}
|
||||
return cr, true
|
||||
}
|
||||
|
||||
// handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication,
|
||||
// but not necessarily before all the Tailscale-level extra verification has
|
||||
// completed. It also handles SFTP requests.
|
||||
func (c *conn) handleSessionPostSSHAuth(s ssh.Session) {
|
||||
// Now that we have passed the SSH-level authentication, we can start the
|
||||
// Tailscale-level extra verification. This means that we are going to
|
||||
// evaluate the policy provided by control against the incoming SSH session.
|
||||
cr, ok := c.authorizeSession(s)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if s.PublicKey() != nil {
|
||||
metricPublicKeyAccepts.Add(1)
|
||||
}
|
||||
|
||||
if cr.HasOutstandingRead() {
|
||||
// There was some buffered input while we were waiting for the policy
|
||||
// decision.
|
||||
s = contextReaderSesssion{s, cr}
|
||||
}
|
||||
|
||||
|
@ -555,20 +591,37 @@ func (c *conn) handleConnPostSSHAuth(s ssh.Session) {
|
|||
return
|
||||
}
|
||||
|
||||
ss := c.newSSHSession(s, action)
|
||||
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), sshUser)
|
||||
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, sshUser)
|
||||
ss := c.newSSHSession(s)
|
||||
ss.logf("handling new SSH connection from %v (%v) to ssh-user %q", c.info.uprof.LoginName, c.info.src.IP(), c.localUser.Username)
|
||||
ss.logf("access granted to %v as ssh-user %q", c.info.uprof.LoginName, c.localUser.Name)
|
||||
ss.run()
|
||||
}
|
||||
|
||||
// resolveTerminalAction either returns action0 (if it's Accept or Reject) or
|
||||
// resolveTerminalActionLocked either returns action0 (if it's Accept or Reject) or
|
||||
// else loops, fetching new SSHActions from the control plane.
|
||||
//
|
||||
// Any action with a Message in the chain will be printed to s.
|
||||
//
|
||||
// The returned SSHAction will be either Reject or Accept.
|
||||
func (c *conn) resolveTerminalAction(s ssh.Session, cr *contextReader) (*tailcfg.SSHAction, error) {
|
||||
action := c.action0
|
||||
//
|
||||
// c.mu must be held.
|
||||
func (c *conn) resolveTerminalActionLocked(s ssh.Session, cr *contextReader) (action *tailcfg.SSHAction, err error) {
|
||||
if c.finalAction != nil || c.finalActionErr != nil {
|
||||
return c.finalAction, c.finalActionErr
|
||||
}
|
||||
|
||||
if s.PublicKey() != nil {
|
||||
metricPublicKeyConnections.Add(1)
|
||||
}
|
||||
defer func() {
|
||||
c.finalAction = action
|
||||
c.finalActionErr = err
|
||||
c.pubKey = s.PublicKey()
|
||||
if c.pubKey != nil && action.Accept {
|
||||
metricPublicKeyAccepts.Add(1)
|
||||
}
|
||||
}()
|
||||
action = c.action0
|
||||
|
||||
var awaitReadOnce sync.Once // to start Reads on cr
|
||||
var sawInterrupt syncs.AtomicBool
|
||||
|
@ -672,13 +725,11 @@ func (c *conn) expandPublicKeyURL(pubKeyURL string) string {
|
|||
// sshSession is an accepted Tailscale SSH session.
|
||||
type sshSession struct {
|
||||
ssh.Session
|
||||
idH string // the RFC4253 sec8 hash H; don't share outside process
|
||||
sharedID string // ID that's shared with control
|
||||
logf logger.Logf
|
||||
|
||||
ctx *sshContext // implements context.Context
|
||||
conn *conn
|
||||
action *tailcfg.SSHAction
|
||||
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
||||
|
||||
// initialized by launchProcess:
|
||||
|
@ -699,22 +750,21 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *conn) newSSHSession(s ssh.Session, action *tailcfg.SSHAction) *sshSession {
|
||||
sharedID := fmt.Sprintf("%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
|
||||
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
||||
sharedID := fmt.Sprintf("sess-%s-%02x", c.now.UTC().Format("20060102T150405"), randBytes(5))
|
||||
c.logf("starting session: %v", sharedID)
|
||||
return &sshSession{
|
||||
Session: s,
|
||||
idH: s.Context().(ssh.Context).SessionID(),
|
||||
sharedID: sharedID,
|
||||
ctx: newSSHContext(),
|
||||
conn: c,
|
||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||
action: action,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) isStillValid(pubKey ssh.PublicKey) bool {
|
||||
a, localUser, err := c.evaluatePolicy(pubKey)
|
||||
// isStillValid reports whether the conn is still valid.
|
||||
func (c *conn) isStillValid() bool {
|
||||
a, localUser, err := c.evaluatePolicy(c.pubKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -724,18 +774,20 @@ func (c *conn) isStillValid(pubKey ssh.PublicKey) bool {
|
|||
return c.localUser.Username == localUser
|
||||
}
|
||||
|
||||
// checkStillValid checks that the session is still valid per the latest SSHPolicy.
|
||||
// If not, it terminates the session.
|
||||
func (ss *sshSession) checkStillValid() {
|
||||
if ss.conn.isStillValid(ss.PublicKey()) {
|
||||
// checkStillValid checks that the conn is still valid per the latest SSHPolicy.
|
||||
// If not, it terminates all sessions associated with the conn.
|
||||
func (c *conn) checkStillValid() {
|
||||
if c.isStillValid() {
|
||||
return
|
||||
}
|
||||
metricPolicyChangeKick.Add(1)
|
||||
ss.logf("session no longer valid per new SSH policy; closing")
|
||||
ss.ctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Access revoked.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
c.logf("session no longer valid per new SSH policy; closing")
|
||||
for _, s := range c.sessions {
|
||||
s.ctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Access revoked.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) fetchSSHAction(ctx context.Context, url string) (*tailcfg.SSHAction, error) {
|
||||
|
@ -798,41 +850,27 @@ func (ss *sshSession) killProcessOnContextDone() {
|
|||
})
|
||||
}
|
||||
|
||||
// sessionAction returns the SSHAction associated with the session.
|
||||
func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bool) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
ss, ok = srv.activeSessionByH[sctx.SessionID()]
|
||||
return
|
||||
}
|
||||
|
||||
// startSessionLocked registers ss as an active session.
|
||||
// It must be called with srv.mu held.
|
||||
func (srv *server) startSessionLocked(ss *sshSession) {
|
||||
srv.sessionWaitGroup.Add(1)
|
||||
if ss.idH == "" {
|
||||
panic("empty idH")
|
||||
}
|
||||
func (c *conn) startSessionLocked(ss *sshSession) {
|
||||
c.srv.sessionWaitGroup.Add(1)
|
||||
if ss.sharedID == "" {
|
||||
panic("empty sharedID")
|
||||
}
|
||||
if _, dup := srv.activeSessionByH[ss.idH]; dup {
|
||||
panic("dup idH")
|
||||
}
|
||||
if _, dup := srv.activeSessionBySharedID[ss.sharedID]; dup {
|
||||
panic("dup sharedID")
|
||||
}
|
||||
mak.Set(&srv.activeSessionByH, ss.idH, ss)
|
||||
mak.Set(&srv.activeSessionBySharedID, ss.sharedID, ss)
|
||||
c.sessions = append(c.sessions, ss)
|
||||
}
|
||||
|
||||
// endSession unregisters s from the list of active sessions.
|
||||
func (srv *server) endSession(ss *sshSession) {
|
||||
defer srv.sessionWaitGroup.Done()
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
delete(srv.activeSessionByH, ss.idH)
|
||||
delete(srv.activeSessionBySharedID, ss.sharedID)
|
||||
func (c *conn) endSession(ss *sshSession) {
|
||||
defer c.srv.sessionWaitGroup.Done()
|
||||
c.srv.mu.Lock()
|
||||
defer c.srv.mu.Unlock()
|
||||
for i, s := range c.sessions {
|
||||
if s == ss {
|
||||
c.sessions = append(c.sessions[:i], c.sessions[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var errSessionDone = errors.New("session is done")
|
||||
|
@ -841,7 +879,7 @@ var errSessionDone = errors.New("session is done")
|
|||
// forwards agent connections between the listener and the ssh.Session.
|
||||
// On success, it assigns ss.agentListener.
|
||||
func (ss *sshSession) handleSSHAgentForwarding(s ssh.Session, lu *user.User) error {
|
||||
if !ssh.AgentRequested(ss) || !ss.action.AllowAgentForwarding {
|
||||
if !ssh.AgentRequested(ss) || !ss.conn.finalAction.AllowAgentForwarding {
|
||||
return nil
|
||||
}
|
||||
ss.logf("ssh: agent forwarding requested")
|
||||
|
@ -906,15 +944,15 @@ func (ss *sshSession) run() {
|
|||
ss.Exit(1)
|
||||
return
|
||||
}
|
||||
srv.startSessionLocked(ss)
|
||||
ss.conn.startSessionLocked(ss)
|
||||
srv.mu.Unlock()
|
||||
|
||||
defer srv.endSession(ss)
|
||||
defer ss.conn.endSession(ss)
|
||||
|
||||
if ss.action.SessionDuration != 0 {
|
||||
t := time.AfterFunc(ss.action.SessionDuration, func() {
|
||||
if ss.conn.finalAction.SessionDuration != 0 {
|
||||
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
|
||||
ss.ctx.CloseWithError(userVisibleError{
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", ss.action.SessionDuration),
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
|
||||
context.DeadlineExceeded,
|
||||
})
|
||||
})
|
||||
|
|
|
@ -238,9 +238,10 @@ func TestSSH(t *testing.T) {
|
|||
node: &tailcfg.Node{},
|
||||
uprof: &tailcfg.UserProfile{},
|
||||
}
|
||||
sc.finalAction = &tailcfg.SSHAction{Accept: true}
|
||||
|
||||
sc.Handler = func(s ssh.Session) {
|
||||
sc.newSSHSession(s, &tailcfg.SSHAction{Accept: true}).run()
|
||||
sc.newSSHSession(s).run()
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
|
|
Loading…
Reference in New Issue