diff --git a/dhcpd/dhcpd.go b/dhcpd/dhcpd.go index f95a9867..0ec83370 100644 --- a/dhcpd/dhcpd.go +++ b/dhcpd/dhcpd.go @@ -92,7 +92,7 @@ type Server struct { conf ServerConfig // Called when the leases DB is modified - onLeaseChanged onLeaseChangedT + onLeaseChanged []onLeaseChangedT } // Print information about the available network interfaces @@ -146,14 +146,16 @@ func (s *Server) Init(config ServerConfig) error { // SetOnLeaseChanged - set callback func (s *Server) SetOnLeaseChanged(onLeaseChanged onLeaseChangedT) { - s.onLeaseChanged = onLeaseChanged + s.onLeaseChanged = append(s.onLeaseChanged, onLeaseChanged) } func (s *Server) notify(flags int) { - if s.onLeaseChanged == nil { + if len(s.onLeaseChanged) == 0 { return } - s.onLeaseChanged(flags) + for _, f := range s.onLeaseChanged { + f(flags) + } } // WriteDiskConfig - write configuration diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index e8d8a6b0..36f5445e 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/querylog" "github.com/AdguardTeam/AdGuardHome/stats" @@ -43,11 +44,15 @@ var webRegistered bool // // The zero Server is empty and ready for use. type Server struct { - dnsProxy *proxy.Proxy // DNS proxy instance - dnsFilter *dnsfilter.Dnsfilter // DNS filter instance - queryLog querylog.QueryLog // Query log instance - stats stats.Stats - access *accessCtx + dnsProxy *proxy.Proxy // DNS proxy instance + dnsFilter *dnsfilter.Dnsfilter // DNS filter instance + dhcpServer *dhcpd.Server // DHCP server instance (optional) + queryLog querylog.QueryLog // Query log instance + stats stats.Stats + access *accessCtx + + tablePTR map[string]string // "IP -> hostname" table for reverse lookup + tablePTRLock sync.Mutex // DNS proxy instance for internal usage // We don't Start() it and so no listen port is required. @@ -59,13 +64,27 @@ type Server struct { conf ServerConfig } +// DNSCreateParams - parameters for NewServer() +type DNSCreateParams struct { + DNSFilter *dnsfilter.Dnsfilter + Stats stats.Stats + QueryLog querylog.QueryLog + DHCPServer *dhcpd.Server +} + // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once -func NewServer(dnsFilter *dnsfilter.Dnsfilter, stats stats.Stats, queryLog querylog.QueryLog) *Server { +func NewServer(p DNSCreateParams) *Server { s := &Server{} - s.dnsFilter = dnsFilter - s.stats = stats - s.queryLog = queryLog + s.dnsFilter = p.DNSFilter + s.stats = p.Stats + s.queryLog = p.QueryLog + s.dhcpServer = p.DHCPServer + + if s.dhcpServer != nil { + s.dhcpServer.SetOnLeaseChanged(s.onDHCPLeaseChanged) + s.onDHCPLeaseChanged(dhcpd.LeaseChangedAdded) + } if runtime.GOARCH == "mips" || runtime.GOARCH == "mipsle" { // Use plain DNS on MIPS, encryption is too slow diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 42018265..773769ee 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -496,7 +497,7 @@ func TestBlockedCustomIP(t *testing.T) { c := dnsfilter.Config{} f := dnsfilter.New(&c, filters) - s := NewServer(f, nil, nil) + s := NewServer(DNSCreateParams{DNSFilter: f}) conf := ServerConfig{} conf.UDPListenAddr = &net.UDPAddr{Port: 0} conf.TCPListenAddr = &net.TCPAddr{Port: 0} @@ -648,7 +649,7 @@ func TestRewrite(t *testing.T) { } f := dnsfilter.New(&c, nil) - s := NewServer(f, nil, nil) + s := NewServer(DNSCreateParams{DNSFilter: f}) conf := ServerConfig{} conf.UDPListenAddr = &net.UDPAddr{Port: 0} conf.TCPListenAddr = &net.TCPAddr{Port: 0} @@ -705,7 +706,7 @@ func createTestServer(t *testing.T) *Server { c.CacheTime = 30 f := dnsfilter.New(&c, filters) - s := NewServer(f, nil, nil) + s := NewServer(DNSCreateParams{DNSFilter: f}) s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} s.conf.UpstreamDNS = []string{"8.8.8.8:53", "8.8.4.4:53"} @@ -1012,3 +1013,39 @@ func TestMatchDNSName(t *testing.T) { assert.True(t, !matchDNSName(dnsNames, "")) assert.True(t, !matchDNSName(dnsNames, "*.host2")) } + +func TestPTRResponse(t *testing.T) { + dhcp := &dhcpd.Server{} + dhcp.IPpool = make(map[[4]byte]net.HardwareAddr) + + c := dnsfilter.Config{} + f := dnsfilter.New(&c, nil) + s := NewServer(DNSCreateParams{DNSFilter: f, DHCPServer: dhcp}) + s.conf.UDPListenAddr = &net.UDPAddr{Port: 0} + s.conf.TCPListenAddr = &net.TCPAddr{Port: 0} + s.conf.UpstreamDNS = []string{"127.0.0.1:53"} + s.conf.FilteringConfig.ProtectionEnabled = true + err := s.Prepare(nil) + assert.True(t, err == nil) + assert.Nil(t, s.Start()) + + l := dhcpd.Lease{} + l.IP = net.ParseIP("127.0.0.1").To4() + l.HWAddr, _ = net.ParseMAC("aa:aa:aa:aa:aa:aa") + l.Hostname = "localhost" + dhcp.AddStaticLease(l) + + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + req := createTestMessage("1.0.0.127.in-addr.arpa.") + req.Question[0].Qtype = dns.TypePTR + + resp, err := dns.Exchange(req, addr.String()) + assert.Nil(t, err) + assert.Equal(t, 1, len(resp.Answer)) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) + assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + ptr := resp.Answer[0].(*dns.PTR) + assert.Equal(t, "localhost.", ptr.Ptr) + + s.Close() +} diff --git a/dnsforward/handle_dns.go b/dnsforward/handle_dns.go index 87230e9b..462f3750 100644 --- a/dnsforward/handle_dns.go +++ b/dnsforward/handle_dns.go @@ -1,9 +1,12 @@ package dnsforward import ( + "strings" "time" + "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" @@ -39,6 +42,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { type modProcessFunc func(ctx *dnsContext) int mods := []modProcessFunc{ processInitial, + processInternalIPAddrs, processFilteringBeforeRequest, processUpstream, processDNSSECAfterResponse, @@ -88,11 +92,82 @@ func processInitial(ctx *dnsContext) int { return resultDone } +func (s *Server) onDHCPLeaseChanged(flags int) { + switch flags { + case dhcpd.LeaseChangedAdded, + dhcpd.LeaseChangedAddedStatic, + dhcpd.LeaseChangedRemovedStatic: + // + default: + return + } + + m := make(map[string]string) + ll := s.dhcpServer.Leases(dhcpd.LeasesAll) + for _, l := range ll { + if len(l.Hostname) == 0 { + continue + } + m[l.IP.String()] = l.Hostname + } + log.Debug("DNS: added %d PTR entries from DHCP", len(m)) + s.tablePTRLock.Lock() + s.tablePTR = m + s.tablePTRLock.Unlock() +} + +// Respond to PTR requests if the target IP address is leased by our DHCP server +func processInternalIPAddrs(ctx *dnsContext) int { + s := ctx.srv + req := ctx.proxyCtx.Req + if req.Question[0].Qtype != dns.TypePTR { + return resultDone + } + + arpa := req.Question[0].Name + arpa = strings.TrimSuffix(arpa, ".") + arpa = strings.ToLower(arpa) + ip := util.DNSUnreverseAddr(arpa) + if ip == nil { + return resultDone + } + + s.tablePTRLock.Lock() + if s.tablePTR == nil { + s.tablePTRLock.Unlock() + return resultDone + } + host, ok := s.tablePTR[ip.String()] + s.tablePTRLock.Unlock() + if !ok { + return resultDone + } + + log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host) + + resp := s.makeResponse(req) + ptr := &dns.PTR{} + ptr.Hdr = dns.RR_Header{ + Name: req.Question[0].Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } + ptr.Ptr = host + "." + resp.Answer = append(resp.Answer, ptr) + ctx.proxyCtx.Res = resp + return resultDone +} + // Apply filtering logic func processFilteringBeforeRequest(ctx *dnsContext) int { s := ctx.srv d := ctx.proxyCtx + if d.Res != nil { + return resultDone // response is already set - nothing to do + } + s.RLock() // Synchronize access to s.dnsFilter so it won't be suddenly uninitialized while in use. // This could happen after proxy server has been stopped, but its workers are not yet exited. diff --git a/home/dns.go b/home/dns.go index db05f99a..a5547627 100644 --- a/home/dns.go +++ b/home/dns.go @@ -61,7 +61,13 @@ func initDNSServer() error { filterConf.HTTPRegister = httpRegister Context.dnsFilter = dnsfilter.New(&filterConf, nil) - Context.dnsServer = dnsforward.NewServer(Context.dnsFilter, Context.stats, Context.queryLog) + p := dnsforward.DNSCreateParams{ + DNSFilter: Context.dnsFilter, + Stats: Context.stats, + QueryLog: Context.queryLog, + DHCPServer: Context.dhcpServer, + } + Context.dnsServer = dnsforward.NewServer(p) dnsConfig := generateServerConfig() err = Context.dnsServer.Prepare(&dnsConfig) if err != nil { diff --git a/home/whois_test.go b/home/whois_test.go index 31e6aba2..3ea73c53 100644 --- a/home/whois_test.go +++ b/home/whois_test.go @@ -9,7 +9,7 @@ import ( func prepareTestDNSServer() error { config.DNS.Port = 1234 - Context.dnsServer = dnsforward.NewServer(nil, nil, nil) + Context.dnsServer = dnsforward.NewServer(dnsforward.DNSCreateParams{}) conf := &dnsforward.ServerConfig{} conf.UpstreamDNS = []string{"8.8.8.8"} return Context.dnsServer.Prepare(conf) diff --git a/util/auto_hosts.go b/util/auto_hosts.go index 34a979da..b12acd81 100644 --- a/util/auto_hosts.go +++ b/util/auto_hosts.go @@ -296,70 +296,6 @@ func (a *AutoHosts) Process(host string, qtype uint16) []net.IP { return ipsCopy } -// convert character to hex number -func charToHex(n byte) int8 { - if n >= '0' && n <= '9' { - return int8(n) - '0' - } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { - return (int8(n) | 0x20) - 'a' + 10 - } - return -1 -} - -// parse IPv6 reverse address -func ipParseArpa6(s string) net.IP { - if len(s) != 63 { - return nil - } - ip6 := make(net.IP, 16) - - for i := 0; i != 64; i += 4 { - - // parse "0.1." - n := charToHex(s[i]) - n2 := charToHex(s[i+2]) - if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || - n < 0 || n2 < 0 { - return nil - } - - ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) - } - return ip6 -} - -// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1 -func ipReverse(ip net.IP) net.IP { - n := len(ip) - r := make(net.IP, n) - for i := 0; i != n; i++ { - r[i] = ip[n-i-1] - } - return r -} - -// Convert reversed ARPA address to a normal IP address -func dnsUnreverseAddr(s string) net.IP { - const arpaV4 = ".in-addr.arpa" - const arpaV6 = ".ip6.arpa" - - if strings.HasSuffix(s, arpaV4) { - ip := strings.TrimSuffix(s, arpaV4) - ip4 := net.ParseIP(ip).To4() - if ip4 == nil { - return nil - } - - return ipReverse(ip4) - - } else if strings.HasSuffix(s, arpaV6) { - ip := strings.TrimSuffix(s, arpaV6) - return ipParseArpa6(ip) - } - - return nil // unknown suffix -} - // ProcessReverse - process PTR request // Return "" if not found or an error occurred func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { @@ -367,7 +303,7 @@ func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string { return "" } - ipReal := dnsUnreverseAddr(addr) + ipReal := DNSUnreverseAddr(addr) if ipReal == nil { return "" // invalid IP in question } diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go index 322e7e9c..ea2e43ad 100644 --- a/util/auto_hosts_test.go +++ b/util/auto_hosts_test.go @@ -104,11 +104,13 @@ func TestAutoHostsFSNotify(t *testing.T) { } func TestIP(t *testing.T) { - assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa").Equal(net.ParseIP("127.0.0.1").To4())) - assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").Equal(net.ParseIP("::abcd:1234"))) + assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String()) + assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) + assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) - assert.True(t, dnsUnreverseAddr("1.0.0.127.in-addr.arpa.") == nil) - assert.True(t, dnsUnreverseAddr(".0.0.127.in-addr.arpa") == nil) - assert.True(t, dnsUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa") == nil) - assert.True(t, dnsUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa") == nil) + assert.Nil(t, DNSUnreverseAddr("1.0.0.127.in-addr.arpa.")) + assert.Nil(t, DNSUnreverseAddr(".0.0.127.in-addr.arpa")) + assert.Nil(t, DNSUnreverseAddr(".3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa")) + assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0..ip6.arpa")) + assert.Nil(t, DNSUnreverseAddr("4.3.2.1.d.c.b. .0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa")) } diff --git a/util/dns.go b/util/dns.go new file mode 100644 index 00000000..aaf51d4d --- /dev/null +++ b/util/dns.go @@ -0,0 +1,70 @@ +package util + +import ( + "net" + "strings" +) + +// convert character to hex number +func charToHex(n byte) int8 { + if n >= '0' && n <= '9' { + return int8(n) - '0' + } else if (n|0x20) >= 'a' && (n|0x20) <= 'f' { + return (int8(n) | 0x20) - 'a' + 10 + } + return -1 +} + +// parse IPv6 reverse address +func ipParseArpa6(s string) net.IP { + if len(s) != 63 { + return nil + } + ip6 := make(net.IP, 16) + + for i := 0; i != 64; i += 4 { + + // parse "0.1." + n := charToHex(s[i]) + n2 := charToHex(s[i+2]) + if s[i+1] != '.' || (i != 60 && s[i+3] != '.') || + n < 0 || n2 < 0 { + return nil + } + + ip6[16-i/4-1] = byte(n2<<4) | byte(n&0x0f) + } + return ip6 +} + +// ipReverse - reverse IP address: 1.0.0.127 -> 127.0.0.1 +func ipReverse(ip net.IP) net.IP { + n := len(ip) + r := make(net.IP, n) + for i := 0; i != n; i++ { + r[i] = ip[n-i-1] + } + return r +} + +// DNSUnreverseAddr - convert reversed ARPA address to a normal IP address +func DNSUnreverseAddr(s string) net.IP { + const arpaV4 = ".in-addr.arpa" + const arpaV6 = ".ip6.arpa" + + if strings.HasSuffix(s, arpaV4) { + ip := strings.TrimSuffix(s, arpaV4) + ip4 := net.ParseIP(ip).To4() + if ip4 == nil { + return nil + } + + return ipReverse(ip4) + + } else if strings.HasSuffix(s, arpaV6) { + ip := strings.TrimSuffix(s, arpaV6) + return ipParseArpa6(ip) + } + + return nil // unknown suffix +}