ssh/tailssh: terminate sessions when tailscaled shutsdown

Ideally we would re-establish these sessions when tailscaled comes back
up, however we do not do that yet so this is better than leaking the
sessions.

Updates #3802

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2022-05-28 04:33:46 -07:00 committed by Maisem Ali
parent 760740905e
commit 7cd8c3e839
2 changed files with 52 additions and 9 deletions

View File

@ -81,6 +81,9 @@ type SSHServer interface {
// so that existing sessions can be re-evaluated for validity // so that existing sessions can be re-evaluated for validity
// and closed if they'd no longer be accepted. // and closed if they'd no longer be accepted.
OnPolicyChange() OnPolicyChange()
// Shutdown is called when tailscaled is shutting down.
Shutdown()
} }
type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error) type newSSHServerFunc func(logger.Logf, *LocalBackend) (SSHServer, error)
@ -346,6 +349,9 @@ func (b *LocalBackend) Shutdown() {
b.mu.Lock() b.mu.Lock()
b.shutdownCalled = true b.shutdownCalled = true
cc := b.cc cc := b.cc
if b.sshServer != nil {
b.sshServer.Shutdown()
}
b.closePeerAPIListenersLocked() b.closePeerAPIListenersLocked()
b.mu.Unlock() b.mu.Unlock()

View File

@ -58,11 +58,14 @@ type server struct {
pubKeyHTTPClient *http.Client // or nil for http.DefaultClient pubKeyHTTPClient *http.Client // or nil for http.DefaultClient
timeNow func() time.Time // or nil for time.Now timeNow func() time.Time // or nil for time.Now
sessionWaitGroup sync.WaitGroup
// mu protects the following // mu protects the following
mu sync.Mutex mu sync.Mutex
activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session activeSessionByH map[string]*sshSession // ssh.SessionID (DH H) => session
activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session activeSessionBySharedID map[string]*sshSession // yyymmddThhmmss-XXXXX => session
fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL
shutdownCalled bool
} }
func (srv *server) now() time.Time { func (srv *server) now() time.Time {
@ -101,6 +104,20 @@ func (srv *server) HandleSSHConn(c net.Conn) error {
return nil return nil
} }
// Shutdown terminates all active sessions.
func (srv *server) Shutdown() {
srv.mu.Lock()
srv.shutdownCalled = true
for _, s := range srv.activeSessionByH {
s.ctx.CloseWithError(userVisibleError{
fmt.Sprintf("Tailscale shutting down.\r\n"),
context.Canceled,
})
}
srv.mu.Unlock()
srv.sessionWaitGroup.Wait()
}
// OnPolicyChange terminates any active sessions that no longer match // OnPolicyChange terminates any active sessions that no longer match
// the SSH access policy. // the SSH access policy.
func (srv *server) OnPolicyChange() { func (srv *server) OnPolicyChange() {
@ -227,6 +244,15 @@ func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig {
} }
func (srv *server) newConn() (*conn, error) { func (srv *server) newConn() (*conn, error) {
srv.mu.Lock()
shutdownCalled := srv.shutdownCalled
srv.mu.Unlock()
if shutdownCalled {
// Stop accepting new connections.
// Connections in the auth phase are handled in handleConnPostSSHAuth.
// Existing sessions are terminated by Shutdown.
return nil, gossh.ErrDenied
}
c := &conn{srv: srv, now: srv.now()} c := &conn{srv: srv, now: srv.now()}
c.Server = &ssh.Server{ c.Server = &ssh.Server{
Version: "Tailscale", Version: "Tailscale",
@ -756,10 +782,10 @@ func (srv *server) getSessionForContext(sctx ssh.Context) (ss *sshSession, ok bo
return return
} }
// startSession registers ss as an active session. // startSessionLocked registers ss as an active session.
func (srv *server) startSession(ss *sshSession) { // It must be called with srv.mu held.
srv.mu.Lock() func (srv *server) startSessionLocked(ss *sshSession) {
defer srv.mu.Unlock() srv.sessionWaitGroup.Add(1)
if ss.idH == "" { if ss.idH == "" {
panic("empty idH") panic("empty idH")
} }
@ -778,6 +804,7 @@ func (srv *server) startSession(ss *sshSession) {
// endSession unregisters s from the list of active sessions. // endSession unregisters s from the list of active sessions.
func (srv *server) endSession(ss *sshSession) { func (srv *server) endSession(ss *sshSession) {
defer srv.sessionWaitGroup.Done()
srv.mu.Lock() srv.mu.Lock()
defer srv.mu.Unlock() defer srv.mu.Unlock()
delete(srv.activeSessionByH, ss.idH) delete(srv.activeSessionByH, ss.idH)
@ -842,11 +869,21 @@ var recordSSH = envknob.Bool("TS_DEBUG_LOG_SSH")
// It handles ss once it's been accepted and determined // It handles ss once it's been accepted and determined
// that it should run. // that it should run.
func (ss *sshSession) run() { func (ss *sshSession) run() {
srv := ss.conn.srv
srv.startSession(ss)
defer srv.endSession(ss)
defer ss.ctx.CloseWithError(errSessionDone) defer ss.ctx.CloseWithError(errSessionDone)
srv := ss.conn.srv
srv.mu.Lock()
if srv.shutdownCalled {
srv.mu.Unlock()
// Do not start any new sessions.
fmt.Fprintf(ss, "Tailscale is shutting down\r\n")
ss.Exit(1)
return
}
srv.startSessionLocked(ss)
srv.mu.Unlock()
defer srv.endSession(ss)
if ss.action.SessionDuration != 0 { if ss.action.SessionDuration != 0 {
t := time.AfterFunc(ss.action.SessionDuration, func() { t := time.AfterFunc(ss.action.SessionDuration, func() {
@ -858,7 +895,7 @@ func (ss *sshSession) run() {
defer t.Stop() defer t.Stop()
} }
logf := srv.logf logf := ss.logf
lu := ss.conn.localUser lu := ss.conn.localUser
localUser := lu.Username localUser := lu.Username