all: sync with master

This commit is contained in:
Ainar Garipov 2024-05-15 13:34:12 +03:00
parent 6318fc424b
commit 667263a3a8
82 changed files with 2356 additions and 1817 deletions

View File

@ -1,7 +1,7 @@
'name': 'build' 'name': 'build'
'env': 'env':
'GO_VERSION': '1.22.2' 'GO_VERSION': '1.22.3'
'NODE_VERSION': '16' 'NODE_VERSION': '16'
'on': 'on':

View File

@ -1,7 +1,7 @@
'name': 'lint' 'name': 'lint'
'env': 'env':
'GO_VERSION': '1.22.2' 'GO_VERSION': '1.22.3'
'on': 'on':
'push': 'push':

View File

@ -14,7 +14,7 @@ and this project adheres to
<!-- <!--
## [v0.108.0] - TBA ## [v0.108.0] - TBA
## [v0.107.49] - 2024-04-24 (APPROX.) ## [v0.107.49] - 2024-05-20 (APPROX.)
See also the [v0.107.49 GitHub milestone][ms-v0.107.49]. See also the [v0.107.49 GitHub milestone][ms-v0.107.49].
@ -23,6 +23,60 @@ See also the [v0.107.49 GitHub milestone][ms-v0.107.49].
NOTE: Add new changes BELOW THIS COMMENT. NOTE: Add new changes BELOW THIS COMMENT.
--> -->
### Security
- Go version has been updated to prevent the possibility of exploiting the Go
vulnerabilities fixed in [Go 1.22.3][go-1.22.3].
### Added
- Support for comments in the ipset file ([#5345]).
### Changed
- Private rDNS resolution now also affects `SOA` and `NS` requests ([#6882]).
- Rewrite rules mechanics was changed due to improve resolving in safe search.
### Deprecated
- Currently, AdGuard Home skips persistent clients that have duplicate fields
when reading them from the configuration file. This behaviour is deprecated
and will cause errors on startup in a future release.
### Fixed
- Acceptance of duplicate UIDs for persistent clients at startup. See also the
section on client settings on the [Wiki page][wiki-config].
- Domain specifications for top-level domains not considered for requests to
unqualified domains ([#6744]).
- Support for link-local subnets, i.e. `fe80::/16`, as client identifiers
([#6312]).
- Issues with QUIC and HTTP/3 upstreams on older Linux kernel versions
([#6422]).
- YouTube restricted mode is not enforced by HTTPS queries on Firefox.
- Support for link-local subnets, i.e. `fe80::/16`, in the access settings
([#6192]).
- The ability to apply an invalid configuration for private rDNS, which led to
server not starting.
- Ignoring query log for clients with ClientID set ([#5812]).
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
incorrectly considered invalid when specified for private rDNS upstream
servers ([#6854]).
- Unspecified IP addresses aren't checked when using "Fastest IP address" mode
([#6875]).
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
[#5812]: https://github.com/AdguardTeam/AdGuardHome/issues/5812
[#6192]: https://github.com/AdguardTeam/AdGuardHome/issues/6192
[#6312]: https://github.com/AdguardTeam/AdGuardHome/issues/6312
[#6422]: https://github.com/AdguardTeam/AdGuardHome/issues/6422
[#6744]: https://github.com/AdguardTeam/AdGuardHome/issues/6744
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
[#6875]: https://github.com/AdguardTeam/AdGuardHome/issues/6875
[#6882]: https://github.com/AdguardTeam/AdGuardHome/issues/6882
[go-1.22.3]: https://groups.google.com/g/golang-announce/c/wkkO4P9stm0
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.
--> -->
@ -35,7 +89,7 @@ See also the [v0.107.48 GitHub milestone][ms-v0.107.48].
### Fixed ### Fixed
- Access settings not being applied to encrypted protocols ([#6890]) - Access settings not being applied to encrypted protocols ([#6890]).
[#6890]: https://github.com/AdguardTeam/AdGuardHome/issues/6890 [#6890]: https://github.com/AdguardTeam/AdGuardHome/issues/6890

View File

@ -27,7 +27,7 @@ DIST_DIR = dist
GOAMD64 = v1 GOAMD64 = v1
GOPROXY = https://goproxy.cn|https://proxy.golang.org|direct GOPROXY = https://goproxy.cn|https://proxy.golang.org|direct
GOSUMDB = sum.golang.google.cn GOSUMDB = sum.golang.google.cn
GOTOOLCHAIN = go1.22.2 GOTOOLCHAIN = go1.22.3
GPG_KEY = devteam@adguard.com GPG_KEY = devteam@adguard.com
GPG_KEY_PASSPHRASE = not-a-real-password GPG_KEY_PASSPHRASE = not-a-real-password
NPM = npm NPM = npm

View File

@ -8,7 +8,7 @@
'variables': 'variables':
'channel': 'edge' 'channel': 'edge'
'dockerFrontend': 'adguard/home-js-builder:1.1' 'dockerFrontend': 'adguard/home-js-builder:1.1'
'dockerGo': 'adguard/go-builder:1.22.2--1' 'dockerGo': 'adguard/go-builder:1.22.3--1'
'stages': 'stages':
- 'Build frontend': - 'Build frontend':
@ -249,7 +249,7 @@
'recipients': 'recipients':
- 'webhook': - 'webhook':
'name': 'Build webhook' 'name': 'Build webhook'
'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa' 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa-dns-builds'
'labels': [] 'labels': []
'other': 'other':
@ -266,7 +266,7 @@
'variables': 'variables':
'channel': 'beta' 'channel': 'beta'
'dockerFrontend': 'adguard/home-js-builder:1.1' 'dockerFrontend': 'adguard/home-js-builder:1.1'
'dockerGo': 'adguard/go-builder:1.22.2--1' 'dockerGo': 'adguard/go-builder:1.22.3--1'
# release-vX.Y.Z branches are the branches from which the actual final # release-vX.Y.Z branches are the branches from which the actual final
# release is built. # release is built.
- '^release-v[0-9]+\.[0-9]+\.[0-9]+': - '^release-v[0-9]+\.[0-9]+\.[0-9]+':
@ -282,4 +282,4 @@
'variables': 'variables':
'channel': 'release' 'channel': 'release'
'dockerFrontend': 'adguard/home-js-builder:1.1' 'dockerFrontend': 'adguard/home-js-builder:1.1'
'dockerGo': 'adguard/go-builder:1.22.2--1' 'dockerGo': 'adguard/go-builder:1.22.3--1'

View File

@ -175,7 +175,7 @@
'recipients': 'recipients':
- 'webhook': - 'webhook':
'name': 'Build webhook' 'name': 'Build webhook'
'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa' 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa-dns-builds'
'labels': [] 'labels': []
'other': 'other':

View File

@ -6,7 +6,7 @@
'name': 'AdGuard Home - Build and run tests' 'name': 'AdGuard Home - Build and run tests'
'variables': 'variables':
'dockerFrontend': 'adguard/home-js-builder:1.1' 'dockerFrontend': 'adguard/home-js-builder:1.1'
'dockerGo': 'adguard/go-builder:1.22.2--1' 'dockerGo': 'adguard/go-builder:1.22.3--1'
'channel': 'development' 'channel': 'development'
'stages': 'stages':
@ -195,5 +195,5 @@
# may need to build a few of these. # may need to build a few of these.
'variables': 'variables':
'dockerFrontend': 'adguard/home-js-builder:1.1' 'dockerFrontend': 'adguard/home-js-builder:1.1'
'dockerGo': 'adguard/go-builder:1.22.2--1' 'dockerGo': 'adguard/go-builder:1.22.3--1'
'channel': 'candidate' 'channel': 'candidate'

View File

@ -13,14 +13,14 @@
"fallback_dns_desc": "List of fallback DNS servers used when upstream DNS servers are not responding. The syntax is the same as in the main upstreams field above.", "fallback_dns_desc": "List of fallback DNS servers used when upstream DNS servers are not responding. The syntax is the same as in the main upstreams field above.",
"fallback_dns_placeholder": "Enter one fallback DNS server per line", "fallback_dns_placeholder": "Enter one fallback DNS server per line",
"local_ptr_title": "Private reverse DNS servers", "local_ptr_title": "Private reverse DNS servers",
"local_ptr_desc": "The DNS servers that AdGuard Home uses for local PTR queries. These servers are used to resolve PTR requests for addresses in private IP ranges, for example \"192.168.12.34\", using reverse DNS. If not set, AdGuard Home uses the addresses of the default DNS resolvers of your OS except for the addresses of AdGuard Home itself.", "local_ptr_desc": "DNS servers used by AdGuard Home for private PTR, SOA, and NS requests. A request is considered private if it asks for an ARPA domain containing a subnet within private IP ranges (such as \"192.168.12.34\") and comes from a client with a private IP address. If not set, the default DNS resolvers of your OS will be used, except for the AdGuard Home IP addresses.",
"local_ptr_default_resolver": "By default, AdGuard Home uses the following reverse DNS resolvers: {{ip}}.", "local_ptr_default_resolver": "By default, AdGuard Home uses the following reverse DNS resolvers: {{ip}}.",
"local_ptr_no_default_resolver": "AdGuard Home could not determine suitable private reverse DNS resolvers for this system.", "local_ptr_no_default_resolver": "AdGuard Home could not determine suitable private reverse DNS resolvers for this system.",
"local_ptr_placeholder": "Enter one IP address per line", "local_ptr_placeholder": "Enter one IP address per line",
"resolve_clients_title": "Enable reverse resolving of clients' IP addresses", "resolve_clients_title": "Enable reverse resolving of clients' IP addresses",
"resolve_clients_desc": "Reversely resolve clients' IP addresses into their hostnames by sending PTR queries to corresponding resolvers (private DNS servers for local clients, upstream servers for clients with public IP addresses).", "resolve_clients_desc": "Reversely resolve clients' IP addresses into their hostnames by sending PTR queries to corresponding resolvers (private DNS servers for local clients, upstream servers for clients with public IP addresses).",
"use_private_ptr_resolvers_title": "Use private reverse DNS resolvers", "use_private_ptr_resolvers_title": "Use private reverse DNS resolvers",
"use_private_ptr_resolvers_desc": "Perform reverse DNS lookups for locally served addresses using these upstream servers. If disabled, AdGuard Home responds with NXDOMAIN to all such PTR requests except for clients known from DHCP, /etc/hosts, and so on.", "use_private_ptr_resolvers_desc": "Resolve PTR, SOA, and NS requests for ARPA domains containing private IP addresses through private upstream servers, DHCP, /etc/hosts, etc. If disabled, AdGuard Home will respond to all such requests with NXDOMAIN.",
"check_dhcp_servers": "Check for DHCP servers", "check_dhcp_servers": "Check for DHCP servers",
"save_config": "Save configuration", "save_config": "Save configuration",
"enabled_dhcp": "DHCP server enabled", "enabled_dhcp": "DHCP server enabled",

23
go.mod
View File

@ -1,10 +1,10 @@
module github.com/AdguardTeam/AdGuardHome module github.com/AdguardTeam/AdGuardHome
go 1.22.2 go 1.22.3
require ( require (
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2 github.com/AdguardTeam/dnsproxy v0.71.1
github.com/AdguardTeam/golibs v0.21.0 github.com/AdguardTeam/golibs v0.23.2
github.com/AdguardTeam/urlfilter v0.18.0 github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnscrypt/v2 v2.2.7
@ -28,14 +28,15 @@ require (
// own code for that. Perhaps, use gopacket. // own code for that. Perhaps, use gopacket.
github.com/mdlayher/raw v0.1.0 github.com/mdlayher/raw v0.1.0
github.com/miekg/dns v1.1.58 github.com/miekg/dns v1.1.58
github.com/quic-go/quic-go v0.41.0 // TODO(a.garipov): Use release version.
github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/ti-mo/netfilter v0.5.1 github.com/ti-mo/netfilter v0.5.1
go.etcd.io/bbolt v1.3.9 go.etcd.io/bbolt v1.3.9
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.22.0
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8
golang.org/x/net v0.23.0 golang.org/x/net v0.24.0
golang.org/x/sys v0.18.0 golang.org/x/sys v0.19.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
howett.net/plist v1.0.1 howett.net/plist v1.0.1
@ -58,9 +59,9 @@ require (
github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect
golang.org/x/mod v0.16.0 // indirect golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.6.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.19.0 // indirect golang.org/x/tools v0.20.0 // indirect
gonum.org/v1/gonum v0.14.0 // indirect gonum.org/v1/gonum v0.14.0 // indirect
) )

42
go.sum
View File

@ -1,7 +1,7 @@
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2 h1:XDhWNn1OfmbtLgj3bR52WWIa0/cf0ijanOvuaT75f1I= github.com/AdguardTeam/dnsproxy v0.71.1 h1:R8jKmoE9HwqdTt7bm8irpvrQEOSmD+iGdNXbOg/uM8Y=
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2/go.mod h1:7hAE3du5XPrBkdsqAPJIEGWklsE0ahHZONRlLASPeNI= github.com/AdguardTeam/dnsproxy v0.71.1/go.mod h1:rCaCL4m4n63sgwTOyUVdc7MC42PlUYBt11Fz/UjD+kM=
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= github.com/AdguardTeam/golibs v0.23.2 h1:rMjYantwtQ39e8G4zBQ6ZLlm4s3XH30Bc9VxhoOHwao=
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/AdguardTeam/golibs v0.23.2/go.mod h1:o9i55Sx6v7qogRQeqaBfmLbC/pZqeMBWi015U5PTDY0=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s= github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@ -101,8 +101,8 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f h1:L7x60Z6AW2giF/SvbDpMglGHJxtmFJV03khPwXLDScU=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4= github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4= github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
@ -131,26 +131,26 @@ go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc=
golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc=
golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -161,17 +161,19 @@ golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY=
golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=

View File

@ -10,29 +10,8 @@ import (
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
) )
// Coalesce returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns a zero
// value.
//
// T is comparable, because Go currently doesn't have a comparableWithZeroValue
// constraint.
//
// TODO(a.garipov): Think of ways to merge with [CoalesceSlice].
func Coalesce[T comparable](values ...T) (res T) {
var zero T
for _, v := range values {
if v != zero {
return v
}
}
return zero
}
// CoalesceSlice returns the first non-zero value. It is named after function // CoalesceSlice returns the first non-zero value. It is named after function
// COALESCE in SQL. If values or all its elements are empty, it returns nil. // COALESCE in SQL. If values or all its elements are empty, it returns nil.
//
// TODO(a.garipov): Think of ways to merge with [Coalesce].
func CoalesceSlice[E any, S []E](values ...S) (res S) { func CoalesceSlice[E any, S []E](values ...S) (res S) {
for _, v := range values { for _, v := range values {
if v != nil { if v != nil {

View File

@ -33,7 +33,7 @@ func elements(b *aghalg.RingBuffer[int], n uint, reverse bool) (es []int) {
func TestNewRingBuffer(t *testing.T) { func TestNewRingBuffer(t *testing.T) {
t.Run("success_and_clear", func(t *testing.T) { t.Run("success_and_clear", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](5) b := aghalg.NewRingBuffer[int](5)
for i := 0; i < 10; i++ { for i := range 10 {
b.Append(i) b.Append(i)
} }
assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false)) assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false))
@ -44,7 +44,7 @@ func TestNewRingBuffer(t *testing.T) {
t.Run("zero", func(t *testing.T) { t.Run("zero", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](0) b := aghalg.NewRingBuffer[int](0)
for i := 0; i < 10; i++ { for i := range 10 {
b.Append(i) b.Append(i)
bufLen := b.Len() bufLen := b.Len()
assert.EqualValues(t, 0, bufLen) assert.EqualValues(t, 0, bufLen)
@ -55,7 +55,7 @@ func TestNewRingBuffer(t *testing.T) {
t.Run("single", func(t *testing.T) { t.Run("single", func(t *testing.T) {
b := aghalg.NewRingBuffer[int](1) b := aghalg.NewRingBuffer[int](1)
for i := 0; i < 10; i++ { for i := range 10 {
b.Append(i) b.Append(i)
bufLen := b.Len() bufLen := b.Len()
assert.EqualValues(t, 1, bufLen) assert.EqualValues(t, 1, bufLen)
@ -94,7 +94,7 @@ func TestRingBuffer_Range(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
for i := 0; i < tc.count; i++ { for i := range tc.count {
b.Append(i) b.Append(i)
} }

View File

@ -11,7 +11,7 @@ func TestNewSortedMap(t *testing.T) {
var m SortedMap[string, int] var m SortedMap[string, int]
letters := []string{} letters := []string{}
for i := 0; i < 10; i++ { for i := range 10 {
r := string('a' + rune(i)) r := string('a' + rune(i))
letters = append(letters, r) letters = append(letters, r)
} }

View File

@ -97,6 +97,8 @@ func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) {
var filename string var filename string
defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }() defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }()
// TODO(e.burkov): Redo this loop, as it modifies the very same slice it
// iterates over.
for i := 0; i < len(src); i++ { for i := 0; i < len(src); i++ {
var patterns []string var patterns []string
var cont bool var cont bool

View File

@ -159,21 +159,11 @@ func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c) notifyReconfigureSignal(c)
} }
// NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c)
}
// IsReconfigureSignal returns true if sig is a reconfigure signal. // IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) { func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig) return isReconfigureSignal(sig)
} }
// IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig)
}
// SendShutdownSignal sends the shutdown signal to the channel. // SendShutdownSignal sends the shutdown signal to the channel.
func SendShutdownSignal(c chan<- os.Signal) { func SendShutdownSignal(c chan<- os.Signal) {
sendShutdownSignal(c) sendShutdownSignal(c)

View File

@ -13,26 +13,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP) signal.Notify(c, unix.SIGHUP)
} }
func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
}
func isReconfigureSignal(sig os.Signal) (ok bool) { func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP return sig == unix.SIGHUP
} }
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case
unix.SIGINT,
unix.SIGQUIT,
unix.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(_ chan<- os.Signal) { func sendShutdownSignal(_ chan<- os.Signal) {
// On Unix we are already notified by the system. // On Unix we are already notified by the system.
} }

View File

@ -5,7 +5,6 @@ package aghos
import ( import (
"os" "os"
"os/signal" "os/signal"
"syscall"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
@ -43,25 +42,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP) signal.Notify(c, windows.SIGHUP)
} }
func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows.
signal.Notify(c, os.Interrupt)
}
func isReconfigureSignal(sig os.Signal) (ok bool) { func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP return sig == windows.SIGHUP
} }
func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig {
case os.Interrupt, syscall.SIGTERM:
return true
default:
return false
}
}
func sendShutdownSignal(c chan<- os.Signal) { func sendShutdownSignal(c chan<- os.Signal) {
c <- os.Interrupt c <- os.Interrupt
} }

View File

@ -78,7 +78,6 @@ func TestWithDeferredCleanup(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -91,8 +91,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
@ -186,8 +184,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -7,6 +7,7 @@ package client
import ( import (
"encoding" "encoding"
"fmt" "fmt"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
) )
@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) {
// Runtime is a client information from different sources. // Runtime is a client information from different sources.
type Runtime struct { type Runtime struct {
// ip is an IP address of a client.
ip netip.Addr
// whois is the filtered WHOIS information of a client. // whois is the filtered WHOIS information of a client.
whois *whois.Info whois *whois.Info
@ -80,6 +84,15 @@ type Runtime struct {
hostsFile []string hostsFile []string
} }
// NewRuntime constructs a new runtime client. ip must be valid IP address.
//
// TODO(s.chzhen): Validate IP address.
func NewRuntime(ip netip.Addr) (r *Runtime) {
return &Runtime{
ip: ip,
}
}
// Info returns a client information from the highest-priority source. // Info returns a client information from the highest-priority source.
func (r *Runtime) Info() (cs Source, host string) { func (r *Runtime) Info() (cs Source, host string) {
info := []string{} info := []string{}
@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) {
r.whois = info r.whois = info
} }
// Unset clears a cs information. // unset clears a cs information.
func (r *Runtime) Unset(cs Source) { func (r *Runtime) unset(cs Source) {
switch cs { switch cs {
case SourceWHOIS: case SourceWHOIS:
r.whois = nil r.whois = nil
@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) {
} }
} }
// IsEmpty returns true if there is no information from any source. // isEmpty returns true if there is no information from any source.
func (r *Runtime) IsEmpty() (ok bool) { func (r *Runtime) isEmpty() (ok bool) {
return r.whois == nil && return r.whois == nil &&
r.arp == nil && r.arp == nil &&
r.rdns == nil && r.rdns == nil &&
r.dhcp == nil && r.dhcp == nil &&
r.hostsFile == nil r.hostsFile == nil
} }
// Addr returns an IP address of the client.
func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip
}

View File

@ -4,8 +4,12 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"slices"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/exp/maps"
) )
// macKey contains MAC as byte array of 6, 8, or 20 bytes. // macKey contains MAC as byte array of 6, 8, or 20 bytes.
@ -28,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) {
// Index stores all information about persistent clients. // Index stores all information about persistent clients.
type Index struct { type Index struct {
// nameToUID maps client name to UID.
nameToUID map[string]UID
// clientIDToUID maps client ID to UID. // clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID clientIDToUID map[string]UID
@ -47,6 +54,7 @@ type Index struct {
// NewIndex initializes the new instance of client index. // NewIndex initializes the new instance of client index.
func NewIndex() (ci *Index) { func NewIndex() (ci *Index) {
return &Index{ return &Index{
nameToUID: map[string]UID{},
clientIDToUID: map[string]UID{}, clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{}, ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
@ -62,6 +70,8 @@ func (ci *Index) Add(c *Persistent) {
panic("client must contain uid") panic("client must contain uid")
} }
ci.nameToUID[c.Name] = c.UID
for _, id := range c.ClientIDs { for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID ci.clientIDToUID[id] = c.UID
} }
@ -82,15 +92,30 @@ func (ci *Index) Add(c *Persistent) {
ci.uidToClient[c.UID] = c ci.uidToClient[c.UID] = c
} }
// ClashesUID returns existing persistent client with the same UID as c. Note
// that this is only possible when configuration contains duplicate fields.
func (ci *Index) ClashesUID(c *Persistent) (err error) {
p, ok := ci.uidToClient[c.UID]
if ok {
return fmt.Errorf("another client %q uses the same uid", p.Name)
}
return nil
}
// Clashes returns an error if the index contains a different persistent client // Clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil. // with at least a single identifier contained by c. c must be non-nil.
func (ci *Index) Clashes(c *Persistent) (err error) { func (ci *Index) Clashes(c *Persistent) (err error) {
if p := ci.clashesName(c); p != nil {
return fmt.Errorf("another client uses the same name %q", p.Name)
}
for _, id := range c.ClientIDs { for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id] existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID { if ok && existing != c.UID {
p := ci.uidToClient[existing] p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id) return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id)
} }
} }
@ -112,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) {
return nil return nil
} }
// clashesName returns existing persistent client with the same name as c or
// nil. c must be non-nil.
func (ci *Index) clashesName(c *Persistent) (existing *Persistent) {
existing, ok := ci.FindByName(c.Name)
if !ok {
return nil
}
if existing.UID != c.UID {
return existing
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be // clashesIP returns a previous client with the same IP address as c. c must be
// non-nil. // non-nil.
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
@ -184,21 +224,33 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) {
mac, err := net.ParseMAC(id) mac, err := net.ParseMAC(id)
if err == nil { if err == nil {
return ci.findByMAC(mac) return ci.FindByMAC(mac)
} }
return nil, false return nil, false
} }
// find finds persistent client by IP address. // FindByName finds persistent client by name.
func (ci *Index) FindByName(name string) (c *Persistent, found bool) {
uid, found := ci.nameToUID[name]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// findByIP finds persistent client by IP address.
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
uid, found := ci.ipToUID[ip] uid, found := ci.ipToUID[ip]
if found { if found {
return ci.uidToClient[uid], true return ci.uidToClient[uid], true
} }
ipWithoutZone := ip.WithZone("")
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) { ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) { // Remove zone before checking because prefixes strip zones.
if pref.Contains(ipWithoutZone) {
uid, found = id, true uid, found = id, true
return false return false
@ -214,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
return nil, false return nil, false
} }
// find finds persistent client by MAC. // FindByMAC finds persistent client by MAC.
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
k := macToKey(mac) k := macToKey(mac)
uid, found := ci.macToUID[k] uid, found := ci.macToUID[k]
if found { if found {
@ -225,9 +277,31 @@ func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
return nil, false return nil, false
} }
// FindByIPWithoutZone finds a persistent client by IP address without zone. It
// strips the IPv6 zone index from the stored IP addresses before comparing,
// because querylog entries don't have it. See TODO on [querylog.logEntry.IP].
//
// Note that multiple clients can have the same IP address with different zones.
// Therefore, the result of this method is indeterminate.
func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) {
if (ip == netip.Addr{}) {
return nil
}
for addr, uid := range ci.ipToUID {
if addr.WithZone("") == ip {
return ci.uidToClient[uid]
}
}
return nil
}
// Delete removes information about persistent client from the index. c must be // Delete removes information about persistent client from the index. c must be
// non-nil. // non-nil.
func (ci *Index) Delete(c *Persistent) { func (ci *Index) Delete(c *Persistent) {
delete(ci.nameToUID, c.Name)
for _, id := range c.ClientIDs { for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id) delete(ci.clientIDToUID, id)
} }
@ -247,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) {
delete(ci.uidToClient, c.UID) delete(ci.uidToClient, c.UID)
} }
// Size returns the number of persistent clients.
func (ci *Index) Size() (n int) {
return len(ci.uidToClient)
}
// Range calls f for each persistent client, unless cont is false. The order is
// undefined.
func (ci *Index) Range(f func(c *Persistent) (cont bool)) {
for _, c := range ci.uidToClient {
if !f(c) {
return
}
}
}
// RangeByName is like [Index.Range] but sorts the persistent clients by name
// before iterating ensuring a predictable order.
func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) {
cs := maps.Values(ci.uidToClient)
slices.SortFunc(cs, func(a, b *Persistent) (n int) {
return strings.Compare(a.Name, b.Name)
})
for _, c := range cs {
if !f(c) {
break
}
}
}
// CloseUpstreams closes upstream configurations of persistent clients.
func (ci *Index) CloseUpstreams() (err error) {
var errs []error
ci.RangeByName(func(c *Persistent) (cont bool) {
err = c.CloseUpstreams()
if err != nil {
errs = append(errs, err)
}
return true
})
return errors.Join(errs...)
}

View File

@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) {
return ci return ci
} }
func TestClientIndex(t *testing.T) { func TestClientIndex_Find(t *testing.T) {
const ( const (
cliIPNone = "1.2.3.4" cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1" cliIP1 = "1.1.1.1"
@ -35,26 +35,49 @@ func TestClientIndex(t *testing.T) {
cliID = "client-id" cliID = "client-id"
cliMAC = "11:11:11:11:11:11" cliMAC = "11:11:11:11:11:11"
linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0"
linkLocalSubnet = "fe80::/16"
) )
clients := []*Persistent{{ var (
clientWithBothFams = &Persistent{
Name: "client1", Name: "client1",
IPs: []netip.Addr{ IPs: []netip.Addr{
netip.MustParseAddr(cliIP1), netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6), netip.MustParseAddr(cliIPv6),
}, },
}, { }
clientWithSubnet = &Persistent{
Name: "client2", Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, { }
clientWithMAC = &Persistent{
Name: "client_with_mac", Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, { }
clientWithID = &Persistent{
Name: "client_with_id", Name: "client_with_id",
ClientIDs: []string{cliID}, ClientIDs: []string{cliID},
}} }
clientLinkLocal = &Persistent{
Name: "client_link_local",
Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)},
}
)
clients := []*Persistent{
clientWithBothFams,
clientWithSubnet,
clientWithMAC,
clientWithID,
clientLinkLocal,
}
ci := newIDIndex(clients) ci := newIDIndex(clients)
testCases := []struct { testCases := []struct {
@ -64,19 +87,23 @@ func TestClientIndex(t *testing.T) {
}{{ }{{
name: "ipv4_ipv6", name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6}, ids: []string{cliIP1, cliIPv6},
want: clients[0], want: clientWithBothFams,
}, { }, {
name: "ipv4_subnet", name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP}, ids: []string{cliIP2, cliSubnetIP},
want: clients[1], want: clientWithSubnet,
}, { }, {
name: "mac", name: "mac",
ids: []string{cliMAC}, ids: []string{cliMAC},
want: clients[2], want: clientWithMAC,
}, { }, {
name: "client_id", name: "client_id",
ids: []string{cliID}, ids: []string{cliID},
want: clients[3], want: clientWithID,
}, {
name: "client_link_local_subnet",
ids: []string{linkLocalIP},
want: clientLinkLocal,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -221,3 +248,103 @@ func TestMACToKey(t *testing.T) {
_ = macToKey(mac) _ = macToKey(mac)
}) })
} }
func TestIndex_FindByIPWithoutZone(t *testing.T) {
var (
ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1")
ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2")
)
var (
clientNoZone = &Persistent{
Name: "client",
IPs: []netip.Addr{ip},
}
clientWithZone = &Persistent{
Name: "client_with_zone",
IPs: []netip.Addr{ipWithZone},
}
)
ci := newIDIndex([]*Persistent{
clientNoZone,
clientWithZone,
})
testCases := []struct {
ip netip.Addr
want *Persistent
name string
}{{
name: "without_zone",
ip: ip,
want: clientNoZone,
}, {
name: "with_zone",
ip: ipWithZone,
want: clientWithZone,
}, {
name: "zero_address",
ip: netip.Addr{},
want: nil,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
c := ci.FindByIPWithoutZone(tc.ip.WithZone(""))
require.Equal(t, tc.want, c)
})
}
}
func TestClientIndex_RangeByName(t *testing.T) {
sortedClients := []*Persistent{{
Name: "clientA",
ClientIDs: []string{"A"},
}, {
Name: "clientB",
ClientIDs: []string{"B"},
}, {
Name: "clientC",
ClientIDs: []string{"C"},
}, {
Name: "clientD",
ClientIDs: []string{"D"},
}, {
Name: "clientE",
ClientIDs: []string{"E"},
}}
testCases := []struct {
name string
want []*Persistent
}{{
name: "basic",
want: sortedClients,
}, {
name: "nil",
want: nil,
}, {
name: "one_element",
want: sortedClients[:1],
}, {
name: "two_elements",
want: sortedClients[:2],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ci := newIDIndex(tc.want)
var got []*Persistent
ci.RangeByName(func(c *Persistent) (cont bool) {
got = append(got, c)
return true
})
assert.Equal(t, tc.want, got)
})
}
}

View File

@ -64,8 +64,6 @@ type Persistent struct {
// upstream must be used. // upstream must be used.
UpstreamConfig *proxy.CustomUpstreamConfig UpstreamConfig *proxy.CustomUpstreamConfig
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client. // BlockedServices is the configuration of blocked services of a client.
@ -95,6 +93,9 @@ type Persistent struct {
UseOwnBlockedServices bool UseOwnBlockedServices bool
IgnoreQueryLog bool IgnoreQueryLog bool
IgnoreStatistics bool IgnoreStatistics bool
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
SafeSearchConf filtering.SafeSearchConfig
} }
// SetTags sets the tags if they are known, otherwise logs an unknown tag. // SetTags sets the tags if they are known, otherwise logs an unknown tag.

View File

@ -0,0 +1,63 @@
package client
import "net/netip"
// RuntimeIndex stores information about runtime clients.
type RuntimeIndex struct {
// index maps IP address to runtime client.
index map[netip.Addr]*Runtime
}
// NewRuntimeIndex returns initialized runtime index.
func NewRuntimeIndex() (ri *RuntimeIndex) {
return &RuntimeIndex{
index: map[netip.Addr]*Runtime{},
}
}
// Client returns the saved runtime client by ip. If no such client exists,
// returns nil.
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
return ri.index[ip]
}
// Add saves the runtime client in the index. IP address of a client must be
// unique. See [Runtime.Client]. rc must not be nil.
func (ri *RuntimeIndex) Add(rc *Runtime) {
ip := rc.Addr()
ri.index[ip] = rc
}
// Size returns the number of the runtime clients.
func (ri *RuntimeIndex) Size() (n int) {
return len(ri.index)
}
// Range calls f for each runtime client in an undefined order.
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
for _, rc := range ri.index {
if !f(rc) {
return
}
}
}
// Delete removes the runtime client by ip.
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
delete(ri.index, ip)
}
// DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients.
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
for ip, rc := range ri.index {
rc.unset(src)
if rc.isEmpty() {
delete(ri.index, ip)
n++
}
}
return n
}

View File

@ -0,0 +1,85 @@
package client_test
import (
"net/netip"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/stretchr/testify/assert"
)
func TestRuntimeIndex(t *testing.T) {
const cliSrc = client.SourceARP
var (
ip1 = netip.MustParseAddr("1.1.1.1")
ip2 = netip.MustParseAddr("2.2.2.2")
ip3 = netip.MustParseAddr("3.3.3.3")
)
ri := client.NewRuntimeIndex()
currentSize := 0
testCases := []struct {
ip netip.Addr
name string
hosts []string
src client.Source
}{{
src: cliSrc,
ip: ip1,
name: "1",
hosts: []string{"host1"},
}, {
src: cliSrc,
ip: ip2,
name: "2",
hosts: []string{"host2"},
}, {
src: cliSrc,
ip: ip3,
name: "3",
hosts: []string{"host3"},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rc := client.NewRuntime(tc.ip)
rc.SetInfo(tc.src, tc.hosts)
ri.Add(rc)
currentSize++
got := ri.Client(tc.ip)
assert.Equal(t, rc, got)
})
}
t.Run("size", func(t *testing.T) {
assert.Equal(t, currentSize, ri.Size())
})
t.Run("range", func(t *testing.T) {
s := 0
ri.Range(func(rc *client.Runtime) (cont bool) {
s++
return true
})
assert.Equal(t, currentSize, s)
})
t.Run("delete", func(t *testing.T) {
ri.Delete(ip1)
currentSize--
assert.Equal(t, currentSize, ri.Size())
})
t.Run("delete_by_src", func(t *testing.T) {
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
assert.Equal(t, 0, ri.Size())
})
}

View File

@ -1,5 +1,7 @@
package configmigrate package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo15 performs the following changes: // migrateTo15 performs the following changes:
// //
// # BEFORE: // # BEFORE:
@ -43,7 +45,7 @@ func migrateTo15(diskConf yobj) (err error) {
} }
diskConf["querylog"] = qlog diskConf["querylog"] = qlog
return coalesceError( return errors.Join(
moveVal[bool](dns, qlog, "querylog_enabled", "enabled"), moveVal[bool](dns, qlog, "querylog_enabled", "enabled"),
moveVal[bool](dns, qlog, "querylog_file_enabled", "file_enabled"), moveVal[bool](dns, qlog, "querylog_file_enabled", "file_enabled"),
moveVal[any](dns, qlog, "querylog_interval", "interval"), moveVal[any](dns, qlog, "querylog_interval", "interval"),

View File

@ -1,5 +1,7 @@
package configmigrate package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo24 performs the following changes: // migrateTo24 performs the following changes:
// //
// # BEFORE: // # BEFORE:
@ -28,7 +30,7 @@ func migrateTo24(diskConf yobj) (err error) {
diskConf["schema_version"] = 24 diskConf["schema_version"] = 24
logObj := yobj{} logObj := yobj{}
err = coalesceError( err = errors.Join(
moveVal[string](diskConf, logObj, "log_file", "file"), moveVal[string](diskConf, logObj, "log_file", "file"),
moveVal[int](diskConf, logObj, "log_max_backups", "max_backups"), moveVal[int](diskConf, logObj, "log_max_backups", "max_backups"),
moveVal[int](diskConf, logObj, "log_max_size", "max_size"), moveVal[int](diskConf, logObj, "log_max_size", "max_size"),

View File

@ -1,5 +1,7 @@
package configmigrate package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo26 performs the following changes: // migrateTo26 performs the following changes:
// //
// # BEFORE: // # BEFORE:
@ -78,7 +80,7 @@ func migrateTo26(diskConf yobj) (err error) {
} }
filteringObj := yobj{} filteringObj := yobj{}
err = coalesceError( err = errors.Join(
moveSameVal[bool](dns, filteringObj, "filtering_enabled"), moveSameVal[bool](dns, filteringObj, "filtering_enabled"),
moveSameVal[int](dns, filteringObj, "filters_update_interval"), moveSameVal[int](dns, filteringObj, "filters_update_interval"),
moveSameVal[bool](dns, filteringObj, "parental_enabled"), moveSameVal[bool](dns, filteringObj, "parental_enabled"),

View File

@ -1,5 +1,7 @@
package configmigrate package configmigrate
import "github.com/AdguardTeam/golibs/errors"
// migrateTo7 performs the following changes: // migrateTo7 performs the following changes:
// //
// # BEFORE: // # BEFORE:
@ -37,7 +39,7 @@ func migrateTo7(diskConf yobj) (err error) {
} }
dhcpv4 := yobj{} dhcpv4 := yobj{}
err = coalesceError( err = errors.Join(
moveSameVal[string](dhcp, dhcpv4, "gateway_ip"), moveSameVal[string](dhcp, dhcpv4, "gateway_ip"),
moveSameVal[string](dhcp, dhcpv4, "subnet_mask"), moveSameVal[string](dhcp, dhcpv4, "subnet_mask"),
moveSameVal[string](dhcp, dhcpv4, "range_start"), moveSameVal[string](dhcp, dhcpv4, "range_start"),

View File

@ -50,19 +50,3 @@ func moveVal[T any](src, dst yobj, srcKey, dstKey string) (err error) {
func moveSameVal[T any](src, dst yobj, key string) (err error) { func moveSameVal[T any](src, dst yobj, key string) (err error) {
return moveVal[T](src, dst, key, key) return moveVal[T](src, dst, key, key)
} }
// coalesceError returns the first non-nil error. It is named after function
// COALESCE in SQL. If all errors are nil, it returns nil.
//
// TODO(e.burkov): Replace with [errors.Join].
//
// TODO(a.garipov): Think of ways to merge with [aghalg.Coalesce].
func coalesceError(errors ...error) (res error) {
for _, err := range errors {
if err != nil {
return err
}
}
return nil
}

View File

@ -156,7 +156,10 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) {
} }
for _, ipnet := range ipnets { for _, ipnet := range ipnets {
if ipnet.Contains(ip) { // Remove zone before checking because prefixes stip zones.
//
// TODO(d.kolyshev): Cover with tests.
if ipnet.Contains(ip.WithZone("")) {
return blocked, ipnet.String() return blocked, ipnet.String()
} }
} }

View File

@ -0,0 +1,116 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(d.kolyshev): Extract to separate package.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return &proxy.BeforeRequestError{
Err: fmt.Errorf("getting clientid: %w", err),
Response: s.NewMsgSERVFAIL(pctx.Req),
}
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}
return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}

View File

@ -0,0 +1,299 @@
package dnsforward
import (
"crypto/tls"
"net"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
blockedHost = "blockedhost.org"
testFQDN = "example.org."
dnsClientTimeout = 200 * time.Millisecond
)
func TestServer_HandleBefore_tls(t *testing.T) {
t.Parallel()
const clientID = "client-1"
testCases := []struct {
clientSrvName string
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantRCode int
}{{
clientSrvName: tlsServerName,
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "%" + "." + tlsServerName,
name: "invalid_client_id",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeServerFailure,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: "client-2." + tlsServerName,
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{clientID},
disallowedClients: []string{},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: clientID + "." + tlsServerName,
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{clientID},
blockedHosts: []string{},
wantRCode: dns.RcodeRefused,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeSuccess,
}, {
clientSrvName: tlsServerName,
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantRCode: dns.RcodeRefused,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s, _ := createTestTLS(t, TLSConfig{
TLSListenAddrs: []*net.TCPAddr{{}},
ServerName: tlsServerName,
})
s.conf.UpstreamDNS = []string{localUpsAddr}
s.conf.AllowedClients = tc.allowedClients
s.conf.DisallowedClients = tc.disallowedClients
s.conf.BlockedHosts = tc.blockedHosts
err := s.Prepare(&s.conf)
require.NoError(t, err)
startDeferStop(t, s)
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
ServerName: tc.clientSrvName,
}
client := &dns.Client{
Net: "tcp-tls",
TLSConfig: tlsConfig,
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoTLS).String()
reply, _, err := client.Exchange(req, addr)
require.NoError(t, err)
assert.Equal(t, tc.wantRCode, reply.Rcode)
if tc.wantRCode == dns.RcodeSuccess {
assert.Equal(t, localAns, reply.Answer)
} else {
assert.Empty(t, reply.Answer)
}
})
}
}
func TestServer_HandleBefore_udp(t *testing.T) {
t.Parallel()
const (
clientIPv4 = "127.0.0.1"
clientIPv6 = "::1"
)
clientIPs := []string{clientIPv4, clientIPv6}
testCases := []struct {
name string
host string
allowedClients []string
disallowedClients []string
blockedHosts []string
wantTimeout bool
}{{
name: "allow_all",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_allowed",
host: testFQDN,
allowedClients: clientIPs,
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "allowed_client_rejected",
host: testFQDN,
allowedClients: []string{"1:2:3::4"},
disallowedClients: []string{},
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "disallowed_client_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{"1:2:3::4"},
blockedHosts: []string{},
wantTimeout: false,
}, {
name: "disallowed_client_rejected",
host: testFQDN,
allowedClients: []string{},
disallowedClients: clientIPs,
blockedHosts: []string{},
wantTimeout: true,
}, {
name: "blocked_hosts_allowed",
host: testFQDN,
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: false,
}, {
name: "blocked_hosts_rejected",
host: dns.Fqdn(blockedHost),
allowedClients: []string{},
disallowedClients: []string{},
blockedHosts: []string{blockedHost},
wantTimeout: true,
}}
localAns := []dns.RR{&dns.A{
Hdr: dns.RR_Header{
Name: testFQDN,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 3600,
Rdlength: 4,
},
A: net.IP{1, 2, 3, 4},
}}
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = localAns
require.NoError(t, w.WriteMsg(resp))
})
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
AllowedClients: tc.allowedClients,
DisallowedClients: tc.disallowedClients,
BlockedHosts: tc.blockedHosts,
UpstreamDNS: []string{localUpsAddr},
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
ServePlainDNS: true,
})
startDeferStop(t, s)
client := &dns.Client{
Net: "udp",
Timeout: dnsClientTimeout,
}
req := createTestMessage(tc.host)
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
reply, _, err := client.Exchange(req, addr)
if tc.wantTimeout {
wantErr := &net.OpError{}
require.ErrorAs(t, err, &wantErr)
assert.True(t, wantErr.Timeout())
assert.Nil(t, reply)
} else {
require.NoError(t, err)
require.NotNil(t, reply)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.Equal(t, localAns, reply.Answer)
}
})
}
}

View File

@ -110,46 +110,6 @@ type quicConnection interface {
ConnectionState() (cs quic.ConnectionState) ConnectionState() (cs quic.ConnectionState)
} }
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// clientServerName returns the TLS server name based on the protocol. For // clientServerName returns the TLS server name based on the protocol. For
// DNS-over-HTTPS requests, it will return the hostname part of the Host header // DNS-over-HTTPS requests, it will return the hostname part of the Host header
// if there is one. // if there is one.

View File

@ -235,9 +235,18 @@ type DNSCryptConfig struct {
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use. // The zero ServerConfig is empty and ready for use.
type ServerConfig struct { type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address // UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
TCPListenAddrs []*net.TCPAddr // TCP listen address UDPListenAddrs []*net.UDPAddr
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
TCPListenAddrs []*net.TCPAddr
// UpstreamConfig is the general configuration of upstream DNS servers.
UpstreamConfig *proxy.UpstreamConfig
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
// for private reverse DNS.
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
// AddrProcConf defines the configuration for the client IP processor. // AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used. // If nil, [client.EmptyAddrProc] is used.
@ -317,13 +326,17 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
CacheMaxTTL: srvConf.CacheMaxTTL, CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic, CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig, UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
BeforeRequestHandler: s,
RequestHandler: s.handleDNSRequest, RequestHandler: s.handleDNSRequest,
HTTPSServerName: aghhttp.UserAgent(), HTTPSServerName: aghhttp.UserAgent(),
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
MaxGoroutines: srvConf.MaxGoroutines, MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64, UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes, DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {
@ -452,12 +465,33 @@ func (s *Server) prepareIpsetListSettings() (err error) {
} }
ipsets := stringutil.SplitTrimmed(string(data), "\n") ipsets := stringutil.SplitTrimmed(string(data), "\n")
ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty)
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn) log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
return s.ipset.init(ipsets) return s.ipset.init(ipsets)
} }
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// collectListenAddr adds addrPort to addrs. It also adds its port to // collectListenAddr adds addrPort to addrs. It also adds its port to
// unspecPorts if its address is unspecified. // unspecPorts if its address is unspecified.
func collectListenAddr( func collectListenAddr(
@ -529,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr()) return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
} }
// filterOut filters out all the upstreams that match um. It returns all the // filterOutAddrs filters out all the upstreams that match um. It returns all
// closing errors joined. // the closing errors joined.
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) { func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
var errs []error var errs []error
delFunc := func(u upstream.Upstream) (ok bool) { delFunc := func(u upstream.Upstream) (ok bool) {

View File

@ -3,7 +3,6 @@ package dnsforward
import ( import (
"net" "net"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
} }
func TestServer_HandleDNSRequest_dns64(t *testing.T) { func TestServer_HandleDNSRequest_dns64(t *testing.T) {
t.Parallel()
const ( const (
ipv4Domain = "ipv4.only." ipv4Domain = "ipv4.only."
ipv6Domain = "ipv6.only." ipv6Domain = "ipv6.only."
@ -252,33 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1) require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain) require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR}, resp := (&dns.Msg{}).SetReply(m)
}).SetReply(m) resp.Answer = []dns.RR{localRR}
require.NoError(t, w.WriteMsg(resp)) require.NoError(t, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{ client := &dns.Client{
Net: "tcp", Net: string(proxy.ProtoTCP),
Timeout: 1 * time.Second, Timeout: testTimeout,
} }
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel()
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0] q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype] answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{ resp := (&dns.Msg{}).SetReply(req)
Answer: answer[sectionAnswer], resp.Answer = answer[sectionAnswer]
Ns: answer[sectionAuthority], resp.Ns = answer[sectionAuthority]
Extra: answer[sectionAdditional], resp.Extra = answer[sectionAdditional]
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp)) require.NoError(pt, w.WriteMsg(resp))
}) })
@ -308,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String()) resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, excErr) require.NoError(t, excErr)
require.Equal(t, tc.wantAns, resp.Answer) require.Equal(t, tc.wantAns, resp.Answer)
}) })
} }
} }
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
t.Parallel()
// Shouldn't go to upstream at all.
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
panic("not implemented")
})
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
cli := &dns.Client{
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
}

