diff --git a/AGHTechDoc.md b/AGHTechDoc.md index c3d21db2..a1323dde 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -1064,11 +1064,12 @@ When a new DNS request is received and processed, we store information about thi "QT":"...", // question type "QC":"...", // question class "Answer":"...", + "OrigAnswer":"...", "Result":{ "IsFiltered":true, "Reason":3, "Rule":"...", - "FilterID":1 + "FilterID":1, }, "Elapsed":12345, "Upstream":"...", @@ -1121,6 +1122,13 @@ Response: } ... ], + "original_answer":[ // Answer from upstream server (optional) + { + "type":"AAAA", + "value":"::" + } + ... + ], "client":"127.0.0.1", "elapsedMs":"0.098403", "filterId":1, @@ -1131,6 +1139,7 @@ Response: }, "reason":"FilteredBlackList", "rule":"||doubleclick.net^", + "service_name": "...", // set if reason=FilteredBlockedService "status":"NOERROR", "time":"2006-01-02T15:04:05.999999999Z07:00" } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index d83753a4..a9fc04c9 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -164,6 +164,15 @@ func (s *Server) Start(config *ServerConfig) error { // startInternal starts without locking func (s *Server) startInternal(config *ServerConfig) error { + err := s.prepare(config) + if err != nil { + return err + } + return s.dnsProxy.Start() +} + +// Prepare the object +func (s *Server) prepare(config *ServerConfig) error { if s.dnsProxy != nil { return errors.New("DNS server is already started") } @@ -243,7 +252,7 @@ func (s *Server) startInternal(config *ServerConfig) error { // Initialize and start the DNS proxy s.dnsProxy = &proxy.Proxy{Config: proxyConfig} - return s.dnsProxy.Start() + return nil } // Stop stops the DNS server @@ -344,6 +353,7 @@ func (s *Server) beforeRequestHandler(p *proxy.Proxy, d *proxy.DNSContext) (bool } // handleDNSRequest filters the incoming DNS requests and writes them to the query log +// nolint (gocyclo) func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { start := time.Now() @@ -372,6 +382,7 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { return err } + var origResp *dns.Msg if d.Res == nil { answer := []dns.RR{} originalQuestion := d.Req.Question[0] @@ -396,6 +407,18 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { answer = append(answer, d.Res.Answer...) // host -> IP d.Res.Answer = answer } + + } else if res.Reason != dnsfilter.NotFilteredWhiteList { + origResp2 := d.Res + res, err = s.filterResponse(d) + if err != nil { + return err + } + if res != nil { + origResp = origResp2 // matched by response + } else { + res = &dnsfilter.Result{} + } } } @@ -416,11 +439,18 @@ func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { // Synchronize access to s.queryLog and s.stats so they won't be suddenly uninitialized while in use. // This can happen after proxy server has been stopped, but its workers haven't yet exited. if shouldLog && s.queryLog != nil { - upstreamAddr := "" - if d.Upstream != nil { - upstreamAddr = d.Upstream.Address() + p := querylog.AddParams{ + Question: msg, + Answer: d.Res, + OrigAnswer: origResp, + Result: res, + Elapsed: elapsed, + ClientIP: getIP(d.Addr), } - s.queryLog.Add(msg, d.Res, res, elapsed, getIP(d.Addr), upstreamAddr) + if d.Upstream != nil { + p.Upstream = d.Upstream.Address() + } + s.queryLog.Add(p) } s.updateStats(d, elapsed, *res) @@ -538,6 +568,54 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error return &res, err } +// If response contains CNAME, A or AAAA records, we apply filtering to each canonical host name or IP address. +// If this is a match, we set a new response in d.Res and return. +func (s *Server) filterResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) { + for _, a := range d.Res.Answer { + host := "" + + switch v := a.(type) { + case *dns.CNAME: + log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) + host = strings.TrimSuffix(v.Target, ".") + + case *dns.A: + host = v.A.String() + log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) + + case *dns.AAAA: + host = v.AAAA.String() + log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) + + default: + continue + } + + s.RLock() + // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. + // This could happen after proxy server has been stopped, but its workers are not yet exited. + if !s.conf.ProtectionEnabled || s.dnsFilter == nil { + s.RUnlock() + continue + } + setts := dnsfilter.RequestFilteringSettings{} + setts.FilteringEnabled = true + res, err := s.dnsFilter.CheckHost(host, d.Req.Question[0].Qtype, &setts) + s.RUnlock() + + if err != nil { + return nil, err + + } else if res.IsFiltered { + d.Res = s.genDNSFilterMessage(d, &res) + log.Debug("DNSFwd: Matched %s by response: %s", d.Req.Question[0].Name, host) + return &res, nil + } + } + + return nil, nil +} + // genDNSFilterMessage generates a DNS message corresponding to the filtering result func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { m := d.Req diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index f05934f6..7fcb5fb4 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -16,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -246,6 +247,142 @@ func TestBlockedRequest(t *testing.T) { } } +// testUpstream is a mock of real upstream. +// specify fields with necessary values to simulate real upstream behaviour +type testUpstream struct { + cn map[string]string // Map of [name]canonical_name + ipv4 map[string][]net.IP // Map of [name]IPv4 + ipv6 map[string][]net.IP // Map of [name]IPv6 +} + +func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { + resp := dns.Msg{} + resp.SetReply(m) + hasARecord := false + hasAAAARecord := false + + reqType := m.Question[0].Qtype + name := m.Question[0].Name + + // Let's check if we have any CNAME for given name + if cname, ok := u.cn[name]; ok { + cn := dns.CNAME{} + cn.Hdr.Name = name + cn.Hdr.Rrtype = dns.TypeCNAME + cn.Target = cname + resp.Answer = append(resp.Answer, &cn) + } + + // Let's check if we can add some A records to the answer + if ipv4addr, ok := u.ipv4[name]; ok && reqType == dns.TypeA { + hasARecord = true + for _, ipv4 := range ipv4addr { + respA := dns.A{} + respA.Hdr.Rrtype = dns.TypeA + respA.Hdr.Name = name + respA.A = ipv4 + resp.Answer = append(resp.Answer, &respA) + } + } + + // Let's check if we can add some AAAA records to the answer + if u.ipv6 != nil { + if ipv6addr, ok := u.ipv6[name]; ok && reqType == dns.TypeAAAA { + hasAAAARecord = true + for _, ipv6 := range ipv6addr { + respAAAA := dns.A{} + respAAAA.Hdr.Rrtype = dns.TypeAAAA + respAAAA.Hdr.Name = name + respAAAA.A = ipv6 + resp.Answer = append(resp.Answer, &respAAAA) + } + } + } + + if len(resp.Answer) == 0 { + if hasARecord || hasAAAARecord { + // Set No Error RCode if there are some records for given Qname but we didn't apply them + resp.SetRcode(m, dns.RcodeSuccess) + } else { + // Set NXDomain RCode otherwise + resp.SetRcode(m, dns.RcodeNameError) + } + } + + return &resp, nil +} + +func (u *testUpstream) Address() string { + return "test" +} + +func (s *Server) startWithUpstream(u upstream.Upstream) error { + s.Lock() + defer s.Unlock() + err := s.prepare(nil) + if err != nil { + return err + } + s.dnsProxy.Upstreams = []upstream.Upstream{u} + return s.dnsProxy.Start() +} + +// testCNAMEs is a simple map of names and CNAMEs necessary for the testUpstream work +var testCNAMEs = map[string]string{ + "badhost.": "null.example.org.", + "whitelist.example.org.": "null.example.org.", +} + +// testIPv4 is a simple map of names and IPv4s necessary for the testUpstream work +var testIPv4 = map[string][]net.IP{ + "null.example.org.": {{1, 2, 3, 4}}, + "example.org.": {{127, 0, 0, 255}}, +} + +func TestBlockCNAME(t *testing.T) { + s := createTestServer(t) + testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} + err := s.startWithUpstream(testUpstm) + assert.True(t, err == nil) + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + // 'badhost' has a canonical name 'null.example.org' which is blocked by filters: + // response is blocked + req := dns.Msg{} + req.Id = dns.Id() + req.Question = []dns.Question{ + {Name: "badhost.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err := dns.Exchange(&req, addr.String()) + assert.True(t, err == nil) + assert.True(t, reply.Rcode == dns.RcodeNameError) + + // 'whitelist.example.org' has a canonical name 'null.example.org' which is blocked by filters + // but 'whitelist.example.org' is in a whitelist: + // response isn't blocked + req = dns.Msg{} + req.Id = dns.Id() + req.Question = []dns.Question{ + {Name: "whitelist.example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err = dns.Exchange(&req, addr.String()) + assert.True(t, err == nil) + assert.True(t, reply.Rcode == dns.RcodeSuccess) + + // 'example.org' has a canonical name 'cname1' with IP 127.0.0.255 which is blocked by filters: + // response is blocked + req = dns.Msg{} + req.Id = dns.Id() + req.Question = []dns.Question{ + {Name: "example.org.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + reply, err = dns.Exchange(&req, addr.String()) + assert.True(t, err == nil) + assert.True(t, reply.Rcode == dns.RcodeNameError) + + _ = s.Stop() +} + func TestNullBlockedRequest(t *testing.T) { s := createTestServer(t) s.conf.FilteringConfig.BlockingMode = "null_ip" @@ -376,7 +513,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { } func createTestServer(t *testing.T) *Server { - rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n" + rules := "||nxdomain.example.org^\n||null.example.org^\n127.0.0.1 host.example.org\n@@||whitelist.example.org^\n||127.0.0.255\n" filters := map[int]string{} filters[0] = rules c := dnsfilter.Config{} diff --git a/querylog/qlog.go b/querylog/qlog.go index bf585c53..8c48c969 100644 --- a/querylog/qlog.go +++ b/querylog/qlog.go @@ -2,7 +2,6 @@ package querylog import ( "fmt" - "net" "os" "path/filepath" "strconv" @@ -96,52 +95,60 @@ type logEntry struct { QType string `json:"QT"` QClass string `json:"QC"` - Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net + OrigAnswer []byte `json:",omitempty"` + Result dnsfilter.Result Elapsed time.Duration Upstream string `json:",omitempty"` // if empty, means it was cached } -func (l *queryLog) Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip net.IP, upstream string) { +func (l *queryLog) Add(params AddParams) { if !l.conf.Enabled { return } - if question == nil || len(question.Question) != 1 || len(question.Question[0].Name) == 0 || - ip == nil { + if params.Question == nil || len(params.Question.Question) != 1 || len(params.Question.Question[0].Name) == 0 || + params.ClientIP == nil { return } - var a []byte - var err error - - if answer != nil { - a, err = answer.Pack() - if err != nil { - log.Printf("failed to pack answer for querylog: %s", err) - return - } - } - - if result == nil { - result = &dnsfilter.Result{} + if params.Result == nil { + params.Result = &dnsfilter.Result{} } now := time.Now() entry := logEntry{ - IP: ip.String(), + IP: params.ClientIP.String(), Time: now, - Answer: a, - Result: *result, - Elapsed: elapsed, - Upstream: upstream, + Result: *params.Result, + Elapsed: params.Elapsed, + Upstream: params.Upstream, } - q := question.Question[0] + q := params.Question.Question[0] entry.QHost = strings.ToLower(q.Name[:len(q.Name)-1]) // remove the last dot entry.QType = dns.Type(q.Qtype).String() entry.QClass = dns.Class(q.Qclass).String() + if params.Answer != nil { + a, err := params.Answer.Pack() + if err != nil { + log.Info("Querylog: Answer.Pack(): %s", err) + return + } + entry.Answer = a + } + + if params.OrigAnswer != nil { + a, err := params.OrigAnswer.Pack() + if err != nil { + log.Info("Querylog: OrigAnswer.Pack(): %s", err) + return + } + entry.OrigAnswer = a + } + l.bufferLock.Lock() l.buffer = append(l.buffer, &entry) needFlush := false @@ -335,6 +342,19 @@ func (l *queryLog) getData(params getDataParams) map[string]interface{} { jsonEntry["answer"] = answers } + if len(entry.OrigAnswer) != 0 { + a := new(dns.Msg) + err := a.Unpack(entry.OrigAnswer) + if err == nil { + answers = answerToMap(a) + if answers != nil { + jsonEntry["original_answer"] = answers + } + } else { + log.Debug("Querylog: a.Unpack(entry.OrigAnswer): %s: %s", err, string(entry.OrigAnswer)) + } + } + data = append(data, jsonEntry) } diff --git a/querylog/querylog.go b/querylog/querylog.go index e7e790e7..7d479d92 100644 --- a/querylog/querylog.go +++ b/querylog/querylog.go @@ -22,7 +22,7 @@ type QueryLog interface { Close() // Add a log entry - Add(question *dns.Msg, answer *dns.Msg, result *dnsfilter.Result, elapsed time.Duration, ip net.IP, upstream string) + Add(params AddParams) // WriteDiskConfig - write configuration WriteDiskConfig(dc *DiskConfig) @@ -42,6 +42,17 @@ type Config struct { HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) } +// AddParams - parameters for Add() +type AddParams struct { + Question *dns.Msg + Answer *dns.Msg // The response we sent to the client (optional) + OrigAnswer *dns.Msg // The response from an upstream server (optional) + Result *dnsfilter.Result // Filtering result (optional) + Elapsed time.Duration // Time spent for processing the request + ClientIP net.IP + Upstream string +} + // New - create a new instance of the query log func New(conf Config) QueryLog { return newQueryLog(conf) diff --git a/querylog/querylog_file.go b/querylog/querylog_file.go index a2250581..e0540f43 100644 --- a/querylog/querylog_file.go +++ b/querylog/querylog_file.go @@ -574,6 +574,8 @@ func decode(ent *logEntry, str string) { case "Answer": ent.Answer, err = base64.StdEncoding.DecodeString(v) + case "OrigAnswer": + ent.OrigAnswer, err = base64.StdEncoding.DecodeString(v) case "IsFiltered": b, err = strconv.ParseBool(v) diff --git a/querylog/querylog_test.go b/querylog/querylog_test.go index 5b2776cd..8c8b9bb4 100644 --- a/querylog/querylog_test.go +++ b/querylog/querylog_test.go @@ -115,7 +115,14 @@ func addEntry(l *queryLog, host, answerStr, client string) { answer.A = net.ParseIP(answerStr) a.Answer = append(a.Answer, answer) res := dnsfilter.Result{} - l.Add(&q, &a, &res, 0, net.ParseIP(client), "upstream") + params := AddParams{ + Question: &q, + Answer: &a, + Result: &res, + ClientIP: net.ParseIP(client), + Upstream: "upstream", + } + l.Add(params) } func checkEntry(t *testing.T, m map[string]interface{}, host, answer, client string) bool {