113 lines
2.7 KiB
Go
113 lines
2.7 KiB
Go
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package tailssh
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"sync"
|
|
|
|
"tailscale.com/tempfork/gliderlabs/ssh"
|
|
)
|
|
|
|
// readResult is a result from a io.Reader.Read call,
|
|
// as used by contextReader.
|
|
type readResult struct {
|
|
buf []byte // ownership passed on chan send
|
|
err error
|
|
}
|
|
|
|
// contextReader wraps an io.Reader, providing a ReadContext method
|
|
// that can be aborted before yielding bytes. If it's aborted, subsequent
|
|
// reads can get those byte(s) later.
|
|
type contextReader struct {
|
|
r io.Reader
|
|
|
|
// buffered is leftover data from a previous read call that wasn't entirely
|
|
// consumed.
|
|
buffered []byte
|
|
// readErr is a previous read error that was seen while filling buffered. It
|
|
// should be returned to the caller after buffered is consumed.
|
|
readErr error
|
|
|
|
mu sync.Mutex // guards ch only
|
|
|
|
// ch is non-nil if a goroutine had been started and has a result to be
|
|
// read. The goroutine may be either still running or done and has
|
|
// send to the channel.
|
|
ch chan readResult
|
|
}
|
|
|
|
// HasOutstandingRead reports whether there's an outstanding Read call that's
|
|
// either currently blocked in a Read or whose result hasn't been consumed.
|
|
func (w *contextReader) HasOutstandingRead() bool {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
return w.ch != nil
|
|
}
|
|
|
|
func (w *contextReader) setChan(c chan readResult) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
w.ch = c
|
|
}
|
|
|
|
// ReadContext is like Read, but takes a context permitting the read to be canceled.
|
|
//
|
|
// If the context becomes done, the underlying Read call continues and its result
|
|
// will be given to the next caller to ReadContext.
|
|
func (w *contextReader) ReadContext(ctx context.Context, p []byte) (n int, err error) {
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
n = copy(p, w.buffered)
|
|
if n > 0 {
|
|
w.buffered = w.buffered[n:]
|
|
if len(w.buffered) == 0 {
|
|
err = w.readErr
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
if w.ch == nil {
|
|
ch := make(chan readResult, 1)
|
|
w.setChan(ch)
|
|
go func() {
|
|
rbuf := make([]byte, len(p))
|
|
n, err := w.r.Read(rbuf)
|
|
ch <- readResult{rbuf[:n], err}
|
|
}()
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return 0, ctx.Err()
|
|
case rr := <-w.ch:
|
|
w.setChan(nil)
|
|
n = copy(p, rr.buf)
|
|
w.buffered = rr.buf[n:]
|
|
w.readErr = rr.err
|
|
if len(w.buffered) == 0 {
|
|
err = rr.err
|
|
}
|
|
return n, err
|
|
}
|
|
}
|
|
|
|
// contextReaderSession implements ssh.Session, wrapping another
|
|
// ssh.Session but changing its Read method to use contextReader.
|
|
type contextReaderSession struct {
|
|
ssh.Session
|
|
cr *contextReader
|
|
}
|
|
|
|
func (a contextReaderSession) Read(p []byte) (n int, err error) {
|
|
if a.cr.HasOutstandingRead() {
|
|
return a.cr.ReadContext(context.Background(), p)
|
|
}
|
|
return a.Session.Read(p)
|
|
}
|