137 lines
4.4 KiB
Go
137 lines
4.4 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package tailssh
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptrace"
|
|
"net/netip"
|
|
"time"
|
|
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/util/multierr"
|
|
)
|
|
|
|
// ConnectToRecorder connects to the recorder at any of the provided addresses.
|
|
// It returns the first successful response, or a multierr if all attempts fail.
|
|
//
|
|
// On success, it returns a WriteCloser that can be used to upload the
|
|
// recording, and a channel that will be sent an error (or nil) when the upload
|
|
// fails or completes.
|
|
//
|
|
// In both cases, a slice of SSHRecordingAttempts is returned which detail the
|
|
// attempted recorder IP and the error message, if the attempt failed. The
|
|
// attempts are in order the recorder(s) was attempted. If successful a
|
|
// successful connection is made, the last attempt in the slice is the
|
|
// attempt for connected recorder.
|
|
func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) {
|
|
if len(recs) == 0 {
|
|
return nil, nil, nil, errors.New("no recorders configured")
|
|
}
|
|
// We use a special context for dialing the recorder, so that we can
|
|
// limit the time we spend dialing to 30 seconds and still have an
|
|
// unbounded context for the upload.
|
|
dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second)
|
|
defer dialCancel()
|
|
hc, err := SessionRecordingClientForDialer(dialCtx, dial)
|
|
if err != nil {
|
|
return nil, nil, nil, err
|
|
}
|
|
|
|
var errs []error
|
|
var attempts []*tailcfg.SSHRecordingAttempt
|
|
for _, ap := range recs {
|
|
attempt := &tailcfg.SSHRecordingAttempt{
|
|
Recorder: ap,
|
|
}
|
|
attempts = append(attempts, attempt)
|
|
|
|
// We dial the recorder and wait for it to send a 100-continue
|
|
// response before returning from this function. This ensures that
|
|
// the recorder is ready to accept the recording.
|
|
|
|
// got100 is closed when we receive the 100-continue response.
|
|
got100 := make(chan struct{})
|
|
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
|
Got100Continue: func() {
|
|
close(got100)
|
|
},
|
|
})
|
|
|
|
pr, pw := io.Pipe()
|
|
req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr)
|
|
if err != nil {
|
|
err = fmt.Errorf("recording: error starting recording: %w", err)
|
|
attempt.FailureMessage = err.Error()
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
// We set the Expect header to 100-continue, so that the recorder
|
|
// will send a 100-continue response before it starts reading the
|
|
// request body.
|
|
req.Header.Set("Expect", "100-continue")
|
|
|
|
// errChan is used to indicate the result of the request.
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
resp, err := hc.Do(req)
|
|
if err != nil {
|
|
errChan <- fmt.Errorf("recording: error starting recording: %w", err)
|
|
return
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status)
|
|
return
|
|
}
|
|
errChan <- nil
|
|
}()
|
|
select {
|
|
case <-got100:
|
|
case err := <-errChan:
|
|
// If we get an error before we get the 100-continue response,
|
|
// we need to try another recorder.
|
|
if err == nil {
|
|
// If the error is nil, we got a 200 response, which
|
|
// is unexpected as we haven't sent any data yet.
|
|
err = errors.New("recording: unexpected EOF")
|
|
}
|
|
attempt.FailureMessage = err.Error()
|
|
errs = append(errs, err)
|
|
continue
|
|
}
|
|
return pw, attempts, errChan, nil
|
|
}
|
|
return nil, attempts, nil, multierr.New(errs...)
|
|
}
|
|
|
|
// SessionRecordingClientForDialer returns an http.Client that uses a clone of
|
|
// the provided Dialer's PeerTransport to dial connections. This is used to make
|
|
// requests to the session recording server to upload session recordings. It
|
|
// uses the provided dialCtx to dial connections, and limits a single dial to 5
|
|
// seconds.
|
|
func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) {
|
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
|
|
|
tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
go func() {
|
|
select {
|
|
case <-perAttemptCtx.Done():
|
|
case <-dialCtx.Done():
|
|
cancel()
|
|
}
|
|
}()
|
|
return dial(perAttemptCtx, network, addr)
|
|
}
|
|
return &http.Client{
|
|
Transport: tr,
|
|
}, nil
|
|
}
|