all: sync with master

This commit is contained in:
Ainar Garipov 2022-08-17 18:23:30 +03:00
parent 3c17853344
commit 9200163f85
25 changed files with 740 additions and 326 deletions

View File

@ -222,6 +222,7 @@
"updated_upstream_dns_toast": "Upstream-servere er gemt", "updated_upstream_dns_toast": "Upstream-servere er gemt",
"dns_test_ok_toast": "Angivne DNS-servere fungerer korrekt", "dns_test_ok_toast": "Angivne DNS-servere fungerer korrekt",
"dns_test_not_ok_toast": "Server \"{{key}}\": Kunne ikke bruges. Tjek, at du har angivet den korrekt", "dns_test_not_ok_toast": "Server \"{{key}}\": Kunne ikke bruges. Tjek, at du har angivet den korrekt",
"dns_test_warning_toast": "Upstream \"{{key}}\" svarer ikke på testforespørgsler og fungerer muligvis ikke korrekt",
"unblock": "Afblokering", "unblock": "Afblokering",
"block": "Blokering", "block": "Blokering",
"disallow_this_client": "Afvis denne klient", "disallow_this_client": "Afvis denne klient",

View File

@ -47,7 +47,7 @@
"form_error_server_name": "Nombre de servidor no válido", "form_error_server_name": "Nombre de servidor no válido",
"form_error_subnet": "La subred \"{{cidr}}\" no contiene la dirección IP \"{{ip}}\"", "form_error_subnet": "La subred \"{{cidr}}\" no contiene la dirección IP \"{{ip}}\"",
"form_error_positive": "Debe ser mayor que 0", "form_error_positive": "Debe ser mayor que 0",
"form_error_gateway_ip": "Asignación no puede tener la dirección IP de la puerta de enlace", "form_error_gateway_ip": "La asignación no puede tener la dirección IP de la puerta de enlace",
"out_of_range_error": "Debe estar fuera del rango \"{{start}}\"-\"{{end}}\"", "out_of_range_error": "Debe estar fuera del rango \"{{start}}\"-\"{{end}}\"",
"lower_range_start_error": "Debe ser inferior que el inicio de rango", "lower_range_start_error": "Debe ser inferior que el inicio de rango",
"greater_range_start_error": "Debe ser mayor que el inicio de rango", "greater_range_start_error": "Debe ser mayor que el inicio de rango",
@ -222,7 +222,7 @@
"updated_upstream_dns_toast": "Servidores DNS de subida guardados correctamente", "updated_upstream_dns_toast": "Servidores DNS de subida guardados correctamente",
"dns_test_ok_toast": "Los servidores DNS especificados funcionan correctamente", "dns_test_ok_toast": "Los servidores DNS especificados funcionan correctamente",
"dns_test_not_ok_toast": "Servidor \"{{key}}\": no se puede utilizar, por favor revisa si lo has escrito correctamente", "dns_test_not_ok_toast": "Servidor \"{{key}}\": no se puede utilizar, por favor revisa si lo has escrito correctamente",
"dns_test_warning_toast": "Upstream \"{{key}}\" no responde a las peticiones de prueba y es posible que no funcione correctamente", "dns_test_warning_toast": "DNS de subida \"{{key}}\" no responde a las peticiones de prueba y es posible que no funcione correctamente",
"unblock": "Desbloquear", "unblock": "Desbloquear",
"block": "Bloquear", "block": "Bloquear",
"disallow_this_client": "No permitir a este cliente", "disallow_this_client": "No permitir a este cliente",
@ -364,7 +364,7 @@
"encryption_config_saved": "Configuración de cifrado guardado", "encryption_config_saved": "Configuración de cifrado guardado",
"encryption_server": "Nombre del servidor", "encryption_server": "Nombre del servidor",
"encryption_server_enter": "Ingresa el nombre del dominio", "encryption_server_enter": "Ingresa el nombre del dominio",
"encryption_server_desc": "Si se configura, AdGuard Home detecta los ClientID, responde a las consultas DDR y realiza validaciones de conexión adicionales. Si no se configura, estas funciones están deshabilitadas. Debe coincidir con uno de los nombres DNS del certificado.", "encryption_server_desc": "Si se configura, AdGuard Home detecta los ID de clientes, responde a las consultas DDR y realiza validaciones de conexión adicionales. Si no se configura, estas funciones se deshabilitarán. Debe coincidir con uno de los nombres DNS del certificado.",
"encryption_redirect": "Redireccionar a HTTPS automáticamente", "encryption_redirect": "Redireccionar a HTTPS automáticamente",
"encryption_redirect_desc": "Si está marcado, AdGuard Home redireccionará automáticamente de HTTP a las direcciones HTTPS.", "encryption_redirect_desc": "Si está marcado, AdGuard Home redireccionará automáticamente de HTTP a las direcciones HTTPS.",
"encryption_https": "Puerto HTTPS", "encryption_https": "Puerto HTTPS",

