diff --git a/.gitignore b/.gitignore index db6d4a1d..20902175 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,11 @@ .DS_Store -.vscode -.idea -debug +/.vscode +/.idea /AdGuardHome /AdGuardHome.yaml /data/ /build/ /client/node_modules/ -/coredns -/Corefile -/dnsfilter.txt /querylog.json /querylog.json.1 /scripts/translations/node_modules diff --git a/Makefile b/Makefile index 038602fc..9e9ae505 100644 --- a/Makefile +++ b/Makefile @@ -19,9 +19,12 @@ client/node_modules: client/package.json client/package-lock.json $(STATIC): $(JSFILES) client/node_modules npm --prefix client run build-prod -$(TARGET): $(STATIC) *.go coredns_plugin/*.go dnsfilter/*.go - GOPATH=$(GOPATH) GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... - GOPATH=$(GOPATH) PATH=$(GOPATH)/bin:$(PATH) packr build -ldflags="-X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" -o $(TARGET) +$(TARGET): $(STATIC) *.go dnsfilter/*.go dnsforward/*.go + go get -d . + GOOS=$(NATIVE_GOOS) GOARCH=$(NATIVE_GOARCH) GO111MODULE=off go get -v github.com/gobuffalo/packr/... + PATH=$(GOPATH)/bin:$(PATH) packr -z + CGO_ENABLED=0 go build -ldflags="-s -w -X main.VersionString=$(GIT_VERSION)" -asmflags="-trimpath=$(PWD)" -gcflags="-trimpath=$(PWD)" + PATH=$(GOPATH)/bin:$(PATH) packr clean clean: $(MAKE) cleanfast diff --git a/README.md b/README.md index 17cec80c..438c3eb1 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ Now open the browser and navigate to http://localhost:3000/ to control your AdGu You can run AdGuard Home without superuser privileges, but you need to instruct it to use a different port rather than 53. You can do that by editing `AdGuardHome.yaml` and finding these two lines: ```yaml -coredns: +dns: port: 53 ``` @@ -104,25 +104,32 @@ Upon the first execution, a file named `AdGuardHome.yaml` will be created, with Settings are stored in [YAML format](https://en.wikipedia.org/wiki/YAML), possible parameters that you can configure are listed below: - * `bind_host` — Web interface IP address to listen on - * `bind_port` — Web interface IP port to listen on - * `auth_name` — Web interface optional authorization username - * `auth_pass` — Web interface optional authorization password - * `coredns` — CoreDNS configuration section - * `port` — DNS server port to listen on - * `filtering_enabled` — Filtering of DNS requests based on filter lists - * `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing - * `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible - * `parental_enabled` — Parental control-based DNS requests filtering - * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 - * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistic purposes) - * `bootstrap_dns` — DNS server used for initial hostnames resolution in case if upstream is DoH or DoT with a hostname - * `upstream_dns` — List of upstream DNS servers + * `bind_host` — Web interface IP address to listen on. + * `bind_port` — Web interface IP port to listen on. + * `auth_name` — Web interface optional authorization username. + * `auth_pass` — Web interface optional authorization password. + * `dns` — DNS configuration section. + * `port` — DNS server port to listen on. + * `protection_enabled` — Whether any kind of filtering and protection should be done, when off it works as a plain dns forwarder. + * `filtering_enabled` — Filtering of DNS requests based on filter lists. + * `blocked_response_ttl` — For how many seconds the clients should cache a filtered response. Low values are useful on LAN if you change filters very often, high values are useful to increase performance and save traffic. + * `querylog_enabled` — Query logging (also used to calculate top 50 clients, blocked domains and requested domains for statistical purposes). + * `ratelimit` — DDoS protection, specifies in how many packets per second a client should receive. Anything above that is silently dropped. To disable set 0, default is 20. Safe to disable if DNS server is not available from internet. + * `ratelimit_whitelist` — If you want exclude some IP addresses from ratelimiting but keep ratelimiting on for others, put them here. + * `refuse_any` — Another DDoS protection mechanism. Requests of type ANY are rarely needed, so refusing to serve them mitigates against attackers trying to use your DNS as a reflection. Safe to disable if DNS server is not available from internet. + * `bootstrap_dns` — DNS server used for initial hostname resolution in case if upstream server name is a hostname. + * `parental_sensitivity` — Age group for parental control-based filtering, must be either 3, 10, 13 or 17 if enabled. + * `parental_enabled` — Parental control-based DNS requests filtering. + * `safesearch_enabled` — Enforcing "Safe search" option for search engines, when possible. + * `safebrowsing_enabled` — Filtering of DNS requests based on safebrowsing. + * `upstream_dns` — List of upstream DNS servers. * `filters` — List of filters, each filter has the following values: - * `ID` - filter ID (must be unique) - * `url` — URL pointing to the filter contents (filtering rules) - * `enabled` — Current filter's status (enabled/disabled) - * `user_rules` — User-specified filtering rules + * `enabled` — Current filter's status (enabled/disabled). + * `url` — URL pointing to the filter contents (filtering rules). + * `name` — Name of the filter. If it's an adguard syntax filter it will get updated automatically, otherwise it stays unchanged. + * `last_updated` — Time when the filter was last updated from server. + * `ID` - filter ID (must be unique). + * `user_rules` — User-specified filtering rules. Removing an entry from settings file will reset it to the default value. Deleting the file will reset all settings to the default values. @@ -151,7 +158,15 @@ cd AdGuardHome make ``` -## How to update translations +## Contributing + +You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls + +### How to update translations + +If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations + +Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384 Before updating translations you need to install dependencies: ``` @@ -181,14 +196,6 @@ node upload.js node download.js ``` -## Contributing - -You are welcome to fork this repository, make your changes and submit a pull request — https://github.com/AdguardTeam/AdGuardHome/pulls - -If you want to help with AdGuard Home translations, please learn more about translating AdGuard products here: https://kb.adguard.com/en/general/adguard-translations - -Here is a direct link to AdGuard Home project: http://translate.adguard.com/collaboration/project?id=153384 - ## Reporting issues If you run into any problem or have a suggestion, head to [this page](https://github.com/AdguardTeam/AdGuardHome/issues) and click on the `New issue` button. @@ -198,7 +205,6 @@ If you run into any problem or have a suggestion, head to [this page](https://gi This software wouldn't have been possible without: * [Go](https://golang.org/dl/) and it's libraries: - * [CoreDNS](https://coredns.io) * [packr](https://github.com/gobuffalo/packr) * [gcache](https://github.com/bluele/gcache) * [miekg's dns](https://github.com/miekg/dns) @@ -209,4 +215,6 @@ This software wouldn't have been possible without: * And many more node.js packages. * [whotracks.me data](https://github.com/cliqz-oss/whotracks.me) +You might have seen that [CoreDNS](https://coredns.io) was mentioned here before — we've stopped using it in AdGuardHome. While we still use it on our servers for [AdGuard DNS](https://adguard.com/adguard-dns/overview.html) service, it seemed like an overkill for Home as it impeded with Home features that we plan to implement. + For a full list of all node.js packages in use, please take a look at [client/package.json](https://github.com/AdguardTeam/AdGuardHome/blob/master/client/package.json) file. diff --git a/app.go b/app.go index 85cb96ae..bbe36359 100644 --- a/app.go +++ b/app.go @@ -7,8 +7,10 @@ import ( "net" "net/http" "os" + "os/signal" "path/filepath" "strconv" + "syscall" "time" "github.com/gobuffalo/packr" @@ -149,7 +151,7 @@ func main() { log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err) // clear LastUpdated so it gets fetched right away } - if len(filter.Contents) == 0 { + if len(filter.Rules) == 0 { filter.LastUpdated = time.Time{} } } @@ -164,10 +166,13 @@ func main() { } }() - // Eat all args so that coredns can start happily - if len(os.Args) > 1 { - os.Args = os.Args[:1] - } + signalChannel := make(chan os.Signal) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + <-signalChannel + cleanup() + os.Exit(0) + }() // Save the updated config err := config.write() @@ -192,6 +197,13 @@ func main() { log.Fatal(http.ListenAndServe(address, nil)) } +func cleanup() { + err := stopDNSServer() + if err != nil { + log.Printf("Couldn't stop DNS server: %s", err) + } +} + func getInput() (string, error) { scanner := bufio.NewScanner(os.Stdin) scanner.Scan() diff --git a/config.go b/config.go index d141706b..1f7464b2 100644 --- a/config.go +++ b/config.go @@ -1,43 +1,36 @@ package main import ( - "bytes" "io/ioutil" "log" "os" "path/filepath" - "regexp" "sync" - "text/template" - "time" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/dnsforward" "gopkg.in/yaml.v2" ) const ( - currentSchemaVersion = 1 // used for upgrading from old configs to new config - dataDir = "data" // data storage - filterDir = "filters" // cache location for downloaded filters, it's under DataDir - userFilterID = 0 // special filter ID, always 0 + dataDir = "data" // data storage + filterDir = "filters" // cache location for downloaded filters, it's under DataDir ) -// Just a counter that we use for incrementing the filter ID -var nextFilterID int64 = time.Now().Unix() - // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { ourConfigFilename string // Config filename (can be overriden via the command line arguments) ourBinaryDir string // Location of our directory, used to protect against CWD being somewhere else - BindHost string `yaml:"bind_host"` - BindPort int `yaml:"bind_port"` - AuthName string `yaml:"auth_name"` - AuthPass string `yaml:"auth_pass"` - Language string `yaml:"language"` // two-letter ISO 639-1 language code - CoreDNS coreDNSConfig `yaml:"coredns"` - Filters []filter `yaml:"filters"` - UserRules []string `yaml:"user_rules,omitempty"` + BindHost string `yaml:"bind_host"` + BindPort int `yaml:"bind_port"` + AuthName string `yaml:"auth_name"` + AuthPass string `yaml:"auth_pass"` + Language string `yaml:"language"` // two-letter ISO 639-1 language code + DNS dnsConfig `yaml:"dns"` + Filters []filter `yaml:"filters"` + UserRules []string `yaml:"user_rules"` sync.RWMutex `yaml:"-"` @@ -45,38 +38,12 @@ type configuration struct { } // field ordering is important -- yaml fields will mirror ordering from here -type coreDNSConfig struct { - binaryFile string - coreFile string - Filters []filter `yaml:"-"` - Port int `yaml:"port"` - ProtectionEnabled bool `yaml:"protection_enabled"` - FilteringEnabled bool `yaml:"filtering_enabled"` - SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` - SafeSearchEnabled bool `yaml:"safesearch_enabled"` - ParentalEnabled bool `yaml:"parental_enabled"` - ParentalSensitivity int `yaml:"parental_sensitivity"` - BlockedResponseTTL int `yaml:"blocked_response_ttl"` - QueryLogEnabled bool `yaml:"querylog_enabled"` - Ratelimit int `yaml:"ratelimit"` - RefuseAny bool `yaml:"refuse_any"` - Pprof string `yaml:"-"` - Cache string `yaml:"-"` - Prometheus string `yaml:"-"` - BootstrapDNS string `yaml:"bootstrap_dns"` - UpstreamDNS []string `yaml:"upstream_dns"` -} +type dnsConfig struct { + Port int `yaml:"port"` -// field ordering is important -- yaml fields will mirror ordering from here -type filter struct { - Enabled bool `json:"enabled"` - URL string `json:"url"` - Name string `json:"name" yaml:"name"` - RulesCount int `json:"rulesCount" yaml:"-"` - LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` - ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase + dnsforward.FilteringConfig `yaml:",inline"` - Contents []byte `json:"-" yaml:"-"` // not in yaml or json + UpstreamDNS []string `yaml:"upstream_dns"` } var defaultDNS = []string{"tls://1.1.1.1", "tls://1.0.0.1"} @@ -86,47 +53,26 @@ var config = configuration{ ourConfigFilename: "AdGuardHome.yaml", BindPort: 3000, BindHost: "127.0.0.1", - CoreDNS: coreDNSConfig{ - Port: 53, - binaryFile: "coredns", // only filename, no path - coreFile: "Corefile", // only filename, no path - ProtectionEnabled: true, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - BlockedResponseTTL: 10, // in seconds - QueryLogEnabled: true, - Ratelimit: 20, - RefuseAny: true, - BootstrapDNS: "8.8.8.8:53", - UpstreamDNS: defaultDNS, - Cache: "cache", - Prometheus: "prometheus :9153", + DNS: dnsConfig{ + Port: 53, + FilteringConfig: dnsforward.FilteringConfig{ + ProtectionEnabled: true, // whether or not use any of dnsfilter features + FilteringEnabled: true, // whether or not use filter lists + BlockedResponseTTL: 10, // in seconds + QueryLogEnabled: true, + Ratelimit: 20, + RefuseAny: true, + BootstrapDNS: "8.8.8.8:53", + }, + UpstreamDNS: defaultDNS, }, Filters: []filter{ - {ID: 1, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, - {ID: 2, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, - {ID: 3, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, - {ID: 4, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, + {Filter: dnsfilter.Filter{ID: 1}, Enabled: true, URL: "https://adguardteam.github.io/AdGuardSDNSFilter/Filters/filter.txt", Name: "AdGuard Simplified Domain Names filter"}, + {Filter: dnsfilter.Filter{ID: 2}, Enabled: false, URL: "https://adaway.org/hosts.txt", Name: "AdAway"}, + {Filter: dnsfilter.Filter{ID: 3}, Enabled: false, URL: "https://hosts-file.net/ad_servers.txt", Name: "hpHosts - Ad and Tracking servers only"}, + {Filter: dnsfilter.Filter{ID: 4}, Enabled: false, URL: "http://www.malwaredomainlist.com/hostslist/hosts.txt", Name: "MalwareDomainList.com Hosts List"}, }, -} - -// Creates a helper object for working with the user rules -func userFilter() filter { - // TODO: This should be calculated when UserRules are set - var contents []byte - for _, rule := range config.UserRules { - contents = append(contents, []byte(rule)...) - contents = append(contents, '\n') - } - - userFilter := filter{ - // User filter always has constant ID=0 - ID: userFilterID, - Contents: contents, - Enabled: true, - } - - return userFilter + SchemaVersion: currentSchemaVersion, } // Loads configuration from the YAML file @@ -150,20 +96,7 @@ func parseConfig() error { } // Deduplicate filters - { - i := 0 // output index, used for deletion later - urls := map[string]bool{} - for _, filter := range config.Filters { - if _, ok := urls[filter.URL]; !ok { - // we didn't see it before, keep it - urls[filter.URL] = true // remember the URL - config.Filters[i] = filter - i++ - } - } - // all entries we want to keep are at front, delete the rest - config.Filters = config.Filters[:i] - } + deduplicateFilters() updateUniqueFilterID(config.Filters) @@ -187,6 +120,16 @@ func (c *configuration) write() error { return err } + return nil +} + +func writeAllConfigs() error { + err := config.write() + if err != nil { + log.Printf("Couldn't write config: %s", err) + return err + } + userFilter := userFilter() err = userFilter.save() if err != nil { @@ -196,112 +139,3 @@ func (c *configuration) write() error { return nil } - -// -------------- -// coredns config -// -------------- -func writeCoreDNSConfig() error { - coreFile := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) - log.Printf("Writing DNS config: %s", coreFile) - configText, err := generateCoreDNSConfigText() - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return err - } - err = safeWriteFile(coreFile, []byte(configText)) - if err != nil { - log.Printf("Couldn't save DNS config: %s", err) - return err - } - return nil -} - -func writeAllConfigs() error { - err := config.write() - if err != nil { - log.Printf("Couldn't write our config: %s", err) - return err - } - err = writeCoreDNSConfig() - if err != nil { - log.Printf("Couldn't write DNS config: %s", err) - return err - } - return nil -} - -const coreDNSConfigTemplate = `.:{{.Port}} { - {{if .ProtectionEnabled}}dnsfilter { - {{if .SafeBrowsingEnabled}}safebrowsing{{end}} - {{if .ParentalEnabled}}parental {{.ParentalSensitivity}}{{end}} - {{if .SafeSearchEnabled}}safesearch{{end}} - {{if .QueryLogEnabled}}querylog{{end}} - blocked_ttl {{.BlockedResponseTTL}} - {{if .FilteringEnabled}}{{range .Filters}}{{if and .Enabled .Contents}} - filter {{.ID}} "{{.Path}}" - {{end}}{{end}}{{end}} - }{{end}} - {{.Pprof}} - {{if .RefuseAny}}refuseany{{end}} - {{if gt .Ratelimit 0}}ratelimit {{.Ratelimit}}{{end}} - hosts { - fallthrough - } - {{if .UpstreamDNS}}upstream {{range .UpstreamDNS}}{{.}} {{end}} { bootstrap {{.BootstrapDNS}} }{{end}} - {{.Cache}} - {{.Prometheus}} -} -` - -var removeEmptyLines = regexp.MustCompile("([\t ]*\n)+") - -// generate CoreDNS config text -func generateCoreDNSConfigText() (string, error) { - t, err := template.New("config").Parse(coreDNSConfigTemplate) - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return "", err - } - - var configBytes bytes.Buffer - temporaryConfig := config.CoreDNS - - // generate temporary filter list, needed to put userfilter in coredns config - filters := []filter{} - - // first of all, append the user filter - userFilter := userFilter() - - filters = append(filters, userFilter) - - // then go through other filters - filters = append(filters, config.Filters...) - temporaryConfig.Filters = filters - - // run the template - err = t.Execute(&configBytes, &temporaryConfig) - if err != nil { - log.Printf("Couldn't generate DNS config: %s", err) - return "", err - } - configText := configBytes.String() - - // remove empty lines from generated config - configText = removeEmptyLines.ReplaceAllString(configText, "\n") - return configText, nil -} - -// Set the next filter ID to max(filter.ID) + 1 -func updateUniqueFilterID(filters []filter) { - for _, filter := range filters { - if nextFilterID < filter.ID { - nextFilterID = filter.ID + 1 - } - } -} - -func assignUniqueFilterID() int64 { - value := nextFilterID - nextFilterID += 1 - return value -} diff --git a/control.go b/control.go index 3e4dcabf..2674585c 100644 --- a/control.go +++ b/control.go @@ -1,29 +1,25 @@ package main import ( - "bytes" "encoding/json" "fmt" "io/ioutil" "log" + "net" "net/http" "os" - "path/filepath" - "regexp" "strconv" "strings" "time" - "github.com/AdguardTeam/AdGuardHome/upstream" + "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/miekg/dns" - corednsplugin "github.com/AdguardTeam/AdGuardHome/coredns_plugin" "gopkg.in/asaskevich/govalidator.v4" ) const updatePeriod = time.Minute * 30 -var filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) - // cached version.json to avoid hammering github.io for each page reload var versionCheckJSON []byte var versionCheckLastTime time.Time @@ -36,24 +32,20 @@ var client = &http.Client{ } // ------------------- -// coredns run control +// dns run control // ------------------- -func tellCoreDNSToReload() { - corednsplugin.Reload <- true -} - -func writeAllConfigsAndReloadCoreDNS() error { +func writeAllConfigsAndReloadDNS() error { err := writeAllConfigs() if err != nil { log.Printf("Couldn't write all configs: %s", err) return err } - tellCoreDNSToReload() + reconfigureDNSServer() return nil } func httpUpdateConfigReloadDNSReturnOK(w http.ResponseWriter, r *http.Request) { - err := writeAllConfigsAndReloadCoreDNS() + err := writeAllConfigsAndReloadDNS() if err != nil { errortext := fmt.Sprintf("Couldn't write config file: %s", err) log.Println(errortext) @@ -75,12 +67,12 @@ func returnOK(w http.ResponseWriter, r *http.Request) { func handleStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ "dns_address": config.BindHost, - "dns_port": config.CoreDNS.Port, - "protection_enabled": config.CoreDNS.ProtectionEnabled, - "querylog_enabled": config.CoreDNS.QueryLogEnabled, + "dns_port": config.DNS.Port, + "protection_enabled": config.DNS.ProtectionEnabled, + "querylog_enabled": config.DNS.QueryLogEnabled, "running": isRunning(), - "bootstrap_dns": config.CoreDNS.BootstrapDNS, - "upstream_dns": config.CoreDNS.UpstreamDNS, + "bootstrap_dns": config.DNS.BootstrapDNS, + "upstream_dns": config.DNS.UpstreamDNS, "version": VersionString, "language": config.Language, } @@ -103,12 +95,12 @@ func handleStatus(w http.ResponseWriter, r *http.Request) { } func handleProtectionEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ProtectionEnabled = true + config.DNS.ProtectionEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ProtectionEnabled = false + config.DNS.ProtectionEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -116,12 +108,12 @@ func handleProtectionDisable(w http.ResponseWriter, r *http.Request) { // stats // ----- func handleQueryLogEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.QueryLogEnabled = true + config.DNS.QueryLogEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleQueryLogDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.QueryLogEnabled = false + config.DNS.QueryLogEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } @@ -143,9 +135,9 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { hosts := strings.Fields(string(body)) if len(hosts) == 0 { - config.CoreDNS.UpstreamDNS = defaultDNS + config.DNS.UpstreamDNS = defaultDNS } else { - config.CoreDNS.UpstreamDNS = hosts + config.DNS.UpstreamDNS = hosts } err = writeAllConfigs() @@ -155,7 +147,7 @@ func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { http.Error(w, errorText, http.StatusInternalServerError) return } - tellCoreDNSToReload() + reconfigureDNSServer() _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) if err != nil { errorText := fmt.Sprintf("Couldn't write body: %s", err) @@ -211,23 +203,32 @@ func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } func checkDNS(input string) error { - u, err := upstream.NewUpstream(input, config.CoreDNS.BootstrapDNS) - + log.Printf("Checking if DNS %s works...", input) + u, err := dnsforward.AddressToUpstream(input, "") if err != nil { - return err + return fmt.Errorf("Failed to choose upstream for %s: %s", input, err) } - defer u.Close() - - alive, err := upstream.IsAlive(u) + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err := u.Exchange(&req) if err != nil { return fmt.Errorf("couldn't communicate with DNS server %s: %s", input, err) } - - if !alive { - return fmt.Errorf("DNS server has not passed the healthcheck: %s", input) + if len(reply.Answer) != 1 { + return fmt.Errorf("DNS server %s returned wrong answer", input) + } + if t, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(t.A) { + return fmt.Errorf("DNS server %s returned wrong answer: %v", input, t.A) + } } + log.Printf("DNS %s works OK", input) return nil } @@ -242,7 +243,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { resp, err := client.Get(versionCheckURL) if err != nil { - errortext := fmt.Sprintf("Couldn't get querylog from coredns: %T %s\n", err, err) + errortext := fmt.Sprintf("Couldn't get version check json from %s: %T %s\n", versionCheckURL, err, err) log.Println(errortext) http.Error(w, errortext, http.StatusBadGateway) return @@ -254,7 +255,7 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { // read the body entirely body, err := ioutil.ReadAll(resp.Body) if err != nil { - errortext := fmt.Sprintf("Couldn't read response body: %s", err) + errortext := fmt.Sprintf("Couldn't read response body from %s: %s", versionCheckURL, err) log.Println(errortext) http.Error(w, errortext, http.StatusBadGateway) return @@ -277,18 +278,18 @@ func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { // --------- func handleFilteringEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.FilteringEnabled = true + config.DNS.FilteringEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.FilteringEnabled = false + config.DNS.FilteringEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleFilteringStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.FilteringEnabled, + "enabled": config.DNS.FilteringEnabled, } config.RLock() @@ -376,7 +377,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - // URL is deemed valid, append it to filters, update config, write new filter file and tell coredns to reload it + // URL is deemed valid, append it to filters, update config, write new filter file and tell dns to reload it + // TODO: since we directly feed filters in-memory, revisit if writing configs is always neccessary config.Filters = append(config.Filters, filter) err = writeAllConfigs() if err != nil { @@ -386,7 +388,7 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) { return } - tellCoreDNSToReload() + reconfigureDNSServer() _, err = fmt.Fprintf(w, "OK %d rules\n", filter.RulesCount) if err != nil { @@ -531,199 +533,23 @@ func handleFilteringRefresh(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "OK %d filters updated\n", updated) } -// Sets up a timer that will be checking for filters updates periodically -func periodicallyRefreshFilters() { - for range time.Tick(time.Minute) { - refreshFiltersIfNeccessary(false) - } -} - -// Checks filters updates if necessary -// If force is true, it ignores the filter.LastUpdated field value -func refreshFiltersIfNeccessary(force bool) int { - config.Lock() - - // fetch URLs - updateCount := 0 - for i := range config.Filters { - filter := &config.Filters[i] // otherwise we will be operating on a copy - - if filter.ID == 0 { // protect against users modifying the yaml and removing the ID - filter.ID = assignUniqueFilterID() - } - - updated, err := filter.update(force) - if err != nil { - log.Printf("Failed to update filter %s: %s\n", filter.URL, err) - continue - } - if updated { - // Saving it to the filters dir now - err = filter.save() - if err != nil { - log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) - continue - } - - updateCount++ - } - } - config.Unlock() - - if updateCount > 0 { - tellCoreDNSToReload() - } - return updateCount -} - -// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) -func parseFilterContents(contents []byte) (int, string) { - lines := strings.Split(string(contents), "\n") - rulesCount := 0 - name := "" - seenTitle := false - - // Count lines in the filter - for _, line := range lines { - line = strings.TrimSpace(line) - if len(line) > 0 && line[0] == '!' { - if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { - name = m[0][1] - seenTitle = true - } - } else if len(line) != 0 { - rulesCount++ - } - } - - return rulesCount, name -} - -// Checks for filters updates -// If "force" is true -- does not check the filter's LastUpdated field -// Call "save" to persist the filter contents -func (filter *filter) update(force bool) (bool, error) { - if filter.ID == 0 { // protect against users deleting the ID - filter.ID = assignUniqueFilterID() - } - if !filter.Enabled { - return false, nil - } - if !force && time.Since(filter.LastUpdated) <= updatePeriod { - return false, nil - } - - log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) - - // use the same update period for failed filter downloads to avoid flooding with requests - filter.LastUpdated = time.Now() - - resp, err := client.Get(filter.URL) - if resp != nil && resp.Body != nil { - defer resp.Body.Close() - } - if err != nil { - log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) - return false, err - } - - if resp.StatusCode != 200 { - log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) - return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) - } - - contentType := strings.ToLower(resp.Header.Get("content-type")) - if !strings.HasPrefix(contentType, "text/plain") { - log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) - return false, fmt.Errorf("non-text response %s", contentType) - } - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) - return false, err - } - - // Extract filter name and count number of rules - rulesCount, filterName := parseFilterContents(body) - - if filterName != "" { - filter.Name = filterName - } - - // Check if the filter has been really changed - if bytes.Equal(filter.Contents, body) { - log.Printf("The filter %d text has not changed", filter.ID) - return false, nil - } - - log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) - filter.RulesCount = rulesCount - filter.Contents = body - - return true, nil -} - -// saves filter contents to the file in dataDir -func (filter *filter) save() error { - filterFilePath := filter.Path() - log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) - - return safeWriteFile(filterFilePath, filter.Contents) -} - -// loads filter contents from the file in dataDir -func (filter *filter) load() error { - if !filter.Enabled { - // No need to load a filter that is not enabled - return nil - } - - filterFilePath := filter.Path() - log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) - - if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { - // do nothing, file doesn't exist - return err - } - - filterFileContents, err := ioutil.ReadFile(filterFilePath) - if err != nil { - return err - } - - log.Printf("Filter %d length is %d", filter.ID, len(filterFileContents)) - filter.Contents = filterFileContents - - // Now extract the rules count - rulesCount, _ := parseFilterContents(filter.Contents) - filter.RulesCount = rulesCount - - return nil -} - -// Path to the filter contents -func (filter *filter) Path() string { - return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") -} - // ------------ // safebrowsing // ------------ func handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeBrowsingEnabled = true + config.DNS.SafeBrowsingEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeBrowsingEnabled = false + config.DNS.SafeBrowsingEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.SafeBrowsingEnabled, + "enabled": config.DNS.SafeBrowsingEnabled, } jsonVal, err := json.Marshal(data) if err != nil { @@ -786,22 +612,22 @@ func handleParentalEnable(w http.ResponseWriter, r *http.Request) { http.Error(w, "Sensitivity must be set to valid value", 400) return } - config.CoreDNS.ParentalSensitivity = i - config.CoreDNS.ParentalEnabled = true + config.DNS.ParentalSensitivity = i + config.DNS.ParentalEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleParentalDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.ParentalEnabled = false + config.DNS.ParentalEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleParentalStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.ParentalEnabled, + "enabled": config.DNS.ParentalEnabled, } - if config.CoreDNS.ParentalEnabled { - data["sensitivity"] = config.CoreDNS.ParentalSensitivity + if config.DNS.ParentalEnabled { + data["sensitivity"] = config.DNS.ParentalSensitivity } jsonVal, err := json.Marshal(data) if err != nil { @@ -826,18 +652,18 @@ func handleParentalStatus(w http.ResponseWriter, r *http.Request) { // ------------ func handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeSearchEnabled = true + config.DNS.SafeSearchEnabled = true httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { - config.CoreDNS.SafeSearchEnabled = false + config.DNS.SafeSearchEnabled = false httpUpdateConfigReloadDNSReturnOK(w, r) } func handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { data := map[string]interface{}{ - "enabled": config.CoreDNS.SafeSearchEnabled, + "enabled": config.DNS.SafeSearchEnabled, } jsonVal, err := json.Marshal(data) if err != nil { @@ -861,17 +687,17 @@ func registerControlHandlers() { http.HandleFunc("/control/status", optionalAuth(ensureGET(handleStatus))) http.HandleFunc("/control/enable_protection", optionalAuth(ensurePOST(handleProtectionEnable))) http.HandleFunc("/control/disable_protection", optionalAuth(ensurePOST(handleProtectionDisable))) - http.HandleFunc("/control/querylog", optionalAuth(ensureGET(corednsplugin.HandleQueryLog))) + http.HandleFunc("/control/querylog", optionalAuth(ensureGET(dnsforward.HandleQueryLog))) http.HandleFunc("/control/querylog_enable", optionalAuth(ensurePOST(handleQueryLogEnable))) http.HandleFunc("/control/querylog_disable", optionalAuth(ensurePOST(handleQueryLogDisable))) http.HandleFunc("/control/set_upstream_dns", optionalAuth(ensurePOST(handleSetUpstreamDNS))) http.HandleFunc("/control/test_upstream_dns", optionalAuth(ensurePOST(handleTestUpstreamDNS))) http.HandleFunc("/control/i18n/change_language", optionalAuth(ensurePOST(handleI18nChangeLanguage))) http.HandleFunc("/control/i18n/current_language", optionalAuth(ensureGET(handleI18nCurrentLanguage))) - http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(corednsplugin.HandleStatsTop))) - http.HandleFunc("/control/stats", optionalAuth(ensureGET(corednsplugin.HandleStats))) - http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(corednsplugin.HandleStatsHistory))) - http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(corednsplugin.HandleStatsReset))) + http.HandleFunc("/control/stats_top", optionalAuth(ensureGET(dnsforward.HandleStatsTop))) + http.HandleFunc("/control/stats", optionalAuth(ensureGET(dnsforward.HandleStats))) + http.HandleFunc("/control/stats_history", optionalAuth(ensureGET(dnsforward.HandleStatsHistory))) + http.HandleFunc("/control/stats_reset", optionalAuth(ensurePOST(dnsforward.HandleStatsReset))) http.HandleFunc("/control/version.json", optionalAuth(handleGetVersionJSON)) http.HandleFunc("/control/filtering/enable", optionalAuth(ensurePOST(handleFilteringEnable))) http.HandleFunc("/control/filtering/disable", optionalAuth(ensurePOST(handleFilteringDisable))) diff --git a/coredns.go b/coredns.go deleted file mode 100644 index 376e6210..00000000 --- a/coredns.go +++ /dev/null @@ -1,132 +0,0 @@ -package main - -import ( - "fmt" - "log" - "os" - "path/filepath" - "sync" // Include all plugins. - - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin" - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/ratelimit" - _ "github.com/AdguardTeam/AdGuardHome/coredns_plugin/refuseany" - _ "github.com/AdguardTeam/AdGuardHome/upstream" - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/coremain" - _ "github.com/coredns/coredns/plugin/auto" - _ "github.com/coredns/coredns/plugin/autopath" - _ "github.com/coredns/coredns/plugin/bind" - _ "github.com/coredns/coredns/plugin/cache" - _ "github.com/coredns/coredns/plugin/chaos" - _ "github.com/coredns/coredns/plugin/debug" - _ "github.com/coredns/coredns/plugin/dnssec" - _ "github.com/coredns/coredns/plugin/dnstap" - _ "github.com/coredns/coredns/plugin/erratic" - _ "github.com/coredns/coredns/plugin/errors" - _ "github.com/coredns/coredns/plugin/file" - _ "github.com/coredns/coredns/plugin/forward" - _ "github.com/coredns/coredns/plugin/health" - _ "github.com/coredns/coredns/plugin/hosts" - _ "github.com/coredns/coredns/plugin/loadbalance" - _ "github.com/coredns/coredns/plugin/log" - _ "github.com/coredns/coredns/plugin/loop" - _ "github.com/coredns/coredns/plugin/metadata" - _ "github.com/coredns/coredns/plugin/metrics" - _ "github.com/coredns/coredns/plugin/nsid" - _ "github.com/coredns/coredns/plugin/pprof" - _ "github.com/coredns/coredns/plugin/proxy" - _ "github.com/coredns/coredns/plugin/reload" - _ "github.com/coredns/coredns/plugin/rewrite" - _ "github.com/coredns/coredns/plugin/root" - _ "github.com/coredns/coredns/plugin/secondary" - _ "github.com/coredns/coredns/plugin/template" - _ "github.com/coredns/coredns/plugin/tls" - _ "github.com/coredns/coredns/plugin/whoami" - _ "github.com/mholt/caddy/onevent" -) - -// Directives are registered in the order they should be -// executed. -// -// Ordering is VERY important. Every plugin will -// feel the effects of all other plugin below -// (after) them during a request, but they must not -// care what plugin above them are doing. - -var directives = []string{ - "metadata", - "tls", - "reload", - "nsid", - "root", - "bind", - "debug", - "health", - "pprof", - "prometheus", - "errors", - "log", - "refuseany", - "ratelimit", - "dnsfilter", - "dnstap", - "chaos", - "loadbalance", - "cache", - "rewrite", - "dnssec", - "autopath", - "template", - "hosts", - "file", - "auto", - "secondary", - "loop", - "forward", - "proxy", - "upstream", - "erratic", - "whoami", - "on", -} - -func init() { - dnsserver.Directives = directives -} - -var ( - isCoreDNSRunningLock sync.Mutex - isCoreDNSRunning = false -) - -func isRunning() bool { - isCoreDNSRunningLock.Lock() - value := isCoreDNSRunning - isCoreDNSRunningLock.Unlock() - return value -} - -func startDNSServer() error { - isCoreDNSRunningLock.Lock() - if isCoreDNSRunning { - isCoreDNSRunningLock.Unlock() - return fmt.Errorf("Unable to start coreDNS: Already running") - } - isCoreDNSRunning = true - isCoreDNSRunningLock.Unlock() - - configpath := filepath.Join(config.ourBinaryDir, config.CoreDNS.coreFile) - os.Args = os.Args[:1] - os.Args = append(os.Args, "-conf") - os.Args = append(os.Args, configpath) - - err := writeCoreDNSConfig() - if err != nil { - errortext := fmt.Errorf("Unable to write coredns config: %s", err) - log.Println(errortext) - return errortext - } - - go coremain.Run() - return nil -} diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go deleted file mode 100644 index f3a946dd..00000000 --- a/coredns_plugin/coredns_plugin.go +++ /dev/null @@ -1,557 +0,0 @@ -package dnsfilter - -import ( - "bufio" - "errors" - "fmt" - "log" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/pkg/upstream" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -var defaultSOA = &dns.SOA{ - // values copied from verisign's nonexistent .com domain - // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers - Refresh: 1800, - Retry: 900, - Expire: 604800, - Minttl: 86400, -} - -func init() { - caddy.RegisterPlugin("dnsfilter", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plugFilter struct { - ID int64 - Path string -} - -type plugSettings struct { - SafeBrowsingBlockHost string - ParentalBlockHost string - QueryLogEnabled bool - BlockedTTL uint32 // in seconds, default 3600 - Filters []plugFilter -} - -type plug struct { - d *dnsfilter.Dnsfilter - Next plugin.Handler - upstream upstream.Upstream - settings plugSettings - - sync.RWMutex -} - -var defaultPluginSettings = plugSettings{ - SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", - ParentalBlockHost: "family.block.dns.adguard.com", - BlockedTTL: 3600, // in seconds - Filters: make([]plugFilter, 0), -} - -// -// coredns handling functions -// -func setupPlugin(c *caddy.Controller) (*plug, error) { - // create new Plugin and copy default values - p := &plug{ - settings: defaultPluginSettings, - d: dnsfilter.New(), - } - - log.Println("Initializing the CoreDNS plugin") - - for c.Next() { - for c.NextBlock() { - blockValue := c.Val() - switch blockValue { - case "safebrowsing": - log.Println("Browsing security service is enabled") - p.d.EnableSafeBrowsing() - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.d.SetSafeBrowsingServer(c.Val()) - } - case "safesearch": - log.Println("Safe search is enabled") - p.d.EnableSafeSearch() - case "parental": - if !c.NextArg() { - return nil, c.ArgErr() - } - sensitivity, err := strconv.Atoi(c.Val()) - if err != nil { - return nil, c.ArgErr() - } - - log.Println("Parental control is enabled") - err = p.d.EnableParental(sensitivity) - if err != nil { - return nil, c.ArgErr() - } - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.settings.ParentalBlockHost = c.Val() - } - case "blocked_ttl": - if !c.NextArg() { - return nil, c.ArgErr() - } - blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32) - if err != nil { - return nil, c.ArgErr() - } - log.Printf("Blocked request TTL is %d", blockedTtl) - p.settings.BlockedTTL = uint32(blockedTtl) - case "querylog": - log.Println("Query log is enabled") - p.settings.QueryLogEnabled = true - case "filter": - if !c.NextArg() { - return nil, c.ArgErr() - } - - filterId, err := strconv.ParseInt(c.Val(), 10, 64) - if err != nil { - return nil, c.ArgErr() - } - if !c.NextArg() { - return nil, c.ArgErr() - } - filterPath := c.Val() - - // Initialize filter and add it to the list - p.settings.Filters = append(p.settings.Filters, plugFilter{ - ID: filterId, - Path: filterPath, - }) - } - } - } - - for _, filter := range p.settings.Filters { - log.Printf("Loading rules from %s", filter.Path) - - file, err := os.Open(filter.Path) - if err != nil { - return nil, err - } - defer file.Close() - - count := 0 - scanner := bufio.NewScanner(file) - for scanner.Scan() { - text := scanner.Text() - - err = p.d.AddRule(text, filter.ID) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s: %s", text, err) - // Just ignore invalid rules - continue - } - count++ - } - log.Printf("Added %d rules from filter ID=%d", count, filter.ID) - - if err = scanner.Err(); err != nil { - return nil, err - } - } - - log.Printf("Loading stats from querylog") - err := fillStatsFromQueryLog() - if err != nil { - log.Printf("Failed to load stats from querylog: %s", err) - return nil, err - } - - if p.settings.QueryLogEnabled { - onceQueryLog.Do(func() { - go periodicQueryLogRotate() - go periodicHourlyTopRotate() - go statsRotator() - }) - } - - onceHook.Do(func() { - caddy.RegisterEventHook("dnsfilter-reload", hook) - }) - - p.upstream, err = upstream.New(nil) - if err != nil { - return nil, err - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(requests) - x.MustRegister(filtered) - x.MustRegister(filteredLists) - x.MustRegister(filteredSafebrowsing) - x.MustRegister(filteredParental) - x.MustRegister(whitelisted) - x.MustRegister(safesearch) - x.MustRegister(errorsTotal) - x.MustRegister(elapsedTime) - x.MustRegister(p) - } - return nil - }) - c.OnShutdown(p.onShutdown) - c.OnFinalShutdown(p.onFinalShutdown) - - return nil -} - -func (p *plug) onShutdown() error { - p.Lock() - p.d.Destroy() - p.d = nil - p.Unlock() - return nil -} - -func (p *plug) onFinalShutdown() error { - logBufferLock.Lock() - err := flushToFile(logBuffer) - if err != nil { - log.Printf("failed to flush to file: %s", err) - return err - } - logBufferLock.Unlock() - return nil -} - -type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) - -func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- *prometheus.Desc) - if !ok { - log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n") - return - } - realch <- prometheus.NewDesc(name, text, nil, nil) -} - -func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- prometheus.Metric) - if !ok { - log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n") - return - } - desc := prometheus.NewDesc(name, text, nil, nil) - realch <- prometheus.MustNewConstMetric(desc, valueType, value) -} - -func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) { - doFunc(ch, name, text, value, valueType) -} - -func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) { - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue) -} - -func (p *plug) doStats(ch interface{}, doFunc statsFunc) { - p.RLock() - stats := p.d.GetStats() - doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) - doStatsLookup(ch, doFunc, "parental", &stats.Parental) - p.RUnlock() -} - -// Describe is called by prometheus handler to know stat types -func (p *plug) Describe(ch chan<- *prometheus.Desc) { - p.doStats(ch, doDesc) -} - -// Collect is called by prometheus handler to collect stats -func (p *plug) Collect(ch chan<- prometheus.Metric) { - p.doStats(ch, doMetric) -} - -func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { - // check if it's a domain name or IP address - addr := net.ParseIP(val) - var records []dns.RR - // log.Println("Will give", val, "instead of", host) // debug logging - if addr != nil { - // this is an IP address, return it - result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val)) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - records = append(records, result) - } else { - // this is a domain name, need to look it up - req := new(dns.Msg) - req.SetQuestion(dns.Fqdn(val), question.Qtype) - req.RecursionDesired = true - reqstate := request.Request{W: w, Req: req, Context: ctx} - result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - if result != nil { - for _, answer := range result.Answer { - answer.Header().Name = question.Name - } - records = result.Answer - } - } - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Answer = append(m.Answer, records...) - state := request.Request{W: w, Req: r, Context: ctx} - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - return dns.RcodeSuccess, nil -} - -// generate SOA record that makes DNS clients cache NXdomain results -// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant -func (p *plug) genSOA(r *dns.Msg) []dns.RR { - zone := r.Question[0].Name - header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET} - - Mbox := "hostmaster." - if zone[0] != '.' { - Mbox += zone - } - Ns := "fake-for-negative-caching.adguard.com." - - soa := *defaultSOA - soa.Hdr = header - soa.Mbox = Mbox - soa.Ns = Ns - soa.Serial = 100500 // faster than uint32(time.Now().Unix()) - return []dns.RR{&soa} -} - -func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r, Context: ctx} - m := new(dns.Msg) - m.SetRcode(state.Req, dns.RcodeNameError) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Ns = p.genSOA(r) - - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return dns.RcodeNameError, nil -} - -func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") - } - for _, question := range r.Question { - host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) - // is it a safesearch domain? - p.RLock() - if val, ok := p.d.SafeSearchDomain(host); ok { - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - p.RUnlock() - return rcode, dnsfilter.Result{}, err - } - p.RUnlock() - return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err - } - p.RUnlock() - - // needs to be filtered instead - p.RLock() - result, err := p.d.CheckHost(host) - if err != nil { - log.Printf("plugin/dnsfilter: %s\n", err) - p.RUnlock() - return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) - } - p.RUnlock() - - if result.IsFiltered { - switch result.Reason { - case dnsfilter.FilteredSafeBrowsing: - // return cname safebrowsing.block.dns.adguard.com - val := p.settings.SafeBrowsingBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredParental: - // return cname family.block.dns.adguard.com - val := p.settings.ParentalBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredBlackList: - - if result.Ip == nil { - // return NXDomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } else { - // This is a hosts-syntax rule - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } - case dnsfilter.FilteredInvalid: - // return NXdomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } else { - switch result.Reason { - case dnsfilter.NotFilteredWhiteList: - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, result, err - case dnsfilter.NotFilteredNotFound: - // do nothing, pass through to lower code - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } - } - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, dnsfilter.Result{}, err -} - -// ServeDNS handles the DNS request and refuses if it's in filterlists -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - start := time.Now() - requests.Inc() - state := request.Request{W: w, Req: r} - ip := state.IP() - - // capture the written answer - rrw := dnstest.NewRecorder(w) - rcode, result, err := p.serveDNSInternal(ctx, rrw, r) - if rcode > 0 { - // actually send the answer if we have one - answer := new(dns.Msg) - answer.SetRcode(r, rcode) - state.SizeAndDo(answer) - err = w.WriteMsg(answer) - if err != nil { - return dns.RcodeServerFailure, err - } - } - - // increment counters - switch { - case err != nil: - errorsTotal.Inc() - case result.Reason == dnsfilter.FilteredBlackList: - filtered.Inc() - filteredLists.Inc() - case result.Reason == dnsfilter.FilteredSafeBrowsing: - filtered.Inc() - filteredSafebrowsing.Inc() - case result.Reason == dnsfilter.FilteredParental: - filtered.Inc() - filteredParental.Inc() - case result.Reason == dnsfilter.FilteredInvalid: - filtered.Inc() - filteredInvalid.Inc() - case result.Reason == dnsfilter.FilteredSafeSearch: - // the request was passsed through but not filtered, don't increment filtered - safesearch.Inc() - case result.Reason == dnsfilter.NotFilteredWhiteList: - whitelisted.Inc() - case result.Reason == dnsfilter.NotFilteredNotFound: - // do nothing - case result.Reason == dnsfilter.NotFilteredError: - text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!" - log.Println(text) - err = errors.New(text) - rcode = dns.RcodeServerFailure - } - - // log - elapsed := time.Since(start) - elapsedTime.Observe(elapsed.Seconds()) - if p.settings.QueryLogEnabled { - logRequest(r, rrw.Msg, result, time.Since(start), ip) - } - return rcode, err -} - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "dnsfilter" } - -var onceHook sync.Once -var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go deleted file mode 100644 index 1733fd6f..00000000 --- a/coredns_plugin/coredns_plugin_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package dnsfilter - -import ( - "context" - "fmt" - "io/ioutil" - "net" - "os" - "testing" - - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/test" - "github.com/mholt/caddy" - "github.com/miekg/dns" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`dnsfilter`, false}, - {`dnsfilter { - filter 0 /dev/nonexistent/abcdef - }`, true}, - {`dnsfilter { - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - safebrowsing - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - parental - filter 0 ../tests/dns.txt - }`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestEtcHostsFilter(t *testing.T) { - text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net") - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatal(err) - } - if _, err = tmpfile.Write(text); err != nil { - t.Fatal(err) - } - if err = tmpfile.Close(); err != nil { - t.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) - - configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name()) - c := caddy.NewTestController("dns", configText) - p, err := setupPlugin(c) - if err != nil { - t.Fatal(err) - } - - p.Next = zeroTTLBackend() - - ctx := context.TODO() - - for _, testcase := range []struct { - host string - filtered bool - }{ - {"www.doubleclick.net", false}, - {"doubleclick.net", true}, - {"www2.example.org", false}, - {"www2.example.net", false}, - {"test.www.example.org", false}, - {"test.www.example.net", false}, - {"example.org", true}, - {"example.net", true}, - {"www.example.org", true}, - {"www.example.net", true}, - } { - req := new(dns.Msg) - req.SetQuestion(testcase.host+".", dns.TypeA) - - resp := test.ResponseWriter{} - rrw := dnstest.NewRecorder(&resp) - rcode, err := p.ServeDNS(ctx, rrw, req) - if err != nil { - t.Fatalf("ServeDNS returned error: %s", err) - } - if rcode != rrw.Rcode { - t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode) - } - A, ok := rrw.Msg.Answer[0].(*dns.A) - if !ok { - t.Fatalf("Host %s expected to have result A", testcase.host) - } - ip := net.IPv4(127, 0, 0, 1) - filtered := ip.Equal(A.A) - if testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host) - } - if !testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host) - } - } -} - -func zeroTTLBackend() plugin.Handler { - return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - m := new(dns.Msg) - m.SetReply(r) - m.Response, m.RecursionAvailable = true, true - - m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")} - w.WriteMsg(m) - return dns.RcodeSuccess, nil - }) -} diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go deleted file mode 100644 index 8d3eeecc..00000000 --- a/coredns_plugin/ratelimit/ratelimit.go +++ /dev/null @@ -1,182 +0,0 @@ -package ratelimit - -import ( - "errors" - "log" - "sort" - "strconv" - "time" - - // ratelimiting and per-ip buckets - "github.com/beefsack/go-rate" - "github.com/patrickmn/go-cache" - - // coredns plugin - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -const defaultRatelimit = 30 -const defaultResponseSize = 1000 - -var ( - tokenBuckets = cache.New(time.Hour, time.Hour) -) - -// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r} - ip := state.IP() - allow, err := p.allowRequest(ip) - if err != nil { - return 0, err - } - if !allow { - ratelimited.Inc() - return 0, nil - } - - // Record response to get status code and size of the reply. - rw := dnstest.NewRecorder(w) - status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r) - - size := rw.Len - - if size > defaultResponseSize && state.Proto() == "udp" { - // For large UDP responses we call allowRequest more times - // The exact number of times depends on the response size - for i := 0; i < size/defaultResponseSize; i++ { - p.allowRequest(ip) - } - } - - return status, err -} - -func (p *plug) allowRequest(ip string) (bool, error) { - if len(p.whitelist) > 0 { - i := sort.SearchStrings(p.whitelist, ip) - - if i < len(p.whitelist) && p.whitelist[i] == ip { - return true, nil - } - } - - if _, found := tokenBuckets.Get(ip); !found { - tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) - } - - value, found := tokenBuckets.Get(ip) - if !found { - // should not happen since we've just inserted it - text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared" - log.Println(text) - err := errors.New(text) - return true, err - } - - rl, ok := value.(*rate.RateLimiter) - if !ok { - text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache" - log.Println(text) - err := errors.New(text) - return true, err - } - - allow, _ := rl.Try() - return allow, nil -} - -// -// helper functions -// -func init() { - caddy.RegisterPlugin("ratelimit", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plug struct { - Next plugin.Handler - - // configuration for creating above - ratelimit int // in requests per second per IP - whitelist []string // a list of whitelisted IP addresses -} - -func setupPlugin(c *caddy.Controller) (*plug, error) { - p := &plug{ratelimit: defaultRatelimit} - - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - ratelimit, err := strconv.Atoi(args[0]) - if err != nil { - return nil, c.ArgErr() - } - p.ratelimit = ratelimit - } - for c.NextBlock() { - switch c.Val() { - case "whitelist": - p.whitelist = c.RemainingArgs() - - if len(p.whitelist) > 0 { - sort.Strings(p.whitelist) - } - } - } - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "ratelimit", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "ratelimit" } diff --git a/coredns_plugin/ratelimit/ratelimit_test.go b/coredns_plugin/ratelimit/ratelimit_test.go deleted file mode 100644 index b426f2eb..00000000 --- a/coredns_plugin/ratelimit/ratelimit_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package ratelimit - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`ratelimit`, false}, - {`ratelimit 100`, false}, - {`ratelimit { - whitelist 127.0.0.1 - }`, false}, - {`ratelimit 50 { - whitelist 127.0.0.1 176.103.130.130 - }`, false}, - {`ratelimit test`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestRatelimiting(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || allowed { - t.Fatal("Second request must have been ratelimited") - } -} - -func TestWhitelist(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("Second request must have been allowed due to whitelist") - } -} diff --git a/coredns_plugin/refuseany/refuseany.go b/coredns_plugin/refuseany/refuseany.go deleted file mode 100644 index 92d5d508..00000000 --- a/coredns_plugin/refuseany/refuseany.go +++ /dev/null @@ -1,91 +0,0 @@ -package refuseany - -import ( - "fmt" - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -type plug struct { - Next plugin.Handler -} - -// ServeDNS handles the DNS request and refuses if it's an ANY request -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions") - } - - q := r.Question[0] - if q.Qtype == dns.TypeANY { - state := request.Request{W: w, Req: r, Context: ctx} - rcode := dns.RcodeNotImplemented - - m := new(dns.Msg) - m.SetRcode(r, rcode) - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return rcode, nil - } - - return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) -} - -func init() { - caddy.RegisterPlugin("refuseany", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -func setup(c *caddy.Controller) error { - p := &plug{} - config := dnsserver.GetConfig(c) - - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "refuseany", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "refuseany" } diff --git a/coredns_plugin/reload.go b/coredns_plugin/reload.go deleted file mode 100644 index 880a3acc..00000000 --- a/coredns_plugin/reload.go +++ /dev/null @@ -1,36 +0,0 @@ -package dnsfilter - -import ( - "log" - - "github.com/mholt/caddy" -) - -var Reload = make(chan bool) - -func hook(event caddy.EventName, info interface{}) error { - if event != caddy.InstanceStartupEvent { - return nil - } - - // this should be an instance. ok to panic if not - instance := info.(*caddy.Instance) - - go func() { - for range Reload { - corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) - if err != nil { - continue - } - _, err = instance.Restart(corefile) - if err != nil { - log.Printf("Corefile changed but reload failed: %s", err) - continue - } - // hook will be called again from new instance - return - } - }() - - return nil -} diff --git a/dns.go b/dns.go new file mode 100644 index 00000000..42894336 --- /dev/null +++ b/dns.go @@ -0,0 +1,89 @@ +package main + +import ( + "fmt" + "log" + "net" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/joomcode/errorx" +) + +var dnsServer = dnsforward.Server{} + +func isRunning() bool { + return dnsServer.IsRunning() +} + +func generateServerConfig() dnsforward.ServerConfig { + filters := []dnsfilter.Filter{} + userFilter := userFilter() + filters = append(filters, dnsfilter.Filter{ + ID: userFilter.ID, + Rules: userFilter.Rules, + }) + for _, filter := range config.Filters { + filters = append(filters, dnsfilter.Filter{ + ID: filter.ID, + Rules: filter.Rules, + }) + } + + newconfig := dnsforward.ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: config.DNS.Port}, + FilteringConfig: config.DNS.FilteringConfig, + Filters: filters, + } + + for _, u := range config.DNS.UpstreamDNS { + upstream, err := dnsforward.AddressToUpstream(u, config.DNS.BootstrapDNS) + if err != nil { + log.Printf("Couldn't get upstream: %s", err) + // continue, just ignore the upstream + continue + } + newconfig.Upstreams = append(newconfig.Upstreams, upstream) + } + return newconfig +} + +func startDNSServer() error { + if isRunning() { + return fmt.Errorf("Unable to start forwarding DNS server: Already running") + } + + newconfig := generateServerConfig() + err := dnsServer.Start(&newconfig) + if err != nil { + return errorx.Decorate(err, "Couldn't start forwarding DNS server") + } + + return nil +} + +func reconfigureDNSServer() error { + if !isRunning() { + return fmt.Errorf("Refusing to reconfigure forwarding DNS server: not running") + } + + err := dnsServer.Reconfigure(generateServerConfig()) + if err != nil { + return errorx.Decorate(err, "Couldn't start forwarding DNS server") + } + + return nil +} + +func stopDNSServer() error { + if !isRunning() { + return fmt.Errorf("Refusing to stop forwarding DNS server: not running") + } + + err := dnsServer.Stop() + if err != nil { + return errorx.Decorate(err, "Couldn't stop forwarding DNS server") + } + + return nil +} diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 3153f69f..cd408a4d 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -38,21 +38,22 @@ var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax") // ErrInvalidSyntax is returned by AddRule when the rule was already added to the filter var ErrAlreadyExists = errors.New("dnsfilter: rule was already added") -// ErrInvalidParental is returned by EnableParental when sensitivity is not a valid value -var ErrInvalidParental = errors.New("dnsfilter: invalid parental sensitivity, must be either 3, 10, 13 or 17") - const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot const enableFastLookup = true // flag for debugging, must be true in production for faster performance const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance -type config struct { - parentalServer string - parentalSensitivity int // must be either 3, 10, 13 or 17 - parentalEnabled bool - safeSearchEnabled bool - safeBrowsingEnabled bool - safeBrowsingServer string +// Config allows you to configure DNS filtering with New() or just change variables directly. +type Config struct { + ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17 + ParentalEnabled bool `yaml:"parental_enabled"` + SafeSearchEnabled bool `yaml:"safesearch_enabled"` + SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"` +} + +type privateConfig struct { + parentalServer string // access via methods + safeBrowsingServer string // access via methods } type rule struct { @@ -110,7 +111,13 @@ type Dnsfilter struct { client http.Client // handle for http client -- single instance as recommended by docs transport *http.Transport // handle for http transport used by http client - config config + Config // for direct access by library users, even a = assignment + privateConfig +} + +type Filter struct { + ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase + Rules []string `json:"-" yaml:"-"` // not in yaml or json } //go:generate stringer -type=Reason @@ -171,7 +178,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { } // check safebrowsing if no match - if d.config.safeBrowsingEnabled { + if d.SafeBrowsingEnabled { result, err = d.checkSafeBrowsing(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache @@ -184,7 +191,7 @@ func (d *Dnsfilter) CheckHost(host string) (Result, error) { } // check parental if no match - if d.config.parentalEnabled { + if d.ParentalEnabled { result, err = d.checkParental(host) if err != nil { // failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache @@ -569,11 +576,11 @@ func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) { func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { // prevent recursion -- checking the host of safebrowsing server makes no sense - if host == d.config.safeBrowsingServer { + if host == d.safeBrowsingServer { return Result{}, nil } format := func(hashparam string) string { - url := fmt.Sprintf(defaultSafebrowsingURL, d.config.safeBrowsingServer, hashparam) + url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam) return url } handleBody := func(body []byte, hashes map[string]bool) (Result, error) { @@ -610,11 +617,11 @@ func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) { func (d *Dnsfilter) checkParental(host string) (Result, error) { // prevent recursion -- checking the host of parental safety server makes no sense - if host == d.config.parentalServer { + if host == d.parentalServer { return Result{}, nil } format := func(hashparam string) string { - url := fmt.Sprintf(defaultParentalURL, d.config.parentalServer, hashparam, d.config.parentalSensitivity) + url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity) return url } handleBody := func(body []byte, hashes map[string]bool) (Result, error) { @@ -727,6 +734,24 @@ func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gc // Adding rule and matching against the rules // +// AddRules is a convinience function to add an array of filters in one call +func (d *Dnsfilter) AddRules(filters []Filter) error { + for _, f := range filters { + for _, rule := range f.Rules { + err := d.AddRule(rule, f.ID) + if err == ErrAlreadyExists || err == ErrInvalidSyntax { + continue + } + if err != nil { + log.Printf("Cannot add rule %s: %s", rule, err) + // Just ignore invalid rules + continue + } + } + } + return nil +} + // AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already func (d *Dnsfilter) AddRule(input string, filterListID int64) error { input = strings.TrimSpace(input) @@ -846,7 +871,7 @@ func (d *Dnsfilter) matchHost(host string) (Result, error) { // // New creates properly initialized DNS Filter that is ready to be used -func New() *Dnsfilter { +func New(c *Config) *Dnsfilter { d := new(Dnsfilter) d.storage = make(map[string]bool) @@ -867,8 +892,11 @@ func New() *Dnsfilter { Transport: d.transport, Timeout: defaultHTTPTimeout, } - d.config.safeBrowsingServer = defaultSafebrowsingServer - d.config.parentalServer = defaultParentalServer + d.safeBrowsingServer = defaultSafebrowsingServer + d.parentalServer = defaultParentalServer + if c != nil { + d.Config = *c + } return d } @@ -885,35 +913,21 @@ func (d *Dnsfilter) Destroy() { // config manipulation helpers // -// EnableSafeBrowsing turns on checking hostnames in malware/phishing database -func (d *Dnsfilter) EnableSafeBrowsing() { - d.config.safeBrowsingEnabled = true -} - -// EnableParental turns on checking hostnames for containing adult content -func (d *Dnsfilter) EnableParental(sensitivity int) error { +// IsParentalSensitivityValid checks if sensitivity is valid value +func IsParentalSensitivityValid(sensitivity int) bool { switch sensitivity { case 3, 10, 13, 17: - d.config.parentalSensitivity = sensitivity - d.config.parentalEnabled = true - return nil - default: - return ErrInvalidParental + return true } -} - -// EnableSafeSearch turns on enforcing safesearch in search engines -// only used in coredns plugin and requires caller to use SafeSearchDomain() -func (d *Dnsfilter) EnableSafeSearch() { - d.config.safeSearchEnabled = true + return false } // SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup func (d *Dnsfilter) SetSafeBrowsingServer(host string) { if len(host) == 0 { - d.config.safeBrowsingServer = defaultSafebrowsingServer + d.safeBrowsingServer = defaultSafebrowsingServer } else { - d.config.safeBrowsingServer = host + d.safeBrowsingServer = host } } @@ -929,7 +943,7 @@ func (d *Dnsfilter) ResetHTTPTimeout() { // SafeSearchDomain returns replacement address for search engine func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) { - if d.config.safeSearchEnabled { + if d.SafeSearchEnabled { val, ok := safeSearchDomains[host] return val, ok } diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 39b33a44..a93fadfc 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -338,7 +338,7 @@ func mustLoadTestRules(d *Dnsfilter) { } func NewForTest() *Dnsfilter { - d := New() + d := New(nil) purgeCaches() return d } @@ -542,7 +542,7 @@ func TestSafeBrowsing(t *testing.T) { t.Run(fmt.Sprintf("%s in %s", tc, _Func()), func(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true stats.Safebrowsing.Requests = 0 d.checkMatch(t, "wmconvirus.narod.ru") d.checkMatch(t, "wmconvirus.narod.ru") @@ -570,7 +570,7 @@ func TestSafeBrowsing(t *testing.T) { func TestParallelSB(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { @@ -597,7 +597,7 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { defer ts.Close() address := ts.Listener.Addr().String() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true d.SetHTTPTimeout(time.Second * 5) d.SetSafeBrowsingServer(address) // this will ensure that test fails d.checkMatchEmpty(t, "wmconvirus.narod.ru") @@ -606,7 +606,8 @@ func TestSafeBrowsingCustomServerFail(t *testing.T) { func TestParentalControl(t *testing.T) { d := NewForTest() defer d.Destroy() - d.EnableParental(3) + d.ParentalEnabled = true + d.ParentalSensitivity = 3 d.checkMatch(t, "pornhub.com") d.checkMatch(t, "pornhub.com") if stats.Parental.Requests != 1 { @@ -637,7 +638,7 @@ func TestSafeSearch(t *testing.T) { if ok { t.Errorf("Expected safesearch to error when disabled") } - d.EnableSafeSearch() + d.SafeSearchEnabled = true val, ok := d.SafeSearchDomain("www.google.com") if !ok { t.Errorf("Expected safesearch to find result for www.google.com") @@ -924,7 +925,7 @@ func BenchmarkLotsOfRulesLotsOfHostsParallel(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true for n := 0; n < b.N; n++ { hostname := "wmconvirus.narod.ru" ret, err := d.CheckHost(hostname) @@ -940,7 +941,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeBrowsing() + d.SafeBrowsingEnabled = true b.RunParallel(func(pb *testing.PB) { for pb.Next() { hostname := "wmconvirus.narod.ru" @@ -958,7 +959,7 @@ func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeSearch() + d.SafeSearchEnabled = true for n := 0; n < b.N; n++ { val, ok := d.SafeSearchDomain("www.google.com") if !ok { @@ -973,7 +974,7 @@ func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearchParallel(b *testing.B) { d := NewForTest() defer d.Destroy() - d.EnableSafeSearch() + d.SafeSearchEnabled = true b.RunParallel(func(pb *testing.PB) { for pb.Next() { val, ok := d.SafeSearchDomain("www.google.com") @@ -1009,17 +1010,3 @@ func _Func() string { f := runtime.FuncForPC(pc[0]) return path.Base(f.Name()) } - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Print(buf.String()) -} diff --git a/dnsfilter/helpers.go b/dnsfilter/helpers.go index 68d4ba26..8152f402 100644 --- a/dnsfilter/helpers.go +++ b/dnsfilter/helpers.go @@ -1,6 +1,10 @@ package dnsfilter import ( + "fmt" + "os" + "path" + "runtime" "strings" "sync/atomic" ) @@ -58,3 +62,17 @@ func updateMax(valuePtr *int64, maxPtr *int64) { // swapping failed because value has changed after reading, try again } } + +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} diff --git a/dnsforward/bootstrap.go b/dnsforward/bootstrap.go new file mode 100644 index 00000000..2d263871 --- /dev/null +++ b/dnsforward/bootstrap.go @@ -0,0 +1,107 @@ +package dnsforward + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + "strings" + "sync" + + "github.com/joomcode/errorx" +) + +type bootstrapper struct { + address string // in form of "tls://one.one.one.one:853" + resolver *net.Resolver // resolver to use to resolve hostname, if neccessary + resolved string // in form "IP:port" + resolvedConfig *tls.Config + sync.Mutex +} + +func toBoot(address, bootstrapAddr string) bootstrapper { + var resolver *net.Resolver + if bootstrapAddr != "" { + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, network, bootstrapAddr) + }, + } + } + return bootstrapper{ + address: address, + resolver: resolver, + } +} + +// will get usable IP address from Address field, and caches the result +func (n *bootstrapper) get() (string, *tls.Config, error) { + // TODO: RLock() here but atomically upgrade to Lock() if fast path doesn't work + n.Lock() + if n.resolved != "" { // fast path + retval, tlsconfig := n.resolved, n.resolvedConfig + n.Unlock() + return retval, tlsconfig, nil + } + + // + // slow path + // + + defer n.Unlock() + + justHostPort := n.address + if strings.Contains(n.address, "://") { + url, err := url.Parse(n.address) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to parse %s", n.address) + } + + justHostPort = url.Host + } + + // convert host to IP if neccessary, we know that it's scheme://hostname:port/ + + // get a host without port + host, port, err := net.SplitHostPort(justHostPort) + if err != nil { + return "", nil, fmt.Errorf("bootstrapper requires port in address %s", n.address) + } + + // if it's an IP + ip := net.ParseIP(host) + if ip != nil { + n.resolved = justHostPort + return n.resolved, nil, nil + } + + // + // if it's a hostname + // + + resolver := n.resolver // no need to check for nil resolver -- documented that nil is default resolver + addrs, err := resolver.LookupIPAddr(context.TODO(), host) + if err != nil { + return "", nil, errorx.Decorate(err, "Failed to lookup %s", host) + } + for _, addr := range addrs { + // TODO: support ipv6, support multiple ipv4 + if addr.IP.To4() == nil { + continue + } + ip = addr.IP + break + } + + if ip == nil { + // couldn't find any suitable IP address + return "", nil, fmt.Errorf("Couldn't find any suitable IP address for host %s", host) + } + + n.resolved = net.JoinHostPort(ip.String(), port) + n.resolvedConfig = &tls.Config{ServerName: host} + return n.resolved, n.resolvedConfig, nil +} diff --git a/dnsforward/cache.go b/dnsforward/cache.go new file mode 100644 index 00000000..568f284c --- /dev/null +++ b/dnsforward/cache.go @@ -0,0 +1,225 @@ +package dnsforward + +import ( + "encoding/binary" + "log" + "math" + "strings" + "sync" + "time" + + "github.com/miekg/dns" +) + +type item struct { + m *dns.Msg + when time.Time +} + +type cache struct { + items map[string]item + + sync.RWMutex +} + +func (c *cache) Get(request *dns.Msg) (*dns.Msg, bool) { + if request == nil { + return nil, false + } + ok, key := key(request) + if !ok { + log.Printf("Get(): key returned !ok") + return nil, false + } + + c.RLock() + item, ok := c.items[key] + c.RUnlock() + if !ok { + return nil, false + } + // get item's TTL + ttl := findLowestTTL(item.m) + // zero TTL? delete and don't serve it + if ttl == 0 { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + // too much time has passed? delete and don't serve it + if time.Since(item.when) >= time.Duration(ttl)*time.Second { + c.Lock() + delete(c.items, key) + c.Unlock() + return nil, false + } + response := item.fromItem(request) + return response, true +} + +func (c *cache) Set(m *dns.Msg) { + if m == nil { + return // no-op + } + if !isRequestCacheable(m) { + return + } + if !isResponseCacheable(m) { + return + } + ok, key := key(m) + if !ok { + return + } + + i := toItem(m) + + c.Lock() + if c.items == nil { + c.items = map[string]item{} + } + c.items[key] = i + c.Unlock() +} + +// check only request fields +func isRequestCacheable(m *dns.Msg) bool { + // truncated messages aren't valid + if m.Truncated { + log.Printf("Refusing to cache truncated message") + return false + } + + // if has wrong number of questions, also don't cache + if len(m.Question) != 1 { + log.Printf("Refusing to cache message with wrong number of questions") + return false + } + + // only OK or NXdomain replies are cached + switch m.Rcode { + case dns.RcodeSuccess: + case dns.RcodeNameError: // that's an NXDomain + case dns.RcodeServerFailure: + return false // quietly refuse, don't log + default: + log.Printf("%s: Refusing to cache message with rcode: %s", m.Question[0].Name, dns.RcodeToString[m.Rcode]) + return false + } + + return true +} + +func isResponseCacheable(m *dns.Msg) bool { + ttl := findLowestTTL(m) + if ttl == 0 { + return false + } + + return true +} + +func findLowestTTL(m *dns.Msg) uint32 { + var ttl uint32 = math.MaxUint32 + found := false + + if m.Answer != nil { + for _, r := range m.Answer { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Ns != nil { + for _, r := range m.Ns { + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if m.Extra != nil { + for _, r := range m.Extra { + if r.Header().Rrtype == dns.TypeOPT { + continue // OPT records use TTL for other purposes + } + if r.Header().Ttl < ttl { + ttl = r.Header().Ttl + found = true + } + } + } + + if found == false { + return 0 + } + + return ttl +} + +// key is binary little endian in sequence: +// uint16(qtype) then uint16(qclass) then name +func key(m *dns.Msg) (bool, string) { + if len(m.Question) != 1 { + log.Printf("got msg with len(m.Question) != 1: %d", len(m.Question)) + return false, "" + } + + bb := strings.Builder{} + b := make([]byte, 2) + binary.LittleEndian.PutUint16(b, m.Question[0].Qtype) + bb.Write(b) + binary.LittleEndian.PutUint16(b, m.Question[0].Qclass) + bb.Write(b) + name := strings.ToLower(m.Question[0].Name) + bb.WriteString(name) + return true, bb.String() +} + +func toItem(m *dns.Msg) item { + return item{ + m: m, + when: time.Now(), + } +} + +func (i *item) fromItem(request *dns.Msg) *dns.Msg { + response := &dns.Msg{} + response.SetReply(request) + + response.Authoritative = false + response.AuthenticatedData = i.m.AuthenticatedData + response.RecursionAvailable = i.m.RecursionAvailable + response.Rcode = i.m.Rcode + + ttl := findLowestTTL(i.m) + timeleft := math.Round(float64(ttl) - time.Since(i.when).Seconds()) + var newttl uint32 + if timeleft > 0 { + newttl = uint32(timeleft) + } + for _, r := range i.m.Answer { + answer := dns.Copy(r) + answer.Header().Ttl = newttl + response.Answer = append(response.Answer, answer) + } + for _, r := range i.m.Ns { + ns := dns.Copy(r) + ns.Header().Ttl = newttl + response.Ns = append(response.Ns, ns) + } + for _, r := range i.m.Extra { + // don't return OPT records as these are hop-by-hop + if r.Header().Rrtype == dns.TypeOPT { + continue + } + extra := dns.Copy(r) + extra.Header().Ttl = newttl + response.Extra = append(response.Extra, extra) + } + return response +} diff --git a/dnsforward/cache_test.go b/dnsforward/cache_test.go new file mode 100644 index 00000000..c9f4577e --- /dev/null +++ b/dnsforward/cache_test.go @@ -0,0 +1,144 @@ +package dnsforward + +import ( + "strings" + "testing" + + "github.com/go-test/deep" + "github.com/miekg/dns" +) + +func RR(rr string) dns.RR { + r, err := dns.NewRR(rr) + if err != nil { + panic(err) + } + return r +} + +// deepEqual is same as deep.Equal, except: +// * ignores Id when comparing +// * question names are not case sensetive +func deepEqualMsg(left *dns.Msg, right *dns.Msg) []string { + temp := *left + temp.Id = right.Id + for i := range left.Question { + left.Question[i].Name = strings.ToLower(left.Question[i].Name) + } + for i := range right.Question { + right.Question[i].Name = strings.ToLower(right.Question[i].Name) + } + return deep.Equal(&temp, right) +} + +func TestCacheSanity(t *testing.T) { + cache := cache{} + request := dns.Msg{} + request.SetQuestion("google.com.", dns.TypeA) + _, ok := cache.Get(&request) + if ok { + t.Fatal("empty cache replied with positive response") + } +} + +type tests struct { + cache []testEntry + cases []testCase +} + +type testEntry struct { + q string + t uint16 + a []dns.RR +} + +type testCase struct { + q string + t uint16 + a []dns.RR + ok bool +} + +func TestCache(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestCacheMixedCase(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "google.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "GOOGLE.COM.", t: dns.TypeA, a: []dns.RR{RR("google.com. 3600 IN A 8.8.8.8")}, ok: true}, + {q: "gOOgle.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "GOOGLE.COM.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func TestZeroTTL(t *testing.T) { + tests := tests{ + cache: []testEntry{ + {q: "gOOgle.com.", t: dns.TypeA, a: []dns.RR{RR("google.com. 0 IN A 8.8.8.8")}}, + }, + cases: []testCase{ + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeA, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + {q: "google.com.", t: dns.TypeMX, ok: false}, + }, + } + runTests(t, tests) +} + +func runTests(t *testing.T, tests tests) { + t.Helper() + cache := cache{} + for _, tc := range tests.cache { + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + } + for _, tc := range tests.cases { + request := dns.Msg{} + request.SetQuestion(tc.q, tc.t) + val, ok := cache.Get(&request) + if diff := deep.Equal(ok, tc.ok); diff != nil { + t.Error(diff) + } + if tc.a != nil { + if ok == false { + continue + } + reply := dns.Msg{} + reply.SetQuestion(tc.q, tc.t) + reply.Response = true + reply.Answer = tc.a + cache.Set(&reply) + if diff := deepEqualMsg(val, &reply); diff != nil { + t.Error(diff) + } else { + if diff := deep.Equal(val, reply); diff == nil { + t.Error("different message ID were not caught") + } + } + } + } +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go new file mode 100644 index 00000000..404bbfb3 --- /dev/null +++ b/dnsforward/dnsforward.go @@ -0,0 +1,594 @@ +package dnsforward + +import ( + "fmt" + "log" + "net" + "reflect" + "strings" + "sync" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/joomcode/errorx" + "github.com/miekg/dns" + gocache "github.com/patrickmn/go-cache" +) + +// Server is the main way to start a DNS server. +// +// Example: +// s := dnsforward.Server{} +// err := s.Start(nil) // will start a DNS server listening on default port 53, in a goroutine +// err := s.Reconfigure(ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) // will reconfigure running DNS server to listen on UDP port 53535 +// err := s.Stop() // will stop listening on port 53535 and cancel all goroutines +// err := s.Start(nil) // will start listening again, on port 53535, in a goroutine +// +// The zero Server is empty and ready for use. +type Server struct { + udpListen *net.UDPConn + + dnsFilter *dnsfilter.Dnsfilter + + cache cache + + ratelimitBuckets *gocache.Cache // where the ratelimiters are stored, per IP + + sync.RWMutex + ServerConfig +} + +// uncomment this block to have tracing of locks +/* +func (s *Server) Lock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.Lock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Lock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) RLock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.RLock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RLock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) Unlock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.Unlock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> Unlock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +func (s *Server) RUnlock() { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + file, line := f.FileLine(pc[0]) + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> in progress\n", path.Base(file), line, path.Base(f.Name())) + s.RWMutex.RUnlock() + fmt.Fprintf(os.Stderr, "%s:%d %s() -> RUnlock() -> done\n", path.Base(file), line, path.Base(f.Name())) +} +*/ + +type FilteringConfig struct { + ProtectionEnabled bool `yaml:"protection_enabled"` + FilteringEnabled bool `yaml:"filtering_enabled"` + BlockedResponseTTL uint32 `yaml:"blocked_response_ttl"` // if 0, then default is used (3600) + QueryLogEnabled bool `yaml:"querylog_enabled"` + Ratelimit int `yaml:"ratelimit"` + RatelimitWhitelist []string `yaml:"ratelimit_whitelist"` + RefuseAny bool `yaml:"refuse_any"` + BootstrapDNS string `yaml:"bootstrap_dns"` + + dnsfilter.Config `yaml:",inline"` +} + +// The zero ServerConfig is empty and ready for use. +type ServerConfig struct { + UDPListenAddr *net.UDPAddr // if nil, then default is is used (port 53 on *) + Upstreams []Upstream + Filters []dnsfilter.Filter + + FilteringConfig +} + +// if any of ServerConfig values are zero, then default values from below are used +var defaultValues = ServerConfig{ + UDPListenAddr: &net.UDPAddr{Port: 53}, + FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, + Upstreams: []Upstream{ + //// dns over HTTPS + // &dnsOverHTTPS{boot: toBoot("https://1.1.1.1/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns.google.com/experimental", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.cleanbrowsing.org/doh/security-filter/", "")}, + // &dnsOverHTTPS{boot: toBoot("https://dns10.quad9.net/dns-query", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.powerdns.org", "")}, + // &dnsOverHTTPS{boot: toBoot("https://doh.securedns.eu/dns-query", "")}, + + //// dns over TLS + // &dnsOverTLS{boot: toBoot("tls://8.8.8.8:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://8.8.4.4:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.1.1.1:853", "")}, + // &dnsOverTLS{boot: toBoot("tls://1.0.0.1:853", "")}, + + //// plainDNS + &plainDNS{boot: toBoot("8.8.8.8:53", "")}, + &plainDNS{boot: toBoot("8.8.4.4:53", "")}, + &plainDNS{boot: toBoot("1.1.1.1:53", "")}, + &plainDNS{boot: toBoot("1.0.0.1:53", "")}, + }, +} + +// +// packet loop +// +func (s *Server) packetLoop() { + log.Printf("Entering packet handle loop") + b := make([]byte, dns.MaxMsgSize) + for { + s.RLock() + conn := s.udpListen + s.RUnlock() + if conn == nil { + log.Printf("udp socket has disappeared, exiting loop") + break + } + n, addr, err := conn.ReadFrom(b) + // documentation says to handle the packet even if err occurs, so do that first + if n > 0 { + // make a copy of all bytes because ReadFrom() will overwrite contents of b on next call + // we need the contents to survive the call because we're handling them in goroutine + p := make([]byte, n) + copy(p, b) + go s.handlePacket(p, addr, conn) // ignore errors + } + if err != nil { + if isConnClosed(err) { + log.Printf("ReadFrom() returned because we're reading from a closed connection, exiting loop") + // don't try to nullify s.udpListen here, because s.udpListen could be already re-bound to listen + break + } + log.Printf("Got error when reading from udp listen: %s", err) + } + } +} + +// +// Control functions +// + +func (s *Server) Start(config *ServerConfig) error { + s.Lock() + defer s.Unlock() + if config != nil { + s.ServerConfig = *config + } + // TODO: handle being called Start() second time after Stop() + if s.udpListen == nil { + log.Printf("Creating UDP socket") + var err error + addr := s.UDPListenAddr + if addr == nil { + addr = defaultValues.UDPListenAddr + } + s.udpListen, err = net.ListenUDP("udp", addr) + if err != nil { + s.udpListen = nil + return errorx.Decorate(err, "Couldn't listen to UDP socket") + } + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + } + + if s.dnsFilter == nil { + log.Printf("Creating dnsfilter") + s.dnsFilter = dnsfilter.New(&s.Config) + // add rules only if they are enabled + if s.FilteringEnabled { + s.dnsFilter.AddRules(s.Filters) + } + } + + log.Printf("Loading stats from querylog") + err := fillStatsFromQueryLog() + if err != nil { + log.Printf("Failed to load stats from querylog: %s", err) + return err + } + + once.Do(func() { + go periodicQueryLogRotate() + go periodicHourlyTopRotate() + go statsRotator() + }) + + go s.packetLoop() + + return nil +} + +func (s *Server) Stop() error { + s.Lock() + defer s.Unlock() + if s.udpListen != nil { + err := s.udpListen.Close() + s.udpListen = nil + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } + + // flush remainder to file + logBufferLock.Lock() + flushBuffer := logBuffer + logBuffer = nil + logBufferLock.Unlock() + err := flushToFile(flushBuffer) + if err != nil { + log.Printf("Saving querylog to file failed: %s", err) + return err + } + + return nil +} + +func (s *Server) IsRunning() bool { + s.RLock() + isRunning := true + if s.udpListen == nil { + isRunning = false + } + s.RUnlock() + return isRunning +} + +// +// Server reconfigure +// + +func (s *Server) reconfigureListenAddr(new ServerConfig) error { + oldAddr := s.UDPListenAddr + if oldAddr == nil { + oldAddr = defaultValues.UDPListenAddr + } + newAddr := new.UDPListenAddr + if newAddr == nil { + newAddr = defaultValues.UDPListenAddr + } + if newAddr.Port == 0 { + return errorx.IllegalArgument.New("new port cannot be 0") + } + if reflect.DeepEqual(oldAddr, newAddr) { + // do nothing, the addresses are exactly the same + log.Printf("Not going to rebind because addresses are same: %v -> %v", oldAddr, newAddr) + return nil + } + + // rebind, using a strategy: + // * if ports are different, bind new first, then close old + // * if ports are same, close old first, then bind new + var newListen *net.UDPConn + var err error + if oldAddr.Port != newAddr.Port { + log.Printf("Rebinding -- ports are different so bind first then close") + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + s.Lock() + if s.udpListen != nil { + err = s.udpListen.Close() + s.udpListen = nil + } + s.Unlock() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + } else { + log.Printf("Rebinding -- ports are same so close first then bind") + s.Lock() + if s.udpListen != nil { + err = s.udpListen.Close() + s.udpListen = nil + } + s.Unlock() + if err != nil { + return errorx.Decorate(err, "Couldn't close UDP listening socket") + } + newListen, err = net.ListenUDP("udp", newAddr) + if err != nil { + return errorx.Decorate(err, "Couldn't bind to %v", newAddr) + } + } + s.Lock() + s.udpListen = newListen + s.UDPListenAddr = new.UDPListenAddr + s.Unlock() + log.Println(s.udpListen.LocalAddr(), s.UDPListenAddr) + + go s.packetLoop() // the old one has quit, use new one + + return nil +} + +func (s *Server) reconfigureBlockedResponseTTL(new ServerConfig) { + newVal := new.BlockedResponseTTL + if newVal == 0 { + newVal = defaultValues.BlockedResponseTTL + } + oldVal := s.BlockedResponseTTL + if oldVal == 0 { + oldVal = defaultValues.BlockedResponseTTL + } + if newVal != oldVal { + s.BlockedResponseTTL = new.BlockedResponseTTL + } +} + +func (s *Server) reconfigureUpstreams(new ServerConfig) { + newVal := new.Upstreams + if len(newVal) == 0 { + newVal = defaultValues.Upstreams + } + oldVal := s.Upstreams + if len(oldVal) == 0 { + oldVal = defaultValues.Upstreams + } + if reflect.DeepEqual(newVal, oldVal) { + // they're exactly the same, do nothing + return + } + s.Upstreams = new.Upstreams +} + +func (s *Server) reconfigureFiltering(new ServerConfig) { + newFilters := new.Filters + if len(newFilters) == 0 { + newFilters = defaultValues.Filters + } + oldFilters := s.Filters + if len(oldFilters) == 0 { + oldFilters = defaultValues.Filters + } + + needUpdate := false + if !reflect.DeepEqual(newFilters, oldFilters) { + needUpdate = true + } + + if !reflect.DeepEqual(new.FilteringConfig, s.FilteringConfig) { + needUpdate = true + } + + if !needUpdate { + // nothing to do, everything is same + return + } + + // TODO: instead of creating new dnsfilter, change existing one's settings and filters + dnsFilter := dnsfilter.New(&new.Config) // sets safebrowsing, safesearch and parental + + // add rules only if they are enabled + if new.FilteringEnabled { + dnsFilter.AddRules(newFilters) + } + + s.Lock() + oldDNSFilter := s.dnsFilter + s.dnsFilter = dnsFilter + s.FilteringConfig = new.FilteringConfig + s.Unlock() + + oldDNSFilter.Destroy() +} + +func (s *Server) Reconfigure(new ServerConfig) error { + s.reconfigureBlockedResponseTTL(new) + s.reconfigureUpstreams(new) + s.reconfigureFiltering(new) + + err := s.reconfigureListenAddr(new) + if err != nil { + return errorx.Decorate(err, "Couldn't reconfigure to new listening address %+v", new.UDPListenAddr) + } + return nil +} + +// +// packet handling functions +// + +// handlePacketInternal processes the incoming packet bytes and returns with an optional response packet. +// +// If an empty dns.Msg is returned, do not try to send anything back to client, otherwise send contents of dns.Msg. +// +// If an error is returned, log it, don't try to generate data based on that error. +func (s *Server) handlePacketInternal(msg *dns.Msg, addr net.Addr, conn *net.UDPConn) (*dns.Msg, *dnsfilter.Result, Upstream, error) { + // log.Printf("Got packet %d bytes from %s: %v", len(p), addr, p) + // + // DNS packet byte format is valid + // + // any errors below here require a response to client + // log.Printf("Unpacked: %v", msg.String()) + if len(msg.Question) != 1 { + log.Printf("Got invalid number of questions: %v", len(msg.Question)) + return s.genServerFailure(msg), nil, nil, nil + } + + if msg.Question[0].Qtype == dns.TypeANY && s.RefuseAny { + return s.genNotImpl(msg), nil, nil, nil + } + + // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise + host := strings.TrimSuffix(msg.Question[0].Name, ".") + res, err := s.dnsFilter.CheckHost(host) + if err != nil { + log.Printf("dnsfilter failed to check host '%s': %s", host, err) + return s.genServerFailure(msg), &res, nil, err + } else if res.IsFiltered { + log.Printf("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) + return s.genNXDomain(msg), &res, nil, nil + } + + { + val, ok := s.cache.Get(msg) + if ok && val != nil { + return val, &res, nil, nil + } + } + + // TODO: replace with single-socket implementation + upstream := s.chooseUpstream() + reply, err := upstream.Exchange(msg) + if err != nil { + log.Printf("talking to upstream failed for host '%s': %s", host, err) + return s.genServerFailure(msg), &res, upstream, err + } + if reply == nil { + log.Printf("SHOULD NOT HAPPEN upstream returned empty message for host '%s'. Request is %v", host, msg.String()) + return s.genServerFailure(msg), &res, upstream, nil + } + + s.cache.Set(reply) + + return reply, &res, upstream, nil +} + +func (s *Server) handlePacket(p []byte, addr net.Addr, conn *net.UDPConn) { + start := time.Now() + ip, _, err := net.SplitHostPort(addr.String()) + if err != nil { + log.Printf("Failed to split %v into host/port: %s", addr, err) + // not a fatal error, move on + } + + // ratelimit based on IP only, protects CPU cycles and outbound connections + if s.isRatelimited(ip) { + // log.Printf("Ratelimiting %s based on IP only", ip) + return // do nothing, don't reply, we got ratelimited + } + + msg := &dns.Msg{} + err = msg.Unpack(p) + if err != nil { + log.Printf("got invalid DNS packet: %s", err) + return // do nothing + } + + reply, result, upstream, err := s.handlePacketInternal(msg, addr, conn) + + if reply != nil { + // ratelimit based on reply size now + replysize := reply.Len() + if s.isRatelimitedForReply(ip, replysize) { + log.Printf("Ratelimiting %s based on IP and size %d", ip, replysize) + return // do nothing, don't reply, we got ratelimited + } + + // we're good to respond + rerr := s.respond(reply, addr, conn) + if rerr != nil { + log.Printf("Couldn't respond to UDP packet: %s", err) + } + } + + // query logging and stats counters + if s.QueryLogEnabled { + elapsed := time.Since(start) + upstreamAddr := "" + if upstream != nil { + upstreamAddr = upstream.Address() + } + logRequest(msg, reply, result, elapsed, ip, upstreamAddr) + } +} + +// +// packet sending functions +// + +func (s *Server) respond(resp *dns.Msg, addr net.Addr, conn *net.UDPConn) error { + // log.Printf("Replying to %s with %s", addr, resp) + resp.Compress = true + bytes, err := resp.Pack() + if err != nil { + return errorx.Decorate(err, "Couldn't convert message into wire format") + } + n, err := conn.WriteTo(bytes, addr) + if n == 0 && isConnClosed(err) { + return err + } + if n != len(bytes) { + return fmt.Errorf("WriteTo() returned with %d != %d", n, len(bytes)) + } + if err != nil { + return errorx.Decorate(err, "WriteTo() returned error") + } + return nil +} + +func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeServerFailure) + resp.RecursionAvailable = true + return &resp +} + +func (s *Server) genNotImpl(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNotImplemented) + resp.RecursionAvailable = true + resp.SetEdns0(1452, false) // NOTIMPL without EDNS is treated as 'we don't support EDNS', so explicitly set it + return &resp +} + +func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetRcode(request, dns.RcodeNameError) + resp.RecursionAvailable = true + resp.Ns = s.genSOA(request) + return &resp +} + +func (s *Server) genSOA(request *dns.Msg) []dns.RR { + zone := "" + if len(request.Question) > 0 { + zone = request.Question[0].Name + } + + soa := dns.SOA{ + // values copied from verisign's nonexistent .com domain + // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers + Refresh: 1800, + Retry: 900, + Expire: 604800, + Minttl: 86400, + // copied from AdGuard DNS + Ns: "fake-for-negative-caching.adguard.com.", + Serial: 100500, + // rest is request-specific + Hdr: dns.RR_Header{ + Name: zone, + Rrtype: dns.TypeSOA, + Ttl: s.BlockedResponseTTL, + Class: dns.ClassINET, + }, + Mbox: "hostmaster.", // zone will be appended later if it's not empty or "." + } + if soa.Hdr.Ttl == 0 { + soa.Hdr.Ttl = defaultValues.BlockedResponseTTL + } + if len(zone) > 0 && zone[0] != '.' { + soa.Mbox += zone + } + return []dns.RR{&soa} +} + +var once sync.Once diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go new file mode 100644 index 00000000..26dabb4b --- /dev/null +++ b/dnsforward/dnsforward_test.go @@ -0,0 +1,49 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestServer(t *testing.T) { + s := Server{} + s.UDPListenAddr = &net.UDPAddr{Port: 0} + err := s.Start(nil) + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + if s.udpListen == nil { + t.Fatal("Started server has nil udpListen") + } + + // server is running, send a message + addr := s.udpListen.LocalAddr() + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + reply, err := dns.Exchange(&req, addr.String()) + if err != nil { + t.Fatalf("Couldn't talk to server %s: %s", addr, err) + } + if len(reply.Answer) != 1 { + t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + } + if a, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(a.A) { + t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) + } + } else { + t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + } + + err = s.Stop() + if err != nil { + t.Fatalf("DNS server %s failed to stop: %s", addr, err) + } +} diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go new file mode 100644 index 00000000..52b65c87 --- /dev/null +++ b/dnsforward/helpers.go @@ -0,0 +1,50 @@ +package dnsforward + +import ( + "fmt" + "net" + "os" + "path" + "runtime" + "strings" +) + +func isConnClosed(err error) bool { + if err == nil { + return false + } + nerr, ok := err.(*net.OpError) + if !ok { + return false + } + + if strings.Contains(nerr.Err.Error(), "use of closed network connection") { + return true + } + + return false +} + +// --------------------- +// debug logging helpers +// --------------------- +func _Func() string { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + return path.Base(f.Name()) +} + +func trace(format string, args ...interface{}) { + pc := make([]uintptr, 10) // at least 1 entry needed + runtime.Callers(2, pc) + f := runtime.FuncForPC(pc[0]) + var buf strings.Builder + buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) + text := fmt.Sprintf(format, args...) + buf.WriteString(text) + if len(text) == 0 || text[len(text)-1] != '\n' { + buf.WriteRune('\n') + } + fmt.Fprint(os.Stderr, buf.String()) +} diff --git a/coredns_plugin/querylog.go b/dnsforward/querylog.go similarity index 88% rename from coredns_plugin/querylog.go rename to dnsforward/querylog.go index 92ba2d1d..d449990d 100644 --- a/coredns_plugin/querylog.go +++ b/dnsforward/querylog.go @@ -1,20 +1,16 @@ -package dnsfilter +package dnsforward import ( "encoding/json" "fmt" "log" "net/http" - "os" - "path" - "runtime" "strconv" "strings" "sync" "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/plugin/pkg/response" "github.com/miekg/dns" ) @@ -42,9 +38,10 @@ type logEntry struct { Time time.Time Elapsed time.Duration IP string + Upstream string `json:",omitempty"` // if empty, means it was cached } -func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { +func logRequest(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip string, upstream string) { var q []byte var a []byte var err error @@ -64,14 +61,19 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela } } + if result == nil { + result = &dnsfilter.Result{} + } + now := time.Now() entry := logEntry{ Question: q, Answer: a, - Result: result, + Result: *result, Time: now, Elapsed: elapsed, IP: ip, + Upstream: upstream, } var flushBuffer []*logEntry @@ -97,6 +99,8 @@ func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, ela // don't do failure, just log } + incrementCounters(&entry) + // if buffer needs to be flushed to disk, do it now if len(flushBuffer) > 0 { // write to file @@ -153,8 +157,7 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { } if a != nil { - status, _ := response.Typify(a, time.Now().UTC()) - jsonEntry["status"] = status.String() + jsonEntry["status"] = dns.RcodeToString[a.Rcode] } if len(entry.Result.Rule) > 0 { jsonEntry["rule"] = entry.Result.Rule @@ -223,17 +226,3 @@ func HandleQueryLog(w http.ResponseWriter, r *http.Request) { http.Error(w, errorText, http.StatusInternalServerError) } } - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Fprint(os.Stderr, buf.String()) -} diff --git a/coredns_plugin/querylog_file.go b/dnsforward/querylog_file.go similarity index 84% rename from coredns_plugin/querylog_file.go rename to dnsforward/querylog_file.go index a36812c2..9ea8ef95 100644 --- a/coredns_plugin/querylog_file.go +++ b/dnsforward/querylog_file.go @@ -1,4 +1,4 @@ -package dnsfilter +package dnsforward import ( "bytes" @@ -251,41 +251,3 @@ func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, ti } return nil } - -func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { - a := []*logEntry{} - - onEntry := func(entry *logEntry) error { - a = append(a, entry) - if len(a) > maxLen { - toskip := len(a) - maxLen - a = a[toskip:] - } - return nil - } - - needMore := func() bool { - return true - } - - err := genericLoader(onEntry, needMore, timeWindow) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return values - } - - // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest - for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { - a[left], a[right] = a[right], a[left] - } - - // append it to values - values = append(values, a...) - - // then cut off of it is bigger than maxLen - if len(values) > maxLen { - values = values[:maxLen] - } - - return values -} diff --git a/coredns_plugin/querylog_top.go b/dnsforward/querylog_top.go similarity index 91% rename from coredns_plugin/querylog_top.go rename to dnsforward/querylog_top.go index d4cc6e0d..b78dea79 100644 --- a/coredns_plugin/querylog_top.go +++ b/dnsforward/querylog_top.go @@ -1,4 +1,4 @@ -package dnsfilter +package dnsforward import ( "bytes" @@ -14,7 +14,6 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/bluele/gcache" "github.com/miekg/dns" ) @@ -231,27 +230,7 @@ func fillStatsFromQueryLog() error { } queryLogLock.Unlock() - requests.IncWithTime(entry.Time) - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) - } - elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) + incrementCounters(entry) return nil } diff --git a/dnsforward/ratelimit.go b/dnsforward/ratelimit.go new file mode 100644 index 00000000..9ea8d216 --- /dev/null +++ b/dnsforward/ratelimit.go @@ -0,0 +1,80 @@ +package dnsforward + +import ( + "log" + "sort" + "time" + + "github.com/beefsack/go-rate" + gocache "github.com/patrickmn/go-cache" +) + +func (s *Server) limiterForIP(ip string) interface{} { + if s.ratelimitBuckets == nil { + s.ratelimitBuckets = gocache.New(time.Hour, time.Hour) + } + + // check if ratelimiter for that IP already exists, if not, create + value, found := s.ratelimitBuckets.Get(ip) + if !found { + value = rate.New(s.Ratelimit, time.Second) + s.ratelimitBuckets.Set(ip, value, time.Hour) + } + + return value +} + +func (s *Server) isRatelimited(ip string) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + allow, _ := rl.Try() + return !allow +} + +func (s *Server) isRatelimitedForReply(ip string, size int) bool { + if s.Ratelimit == 0 { // 0 -- disabled + return false + } + if len(s.RatelimitWhitelist) > 0 { + i := sort.SearchStrings(s.RatelimitWhitelist, ip) + + if i < len(s.RatelimitWhitelist) && s.RatelimitWhitelist[i] == ip { + // found, don't ratelimit + return false + } + } + + value := s.limiterForIP(ip) + rl, ok := value.(*rate.RateLimiter) + if !ok { + log.Println("SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache") + return false + } + + // For large UDP responses we try more times, effectively limiting per bandwidth + // The exact number of times depends on the response size + for i := 0; i < size/1000; i++ { + allow, _ := rl.Try() + if !allow { // not allowed -> ratelimited + return true + } + } + return false +} diff --git a/dnsforward/ratelimit_test.go b/dnsforward/ratelimit_test.go new file mode 100644 index 00000000..ed6f5ce9 --- /dev/null +++ b/dnsforward/ratelimit_test.go @@ -0,0 +1,42 @@ +package dnsforward + +import ( + "testing" +) + +func TestRatelimiting(t *testing.T) { + // rate limit is 1 per sec + p := Server{} + p.Ratelimit = 1 + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if !limited { + t.Fatal("Second request must have been ratelimited") + } +} + +func TestWhitelist(t *testing.T) { + // rate limit is 1 per sec with whitelist + p := Server{} + p.Ratelimit = 1 + p.RatelimitWhitelist = []string{"127.0.0.1", "127.0.0.2", "127.0.0.125"} + + limited := p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("First request must have been allowed") + } + + limited = p.isRatelimited("127.0.0.1") + + if limited { + t.Fatal("Second request must have been allowed due to whitelist") + } +} diff --git a/dnsforward/standalone/.gitignore b/dnsforward/standalone/.gitignore new file mode 100644 index 00000000..5f81988c --- /dev/null +++ b/dnsforward/standalone/.gitignore @@ -0,0 +1 @@ +/standalone \ No newline at end of file diff --git a/dnsforward/standalone/standalone.go b/dnsforward/standalone/standalone.go new file mode 100644 index 00000000..ae3e6d13 --- /dev/null +++ b/dnsforward/standalone/standalone.go @@ -0,0 +1,51 @@ +package main + +import ( + "log" + "net" + "net/http" + _ "net/http/pprof" + "os" + "os/signal" + "runtime" + "syscall" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsforward" +) + +// +// main function +// +func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + go func() { + for range time.Tick(time.Second) { + log.Printf("goroutines = %d", runtime.NumGoroutine()) + } + }() + s := dnsforward.Server{} + err := s.Start(nil) + if err != nil { + panic(err) + } + time.Sleep(time.Second) + err = s.Stop() + if err != nil { + panic(err) + } + err = s.Start(&dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53535}}) + if err != nil { + panic(err) + } + err = s.Reconfigure(dnsforward.ServerConfig{UDPListenAddr: &net.UDPAddr{Port: 53, IP: net.ParseIP("0.0.0.0")}}) + if err != nil { + panic(err) + } + log.Printf("Now serving DNS") + signal_channel := make(chan os.Signal) + signal.Notify(signal_channel, syscall.SIGINT, syscall.SIGTERM) + <-signal_channel +} diff --git a/coredns_plugin/coredns_stats.go b/dnsforward/stats.go similarity index 81% rename from coredns_plugin/coredns_stats.go rename to dnsforward/stats.go index b138911e..9cfe5f58 100644 --- a/coredns_plugin/coredns_stats.go +++ b/dnsforward/stats.go @@ -1,4 +1,4 @@ -package dnsfilter +package dnsforward import ( "encoding/json" @@ -8,21 +8,20 @@ import ( "sync" "time" - "github.com/coredns/coredns/plugin" - "github.com/prometheus/client_golang/prometheus" + "github.com/AdguardTeam/AdGuardHome/dnsfilter" ) var ( - requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") - filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") - filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") - filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") - filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") - filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") - whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") - safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") - errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") - elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.") + requests = newDNSCounter("requests_total") + filtered = newDNSCounter("filtered_total") + filteredLists = newDNSCounter("filtered_lists_total") + filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total") + filteredParental = newDNSCounter("filtered_parental_total") + filteredInvalid = newDNSCounter("filtered_invalid_total") + whitelisted = newDNSCounter("whitelisted_total") + safesearch = newDNSCounter("safesearch_total") + errorsTotal = newDNSCounter("errors_total") + elapsedTime = newDNSHistogram("request_duration") ) // entries for single time period (for example all per-second entries) @@ -143,21 +142,13 @@ func statsRotator() { type counter struct { name string // used as key in periodic stats value int64 - prom prometheus.Counter } -func newDNSCounter(name string, help string) *counter { +func newDNSCounter(name string) *counter { // trace("called") - c := &counter{} - c.prom = prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - c.name = name - - return c + return &counter{ + name: name, + } } func (c *counter) IncWithTime(when time.Time) { @@ -166,40 +157,22 @@ func (c *counter) IncWithTime(when time.Time) { statistics.PerHour.Inc(c.name, when) statistics.PerDay.Inc(c.name, when) c.value++ - c.prom.Inc() } func (c *counter) Inc() { c.IncWithTime(time.Now()) } -func (c *counter) Describe(ch chan<- *prometheus.Desc) { - c.prom.Describe(ch) -} - -func (c *counter) Collect(ch chan<- prometheus.Metric) { - c.prom.Collect(ch) -} - type histogram struct { name string // used as key in periodic stats count int64 total float64 - prom prometheus.Histogram } -func newDNSHistogram(name string, help string) *histogram { - // trace("called") - h := &histogram{} - h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - h.name = name - - return h +func newDNSHistogram(name string) *histogram { + return &histogram{ + name: name, + } } func (h *histogram) ObserveWithTime(value float64, when time.Time) { @@ -209,24 +182,40 @@ func (h *histogram) ObserveWithTime(value float64, when time.Time) { statistics.PerDay.Observe(h.name, when, value) h.count++ h.total += value - h.prom.Observe(value) } func (h *histogram) Observe(value float64) { h.ObserveWithTime(value, time.Now()) } -func (h *histogram) Describe(ch chan<- *prometheus.Desc) { - h.prom.Describe(ch) -} - -func (h *histogram) Collect(ch chan<- prometheus.Metric) { - h.prom.Collect(ch) -} - // ----- // stats // ----- +func incrementCounters(entry *logEntry) { + requests.IncWithTime(entry.Time) + if entry.Result.IsFiltered { + filtered.IncWithTime(entry.Time) + } + + switch entry.Result.Reason { + case dnsfilter.NotFilteredWhiteList: + whitelisted.IncWithTime(entry.Time) + case dnsfilter.NotFilteredError: + errorsTotal.IncWithTime(entry.Time) + case dnsfilter.FilteredBlackList: + filteredLists.IncWithTime(entry.Time) + case dnsfilter.FilteredSafeBrowsing: + filteredSafebrowsing.IncWithTime(entry.Time) + case dnsfilter.FilteredParental: + filteredParental.IncWithTime(entry.Time) + case dnsfilter.FilteredInvalid: + // do nothing + case dnsfilter.FilteredSafeSearch: + safesearch.IncWithTime(entry.Time) + } + elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) +} + func HandleStats(w http.ResponseWriter, r *http.Request) { const numHours = 24 histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) diff --git a/dnsforward/upstream.go b/dnsforward/upstream.go new file mode 100644 index 00000000..89016951 --- /dev/null +++ b/dnsforward/upstream.go @@ -0,0 +1,239 @@ +package dnsforward + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/joomcode/errorx" + "github.com/miekg/dns" +) + +const defaultTimeout = time.Second * 10 + +type Upstream interface { + Exchange(m *dns.Msg) (*dns.Msg, error) + Address() string +} + +// +// plain DNS +// +type plainDNS struct { + boot bootstrapper + preferTCP bool +} + +var defaultUDPClient = dns.Client{ + Timeout: defaultTimeout, + UDPSize: dns.MaxMsgSize, +} + +var defaultTCPClient = dns.Client{ + Net: "tcp", + UDPSize: dns.MaxMsgSize, + Timeout: defaultTimeout, +} + +// Address returns the original address that we've put in initially, not resolved one +func (p *plainDNS) Address() string { return p.boot.address } + +func (p *plainDNS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, _, err := p.boot.get() + if err != nil { + return nil, err + } + if p.preferTCP { + reply, _, err := defaultTCPClient.Exchange(m, addr) + return reply, err + } + + reply, _, err := defaultUDPClient.Exchange(m, addr) + if err != nil && reply != nil && reply.Truncated { + log.Printf("Truncated message was received, retrying over TCP, question: %s", m.Question[0].String()) + reply, _, err = defaultTCPClient.Exchange(m, addr) + } + + return reply, err +} + +// +// DNS-over-TLS +// +type dnsOverTLS struct { + boot bootstrapper + pool *TLSPool + + sync.RWMutex // protects pool +} + +func (p *dnsOverTLS) Address() string { return p.boot.address } + +func (p *dnsOverTLS) Exchange(m *dns.Msg) (*dns.Msg, error) { + var pool *TLSPool + p.RLock() + pool = p.pool + p.RUnlock() + if pool == nil { + p.Lock() + // lazy initialize it + p.pool = &TLSPool{boot: &p.boot} + p.Unlock() + } + + p.RLock() + poolConn, err := p.pool.Get() + p.RUnlock() + if err != nil { + return nil, errorx.Decorate(err, "Failed to get a connection from TLSPool to %s", p.Address()) + } + c := dns.Conn{Conn: poolConn} + err = c.WriteMsg(m) + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to send a request to %s", p.Address()) + } + + reply, err := c.ReadMsg() + if err != nil { + poolConn.Close() + return nil, errorx.Decorate(err, "Failed to read a request from %s", p.Address()) + } + p.RLock() + p.pool.Put(poolConn) + p.RUnlock() + return reply, nil +} + +// +// DNS-over-https +// +type dnsOverHTTPS struct { + boot bootstrapper +} + +func (p *dnsOverHTTPS) Address() string { return p.boot.address } + +func (p *dnsOverHTTPS) Exchange(m *dns.Msg) (*dns.Msg, error) { + addr, tlsConfig, err := p.boot.get() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't bootstrap %s", p.boot.address) + } + + buf, err := m.Pack() + if err != nil { + return nil, errorx.Decorate(err, "Couldn't pack request msg") + } + bb := bytes.NewBuffer(buf) + + // set up a custom request with custom URL + url, err := url.Parse(p.boot.address) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't parse URL %s", p.boot.address) + } + req := http.Request{ + Method: "POST", + URL: url, + Body: ioutil.NopCloser(bb), + Header: make(http.Header), + Host: url.Host, + } + url.Host = addr + req.Header.Set("Content-Type", "application/dns-message") + client := http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsConfig}, + } + resp, err := client.Do(&req) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + return nil, errorx.Decorate(err, "Couldn't do a POST request to '%s'", addr) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't read body contents for '%s'", addr) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Got an unexpected HTTP status code %d from '%s'", resp.StatusCode, addr) + } + if len(body) == 0 { + return nil, fmt.Errorf("Got an unexpected empty body from '%s'", addr) + } + response := dns.Msg{} + err = response.Unpack(body) + if err != nil { + return nil, errorx.Decorate(err, "Couldn't unpack DNS response from '%s': body is %s", addr, string(body)) + } + return &response, nil +} + +func (s *Server) chooseUpstream() Upstream { + upstreams := s.Upstreams + if upstreams == nil { + upstreams = defaultValues.Upstreams + } + if len(upstreams) == 0 { + panic("SHOULD NOT HAPPEN: no default upstreams specified") + } + if len(upstreams) == 1 { + return upstreams[0] + } + n := rand.Intn(len(upstreams)) + upstream := upstreams[n] + return upstream +} + +func AddressToUpstream(address string, bootstrap string) (Upstream, error) { + if strings.Contains(address, "://") { + url, err := url.Parse(address) + if err != nil { + return nil, errorx.Decorate(err, "Failed to parse %s", address) + } + switch url.Scheme { + case "dns": + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{boot: toBoot(url.Host, bootstrap)}, nil + case "tcp": + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{boot: toBoot(url.Host, bootstrap), preferTCP: true}, nil + case "tls": + if url.Port() == "" { + url.Host += ":853" + } + return &dnsOverTLS{boot: toBoot(url.String(), bootstrap)}, nil + case "https": + if url.Port() == "" { + url.Host += ":443" + } + return &dnsOverHTTPS{boot: toBoot(url.String(), bootstrap)}, nil + default: + // assume it's plain DNS + if url.Port() == "" { + url.Host += ":53" + } + return &plainDNS{boot: toBoot(url.String(), bootstrap)}, nil + } + } + + // we don't have scheme in the url, so it's just a plain DNS host:port + _, _, err := net.SplitHostPort(address) + if err != nil { + // doesn't have port, default to 53 + address = net.JoinHostPort(address, "53") + } + return &plainDNS{boot: toBoot(address, bootstrap)}, nil +} diff --git a/dnsforward/upstream_pool.go b/dnsforward/upstream_pool.go new file mode 100644 index 00000000..ca597808 --- /dev/null +++ b/dnsforward/upstream_pool.go @@ -0,0 +1,74 @@ +package dnsforward + +import ( + "crypto/tls" + "net" + "sync" + + "github.com/joomcode/errorx" +) + +// Upstream TLS pool. +// +// Example: +// pool := TLSPool{Address: "tls://1.1.1.1:853"} +// netConn, err := pool.Get() +// if err != nil {panic(err)} +// c := dns.Conn{Conn: netConn} +// q := dns.Msg{} +// q.SetQuestion("google.com.", dns.TypeA) +// log.Println(q) +// err = c.WriteMsg(&q) +// if err != nil {panic(err)} +// r, err := c.ReadMsg() +// if err != nil {panic(err)} +// log.Println(r) +// pool.Put(c.Conn) +type TLSPool struct { + boot *bootstrapper + + // connections + conns []net.Conn + connsMutex sync.Mutex // protects conns +} + +func (n *TLSPool) Get() (net.Conn, error) { + address, tlsConfig, err := n.boot.get() + if err != nil { + return nil, err + } + + // get the connection from the slice inside the lock + var c net.Conn + n.connsMutex.Lock() + num := len(n.conns) + if num > 0 { + last := num - 1 + c = n.conns[last] + n.conns = n.conns[:last] + } + n.connsMutex.Unlock() + + // if we got connection from the slice, return it + if c != nil { + // log.Printf("Returning existing connection to %s", host) + return c, nil + } + + // we'll need a new connection, dial now + // log.Printf("Dialing to %s", address) + conn, err := tls.Dial("tcp", address, tlsConfig) + if err != nil { + return nil, errorx.Decorate(err, "Failed to connect to %s", address) + } + return conn, nil +} + +func (n *TLSPool) Put(c net.Conn) { + if c == nil { + return + } + n.connsMutex.Lock() + n.conns = append(n.conns, c) + n.connsMutex.Unlock() +} diff --git a/dnsforward/upstream_test.go b/dnsforward/upstream_test.go new file mode 100644 index 00000000..0b83670f --- /dev/null +++ b/dnsforward/upstream_test.go @@ -0,0 +1,96 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestUpstreams(t *testing.T) { + upstreams := []struct { + address string + bootstrap string + }{ + { + address: "8.8.8.8:53", + bootstrap: "8.8.8.8:53", + }, + { + address: "1.1.1.1", + bootstrap: "", + }, + { + address: "tcp://1.1.1.1:53", + bootstrap: "", + }, + { + address: "176.103.130.130:5353", + bootstrap: "", + }, + { + address: "tls://1.1.1.1", + bootstrap: "", + }, + { + address: "tls://9.9.9.9:853", + bootstrap: "", + }, + { + address: "tls://security-filter-dns.cleanbrowsing.org", + bootstrap: "8.8.8.8:53", + }, + { + address: "tls://adult-filter-dns.cleanbrowsing.org:853", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://cloudflare-dns.com/dns-query", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://dns.google.com/experimental", + bootstrap: "8.8.8.8:53", + }, + { + address: "https://doh.cleanbrowsing.org/doh/security-filter/", + bootstrap: "", + }, + } + for _, test := range upstreams { + t.Run(test.address, func(t *testing.T) { + u, err := AddressToUpstream(test.address, test.bootstrap) + if err != nil { + t.Fatalf("Failed to generate upstream from address %s: %s", test.address, err) + } + + checkUpstream(t, u, test.address) + }) + } +} + +func checkUpstream(t *testing.T, u Upstream, addr string) { + t.Helper() + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + reply, err := u.Exchange(&req) + if err != nil { + t.Fatalf("Couldn't talk to upstream %s: %s", addr, err) + } + if len(reply.Answer) != 1 { + t.Fatalf("DNS upstream %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) + } + if a, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(a.A) { + t.Fatalf("DNS upstream %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) + } + } else { + t.Fatalf("DNS upstream %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + } +} diff --git a/filter.go b/filter.go new file mode 100644 index 00000000..1150d292 --- /dev/null +++ b/filter.go @@ -0,0 +1,251 @@ +package main + +import ( + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + "github.com/AdguardTeam/AdGuardHome/dnsfilter" +) + +var ( + nextFilterID = time.Now().Unix() // semi-stable way to generate an unique ID + filterTitleRegexp = regexp.MustCompile(`^! Title: +(.*)$`) +) + +// field ordering is important -- yaml fields will mirror ordering from here +type filter struct { + Enabled bool `json:"enabled"` + URL string `json:"url"` + Name string `json:"name" yaml:"name"` + RulesCount int `json:"rulesCount" yaml:"-"` + LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` + + dnsfilter.Filter `yaml:",inline"` +} + +// Creates a helper object for working with the user rules +func userFilter() filter { + return filter{ + // User filter always has constant ID=0 + Enabled: true, + Filter: dnsfilter.Filter{ + Rules: config.UserRules, + }, + } +} + +func deduplicateFilters() { + // Deduplicate filters + i := 0 // output index, used for deletion later + urls := map[string]bool{} + for _, filter := range config.Filters { + if _, ok := urls[filter.URL]; !ok { + // we didn't see it before, keep it + urls[filter.URL] = true // remember the URL + config.Filters[i] = filter + i++ + } + } + + // all entries we want to keep are at front, delete the rest + config.Filters = config.Filters[:i] +} + +// Set the next filter ID to max(filter.ID) + 1 +func updateUniqueFilterID(filters []filter) { + for _, filter := range filters { + if nextFilterID < filter.ID { + nextFilterID = filter.ID + 1 + } + } +} + +func assignUniqueFilterID() int64 { + value := nextFilterID + nextFilterID += 1 + return value +} + +// Sets up a timer that will be checking for filters updates periodically +func periodicallyRefreshFilters() { + for range time.Tick(time.Minute) { + refreshFiltersIfNeccessary(false) + } +} + +// Checks filters updates if necessary +// If force is true, it ignores the filter.LastUpdated field value +func refreshFiltersIfNeccessary(force bool) int { + config.Lock() + + // fetch URLs + updateCount := 0 + for i := range config.Filters { + filter := &config.Filters[i] // otherwise we will be operating on a copy + + if filter.ID == 0 { // protect against users modifying the yaml and removing the ID + filter.ID = assignUniqueFilterID() + } + + updated, err := filter.update(force) + if err != nil { + log.Printf("Failed to update filter %s: %s\n", filter.URL, err) + continue + } + if updated { + // Saving it to the filters dir now + err = filter.save() + if err != nil { + log.Printf("Failed to save the updated filter %d: %s", filter.ID, err) + continue + } + + updateCount++ + } + } + config.Unlock() + + if updateCount > 0 { + reconfigureDNSServer() + } + return updateCount +} + +// A helper function that parses filter contents and returns a number of rules and a filter name (if there's any) +func parseFilterContents(contents []byte) (int, string, []string) { + lines := strings.Split(string(contents), "\n") + rulesCount := 0 + name := "" + seenTitle := false + + // Count lines in the filter + for _, line := range lines { + line = strings.TrimSpace(line) + if len(line) > 0 && line[0] == '!' { + if m := filterTitleRegexp.FindAllStringSubmatch(line, -1); len(m) > 0 && len(m[0]) >= 2 && !seenTitle { + name = m[0][1] + seenTitle = true + } + } else if len(line) != 0 { + rulesCount++ + } + } + + return rulesCount, name, lines +} + +// Checks for filters updates +// If "force" is true -- does not check the filter's LastUpdated field +// Call "save" to persist the filter contents +func (filter *filter) update(force bool) (bool, error) { + if filter.ID == 0 { // protect against users deleting the ID + filter.ID = assignUniqueFilterID() + } + if !filter.Enabled { + return false, nil + } + if !force && time.Since(filter.LastUpdated) <= updatePeriod { + return false, nil + } + + log.Printf("Downloading update for filter %d from %s", filter.ID, filter.URL) + + // use the same update period for failed filter downloads to avoid flooding with requests + filter.LastUpdated = time.Now() + + resp, err := client.Get(filter.URL) + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + if err != nil { + log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err) + return false, err + } + + if resp.StatusCode != 200 { + log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL) + return false, fmt.Errorf("got status code != 200: %d", resp.StatusCode) + } + + contentType := strings.ToLower(resp.Header.Get("content-type")) + if !strings.HasPrefix(contentType, "text/plain") { + log.Printf("Non-text response %s from %s, skipping", contentType, filter.URL) + return false, fmt.Errorf("non-text response %s", contentType) + } + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Printf("Couldn't fetch filter contents from URL %s, skipping: %s", filter.URL, err) + return false, err + } + + // Extract filter name and count number of rules + rulesCount, filterName, rules := parseFilterContents(body) + + if filterName != "" { + filter.Name = filterName + } + + // Check if the filter has been really changed + if reflect.DeepEqual(filter.Rules, rules) { + log.Printf("Filter #%d at URL %s hasn't changed, not updating it", filter.ID, filter.URL) + return false, nil + } + + log.Printf("Filter %d has been updated: %d bytes, %d rules", filter.ID, len(body), rulesCount) + filter.RulesCount = rulesCount + filter.Rules = rules + + return true, nil +} + +// saves filter contents to the file in dataDir +func (filter *filter) save() error { + filterFilePath := filter.Path() + log.Printf("Saving filter %d contents to: %s", filter.ID, filterFilePath) + body := []byte(strings.Join(filter.Rules, "\n")) + + return safeWriteFile(filterFilePath, body) +} + +// loads filter contents from the file in dataDir +func (filter *filter) load() error { + if !filter.Enabled { + // No need to load a filter that is not enabled + return nil + } + + filterFilePath := filter.Path() + log.Printf("Loading filter %d contents to: %s", filter.ID, filterFilePath) + + if _, err := os.Stat(filterFilePath); os.IsNotExist(err) { + // do nothing, file doesn't exist + return err + } + + filterFileContents, err := ioutil.ReadFile(filterFilePath) + if err != nil { + return err + } + + log.Printf("File %s, id %d, length %d", filterFilePath, filter.ID, len(filterFileContents)) + rulesCount, _, rules := parseFilterContents(filterFileContents) + + filter.RulesCount = rulesCount + filter.Rules = rules + + return nil +} + +// Path to the filter contents +func (filter *filter) Path() string { + return filepath.Join(config.ourBinaryDir, dataDir, filterDir, strconv.FormatInt(filter.ID, 10)+".txt") +} diff --git a/go.mod b/go.mod index dae96b71..166e3cce 100644 --- a/go.mod +++ b/go.mod @@ -3,34 +3,19 @@ module github.com/AdguardTeam/AdGuardHome require ( github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 - github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 - github.com/coredns/coredns v1.2.6 - github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 // indirect - github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 // indirect - github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-test/deep v1.0.1 github.com/gobuffalo/packr v1.19.0 - github.com/google/uuid v1.0.0 // indirect - github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect - github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/mholt/caddy v0.11.0 + github.com/joomcode/errorx v0.1.0 github.com/miekg/dns v1.0.15 - github.com/opentracing/opentracing-go v1.0.2 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/pkg/errors v0.8.0 - github.com/prometheus/client_golang v0.9.0-pre1 - github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect - github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 // indirect - github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect go.uber.org/goleak v0.10.0 golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd golang.org/x/net v0.0.0-20181108082009-03003ca0c849 golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 // indirect - google.golang.org/grpc v1.16.0 // indirect gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 gopkg.in/yaml.v2 v2.2.1 ) diff --git a/go.sum b/go.sum index 06efaa9e..af10df24 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,11 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coredns/coredns v1.2.6 h1:QIAOkBqVE44Zx0ttrFqgE5YhCEn64XPIngU60JyuTGM= -github.com/coredns/coredns v1.2.6/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 h1:m8nX8hsUghn853BJ5qB0lX+VvS6LTJPksWyILFZRYN4= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11/go.mod h1:s1PfVYYVmTMgCSPtho4LKBDecEHJWtiVDPNv78Z985U= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 h1:QdyRyGZWLEvJG5Kw3VcVJvhXJ5tZ1MkRgqpJOEZSySM= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710/go.mod h1:eNde4IQyEiA5br02AouhEHCu3p3UzrCdFR4LuQHklMI= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= @@ -28,44 +16,21 @@ github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264 h1:roWyi0eEdiFreSq github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264/go.mod h1:Yf2toFaISlyQrr5TfO3h6DB9pl9mZRmyvBGQb/aQ/pI= github.com/gobuffalo/packr v1.19.0 h1:3UDmBDxesCOPF8iZdMDBBWKfkBoYujIMIZePnobqIUI= github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc= +github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mholt/caddy v0.11.0 h1:cuhEyR7So/SBBRiAaiRBe9BoccDu6uveIPuM9FMMavg= -github.com/mholt/caddy v0.11.0/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY= github.com/miekg/dns v1.0.15 h1:9+UupePBQCG6zf1q/bGmTO1vumoG13jsrbWOSX1W6Tw= github.com/miekg/dns v1.0.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM= -github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 h1:MVbUQq1a49hMEISI29UcAUjywT3FyvDwx5up90OvVa4= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= @@ -80,29 +45,16 @@ go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4= go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd h1:VtIkGDhk0ph3t+THbvXHfMZ8QHgsBO39Nh52+74pq7w= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg= golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 h1:YoY1wS6JYVRpIfFngRf2HHo9R9dAne3xbkGOQ5rJXjU= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/upgrade.go b/upgrade.go index 4154ee03..21d7686d 100644 --- a/upgrade.go +++ b/upgrade.go @@ -10,6 +10,8 @@ import ( "gopkg.in/yaml.v2" ) +const currentSchemaVersion = 2 // used for upgrading from old configs to new config + // Performs necessary upgrade operations if needed func upgradeConfig() error { // read a config file into an interface map, so we can manipulate values without losing any @@ -57,7 +59,12 @@ func upgradeConfig() error { func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) error { switch oldVersion { case 0: - err := upgradeSchema0to1(diskConfig) + err := upgradeSchema0to2(diskConfig) + if err != nil { + return err + } + case 1: + err := upgradeSchema1to2(diskConfig) if err != nil { return err } @@ -83,14 +90,13 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err return nil } +// The first schema upgrade: +// No more "dnsfilter.txt", filters are now kept in data/filters/ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { log.Printf("%s(): called", _Func()) - // The first schema upgrade: - // No more "dnsfilter.txt", filters are now kept in data/filters/ dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") - _, err := os.Stat(dnsFilterPath) - if !os.IsNotExist(err) { + if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) { log.Printf("Deleting %s as we don't need it anymore", dnsFilterPath) err = os.Remove(dnsFilterPath) if err != nil { @@ -103,3 +109,38 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error { return nil } + +// Second schema upgrade: +// coredns is now dns in config +// delete 'Corefile', since we don't use that anymore +func upgradeSchema1to2(diskConfig *map[string]interface{}) error { + log.Printf("%s(): called", _Func()) + + coreFilePath := filepath.Join(config.ourBinaryDir, "Corefile") + if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) { + log.Printf("Deleting %s as we don't need it anymore", coreFilePath) + err = os.Remove(coreFilePath) + if err != nil { + log.Printf("Cannot remove %s due to %s", coreFilePath, err) + // not fatal, move on + } + } + + if _, ok := (*diskConfig)["dns"]; !ok { + (*diskConfig)["dns"] = (*diskConfig)["coredns"] + delete((*diskConfig), "coredns") + } + (*diskConfig)["schema_version"] = 2 + + return nil +} + +// jump two schemas at once -- this time we just do it sequentially +func upgradeSchema0to2(diskConfig *map[string]interface{}) error { + err := upgradeSchema0to1(diskConfig) + if err != nil { + return err + } + + return upgradeSchema1to2(diskConfig) +} diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go deleted file mode 100644 index 171f6362..00000000 --- a/upstream/dns_upstream.go +++ /dev/null @@ -1,105 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "time" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// DnsUpstream is a very simple upstream implementation for plain DNS -type DnsUpstream struct { - endpoint string // IP:port - timeout time.Duration // Max read and write timeout - proto string // Protocol (tcp, tcp-tls, or udp) - transport *Transport // Persistent connections cache -} - -// NewDnsUpstream creates a new DNS upstream -func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { - u := &DnsUpstream{ - endpoint: endpoint, - timeout: defaultTimeout, - proto: proto, - } - - var tlsConfig *tls.Config - - if proto == "tcp-tls" { - tlsConfig = new(tls.Config) - tlsConfig.ServerName = tlsServerName - } - - // Initialize the connections cache - u.transport = NewTransport(endpoint) - u.transport.tlsConfig = tlsConfig - u.transport.Start() - - return u, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - resp, err := u.exchange(u.proto, query) - - // Retry over TCP if response is truncated - if err == dns.ErrTruncated && u.proto == "udp" { - resp, err = u.exchange("tcp", query) - } else if err == dns.ErrTruncated && resp != nil { - // Reassemble something to be sent to client - m := new(dns.Msg) - m.SetReply(query) - m.Truncated = true - m.Authoritative = true - m.Rcode = dns.RcodeSuccess - return m, nil - } - - if err != nil { - resp = &dns.Msg{} - resp.SetRcode(resp, dns.RcodeServerFailure) - } - - return resp, err -} - -// Clear resources -func (u *DnsUpstream) Close() error { - // Close active connections - u.transport.Stop() - return nil -} - -// Performs a synchronous query. It sends the message m via the conn -// c and waits for a reply. The conn c is not closed. -func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) { - // Establish a connection if needed (or reuse cached) - conn, err := u.transport.Dial(proto) - if err != nil { - return nil, err - } - - // Write the request with a timeout - conn.SetWriteDeadline(time.Now().Add(u.timeout)) - if err = conn.WriteMsg(query); err != nil { - conn.Close() // Not giving it back - return nil, err - } - - // Write response with a timeout - conn.SetReadDeadline(time.Now().Add(u.timeout)) - r, err = conn.ReadMsg() - if err != nil { - conn.Close() // Not giving it back - } else if err == nil && r.Id != query.Id { - err = dns.ErrId - conn.Close() // Not giving it back - } - - if err == nil { - // Return it back to the connections cache if there were no errors - u.transport.Yield(conn) - } - return r, err -} diff --git a/upstream/helpers.go b/upstream/helpers.go deleted file mode 100644 index 520a7a8b..00000000 --- a/upstream/helpers.go +++ /dev/null @@ -1,98 +0,0 @@ -package upstream - -import ( - "net" - "strings" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// Detects the upstream type from the specified url and creates a proper Upstream object -func NewUpstream(url string, bootstrap string) (Upstream, error) { - proto := "udp" - prefix := "" - - switch { - case strings.HasPrefix(url, "tcp://"): - proto = "tcp" - prefix = "tcp://" - case strings.HasPrefix(url, "tls://"): - proto = "tcp-tls" - prefix = "tls://" - case strings.HasPrefix(url, "https://"): - return NewHttpsUpstream(url, bootstrap) - } - - hostname := strings.TrimPrefix(url, prefix) - - host, port, err := net.SplitHostPort(hostname) - if err != nil { - // Set port depending on the protocol - switch proto { - case "udp": - port = "53" - case "tcp": - port = "53" - case "tcp-tls": - port = "853" - } - - // Set host = hostname - host = hostname - } - - // Try to resolve the host address (or check if it's an IP address) - bootstrapResolver := CreateResolver(bootstrap) - ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) - - if err != nil || len(ips) == 0 { - return nil, err - } - - addr := ips[0].String() - endpoint := net.JoinHostPort(addr, port) - tlsServerName := "" - - if proto == "tcp-tls" && host != addr { - // Check if we need to specify TLS server name - tlsServerName = host - } - - return NewDnsUpstream(endpoint, proto, tlsServerName) -} - -func CreateResolver(bootstrap string) *net.Resolver { - bootstrapResolver := net.DefaultResolver - - if bootstrap != "" { - bootstrapResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, network, bootstrap) - }, - } - } - - return bootstrapResolver -} - -// Performs a simple health-check of the specified upstream -func IsAlive(u Upstream) (bool, error) { - // Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere - ping := new(dns.Msg) - ping.SetQuestion("ipv4only.arpa.", dns.TypeA) - - resp, err := u.Exchange(context.Background(), ping) - - // If we got a header, we're alright, basically only care about I/O errors 'n stuff. - if err != nil && resp != nil { - // Silly check, something sane came back. - if resp.Rcode != dns.RcodeServerFailure { - err = nil - } - } - - return err == nil, err -} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go deleted file mode 100644 index d7d7bdde..00000000 --- a/upstream/https_upstream.go +++ /dev/null @@ -1,128 +0,0 @@ -package upstream - -import ( - "bytes" - "crypto/tls" - "fmt" - "io/ioutil" - "log" - "net" - "net/http" - "net/url" - "time" - - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" - "golang.org/x/net/http2" -) - -const ( - dnsMessageContentType = "application/dns-message" - defaultKeepAlive = 30 * time.Second -) - -// HttpsUpstream is the upstream implementation for DNS-over-HTTPS -type HttpsUpstream struct { - client *http.Client - endpoint *url.URL -} - -// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url -func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { - u, err := url.Parse(endpoint) - if err != nil { - return nil, err - } - - // Initialize bootstrap resolver - bootstrapResolver := CreateResolver(bootstrap) - dialer := &net.Dialer{ - Timeout: defaultTimeout, - KeepAlive: defaultKeepAlive, - DualStack: true, - Resolver: bootstrapResolver, - } - - // Update TLS and HTTP client configuration - tlsConfig := &tls.Config{ServerName: u.Hostname()} - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - DisableCompression: true, - MaxIdleConns: 1, - DialContext: dialer.DialContext, - } - http2.ConfigureTransport(transport) - - client := &http.Client{ - Timeout: defaultTimeout, - Transport: transport, - } - - return &HttpsUpstream{client: client, endpoint: u}, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - queryBuf, err := query.Pack() - if err != nil { - return nil, errors.Wrap(err, "failed to pack DNS query") - } - - // No content negotiation for now, use DNS wire format - buf, backendErr := u.exchangeWireformat(queryBuf) - if backendErr == nil { - response := &dns.Msg{} - if err := response.Unpack(buf); err != nil { - return nil, errors.Wrap(err, "failed to unpack DNS response from body") - } - - response.Id = query.Id - return response, nil - } - - log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) - return nil, backendErr -} - -// Perform message exchange with the default UDP wireformat defined in current draft -// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 -func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { - req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) - if err != nil { - return nil, errors.Wrap(err, "failed to create an HTTPS request") - } - - req.Header.Add("Content-Type", dnsMessageContentType) - req.Header.Add("Accept", dnsMessageContentType) - req.Host = u.endpoint.Hostname() - - resp, err := u.client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "failed to perform an HTTPS request") - } - - // Check response status code - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("returned status code %d", resp.StatusCode) - } - - contentType := resp.Header.Get("Content-Type") - if contentType != dnsMessageContentType { - return nil, fmt.Errorf("return wrong content type %s", contentType) - } - - // Read application/dns-message response from the body - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, errors.Wrap(err, "failed to read the response body") - } - - return buf, nil -} - -// Clear resources -func (u *HttpsUpstream) Close() error { - return nil -} diff --git a/upstream/persistent.go b/upstream/persistent.go deleted file mode 100644 index 91cc9094..00000000 --- a/upstream/persistent.go +++ /dev/null @@ -1,210 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "net" - "sort" - "sync/atomic" - "time" - - "github.com/miekg/dns" -) - -// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin - -const ( - defaultExpire = 10 * time.Second - minDialTimeout = 100 * time.Millisecond - maxDialTimeout = 30 * time.Second - defaultDialTimeout = 30 * time.Second - cumulativeAvgWeight = 4 -) - -// a persistConn hold the dns.Conn and the last used time. -type persistConn struct { - c *dns.Conn - used time.Time -} - -// Transport hold the persistent cache. -type Transport struct { - avgDialTime int64 // kind of average time of dial time - conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. - expire time.Duration // After this duration a connection is expired. - addr string - tlsConfig *tls.Config - - dial chan string - yield chan *dns.Conn - ret chan *dns.Conn - stop chan bool -} - -// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *Transport) Dial(proto string) (*dns.Conn, error) { - // If tls has been configured; use it. - if t.tlsConfig != nil { - proto = "tcp-tls" - } - - t.dial <- proto - c := <-t.ret - - if c != nil { - return c, nil - } - - reqTime := time.Now() - timeout := t.dialTimeout() - if proto == "tcp-tls" { - conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err - } - conn, err := dns.DialTimeout(proto, t.addr, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err -} - -// Yield return the connection to transport for reuse. -func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } - -// Start starts the transport's connection manager. -func (t *Transport) Start() { go t.connManager() } - -// Stop stops the transport's connection manager. -func (t *Transport) Stop() { close(t.stop) } - -// SetExpire sets the connection expire time in transport. -func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } - -// SetTLSConfig sets the TLS config in transport. -func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } - -func NewTransport(addr string) *Transport { - t := &Transport{ - avgDialTime: int64(defaultDialTimeout / 2), - conns: make(map[string][]*persistConn), - expire: defaultExpire, - addr: addr, - dial: make(chan string), - yield: make(chan *dns.Conn), - ret: make(chan *dns.Conn), - stop: make(chan bool), - } - return t -} - -func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { - dt := time.Duration(atomic.LoadInt64(currentAvg)) - atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) -} - -func (t *Transport) dialTimeout() time.Duration { - return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) -} - -func (t *Transport) updateDialTimeout(newDialTime time.Duration) { - averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) -} - -// limitTimeout is a utility function to auto-tune timeout values -// average observed time is moved towards the last observed delay moderated by a weight -// next timeout to use will be the double of the computed average, limited by min and max frame. -func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { - rt := time.Duration(atomic.LoadInt64(currentAvg)) - if rt < minValue { - return minValue - } - if rt < maxValue/2 { - return 2 * rt - } - return maxValue -} - -// connManagers manages the persistent connection cache for UDP and TCP. -func (t *Transport) connManager() { - ticker := time.NewTicker(t.expire) -Wait: - for { - select { - case proto := <-t.dial: - // take the last used conn - complexity O(1) - if stack := t.conns[proto]; len(stack) > 0 { - pc := stack[len(stack)-1] - if time.Since(pc.used) < t.expire { - // Found one, remove from pool and return this conn. - t.conns[proto] = stack[:len(stack)-1] - t.ret <- pc.c - continue Wait - } - // clear entire cache if the last conn is expired - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - } - - t.ret <- nil - - case conn := <-t.yield: - - // no proto here, infer from config and conn - if _, ok := conn.Conn.(*net.UDPConn); ok { - t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) - continue Wait - } - - if t.tlsConfig == nil { - t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) - continue Wait - } - - t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) - - case <-ticker.C: - t.cleanup(false) - - case <-t.stop: - t.cleanup(true) - close(t.ret) - return - } - } -} - -// closeConns closes connections. -func closeConns(conns []*persistConn) { - for _, pc := range conns { - pc.c.Close() - } -} - -// cleanup removes connections from cache. -func (t *Transport) cleanup(all bool) { - staleTime := time.Now().Add(-t.expire) - for proto, stack := range t.conns { - if len(stack) == 0 { - continue - } - if all { - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - continue - } - if stack[0].used.After(staleTime) { - continue - } - - // connections in stack are sorted by "used" - good := sort.Search(len(stack), func(i int) bool { - return stack[i].used.After(staleTime) - }) - t.conns[proto] = stack[good:] - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack[:good]) - } -} diff --git a/upstream/setup.go b/upstream/setup.go deleted file mode 100644 index 4aed6bcf..00000000 --- a/upstream/setup.go +++ /dev/null @@ -1,81 +0,0 @@ -package upstream - -import ( - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/mholt/caddy" -) - -func init() { - caddy.RegisterPlugin("upstream", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -// Read the configuration and initialize upstreams -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnShutdown(p.onShutdown) - return nil -} - -// Read the configuration -func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) { - p := New() - - log.Println("Initializing the Upstream plugin") - - bootstrap := "" - upstreamUrls := []string{} - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - upstreamUrls = append(upstreamUrls, args...) - } - for c.NextBlock() { - switch c.Val() { - case "bootstrap": - if !c.NextArg() { - return nil, c.ArgErr() - } - bootstrap = c.Val() - } - } - } - - for _, url := range upstreamUrls { - u, err := NewUpstream(url, bootstrap) - if err != nil { - log.Printf("Cannot initialize upstream %s", url) - return nil, err - } - - p.Upstreams = append(p.Upstreams, u) - } - - return p, nil -} - -func (p *UpstreamPlugin) onShutdown() error { - for i := range p.Upstreams { - u := p.Upstreams[i] - err := u.Close() - if err != nil { - log.Printf("Error while closing the upstream: %s", err) - } - } - - return nil -} diff --git a/upstream/setup_test.go b/upstream/setup_test.go deleted file mode 100644 index 82b8ab5c..00000000 --- a/upstream/setup_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package upstream - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - var tests = []struct { - config string - }{ - {`upstream 8.8.8.8`}, - {`upstream 8.8.8.8 { - bootstrap 8.8.8.8:53 -}`}, - {`upstream tls://1.1.1.1 8.8.8.8 { - bootstrap 1.1.1.1 -}`}, - } - - for _, test := range tests { - c := caddy.NewTestController("dns", test.config) - err := setup(c) - if err != nil { - t.Fatalf("Test failed") - } - } -} diff --git a/upstream/upstream.go b/upstream/upstream.go deleted file mode 100644 index faef224e..00000000 --- a/upstream/upstream.go +++ /dev/null @@ -1,57 +0,0 @@ -package upstream - -import ( - "time" - - "github.com/coredns/coredns/plugin" - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" -) - -const ( - defaultTimeout = 5 * time.Second -) - -// Upstream is a simplified interface for proxy destination -type Upstream interface { - Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) - Close() error -} - -// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface -type UpstreamPlugin struct { - Upstreams []Upstream - Next plugin.Handler -} - -// Initialize the upstream plugin -func New() *UpstreamPlugin { - p := &UpstreamPlugin{ - Upstreams: []Upstream{}, - } - - return p -} - -// ServeDNS implements interface for CoreDNS plugin -func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - var reply *dns.Msg - var backendErr error - - for i := range p.Upstreams { - upstream := p.Upstreams[i] - reply, backendErr = upstream.Exchange(ctx, r) - if backendErr == nil { - w.WriteMsg(reply) - return 0, nil - } - } - - return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") -} - -// Name implements interface for CoreDNS plugin -func (p *UpstreamPlugin) Name() string { - return "upstream" -} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go deleted file mode 100644 index 9221e6f5..00000000 --- a/upstream/upstream_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package upstream - -import ( - "net" - "testing" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -func TestDnsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestHttpsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsOverTlsIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstream(t, u) - } -} - -func TestHttpsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstream(t, u) - } -} - -func TestDnsOverTlsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstream(t, u) - } -} - -func testUpstreamIsAlive(t *testing.T, u Upstream) { - alive, err := IsAlive(u) - if !alive || err != nil { - t.Errorf("Upstream is not alive") - } - - u.Close() -} - -func testUpstream(t *testing.T, u Upstream) { - var tests = []struct { - name string - expected net.IP - }{ - {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, - {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, - } - - for _, test := range tests { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - resp, err := u.Exchange(context.Background(), &req) - - if err != nil { - t.Fatalf("error while making an upstream request: %s", err) - } - - if len(resp.Answer) != 1 { - t.Fatalf("no answer section in the response") - } - if answer, ok := resp.Answer[0].(*dns.A); ok { - if !test.expected.Equal(answer.A) { - t.Errorf("wrong IP in the response: %v", answer.A) - } - } - } - - err := u.Close() - if err != nil { - t.Errorf("Error while closing the upstream: %s", err) - } -}