cmd/tsconnect: add progress and connection callbacks
Allows UI to display slightly more fine-grained progress when the SSH connection is being established. Updates tailscale/corp#7186 Signed-off-by: Mihai Parparita <mihai@tailscale.com>
This commit is contained in:
parent
246274b8e9
commit
7741e9feb0
|
@ -46,7 +46,12 @@ function SSHSession({
|
||||||
const ref = useRef<HTMLDivElement>(null)
|
const ref = useRef<HTMLDivElement>(null)
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (ref.current) {
|
if (ref.current) {
|
||||||
runSSHSession(ref.current, def, ipn, onDone, (err) => console.error(err))
|
runSSHSession(ref.current, def, ipn, {
|
||||||
|
onConnectionProgress: (p) => console.log("Connection progress", p),
|
||||||
|
onConnected() {},
|
||||||
|
onError: (err) => console.error(err),
|
||||||
|
onDone,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}, [ref])
|
}, [ref])
|
||||||
|
|
||||||
|
|
|
@ -9,12 +9,18 @@ export type SSHSessionDef = {
|
||||||
timeoutSeconds?: number
|
timeoutSeconds?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type SSHSessionCallbacks = {
|
||||||
|
onConnectionProgress: (messsage: string) => void
|
||||||
|
onConnected: () => void
|
||||||
|
onDone: () => void
|
||||||
|
onError?: (err: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
export function runSSHSession(
|
export function runSSHSession(
|
||||||
termContainerNode: HTMLDivElement,
|
termContainerNode: HTMLDivElement,
|
||||||
def: SSHSessionDef,
|
def: SSHSessionDef,
|
||||||
ipn: IPN,
|
ipn: IPN,
|
||||||
onDone: () => void,
|
callbacks: SSHSessionCallbacks,
|
||||||
onError?: (err: string) => void,
|
|
||||||
terminalOptions?: ITerminalOptions
|
terminalOptions?: ITerminalOptions
|
||||||
) {
|
) {
|
||||||
const parentWindow = termContainerNode.ownerDocument.defaultView ?? window
|
const parentWindow = termContainerNode.ownerDocument.defaultView ?? window
|
||||||
|
@ -49,7 +55,7 @@ export function runSSHSession(
|
||||||
term.write(input)
|
term.write(input)
|
||||||
},
|
},
|
||||||
writeErrorFn(err) {
|
writeErrorFn(err) {
|
||||||
onError?.(err)
|
callbacks.onError?.(err)
|
||||||
term.write(err)
|
term.write(err)
|
||||||
},
|
},
|
||||||
setReadFn(hook) {
|
setReadFn(hook) {
|
||||||
|
@ -57,13 +63,15 @@ export function runSSHSession(
|
||||||
},
|
},
|
||||||
rows: term.rows,
|
rows: term.rows,
|
||||||
cols: term.cols,
|
cols: term.cols,
|
||||||
|
onConnectionProgress: callbacks.onConnectionProgress,
|
||||||
|
onConnected: callbacks.onConnected,
|
||||||
onDone() {
|
onDone() {
|
||||||
resizeObserver?.disconnect()
|
resizeObserver?.disconnect()
|
||||||
term.dispose()
|
term.dispose()
|
||||||
if (handleUnload) {
|
if (handleUnload) {
|
||||||
parentWindow.removeEventListener("unload", handleUnload)
|
parentWindow.removeEventListener("unload", handleUnload)
|
||||||
}
|
}
|
||||||
onDone()
|
callbacks.onDone()
|
||||||
},
|
},
|
||||||
timeoutSeconds: def.timeoutSeconds,
|
timeoutSeconds: def.timeoutSeconds,
|
||||||
})
|
})
|
||||||
|
|
|
@ -25,6 +25,8 @@ declare global {
|
||||||
cols: number
|
cols: number
|
||||||
/** Defaults to 5 seconds */
|
/** Defaults to 5 seconds */
|
||||||
timeoutSeconds?: number
|
timeoutSeconds?: number
|
||||||
|
onConnectionProgress: (message: string) => void
|
||||||
|
onConnected: () => void
|
||||||
onDone: () => void
|
onDone: () => void
|
||||||
}
|
}
|
||||||
): IPNSSHSession
|
): IPNSSHSession
|
||||||
|
|
|
@ -364,15 +364,21 @@ func (s *jsSSHSession) Run() {
|
||||||
if jsTimeoutSeconds := s.termConfig.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber {
|
if jsTimeoutSeconds := s.termConfig.Get("timeoutSeconds"); jsTimeoutSeconds.Type() == js.TypeNumber {
|
||||||
timeoutSeconds = jsTimeoutSeconds.Float()
|
timeoutSeconds = jsTimeoutSeconds.Float()
|
||||||
}
|
}
|
||||||
|
onConnectionProgress := s.termConfig.Get("onConnectionProgress")
|
||||||
|
onConnected := s.termConfig.Get("onConnected")
|
||||||
onDone := s.termConfig.Get("onDone")
|
onDone := s.termConfig.Get("onDone")
|
||||||
defer onDone.Invoke()
|
defer onDone.Invoke()
|
||||||
|
|
||||||
writeError := func(label string, err error) {
|
writeError := func(label string, err error) {
|
||||||
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
writeErrorFn.Invoke(fmt.Sprintf("%s Error: %v\r\n", label, err))
|
||||||
}
|
}
|
||||||
|
reportProgress := func(message string) {
|
||||||
|
onConnectionProgress.Invoke(message)
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second)))
|
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSeconds*float64(time.Second)))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
reportProgress(fmt.Sprintf("Connecting to %s…", strings.Split(s.host, ".")[0]))
|
||||||
c, err := s.jsIPN.dialer.UserDial(ctx, "tcp", net.JoinHostPort(s.host, "22"))
|
c, err := s.jsIPN.dialer.UserDial(ctx, "tcp", net.JoinHostPort(s.host, "22"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError("Dial", err)
|
writeError("Dial", err)
|
||||||
|
@ -381,10 +387,16 @@ func (s *jsSSHSession) Run() {
|
||||||
defer c.Close()
|
defer c.Close()
|
||||||
|
|
||||||
config := &ssh.ClientConfig{
|
config := &ssh.ClientConfig{
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
|
// Host keys are not used with Tailscale SSH, but we can use this
|
||||||
|
// callback to know that the connection has been established.
|
||||||
|
reportProgress("SSH connection established…")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
User: s.username,
|
User: s.username,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reportProgress("Starting SSH client…")
|
||||||
sshConn, _, _, err := ssh.NewClientConn(c, s.host, config)
|
sshConn, _, _, err := ssh.NewClientConn(c, s.host, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError("SSH Connection", err)
|
writeError("SSH Connection", err)
|
||||||
|
@ -442,6 +454,7 @@ func (s *jsSSHSession) Run() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onConnected.Invoke()
|
||||||
err = session.Wait()
|
err = session.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError("Wait", err)
|
writeError("Wait", err)
|
||||||
|
|
Loading…
Reference in New Issue