View File

@ -10,6 +10,20 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
// Coalesce returns the first non-zero value. It is named after the function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// UniqChecker allows validating uniqueness of comparable items. // UniqChecker allows validating uniqueness of comparable items.
// //
// TODO(a.garipov): The Ordered constraint is only really necessary in Validate. // TODO(a.garipov): The Ordered constraint is only really necessary in Validate.

View File

@ -470,7 +470,7 @@ func TestHostsContainer(t *testing.T) {
}}, }},
}, { }, {
req: &urlfilter.DNSRequest{ req: &urlfilter.DNSRequest{
Hostname: "nonexisting", Hostname: "nonexistent.example",
DNSType: dns.TypeA, DNSType: dns.TypeA,
}, },
name: "non-existing", name: "non-existing",

View File

@ -154,10 +154,13 @@ func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) {
return netIfaces, nil return netIfaces, nil
} }
// GetInterfaceByIP returns the name of interface containing provided ip. // InterfaceByIP returns the name of the interface bound to ip.
// //
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. // TODO(a.garipov, e.burkov): This function is technically incorrect, since one
func GetInterfaceByIP(ip net.IP) string { // IP address can be shared by multiple interfaces in some configurations.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func InterfaceByIP(ip net.IP) (ifaceName string) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
return "" return ""
@ -177,7 +180,7 @@ func GetInterfaceByIP(ip net.IP) string {
// GetSubnet returns pointer to net.IPNet for the specified interface or nil if // GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// the search fails. // the search fails.
// //
// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. // TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) *net.IPNet { func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb() netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {

View File

@ -132,7 +132,7 @@ func TestGatewayIP(t *testing.T) {
} }
} }
func TestGetInterfaceByIP(t *testing.T) { func TestInterfaceByIP(t *testing.T) {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, ifaces) require.NotEmpty(t, ifaces)
@ -142,7 +142,7 @@ func TestGetInterfaceByIP(t *testing.T) {
require.NotEmpty(t, iface.Addresses) require.NotEmpty(t, iface.Addresses)
for _, ip := range iface.Addresses { for _, ip := range iface.Addresses {
ifaceName := GetInterfaceByIP(ip) ifaceName := InterfaceByIP(ip)
require.Equal(t, iface.Name, ifaceName) require.Equal(t, iface.Name, ifaceName)
} }
}) })

View File

@ -1,4 +1,4 @@
package aghos package aghos_test
import ( import (
"testing" "testing"

View File

@ -0,0 +1,57 @@
package aghos
import (
"io/fs"
"path"
"testing"
"testing/fstest"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// errFS is an fs.FS implementation, method Open of which always returns
// errFSOpen.
type errFS struct{}
// errFSOpen is returned from errGlobFS.Open.
const errFSOpen errors.Error = "test open error"
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and
// err is always errFSOpen.
func (efs *errFS) Open(name string) (fsys fs.File, err error) {
return nil, errFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@ -1,13 +1,13 @@
package aghos package aghos_test
import ( import (
"bufio" "bufio"
"io" "io"
"io/fs"
"path" "path"
"testing" "testing"
"testing/fstest" "testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -16,7 +16,7 @@ import (
func TestFileWalker_Walk(t *testing.T) { func TestFileWalker_Walk(t *testing.T) {
const attribute = `000` const attribute = `000`
makeFileWalker := func(_ string) (fw FileWalker) { makeFileWalker := func(_ string) (fw aghos.FileWalker) {
return func(r io.Reader) (patterns []string, cont bool, err error) { return func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) {
f := fstest.MapFS{ f := fstest.MapFS{
filename: &fstest.MapFile{Data: []byte("[]")}, filename: &fstest.MapFile{Data: []byte("[]")},
} }
ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) {
s := bufio.NewScanner(r) s := bufio.NewScanner(r)
for s.Scan() { for s.Scan() {
patterns = append(patterns, s.Text()) patterns = append(patterns, s.Text())
@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) {
"mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)}, "mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)},
} }
ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) {
return nil, true, rerr return nil, true, rerr
}).Walk(f, "*") }).Walk(f, "*")
require.ErrorIs(t, err, rerr) require.ErrorIs(t, err, rerr)
@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) {
assert.False(t, ok) assert.False(t, ok)
}) })
} }
type errFS struct {
fs.GlobFS
}
const errErrFSOpen errors.Error = "this error is always returned"
func (efs *errFS) Open(name string) (fs.File, error) {
return nil, errErrFSOpen
}
func TestWalkerFunc_CheckFile(t *testing.T) {
emptyFS := fstest.MapFS{}
t.Run("non-existing", func(t *testing.T) {
_, ok, err := checkFile(emptyFS, nil, "lol")
require.NoError(t, err)
assert.True(t, ok)
})
t.Run("invalid_argument", func(t *testing.T) {
_, ok, err := checkFile(&errFS{}, nil, "")
require.ErrorIs(t, err, errErrFSOpen)
assert.False(t, ok)
})
t.Run("ignore_dirs", func(t *testing.T) {
const dirName = "dir"
testFS := fstest.MapFS{
path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}},
}
patterns, ok, err := checkFile(testFS, nil, dirName)
require.NoError(t, err)
assert.Empty(t, patterns)
assert.True(t, ok)
})
}

