diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index c99b32532..15ca13133 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -40,32 +40,39 @@ import ( // Handle handles an SSH connection from c. func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error { - sshd := &server{lb, logf} - srv := &ssh.Server{ - Handler: sshd.handleSSH, + srv := &server{lb, logf} + ss, err := srv.newSSHServer() + if err != nil { + return err + } + ss.HandleConn(c) + return nil +} + +func (srv *server) newSSHServer() (*ssh.Server, error) { + ss := &ssh.Server{ + Handler: srv.handleSSH, RequestHandlers: map[string]ssh.RequestHandler{}, SubsystemHandlers: map[string]ssh.SubsystemHandler{}, ChannelHandlers: map[string]ssh.ChannelHandler{}, } for k, v := range ssh.DefaultRequestHandlers { - srv.RequestHandlers[k] = v + ss.RequestHandlers[k] = v } for k, v := range ssh.DefaultChannelHandlers { - srv.ChannelHandlers[k] = v + ss.ChannelHandlers[k] = v } for k, v := range ssh.DefaultSubsystemHandlers { - srv.SubsystemHandlers[k] = v + ss.SubsystemHandlers[k] = v } - keys, err := lb.GetSSH_HostKeys() + keys, err := srv.lb.GetSSH_HostKeys() if err != nil { - return err + return nil, err } for _, signer := range keys { - srv.AddHostKey(signer) + ss.AddHostKey(signer) } - - srv.HandleConn(c) - return nil + return ss, nil } type server struct { diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 252c3ff30..be0b4febf 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -8,11 +8,24 @@ package tailssh import ( + "context" + "errors" + "fmt" + "net" + "os/exec" + "os/user" "testing" "time" + "github.com/gliderlabs/ssh" "inet.af/netaddr" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/wgengine" ) func TestMatchRule(t *testing.T) { @@ -155,3 +168,75 @@ func TestMatchRule(t *testing.T) { } func timePtr(t time.Time) *time.Time { return &t } + +func TestSSH(t *testing.T) { + ml := new(tstest.MemLogger) + var logf logger.Logf = ml.Logf + eng, err := wgengine.NewFakeUserspaceEngine(logf, 0) + if err != nil { + t.Fatal(err) + } + lb, err := ipnlocal.NewLocalBackend(logf, "", + new(ipn.MemoryStore), + new(tsdial.Dialer), + eng, 0) + if err != nil { + t.Fatal(err) + } + defer lb.Shutdown() + dir := t.TempDir() + lb.SetVarRoot(dir) + + srv := &server{lb, logf} + ss, err := srv.newSSHServer() + if err != nil { + t.Fatal(err) + } + + u, err := user.Current() + if err != nil { + t.Fatal(err) + } + + ci := &sshConnInfo{ + sshUser: "test", + srcIP: netaddr.MustParseIP("1.2.3.4"), + node: &tailcfg.Node{}, + uprof: &tailcfg.UserProfile{}, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ss.Handler = func(s ssh.Session) { + srv.handleAcceptedSSH(ctx, s, ci, u) + } + + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + port := ln.Addr().(*net.TCPAddr).Port + + go func() { + for { + c, err := ln.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + t.Errorf("Accept: %v", err) + } + return + } + go ss.HandleConn(c) + } + }() + + got, err := exec.Command("ssh", + "-p", fmt.Sprint(port), + "-o", "StrictHostKeyChecking=no", + "user@127.0.0.1", "env").CombinedOutput() + if err != nil { + t.Fatal(err) + } + t.Logf("Got: %s", got) +}