From ec7efcc9d6479d26b9ba23b21d6270dfb8c99ff1 Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Tue, 27 Nov 2018 20:39:59 +0300 Subject: [PATCH] Move config upgrade to separate upgrade.go --- app.go | 88 +++--------------------------------- config.go | 6 +-- control.go | 2 +- helpers.go | 2 +- upgrade.go | 128 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 140 insertions(+), 86 deletions(-) create mode 100644 upgrade.go diff --git a/app.go b/app.go index 7818f061..61989816 100644 --- a/app.go +++ b/app.go @@ -114,6 +114,12 @@ func main() { log.Fatal(err) } + // Do the upgrade if necessary + err = upgradeConfig() + if err != nil { + log.Fatal(err) + } + // parse from config file err = parseConfig() if err != nil { @@ -134,14 +140,8 @@ func main() { os.Args = os.Args[:1] } - // Do the upgrade if necessary - err := upgradeConfig() - if err != nil { - log.Fatal(err) - } - // Save the updated config - err = writeConfig() + err := writeConfig() if err != nil { log.Fatal(err) } @@ -260,77 +260,3 @@ func askUsernamePasswordIfPossible() error { config.AuthPass = password return nil } - -// Performs necessary upgrade operations if needed -func upgradeConfig() error { - if config.SchemaVersion == SchemaVersion { - // No upgrade, do nothing - return nil - } - - if config.SchemaVersion > SchemaVersion { - // Unexpected -- the config file is newer than we expect - return fmt.Errorf("configuration file is supposed to be used with a newer version of AdGuard Home, schema=%d", config.SchemaVersion) - } - - // Perform upgrade operations for each consecutive version upgrade - for oldVersion, newVersion := config.SchemaVersion, config.SchemaVersion+1; newVersion <= SchemaVersion; { - err := upgradeConfigSchema(oldVersion, newVersion) - if err != nil { - log.Fatal(err) - } - - // Increment old and new versions - oldVersion++ - newVersion++ - } - - // Save the current schema version - config.SchemaVersion = SchemaVersion - - return nil -} - -// Upgrade from oldVersion to newVersion -func upgradeConfigSchema(oldVersion int, newVersion int) error { - if oldVersion == 0 && newVersion == 1 { - log.Printf("Updating schema from %d to %d", oldVersion, newVersion) - - // The first schema upgrade: - // Added "ID" field to "filter" -- we need to populate this field now - // Added "config.ourDataDir" -- where we will now store filters contents - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - - // Set the filter ID - log.Printf("Seting ID=%d for filter %s", NextFilterId, filter.URL) - filter.ID = NextFilterId - NextFilterId++ - - // Forcibly update the filter - _, err := filter.update(true) - if err != nil { - log.Fatal(err) - } - - // Saving it to the filters dir now - err = filter.save() - if err != nil { - log.Fatal(err) - } - } - - // No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/ - dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") - _, err := os.Stat(dnsFilterPath) - if !os.IsNotExist(err) { - log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) - err = os.Remove(dnsFilterPath) - if err != nil { - log.Printf("Cannot remove %s due to %s", dnsFilterPath, err) - } - } - } - - return nil -} diff --git a/config.go b/config.go index f60a9040..340bc969 100644 --- a/config.go +++ b/config.go @@ -16,7 +16,7 @@ import ( // Current schema version. We compare it with the value from // the configuration file and perform necessary upgrade operations if needed -const SchemaVersion = 1 +const CurrentSchemaVersion = 1 // Directory where we'll store all downloaded filters contents const FiltersDir = "filters" @@ -188,7 +188,7 @@ func writeConfig() error { log.Printf("Couldn't generate YAML file: %s", err) return err } - err = writeFileSafe(configFile, yamlText) + err = safeWriteFile(configFile, yamlText) if err != nil { log.Printf("Couldn't save YAML config: %s", err) return err @@ -215,7 +215,7 @@ func writeCoreDNSConfig() error { log.Printf("Couldn't generate DNS config: %s", err) return err } - err = writeFileSafe(coreFile, []byte(configText)) + err = safeWriteFile(coreFile, []byte(configText)) if err != nil { log.Printf("Couldn't save DNS config: %s", err) return err diff --git a/control.go b/control.go index f61272a2..3c4815b6 100644 --- a/control.go +++ b/control.go @@ -664,7 +664,7 @@ func (filter *filter) save() error { filterFilePath := filter.Path() log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) - err := writeFileSafe(filterFilePath, filter.Contents) + err := safeWriteFile(filterFilePath, filter.Contents) if err != nil { return err } diff --git a/helpers.go b/helpers.go index 01fc1bdf..03d1e2f4 100644 --- a/helpers.go +++ b/helpers.go @@ -19,7 +19,7 @@ import ( // ---------------------------------- // Writes data first to a temporary file and then renames it to what's specified in path -func writeFileSafe(path string, data []byte) error { +func safeWriteFile(path string, data []byte) error { dir := filepath.Dir(path) err := os.MkdirAll(dir, 0755) if err != nil { diff --git a/upgrade.go b/upgrade.go new file mode 100644 index 00000000..0bf27c51 --- /dev/null +++ b/upgrade.go @@ -0,0 +1,128 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + + "gopkg.in/yaml.v2" +) + +// Performs necessary upgrade operations if needed +func upgradeConfig() error { + // read a config file into an interface map, so we can manipulate values without losing any + configFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + if _, err := os.Stat(configFile); os.IsNotExist(err) { + log.Printf("config file %s does not exist, nothing to upgrade", configFile) + return nil + } + diskConfig := map[string]interface{}{} + body, err := ioutil.ReadFile(configFile) + if err != nil { + log.Printf("Couldn't read config file '%s': %s", configFile, err) + return err + } + + err = yaml.Unmarshal(body, &diskConfig) + if err != nil { + log.Printf("Couldn't parse config file '%s': %s", configFile, err) + return err + } + + schemaVersionInterface, ok := diskConfig["schema_version"] + trace("schemaVersionInterface = %v, ok = %v", schemaVersionInterface, ok) + if !ok { + // no schema version, set it to 0 + schemaVersionInterface = 0 + } + + schemaVersion, ok := schemaVersionInterface.(int) + if !ok { + err = fmt.Errorf("configuration file contains non-integer schema_version, abort") + log.Println(err) + return err + } + + if schemaVersion == CurrentSchemaVersion { + // do nothing + return nil + } + + return upgradeConfigSchema(schemaVersion, &diskConfig) +} + +// Upgrade from oldVersion to newVersion +func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error { + switch oldVersion { + case 0: + err := upgradeSchema0to1(diskConfig) + if err != nil { + return err + } + default: + err := fmt.Errorf("configuration file contains unknown schema_version, abort") + log.Println(err) + return err + } + + configFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) + body, err := yaml.Marshal(diskConfig) + if err != nil { + log.Printf("Couldn't generate YAML file: %s", err) + return err + } + + err = safeWriteFile(configFile, body) + if err != nil { + log.Printf("Couldn't save YAML config: %s", err) + return err + } + + return nil +} + +func upgradeSchema0to1(diskConfig *map[string]interface{}) error { + trace("Called") + + // The first schema upgrade: + // Added "ID" field to "filter" -- we need to populate this field now + // Added "config.ourDataDir" -- where we will now store filters contents + for i := range config.Filters { + filter := &config.Filters[i] // otherwise we will be operating on a copy + + // Set the filter ID + log.Printf("Seting ID=%d for filter %s", NextFilterId, filter.URL) + filter.ID = NextFilterId + NextFilterId++ + + // Forcibly update the filter + _, err := filter.update(true) + if err != nil { + log.Fatal(err) + } + + // Saving it to the filters dir now + err = filter.save() + if err != nil { + log.Fatal(err) + } + } + + // No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/ + dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") + _, err := os.Stat(dnsFilterPath) + if !os.IsNotExist(err) { + log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) + err = os.Remove(dnsFilterPath) + if err != nil { + log.Printf("Cannot remove %s due to %s", dnsFilterPath, err) + // not fatal, move on + } + } + + (*diskConfig)["schema_version"] = 1 + + return nil +}