cherry-pick: filtering: restore rewrite behavior with other question types

Updates #4008.

Squashed commit of the following:

commit babbc29331cfc2603c0c3b0987f5ba926690ec3e
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Fri Dec 24 18:46:20 2021 +0300

    filtering: restore rewrite behavior with other question types
This commit is contained in:
Ainar Garipov 2021-12-24 20:14:36 +03:00 committed by Ainar Garipov
parent 9f36e57c1e
commit 84c9085516
5 changed files with 172 additions and 131 deletions

View File

@ -36,9 +36,12 @@ and this project adheres to
### Fixed
- Legacy DNS rewrites responding from upstream when a request other than `A` or
`AAAA` is received ([#4008]).
- Panic on port availability check during installation ([#3987]).
[#3987]: https://github.com/AdguardTeam/AdGuardHome/issues/3987
[#4008]: https://github.com/AdguardTeam/AdGuardHome/issues/4008

View File

@ -507,61 +507,76 @@ func (d *DNSFilter) matchSysHostsIntl(
return res, nil
}
// Process rewrites table
// . Find CNAME for a domain name (exact match or by wildcard)
// . if found and CNAME equals to domain name - this is an exception; exit
// . if found, set domain name to canonical name
// . repeat for the new domain name (Note: we return only the last CNAME)
// . Find A or AAAA record for a domain name (exact match or by wildcard)
// . if found, set IP addresses (IPv4 or IPv6 depending on qtype) in Result.IPList array
// processRewrites performs filtering based on the legacy rewrite records.
//
// Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host,
// this query isn't filtered. If it's different, repeat the process for the new
// CNAME, breaking loops in the process.
//
// Secondly, it finds A or AAAA rewrites for host and, if found, sets res.IPList
// accordingly. If the found rewrite has a special value of "A" or "AAAA", the
// result is an exception.
func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
d.confLock.RLock()
defer d.confLock.RUnlock()
rr := findRewrites(d.Rewrites, host, qtype)
if len(rr) != 0 {
res.Reason = Rewritten
rewrites, matched := findRewrites(d.Rewrites, host, qtype)
if !matched {
return Result{}
}
res.Reason = Rewritten
cnames := stringutil.NewSet()
origHost := host
for len(rr) != 0 && rr[0].Type == dns.TypeCNAME {
log.Debug("rewrite: CNAME for %s is %s", host, rr[0].Answer)
for matched && len(rewrites) > 0 && rewrites[0].Type == dns.TypeCNAME {
rwAns := rewrites[0].Answer
if host == rr[0].Answer { // "host == CNAME" is an exception
log.Debug("rewrite: cname for %s is %s", host, rwAns)
if host == rwAns {
// Rewrite of a domain onto itself is an exception rule.
res.Reason = NotFilteredNotFound
return res
}
host = rr[0].Answer
host = rwAns
if cnames.Has(host) {
log.Info("rewrite: breaking CNAME redirection loop: %s. Question: %s", host, origHost)
log.Info("rewrite: cname loop for %q on %q", origHost, host)
return res
}
cnames.Add(host)
res.CanonName = rr[0].Answer
rr = findRewrites(d.Rewrites, host, qtype)
res.CanonName = host
rewrites, matched = findRewrites(d.Rewrites, host, qtype)
}
for _, r := range rr {
if r.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if r.IP == nil { // IP exception
res.Reason = NotFilteredNotFound
return res
}
res.IPList = append(res.IPList, r.IP)
log.Debug("rewrite: A/AAAA for %s is %s", host, r.IP)
}
}
setRewriteResult(&res, host, rewrites, qtype)
return res
}
// setRewriteResult sets the Reason or IPList of res if necessary. res must not
// be nil.
func setRewriteResult(res *Result, host string, rewrites []RewriteEntry, qtype uint16) {
for _, rw := range rewrites {
if rw.Type == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
if rw.IP == nil {
// "A"/"AAAA" exception: allow getting from upstream.
res.Reason = NotFilteredNotFound
return
}
res.IPList = append(res.IPList, rw.IP)
log.Debug("rewrite: a/aaaa for %s is %s", host, rw.IP)
}
}
}
// matchBlockedServicesRules checks the host against the blocked services rules
// in settings, if any. The err is always nil, it is only there to make this
// a valid hostChecker function.

View File

@ -18,12 +18,15 @@ import (
type RewriteEntry struct {
// Domain is the domain for which this rewrite should work.
Domain string `yaml:"domain"`
// Answer is the IP address, canonical name, or one of the special
// values: "A" or "AAAA".
Answer string `yaml:"answer"`
// IP is the IP address that should be used in the response if Type is
// A or AAAA.
IP net.IP `yaml:"-"`
// Type is the DNS record type: A, AAAA, or CNAME.
Type uint16 `yaml:"-"`
}
@ -143,39 +146,46 @@ func (d *DNSFilter) prepareRewrites() {
}
}
// findRewrites returns the list of matched rewrite entries. The priority is:
// CNAME, then A and AAAA; exact, then wildcard. If the host is matched
// exactly, wildcard entries aren't returned. If the host matched by wildcards,
// return the most specific for the question type.
func findRewrites(entries []RewriteEntry, host string, qtype uint16) (matched []RewriteEntry) {
rr := rewritesSorted{}
// findRewrites returns the list of matched rewrite entries. If rewrites are
// empty, but matched is true, the domain is found among the rewrite rules but
// not for this question type.
//
// The result priority is: CNAME, then A and AAAA; exact, then wildcard. If the
// host is matched exactly, wildcard entries aren't returned. If the host
// matched by wildcards, return the most specific for the question type.
func findRewrites(
entries []RewriteEntry,
host string,
qtype uint16,
) (rewrites []RewriteEntry, matched bool) {
for _, e := range entries {
if e.Domain != host && !matchDomainWildcard(host, e.Domain) {
continue
}
matched = true
if e.matchesQType(qtype) {
rr = append(rr, e)
rewrites = append(rewrites, e)
}
}
if len(rr) == 0 {
return nil
if len(rewrites) == 0 {
return nil, matched
}
sort.Sort(rr)
sort.Sort(rewritesSorted(rewrites))
for i, r := range rr {
for i, r := range rewrites {
if isWildcard(r.Domain) {
// Don't use rr[:0], because we need to return at least
// one item here.
rr = rr[:max(1, i)]
// Don't use rewrites[:0], because we need to return at least one
// item here.
rewrites = rewrites[:max(1, i)]
break
}
}
return rr
return rewrites, matched
}
func max(a, b int) int {
@ -230,8 +240,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
d.confLock.Lock()
d.Config.Rewrites = append(d.Config.Rewrites, ent)
d.confLock.Unlock()
log.Debug("Rewrites: added element: %s -> %s [%d]",
ent.Domain, ent.Answer, len(d.Config.Rewrites))
log.Debug("rewrite: added element: %s -> %s [%d]", ent.Domain, ent.Answer, len(d.Config.Rewrites))
d.Config.ConfigModified()
}
@ -253,9 +262,11 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
d.confLock.Lock()
for _, ent := range d.Config.Rewrites {
if ent.equal(entDel) {
log.Debug("Rewrites: removed element: %s -> %s", ent.Domain, ent.Answer)
log.Debug("rewrite: removed element: %s -> %s", ent.Domain, ent.Answer)
continue
}
arr = append(arr, ent)
}
d.Config.Rewrites = arr

View File

@ -70,94 +70,101 @@ func TestRewrites(t *testing.T) {
d.prepareRewrites()
testCases := []struct {
name string
host string
wantCName string
wantVals []net.IP
dtyp uint16
name string
host string
wantCName string
wantIPs []net.IP
wantReason Reason
dtyp uint16
}{{
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantVals: nil,
dtyp: dns.TypeA,
name: "not_filtered_not_found",
host: "hoost.com",
wantCName: "",
wantIPs: nil,
wantReason: NotFilteredNotFound,
dtyp: dns.TypeA,
}, {
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "rewritten_a",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantVals: []net.IP{net.ParseIP("1:2:3::4")},
dtyp: dns.TypeAAAA,
name: "rewritten_aaaa",
host: "www.host.com",
wantCName: "host.com",
wantIPs: []net.IP{net.ParseIP("1:2:3::4")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_match",
host: "abc.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 4}},
dtyp: dns.TypeA,
name: "wildcard_override",
host: "a.host.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 4}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantVals: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "wildcard_cname_interaction",
host: "www.host2.com",
wantCName: "host.com",
wantIPs: []net.IP{{1, 2, 3, 4}, {1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantVals: []net.IP{{0, 0, 0, 0}},
dtyp: dns.TypeA,
name: "two_cnames",
host: "b.host.com",
wantCName: "somehost.com",
wantIPs: []net.IP{{0, 0, 0, 0}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantVals: []net.IP{{1, 2, 3, 5}},
dtyp: dns.TypeA,
name: "two_cnames_and_wildcard",
host: "b.host3.com",
wantCName: "x.host.com",
wantIPs: []net.IP{{1, 2, 3, 5}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantVals: []net.IP{net.ParseIP("1234::5678")},
dtyp: dns.TypeAAAA,
name: "issue3343",
host: "www.hostboth.com",
wantCName: "",
wantIPs: []net.IP{net.ParseIP("1234::5678")},
wantReason: Rewritten,
dtyp: dns.TypeAAAA,
}, {
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantVals: []net.IP{{1, 2, 3, 7}},
dtyp: dns.TypeA,
name: "issue3351",
host: "bighost.com",
wantCName: "",
wantIPs: []net.IP{{1, 2, 3, 7}},
wantReason: Rewritten,
dtyp: dns.TypeA,
}, {
name: "issue4008",
host: "somehost.com",
wantCName: "",
wantIPs: nil,
wantReason: Rewritten,
dtyp: dns.TypeHTTPS,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
valsNum := len(tc.wantVals)
r := d.processRewrites(tc.host, tc.dtyp)
if valsNum == 0 {
assert.Equal(t, NotFilteredNotFound, r.Reason)
return
}
require.Equalf(t, Rewritten, r.Reason, "got %s", r.Reason)
require.Equalf(t, tc.wantReason, r.Reason, "got %s", r.Reason)
if tc.wantCName != "" {
assert.Equal(t, tc.wantCName, r.CanonName)
}
require.Len(t, r.IPList, valsNum)
for i, ip := range tc.wantVals {
assert.Equal(t, ip, r.IPList[i])
}
assert.Equal(t, tc.wantIPs, r.IPList)
})
}
}
@ -229,15 +236,17 @@ func TestRewritesExceptionCNAME(t *testing.T) {
host string
want net.IP
}{{
name: "match_sub-domain",
name: "match_subdomain",
host: "my.host.com",
want: net.IP{2, 2, 2, 2},
}, {
name: "exception_cname",
host: "sub.host.com",
want: nil,
}, {
name: "exception_wildcard",
host: "my.sub.host.com",
want: nil,
}}
for _, tc := range testCases {

View File

@ -22,16 +22,17 @@ import (
// getAddrsResponse is the response for /install/get_addresses endpoint.
type getAddrsResponse struct {
Interfaces map[string]*aghnet.NetInterface `json:"interfaces"`
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces map[string]*aghnet.NetInterface `json:"interfaces"`
}
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponse{}
data.WebPort = defaultPortHTTP
data.DNSPort = defaultPortDNS
data := getAddrsResponse{
WebPort: defaultPortHTTP,
DNSPort: defaultPortDNS,
}
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil {
@ -61,8 +62,8 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
}
type checkConfigReqEnt struct {
Port int `json:"port"`
IP net.IP `json:"ip"`
Port int `json:"port"`
Autofix bool `json:"autofix"`
}
@ -84,9 +85,9 @@ type staticIPJSON struct {
}
type checkConfigResp struct {
StaticIP staticIPJSON `json:"static_ip"`
Web checkConfigRespEnt `json:"web"`
DNS checkConfigRespEnt `json:"dns"`
StaticIP staticIPJSON `json:"static_ip"`
}
// Check if ports are available, respond with results
@ -298,10 +299,11 @@ func shutdownSrv(ctx context.Context, srv *http.Server) {
err := srv.Shutdown(ctx)
if err != nil {
const msgFmt = "shutting down http server %q: %s"
if errors.Is(err, context.Canceled) {
log.Debug("shutting down http server %q: %s", srv.Addr, err)
log.Debug(msgFmt, srv.Addr, err)
} else {
log.Error("shutting down http server %q: %s", srv.Addr, err)
log.Error(msgFmt, srv.Addr, err)
}
}
}
@ -436,8 +438,8 @@ func (web *Web) registerInstallHandlers() {
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct {
Port int `json:"port"`
IP []net.IP `json:"ip"`
Port int `json:"port"`
Autofix bool `json:"autofix"`
}
@ -474,13 +476,13 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
nonBetaReqData := checkConfigReq{
Web: checkConfigReqEnt{
Port: reqData.Web.Port,
IP: reqData.Web.IP[0],
Port: reqData.Web.Port,
Autofix: reqData.Web.Autofix,
},
DNS: checkConfigReqEnt{
Port: reqData.DNS.Port,
IP: reqData.DNS.IP[0],
Port: reqData.DNS.Port,
Autofix: reqData.DNS.Autofix,
},
SetStaticIP: reqData.SetStaticIP,
@ -589,9 +591,9 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default firstRunData.
type getAddrsResponseBeta struct {
Interfaces []*aghnet.NetInterface `json:"interfaces"`
WebPort int `json:"web_port"`
DNSPort int `json:"dns_port"`
Interfaces []*aghnet.NetInterface `json:"interfaces"`
}
// handleInstallConfigureBeta is a substitution of /install/get_addresses
@ -600,9 +602,10 @@ type getAddrsResponseBeta struct {
// TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallGetAddresses.
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponseBeta{}
data.WebPort = defaultPortHTTP
data.DNSPort = defaultPortDNS
data := getAddrsResponseBeta{
WebPort: defaultPortHTTP,
DNSPort: defaultPortDNS,
}
ifaces, err := aghnet.GetValidNetInterfacesForWeb()
if err != nil {