diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index a3f9fa73..70abd660 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" + "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" @@ -277,17 +278,6 @@ func (s *Server) Resolve(host string) ([]net.IPAddr, error) { return s.internalProxy.LookupIPAddr(host) } -// RDNSExchanger is a resolver for clients' addresses. -type RDNSExchanger interface { - // Exchange tries to resolve the ip in a suitable way, i.e. either as local - // or as external. - Exchange(ip net.IP) (host string, err error) - - // ResolvesPrivatePTR returns true if the RDNSExchanger is able to - // resolve PTR requests for locally-served addresses. - ResolvesPrivatePTR() (ok bool) -} - const ( // ErrRDNSNoData is returned by [RDNSExchanger.Exchange] when the answer // section of response is either NODATA or has no PTR records. @@ -299,10 +289,10 @@ const ( ) // type check -var _ RDNSExchanger = (*Server)(nil) +var _ rdns.Exchanger = (*Server)(nil) -// Exchange implements the RDNSExchanger interface for *Server. -func (s *Server) Exchange(ip net.IP) (host string, err error) { +// Exchange implements the [rdns.Exchanger] interface for *Server. +func (s *Server) Exchange(ip netip.Addr) (host string, err error) { s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -310,7 +300,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { return "", nil } - arpa, err := netutil.IPToReversedAddr(ip) + arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) if err != nil { return "", fmt.Errorf("reversing ip: %w", err) } @@ -335,7 +325,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { } var resolver *proxy.Proxy - if s.privateNets.Contains(ip) { + if s.isPrivateIP(ip) { if !s.conf.UsePrivateRDNS { return "", nil } @@ -350,8 +340,12 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { return "", err } + return hostFromPTR(ctx.Res) +} + +// hostFromPTR returns domain name from the PTR response or error. +func hostFromPTR(resp *dns.Msg) (host string, err error) { // Distinguish between NODATA response and a failed request. - resp := ctx.Res if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { return "", fmt.Errorf( "received %s response: %w", @@ -370,12 +364,25 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { return "", ErrRDNSNoData } -// ResolvesPrivatePTR implements the RDNSExchanger interface for *Server. -func (s *Server) ResolvesPrivatePTR() (ok bool) { +// isPrivateIP returns true if the ip is private. +func (s *Server) isPrivateIP(ip netip.Addr) (ok bool) { + return s.privateNets.Contains(ip.AsSlice()) +} + +// ShouldResolveClient returns false if ip is a loopback address, or ip is +// private and resolving of private addresses is disabled. +func (s *Server) ShouldResolveClient(ip netip.Addr) (ok bool) { + if ip.IsLoopback() { + return false + } + + isPrivate := s.isPrivateIP(ip) + s.serverLock.RLock() defer s.serverLock.RUnlock() - return s.conf.UsePrivateRDNS + return s.conf.ResolveClients && + (s.conf.UsePrivateRDNS || !isPrivate) } // Start starts the DNS server. diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index f7ff57a3..705227a1 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1273,11 +1273,11 @@ func TestServer_Exchange(t *testing.T) { ) var ( - onesIP = net.IP{1, 1, 1, 1} - localIP = net.IP{192, 168, 1, 1} + onesIP = netip.MustParseAddr("1.1.1.1") + localIP = netip.MustParseAddr("192.168.1.1") ) - revExtIPv4, err := netutil.IPToReversedAddr(onesIP) + revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) require.NoError(t, err) extUpstream := &aghtest.UpstreamMock{ @@ -1290,7 +1290,7 @@ func TestServer_Exchange(t *testing.T) { }, } - revLocIPv4, err := netutil.IPToReversedAddr(localIP) + revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice()) require.NoError(t, err) locUpstream := &aghtest.UpstreamMock{ @@ -1330,7 +1330,7 @@ func TestServer_Exchange(t *testing.T) { want string wantErr error locUpstream upstream.Upstream - req net.IP + req netip.Addr }{{ name: "external_good", want: onesHost, @@ -1354,7 +1354,7 @@ func TestServer_Exchange(t *testing.T) { want: "", wantErr: ErrRDNSNoData, locUpstream: locUpstream, - req: net.IP{192, 168, 1, 2}, + req: netip.MustParseAddr("192.168.1.2"), }, { name: "invalid_answer", want: "", @@ -1396,3 +1396,57 @@ func TestServer_Exchange(t *testing.T) { assert.Empty(t, host) }) } + +func TestServer_ShouldResolveClient(t *testing.T) { + srv := &Server{ + privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + } + + testCases := []struct { + ip netip.Addr + want require.BoolAssertionFunc + name string + resolve bool + usePrivate bool + }{{ + name: "default", + ip: netip.MustParseAddr("1.1.1.1"), + want: require.True, + resolve: true, + usePrivate: true, + }, { + name: "no_rdns", + ip: netip.MustParseAddr("1.1.1.1"), + want: require.False, + resolve: false, + usePrivate: true, + }, { + name: "loopback", + ip: netip.MustParseAddr("127.0.0.1"), + want: require.False, + resolve: true, + usePrivate: true, + }, { + name: "private_resolve", + ip: netip.MustParseAddr("192.168.0.1"), + want: require.True, + resolve: true, + usePrivate: true, + }, { + name: "private_no_resolve", + ip: netip.MustParseAddr("192.168.0.1"), + want: require.False, + resolve: true, + usePrivate: false, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + srv.conf.ResolveClients = tc.resolve + srv.conf.UsePrivateRDNS = tc.usePrivate + + ok := srv.ShouldResolveClient(tc.ip) + tc.want(t, ok) + }) + } +} diff --git a/internal/home/dns.go b/internal/home/dns.go index 3a37f751..fbbda423 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -17,6 +17,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" + "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/proxy" @@ -167,30 +168,77 @@ func initDNSServer( return fmt.Errorf("dnsServer.Prepare: %w", err) } - if config.Clients.Sources.RDNS { - Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS) - } - + initRDNS() initWHOIS() return nil } +const ( + // defaultQueueSize is the size of queue of IPs for rDNS and WHOIS + // processing. + defaultQueueSize = 255 + + // defaultCacheSize is the maximum size of the cache for rDNS and WHOIS + // processing. It must be greater than zero. + defaultCacheSize = 10_000 + + // defaultIPTTL is the Time to Live duration for IP addresses cached by + // rDNS and WHOIS. + defaultIPTTL = 1 * time.Hour +) + +// initRDNS initializes the rDNS. +func initRDNS() { + Context.rdnsCh = make(chan netip.Addr, defaultQueueSize) + + // TODO(s.chzhen): Add ability to disable it on dns server configuration + // update in [dnsforward] package. + r := rdns.New(&rdns.Config{ + Exchanger: Context.dnsServer, + CacheSize: defaultCacheSize, + CacheTTL: defaultIPTTL, + }) + + go processRDNS(r) +} + +// processRDNS processes reverse DNS lookup queries. It is intended to be used +// as a goroutine. +func processRDNS(r rdns.Interface) { + defer log.OnPanic("rdns") + + for ip := range Context.rdnsCh { + ok := Context.dnsServer.ShouldResolveClient(ip) + if !ok { + continue + } + + host, changed := r.Process(ip) + if host == "" || !changed { + continue + } + + ok = Context.clients.AddHost(ip, host, ClientSourceRDNS) + if ok { + continue + } + + log.Debug( + "dns: can't set rdns info for client %q: already set with higher priority source", + ip, + ) + } +} + // initWHOIS initializes the WHOIS. // // TODO(s.chzhen): Consider making configurable. func initWHOIS() { const ( - // defaultQueueSize is the size of queue of IPs for WHOIS processing. - defaultQueueSize = 255 - // defaultTimeout is the timeout for WHOIS requests. defaultTimeout = 5 * time.Second - // defaultCacheSize is the maximum size of the cache. If it's zero, - // cache size is unlimited. - defaultCacheSize = 10_000 - // defaultMaxConnReadSize is an upper limit in bytes for reading from // net.Conn. defaultMaxConnReadSize = 64 * 1024 @@ -200,9 +248,6 @@ func initWHOIS() { // defaultMaxInfoLen is the maximum length of whois.Info fields. defaultMaxInfoLen = 250 - - // defaultIPTTL is the Time to Live duration for cached IP addresses. - defaultIPTTL = 1 * time.Hour ) Context.whoisCh = make(chan netip.Addr, defaultQueueSize) @@ -274,11 +319,7 @@ func onDNSRequest(pctx *proxy.DNSContext) { return } - srcs := config.Clients.Sources - if srcs.RDNS && !ip.IsLoopback() { - Context.rdns.Begin(ip) - } - + Context.rdnsCh <- ip Context.whoisCh <- ip } @@ -517,11 +558,7 @@ func startDNSServer() error { const topClientsNumber = 100 // the number of clients to get for _, ip := range Context.stats.TopClientsIP(topClientsNumber) { - srcs := config.Clients.Sources - if srcs.RDNS && !ip.IsLoopback() { - Context.rdns.Begin(ip) - } - + Context.rdnsCh <- ip Context.whoisCh <- ip } diff --git a/internal/home/home.go b/internal/home/home.go index fe820c3d..c8f218ec 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -56,7 +56,6 @@ type homeContext struct { stats stats.Interface // statistics module queryLog querylog.QueryLog // query log module dnsServer *dnsforward.Server // DNS module - rdns *RDNS // rDNS module dhcpServer dhcpd.Interface // DHCP module auth *Auth // HTTP authentication module filters *filtering.DNSFilter // DNS filtering module @@ -83,6 +82,9 @@ type homeContext struct { client *http.Client appSignalChannel chan os.Signal // Channel for receiving OS signals by the console app + // rdnsCh is the channel for receiving IPs for rDNS processing. + rdnsCh chan netip.Addr + // whoisCh is the channel for receiving IPs for WHOIS processing. whoisCh chan netip.Addr diff --git a/internal/home/rdns.go b/internal/home/rdns.go deleted file mode 100644 index cae7a9c3..00000000 --- a/internal/home/rdns.go +++ /dev/null @@ -1,143 +0,0 @@ -package home - -import ( - "encoding/binary" - "net/netip" - "sync/atomic" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" - "github.com/AdguardTeam/golibs/cache" - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" -) - -// RDNS resolves clients' addresses to enrich their metadata. -type RDNS struct { - exchanger dnsforward.RDNSExchanger - clients *clientsContainer - - // ipCh used to pass client's IP to rDNS workerLoop. - ipCh chan netip.Addr - - // ipCache caches the IP addresses to be resolved by rDNS. The resolved - // address stays here while it's inside clients. After leaving clients the - // address will be resolved once again. If the address couldn't be - // resolved, cache prevents further attempts to resolve it for some time. - ipCache cache.Cache - - // usePrivate stores the state of current private reverse-DNS resolving - // settings. - usePrivate atomic.Bool -} - -// Default AdGuard Home reverse DNS values. -const ( - revDNSCacheSize = 10000 - - // TODO(e.burkov): Make these values configurable. - revDNSCacheTTL = 24 * 60 * 60 - revDNSFailureCacheTTL = 1 * 60 * 60 - - revDNSQueueSize = 256 -) - -// NewRDNS creates and returns initialized RDNS. -func NewRDNS( - exchanger dnsforward.RDNSExchanger, - clients *clientsContainer, - usePrivate bool, -) (rDNS *RDNS) { - rDNS = &RDNS{ - exchanger: exchanger, - clients: clients, - ipCache: cache.New(cache.Config{ - EnableLRU: true, - MaxCount: revDNSCacheSize, - }), - ipCh: make(chan netip.Addr, revDNSQueueSize), - } - - rDNS.usePrivate.Store(usePrivate) - - go rDNS.workerLoop() - - return rDNS -} - -// ensurePrivateCache ensures that the state of the RDNS cache is consistent -// with the current private client RDNS resolving settings. -// -// TODO(e.burkov): Clearing cache each time this value changed is not a perfect -// approach since only unresolved locally-served addresses should be removed. -// Implement when improving the cache. -func (r *RDNS) ensurePrivateCache() { - usePrivate := r.exchanger.ResolvesPrivatePTR() - if r.usePrivate.CompareAndSwap(!usePrivate, usePrivate) { - r.ipCache.Clear() - } -} - -// isCached returns true if ip is already cached and not expired yet. It also -// caches it otherwise. -func (r *RDNS) isCached(ip netip.Addr) (ok bool) { - ipBytes := ip.AsSlice() - now := uint64(time.Now().Unix()) - if expire := r.ipCache.Get(ipBytes); len(expire) != 0 { - return binary.BigEndian.Uint64(expire) > now - } - - return false -} - -// cache caches the ip address for ttl seconds. -func (r *RDNS) cache(ip netip.Addr, ttl uint64) { - ipData := ip.AsSlice() - - ttlData := [8]byte{} - binary.BigEndian.PutUint64(ttlData[:], uint64(time.Now().Unix())+ttl) - - r.ipCache.Set(ipData, ttlData[:]) -} - -// Begin adds the ip to the resolving queue if it is not cached or already -// resolved. -func (r *RDNS) Begin(ip netip.Addr) { - r.ensurePrivateCache() - - if r.isCached(ip) || r.clients.clientSource(ip) > ClientSourceRDNS { - return - } - - select { - case r.ipCh <- ip: - log.Debug("rdns: %q added to queue", ip) - default: - log.Debug("rdns: queue is full") - } -} - -// workerLoop handles incoming IP addresses from ipChan and adds it into -// clients. -func (r *RDNS) workerLoop() { - defer log.OnPanic("rdns") - - for ip := range r.ipCh { - ttl := uint64(revDNSCacheTTL) - - host, err := r.exchanger.Exchange(ip.AsSlice()) - if err != nil { - log.Debug("rdns: resolving %q: %s", ip, err) - if errors.Is(err, dnsforward.ErrRDNSFailed) { - // Cache failure for a less time. - ttl = revDNSFailureCacheTTL - } - } - - r.cache(ip, ttl) - - if host != "" { - _ = r.clients.AddHost(ip, host, ClientSourceRDNS) - } - } -} diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go deleted file mode 100644 index 5582bf5b..00000000 --- a/internal/home/rdns_test.go +++ /dev/null @@ -1,264 +0,0 @@ -package home - -import ( - "bytes" - "encoding/binary" - "fmt" - "net" - "net/netip" - "sync" - "testing" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/cache" - "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/stringutil" - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRDNS_Begin(t *testing.T) { - aghtest.ReplaceLogLevel(t, log.DEBUG) - w := &bytes.Buffer{} - aghtest.ReplaceLogWriter(t, w) - - ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5") - - testCases := []struct { - cliIDIndex map[string]*Client - customChan chan netip.Addr - name string - wantLog string - ip netip.Addr - wantCacheHit int - wantCacheMiss int - }{{ - cliIDIndex: map[string]*Client{}, - customChan: nil, - name: "cached", - wantLog: "", - ip: ip1234, - wantCacheHit: 1, - wantCacheMiss: 0, - }, { - cliIDIndex: map[string]*Client{}, - customChan: nil, - name: "not_cached", - wantLog: "rdns: queue is full", - ip: ip1235, - wantCacheHit: 0, - wantCacheMiss: 1, - }, { - cliIDIndex: map[string]*Client{"1.2.3.5": {}}, - customChan: nil, - name: "already_in_clients", - wantLog: "", - ip: ip1235, - wantCacheHit: 0, - wantCacheMiss: 1, - }, { - cliIDIndex: map[string]*Client{}, - customChan: make(chan netip.Addr, 1), - name: "add_to_queue", - wantLog: `rdns: "1.2.3.5" added to queue`, - ip: ip1235, - wantCacheHit: 0, - wantCacheMiss: 1, - }} - - for _, tc := range testCases { - w.Reset() - - ipCache := cache.New(cache.Config{ - EnableLRU: true, - MaxCount: revDNSCacheSize, - }) - ttl := make([]byte, binary.Size(uint64(0))) - binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix())) - - rdns := &RDNS{ - ipCache: ipCache, - exchanger: &rDNSExchanger{ - ex: aghtest.NewErrorUpstream(), - }, - clients: &clientsContainer{ - list: map[string]*Client{}, - idIndex: tc.cliIDIndex, - ipToRC: map[netip.Addr]*RuntimeClient{}, - allTags: stringutil.NewSet(), - }, - } - ipCache.Clear() - ipCache.Set(net.IP{1, 2, 3, 4}, ttl) - - if tc.customChan != nil { - rdns.ipCh = tc.customChan - defer close(tc.customChan) - } - - t.Run(tc.name, func(t *testing.T) { - rdns.Begin(tc.ip) - assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit) - assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss) - assert.Contains(t, w.String(), tc.wantLog) - }) - } -} - -// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests. -type rDNSExchanger struct { - ex upstream.Upstream - usePrivate bool -} - -// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger. -func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) { - rev, err := netutil.IPToReversedAddr(ip) - if err != nil { - return "", fmt.Errorf("reversing ip: %w", err) - } - - req := &dns.Msg{ - Question: []dns.Question{{ - Name: dns.Fqdn(rev), - Qclass: dns.ClassINET, - Qtype: dns.TypePTR, - }}, - } - - resp, err := e.ex.Exchange(req) - if err != nil { - return "", err - } - - if len(resp.Answer) == 0 { - return "", nil - } - - return resp.Answer[0].Header().Name, nil -} - -// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger. -func (e *rDNSExchanger) ResolvesPrivatePTR() (ok bool) { - return e.usePrivate -} - -func TestRDNS_ensurePrivateCache(t *testing.T) { - data := []byte{1, 2, 3, 4} - - ipCache := cache.New(cache.Config{ - EnableLRU: true, - MaxCount: revDNSCacheSize, - }) - - ex := &rDNSExchanger{ - ex: aghtest.NewErrorUpstream(), - } - - rdns := &RDNS{ - ipCache: ipCache, - exchanger: ex, - } - - rdns.ipCache.Set(data, data) - require.NotZero(t, rdns.ipCache.Stats().Count) - - ex.usePrivate = !ex.usePrivate - - rdns.ensurePrivateCache() - require.Zero(t, rdns.ipCache.Stats().Count) -} - -func TestRDNS_WorkerLoop(t *testing.T) { - aghtest.ReplaceLogLevel(t, log.DEBUG) - w := &bytes.Buffer{} - aghtest.ReplaceLogWriter(t, w) - - localIP := netip.MustParseAddr("192.168.1.1") - revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice()) - require.NoError(t, err) - - revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93")) - require.NoError(t, err) - - locUpstream := &aghtest.UpstreamMock{ - OnAddress: func() (addr string) { return "local.upstream.example" }, - OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypePTR, revIPv4, "local.domain"), - aghtest.MatchedResponse(req, dns.TypePTR, revIPv6, "ipv6.domain"), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), - ), nil - }, - } - - errUpstream := aghtest.NewErrorUpstream() - - testCases := []struct { - ups upstream.Upstream - cliIP netip.Addr - wantLog string - name string - wantClientSource clientSource - }{{ - ups: locUpstream, - cliIP: localIP, - wantLog: "", - name: "all_good", - wantClientSource: ClientSourceRDNS, - }, { - ups: errUpstream, - cliIP: netip.MustParseAddr("192.168.1.2"), - wantLog: `rdns: resolving "192.168.1.2": test upstream error`, - name: "resolve_error", - wantClientSource: ClientSourceNone, - }, { - ups: locUpstream, - cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"), - wantLog: "", - name: "ipv6_good", - wantClientSource: ClientSourceRDNS, - }} - - for _, tc := range testCases { - w.Reset() - - cc := newClientsContainer(t) - ch := make(chan netip.Addr) - rdns := &RDNS{ - exchanger: &rDNSExchanger{ - ex: tc.ups, - }, - clients: cc, - ipCh: ch, - ipCache: cache.New(cache.Config{ - EnableLRU: true, - MaxCount: revDNSCacheSize, - }), - } - - t.Run(tc.name, func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - rdns.workerLoop() - wg.Done() - }() - - ch <- tc.cliIP - close(ch) - wg.Wait() - - if tc.wantLog != "" { - assert.Contains(t, w.String(), tc.wantLog) - } - - assert.Equal(t, tc.wantClientSource, cc.clientSource(tc.cliIP)) - }) - } -} diff --git a/internal/rdns/rdns.go b/internal/rdns/rdns.go new file mode 100644 index 00000000..e352da52 --- /dev/null +++ b/internal/rdns/rdns.go @@ -0,0 +1,132 @@ +// Package rdns processes reverse DNS lookup queries. +package rdns + +import ( + "net/netip" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/bluele/gcache" +) + +// Interface processes rDNS queries. +type Interface interface { + // Process makes rDNS request and returns domain name. changed indicates + // that domain name was updated since last request. + Process(ip netip.Addr) (host string, changed bool) +} + +// Empty is an empty [Inteface] implementation which does nothing. +type Empty struct{} + +// type check +var _ Interface = (*Empty)(nil) + +// Process implements the [Interface] interface for Empty. +func (Empty) Process(_ netip.Addr) (host string, changed bool) { + return "", false +} + +// Exchanger is a resolver for clients' addresses. +type Exchanger interface { + // Exchange tries to resolve the ip in a suitable way, i.e. either as local + // or as external. + Exchange(ip netip.Addr) (host string, err error) +} + +// Config is the configuration structure for Default. +type Config struct { + // Exchanger resolves IP addresses to domain names. + Exchanger Exchanger + + // CacheSize is the maximum size of the cache. It must be greater than + // zero. + CacheSize int + + // CacheTTL is the Time to Live duration for cached IP addresses. + CacheTTL time.Duration +} + +// Default is the default rDNS query processor. +type Default struct { + // cache is the cache containing IP addresses of clients. An active IP + // address is resolved once again after it expires. If IP address couldn't + // be resolved, it stays here for some time to prevent further attempts to + // resolve the same IP. + cache gcache.Cache + + // exchanger resolves IP addresses to domain names. + exchanger Exchanger + + // cacheTTL is the Time to Live duration for cached IP addresses. + cacheTTL time.Duration +} + +// New returns a new default rDNS query processor. conf must not be nil. +func New(conf *Config) (r *Default) { + return &Default{ + cache: gcache.New(conf.CacheSize).LRU().Build(), + exchanger: conf.Exchanger, + cacheTTL: conf.CacheTTL, + } +} + +// type check +var _ Interface = (*Default)(nil) + +// Process implements the [Interface] interface for Default. +func (r *Default) Process(ip netip.Addr) (host string, changed bool) { + fromCache, expired := r.findInCache(ip) + if !expired { + return fromCache, false + } + + host, err := r.exchanger.Exchange(ip) + if err != nil { + log.Debug("rdns: resolving %q: %s", ip, err) + } + + item := &cacheItem{ + expiry: time.Now().Add(r.cacheTTL), + host: host, + } + + err = r.cache.Set(ip, item) + if err != nil { + log.Debug("rdns: cache: adding item %q: %s", ip, err) + } + + return host, fromCache == "" || host != fromCache +} + +// findInCache finds domain name in the cache. expired is true if host is not +// valid anymore. +func (r *Default) findInCache(ip netip.Addr) (host string, expired bool) { + val, err := r.cache.Get(ip) + if err != nil { + if !errors.Is(err, gcache.KeyNotFoundError) { + log.Debug("rdns: cache: retrieving %q: %s", ip, err) + } + + return "", true + } + + item, ok := val.(*cacheItem) + if !ok { + log.Debug("rdns: cache: %q bad type %T", ip, val) + + return "", true + } + + return item.host, time.Now().After(item.expiry) +} + +// cacheItem represents an item that we will store in the cache. +type cacheItem struct { + // expiry is the time when cacheItem will expire. + expiry time.Time + + // host is the domain name of a runtime client. + host string +} diff --git a/internal/rdns/rdns_test.go b/internal/rdns/rdns_test.go new file mode 100644 index 00000000..8694eba3 --- /dev/null +++ b/internal/rdns/rdns_test.go @@ -0,0 +1,105 @@ +package rdns_test + +import ( + "net/netip" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/rdns" + "github.com/AdguardTeam/golibs/netutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeRDNSExchanger is a mock [rdns.Exchanger] implementation for tests. +type fakeRDNSExchanger struct { + OnExchange func(ip netip.Addr) (host string, err error) +} + +// type check +var _ rdns.Exchanger = (*fakeRDNSExchanger)(nil) + +// Exchange implements [rdns.Exchanger] interface for *fakeRDNSExchanger. +func (e *fakeRDNSExchanger) Exchange(ip netip.Addr) (host string, err error) { + return e.OnExchange(ip) +} + +func TestDefault_Process(t *testing.T) { + ip1 := netip.MustParseAddr("1.2.3.4") + revAddr1, err := netutil.IPToReversedAddr(ip1.AsSlice()) + require.NoError(t, err) + + ip2 := netip.MustParseAddr("4.3.2.1") + revAddr2, err := netutil.IPToReversedAddr(ip2.AsSlice()) + require.NoError(t, err) + + localIP := netip.MustParseAddr("192.168.0.1") + localRevAddr1, err := netutil.IPToReversedAddr(localIP.AsSlice()) + require.NoError(t, err) + + config := &rdns.Config{ + CacheSize: 100, + CacheTTL: time.Hour, + } + + testCases := []struct { + name string + addr netip.Addr + want string + }{{ + name: "first", + addr: ip1, + want: revAddr1, + }, { + name: "second", + addr: ip2, + want: revAddr2, + }, { + name: "empty", + addr: netip.MustParseAddr("0.0.0.0"), + want: "", + }, { + name: "private", + addr: localIP, + want: localRevAddr1, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hit := 0 + onExchange := func(ip netip.Addr) (host string, err error) { + hit++ + + switch ip { + case ip1: + return revAddr1, nil + case ip2: + return revAddr2, nil + case localIP: + return localRevAddr1, nil + default: + return "", nil + } + } + exchanger := &fakeRDNSExchanger{ + OnExchange: onExchange, + } + + config.Exchanger = exchanger + r := rdns.New(config) + + got, changed := r.Process(tc.addr) + require.True(t, changed) + + assert.Equal(t, tc.want, got) + assert.Equal(t, 1, hit) + + // From cache. + got, changed = r.Process(tc.addr) + require.False(t, changed) + + assert.Equal(t, tc.want, got) + assert.Equal(t, 1, hit) + }) + } +} diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 3409ed2b..16e589b3 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -177,6 +177,7 @@ run_linter gocognit --over 10\ ./internal/aghhttp/\ ./internal/aghio/\ ./internal/next/\ + ./internal/rdns/\ ./internal/tools/\ ./internal/version/\ ./internal/whois/\