ssh/tailssh: add start of real ssh tests

Updates #3802

Change-Id: I9aea4250062d3a06ca7a5e71a81d31c27a988615
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-02-24 13:31:54 -08:00 committed by Brad Fitzpatrick
parent c9eca9451a
commit 6e4f3614cf
2 changed files with 104 additions and 12 deletions

View File

@ -40,32 +40,39 @@ import (
// Handle handles an SSH connection from c.
func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
sshd := &server{lb, logf}
srv := &ssh.Server{
Handler: sshd.handleSSH,
srv := &server{lb, logf}
ss, err := srv.newSSHServer()
if err != nil {
return err
}
ss.HandleConn(c)
return nil
}
func (srv *server) newSSHServer() (*ssh.Server, error) {
ss := &ssh.Server{
Handler: srv.handleSSH,
RequestHandlers: map[string]ssh.RequestHandler{},
SubsystemHandlers: map[string]ssh.SubsystemHandler{},
ChannelHandlers: map[string]ssh.ChannelHandler{},
}
for k, v := range ssh.DefaultRequestHandlers {
srv.RequestHandlers[k] = v
ss.RequestHandlers[k] = v
}
for k, v := range ssh.DefaultChannelHandlers {
srv.ChannelHandlers[k] = v
ss.ChannelHandlers[k] = v
}
for k, v := range ssh.DefaultSubsystemHandlers {
srv.SubsystemHandlers[k] = v
ss.SubsystemHandlers[k] = v
}
keys, err := lb.GetSSH_HostKeys()
keys, err := srv.lb.GetSSH_HostKeys()
if err != nil {
return err
return nil, err
}
for _, signer := range keys {
srv.AddHostKey(signer)
ss.AddHostKey(signer)
}
srv.HandleConn(c)
return nil
return ss, nil
}
type server struct {

View File

@ -8,11 +8,24 @@
package tailssh
import (
"context"
"errors"
"fmt"
"net"
"os/exec"
"os/user"
"testing"
"time"
"github.com/gliderlabs/ssh"
"inet.af/netaddr"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnlocal"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/logger"
"tailscale.com/wgengine"
)
func TestMatchRule(t *testing.T) {
@ -155,3 +168,75 @@ func TestMatchRule(t *testing.T) {
}
func timePtr(t time.Time) *time.Time { return &t }
func TestSSH(t *testing.T) {
ml := new(tstest.MemLogger)
var logf logger.Logf = ml.Logf
eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
if err != nil {
t.Fatal(err)
}
lb, err := ipnlocal.NewLocalBackend(logf, "",
new(ipn.MemoryStore),
new(tsdial.Dialer),
eng, 0)
if err != nil {
t.Fatal(err)
}
defer lb.Shutdown()
dir := t.TempDir()
lb.SetVarRoot(dir)
srv := &server{lb, logf}
ss, err := srv.newSSHServer()
if err != nil {
t.Fatal(err)
}
u, err := user.Current()
if err != nil {
t.Fatal(err)
}
ci := &sshConnInfo{
sshUser: "test",
srcIP: netaddr.MustParseIP("1.2.3.4"),
node: &tailcfg.Node{},
uprof: &tailcfg.UserProfile{},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ss.Handler = func(s ssh.Session) {
srv.handleAcceptedSSH(ctx, s, ci, u)
}
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
port := ln.Addr().(*net.TCPAddr).Port
go func() {
for {
c, err := ln.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Errorf("Accept: %v", err)
}
return
}
go ss.HandleConn(c)
}
}()
got, err := exec.Command("ssh",
"-p", fmt.Sprint(port),
"-o", "StrictHostKeyChecking=no",
"user@127.0.0.1", "env").CombinedOutput()
if err != nil {
t.Fatal(err)
}
t.Logf("Got: %s", got)
}