View File

@ -1,20 +0,0 @@
package aghtest
import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Exchanger is a mock aghnet.Exchanger implementation for tests.
type Exchanger struct {
Ups upstream.Upstream
}
// Exchange implements aghnet.Exchanger interface for *Exchanger.
func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
if e.Ups == nil {
e.Ups = &TestErrUpstream{}
}
return e.Ups.Exchange(req)
}

View File

@ -1,23 +0,0 @@
package aghtest
// FSWatcher is a mock aghos.FSWatcher implementation to use in tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the aghos.FSWatcher interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@ -0,0 +1,135 @@
package aghtest
import (
"io/fs"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)
// Interface Mocks
//
// Keep entities in this file in alphabetic order.
// Standard Library
// type check
var _ fs.FS = &FS{}
// FS is a mock [fs.FS] implementation for tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the [fs.FS] interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock [fs.GlobFS] implementation for tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the [fs.GlobFS] interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock [fs.StatFS] implementation for tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the [fs.StatFS] interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ net.Listener = (*Listener)(nil)
// Listener is a mock [net.Listener] implementation for tests.
type Listener struct {
OnAccept func() (conn net.Conn, err error)
OnAddr func() (addr net.Addr)
OnClose func() (err error)
}
// Accept implements the [net.Listener] interface for *Listener.
func (l *Listener) Accept() (conn net.Conn, err error) {
return l.OnAccept()
}
// Addr implements the [net.Listener] interface for *Listener.
func (l *Listener) Addr() (addr net.Addr) {
return l.OnAddr()
}
// Close implements the [net.Listener] interface for *Listener.
func (l *Listener) Close() (err error) {
return l.OnClose()
}
// Module dnsproxy
// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Module AdGuardHome
// type check
var _ aghos.FSWatcher = (*FSWatcher)(nil)
// FSWatcher is a mock [aghos.FSWatcher] implementation for tests.
type FSWatcher struct {
OnEvents func() (e <-chan struct{})
OnAdd func(name string) (err error)
OnClose func() (err error)
}
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Events() (e <-chan struct{}) {
return w.OnEvents()
}
// Add implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Add(name string) (err error) {
return w.OnAdd(name)
}
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
func (w *FSWatcher) Close() (err error) {
return w.OnClose()
}

View File

@ -0,0 +1,9 @@
package aghtest_test
import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
)
// type check
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)

View File

@ -1,46 +0,0 @@
package aghtest
import "io/fs"
// type check
var _ fs.FS = &FS{}
// FS is a mock fs.FS implementation to use in tests.
type FS struct {
OnOpen func(name string) (fs.File, error)
}
// Open implements the fs.FS interface for *FS.
func (fsys *FS) Open(name string) (fs.File, error) {
return fsys.OnOpen(name)
}
// type check
var _ fs.StatFS = &StatFS{}
// StatFS is a mock fs.StatFS implementation to use in tests.
type StatFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnStat func(name string) (fs.FileInfo, error)
}
// Stat implements the fs.StatFS interface for *StatFS.
func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name)
}
// type check
var _ fs.GlobFS = &GlobFS{}
// GlobFS is a mock fs.GlobFS implementation to use in tests.
type GlobFS struct {
// FS is embedded here to avoid implementing all it's methods.
FS
OnGlob func(pattern string) ([]string, error)
}
// Glob implements the fs.GlobFS interface for *GlobFS.
func (fsys *GlobFS) Glob(pattern string) ([]string, error) {
return fsys.OnGlob(pattern)
}

View File

