Pull request: all: imp code, decr cyclo
Updates #2646. Squashed commit of the following: commit c83c230f3d2c542d7b1a4bc0e1c503d5bbc16cb8 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Mar 25 19:47:11 2021 +0300 all: imp code, decr cyclo
This commit is contained in:
parent
27f4f05273
commit
8c735d0dd5
|
@ -254,7 +254,7 @@ func BlockedSvcKnown(s string) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyBlockedServices - set blocked services settings for this DNS request
|
// ApplyBlockedServices - set blocked services settings for this DNS request
|
||||||
func (d *DNSFilter) ApplyBlockedServices(setts *RequestFilteringSettings, list []string, global bool) {
|
func (d *DNSFilter) ApplyBlockedServices(setts *FilteringSettings, list []string, global bool) {
|
||||||
setts.ServicesRules = []ServiceEntry{}
|
setts.ServicesRules = []ServiceEntry{}
|
||||||
if global {
|
if global {
|
||||||
d.confLock.RLock()
|
d.confLock.RLock()
|
||||||
|
|
|
@ -29,18 +29,18 @@ type ServiceEntry struct {
|
||||||
Rules []*rules.NetworkRule
|
Rules []*rules.NetworkRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequestFilteringSettings is custom filtering settings
|
// FilteringSettings are custom filtering settings for a client.
|
||||||
type RequestFilteringSettings struct {
|
type FilteringSettings struct {
|
||||||
FilteringEnabled bool
|
|
||||||
SafeSearchEnabled bool
|
|
||||||
SafeBrowsingEnabled bool
|
|
||||||
ParentalEnabled bool
|
|
||||||
|
|
||||||
ClientName string
|
ClientName string
|
||||||
ClientIP net.IP
|
ClientIP net.IP
|
||||||
ClientTags []string
|
ClientTags []string
|
||||||
|
|
||||||
ServicesRules []ServiceEntry
|
ServicesRules []ServiceEntry
|
||||||
|
|
||||||
|
FilteringEnabled bool
|
||||||
|
SafeSearchEnabled bool
|
||||||
|
SafeBrowsingEnabled bool
|
||||||
|
ParentalEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolver is the interface for net.Resolver to simplify testing.
|
// Resolver is the interface for net.Resolver to simplify testing.
|
||||||
|
@ -99,6 +99,11 @@ type filtersInitializerParams struct {
|
||||||
blockFilters []Filter
|
blockFilters []Filter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type hostChecker struct {
|
||||||
|
check func(host string, qtype uint16, setts *FilteringSettings) (res Result, err error)
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
// DNSFilter matches hostnames and DNS requests against filtering rules.
|
// DNSFilter matches hostnames and DNS requests against filtering rules.
|
||||||
type DNSFilter struct {
|
type DNSFilter struct {
|
||||||
rulesStorage *filterlist.RuleStorage
|
rulesStorage *filterlist.RuleStorage
|
||||||
|
@ -123,6 +128,8 @@ type DNSFilter struct {
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
// TODO(e.burkov): Use upstream that configured in dnsforward instead.
|
||||||
resolver Resolver
|
resolver Resolver
|
||||||
|
|
||||||
|
hostCheckers []hostChecker
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter represents a filter list
|
// Filter represents a filter list
|
||||||
|
@ -216,8 +223,8 @@ func (r Reason) In(reasons ...Reason) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfig - get configuration
|
// GetConfig - get configuration
|
||||||
func (d *DNSFilter) GetConfig() RequestFilteringSettings {
|
func (d *DNSFilter) GetConfig() FilteringSettings {
|
||||||
c := RequestFilteringSettings{}
|
c := FilteringSettings{}
|
||||||
// d.confLock.RLock()
|
// d.confLock.RLock()
|
||||||
c.SafeSearchEnabled = d.Config.SafeSearchEnabled
|
c.SafeSearchEnabled = d.Config.SafeSearchEnabled
|
||||||
c.SafeBrowsingEnabled = d.Config.SafeBrowsingEnabled
|
c.SafeBrowsingEnabled = d.Config.SafeBrowsingEnabled
|
||||||
|
@ -372,122 +379,85 @@ func (r Reason) Matched() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHostRules tries to match the host against filtering rules only.
|
// CheckHostRules tries to match the host against filtering rules only.
|
||||||
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) {
|
func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *FilteringSettings) (Result, error) {
|
||||||
if !setts.FilteringEnabled {
|
if !setts.FilteringEnabled {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return d.matchHost(host, qtype, *setts)
|
return d.matchHost(host, qtype, setts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckHost tries to match the host against filtering rules, then
|
// CheckHost tries to match the host against filtering rules, then safebrowsing
|
||||||
// safebrowsing and parental control rules, if they are enabled.
|
// and parental control rules, if they are enabled.
|
||||||
func (d *DNSFilter) CheckHost(host string, qtype uint16, setts *RequestFilteringSettings) (Result, error) {
|
func (d *DNSFilter) CheckHost(
|
||||||
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
host string,
|
||||||
|
qtype uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
// Sometimes clients try to resolve ".", which is a request to get root
|
||||||
|
// servers.
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return Result{Reason: NotFilteredNotFound}, nil
|
return Result{Reason: NotFilteredNotFound}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
host = strings.ToLower(host)
|
host = strings.ToLower(host)
|
||||||
|
|
||||||
var result Result
|
res = d.processRewrites(host, qtype)
|
||||||
var err error
|
if res.Reason == Rewritten {
|
||||||
|
return res, nil
|
||||||
// first - check rewrites, they have the highest priority
|
|
||||||
result = d.processRewrites(host, qtype)
|
|
||||||
if result.Reason == Rewritten {
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now check the hosts file -- do we have any rules for it?
|
for _, hc := range d.hostCheckers {
|
||||||
// just like DNS rewrites, it has higher priority than filtering rules.
|
res, err = hc.check(host, qtype, setts)
|
||||||
if d.Config.AutoHosts != nil {
|
|
||||||
matched := d.checkAutoHosts(host, qtype, &result)
|
|
||||||
if matched {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if setts.FilteringEnabled {
|
|
||||||
result, err = d.matchHost(host, qtype, *setts)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return result, err
|
return Result{}, fmt.Errorf("%s: %w", hc.name, err)
|
||||||
}
|
|
||||||
if result.Reason.Matched() {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// are there any blocked services?
|
|
||||||
if len(setts.ServicesRules) != 0 {
|
|
||||||
result = matchBlockedServicesRules(host, setts.ServicesRules)
|
|
||||||
if result.Reason.Matched() {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// browsing security web service
|
|
||||||
if setts.SafeBrowsingEnabled {
|
|
||||||
result, err = d.checkSafeBrowsing(host)
|
|
||||||
if err != nil {
|
|
||||||
log.Info("SafeBrowsing: failed: %v", err)
|
|
||||||
return Result{}, nil
|
|
||||||
}
|
|
||||||
if result.Reason.Matched() {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// parental control web service
|
|
||||||
if setts.ParentalEnabled {
|
|
||||||
result, err = d.checkParental(host)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Parental: failed: %v", err)
|
|
||||||
return Result{}, nil
|
|
||||||
}
|
|
||||||
if result.Reason.Matched() {
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// apply safe search if needed
|
|
||||||
if setts.SafeSearchEnabled {
|
|
||||||
result, err = d.checkSafeSearch(host)
|
|
||||||
if err != nil {
|
|
||||||
log.Info("SafeSearch: failed: %v", err)
|
|
||||||
return Result{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Reason.Matched() {
|
if res.Reason.Matched() {
|
||||||
return result, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkAutoHosts(host string, qtype uint16, result *Result) (matched bool) {
|
// checkAutoHosts compares the host against our autohosts table. The err is
|
||||||
|
// always nil, it is only there to make this a valid hostChecker function.
|
||||||
|
func (d *DNSFilter) checkAutoHosts(
|
||||||
|
host string,
|
||||||
|
qtype uint16,
|
||||||
|
_ *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
if d.Config.AutoHosts == nil {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
ips := d.Config.AutoHosts.Process(host, qtype)
|
ips := d.Config.AutoHosts.Process(host, qtype)
|
||||||
if ips != nil {
|
if ips != nil {
|
||||||
result.Reason = RewrittenAutoHosts
|
res = Result{
|
||||||
result.IPList = ips
|
Reason: RewrittenAutoHosts,
|
||||||
|
IPList: ips,
|
||||||
|
}
|
||||||
|
|
||||||
return true
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype)
|
revHosts := d.Config.AutoHosts.ProcessReverse(host, qtype)
|
||||||
if len(revHosts) != 0 {
|
if len(revHosts) != 0 {
|
||||||
result.Reason = RewrittenAutoHosts
|
res = Result{
|
||||||
|
Reason: RewrittenAutoHosts,
|
||||||
// TODO(a.garipov): Optimize this with a buffer.
|
|
||||||
result.ReverseHosts = make([]string, len(revHosts))
|
|
||||||
for i := range revHosts {
|
|
||||||
result.ReverseHosts[i] = revHosts[i] + "."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
// TODO(a.garipov): Optimize this with a buffer.
|
||||||
|
res.ReverseHosts = make([]string, len(revHosts))
|
||||||
|
for i := range revHosts {
|
||||||
|
res.ReverseHosts[i] = revHosts[i] + "."
|
||||||
|
}
|
||||||
|
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process rewrites table
|
// Process rewrites table
|
||||||
|
@ -545,10 +515,20 @@ func (d *DNSFilter) processRewrites(host string, qtype uint16) (res Result) {
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result {
|
// matchBlockedServicesRules checks the host against the blocked services rules
|
||||||
req := rules.NewRequestForHostname(host)
|
// in settings, if any. The err is always nil, it is only there to make this
|
||||||
res := Result{}
|
// a valid hostChecker function.
|
||||||
|
func matchBlockedServicesRules(
|
||||||
|
host string,
|
||||||
|
_ uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
svcs := setts.ServicesRules
|
||||||
|
if len(svcs) == 0 {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := rules.NewRequestForHostname(host)
|
||||||
for _, s := range svcs {
|
for _, s := range svcs {
|
||||||
for _, rule := range s.Rules {
|
for _, rule := range s.Rules {
|
||||||
if rule.Match(req) {
|
if rule.Match(req) {
|
||||||
|
@ -565,11 +545,12 @@ func matchBlockedServicesRules(host string, svcs []ServiceEntry) Result {
|
||||||
log.Debug("blocked services: matched rule: %s host: %s service: %s",
|
log.Debug("blocked services: matched rule: %s host: %s service: %s",
|
||||||
ruleText, host, s.Name)
|
ruleText, host, s.Name)
|
||||||
|
|
||||||
return res
|
return res, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return res
|
|
||||||
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -680,7 +661,15 @@ func (d *DNSFilter) matchHostProcessAllowList(host string, dnsres urlfilter.DNSR
|
||||||
|
|
||||||
// matchHost is a low-level way to check only if hostname is filtered by rules,
|
// matchHost is a low-level way to check only if hostname is filtered by rules,
|
||||||
// skipping expensive safebrowsing and parental lookups.
|
// skipping expensive safebrowsing and parental lookups.
|
||||||
func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringSettings) (res Result, err error) {
|
func (d *DNSFilter) matchHost(
|
||||||
|
host string,
|
||||||
|
qtype uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
if !setts.FilteringEnabled {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
d.engineLock.RLock()
|
d.engineLock.RLock()
|
||||||
// Keep in mind that this lock must be held no just when calling Match()
|
// Keep in mind that this lock must be held no just when calling Match()
|
||||||
// but also while using the rules returned by it.
|
// but also while using the rules returned by it.
|
||||||
|
@ -827,6 +816,26 @@ func New(c *Config, blockFilters []Filter) *DNSFilter {
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
d.hostCheckers = []hostChecker{{
|
||||||
|
check: d.checkAutoHosts,
|
||||||
|
name: "autohosts",
|
||||||
|
}, {
|
||||||
|
check: d.matchHost,
|
||||||
|
name: "filtering",
|
||||||
|
}, {
|
||||||
|
check: matchBlockedServicesRules,
|
||||||
|
name: "blocked services",
|
||||||
|
}, {
|
||||||
|
check: d.checkSafeBrowsing,
|
||||||
|
name: "safe browsing",
|
||||||
|
}, {
|
||||||
|
check: d.checkParental,
|
||||||
|
name: "parental",
|
||||||
|
}, {
|
||||||
|
check: d.checkSafeSearch,
|
||||||
|
name: "safe search",
|
||||||
|
}}
|
||||||
|
|
||||||
err := d.initSecurityServices()
|
err := d.initSecurityServices()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("dnsfilter: initialize services: %s", err)
|
log.Error("dnsfilter: initialize services: %s", err)
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestMain(m *testing.M) {
|
||||||
aghtest.DiscardLogOutput(m)
|
aghtest.DiscardLogOutput(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
var setts RequestFilteringSettings
|
var setts FilteringSettings
|
||||||
|
|
||||||
// Helpers.
|
// Helpers.
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ func purgeCaches() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func newForTest(c *Config, filters []Filter) *DNSFilter {
|
func newForTest(c *Config, filters []Filter) *DNSFilter {
|
||||||
setts = RequestFilteringSettings{
|
setts = FilteringSettings{
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
}
|
}
|
||||||
setts.FilteringEnabled = true
|
setts.FilteringEnabled = true
|
||||||
|
@ -699,7 +699,7 @@ func TestWhitelist(t *testing.T) {
|
||||||
|
|
||||||
// Client Settings.
|
// Client Settings.
|
||||||
|
|
||||||
func applyClientSettings(setts *RequestFilteringSettings) {
|
func applyClientSettings(setts *FilteringSettings) {
|
||||||
setts.FilteringEnabled = false
|
setts.FilteringEnabled = false
|
||||||
setts.ParentalEnabled = false
|
setts.ParentalEnabled = false
|
||||||
setts.SafeBrowsingEnabled = true
|
setts.SafeBrowsingEnabled = true
|
||||||
|
|
|
@ -47,7 +47,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
|
||||||
`
|
`
|
||||||
|
|
||||||
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
f := newForTest(nil, []Filter{{ID: 0, Data: []byte(text)}})
|
||||||
setts := &RequestFilteringSettings{
|
setts := &FilteringSettings{
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -304,46 +304,70 @@ func check(c *sbCtx, r Result, u upstream.Upstream) (Result, error) {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkSafeBrowsing(host string) (Result, error) {
|
// TODO(a.garipov): Unify with checkParental.
|
||||||
|
func (d *DNSFilter) checkSafeBrowsing(
|
||||||
|
host string,
|
||||||
|
_ uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
if !setts.SafeBrowsingEnabled {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if log.GetLevel() >= log.DEBUG {
|
if log.GetLevel() >= log.DEBUG {
|
||||||
timer := log.StartTimer()
|
timer := log.StartTimer()
|
||||||
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
|
defer timer.LogElapsed("SafeBrowsing lookup for %s", host)
|
||||||
}
|
}
|
||||||
ctx := &sbCtx{
|
|
||||||
|
sctx := &sbCtx{
|
||||||
host: host,
|
host: host,
|
||||||
svc: "SafeBrowsing",
|
svc: "SafeBrowsing",
|
||||||
cache: gctx.safebrowsingCache,
|
cache: gctx.safebrowsingCache,
|
||||||
cacheTime: d.Config.CacheTime,
|
cacheTime: d.Config.CacheTime,
|
||||||
}
|
}
|
||||||
res := Result{
|
|
||||||
|
res = Result{
|
||||||
IsFiltered: true,
|
IsFiltered: true,
|
||||||
Reason: FilteredSafeBrowsing,
|
Reason: FilteredSafeBrowsing,
|
||||||
Rules: []*ResultRule{{
|
Rules: []*ResultRule{{
|
||||||
Text: "adguard-malware-shavar",
|
Text: "adguard-malware-shavar",
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
return check(ctx, res, d.safeBrowsingUpstream)
|
|
||||||
|
return check(sctx, res, d.safeBrowsingUpstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkParental(host string) (Result, error) {
|
// TODO(a.garipov): Unify with checkSafeBrowsing.
|
||||||
|
func (d *DNSFilter) checkParental(
|
||||||
|
host string,
|
||||||
|
_ uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
if !setts.ParentalEnabled {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if log.GetLevel() >= log.DEBUG {
|
if log.GetLevel() >= log.DEBUG {
|
||||||
timer := log.StartTimer()
|
timer := log.StartTimer()
|
||||||
defer timer.LogElapsed("Parental lookup for %s", host)
|
defer timer.LogElapsed("Parental lookup for %s", host)
|
||||||
}
|
}
|
||||||
ctx := &sbCtx{
|
|
||||||
|
sctx := &sbCtx{
|
||||||
host: host,
|
host: host,
|
||||||
svc: "Parental",
|
svc: "Parental",
|
||||||
cache: gctx.parentalCache,
|
cache: gctx.parentalCache,
|
||||||
cacheTime: d.Config.CacheTime,
|
cacheTime: d.Config.CacheTime,
|
||||||
}
|
}
|
||||||
res := Result{
|
|
||||||
|
res = Result{
|
||||||
IsFiltered: true,
|
IsFiltered: true,
|
||||||
Reason: FilteredParental,
|
Reason: FilteredParental,
|
||||||
Rules: []*ResultRule{{
|
Rules: []*ResultRule{{
|
||||||
Text: "parental CATEGORY_BLACKLISTED",
|
Text: "parental CATEGORY_BLACKLISTED",
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
return check(ctx, res, d.parentalUpstream)
|
|
||||||
|
return check(sctx, res, d.parentalUpstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
func httpError(r *http.Request, w http.ResponseWriter, code int, format string, args ...interface{}) {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -115,11 +116,16 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
|
||||||
d.SetSafeBrowsingUpstream(ups)
|
d.SetSafeBrowsingUpstream(ups)
|
||||||
d.SetParentalUpstream(ups)
|
d.SetParentalUpstream(ups)
|
||||||
|
|
||||||
_, err := d.checkSafeBrowsing("smthng.com")
|
setts := &FilteringSettings{
|
||||||
assert.NotNil(t, err)
|
SafeBrowsingEnabled: true,
|
||||||
|
ParentalEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
_, err = d.checkParental("smthng.com")
|
_, err := d.checkSafeBrowsing("smthng.com", dns.TypeA, setts)
|
||||||
assert.NotNil(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
_, err = d.checkParental("smthng.com", dns.TypeA, setts)
|
||||||
|
assert.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSBPC(t *testing.T) {
|
func TestSBPC(t *testing.T) {
|
||||||
|
@ -128,10 +134,15 @@ func TestSBPC(t *testing.T) {
|
||||||
|
|
||||||
const hostname = "example.org"
|
const hostname = "example.org"
|
||||||
|
|
||||||
|
setts := &FilteringSettings{
|
||||||
|
SafeBrowsingEnabled: true,
|
||||||
|
ParentalEnabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
block bool
|
block bool
|
||||||
testFunc func(string) (Result, error)
|
testFunc func(host string, _ uint16, _ *FilteringSettings) (res Result, err error)
|
||||||
testCache cache.Cache
|
testCache cache.Cache
|
||||||
}{{
|
}{{
|
||||||
name: "sb_no_block",
|
name: "sb_no_block",
|
||||||
|
@ -167,8 +178,9 @@ func TestSBPC(t *testing.T) {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
// Firstly, check the request blocking.
|
// Firstly, check the request blocking.
|
||||||
hits := 0
|
hits := 0
|
||||||
res, err := tc.testFunc(hostname)
|
res, err := tc.testFunc(hostname, dns.TypeA, setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tc.block {
|
if tc.block {
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
@ -185,8 +197,9 @@ func TestSBPC(t *testing.T) {
|
||||||
assert.Equal(t, 1, ups.RequestsCount())
|
assert.Equal(t, 1, ups.RequestsCount())
|
||||||
|
|
||||||
// Now make the same request to check the cache was used.
|
// Now make the same request to check the cache was used.
|
||||||
res, err = tc.testFunc(hostname)
|
res, err = tc.testFunc(hostname, dns.TypeA, setts)
|
||||||
require.Nil(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if tc.block {
|
if tc.block {
|
||||||
assert.True(t, res.IsFiltered)
|
assert.True(t, res.IsFiltered)
|
||||||
require.Len(t, res.Rules, 1)
|
require.Len(t, res.Rules, 1)
|
||||||
|
|
|
@ -69,7 +69,15 @@ func (d *DNSFilter) SafeSearchDomain(host string) (string, bool) {
|
||||||
return val, ok
|
return val, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
|
func (d *DNSFilter) checkSafeSearch(
|
||||||
|
host string,
|
||||||
|
_ uint16,
|
||||||
|
setts *FilteringSettings,
|
||||||
|
) (res Result, err error) {
|
||||||
|
if !setts.SafeSearchEnabled {
|
||||||
|
return Result{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
if log.GetLevel() >= log.DEBUG {
|
if log.GetLevel() >= log.DEBUG {
|
||||||
timer := log.StartTimer()
|
timer := log.StartTimer()
|
||||||
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
|
defer timer.LogElapsed("SafeSearch: lookup for %s", host)
|
||||||
|
@ -88,7 +96,7 @@ func (d *DNSFilter) checkSafeSearch(host string) (Result, error) {
|
||||||
return Result{}, nil
|
return Result{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
res := Result{
|
res = Result{
|
||||||
IsFiltered: true,
|
IsFiltered: true,
|
||||||
Reason: FilteredSafeSearch,
|
Reason: FilteredSafeSearch,
|
||||||
Rules: []*ResultRule{{}},
|
Rules: []*ResultRule{{}},
|
||||||
|
|
|
@ -25,7 +25,7 @@ type FilteringConfig struct {
|
||||||
// --
|
// --
|
||||||
|
|
||||||
// FilterHandler is an optional additional filtering callback.
|
// FilterHandler is an optional additional filtering callback.
|
||||||
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
|
FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.FilteringSettings) `yaml:"-"`
|
||||||
|
|
||||||
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration
|
||||||
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
// based on the client IP address. Returns nil if there are no custom upstreams for the client
|
||||||
|
|
|
@ -20,7 +20,7 @@ type dnsContext struct {
|
||||||
srv *Server
|
srv *Server
|
||||||
proxyCtx *proxy.DNSContext
|
proxyCtx *proxy.DNSContext
|
||||||
// setts are the filtering settings for the client.
|
// setts are the filtering settings for the client.
|
||||||
setts *dnsfilter.RequestFilteringSettings
|
setts *dnsfilter.FilteringSettings
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
result *dnsfilter.Result
|
result *dnsfilter.Result
|
||||||
// origResp is the response received from upstream. It is set when the
|
// origResp is the response received from upstream. It is set when the
|
||||||
|
|
|
@ -632,7 +632,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||||
FilteringConfig: FilteringConfig{
|
FilteringConfig: FilteringConfig{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) {
|
FilterHandler: func(_ net.IP, _ string, settings *dnsfilter.FilteringSettings) {
|
||||||
settings.FilteringEnabled = false
|
settings.FilteringEnabled = false
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool
|
||||||
|
|
||||||
// getClientRequestFilteringSettings looks up client filtering settings using
|
// getClientRequestFilteringSettings looks up client filtering settings using
|
||||||
// the client's IP address and ID, if any, from ctx.
|
// the client's IP address and ID, if any, from ctx.
|
||||||
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings {
|
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.FilteringSettings {
|
||||||
setts := s.dnsFilter.GetConfig()
|
setts := s.dnsFilter.GetConfig()
|
||||||
setts.FilteringEnabled = true
|
setts.FilteringEnabled = true
|
||||||
if s.conf.FilterHandler != nil {
|
if s.conf.FilterHandler != nil {
|
||||||
|
|
|
@ -276,7 +276,7 @@ func getDNSEncryption() (de dnsEncryption) {
|
||||||
|
|
||||||
// applyAdditionalFiltering adds additional client information and settings if
|
// applyAdditionalFiltering adds additional client information and settings if
|
||||||
// the client has them.
|
// the client has them.
|
||||||
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
|
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.FilteringSettings) {
|
||||||
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
|
||||||
|
|
||||||
if clientAddr == nil {
|
if clientAddr == nil {
|
||||||
|
|
|
@ -375,6 +375,64 @@ func decodeResultIPList(dec *json.Decoder, ent *logEntry) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeResultDNSRewriteResultKey(key string, dec *json.Decoder, ent *logEntry) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case "RCode":
|
||||||
|
var vToken json.Token
|
||||||
|
vToken, err = dec.Token()
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
log.Debug("decodeResultDNSRewriteResultKey err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ent.Result.DNSRewriteResult == nil {
|
||||||
|
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if n, ok := vToken.(json.Number); ok {
|
||||||
|
rcode64, _ := n.Int64()
|
||||||
|
ent.Result.DNSRewriteResult.RCode = rules.RCode(rcode64)
|
||||||
|
}
|
||||||
|
case "Response":
|
||||||
|
if ent.Result.DNSRewriteResult == nil {
|
||||||
|
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ent.Result.DNSRewriteResult.Response == nil {
|
||||||
|
ent.Result.DNSRewriteResult.Response = dnsfilter.DNSRewriteResultResponse{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(a.garipov): I give up. This whole file is a mess.
|
||||||
|
// Luckily, we can assume that this field is relatively rare and
|
||||||
|
// just use the normal decoding and correct the values.
|
||||||
|
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("decodeResultDNSRewriteResultKey response err: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for rrType, rrValues := range ent.Result.DNSRewriteResult.Response {
|
||||||
|
switch rrType {
|
||||||
|
case
|
||||||
|
dns.TypeA,
|
||||||
|
dns.TypeAAAA:
|
||||||
|
for i, v := range rrValues {
|
||||||
|
s, _ := v.(string)
|
||||||
|
rrValues[i] = net.ParseIP(s)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Go on.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
||||||
for {
|
for {
|
||||||
keyToken, err := dec.Token()
|
keyToken, err := dec.Token()
|
||||||
|
@ -401,66 +459,7 @@ func decodeResultDNSRewriteResult(dec *json.Decoder, ent *logEntry) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(a.garipov): Refactor this into a separate
|
decodeResultDNSRewriteResultKey(key, dec, ent)
|
||||||
// function à la decodeResultRuleKey if we keep this
|
|
||||||
// code for a longer time than planned.
|
|
||||||
switch key {
|
|
||||||
case "RCode":
|
|
||||||
var vToken json.Token
|
|
||||||
vToken, err = dec.Token()
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
log.Debug("decodeResultDNSRewriteResult err: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if ent.Result.DNSRewriteResult == nil {
|
|
||||||
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var n json.Number
|
|
||||||
if n, ok = vToken.(json.Number); ok {
|
|
||||||
rcode64, _ := n.Int64()
|
|
||||||
ent.Result.DNSRewriteResult.RCode = rules.RCode(rcode64)
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
case "Response":
|
|
||||||
if ent.Result.DNSRewriteResult == nil {
|
|
||||||
ent.Result.DNSRewriteResult = &dnsfilter.DNSRewriteResult{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ent.Result.DNSRewriteResult.Response == nil {
|
|
||||||
ent.Result.DNSRewriteResult.Response = dnsfilter.DNSRewriteResultResponse{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(a.garipov): I give up. This whole file
|
|
||||||
// is a mess. Luckily, we can assume that this
|
|
||||||
// field is relatively rare and just use the
|
|
||||||
// normal decoding and correct the values.
|
|
||||||
err = dec.Decode(&ent.Result.DNSRewriteResult.Response)
|
|
||||||
if err != nil {
|
|
||||||
log.Debug("decodeResultDNSRewriteResult response err: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for rrType, rrValues := range ent.Result.DNSRewriteResult.Response {
|
|
||||||
switch rrType {
|
|
||||||
case dns.TypeA, dns.TypeAAAA:
|
|
||||||
for i, v := range rrValues {
|
|
||||||
s, _ := v.(string)
|
|
||||||
rrValues[i] = net.ParseIP(s)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// Go on.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
continue
|
|
||||||
default:
|
|
||||||
// Go on.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue