ssh/tailssh: add start of real ssh tests
Updates #3802 Change-Id: I9aea4250062d3a06ca7a5e71a81d31c27a988615 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
c9eca9451a
commit
6e4f3614cf
|
@ -40,32 +40,39 @@ import (
|
|||
|
||||
// Handle handles an SSH connection from c.
|
||||
func Handle(logf logger.Logf, lb *ipnlocal.LocalBackend, c net.Conn) error {
|
||||
sshd := &server{lb, logf}
|
||||
srv := &ssh.Server{
|
||||
Handler: sshd.handleSSH,
|
||||
srv := &server{lb, logf}
|
||||
ss, err := srv.newSSHServer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ss.HandleConn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (srv *server) newSSHServer() (*ssh.Server, error) {
|
||||
ss := &ssh.Server{
|
||||
Handler: srv.handleSSH,
|
||||
RequestHandlers: map[string]ssh.RequestHandler{},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{},
|
||||
}
|
||||
for k, v := range ssh.DefaultRequestHandlers {
|
||||
srv.RequestHandlers[k] = v
|
||||
ss.RequestHandlers[k] = v
|
||||
}
|
||||
for k, v := range ssh.DefaultChannelHandlers {
|
||||
srv.ChannelHandlers[k] = v
|
||||
ss.ChannelHandlers[k] = v
|
||||
}
|
||||
for k, v := range ssh.DefaultSubsystemHandlers {
|
||||
srv.SubsystemHandlers[k] = v
|
||||
ss.SubsystemHandlers[k] = v
|
||||
}
|
||||
keys, err := lb.GetSSH_HostKeys()
|
||||
keys, err := srv.lb.GetSSH_HostKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
for _, signer := range keys {
|
||||
srv.AddHostKey(signer)
|
||||
ss.AddHostKey(signer)
|
||||
}
|
||||
|
||||
srv.HandleConn(c)
|
||||
return nil
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
type server struct {
|
||||
|
|
|
@ -8,11 +8,24 @@
|
|||
package tailssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine"
|
||||
)
|
||||
|
||||
func TestMatchRule(t *testing.T) {
|
||||
|
@ -155,3 +168,75 @@ func TestMatchRule(t *testing.T) {
|
|||
}
|
||||
|
||||
func timePtr(t time.Time) *time.Time { return &t }
|
||||
|
||||
func TestSSH(t *testing.T) {
|
||||
ml := new(tstest.MemLogger)
|
||||
var logf logger.Logf = ml.Logf
|
||||
eng, err := wgengine.NewFakeUserspaceEngine(logf, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
lb, err := ipnlocal.NewLocalBackend(logf, "",
|
||||
new(ipn.MemoryStore),
|
||||
new(tsdial.Dialer),
|
||||
eng, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer lb.Shutdown()
|
||||
dir := t.TempDir()
|
||||
lb.SetVarRoot(dir)
|
||||
|
||||
srv := &server{lb, logf}
|
||||
ss, err := srv.newSSHServer()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ci := &sshConnInfo{
|
||||
sshUser: "test",
|
||||
srcIP: netaddr.MustParseIP("1.2.3.4"),
|
||||
node: &tailcfg.Node{},
|
||||
uprof: &tailcfg.UserProfile{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ss.Handler = func(s ssh.Session) {
|
||||
srv.handleAcceptedSSH(ctx, s, ci, u)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
port := ln.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
t.Errorf("Accept: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go ss.HandleConn(c)
|
||||
}
|
||||
}()
|
||||
|
||||
got, err := exec.Command("ssh",
|
||||
"-p", fmt.Sprint(port),
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"user@127.0.0.1", "env").CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("Got: %s", got)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue