diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 6f960cc71..4178c1e06 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -203,16 +203,19 @@ func (srv *server) handleAcceptedSSH(ctx context.Context, s ssh.Session, ci *ssh return } cmd = exec.Command(loginShell(lu.Uid)) + if rawCmd := s.RawCommand(); rawCmd != "" { + cmd.Args = append(cmd.Args, "-c", rawCmd) + } } else { if rawCmd := s.RawCommand(); rawCmd != "" { cmd = exec.Command("/usr/bin/env", "su", "-c", rawCmd, localUser) - cmd.Dir = lu.HomeDir - cmd.Env = append(cmd.Env, envForUser(lu)...) // TODO: and Env for PATH, SSH_CONNECTION, SSH_CLIENT, XDG_SESSION_TYPE, XDG_*, etc } else { cmd = exec.Command("/usr/bin/env", "su", "-", localUser) } } + cmd.Dir = lu.HomeDir + cmd.Env = append(cmd.Env, envForUser(lu)...) if ptyReq.Term != "" { cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) } @@ -397,7 +400,7 @@ func loginShell(uid string) string { // out is "root:x:0:0:root:/root:/bin/bash" f := strings.SplitN(string(out), ":", 10) if len(f) > 6 { - return f[6] // shell + return strings.TrimSpace(f[6]) // shell } } if e := os.Getenv("SHELL"); e != "" {