diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 08ddd455..067ed83d 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -7,7 +7,6 @@ import ( "net/http" "os" "os/exec" - "path/filepath" "runtime" "syscall" "time" @@ -180,11 +179,10 @@ func finishUpdate(ctx context.Context) { cleanup(ctx) cleanupAlways() - exeName := "AdGuardHome" - if runtime.GOOS == "windows" { - exeName = "AdGuardHome.exe" + curBinName, err := os.Executable() + if err != nil { + log.Fatalf("executable path request failed: %s", err) } - curBinName := filepath.Join(Context.workDir, exeName) if runtime.GOOS == "windows" { if Context.runningAsService { @@ -192,7 +190,7 @@ func finishUpdate(ctx context.Context) { // we can't restart the service via "kardianos/service" package - it kills the process first // we can't start a new instance - Windows doesn't allow it cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome") - err := cmd.Start() + err = cmd.Start() if err != nil { log.Fatalf("exec.Command() failed: %s", err) } @@ -204,14 +202,14 @@ func finishUpdate(ctx context.Context) { cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - err := cmd.Start() + err = cmd.Start() if err != nil { log.Fatalf("exec.Command() failed: %s", err) } os.Exit(0) } else { log.Info("Restarting: %v", os.Args) - err := syscall.Exec(curBinName, os.Args, os.Environ()) + err = syscall.Exec(curBinName, os.Args, os.Environ()) if err != nil { log.Fatalf("syscall.Exec() failed: %s", err) } diff --git a/internal/updater/updater.go b/internal/updater/updater.go index d975d977..5077787c 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -111,7 +111,12 @@ func (u *Updater) Update() (err error) { log.Info("updater: updating") defer func() { log.Info("updater: finished; errors: %v", err) }() - err = u.prepare() + execPath, err := os.Executable() + if err != nil { + return err + } + + err = u.prepare(filepath.Base(execPath)) if err != nil { return err } @@ -162,7 +167,8 @@ func (u *Updater) VersionCheckURL() (vcu string) { return u.versionCheckURL } -func (u *Updater) prepare() (err error) { +// prepare fills all necessary fields in Updater object. +func (u *Updater) prepare(exeName string) (err error) { u.updateDir = filepath.Join(u.workDir, fmt.Sprintf("agh-update-%s", u.newVersion)) _, pkgNameOnly := filepath.Split(u.packageURL) @@ -173,13 +179,13 @@ func (u *Updater) prepare() (err error) { u.packageName = filepath.Join(u.updateDir, pkgNameOnly) u.backupDir = filepath.Join(u.workDir, "agh-backup") - exeName := "AdGuardHome" + updateExeName := "AdGuardHome" if u.goos == "windows" { - exeName = "AdGuardHome.exe" + updateExeName = "AdGuardHome.exe" } u.backupExeName = filepath.Join(u.backupDir, exeName) - u.updateExeName = filepath.Join(u.updateDir, exeName) + u.updateExeName = filepath.Join(u.updateDir, updateExeName) log.Debug( "updater: updating from %s to %s using url: %s", @@ -188,7 +194,6 @@ func (u *Updater) prepare() (err error) { u.packageURL, ) - // TODO(a.garipov): Use os.Args[0] instead? u.currentExeName = filepath.Join(u.workDir, exeName) _, err = os.Stat(u.currentExeName) if err != nil { diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index 771fb6d4..b3268f2f 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -131,7 +131,7 @@ func TestUpdate(t *testing.T) { u.newVersion = "v0.103.1" u.packageURL = fakeURL.String() - require.NoError(t, u.prepare()) + require.NoError(t, u.prepare("AdGuardHome")) u.currentExeName = filepath.Join(wd, "AdGuardHome") @@ -209,7 +209,7 @@ func TestUpdateWindows(t *testing.T) { u.newVersion = "v0.103.1" u.packageURL = fakeURL.String() - require.NoError(t, u.prepare()) + require.NoError(t, u.prepare("AdGuardHome.exe")) u.currentExeName = filepath.Join(wd, "AdGuardHome.exe")