ssh/tailssh: add session recording test for non-pty sessions

Updates tailscale/corp#9967

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-03-24 13:50:33 -07:00 committed by Maisem Ali
parent d39a5e4417
commit 09d0b632d4
1 changed files with 96 additions and 2 deletions

View File

@ -7,6 +7,7 @@ package tailssh
import ( import (
"bytes" "bytes"
"context"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
@ -14,6 +15,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -324,9 +326,101 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
} }
} }
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
// when the client is not interactive (i.e. no PTY).
// It starts a local SSH server and a recording server. The recording server
// records the SSH session and returns it to the test.
// The test then verifies that the recording has a valid CastHeader, it does not
// validate the contents of the recording.
func TestSSHRecordingNonInteractive(t *testing.T) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
}
var recording []byte
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer cancel()
var err error
recording, err = ioutil.ReadAll(r.Body)
if err != nil {
t.Error(err)
return
}
w.WriteHeader(http.StatusOK)
}))
defer recordingServer.Close()
state := &localState{
sshEnabled: true,
matchingRule: newSSHRule(
&tailcfg.SSHAction{
Accept: true,
Recorders: []netip.AddrPort{
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())),
},
},
),
}
s := &server{
logf: t.Logf,
httpc: recordingServer.Client(),
}
defer s.Shutdown()
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22"))
sc, dc := memnet.NewTCPConn(src, dst, 1024)
s.lb = state
const sshUser = "alice"
cfg := &gossh.ClientConfig{
User: sshUser,
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg)
if err != nil {
t.Errorf("client: %v", err)
return
}
client := gossh.NewClient(c, chans, reqs)
defer client.Close()
session, err := client.NewSession()
if err != nil {
t.Errorf("client: %v", err)
return
}
defer session.Close()
t.Logf("client established session")
_, err = session.CombinedOutput("echo Ran echo!")
if err != nil {
t.Errorf("client: %v", err)
}
}()
if err := s.HandleSSHConn(dc); err != nil {
t.Errorf("unexpected error: %v", err)
}
wg.Wait()
<-ctx.Done() // wait for recording to finish
var ch CastHeader
if err := json.NewDecoder(bytes.NewReader(recording)).Decode(&ch); err != nil {
t.Fatal(err)
}
if ch.SSHUser != sshUser {
t.Errorf("SSHUser = %q; want %q", ch.SSHUser, sshUser)
}
if ch.Command != "echo Ran echo!" {
t.Errorf("Command = %q; want %q", ch.Command, "echo Ran echo!")
}
}
func TestSSHAuthFlow(t *testing.T) { func TestSSHAuthFlow(t *testing.T) {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skip("Not running on Linux, skipping") t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
} }
acceptRule := newSSHRule(&tailcfg.SSHAction{ acceptRule := newSSHRule(&tailcfg.SSHAction{
Accept: true, Accept: true,