Pull request: 1558 enable dnsrewrites on disabled protection

Merge in DNS/adguard-home from 1558-always-rewrite to master

Squashed commit of the following:

commit b8508b3b5fb688cad273a9259c09ccfc07948b2f
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Oct 20 19:17:22 2021 +0300

    all: imp log of changes

commit 97e3649b670786a2936e368a9505faf52f8e8804
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Mon Oct 18 13:18:15 2021 +0300

    all: enable dnsrewrites on disabled protection
This commit is contained in:
Eugene Burkov 2021-10-20 19:52:13 +03:00
parent d7aafa7dc6
commit b8c4651dec
10 changed files with 178 additions and 132 deletions

View File

@ -46,6 +46,8 @@ and this project adheres to
### Changed ### Changed
- `$dnsrewrite` rules and other DNS rewrites will now be applied even when the
protection is disabled ([#1558]).
- DHCP gateway address, subnet mask, IP address range, and leases validations - DHCP gateway address, subnet mask, IP address range, and leases validations
([#3529]). ([#3529]).
- The `systemd` service script will now create the `/var/log` directory when it - The `systemd` service script will now create the `/var/log` directory when it
@ -155,6 +157,7 @@ In this release, the schema version has changed from 10 to 12.
- Go 1.15 support. - Go 1.15 support.
[#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381 [#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381
[#1558]: https://github.com/AdguardTeam/AdGuardHome/issues/1558
[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691
[#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898 [#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898
[#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992 [#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992

View File

@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processRestrictLocal, s.processRestrictLocal,
s.processInternalIPAddrs, s.processInternalIPAddrs,
s.processClientID, s.processClientID,
processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR, s.processLocalPTR,
s.processUpstream, s.processUpstream,
processDNSSECAfterResponse, processDNSSECAfterResponse,
@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
} }
// Apply filtering logic // Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv if ctx.proxyCtx.Res != nil {
d := ctx.proxyCtx // Go on since the response is already set.
return resultCodeSuccess
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
} }
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil ctx.protectionEnabled = s.conf.ProtectionEnabled
if !ctx.protectionEnabled {
if s.dnsFilter == nil {
return resultCodeSuccess return resultCodeSuccess
} }
@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
} }
var err error var err error
ctx.result, err = s.filterDNSRequest(ctx) if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
if err != nil {
ctx.err = err ctx.err = err
return resultCodeError return resultCodeError
@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv s := ctx.srv
d := ctx.proxyCtx d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason { switch res := ctx.result; res.Reason {
case filtering.Rewritten, case filtering.NotFilteredAllowList:
// Go on.
case
filtering.Rewritten,
filtering.RewrittenRule: filtering.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 { if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table // origQuestion is set in case we get only CNAME without IP from
// rewrites table.
break break
} }
d.Req.Question[0] = ctx.origQuestion d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion if len(d.Res.Answer) > 0 {
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Res.Answer = answer d.Res.Answer = answer
} }
case filtering.NotFilteredAllowList:
// nothing
default: default:
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for // Check the response only if the it's from an upstream. Don't check
!ctx.responseFromUpstream { // only check response if it's from an upstream server // the response if the protection is disabled since dnsrewrite rules
// aren't applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break break
} }
origResp2 := d.Res
ctx.result, err = s.filterDNSResponse(ctx) origResp := d.Res
result, err := s.filterDNSResponse(ctx)
if err != nil { if err != nil {
ctx.err = err ctx.err = err
return resultCodeError return resultCodeError
} }
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response if result != nil {
} else { ctx.result = result
ctx.result = &filtering.Result{} ctx.origResp = origResp
} }
} }
if ctx.result == nil {
ctx.result = &filtering.Result{}
}
return resultCodeSuccess return resultCodeSuccess
} }

View File