View File

@ -2,6 +2,7 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"context" "context"
"fmt" "fmt"
"io" "io"
@ -15,7 +16,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -135,12 +135,6 @@ type Server struct {
// WHOIS, etc. // WHOIS, etc.
addrProc client.AddressProcessor addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
// sysResolvers used to fetch system resolvers to use by default for private // sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving. // PTR resolving.
sysResolvers SystemResolvers sysResolvers SystemResolvers
@ -158,12 +152,6 @@ type Server struct {
// [upstream.Resolver] interface. // [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major // dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
// part of DNS64 happens inside the [proxy] package, but there still are // part of DNS64 happens inside the [proxy] package, but there still are
// some places where response mapping is needed (e.g. DHCP). // some places where response mapping is needed (e.g. DHCP).
@ -212,14 +200,6 @@ type DNSCreateParams struct {
LocalDomain string LocalDomain string
} }
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
// //
@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{ clientIDCache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxCount: defaultClientIDCacheCount, MaxCount: defaultClientIDCacheCount,
@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
// TODO(e.burkov): Migrate to [netip.Addr] already.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", 0, fmt.Errorf("reversing ip: %w", err) return "", 0, fmt.Errorf("reversing ip: %w", err)
@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
} }
dctx := &proxy.DNSContext{ dctx := &proxy.DNSContext{
Proto: "udp", Proto: proxy.ProtoUDP,
Req: req, Req: req,
IsPrivateClient: true,
} }
var resolver *proxy.Proxy
var errMsg string var errMsg string
if s.privateNets.Contains(ip) { if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", 0, nil return "", 0, nil
} }
resolver = s.localResolvers
errMsg = "resolving a private address: %w" errMsg = "resolving a private address: %w"
s.recDetector.add(*req) dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
} else { } else {
resolver = s.internalProxy
errMsg = "resolving an address: %w" errMsg = "resolving an address: %w"
} }
if err = resolver.Resolve(dctx); err != nil { if err = s.internalProxy.Resolve(dctx); err != nil {
return "", 0, fmt.Errorf(errMsg, err) return "", 0, fmt.Errorf(errMsg, err)
} }
@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
return err return err
} }
// prepareLocalResolvers initializes the local upstreams configuration using
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
func (s *Server) prepareLocalResolvers(
boot upstream.Resolver,
) (uc *proxy.UpstreamConfig, err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
resolvers := s.conf.LocalPTRResolvers
confNeedsFiltering := len(resolvers) > 0
if confNeedsFiltering {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return nil, fmt.Errorf("preparing private upstreams: %w", err)
}
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return nil, fmt.Errorf("filtering private upstreams: %w", err)
}
}
return uc, nil
}
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
// nil. // nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) { func (s *Server) Prepare(conf *ServerConfig) (err error) {
@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings() s.initDefaultSettings()
boot, err := s.prepareInternalDNS() err = s.prepareInternalDNS()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err) return fmt.Errorf("preparing access: %w", err)
} }
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
proxyConfig.Fallbacks, err = s.setupFallbackDNS() proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil { if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err) return fmt.Errorf("setting up fallback dns servers: %w", err)
@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.dnsProxy = dnsProxy s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc() s.setupAddrProc()
s.registerHandlers() s.registerHandlers()
@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil return nil
} }
// prepareInternalDNS initializes the internal state of s before initializing // prepareUpstreamSettings sets upstream DNS server settings.
// the primary DNS proxy instance. It assumes s.serverLock is locked or the func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Server not running. // Load upstreams either from the file, or from the settings
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { var upstreams []string
err = s.prepareIpsetListSettings() upstreams, err = s.conf.loadUpstreams()
if err != nil { if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err) return fmt.Errorf("loading upstreams: %w", err)
} }
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Timeout: DefaultTimeout, Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
}) })
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
s.conf.UpstreamConfig = uc
return nil
}
// PrivateRDNSError is returned when the private rDNS upstreams are
// invalid but enabled.
//
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
// configuration in proxy when the private rDNS function is enabled. In theory,
// proxy supports the case when no upstreams provided to resolve the private
// request, since it already supports this for DNS64-prefixed PTR requests.
type PrivateRDNSError struct {
err error
}
// Error implements the [errors.Error] interface.
func (e *PrivateRDNSError) Error() (s string) {
return e.err.Error()
}
func (e *PrivateRDNSError) Unwrap() (err error) {
return e.err
}
// prepareLocalResolvers initializes the private RDNS upstream configuration
// according to the server's settings. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
return nil, nil
}
var ownAddrs addrPortSet
ownAddrs, err = s.conf.ourAddrsSet()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
opts := &upstream.Options{
Bootstrap: s.bootstrap,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
addrs := s.conf.LocalPTRResolvers
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
if err != nil {
return nil, fmt.Errorf("preparing resolvers: %w", err)
}
return uc, nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (err error) {
err = s.prepareIpsetListSettings()
if err != nil {
return fmt.Errorf("preparing ipset settings: %w", err)
}
bootOpts := &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
}
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = s.prepareUpstreamSettings(s.bootstrap) err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err return err
}
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
if err != nil {
return err
} }
err = s.prepareInternalProxy() err = s.prepareInternalProxy()
if err != nil { if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) return fmt.Errorf("preparing internal proxy: %w", err)
} }
return s.bootstrap, nil return nil
} }
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
@ -745,8 +709,14 @@ func (s *Server) prepareInternalProxy() (err error) {
conf := &proxy.Config{ conf := &proxy.Config{
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
UpstreamConfig: srvConf.UpstreamConfig, UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: s.conf.MaxGoroutines, MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
} }
} }
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers { for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
} }
@ -908,5 +873,5 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
blocked = true blocked = true
} }
return blocked, aghalg.Coalesce(rule, clientID) return blocked, cmp.Or(rule, clientID)
} }

View File

@ -1,7 +1,7 @@
package dnsforward package dnsforward
import ( import (
"context" "cmp"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
@ -21,7 +21,6 @@ import (
"testing/fstest" "testing/fstest"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -190,7 +189,7 @@ func newGoogleUpstream() (u upstream.Upstream) {
return &aghtest.UpstreamMock{ return &aghtest.UpstreamMock{
OnAddress: func() (addr string) { return "google.upstream.example" }, OnAddress: func() (addr string) { return "google.upstream.example" },
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce( return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"), aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil ), nil
@ -253,7 +252,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
for i := 0; i < testMessagesCount; i++ { for range testMessagesCount {
msg := createGoogleATestMessage() msg := createGoogleATestMessage()
wg.Add(1) wg.Add(1)
@ -276,7 +275,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) {
func sendTestMessages(t *testing.T, conn *dns.Conn) { func sendTestMessages(t *testing.T, conn *dns.Conn) {
t.Helper() t.Helper()
for i := 0; i < testMessagesCount; i++ { for i := range testMessagesCount {
req := createGoogleATestMessage() req := createGoogleATestMessage()
err := conn.WriteMsg(req) err := conn.WriteMsg(req)
assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err) assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err)
@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) {
} }
func TestSafeSearch(t *testing.T) { func TestSafeSearch(t *testing.T) {
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
safeSearchConf := filtering.SafeSearchConfig{ safeSearchConf := filtering.SafeSearchConfig{
Enabled: true, Enabled: true,
Google: true, Google: true,
Yandex: true, Yandex: true,
CustomResolver: resolver,
} }
filterConf := &filtering.Config{ filterConf := &filtering.Config{
@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) {
client := &dns.Client{} client := &dns.Client{}
yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56}) yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56})
googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
testCases := []struct { testCases := []struct {
host string host string
@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) {
wantCNAME: "", wantCNAME: "",
}, { }, {
host: "www.google.com.", host: "www.google.com.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.com.af.", host: "www.google.com.af.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.be.", host: "www.google.be.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}, { }, {
host: "www.google.by.", host: "www.google.by.",
want: googleIP, want: netip.Addr{},
wantCNAME: "forcesafesearch.google.com.", wantCNAME: "forcesafesearch.google.com.",
}} }}
@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) {
cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0]) cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0])
assert.Equal(t, tc.wantCNAME, cname.Target) assert.Equal(t, tc.wantCNAME, cname.Target)
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1])
assert.NotEmpty(t, a.A)
} else { } else {
require.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
}
a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1]) a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0])
assert.Equal(t, net.IP(tc.want.AsSlice()), a.A) assert.Equal(t, net.IP(tc.want.AsSlice()), a.A)
}
}) })
} }
} }
@ -691,7 +683,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
atomic.AddUint32(&upsCalledCounter, 1) atomic.AddUint32(&upsCalledCounter, 1)
return aghalg.Coalesce( return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"), aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil ), nil
@ -1152,7 +1144,7 @@ func TestRewrite(t *testing.T) {
})) }))
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
return aghalg.Coalesce( return cmp.Or(
aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"), aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
), nil ), nil
@ -1481,7 +1473,7 @@ func TestServer_Exchange(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)), aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))), doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
@ -1495,7 +1487,7 @@ func TestServer_Exchange(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)), aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), new(dns.Msg).SetRcode(req, dns.RcodeNameError),
) )

View File

@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
res *filtering.Result, res *filtering.Result,
pctx *proxy.DNSContext, pctx *proxy.DNSContext,
) (err error) { ) (err error) {
resp := s.makeResponse(req) resp := s.replyCompressed(req)
dnsrr := res.DNSRewriteResult dnsrr := res.DNSRewriteResult
if dnsrr == nil { if dnsrr == nil {
return errors.Error("no dns rewrite rule content") return errors.Error("no dns rewrite rule content")

View File

@ -1,57 +1,17 @@
package dnsforward package dnsforward
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}
// clientRequestFilteringSettings looks up client filtering settings using the // clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx. // client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
@ -71,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
req := pctx.Req req := pctx.Req
q := req.Question[0] q := req.Question[0]
host := strings.TrimSuffix(q.Name, ".") host := strings.TrimSuffix(q.Name, ".")
resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts) resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts)
if err != nil { if err != nil {
return nil, fmt.Errorf("checking host %q: %w", host, err) return nil, fmt.Errorf("checking host %q: %w", host, err)
@ -79,22 +40,15 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
// TODO(a.garipov): Make CheckHost return a pointer. // TODO(a.garipov): Make CheckHost return a pointer.
res = &resVal res = &resVal
switch { switch {
case res.IsFiltered: case isRewrittenCNAME(res):
log.Debug(
"dnsforward: host %q is filtered, reason: %q; rule: %q",
host,
res.Reason,
res.Rules[0].Text,
)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" &&
len(res.IPList) == 0:
// Resolve the new canonical name, not the original host name. The // Resolve the new canonical name, not the original host name. The
// original question is readded in processFilteringAfterResponse. // original question is readded in processFilteringAfterResponse.
dctx.origQuestion = q dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName) req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.Reason == filtering.Rewritten: case res.IsFiltered:
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
pctx.Res = s.genDNSFilterMessage(pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName) pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts): case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil { if err = s.filterDNSRewrite(req, res, pctx); err != nil {
@ -105,6 +59,17 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
return res, err return res, err
} }
// isRewrittenCNAME returns true if the request considered to be rewritten with
// CNAME and has no resolved IPs.
func isRewrittenCNAME(res *filtering.Result) (ok bool) {
return res.Reason.In(
filtering.Rewritten,
filtering.RewrittenRule,
filtering.FilteredSafeSearch) &&
res.CanonName != "" &&
len(res.IPList) == 0
}
// checkHostRules checks the host against filters. It is safe for concurrent // checkHostRules checks the host against filters. It is safe for concurrent
// use. // use.
func (s *Server) checkHostRules( func (s *Server) checkHostRules(

View File

@ -1,6 +1,7 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -261,55 +262,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
} }
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid. // validate returns an error if any field of req is invalid.
// //
// TODO(s.chzhen): Parse, don't validate. // TODO(s.chzhen): Parse, don't validate.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validate(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
defer func() { err = errors.Annotate(err, "validating dns config: %w") }() defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
err = req.validateUpstreamDNSServers(privateNets) err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
@ -342,20 +305,77 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil return nil
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkPrivateRDNS returns an error if the configuration of the private RDNS is
// not valid.
func (req *jsonDNSConfig) checkPrivateRDNS(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
if (req.UsePrivateRDNS == nil || !*req.UsePrivateRDNS) && req.LocalPTRUpstreams == nil {
return nil
}
addrs := cmp.Or(req.LocalPTRUpstreams, &[]string{})
uc, err := newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, &upstream.Options{})
err = errors.WithDeferred(err, uc.Close())
if err != nil {
return fmt.Errorf("private upstream servers: %w", err)
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid. // validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validateUpstreamDNSServers(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
var uc *proxy.UpstreamConfig
opts := &upstream.Options{}
if req.Upstreams != nil { if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{}) uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil { if err != nil {
return fmt.Errorf("upstream servers: %w", err) return fmt.Errorf("upstream servers: %w", err)
} }
} }
if req.LocalPTRUpstreams != nil { err = req.checkPrivateRDNS(ownAddrs, sysResolvers, privateNets)
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets)
if err != nil { if err != nil {
return fmt.Errorf("private upstream servers: %w", err) // Don't wrap the error since it's informative enough as is.
} return err
} }
err = req.checkBootstrap() err = req.checkBootstrap()
@ -364,10 +384,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err return err
} }
err = req.checkFallbacks() if req.Fallbacks != nil {
uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. return fmt.Errorf("fallback servers: %w", err)
return err }
} }
return nil return nil
@ -436,7 +458,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = req.validate(s.privateNets) // TODO(e.burkov): Consider prebuilding this set on startup.
ourAddrs, err := s.conf.ourAddrsSet()
if err != nil {
// TODO(e.burkov): Put into openapi.
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
return
}
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -587,7 +618,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
} }
var boots []*upstream.UpstreamResolver var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)

View File

@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "", wantSet: "",
}, { }, {
name: "local_ptr_upstreams_bad", name: "local_ptr_upstreams_bad",
wantSet: `validating dns config: ` + wantSet: `validating dns config: private upstream servers: ` +
`private upstream servers: checking domain-specific upstreams: ` + `bad arpa domain name "non.arpa": not a reversed ip network`,
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, { }, {
name: "local_ptr_upstreams_null", name: "local_ptr_upstreams_null",
wantSet: "", wantSet: "",
@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
} }
} }
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper() t.Helper()

View File

@ -11,17 +11,21 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// makeResponse creates a DNS response by req and sets necessary flags. It also // TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// guarantees that req.Question will be not empty. // template. Also extract all the methods to a separate entity.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{ // reply creates a DNS response for req.
MsgHdr: dns.MsgHdr{ func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
RecursionAvailable: true, resp = (&dns.Msg{}).SetRcode(req, code)
}, resp.RecursionAvailable = true
Compress: true,
return resp
} }
resp.SetReply(req) // replyCompressed creates a DNS response for req and sets the compress flag.
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeSuccess)
resp.Compress = true
return resp return resp
} }
@ -48,10 +52,10 @@ func (s *Server) genDNSFilterMessage(
) (resp *dns.Msg) { ) (resp *dns.Msg) {
req := dctx.Req req := dctx.Req
qt := req.Question[0].Qtype qt := req.Question[0].Qtype
if qt != dns.TypeA && qt != dns.TypeAAAA { if qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeHTTPS {
m, _, _ := s.dnsFilter.BlockingMode() m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP { if m == filtering.BlockingModeNullIP {
return s.makeResponse(req) return s.replyCompressed(req)
} }
return s.newMsgNODATA(req) return s.newMsgNODATA(req)
@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
// getCNAMEWithIPs generates a filtered response to req for with CNAME record // getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips. // and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
originalName := req.Question[0].Name originalName := req.Question[0].Name
@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeNullIP: case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req) return s.makeResponseNullIP(req)
case filtering.BlockingModeNXDOMAIN: case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req) return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED: case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req) return s.makeResponseREFUSED(req)
default: default:
log.Error("dnsforward: invalid blocking mode %q", mode) log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
// genDNSFilterMessage. // genDNSFilterMessage.
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp return resp
} }
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp return resp
} }
@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// Go on and return an empty response. // Go on and return an empty response.
} }
resp = s.makeResponse(req) resp = s.replyCompressed(req)
resp.Answer = ans resp.Answer = ans
return resp return resp
@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
case dns.TypeAAAA: case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default: default:
resp = s.makeResponse(req) resp = s.replyCompressed(req)
} }
return resp return resp
@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if newAddr == "" { if newAddr == "" {
log.Info("dnsforward: block host is not specified") log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
ip, err := netip.ParseAddr(newAddr) ip, err := netip.ParseAddr(newAddr)
@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if prx == nil { if prx == nil {
log.Debug("dnsforward: %s", srvClosedErr) log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
err = prx.Resolve(newContext) err = prx.Resolve(newContext)
if err != nil { if err != nil {
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
resp := s.makeResponse(request) resp := s.replyCompressed(request)
if newContext.Res != nil { if newContext.Res != nil {
for _, answer := range newContext.Res.Answer { for _, answer := range newContext.Res.Answer {
answer.Header().Name = request.Question[0].Name answer.Header().Name = request.Question[0].Name
@ -342,48 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp return resp
} }
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return false, nil
}
pctx.Res = s.makeResponseREFUSED(pctx.Req)
// Return true so that dnsproxy responds with the REFUSED message.
return true, nil
}
// Create REFUSED DNS response // Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
resp := dns.Msg{} return s.reply(req, dns.RcodeRefused)
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
} }
// newMsgNODATA returns a properly initialized NODATA response. // newMsgNODATA returns a properly initialized NODATA response.
// //
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) { func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess) resp = s.reply(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(req) resp.Ns = s.genSOA(req)
return resp return resp
} }
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR { func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := "" zone := ""
if len(request.Question) > 0 { if len(request.Question) > 0 {
@ -415,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
if len(zone) > 0 && zone[0] != '.' { if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone soa.Mbox += zone
} }
return []dns.RR{&soa} return []dns.RR{&soa}
} }
// type check
var _ proxy.MessageConstructor = (*Server)(nil)
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNameError)
resp.Ns = s.genSOA(req)
return resp
}
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return s.reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}

View File

