diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index a84e9af6..10789d8e 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/client" @@ -131,14 +132,14 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []ne // Exchanger is a fake [rdns.Exchanger] implementation for tests. type Exchanger struct { - OnExchange func(ip netip.Addr) (host string, err error) + OnExchange func(ip netip.Addr) (host string, ttl time.Duration, err error) } // type check var _ rdns.Exchanger = (*Exchanger)(nil) // Exchange implements [rdns.Exchanger] interface for *Exchanger. -func (e *Exchanger) Exchange(ip netip.Addr) (host string, err error) { +func (e *Exchanger) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) { return e.OnExchange(ip) } diff --git a/internal/client/addrproc_test.go b/internal/client/addrproc_test.go index 5690be37..c6b847cd 100644 --- a/internal/client/addrproc_test.go +++ b/internal/client/addrproc_test.go @@ -104,8 +104,8 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) { panic("not implemented") }, Exchanger: &aghtest.Exchanger{ - OnExchange: func(ip netip.Addr) (host string, err error) { - return tc.host, tc.rdnsErr + OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) { + return tc.host, 0, tc.rdnsErr }, }, PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed), @@ -214,7 +214,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) { return whoisConn, nil }, Exchanger: &aghtest.Exchanger{ - OnExchange: func(_ netip.Addr) (host string, err error) { + OnExchange: func(_ netip.Addr) (_ string, _ time.Duration, _ error) { panic("not implemented") }, }, diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index a2f77950..730e88f8 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -316,13 +316,13 @@ const ( var _ rdns.Exchanger = (*Server)(nil) // Exchange implements the [rdns.Exchanger] interface for *Server. -func (s *Server) Exchange(ip netip.Addr) (host string, err error) { +func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) { s.serverLock.RLock() defer s.serverLock.RUnlock() arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) if err != nil { - return "", fmt.Errorf("reversing ip: %w", err) + return "", 0, fmt.Errorf("reversing ip: %w", err) } arpa = dns.Fqdn(arpa) @@ -348,7 +348,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) { var resolver *proxy.Proxy if s.privateNets.Contains(ip.AsSlice()) { if !s.conf.UsePrivateRDNS { - return "", nil + return "", 0, nil } resolver = s.localResolvers @@ -358,31 +358,47 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) { } if err = resolver.Resolve(dctx); err != nil { - return "", err + return "", 0, err } return hostFromPTR(dctx.Res) } // hostFromPTR returns domain name from the PTR response or error. -func hostFromPTR(resp *dns.Msg) (host string, err error) { +func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) { // Distinguish between NODATA response and a failed request. if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { - return "", fmt.Errorf( + return "", 0, fmt.Errorf( "received %s response: %w", dns.RcodeToString[resp.Rcode], ErrRDNSFailed, ) } + var ttlSec uint32 + for _, ans := range resp.Answer { ptr, ok := ans.(*dns.PTR) - if ok { - return strings.TrimSuffix(ptr.Ptr, "."), nil + if !ok { + continue + } + + if ptr.Hdr.Ttl > ttlSec { + host = ptr.Ptr + ttlSec = ptr.Hdr.Ttl } } - return "", ErrRDNSNoData + if host != "" { + // NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter + // case. + host = strings.TrimSuffix(host, ".") + ttl = time.Duration(ttlSec) * time.Second + + return host, ttl, nil + } + + return "", 0, ErrRDNSNoData } // Start starts the DNS server. diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index eb077aff..775a97b5 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1300,25 +1300,57 @@ func TestNewServer(t *testing.T) { } } +// doubleTTL is a helper function that returns a clone of DNS PTR with appended +// copy of first answer record with doubled TTL. +func doubleTTL(msg *dns.Msg) (resp *dns.Msg) { + if msg == nil { + return nil + } + + if len(msg.Answer) == 0 { + return msg + } + + rec := msg.Answer[0] + ptr, ok := rec.(*dns.PTR) + if !ok { + return msg + } + + clone := *ptr + clone.Hdr.Ttl *= 2 + msg.Answer = append(msg.Answer, &clone) + + return msg +} + func TestServer_Exchange(t *testing.T) { const ( onesHost = "one.one.one.one" + twosHost = "two.two.two.two" localDomainHost = "local.domain" + + defaultTTL = time.Second * 60 ) var ( onesIP = netip.MustParseAddr("1.1.1.1") + twosIP = netip.MustParseAddr("2.2.2.2") localIP = netip.MustParseAddr("192.168.1.1") ) - revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) + onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) + require.NoError(t, err) + + twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice()) require.NoError(t, err) extUpstream := &aghtest.UpstreamMock{ OnAddress: func() (addr string) { return "external.upstream.example" }, OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { return aghalg.Coalesce( - aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost), + aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost), + doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)), new(dns.Msg).SetRcode(req, dns.RcodeNameError), ), nil }, @@ -1358,47 +1390,61 @@ func TestServer_Exchange(t *testing.T) { srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) testCases := []struct { - name string - want string + req netip.Addr wantErr error locUpstream upstream.Upstream - req netip.Addr + name string + want string + wantTTL time.Duration }{{ name: "external_good", want: onesHost, wantErr: nil, locUpstream: nil, req: onesIP, + wantTTL: defaultTTL, }, { name: "local_good", want: localDomainHost, wantErr: nil, locUpstream: locUpstream, req: localIP, + wantTTL: defaultTTL, }, { name: "upstream_error", want: "", wantErr: aghtest.ErrUpstream, locUpstream: errUpstream, req: localIP, + wantTTL: 0, }, { name: "empty_answer_error", want: "", wantErr: ErrRDNSNoData, locUpstream: locUpstream, req: netip.MustParseAddr("192.168.1.2"), + wantTTL: 0, }, { name: "invalid_answer", want: "", wantErr: ErrRDNSNoData, locUpstream: nonPtrUpstream, req: localIP, + wantTTL: 0, }, { name: "refused", want: "", wantErr: ErrRDNSFailed, locUpstream: refusingUpstream, req: localIP, + wantTTL: 0, + }, { + name: "longest_ttl", + want: twosHost, + wantErr: nil, + locUpstream: nil, + req: twosIP, + wantTTL: defaultTTL * 2, }} for _, tc := range testCases { @@ -1412,17 +1458,18 @@ func TestServer_Exchange(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - host, eerr := srv.Exchange(tc.req) + host, ttl, eerr := srv.Exchange(tc.req) require.ErrorIs(t, eerr, tc.wantErr) assert.Equal(t, tc.want, host) + assert.Equal(t, tc.wantTTL, ttl) }) } t.Run("resolving_disabled", func(t *testing.T) { srv.conf.UsePrivateRDNS = false - host, eerr := srv.Exchange(localIP) + host, _, eerr := srv.Exchange(localIP) require.NoError(t, eerr) assert.Empty(t, host) diff --git a/internal/rdns/rdns.go b/internal/rdns/rdns.go index ff1998a5..b33e212c 100644 --- a/internal/rdns/rdns.go +++ b/internal/rdns/rdns.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/mathutil" "github.com/bluele/gcache" ) @@ -32,7 +33,7 @@ func (Empty) Process(_ netip.Addr) (host string, changed bool) { type Exchanger interface { // Exchange tries to resolve the ip in a suitable way, i.e. either as local // or as external. - Exchange(ip netip.Addr) (host string, err error) + Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) } // Config is the configuration structure for Default. @@ -82,13 +83,16 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) { return fromCache, false } - host, err := r.exchanger.Exchange(ip) + host, ttl, err := r.exchanger.Exchange(ip) if err != nil { log.Debug("rdns: resolving %q: %s", ip, err) } + // TODO(s.chzhen): Use built-in function max in Go 1.21. + ttl = mathutil.Max(ttl, r.cacheTTL) + item := &cacheItem{ - expiry: time.Now().Add(r.cacheTTL), + expiry: time.Now().Add(ttl), host: host, } diff --git a/internal/rdns/rdns_test.go b/internal/rdns/rdns_test.go index 642e4567..61130ec5 100644 --- a/internal/rdns/rdns_test.go +++ b/internal/rdns/rdns_test.go @@ -55,18 +55,18 @@ func TestDefault_Process(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { hit := 0 - onExchange := func(ip netip.Addr) (host string, err error) { + onExchange := func(ip netip.Addr) (host string, ttl time.Duration, err error) { hit++ switch ip { case ip1: - return revAddr1, nil + return revAddr1, 0, nil case ip2: - return revAddr2, nil + return revAddr2, 0, nil case localIP: - return localRevAddr1, nil + return localRevAddr1, 0, nil default: - return "", nil + return "", 0, nil } } exchanger := &aghtest.Exchanger{