ssh/tailssh: use context.WithCancelCause
It was using a custom implmentation of the context.WithCancelCause, replace usage with stdlib. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
a2be1aabfa
commit
e69682678f
|
@ -1,63 +0,0 @@
|
|||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tailssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// sshContext is the context.Context implementation we use for SSH
|
||||
// that adds a CloseWithError method. Otherwise it's just a normalish
|
||||
// Context.
|
||||
type sshContext struct {
|
||||
underlying context.Context
|
||||
cancel context.CancelFunc // cancels underlying
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
err error
|
||||
}
|
||||
|
||||
func newSSHContext(ctx context.Context) *sshContext {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &sshContext{underlying: ctx, cancel: cancel}
|
||||
}
|
||||
|
||||
func (ctx *sshContext) CloseWithError(err error) {
|
||||
ctx.mu.Lock()
|
||||
defer ctx.mu.Unlock()
|
||||
if ctx.closed {
|
||||
return
|
||||
}
|
||||
ctx.closed = true
|
||||
ctx.err = err
|
||||
ctx.cancel()
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Err() error {
|
||||
ctx.mu.Lock()
|
||||
defer ctx.mu.Unlock()
|
||||
return ctx.err
|
||||
}
|
||||
|
||||
func (ctx *sshContext) Done() <-chan struct{} { return ctx.underlying.Done() }
|
||||
func (ctx *sshContext) Deadline() (deadline time.Time, ok bool) { return }
|
||||
func (ctx *sshContext) Value(k any) any { return ctx.underlying.Value(k) }
|
||||
|
||||
// userVisibleError is a wrapper around an error that implements
|
||||
// SSHTerminationError, so msg is written to their session.
|
||||
type userVisibleError struct {
|
||||
msg string
|
||||
error
|
||||
}
|
||||
|
||||
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||
|
||||
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||
// session and should be written to user's sessions.
|
||||
type SSHTerminationError interface {
|
||||
error
|
||||
SSHTerminationMessage() string
|
||||
}
|
|
@ -787,7 +787,8 @@ type sshSession struct {
|
|||
sharedID string // ID that's shared with control
|
||||
logf logger.Logf
|
||||
|
||||
ctx *sshContext // implements context.Context
|
||||
ctx context.Context
|
||||
cancelCtx context.CancelCauseFunc
|
||||
conn *conn
|
||||
agentListener net.Listener // non-nil if agent-forwarding requested+allowed
|
||||
|
||||
|
@ -812,12 +813,14 @@ func (ss *sshSession) vlogf(format string, args ...interface{}) {
|
|||
func (c *conn) newSSHSession(s ssh.Session) *sshSession {
|
||||
sharedID := fmt.Sprintf("sess-%s-%02x", c.srv.now().UTC().Format("20060102T150405"), randBytes(5))
|
||||
c.logf("starting session: %v", sharedID)
|
||||
ctx, cancel := context.WithCancelCause(s.Context())
|
||||
return &sshSession{
|
||||
Session: s,
|
||||
sharedID: sharedID,
|
||||
ctx: newSSHContext(s.Context()),
|
||||
conn: c,
|
||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||
Session: s,
|
||||
sharedID: sharedID,
|
||||
ctx: ctx,
|
||||
cancelCtx: cancel,
|
||||
conn: c,
|
||||
logf: logger.WithPrefix(c.srv.logf, "ssh-session("+sharedID+"): "),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -844,7 +847,7 @@ func (c *conn) checkStillValid() {
|
|||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
for _, s := range c.sessions {
|
||||
s.ctx.CloseWithError(userVisibleError{
|
||||
s.cancelCtx(userVisibleError{
|
||||
fmt.Sprintf("Access revoked.\r\n"),
|
||||
context.Canceled,
|
||||
})
|
||||
|
@ -897,7 +900,7 @@ func (ss *sshSession) killProcessOnContextDone() {
|
|||
// Either the process has already exited, in which case this does nothing.
|
||||
// Or, the process is still running in which case this will kill it.
|
||||
ss.exitOnce.Do(func() {
|
||||
err := ss.ctx.Err()
|
||||
err := context.Cause(ss.ctx)
|
||||
if serr, ok := err.(SSHTerminationError); ok {
|
||||
msg := serr.SSHTerminationMessage()
|
||||
if msg != "" {
|
||||
|
@ -997,7 +1000,7 @@ var recordSSH = envknob.RegisterBool("TS_DEBUG_LOG_SSH")
|
|||
func (ss *sshSession) run() {
|
||||
metricActiveSessions.Add(1)
|
||||
defer metricActiveSessions.Add(-1)
|
||||
defer ss.ctx.CloseWithError(errSessionDone)
|
||||
defer ss.cancelCtx(errSessionDone)
|
||||
|
||||
if attached := ss.conn.srv.attachSessionToConnIfNotShutdown(ss); !attached {
|
||||
fmt.Fprintf(ss, "Tailscale SSH is shutting down\r\n")
|
||||
|
@ -1011,7 +1014,7 @@ func (ss *sshSession) run() {
|
|||
|
||||
if ss.conn.finalAction.SessionDuration != 0 {
|
||||
t := time.AfterFunc(ss.conn.finalAction.SessionDuration, func() {
|
||||
ss.ctx.CloseWithError(userVisibleError{
|
||||
ss.cancelCtx(userVisibleError{
|
||||
fmt.Sprintf("Session timeout of %v elapsed.", ss.conn.finalAction.SessionDuration),
|
||||
context.DeadlineExceeded,
|
||||
})
|
||||
|
@ -1066,7 +1069,7 @@ func (ss *sshSession) run() {
|
|||
defer ss.stdin.Close()
|
||||
if _, err := io.Copy(rec.writer("i", ss.stdin), ss); err != nil {
|
||||
logf("stdin copy: %v", err)
|
||||
ss.ctx.CloseWithError(err)
|
||||
ss.cancelCtx(err)
|
||||
}
|
||||
}()
|
||||
var openOutputStreams atomic.Int32
|
||||
|
@ -1080,7 +1083,7 @@ func (ss *sshSession) run() {
|
|||
_, err := io.Copy(rec.writer("o", ss), ss.stdout)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logf("stdout copy: %v", err)
|
||||
ss.ctx.CloseWithError(err)
|
||||
ss.cancelCtx(err)
|
||||
}
|
||||
if openOutputStreams.Add(-1) == 0 {
|
||||
ss.CloseWrite()
|
||||
|
@ -1489,3 +1492,19 @@ var (
|
|||
metricSFTP = clientmetric.NewCounter("ssh_sftp_requests")
|
||||
metricLocalPortForward = clientmetric.NewCounter("ssh_local_port_forward_requests")
|
||||
)
|
||||
|
||||
// userVisibleError is a wrapper around an error that implements
|
||||
// SSHTerminationError, so msg is written to their session.
|
||||
type userVisibleError struct {
|
||||
msg string
|
||||
error
|
||||
}
|
||||
|
||||
func (ue userVisibleError) SSHTerminationMessage() string { return ue.msg }
|
||||
|
||||
// SSHTerminationError is implemented by errors that terminate an SSH
|
||||
// session and should be written to user's sessions.
|
||||
type SSHTerminationError interface {
|
||||
error
|
||||
SSHTerminationMessage() string
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue