tailscale/util/winutil/restartmgr_windows.go

843 lines
25 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package winutil
import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"unicode/utf16"
"unsafe"
"github.com/dblohm7/wingoes"
"golang.org/x/sys/windows"
"tailscale.com/types/logger"
"tailscale.com/util/multierr"
)
var (
// ErrDefunctProcess is returned by (*UniqueProcess).AsRestartableProcess
// when the process no longer exists.
ErrDefunctProcess = errors.New("process is defunct")
// ErrProcessNotRestartable is returned by (*UniqueProcess).AsRestartableProcess
// when the process has previously indicated that it must not be restarted
// during a patch/upgrade.
ErrProcessNotRestartable = errors.New("process is not restartable")
)
// Implementation note: the code in this file will be invoked from within
// MSI custom actions, so please try to return windows.Errno error codes
// whenever possible; this makes the action return more accurate errors to
// the installer engine.
const (
_RESTART_NO_CRASH = 1
_RESTART_NO_HANG = 2
_RESTART_NO_PATCH = 4
_RESTART_NO_REBOOT = 8
)
func registerForRestart(opts RegisterForRestartOpts) error {
var flags uint32
if !opts.RestartOnCrash {
flags |= _RESTART_NO_CRASH
}
if !opts.RestartOnHang {
flags |= _RESTART_NO_HANG
}
if !opts.RestartOnUpgrade {
flags |= _RESTART_NO_PATCH
}
if !opts.RestartOnReboot {
flags |= _RESTART_NO_REBOOT
}
var cmdLine *uint16
if opts.UseCmdLineArgs {
if len(opts.CmdLineArgs) == 0 {
// re-use our current args, excluding the exe name itself
opts.CmdLineArgs = os.Args[1:]
}
var b strings.Builder
for _, arg := range opts.CmdLineArgs {
if b.Len() > 0 {
b.WriteByte(' ')
}
b.WriteString(windows.EscapeArg(arg))
}
if b.Len() > 0 {
var err error
cmdLine, err = windows.UTF16PtrFromString(b.String())
if err != nil {
return err
}
}
}
hr := registerApplicationRestart(cmdLine, flags)
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() {
return e
}
return nil
}
type _RMHANDLE uint32
// See https://web.archive.org/web/20231128212837/https://learn.microsoft.com/en-us/windows/win32/rstmgr/using-restart-manager-with-a-secondary-installer
const _INVALID_RMHANDLE = ^_RMHANDLE(0)
type _RM_UNIQUE_PROCESS struct {
PID uint32
ProcessStartTime windows.Filetime
}
type _RM_APP_TYPE int32
const (
_RmUnknownApp _RM_APP_TYPE = 0
_RmMainWindow _RM_APP_TYPE = 1
_RmOtherWindow _RM_APP_TYPE = 2
_RmService _RM_APP_TYPE = 3
_RmExplorer _RM_APP_TYPE = 4
_RmConsole _RM_APP_TYPE = 5
_RmCritical _RM_APP_TYPE = 1000
)
type _RM_APP_STATUS uint32
const (
//lint:ignore U1000 maps to a win32 API
_RmStatusUnknown _RM_APP_STATUS = 0x0
_RmStatusRunning _RM_APP_STATUS = 0x1
_RmStatusStopped _RM_APP_STATUS = 0x2
_RmStatusStoppedOther _RM_APP_STATUS = 0x4
_RmStatusRestarted _RM_APP_STATUS = 0x8
_RmStatusErrorOnStop _RM_APP_STATUS = 0x10
_RmStatusErrorOnRestart _RM_APP_STATUS = 0x20
_RmStatusShutdownMasked _RM_APP_STATUS = 0x40
_RmStatusRestartMasked _RM_APP_STATUS = 0x80
)
type _RM_PROCESS_INFO struct {
Process _RM_UNIQUE_PROCESS
AppName [256]uint16
ServiceShortName [64]uint16
AppType _RM_APP_TYPE
AppStatus _RM_APP_STATUS
TSSessionID uint32
Restartable int32 // Win32 BOOL
}
// RestartManagerSession represents an open Restart Manager session.
type RestartManagerSession interface {
io.Closer
// AddPaths adds the fully-qualified paths in fqPaths to the set of binaries
// that will be monitored by this restart manager session. NOTE: This
// method is expensive to call, so it is better to make a single call with
// a larger slice than to make multiple calls with smaller slices.
AddPaths(fqPaths []string) error
// AffectedProcesses returns the UniqueProcess information for all running
// processes that utilize the binaries previously specified by calls to
// AddPaths.
AffectedProcesses() ([]UniqueProcess, error)
// Key returns the session key associated with this instance.
Key() string
}
// rmSession encapsulates the necessary information to represent an open
// restart manager session.
//
// Implementation note: rmSession methods that return errors should use
// windows.Errno codes whenever possible, as we call them from the custom
// action DLL. MSI custom actions are expected to return windows.Errno values;
// to ensure our compliance with this expectation, we should also use those
// values. Failure to do so will result in a generic windows.Errno being
// returned to the Windows Installer, which obviously is less than ideal.
type rmSession struct {
session _RMHANDLE
key string
logf logger.Logf
}
const _CCH_RM_SESSION_KEY = 32 // (excludes NUL terminator)
// NewRestartManagerSession creates a new RestartManagerSession that utilizes
// logf for logging.
func NewRestartManagerSession(logf logger.Logf) (RestartManagerSession, error) {
var sessionKeyBuf [_CCH_RM_SESSION_KEY + 1]uint16
result := rmSession{
logf: logf,
}
if err := rmStartSession(&result.session, 0, &sessionKeyBuf[0]); err != nil {
return nil, err
}
result.key = windows.UTF16ToString(sessionKeyBuf[:_CCH_RM_SESSION_KEY])
return &result, nil
}
// AttachRestartManagerSession opens a connection to an existing session
// specified by sessionKey, using logf for logging.
func AttachRestartManagerSession(logf logger.Logf, sessionKey string) (RestartManagerSession, error) {
sessionKey16, err := windows.UTF16PtrFromString(sessionKey)
if err != nil {
return nil, err
}
result := rmSession{
key: sessionKey,
logf: logf,
}
if err := rmJoinSession(&result.session, sessionKey16); err != nil {
return nil, err
}
return &result, nil
}
func (rms *rmSession) Close() error {
if rms == nil || rms.session == _INVALID_RMHANDLE {
return nil
}
if err := rmEndSession(rms.session); err != nil {
return err
}
rms.session = _INVALID_RMHANDLE
return nil
}
func (rms *rmSession) Key() string {
return rms.key
}
func (rms *rmSession) AffectedProcesses() ([]UniqueProcess, error) {
infos, err := rms.processList()
if err != nil {
return nil, err
}
result := make([]UniqueProcess, 0, len(infos))
for _, info := range infos {
result = append(result, UniqueProcess{
_RM_UNIQUE_PROCESS: info.Process,
CanReceiveGUIMsgs: info.AppType == _RmMainWindow || info.AppType == _RmOtherWindow,
})
}
return result, nil
}
func (rms *rmSession) processList() ([]_RM_PROCESS_INFO, error) {
const maxAttempts = 5
var avail, rebootReasons uint32
needed := uint32(1)
var buf []_RM_PROCESS_INFO
err := error(windows.ERROR_MORE_DATA)
numAttempts := 0
for err == windows.ERROR_MORE_DATA && numAttempts < maxAttempts {
numAttempts++
buf = make([]_RM_PROCESS_INFO, needed)
avail = needed
err = rmGetList(rms.session, &needed, &avail, unsafe.SliceData(buf), &rebootReasons)
}
if err != nil {
if err == windows.ERROR_SESSION_CREDENTIAL_CONFLICT {
// Add some more context about the meaning of this error.
err = fmt.Errorf("%w (the Restart Manager does not permit calling RmGetList from a process that did not originally create the session)", err)
}
return nil, err
}
return buf[:avail], nil
}
func (rms *rmSession) AddPaths(fqPaths []string) error {
if len(fqPaths) == 0 {
return nil
}
fqPaths16 := make([]*uint16, 0, len(fqPaths))
for _, fqPath := range fqPaths {
if !filepath.IsAbs(fqPath) {
return fmt.Errorf("%w: paths must be fully-qualified", windows.ERROR_BAD_PATHNAME)
}
fqPath16, err := windows.UTF16PtrFromString(fqPath)
if err != nil {
return err
}
fqPaths16 = append(fqPaths16, fqPath16)
}
return rmRegisterResources(rms.session, uint32(len(fqPaths16)), unsafe.SliceData(fqPaths16), 0, nil, 0, nil)
}
// UniqueProcess contains the necessary information to uniquely identify a
// process in the face of potential PID reuse.
type UniqueProcess struct {
_RM_UNIQUE_PROCESS
// CanReceiveGUIMsgs is true when the process has open top-level windows.
CanReceiveGUIMsgs bool
}
// AsRestartableProcess obtains a RestartableProcess populated using the
// information obtained from up.
func (up *UniqueProcess) AsRestartableProcess() (*RestartableProcess, error) {
// We need PROCESS_QUERY_INFORMATION instead of PROCESS_QUERY_LIMITED_INFORMATION
// in order for ProcessImageName to be able to work from within a privileged
// Windows Installer process.
// We need PROCESS_VM_READ for GetApplicationRestartSettings.
// We need PROCESS_TERMINATE and SYNCHRONIZE to terminate the process and
// to be able to wait for the terminated process's handle to signal.
access := uint32(windows.PROCESS_QUERY_INFORMATION | windows.PROCESS_TERMINATE | windows.PROCESS_VM_READ | windows.SYNCHRONIZE)
h, err := windows.OpenProcess(access, false, up.PID)
if err != nil {
return nil, fmt.Errorf("OpenProcess(%d[%#X]): %w", up.PID, up.PID, err)
}
defer func() {
if h == 0 {
return
}
windows.CloseHandle(h)
}()
var creationTime, exitTime, kernelTime, userTime windows.Filetime
if err := windows.GetProcessTimes(h, &creationTime, &exitTime, &kernelTime, &userTime); err != nil {
return nil, fmt.Errorf("GetProcessTimes: %w", err)
}
if creationTime != up.ProcessStartTime {
// The PID has been reused and does not actually reference the original process.
return nil, ErrDefunctProcess
}
var tok windows.Token
if err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &tok); err != nil {
return nil, fmt.Errorf("OpenProcessToken: %w", err)
}
defer tok.Close()
tsSessionID, err := TSSessionID(tok)
if err != nil {
return nil, fmt.Errorf("TSSessionID: %w", err)
}
logonSessionID, err := LogonSessionID(tok)
if err != nil {
return nil, fmt.Errorf("LogonSessionID: %w", err)
}
img, err := ProcessImageName(h)
if err != nil {
return nil, fmt.Errorf("ProcessImageName: %w", err)
}
const _RESTART_MAX_CMD_LINE = 1024
var cmdLine [_RESTART_MAX_CMD_LINE]uint16
cmdLineLen := uint32(len(cmdLine))
var rmFlags uint32
hr := getApplicationRestartSettings(h, &cmdLine[0], &cmdLineLen, &rmFlags)
// Not found is not an error; it just means that the app never set any restart settings.
if e := wingoes.ErrorFromHRESULT(hr); e.Failed() && e != wingoes.ErrorFromErrno(windows.ERROR_NOT_FOUND) {
return nil, fmt.Errorf("GetApplicationRestartSettings: %w", error(e))
}
if (rmFlags & _RESTART_NO_PATCH) != 0 {
// The application explicitly stated that it cannot be restarted during
// an upgrade.
return nil, ErrProcessNotRestartable
}
var logonSID string
// Non-fatal, so we'll proceed with best-effort.
if tokenGroups, err := tok.GetTokenGroups(); err == nil {
for _, group := range tokenGroups.AllGroups() {
if (group.Attributes & windows.SE_GROUP_LOGON_ID) != 0 {
logonSID = group.Sid.String()
break
}
}
}
var userSID string
// Non-fatal, so we'll proceed with best-effort.
if tokenUser, err := tok.GetTokenUser(); err == nil {
// Save the user's SID so that we can later check it against the currently
// logged-in Tailscale profile.
userSID = tokenUser.User.Sid.String()
}
result := &RestartableProcess{
Process: *up,
SessionInfo: SessionID{
LogonSession: logonSessionID,
TSSession: tsSessionID,
},
CommandLineInfo: CommandLineInfo{
ExePath: img,
Args: windows.UTF16ToString(cmdLine[:cmdLineLen]),
},
LogonSID: logonSID,
UserSID: userSID,
handle: h,
}
runtime.SetFinalizer(result, func(rp *RestartableProcess) { rp.Close() })
h = 0
return result, nil
}
// RestartableProcess contains the necessary information to uniquely identify
// an existing process, as well as the necessary information to be able to
// terminate it and later start a new instance in the identical logon session
// to the previous instance.
type RestartableProcess struct {
// Process uniquely identifies the existing process.
Process UniqueProcess
// SessionInfo uniquely identifies the Terminal Services (RDP) and logon
// sessions the existing process is running under.
SessionInfo SessionID
// CommandLineInfo contains the command line information necessary for restarting.
CommandLineInfo CommandLineInfo
// LogonSID contains the stringified SID of the existing process's token's logon session.
LogonSID string
// UserSID contains the stringified SID of the existing process's token's user.
UserSID string
// handle specifies the Win32 HANDLE associated with the existing process.
// When non-zero, it includes access rights for querying, terminating, and synchronizing.
handle windows.Handle
// hasExitCode is true when the exitCode field is valid.
hasExitCode bool
// exitCode contains exit code returned by this RestartableProcess once
// its termination has been recorded by (RestartableProcesses).Terminate.
// It is only valid when hasExitCode == true.
exitCode uint32
}
func (rp *RestartableProcess) Close() error {
if rp.handle == 0 {
return nil
}
windows.CloseHandle(rp.handle)
runtime.SetFinalizer(rp, nil)
rp.handle = 0
return nil
}
// RestartableProcesses is a map of PID to *RestartableProcess instance.
type RestartableProcesses map[uint32]*RestartableProcess
// NewRestartableProcesses instantiates a new RestartableProcesses.
func NewRestartableProcesses() RestartableProcesses {
return make(RestartableProcesses)
}
// Add inserts rp into rps.
func (rps RestartableProcesses) Add(rp *RestartableProcess) {
if rp != nil {
rps[rp.Process.PID] = rp
}
}
// Delete removes rp from rps.
func (rps RestartableProcesses) Delete(rp *RestartableProcess) {
if rp != nil {
delete(rps, rp.Process.PID)
}
}
// Close invokes (*RestartableProcess).Close on every value in rps, and then
// clears rps.
func (rps RestartableProcesses) Close() error {
for _, v := range rps {
v.Close()
}
clear(rps)
return nil
}
// _MAXIMUM_WAIT_OBJECTS is the Win32 constant for the maximum number of
// handles that a call to WaitForMultipleObjects may receive at once.
const _MAXIMUM_WAIT_OBJECTS = 64
// Terminate forcibly terminates all processes in rps using exitCode, and then
// waits for their process handles to signal, up to timeout.
func (rps RestartableProcesses) Terminate(logf logger.Logf, exitCode uint32, timeout time.Duration) error {
if len(rps) == 0 {
return nil
}
millis, err := wingoes.DurationToTimeoutMilliseconds(timeout)
if err != nil {
return err
}
errs := make([]error, 0, len(rps))
procs := make([]*RestartableProcess, 0, len(rps))
handles := make([]windows.Handle, 0, len(rps))
for _, v := range rps {
if err := windows.TerminateProcess(v.handle, exitCode); err != nil {
if err == windows.ERROR_ACCESS_DENIED {
// If v terminated before we attempted to terminate, we'll receive
// ERROR_ACCESS_DENIED, which is not really an error worth reporting in
// our use case. Just obtain the exit code and then close the process.
if err := windows.GetExitCodeProcess(v.handle, &v.exitCode); err != nil {
logf("GetExitCodeProcess failed: %v", err)
} else {
v.hasExitCode = true
}
v.Close()
} else {
errs = append(errs, &terminationError{rp: v, err: err})
}
continue
}
procs = append(procs, v)
handles = append(handles, v.handle)
}
for len(handles) > 0 {
// WaitForMultipleObjects can only wait on _MAXIMUM_WAIT_OBJECTS handles per
// call, so we batch them as necessary.
count := uint32(min(len(handles), _MAXIMUM_WAIT_OBJECTS))
waitCode, err := windows.WaitForMultipleObjects(handles[:count], true, millis)
if err != nil {
errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", err))
break
}
if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT {
errs = append(errs, fmt.Errorf("waiting on terminated process handles: %w", error(e)))
break
}
if waitCode >= windows.WAIT_OBJECT_0 && waitCode < (windows.WAIT_OBJECT_0+count) {
// The first count process handles have all been signaled. Close them out.
for _, proc := range procs[:count] {
if err := windows.GetExitCodeProcess(proc.handle, &proc.exitCode); err != nil {
logf("GetExitCodeProcess failed: %v", err)
} else {
proc.hasExitCode = true
}
proc.Close()
}
procs = procs[count:]
handles = handles[count:]
continue
}
// We really shouldn't be reaching this point
panic(fmt.Sprintf("unexpected state from WaitForMultipleObjects: %d", waitCode))
}
if len(errs) != 0 {
return multierr.New(errs...)
}
return nil
}
type terminationError struct {
rp *RestartableProcess
err error
}
func (te *terminationError) Error() string {
pid := te.rp.Process.PID
return fmt.Sprintf("terminating process %d (%#X): %v", pid, pid, te.err)
}
func (te *terminationError) Unwrap() error {
return te.err
}
// SessionID encapsulates the necessary information for uniquely identifying
// sessions. In particular, SessionID contains enough information to detect
// reuse of Terminal Service session IDs.
type SessionID struct {
// LogonSession is the NT logon session ID.
LogonSession windows.LUID
// TSSession is the terminal services session ID.
TSSession uint32
}
// OpenToken obtains the security token associated with sessID.
func (sessID *SessionID) OpenToken() (windows.Token, error) {
var token windows.Token
if err := windows.WTSQueryUserToken(sessID.TSSession, &token); err != nil {
return 0, err
}
var err error
defer func() {
if err != nil {
token.Close()
}
}()
tokenLogonSession, err := LogonSessionID(token)
if err != nil {
return 0, err
}
if tokenLogonSession != sessID.LogonSession {
err = windows.ERROR_NO_SUCH_LOGON_SESSION
return 0, err
}
return token, nil
}
// ContainsToken determines whether token is contained within sessID.
func (sessID *SessionID) ContainsToken(token windows.Token) (bool, error) {
tokenTSSessionID, err := TSSessionID(token)
if err != nil {
return false, err
}
if tokenTSSessionID != sessID.TSSession {
return false, nil
}
tokenLogonSession, err := LogonSessionID(token)
if err != nil {
return false, err
}
return tokenLogonSession == sessID.LogonSession, nil
}
// This is the Window Station and Desktop within a particular session that must
// be specified for interactive processes: "Winsta0\\default\x00"
var defaultDesktop = unsafe.SliceData([]uint16{'W', 'i', 'n', 's', 't', 'a', '0', '\\', 'd', 'e', 'f', 'a', 'u', 'l', 't', 0})
// CommandLineInfo manages the necessary information for creating a Win32
// process using a specific command line.
type CommandLineInfo struct {
// ExePath must be a fully-qualified path to a Windows executable binary.
ExePath string
// Args must be any arguments supplied to the process, excluding the
// path to the binary itself. Args must be properly quoted according to
// Windows path rules. To create a properly quoted Args from scratch, call the
// SetArgs method instead.
Args string `json:",omitempty"`
}
// SetArgs converts args to a string quoted as necessary to satisfy the rules
// for Win32 command lines, and sets cli.Args to that string.
func (cli *CommandLineInfo) SetArgs(args []string) {
var buf strings.Builder
for _, arg := range args {
if buf.Len() > 0 {
buf.WriteByte(' ')
}
buf.WriteString(windows.EscapeArg(arg))
}
cli.Args = buf.String()
}
// Validate ensures that cli.ExePath contains an absolute path.
func (cli *CommandLineInfo) Validate() error {
if cli == nil {
return windows.ERROR_INVALID_PARAMETER
}
if !filepath.IsAbs(cli.ExePath) {
return fmt.Errorf("%w: CommandLineInfo requires absolute ExePath", windows.ERROR_BAD_PATHNAME)
}
return nil
}
// Resolve converts the information in cli to a format compatible with the Win32
// CreateProcess* family of APIs, as pointers to C-style UTF-16 strings. It also
// returns the full command line as a Go string for logging purposes.
func (cli *CommandLineInfo) Resolve() (exePath *uint16, cmdLine *uint16, cmdLineStr string, err error) {
// Resolve cmdLine first since that also does a Validate.
cmdLineStr, cmdLine, err = cli.resolveArgsAsUTF16Ptr()
if err != nil {
return nil, nil, "", err
}
exePath, err = windows.UTF16PtrFromString(cli.ExePath)
if err != nil {
return nil, nil, "", err
}
return exePath, cmdLine, cmdLineStr, nil
}
// resolveArgs quotes cli.ExePath as necessary, appends Args, and returns the result.
func (cli *CommandLineInfo) resolveArgs() (string, error) {
if err := cli.Validate(); err != nil {
return "", err
}
var cmdLineBuf strings.Builder
cmdLineBuf.WriteString(windows.EscapeArg(cli.ExePath))
if args := cli.Args; args != "" {
cmdLineBuf.WriteByte(' ')
cmdLineBuf.WriteString(args)
}
return cmdLineBuf.String(), nil
}
func (cli *CommandLineInfo) resolveArgsAsUTF16Ptr() (string, *uint16, error) {
s, err := cli.resolveArgs()
if err != nil {
return "", nil, err
}
s16, err := windows.UTF16PtrFromString(s)
if err != nil {
return "", nil, err
}
return s, s16, nil
}
// StartProcessInSession creates a new process using cmdLineInfo that will
// reside inside the session identified by sessID, with the security token whose
// logon is associated with sessID. The child process's environment will be
// inherited from the session token's environment.
func StartProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo) error {
return StartProcessInSessionWithHandler(sessID, cmdLineInfo, nil)
}
// PostCreateProcessHandler is a function that is invoked by
// StartProcessInSessionWithHandler when the child process has been successfully
// created. It is the responsibility of the handler to close the pi.Thread and
// pi.Process handles.
type PostCreateProcessHandler func(pi *windows.ProcessInformation)
// StartProcessInSessionWithHandler creates a new process using cmdLineInfo that
// will reside inside the session identified by sessID, with the security token
// whose logon is associated with sessID. The child process's environment will be
// inherited from the session token's environment. When the child process has
// been successfully created, handler is invoked with the windows.ProcessInformation
// that was returned by the OS.
func StartProcessInSessionWithHandler(sessID SessionID, cmdLineInfo CommandLineInfo, handler PostCreateProcessHandler) error {
pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0)
if err != nil {
return err
}
if handler != nil {
handler(pi)
return nil
}
windows.CloseHandle(pi.Process)
windows.CloseHandle(pi.Thread)
return nil
}
// RunProcessInSession creates a new process and waits up to timeout for that
// child process to complete its execution. The process is created using
// cmdLineInfo and will reside inside the session identified by sessID, with the
// security token whose logon is associated with sessID. The child process's
// environment will be inherited from the session token's environment.
func RunProcessInSession(sessID SessionID, cmdLineInfo CommandLineInfo, timeout time.Duration) (uint32, error) {
timeoutMillis, err := wingoes.DurationToTimeoutMilliseconds(timeout)
if err != nil {
return 1, err
}
pi, err := startProcessInSessionInternal(sessID, cmdLineInfo, 0)
if err != nil {
return 1, err
}
windows.CloseHandle(pi.Thread)
defer windows.CloseHandle(pi.Process)
waitCode, err := windows.WaitForSingleObject(pi.Process, timeoutMillis)
if err != nil {
return 1, fmt.Errorf("WaitForSingleObject: %w", err)
}
if e := windows.Errno(waitCode); e == windows.WAIT_TIMEOUT {
return 1, e
}
if waitCode != windows.WAIT_OBJECT_0 {
// This should not be possible; log
return 1, fmt.Errorf("unexpected state from WaitForSingleObject: %d", waitCode)
}
var exitCode uint32
if err := windows.GetExitCodeProcess(pi.Process, &exitCode); err != nil {
return 1, err
}
return exitCode, nil
}
func startProcessInSessionInternal(sessID SessionID, cmdLineInfo CommandLineInfo, extraFlags uint32) (*windows.ProcessInformation, error) {
if err := cmdLineInfo.Validate(); err != nil {
return nil, err
}
token, err := sessID.OpenToken()
if err != nil {
return nil, fmt.Errorf("(*SessionID).OpenToken: %w", err)
}
defer token.Close()
exePath16, commandLine16, _, err := cmdLineInfo.Resolve()
if err != nil {
return nil, fmt.Errorf("(*CommandLineInfo).Resolve(): %w", err)
}
wd16, err := windows.UTF16PtrFromString(filepath.Dir(cmdLineInfo.ExePath))
if err != nil {
return nil, fmt.Errorf("UTF16PtrFromString(wd): %w", err)
}
env, err := token.Environ(false)
if err != nil {
return nil, fmt.Errorf("token environment: %w", err)
}
env16 := newEnvBlock(env)
// The privileges in privNames are required for CreateProcessAsUser to be
// able to start processes as other users in other logon sessions.
privNames := []string{
"SeAssignPrimaryTokenPrivilege",
"SeIncreaseQuotaPrivilege",
}
dropPrivs, err := EnableCurrentThreadPrivileges(privNames)
if err != nil {
return nil, fmt.Errorf("EnableCurrentThreadPrivileges(%#v): %w", privNames, err)
}
defer dropPrivs()
createFlags := extraFlags | windows.CREATE_UNICODE_ENVIRONMENT | windows.DETACHED_PROCESS
si := windows.StartupInfo{
Cb: uint32(unsafe.Sizeof(windows.StartupInfo{})),
Desktop: defaultDesktop,
}
var pi windows.ProcessInformation
if err := windows.CreateProcessAsUser(token, exePath16, commandLine16, nil, nil,
false, createFlags, env16, wd16, &si, &pi); err != nil {
return nil, fmt.Errorf("CreateProcessAsUser: %w", err)
}
return &pi, nil
}
func newEnvBlock(env []string) *uint16 {
// Intentionally using bytes.Buffer here because we're writing nul bytes (the standard library does this too).
var buf bytes.Buffer
for _, v := range env {
buf.WriteString(v)
buf.WriteByte(0)
}
if buf.Len() == 0 {
// So that we end with a double-null in the empty env case
buf.WriteByte(0)
}
buf.WriteByte(0)
return unsafe.SliceData(utf16.Encode([]rune(string(buf.Bytes()))))
}