From e69682678fc73787dedb6c9198d101b9307e9fb0 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Wed, 8 Mar 2023 22:38:29 -0800 Subject: [PATCH] ssh/tailssh: use context.WithCancelCause It was using a custom implmentation of the context.WithCancelCause, replace usage with stdlib. Signed-off-by: Maisem Ali --- ssh/tailssh/context.go | 63 ------------------------------------------ ssh/tailssh/tailssh.go | 43 ++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 75 deletions(-) delete mode 100644 ssh/tailssh/context.go diff --git a/ssh/tailssh/context.go b/ssh/tailssh/context.go deleted file mode 100644 index de3886da4..000000000 --- a/ssh/tailssh/context.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package tailssh - -import ( - "context" - "sync" - "time" -) - -// sshContext is the context.Context implementation we use for SSH -// that adds a CloseWithError method. Otherwise it's just a normalish -// Context. -type sshContext struct { - underlying context.Context - cancel context.CancelFunc // cancels underlying - mu sync.Mutex - closed bool - err error -} - -func newSSHContext(ctx context.Context) *sshContext { - ctx, cancel := context.WithCancel(ctx) - return &sshContext{underlying: ctx, cancel: cancel} -} - -func (ctx *sshContext) CloseWithError(err error) { - ctx.mu.Lock() - defer ctx.mu.Unlock() - if ctx.closed { - return - } - ctx.closed = true - ctx.err = err - ctx.cancel() -} - -func (ctx *sshContext) Err() error { - ctx.mu.Lock() - defer ctx.mu.Unlock() - return ctx.err -} - -func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() } -func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return } -func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) } - -// userVisibleError is a wrapper around an error that implements -// SSHTerminationError, so msg is written to their session. -type userVisibleError struct { - msg string - error -} - -func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg } - -// SSHTerminationError is implemented by errors that terminate an SSH -// session and should be written to user's sessions. -type SSHTerminationError interface { - error - SSHTerminationMessage() string -} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index f4167ffbe..2dce300bb 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -787,7 +787,8 @@ type sshSession struct { sharedID string // ID that's shared with control logf logger.Logf - ctx *sshContext // implements context.Context + ctx context.Context + cancelCtx context.CancelCauseFunc conn *conn agentListener net.Listener // non-nil if agent-forwarding requested+allowed @@ -812,12 +813,14 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) { func (c *conn) newSSHSession(s ssh.Session) *sshSession { sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5)) c.logf("starting session: %v", sharedID) + ctx, cancel := context.WithCancelCause(s.Context()) return &sshSession{ - Session: s, - sharedID: sharedID, - ctx: newSSHContext(s.Context()), - conn: c, - logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), + Session: s, + sharedID: sharedID, + ctx: ctx, + cancelCtx: cancel, + conn: c, + logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "), } } @@ -844,7 +847,7 @@ func (c *conn) checkStillValid() { c.mu.Lock() defer c.mu.Unlock() for _, s := range c.sessions { - s.ctx.CloseWithError(userVisibleError{ + s.cancelCtx(userVisibleError{ fmt.Sprintf("Access revoked.\r\n"), context.Canceled, }) @@ -897,7 +900,7 @@ func (ss *sshSession) killProcessOnContextDone() { // Either the process has already exited, in which case this does nothing. // Or, the process is still running in which case this will kill it. ss.exitOnce.Do(func() { - err := ss.ctx.Err() + err := context.Cause(ss.ctx) if serr, ok := err.(SSHTerminationError); ok { msg := serr.SSHTerminationMessage() if msg != "" { @@ -997,7 +1000,7 @@ var recordSSH = envknob.RegisterBool("TS_DEBUG_LOG_SSH") func (ss *sshSession) run() { metricActiveSessions.Add(1) defer metricActiveSessions.Add(-1) - defer ss.ctx.CloseWithError(errSessionDone) + defer ss.cancelCtx(errSessionDone) if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached { fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n") @@ -1011,7 +1014,7 @@ func (ss *sshSession) run() { if ss.conn.finalAction.SessionDuration != 0 { t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() { - ss.ctx.CloseWithError(userVisibleError{ + ss.cancelCtx(userVisibleError{ fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration), context.DeadlineExceeded, }) @@ -1066,7 +1069,7 @@ func (ss *sshSession) run() { defer ss.stdin.Close() if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil { logf("stdin copy: %v", err) - ss.ctx.CloseWithError(err) + ss.cancelCtx(err) } }() var openOutputStreams atomic.Int32 @@ -1080,7 +1083,7 @@ func (ss *sshSession) run() { _, err := io.Copy(rec.writer("o", ss), ss.stdout) if err != nil && !errors.Is(err, io.EOF) { logf("stdout copy: %v", err) - ss.ctx.CloseWithError(err) + ss.cancelCtx(err) } if openOutputStreams.Add(-1) == 0 { ss.CloseWrite() @@ -1489,3 +1492,19 @@ var ( metricSFTP = clientmetric.NewCounter("ssh_sftp_requests") metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests") ) + +// userVisibleError is a wrapper around an error that implements +// SSHTerminationError, so msg is written to their session. +type userVisibleError struct { + msg string + error +} + +func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg } + +// SSHTerminationError is implemented by errors that terminate an SSH +// session and should be written to user's sessions. +type SSHTerminationError interface { + error + SSHTerminationMessage() string +}