From b84786631026b78e3d04ed3ab42555aabbc9b88c Mon Sep 17 00:00:00 2001 From: Eugene Bujak Date: Wed, 5 Dec 2018 19:18:58 +0300 Subject: [PATCH] Remove unused code. Goodbye CoreDNS. --- coredns_plugin/coredns_plugin.go | 558 --------------------- coredns_plugin/coredns_plugin_test.go | 131 ----- coredns_plugin/coredns_stats.go | 410 --------------- coredns_plugin/querylog.go | 239 --------- coredns_plugin/querylog_file.go | 291 ----------- coredns_plugin/querylog_top.go | 386 -------------- coredns_plugin/ratelimit/ratelimit.go | 182 ------- coredns_plugin/ratelimit/ratelimit_test.go | 80 --- coredns_plugin/refuseany/refuseany.go | 91 ---- coredns_plugin/reload.go | 36 -- go.mod | 16 - go.sum | 50 -- upstream/dns_upstream.go | 105 ---- upstream/helpers.go | 98 ---- upstream/https_upstream.go | 128 ----- upstream/persistent.go | 210 -------- upstream/setup.go | 81 --- upstream/setup_test.go | 29 -- upstream/upstream.go | 57 --- upstream/upstream_test.go | 187 ------- 20 files changed, 3365 deletions(-) delete mode 100644 coredns_plugin/coredns_plugin.go delete mode 100644 coredns_plugin/coredns_plugin_test.go delete mode 100644 coredns_plugin/coredns_stats.go delete mode 100644 coredns_plugin/querylog.go delete mode 100644 coredns_plugin/querylog_file.go delete mode 100644 coredns_plugin/querylog_top.go delete mode 100644 coredns_plugin/ratelimit/ratelimit.go delete mode 100644 coredns_plugin/ratelimit/ratelimit_test.go delete mode 100644 coredns_plugin/refuseany/refuseany.go delete mode 100644 coredns_plugin/reload.go delete mode 100644 upstream/dns_upstream.go delete mode 100644 upstream/helpers.go delete mode 100644 upstream/https_upstream.go delete mode 100644 upstream/persistent.go delete mode 100644 upstream/setup.go delete mode 100644 upstream/setup_test.go delete mode 100644 upstream/upstream.go delete mode 100644 upstream/upstream_test.go diff --git a/coredns_plugin/coredns_plugin.go b/coredns_plugin/coredns_plugin.go deleted file mode 100644 index 8d302fd3..00000000 --- a/coredns_plugin/coredns_plugin.go +++ /dev/null @@ -1,558 +0,0 @@ -package dnsfilter - -import ( - "bufio" - "errors" - "fmt" - "log" - "net" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/pkg/upstream" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -var defaultSOA = &dns.SOA{ - // values copied from verisign's nonexistent .com domain - // their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers - Refresh: 1800, - Retry: 900, - Expire: 604800, - Minttl: 86400, -} - -func init() { - caddy.RegisterPlugin("dnsfilter", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plugFilter struct { - ID int64 - Path string -} - -type plugSettings struct { - SafeBrowsingBlockHost string - ParentalBlockHost string - QueryLogEnabled bool - BlockedTTL uint32 // in seconds, default 3600 - Filters []plugFilter -} - -type plug struct { - d *dnsfilter.Dnsfilter - Next plugin.Handler - upstream upstream.Upstream - settings plugSettings - - sync.RWMutex -} - -var defaultPluginSettings = plugSettings{ - SafeBrowsingBlockHost: "safebrowsing.block.dns.adguard.com", - ParentalBlockHost: "family.block.dns.adguard.com", - BlockedTTL: 3600, // in seconds - Filters: make([]plugFilter, 0), -} - -// -// coredns handling functions -// -func setupPlugin(c *caddy.Controller) (*plug, error) { - // create new Plugin and copy default values - p := &plug{ - settings: defaultPluginSettings, - d: dnsfilter.New(nil), - } - - log.Println("Initializing the CoreDNS plugin") - - for c.Next() { - for c.NextBlock() { - blockValue := c.Val() - switch blockValue { - case "safebrowsing": - log.Println("Browsing security service is enabled") - p.d.SafeBrowsingEnabled = true - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.d.SetSafeBrowsingServer(c.Val()) - } - case "safesearch": - log.Println("Safe search is enabled") - p.d.SafeSearchEnabled = true - case "parental": - if !c.NextArg() { - return nil, c.ArgErr() - } - sensitivity, err := strconv.Atoi(c.Val()) - if err != nil { - return nil, c.ArgErr() - } - - log.Println("Parental control is enabled") - if !dnsfilter.IsParentalSensitivityValid(sensitivity) { - return nil, dnsfilter.ErrInvalidParental - } - p.d.ParentalEnabled = true - p.d.ParentalSensitivity = sensitivity - if c.NextArg() { - if len(c.Val()) == 0 { - return nil, c.ArgErr() - } - p.settings.ParentalBlockHost = c.Val() - } - case "blocked_ttl": - if !c.NextArg() { - return nil, c.ArgErr() - } - blockedTtl, err := strconv.ParseUint(c.Val(), 10, 32) - if err != nil { - return nil, c.ArgErr() - } - log.Printf("Blocked request TTL is %d", blockedTtl) - p.settings.BlockedTTL = uint32(blockedTtl) - case "querylog": - log.Println("Query log is enabled") - p.settings.QueryLogEnabled = true - case "filter": - if !c.NextArg() { - return nil, c.ArgErr() - } - - filterId, err := strconv.ParseInt(c.Val(), 10, 64) - if err != nil { - return nil, c.ArgErr() - } - if !c.NextArg() { - return nil, c.ArgErr() - } - filterPath := c.Val() - - // Initialize filter and add it to the list - p.settings.Filters = append(p.settings.Filters, plugFilter{ - ID: filterId, - Path: filterPath, - }) - } - } - } - - for _, filter := range p.settings.Filters { - log.Printf("Loading rules from %s", filter.Path) - - file, err := os.Open(filter.Path) - if err != nil { - return nil, err - } - defer file.Close() - - count := 0 - scanner := bufio.NewScanner(file) - for scanner.Scan() { - text := scanner.Text() - - err = p.d.AddRule(text, filter.ID) - if err == dnsfilter.ErrAlreadyExists || err == dnsfilter.ErrInvalidSyntax { - continue - } - if err != nil { - log.Printf("Cannot add rule %s: %s", text, err) - // Just ignore invalid rules - continue - } - count++ - } - log.Printf("Added %d rules from filter ID=%d", count, filter.ID) - - if err = scanner.Err(); err != nil { - return nil, err - } - } - - log.Printf("Loading stats from querylog") - err := fillStatsFromQueryLog() - if err != nil { - log.Printf("Failed to load stats from querylog: %s", err) - return nil, err - } - - if p.settings.QueryLogEnabled { - onceQueryLog.Do(func() { - go periodicQueryLogRotate() - go periodicHourlyTopRotate() - go statsRotator() - }) - } - - onceHook.Do(func() { - caddy.RegisterEventHook("dnsfilter-reload", hook) - }) - - p.upstream, err = upstream.New(nil) - if err != nil { - return nil, err - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(requests) - x.MustRegister(filtered) - x.MustRegister(filteredLists) - x.MustRegister(filteredSafebrowsing) - x.MustRegister(filteredParental) - x.MustRegister(whitelisted) - x.MustRegister(safesearch) - x.MustRegister(errorsTotal) - x.MustRegister(elapsedTime) - x.MustRegister(p) - } - return nil - }) - c.OnShutdown(p.onShutdown) - c.OnFinalShutdown(p.onFinalShutdown) - - return nil -} - -func (p *plug) onShutdown() error { - p.Lock() - p.d.Destroy() - p.d = nil - p.Unlock() - return nil -} - -func (p *plug) onFinalShutdown() error { - logBufferLock.Lock() - err := flushToFile(logBuffer) - if err != nil { - log.Printf("failed to flush to file: %s", err) - return err - } - logBufferLock.Unlock() - return nil -} - -type statsFunc func(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) - -func doDesc(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- *prometheus.Desc) - if !ok { - log.Printf("Couldn't convert ch to chan<- *prometheus.Desc\n") - return - } - realch <- prometheus.NewDesc(name, text, nil, nil) -} - -func doMetric(ch interface{}, name string, text string, value float64, valueType prometheus.ValueType) { - realch, ok := ch.(chan<- prometheus.Metric) - if !ok { - log.Printf("Couldn't convert ch to chan<- prometheus.Metric\n") - return - } - desc := prometheus.NewDesc(name, text, nil, nil) - realch <- prometheus.MustNewConstMetric(desc, valueType, value) -} - -func gen(ch interface{}, doFunc statsFunc, name string, text string, value float64, valueType prometheus.ValueType) { - doFunc(ch, name, text, value, valueType) -} - -func doStatsLookup(ch interface{}, doFunc statsFunc, name string, lookupstats *dnsfilter.LookupStats) { - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_requests", name), fmt.Sprintf("Number of %s HTTP requests that were sent", name), float64(lookupstats.Requests), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_cachehits", name), fmt.Sprintf("Number of %s lookups that didn't need HTTP requests", name), float64(lookupstats.CacheHits), prometheus.CounterValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending", name), fmt.Sprintf("Number of currently pending %s HTTP requests", name), float64(lookupstats.Pending), prometheus.GaugeValue) - gen(ch, doFunc, fmt.Sprintf("coredns_dnsfilter_%s_pending_max", name), fmt.Sprintf("Maximum number of pending %s HTTP requests", name), float64(lookupstats.PendingMax), prometheus.GaugeValue) -} - -func (p *plug) doStats(ch interface{}, doFunc statsFunc) { - p.RLock() - stats := p.d.GetStats() - doStatsLookup(ch, doFunc, "safebrowsing", &stats.Safebrowsing) - doStatsLookup(ch, doFunc, "parental", &stats.Parental) - p.RUnlock() -} - -// Describe is called by prometheus handler to know stat types -func (p *plug) Describe(ch chan<- *prometheus.Desc) { - p.doStats(ch, doDesc) -} - -// Collect is called by prometheus handler to collect stats -func (p *plug) Collect(ch chan<- prometheus.Metric) { - p.doStats(ch, doMetric) -} - -func (p *plug) replaceHostWithValAndReply(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, host string, val string, question dns.Question) (int, error) { - // check if it's a domain name or IP address - addr := net.ParseIP(val) - var records []dns.RR - // log.Println("Will give", val, "instead of", host) // debug logging - if addr != nil { - // this is an IP address, return it - result, err := dns.NewRR(fmt.Sprintf("%s %d A %s", host, p.settings.BlockedTTL, val)) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - records = append(records, result) - } else { - // this is a domain name, need to look it up - req := new(dns.Msg) - req.SetQuestion(dns.Fqdn(val), question.Qtype) - req.RecursionDesired = true - reqstate := request.Request{W: w, Req: req, Context: ctx} - result, err := p.upstream.Lookup(reqstate, dns.Fqdn(val), reqstate.QType()) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - if result != nil { - for _, answer := range result.Answer { - answer.Header().Name = question.Name - } - records = result.Answer - } - } - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Answer = append(m.Answer, records...) - state := request.Request{W: w, Req: r, Context: ctx} - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, fmt.Errorf("plugin/dnsfilter: %s", err) - } - return dns.RcodeSuccess, nil -} - -// generate SOA record that makes DNS clients cache NXdomain results -// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant -func (p *plug) genSOA(r *dns.Msg) []dns.RR { - zone := r.Question[0].Name - header := dns.RR_Header{Name: zone, Rrtype: dns.TypeSOA, Ttl: p.settings.BlockedTTL, Class: dns.ClassINET} - - Mbox := "hostmaster." - if zone[0] != '.' { - Mbox += zone - } - Ns := "fake-for-negative-caching.adguard.com." - - soa := *defaultSOA - soa.Hdr = header - soa.Mbox = Mbox - soa.Ns = Ns - soa.Serial = 100500 // faster than uint32(time.Now().Unix()) - return []dns.RR{&soa} -} - -func (p *plug) writeNXdomain(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r, Context: ctx} - m := new(dns.Msg) - m.SetRcode(state.Req, dns.RcodeNameError) - m.Authoritative, m.RecursionAvailable, m.Compress = true, true, true - m.Ns = p.genSOA(r) - - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return dns.RcodeNameError, nil -} - -func (p *plug) serveDNSInternal(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, dnsfilter.Result, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, dnsfilter.Result{}, fmt.Errorf("got a DNS request with more than one Question") - } - for _, question := range r.Question { - host := strings.ToLower(strings.TrimSuffix(question.Name, ".")) - // is it a safesearch domain? - p.RLock() - if val, ok := p.d.SafeSearchDomain(host); ok { - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - p.RUnlock() - return rcode, dnsfilter.Result{}, err - } - p.RUnlock() - return rcode, dnsfilter.Result{Reason: dnsfilter.FilteredSafeSearch}, err - } - p.RUnlock() - - // needs to be filtered instead - p.RLock() - result, err := p.d.CheckHost(host) - if err != nil { - log.Printf("plugin/dnsfilter: %s\n", err) - p.RUnlock() - return dns.RcodeServerFailure, dnsfilter.Result{}, fmt.Errorf("plugin/dnsfilter: %s", err) - } - p.RUnlock() - - if result.IsFiltered { - switch result.Reason { - case dnsfilter.FilteredSafeBrowsing: - // return cname safebrowsing.block.dns.adguard.com - val := p.settings.SafeBrowsingBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredParental: - // return cname family.block.dns.adguard.com - val := p.settings.ParentalBlockHost - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, val, question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - case dnsfilter.FilteredBlackList: - - if result.Ip == nil { - // return NXDomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } else { - // This is a hosts-syntax rule - rcode, err := p.replaceHostWithValAndReply(ctx, w, r, host, result.Ip.String(), question) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - } - case dnsfilter.FilteredInvalid: - // return NXdomain - rcode, err := p.writeNXdomain(ctx, w, r) - if err != nil { - return rcode, dnsfilter.Result{}, err - } - return rcode, result, err - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } else { - switch result.Reason { - case dnsfilter.NotFilteredWhiteList: - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, result, err - case dnsfilter.NotFilteredNotFound: - // do nothing, pass through to lower code - default: - log.Printf("SHOULD NOT HAPPEN -- got unknown reason for not filtering host \"%s\": %v, %+v", host, result.Reason, result) - } - } - } - rcode, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) - return rcode, dnsfilter.Result{}, err -} - -// ServeDNS handles the DNS request and refuses if it's in filterlists -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - start := time.Now() - requests.Inc() - state := request.Request{W: w, Req: r} - ip := state.IP() - - // capture the written answer - rrw := dnstest.NewRecorder(w) - rcode, result, err := p.serveDNSInternal(ctx, rrw, r) - if rcode > 0 { - // actually send the answer if we have one - answer := new(dns.Msg) - answer.SetRcode(r, rcode) - state.SizeAndDo(answer) - err = w.WriteMsg(answer) - if err != nil { - return dns.RcodeServerFailure, err - } - } - - // increment counters - switch { - case err != nil: - errorsTotal.Inc() - case result.Reason == dnsfilter.FilteredBlackList: - filtered.Inc() - filteredLists.Inc() - case result.Reason == dnsfilter.FilteredSafeBrowsing: - filtered.Inc() - filteredSafebrowsing.Inc() - case result.Reason == dnsfilter.FilteredParental: - filtered.Inc() - filteredParental.Inc() - case result.Reason == dnsfilter.FilteredInvalid: - filtered.Inc() - filteredInvalid.Inc() - case result.Reason == dnsfilter.FilteredSafeSearch: - // the request was passsed through but not filtered, don't increment filtered - safesearch.Inc() - case result.Reason == dnsfilter.NotFilteredWhiteList: - whitelisted.Inc() - case result.Reason == dnsfilter.NotFilteredNotFound: - // do nothing - case result.Reason == dnsfilter.NotFilteredError: - text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!" - log.Println(text) - err = errors.New(text) - rcode = dns.RcodeServerFailure - } - - // log - elapsed := time.Since(start) - elapsedTime.Observe(elapsed.Seconds()) - if p.settings.QueryLogEnabled { - logRequest(r, rrw.Msg, result, time.Since(start), ip) - } - return rcode, err -} - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "dnsfilter" } - -var onceHook sync.Once -var onceQueryLog sync.Once diff --git a/coredns_plugin/coredns_plugin_test.go b/coredns_plugin/coredns_plugin_test.go deleted file mode 100644 index 1733fd6f..00000000 --- a/coredns_plugin/coredns_plugin_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package dnsfilter - -import ( - "context" - "fmt" - "io/ioutil" - "net" - "os" - "testing" - - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/plugin/test" - "github.com/mholt/caddy" - "github.com/miekg/dns" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`dnsfilter`, false}, - {`dnsfilter { - filter 0 /dev/nonexistent/abcdef - }`, true}, - {`dnsfilter { - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - safebrowsing - filter 0 ../tests/dns.txt - }`, false}, - {`dnsfilter { - parental - filter 0 ../tests/dns.txt - }`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestEtcHostsFilter(t *testing.T) { - text := []byte("127.0.0.1 doubleclick.net\n" + "127.0.0.1 example.org example.net www.example.org www.example.net") - tmpfile, err := ioutil.TempFile("", "") - if err != nil { - t.Fatal(err) - } - if _, err = tmpfile.Write(text); err != nil { - t.Fatal(err) - } - if err = tmpfile.Close(); err != nil { - t.Fatal(err) - } - - defer os.Remove(tmpfile.Name()) - - configText := fmt.Sprintf("dnsfilter {\nfilter 0 %s\n}", tmpfile.Name()) - c := caddy.NewTestController("dns", configText) - p, err := setupPlugin(c) - if err != nil { - t.Fatal(err) - } - - p.Next = zeroTTLBackend() - - ctx := context.TODO() - - for _, testcase := range []struct { - host string - filtered bool - }{ - {"www.doubleclick.net", false}, - {"doubleclick.net", true}, - {"www2.example.org", false}, - {"www2.example.net", false}, - {"test.www.example.org", false}, - {"test.www.example.net", false}, - {"example.org", true}, - {"example.net", true}, - {"www.example.org", true}, - {"www.example.net", true}, - } { - req := new(dns.Msg) - req.SetQuestion(testcase.host+".", dns.TypeA) - - resp := test.ResponseWriter{} - rrw := dnstest.NewRecorder(&resp) - rcode, err := p.ServeDNS(ctx, rrw, req) - if err != nil { - t.Fatalf("ServeDNS returned error: %s", err) - } - if rcode != rrw.Rcode { - t.Fatalf("ServeDNS return value for host %s has rcode %d that does not match captured rcode %d", testcase.host, rcode, rrw.Rcode) - } - A, ok := rrw.Msg.Answer[0].(*dns.A) - if !ok { - t.Fatalf("Host %s expected to have result A", testcase.host) - } - ip := net.IPv4(127, 0, 0, 1) - filtered := ip.Equal(A.A) - if testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be filtered, instead it is not filtered", testcase.host) - } - if !testcase.filtered && testcase.filtered != filtered { - t.Fatalf("Host %s expected to be not filtered, instead it is filtered", testcase.host) - } - } -} - -func zeroTTLBackend() plugin.Handler { - return plugin.HandlerFunc(func(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - m := new(dns.Msg) - m.SetReply(r) - m.Response, m.RecursionAvailable = true, true - - m.Answer = []dns.RR{test.A("example.org. 0 IN A 127.0.0.53")} - w.WriteMsg(m) - return dns.RcodeSuccess, nil - }) -} diff --git a/coredns_plugin/coredns_stats.go b/coredns_plugin/coredns_stats.go deleted file mode 100644 index b138911e..00000000 --- a/coredns_plugin/coredns_stats.go +++ /dev/null @@ -1,410 +0,0 @@ -package dnsfilter - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "sync" - "time" - - "github.com/coredns/coredns/plugin" - "github.com/prometheus/client_golang/prometheus" -) - -var ( - requests = newDNSCounter("requests_total", "Count of requests seen by dnsfilter.") - filtered = newDNSCounter("filtered_total", "Count of requests filtered by dnsfilter.") - filteredLists = newDNSCounter("filtered_lists_total", "Count of requests filtered by dnsfilter using lists.") - filteredSafebrowsing = newDNSCounter("filtered_safebrowsing_total", "Count of requests filtered by dnsfilter using safebrowsing.") - filteredParental = newDNSCounter("filtered_parental_total", "Count of requests filtered by dnsfilter using parental.") - filteredInvalid = newDNSCounter("filtered_invalid_total", "Count of requests filtered by dnsfilter because they were invalid.") - whitelisted = newDNSCounter("whitelisted_total", "Count of requests not filtered by dnsfilter because they are whitelisted.") - safesearch = newDNSCounter("safesearch_total", "Count of requests replaced by dnsfilter safesearch.") - errorsTotal = newDNSCounter("errors_total", "Count of requests that dnsfilter couldn't process because of transitive errors.") - elapsedTime = newDNSHistogram("request_duration", "Histogram of the time (in seconds) each request took.") -) - -// entries for single time period (for example all per-second entries) -type statsEntries map[string][statsHistoryElements]float64 - -// how far back to keep the stats -const statsHistoryElements = 60 + 1 // +1 for calculating delta - -// each periodic stat is a map of arrays -type periodicStats struct { - Entries statsEntries - period time.Duration // how long one entry lasts - LastRotate time.Time // last time this data was rotated - - sync.RWMutex -} - -type stats struct { - PerSecond periodicStats - PerMinute periodicStats - PerHour periodicStats - PerDay periodicStats -} - -// per-second/per-minute/per-hour/per-day stats -var statistics stats - -func initPeriodicStats(periodic *periodicStats, period time.Duration) { - periodic.Entries = statsEntries{} - periodic.LastRotate = time.Now() - periodic.period = period -} - -func init() { - purgeStats() -} - -func purgeStats() { - initPeriodicStats(&statistics.PerSecond, time.Second) - initPeriodicStats(&statistics.PerMinute, time.Minute) - initPeriodicStats(&statistics.PerHour, time.Hour) - initPeriodicStats(&statistics.PerDay, time.Hour*24) -} - -func (p *periodicStats) Inc(name string, when time.Time) { - // calculate how many periods ago this happened - elapsed := int64(time.Since(when) / p.period) - // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) - if elapsed >= statsHistoryElements { - return // outside of our timeframe - } - p.Lock() - currentValues := p.Entries[name] - currentValues[elapsed]++ - p.Entries[name] = currentValues - p.Unlock() -} - -func (p *periodicStats) Observe(name string, when time.Time, value float64) { - // calculate how many periods ago this happened - elapsed := int64(time.Since(when) / p.period) - // trace("%s: %v as %v -> [%v]", name, time.Since(when), p.period, elapsed) - if elapsed >= statsHistoryElements { - return // outside of our timeframe - } - p.Lock() - { - countname := name + "_count" - currentValues := p.Entries[countname] - value := currentValues[elapsed] - // trace("Will change p.Entries[%s][%d] from %v to %v", countname, elapsed, value, value+1) - value += 1 - currentValues[elapsed] = value - p.Entries[countname] = currentValues - } - { - totalname := name + "_sum" - currentValues := p.Entries[totalname] - currentValues[elapsed] += value - p.Entries[totalname] = currentValues - } - p.Unlock() -} - -func (p *periodicStats) statsRotate(now time.Time) { - p.Lock() - rotations := int64(now.Sub(p.LastRotate) / p.period) - if rotations > statsHistoryElements { - rotations = statsHistoryElements - } - // calculate how many times we should rotate - for r := int64(0); r < rotations; r++ { - for key, values := range p.Entries { - newValues := [statsHistoryElements]float64{} - for i := 1; i < len(values); i++ { - newValues[i] = values[i-1] - } - p.Entries[key] = newValues - } - } - if rotations > 0 { - p.LastRotate = now - } - p.Unlock() -} - -func statsRotator() { - for range time.Tick(time.Second) { - now := time.Now() - statistics.PerSecond.statsRotate(now) - statistics.PerMinute.statsRotate(now) - statistics.PerHour.statsRotate(now) - statistics.PerDay.statsRotate(now) - } -} - -// counter that wraps around prometheus Counter but also adds to periodic stats -type counter struct { - name string // used as key in periodic stats - value int64 - prom prometheus.Counter -} - -func newDNSCounter(name string, help string) *counter { - // trace("called") - c := &counter{} - c.prom = prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - c.name = name - - return c -} - -func (c *counter) IncWithTime(when time.Time) { - statistics.PerSecond.Inc(c.name, when) - statistics.PerMinute.Inc(c.name, when) - statistics.PerHour.Inc(c.name, when) - statistics.PerDay.Inc(c.name, when) - c.value++ - c.prom.Inc() -} - -func (c *counter) Inc() { - c.IncWithTime(time.Now()) -} - -func (c *counter) Describe(ch chan<- *prometheus.Desc) { - c.prom.Describe(ch) -} - -func (c *counter) Collect(ch chan<- prometheus.Metric) { - c.prom.Collect(ch) -} - -type histogram struct { - name string // used as key in periodic stats - count int64 - total float64 - prom prometheus.Histogram -} - -func newDNSHistogram(name string, help string) *histogram { - // trace("called") - h := &histogram{} - h.prom = prometheus.NewHistogram(prometheus.HistogramOpts{ - Namespace: plugin.Namespace, - Subsystem: "dnsfilter", - Name: name, - Help: help, - }) - h.name = name - - return h -} - -func (h *histogram) ObserveWithTime(value float64, when time.Time) { - statistics.PerSecond.Observe(h.name, when, value) - statistics.PerMinute.Observe(h.name, when, value) - statistics.PerHour.Observe(h.name, when, value) - statistics.PerDay.Observe(h.name, when, value) - h.count++ - h.total += value - h.prom.Observe(value) -} - -func (h *histogram) Observe(value float64) { - h.ObserveWithTime(value, time.Now()) -} - -func (h *histogram) Describe(ch chan<- *prometheus.Desc) { - h.prom.Describe(ch) -} - -func (h *histogram) Collect(ch chan<- prometheus.Metric) { - h.prom.Collect(ch) -} - -// ----- -// stats -// ----- -func HandleStats(w http.ResponseWriter, r *http.Request) { - const numHours = 24 - histrical := generateMapFromStats(&statistics.PerHour, 0, numHours) - // sum them up - summed := map[string]interface{}{} - for key, values := range histrical { - summedValue := 0.0 - floats, ok := values.([]float64) - if !ok { - continue - } - for _, v := range floats { - summedValue += v - } - summed[key] = summedValue - } - // don't forget to divide by number of elements in returned slice - if val, ok := summed["avg_processing_time"]; ok { - if flval, flok := val.(float64); flok { - flval /= numHours - summed["avg_processing_time"] = flval - } - } - - summed["stats_period"] = "24 hours" - - json, err := json.Marshal(summed) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func generateMapFromStats(stats *periodicStats, start int, end int) map[string]interface{} { - // clamp - start = clamp(start, 0, statsHistoryElements) - end = clamp(end, 0, statsHistoryElements) - - avgProcessingTime := make([]float64, 0) - - count := getReversedSlice(stats.Entries[elapsedTime.name+"_count"], start, end) - sum := getReversedSlice(stats.Entries[elapsedTime.name+"_sum"], start, end) - for i := 0; i < len(count); i++ { - var avg float64 - if count[i] != 0 { - avg = sum[i] / count[i] - avg *= 1000 - } - avgProcessingTime = append(avgProcessingTime, avg) - } - - result := map[string]interface{}{ - "dns_queries": getReversedSlice(stats.Entries[requests.name], start, end), - "blocked_filtering": getReversedSlice(stats.Entries[filtered.name], start, end), - "replaced_safebrowsing": getReversedSlice(stats.Entries[filteredSafebrowsing.name], start, end), - "replaced_safesearch": getReversedSlice(stats.Entries[safesearch.name], start, end), - "replaced_parental": getReversedSlice(stats.Entries[filteredParental.name], start, end), - "avg_processing_time": avgProcessingTime, - } - return result -} - -func HandleStatsHistory(w http.ResponseWriter, r *http.Request) { - // handle time unit and prepare our time window size - now := time.Now() - timeUnitString := r.URL.Query().Get("time_unit") - var stats *periodicStats - var timeUnit time.Duration - switch timeUnitString { - case "seconds": - timeUnit = time.Second - stats = &statistics.PerSecond - case "minutes": - timeUnit = time.Minute - stats = &statistics.PerMinute - case "hours": - timeUnit = time.Hour - stats = &statistics.PerHour - case "days": - timeUnit = time.Hour * 24 - stats = &statistics.PerDay - default: - http.Error(w, "Must specify valid time_unit parameter", 400) - return - } - - // parse start and end time - startTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("start_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid start_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - endTime, err := time.Parse(time.RFC3339, r.URL.Query().Get("end_time")) - if err != nil { - errortext := fmt.Sprintf("Must specify valid end_time parameter: %s", err) - log.Println(errortext) - http.Error(w, errortext, 400) - return - } - - // check if start and time times are within supported time range - timeRange := timeUnit * statsHistoryElements - if startTime.Add(timeRange).Before(now) { - http.Error(w, "start_time parameter is outside of supported range", 501) - return - } - if endTime.Add(timeRange).Before(now) { - http.Error(w, "end_time parameter is outside of supported range", 501) - return - } - - // calculate start and end of our array - // basically it's how many hours/minutes/etc have passed since now - start := int(now.Sub(endTime) / timeUnit) - end := int(now.Sub(startTime) / timeUnit) - - // swap them around if they're inverted - if start > end { - start, end = end, start - } - - data := generateMapFromStats(stats, start, end) - json, err := json.Marshal(data) - if err != nil { - errortext := fmt.Sprintf("Unable to marshal status json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(json) - if err != nil { - errortext := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errortext) - http.Error(w, errortext, 500) - return - } -} - -func HandleStatsReset(w http.ResponseWriter, r *http.Request) { - purgeStats() - _, err := fmt.Fprintf(w, "OK\n") - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -func clamp(value, low, high int) int { - if value < low { - return low - } - if value > high { - return high - } - return value -} - -// -------------------------- -// helper functions for stats -// -------------------------- -func getReversedSlice(input [statsHistoryElements]float64, start int, end int) []float64 { - output := make([]float64, 0) - for i := start; i <= end; i++ { - output = append([]float64{input[i]}, output...) - } - return output -} diff --git a/coredns_plugin/querylog.go b/coredns_plugin/querylog.go deleted file mode 100644 index 92ba2d1d..00000000 --- a/coredns_plugin/querylog.go +++ /dev/null @@ -1,239 +0,0 @@ -package dnsfilter - -import ( - "encoding/json" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/coredns/coredns/plugin/pkg/response" - "github.com/miekg/dns" -) - -const ( - logBufferCap = 5000 // maximum capacity of logBuffer before it's flushed to disk - queryLogTimeLimit = time.Hour * 24 // how far in the past we care about querylogs - queryLogRotationPeriod = time.Hour * 24 // rotate the log every 24 hours - queryLogFileName = "querylog.json" // .gz added during compression - queryLogSize = 5000 // maximum API response for /querylog - queryLogTopSize = 500 // Keep in memory only top N values -) - -var ( - logBufferLock sync.RWMutex - logBuffer []*logEntry - - queryLogCache []*logEntry - queryLogLock sync.RWMutex -) - -type logEntry struct { - Question []byte - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net - Result dnsfilter.Result - Time time.Time - Elapsed time.Duration - IP string -} - -func logRequest(question *dns.Msg, answer *dns.Msg, result dnsfilter.Result, elapsed time.Duration, ip string) { - var q []byte - var a []byte - var err error - - if question != nil { - q, err = question.Pack() - if err != nil { - log.Printf("failed to pack question for querylog: %s", err) - return - } - } - if answer != nil { - a, err = answer.Pack() - if err != nil { - log.Printf("failed to pack answer for querylog: %s", err) - return - } - } - - now := time.Now() - entry := logEntry{ - Question: q, - Answer: a, - Result: result, - Time: now, - Elapsed: elapsed, - IP: ip, - } - var flushBuffer []*logEntry - - logBufferLock.Lock() - logBuffer = append(logBuffer, &entry) - if len(logBuffer) >= logBufferCap { - flushBuffer = logBuffer - logBuffer = nil - } - logBufferLock.Unlock() - queryLogLock.Lock() - queryLogCache = append(queryLogCache, &entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - // add it to running top - err = runningTop.addEntry(&entry, question, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - // don't do failure, just log - } - - // if buffer needs to be flushed to disk, do it now - if len(flushBuffer) > 0 { - // write to file - // do it in separate goroutine -- we are stalling DNS response this whole time - go flushToFile(flushBuffer) - } -} - -func HandleQueryLog(w http.ResponseWriter, r *http.Request) { - queryLogLock.RLock() - values := make([]*logEntry, len(queryLogCache)) - copy(values, queryLogCache) - queryLogLock.RUnlock() - - // reverse it so that newest is first - for left, right := 0, len(values)-1; left < right; left, right = left+1, right-1 { - values[left], values[right] = values[right], values[left] - } - - var data = []map[string]interface{}{} - for _, entry := range values { - var q *dns.Msg - var a *dns.Msg - - if len(entry.Question) > 0 { - q = new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - q = nil - } - } - if len(entry.Answer) > 0 { - a = new(dns.Msg) - if err := a.Unpack(entry.Answer); err != nil { - // ignore, log and move on - log.Printf("Failed to unpack dns message question: %s", err) - a = nil - } - } - - jsonEntry := map[string]interface{}{ - "reason": entry.Result.Reason.String(), - "elapsedMs": strconv.FormatFloat(entry.Elapsed.Seconds()*1000, 'f', -1, 64), - "time": entry.Time.Format(time.RFC3339), - "client": entry.IP, - } - if q != nil { - jsonEntry["question"] = map[string]interface{}{ - "host": strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")), - "type": dns.Type(q.Question[0].Qtype).String(), - "class": dns.Class(q.Question[0].Qclass).String(), - } - } - - if a != nil { - status, _ := response.Typify(a, time.Now().UTC()) - jsonEntry["status"] = status.String() - } - if len(entry.Result.Rule) > 0 { - jsonEntry["rule"] = entry.Result.Rule - jsonEntry["filterId"] = entry.Result.FilterID - } - - if a != nil && len(a.Answer) > 0 { - var answers = []map[string]interface{}{} - for _, k := range a.Answer { - header := k.Header() - answer := map[string]interface{}{ - "type": dns.TypeToString[header.Rrtype], - "ttl": header.Ttl, - } - // try most common record types - switch v := k.(type) { - case *dns.A: - answer["value"] = v.A - case *dns.AAAA: - answer["value"] = v.AAAA - case *dns.MX: - answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) - case *dns.CNAME: - answer["value"] = v.Target - case *dns.NS: - answer["value"] = v.Ns - case *dns.SPF: - answer["value"] = v.Txt - case *dns.TXT: - answer["value"] = v.Txt - case *dns.PTR: - answer["value"] = v.Ptr - case *dns.SOA: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v", v.Ns, v.Mbox, v.Serial, v.Refresh, v.Retry, v.Expire, v.Minttl) - case *dns.CAA: - answer["value"] = fmt.Sprintf("%v %v \"%v\"", v.Flag, v.Tag, v.Value) - case *dns.HINFO: - answer["value"] = fmt.Sprintf("\"%v\" \"%v\"", v.Cpu, v.Os) - case *dns.RRSIG: - answer["value"] = fmt.Sprintf("%v %v %v %v %v %v %v %v %v", dns.TypeToString[v.TypeCovered], v.Algorithm, v.Labels, v.OrigTtl, v.Expiration, v.Inception, v.KeyTag, v.SignerName, v.Signature) - default: - // type unknown, marshall it as-is - answer["value"] = v - } - answers = append(answers, answer) - } - jsonEntry["answer"] = answers - } - - data = append(data, jsonEntry) - } - - jsonVal, err := json.Marshal(data) - if err != nil { - errorText := fmt.Sprintf("Couldn't marshal data into json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - _, err = w.Write(jsonVal) - if err != nil { - errorText := fmt.Sprintf("Unable to write response json: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - -func trace(format string, args ...interface{}) { - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := runtime.FuncForPC(pc[0]) - var buf strings.Builder - buf.WriteString(fmt.Sprintf("%s(): ", path.Base(f.Name()))) - text := fmt.Sprintf(format, args...) - buf.WriteString(text) - if len(text) == 0 || text[len(text)-1] != '\n' { - buf.WriteRune('\n') - } - fmt.Fprint(os.Stderr, buf.String()) -} diff --git a/coredns_plugin/querylog_file.go b/coredns_plugin/querylog_file.go deleted file mode 100644 index a36812c2..00000000 --- a/coredns_plugin/querylog_file.go +++ /dev/null @@ -1,291 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "compress/gzip" - "encoding/json" - "fmt" - "log" - "os" - "sync" - "time" - - "github.com/go-test/deep" -) - -var ( - fileWriteLock sync.Mutex -) - -const enableGzip = false - -func flushToFile(buffer []*logEntry) error { - if len(buffer) == 0 { - return nil - } - start := time.Now() - - var b bytes.Buffer - e := json.NewEncoder(&b) - for _, entry := range buffer { - err := e.Encode(entry) - if err != nil { - log.Printf("Failed to marshal entry: %s", err) - return err - } - } - - elapsed := time.Since(start) - log.Printf("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer))) - - err := checkBuffer(buffer, b) - if err != nil { - log.Printf("failed to check buffer: %s", err) - return err - } - - var zb bytes.Buffer - filename := queryLogFileName - - // gzip enabled? - if enableGzip { - filename += ".gz" - - zw := gzip.NewWriter(&zb) - zw.Name = queryLogFileName - zw.ModTime = time.Now() - - _, err = zw.Write(b.Bytes()) - if err != nil { - log.Printf("Couldn't compress to gzip: %s", err) - zw.Close() - return err - } - - if err = zw.Close(); err != nil { - log.Printf("Couldn't close gzip writer: %s", err) - return err - } - } else { - zb = b - } - - fileWriteLock.Lock() - defer fileWriteLock.Unlock() - f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) - if err != nil { - log.Printf("failed to create file \"%s\": %s", filename, err) - return err - } - defer f.Close() - - n, err := f.Write(zb.Bytes()) - if err != nil { - log.Printf("Couldn't write to file: %s", err) - return err - } - - log.Printf("ok \"%s\": %v bytes written", filename, n) - - return nil -} - -func checkBuffer(buffer []*logEntry, b bytes.Buffer) error { - l := len(buffer) - d := json.NewDecoder(&b) - - i := 0 - for d.More() { - entry := &logEntry{} - err := d.Decode(entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - return err - } - if diff := deep.Equal(entry, buffer[i]); diff != nil { - log.Printf("decoded buffer differs: %s", diff) - return fmt.Errorf("decoded buffer differs: %s", diff) - } - i++ - } - if i != l { - err := fmt.Errorf("check fail: %d vs %d entries", l, i) - log.Print(err) - return err - } - log.Printf("check ok: %d entries", i) - - return nil -} - -func rotateQueryLog() error { - from := queryLogFileName - to := queryLogFileName + ".1" - - if enableGzip { - from = queryLogFileName + ".gz" - to = queryLogFileName + ".gz.1" - } - - if _, err := os.Stat(from); os.IsNotExist(err) { - // do nothing, file doesn't exist - return nil - } - - err := os.Rename(from, to) - if err != nil { - log.Printf("Failed to rename querylog: %s", err) - return err - } - - log.Printf("Rotated from %s to %s successfully", from, to) - - return nil -} - -func periodicQueryLogRotate() { - for range time.Tick(queryLogRotationPeriod) { - err := rotateQueryLog() - if err != nil { - log.Printf("Failed to rotate querylog: %s", err) - // do nothing, continue rotating - } - } -} - -func genericLoader(onEntry func(entry *logEntry) error, needMore func() bool, timeWindow time.Duration) error { - now := time.Now() - // read from querylog files, try newest file first - files := []string{} - - if enableGzip { - files = []string{ - queryLogFileName + ".gz", - queryLogFileName + ".gz.1", - } - } else { - files = []string{ - queryLogFileName, - queryLogFileName + ".1", - } - } - - // read from all files - for _, file := range files { - if !needMore() { - break - } - if _, err := os.Stat(file); os.IsNotExist(err) { - // do nothing, file doesn't exist - continue - } - - f, err := os.Open(file) - if err != nil { - log.Printf("Failed to open file \"%s\": %s", file, err) - // try next file - continue - } - defer f.Close() - - var d *json.Decoder - - if enableGzip { - trace("Creating gzip reader") - zr, err := gzip.NewReader(f) - if err != nil { - log.Printf("Failed to create gzip reader: %s", err) - continue - } - defer zr.Close() - - trace("Creating json decoder") - d = json.NewDecoder(zr) - } else { - d = json.NewDecoder(f) - } - - i := 0 - over := 0 - max := 10000 * time.Second - var sum time.Duration - // entries on file are in oldest->newest order - // we want maxLen newest - for d.More() { - if !needMore() { - break - } - var entry logEntry - err := d.Decode(&entry) - if err != nil { - log.Printf("Failed to decode: %s", err) - // next entry can be fine, try more - continue - } - - if now.Sub(entry.Time) > timeWindow { - // trace("skipping entry") // debug logging - continue - } - - if entry.Elapsed > max { - over++ - } else { - sum += entry.Elapsed - } - - i++ - err = onEntry(&entry) - if err != nil { - return err - } - } - elapsed := time.Since(now) - var perunit time.Duration - var avg time.Duration - if i > 0 { - perunit = elapsed / time.Duration(i) - avg = sum / time.Duration(i) - } - log.Printf("file \"%s\": read %d entries in %v, %v/entry, %v over %v, %v avg", file, i, elapsed, perunit, over, max, avg) - } - return nil -} - -func appendFromLogFile(values []*logEntry, maxLen int, timeWindow time.Duration) []*logEntry { - a := []*logEntry{} - - onEntry := func(entry *logEntry) error { - a = append(a, entry) - if len(a) > maxLen { - toskip := len(a) - maxLen - a = a[toskip:] - } - return nil - } - - needMore := func() bool { - return true - } - - err := genericLoader(onEntry, needMore, timeWindow) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return values - } - - // now that we've read all eligible entries, reverse the slice to make it go from newest->oldest - for left, right := 0, len(a)-1; left < right; left, right = left+1, right-1 { - a[left], a[right] = a[right], a[left] - } - - // append it to values - values = append(values, a...) - - // then cut off of it is bigger than maxLen - if len(values) > maxLen { - values = values[:maxLen] - } - - return values -} diff --git a/coredns_plugin/querylog_top.go b/coredns_plugin/querylog_top.go deleted file mode 100644 index d4cc6e0d..00000000 --- a/coredns_plugin/querylog_top.go +++ /dev/null @@ -1,386 +0,0 @@ -package dnsfilter - -import ( - "bytes" - "fmt" - "log" - "net/http" - "os" - "path" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/dnsfilter" - "github.com/bluele/gcache" - "github.com/miekg/dns" -) - -type hourTop struct { - domains gcache.Cache - blocked gcache.Cache - clients gcache.Cache - - mutex sync.RWMutex -} - -func (top *hourTop) init() { - top.domains = gcache.New(queryLogTopSize).LRU().Build() - top.blocked = gcache.New(queryLogTopSize).LRU().Build() - top.clients = gcache.New(queryLogTopSize).LRU().Build() -} - -type dayTop struct { - hours []*hourTop - hoursLock sync.RWMutex // writelock this lock ONLY WHEN rotating or intializing hours! - - loaded bool - loadedLock sync.Mutex -} - -var runningTop dayTop - -func init() { - runningTop.hoursWriteLock() - for i := 0; i < 24; i++ { - hour := hourTop{} - hour.init() - runningTop.hours = append(runningTop.hours, &hour) - } - runningTop.hoursWriteUnlock() -} - -func rotateHourlyTop() { - log.Printf("Rotating hourly top") - hour := &hourTop{} - hour.init() - runningTop.hoursWriteLock() - runningTop.hours = append([]*hourTop{hour}, runningTop.hours...) - runningTop.hours = runningTop.hours[:24] - runningTop.hoursWriteUnlock() -} - -func periodicHourlyTopRotate() { - t := time.Hour - for range time.Tick(t) { - rotateHourlyTop() - } -} - -func (top *hourTop) incrementValue(key string, cache gcache.Cache) error { - top.Lock() - defer top.Unlock() - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - // we just set it and we're done - err = cache.Set(key, 1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return err - } - - cachedValue, ok := ivalue.(int) - if !ok { - err = fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return err - } - - err = cache.Set(key, cachedValue+1) - if err != nil { - log.Printf("Failed to set hourly top value: %s", err) - return err - } - return nil -} - -func (top *hourTop) incrementDomains(key string) error { - return top.incrementValue(key, top.domains) -} - -func (top *hourTop) incrementBlocked(key string) error { - return top.incrementValue(key, top.blocked) -} - -func (top *hourTop) incrementClients(key string) error { - return top.incrementValue(key, top.clients) -} - -// if does not exist -- return 0 -func (top *hourTop) lockedGetValue(key string, cache gcache.Cache) (int, error) { - ivalue, err := cache.Get(key) - if err == gcache.KeyNotFoundError { - return 0, nil - } - - if err != nil { - log.Printf("gcache encountered an error during get: %s", err) - return 0, err - } - - value, ok := ivalue.(int) - if !ok { - err := fmt.Errorf("SHOULD NOT HAPPEN: gcache has non-int as value: %v", ivalue) - log.Println(err) - return 0, err - } - - return value, nil -} - -func (top *hourTop) lockedGetDomains(key string) (int, error) { - return top.lockedGetValue(key, top.domains) -} - -func (top *hourTop) lockedGetBlocked(key string) (int, error) { - return top.lockedGetValue(key, top.blocked) -} - -func (top *hourTop) lockedGetClients(key string) (int, error) { - return top.lockedGetValue(key, top.clients) -} - -func (r *dayTop) addEntry(entry *logEntry, q *dns.Msg, now time.Time) error { - // figure out which hour bucket it belongs to - hour := int(now.Sub(entry.Time).Hours()) - if hour >= 24 { - log.Printf("t %v is >24 hours ago, ignoring", entry.Time) - return nil - } - - hostname := strings.ToLower(strings.TrimSuffix(q.Question[0].Name, ".")) - - // get value, if not set, crate one - runningTop.hoursReadLock() - defer runningTop.hoursReadUnlock() - err := runningTop.hours[hour].incrementDomains(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - - if entry.Result.IsFiltered { - err := runningTop.hours[hour].incrementBlocked(hostname) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - if len(entry.IP) > 0 { - err := runningTop.hours[hour].incrementClients(entry.IP) - if err != nil { - log.Printf("Failed to increment value: %s", err) - return err - } - } - - return nil -} - -func fillStatsFromQueryLog() error { - now := time.Now() - runningTop.loadedWriteLock() - defer runningTop.loadedWriteUnlock() - if runningTop.loaded { - return nil - } - onEntry := func(entry *logEntry) error { - if len(entry.Question) == 0 { - log.Printf("entry question is absent, skipping") - return nil - } - - if entry.Time.After(now) { - log.Printf("t %v vs %v is in the future, ignoring", entry.Time, now) - return nil - } - - q := new(dns.Msg) - if err := q.Unpack(entry.Question); err != nil { - log.Printf("failed to unpack dns message question: %s", err) - return err - } - - if len(q.Question) != 1 { - log.Printf("malformed dns message, has no questions, skipping") - return nil - } - - err := runningTop.addEntry(entry, q, now) - if err != nil { - log.Printf("Failed to add entry to running top: %s", err) - return err - } - - queryLogLock.Lock() - queryLogCache = append(queryLogCache, entry) - if len(queryLogCache) > queryLogSize { - toremove := len(queryLogCache) - queryLogSize - queryLogCache = queryLogCache[toremove:] - } - queryLogLock.Unlock() - - requests.IncWithTime(entry.Time) - if entry.Result.IsFiltered { - filtered.IncWithTime(entry.Time) - } - switch entry.Result.Reason { - case dnsfilter.NotFilteredWhiteList: - whitelisted.IncWithTime(entry.Time) - case dnsfilter.NotFilteredError: - errorsTotal.IncWithTime(entry.Time) - case dnsfilter.FilteredBlackList: - filteredLists.IncWithTime(entry.Time) - case dnsfilter.FilteredSafeBrowsing: - filteredSafebrowsing.IncWithTime(entry.Time) - case dnsfilter.FilteredParental: - filteredParental.IncWithTime(entry.Time) - case dnsfilter.FilteredInvalid: - // do nothing - case dnsfilter.FilteredSafeSearch: - safesearch.IncWithTime(entry.Time) - } - elapsedTime.ObserveWithTime(entry.Elapsed.Seconds(), entry.Time) - - return nil - } - - needMore := func() bool { return true } - err := genericLoader(onEntry, needMore, queryLogTimeLimit) - if err != nil { - log.Printf("Failed to load entries from querylog: %s", err) - return err - } - - runningTop.loaded = true - - return nil -} - -func HandleStatsTop(w http.ResponseWriter, r *http.Request) { - domains := map[string]int{} - blocked := map[string]int{} - clients := map[string]int{} - - do := func(keys []interface{}, getter func(key string) (int, error), result map[string]int) { - for _, ikey := range keys { - key, ok := ikey.(string) - if !ok { - continue - } - value, err := getter(key) - if err != nil { - log.Printf("Failed to get top domains value for %v: %s", key, err) - return - } - result[key] += value - } - } - - runningTop.hoursReadLock() - for hour := 0; hour < 24; hour++ { - runningTop.hours[hour].RLock() - do(runningTop.hours[hour].domains.Keys(), runningTop.hours[hour].lockedGetDomains, domains) - do(runningTop.hours[hour].blocked.Keys(), runningTop.hours[hour].lockedGetBlocked, blocked) - do(runningTop.hours[hour].clients.Keys(), runningTop.hours[hour].lockedGetClients, clients) - runningTop.hours[hour].RUnlock() - } - runningTop.hoursReadUnlock() - - // use manual json marshalling because we want maps to be sorted by value - json := bytes.Buffer{} - json.WriteString("{\n") - - gen := func(json *bytes.Buffer, name string, top map[string]int, addComma bool) { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", name)) - json.WriteString(": {\n") - sorted := sortByValue(top) - // no more than 50 entries - if len(sorted) > 50 { - sorted = sorted[:50] - } - for i, key := range sorted { - json.WriteString(" ") - json.WriteString(fmt.Sprintf("%q", key)) - json.WriteString(": ") - json.WriteString(strconv.Itoa(top[key])) - if i+1 != len(sorted) { - json.WriteByte(',') - } - json.WriteByte('\n') - } - json.WriteString(" }") - if addComma { - json.WriteByte(',') - } - json.WriteByte('\n') - } - gen(&json, "top_queried_domains", domains, true) - gen(&json, "top_blocked_domains", blocked, true) - gen(&json, "top_clients", clients, true) - json.WriteString(" \"stats_period\": \"24 hours\"\n") - json.WriteString("}\n") - - w.Header().Set("Content-Type", "application/json") - _, err := w.Write(json.Bytes()) - if err != nil { - errortext := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errortext) - http.Error(w, errortext, http.StatusInternalServerError) - } -} - -// helper function for querylog API -func sortByValue(m map[string]int) []string { - type kv struct { - k string - v int - } - var ss []kv - for k, v := range m { - ss = append(ss, kv{k, v}) - } - sort.Slice(ss, func(l, r int) bool { - return ss[l].v > ss[r].v - }) - - sorted := []string{} - for _, v := range ss { - sorted = append(sorted, v.k) - } - return sorted -} - -func (d *dayTop) hoursWriteLock() { tracelock(); d.hoursLock.Lock() } -func (d *dayTop) hoursWriteUnlock() { tracelock(); d.hoursLock.Unlock() } -func (d *dayTop) hoursReadLock() { tracelock(); d.hoursLock.RLock() } -func (d *dayTop) hoursReadUnlock() { tracelock(); d.hoursLock.RUnlock() } -func (d *dayTop) loadedWriteLock() { tracelock(); d.loadedLock.Lock() } -func (d *dayTop) loadedWriteUnlock() { tracelock(); d.loadedLock.Unlock() } - -func (h *hourTop) Lock() { tracelock(); h.mutex.Lock() } -func (h *hourTop) RLock() { tracelock(); h.mutex.RLock() } -func (h *hourTop) RUnlock() { tracelock(); h.mutex.RUnlock() } -func (h *hourTop) Unlock() { tracelock(); h.mutex.Unlock() } - -func tracelock() { - if false { // not commented out to make code checked during compilation - pc := make([]uintptr, 10) // at least 1 entry needed - runtime.Callers(2, pc) - f := path.Base(runtime.FuncForPC(pc[1]).Name()) - lockf := path.Base(runtime.FuncForPC(pc[0]).Name()) - fmt.Fprintf(os.Stderr, "%s(): %s\n", f, lockf) - } -} diff --git a/coredns_plugin/ratelimit/ratelimit.go b/coredns_plugin/ratelimit/ratelimit.go deleted file mode 100644 index 8d3eeecc..00000000 --- a/coredns_plugin/ratelimit/ratelimit.go +++ /dev/null @@ -1,182 +0,0 @@ -package ratelimit - -import ( - "errors" - "log" - "sort" - "strconv" - "time" - - // ratelimiting and per-ip buckets - "github.com/beefsack/go-rate" - "github.com/patrickmn/go-cache" - - // coredns plugin - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/plugin/pkg/dnstest" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -const defaultRatelimit = 30 -const defaultResponseSize = 1000 - -var ( - tokenBuckets = cache.New(time.Hour, time.Hour) -) - -// ServeDNS handles the DNS request and refuses if it's an beyind specified ratelimit -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - state := request.Request{W: w, Req: r} - ip := state.IP() - allow, err := p.allowRequest(ip) - if err != nil { - return 0, err - } - if !allow { - ratelimited.Inc() - return 0, nil - } - - // Record response to get status code and size of the reply. - rw := dnstest.NewRecorder(w) - status, err := plugin.NextOrFailure(p.Name(), p.Next, ctx, rw, r) - - size := rw.Len - - if size > defaultResponseSize && state.Proto() == "udp" { - // For large UDP responses we call allowRequest more times - // The exact number of times depends on the response size - for i := 0; i < size/defaultResponseSize; i++ { - p.allowRequest(ip) - } - } - - return status, err -} - -func (p *plug) allowRequest(ip string) (bool, error) { - if len(p.whitelist) > 0 { - i := sort.SearchStrings(p.whitelist, ip) - - if i < len(p.whitelist) && p.whitelist[i] == ip { - return true, nil - } - } - - if _, found := tokenBuckets.Get(ip); !found { - tokenBuckets.Set(ip, rate.New(p.ratelimit, time.Second), time.Hour) - } - - value, found := tokenBuckets.Get(ip) - if !found { - // should not happen since we've just inserted it - text := "SHOULD NOT HAPPEN: just-inserted ratelimiter disappeared" - log.Println(text) - err := errors.New(text) - return true, err - } - - rl, ok := value.(*rate.RateLimiter) - if !ok { - text := "SHOULD NOT HAPPEN: non-bool entry found in safebrowsing lookup cache" - log.Println(text) - err := errors.New(text) - return true, err - } - - allow, _ := rl.Try() - return allow, nil -} - -// -// helper functions -// -func init() { - caddy.RegisterPlugin("ratelimit", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -type plug struct { - Next plugin.Handler - - // configuration for creating above - ratelimit int // in requests per second per IP - whitelist []string // a list of whitelisted IP addresses -} - -func setupPlugin(c *caddy.Controller) (*plug, error) { - p := &plug{ratelimit: defaultRatelimit} - - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - ratelimit, err := strconv.Atoi(args[0]) - if err != nil { - return nil, c.ArgErr() - } - p.ratelimit = ratelimit - } - for c.NextBlock() { - switch c.Val() { - case "whitelist": - p.whitelist = c.RemainingArgs() - - if len(p.whitelist) > 0 { - sort.Strings(p.whitelist) - } - } - } - } - - return p, nil -} - -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "ratelimit", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("dropped_total", "Count of requests that have been dropped because of rate limit") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "ratelimit" } diff --git a/coredns_plugin/ratelimit/ratelimit_test.go b/coredns_plugin/ratelimit/ratelimit_test.go deleted file mode 100644 index b426f2eb..00000000 --- a/coredns_plugin/ratelimit/ratelimit_test.go +++ /dev/null @@ -1,80 +0,0 @@ -package ratelimit - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - for i, testcase := range []struct { - config string - failing bool - }{ - {`ratelimit`, false}, - {`ratelimit 100`, false}, - {`ratelimit { - whitelist 127.0.0.1 - }`, false}, - {`ratelimit 50 { - whitelist 127.0.0.1 176.103.130.130 - }`, false}, - {`ratelimit test`, true}, - } { - c := caddy.NewTestController("dns", testcase.config) - err := setup(c) - if err != nil { - if !testcase.failing { - t.Fatalf("Test #%d expected no errors, but got: %v", i, err) - } - continue - } - if testcase.failing { - t.Fatalf("Test #%d expected to fail but it didn't", i) - } - } -} - -func TestRatelimiting(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || allowed { - t.Fatal("Second request must have been ratelimited") - } -} - -func TestWhitelist(t *testing.T) { - // rate limit is 1 per sec - c := caddy.NewTestController("dns", `ratelimit 1 { whitelist 127.0.0.2 127.0.0.1 127.0.0.125 }`) - p, err := setupPlugin(c) - - if err != nil { - t.Fatal("Failed to initialize the plugin") - } - - allowed, err := p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("First request must have been allowed") - } - - allowed, err = p.allowRequest("127.0.0.1") - - if err != nil || !allowed { - t.Fatal("Second request must have been allowed due to whitelist") - } -} diff --git a/coredns_plugin/refuseany/refuseany.go b/coredns_plugin/refuseany/refuseany.go deleted file mode 100644 index 92d5d508..00000000 --- a/coredns_plugin/refuseany/refuseany.go +++ /dev/null @@ -1,91 +0,0 @@ -package refuseany - -import ( - "fmt" - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/coredns/coredns/plugin/metrics" - "github.com/coredns/coredns/request" - "github.com/mholt/caddy" - "github.com/miekg/dns" - "github.com/prometheus/client_golang/prometheus" - "golang.org/x/net/context" -) - -type plug struct { - Next plugin.Handler -} - -// ServeDNS handles the DNS request and refuses if it's an ANY request -func (p *plug) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - if len(r.Question) != 1 { - // google DNS, bind and others do the same - return dns.RcodeFormatError, fmt.Errorf("Got DNS request with != 1 questions") - } - - q := r.Question[0] - if q.Qtype == dns.TypeANY { - state := request.Request{W: w, Req: r, Context: ctx} - rcode := dns.RcodeNotImplemented - - m := new(dns.Msg) - m.SetRcode(r, rcode) - state.SizeAndDo(m) - err := state.W.WriteMsg(m) - if err != nil { - log.Printf("Got error %s\n", err) - return dns.RcodeServerFailure, err - } - return rcode, nil - } - - return plugin.NextOrFailure(p.Name(), p.Next, ctx, w, r) -} - -func init() { - caddy.RegisterPlugin("refuseany", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -func setup(c *caddy.Controller) error { - p := &plug{} - config := dnsserver.GetConfig(c) - - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnStartup(func() error { - m := dnsserver.GetConfig(c).Handler("prometheus") - if m == nil { - return nil - } - if x, ok := m.(*metrics.Metrics); ok { - x.MustRegister(ratelimited) - } - return nil - }) - - return nil -} - -func newDNSCounter(name string, help string) prometheus.Counter { - return prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: plugin.Namespace, - Subsystem: "refuseany", - Name: name, - Help: help, - }) -} - -var ( - ratelimited = newDNSCounter("refusedany_total", "Count of ANY requests that have been dropped") -) - -// Name returns name of the plugin as seen in Corefile and plugin.cfg -func (p *plug) Name() string { return "refuseany" } diff --git a/coredns_plugin/reload.go b/coredns_plugin/reload.go deleted file mode 100644 index 880a3acc..00000000 --- a/coredns_plugin/reload.go +++ /dev/null @@ -1,36 +0,0 @@ -package dnsfilter - -import ( - "log" - - "github.com/mholt/caddy" -) - -var Reload = make(chan bool) - -func hook(event caddy.EventName, info interface{}) error { - if event != caddy.InstanceStartupEvent { - return nil - } - - // this should be an instance. ok to panic if not - instance := info.(*caddy.Instance) - - go func() { - for range Reload { - corefile, err := caddy.LoadCaddyfile(instance.Caddyfile().ServerType()) - if err != nil { - continue - } - _, err = instance.Restart(corefile) - if err != nil { - log.Printf("Corefile changed but reload failed: %s", err) - continue - } - // hook will be called again from new instance - return - } - }() - - return nil -} diff --git a/go.mod b/go.mod index 1b8d78e6..166e3cce 100644 --- a/go.mod +++ b/go.mod @@ -3,35 +3,19 @@ module github.com/AdguardTeam/AdGuardHome require ( github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f // indirect github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 - github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 // indirect github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 - github.com/coredns/coredns v1.2.6 - github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 // indirect - github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 // indirect - github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect github.com/go-ole/go-ole v1.2.1 // indirect github.com/go-test/deep v1.0.1 github.com/gobuffalo/packr v1.19.0 - github.com/google/uuid v1.0.0 // indirect - github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 // indirect github.com/joomcode/errorx v0.1.0 - github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect - github.com/mholt/caddy v0.11.0 github.com/miekg/dns v1.0.15 - github.com/opentracing/opentracing-go v1.0.2 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/pkg/errors v0.8.0 - github.com/prometheus/client_golang v0.9.0-pre1 - github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 // indirect - github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 // indirect - github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d // indirect github.com/shirou/gopsutil v2.18.10+incompatible github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 // indirect go.uber.org/goleak v0.10.0 golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd golang.org/x/net v0.0.0-20181108082009-03003ca0c849 golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 // indirect - google.golang.org/grpc v1.16.0 // indirect gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 gopkg.in/yaml.v2 v2.2.1 ) diff --git a/go.sum b/go.sum index 4ecb93be..af10df24 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,11 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f h1:5ZfJxyXo8KyX8DgGXC5B7ILL8y51fci/qYz2B4j8iLY= github.com/StackExchange/wmi v0.0.0-20180725035823-b12b22c5341f/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6 h1:KXlsf+qt/X5ttPGEjR0tPH1xaWWoKBEg9Q1THAj2h3I= github.com/beefsack/go-rate v0.0.0-20180408011153-efa7637bb9b6/go.mod h1:6YNgTHLutezwnBvyneBbwvB8C82y3dcoOj5EQJIdGXA= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0= -github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7 h1:NpQ+gkFOH27AyDypSCJ/LdsIi/b4rdnEb1N5+IpFfYs= github.com/bluele/gcache v0.0.0-20171010155617-472614239ac7/go.mod h1:8c4/i2VlovMO2gBnHGQPN5EJw+H0lx1u/5p+cgsXtCk= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/coredns/coredns v1.2.6 h1:QIAOkBqVE44Zx0ttrFqgE5YhCEn64XPIngU60JyuTGM= -github.com/coredns/coredns v1.2.6/go.mod h1:zASH/MVDgR6XZTbxvOnsZfffS+31vg6Ackf/wo1+AM0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11 h1:m8nX8hsUghn853BJ5qB0lX+VvS6LTJPksWyILFZRYN4= -github.com/dnstap/golang-dnstap v0.0.0-20170829151710-2cf77a2b5e11/go.mod h1:s1PfVYYVmTMgCSPtho4LKBDecEHJWtiVDPNv78Z985U= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710 h1:QdyRyGZWLEvJG5Kw3VcVJvhXJ5tZ1MkRgqpJOEZSySM= -github.com/farsightsec/golang-framestream v0.0.0-20181102145529-8a0cb8ba8710/go.mod h1:eNde4IQyEiA5br02AouhEHCu3p3UzrCdFR4LuQHklMI= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= -github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/go-ole/go-ole v1.2.1 h1:2lOsA72HgjxAuMlKpFiCbHTvu44PIVkZ5hqm3RSdI/E= github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= github.com/go-test/deep v1.0.1 h1:UQhStjbkDClarlmv0am7OXXO4/GaPdCGiUiMTvi28sg= @@ -28,46 +16,21 @@ github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264 h1:roWyi0eEdiFreSq github.com/gobuffalo/packd v0.0.0-20181031195726-c82734870264/go.mod h1:Yf2toFaISlyQrr5TfO3h6DB9pl9mZRmyvBGQb/aQ/pI= github.com/gobuffalo/packr v1.19.0 h1:3UDmBDxesCOPF8iZdMDBBWKfkBoYujIMIZePnobqIUI= github.com/gobuffalo/packr v1.19.0/go.mod h1:MstrNkfCQhd5o+Ct4IJ0skWlxN8emOq8DsoT1G98VIU= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA= -github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645 h1:MJG/KsmcqMwFAkh8mTnAwhyKoB+sTAnY4CACC110tbU= -github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645/go.mod h1:6iZfnjpejD4L/4DwD7NryNaJyCQdzwWwH2MWhCA90Kw= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/joomcode/errorx v0.1.0 h1:QmJMiI1DE1UFje2aI1ZWO/VMT5a32qBoXUclGOt8vsc= github.com/joomcode/errorx v0.1.0/go.mod h1:kgco15ekB6cs+4Xjzo7SPeXzx38PbJzBwbnu9qfVNHQ= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4 h1:Mlji5gkcpzkqTROyE4ZxZ8hN7osunMb2RuGVrbvMvCc= github.com/markbates/oncer v0.0.0-20181014194634-05fccaae8fc4/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= -github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= -github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mholt/caddy v0.11.0 h1:cuhEyR7So/SBBRiAaiRBe9BoccDu6uveIPuM9FMMavg= -github.com/mholt/caddy v0.11.0/go.mod h1:Wb1PlT4DAYSqOEd03MsqkdkXnTxA8v9pKjdpxbqM1kY= github.com/miekg/dns v1.0.15 h1:9+UupePBQCG6zf1q/bGmTO1vumoG13jsrbWOSX1W6Tw= github.com/miekg/dns v1.0.15/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/opentracing/opentracing-go v1.0.2 h1:3jA2P6O1F9UOrWVpwrIo17pu01KWvNWg4X946/Y5Zwg= -github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v0.9.0-pre1 h1:AWTOhsOI9qxeirTuA0A4By/1Es1+y9EcCGY6bBZ2fhM= -github.com/prometheus/client_golang v0.9.0-pre1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8= -github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949 h1:MVbUQq1a49hMEISI29UcAUjywT3FyvDwx5up90OvVa4= -github.com/prometheus/common v0.0.0-20181109100915-0b1957f9d949/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d h1:GoAlyOgbOEIFdaDqxJVlbOQ1DtGmZWs/Qau0hIlk+WQ= -github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/shirou/gopsutil v2.18.10+incompatible h1:cy84jW6EVRPa5g9HAHrlbxMSIjBhDSX0OFYyMYminYs= github.com/shirou/gopsutil v2.18.10+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4 h1:udFKJ0aHUL60LboW/A+DfgoHVedieIzIXE8uylPue0U= @@ -82,29 +45,16 @@ go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4= go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd h1:VtIkGDhk0ph3t+THbvXHfMZ8QHgsBO39Nh52+74pq7w= golang.org/x/crypto v0.0.0-20181106171534-e4dc69e5b2fd/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181102091132-c10e9556a7bc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181108082009-03003ca0c849 h1:FSqE2GGG7wzsYUsWiQ8MZrvEd1EOyU3NCF0AW3Wtltg= golang.org/x/net v0.0.0-20181108082009-03003ca0c849/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8 h1:YoY1wS6JYVRpIfFngRf2HHo9R9dAne3xbkGOQ5rJXjU= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8 h1:Nw54tB0rB7hY/N0NQvRW8DG4Yk3Q6T9cu9RcFQDu1tc= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/grpc v1.16.0 h1:dz5IJGuC2BB7qXR5AyHNwAUBhZscK2xVez7mznh72sY= -google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477 h1:5xUJw+lg4zao9W4HIDzlFbMYgSgtvNVHh00MEHvbGpQ= gopkg.in/asaskevich/govalidator.v4 v4.0.0-20160518190739-766470278477/go.mod h1:QDV1vrFSrowdoOba0UM8VJPUZONT7dnfdLsM+GG53Z8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.1 h1:mUhvW9EsL+naU5Q3cakzfE91YhliOondGd6ZrsDBHQE= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/upstream/dns_upstream.go b/upstream/dns_upstream.go deleted file mode 100644 index 171f6362..00000000 --- a/upstream/dns_upstream.go +++ /dev/null @@ -1,105 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "time" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// DnsUpstream is a very simple upstream implementation for plain DNS -type DnsUpstream struct { - endpoint string // IP:port - timeout time.Duration // Max read and write timeout - proto string // Protocol (tcp, tcp-tls, or udp) - transport *Transport // Persistent connections cache -} - -// NewDnsUpstream creates a new DNS upstream -func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) { - u := &DnsUpstream{ - endpoint: endpoint, - timeout: defaultTimeout, - proto: proto, - } - - var tlsConfig *tls.Config - - if proto == "tcp-tls" { - tlsConfig = new(tls.Config) - tlsConfig.ServerName = tlsServerName - } - - // Initialize the connections cache - u.transport = NewTransport(endpoint) - u.transport.tlsConfig = tlsConfig - u.transport.Start() - - return u, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - resp, err := u.exchange(u.proto, query) - - // Retry over TCP if response is truncated - if err == dns.ErrTruncated && u.proto == "udp" { - resp, err = u.exchange("tcp", query) - } else if err == dns.ErrTruncated && resp != nil { - // Reassemble something to be sent to client - m := new(dns.Msg) - m.SetReply(query) - m.Truncated = true - m.Authoritative = true - m.Rcode = dns.RcodeSuccess - return m, nil - } - - if err != nil { - resp = &dns.Msg{} - resp.SetRcode(resp, dns.RcodeServerFailure) - } - - return resp, err -} - -// Clear resources -func (u *DnsUpstream) Close() error { - // Close active connections - u.transport.Stop() - return nil -} - -// Performs a synchronous query. It sends the message m via the conn -// c and waits for a reply. The conn c is not closed. -func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) { - // Establish a connection if needed (or reuse cached) - conn, err := u.transport.Dial(proto) - if err != nil { - return nil, err - } - - // Write the request with a timeout - conn.SetWriteDeadline(time.Now().Add(u.timeout)) - if err = conn.WriteMsg(query); err != nil { - conn.Close() // Not giving it back - return nil, err - } - - // Write response with a timeout - conn.SetReadDeadline(time.Now().Add(u.timeout)) - r, err = conn.ReadMsg() - if err != nil { - conn.Close() // Not giving it back - } else if err == nil && r.Id != query.Id { - err = dns.ErrId - conn.Close() // Not giving it back - } - - if err == nil { - // Return it back to the connections cache if there were no errors - u.transport.Yield(conn) - } - return r, err -} diff --git a/upstream/helpers.go b/upstream/helpers.go deleted file mode 100644 index 520a7a8b..00000000 --- a/upstream/helpers.go +++ /dev/null @@ -1,98 +0,0 @@ -package upstream - -import ( - "net" - "strings" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -// Detects the upstream type from the specified url and creates a proper Upstream object -func NewUpstream(url string, bootstrap string) (Upstream, error) { - proto := "udp" - prefix := "" - - switch { - case strings.HasPrefix(url, "tcp://"): - proto = "tcp" - prefix = "tcp://" - case strings.HasPrefix(url, "tls://"): - proto = "tcp-tls" - prefix = "tls://" - case strings.HasPrefix(url, "https://"): - return NewHttpsUpstream(url, bootstrap) - } - - hostname := strings.TrimPrefix(url, prefix) - - host, port, err := net.SplitHostPort(hostname) - if err != nil { - // Set port depending on the protocol - switch proto { - case "udp": - port = "53" - case "tcp": - port = "53" - case "tcp-tls": - port = "853" - } - - // Set host = hostname - host = hostname - } - - // Try to resolve the host address (or check if it's an IP address) - bootstrapResolver := CreateResolver(bootstrap) - ips, err := bootstrapResolver.LookupIPAddr(context.Background(), host) - - if err != nil || len(ips) == 0 { - return nil, err - } - - addr := ips[0].String() - endpoint := net.JoinHostPort(addr, port) - tlsServerName := "" - - if proto == "tcp-tls" && host != addr { - // Check if we need to specify TLS server name - tlsServerName = host - } - - return NewDnsUpstream(endpoint, proto, tlsServerName) -} - -func CreateResolver(bootstrap string) *net.Resolver { - bootstrapResolver := net.DefaultResolver - - if bootstrap != "" { - bootstrapResolver = &net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - var d net.Dialer - return d.DialContext(ctx, network, bootstrap) - }, - } - } - - return bootstrapResolver -} - -// Performs a simple health-check of the specified upstream -func IsAlive(u Upstream) (bool, error) { - // Using ipv4only.arpa. domain as it is a part of DNS64 RFC and it should exist everywhere - ping := new(dns.Msg) - ping.SetQuestion("ipv4only.arpa.", dns.TypeA) - - resp, err := u.Exchange(context.Background(), ping) - - // If we got a header, we're alright, basically only care about I/O errors 'n stuff. - if err != nil && resp != nil { - // Silly check, something sane came back. - if resp.Rcode != dns.RcodeServerFailure { - err = nil - } - } - - return err == nil, err -} diff --git a/upstream/https_upstream.go b/upstream/https_upstream.go deleted file mode 100644 index d7d7bdde..00000000 --- a/upstream/https_upstream.go +++ /dev/null @@ -1,128 +0,0 @@ -package upstream - -import ( - "bytes" - "crypto/tls" - "fmt" - "io/ioutil" - "log" - "net" - "net/http" - "net/url" - "time" - - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" - "golang.org/x/net/http2" -) - -const ( - dnsMessageContentType = "application/dns-message" - defaultKeepAlive = 30 * time.Second -) - -// HttpsUpstream is the upstream implementation for DNS-over-HTTPS -type HttpsUpstream struct { - client *http.Client - endpoint *url.URL -} - -// NewHttpsUpstream creates a new DNS-over-HTTPS upstream from the specified url -func NewHttpsUpstream(endpoint string, bootstrap string) (Upstream, error) { - u, err := url.Parse(endpoint) - if err != nil { - return nil, err - } - - // Initialize bootstrap resolver - bootstrapResolver := CreateResolver(bootstrap) - dialer := &net.Dialer{ - Timeout: defaultTimeout, - KeepAlive: defaultKeepAlive, - DualStack: true, - Resolver: bootstrapResolver, - } - - // Update TLS and HTTP client configuration - tlsConfig := &tls.Config{ServerName: u.Hostname()} - transport := &http.Transport{ - TLSClientConfig: tlsConfig, - DisableCompression: true, - MaxIdleConns: 1, - DialContext: dialer.DialContext, - } - http2.ConfigureTransport(transport) - - client := &http.Client{ - Timeout: defaultTimeout, - Transport: transport, - } - - return &HttpsUpstream{client: client, endpoint: u}, nil -} - -// Exchange provides an implementation for the Upstream interface -func (u *HttpsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) { - queryBuf, err := query.Pack() - if err != nil { - return nil, errors.Wrap(err, "failed to pack DNS query") - } - - // No content negotiation for now, use DNS wire format - buf, backendErr := u.exchangeWireformat(queryBuf) - if backendErr == nil { - response := &dns.Msg{} - if err := response.Unpack(buf); err != nil { - return nil, errors.Wrap(err, "failed to unpack DNS response from body") - } - - response.Id = query.Id - return response, nil - } - - log.Printf("failed to connect to an HTTPS backend %q due to %s", u.endpoint, backendErr) - return nil, backendErr -} - -// Perform message exchange with the default UDP wireformat defined in current draft -// https://tools.ietf.org/html/draft-ietf-doh-dns-over-https-10 -func (u *HttpsUpstream) exchangeWireformat(msg []byte) ([]byte, error) { - req, err := http.NewRequest("POST", u.endpoint.String(), bytes.NewBuffer(msg)) - if err != nil { - return nil, errors.Wrap(err, "failed to create an HTTPS request") - } - - req.Header.Add("Content-Type", dnsMessageContentType) - req.Header.Add("Accept", dnsMessageContentType) - req.Host = u.endpoint.Hostname() - - resp, err := u.client.Do(req) - if err != nil { - return nil, errors.Wrap(err, "failed to perform an HTTPS request") - } - - // Check response status code - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("returned status code %d", resp.StatusCode) - } - - contentType := resp.Header.Get("Content-Type") - if contentType != dnsMessageContentType { - return nil, fmt.Errorf("return wrong content type %s", contentType) - } - - // Read application/dns-message response from the body - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, errors.Wrap(err, "failed to read the response body") - } - - return buf, nil -} - -// Clear resources -func (u *HttpsUpstream) Close() error { - return nil -} diff --git a/upstream/persistent.go b/upstream/persistent.go deleted file mode 100644 index 91cc9094..00000000 --- a/upstream/persistent.go +++ /dev/null @@ -1,210 +0,0 @@ -package upstream - -import ( - "crypto/tls" - "net" - "sort" - "sync/atomic" - "time" - - "github.com/miekg/dns" -) - -// Persistent connections cache -- almost similar to the same used in the CoreDNS forward plugin - -const ( - defaultExpire = 10 * time.Second - minDialTimeout = 100 * time.Millisecond - maxDialTimeout = 30 * time.Second - defaultDialTimeout = 30 * time.Second - cumulativeAvgWeight = 4 -) - -// a persistConn hold the dns.Conn and the last used time. -type persistConn struct { - c *dns.Conn - used time.Time -} - -// Transport hold the persistent cache. -type Transport struct { - avgDialTime int64 // kind of average time of dial time - conns map[string][]*persistConn // Buckets for udp, tcp and tcp-tls. - expire time.Duration // After this duration a connection is expired. - addr string - tlsConfig *tls.Config - - dial chan string - yield chan *dns.Conn - ret chan *dns.Conn - stop chan bool -} - -// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. -func (t *Transport) Dial(proto string) (*dns.Conn, error) { - // If tls has been configured; use it. - if t.tlsConfig != nil { - proto = "tcp-tls" - } - - t.dial <- proto - c := <-t.ret - - if c != nil { - return c, nil - } - - reqTime := time.Now() - timeout := t.dialTimeout() - if proto == "tcp-tls" { - conn, err := dns.DialTimeoutWithTLS(proto, t.addr, t.tlsConfig, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err - } - conn, err := dns.DialTimeout(proto, t.addr, timeout) - t.updateDialTimeout(time.Since(reqTime)) - return conn, err -} - -// Yield return the connection to transport for reuse. -func (t *Transport) Yield(c *dns.Conn) { t.yield <- c } - -// Start starts the transport's connection manager. -func (t *Transport) Start() { go t.connManager() } - -// Stop stops the transport's connection manager. -func (t *Transport) Stop() { close(t.stop) } - -// SetExpire sets the connection expire time in transport. -func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } - -// SetTLSConfig sets the TLS config in transport. -func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } - -func NewTransport(addr string) *Transport { - t := &Transport{ - avgDialTime: int64(defaultDialTimeout / 2), - conns: make(map[string][]*persistConn), - expire: defaultExpire, - addr: addr, - dial: make(chan string), - yield: make(chan *dns.Conn), - ret: make(chan *dns.Conn), - stop: make(chan bool), - } - return t -} - -func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { - dt := time.Duration(atomic.LoadInt64(currentAvg)) - atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) -} - -func (t *Transport) dialTimeout() time.Duration { - return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) -} - -func (t *Transport) updateDialTimeout(newDialTime time.Duration) { - averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) -} - -// limitTimeout is a utility function to auto-tune timeout values -// average observed time is moved towards the last observed delay moderated by a weight -// next timeout to use will be the double of the computed average, limited by min and max frame. -func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { - rt := time.Duration(atomic.LoadInt64(currentAvg)) - if rt < minValue { - return minValue - } - if rt < maxValue/2 { - return 2 * rt - } - return maxValue -} - -// connManagers manages the persistent connection cache for UDP and TCP. -func (t *Transport) connManager() { - ticker := time.NewTicker(t.expire) -Wait: - for { - select { - case proto := <-t.dial: - // take the last used conn - complexity O(1) - if stack := t.conns[proto]; len(stack) > 0 { - pc := stack[len(stack)-1] - if time.Since(pc.used) < t.expire { - // Found one, remove from pool and return this conn. - t.conns[proto] = stack[:len(stack)-1] - t.ret <- pc.c - continue Wait - } - // clear entire cache if the last conn is expired - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - } - - t.ret <- nil - - case conn := <-t.yield: - - // no proto here, infer from config and conn - if _, ok := conn.Conn.(*net.UDPConn); ok { - t.conns["udp"] = append(t.conns["udp"], &persistConn{conn, time.Now()}) - continue Wait - } - - if t.tlsConfig == nil { - t.conns["tcp"] = append(t.conns["tcp"], &persistConn{conn, time.Now()}) - continue Wait - } - - t.conns["tcp-tls"] = append(t.conns["tcp-tls"], &persistConn{conn, time.Now()}) - - case <-ticker.C: - t.cleanup(false) - - case <-t.stop: - t.cleanup(true) - close(t.ret) - return - } - } -} - -// closeConns closes connections. -func closeConns(conns []*persistConn) { - for _, pc := range conns { - pc.c.Close() - } -} - -// cleanup removes connections from cache. -func (t *Transport) cleanup(all bool) { - staleTime := time.Now().Add(-t.expire) - for proto, stack := range t.conns { - if len(stack) == 0 { - continue - } - if all { - t.conns[proto] = nil - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack) - continue - } - if stack[0].used.After(staleTime) { - continue - } - - // connections in stack are sorted by "used" - good := sort.Search(len(stack), func(i int) bool { - return stack[i].used.After(staleTime) - }) - t.conns[proto] = stack[good:] - // now, the connections being passed to closeConns() are not reachable from - // transport methods anymore. So, it's safe to close them in a separate goroutine - go closeConns(stack[:good]) - } -} diff --git a/upstream/setup.go b/upstream/setup.go deleted file mode 100644 index 4aed6bcf..00000000 --- a/upstream/setup.go +++ /dev/null @@ -1,81 +0,0 @@ -package upstream - -import ( - "log" - - "github.com/coredns/coredns/core/dnsserver" - "github.com/coredns/coredns/plugin" - "github.com/mholt/caddy" -) - -func init() { - caddy.RegisterPlugin("upstream", caddy.Plugin{ - ServerType: "dns", - Action: setup, - }) -} - -// Read the configuration and initialize upstreams -func setup(c *caddy.Controller) error { - p, err := setupPlugin(c) - if err != nil { - return err - } - config := dnsserver.GetConfig(c) - config.AddPlugin(func(next plugin.Handler) plugin.Handler { - p.Next = next - return p - }) - - c.OnShutdown(p.onShutdown) - return nil -} - -// Read the configuration -func setupPlugin(c *caddy.Controller) (*UpstreamPlugin, error) { - p := New() - - log.Println("Initializing the Upstream plugin") - - bootstrap := "" - upstreamUrls := []string{} - for c.Next() { - args := c.RemainingArgs() - if len(args) > 0 { - upstreamUrls = append(upstreamUrls, args...) - } - for c.NextBlock() { - switch c.Val() { - case "bootstrap": - if !c.NextArg() { - return nil, c.ArgErr() - } - bootstrap = c.Val() - } - } - } - - for _, url := range upstreamUrls { - u, err := NewUpstream(url, bootstrap) - if err != nil { - log.Printf("Cannot initialize upstream %s", url) - return nil, err - } - - p.Upstreams = append(p.Upstreams, u) - } - - return p, nil -} - -func (p *UpstreamPlugin) onShutdown() error { - for i := range p.Upstreams { - u := p.Upstreams[i] - err := u.Close() - if err != nil { - log.Printf("Error while closing the upstream: %s", err) - } - } - - return nil -} diff --git a/upstream/setup_test.go b/upstream/setup_test.go deleted file mode 100644 index 82b8ab5c..00000000 --- a/upstream/setup_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package upstream - -import ( - "testing" - - "github.com/mholt/caddy" -) - -func TestSetup(t *testing.T) { - var tests = []struct { - config string - }{ - {`upstream 8.8.8.8`}, - {`upstream 8.8.8.8 { - bootstrap 8.8.8.8:53 -}`}, - {`upstream tls://1.1.1.1 8.8.8.8 { - bootstrap 1.1.1.1 -}`}, - } - - for _, test := range tests { - c := caddy.NewTestController("dns", test.config) - err := setup(c) - if err != nil { - t.Fatalf("Test failed") - } - } -} diff --git a/upstream/upstream.go b/upstream/upstream.go deleted file mode 100644 index faef224e..00000000 --- a/upstream/upstream.go +++ /dev/null @@ -1,57 +0,0 @@ -package upstream - -import ( - "time" - - "github.com/coredns/coredns/plugin" - "github.com/miekg/dns" - "github.com/pkg/errors" - "golang.org/x/net/context" -) - -const ( - defaultTimeout = 5 * time.Second -) - -// Upstream is a simplified interface for proxy destination -type Upstream interface { - Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) - Close() error -} - -// UpstreamPlugin is a simplified DNS proxy using a generic upstream interface -type UpstreamPlugin struct { - Upstreams []Upstream - Next plugin.Handler -} - -// Initialize the upstream plugin -func New() *UpstreamPlugin { - p := &UpstreamPlugin{ - Upstreams: []Upstream{}, - } - - return p -} - -// ServeDNS implements interface for CoreDNS plugin -func (p *UpstreamPlugin) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { - var reply *dns.Msg - var backendErr error - - for i := range p.Upstreams { - upstream := p.Upstreams[i] - reply, backendErr = upstream.Exchange(ctx, r) - if backendErr == nil { - w.WriteMsg(reply) - return 0, nil - } - } - - return dns.RcodeServerFailure, errors.Wrap(backendErr, "failed to contact any of the upstreams") -} - -// Name implements interface for CoreDNS plugin -func (p *UpstreamPlugin) Name() string { - return "upstream" -} diff --git a/upstream/upstream_test.go b/upstream/upstream_test.go deleted file mode 100644 index 9221e6f5..00000000 --- a/upstream/upstream_test.go +++ /dev/null @@ -1,187 +0,0 @@ -package upstream - -import ( - "net" - "testing" - - "github.com/miekg/dns" - "golang.org/x/net/context" -) - -func TestDnsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestHttpsUpstreamIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsOverTlsIsAlive(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstreamIsAlive(t, u) - } -} - -func TestDnsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"8.8.8.8:53", "8.8.8.8:53"}, - {"1.1.1.1", ""}, - {"tcp://1.1.1.1:53", ""}, - {"176.103.130.130:5353", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS upstream") - } - - testUpstream(t, u) - } -} - -func TestHttpsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"https://cloudflare-dns.com/dns-query", "8.8.8.8:53"}, - {"https://dns.google.com/experimental", "8.8.8.8:53"}, - {"https://doh.cleanbrowsing.org/doh/security-filter/", ""}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-HTTPS upstream") - } - - testUpstream(t, u) - } -} - -func TestDnsOverTlsUpstream(t *testing.T) { - var tests = []struct { - url string - bootstrap string - }{ - {"tls://1.1.1.1", ""}, - {"tls://9.9.9.9:853", ""}, - {"tls://security-filter-dns.cleanbrowsing.org", "8.8.8.8:53"}, - {"tls://adult-filter-dns.cleanbrowsing.org:853", "8.8.8.8:53"}, - } - - for _, test := range tests { - u, err := NewUpstream(test.url, test.bootstrap) - - if err != nil { - t.Errorf("cannot create a DNS-over-TLS upstream") - } - - testUpstream(t, u) - } -} - -func testUpstreamIsAlive(t *testing.T, u Upstream) { - alive, err := IsAlive(u) - if !alive || err != nil { - t.Errorf("Upstream is not alive") - } - - u.Close() -} - -func testUpstream(t *testing.T, u Upstream) { - var tests = []struct { - name string - expected net.IP - }{ - {"google-public-dns-a.google.com.", net.IPv4(8, 8, 8, 8)}, - {"google-public-dns-b.google.com.", net.IPv4(8, 8, 4, 4)}, - } - - for _, test := range tests { - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: test.name, Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - resp, err := u.Exchange(context.Background(), &req) - - if err != nil { - t.Fatalf("error while making an upstream request: %s", err) - } - - if len(resp.Answer) != 1 { - t.Fatalf("no answer section in the response") - } - if answer, ok := resp.Answer[0].(*dns.A); ok { - if !test.expected.Equal(answer.A) { - t.Errorf("wrong IP in the response: %v", answer.A) - } - } - } - - err := u.Close() - if err != nil { - t.Errorf("Error while closing the upstream: %s", err) - } -}