@ -6,12 +6,18 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "testing"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/require"
) )
// Additional Upstream Testing Utilities
// Upstream is a mock implementation of upstream.Upstream. // Upstream is a mock implementation of upstream.Upstream.
//
// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream.
type Upstream struct { type Upstream struct {
// CName is a map of hostname to canonical name. // CName is a map of hostname to canonical name.
CName map[string][]string CName map[string][]string
@ -25,6 +31,43 @@ type Upstream struct {
Addr string Addr string
} }
// RespondTo returns a response with answer if req has class cl, question type
// qt, and target targ.
func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) {
t.Helper()
require.NotNil(t, req)
require.Len(t, req.Question, 1)
q := req.Question[0]
targ = dns.Fqdn(targ)
if q.Qclass != cl || q.Qtype != qt || q.Name != targ {
return nil
}
respHdr := dns.RR_Header{
Name: targ,
Rrtype: qt,
Class: cl,
Ttl: 60,
}
resp = new(dns.Msg).SetReply(req)
switch qt {
case dns.TypePTR:
resp.Answer = []dns.RR{
&dns.PTR{
Hdr: respHdr,
Ptr: answer,
},
}
default:
t.Fatalf("unsupported question type: %s", dns.Type(qt))
}
return resp
}
// Exchange implements the upstream.Upstream interface for *Upstream. // Exchange implements the upstream.Upstream interface for *Upstream.
// //
// TODO(a.garipov): Split further into handlers. // TODO(a.garipov): Split further into handlers.
@ -76,74 +119,57 @@ func (u *Upstream) Address() string {
return u.Addr return u.Addr
} }
// TestBlockUpstream implements upstream.Upstream interface for replacing real // NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that
// upstream in tests. // supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is
type TestBlockUpstream struct { // true, hostname's actual hash is returned, blocking it. Otherwise, it returns
Hostname string // a different hash.
func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) {
// lock protects reqNum. hash := sha256.Sum256([]byte(hostname))
lock sync.RWMutex hashStr := hex.EncodeToString(hash[:])
reqNum int if !shouldBlock {
hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
Block bool
}
// Exchange returns a message unique for TestBlockUpstream's Hostname-Block
// pair.
func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) {
u.lock.Lock()
defer u.lock.Unlock()
u.reqNum++
hash := sha256.Sum256([]byte(u.Hostname))
hashToReturn := hex.EncodeToString(hash[:])
if !u.Block {
hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28)
} }
m := &dns.Msg{} ans := &dns.TXT{
m.SetReply(r) Hdr: dns.RR_Header{
m.Answer = []dns.RR{ Name: "",
&dns.TXT{ Rrtype: dns.TypeTXT,
Hdr: dns.RR_Header{ Class: dns.ClassINET,
Name: r.Question[0].Name, Ttl: 60,
}, },
Txt: []string{ Txt: []string{hashStr},
hashToReturn, }
}, respTmpl := &dns.Msg{
Answer: []dns.RR{ans},
}
return &UpstreamMock{
OnAddress: func() (addr string) {
return "sbpc.upstream.example"
},
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = respTmpl.Copy()
resp.SetReply(req)
resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name
return resp, nil
}, },
} }
return m, nil
} }
// Address always returns an empty string. // ErrUpstream is the error returned from the [*UpstreamMock] created by
func (u *TestBlockUpstream) Address() string { // [NewErrorUpstream].
return "" const ErrUpstream errors.Error = "test upstream error"
}
// RequestsCount returns the number of handled requests. It's safe for // NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from
// concurrent use. // its Exchange method.
func (u *TestBlockUpstream) RequestsCount() int { func NewErrorUpstream() (u *UpstreamMock) {
u.lock.Lock() return &UpstreamMock{
defer u.lock.Unlock() OnAddress: func() (addr string) {
return "error.upstream.example"
return u.reqNum },
} OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("test upstream error")
// TestErrUpstream implements upstream.Upstream interface for replacing real },
// upstream in tests. }
type TestErrUpstream struct {
// The error returned by Exchange may be unwrapped to the Err.
Err error
}
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
return nil, fmt.Errorf("errupstream: %w", u.Err)
}
// Address always returns an empty string.
func (u *TestErrUpstream) Address() string {
return ""
} }

View File

