Pull request 1803: 5685-fix-safe-search

Updates #5685.

Squashed commit of the following:

commit 5312147abfa0914c896acbf1e88f8c8f1af90f2b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Apr 6 14:09:44 2023 +0300

    safesearch: imp tests, logs

commit 298b5d24ce292c5f83ebe33d1e92329e4b3c1acc
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:36:16 2023 +0300

    safesearch: fix filters, logging

commit 63d6ca5d694d45705473f2f0410e9e0b49cf7346
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:24:47 2023 +0300

    all: dry; fix logs

commit fdbf2f364fd0484b47b3161bf6f4581856fdf47b
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 5 20:01:08 2023 +0300

    all: fix safe search update
This commit is contained in:
Ainar Garipov 2023-04-06 14:12:50 +03:00
parent a55cbbe79c
commit 61b4043775
19 changed files with 527 additions and 265 deletions

View File

@ -25,6 +25,7 @@ NOTE: Add new changes BELOW THIS COMMENT.
### Added ### Added
- IPv6 support in Safe Search for some services.
- The ability to make bootstrap DNS lookups prefer IPv6 addresses to IPv4 ones - The ability to make bootstrap DNS lookups prefer IPv6 addresses to IPv4 ones
using the new `dns.bootstrap_prefer_ipv6` configuration file property using the new `dns.bootstrap_prefer_ipv6` configuration file property
([#4262]). ([#4262]).

View File

@ -453,8 +453,9 @@ func TestSafeSearch(t *testing.T) {
SafeSearchCacheSize: 1000, SafeSearchCacheSize: 1000,
CacheTime: 30, CacheTime: 30,
} }
safeSearch, err := safesearch.NewDefaultSafeSearch( safeSearch, err := safesearch.NewDefault(
safeSearchConf, safeSearchConf,
"",
filterConf.SafeSearchCacheSize, filterConf.SafeSearchCacheSize,
time.Minute*time.Duration(filterConf.CacheTime), time.Minute*time.Duration(filterConf.CacheTime),
) )

View File

@ -1,17 +1,17 @@
package filtering package filtering
import ( import "github.com/miekg/dns"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
)
// SafeSearch interface describes a service for search engines hosts rewrites. // SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface { type SafeSearch interface {
// SearchHost returns a replacement address for the search engine host. // CheckHost checks host with safe search filter. CheckHost must be safe
SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) // for concurrent use. qtype must be either [dns.TypeA] or [dns.TypeAAAA].
// CheckHost checks host with safe search engine.
CheckHost(host string, qtype uint16) (res Result, err error) CheckHost(host string, qtype uint16) (res Result, err error)
// Update updates the configuration of the safe search filter. Update must
// be safe for concurrent use. An implementation of Update may ignore some
// fields, but it must document which.
Update(conf SafeSearchConfig) (err error)
} }
// SafeSearchConfig is a struct with safe search related settings. // SafeSearchConfig is a struct with safe search related settings.
@ -37,10 +37,12 @@ type SafeSearchConfig struct {
// [hostChecker.check]. // [hostChecker.check].
func (d *DNSFilter) checkSafeSearch( func (d *DNSFilter) checkSafeSearch(
host string, host string,
_ uint16, qtype uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ProtectionEnabled || !setts.SafeSearchEnabled { if !setts.ProtectionEnabled ||
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil return Result{}, nil
} }
@ -50,8 +52,8 @@ func (d *DNSFilter) checkSafeSearch(
clientSafeSearch := setts.ClientSafeSearch clientSafeSearch := setts.ClientSafeSearch
if clientSafeSearch != nil { if clientSafeSearch != nil {
return clientSafeSearch.CheckHost(host, dns.TypeA) return clientSafeSearch.CheckHost(host, qtype)
} }
return d.safeSearch.CheckHost(host, dns.TypeA) return d.safeSearch.CheckHost(host, qtype)
} }

View File

@ -1 +1 @@
|www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com |www.bing.com^$dnsrewrite=NOERROR;CNAME;strict.bing.com

View File

@ -1,3 +1,3 @@
|duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com |duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|start.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com |start.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com
|www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com |www.duckduckgo.com^$dnsrewrite=NOERROR;CNAME;safe.duckduckgo.com

View File

@ -188,4 +188,4 @@
|www.google.tt^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.tt^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.vg^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.vg^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.vu^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.vu^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com
|www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com |www.google.ws^$dnsrewrite=NOERROR;CNAME;forcesafesearch.google.com

View File

@ -1 +1 @@
|pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com |pixabay.com^$dnsrewrite=NOERROR;CNAME;safesearch.pixabay.com

View File

@ -49,4 +49,4 @@
|yandex.ru^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.ru^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.tj^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.tj^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.tm^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.tm^$dnsrewrite=NOERROR;A;213.180.193.56
|yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56 |yandex.uz^$dnsrewrite=NOERROR;A;213.180.193.56

View File

@ -2,4 +2,4 @@
|m.youtube.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |m.youtube.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|youtubei.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |youtubei.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|youtube.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |youtube.googleapis.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com
|www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com |www.youtube-nocookie.com^$dnsrewrite=NOERROR;CNAME;restrictmoderate.youtube.com

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -53,44 +54,85 @@ func isServiceProtected(s filtering.SafeSearchConfig, service Service) (ok bool)
} }
} }
// DefaultSafeSearch is the default safesearch struct. // Default is the default safe search filter that uses filtering rules with the
type DefaultSafeSearch struct { // dnsrewrite modifier.
engine *urlfilter.DNSEngine type Default struct {
safeSearchCache cache.Cache // mu protects engine.
resolver filtering.Resolver mu *sync.RWMutex
cacheTime time.Duration
// engine is the filtering engine that contains the DNS rewrite rules.
// engine may be nil, which means that this safe search filter is disabled.
engine *urlfilter.DNSEngine
cache cache.Cache
resolver filtering.Resolver
logPrefix string
cacheTTL time.Duration
} }
// NewDefaultSafeSearch returns new safesearch struct. CacheTime is an element // NewDefault returns an initialized default safe search filter. name is used
// TTL (in minutes). // for logging.
func NewDefaultSafeSearch( func NewDefault(
conf filtering.SafeSearchConfig, conf filtering.SafeSearchConfig,
name string,
cacheSize uint, cacheSize uint,
cacheTime time.Duration, cacheTTL time.Duration,
) (ss *DefaultSafeSearch, err error) { ) (ss *Default, err error) {
engine, err := newEngine(filtering.SafeSearchListID, conf)
if err != nil {
return nil, err
}
var resolver filtering.Resolver = net.DefaultResolver var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil { if conf.CustomResolver != nil {
resolver = conf.CustomResolver resolver = conf.CustomResolver
} }
return &DefaultSafeSearch{ ss = &Default{
engine: engine, mu: &sync.RWMutex{},
safeSearchCache: cache.New(cache.Config{
cache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxSize: cacheSize, MaxSize: cacheSize,
}), }),
cacheTime: cacheTime, resolver: resolver,
resolver: resolver, // Use %s, because the client safe-search names already contain double
}, nil // quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name),
cacheTTL: cacheTTL,
}
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return ss, nil
} }
// newEngine creates new engine for provided safe search configuration. // log is a helper for logging that includes the name of the safe search
func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.DNSEngine, err error) { // filter. level must be one of [log.DEBUG], [log.INFO], and [log.ERROR].
func (ss *Default) log(level log.Level, msg string, args ...any) {
switch level {
case log.DEBUG:
log.Debug(ss.logPrefix+msg, args...)
case log.INFO:
log.Info(ss.logPrefix+msg, args...)
case log.ERROR:
log.Error(ss.logPrefix+msg, args...)
default:
panic(fmt.Errorf("safesearch: unsupported logging level %d", level))
}
}
// resetEngine creates new engine for provided safe search configuration and
// sets it in ss.
func (ss *Default) resetEngine(
listID int,
conf filtering.SafeSearchConfig,
) (err error) {
if !conf.Enabled {
ss.log(log.INFO, "disabled")
return nil
}
var sb strings.Builder var sb strings.Builder
for service, serviceRules := range safeSearchRules { for service, serviceRules := range safeSearchRules {
if isServiceProtected(conf, service) { if isServiceProtected(conf, service) {
@ -106,20 +148,73 @@ func newEngine(listID int, conf filtering.SafeSearchConfig) (engine *urlfilter.D
rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList}) rs, err := filterlist.NewRuleStorage([]filterlist.RuleList{strList})
if err != nil { if err != nil {
return nil, fmt.Errorf("creating rule storage: %w", err) return fmt.Errorf("creating rule storage: %w", err)
} }
engine = urlfilter.NewDNSEngine(rs) ss.engine = urlfilter.NewDNSEngine(rs)
log.Info("safesearch: filter %d: reset %d rules", listID, engine.RulesCount)
return engine, nil ss.log(log.INFO, "reset %d rules", ss.engine.RulesCount)
return nil
} }
// type check // type check
var _ filtering.SafeSearch = (*DefaultSafeSearch)(nil) var _ filtering.SafeSearch = (*Default)(nil)
// CheckHost implements the [filtering.SafeSearch] interface for
// *DefaultSafeSearch.
func (ss *Default) CheckHost(
host string,
qtype rules.RRType,
) (res filtering.Result, err error) {
start := time.Now()
defer func() {
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}()
if qtype != dns.TypeA && qtype != dns.TypeAAAA {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype))
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host, qtype)
if isFound {
ss.log(log.DEBUG, "found in cache: %q", host)
return cachedValue, nil
}
rewrite := ss.searchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
fltRes, err := ss.newResult(rewrite, qtype)
if err != nil {
ss.log(log.DEBUG, "looking up addresses for %q: %s", host, err)
return filtering.Result{}, err
}
if fltRes != nil {
res = *fltRes
ss.setCacheResult(host, qtype, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses for %q", host)
}
// searchHost looks up DNS rewrites in the internal DNS filtering engine.
func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRewrite) {
ss.mu.RLock()
defer ss.mu.RUnlock()
if ss.engine == nil {
return nil
}
// SearchHost implements the [filtering.SafeSearch] interface for *DefaultSafeSearch.
func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.DNSRewrite) {
r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{ r, _ := ss.engine.MatchRequest(&urlfilter.DNSRequest{
Hostname: strings.ToLower(host), Hostname: strings.ToLower(host),
DNSType: qtype, DNSType: qtype,
@ -133,51 +228,11 @@ func (ss *DefaultSafeSearch) SearchHost(host string, qtype uint16) (res *rules.D
return nil return nil
} }
// CheckHost implements the [filtering.SafeSearch] interface for // newResult creates Result object from rewrite rule. qtype must be either
// *DefaultSafeSearch. // [dns.TypeA] or [dns.TypeAAAA].
func (ss *DefaultSafeSearch) CheckHost( func (ss *Default) newResult(
host string,
qtype uint16,
) (res filtering.Result, err error) {
if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer()
defer timer.LogElapsed("safesearch: lookup for %s", host)
}
// Check cache. Return cached result if it was found
cachedValue, isFound := ss.getCachedResult(host)
if isFound {
log.Debug("safesearch: found in cache: %s", host)
return cachedValue, nil
}
rewrite := ss.SearchHost(host, qtype)
if rewrite == nil {
return filtering.Result{}, nil
}
dRes, err := ss.newResult(rewrite, qtype)
if err != nil {
log.Debug("safesearch: failed to lookup addresses for %s: %s", host, err)
return filtering.Result{}, err
}
if dRes != nil {
res = *dRes
ss.setCacheResult(host, res)
return res, nil
}
return filtering.Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", host)
}
// newResult creates Result object from rewrite rule.
func (ss *DefaultSafeSearch) newResult(
rewrite *rules.DNSRewrite, rewrite *rules.DNSRewrite,
qtype uint16, qtype rules.RRType,
) (res *filtering.Result, err error) { ) (res *filtering.Result, err error) {
res = &filtering.Result{ res = &filtering.Result{
Rules: []*filtering.ResultRule{{ Rules: []*filtering.ResultRule{{
@ -187,7 +242,7 @@ func (ss *DefaultSafeSearch) newResult(
IsFiltered: true, IsFiltered: true,
} }
if rewrite.RRType == qtype && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { if rewrite.RRType == qtype {
ip, ok := rewrite.Value.(net.IP) ip, ok := rewrite.Value.(net.IP)
if !ok || ip == nil { if !ok || ip == nil {
return nil, nil return nil, nil
@ -198,17 +253,25 @@ func (ss *DefaultSafeSearch) newResult(
return res, nil return res, nil
} }
if rewrite.NewCNAME == "" { host := rewrite.NewCNAME
if host == "" {
return nil, nil return nil, nil
} }
ips, err := ss.resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME) ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips { for _, ip := range ips {
if ip = ip.To4(); ip == nil { // TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
ip = fitToProto(ip, qtype)
if ip == nil {
continue continue
} }
@ -220,38 +283,71 @@ func (ss *DefaultSafeSearch) newResult(
return nil, nil return nil, nil
} }
// setCacheResult stores data in cache for host. // qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
func (ss *DefaultSafeSearch) setCacheResult(host string, res filtering.Result) { // It panics for other types.
expire := uint32(time.Now().Add(ss.cacheTime).Unix()) func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res net.IP) {
ip4 := ip.To4()
if qtype == dns.TypeA {
return ip4
}
if ip4 == nil {
return ip
}
return nil
}
// setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {
expire := uint32(time.Now().Add(ss.cacheTTL).Unix())
exp := make([]byte, 4) exp := make([]byte, 4)
binary.BigEndian.PutUint32(exp, expire) binary.BigEndian.PutUint32(exp, expire)
buf := bytes.NewBuffer(exp) buf := bytes.NewBuffer(exp)
err := gob.NewEncoder(buf).Encode(res) err := gob.NewEncoder(buf).Encode(res)
if err != nil { if err != nil {
log.Error("safesearch: cache encoding: %s", err) ss.log(log.ERROR, "cache encoding: %s", err)
return return
} }
val := buf.Bytes() val := buf.Bytes()
_ = ss.safeSearchCache.Set([]byte(host), val) _ = ss.cache.Set([]byte(dns.Type(qtype).String()+" "+host), val)
log.Debug("safesearch: stored in cache: %s (%d bytes)", host, len(val)) ss.log(log.DEBUG, "stored in cache: %q, %d bytes", host, len(val))
} }
// getCachedResult returns stored data from cache for host. // getCachedResult returns stored data from cache for host. qtype is expected
func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result, ok bool) { // to be either [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) getCachedResult(
host string,
qtype rules.RRType,
) (res filtering.Result, ok bool) {
res = filtering.Result{} res = filtering.Result{}
data := ss.safeSearchCache.Get([]byte(host)) data := ss.cache.Get([]byte(dns.Type(qtype).String() + " " + host))
if data == nil { if data == nil {
return res, false return res, false
} }
exp := binary.BigEndian.Uint32(data[:4]) exp := binary.BigEndian.Uint32(data[:4])
if exp <= uint32(time.Now().Unix()) { if exp <= uint32(time.Now().Unix()) {
ss.safeSearchCache.Del([]byte(host)) ss.cache.Del([]byte(host))
return res, false return res, false
} }
@ -260,10 +356,27 @@ func (ss *DefaultSafeSearch) getCachedResult(host string) (res filtering.Result,
err := gob.NewDecoder(buf).Decode(&res) err := gob.NewDecoder(buf).Decode(&res)
if err != nil { if err != nil {
log.Debug("safesearch: cache decoding: %s", err) ss.log(log.ERROR, "cache decoding: %s", err)
return filtering.Result{}, false return filtering.Result{}, false
} }
return res, true return res, true
} }
// Update implements the [filtering.SafeSearch] interface for *Default. Update
// ignores the CustomResolver and Enabled fields.
func (ss *Default) Update(conf filtering.SafeSearchConfig) (err error) {
ss.mu.Lock()
defer ss.mu.Unlock()
err = ss.resetEngine(filtering.SafeSearchListID, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
ss.cache.Clear()
return nil
}

View File

@ -0,0 +1,137 @@
package safesearch
import (
"context"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TODO(a.garipov): Move as much of this as possible into proper external tests.
const (
// TODO(a.garipov): Add IPv6 tests.
testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
)
var defaultSafeSearchConf = filtering.SafeSearchConfig{
Enabled: true,
Bing: true,
DuckDuckGo: true,
Google: true,
Pixabay: true,
Yandex: true,
YouTube: true,
}
var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *Default) {
ss, err := NewDefault(ssConf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.searchHost("www.google.com", testQType)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestSafeSearchCacheYandex(t *testing.T) {
const domain = "yandex.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
// For yandex we already know valid IP.
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.searchHost(googleHost, testQType)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}

View File

@ -1,26 +1,37 @@
package safesearch package safesearch_test
import ( import (
"context"
"net" "net"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// Common test constants.
const ( const (
safeSearchCacheSize = 5000 // TODO(a.garipov): Add IPv6 tests.
cacheTime = 30 * time.Minute testQType = dns.TypeA
testCacheSize = 5000
testCacheTTL = 30 * time.Minute
) )
var defaultSafeSearchConf = filtering.SafeSearchConfig{ // testConf is the default safe search configuration for tests.
Enabled: true, var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true,
Bing: true, Bing: true,
DuckDuckGo: true, DuckDuckGo: true,
Google: true, Google: true,
@ -29,25 +40,15 @@ var defaultSafeSearchConf = filtering.SafeSearchConfig{
YouTube: true, YouTube: true,
} }
// yandexIP is the expected IP address of Yandex safe search results. Keep in
// sync with the rules data.
var yandexIP = net.IPv4(213, 180, 193, 56) var yandexIP = net.IPv4(213, 180, 193, 56)
func newForTest(t testing.TB, ssConf filtering.SafeSearchConfig) (ss *DefaultSafeSearch) { func TestDefault_CheckHost_yandex(t *testing.T) {
ss, err := NewDefaultSafeSearch(ssConf, safeSearchCacheSize, cacheTime) conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
return ss
}
func TestSafeSearch(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
val := ss.SearchHost("www.google.com", dns.TypeA)
assert.Equal(t, &rules.DNSRewrite{NewCNAME: "forcesafesearch.google.com"}, val)
}
func TestCheckHostSafeSearchYandex(t *testing.T) {
ss := newForTest(t, defaultSafeSearchConf)
// Check host for each domain. // Check host for each domain.
for _, host := range []string{ for _, host := range []string{
"yandex.ru", "yandex.ru",
@ -57,7 +58,8 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
"yandex.kz", "yandex.kz",
"www.yandex.com", "www.yandex.com",
} { } {
res, err := ss.CheckHost(host, dns.TypeA) var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@ -69,12 +71,14 @@ func TestCheckHostSafeSearchYandex(t *testing.T) {
} }
} }
func TestCheckHostSafeSearchGoogle(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.TestResolver{} resolver := &aghtest.TestResolver{}
ip, _ := resolver.HostToIPs("forcesafesearch.google.com") ip, _ := resolver.HostToIPs("forcesafesearch.google.com")
ss := newForTest(t, defaultSafeSearchConf) conf := testConf
ss.resolver = resolver conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
// Check host for each domain. // Check host for each domain.
for _, host := range []string{ for _, host := range []string{
@ -87,7 +91,8 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
"www.google.je", "www.google.je",
} { } {
t.Run(host, func(t *testing.T) { t.Run(host, func(t *testing.T) {
res, err := ss.CheckHost(host, dns.TypeA) var res filtering.Result
res, err = ss.CheckHost(host, testQType)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
@ -100,103 +105,35 @@ func TestCheckHostSafeSearchGoogle(t *testing.T) {
} }
} }
func TestSafeSearchCacheYandex(t *testing.T) { func TestDefault_Update(t *testing.T) {
const domain = "yandex.ru" conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
// Check host with disabled safesearch.
res, err := ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, res.IsFiltered) res, err := ss.CheckHost("www.yandex.com", testQType)
assert.Empty(t, res.Rules)
ss = newForTest(t, defaultSafeSearchConf)
res, err = ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err) require.NoError(t, err)
// For yandex we already know valid IP. assert.True(t, res.IsFiltered)
require.Len(t, res.Rules, 1)
assert.Equal(t, res.Rules[0].IP, yandexIP) err = ss.Update(filtering.SafeSearchConfig{
Enabled: true,
// Check cache. Google: false,
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
}
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.TestResolver{}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.SearchHost(domain, dns.TypeA)
ips, err := resolver.LookupIP(context.Background(), "ip", rewrite.NewCNAME)
require.NoError(t, err)
var foundIP net.IP
for _, ip := range ips {
if ip.To4() != nil {
foundIP = ip
break
}
}
res, err = ss.CheckHost(domain, dns.TypeA)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.True(t, res.Rules[0].IP.Equal(foundIP))
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.True(t, cachedValue.Rules[0].IP.Equal(foundIP))
}
const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ {
dnsRewriteSink = ss.SearchHost(googleHost, dns.TypeA)
}
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteSink.NewCNAME)
}
var dnsRewriteParallelSink *rules.DNSRewrite
func BenchmarkSafeSearch_parallel(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
dnsRewriteParallelSink = ss.SearchHost(googleHost, dns.TypeA)
}
}) })
require.NoError(t, err)
assert.Equal(b, "forcesafesearch.google.com", dnsRewriteParallelSink.NewCNAME) res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
err = ss.Update(filtering.SafeSearchConfig{
Enabled: false,
Google: true,
})
require.NoError(t, err)
res, err = ss.CheckHost("www.yandex.com", testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
} }

View File

@ -50,11 +50,19 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
return return
} }
conf := *req
err = d.safeSearch.Update(conf)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
return
}
func() { func() {
d.confLock.Lock() d.confLock.Lock()
defer d.confLock.Unlock() defer d.confLock.Unlock()
d.Config.SafeSearchConf = *req d.Config.SafeSearchConf = conf
}() }()
d.Config.ConfigModified() d.Config.ConfigModified()

View File

@ -3,8 +3,10 @@ package home
import ( import (
"encoding" "encoding"
"fmt" "fmt"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
) )
@ -45,6 +47,23 @@ func (c *Client) closeUpstreams() (err error) {
return nil return nil
} }
// setSafeSearch initializes and sets the safe search filter for this client.
func (c *Client) setSafeSearch(
conf filtering.SafeSearchConfig,
cacheSize uint,
cacheTTL time.Duration,
) (err error) {
ss, err := safesearch.NewDefault(conf, fmt.Sprintf("client %q", c.Name), cacheSize, cacheTTL)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
c.SafeSearch = ss
return nil
}
// clientSource represents the source from which the information about the // clientSource represents the source from which the information about the
// client has been obtained. // client has been obtained.
type clientSource uint type clientSource uint

View File

@ -13,7 +13,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
@ -55,6 +54,14 @@ type clientsContainer struct {
// more detail. // more detail.
lock sync.Mutex lock sync.Mutex
// safeSearchCacheSize is the size of the safe search cache to use for
// persistent clients.
safeSearchCacheSize uint
// safeSearchCacheTTL is the TTL of the safe search cache to use for
// persistent clients.
safeSearchCacheTTL time.Duration
// testing is a flag that disables some features for internal tests. // testing is a flag that disables some features for internal tests.
// //
// TODO(a.garipov): Awful. Remove. // TODO(a.garipov): Awful. Remove.
@ -74,6 +81,7 @@ func (clients *clientsContainer) Init(
if clients.list != nil { if clients.list != nil {
log.Fatal("clients.list != nil") log.Fatal("clients.list != nil")
} }
clients.list = make(map[string]*Client) clients.list = make(map[string]*Client)
clients.idIndex = make(map[string]*Client) clients.idIndex = make(map[string]*Client)
clients.ipToRC = map[netip.Addr]*RuntimeClient{} clients.ipToRC = map[netip.Addr]*RuntimeClient{}
@ -85,6 +93,9 @@ func (clients *clientsContainer) Init(
clients.arpdb = arpdb clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf) clients.addFromConfig(objects, filteringConf)
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing { if clients.testing {
return return
} }
@ -171,18 +182,16 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled { if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{} o.SafeSearchConf.CustomResolver = safeSearchResolver{}
ss, err := safesearch.NewDefaultSafeSearch( err := cli.setSafeSearch(
o.SafeSearchConf, o.SafeSearchConf,
filteringConf.SafeSearchCacheSize, filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime), time.Minute*time.Duration(filteringConf.CacheTime),
) )
if err != nil { if err != nil {
log.Error("clients: init client safesearch %s: %s", cli.Name, err) log.Error("clients: init client safesearch %q: %s", cli.Name, err)
continue continue
} }
cli.SafeSearch = ss
} }
for _, s := range o.BlockedServices { for _, s := range o.BlockedServices {

View File

@ -9,17 +9,27 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestClients(t *testing.T) { // newClientsContainer is a helper that creates a new clients container for
clients := clientsContainer{} // tests.
clients.testing = true func newClientsContainer() (c *clientsContainer) {
c = &clientsContainer{
testing: true,
}
clients.Init(nil, nil, nil, nil, nil) c.Init(nil, nil, nil, nil, &filtering.Config{})
return c
}
func TestClients(t *testing.T) {
clients := newClientsContainer()
t.Run("add_success", func(t *testing.T) { t.Run("add_success", func(t *testing.T) {
var ( var (
@ -198,10 +208,7 @@ func TestClients(t *testing.T) {
} }
func TestClientsWHOIS(t *testing.T) { func TestClientsWHOIS(t *testing.T) {
clients := clientsContainer{ clients := newClientsContainer()
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
whois := &RuntimeClientWHOISInfo{ whois := &RuntimeClientWHOISInfo{
Country: "AU", Country: "AU",
Orgname: "Example Org", Orgname: "Example Org",
@ -247,10 +254,7 @@ func TestClientsWHOIS(t *testing.T) {
} }
func TestClientsAddExisting(t *testing.T) { func TestClientsAddExisting(t *testing.T) {
clients := clientsContainer{ clients := newClientsContainer()
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1") ip := netip.MustParseAddr("1.1.1.1")
@ -325,10 +329,7 @@ func TestClientsAddExisting(t *testing.T) {
} }
func TestClientsCustomUpstream(t *testing.T) { func TestClientsCustomUpstream(t *testing.T) {
clients := clientsContainer{ clients := newClientsContainer()
testing: true,
}
clients.Init(nil, nil, nil, nil, nil)
// Add client with upstreams. // Add client with upstreams.
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{

View File

@ -49,8 +49,8 @@ type clientJSON struct {
type runtimeClientJSON struct { type runtimeClientJSON struct {
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
Name string `json:"name"`
IP netip.Addr `json:"ip"` IP netip.Addr `json:"ip"`
Name string `json:"name"`
Source clientSource `json:"source"` Source clientSource `json:"source"`
} }
@ -90,14 +90,16 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
} }
// jsonToClient converts JSON object to Client object. // jsonToClient converts JSON object to Client object.
func jsonToClient(cj clientJSON) (c *Client) { func (clients *clientsContainer) jsonToClient(cj clientJSON) (c *Client, err error) {
var safeSearchConf filtering.SafeSearchConfig var safeSearchConf filtering.SafeSearchConfig
if cj.SafeSearchConf != nil { if cj.SafeSearchConf != nil {
safeSearchConf = *cj.SafeSearchConf safeSearchConf = *cj.SafeSearchConf
} else { } else {
// TODO(d.kolyshev): Remove after cleaning the deprecated // TODO(d.kolyshev): Remove after cleaning the deprecated
// [clientJSON.SafeSearchEnabled] field. // [clientJSON.SafeSearchEnabled] field.
safeSearchConf = filtering.SafeSearchConfig{Enabled: cj.SafeSearchEnabled} safeSearchConf = filtering.SafeSearchConfig{
Enabled: cj.SafeSearchEnabled,
}
// Set default service flags for enabled safesearch. // Set default service flags for enabled safesearch.
if safeSearchConf.Enabled { if safeSearchConf.Enabled {
@ -110,20 +112,35 @@ func jsonToClient(cj clientJSON) (c *Client) {
} }
} }
return &Client{ c = &Client{
Name: cj.Name, safeSearchConf: safeSearchConf,
IDs: cj.IDs,
Tags: cj.Tags, Name: cj.Name,
IDs: cj.IDs,
Tags: cj.Tags,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings, UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled, FilteringEnabled: cj.FilteringEnabled,
ParentalEnabled: cj.ParentalEnabled, ParentalEnabled: cj.ParentalEnabled,
SafeBrowsingEnabled: cj.SafeBrowsingEnabled, SafeBrowsingEnabled: cj.SafeBrowsingEnabled,
safeSearchConf: safeSearchConf,
UseOwnBlockedServices: !cj.UseGlobalBlockedServices, UseOwnBlockedServices: !cj.UseGlobalBlockedServices,
BlockedServices: cj.BlockedServices,
Upstreams: cj.Upstreams,
} }
if safeSearchConf.Enabled {
err = c.setSafeSearch(
safeSearchConf,
clients.safeSearchCacheSize,
clients.safeSearchCacheTTL,
)
if err != nil {
return nil, fmt.Errorf("creating safesearch for client %q: %w", c.Name, err)
}
}
return c, nil
} }
// clientToJSON converts Client object to JSON. // clientToJSON converts Client object to JSON.
@ -161,7 +178,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return return
} }
c := jsonToClient(cj) c, err := clients.jsonToClient(cj)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
ok, err := clients.Add(c) ok, err := clients.Add(c)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -224,7 +247,13 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return return
} }
c := jsonToClient(dj.Data) c, err := clients.jsonToClient(dj.Data)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return
}
err = clients.Update(dj.Name, c) err = clients.Update(dj.Name, c)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)

View File

@ -545,6 +545,8 @@ var _ filtering.Resolver = safeSearchResolver{}
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver. // LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.IP with IPv4 and IPv6 instances. // It returns the slice of net.IP with IPv4 and IPv6 instances.
//
// TODO(a.garipov): Support network.
func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) { func (r safeSearchResolver) LookupIP(_ context.Context, _, host string) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(host) addrs, err := Context.dnsServer.Resolve(host)
if err != nil { if err != nil {

View File

@ -297,8 +297,9 @@ func setupConfig(opts options) (err error) {
config.DNS.DnsfilterConf.HTTPClient = Context.client config.DNS.DnsfilterConf.HTTPClient = Context.client
config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{} config.DNS.DnsfilterConf.SafeSearchConf.CustomResolver = safeSearchResolver{}
config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefaultSafeSearch( config.DNS.DnsfilterConf.SafeSearch, err = safesearch.NewDefault(
config.DNS.DnsfilterConf.SafeSearchConf, config.DNS.DnsfilterConf.SafeSearchConf,
"default",
config.DNS.DnsfilterConf.SafeSearchCacheSize, config.DNS.DnsfilterConf.SafeSearchCacheSize,
time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime), time.Minute*time.Duration(config.DNS.DnsfilterConf.CacheTime),
) )
@ -869,8 +870,10 @@ func detectFirstRun() bool {
// Connect to a remote server resolving hostname using our own DNS server. // Connect to a remote server resolving hostname using our own DNS server.
// //
// TODO(e.burkov): This messy logic should be decomposed and clarified. // TODO(e.burkov): This messy logic should be decomposed and clarified.
//
// TODO(a.garipov): Support network.
func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { func customDialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
log.Tracef("network:%v addr:%v", network, addr) log.Debug("home: customdial: dialing addr %q for network %s", addr, network)
host, port, err := net.SplitHostPort(addr) host, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {