Pull request 1938: AG-24132-rdns-ttl

Squashed commit of the following:

commit ba1e7b12cf7c0dc3ffab508d59c149f6c0930548
Merge: 8a94433ec ed86af582
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jul 25 13:43:25 2023 +0300

    Merge branch 'master' into AG-24132-rdns-ttl

commit 8a94433ec119d2158c166dd0222f57917908f3ad
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 24 19:30:21 2023 +0300

    all: imp docs

commit 4c1a3676b7be7ac4295c4e28550ddb6eb79a35d4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 24 13:13:34 2023 +0300

    all: add rdns ttl
This commit is contained in:
Stanislav Chzhen 2023-07-25 14:16:26 +03:00
parent ed86af582a
commit 996c6b3ee3
6 changed files with 97 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"net"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/client"
@ -131,14 +132,14 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []ne
// Exchanger is a fake [rdns.Exchanger] implementation for tests.
type Exchanger struct {
OnExchange func(ip netip.Addr) (host string, err error)
OnExchange func(ip netip.Addr) (host string, ttl time.Duration, err error)
}
// type check
var _ rdns.Exchanger = (*Exchanger)(nil)
// Exchange implements [rdns.Exchanger] interface for *Exchanger.
func (e *Exchanger) Exchange(ip netip.Addr) (host string, err error) {
func (e *Exchanger) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
return e.OnExchange(ip)
}

View File

@ -104,8 +104,8 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
panic("not implemented")
},
Exchanger: &aghtest.Exchanger{
OnExchange: func(ip netip.Addr) (host string, err error) {
return tc.host, tc.rdnsErr
OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) {
return tc.host, 0, tc.rdnsErr
},
},
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
@ -214,7 +214,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
return whoisConn, nil
},
Exchanger: &aghtest.Exchanger{
OnExchange: func(_ netip.Addr) (host string, err error) {
OnExchange: func(_ netip.Addr) (_ string, _ time.Duration, _ error) {
panic("not implemented")
},
},

View File

@ -316,13 +316,13 @@ const (
var _ rdns.Exchanger = (*Server)(nil)
// Exchange implements the [rdns.Exchanger] interface for *Server.
func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
return "", 0, fmt.Errorf("reversing ip: %w", err)
}
arpa = dns.Fqdn(arpa)
@ -348,7 +348,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
var resolver *proxy.Proxy
if s.privateNets.Contains(ip.AsSlice()) {
if !s.conf.UsePrivateRDNS {
return "", nil
return "", 0, nil
}
resolver = s.localResolvers
@ -358,31 +358,47 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
}
if err = resolver.Resolve(dctx); err != nil {
return "", err
return "", 0, err
}
return hostFromPTR(dctx.Res)
}
// hostFromPTR returns domain name from the PTR response or error.
func hostFromPTR(resp *dns.Msg) (host string, err error) {
func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
// Distinguish between NODATA response and a failed request.
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
return "", fmt.Errorf(
return "", 0, fmt.Errorf(
"received %s response: %w",
dns.RcodeToString[resp.Rcode],
ErrRDNSFailed,
)
}
var ttlSec uint32
for _, ans := range resp.Answer {
ptr, ok := ans.(*dns.PTR)
if ok {
return strings.TrimSuffix(ptr.Ptr, "."), nil
if !ok {
continue
}
if ptr.Hdr.Ttl > ttlSec {
host = ptr.Ptr
ttlSec = ptr.Hdr.Ttl
}
}
return "", ErrRDNSNoData
if host != "" {
// NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter
// case.
host = strings.TrimSuffix(host, ".")
ttl = time.Duration(ttlSec) * time.Second
return host, ttl, nil
}
return "", 0, ErrRDNSNoData
}
// Start starts the DNS server.

View File

@ -1300,25 +1300,57 @@ func TestNewServer(t *testing.T) {
}
}
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
// copy of first answer record with doubled TTL.
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
if msg == nil {
return nil
}
if len(msg.Answer) == 0 {
return msg
}
rec := msg.Answer[0]
ptr, ok := rec.(*dns.PTR)
if !ok {
return msg
}
clone := *ptr
clone.Hdr.Ttl *= 2
msg.Answer = append(msg.Answer, &clone)
return msg
}
func TestServer_Exchange(t *testing.T) {
const (
onesHost = "one.one.one.one"
twosHost = "two.two.two.two"
localDomainHost = "local.domain"
defaultTTL = time.Second * 60
)
var (
onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1")
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
require.NoError(t, err)
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost),
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil
},
@ -1358,47 +1390,61 @@ func TestServer_Exchange(t *testing.T) {
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
want string
req netip.Addr
wantErr error
locUpstream upstream.Upstream
req netip.Addr
name string
want string
wantTTL time.Duration
}{{
name: "external_good",
want: onesHost,
wantErr: nil,
locUpstream: nil,
req: onesIP,
wantTTL: defaultTTL,
}, {
name: "local_good",
want: localDomainHost,
wantErr: nil,
locUpstream: locUpstream,
req: localIP,
wantTTL: defaultTTL,
}, {
name: "upstream_error",
want: "",
wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "empty_answer_error",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: locUpstream,
req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0,
}, {
name: "invalid_answer",
want: "",
wantErr: ErrRDNSNoData,
locUpstream: nonPtrUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "refused",
want: "",
wantErr: ErrRDNSFailed,
locUpstream: refusingUpstream,
req: localIP,
wantTTL: 0,
}, {
name: "longest_ttl",
want: twosHost,
wantErr: nil,
locUpstream: nil,
req: twosIP,
wantTTL: defaultTTL * 2,
}}
for _, tc := range testCases {
@ -1412,17 +1458,18 @@ func TestServer_Exchange(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
host, eerr := srv.Exchange(tc.req)
host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr)
assert.Equal(t, tc.want, host)
assert.Equal(t, tc.wantTTL, ttl)
})
}
t.Run("resolving_disabled", func(t *testing.T) {
srv.conf.UsePrivateRDNS = false
host, eerr := srv.Exchange(localIP)
host, _, eerr := srv.Exchange(localIP)
require.NoError(t, eerr)
assert.Empty(t, host)

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache"
)
@ -32,7 +33,7 @@ func (Empty) Process(_ netip.Addr) (host string, changed bool) {
type Exchanger interface {
// Exchange tries to resolve the ip in a suitable way, i.e. either as local
// or as external.
Exchange(ip netip.Addr) (host string, err error)
Exchange(ip netip.Addr) (host string, ttl time.Duration, err error)
}
// Config is the configuration structure for Default.
@ -82,13 +83,16 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
return fromCache, false
}
host, err := r.exchanger.Exchange(ip)
host, ttl, err := r.exchanger.Exchange(ip)
if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err)
}
// TODO(s.chzhen): Use built-in function max in Go 1.21.
ttl = mathutil.Max(ttl, r.cacheTTL)
item := &cacheItem{
expiry: time.Now().Add(r.cacheTTL),
expiry: time.Now().Add(ttl),
host: host,
}

View File

@ -55,18 +55,18 @@ func TestDefault_Process(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
hit := 0
onExchange := func(ip netip.Addr) (host string, err error) {
onExchange := func(ip netip.Addr) (host string, ttl time.Duration, err error) {
hit++
switch ip {
case ip1:
return revAddr1, nil
return revAddr1, 0, nil
case ip2:
return revAddr2, nil
return revAddr2, 0, nil
case localIP:
return localRevAddr1, nil
return localRevAddr1, 0, nil
default:
return "", nil
return "", 0, nil
}
}
exchanger := &aghtest.Exchanger{