diff --git a/AGHTechDoc.md b/AGHTechDoc.md index f4e4d4c7..48974e2b 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -217,10 +217,13 @@ Algorithm of an update by command: * Copy the current configuration file to the directory we unpacked new AGH to * Check configuration compatibility by executing `./AGH --check-config`. If this command fails, we won't be able to update. * Create `backup-vXXX` directory and copy the current configuration file there - * Stop all tasks, including DNS server, DHCP server, HTTP server + * Copy supporting files (README, LICENSE, etc.) to backup directory + * Copy supporting files from the update directory to the current directory * Move the current binary file to backup directory * Note: if power fails here, AGH won't be able to start at system boot. Administrator has to fix it manually * Move new binary file to the current directory + * Send response to UI + * Stop all tasks, including DNS server, DHCP server, HTTP server * If AGH is running as a service, use service control functionality to restart * If AGH is not running as a service, use the current process arguments to start a new process * Exit process @@ -250,6 +253,8 @@ Example of version.json data: "selfupdate_min_version": "v0.0" } +Server can only auto-update if the current version is equal or higher than `selfupdate_min_version`. + Request: GET /control/version.json diff --git a/client/src/actions/index.js b/client/src/actions/index.js index d683a3f5..3ceed2c4 100644 --- a/client/src/actions/index.js +++ b/client/src/actions/index.js @@ -160,7 +160,9 @@ export const getUpdateRequest = createAction('GET_UPDATE_REQUEST'); export const getUpdateFailure = createAction('GET_UPDATE_FAILURE'); export const getUpdateSuccess = createAction('GET_UPDATE_SUCCESS'); -export const getUpdate = () => async (dispatch) => { +export const getUpdate = () => async (dispatch, getState) => { + const { dnsVersion } = getState().dashboard; + dispatch(getUpdateRequest()); try { await apiClient.getUpdate(); @@ -185,9 +187,13 @@ export const getUpdate = () => async (dispatch) => { axios.get('control/status') .then((response) => { rmTimeout(timeout); - if (response) { - dispatch(getUpdateSuccess()); - window.location.reload(true); + if (response && response.status === 200) { + const responseVersion = response.data && response.data.version; + + if (dnsVersion !== responseVersion) { + dispatch(getUpdateSuccess()); + window.location.reload(true); + } } timeout = setRecursiveTimeout(CHECK_TIMEOUT, count += 1); }) diff --git a/control_update.go b/control_update.go index 987457da..6a819a09 100644 --- a/control_update.go +++ b/control_update.go @@ -1,7 +1,9 @@ package main import ( + "archive/tar" "archive/zip" + "compress/gzip" "encoding/json" "fmt" "io" @@ -34,13 +36,14 @@ func getVersionResp(data []byte) []byte { ret["new_version"], ok1 = versionJSON["version"].(string) ret["announcement"], ok2 = versionJSON["announcement"].(string) ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string) - if !ok1 || !ok2 || !ok3 { + selfUpdateMinVersion, ok4 := versionJSON["selfupdate_min_version"].(string) + if !ok1 || !ok2 || !ok3 || !ok4 { log.Error("version.json: invalid data") return []byte{} } _, ok := versionJSON[fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH)] - if ok && ret["new_version"] != VersionString { + if ok && ret["new_version"] != VersionString && VersionString >= selfUpdateMinVersion { ret["can_autoupdate"] = true } @@ -146,14 +149,15 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) { return nil, fmt.Errorf("No need to update") } + u.updateDir = filepath.Join(workDir, fmt.Sprintf("agh-update-%s", u.newVer)) + u.backupDir = filepath.Join(workDir, fmt.Sprintf("agh-backup-%s", VersionString)) + _, pkgFileName := filepath.Split(u.pkgURL) if len(pkgFileName) == 0 { return nil, fmt.Errorf("Invalid JSON") } - u.pkgName = filepath.Join(workDir, pkgFileName) + u.pkgName = filepath.Join(u.updateDir, pkgFileName) - u.updateDir = filepath.Join(workDir, fmt.Sprintf("update-%s", u.newVer)) - u.backupDir = filepath.Join(workDir, fmt.Sprintf("backup-%s", VersionString)) u.configName = config.getConfigFilename() u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml") if strings.HasSuffix(pkgFileName, ".zip") { @@ -175,60 +179,162 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) { } // Unpack all files from .zip file to the specified directory -func zipFileUnpack(zipfile, outdir string) error { +// Existing files are overwritten +// Return the list of files (not directories) written +func zipFileUnpack(zipfile, outdir string) ([]string, error) { + r, err := zip.OpenReader(zipfile) if err != nil { - return fmt.Errorf("zip.OpenReader(): %s", err) + return nil, fmt.Errorf("zip.OpenReader(): %s", err) } defer r.Close() + var files []string + var err2 error + var zr io.ReadCloser for _, zf := range r.File { - zr, err := zf.Open() + zr, err = zf.Open() if err != nil { - return fmt.Errorf("zip file Open(): %s", err) + err2 = fmt.Errorf("zip file Open(): %s", err) + break } + fi := zf.FileInfo() + if len(fi.Name()) == 0 { + continue + } + fn := filepath.Join(outdir, fi.Name()) if fi.IsDir() { err = os.Mkdir(fn, fi.Mode()) - if err != nil { - return fmt.Errorf("zip file Read(): %s", err) + if err != nil && !os.IsExist(err) { + err2 = fmt.Errorf("os.Mkdir(): %s", err) + break } + log.Tracef("created directory %s", fn) continue } f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) if err != nil { - zr.Close() - return fmt.Errorf("os.OpenFile(): %s", err) + err2 = fmt.Errorf("os.OpenFile(): %s", err) + break } _, err = io.Copy(f, zr) if err != nil { - zr.Close() - return fmt.Errorf("io.Copy(): %s", err) + f.Close() + err2 = fmt.Errorf("io.Copy(): %s", err) + break } - zr.Close() + f.Close() + + log.Tracef("created file %s", fn) + files = append(files, fi.Name()) } - return nil + + zr.Close() + return files, err2 } // Unpack all files from .tar.gz file to the specified directory -func targzFileUnpack(tarfile, outdir string) error { - cmd := exec.Command("tar", "zxf", tarfile, "-C", outdir) - log.Tracef("Unpacking: %v", cmd.Args) - _, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return fmt.Errorf("exec.Command() failed: %s", err) +// Existing files are overwritten +// Return the list of files (not directories) written +func targzFileUnpack(tarfile, outdir string) ([]string, error) { + + f, err := os.Open(tarfile) + if err != nil { + return nil, fmt.Errorf("os.Open(): %s", err) + } + defer f.Close() + + gzReader, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("gzip.NewReader(): %s", err) + } + + var files []string + var err2 error + tarReader := tar.NewReader(gzReader) + for { + header, err := tarReader.Next() + if err == io.EOF { + err2 = nil + break + } + if err != nil { + err2 = fmt.Errorf("tarReader.Next(): %s", err) + break + } + if len(header.Name) == 0 { + continue + } + + fn := filepath.Join(outdir, header.Name) + + if header.Typeflag == tar.TypeDir { + err = os.Mkdir(fn, os.FileMode(header.Mode&0777)) + if err != nil && !os.IsExist(err) { + err2 = fmt.Errorf("os.Mkdir(%s): %s", fn, err) + break + } + log.Tracef("created directory %s", fn) + continue + } else if header.Typeflag != tar.TypeReg { + log.Tracef("%s: unknown file type %d, skipping", header.Name, header.Typeflag) + continue + } + + f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777)) + if err != nil { + err2 = fmt.Errorf("os.OpenFile(%s): %s", fn, err) + break + } + _, err = io.Copy(f, tarReader) + if err != nil { + f.Close() + err2 = fmt.Errorf("io.Copy(): %s", err) + break + } + f.Close() + + log.Tracef("created file %s", fn) + files = append(files, header.Name) + } + + gzReader.Close() + return files, err2 +} + +func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error { + for _, f := range files { + _, name := filepath.Split(f) + if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { + continue + } + + src := filepath.Join(srcdir, f) + if useSrcNameOnly { + src = filepath.Join(srcdir, name) + } + + dst := filepath.Join(dstdir, f) + if useDstNameOnly { + dst = filepath.Join(dstdir, name) + } + + err := copyFile(src, dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + log.Tracef("Copied: %s -> %s", src, dst) } return nil } -// Perform an update procedure -func doUpdate(u *updateInfo) error { - log.Info("Updating from %s to %s. URL:%s Package:%s", - VersionString, u.newVer, u.pkgURL, u.pkgName) - +// Download package file and save it to disk +func getPackageFile(u *updateInfo) error { resp, err := client.Get(u.pkgURL) if err != nil { return fmt.Errorf("HTTP request failed: %s", err) @@ -248,19 +354,34 @@ func doUpdate(u *updateInfo) error { if err != nil { return fmt.Errorf("ioutil.WriteFile() failed: %s", err) } + return nil +} + +// Perform an update procedure +func doUpdate(u *updateInfo) error { + log.Info("Updating from %s to %s. URL:%s Package:%s", + VersionString, u.newVer, u.pkgURL, u.pkgName) + + _ = os.Mkdir(u.updateDir, 0755) + + var err error + err = getPackageFile(u) + if err != nil { + return err + } log.Tracef("Unpacking the package") - _ = os.Mkdir(u.updateDir, 0755) _, file := filepath.Split(u.pkgName) + var files []string if strings.HasSuffix(file, ".zip") { - err = zipFileUnpack(u.pkgName, u.updateDir) + files, err = zipFileUnpack(u.pkgName, u.updateDir) if err != nil { return fmt.Errorf("zipFileUnpack() failed: %s", err) } } else if strings.HasSuffix(file, ".tar.gz") { - err = targzFileUnpack(u.pkgName, u.updateDir) + files, err = targzFileUnpack(u.pkgName, u.updateDir) if err != nil { - return fmt.Errorf("zipFileUnpack() failed: %s", err) + return fmt.Errorf("targzFileUnpack() failed: %s", err) } } else { return fmt.Errorf("Unknown package extension") @@ -284,6 +405,20 @@ func doUpdate(u *updateInfo) error { return fmt.Errorf("copyFile() failed: %s", err) } + // ./README.md -> backup/README.md + err = copySupportingFiles(files, config.ourWorkingDir, u.backupDir, true, true) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + config.ourWorkingDir, u.backupDir, err) + } + + // update/[AdGuardHome/]README.md -> ./README.md + err = copySupportingFiles(files, u.updateDir, config.ourWorkingDir, false, true) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + u.updateDir, config.ourWorkingDir, err) + } + log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) err = os.Rename(u.curBinName, u.bkpBinName) if err != nil { @@ -301,7 +436,7 @@ func doUpdate(u *updateInfo) error { log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName) _ = os.Remove(u.pkgName) - // _ = os.RemoveAll(u.updateDir) + _ = os.RemoveAll(u.updateDir) return nil } diff --git a/control_update_test.go b/control_update_test.go index 346b98f3..3e142a78 100644 --- a/control_update_test.go +++ b/control_update_test.go @@ -1,3 +1,5 @@ +// +build ignore + package main import ( @@ -5,35 +7,48 @@ import ( "testing" ) -func testDoUpdate(t *testing.T) { +func TestDoUpdate(t *testing.T) { config.DNS.Port = 0 + config.ourWorkingDir = "." u := updateInfo{ pkgURL: "https://github.com/AdguardTeam/AdGuardHome/releases/download/v0.95/AdGuardHome_v0.95_linux_amd64.tar.gz", pkgName: "./AdGuardHome_v0.95_linux_amd64.tar.gz", newVer: "v0.95", - updateDir: "./update-v0.95", - backupDir: "./backup-v0.94", + updateDir: "./agh-update-v0.95", + backupDir: "./agh-backup-v0.94", configName: "./AdGuardHome.yaml", - updateConfigName: "./update-v0.95/AdGuardHome/AdGuardHome.yaml", + updateConfigName: "./agh-update-v0.95/AdGuardHome/AdGuardHome.yaml", curBinName: "./AdGuardHome", - bkpBinName: "./backup-v0.94/AdGuardHome", - newBinName: "./update-v0.95/AdGuardHome/AdGuardHome", + bkpBinName: "./agh-backup-v0.94/AdGuardHome", + newBinName: "./agh-update-v0.95/AdGuardHome/AdGuardHome", } e := doUpdate(&u) if e != nil { t.Fatalf("FAILED: %s", e) } os.RemoveAll(u.backupDir) - os.RemoveAll(u.updateDir) } -func testZipFileUnpack(t *testing.T) { - fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip" +func TestTargzFileUnpack(t *testing.T) { + fn := "./dist/AdGuardHome_v0.95_linux_amd64.tar.gz" outdir := "./test-unpack" _ = os.Mkdir(outdir, 0755) - e := zipFileUnpack(fn, outdir) + files, e := targzFileUnpack(fn, outdir) if e != nil { t.Fatalf("FAILED: %s", e) } + t.Logf("%v", files) + os.RemoveAll(outdir) +} + +func TestZipFileUnpack(t *testing.T) { + fn := "./dist/AdGuardHome_v0.95_Windows_amd64.zip" + outdir := "./test-unpack" + _ = os.Mkdir(outdir, 0755) + files, e := zipFileUnpack(fn, outdir) + if e != nil { + t.Fatalf("FAILED: %s", e) + } + t.Logf("%v", files) os.RemoveAll(outdir) } diff --git a/helpers.go b/helpers.go index afd4e34b..a32e4745 100644 --- a/helpers.go +++ b/helpers.go @@ -341,7 +341,7 @@ func customDialContext(ctx context.Context, network, addr string) (net.Conn, err var firstErr error firstErr = nil for _, a := range addrs { - addr = fmt.Sprintf("%s:%s", a.String(), port) + addr = net.JoinHostPort(a.String(), port) con, err := dialer.DialContext(ctx, network, addr) if err != nil { if firstErr == nil { diff --git a/release.sh b/release.sh index 9fbb86da..abf87ac5 100755 --- a/release.sh +++ b/release.sh @@ -20,7 +20,7 @@ f() { mkdir -p dist/AdGuardHome cp -pv {AdGuardHome,LICENSE.txt,README.md} dist/AdGuardHome/ pushd dist - tar zcvf AdGuardHome_"$GOOS"_"$GOARCH".tar.gz AdGuardHome/{AdGuardHome,LICENSE.txt,README.md} + tar zcvf AdGuardHome_"$GOOS"_"$GOARCH".tar.gz AdGuardHome/ popd rm -rf dist/AdguardHome fi