Pull request 1938: AG-24132-rdns-ttl

Squashed commit of the following:

commit ba1e7b12cf7c0dc3ffab508d59c149f6c0930548
Merge: 8a94433ec ed86af582
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jul 25 13:43:25 2023 +0300

    Merge branch 'master' into AG-24132-rdns-ttl

commit 8a94433ec119d2158c166dd0222f57917908f3ad
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 24 19:30:21 2023 +0300

    all: imp docs

commit 4c1a3676b7be7ac4295c4e28550ddb6eb79a35d4
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jul 24 13:13:34 2023 +0300

    all: add rdns ttl
This commit is contained in:
Stanislav Chzhen 2023-07-25 14:16:26 +03:00
parent ed86af582a
commit 996c6b3ee3
6 changed files with 97 additions and 29 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"net" "net"
"net/netip" "net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
@ -131,14 +132,14 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) (ips []ne
// Exchanger is a fake [rdns.Exchanger] implementation for tests. // Exchanger is a fake [rdns.Exchanger] implementation for tests.
type Exchanger struct { type Exchanger struct {
OnExchange func(ip netip.Addr) (host string, err error) OnExchange func(ip netip.Addr) (host string, ttl time.Duration, err error)
} }
// type check // type check
var _ rdns.Exchanger = (*Exchanger)(nil) var _ rdns.Exchanger = (*Exchanger)(nil)
// Exchange implements [rdns.Exchanger] interface for *Exchanger. // Exchange implements [rdns.Exchanger] interface for *Exchanger.
func (e *Exchanger) Exchange(ip netip.Addr) (host string, err error) { func (e *Exchanger) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
return e.OnExchange(ip) return e.OnExchange(ip)
} }

View File

@ -104,8 +104,8 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
panic("not implemented") panic("not implemented")
}, },
Exchanger: &aghtest.Exchanger{ Exchanger: &aghtest.Exchanger{
OnExchange: func(ip netip.Addr) (host string, err error) { OnExchange: func(ip netip.Addr) (host string, ttl time.Duration, err error) {
return tc.host, tc.rdnsErr return tc.host, 0, tc.rdnsErr
}, },
}, },
PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateSubnets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
@ -214,7 +214,7 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
return whoisConn, nil return whoisConn, nil
}, },
Exchanger: &aghtest.Exchanger{ Exchanger: &aghtest.Exchanger{
OnExchange: func(_ netip.Addr) (host string, err error) { OnExchange: func(_ netip.Addr) (_ string, _ time.Duration, _ error) {
panic("not implemented") panic("not implemented")
}, },
}, },

View File

@ -316,13 +316,13 @@ const (
var _ rdns.Exchanger = (*Server)(nil) var _ rdns.Exchanger = (*Server)(nil)
// Exchange implements the [rdns.Exchanger] interface for *Server. // Exchange implements the [rdns.Exchanger] interface for *Server.
func (s *Server) Exchange(ip netip.Addr) (host string, err error) { func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error) {
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", fmt.Errorf("reversing ip: %w", err) return "", 0, fmt.Errorf("reversing ip: %w", err)
} }
arpa = dns.Fqdn(arpa) arpa = dns.Fqdn(arpa)
@ -348,7 +348,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
var resolver *proxy.Proxy var resolver *proxy.Proxy
if s.privateNets.Contains(ip.AsSlice()) { if s.privateNets.Contains(ip.AsSlice()) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", nil return "", 0, nil
} }
resolver = s.localResolvers resolver = s.localResolvers
@ -358,31 +358,47 @@ func (s *Server) Exchange(ip netip.Addr) (host string, err error) {
} }
if err = resolver.Resolve(dctx); err != nil { if err = resolver.Resolve(dctx); err != nil {
return "", err return "", 0, err
} }
return hostFromPTR(dctx.Res) return hostFromPTR(dctx.Res)
} }
// hostFromPTR returns domain name from the PTR response or error. // hostFromPTR returns domain name from the PTR response or error.
func hostFromPTR(resp *dns.Msg) (host string, err error) { func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
// Distinguish between NODATA response and a failed request. // Distinguish between NODATA response and a failed request.
if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError { if resp.Rcode != dns.RcodeSuccess && resp.Rcode != dns.RcodeNameError {
return "", fmt.Errorf( return "", 0, fmt.Errorf(
"received %s response: %w", "received %s response: %w",
dns.RcodeToString[resp.Rcode], dns.RcodeToString[resp.Rcode],
ErrRDNSFailed, ErrRDNSFailed,
) )
} }
var ttlSec uint32
for _, ans := range resp.Answer { for _, ans := range resp.Answer {
ptr, ok := ans.(*dns.PTR) ptr, ok := ans.(*dns.PTR)
if ok { if !ok {
return strings.TrimSuffix(ptr.Ptr, "."), nil continue
}
if ptr.Hdr.Ttl > ttlSec {
host = ptr.Ptr
ttlSec = ptr.Hdr.Ttl
} }
} }
return "", ErrRDNSNoData if host != "" {
// NOTE: Don't use [aghnet.NormalizeDomain] to retain original letter
// case.
host = strings.TrimSuffix(host, ".")
ttl = time.Duration(ttlSec) * time.Second
return host, ttl, nil
}
return "", 0, ErrRDNSNoData
} }
// Start starts the DNS server. // Start starts the DNS server.

