From 1a1c09135d1548043b7f92d51b7c786da17e8bfb Mon Sep 17 00:00:00 2001 From: Simon Zolin Date: Thu, 16 Apr 2020 18:56:47 +0300 Subject: [PATCH] + auto-hosts: respond to PTR requests Close #1562 Squashed commit of the following: commit d5c6bb0e5f0c8c1618bd0df764ae86a5e62a850b Author: Simon Zolin Date: Mon Apr 13 14:10:10 2020 +0300 + auto-hosts: respond to PTR requests --- AGHTechDoc.md | 5 +- dnsfilter/dnsfilter.go | 18 +++- dnsforward/dnsforward.go | 30 ++++--- util/auto_hosts.go | 176 ++++++++++++++++++++++++++++++++------- util/auto_hosts_test.go | 30 ++++++- 5 files changed, 210 insertions(+), 49 deletions(-) diff --git a/AGHTechDoc.md b/AGHTechDoc.md index 41413046..4c8e0fda 100644 --- a/AGHTechDoc.md +++ b/AGHTechDoc.md @@ -1329,7 +1329,10 @@ This is how DNS requests and responses are filtered by AGH: * 'dnsproxy' module receives DNS request from client and passes control to AGH * AGH applies filtering logic to the host name in DNS Question: - * process Rewrite rules + * process Rewrite rules. + Can set CNAME and a list of IP addresses. + * process /etc/hosts entries. + Can set a list of IP addresses or a hostname (for PTR requests). * match host name against filtering lists * match host name against blocked services rules * process SafeSearch rules diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 4744ef5e..60a32d39 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -273,8 +273,13 @@ type Result struct { FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to // for ReasonRewrite: - CanonName string `json:",omitempty"` // CNAME value - IPList []net.IP `json:",omitempty"` // list of IP addresses + CanonName string `json:",omitempty"` // CNAME value + + // for RewriteEtcHosts: + ReverseHost string `json:",omitempty"` + + // for ReasonRewrite & RewriteEtcHosts: + IPList []net.IP `json:",omitempty"` // list of IP addresses // for FilteredBlockedService: ServiceName string `json:",omitempty"` // Name of the blocked service @@ -312,12 +317,19 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering } if d.Config.AutoHosts != nil { - ips := d.Config.AutoHosts.Process(host) + ips := d.Config.AutoHosts.Process(host, qtype) if ips != nil { result.Reason = RewriteEtcHosts result.IPList = ips return result, nil } + + revHost := d.Config.AutoHosts.ProcessReverse(host, qtype) + if len(revHost) != 0 { + result.Reason = RewriteEtcHosts + result.ReverseHost = revHost + "." + return result, nil + } } // try filter lists first diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index de396e9f..b191a966 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -810,23 +810,16 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns e.Client = addr.IP } e.Time = uint32(elapsed / 1000) - switch res.Reason { + e.Result = stats.RNotFiltered - case dnsfilter.NotFilteredNotFound: - fallthrough - case dnsfilter.NotFilteredWhiteList: - fallthrough - case dnsfilter.NotFilteredError: - fallthrough - case dnsfilter.ReasonRewrite: - fallthrough - case dnsfilter.RewriteEtcHosts: - e.Result = stats.RNotFiltered + switch res.Reason { case dnsfilter.FilteredSafeBrowsing: e.Result = stats.RSafeBrowsing + case dnsfilter.FilteredParental: e.Result = stats.RParental + case dnsfilter.FilteredSafeSearch: e.Result = stats.RSafeSearch @@ -837,6 +830,7 @@ func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dns case dnsfilter.FilteredBlockedService: e.Result = stats.RFiltered } + s.stats.Update(e) } @@ -895,6 +889,20 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { ctx.origQuestion = d.Req.Question[0] // resolve canonical name, not the original host name d.Req.Question[0].Name = dns.Fqdn(res.CanonName) + + } else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 { + + resp := s.makeResponse(req) + ptr := &dns.PTR{} + ptr.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + ptr.Ptr = res.ReverseHost + resp.Answer = append(resp.Answer, ptr) + d.Res = resp } return &res, err diff --git a/util/auto_hosts.go b/util/auto_hosts.go index f310528c..046e80e3 100644 --- a/util/auto_hosts.go +++ b/util/auto_hosts.go @@ -12,18 +12,21 @@ import ( "github.com/AdguardTeam/golibs/log" "github.com/fsnotify/fsnotify" + "github.com/miekg/dns" ) type onChangedT func() // AutoHosts - automatic DNS records type AutoHosts struct { - lock sync.Mutex // serialize access to table - table map[string][]net.IP // 'hostname -> IP' table - hostsFn string // path to the main hosts-file - hostsDirs []string // paths to OS-specific directories with hosts-files - watcher *fsnotify.Watcher // file and directory watcher object - updateChan chan bool // signal for 'updateLoop' goroutine + lock sync.Mutex // serialize access to table + table map[string][]net.IP // 'hostname -> IP' table + tableReverse map[string]string // "IP -> hostname" table for reverse lookup + + hostsFn string // path to the main hosts-file + hostsDirs []string // paths to OS-specific directories with hosts-files + watcher *fsnotify.Watcher // file and directory watcher object + updateChan chan bool // signal for 'updateLoop' goroutine onChanged onChangedT // notification to other modules } @@ -95,9 +98,43 @@ func (a *AutoHosts) Close() { _ = a.watcher.Close() } +// update table +func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) { + ips, ok := table[host] + if ok { + for _, ip := range ips { + if ip.Equal(ipAddr) { + // IP already exists: don't add duplicates + ok = false + break + } + } + if !ok { + ips = append(ips, ipAddr) + table[host] = ips + } + } else { + table[host] = []net.IP{ipAddr} + ok = true + } + if ok { + log.Debug("AutoHosts: added %s -> %s", ipAddr, host) + } +} + +// update "reverse" table +func (a *AutoHosts) updateTableRev(tableRev map[string]string, host string, ipAddr net.IP) { + ipStr := ipAddr.String() + _, ok := tableRev[ipStr] + if !ok { + tableRev[ipStr] = host + log.Debug("AutoHosts: added reverse-address %s -> %s", ipStr, host) + } +} + // Read IP-hostname pairs from file // Multiple hostnames per line (per one IP) is supported. -func (a *AutoHosts) load(table map[string][]net.IP, fn string) { +func (a *AutoHosts) load(table map[string][]net.IP, tableRev map[string]string, fn string) { f, err := os.Open(fn) if err != nil { log.Error("AutoHosts: %s", err) @@ -128,26 +165,8 @@ func (a *AutoHosts) load(table map[string][]net.IP, fn string) { if len(host) == 0 { break } - ips, ok := table[host] - if ok { - for _, ip := range ips { - if ip.Equal(ipAddr) { - // IP already exists: don't add duplicates - ok = false - break - } - } - if !ok { - ips = append(ips, ipAddr) - table[host] = ips - } - } else { - table[host] = []net.IP{ipAddr} - ok = true - } - if ok { - log.Debug("AutoHosts: added %s -> %s", ip, host) - } + a.updateTable(table, host, ipAddr) + a.updateTableRev(tableRev, host, ipAddr) } } } @@ -210,8 +229,9 @@ func (a *AutoHosts) updateLoop() { // updateHosts - loads system hosts func (a *AutoHosts) updateHosts() { table := make(map[string][]net.IP) + tableRev := make(map[string]string) - a.load(table, a.hostsFn) + a.load(table, tableRev, a.hostsFn) for _, dir := range a.hostsDirs { fis, err := ioutil.ReadDir(dir) @@ -223,12 +243,13 @@ func (a *AutoHosts) updateHosts() { } for _, fi := range fis { - a.load(table, dir+"/"+fi.Name()) + a.load(table, tableRev, dir+"/"+fi.Name()) } } a.lock.Lock() a.table = table + a.tableReverse = tableRev a.lock.Unlock() a.notify() @@ -236,7 +257,11 @@ func (a *AutoHosts) updateHosts() { // Process - get the list of IP addresses for the hostname // Return nil if not found -func (a *AutoHosts) Process(host string) []net.IP { +func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { + if qtype == dns.TypePTR { + return nil + } + var ipsCopy []net.IP a.lock.Lock() ips, _ := a.table[host] @@ -245,9 +270,100 @@ func (a *AutoHosts) Process(host string) []net.IP { copy(ipsCopy, ips) } a.lock.Unlock() + + log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy) return ipsCopy } +// convert character to hex number +func charToHex(n byte) int8 { + if n >= '0' && n <= '9' { + return int8(n) - '0' + } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { + return (int8(n) | 0x20) - 'a' + 10 + } + return -1 +} + +// parse IPv6 reverse address +func ipParseArpa6(s string) net.IP { + if len(s) != 63 { + return nil + } + ip6 := make(net.IP, 16) + + for i := 0; i != 64; i += 4 { + + // parse "0.1." + n := charToHex(s[i]) + n2 := charToHex(s[i+2]) + if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || + n < 0 || n2 < 0 { + return nil + } + + ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) + } + return ip6 +} + +// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1 +func ipReverse(ip net.IP) net.IP { + n := len(ip) + r := make(net.IP, n) + for i := 0; i != n; i++ { + r[i] = ip[n-i-1] + } + return r +} + +// Convert reversed ARPA address to a normal IP address +func dnsUnreverseAddr(s string) net.IP { + const arpaV4 = ".in-addr.arpa" + const arpaV6 = ".ip6.arpa" + + if strings.HasSuffix(s, arpaV4) { + ip := strings.TrimSuffix(s, arpaV4) + ip4 := net.ParseIP(ip).To4() + if ip4 == nil { + return nil + } + + return ipReverse(ip4) + + } else if strings.HasSuffix(s, arpaV6) { + ip := strings.TrimSuffix(s, arpaV6) + return ipParseArpa6(ip) + } + + return nil // unknown suffix +} + +// ProcessReverse - process PTR request +// Return "" if not found or an error occurred +func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { + if qtype != dns.TypePTR { + return "" + } + + ipReal := dnsUnreverseAddr(addr) + if ipReal == nil { + return "" // invalid IP in question + } + ipStr := ipReal.String() + + a.lock.Lock() + host := a.tableReverse[ipStr] + a.lock.Unlock() + + if len(host) == 0 { + return "" // not found + } + + log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host) + return host +} + // List - get the hosts table. Thread-safe. func (a *AutoHosts) List() map[string][]net.IP { table := make(map[string][]net.IP) diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go index 4aaa519e..3871b394 100644 --- a/util/auto_hosts_test.go +++ b/util/auto_hosts_test.go @@ -4,9 +4,11 @@ import ( "io/ioutil" "net" "os" + "strings" "testing" "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -28,19 +30,21 @@ func TestAutoHostsResolution(t *testing.T) { defer f.Close() _, _ = f.WriteString(" 127.0.0.1 host localhost \n") + _, _ = f.WriteString(" ::1 localhost \n") + ah.Init(f.Name()) // Update from the hosts file ah.updateHosts() // Existing host - ips := ah.Process("localhost") + ips := ah.Process("localhost", dns.TypeA) assert.NotNil(t, ips) assert.Equal(t, 1, len(ips)) assert.Equal(t, net.ParseIP("127.0.0.1"), ips[0]) // Unknown host - ips = ah.Process("newhost") + ips = ah.Process("newhost", dns.TypeA) assert.Nil(t, ips) // Test hosts file @@ -49,6 +53,14 @@ func TestAutoHostsResolution(t *testing.T) { assert.NotNil(t, ips) assert.Equal(t, 1, len(ips)) assert.Equal(t, "127.0.0.1", ips[0].String()) + + // Test PTR + a, _ := dns.ReverseAddr("127.0.0.1") + a = strings.TrimSuffix(a, ".") + assert.True(t, ah.ProcessReverse(a, dns.TypePTR) == "host") + a, _ = dns.ReverseAddr("::1") + a = strings.TrimSuffix(a, ".") + assert.True(t, ah.ProcessReverse(a, dns.TypePTR) == "localhost") } func TestAutoHostsFSNotify(t *testing.T) { @@ -67,7 +79,7 @@ func TestAutoHostsFSNotify(t *testing.T) { ah.updateHosts() // Unknown host - ips := ah.Process("newhost") + ips := ah.Process("newhost", dns.TypeA) assert.Nil(t, ips) // Stat monitoring for changes @@ -82,8 +94,18 @@ func TestAutoHostsFSNotify(t *testing.T) { time.Sleep(50 * time.Millisecond) // Check if we are notified about changes - ips = ah.Process("newhost") + ips = ah.Process("newhost", dns.TypeA) assert.NotNil(t, ips) assert.Equal(t, 1, len(ips)) assert.Equal(t, "127.0.0.2", ips[0].String()) } + +func TestIP(t *testing.T) { + assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa").Equal(net.ParseIP("127.0.0.1").To4())) + assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").Equal(net.ParseIP("::abcd:1234"))) + + assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa.") == nil) + assert.True(t, dnsUnreverseAddr(".0.0.127.in-addr.arpa") == nil) + assert.True(t, dnsUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa") == nil) + assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa") == nil) +}