@ -1,20 +1,17 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -34,11 +31,6 @@ type dnsContext struct {
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function. // err is the error returned from a processing function.
err error err error
@ -63,10 +55,6 @@ type dnsContext struct {
// responseAD shows if the response had the AD bit set. // responseAD shows if the response had the AD bit set.
responseAD bool responseAD bool
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is // isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request. // available for this request.
isDHCPHost bool isDHCPHost bool
@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// (*proxy.Proxy).handleDNSRequest method performs it before calling the // (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler. // appropriate handler.
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion,
s.processInitial, s.processInitial,
s.processDDRQuery, s.processDDRQuery,
s.processDetermineLocal,
s.processDHCPHosts, s.processDHCPHosts,
s.processRestrictLocal,
s.processDHCPAddrs, s.processDHCPAddrs,
s.processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream, s.processUpstream,
s.processFilteringAfterResponse, s.processFilteringAfterResponse,
s.ipset.process, s.ipset.process,
@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
return nil return nil
} }
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its // mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server. // own DoH server.
// //
@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
} }
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
if q.Name == healthcheckFQDN { if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0. // Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req) pctx.Res = s.replyCompressed(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
// //
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. // [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
if req.Question[0].Qtype != dns.TypeSVCB { if req.Question[0].Qtype != dns.TypeSVCB {
return resp return resp
} }
@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
return resp return resp
} }
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
// processDHCPHosts respond to A requests if the target hostname is known to // processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and // the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA. // the request is for AAAA.
@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
if !dctx.isLocalClient { if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
// Do not even put into query log. // Do not even put into query log.
return resultCodeFinish return resultCodeFinish
@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req) resp := s.replyCompressed(req)
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA:
a := &dns.A{ a := &dns.A{
@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if q.Qtype != dns.TypePTR {
// No need for restriction.
return resultCodeSuccess
}
subnet, err := extractARPASubnet(q.Name)
if err != nil {
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
return resultCodeSuccess
}
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
// hostname so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the // processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server. // DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@ -562,23 +380,27 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
ipAddr := dctx.unreversedReqIP req := pctx.Req
if ipAddr == (netip.Addr{}) { q := req.Question[0]
pref := pctx.RequestedPrivateRDNS
// TODO(e.burkov): Consider answering authoritatively for SOA and NS
// queries.
if pref == (netip.Prefix{}) || q.Qtype != dns.TypePTR {
return resultCodeSuccess return resultCodeSuccess
} }
host := s.dhcpServer.HostByIP(ipAddr) addr := pref.Addr()
host := s.dhcpServer.HostByIP(addr)
if host == "" { if host == "" {
return resultCodeSuccess return resultCodeSuccess
} }
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host) log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req resp := s.replyCompressed(req)
resp := s.makeResponse(req)
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: req.Question[0].Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
// TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See // TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See
// https://github.com/AdguardTeam/AdGuardHome/issues/3932. // https://github.com/AdguardTeam/AdGuardHome/issues/3932.
@ -593,62 +415,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
log.Debug("dnsforward: resolving private address: %s", err)
// Generate the server failure if the private upstream configuration
// is empty.
//
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError
}
}
if pctx.Res == nil {
pctx.Res = s.genNXDomain(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
// Apply filtering logic // Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req") log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
// so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
}
if dctx.proxyCtx.Res != nil { if dctx.proxyCtx.Res != nil {
// Go on since the response is already set. // Go on since the response is already set.
return resultCodeSuccess return resultCodeSuccess
@ -695,7 +475,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// local domain name if there is one. // local domain name if there is one.
name := req.Question[0].Name name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
return resultCodeFinish return resultCodeFinish
} }
@ -712,21 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeError return resultCodeError
} }
if err := prx.Resolve(pctx); err != nil { if dctx.err = prx.Resolve(pctx); dctx.err != nil {
if errors.Is(err, upstream.ErrNoUpstreams) {
// Do not even put into querylog. Currently this happens either
// when the private resolvers enabled and the request is DNS64 PTR,
// or when the client isn't considered local by prx.
//
// TODO(e.burkov): Make proxy detect local client the same way as
// AGH does.
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
}
dctx.err = err
return resultCodeError return resultCodeError
} }
@ -810,7 +576,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
} }
// Use the ClientID first, since it has a higher priority. // Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String()) id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil { if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
@ -835,7 +601,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess return resultCodeSuccess
case case
filtering.Rewritten, filtering.Rewritten,
filtering.RewrittenRule: filtering.RewrittenRule,
filtering.FilteredSafeSearch:
if dctx.origQuestion.Name == "" { if dctx.origQuestion.Name == "" {
// origQuestion is set in case we get only CNAME without IP from // origQuestion is set in case we get only CNAME without IP from
@ -845,11 +612,10 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
pctx := dctx.proxyCtx pctx := dctx.proxyCtx
pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion
if len(pctx.Res.Answer) > 0 {
rr := s.genAnswerCNAME(pctx.Req, res.CanonName) rr := s.genAnswerCNAME(pctx.Req, res.CanonName)
answer := append([]dns.RR{rr}, pctx.Res.Answer...) answer := append([]dns.RR{rr}, pctx.Res.Answer...)
pctx.Res.Answer = answer pctx.Res.Answer = answer
}
return resultCodeSuccess return resultCodeSuccess
default: default:

View File

@ -1,14 +1,15 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"net" "net"
"net/netip" "net/netip"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@ -70,8 +71,6 @@ func TestServer_ProcessInitial(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
@ -171,8 +170,6 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
@ -379,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f return f
} }
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
}
s.processDetermineLocal(dctx)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const ( const (
localDomainSuffix = "lan" localDomainSuffix = "lan"
@ -487,8 +446,8 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: tc.isLocalCli,
}, },
isLocalClient: tc.isLocalCli,
} }
res := s.processDHCPHosts(dctx) res := s.processDHCPHosts(dctx)
@ -622,8 +581,8 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: true,
}, },
isLocalClient: true,
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -658,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
} }
} }
func TestServer_ProcessRestrictLocal(t *testing.T) { // TODO(e.burkov): Rewrite this test to use the whole server instead of just
// testing the [handleDNSRequest] method. See comment on
// "from_external_for_local" test case.
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
require.NoError(t, err)
extAddr := netip.MustParseAddr("254.253.252.1")
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
require.NoError(t, err)
const ( const (
extPTRQuestion = "251.252.253.254.in-addr.arpa."
extPTRAnswer = "host1.example.net." extPTRAnswer = "host1.example.net."
intPTRQuestion = "1.1.168.192.in-addr.arpa."
intPTRAnswer = "some.local-client." intPTRAnswer = "some.local-client."
) )
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
@ -697,74 +665,114 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
want string question string
question net.IP wantErr error
cliAddr netip.AddrPort wantAns []dns.RR
wantLen int isPrivate bool
}{{ }{{
name: "from_local_to_external", name: "from_local_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: true,
}, { }, {
// In theory this case is not reproducible because [proxy.Proxy] should
// respond to such queries with NXDOMAIN before they reach
// [Server.handleDNSRequest].
name: "from_external_for_local", name: "from_external_for_local",
want: "", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: upstream.ErrNoUpstreams,
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"), wantAns: nil,
wantLen: 0, isPrivate: false,
}, { }, {
name: "from_local_for_local", name: "from_local_for_local",
want: "some.local-client.", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(intPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(intPTRAnswer) + 1),
},
Ptr: dns.Fqdn(intPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_external", name: "from_external_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
reqAddr, err := dns.ReverseAddr(tc.question.String()) pref, extErr := netutil.ExtractReversedAddr(tc.question)
require.NoError(t, err) require.NoError(t, extErr)
req := createTestMessageWithType(reqAddr, dns.TypePTR)
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP,
Req: req, Req: req,
Addr: tc.cliAddr, IsPrivateClient: tc.isPrivate,
}
// TODO(e.burkov): Configure the subnet set properly.
if netutil.IsLocallyServed(pref.Addr()) {
pctx.RequestedPrivateRDNS = pref
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, pctx) err = s.handleDNSRequest(s.dnsProxy, pctx)
require.NoError(t, err) require.ErrorIs(t, err, tc.wantErr)
require.NotNil(t, pctx.Res)
require.Len(t, pctx.Res.Answer, tc.wantLen)
if tc.wantLen > 0 { require.NotNil(t, pctx.Res)
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, tc.wantAns, pctx.Res.Answer)
}
}) })
} }
} }
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { func TestServer_ProcessUpstream_localPTR(t *testing.T) {
const locDomain = "some.local." const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa." const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
newPrxCtx := func() (prxCtx *proxy.DNSContext) {
return &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
IsPrivateClient: true,
RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
}
}
t.Run("enabled", func(t *testing.T) {
s := createTestServer( s := createTestServer(
t, t,
&filtering.Config{ &filtering.Config{
@ -782,37 +790,39 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
ServePlainDNS: true, ServePlainDNS: true,
}, },
) )
pctx := newPrxCtx()
var proxyCtx *proxy.DNSContext rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
}
t.Run("enabled", func(t *testing.T) {
setup(true)
rc := s.processLocalPTR(dnsCtx)
require.Equal(t, resultCodeSuccess, rc) require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer) require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, locDomain, ptr.Ptr)
}) })
t.Run("disabled", func(t *testing.T) { t.Run("disabled", func(t *testing.T) {
setup(false) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeFinish, rc) require.Equal(t, resultCodeError, rc)
require.Empty(t, proxyCtx.Res.Answer) require.Empty(t, pctx.Res.Answer)
}) })
} }
@ -830,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil)) assert.Empty(t, ipStringFromAddr(nil))
}) })
} }
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@ -29,7 +29,13 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr) log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
ids := []string{ipStr, dctx.clientID} ids := []string{ipStr}
if dctx.clientID != "" {
// Use the ClientID first because it has a higher priority. Filters
// have the same priority, see applyAdditionalFiltering.
ids = []string{dctx.clientID, ipStr}
}
qt, cl := q.Qtype, q.Qclass qt, cl := q.Qtype, q.Qclass
// Synchronize access to s.queryLog and s.stats so they won't be suddenly // Synchronize access to s.queryLog and s.stats so they won't be suddenly
@ -124,7 +130,7 @@ func (s *Server) logQuery(dctx *dnsContext, ip net.IP, processingTime time.Durat
s.queryLog.Add(p) s.queryLog.Add(p)
} }
// updatesStats writes the request into statistics. // updateStats writes the request data into statistics.
func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) { func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) {
pctx := dctx.proxyCtx pctx := dctx.proxyCtx

View File

@ -2,90 +2,77 @@ package dnsforward
import ( import (
"fmt" "fmt"
"net/netip"
"os"
"slices" "slices"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
) )
// loadUpstreams parses upstream DNS servers from the configured file or from // newBootstrap returns a bootstrap resolver based on the configuration of s.
// the configuration itself. // boots are the upstream resolvers that should be closed after use. r is the
func (s *Server) loadUpstreams() (upstreams []string, err error) { // actual bootstrap resolver, which may include the system hosts.
if s.conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// prepareUpstreamSettings sets upstream DNS server settings.
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Load upstreams either from the file, or from the settings
var upstreams []string
upstreams, err = s.loadUpstreams()
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
// //
// See [aghtls.SystemRootCAs]. // TODO(e.burkov): This function currently returns a resolver and a slice of
// // the upstream resolvers, which are essentially the same. boots are returned
// TODO(a.garipov): Investigate if that's true. // for being able to close them afterwards, but it introduces an implicit
RootCAs: s.conf.TLSv12Roots, // contract that r could only be used before that. Anyway, this code should
CipherSuites: s.conf.TLSCiphers, // improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
}) // and be used here.
func newBootstrap(
addrs []string,
etcHosts upstream.Resolver,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil { if err != nil {
return fmt.Errorf("preparing upstream config: %w", err) // Don't wrap the error, since it's informative enough as is.
return nil, nil, err
} }
return nil var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
} }
// prepareUpstreamConfig returns the upstream configuration based on upstreams if etcHosts != nil {
// and configuration of s. r = upstream.ConsequentResolver{etcHosts, parallel}
func (s *Server) prepareUpstreamConfig( } else {
r = parallel
}
return r, boots, nil
}
// newUpstreamConfig returns the upstream configuration based on upstreams. If
// upstreams slice specifies no default upstreams, defaultUpstreams are used to
// create upstreams with no domain specifications. opts are used when creating
// upstream configuration.
func newUpstreamConfig(
upstreams []string, upstreams []string,
defaultUpstreams []string, defaultUpstreams []string,
opts *upstream.Options, opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) { ) (uc *proxy.UpstreamConfig, err error) {
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts) uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing upstream config: %w", err) return uc, fmt.Errorf("parsing upstreams: %w", err)
} }
if len(uc.Upstreams) == 0 && defaultUpstreams != nil { if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams) log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts) defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing default upstreams: %w", err) return uc, fmt.Errorf("parsing default upstreams: %w", err)
} }
uc.Upstreams = defaultUpstreamConfig.Upstreams uc.Upstreams = defaultUpstreamConfig.Upstreams
@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
return uc, nil return uc, nil
} }
// newPrivateConfig creates an upstream configuration for resolving PTR records
// for local addresses. The configuration is built either from the provided
// addresses or from the system resolvers. unwanted filters the resulting
// upstream configuration.
func newPrivateConfig(
addrs []string,
unwanted addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
addrs = append(addrs, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
if err != nil {
return uc, fmt.Errorf("preparing private upstreams: %w", err)
}
if !confNeedsFiltering {
return uc, nil
}
err = filterOutAddrs(uc, unwanted)
if err != nil {
return uc, fmt.Errorf("filtering private upstreams: %w", err)
}
// Prevalidate the config to catch the exact error before creating proxy.
// See TODO on [PrivateRDNSError].
err = proxy.ValidatePrivateConfig(uc, privateNets)
if err != nil {
return uc, &PrivateRDNSError{err: err}
}
return uc, nil
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration // UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration. // depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
@ -130,85 +165,9 @@ func setProxyUpstreamMode(
return nil return nil
} }
// createBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func (s *Server) createBootstrap(
addrs []string,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
if s.etcHosts != nil {
r = upstream.ConsequentResolver{s.etcHosts, parallel}
} else {
r = parallel
}
return r, boots, nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. // IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream // This function is useful for filtering out non-upstream lines from upstream
// configs. // configs.
func IsCommentOrEmpty(s string) (ok bool) { func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' return len(s) == 0 || s[0] == '#'
} }
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}

View File

@ -559,6 +559,8 @@ type Result struct {
Reason Reason `json:",omitempty"` Reason Reason `json:",omitempty"`
// IsFiltered is true if the request is filtered. // IsFiltered is true if the request is filtered.
//
// TODO(d.kolyshev): Get rid of this flag.
IsFiltered bool `json:",omitempty"` IsFiltered bool `json:",omitempty"`
} }

View File