@ -121,6 +121,7 @@ type FilteringConfig struct {
EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request
EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option
MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests
HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests
// IpsetList is the ipset configuration that allows AdGuard Home to add // IpsetList is the ipset configuration that allows AdGuard Home to add
// IP addresses of the specified domain names to an ipset list. Syntax: // IP addresses of the specified domain names to an ipset list. Syntax:
@ -151,7 +152,7 @@ type TLSConfig struct {
PrivateKeyData []byte `yaml:"-" json:"-"` PrivateKeyData []byte `yaml:"-" json:"-"`
// ServerName is the hostname of the server. Currently, it is only being // ServerName is the hostname of the server. Currently, it is only being
// used for ClientID checking. // used for ClientID checking and Discovery of Designated Resolvers (DDR).
ServerName string `yaml:"-" json:"-"` ServerName string `yaml:"-" json:"-"`
cert tls.Certificate cert tls.Certificate

View File

@ -76,6 +76,10 @@ const (
resultCodeError resultCodeError
) )
// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests.
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
const ddrHostFQDN = "_dns.resolver.arpa."
// handleDNSRequest filters the incoming DNS requests and writes them to the query log // handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
ctx := &dnsContext{ ctx := &dnsContext{
@ -94,6 +98,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion, s.processRecursion,
s.processInitial, s.processInitial,
s.processDDRQuery,
s.processDetermineLocal, s.processDetermineLocal,
s.processDHCPHosts, s.processDHCPHosts,
s.processRestrictLocal, s.processRestrictLocal,
@ -239,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
s.setTableIPToHost(ipToHost) s.setTableIPToHost(ipToHost)
} }
// processDDRQuery responds to SVCB query for a special use domain name
// _dns.resolver.arpa. The result contains different types of encryption
// supported by current user configuration.
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) {
d := ctx.proxyCtx
question := d.Req.Question[0]
if !s.conf.HandleDDR {
return resultCodeSuccess
}
if question.Name == ddrHostFQDN {
if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil &&
s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB {
d.Res = s.makeResponse(d.Req)
return resultCodeFinish
}
d.Res = s.makeDDRResponse(d.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// makeDDRResponse creates DDR answer according to server configuration. The
// contructed SVCB resource records have the priority of 1 for each entry,
// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html.
//
// TODO(a.meshkov): Consider setting the priority values based on the protocol.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req)
// TODO(e.burkov): Think about storing the FQDN version of the server's
// name somewhere.
domainName := dns.Fqdn(s.conf.ServerName)
for _, addr := range s.conf.HTTPSListenAddrs {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.TLSListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
for _, addr := range s.dnsProxy.QUICListenAddr {
values := []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: uint16(addr.Port)},
}
ans := &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: 1,
Target: domainName,
Value: values,
}
resp.Answer = append(resp.Answer, ans)
}
return resp
}
// processDetermineLocal determines if the client's IP address is from // processDetermineLocal determines if the client's IP address is from
// locally-served network and saves the result into the context. // locally-served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {

View File

