From 32d4e80c93cba5b3ef4bb96c79f7256b5cb4d913 Mon Sep 17 00:00:00 2001 From: Andrey Meshkov Date: Tue, 30 Oct 2018 02:17:24 +0300 Subject: [PATCH] Fix #371 #421 Filters are now saved to a file Also, they're loaded from the file on startup Filter ID is not passed to the CoreDNS plugin config (server-side AG DNS must be changed accordingly) Some minor refactoring, unused functions removed --- .gitignore | 1 + app.go | 119 +++++++++++++----- config.go | 148 ++++++++++++++++------- control.go | 166 +++++++++++++------------- coredns.go | 6 - coredns_plugin/coredns_plugin.go | 51 ++++++-- coredns_plugin/coredns_plugin_test.go | 1 + helpers.go | 37 ++++-- 8 files changed, 339 insertions(+), 190 deletions(-) diff --git a/.gitignore b/.gitignore index e22df4e9..5cfd4889 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ debug /AdGuardHome /AdGuardHome.yaml +/data/ /build/ /client/node_modules/ /coredns diff --git a/app.go b/app.go index a317aa29..78a995e7 100644 --- a/app.go +++ b/app.go @@ -25,10 +25,18 @@ func main() { if err != nil { panic(err) } - config.ourBinaryDir = filepath.Dir(executable) - } - doConfigRename := true + executableName := filepath.Base(executable) + if executableName == "AdGuardHome" { + // Binary build + config.ourBinaryDir = filepath.Dir(executable) + } else { + // Most likely we're debugging -- using current working directory in this case + workDir, _ := os.Getwd() + config.ourBinaryDir = workDir + } + log.Printf("Current working directory is %s", config.ourBinaryDir) + } // config can be specified, which reads options from there, but other command line flags have to override config values // therefore, we must do it manually instead of using a lib @@ -98,18 +106,9 @@ func main() { } } if configFilename != nil { - // config was manually specified, don't do anything - doConfigRename = false config.ourConfigFilename = *configFilename } - if doConfigRename { - err := renameOldConfigIfNeccessary() - if err != nil { - panic(err) - } - } - err := askUsernamePasswordIfPossible() if err != nil { log.Fatal(err) @@ -128,16 +127,32 @@ func main() { } } - // eat all args so that coredns can start happily + // Eat all args so that coredns can start happily if len(os.Args) > 1 { os.Args = os.Args[:1] } - err := writeConfig() + // Do the upgrade if necessary + err := upgradeConfig() if err != nil { log.Fatal(err) } + // Save the updated config + err = writeConfig() + if err != nil { + log.Fatal(err) + } + + // Load filters from the disk + for i := range config.Filters { + filter := &config.Filters[i] + err = filter.load() + if err != nil { + log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err) + } + } + address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) runFilterRefreshers() @@ -240,27 +255,71 @@ func askUsernamePasswordIfPossible() error { return nil } -func renameOldConfigIfNeccessary() error { - oldConfigFile := filepath.Join(config.ourBinaryDir, "AdguardDNS.yaml") - _, err := os.Stat(oldConfigFile) - if os.IsNotExist(err) { - // do nothing, file doesn't exist - trace("File %s doesn't exist, nothing to do", oldConfigFile) +// Performs necessary upgrade operations if needed +func upgradeConfig() error { + + if config.SchemaVersion == SchemaVersion { + // No upgrade, do nothing return nil } - newConfigFile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) - _, err = os.Stat(newConfigFile) - if !os.IsNotExist(err) { - // do nothing, file doesn't exist - trace("File %s already exists, will not overwrite", newConfigFile) - return nil + if config.SchemaVersion > SchemaVersion { + // Unexpected -- config file is newer than the + return fmt.Errorf("configuration file is supposed to be used with a newer version of AdGuard Home, schema=%d", config.SchemaVersion) } - err = os.Rename(oldConfigFile, newConfigFile) - if err != nil { - log.Printf("Failed to rename %s to %s: %s", oldConfigFile, newConfigFile, err) - return err + // Perform upgrade operations for each consecutive version upgrade + for oldVersion, newVersion := config.SchemaVersion, config.SchemaVersion+1; newVersion <= SchemaVersion; { + + err := upgradeConfigSchema(oldVersion, newVersion) + if err != nil { + log.Fatal(err) + } + + // Increment old and new versions + oldVersion++ + newVersion++ + } + + // Save the current schema version + config.SchemaVersion = SchemaVersion + + return nil +} + +// Upgrade from oldVersion to newVersion +func upgradeConfigSchema(oldVersion int, newVersion int) error { + + if oldVersion == 0 && newVersion == 1 { + log.Printf("Updating schema from %d to %d", oldVersion, newVersion) + + // The first schema upgrade: + // Added "ID" field to "filter" -- we need to populate this field now + // Added "config.ourDataDir" -- where we will now store filters contents + for i := range config.Filters { + + filter := &config.Filters[i] // otherwise we will be operating on a copy + + log.Printf("Seting ID=%d for filter %s", i, filter.URL) + filter.ID = i + 1 // start with ID=1 + + // Forcibly update the filter + _, err := filter.update(true) + if err != nil { + log.Fatal(err) + } + } + + // No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/ + dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") + _, err := os.Stat(dnsFilterPath) + if !os.IsNotExist(err) { + log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) + err = os.Remove(dnsFilterPath) + if err != nil { + log.Printf("Cannot remove %s due to %s", dnsFilterPath, err) + } + } } return nil diff --git a/config.go b/config.go index 09e89ee1..5ac55015 100644 --- a/config.go +++ b/config.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "gopkg.in/yaml.v2" "io/ioutil" "log" "os" @@ -10,46 +11,61 @@ import ( "sync" "text/template" "time" - - "gopkg.in/yaml.v2" ) +// Current schema version. We compare it with the value from +// the configuration file and perform necessary upgrade operations if needed +const SchemaVersion = 1 + +// Directory where we'll store all downloaded filters contents +const FiltersDir = "filters" + // configuration is loaded from YAML type configuration struct { ourConfigFilename string ourBinaryDir string + // Directory to store data (i.e. filters contents) + ourDataDir string - BindHost string `yaml:"bind_host"` - BindPort int `yaml:"bind_port"` - AuthName string `yaml:"auth_name"` - AuthPass string `yaml:"auth_pass"` - CoreDNS coreDNSConfig `yaml:"coredns"` - Filters []filter `yaml:"filters"` - UserRules []string `yaml:"user_rules"` + // Schema version of the config file. This value is used when performing the app updates. + SchemaVersion int `yaml:"schema_version"` + BindHost string `yaml:"bind_host"` + BindPort int `yaml:"bind_port"` + AuthName string `yaml:"auth_name"` + AuthPass string `yaml:"auth_pass"` + CoreDNS coreDNSConfig `yaml:"coredns"` + Filters []filter `yaml:"filters"` + UserRules []string `yaml:"user_rules"` sync.RWMutex `yaml:"-"` } +type coreDnsFilter struct { + ID int `yaml:"-"` + Path string `yaml:"-"` +} + type coreDNSConfig struct { binaryFile string coreFile string - FilterFile string `yaml:"-"` - Port int `yaml:"port"` - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - ParentalEnabled bool `yaml:"parental_enabled"` - ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL int `yaml:"blocked_response_ttl"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - UpstreamDNS []string `yaml:"upstream_dns"` + Filters []coreDnsFilter `yaml:"-"` + Port int `yaml:"port"` + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + ParentalEnabled bool `yaml:"parental_enabled"` + ParentalSensitivity int `yaml:"parental_sensitivity"` + BlockedResponseTTL int `yaml:"blocked_response_ttl"` + QueryLogEnabled bool `yaml:"querylog_enabled"` + Pprof string `yaml:"-"` + Cache string `yaml:"-"` + Prometheus string `yaml:"-"` + UpstreamDNS []string `yaml:"upstream_dns"` } type filter struct { + ID int `json:"ID"` // auto-assigned when filter is added URL string `json:"url"` Name string `json:"name" yaml:"name"` Enabled bool `json:"enabled"` @@ -63,13 +79,13 @@ var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} // initialize to default values, will be changed later when reading config or parsing command line var config = configuration{ ourConfigFilename: "AdGuardHome.yaml", + ourDataDir: "data", BindPort: 3000, BindHost: "127.0.0.1", CoreDNS: coreDNSConfig{ Port: 53, - binaryFile: "coredns", // only filename, no path - coreFile: "Corefile", // only filename, no path - FilterFile: "dnsfilter.txt", // only filename, no path + binaryFile: "coredns", // only filename, no path + coreFile: "Corefile", // only filename, no path ProtectionEnabled: true, FilteringEnabled: true, SafeBrowsingEnabled: false, @@ -80,13 +96,33 @@ var config = configuration{ Prometheus: "prometheus :9153", }, Filters: []filter{ - {Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt"}, - {Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, - {Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, - {Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, + {ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, + {ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, + {ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, + {ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, } +// Creates a helper object for working with the user rules +func getUserFilter() filter { + + // TODO: This should be calculated when UserRules are set + contents := []byte{} + for _, rule := range config.UserRules { + contents = append(contents, []byte(rule)...) + contents = append(contents, '\n') + } + + userFilter := filter{ + // User filter always has ID=0 + ID: 0, + contents: contents, + Enabled: true, + } + + return userFilter +} + func parseConfig() error { configfile := filepath.Join(config.ourBinaryDir, config.ourConfigFilename) log.Printf("Reading YAML file: %s", configfile) @@ -117,16 +153,19 @@ func writeConfig() error { log.Printf("Couldn't generate YAML file: %s", err) return err } - err = ioutil.WriteFile(configfile+".tmp", yamlText, 0644) + err = writeFileSafe(configfile, yamlText) if err != nil { - log.Printf("Couldn't write YAML config: %s", err) + log.Printf("Couldn't save YAML config: %s", err) return err } - err = os.Rename(configfile+".tmp", configfile) + + userFilter := getUserFilter() + err = userFilter.save() if err != nil { - log.Printf("Couldn't rename YAML config: %s", err) + log.Printf("Couldn't save the user filter: %s", err) return err } + return nil } @@ -141,15 +180,12 @@ func writeCoreDNSConfig() error { log.Printf("Couldn't generate DNS config: %s", err) return err } - err = ioutil.WriteFile(corefile+".tmp", []byte(configtext), 0644) + err = writeFileSafe(corefile, []byte(configtext)) if err != nil { - log.Printf("Couldn't write DNS config: %s", err) + log.Printf("Couldn't save DNS config: %s", err) + return err } - err = os.Rename(corefile+".tmp", corefile) - if err != nil { - log.Printf("Couldn't rename DNS config: %s", err) - } - return err + return nil } func writeAllConfigs() error { @@ -167,12 +203,17 @@ func writeAllConfigs() error { } const coreDNSConfigTemplate = `.:{{.Port}} { - {{if .ProtectionEnabled}}dnsfilter {{if .FilteringEnabled}}{{.FilterFile}}{{end}} { + {{if .ProtectionEnabled}}dnsfilter { {{if .SafeBrowsingEnabled}}safebrowsing{{end}} {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}} {{if .SafeSearchEnabled}}safesearch{{end}} {{if .QueryLogEnabled}}querylog{{end}} blocked_ttl {{.BlockedResponseTTL}} + {{if .FilteringEnabled}} + {{range .Filters}} + filter {{.ID}} "{{.Path}}" + {{end}} + {{end}} }{{end}} {{.Pprof}} hosts { @@ -196,7 +237,28 @@ func generateCoreDNSConfigText() (string, error) { var configBytes bytes.Buffer temporaryConfig := config.CoreDNS - temporaryConfig.FilterFile = filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile) + + // fill the list of filters + filters := make([]coreDnsFilter, 0) + + // first of all, append the user filter + userFilter := getUserFilter() + + // TODO: Don't add if empty + //if len(userFilter.contents) > 0 { + filters = append(filters, coreDnsFilter{ID: userFilter.ID, Path: userFilter.getFilterFilePath()}) + //} + + // then go through other filters + for i := range config.Filters { + filter := &config.Filters[i] + + if filter.Enabled && len(filter.contents) > 0 { + filters = append(filters, coreDnsFilter{ID: filter.ID, Path: filter.getFilterFilePath()}) + } + } + temporaryConfig.Filters = filters + // run the template err = t.Execute(&configBytes, &temporaryConfig) if err != nil { diff --git a/control.go b/control.go index 1afd4e24..828909be 100644 --- a/control.go +++ b/control.go @@ -16,7 +16,6 @@ import ( "time" coredns_plugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/miekg/dns" "gopkg.in/asaskevich/govalidator.v4" ) @@ -423,7 +422,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { } } - ok, err := filter.update(time.Now()) + ok, err := filter.update(true) if err != nil { errortext := fmt.Sprintf("Couldn't fetch filter from url %s: %s", filter.URL, err) log.Println(errortext) @@ -452,14 +451,9 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { http.Error(w, errortext, http.StatusInternalServerError) return } - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } + tellCoreDNSToReload() + _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) if err != nil { errortext := fmt.Sprintf("Couldn't write body: %s", err) @@ -468,6 +462,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { } } +// TODO: Start using filter ID func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { @@ -493,19 +488,22 @@ func handleFilteringRemoveURL(w http.ResponseWriter, r *http.Request) { for _, filter := range config.Filters { if filter.URL != url { newFilters = append(newFilters, filter) + } else { + // Remove the filter file + err := os.Remove(filter.getFilterFilePath()) + if err != nil { + errortext := fmt.Sprintf("Couldn't remove the filter file: %s", err) + http.Error(w, errortext, http.StatusInternalServerError) + return + } } } + // Update the configuration after removing filter files config.Filters = newFilters - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } +// TODO: Start using filter ID func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { @@ -542,16 +540,10 @@ func handleFilteringEnableURL(w http.ResponseWriter, r *http.Request) { // kick off refresh of rules from new URLs refreshFiltersIfNeccessary() - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } +// TODO: Start using filter ID func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { parameters, err := parseParametersFromBody(r.Body) if err != nil { @@ -586,13 +578,6 @@ func handleFilteringDisableURL(w http.ResponseWriter, r *http.Request) { return } - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -606,13 +591,6 @@ func handleFilteringSetRules(w http.ResponseWriter, r *http.Request) { } config.UserRules = strings.Split(string(body), "\n") - err = writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - return - } httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -639,7 +617,6 @@ func runFilterRefreshers() { } func refreshFiltersIfNeccessary() int { - now := time.Now() config.Lock() // deduplicate @@ -663,7 +640,7 @@ func refreshFiltersIfNeccessary() int { updateCount := 0 for i := range config.Filters { filter := &config.Filters[i] // otherwise we will be operating on a copy - updated, err := filter.update(now) + updated, err := filter.update(false) if err != nil { log.Printf("Failed to update filter %s: %s\n", filter.URL, err) continue @@ -675,27 +652,25 @@ func refreshFiltersIfNeccessary() int { config.Unlock() if updateCount > 0 { - err := writeFilterFile() - if err != nil { - errortext := fmt.Sprintf("Couldn't write filter file: %s", err) - log.Println(errortext) - } tellCoreDNSToReload() } return updateCount } -func (filter *filter) update(now time.Time) (bool, error) { +// Checks for filters updates +// If "force" is true -- does not check the filter's LastUpdated field +func (filter *filter) update(force bool) (bool, error) { if !filter.Enabled { return false, nil } - elapsed := time.Since(filter.LastUpdated) - if elapsed <= updatePeriod { + if !force && time.Since(filter.LastUpdated) <= updatePeriod { return false, nil } + log.Printf("Downloading update for filter %d", filter.ID) + // use same update period for failed filter downloads to avoid flooding with requests - filter.LastUpdated = now + filter.LastUpdated = time.Now() resp, err := client.Get(filter.URL) if resp != nil && resp.Body != nil { @@ -706,9 +681,15 @@ func (filter *filter) update(now time.Time) (bool, error) { return false, err } - if resp.StatusCode >= 400 { + if resp.StatusCode != 200 { log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) - return false, fmt.Errorf("Got status code >= 400: %d", resp.StatusCode) + return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) + } + + contentType := strings.ToLower(resp.Header.Get("content-type")) + if !strings.HasPrefix(contentType, "text/plain") { + log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) + return false, fmt.Errorf("non-text response %s", contentType) } body, err := ioutil.ReadAll(resp.Body) @@ -717,11 +698,12 @@ func (filter *filter) update(now time.Time) (bool, error) { return false, err } - // extract filter name and count number of rules + // Extract filter name and count number of rules lines := strings.Split(string(body), "\n") rulesCount := 0 seenTitle := false - d := dnsfilter.New() + + // Count lines in the filter for _, line := range lines { line = strings.TrimSpace(line) if len(line) > 0 && line[0] == '!' { @@ -730,61 +712,73 @@ func (filter *filter) update(now time.Time) (bool, error) { seenTitle = true } } else if len(line) != 0 { - err = d.AddRule(line, 0) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s from %s: %s", line, filter.URL, err) - // Just ignore invalid rules - continue - } rulesCount++ } } + + // Check if the filter was really changed if bytes.Equal(filter.contents, body) { return false, nil } + log.Printf("Filter %s updated: %d bytes, %d rules", filter.URL, len(body), rulesCount) filter.RulesCount = rulesCount filter.contents = body + + // Saving it to the filters dir now + err = filter.save() + if err != nil { + return false, nil + } + return true, nil } -// write filter file -func writeFilterFile() error { - filterpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.FilterFile) - log.Printf("Writing filter file: %s", filterpath) - // TODO: check if file contents have modified - data := []byte{} - config.RLock() - filters := config.Filters - for _, filter := range filters { - if !filter.Enabled { - continue - } - data = append(data, filter.contents...) - data = append(data, '\n') - } - for _, rule := range config.UserRules { - data = append(data, []byte(rule)...) - data = append(data, '\n') - } - config.RUnlock() - err := ioutil.WriteFile(filterpath+".tmp", data, 0644) +// saves filter contents to the file in config.ourDataDir +func (filter *filter) save() error { + + filterFilePath := filter.getFilterFilePath() + log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + + err := writeFileSafe(filterFilePath, filter.contents) if err != nil { - log.Printf("Couldn't write filter file: %s", err) return err } - err = os.Rename(filterpath+".tmp", filterpath) - if err != nil { - log.Printf("Couldn't rename filter file: %s", err) + return nil; +} + +// loads filter contents from the file in config.ourDataDir +func (filter *filter) load() error { + + if !filter.Enabled { + // No need to load a filter that is not enabled + return nil + } + + filterFilePath := filter.getFilterFilePath() + log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) + + if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + // do nothing, file doesn't exist return err } + + filterFile, err := ioutil.ReadFile(filterFilePath) + if err != nil { + return err + } + + log.Printf("Filter %d length is %d", filter.ID, len(filterFile)) + filter.contents = filterFile return nil } +// Path to the filter contents +func (filter *filter) getFilterFilePath() string { + return filepath.Join(config.ourBinaryDir, config.ourDataDir, FiltersDir, strconv.Itoa(filter.ID) + ".txt") +} + // ------------ // safebrowsing // ------------ diff --git a/coredns.go b/coredns.go index b6941f2b..5dbe01b4 100644 --- a/coredns.go +++ b/coredns.go @@ -120,12 +120,6 @@ func startDNSServer() error { log.Println(errortext) return errortext } - err = writeFilterFile() - if err != nil { - errortext := fmt.Errorf("Couldn't write filter file: %s", err) - log.Println(errortext) - return errortext - } go coremain.Run() return nil diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go index 31b147cf..e2d3f821 100644 --- a/coredns_plugin/coredns_plugin.go +++ b/coredns_plugin/coredns_plugin.go @@ -51,11 +51,17 @@ var ( lookupCache = map[string]cacheEntry{} ) +type plugFilter struct { + ID uint32 + Path string +} + type plugSettings struct { SafeBrowsingBlockHost string ParentalBlockHost string QueryLogEnabled bool BlockedTTL uint32 // in seconds, default 3600 + Filters []plugFilter } type plug struct { @@ -71,6 +77,7 @@ var defaultPluginSettings = plugSettings{ SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", ParentalBlockHost: "family.block.dns.adguard.com", BlockedTTL: 3600, // in seconds + Filters: make([]plugFilter, 0), } // @@ -83,14 +90,12 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { d: dnsfilter.New(), } - filterFileNames := []string{} + log.Println("Initializing the CoreDNS plugin") + for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - filterFileNames = append(filterFileNames, args...) - } for c.NextBlock() { - switch c.Val() { + blockValue := c.Val() + switch blockValue { case "safebrowsing": p.d.EnableSafeBrowsing() if c.NextArg() { @@ -130,17 +135,38 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { p.settings.BlockedTTL = uint32(blockttl) case "querylog": p.settings.QueryLogEnabled = true + + case "filter": + if !c.NextArg() { + return nil, c.ArgErr() + } + + filterId, err := strconv.Atoi(c.Val()) + if err != nil { + return nil, c.ArgErr() + } + if !c.NextArg() { + return nil, c.ArgErr() + } + filterPath := c.Val() + + // Initialize filter and add it to the list + p.settings.Filters = append(p.settings.Filters, plugFilter{ + ID: uint32(filterId), + Path: filterPath, + }) } } } - log.Printf("filterFileNames = %+v", filterFileNames) + for _, filter := range p.settings.Filters { + log.Printf("Loading rules from %s", filter.Path) - for i, filterFileName := range filterFileNames { - file, err := os.Open(filterFileName) + file, err := os.Open(filter.Path) if err != nil { return nil, err } + //noinspection GoDeferInLoop defer file.Close() count := 0 @@ -148,7 +174,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { for scanner.Scan() { text := scanner.Text() - err = p.d.AddRule(text, uint32(i)) + err = p.d.AddRule(text, filter.ID) if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { continue } @@ -159,7 +185,7 @@ func setupPlugin(c *caddy.Controller) (*plug, error) { } count++ } - log.Printf("Added %d rules from %s", count, filterFileName) + log.Printf("Added %d rules from %d", count, filter.ID) if err = scanner.Err(); err != nil { return nil, err @@ -250,6 +276,7 @@ func (p *plug) onFinalShutdown() error { type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) +//noinspection GoUnusedParameter func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { realch, ok := ch.(chan<- *prometheus.Desc) if !ok { @@ -391,7 +418,7 @@ func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.M func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { if len(r.Question) != 1 { // google DNS, bind and others do the same - return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("Got DNS request with != 1 questions") + return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") } for _, question := range r.Question { host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go index 2f65cf9a..f6b85e9d 100644 --- a/coredns_plugin/coredns_plugin_test.go +++ b/coredns_plugin/coredns_plugin_test.go @@ -15,6 +15,7 @@ import ( "github.com/miekg/dns" ) +// TODO: Change tests -- there's new config template now func TestSetup(t *testing.T) { for i, testcase := range []struct { config string diff --git a/helpers.go b/helpers.go index 6d598224..7ae69b8a 100644 --- a/helpers.go +++ b/helpers.go @@ -5,21 +5,39 @@ import ( "errors" "fmt" "io" + "io/ioutil" "net/http" "os" "path" + "path/filepath" "runtime" "strings" ) -func clamp(value, low, high int) int { - if value < low { - return low +// ---------------------------------- +// helper functions for working with files +// ---------------------------------- + +// Writes data first to a temporary file and then renames it to what's specified in path +func writeFileSafe(path string, data []byte) error { + + dir := filepath.Dir(path) + err := os.MkdirAll(dir, 0755) + if err != nil { + return err } - if value > high { - return high + + tmpPath := path + ".tmp" + err = ioutil.WriteFile(tmpPath, data, 0644) + if err != nil { + return err } - return value + err = os.Rename(tmpPath, path) + if err != nil { + return err + } + + return nil } // ---------------------------------- @@ -117,13 +135,6 @@ func parseParametersFromBody(r io.Reader) (map[string]string, error) { // --------------------- // debug logging helpers // --------------------- -func _Func() string { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - return path.Base(f.Name()) -} - func trace(format string, args ...interface{}) { pc := make([]uintptr, 10) // at least 1 entry needed runtime.Callers(2, pc)