@ -200,7 +200,7 @@ func TestParallelSB(t *testing.T) {
t.Cleanup(d.Close) t.Cleanup(d.Close)
t.Run("group", func(t *testing.T) { t.Run("group", func(t *testing.T) {
for i := 0; i < 100; i++ { for i := range 100 {
t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) {
t.Parallel() t.Parallel()
d.checkMatch(t, sbBlocked, setts) d.checkMatch(t, sbBlocked, setts)
@ -670,7 +670,7 @@ func BenchmarkSafeBrowsing(b *testing.B) {
}, nil) }, nil)
b.Cleanup(d.Close) b.Cleanup(d.Close)
for n := 0; n < b.N; n++ { for range b.N {
res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) res, err := d.CheckHost(sbBlocked, dns.TypeA, setts)
require.NoError(b, err) require.NoError(b, err)

View File

@ -63,8 +63,6 @@ func TestIDGenerator_Fix(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
g := newIDGenerator(1) g := newIDGenerator(1)
g.fix(tc.in) g.fix(tc.in)

View File

@ -1,7 +1,6 @@
package rulelist_test package rulelist_test
import ( import (
"context"
"net/http" "net/http"
"testing" "testing"
@ -28,14 +27,12 @@ func TestEngine_Refresh(t *testing.T) {
require.NotNil(t, eng) require.NotNil(t, eng)
testutil.CleanupAndRequireSuccess(t, eng.Close) testutil.CleanupAndRequireSuccess(t, eng.Close)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize) buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{ cli := &http.Client{
Timeout: testTimeout, Timeout: testTimeout,
} }
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err) require.NoError(t, err)

View File

@ -1,7 +1,6 @@
package rulelist_test package rulelist_test
import ( import (
"context"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -67,14 +66,12 @@ func TestFilter_Refresh(t *testing.T) {
require.NotNil(t, f) require.NotNil(t, f)
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
buf := make([]byte, rulelist.DefaultRuleBufSize) buf := make([]byte, rulelist.DefaultRuleBufSize)
cli := &http.Client{ cli := &http.Client{
Timeout: testTimeout, Timeout: testTimeout,
} }
ctx := testutil.ContextWithTimeout(t, testTimeout)
res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize)
require.NoError(t, err) require.NoError(t, err)

View File

@ -132,7 +132,6 @@ func TestParser_Parse(t *testing.T) {
}} }}
for _, tc := range testCases { for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
@ -216,7 +215,7 @@ func BenchmarkParser_Parse(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for range b.N {
resSink, errSink = p.Parse(dst, src, buf) resSink, errSink = p.Parse(dst, src, buf)
dst.Reset() dst.Reset()
} }

View File

@ -1,7 +1,5 @@
package filtering package filtering
import "github.com/miekg/dns"
// SafeSearch interface describes a service for search engines hosts rewrites. // SafeSearch interface describes a service for search engines hosts rewrites.
type SafeSearch interface { type SafeSearch interface {
// CheckHost checks host with safe search filter. CheckHost must be safe // CheckHost checks host with safe search filter. CheckHost must be safe
@ -16,9 +14,6 @@ type SafeSearch interface {
// SafeSearchConfig is a struct with safe search related settings. // SafeSearchConfig is a struct with safe search related settings.
type SafeSearchConfig struct { type SafeSearchConfig struct {
// CustomResolver is the resolver used by safe search.
CustomResolver Resolver `yaml:"-" json:"-"`
// Enabled indicates if safe search is enabled entirely. // Enabled indicates if safe search is enabled entirely.
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
@ -40,13 +35,7 @@ func (d *DNSFilter) checkSafeSearch(
qtype uint16, qtype uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ProtectionEnabled || if d.safeSearch == nil || !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
!setts.SafeSearchEnabled ||
(qtype != dns.TypeA && qtype != dns.TypeAAAA) {
return Result{}, nil
}
if d.safeSearch == nil {
return Result{}, nil return Result{}, nil
} }

View File

@ -3,11 +3,9 @@ package safesearch
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"sync" "sync"
@ -67,7 +65,6 @@ type Default struct {
engine *urlfilter.DNSEngine engine *urlfilter.DNSEngine
cache cache.Cache cache cache.Cache
resolver filtering.Resolver
logPrefix string logPrefix string
cacheTTL time.Duration cacheTTL time.Duration
} }
@ -80,11 +77,6 @@ func NewDefault(
cacheSize uint, cacheSize uint,
cacheTTL time.Duration, cacheTTL time.Duration,
) (ss *Default, err error) { ) (ss *Default, err error) {
var resolver filtering.Resolver = net.DefaultResolver
if conf.CustomResolver != nil {
resolver = conf.CustomResolver
}
ss = &Default{ ss = &Default{
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
@ -92,7 +84,6 @@ func NewDefault(
EnableLRU: true, EnableLRU: true,
MaxSize: cacheSize, MaxSize: cacheSize,
}), }),
resolver: resolver,
// Use %s, because the client safe-search names already contain double // Use %s, because the client safe-search names already contain double
// quotes. // quotes.
logPrefix: fmt.Sprintf("safesearch %s: ", name), logPrefix: fmt.Sprintf("safesearch %s: ", name),
@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start))
}() }()
if qtype != dns.TypeA && qtype != dns.TypeAAAA { switch qtype {
return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype)) case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS:
// Go on.
default:
return filtering.Result{}, nil
} }
// Check cache. Return cached result if it was found // Check cache. Return cached result if it was found
@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res
} }
res = *fltRes res = *fltRes
// TODO(a.garipov): Consider switch back to resolving CNAME records IPs and
// saving results to cache.
ss.setCacheResult(host, qtype, res) ss.setCacheResult(host, qtype, res)
return res, nil return res, nil
@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe
} }
// newResult creates Result object from rewrite rule. qtype must be either // newResult creates Result object from rewrite rule. qtype must be either
// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the // [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is
// empty result is converted into a NODATA response. // never nil, so that the empty result is converted into a NODATA response.
//
// TODO(a.garipov): Use the main rewrite result mechanism used in
// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save
// them in the safe search cache.
func (ss *Default) newResult( func (ss *Default) newResult(
rewrite *rules.DNSRewrite, rewrite *rules.DNSRewrite,
qtype rules.RRType, qtype rules.RRType,
) (res *filtering.Result, err error) { ) (res *filtering.Result, err error) {
res = &filtering.Result{ res = &filtering.Result{
Rules: []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
}},
Reason: filtering.FilteredSafeSearch, Reason: filtering.FilteredSafeSearch,
IsFiltered: true, IsFiltered: true,
} }
@ -247,69 +237,19 @@ func (ss *Default) newResult(
return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value) return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value)
} }
res.Rules[0].IP = ip res.Rules = []*filtering.ResultRule{{
FilterListID: rulelist.URLFilterIDSafeSearch,
IP: ip,
}}
return res, nil return res, nil
} }
host := rewrite.NewCNAME res.CanonName = rewrite.NewCNAME
if host == "" {
return res, nil
}
res.CanonName = host
ss.log(log.DEBUG, "resolving %q", host)
ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host)
if err != nil {
return nil, fmt.Errorf("resolving cname: %w", err)
}
ss.log(log.DEBUG, "resolved %s", ips)
for _, ip := range ips {
// TODO(a.garipov): Remove this filtering once the resolver we use
// actually learns about network.
addr := fitToProto(ip, qtype)
if addr == (netip.Addr{}) {
continue
}
// TODO(e.burkov): Rules[0]?
res.Rules[0].IP = addr
}
return res, nil return res, nil
} }
// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA].
// It panics for other types.
func qtypeToProto(qtype rules.RRType) (proto string) {
switch qtype {
case dns.TypeA:
return "ip4"
case dns.TypeAAAA:
return "ip6"
default:
panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype)))
}
}
// fitToProto returns a non-nil IP address if ip is the correct protocol version
// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA].
func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) {
if ip4 := ip.To4(); qtype == dns.TypeA {
if ip4 != nil {
return netip.AddrFrom4([4]byte(ip4))
}
} else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA {
return netip.AddrFrom16([16]byte(ip))
}
return netip.Addr{}
}
// setCacheResult stores data in cache for host. qtype is expected to be either // setCacheResult stores data in cache for host. qtype is expected to be either
// [dns.TypeA] or [dns.TypeAAAA]. // [dns.TypeA] or [dns.TypeAAAA].
func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) {

View File

@ -1,13 +1,10 @@
package safesearch package safesearch
import ( import (
"context"
"net"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) {
assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) assert.Equal(t, cachedValue.Rules[0].IP, yandexIP)
} }
func TestSafeSearchCacheGoogle(t *testing.T) {
const domain = "www.google.ru"
ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false})
res, err := ss.CheckHost(domain, testQType)
require.NoError(t, err)
assert.False(t, res.IsFiltered)
assert.Empty(t, res.Rules)
resolver := &aghtest.Resolver{
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
ss = newForTest(t, defaultSafeSearchConf)
ss.resolver = resolver
// Lookup for safesearch domain.
rewrite := ss.searchHost(domain, testQType)
wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME)
res, err = ss.CheckHost(domain, testQType)
require.NoError(t, err)
require.Len(t, res.Rules, 1)
assert.Equal(t, wantIP, res.Rules[0].IP)
// Check cache.
cachedValue, isFound := ss.getCachedResult(domain, testQType)
require.True(t, isFound)
require.Len(t, cachedValue.Rules, 1)
assert.Equal(t, wantIP, cachedValue.Rules[0].IP)
}
const googleHost = "www.google.com" const googleHost = "www.google.com"
var dnsRewriteSink *rules.DNSRewrite var dnsRewriteSink *rules.DNSRewrite
@ -127,7 +83,7 @@ var dnsRewriteSink *rules.DNSRewrite
func BenchmarkSafeSearch(b *testing.B) { func BenchmarkSafeSearch(b *testing.B) {
ss := newForTest(b, defaultSafeSearchConf) ss := newForTest(b, defaultSafeSearchConf)
for n := 0; n < b.N; n++ { for range b.N {
dnsRewriteSink = ss.searchHost(googleHost, testQType) dnsRewriteSink = ss.searchHost(googleHost, testQType)
} }

View File

@ -7,7 +7,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
@ -31,8 +30,6 @@ const (
// testConf is the default safe search configuration for tests. // testConf is the default safe search configuration for tests.
var testConf = filtering.SafeSearchConfig{ var testConf = filtering.SafeSearchConfig{
CustomResolver: nil,
Enabled: true, Enabled: true,
Bing: true, Bing: true,
@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// Check host for each domain. hosts := []string{
for _, host := range []string{
"yandex.ru", "yandex.ru",
"yAndeX.ru", "yAndeX.ru",
"YANdex.COM", "YANdex.COM",
"yandex.by", "yandex.by",
"yandex.kz", "yandex.kz",
"www.yandex.com", "www.yandex.com",
} { }
testCases := []struct {
want netip.Addr
name string
qt uint16
}{{
want: yandexIP,
name: "a",
qt: dns.TypeA,
}, {
want: netip.Addr{},
name: "aaaa",
qt: dns.TypeAAAA,
}, {
want: netip.Addr{},
name: "https",
qt: dns.TypeHTTPS,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, host := range hosts {
// Check host for each domain.
var res filtering.Result var res filtering.Result
res, err = ss.CheckHost(host, testQType) res, err = ss.CheckHost(host, tc.qt)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
if tc.want == (netip.Addr{}) {
assert.Empty(t, res.Rules)
} else {
require.Len(t, res.Rules, 1) require.Len(t, res.Rules, 1)
assert.Equal(t, yandexIP, res.Rules[0].IP) rule := res.Rules[0]
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) assert.Equal(t, tc.want, rule.IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID)
} }
} }
})
func TestDefault_CheckHost_yandexAAAA(t *testing.T) { }
conf := testConf
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err)
res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA)
require.NoError(t, err)
assert.True(t, res.IsFiltered)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule
// with a nil IP address. This isn't really necessary and should be changed
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
} }
func TestDefault_CheckHost_google(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) {
resolver := &aghtest.Resolver{ ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) {
ip4, ip6 := aghtest.HostToIPs(host)
return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil
},
}
wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com")
conf := testConf
conf.CustomResolver = resolver
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// Check host for each domain. // Check host for each domain.
@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
require.Len(t, res.Rules, 1) assert.Equal(t, "forcesafesearch.google.com", res.CanonName)
assert.Empty(t, res.Rules)
assert.Equal(t, wantIP, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
}) })
} }
} }
@ -154,17 +148,7 @@ func (r *testResolver) LookupIP(
} }
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
conf := testConf ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL)
conf.CustomResolver = &testResolver{
OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) {
assert.Equal(t, "ip6", network)
assert.Equal(t, "safe.duckduckgo.com", host)
return nil, nil
},
}
ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL)
require.NoError(t, err) require.NoError(t, err)
// The DuckDuckGo safe-search addresses are resolved through CNAMEs, but // The DuckDuckGo safe-search addresses are resolved through CNAMEs, but
@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.True(t, res.IsFiltered) assert.True(t, res.IsFiltered)
assert.Equal(t, filtering.FilteredSafeSearch, res.Reason)
// TODO(a.garipov): Currently, the safe-search filter returns a single rule assert.Equal(t, "safe.duckduckgo.com", res.CanonName)
// with a nil IP address. This isn't really necessary and should be changed assert.Empty(t, res.Rules)
// once the TODO in [safesearch.Default.newResult] is resolved.
require.Len(t, res.Rules, 1)
assert.Empty(t, res.Rules[0].IP)
assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID)
} }
func TestDefault_Update(t *testing.T) { func TestDefault_Update(t *testing.T) {

View File

@ -24,7 +24,6 @@ import (
"github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
) )
// DHCP is an interface for accessing DHCP lease data the [clientsContainer] // DHCP is an interface for accessing DHCP lease data the [clientsContainer]
@ -46,22 +45,20 @@ type DHCP interface {
// clientsContainer is the storage of all runtime and persistent clients. // clientsContainer is the storage of all runtime and persistent clients.
type clientsContainer struct { type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different // clientIndex stores information about persistent clients.
// types (string, netip.Addr, and so on).
list map[string]*client.Persistent // name -> client
clientIndex *client.Index clientIndex *client.Index
// ipToRC maps IP addresses to runtime client information. // runtimeIndex stores information about runtime clients.
ipToRC map[netip.Addr]*client.Runtime runtimeIndex *client.RuntimeIndex
allTags *container.MapSet[string] allTags *container.MapSet[string]
// dhcp is the DHCP service implementation. // dhcp is the DHCP service implementation.
dhcp DHCP dhcp DHCP
// dnsServer is used for checking clients IP status access list status // clientChecker checks if a client is blocked by the current access
dnsServer *dnsforward.Server // settings.
clientChecker BlockedClientChecker
// etcHosts contains list of rewrite rules taken from the operating system's // etcHosts contains list of rewrite rules taken from the operating system's
// hosts database. // hosts database.
@ -90,6 +87,12 @@ type clientsContainer struct {
testing bool testing bool
} }
// BlockedClientChecker checks if a client is blocked by the current access
// settings.
type BlockedClientChecker interface {
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
}
// Init initializes clients container // Init initializes clients container
// dhcpServer: optional // dhcpServer: optional
// Note: this function must be called only once // Note: this function must be called only once
@ -100,12 +103,12 @@ func (clients *clientsContainer) Init(
arpDB arpdb.Interface, arpDB arpdb.Interface,
filteringConf *filtering.Config, filteringConf *filtering.Config,
) (err error) { ) (err error) {
if clients.list != nil { // TODO(s.chzhen): Refactor it.
log.Fatal("clients.list != nil") if clients.clientIndex != nil {
return errors.Error("clients container already initialized")
} }
clients.list = map[string]*client.Persistent{} clients.runtimeIndex = client.NewRuntimeIndex()
clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.clientIndex = client.NewIndex() clients.clientIndex = client.NewIndex()
@ -248,8 +251,6 @@ func (o *clientObject) toPersistent(
} }
if o.SafeSearchConf.Enabled { if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err = cli.SetSafeSearch( err = cli.SetSafeSearch(
o.SafeSearchConf, o.SafeSearchConf,
filteringConf.SafeSearchCacheSize, filteringConf.SafeSearchCacheSize,
@ -285,9 +286,17 @@ func (clients *clientsContainer) addFromConfig(
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
} }
_, err = clients.add(cli) // TODO(s.chzhen): Consider moving to the client index constructor.
err = clients.clientIndex.ClashesUID(cli)
if err != nil { if err != nil {
log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err) return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err)
}
err = clients.add(cli)
if err != nil {
// TODO(s.chzhen): Return an error instead of logging if more
// stringent requirements are implemented.
log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err)
} }
} }
@ -300,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
objs = make([]*clientObject, 0, len(clients.list)) objs = make([]*clientObject, 0, clients.clientIndex.Size())
for _, cli := range clients.list { clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
o := &clientObject{ objs = append(objs, &clientObject{
Name: cli.Name, Name: cli.Name,
BlockedServices: cli.BlockedServices.Clone(), BlockedServices: cli.BlockedServices.Clone(),
@ -323,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
IgnoreStatistics: cli.IgnoreStatistics, IgnoreStatistics: cli.IgnoreStatistics,
UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled, UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled,
UpstreamsCacheSize: cli.UpstreamsCacheSize, UpstreamsCacheSize: cli.UpstreamsCacheSize,
} })
objs = append(objs, o) return true
} })
// Maps aren't guaranteed to iterate in the same order each time, so the // Maps aren't guaranteed to iterate in the same order each time, so the
// above loop can generate different orderings when writing to the config // above loop can generate different orderings when writing to the config
@ -363,8 +372,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent return client.SourcePersistent
} }
rc, ok := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
if ok { if rc != nil {
src, _ = rc.Info() src, _ = rc.Info()
} }
@ -406,23 +415,26 @@ func (clients *clientsContainer) clientOrArtificial(
id string, id string,
) (c *querylog.Client, art bool) { ) (c *querylog.Client, art bool) {
defer func() { defer func() {
c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id) c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id)
if c.WHOIS == nil { if c.WHOIS == nil {
c.WHOIS = &whois.Info{} c.WHOIS = &whois.Info{}
} }
}() }()
cli, ok := clients.find(id) cli, ok := clients.find(id)
if ok { if !ok {
cli = clients.clientIndex.FindByIPWithoutZone(ip)
}
if cli != nil {
return &querylog.Client{ return &querylog.Client{
Name: cli.Name, Name: cli.Name,
IgnoreQueryLog: cli.IgnoreQueryLog, IgnoreQueryLog: cli.IgnoreQueryLog,
}, false }, false
} }
var rc *client.Runtime rc := clients.findRuntimeClient(ip)
rc, ok = clients.findRuntimeClient(ip) if rc != nil {
if ok {
_, host := rc.Info() _, host := rc.Info()
return &querylog.Client{ return &querylog.Client{
@ -542,47 +554,38 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
return nil, false return nil, false
} }
for _, c = range clients.list { return clients.clientIndex.FindByMAC(foundMAC)
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
if found {
return c, true
}
}
return nil, false
} }
// runtimeClient returns a runtime client from internal index. Note that it // runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients. // doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
if ip == (netip.Addr{}) { if ip == (netip.Addr{}) {
return nil, false return nil
} }
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
rc, ok = clients.ipToRC[ip] return clients.runtimeIndex.Client(ip)
return rc, ok
} }
// findRuntimeClient finds a runtime client by their IP. // findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc, ok = clients.runtimeClient(ip) rc = clients.runtimeClient(ip)
host := clients.dhcp.HostByIP(ip) host := clients.dhcp.HostByIP(ip)
if host != "" { if host != "" {
if !ok { if rc == nil {
rc = &client.Runtime{} rc = client.NewRuntime(ip)
} }
rc.SetInfo(client.SourceDHCP, []string{host}) rc.SetInfo(client.SourceDHCP, []string{host})
return rc, true return rc
} }
return rc, ok return rc
} }
// check validates the client. It also sorts the client tags. // check validates the client. It also sorts the client tags.
@ -615,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) {
return nil return nil
} }
// add adds a new client object. ok is false if such client already exists or // add adds a persistent client or returns an error.
// if an error occurred. func (clients *clientsContainer) add(c *client.Persistent) (err error) {
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
err = clients.check(c) err = clients.check(c)
if err != nil { if err != nil {
return false, err // Don't wrap the error since it's informative enough as is.
return err
} }
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
// check Name index
_, ok = clients.list[c.Name]
if ok {
return false, nil
}
// check ID index
err = clients.clientIndex.Clashes(c) err = clients.clientIndex.Clashes(c)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return false, err return err
} }
clients.addLocked(c) clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list)) log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size())
return true, nil return nil
} }
// addLocked c to the indexes. clients.lock is expected to be locked. // addLocked c to the indexes. clients.lock is expected to be locked.
func (clients *clientsContainer) addLocked(c *client.Persistent) { func (clients *clientsContainer) addLocked(c *client.Persistent) {
// update Name index
clients.list[c.Name] = c
// update ID index
clients.clientIndex.Add(c) clients.clientIndex.Add(c)
} }
@ -660,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
var c *client.Persistent c, ok := clients.clientIndex.FindByName(name)
c, ok = clients.list[name]
if !ok { if !ok {
return false return false
} }
@ -678,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) {
log.Error("client container: removing client %s: %s", c.Name, err) log.Error("client container: removing client %s: %s", c.Name, err)
} }
// Update the name index.
delete(clients.list, c.Name)
// Update the ID index. // Update the ID index.
clients.clientIndex.Delete(c) clients.clientIndex.Delete(c)
} }
@ -696,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error)
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
// Check the name index.
if prev.Name != c.Name {
_, ok := clients.list[c.Name]
if ok {
return errors.Error("client already exists")
}
}
if c.EqualIDs(prev) {
clients.removeLocked(prev)
clients.addLocked(c)
return nil
}
// Check the ID index.
err = clients.clientIndex.Clashes(c) err = clients.clientIndex.Clashes(c)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
@ -734,12 +706,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return return
} }
rc, ok := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
if !ok { if rc == nil {
// Create a RuntimeClient implicitly so that we don't do this check // Create a RuntimeClient implicitly so that we don't do this check
// again. // again.
rc = &client.Runtime{} rc = client.NewRuntime(ip)
clients.ipToRC[ip] = rc clients.runtimeIndex.Add(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else { } else {
@ -798,61 +770,54 @@ func (clients *clientsContainer) addHostLocked(
host string, host string,
src client.Source, src client.Source,
) (ok bool) { ) (ok bool) {
rc, ok := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
if !ok { if rc == nil {
if src < client.SourceDHCP { if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" { if clients.dhcp.HostByIP(ip) != "" {
return false return false
} }
} }
rc = &client.Runtime{} rc = client.NewRuntime(ip)
clients.ipToRC[ip] = rc clients.runtimeIndex.Add(rc)
} }
rc.SetInfo(src, []string{host}) rc.SetInfo(src, []string{host})
log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC)) log.Debug(
"clients: adding client info %s -> %q %q [%d]",
ip,
src,
host,
clients.runtimeIndex.Size(),
)
return true return true
} }
// rmHostsBySrc removes all entries that match the specified source.
func (clients *clientsContainer) rmHostsBySrc(src client.Source) {
n := 0
for ip, rc := range clients.ipToRC {
rc.Unset(src)
if rc.IsEmpty() {
delete(clients.ipToRC, ip)
n++
}
}
log.Debug("clients: removed %d client aliases", n)
}
// addFromHostsFile fills the client-hostname pairing index from the system's // addFromHostsFile fills the client-hostname pairing index from the system's
// hosts files. // hosts files.
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
clients.rmHostsBySrc(client.SourceHostsFile) deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
n := 0 added := 0
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical // Only the first name of the first record is considered a canonical
// hostname for the IP address. // hostname for the IP address.
// //
// TODO(e.burkov): Consider using all the names from all the records. // TODO(e.burkov): Consider using all the names from all the records.
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
n++ added++
} }
return true return true
}) })
log.Debug("clients: added %d client aliases from system hosts file", n) log.Debug("clients: added %d client aliases from system hosts file", added)
} }
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
@ -876,7 +841,8 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
clients.rmHostsBySrc(client.SourceARP) deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
added := 0 added := 0
for _, n := range ns { for _, n := range ns {
@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() {
// close gracefully closes all the client-specific upstream configurations of // close gracefully closes all the client-specific upstream configurations of
// the persistent clients. // the persistent clients.
func (clients *clientsContainer) close() (err error) { func (clients *clientsContainer) close() (err error) {
persistent := maps.Values(clients.list) return clients.clientIndex.CloseUpstreams()
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
return strings.Compare(a.Name, b.Name)
})
var errs []error
for _, cli := range persistent {
if err = cli.CloseUpstreams(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
} }

View File

@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
} }
dhcp := &testDHCP{ dhcp := &testDHCP{
OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
OnHostBy: func(ip netip.Addr) (host string) { return "" }, OnHostBy: func(ip netip.Addr) (host string) { return "" },
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
} }
@ -72,23 +72,19 @@ func TestClients(t *testing.T) {
IPs: []netip.Addr{cli1IP, cliIPv6}, IPs: []netip.Addr{cli1IP, cliIPv6},
} }
ok, err := clients.add(c) err := clients.add(c)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
c = &client.Persistent{ c = &client.Persistent{
Name: "client2", Name: "client2",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{cli2IP}, IPs: []netip.Addr{cli2IP},
} }
ok, err = clients.add(c) err = clients.add(c)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok) c, ok := clients.find(cli1)
c, ok = clients.find(cli1)
require.True(t, ok) require.True(t, ok)
assert.Equal(t, "client1", c.Name) assert.Equal(t, "client1", c.Name)
@ -111,22 +107,20 @@ func TestClients(t *testing.T) {
}) })
t.Run("add_fail_name", func(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) {
ok, err := clients.add(&client.Persistent{ err := clients.add(&client.Persistent{
Name: "client1", Name: "client1",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
}) })
require.NoError(t, err) require.Error(t, err)
assert.False(t, ok)
}) })
t.Run("add_fail_ip", func(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) {
ok, err := clients.add(&client.Persistent{ err := clients.add(&client.Persistent{
Name: "client3", Name: "client3",
UID: client.MustNewUID(), UID: client.MustNewUID(),
}) })
require.Error(t, err) require.Error(t, err)
assert.False(t, ok)
}) })
t.Run("update_fail_ip", func(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) {
@ -145,12 +139,13 @@ func TestClients(t *testing.T) {
cliNewIP = netip.MustParseAddr(cliNew) cliNewIP = netip.MustParseAddr(cliNew)
) )
prev, ok := clients.list["client1"] prev, ok := clients.clientIndex.FindByName("client1")
require.True(t, ok) require.True(t, ok)
require.NotNil(t, prev)
err := clients.update(prev, &client.Persistent{ err := clients.update(prev, &client.Persistent{
Name: "client1", Name: "client1",
UID: client.MustNewUID(), UID: prev.UID,
IPs: []netip.Addr{cliNewIP}, IPs: []netip.Addr{cliNewIP},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -160,12 +155,13 @@ func TestClients(t *testing.T) {
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
prev, ok = clients.list["client1"] prev, ok = clients.clientIndex.FindByName("client1")
require.True(t, ok) require.True(t, ok)
require.NotNil(t, prev)
err = clients.update(prev, &client.Persistent{ err = clients.update(prev, &client.Persistent{
Name: "client1-renamed", Name: "client1-renamed",
UID: client.MustNewUID(), UID: prev.UID,
IPs: []netip.Addr{cliNewIP}, IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true, UseOwnSettings: true,
}) })
@ -177,7 +173,7 @@ func TestClients(t *testing.T) {
assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "client1-renamed", c.Name)
assert.True(t, c.UseOwnSettings) assert.True(t, c.UseOwnSettings)
nilCli, ok := clients.list["client1"] nilCli, ok := clients.clientIndex.FindByName("client1")
require.False(t, ok) require.False(t, ok)
assert.Nil(t, nilCli) assert.Nil(t, nilCli)
@ -244,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) { t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255") ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc) require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS()) assert.Equal(t, whois, rc.WHOIS())
@ -256,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
require.NotNil(t, rc) require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS()) assert.Equal(t, whois, rc.WHOIS())
@ -265,16 +261,15 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("can't_set_manually-added", func(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.2") ip := netip.MustParseAddr("1.1.1.2")
ok, err := clients.add(&client.Persistent{ err := clients.add(&client.Persistent{
Name: "client1", Name: "client1",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.ipToRC[ip] rc := clients.runtimeIndex.Client(ip)
require.Nil(t, rc) require.Nil(t, rc)
assert.True(t, clients.remove("client1")) assert.True(t, clients.remove("client1"))
@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1") ip := netip.MustParseAddr("1.1.1.1")
// Add a client. // Add a client.
ok, err := clients.add(&client.Persistent{ err := clients.add(&client.Persistent{
Name: "client1", Name: "client1",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) {
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
// Now add an auto-client with the same IP. // Now add an auto-client with the same IP.
ok = clients.addHost(ip, "test", client.SourceRDNS) ok := clients.addHost(ip, "test", client.SourceRDNS)
assert.True(t, ok) assert.True(t, ok)
}) })
@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Add a new client with the same IP as for a client with MAC. // Add a new client with the same IP as for a client with MAC.
ok, err := clients.add(&client.Persistent{ err = clients.add(&client.Persistent{
Name: "client2", Name: "client2",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{ip}, IPs: []netip.Addr{ip},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
// Add a new client with the IP from the first client's IP range. // Add a new client with the IP from the first client's IP range.
ok, err = clients.add(&client.Persistent{ err = clients.add(&client.Persistent{
Name: "client3", Name: "client3",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
}) })
} }
@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer(t) clients := newClientsContainer(t)
// Add client with upstreams. // Add client with upstreams.
ok, err := clients.add(&client.Persistent{ err := clients.add(&client.Persistent{
Name: "client1", Name: "client1",
UID: client.MustNewUID(), UID: client.MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) {
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
assert.True(t, ok)
upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver)
assert.Nil(t, upsConf) assert.Nil(t, upsConf)

View File

@ -96,22 +96,26 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
for _, c := range clients.list { clients.clientIndex.Range(func(c *client.Persistent) (cont bool) {
cj := clientToJSON(c) cj := clientToJSON(c)
data.Clients = append(data.Clients, cj) data.Clients = append(data.Clients, cj)
}
for ip, rc := range clients.ipToRC { return true
})
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info() src, host := rc.Info()
cj := runtimeClientJSON{ cj := runtimeClientJSON{
WHOIS: whoisOrEmpty(rc), WHOIS: whoisOrEmpty(rc),
Name: host, Name: host,
Source: src, Source: src,
IP: ip, IP: rc.Addr(),
} }
data.RuntimeClients = append(data.RuntimeClients, cj) data.RuntimeClients = append(data.RuntimeClients, cj)
}
return true
})
for _, l := range clients.dhcp.Leases() { for _, l := range clients.dhcp.Leases() {
cj := runtimeClientJSON{ cj := runtimeClientJSON{
@ -332,21 +336,17 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
return return
} }
ok, err := clients.add(c) err = clients.add(c)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
return return
} }
if !ok { if !clients.testing {
aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists")
return
}
onConfigModified() onConfigModified()
} }
}
// handleDelClient is the handler for POST /control/clients/delete HTTP API. // handleDelClient is the handler for POST /control/clients/delete HTTP API.
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
@ -370,8 +370,10 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
return return
} }
if !clients.testing {
onConfigModified() onConfigModified()
} }
}
// updateJSON contains the name and data of the updated persistent client. // updateJSON contains the name and data of the updated persistent client.
type updateJSON struct { type updateJSON struct {
@ -404,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
prev, ok = clients.list[dj.Name] prev, ok = clients.clientIndex.FindByName(dj.Name)
}() }()
if !ok { if !ok {
@ -427,14 +429,16 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
return return
} }
if !clients.testing {
onConfigModified() onConfigModified()
} }
}
// handleFindClient is the handler for GET /control/clients/find HTTP API. // handleFindClient is the handler for GET /control/clients/find HTTP API.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query() q := r.URL.Query()
data := []map[string]*clientJSON{} data := []map[string]*clientJSON{}
for i := 0; i < len(q); i++ { for i := range len(q) {
idStr := q.Get(fmt.Sprintf("ip%d", i)) idStr := q.Get(fmt.Sprintf("ip%d", i))
if idStr == "" { if idStr == "" {
break break
@ -447,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
cj = clients.findRuntime(ip, idStr) cj = clients.findRuntime(ip, idStr)
} else { } else {
cj = clientToJSON(c) cj = clientToJSON(c)
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
} }
@ -463,14 +467,14 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
// non-nil. // non-nil.
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
rc, ok := clients.findRuntimeClient(ip) rc := clients.findRuntimeClient(ip)
if !ok { if rc == nil {
// It is still possible that the IP used to be in the runtime clients // It is still possible that the IP used to be in the runtime clients
// list, but then the server was reloaded. So, check the DNS server's // list, but then the server was reloaded. So, check the DNS server's
// blocked IP list. // blocked IP list.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428. // See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj = &clientJSON{ cj = &clientJSON{
IDs: []string{idStr}, IDs: []string{idStr},
Disallowed: &disallowed, Disallowed: &disallowed,
@ -488,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c
WHOIS: whoisOrEmpty(rc), WHOIS: whoisOrEmpty(rc),
} }
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr)
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
return cj return cj

View File

@ -0,0 +1,399 @@
package home
import (
"bytes"
"cmp"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"net/url"
"slices"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
testClientIP1 = "1.1.1.1"
testClientIP2 = "2.2.2.2"
)
// testBlockedClientChecker is a mock implementation of the
// [BlockedClientChecker] interface.
type testBlockedClientChecker struct {
onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string)
}
// type check
var _ BlockedClientChecker = (*testBlockedClientChecker)(nil)
// IsBlockedClient implements the [BlockedClientChecker] interface for
// *testBlockedClientChecker.
func (c *testBlockedClientChecker) IsBlockedClient(
ip netip.Addr,
clientID string,
) (blocked bool, rule string) {
return c.onIsBlockedClient(ip, clientID)
}
// newPersistentClient is a helper function that returns a persistent client
// with the specified name and newly generated UID.
func newPersistentClient(name string) (c *client.Persistent) {
return &client.Persistent{
Name: name,
UID: client.MustNewUID(),
BlockedServices: &filtering.BlockedServices{
Schedule: &schedule.Weekly{},
},
}
}
// newPersistentClientWithIDs is a helper function that returns a persistent
// client with the specified name and ids.
func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) {
tb.Helper()
c = newPersistentClient(name)
err := c.SetIDs(ids)
require.NoError(tb, err)
return c
}
// assertClients is a helper function that compares lists of persistent clients.
func assertClients(tb testing.TB, want, got []*client.Persistent) {
tb.Helper()
require.Len(tb, got, len(want))
sortFunc := func(a, b *client.Persistent) (n int) {
return cmp.Compare(a.Name, b.Name)
}
slices.SortFunc(want, sortFunc)
slices.SortFunc(got, sortFunc)
slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) {
assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name)
return 0
})
}
// assertPersistentClients is a helper function that uses HTTP API to check
// whether want persistent clients are the same as the persistent clients stored
// in the clients container.
func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) {
tb.Helper()
rw := httptest.NewRecorder()
clients.handleGetClients(rw, &http.Request{})
body, err := io.ReadAll(rw.Body)
require.NoError(tb, err)
clientList := &clientListJSON{}
err = json.Unmarshal(body, clientList)
require.NoError(tb, err)
var got []*client.Persistent
for _, cj := range clientList.Clients {
var c *client.Persistent
c, err = clients.jsonToClient(*cj, nil)
require.NoError(tb, err)
got = append(got, c)
}
assertClients(tb, want, got)
}
// assertPersistentClientsData is a helper function that checks whether want
// persistent clients are the same as the persistent clients stored in data.
func assertPersistentClientsData(
tb testing.TB,
clients *clientsContainer,
data []map[string]*clientJSON,
want []*client.Persistent,
) {
tb.Helper()
var got []*client.Persistent
for _, cm := range data {
for _, cj := range cm {
var c *client.Persistent
c, err := clients.jsonToClient(*cj, nil)
require.NoError(tb, err)
got = append(got, c)
}
}
assertClients(tb, want, got)
}
func TestClientsContainer_HandleAddClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
testCases := []struct {
name string
client *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "add_one",
client: clientOne,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne},
}, {
name: "add_two",
client: clientTwo,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne, clientTwo},
}, {
name: "duplicate_client",
client: clientTwo,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientOne, clientTwo},
}, {
name: "empty_client_id",
client: clientEmptyID,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientOne, clientTwo},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cj := clientToJSON(tc.client)
body, err := json.Marshal(cj)
require.NoError(t, err)
r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleAddClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleDelClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
testCases := []struct {
name string
client *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "remove_one",
client: clientOne,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "duplicate_client",
client: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "empty_client_name",
client: newPersistentClient(""),
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientTwo},
}, {
name: "remove_two",
client: clientTwo,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cj := clientToJSON(tc.client)
var body []byte
body, err = json.Marshal(cj)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleDelClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleUpdateClient(t *testing.T) {
clients := newClientsContainer(t)
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne})
clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
clientEmptyID := newPersistentClient("empty_client_id")
clientEmptyID.ClientIDs = []string{""}
testCases := []struct {
name string
clientName string
modified *client.Persistent
wantCode int
wantClient []*client.Persistent
}{{
name: "update_one",
clientName: clientOne.Name,
modified: clientModified,
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientModified},
}, {
name: "empty_name",
clientName: "",
modified: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "client_not_found",
clientName: "client_not_found",
modified: clientOne,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "empty_client_id",
clientName: clientModified.Name,
modified: clientEmptyID,
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}, {
name: "no_ids",
clientName: clientModified.Name,
modified: newPersistentClient("no_ids"),
wantCode: http.StatusBadRequest,
wantClient: []*client.Persistent{clientModified},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
uj := updateJSON{
Name: tc.clientName,
Data: *clientToJSON(tc.modified),
}
var body []byte
body, err = json.Marshal(uj)
require.NoError(t, err)
var r *http.Request
r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body))
require.NoError(t, err)
rw := httptest.NewRecorder()
clients.handleUpdateClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
assertPersistentClients(t, clients, tc.wantClient)
})
}
}
func TestClientsContainer_HandleFindClient(t *testing.T) {
clients := newClientsContainer(t)
clients.clientChecker = &testBlockedClientChecker{
onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) {
return false, ""
},
}
clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1})
err := clients.add(clientOne)
require.NoError(t, err)
clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2})
err = clients.add(clientTwo)
require.NoError(t, err)
assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo})
testCases := []struct {
name string
query url.Values
wantCode int
wantClient []*client.Persistent
}{{
name: "single",
query: url.Values{
"ip0": []string{testClientIP1},
},
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne},
}, {
name: "multiple",
query: url.Values{
"ip0": []string{testClientIP1},
"ip1": []string{testClientIP2},
},
wantCode: http.StatusOK,
wantClient: []*client.Persistent{clientOne, clientTwo},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var r *http.Request
r, err = http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
r.URL.RawQuery = tc.query.Encode()
rw := httptest.NewRecorder()
clients.handleFindClient(rw, r)
require.NoError(t, err)
require.Equal(t, tc.wantCode, rw.Code)
var body []byte
body, err = io.ReadAll(rw.Body)
require.NoError(t, err)
clientData := []map[string]*clientJSON{}
err = json.Unmarshal(body, &clientData)
require.NoError(t, err)
assertPersistentClientsData(t, clients, clientData, tc.wantClient)
})
}
}