@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) {
}}, }},
} }
f := filtering.New(c, nil) f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector() snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err) require.NoError(t, err)
@ -945,9 +946,10 @@ func TestRewrite(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA) req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
@ -957,14 +959,14 @@ func TestRewrite(t *testing.T) {
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA) req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
assert.Empty(t, reply.Answer) assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA) req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Len(t, reply.Answer, 2) require.Len(t, reply.Answer, 2)
@ -972,8 +974,8 @@ func TestRewrite(t *testing.T) {
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
// The original question is restored. // The original question is restored.
require.Len(t, reply.Question, 1) require.Len(t, reply.Question, 1)
@ -984,6 +986,16 @@ func TestRewrite(t *testing.T) {
assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
} }
func publicKey(priv interface{}) interface{} { func publicKey(priv interface{}) interface{} {
@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) {
require.ErrorIs(t, hc.Close(), closeCalled) require.ErrorIs(t, hc.Close(), closeCalled)
}) })
c := filtering.Config{ flt := filtering.New(&filtering.Config{
EtcHosts: hc, EtcHosts: hc,
} }, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector() snd, err = aghnet.NewSubnetDetector()
@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server var s *Server
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: filtering.New(&c, nil), DNSFilter: flt,
SubnetDetector: snd, SubnetDetector: snd,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1112,25 +1125,24 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err = s.Prepare(nil) err = s.Prepare(nil)
require.NoError(t, err) require.NoError(t, err)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
s.Close() s.Close()
}) })
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String()) resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Lenf(t, resp.Answer, 1, "%#v", resp) require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
@ -1138,6 +1150,16 @@ func TestPTRResponseFromHosts(t *testing.T) {
ptr, ok := resp.Answer[0].(*dns.PTR) ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "host.", ptr.Ptr) assert.Equal(t, "host.", ptr.Ptr)
}
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
} }
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {

View File

@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
// the client's IP address and ID, if any, from ctx. // the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig() setts := s.dnsFilter.GetConfig()
setts.ProtectionEnabled = ctx.protectionEnabled
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr) ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts) s.conf.FilterHandler(ip, ctx.clientID, &setts)
@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
d := ctx.proxyCtx d := ctx.proxyCtx
req := d.Req req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".") q := req.Question[0]
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts) host := strings.TrimSuffix(q.Name, ".")
if err != nil { res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
// Return immediately if there's an error switch {
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err) case err != nil:
} else if res.IsFiltered { return nil, fmt.Errorf("failed to check host %q: %w", host, err)
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text) case res.IsFiltered:
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res) d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" && res.CanonName != "" &&
len(res.IPList) == 0 { len(res.IPList) == 0:
// Resolve the new canonical name, not the original host // Resolve the new canonical name, not the original host name. The
// name. The original question is readded in // original question is readded in processFilteringAfterResponse.
// processFilteringAfterResponse. ctx.origQuestion = q
ctx.origQuestion = req.Question[0]
req.Question[0].Name = dns.Fqdn(res.CanonName) req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 { case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0:
resp := s.makeResponse(req) resp := s.makeResponse(req)
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{ hdr := dns.RR_Header{
Name: req.Question[0].Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL, Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET, Class: dns.ClassINET,
} }
for _, h := range res.ReverseHosts {
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: hdr, Hdr: hdr,
Ptr: h, Ptr: h,
@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
d.Res = resp d.Res = resp
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) { case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts):
resp := s.makeResponse(req) resp := s.makeResponse(req)
name := host name := host
@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
for _, ip := range res.IPList { for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA { switch q.Qtype {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4()) a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA { case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip) a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)
@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
d.Res = resp d.Res = resp
} else if res.Reason == filtering.RewrittenRule { case res.Reason == filtering.RewrittenRule:
err = s.filterDNSRewrite(req, res, d) if err = s.filterDNSRewrite(req, res, d); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
continue continue
} }
host = strings.TrimSuffix(host, ".")
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -38,6 +38,7 @@ type Settings struct {
ServicesRules []ServiceEntry ServicesRules []ServiceEntry
ProtectionEnabled bool
FilteringEnabled bool FilteringEnabled bool
SafeSearchEnabled bool SafeSearchEnabled bool
SafeBrowsingEnabled bool SafeBrowsingEnabled bool
@ -221,12 +222,13 @@ func (r Reason) String() string {
} }
// In returns true if reasons include r. // In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) bool { func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons { for _, reason := range reasons {
if r == reason { if r == reason {
return true return true
} }
} }
return false return false
} }
@ -245,7 +247,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
defer d.confLock.RUnlock() defer d.confLock.RUnlock()
return Settings{ return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1, FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled, SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled, ParentalEnabled: d.Config.ParentalEnabled,
@ -421,15 +423,17 @@ func (d *DNSFilter) CheckHost(
// Sometimes clients try to resolve ".", which is a request to get root // Sometimes clients try to resolve ".", which is a request to get root
// servers. // servers.
if host == "" { if host == "" {
return Result{Reason: NotFilteredNotFound}, nil return Result{}, nil
} }
host = strings.ToLower(host) host = strings.ToLower(host)
if setts.FilteringEnabled {
res = d.processRewrites(host, qtype) res = d.processRewrites(host, qtype)
if res.Reason == Rewritten { if res.Reason == Rewritten {
return res, nil return res, nil
} }
}
for _, hc := range d.hostCheckers { for _, hc := range d.hostCheckers {
res, err = hc.check(host, qtype, setts) res, err = hc.check(host, qtype, setts)
@ -448,7 +452,7 @@ func (d *DNSFilter) CheckHost(
// matchSysHosts tries to match the host against the operating system's hosts // matchSysHosts tries to match the host against the operating system's hosts
// database. // database.
func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) { func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
if d.EtcHosts == nil { if !setts.FilteringEnabled || d.EtcHosts == nil {
return Result{}, nil return Result{}, nil
} }
@ -468,10 +472,8 @@ func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (r
var ips []net.IP var ips []net.IP
var revHosts []string var revHosts []string
for _, nr := range dnsr { for _, nr := range dnsr {
dr := nr.DNSRewrite if nr.DNSRewrite == nil {
if dr == nil {
continue continue
} }
@ -553,6 +555,10 @@ func matchBlockedServicesRules(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ProtectionEnabled {
return Result{}, nil
}
svcs := setts.ServicesRules svcs := setts.ServicesRules
if len(svcs) == 0 { if len(svcs) == 0 {
return Result{}, nil return Result{}, nil
@ -784,7 +790,7 @@ func (d *DNSFilter) matchHost(
// TODO(e.burkov): Inspect if the above is true. // TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock() defer d.engineLock.RUnlock()
if d.filteringEngineAllow != nil { if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok { if ok {
return d.matchHostProcessAllowList(host, dnsres) return d.matchHostProcessAllowList(host, dnsres)
@ -810,6 +816,11 @@ func (d *DNSFilter) matchHost(
return Result{}, nil return Result{}, nil
} }
if !setts.ProtectionEnabled {
// Don't check non-dnsrewrite filtering results.
return Result{}, nil
}
res = d.matchHostProcessDNSResult(qtype, dnsres) res = d.matchHostProcessDNSResult(qtype, dnsres)
for _, r := range res.Rules { for _, r := range res.Rules {
log.Debug( log.Debug(

View File

@ -21,7 +21,9 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
var setts Settings var setts = Settings{
ProtectionEnabled: true,
}
// Helpers. // Helpers.
@ -39,9 +41,9 @@ func purgeCaches() {
func newForTest(c *Config, filters []Filter) *DNSFilter { func newForTest(c *Config, filters []Filter) *DNSFilter {
setts = Settings{ setts = Settings{
ProtectionEnabled: true,
FilteringEnabled: true, FilteringEnabled: true,
} }
setts.FilteringEnabled = true
if c != nil { if c != nil {
c.SafeBrowsingCacheSize = 10000 c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000 c.ParentalCacheSize = 10000
@ -797,7 +799,11 @@ func TestClientSettings(t *testing.T) {
makeTester := func(tc testCase, before bool) func(t *testing.T) { makeTester := func(tc testCase, before bool) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts) t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
require.NoError(t, err)
if before { if before {
assert.True(t, r.IsFiltered) assert.True(t, r.IsFiltered)
assert.Equal(t, tc.wantReason, r.Reason) assert.Equal(t, tc.wantReason, r.Reason)
@ -808,7 +814,7 @@ func TestClientSettings(t *testing.T) {
} }
// Check behaviour without any per-client settings, then apply per-client // Check behaviour without any per-client settings, then apply per-client
// settings and check behaviour once again. // settings and check behavior once again.
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, tc.before)) t.Run(tc.name, makeTester(tc, tc.before))
} }

View File

@ -306,7 +306,7 @@ func (d *DNSFilter) checkSafeBrowsing(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.SafeBrowsingEnabled { if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil return Result{}, nil
} }
@ -339,7 +339,7 @@ func (d *DNSFilter) checkParental(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ParentalEnabled { if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil return Result{}, nil
} }

View File

@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
setts := &Settings{ setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
ParentalEnabled: true, ParentalEnabled: true,
} }
@ -135,35 +136,36 @@ func TestSBPC(t *testing.T) {
const hostname = "example.org" const hostname = "example.org"
setts := &Settings{ setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
ParentalEnabled: true, ParentalEnabled: true,
} }
testCases := []struct { testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string name string
block bool block bool
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
testCache cache.Cache
}{{ }{{
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block", name: "sb_no_block",
block: false, block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, { }, {
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block", name: "sb_block",
block: true, block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, { }, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block", name: "pc_no_block",
block: false, block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, { }, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_block", name: "pc_block",
block: true, block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}} }}
for _, tc := range testCases { for _, tc := range testCases {

View File

@ -74,7 +74,7 @@ func (d *DNSFilter) checkSafeSearch(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.SafeSearchEnabled { if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil return Result{}, nil
} }

View File

@ -404,6 +404,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig() setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {