Pull request: all: imp code, decr cyclo

Updates #2646.

Squashed commit of the following:

commit c83c230f3d2c542d7b1a4bc0e1c503d5bbc16cb8
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Mar 25 19:47:11 2021 +0300

    all: imp code, decr cyclo
This commit is contained in:
Ainar Garipov 2021-03-25 20:30:30 +03:00
parent 27f4f05273
commit 8c735d0dd5
13 changed files with 240 additions and 187 deletions

View File

@ -254,7 +254,7 @@ func BlockedSvcKnown(s string) bool {
}
// ApplyBlockedServices - set blocked services settings for this DNS request
func (d *DNSFilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
func (d *DNSFilter) ApplyBlockedServices(setts *FilteringSettings, list []string, global bool) {
setts.ServicesRules = []ServiceEntry{}
if global {
d.confLock.RLock()

View File

@ -29,18 +29,18 @@ type ServiceEntry struct {
Rules []*rules.NetworkRule
}
// RequestFilteringSettings is custom filtering settings
type RequestFilteringSettings struct {
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
// FilteringSettings are custom filtering settings for a client.
type FilteringSettings struct {
ClientName string
ClientIP net.IP
ClientTags []string
ServicesRules []ServiceEntry
FilteringEnabled bool
SafeSearchEnabled bool
SafeBrowsingEnabled bool
ParentalEnabled bool
}
// Resolver is the interface for net.Resolver to simplify testing.
@ -99,6 +99,11 @@ type filtersInitializerParams struct {
blockFilters []Filter
}
type hostChecker struct {
check func(host string, qtype uint16, setts *FilteringSettings) (res Result, err error)
name string
}
// DNSFilter matches hostnames and DNS requests against filtering rules.
type DNSFilter struct {
rulesStorage *filterlist.RuleStorage
@ -123,6 +128,8 @@ type DNSFilter struct {
//
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
resolver Resolver
hostCheckers []hostChecker
}
// Filter represents a filter list
@ -216,8 +223,8 @@ func (r Reason) In(reasons ...Reason) bool {
}
// GetConfig - get configuration
func (d *DNSFilter) GetConfig() RequestFilteringSettings {
c := RequestFilteringSettings{}
func (d *DNSFilter) GetConfig() FilteringSettings {
c := FilteringSettings{}
// d.confLock.RLock()
c.SafeSearchEnabled = d.Config.SafeSearchEnabled
c.SafeBrowsingEnabled = d.Config.SafeBrowsingEnabled
@ -372,122 +379,85 @@ func (r Reason) Matched() bool {
}
// CheckHostRules tries to match the host against filtering rules only.
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) {
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *FilteringSettings) (Result, error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
return d.matchHost(host, qtype, *setts)
return d.matchHost(host, qtype, setts)
}
// CheckHost tries to match the host against filtering rules, then
// safebrowsing and parental control rules, if they are enabled.
func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) {
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
// CheckHost tries to match the host against filtering rules, then safebrowsing
// and parental control rules, if they are enabled.
func (d *DNSFilter) CheckHost(
host string,
qtype uint16,
setts *FilteringSettings,
) (res Result, err error) {
// Sometimes clients try to resolve ".", which is a request to get root
// servers.
if host == "" {
return Result{Reason: NotFilteredNotFound}, nil
}
host = strings.ToLower(host)
var result Result
var err error
// first - check rewrites, they have the highest priority
result = d.processRewrites(host, qtype)
if result.Reason == Rewritten {
return result, nil
res = d.processRewrites(host, qtype)
if res.Reason == Rewritten {
return res, nil
}
// Now check the hosts file -- do we have any rules for it?
// just like DNS rewrites, it has higher priority than filtering rules.
if d.Config.AutoHosts != nil {
matched := d.checkAutoHosts(host, qtype, &result)
if matched {
return result, nil
}
}
if setts.FilteringEnabled {
result, err = d.matchHost(host, qtype, *setts)
for _, hc := range d.hostCheckers {
res, err = hc.check(host, qtype, setts)
if err != nil {
return result, err
}
if result.Reason.Matched() {
return result, nil
}
}
// are there any blocked services?
if len(setts.ServicesRules) != 0 {
result = matchBlockedServicesRules(host, setts.ServicesRules)
if result.Reason.Matched() {
return result, nil
}
}
// browsing security web service
if setts.SafeBrowsingEnabled {
result, err = d.checkSafeBrowsing(host)
if err != nil {
log.Info("SafeBrowsing: failed: %v", err)
return Result{}, nil
}
if result.Reason.Matched() {
return result, nil
}
}
// parental control web service
if setts.ParentalEnabled {
result, err = d.checkParental(host)
if err != nil {
log.Printf("Parental: failed: %v", err)
return Result{}, nil
}
if result.Reason.Matched() {
return result, nil
}
}
// apply safe search if needed
if setts.SafeSearchEnabled {
result, err = d.checkSafeSearch(host)
if err != nil {
log.Info("SafeSearch: failed: %v", err)
return Result{}, nil
return Result{}, fmt.Errorf("%s: %w", hc.name, err)
}
if result.Reason.Matched() {
return result, nil
if res.Reason.Matched() {
return res, nil
}
}
return Result{}, nil
}
func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (matched bool) {
// checkAutoHosts compares the host against our autohosts table. The err is
// always nil, it is only there to make this a valid hostChecker function.
func (d *DNSFilter) checkAutoHosts(
host string,
qtype uint16,
_ *FilteringSettings,
) (res Result, err error) {
if d.Config.AutoHosts == nil {
return Result{}, nil
}
ips := d.Config.AutoHosts.Process(host, qtype)
if ips != nil {
result.Reason = RewrittenAutoHosts
result.IPList = ips
res = Result{
Reason: RewrittenAutoHosts,
IPList: ips,
}
return true
return res, nil
}
revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype)
if len(revHosts) != 0 {
result.Reason = RewrittenAutoHosts
// TODO(a.garipov): Optimize this with a buffer.
result.ReverseHosts = make([]string, len(revHosts))
for i := range revHosts {
result.ReverseHosts[i] = revHosts[i] + "."
res = Result{
Reason: RewrittenAutoHosts,
}
return true
// TODO(a.garipov): Optimize this with a buffer.
res.ReverseHosts = make([]string, len(revHosts))
for i := range revHosts {
res.ReverseHosts[i] = revHosts[i] + "."
}
return res, nil
}
return false
return Result{}, nil
}
// Process rewrites table
@ -545,10 +515,20 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
return res
}
func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result {
req := rules.NewRequestForHostname(host)
res := Result{}
// matchBlockedServicesRules checks the host against the blocked services rules
// in settings, if any. The err is always nil, it is only there to make this
// a valid hostChecker function.
func matchBlockedServicesRules(
host string,
_ uint16,
setts *FilteringSettings,
) (res Result, err error) {
svcs := setts.ServicesRules
if len(svcs) == 0 {
return Result{}, nil
}
req := rules.NewRequestForHostname(host)
for _, s := range svcs {
for _, rule := range s.Rules {
if rule.Match(req) {
@ -565,11 +545,12 @@ func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result {
log.Debug("blocked services: matched rule: %s host: %s service: %s",
ruleText, host, s.Name)
return res
return res, nil
}
}
}
return res
return res, nil
}
//
@ -680,7 +661,15 @@ func (d *DNSFilter) matchHostProcessAllowList(host string, dnsres urlfilter.DNSR
// matchHost is a low-level way to check only if hostname is filtered by rules,
// skipping expensive safebrowsing and parental lookups.
func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringSettings) (res Result, err error) {
func (d *DNSFilter) matchHost(
host string,
qtype uint16,
setts *FilteringSettings,
) (res Result, err error) {
if !setts.FilteringEnabled {
return Result{}, nil
}
d.engineLock.RLock()
// Keep in mind that this lock must be held no just when calling Match()
// but also while using the rules returned by it.
@ -827,6 +816,26 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
resolver: resolver,
}
d.hostCheckers = []hostChecker{{
check: d.checkAutoHosts,
name: "autohosts",
}, {
check: d.matchHost,
name: "filtering",
}, {
check: matchBlockedServicesRules,
name: "blocked services",
}, {
check: d.checkSafeBrowsing,
name: "safe browsing",
}, {
check: d.checkParental,
name: "parental",
}, {
check: d.checkSafeSearch,
name: "safe search",
}}
err := d.initSecurityServices()
if err != nil {
log.Error("dnsfilter: initialize services: %s", err)

View File

@ -21,7 +21,7 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
var setts RequestFilteringSettings
var setts FilteringSettings
// Helpers.
@ -38,7 +38,7 @@ func purgeCaches() {
}
func newForTest(c *Config, filters []Filter) *DNSFilter {
setts = RequestFilteringSettings{
setts = FilteringSettings{
FilteringEnabled: true,
}
setts.FilteringEnabled = true
@ -699,7 +699,7 @@ func TestWhitelist(t *testing.T) {
// Client Settings.
func applyClientSettings(setts *RequestFilteringSettings) {
func applyClientSettings(setts *FilteringSettings) {
setts.FilteringEnabled = false
setts.ParentalEnabled = false
setts.SafeBrowsingEnabled = true

View File

@ -47,7 +47,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
`
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
setts := &RequestFilteringSettings{
setts := &FilteringSettings{
FilteringEnabled: true,
}

View File

@ -304,46 +304,70 @@ func check(c *sbCtx, r Result, u upstream.Upstream) (Result, error) {
return Result{}, nil
}
func (d *DNSFilter) checkSafeBrowsing(host string) (Result, error) {
// TODO(a.garipov): Unify with checkParental.
func (d *DNSFilter) checkSafeBrowsing(
host string,
_ uint16,
setts *FilteringSettings,
) (res Result, err error) {
if !setts.SafeBrowsingEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
}
ctx := &sbCtx{
sctx := &sbCtx{
host: host,
svc: "SafeBrowsing",
cache: gctx.safebrowsingCache,
cacheTime: d.Config.CacheTime,
}
res := Result{
res = Result{
IsFiltered: true,
Reason: FilteredSafeBrowsing,
Rules: []*ResultRule{{
Text: "adguard-malware-shavar",
}},
}
return check(ctx, res, d.safeBrowsingUpstream)
return check(sctx, res, d.safeBrowsingUpstream)
}
func (d *DNSFilter) checkParental(host string) (Result, error) {
// TODO(a.garipov): Unify with checkSafeBrowsing.
func (d *DNSFilter) checkParental(
host string,
_ uint16,
setts *FilteringSettings,
) (res Result, err error) {
if !setts.ParentalEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host)
}
ctx := &sbCtx{
sctx := &sbCtx{
host: host,
svc: "Parental",
cache: gctx.parentalCache,
cacheTime: d.Config.CacheTime,
}
res := Result{
res = Result{
IsFiltered: true,
Reason: FilteredParental,
Rules: []*ResultRule{{
Text: "parental CATEGORY_BLACKLISTED",
}},
}
return check(ctx, res, d.parentalUpstream)
return check(sctx, res, d.parentalUpstream)
}
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/cache"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -115,11 +116,16 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups)
_, err := d.checkSafeBrowsing("smthng.com")
assert.NotNil(t, err)
setts := &FilteringSettings{
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
_, err = d.checkParental("smthng.com")
assert.NotNil(t, err)
_, err := d.checkSafeBrowsing("smthng.com", dns.TypeA, setts)
assert.Error(t, err)
_, err = d.checkParental("smthng.com", dns.TypeA, setts)
assert.Error(t, err)
}
func TestSBPC(t *testing.T) {
@ -128,10 +134,15 @@ func TestSBPC(t *testing.T) {
const hostname = "example.org"
setts := &FilteringSettings{
SafeBrowsingEnabled: true,
ParentalEnabled: true,
}
testCases := []struct {
name string
block bool
testFunc func(string) (Result, error)
testFunc func(host string, _ uint16, _ *FilteringSettings) (res Result, err error)
testCache cache.Cache
}{{
name: "sb_no_block",
@ -167,8 +178,9 @@ func TestSBPC(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// Firstly, check the request blocking.
hits := 0
res, err := tc.testFunc(hostname)
require.Nil(t, err)
res, err := tc.testFunc(hostname, dns.TypeA, setts)
require.NoError(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
@ -185,8 +197,9 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, 1, ups.RequestsCount())
// Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname)
require.Nil(t, err)
res, err = tc.testFunc(hostname, dns.TypeA, setts)
require.NoError(t, err)
if tc.block {
assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)

View File

@ -69,7 +69,15 @@ func (d *DNSFilter) SafeSearchDomain(host string) (string, bool) {
return val, ok
}
func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
func (d *DNSFilter) checkSafeSearch(
host string,
_ uint16,
setts *FilteringSettings,
) (res Result, err error) {
if !setts.SafeSearchEnabled {
return Result{}, nil
}
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
@ -88,7 +96,7 @@ func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
return Result{}, nil
}
res := Result{
res = Result{
IsFiltered: true,
Reason: FilteredSafeSearch,
Rules: []*ResultRule{{}},

View File

@ -25,7 +25,7 @@ type FilteringConfig struct {
// --
// FilterHandler is an optional additional filtering callback.
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.FilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client

View File

@ -20,7 +20,7 @@ type dnsContext struct {
srv *Server
proxyCtx *proxy.DNSContext
// setts are the filtering settings for the client.
setts *dnsfilter.RequestFilteringSettings
setts *dnsfilter.FilteringSettings
startTime time.Time
result *dnsfilter.Result
// origResp is the response received from upstream. It is set when the

View File

@ -632,7 +632,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
TCPListenAddrs: []*net.TCPAddr{{}},
FilteringConfig: FilteringConfig{
ProtectionEnabled: true,
FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.FilteringSettings) {
settings.FilteringEnabled = false
},
},

View File

@ -32,7 +32,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
// getClientRequestFilteringSettings looks up client filtering settings using
// the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.FilteringSettings {
setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true
if s.conf.FilterHandler != nil {

View File

@ -276,7 +276,7 @@ func getDNSEncryption() (de dnsEncryption) {
// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.FilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if clientAddr == nil {

View File

@ -375,6 +375,64 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
}
}
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) {
var err error
switch key {
case "RCode":
var vToken json.Token
vToken, err = dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultDNSRewriteResultKey err: %s", err)
}
return
}
if ent.Result.DNSRewriteResult == nil {
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
}
if n, ok := vToken.(json.Number); ok {
rcode64, _ := n.Int64()
ent.Result.DNSRewriteResult.RCode = rules.RCode(rcode64)
}
case "Response":
if ent.Result.DNSRewriteResult == nil {
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
}
if ent.Result.DNSRewriteResult.Response == nil {
ent.Result.DNSRewriteResult.Response = dnsfilter.DNSRewriteResultResponse{}
}
// TODO(a.garipov): I give up. This whole file is a mess.
// Luckily, we can assume that this field is relatively rare and
// just use the normal decoding and correct the values.
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
if err != nil {
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err)
}
for rrType, rrValues := range ent.Result.DNSRewriteResult.Response {
switch rrType {
case
dns.TypeA,
dns.TypeAAAA:
for i, v := range rrValues {
s, _ := v.(string)
rrValues[i] = net.ParseIP(s)
}
default:
// Go on.
}
}
default:
// Go on.
}
}
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
for {
keyToken, err := dec.Token()
@ -401,66 +459,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
return
}
// TODO(a.garipov): Refactor this into a separate
// function à la decodeResultRuleKey if we keep this
// code for a longer time than planned.
switch key {
case "RCode":
var vToken json.Token
vToken, err = dec.Token()
if err != nil {
if err != io.EOF {
log.Debug("decodeResultDNSRewriteResult err: %s", err)
}
return
}
if ent.Result.DNSRewriteResult == nil {
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
}
var n json.Number
if n, ok = vToken.(json.Number); ok {
rcode64, _ := n.Int64()
ent.Result.DNSRewriteResult.RCode = rules.RCode(rcode64)
}
continue
case "Response":
if ent.Result.DNSRewriteResult == nil {
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
}
if ent.Result.DNSRewriteResult.Response == nil {
ent.Result.DNSRewriteResult.Response = dnsfilter.DNSRewriteResultResponse{}
}
// TODO(a.garipov): I give up. This whole file
// is a mess. Luckily, we can assume that this
// field is relatively rare and just use the
// normal decoding and correct the values.
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
if err != nil {
log.Debug("decodeResultDNSRewriteResult response err: %s", err)
}
for rrType, rrValues := range ent.Result.DNSRewriteResult.Response {
switch rrType {
case dns.TypeA, dns.TypeAAAA:
for i, v := range rrValues {
s, _ := v.(string)
rrValues[i] = net.ParseIP(s)
}
default:
// Go on.
}
}
continue
default:
// Go on.
}
decodeResultDNSRewriteResultKey(key, dec, ent)
}
}