@ -14,6 +14,177 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const (
ddrTestDomainName = "dns.example.net"
ddrTestFQDN = ddrTestDomainName + "."
)
func TestServer_ProcessDDRQuery(t *testing.T) {
dohSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"h2"}},
&dns.SVCBPort{Port: 8044},
&dns.SVCBDoHPath{Template: "/dns-query?dns"},
},
}
dotSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"dot"}},
&dns.SVCBPort{Port: 8043},
},
}
doqSVCB := &dns.SVCB{
Priority: 1,
Target: ddrTestFQDN,
Value: []dns.SVCBKeyValue{
&dns.SVCBAlpn{Alpn: []string{"doq"}},
&dns.SVCBPort{Port: 8042},
},
}
testCases := []struct {
name string
host string
want []*dns.SVCB
wantRes resultCode
portDoH int
portDoT int
portDoQ int
qtype uint16
ddrEnabled bool
}{{
name: "pass_host",
wantRes: resultCodeSuccess,
host: "example.net.",
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_qtype",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeA,
ddrEnabled: true,
portDoH: 8043,
}, {
name: "pass_disabled_tls",
wantRes: resultCodeFinish,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
}, {
name: "pass_disabled_ddr",
wantRes: resultCodeSuccess,
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: false,
portDoH: 8043,
}, {
name: "dot",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
}, {
name: "doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoH: 8044,
}, {
name: "doq",
wantRes: resultCodeFinish,
want: []*dns.SVCB{doqSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoQ: 8042,
}, {
name: "dot_doh",
wantRes: resultCodeFinish,
want: []*dns.SVCB{dotSVCB, dohSVCB},
host: ddrHostFQDN,
qtype: dns.TypeSVCB,
ddrEnabled: true,
portDoT: 8043,
portDoH: 8044,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
req := createTestMessageWithType(tc.host, tc.qtype)
dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{
Req: req,
},
}
res := s.processDDRQuery(dctx)
require.Equal(t, tc.wantRes, res)
if tc.wantRes != resultCodeFinish {
return
}
msg := dctx.proxyCtx.Res
require.NotNil(t, msg)
for _, v := range tc.want {
v.Hdr = s.hdr(req, dns.TypeSVCB)
}
assert.ElementsMatch(t, tc.want, msg.Answer)
})
}
}
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
t.Helper()
proxyConf := proxy.Config{}
if portDoT > 0 {
proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
}
if portDoQ > 0 {
proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
}
s = &Server{
dnsProxy: &proxy.Proxy{
Config: proxyConf,
},
conf: ServerConfig{
FilteringConfig: FilteringConfig{
HandleDDR: ddrEnabled,
},
TLSConfig: TLSConfig{
ServerName: ddrTestDomainName,
},
},
}
if portDoH > 0 {
s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
}
return s
}
func TestServer_ProcessDetermineLocal(t *testing.T) { func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{ s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),

View File

@ -17,13 +17,13 @@ import (
"testing/fstest" "testing/fstest"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) {
func TestBlockedBySafeBrowsing(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) {
const hostname = "wmconvirus.narod.ru" const hostname = "wmconvirus.narod.ru"
sbUps := &aghtest.TestBlockUpstream{ sbUps := aghtest.NewBlockUpstream(hostname, true)
Hostname: hostname,
Block: true,
}
ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname)
filterConf := &filtering.Config{ filterConf := &filtering.Config{
@ -1029,7 +1026,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true s.conf.ProtectionEnabled = true
err = s.Prepare(nil) err = s.Prepare(nil)
require.NoError(t, err) require.NoError(t, err)
@ -1177,25 +1174,48 @@ func TestNewServer(t *testing.T) {
} }
func TestServer_Exchange(t *testing.T) { func TestServer_Exchange(t *testing.T) {
extUpstream := &aghtest.Upstream{ const (
Reverse: map[string][]string{ onesHost = "one.one.one.one"
"1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, localDomainHost = "local.domain"
)
var (
onesIP = net.IP{1, 1, 1, 1}
localIP = net.IP{192, 168, 1, 1}
)
revExtIPv4, err := netutil.IPToReversedAddr(onesIP)
require.NoError(t, err)
extUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "external.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
locUpstream := &aghtest.Upstream{
Reverse: map[string][]string{ revLocIPv4, err := netutil.IPToReversedAddr(localIP)
"1.1.168.192.in-addr.arpa.": {"local.domain"}, require.NoError(t, err)
"2.1.168.192.in-addr.arpa.": {},
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{ errUpstream := aghtest.NewErrorUpstream()
Err: upstreamErr, nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
}
nonPtrUpstream := &aghtest.TestBlockUpstream{
Hostname: "some-host",
Block: true,
}
srv := NewCustomServer(&proxy.Proxy{ srv := NewCustomServer(&proxy.Proxy{
Config: proxy.Config{ Config: proxy.Config{
@ -1209,7 +1229,6 @@ func TestServer_Exchange(t *testing.T) {
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
localIP := net.IP{192, 168, 1, 1}
testCases := []struct { testCases := []struct {
name string name string
want string want string
@ -1218,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) {
req net.IP req net.IP
}{{ }{{
name: "external_good", name: "external_good",
want: "one.one.one.one", want: onesHost,
wantErr: nil, wantErr: nil,
locUpstream: nil, locUpstream: nil,
req: net.IP{1, 1, 1, 1}, req: onesIP,
}, { }, {
name: "local_good", name: "local_good",
want: "local.domain", want: localDomainHost,
wantErr: nil, wantErr: nil,
locUpstream: locUpstream, locUpstream: locUpstream,
req: localIP, req: localIP,
}, { }, {
name: "upstream_error", name: "upstream_error",
want: "", want: "",
wantErr: upstreamErr, wantErr: aghtest.ErrUpstream,
locUpstream: errUpstream, locUpstream: errUpstream,
req: localIP, req: localIP,
}, { }, {

View File

@ -21,6 +21,11 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"
)
var setts = Settings{ var setts = Settings{
ProtectionEnabled: true, ProtectionEnabled: true,
} }
@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching)
require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+matching) require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked))
d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
// Cached result. // Cached result.
d.safeBrowsingServer = "127.0.0.1" d.safeBrowsingServer = "127.0.0.1"
d.checkMatch(t, matching) d.checkMatch(t, sbBlocked)
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
d.safeBrowsingServer = defaultSafebrowsingServer d.safeBrowsingServer = defaultSafebrowsingServer
} }
func TestParallelSB(t *testing.T) { func TestParallelSB(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: matching,
Block: true,
})
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel() t.Parallel()
d.checkMatch(t, matching) d.checkMatch(t, sbBlocked)
d.checkMatch(t, "test."+matching) d.checkMatch(t, "test."+sbBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "pornhub.com") d.checkMatchEmpty(t, pcBlocked)
}) })
} }
}) })
@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) {
d := newForTest(t, &Config{ParentalEnabled: true}, nil) d := newForTest(t, &Config{ParentalEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
const matching = "pornhub.com"
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: matching,
Block: true,
})
d.checkMatch(t, matching) d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
require.Contains(t, logOutput.String(), "Parental lookup for "+matching) d.checkMatch(t, pcBlocked)
require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked))
d.checkMatch(t, "www."+matching) d.checkMatch(t, "www."+pcBlocked)
d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "www.yandex.ru")
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
d.checkMatchEmpty(t, "api.jquery.com") d.checkMatchEmpty(t, "api.jquery.com")
// Test cached result. // Test cached result.
d.parentalServer = "127.0.0.1" d.parentalServer = "127.0.0.1"
d.checkMatch(t, matching) d.checkMatch(t, pcBlocked)
d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "yandex.ru")
} }
@ -445,7 +440,7 @@ func TestMatching(t *testing.T) {
}, { }, {
name: "sanity", name: "sanity",
rules: "||doubleclick.net^", rules: "||doubleclick.net^",
host: "wmconvirus.narod.ru", host: sbBlocked,
wantIsFiltered: false, wantIsFiltered: false,
wantReason: NotFilteredNotFound, wantReason: NotFilteredNotFound,
wantDNSType: dns.TypeA, wantDNSType: dns.TypeA,
@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) {
}}, }},
) )
t.Cleanup(d.Close) t.Cleanup(d.Close)
d.SetParentalUpstream(&aghtest.TestBlockUpstream{
Hostname: "pornhub.com", d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true))
Block: true, d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
})
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{
Hostname: "wmconvirus.narod.ru",
Block: true,
})
type testCase struct { type testCase struct {
name string name string
@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) {
wantReason: FilteredBlockList, wantReason: FilteredBlockList,
}, { }, {
name: "parental", name: "parental",
host: "pornhub.com", host: pcBlocked,
before: true, before: true,
wantReason: FilteredParental, wantReason: FilteredParental,
}, { }, {
name: "safebrowsing", name: "safebrowsing",
host: "wmconvirus.narod.ru", host: sbBlocked,
before: false, before: false,
wantReason: FilteredSafeBrowsing, wantReason: FilteredSafeBrowsing,
}, { }, {
@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) {
func BenchmarkSafeBrowsing(b *testing.B) { func BenchmarkSafeBrowsing(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: blocked,
Block: true,
})
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err) require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
} }
} }
func BenchmarkSafeBrowsingParallel(b *testing.B) { func BenchmarkSafeBrowsingParallel(b *testing.B) {
d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
blocked := "wmconvirus.narod.ru"
d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true))
Hostname: blocked,
Block: true,
})
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
res, err := d.CheckHost(blocked, dns.TypeA, &setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts)
require.NoError(b, err) require.NoError(b, err)
assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked)
} }
}) })
} }

View File

@ -314,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing(
if log.GetLevel() >= log.DEBUG { if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer() timer := log.StartTimer()
defer timer.LogElapsed("SafeBrowsing lookup for %s", host) defer timer.LogElapsed("safebrowsing lookup for %q", host)
} }
sctx := &sbCtx{ sctx := &sbCtx{
@ -348,7 +348,7 @@ func (d *DNSFilter) checkParental(
if log.GetLevel() >= log.DEBUG { if log.GetLevel() >= log.DEBUG {
timer := log.StartTimer() timer := log.StartTimer()
defer timer.LogElapsed("Parental lookup for %s", host) defer timer.LogElapsed("parental lookup for %q", host)
} }
sctx := &sbCtx{ sctx := &sbCtx{

View File

@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) {
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
assert.Equal(t, -1, c.getCached()) assert.Equal(t, -1, c.getCached())
// match "sub.host.com" from cache, // Match "sub.host.com" from cache. Another hash for "host.example" is not
// but another hash for "nonexisting.com" is not in cache // in the cache, so get data for it from the server.
// which means that we must get data from server for it
c.hashToHost = make(map[[32]byte]string) c.hashToHost = make(map[[32]byte]string)
hash = sha256.Sum256([]byte("sub.host.com")) hash = sha256.Sum256([]byte("sub.host.com"))
c.hashToHost[hash] = "sub.host.com" c.hashToHost[hash] = "sub.host.com"
hash = sha256.Sum256([]byte("nonexisting.com")) hash = sha256.Sum256([]byte("host.example"))
c.hashToHost[hash] = "nonexisting.com" c.hashToHost[hash] = "host.example"
assert.Empty(t, c.getCached()) assert.Empty(t, c.getCached())
hash = sha256.Sum256([]byte("sub.host.com")) hash = sha256.Sum256([]byte("sub.host.com"))
_, ok := c.hashToHost[hash] _, ok := c.hashToHost[hash]
assert.False(t, ok) assert.False(t, ok)
hash = sha256.Sum256([]byte("nonexisting.com")) hash = sha256.Sum256([]byte("host.example"))
_, ok = c.hashToHost[hash] _, ok = c.hashToHost[hash]
assert.True(t, ok) assert.True(t, ok)
@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil)
t.Cleanup(d.Close) t.Cleanup(d.Close)
ups := &aghtest.TestErrUpstream{} ups := aghtest.NewErrorUpstream()
d.SetSafeBrowsingUpstream(ups) d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
// Prepare the upstream. // Prepare the upstream.
ups := &aghtest.TestBlockUpstream{ ups := aghtest.NewBlockUpstream(hostname, tc.block)
Hostname: hostname,
Block: tc.block, var numReq int
onExchange := ups.OnExchange
ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) {
numReq++
return onExchange(req)
} }
d.SetSafeBrowsingUpstream(ups) d.SetSafeBrowsingUpstream(ups)
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits, tc.testCache.Stats().Hit) assert.Equal(t, hits, tc.testCache.Stats().Hit)
// There was one request to an upstream. // There was one request to an upstream.
assert.Equal(t, 1, ups.RequestsCount()) assert.Equal(t, 1, numReq)
// Now make the same request to check the cache was used. // Now make the same request to check the cache was used.
res, err = tc.testFunc(hostname, dns.TypeA, setts) res, err = tc.testFunc(hostname, dns.TypeA, setts)
@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) {
assert.Equal(t, hits+1, tc.testCache.Stats().Hit) assert.Equal(t, hits+1, tc.testCache.Stats().Hit)
// Check that there were no additional requests. // Check that there were no additional requests.
assert.Equal(t, 1, ups.RequestsCount()) assert.Equal(t, 1, numReq)
}) })
purgeCaches(d) purgeCaches(d)

View File

@ -209,6 +209,7 @@ var config = &configuration{
Ratelimit: 20, Ratelimit: 20,
RefuseAny: true, RefuseAny: true,
AllServers: false, AllServers: false,
HandleDDR: true,
FastestTimeout: timeutil.Duration{ FastestTimeout: timeutil.Duration{
Duration: fastip.DefaultPingWaitTimeout, Duration: fastip.DefaultPingWaitTimeout,
}, },

View File

@ -216,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
func handleStaticIP(ip net.IP, set bool) staticIPJSON { func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{} resp := staticIPJSON{}
interfaceName := aghnet.GetInterfaceByIP(ip) interfaceName := aghnet.InterfaceByIP(ip)
resp.Static = "no" resp.Static = "no"
if len(interfaceName) == 0 { if len(interfaceName) == 0 {

View File

@ -3,15 +3,16 @@ package home
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) {
binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix())) binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix()))
rdns := &RDNS{ rdns := &RDNS{
ipCache: ipCache, ipCache: ipCache,
exchanger: &rDNSExchanger{}, exchanger: &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
},
clients: &clientsContainer{ clients: &clientsContainer{
list: map[string]*Client{}, list: map[string]*Client{},
idIndex: tc.cliIDIndex, idIndex: tc.cliIDIndex,
@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) {
// rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests. // rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests.
type rDNSExchanger struct { type rDNSExchanger struct {
ex aghtest.Exchanger ex upstream.Upstream
usePrivate bool usePrivate bool
} }
// Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger. // Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger.
func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) { func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) {
rev, err := netutil.IPToReversedAddr(ip)
if err != nil {
return "", fmt.Errorf("reversing ip: %w", err)
}
req := &dns.Msg{ req := &dns.Msg{
Question: []dns.Question{{ Question: []dns.Question{{
Name: ip.String(), Name: dns.Fqdn(rev),
Qtype: dns.TypePTR, Qclass: dns.ClassINET,
Qtype: dns.TypePTR,
}}, }},
} }
@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) {
MaxCount: defaultRDNSCacheSize, MaxCount: defaultRDNSCacheSize,
}) })
ex := &rDNSExchanger{} ex := &rDNSExchanger{
ex: aghtest.NewErrorUpstream(),
}
rdns := &RDNS{ rdns := &RDNS{
ipCache: ipCache, ipCache: ipCache,
@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, w) aghtest.ReplaceLogWriter(t, w)
locUpstream := &aghtest.Upstream{ localIP := net.IP{192, 168, 1, 1}
Reverse: map[string][]string{ revIPv4, err := netutil.IPToReversedAddr(localIP)
"192.168.1.1": {"local.domain"}, require.NoError(t, err)
"2a00:1450:400c:c06::93": {"ipv6.domain"},
revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93"))
require.NoError(t, err)
locUpstream := &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "local.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
resp = aghalg.Coalesce(
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"),
aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
)
return resp, nil
}, },
} }
errUpstream := &aghtest.TestErrUpstream{
Err: errors.Error("1234"), errUpstream := aghtest.NewErrorUpstream()
}
testCases := []struct { testCases := []struct {
ups upstream.Upstream ups upstream.Upstream
@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ups: locUpstream, ups: locUpstream,
wantLog: "", wantLog: "",
name: "all_good", name: "all_good",
cliIP: net.IP{192, 168, 1, 1}, cliIP: localIP,
}, { }, {
ups: errUpstream, ups: errUpstream,
wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`, wantLog: `rdns: resolving "192.168.1.2": test upstream error`,
name: "resolve_error", name: "resolve_error",
cliIP: net.IP{192, 168, 1, 2}, cliIP: net.IP{192, 168, 1, 2},
}, { }, {
@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
ch := make(chan net.IP) ch := make(chan net.IP)
rdns := &RDNS{ rdns := &RDNS{
exchanger: &rDNSExchanger{ exchanger: &rDNSExchanger{
ex: aghtest.Exchanger{ ex: tc.ups,
Ups: tc.ups,
},
}, },
clients: cc, clients: cc,
ipCh: ch, ipCh: ch,