View File

@ -1300,25 +1300,57 @@ func TestNewServer(t *testing.T) {
} }
} }
// doubleTTL is a helper function that returns a clone of DNS PTR with appended
// copy of first answer record with doubled TTL.
func doubleTTL(msg *dns.Msg) (resp *dns.Msg) {
if msg == nil {
return nil
}
if len(msg.Answer) == 0 {
return msg
}
rec := msg.Answer[0]
ptr, ok := rec.(*dns.PTR)
if !ok {
return msg
}
clone := *ptr
clone.Hdr.Ttl *= 2
msg.Answer = append(msg.Answer, &clone)
return msg
}
func TestServer_Exchange(t *testing.T) { func TestServer_Exchange(t *testing.T) {
const ( const (
onesHost = "one.one.one.one" onesHost = "one.one.one.one"
twosHost = "two.two.two.two"
localDomainHost = "local.domain" localDomainHost = "local.domain"
defaultTTL = time.Second * 60
) )
var ( var (
onesIP = netip.MustParseAddr("1.1.1.1") onesIP = netip.MustParseAddr("1.1.1.1")
twosIP = netip.MustParseAddr("2.2.2.2")
localIP = netip.MustParseAddr("192.168.1.1") localIP = netip.MustParseAddr("192.168.1.1")
) )
revExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice()) onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
require.NoError(t, err)
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
require.NoError(t, err) require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{ extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" }, OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce( return aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, revExtIPv4, onesHost), aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil ), nil
}, },
@ -1358,47 +1390,61 @@ func TestServer_Exchange(t *testing.T) {
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct { testCases := []struct {
name string req netip.Addr
want string
wantErr error wantErr error
locUpstream upstream.Upstream locUpstream upstream.Upstream
req netip.Addr name string
want string
wantTTL time.Duration
}{{ }{{
name: "external_good", name: "external_good",
want: onesHost, want: onesHost,
wantErr: nil, wantErr: nil,
locUpstream: nil, locUpstream: nil,
req: onesIP, req: onesIP,
wantTTL: defaultTTL,
}, { }, {
name: "local_good", name: "local_good",
want: localDomainHost, want: localDomainHost,
wantErr: nil, wantErr: nil,
locUpstream: locUpstream, locUpstream: locUpstream,
req: localIP, req: localIP,
wantTTL: defaultTTL,
}, { }, {
name: "upstream_error", name: "upstream_error",
want: "", want: "",
wantErr: aghtest.ErrUpstream, wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream, locUpstream: errUpstream,
req: localIP, req: localIP,
wantTTL: 0,
}, { }, {
name: "empty_answer_error", name: "empty_answer_error",
want: "", want: "",
wantErr: ErrRDNSNoData, wantErr: ErrRDNSNoData,
locUpstream: locUpstream, locUpstream: locUpstream,
req: netip.MustParseAddr("192.168.1.2"), req: netip.MustParseAddr("192.168.1.2"),
wantTTL: 0,
}, { }, {
name: "invalid_answer", name: "invalid_answer",
want: "", want: "",
wantErr: ErrRDNSNoData, wantErr: ErrRDNSNoData,
locUpstream: nonPtrUpstream, locUpstream: nonPtrUpstream,
req: localIP, req: localIP,
wantTTL: 0,
}, { }, {
name: "refused", name: "refused",
want: "", want: "",
wantErr: ErrRDNSFailed, wantErr: ErrRDNSFailed,
locUpstream: refusingUpstream, locUpstream: refusingUpstream,
req: localIP, req: localIP,
wantTTL: 0,
}, {
name: "longest_ttl",
want: twosHost,
wantErr: nil,
locUpstream: nil,
req: twosIP,
wantTTL: defaultTTL * 2,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -1412,17 +1458,18 @@ func TestServer_Exchange(t *testing.T) {
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
host, eerr := srv.Exchange(tc.req) host, ttl, eerr := srv.Exchange(tc.req)
require.ErrorIs(t, eerr, tc.wantErr) require.ErrorIs(t, eerr, tc.wantErr)
assert.Equal(t, tc.want, host) assert.Equal(t, tc.want, host)
assert.Equal(t, tc.wantTTL, ttl)
}) })
} }
t.Run("resolving_disabled", func(t *testing.T) { t.Run("resolving_disabled", func(t *testing.T) {
srv.conf.UsePrivateRDNS = false srv.conf.UsePrivateRDNS = false
host, eerr := srv.Exchange(localIP) host, _, eerr := srv.Exchange(localIP)
require.NoError(t, eerr) require.NoError(t, eerr)
assert.Empty(t, host) assert.Empty(t, host)

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/mathutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
) )
@ -32,7 +33,7 @@ func (Empty) Process(_ netip.Addr) (host string, changed bool) {
type Exchanger interface { type Exchanger interface {
// Exchange tries to resolve the ip in a suitable way, i.e. either as local // Exchange tries to resolve the ip in a suitable way, i.e. either as local
// or as external. // or as external.
Exchange(ip netip.Addr) (host string, err error) Exchange(ip netip.Addr) (host string, ttl time.Duration, err error)
} }
// Config is the configuration structure for Default. // Config is the configuration structure for Default.
@ -82,13 +83,16 @@ func (r *Default) Process(ip netip.Addr) (host string, changed bool) {
return fromCache, false return fromCache, false
} }
host, err := r.exchanger.Exchange(ip) host, ttl, err := r.exchanger.Exchange(ip)
if err != nil { if err != nil {
log.Debug("rdns: resolving %q: %s", ip, err) log.Debug("rdns: resolving %q: %s", ip, err)
} }
// TODO(s.chzhen): Use built-in function max in Go 1.21.
ttl = mathutil.Max(ttl, r.cacheTTL)
item := &cacheItem{ item := &cacheItem{
expiry: time.Now().Add(r.cacheTTL), expiry: time.Now().Add(ttl),
host: host, host: host,
} }

View File

@ -55,18 +55,18 @@ func TestDefault_Process(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
hit := 0 hit := 0
onExchange := func(ip netip.Addr) (host string, err error) { onExchange := func(ip netip.Addr) (host string, ttl time.Duration, err error) {
hit++ hit++
switch ip { switch ip {
case ip1: case ip1:
return revAddr1, nil return revAddr1, 0, nil
case ip2: case ip2:
return revAddr2, nil return revAddr2, 0, nil
case localIP: case localIP:
return localRevAddr1, nil return localRevAddr1, 0, nil
default: default:
return "", nil return "", 0, nil
} }
} }
exchanger := &aghtest.Exchanger{ exchanger := &aghtest.Exchanger{