View File

@ -203,15 +203,24 @@ type dnsConfig struct {
// resolver should be used. // resolver should be used.
PrivateNets []netutil.Prefix `yaml:"private_networks"` PrivateNets []netutil.Prefix `yaml:"private_networks"`
// UsePrivateRDNS defines if the PTR requests for unknown addresses from // UsePrivateRDNS enables resolving requests containing a private IP address
// locally-served networks should be resolved via private PTR resolvers. // using private reverse DNS resolvers. See PrivateRDNSResolvers.
//
// TODO(e.burkov): Rename in YAML.
UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"` UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"`
// LocalPTRResolvers is the slice of addresses to be used as upstreams // PrivateRDNSResolvers is the slice of addresses to be used as upstreams
// for PTR queries for locally-served networks. // for private requests. It's only used for PTR, SOA, and NS queries,
LocalPTRResolvers []string `yaml:"local_ptr_upstreams"` // containing an ARPA subdomain, came from the the client with private
// address. The address considered private according to PrivateNets.
//
// If empty, the OS-provided resolvers are used for private requests.
PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"`
// UseDNS64 defines if DNS64 should be used for incoming requests. // UseDNS64 defines if DNS64 should be used for incoming requests. Requests
// of type PTR for addresses within the configured prefixes will be resolved
// via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS
// be set to true.
UseDNS64 bool `yaml:"use_dns64"` UseDNS64 bool `yaml:"use_dns64"`
// DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64. // DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64.
@ -658,7 +667,7 @@ func (c *configuration) write() (err error) {
dns := &config.DNS dns := &config.DNS
dns.Config = c dns.Config = c
dns.LocalPTRResolvers = s.LocalPTRResolvers() dns.PrivateRDNSResolvers = s.LocalPTRResolvers()
addrProcConf := s.AddrProcConfig() addrProcConf := s.AddrProcConfig()
config.Clients.Sources.RDNS = addrProcConf.UseRDNS config.Clients.Sources.RDNS = addrProcConf.UseRDNS

View File

@ -1,7 +1,6 @@
package home package home
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@ -18,7 +17,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -150,21 +148,19 @@ func initDNSServer(
return fmt.Errorf("dnsforward.NewServer: %w", err) return fmt.Errorf("dnsforward.NewServer: %w", err)
} }
Context.clients.dnsServer = Context.dnsServer Context.clients.clientChecker = Context.dnsServer
dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg) dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg)
if err != nil { if err != nil {
return fmt.Errorf("newServerConfig: %w", err) return fmt.Errorf("newServerConfig: %w", err)
} }
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
} }
@ -245,7 +241,7 @@ func newServerConfig(
TLSv12Roots: Context.tlsRoots, TLSv12Roots: Context.tlsRoots,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpReg, HTTPRegister: httpReg,
LocalPTRResolvers: dnsConf.LocalPTRResolvers, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
UseDNS64: dnsConf.UseDNS64, UseDNS64: dnsConf.UseDNS64,
DNS64Prefixes: dnsConf.DNS64Prefixes, DNS64Prefixes: dnsConf.DNS64Prefixes,
UsePrivateRDNS: dnsConf.UsePrivateRDNS, UsePrivateRDNS: dnsConf.UsePrivateRDNS,
@ -531,36 +527,6 @@ func closeDNSServer() {
log.Debug("all dns modules are closed") log.Debug("all dns modules are closed")
} }
// safeSearchResolver is a [filtering.Resolver] implementation used for safe
// search.
type safeSearchResolver struct{}
// type check
var _ filtering.Resolver = safeSearchResolver{}
// LookupIP implements [filtering.Resolver] interface for safeSearchResolver.
// It returns the slice of net.Addr with IPv4 and IPv6 instances.
func (r safeSearchResolver) LookupIP(
ctx context.Context,
network string,
host string,
) (ips []net.IP, err error) {
addrs, err := Context.dnsServer.Resolve(ctx, network, host)
if err != nil {
return nil, err
}
if len(addrs) == 0 {
return nil, fmt.Errorf("couldn't lookup host: %s", host)
}
for _, a := range addrs {
ips = append(ips, a.AsSlice())
}
return ips, nil
}
// checkStatsAndQuerylogDirs checks and returns directory paths to store // checkStatsAndQuerylogDirs checks and returns directory paths to store
// statistics and query log. // statistics and query log.
func checkStatsAndQuerylogDirs( func checkStatsAndQuerylogDirs(

View File

@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
conf.ParentalBlockHost = host conf.ParentalBlockHost = host
} }
conf.SafeSearchConf.CustomResolver = safeSearchResolver{}
conf.SafeSearch, err = safesearch.NewDefault( conf.SafeSearch, err = safesearch.NewDefault(
conf.SafeSearchConf, conf.SafeSearchConf,
"default", "default",

View File

@ -1,13 +1,13 @@
package home package home
import ( import (
"cmp"
"fmt" "fmt"
"path/filepath" "path/filepath"
"runtime" "runtime"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"gopkg.in/natefinch/lumberjack.v2" "gopkg.in/natefinch/lumberjack.v2"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -76,8 +76,7 @@ func getLogSettings(opts options) (ls *logSettings) {
ls.Verbose = true ls.Verbose = true
} }
// TODO(a.garipov): Use cmp.Or in Go 1.22. ls.File = cmp.Or(opts.logFile, ls.File)
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if // When running as a Windows service, use eventlog by default if

View File

@ -306,7 +306,7 @@ func handleServiceStatusCommand(s service.Service) {
} }
} }
// handleServiceStatusCommand handles service "install" command // handleServiceInstallCommand handles service "install" command.
func handleServiceInstallCommand(s service.Service) { func handleServiceInstallCommand(s service.Service) {
err := svcAction(s, "install") err := svcAction(s, "install")
if err != nil { if err != nil {
@ -340,7 +340,7 @@ AdGuard Home is now available at the following addresses:`)
} }
} }
// handleServiceStatusCommand handles service "uninstall" command // handleServiceUninstallCommand handles service "uninstall" command.
func handleServiceUninstallCommand(s service.Service) { func handleServiceUninstallCommand(s service.Service) {
if aghos.IsOpenWrt() { if aghos.IsOpenWrt() {
// On OpenWrt it is important to run disable command first // On OpenWrt it is important to run disable command first
@ -649,11 +649,6 @@ status() {
// freeBSDScript is the source of the daemon script for FreeBSD. Keep as close // freeBSDScript is the source of the daemon script for FreeBSD. Keep as close
// as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204. // as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204.
//
// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no
// guarantees that it will actually be the required directory.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2614.
const freeBSDScript = `#!/bin/sh const freeBSDScript = `#!/bin/sh
# PROVIDE: {{.Name}} # PROVIDE: {{.Name}}
# REQUIRE: networking # REQUIRE: networking
@ -667,7 +662,9 @@ name="{{.Name}}"
pidfile_child="/var/run/${name}.pid" pidfile_child="/var/run/${name}.pid"
pidfile="/var/run/${name}_daemon.pid" pidfile="/var/run/${name}_daemon.pid"
command="/usr/sbin/daemon" command="/usr/sbin/daemon"
command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}" daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}"
command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}"
run_rc_command "$1" run_rc_command "$1"
` `

View File

@ -3,6 +3,7 @@
package home package home
import ( import (
"cmp"
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
@ -14,7 +15,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/kardianos/service" "github.com/kardianos/service"
) )
@ -76,7 +76,7 @@ func (*openbsdRunComService) Platform() (p string) {
// String implements service.Service interface for *openbsdRunComService. // String implements service.Service interface for *openbsdRunComService.
func (s *openbsdRunComService) String() string { func (s *openbsdRunComService) String() string {
return stringutil.Coalesce(s.cfg.DisplayName, s.cfg.Name) return cmp.Or(s.cfg.DisplayName, s.cfg.Name)
} }
// getBool returns the value of the given name from kv, assuming the value is a // getBool returns the value of the given name from kv, assuming the value is a

View File

@ -147,7 +147,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
b.Run("long", func(b *testing.B) { b.Run("long", func(b *testing.B) {
const name = "a.very.long.domain.name.inside.the.domain.example.com" const name = "a.very.long.domain.name.inside.the.domain.example.com"
for i := 0; i < b.N; i++ { for range b.N {
ipsetPropsSink = m.lookupHost(name) ipsetPropsSink = m.lookupHost(name)
} }
@ -156,7 +156,7 @@ func BenchmarkManager_LookupHost(b *testing.B) {
b.Run("short", func(b *testing.B) { b.Run("short", func(b *testing.B) {
const name = "example.net" const name = "example.net"
for i := 0; i < b.N; i++ { for range b.N {
ipsetPropsSink = m.lookupHost(name) ipsetPropsSink = m.lookupHost(name)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/google/renameio/v2/maybe" "github.com/google/renameio/v2/maybe"
) )
@ -38,7 +39,7 @@ func (h *signalHandler) handle() {
if aghos.IsReconfigureSignal(sig) { if aghos.IsReconfigureSignal(sig) {
h.reconfigure() h.reconfigure()
} else if aghos.IsShutdownSignal(sig) { } else if osutil.IsShutdownSignal(sig) {
status := h.shutdown() status := h.shutdown()
h.removePID() h.removePID()
@ -122,7 +123,8 @@ func newSignalHandler(
services: svcs, services: svcs,
} }
aghos.NotifyShutdownSignal(h.signal) notifier := osutil.DefaultSignalNotifier{}
osutil.NotifyShutdownSignal(notifier, h.signal)
aghos.NotifyReconfigureSignal(h.signal) aghos.NotifyReconfigureSignal(h.signal)
return h return h

View File

@ -1,7 +1,6 @@
package dnssvc_test package dnssvc_test
import ( import (
"context"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@ -94,10 +93,8 @@ func TestService(t *testing.T) {
}}, }},
} }
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{} cli := &dns.Client{}
ctx := testutil.ContextWithTimeout(t, testTimeout)
var resp *dns.Msg var resp *dns.Msg
require.Eventually(t, func() (ok bool) { require.Eventually(t, func() (ok bool) {
@ -110,10 +107,8 @@ func TestService(t *testing.T) {
assert.NotNil(t, resp) assert.NotNil(t, resp)
}) })
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) err = svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err) require.NoError(t, err)
err = upstreamSrv.Shutdown() err = upstreamSrv.Shutdown()

View File

@ -109,12 +109,8 @@ func newTestServer(
err = svc.Start() err = svc.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { testutil.CleanupAndRequireSuccess(t, func() (err error) {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
}) })
c = svc.Config() c = svc.Config()

View File

@ -303,7 +303,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name, func(b *testing.B) { b.Run(bc.name, func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for range b.N {
AnonymizeIP(bc.ip) AnonymizeIP(bc.ip)
} }
@ -313,7 +313,7 @@ func BenchmarkAnonymizeIP(b *testing.B) {
b.Run(bc.name+"_slow", func(b *testing.B) { b.Run(bc.name+"_slow", func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for i := 0; i < b.N; i++ { for range b.N {
anonymizeIPSlow(bc.ip) anonymizeIPSlow(bc.ip)
} }

View File

@ -31,6 +31,7 @@ type logEntry struct {
Answer []byte `json:",omitempty"` Answer []byte `json:",omitempty"`
OrigAnswer []byte `json:",omitempty"` OrigAnswer []byte `json:",omitempty"`
// TODO(s.chzhen): Use netip.Addr.
IP net.IP `json:"IP"` IP net.IP `json:"IP"`
Result filtering.Result Result filtering.Result

View File

@ -143,13 +143,13 @@ func TestQueryLogOffsetLimit(t *testing.T) {
secondPageDomain = "second.example.org" secondPageDomain = "second.example.org"
) )
// Add entries to the log. // Add entries to the log.
for i := 0; i < entNum; i++ { for range entNum {
addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to the first file. // Write them to the first file.
require.NoError(t, l.flushLogBuffer()) require.NoError(t, l.flushLogBuffer())
// Add more to the in-memory part of log. // Add more to the in-memory part of log.
for i := 0; i < entNum; i++ { for range entNum {
addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
@ -215,7 +215,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
const entNum = 10 const entNum = 10
// Add entries to the log. // Add entries to the log.
for i := 0; i < entNum; i++ { for range entNum {
addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// Write them to disk. // Write them to disk.

View File

@ -37,7 +37,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) {
var lineIP uint32 var lineIP uint32
lineTime := time.Date(2020, 2, 18, 19, 36, 35, 920973000, time.UTC) lineTime := time.Date(2020, 2, 18, 19, 36, 35, 920973000, time.UTC)
for i := 0; i < linesNum; i++ { for range linesNum {
lineIP++ lineIP++
lineTime = lineTime.Add(time.Second) lineTime = lineTime.Add(time.Second)

View File

@ -68,13 +68,13 @@ func TestStats_races(t *testing.T) {
startWG, finWG := &sync.WaitGroup{}, &sync.WaitGroup{} startWG, finWG := &sync.WaitGroup{}, &sync.WaitGroup{}
waitCh := make(chan unit) waitCh := make(chan unit)
for i := 0; i < writersNum; i++ { for i := range writersNum {
startWG.Add(1) startWG.Add(1)
finWG.Add(1) finWG.Add(1)
go writeFunc(startWG, finWG, waitCh, i) go writeFunc(startWG, finWG, waitCh, i)
} }
for i := 0; i < readersNum; i++ { for range readersNum {
startWG.Add(1) startWG.Add(1)
finWG.Add(1) finWG.Add(1)
go readFunc(startWG, finWG, waitCh) go readFunc(startWG, finWG, waitCh)
@ -111,7 +111,7 @@ func TestStatsCtx_FillCollectedStats_daily(t *testing.T) {
dailyData := []*unitDB{} dailyData := []*unitDB{}
for i := 0; i < daysCount*24; i++ { for i := range daysCount * 24 {
n := uint64(i) n := uint64(i)
nResult := make([]uint64, resultLast) nResult := make([]uint64, resultLast)
nResult[RFiltered] = n nResult[RFiltered] = n

View File

@ -195,7 +195,7 @@ func TestLargeNumbers(t *testing.T) {
for h := 0; h < hoursNum; h++ { for h := 0; h < hoursNum; h++ {
atomic.AddUint32(&curHour, 1) atomic.AddUint32(&curHour, 1)
for i := 0; i < cliNumPerHour; i++ { for i := range cliNumPerHour {
ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)} ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)}
e := &stats.Entry{ e := &stats.Entry{
Domain: fmt.Sprintf("domain%d.hour%d", i, h), Domain: fmt.Sprintf("domain%d.hour%d", i, h),

View File

@ -525,9 +525,8 @@ func (s *StatsCtx) fillCollectedStatsDaily(
hours := countHours(curHour, days) hours := countHours(curHour, days)
units = units[len(units)-hours:] units = units[len(units)-hours:]
for i := 0; i < len(units); i++ { for i, u := range units {
day := i / 24 day := i / 24
u := units[i]
data.DNSQueries[day] += u.NTotal data.DNSQueries[day] += u.NTotal
data.BlockedFiltering[day] += u.NResult[RFiltered] data.BlockedFiltering[day] += u.NResult[RFiltered]

View File

@ -1,6 +1,6 @@
module github.com/AdguardTeam/AdGuardHome/internal/tools module github.com/AdguardTeam/AdGuardHome/internal/tools
go 1.22.2 go 1.22.3
require ( require (
github.com/fzipp/gocyclo v0.6.0 github.com/fzipp/gocyclo v0.6.0

View File

@ -3,6 +3,7 @@ package whois
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"fmt" "fmt"
"io" "io"
@ -17,7 +18,6 @@ import (
"github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/bluele/gcache" "github.com/bluele/gcache"
) )
@ -174,7 +174,7 @@ func whoisParse(data []byte, maxLen int) (info map[string]string) {
val = trimValue(val, maxLen) val = trimValue(val, maxLen)
case "descr", "netname": case "descr", "netname":
key = "orgname" key = "orgname"
val = stringutil.Coalesce(orgname, val) val = cmp.Or(orgname, val)
orgname = val orgname = val
case "whois": case "whois":
key = "whois" key = "whois"
@ -232,7 +232,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string]
server := net.JoinHostPort(w.serverAddr, w.portStr) server := net.JoinHostPort(w.serverAddr, w.portStr)
var data []byte var data []byte
for i := 0; i < w.maxRedirects; i++ { for range w.maxRedirects {
data, err = w.query(ctx, target, server) data, err = w.query(ctx, target, server)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.

View File

@ -48,7 +48,7 @@ func (c *twoskyClient) download() (err error) {
failed := &sync.Map{} failed := &sync.Map{}
uriCh := make(chan *url.URL, len(c.langs)) uriCh := make(chan *url.URL, len(c.langs))
for i := 0; i < numWorker; i++ { for range numWorker {
wg.Add(1) wg.Add(1)
go downloadWorker(wg, failed, client, uriCh) go downloadWorker(wg, failed, client, uriCh)
} }

View File

@ -5,6 +5,7 @@ package main
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"cmp"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
@ -204,19 +205,13 @@ type twoskyClient struct {
func (t *twoskyConfig) toClient() (cli *twoskyClient, err error) { func (t *twoskyConfig) toClient() (cli *twoskyClient, err error) {
defer func() { err = errors.Annotate(err, "filling config: %w") }() defer func() { err = errors.Annotate(err, "filling config: %w") }()
uriStr := os.Getenv("TWOSKY_URI") uriStr := cmp.Or(os.Getenv("TWOSKY_URI"), twoskyURI)
if uriStr == "" {
uriStr = twoskyURI
}
uri, err := url.Parse(uriStr) uri, err := url.Parse(uriStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
projectID := os.Getenv("TWOSKY_PROJECT_ID") projectID := cmp.Or(os.Getenv("TWOSKY_PROJECT_ID"), defaultProjectID)
if projectID == "" {
projectID = defaultProjectID
}
baseLang := t.BaseLangcode baseLang := t.BaseLangcode
uLangStr := os.Getenv("UPLOAD_LANGUAGE") uLangStr := os.Getenv("UPLOAD_LANGUAGE")