diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 167b9caa..48974e2b 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -217,10 +217,13 @@ Algorithm of an update by command: * Copy the current configuration file to the directory we unpacked new AGH to * Check configuration compatibility by executing `./AGH --check-config`. If this command fails, we won't be able to update. * Create `backup-vXXX` directory and copy the current configuration file there - * Stop all tasks, including DNS server, DHCP server, HTTP server + * Copy supporting files (README, LICENSE, etc.) to backup directory + * Copy supporting files from the update directory to the current directory * Move the current binary file to backup directory * Note: if power fails here, AGH won't be able to start at system boot. Administrator has to fix it manually * Move new binary file to the current directory + * Send response to UI + * Stop all tasks, including DNS server, DHCP server, HTTP server * If AGH is running as a service, use service control functionality to restart * If AGH is not running as a service, use the current process arguments to start a new process * Exit process diff --git a/control_update.go b/control_update.go index 799c4968..1b428bce 100644 --- a/control_update.go +++ b/control_update.go @@ -305,11 +305,35 @@ func targzFileUnpack(tarfile, outdir string) ([]string, error) { return files, err2 } -// Perform an update procedure -func doUpdate(u *updateInfo) error { - log.Info("Updating from %s to %s. URL:%s Package:%s", - VersionString, u.newVer, u.pkgURL, u.pkgName) +func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error { + for _, f := range files { + _, name := filepath.Split(f) + if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { + continue + } + src := filepath.Join(srcdir, f) + if useSrcNameOnly { + src = filepath.Join(srcdir, name) + } + + dst := filepath.Join(dstdir, f) + if useDstNameOnly { + dst = filepath.Join(dstdir, name) + } + + err := copyFile(src, dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + log.Tracef("Copied: %s -> %s", src, dst) + } + return nil +} + +// Download package file and save it to disk +func getPackageFile(u *updateInfo) error { resp, err := client.Get(u.pkgURL) if err != nil { return fmt.Errorf("HTTP request failed: %s", err) @@ -329,19 +353,33 @@ func doUpdate(u *updateInfo) error { if err != nil { return fmt.Errorf("ioutil.WriteFile() failed: %s", err) } + return nil +} + +// Perform an update procedure +func doUpdate(u *updateInfo) error { + log.Info("Updating from %s to %s. URL:%s Package:%s", + VersionString, u.newVer, u.pkgURL, u.pkgName) + + var err error + err = getPackageFile(u) + if err != nil { + return err + } log.Tracef("Unpacking the package") _ = os.Mkdir(u.updateDir, 0755) _, file := filepath.Split(u.pkgName) + var files []string if strings.HasSuffix(file, ".zip") { - _, err = zipFileUnpack(u.pkgName, u.updateDir) + files, err = zipFileUnpack(u.pkgName, u.updateDir) if err != nil { return fmt.Errorf("zipFileUnpack() failed: %s", err) } } else if strings.HasSuffix(file, ".tar.gz") { - _, err = targzFileUnpack(u.pkgName, u.updateDir) + files, err = targzFileUnpack(u.pkgName, u.updateDir) if err != nil { - return fmt.Errorf("zipFileUnpack() failed: %s", err) + return fmt.Errorf("targzFileUnpack() failed: %s", err) } } else { return fmt.Errorf("Unknown package extension") @@ -365,6 +403,20 @@ func doUpdate(u *updateInfo) error { return fmt.Errorf("copyFile() failed: %s", err) } + // ./README.md -> backup/README.md + err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + config.ourWorkingDir, u.backupDir, err) + } + + // update/[AdGuardHome/]README.md -> ./README.md + err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + u.updateDir, config.ourWorkingDir, err) + } + log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) err = os.Rename(u.curBinName, u.bkpBinName) if err != nil { diff --git a/control_update_test.go b/control_update_test.go index 8c86852c..0e75562a 100644 --- a/control_update_test.go +++ b/control_update_test.go @@ -7,6 +7,7 @@ import ( func testDoUpdate(t *testing.T) { config.DNS.Port = 0 + config.ourWorkingDir = "." u := updateInfo{ pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz", pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz",