cmd/tsconnect: add progress and connection callbacks

Allows UI to display slightly more fine-grained progress when the SSH
connection is being established.

Updates tailscale/corp#7186

Signed-off-by: Mihai Parparita <mihai@tailscale.com>
This commit is contained in:
Mihai Parparita 2022-10-17 15:02:59 -07:00 committed by Mihai Parparita
parent 246274b8e9
commit 7741e9feb0
4 changed files with 35 additions and 7 deletions

View File

@ -46,7 +46,12 @@ function SSHSession({
const ref = useRef<HTMLDivElement>(null) const ref = useRef<HTMLDivElement>(null)
useEffect(() => { useEffect(() => {
if (ref.current) { if (ref.current) {
runSSHSession(ref.current, def, ipn, onDone, (err) => console.error(err)) runSSHSession(ref.current, def, ipn, {
onConnectionProgress: (p) => console.log("Connection progress", p),
onConnected() {},
onError: (err) => console.error(err),
onDone,
})
} }
}, [ref]) }, [ref])

View File

@ -9,12 +9,18 @@ export type SSHSessionDef = {
timeoutSeconds?: number timeoutSeconds?: number
} }
export type SSHSessionCallbacks = {
onConnectionProgress: (messsage: string) => void
onConnected: () => void
onDone: () => void
onError?: (err: string) => void
}
export function runSSHSession( export function runSSHSession(
termContainerNode: HTMLDivElement, termContainerNode: HTMLDivElement,
def: SSHSessionDef, def: SSHSessionDef,
ipn: IPN, ipn: IPN,
onDone: () => void, callbacks: SSHSessionCallbacks,
onError?: (err: string) => void,
terminalOptions?: ITerminalOptions terminalOptions?: ITerminalOptions
) { ) {
const parentWindow = termContainerNode.ownerDocument.defaultView ?? window const parentWindow = termContainerNode.ownerDocument.defaultView ?? window
@ -49,7 +55,7 @@ export function runSSHSession(
term.write(input) term.write(input)
}, },
writeErrorFn(err) { writeErrorFn(err) {
onError?.(err) callbacks.onError?.(err)
term.write(err) term.write(err)
}, },
setReadFn(hook) { setReadFn(hook) {
@ -57,13 +63,15 @@ export function runSSHSession(
}, },
rows: term.rows, rows: term.rows,
cols: term.cols, cols: term.cols,
onConnectionProgress: callbacks.onConnectionProgress,
onConnected: callbacks.onConnected,
onDone() { onDone() {
resizeObserver?.disconnect() resizeObserver?.disconnect()
term.dispose() term.dispose()
if (handleUnload) { if (handleUnload) {
parentWindow.removeEventListener("unload", handleUnload) parentWindow.removeEventListener("unload", handleUnload)
} }
onDone() callbacks.onDone()
}, },
timeoutSeconds: def.timeoutSeconds, timeoutSeconds: def.timeoutSeconds,
}) })

View File

@ -25,6 +25,8 @@ declare global {
cols: number cols: number
/** Defaults to 5 seconds */ /** Defaults to 5 seconds */
timeoutSeconds?: number timeoutSeconds?: number
onConnectionProgress: (message: string) => void
onConnected: () => void
onDone: () => void onDone: () => void
} }
): IPNSSHSession ): IPNSSHSession

View File

@ -364,15 +364,21 @@ func (s *jsSSHSession) Run() {
if jsTimeoutSeconds := s.termConfig.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber { if jsTimeoutSeconds := s.termConfig.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber {
timeoutSeconds = jsTimeoutSeconds.Float() timeoutSeconds = jsTimeoutSeconds.Float()
} }
onConnectionProgress := s.termConfig.Get("onConnectionProgress")
onConnected := s.termConfig.Get("onConnected")
onDone := s.termConfig.Get("onDone") onDone := s.termConfig.Get("onDone")
defer onDone.Invoke() defer onDone.Invoke()
writeError := func(label string, err error) { writeError := func(label string, err error) {
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err)) writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
} }
reportProgress := func(message string) {
onConnectionProgress.Invoke(message)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second))) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second)))
defer cancel() defer cancel()
reportProgress(fmt.Sprintf("Connecting to %s…", strings.Split(s.host, ".")[0]))
c, err := s.jsIPN.dialer.UserDial(ctx, "tcp", net.JoinHostPort(s.host, "22")) c, err := s.jsIPN.dialer.UserDial(ctx, "tcp", net.JoinHostPort(s.host, "22"))
if err != nil { if err != nil {
writeError("Dial", err) writeError("Dial", err)
@ -381,10 +387,16 @@ func (s *jsSSHSession) Run() {
defer c.Close() defer c.Close()
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Host keys are not used with Tailscale SSH, but we can use this
// callback to know that the connection has been established.
reportProgress("SSH connection established…")
return nil
},
User: s.username, User: s.username,
} }
reportProgress("Starting SSH client…")
sshConn, _, _, err := ssh.NewClientConn(c, s.host, config) sshConn, _, _, err := ssh.NewClientConn(c, s.host, config)
if err != nil { if err != nil {
writeError("SSH Connection", err) writeError("SSH Connection", err)
@ -442,6 +454,7 @@ func (s *jsSSHSession) Run() {
return return
} }
onConnected.Invoke()
err = session.Wait() err = session.Wait()
if err != nil { if err != nil {
writeError("Wait", err) writeError("Wait", err)