diff --git a/home/config.go b/home/config.go index 05831532..42e08cc2 100644 --- a/home/config.go +++ b/home/config.go @@ -5,7 +5,6 @@ import ( "os" "path/filepath" "sync" - "time" "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" @@ -40,10 +39,6 @@ type configuration struct { // It's reset after config is parsed fileData []byte - // cached version.json to avoid hammering github.io for each page reload - versionCheckJSON []byte - versionCheckLastTime time.Time - BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server Users []User `yaml:"users"` // Users that can access HTTP server diff --git a/home/control_update.go b/home/control_update.go index 566b8327..fb160900 100644 --- a/home/control_update.go +++ b/home/control_update.go @@ -1,13 +1,7 @@ package home import ( - "archive/tar" - "archive/zip" - "compress/gzip" "encoding/json" - "fmt" - "io" - "io/ioutil" "net/http" "os" "os/exec" @@ -15,25 +9,12 @@ import ( "runtime" "strings" "syscall" - "time" + "github.com/AdguardTeam/AdGuardHome/update" "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/golibs/log" ) -type updateInfo struct { - pkgURL string // URL for the new package - pkgName string // Full path to package file - newVer string // New version string - updateDir string // Full path to the directory containing unpacked files from the new package - backupDir string // Full path to backup directory - configName string // Full path to the current configuration file - updateConfigName string // Full path to the configuration file to check by the new binary - curBinName string // Full path to the current executable file - bkpBinName string // Full path to the current executable file in backup directory - newBinName string // Full path to the new executable file -} - type getVersionJSONRequest struct { RecheckNow bool `json:"recheck_now"` } @@ -58,28 +39,11 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { } } - now := time.Now() - if !req.RecheckNow { - Context.controlLock.Lock() - cached := now.Sub(config.versionCheckLastTime) <= versionCheckPeriod && len(config.versionCheckJSON) != 0 - data := config.versionCheckJSON - Context.controlLock.Unlock() - - if cached { - log.Tracef("Returning cached data") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(getVersionResp(data)) - return - } - } - - var resp *http.Response + var info update.VersionInfo for i := 0; i != 3; i++ { - log.Tracef("Downloading data from %s", versionCheckURL) - resp, err = Context.client.Get(versionCheckURL) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } + Context.controlLock.Lock() + info, err = Context.updater.GetVersionResponse(req.RecheckNow) + Context.controlLock.Unlock() if err != nil && strings.HasSuffix(err.Error(), "i/o timeout") { // This case may happen while we're restarting DNS server // https://github.com/AdguardTeam/AdGuardHome/issues/934 @@ -92,39 +56,21 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { return } - // read the body entirely - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - httpError(w, http.StatusBadGateway, "Couldn't read response body from %s: %s", versionCheckURL, err) - return - } - - Context.controlLock.Lock() - config.versionCheckLastTime = now - config.versionCheckJSON = body - Context.controlLock.Unlock() - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(getVersionResp(body)) + _, err = w.Write(getVersionResp(info)) if err != nil { httpError(w, http.StatusInternalServerError, "Couldn't write body: %s", err) } } // Perform an update procedure to the latest available version -func handleUpdate(w http.ResponseWriter, r *http.Request) { - if len(config.versionCheckJSON) == 0 { +func handleUpdate(w http.ResponseWriter, _ *http.Request) { + if len(Context.updater.NewVersion) == 0 { httpError(w, http.StatusBadRequest, "/update request isn't allowed now") return } - u, err := getUpdateInfo(config.versionCheckJSON) - if err != nil { - httpError(w, http.StatusInternalServerError, "%s", err) - return - } - - err = doUpdate(u) + err := Context.updater.DoUpdate() if err != nil { httpError(w, http.StatusInternalServerError, "%s", err) return @@ -135,33 +81,18 @@ func handleUpdate(w http.ResponseWriter, r *http.Request) { f.Flush() } - go finishUpdate(u) + go finishUpdate() } // Convert version.json data to our JSON response -func getVersionResp(data []byte) []byte { - versionJSON := make(map[string]interface{}) - err := json.Unmarshal(data, &versionJSON) - if err != nil { - log.Error("version.json: %s", err) - return []byte{} - } - +func getVersionResp(info update.VersionInfo) []byte { ret := make(map[string]interface{}) ret["can_autoupdate"] = false + ret["new_version"] = info.NewVersion + ret["announcement"] = info.Announcement + ret["announcement_url"] = info.AnnouncementURL - var ok1, ok2, ok3 bool - ret["new_version"], ok1 = versionJSON["version"].(string) - ret["announcement"], ok2 = versionJSON["announcement"].(string) - ret["announcement_url"], ok3 = versionJSON["announcement_url"].(string) - selfUpdateMinVersion, ok4 := versionJSON["selfupdate_min_version"].(string) - if !ok1 || !ok2 || !ok3 || !ok4 { - log.Error("version.json: invalid data") - return []byte{} - } - - _, ok := getDownloadURL(versionJSON) - if ok && ret["new_version"] != versionString && versionString >= selfUpdateMinVersion { + if info.CanAutoUpdate { canUpdate := true tlsConf := tlsConfigSettings{} @@ -185,373 +116,18 @@ func getVersionResp(data []byte) []byte { return d } -// Copy file on disk -func copyFile(src, dst string) error { - d, e := ioutil.ReadFile(src) - if e != nil { - return e - } - e = ioutil.WriteFile(dst, d, 0644) - if e != nil { - return e - } - return nil -} - -// Fill in updateInfo object -func getUpdateInfo(jsonData []byte) (*updateInfo, error) { - var u updateInfo - - workDir := Context.workDir - - versionJSON := make(map[string]interface{}) - err := json.Unmarshal(jsonData, &versionJSON) - if err != nil { - return nil, fmt.Errorf("JSON parse: %s", err) - } - - pkgURL, ok := getDownloadURL(versionJSON) - if !ok { - return nil, fmt.Errorf("failed to get download URL") - } - - u.pkgURL = pkgURL - u.newVer = versionJSON["version"].(string) - if len(u.pkgURL) == 0 || len(u.newVer) == 0 { - return nil, fmt.Errorf("invalid JSON") - } - - if u.newVer == versionString { - return nil, fmt.Errorf("no need to update") - } - - u.updateDir = filepath.Join(workDir, fmt.Sprintf("agh-update-%s", u.newVer)) - u.backupDir = filepath.Join(workDir, "agh-backup") - - _, pkgFileName := filepath.Split(u.pkgURL) - if len(pkgFileName) == 0 { - return nil, fmt.Errorf("invalid JSON") - } - u.pkgName = filepath.Join(u.updateDir, pkgFileName) - - u.configName = config.getConfigFilename() - u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome", "AdGuardHome.yaml") - if strings.HasSuffix(pkgFileName, ".zip") { - u.updateConfigName = filepath.Join(u.updateDir, "AdGuardHome.yaml") - } - - binName := "AdGuardHome" - if runtime.GOOS == "windows" { - binName = "AdGuardHome.exe" - } - u.curBinName = filepath.Join(workDir, binName) - if !util.FileExists(u.curBinName) { - return nil, fmt.Errorf("executable file %s doesn't exist", u.curBinName) - } - u.bkpBinName = filepath.Join(u.backupDir, binName) - u.newBinName = filepath.Join(u.updateDir, "AdGuardHome", binName) - if strings.HasSuffix(pkgFileName, ".zip") { - u.newBinName = filepath.Join(u.updateDir, binName) - } - - return &u, nil -} - -// getDownloadURL - gets download URL for the current GOOS/GOARCH -// returns -func getDownloadURL(json map[string]interface{}) (string, bool) { - var key string - - if runtime.GOARCH == "arm" && ARMVersion != "" { - // the key is: - // download_linux_armv5 for ARMv5 - // download_linux_armv6 for ARMv6 - // download_linux_armv7 for ARMv7 - key = fmt.Sprintf("download_%s_%sv%s", runtime.GOOS, runtime.GOARCH, ARMVersion) - } - - u, ok := json[key] - if !ok { - // the key is download_linux_arm or download_linux_arm64 for regular ARM versions - key = fmt.Sprintf("download_%s_%s", runtime.GOOS, runtime.GOARCH) - u, ok = json[key] - } - - if !ok { - return "", false - } - - return u.(string), true -} - -// Unpack all files from .zip file to the specified directory -// Existing files are overwritten -// Return the list of files (not directories) written -func zipFileUnpack(zipfile, outdir string) ([]string, error) { - - r, err := zip.OpenReader(zipfile) - if err != nil { - return nil, fmt.Errorf("zip.OpenReader(): %s", err) - } - defer r.Close() - - var files []string - var err2 error - var zr io.ReadCloser - for _, zf := range r.File { - zr, err = zf.Open() - if err != nil { - err2 = fmt.Errorf("zip file Open(): %s", err) - break - } - - fi := zf.FileInfo() - if len(fi.Name()) == 0 { - continue - } - - fn := filepath.Join(outdir, fi.Name()) - - if fi.IsDir() { - err = os.Mkdir(fn, fi.Mode()) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(): %s", err) - break - } - log.Tracef("created directory %s", fn) - continue - } - - f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(): %s", err) - break - } - _, err = io.Copy(f, zr) - if err != nil { - f.Close() - err2 = fmt.Errorf("io.Copy(): %s", err) - break - } - f.Close() - - log.Tracef("created file %s", fn) - files = append(files, fi.Name()) - } - - zr.Close() - return files, err2 -} - -// Unpack all files from .tar.gz file to the specified directory -// Existing files are overwritten -// Return the list of files (not directories) written -func targzFileUnpack(tarfile, outdir string) ([]string, error) { - - f, err := os.Open(tarfile) - if err != nil { - return nil, fmt.Errorf("os.Open(): %s", err) - } - defer f.Close() - - gzReader, err := gzip.NewReader(f) - if err != nil { - return nil, fmt.Errorf("gzip.NewReader(): %s", err) - } - - var files []string - var err2 error - tarReader := tar.NewReader(gzReader) - for { - header, err := tarReader.Next() - if err == io.EOF { - err2 = nil - break - } - if err != nil { - err2 = fmt.Errorf("tarReader.Next(): %s", err) - break - } - if len(header.Name) == 0 { - continue - } - - fn := filepath.Join(outdir, header.Name) - - if header.Typeflag == tar.TypeDir { - err = os.Mkdir(fn, os.FileMode(header.Mode&0777)) - if err != nil && !os.IsExist(err) { - err2 = fmt.Errorf("os.Mkdir(%s): %s", fn, err) - break - } - log.Tracef("created directory %s", fn) - continue - } else if header.Typeflag != tar.TypeReg { - log.Tracef("%s: unknown file type %d, skipping", header.Name, header.Typeflag) - continue - } - - f, err := os.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777)) - if err != nil { - err2 = fmt.Errorf("os.OpenFile(%s): %s", fn, err) - break - } - _, err = io.Copy(f, tarReader) - if err != nil { - f.Close() - err2 = fmt.Errorf("io.Copy(): %s", err) - break - } - f.Close() - - log.Tracef("created file %s", fn) - files = append(files, header.Name) - } - - gzReader.Close() - return files, err2 -} - -func copySupportingFiles(files []string, srcdir, dstdir string, useSrcNameOnly, useDstNameOnly bool) error { - for _, f := range files { - _, name := filepath.Split(f) - if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { - continue - } - - src := filepath.Join(srcdir, f) - if useSrcNameOnly { - src = filepath.Join(srcdir, name) - } - - dst := filepath.Join(dstdir, f) - if useDstNameOnly { - dst = filepath.Join(dstdir, name) - } - - err := copyFile(src, dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - log.Tracef("Copied: %s -> %s", src, dst) - } - return nil -} - -// Download package file and save it to disk -func getPackageFile(u *updateInfo) error { - resp, err := Context.client.Get(u.pkgURL) - if err != nil { - return fmt.Errorf("HTTP request failed: %s", err) - } - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - - log.Tracef("Reading HTTP body") - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("ioutil.ReadAll() failed: %s", err) - } - - log.Tracef("Saving package to file") - err = ioutil.WriteFile(u.pkgName, body, 0644) - if err != nil { - return fmt.Errorf("ioutil.WriteFile() failed: %s", err) - } - return nil -} - -// Perform an update procedure -func doUpdate(u *updateInfo) error { - log.Info("Updating from %s to %s. URL:%s Package:%s", - versionString, u.newVer, u.pkgURL, u.pkgName) - - _ = os.Mkdir(u.updateDir, 0755) - - var err error - err = getPackageFile(u) - if err != nil { - return err - } - - log.Tracef("Unpacking the package") - _, file := filepath.Split(u.pkgName) - var files []string - if strings.HasSuffix(file, ".zip") { - files, err = zipFileUnpack(u.pkgName, u.updateDir) - if err != nil { - return fmt.Errorf("zipFileUnpack() failed: %s", err) - } - } else if strings.HasSuffix(file, ".tar.gz") { - files, err = targzFileUnpack(u.pkgName, u.updateDir) - if err != nil { - return fmt.Errorf("targzFileUnpack() failed: %s", err) - } - } else { - return fmt.Errorf("unknown package extension") - } - - log.Tracef("Checking configuration") - err = copyFile(u.configName, u.updateConfigName) - if err != nil { - return fmt.Errorf("copyFile() failed: %s", err) - } - cmd := exec.Command(u.newBinName, "--check-config") - err = cmd.Run() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) - } - - log.Tracef("Backing up the current configuration") - _ = os.Mkdir(u.backupDir, 0755) - err = copyFile(u.configName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) - if err != nil { - return fmt.Errorf("copyFile() failed: %s", err) - } - - // ./README.md -> backup/README.md - err = copySupportingFiles(files, Context.workDir, u.backupDir, true, true) - if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - Context.workDir, u.backupDir, err) - } - - // update/[AdGuardHome/]README.md -> ./README.md - err = copySupportingFiles(files, u.updateDir, Context.workDir, false, true) - if err != nil { - return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", - u.updateDir, Context.workDir, err) - } - - log.Tracef("Renaming: %s -> %s", u.curBinName, u.bkpBinName) - err = os.Rename(u.curBinName, u.bkpBinName) - if err != nil { - return err - } - if runtime.GOOS == "windows" { - // rename fails with "File in use" error - err = copyFile(u.newBinName, u.curBinName) - } else { - err = os.Rename(u.newBinName, u.curBinName) - } - if err != nil { - return err - } - log.Tracef("Renamed: %s -> %s", u.newBinName, u.curBinName) - - _ = os.Remove(u.pkgName) - _ = os.RemoveAll(u.updateDir) - return nil -} - // Complete an update procedure -func finishUpdate(u *updateInfo) { +func finishUpdate() { log.Info("Stopping all tasks") cleanup() cleanupAlways() + exeName := "AdGuardHome" + if runtime.GOOS == "windows" { + exeName = "AdGuardHome.exe" + } + curBinName := filepath.Join(Context.workDir, exeName) + if runtime.GOOS == "windows" { if Context.runningAsService { // Note: @@ -565,7 +141,7 @@ func finishUpdate(u *updateInfo) { os.Exit(0) } - cmd := exec.Command(u.curBinName, os.Args[1:]...) + cmd := exec.Command(curBinName, os.Args[1:]...) log.Info("Restarting: %v", cmd.Args) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout @@ -577,7 +153,7 @@ func finishUpdate(u *updateInfo) { os.Exit(0) } else { log.Info("Restarting: %v", os.Args) - err := syscall.Exec(u.curBinName, os.Args, os.Environ()) + err := syscall.Exec(curBinName, os.Args, os.Environ()) if err != nil { log.Fatalf("syscall.Exec() failed: %s", err) } diff --git a/home/home.go b/home/home.go index 79cf0843..e297173c 100644 --- a/home/home.go +++ b/home/home.go @@ -20,6 +20,7 @@ import ( "gopkg.in/natefinch/lumberjack.v2" + "github.com/AdguardTeam/AdGuardHome/update" "github.com/AdguardTeam/AdGuardHome/util" "github.com/joomcode/errorx" @@ -47,8 +48,6 @@ var ( ARMVersion = "" ) -const versionCheckPeriod = time.Hour * 8 - // Global context type homeContext struct { // Modules @@ -67,6 +66,7 @@ type homeContext struct { web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files + updater *update.Updater // Runtime properties // -- @@ -225,6 +225,18 @@ func run(args options) { os.Exit(1) } Context.autoHosts.Init("") + + Context.updater = update.NewUpdater(update.Config{ + Client: Context.client, + WorkDir: Context.workDir, + VersionURL: versionCheckURL, + VersionString: versionString, + OS: runtime.GOOS, + Arch: runtime.GOARCH, + ARMVersion: ARMVersion, + ConfigName: config.getConfigFilename(), + }) + Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts) config.Clients = nil diff --git a/update/check.go b/update/check.go new file mode 100644 index 00000000..09755d65 --- /dev/null +++ b/update/check.go @@ -0,0 +1,103 @@ +package update + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "time" +) + +const versionCheckPeriod = 8 * 60 * 60 + +// VersionInfo - VersionInfo +type VersionInfo struct { + NewVersion string // New version string + Announcement string // Announcement text + AnnouncementURL string // Announcement URL + SelfUpdateMinVersion string // Min version starting with which we can auto-update + CanAutoUpdate bool // If true - we can auto-update +} + +// GetVersionResponse - downloads version.json (if needed) and deserializes it +func (u *Updater) GetVersionResponse(forceRecheck bool) (VersionInfo, error) { + if !forceRecheck && + u.versionCheckLastTime.Unix()+versionCheckPeriod > time.Now().Unix() { + return u.parseVersionResponse(u.versionJSON) + } + + resp, err := u.Client.Get(u.VersionURL) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + if err != nil { + return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %s", u.VersionURL, err) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %s", u.VersionURL, err) + } + + u.versionJSON = body + u.versionCheckLastTime = time.Now() + + return u.parseVersionResponse(body) +} + +func (u *Updater) parseVersionResponse(data []byte) (VersionInfo, error) { + info := VersionInfo{} + versionJSON := make(map[string]interface{}) + err := json.Unmarshal(data, &versionJSON) + if err != nil { + return info, fmt.Errorf("version.json: %s", err) + } + + var ok1, ok2, ok3, ok4 bool + info.NewVersion, ok1 = versionJSON["version"].(string) + info.Announcement, ok2 = versionJSON["announcement"].(string) + info.AnnouncementURL, ok3 = versionJSON["announcement_url"].(string) + info.SelfUpdateMinVersion, ok4 = versionJSON["selfupdate_min_version"].(string) + if !ok1 || !ok2 || !ok3 || !ok4 { + return info, fmt.Errorf("version.json: invalid data") + } + + packageURL, ok := u.getDownloadURL(versionJSON) + + if ok && + info.NewVersion != u.VersionString && + u.VersionString >= info.SelfUpdateMinVersion { + info.CanAutoUpdate = true + } + + u.NewVersion = info.NewVersion + u.PackageURL = packageURL + + return info, nil +} + +// Get download URL for the current GOOS/GOARCH/ARMVersion +func (u *Updater) getDownloadURL(json map[string]interface{}) (string, bool) { + var key string + + if u.Arch == "arm" && u.ARMVersion != "" { + // the key is: + // download_linux_armv5 for ARMv5 + // download_linux_armv6 for ARMv6 + // download_linux_armv7 for ARMv7 + key = fmt.Sprintf("download_%s_%sv%s", u.OS, u.Arch, u.ARMVersion) + } + + val, ok := json[key] + if !ok { + // the key is download_linux_arm or download_linux_arm64 for regular ARM versions + key = fmt.Sprintf("download_%s_%s", u.OS, u.Arch) + val, ok = json[key] + } + + if !ok { + return "", false + } + + return val.(string), true +} diff --git a/update/test/AdGuardHome.tar.gz b/update/test/AdGuardHome.tar.gz new file mode 100644 index 00000000..b292c6d9 Binary files /dev/null and b/update/test/AdGuardHome.tar.gz differ diff --git a/update/test/AdGuardHome.zip b/update/test/AdGuardHome.zip new file mode 100644 index 00000000..c984a347 Binary files /dev/null and b/update/test/AdGuardHome.zip differ diff --git a/update/update_test.go b/update/update_test.go new file mode 100644 index 00000000..7481ba3f --- /dev/null +++ b/update/update_test.go @@ -0,0 +1,210 @@ +package update + +import ( + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func startHTTPServer(data string) (net.Listener, uint16) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(data)) + }) + + listener, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + go func() { _ = http.Serve(listener, mux) }() + return listener, uint16(listener.Addr().(*net.TCPAddr).Port) +} + +func TestUpdateGetVersion(t *testing.T) { + const jsonData = `{ + "version": "v0.103.0-beta2", + "announcement": "AdGuard Home v0.103.0-beta2 is now available!", + "announcement_url": "https://github.com/AdguardTeam/AdGuardHome/releases", + "selfupdate_min_version": "v0.0", + "download_windows_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_amd64.zip", + "download_windows_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_windows_386.zip", + "download_darwin_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_amd64.zip", + "download_darwin_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_darwin_386.zip", + "download_linux_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_amd64.tar.gz", + "download_linux_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_386.tar.gz", + "download_linux_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz", + "download_linux_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv5.tar.gz", + "download_linux_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv6.tar.gz", + "download_linux_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_armv7.tar.gz", + "download_linux_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_arm64.tar.gz", + "download_linux_mips": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips_softfloat.tar.gz", + "download_linux_mipsle": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mipsle_softfloat.tar.gz", + "download_linux_mips64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64_softfloat.tar.gz", + "download_linux_mips64le": "https://static.adguard.com/adguardhome/beta/AdGuardHome_linux_mips64le_softfloat.tar.gz", + "download_freebsd_386": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_386.tar.gz", + "download_freebsd_amd64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_amd64.tar.gz", + "download_freebsd_arm": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz", + "download_freebsd_armv5": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv5.tar.gz", + "download_freebsd_armv6": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv6.tar.gz", + "download_freebsd_armv7": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_armv7.tar.gz", + "download_freebsd_arm64": "https://static.adguard.com/adguardhome/beta/AdGuardHome_freebsd_arm64.tar.gz" +}` + + l, lport := startHTTPServer(jsonData) + defer func() { _ = l.Close() }() + + u := NewUpdater(Config{ + Client: &http.Client{}, + VersionURL: fmt.Sprintf("http://127.0.0.1:%d/", lport), + OS: "linux", + Arch: "arm", + VersionString: "v0.103.0-beta1", + }) + + info, err := u.GetVersionResponse(false) + assert.Nil(t, err) + assert.Equal(t, "v0.103.0-beta2", info.NewVersion) + assert.Equal(t, "AdGuard Home v0.103.0-beta2 is now available!", info.Announcement) + assert.Equal(t, "https://github.com/AdguardTeam/AdGuardHome/releases", info.AnnouncementURL) + assert.Equal(t, "v0.0", info.SelfUpdateMinVersion) + assert.True(t, info.CanAutoUpdate) + + _ = l.Close() + + // check cached + _, err = u.GetVersionResponse(false) + assert.Nil(t, err) +} + +func TestUpdate(t *testing.T) { + _ = os.Mkdir("aghtest", 0755) + defer func() { + _ = os.RemoveAll("aghtest") + }() + + // create "current" files + assert.Nil(t, ioutil.WriteFile("aghtest/AdGuardHome", []byte("AdGuardHome"), 0755)) + assert.Nil(t, ioutil.WriteFile("aghtest/README.md", []byte("README.md"), 0644)) + assert.Nil(t, ioutil.WriteFile("aghtest/LICENSE.txt", []byte("LICENSE.txt"), 0644)) + assert.Nil(t, ioutil.WriteFile("aghtest/AdGuardHome.yaml", []byte("AdGuardHome.yaml"), 0644)) + + // start server for returning package file + pkgData, err := ioutil.ReadFile("test/AdGuardHome.tar.gz") + assert.Nil(t, err) + l, lport := startHTTPServer(string(pkgData)) + defer func() { _ = l.Close() }() + + u := NewUpdater(Config{ + Client: &http.Client{}, + PackageURL: fmt.Sprintf("http://127.0.0.1:%d/AdGuardHome.tar.gz", lport), + VersionString: "v0.103.0", + NewVersion: "v0.103.1", + ConfigName: "aghtest/AdGuardHome.yaml", + WorkDir: "aghtest", + }) + + assert.Nil(t, u.prepare()) + u.currentExeName = "aghtest/AdGuardHome" + assert.Nil(t, u.downloadPackageFile(u.PackageURL, u.packageName)) + assert.Nil(t, u.unpack()) + // assert.Nil(t, u.check()) + assert.Nil(t, u.backup()) + assert.Nil(t, u.replace()) + u.clean() + + // check backup files + d, err := ioutil.ReadFile("aghtest/agh-backup/AdGuardHome.yaml") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome.yaml", string(d)) + + d, err = ioutil.ReadFile("aghtest/agh-backup/AdGuardHome") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome", string(d)) + + // check updated files + d, err = ioutil.ReadFile("aghtest/AdGuardHome") + assert.Nil(t, err) + assert.Equal(t, "1", string(d)) + + d, err = ioutil.ReadFile("aghtest/README.md") + assert.Nil(t, err) + assert.Equal(t, "2", string(d)) + + d, err = ioutil.ReadFile("aghtest/LICENSE.txt") + assert.Nil(t, err) + assert.Equal(t, "3", string(d)) + + d, err = ioutil.ReadFile("aghtest/AdGuardHome.yaml") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome.yaml", string(d)) +} + +func TestUpdateWindows(t *testing.T) { + _ = os.Mkdir("aghtest", 0755) + defer func() { + _ = os.RemoveAll("aghtest") + }() + + // create "current" files + assert.Nil(t, ioutil.WriteFile("aghtest/AdGuardHome.exe", []byte("AdGuardHome.exe"), 0755)) + assert.Nil(t, ioutil.WriteFile("aghtest/README.md", []byte("README.md"), 0644)) + assert.Nil(t, ioutil.WriteFile("aghtest/LICENSE.txt", []byte("LICENSE.txt"), 0644)) + assert.Nil(t, ioutil.WriteFile("aghtest/AdGuardHome.yaml", []byte("AdGuardHome.yaml"), 0644)) + + // start server for returning package file + pkgData, err := ioutil.ReadFile("test/AdGuardHome.zip") + assert.Nil(t, err) + l, lport := startHTTPServer(string(pkgData)) + defer func() { _ = l.Close() }() + + u := NewUpdater(Config{ + WorkDir: "aghtest", + Client: &http.Client{}, + PackageURL: fmt.Sprintf("http://127.0.0.1:%d/AdGuardHome.zip", lport), + OS: "windows", + VersionString: "v0.103.0", + NewVersion: "v0.103.1", + ConfigName: "aghtest/AdGuardHome.yaml", + }) + + assert.Nil(t, u.prepare()) + u.currentExeName = "aghtest/AdGuardHome.exe" + assert.Nil(t, u.downloadPackageFile(u.PackageURL, u.packageName)) + assert.Nil(t, u.unpack()) + // assert.Nil(t, u.check()) + assert.Nil(t, u.backup()) + assert.Nil(t, u.replace()) + u.clean() + + // check backup files + d, err := ioutil.ReadFile("aghtest/agh-backup/AdGuardHome.yaml") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome.yaml", string(d)) + + d, err = ioutil.ReadFile("aghtest/agh-backup/AdGuardHome.exe") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome.exe", string(d)) + + // check updated files + d, err = ioutil.ReadFile("aghtest/AdGuardHome.exe") + assert.Nil(t, err) + assert.Equal(t, "1", string(d)) + + d, err = ioutil.ReadFile("aghtest/README.md") + assert.Nil(t, err) + assert.Equal(t, "2", string(d)) + + d, err = ioutil.ReadFile("aghtest/LICENSE.txt") + assert.Nil(t, err) + assert.Equal(t, "3", string(d)) + + d, err = ioutil.ReadFile("aghtest/AdGuardHome.yaml") + assert.Nil(t, err) + assert.Equal(t, "AdGuardHome.yaml", string(d)) +} diff --git a/update/updater.go b/update/updater.go new file mode 100644 index 00000000..469544dc --- /dev/null +++ b/update/updater.go @@ -0,0 +1,418 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/util" + "github.com/AdguardTeam/golibs/log" +) + +// Updater - Updater +type Updater struct { + Config // Updater configuration + + currentExeName string // current binary executable + updateDir string // "work_dir/agh-update-v0.103.0" + packageName string // "work_dir/agh-update-v0.103.0/pkg_name.tar.gz" + backupDir string // "work_dir/agh-backup" + backupExeName string // "work_dir/agh-backup/AdGuardHome[.exe]" + updateExeName string // "work_dir/agh-update-v0.103.0/AdGuardHome[.exe]" + unpackedFiles []string + + // cached version.json to avoid hammering github.io for each page reload + versionJSON []byte + versionCheckLastTime time.Time +} + +// Config - updater config +type Config struct { + Client *http.Client + + VersionURL string // version.json URL + VersionString string + OS string // GOOS + Arch string // GOARCH + ARMVersion string // ARM version, e.g. "6" + NewVersion string // VersionInfo.NewVersion + PackageURL string // VersionInfo.PackageURL + ConfigName string // current config file ".../AdGuardHome.yaml" + WorkDir string // updater work dir (where backup/upd dirs will be created) +} + +// NewUpdater - creates a new instance of the Updater +func NewUpdater(cfg Config) *Updater { + return &Updater{ + Config: cfg, + } +} + +// DoUpdate - conducts the auto-update +// 1. Downloads the update file +// 2. Unpacks it and checks the contents +// 3. Backups the current version and configuration +// 4. Replaces the old files +func (u *Updater) DoUpdate() error { + err := u.prepare() + if err != nil { + return err + } + + defer u.clean() + + err = u.downloadPackageFile(u.PackageURL, u.packageName) + if err != nil { + return err + } + + err = u.unpack() + if err != nil { + return err + } + + err = u.check() + if err != nil { + u.clean() + return err + } + + err = u.backup() + if err != nil { + return err + } + + err = u.replace() + if err != nil { + return err + } + + return nil +} + +func (u *Updater) prepare() error { + u.updateDir = filepath.Join(u.WorkDir, fmt.Sprintf("agh-update-%s", u.NewVersion)) + + _, pkgNameOnly := filepath.Split(u.PackageURL) + if len(pkgNameOnly) == 0 { + return fmt.Errorf("invalid PackageURL") + } + u.packageName = filepath.Join(u.updateDir, pkgNameOnly) + u.backupDir = filepath.Join(u.WorkDir, "agh-backup") + + exeName := "AdGuardHome" + if u.OS == "windows" { + exeName = "AdGuardHome.exe" + } + + u.backupExeName = filepath.Join(u.backupDir, exeName) + u.updateExeName = filepath.Join(u.updateDir, exeName) + + log.Info("Updating from %s to %s. URL:%s", + u.VersionString, u.NewVersion, u.PackageURL) + + // If the binary file isn't found in working directory, we won't be able to auto-update + // Getting the full path to the current binary file on UNIX and checking write permissions + // is more difficult. + u.currentExeName = filepath.Join(u.WorkDir, exeName) + if !util.FileExists(u.currentExeName) { + return fmt.Errorf("executable file %s doesn't exist", u.currentExeName) + } + return nil +} + +func (u *Updater) unpack() error { + var err error + _, pkgNameOnly := filepath.Split(u.PackageURL) + + log.Debug("updater: unpacking the package") + if strings.HasSuffix(pkgNameOnly, ".zip") { + u.unpackedFiles, err = zipFileUnpack(u.packageName, u.updateDir) + if err != nil { + return fmt.Errorf(".zip unpack failed: %s", err) + } + + } else if strings.HasSuffix(pkgNameOnly, ".tar.gz") { + u.unpackedFiles, err = tarGzFileUnpack(u.packageName, u.updateDir) + if err != nil { + return fmt.Errorf(".tar.gz unpack failed: %s", err) + } + + } else { + return fmt.Errorf("unknown package extension") + } + + return nil +} + +func (u *Updater) check() error { + log.Debug("updater: checking configuration") + err := copyFile(u.ConfigName, filepath.Join(u.updateDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + cmd := exec.Command(u.updateExeName, "--check-config") + err = cmd.Run() + if err != nil || cmd.ProcessState.ExitCode() != 0 { + return fmt.Errorf("exec.Command(): %s %d", err, cmd.ProcessState.ExitCode()) + } + return nil +} + +func (u *Updater) backup() error { + log.Debug("updater: backing up the current configuration") + _ = os.Mkdir(u.backupDir, 0755) + err := copyFile(u.ConfigName, filepath.Join(u.backupDir, "AdGuardHome.yaml")) + if err != nil { + return fmt.Errorf("copyFile() failed: %s", err) + } + + // workdir/README.md -> backup/README.md + err = copySupportingFiles(u.unpackedFiles, u.WorkDir, u.backupDir) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + u.WorkDir, u.backupDir, err) + } + + return nil +} + +func (u *Updater) replace() error { + // update/README.md -> workdir/README.md + err := copySupportingFiles(u.unpackedFiles, u.updateDir, u.WorkDir) + if err != nil { + return fmt.Errorf("copySupportingFiles(%s, %s) failed: %s", + u.updateDir, u.WorkDir, err) + } + + log.Debug("updater: renaming: %s -> %s", u.currentExeName, u.backupExeName) + err = os.Rename(u.currentExeName, u.backupExeName) + if err != nil { + return err + } + + if u.OS == "windows" { + // rename fails with "File in use" error + err = copyFile(u.updateExeName, u.currentExeName) + } else { + err = os.Rename(u.updateExeName, u.currentExeName) + } + if err != nil { + return err + } + log.Debug("updater: renamed: %s -> %s", u.updateExeName, u.currentExeName) + return nil +} + +func (u *Updater) clean() { + _ = os.RemoveAll(u.updateDir) +} + +// Download package file and save it to disk +func (u *Updater) downloadPackageFile(url string, filename string) error { + resp, err := u.Client.Get(url) + if err != nil { + return fmt.Errorf("HTTP request failed: %s", err) + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + log.Debug("updater: reading HTTP body") + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("ioutil.ReadAll() failed: %s", err) + } + + _ = os.Mkdir(u.updateDir, 0755) + + log.Debug("updater: saving package to file") + err = ioutil.WriteFile(filename, body, 0644) + if err != nil { + return fmt.Errorf("ioutil.WriteFile() failed: %s", err) + } + return nil +} + +// Unpack all files from .tar.gz file to the specified directory +// Existing files are overwritten +// All files are created inside 'outdir', subdirectories are not created +// Return the list of files (not directories) written +func tarGzFileUnpack(tarfile, outdir string) ([]string, error) { + f, err := os.Open(tarfile) + if err != nil { + return nil, fmt.Errorf("os.Open(): %s", err) + } + defer func() { + _ = f.Close() + }() + + gzReader, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("gzip.NewReader(): %s", err) + } + + var files []string + var err2 error + tarReader := tar.NewReader(gzReader) + for { + header, err := tarReader.Next() + if err == io.EOF { + err2 = nil + break + } + if err != nil { + err2 = fmt.Errorf("tarReader.Next(): %s", err) + break + } + + _, inputNameOnly := filepath.Split(header.Name) + if len(inputNameOnly) == 0 { + continue + } + + outputName := filepath.Join(outdir, inputNameOnly) + + if header.Typeflag == tar.TypeDir { + err = os.Mkdir(outputName, os.FileMode(header.Mode&0777)) + if err != nil && !os.IsExist(err) { + err2 = fmt.Errorf("os.Mkdir(%s): %s", outputName, err) + break + } + log.Debug("updater: created directory %s", outputName) + continue + } else if header.Typeflag != tar.TypeReg { + log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag) + continue + } + + f, err := os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0777)) + if err != nil { + err2 = fmt.Errorf("os.OpenFile(%s): %s", outputName, err) + break + } + _, err = io.Copy(f, tarReader) + if err != nil { + _ = f.Close() + err2 = fmt.Errorf("io.Copy(): %s", err) + break + } + err = f.Close() + if err != nil { + err2 = fmt.Errorf("f.Close(): %s", err) + break + } + + log.Debug("updater: created file %s", outputName) + files = append(files, header.Name) + } + + _ = gzReader.Close() + return files, err2 +} + +// Unpack all files from .zip file to the specified directory +// Existing files are overwritten +// All files are created inside 'outdir', subdirectories are not created +// Return the list of files (not directories) written +func zipFileUnpack(zipfile, outdir string) ([]string, error) { + r, err := zip.OpenReader(zipfile) + if err != nil { + return nil, fmt.Errorf("zip.OpenReader(): %s", err) + } + defer r.Close() + + var files []string + var err2 error + var zr io.ReadCloser + for _, zf := range r.File { + zr, err = zf.Open() + if err != nil { + err2 = fmt.Errorf("zip file Open(): %s", err) + break + } + + fi := zf.FileInfo() + inputNameOnly := fi.Name() + if len(inputNameOnly) == 0 { + continue + } + + outputName := filepath.Join(outdir, inputNameOnly) + + if fi.IsDir() { + err = os.Mkdir(outputName, fi.Mode()) + if err != nil && !os.IsExist(err) { + err2 = fmt.Errorf("os.Mkdir(): %s", err) + break + } + log.Tracef("created directory %s", outputName) + continue + } + + f, err := os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode()) + if err != nil { + err2 = fmt.Errorf("os.OpenFile(): %s", err) + break + } + _, err = io.Copy(f, zr) + if err != nil { + _ = f.Close() + err2 = fmt.Errorf("io.Copy(): %s", err) + break + } + err = f.Close() + if err != nil { + err2 = fmt.Errorf("f.Close(): %s", err) + break + } + + log.Tracef("created file %s", outputName) + files = append(files, inputNameOnly) + } + + _ = zr.Close() + return files, err2 +} + +// Copy file on disk +func copyFile(src, dst string) error { + d, e := ioutil.ReadFile(src) + if e != nil { + return e + } + e = ioutil.WriteFile(dst, d, 0644) + if e != nil { + return e + } + return nil +} + +func copySupportingFiles(files []string, srcdir, dstdir string) error { + for _, f := range files { + _, name := filepath.Split(f) + if name == "AdGuardHome" || name == "AdGuardHome.exe" || name == "AdGuardHome.yaml" { + continue + } + + src := filepath.Join(srcdir, name) + dst := filepath.Join(dstdir, name) + + err := copyFile(src, dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + log.Debug("updater: copied: %s -> %s", src, dst) + } + return nil +}