From 667263a3a865834b0965b3d4098bdeaedf3937c1 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Wed, 15 May 2024 13:34:12 +0300 Subject: [PATCH] all: sync with master --- .github/workflows/build.yml | 2 +- .github/workflows/lint.yml | 2 +- CHANGELOG.md | 58 ++- Makefile | 2 +- bamboo-specs/release.yaml | 8 +- bamboo-specs/snapcraft.yaml | 2 +- bamboo-specs/test.yaml | 4 +- client/src/__locales/en.json | 4 +- go.mod | 23 +- go.sum | 42 +- internal/aghalg/aghalg.go | 21 - internal/aghalg/ringbuffer_test.go | 8 +- internal/aghalg/sortedmap_test.go | 2 +- internal/aghos/filewalker.go | 2 + internal/aghos/os.go | 10 - internal/aghos/os_unix.go | 16 - internal/aghos/os_windows.go | 16 - internal/aghrenameio/renameio_test.go | 1 - internal/client/addrproc_test.go | 4 - internal/client/client.go | 26 +- internal/client/index.go | 131 +++++- internal/client/index_internal_test.go | 171 +++++++- internal/client/persistent.go | 7 +- internal/client/runtimeindex.go | 63 +++ internal/client/runtimeindex_test.go | 85 ++++ internal/configmigrate/v15.go | 4 +- internal/configmigrate/v24.go | 4 +- internal/configmigrate/v26.go | 4 +- internal/configmigrate/v7.go | 4 +- internal/configmigrate/yaml.go | 16 - internal/dnsforward/access.go | 5 +- internal/dnsforward/beforerequest.go | 116 +++++ .../dnsforward/beforerequest_internal_test.go | 299 +++++++++++++ internal/dnsforward/clientid.go | 40 -- internal/dnsforward/config.go | 80 +++- internal/dnsforward/dns64_test.go | 74 +++- internal/dnsforward/dnsforward.go | 277 ++++++------ internal/dnsforward/dnsforward_test.go | 50 +-- internal/dnsforward/dnsrewrite.go | 2 +- internal/dnsforward/filter.go | 69 +-- internal/dnsforward/http.go | 145 ++++--- internal/dnsforward/http_test.go | 57 +-- internal/dnsforward/msg.go | 130 +++--- internal/dnsforward/process.go | 306 ++------------ internal/dnsforward/process_internal_test.go | 396 ++++++----------- internal/dnsforward/stats.go | 10 +- internal/dnsforward/upstreams.go | 215 ++++------ internal/filtering/filtering.go | 2 + internal/filtering/filtering_test.go | 4 +- .../filtering/idgenerator_internal_test.go | 2 - internal/filtering/rulelist/engine_test.go | 5 +- internal/filtering/rulelist/filter_test.go | 5 +- internal/filtering/rulelist/parser_test.go | 3 +- internal/filtering/safesearch.go | 13 +- internal/filtering/safesearch/safesearch.go | 90 +--- .../safesearch/safesearch_internal_test.go | 46 +- .../filtering/safesearch/safesearch_test.go | 113 ++--- internal/home/clients.go | 213 ++++------ internal/home/clients_internal_test.go | 55 +-- internal/home/clientshttp.go | 46 +- internal/home/clientshttp_internal_test.go | 399 ++++++++++++++++++ internal/home/config.go | 23 +- internal/home/dns.go | 46 +- internal/home/home.go | 1 - internal/home/log.go | 5 +- internal/home/service.go | 13 +- internal/home/service_openbsd.go | 4 +- internal/ipset/ipset_linux_internal_test.go | 4 +- internal/next/cmd/signal.go | 6 +- internal/next/dnssvc/dnssvc_test.go | 9 +- internal/next/websvc/websvc_test.go | 8 +- internal/querylog/decode_test.go | 4 +- internal/querylog/entry.go | 1 + internal/querylog/qlog_test.go | 6 +- internal/querylog/qlogfile_test.go | 2 +- internal/stats/stats_internal_test.go | 6 +- internal/stats/stats_test.go | 2 +- internal/stats/unit.go | 3 +- internal/tools/go.mod | 2 +- internal/whois/whois.go | 6 +- scripts/translations/download.go | 2 +- scripts/translations/main.go | 11 +- 82 files changed, 2356 insertions(+), 1817 deletions(-) create mode 100644 internal/client/runtimeindex.go create mode 100644 internal/client/runtimeindex_test.go create mode 100644 internal/dnsforward/beforerequest.go create mode 100644 internal/dnsforward/beforerequest_internal_test.go create mode 100644 internal/home/clientshttp_internal_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f30180e9..fcc514df 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,7 +1,7 @@ 'name': 'build' 'env': - 'GO_VERSION': '1.22.2' + 'GO_VERSION': '1.22.3' 'NODE_VERSION': '16' 'on': diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 2e606520..2fe919d6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,7 +1,7 @@ 'name': 'lint' 'env': - 'GO_VERSION': '1.22.2' + 'GO_VERSION': '1.22.3' 'on': 'push': diff --git a/CHANGELOG.md b/CHANGELOG.md index 4adc38e8..6205bf9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ and this project adheres to +### Security + +- Go version has been updated to prevent the possibility of exploiting the Go + vulnerabilities fixed in [Go 1.22.3][go-1.22.3]. + +### Added + +- Support for comments in the ipset file ([#5345]). + +### Changed + +- Private rDNS resolution now also affects `SOA` and `NS` requests ([#6882]). +- Rewrite rules mechanics was changed due to improve resolving in safe search. + +### Deprecated + +- Currently, AdGuard Home skips persistent clients that have duplicate fields + when reading them from the configuration file. This behaviour is deprecated + and will cause errors on startup in a future release. + +### Fixed + +- Acceptance of duplicate UIDs for persistent clients at startup. See also the + section on client settings on the [Wiki page][wiki-config]. +- Domain specifications for top-level domains not considered for requests to + unqualified domains ([#6744]). +- Support for link-local subnets, i.e. `fe80::/16`, as client identifiers + ([#6312]). +- Issues with QUIC and HTTP/3 upstreams on older Linux kernel versions + ([#6422]). +- YouTube restricted mode is not enforced by HTTPS queries on Firefox. +- Support for link-local subnets, i.e. `fe80::/16`, in the access settings + ([#6192]). +- The ability to apply an invalid configuration for private rDNS, which led to + server not starting. +- Ignoring query log for clients with ClientID set ([#5812]). +- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix + incorrectly considered invalid when specified for private rDNS upstream + servers ([#6854]). +- Unspecified IP addresses aren't checked when using "Fastest IP address" mode + ([#6875]). + +[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345 +[#5812]: https://github.com/AdguardTeam/AdGuardHome/issues/5812 +[#6192]: https://github.com/AdguardTeam/AdGuardHome/issues/6192 +[#6312]: https://github.com/AdguardTeam/AdGuardHome/issues/6312 +[#6422]: https://github.com/AdguardTeam/AdGuardHome/issues/6422 +[#6744]: https://github.com/AdguardTeam/AdGuardHome/issues/6744 +[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854 +[#6875]: https://github.com/AdguardTeam/AdGuardHome/issues/6875 +[#6882]: https://github.com/AdguardTeam/AdGuardHome/issues/6882 + +[go-1.22.3]: https://groups.google.com/g/golang-announce/c/wkkO4P9stm0 + @@ -35,7 +89,7 @@ See also the [v0.107.48 GitHub milestone][ms-v0.107.48]. ### Fixed -- Access settings not being applied to encrypted protocols ([#6890]) +- Access settings not being applied to encrypted protocols ([#6890]). [#6890]: https://github.com/AdguardTeam/AdGuardHome/issues/6890 diff --git a/Makefile b/Makefile index 55939ff9..66c387ea 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ DIST_DIR = dist GOAMD64 = v1 GOPROXY = https://goproxy.cn|https://proxy.golang.org|direct GOSUMDB = sum.golang.google.cn -GOTOOLCHAIN = go1.22.2 +GOTOOLCHAIN = go1.22.3 GPG_KEY = devteam@adguard.com GPG_KEY_PASSPHRASE = not-a-real-password NPM = npm diff --git a/bamboo-specs/release.yaml b/bamboo-specs/release.yaml index 9a991bc9..f6313c0a 100644 --- a/bamboo-specs/release.yaml +++ b/bamboo-specs/release.yaml @@ -8,7 +8,7 @@ 'variables': 'channel': 'edge' 'dockerFrontend': 'adguard/home-js-builder:1.1' - 'dockerGo': 'adguard/go-builder:1.22.2--1' + 'dockerGo': 'adguard/go-builder:1.22.3--1' 'stages': - 'Build frontend': @@ -249,7 +249,7 @@ 'recipients': - 'webhook': 'name': 'Build webhook' - 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa' + 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa-dns-builds' 'labels': [] 'other': @@ -266,7 +266,7 @@ 'variables': 'channel': 'beta' 'dockerFrontend': 'adguard/home-js-builder:1.1' - 'dockerGo': 'adguard/go-builder:1.22.2--1' + 'dockerGo': 'adguard/go-builder:1.22.3--1' # release-vX.Y.Z branches are the branches from which the actual final # release is built. - '^release-v[0-9]+\.[0-9]+\.[0-9]+': @@ -282,4 +282,4 @@ 'variables': 'channel': 'release' 'dockerFrontend': 'adguard/home-js-builder:1.1' - 'dockerGo': 'adguard/go-builder:1.22.2--1' + 'dockerGo': 'adguard/go-builder:1.22.3--1' diff --git a/bamboo-specs/snapcraft.yaml b/bamboo-specs/snapcraft.yaml index 14e9d3df..a36d99d3 100644 --- a/bamboo-specs/snapcraft.yaml +++ b/bamboo-specs/snapcraft.yaml @@ -175,7 +175,7 @@ 'recipients': - 'webhook': 'name': 'Build webhook' - 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa' + 'url': 'http://prod.jirahub.service.eu.consul/v1/webhook/bamboo?channel=adguard-qa-dns-builds' 'labels': [] 'other': diff --git a/bamboo-specs/test.yaml b/bamboo-specs/test.yaml index b18ab1c0..b58fdcd6 100644 --- a/bamboo-specs/test.yaml +++ b/bamboo-specs/test.yaml @@ -6,7 +6,7 @@ 'name': 'AdGuard Home - Build and run tests' 'variables': 'dockerFrontend': 'adguard/home-js-builder:1.1' - 'dockerGo': 'adguard/go-builder:1.22.2--1' + 'dockerGo': 'adguard/go-builder:1.22.3--1' 'channel': 'development' 'stages': @@ -195,5 +195,5 @@ # may need to build a few of these. 'variables': 'dockerFrontend': 'adguard/home-js-builder:1.1' - 'dockerGo': 'adguard/go-builder:1.22.2--1' + 'dockerGo': 'adguard/go-builder:1.22.3--1' 'channel': 'candidate' diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index dc4e9271..1ffe7bf1 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -13,14 +13,14 @@ "fallback_dns_desc": "List of fallback DNS servers used when upstream DNS servers are not responding. The syntax is the same as in the main upstreams field above.", "fallback_dns_placeholder": "Enter one fallback DNS server per line", "local_ptr_title": "Private reverse DNS servers", - "local_ptr_desc": "The DNS servers that AdGuard Home uses for local PTR queries. These servers are used to resolve PTR requests for addresses in private IP ranges, for example \"192.168.12.34\", using reverse DNS. If not set, AdGuard Home uses the addresses of the default DNS resolvers of your OS except for the addresses of AdGuard Home itself.", + "local_ptr_desc": "DNS servers used by AdGuard Home for private PTR, SOA, and NS requests. A request is considered private if it asks for an ARPA domain containing a subnet within private IP ranges (such as \"192.168.12.34\") and comes from a client with a private IP address. If not set, the default DNS resolvers of your OS will be used, except for the AdGuard Home IP addresses.", "local_ptr_default_resolver": "By default, AdGuard Home uses the following reverse DNS resolvers: {{ip}}.", "local_ptr_no_default_resolver": "AdGuard Home could not determine suitable private reverse DNS resolvers for this system.", "local_ptr_placeholder": "Enter one IP address per line", "resolve_clients_title": "Enable reverse resolving of clients' IP addresses", "resolve_clients_desc": "Reversely resolve clients' IP addresses into their hostnames by sending PTR queries to corresponding resolvers (private DNS servers for local clients, upstream servers for clients with public IP addresses).", "use_private_ptr_resolvers_title": "Use private reverse DNS resolvers", - "use_private_ptr_resolvers_desc": "Perform reverse DNS lookups for locally served addresses using these upstream servers. If disabled, AdGuard Home responds with NXDOMAIN to all such PTR requests except for clients known from DHCP, /etc/hosts, and so on.", + "use_private_ptr_resolvers_desc": "Resolve PTR, SOA, and NS requests for ARPA domains containing private IP addresses through private upstream servers, DHCP, /etc/hosts, etc. If disabled, AdGuard Home will respond to all such requests with NXDOMAIN.", "check_dhcp_servers": "Check for DHCP servers", "save_config": "Save configuration", "enabled_dhcp": "DHCP server enabled", diff --git a/go.mod b/go.mod index 0d809eb8..ac607fd9 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,10 @@ module github.com/AdguardTeam/AdGuardHome -go 1.22.2 +go 1.22.3 require ( - github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2 - github.com/AdguardTeam/golibs v0.21.0 + github.com/AdguardTeam/dnsproxy v0.71.1 + github.com/AdguardTeam/golibs v0.23.2 github.com/AdguardTeam/urlfilter v0.18.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.7 @@ -28,14 +28,15 @@ require ( // own code for that. Perhaps, use gopacket. github.com/mdlayher/raw v0.1.0 github.com/miekg/dns v1.1.58 - github.com/quic-go/quic-go v0.41.0 + // TODO(a.garipov): Use release version. + github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f github.com/stretchr/testify v1.9.0 github.com/ti-mo/netfilter v0.5.1 go.etcd.io/bbolt v1.3.9 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 - golang.org/x/net v0.23.0 - golang.org/x/sys v0.18.0 + golang.org/x/crypto v0.22.0 + golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 + golang.org/x/net v0.24.0 + golang.org/x/sys v0.19.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 howett.net/plist v1.0.1 @@ -58,9 +59,9 @@ require ( github.com/quic-go/qpack v0.4.0 // indirect github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect go.uber.org/mock v0.4.0 // indirect - golang.org/x/mod v0.16.0 // indirect - golang.org/x/sync v0.6.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.7.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.19.0 // indirect + golang.org/x/tools v0.20.0 // indirect gonum.org/v1/gonum v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 5a079513..a0785ca0 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2 h1:XDhWNn1OfmbtLgj3bR52WWIa0/cf0ijanOvuaT75f1I= -github.com/AdguardTeam/dnsproxy v0.67.1-0.20240405111306-032a0534ccd2/go.mod h1:7hAE3du5XPrBkdsqAPJIEGWklsE0ahHZONRlLASPeNI= -github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= -github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= +github.com/AdguardTeam/dnsproxy v0.71.1 h1:R8jKmoE9HwqdTt7bm8irpvrQEOSmD+iGdNXbOg/uM8Y= +github.com/AdguardTeam/dnsproxy v0.71.1/go.mod h1:rCaCL4m4n63sgwTOyUVdc7MC42PlUYBt11Fz/UjD+kM= +github.com/AdguardTeam/golibs v0.23.2 h1:rMjYantwtQ39e8G4zBQ6ZLlm4s3XH30Bc9VxhoOHwao= +github.com/AdguardTeam/golibs v0.23.2/go.mod h1:o9i55Sx6v7qogRQeqaBfmLbC/pZqeMBWi015U5PTDY0= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= @@ -101,8 +101,8 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= -github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= +github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f h1:L7x60Z6AW2giF/SvbDpMglGHJxtmFJV03khPwXLDScU= +github.com/quic-go/quic-go v0.42.1-0.20240424141022-12aa63824c7f/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M= github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4= github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= @@ -131,26 +131,26 @@ go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw= -golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc= +golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= -golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= -golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= -golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190322080309-f49334f85ddc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -161,17 +161,19 @@ golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= -golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +golang.org/x/tools v0.20.0 h1:hz/CVckiOxybQvFw6h7b/q80NTr9IUQb4s1IIzW7KNY= +golang.org/x/tools v0.20.0/go.mod h1:WvitBU7JJf6A4jOdg4S1tviW9bhUxkgeCui/0JHctQg= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0= diff --git a/internal/aghalg/aghalg.go b/internal/aghalg/aghalg.go index 93b259ef..b2d84e54 100644 --- a/internal/aghalg/aghalg.go +++ b/internal/aghalg/aghalg.go @@ -10,29 +10,8 @@ import ( "golang.org/x/exp/constraints" ) -// Coalesce returns the first non-zero value. It is named after function -// COALESCE in SQL. If values or all its elements are empty, it returns a zero -// value. -// -// T is comparable, because Go currently doesn't have a comparableWithZeroValue -// constraint. -// -// TODO(a.garipov): Think of ways to merge with [CoalesceSlice]. -func Coalesce[T comparable](values ...T) (res T) { - var zero T - for _, v := range values { - if v != zero { - return v - } - } - - return zero -} - // CoalesceSlice returns the first non-zero value. It is named after function // COALESCE in SQL. If values or all its elements are empty, it returns nil. -// -// TODO(a.garipov): Think of ways to merge with [Coalesce]. func CoalesceSlice[E any, S []E](values ...S) (res S) { for _, v := range values { if v != nil { diff --git a/internal/aghalg/ringbuffer_test.go b/internal/aghalg/ringbuffer_test.go index b86295c3..31ae4d7b 100644 --- a/internal/aghalg/ringbuffer_test.go +++ b/internal/aghalg/ringbuffer_test.go @@ -33,7 +33,7 @@ func elements(b *aghalg.RingBuffer[int], n uint, reverse bool) (es []int) { func TestNewRingBuffer(t *testing.T) { t.Run("success_and_clear", func(t *testing.T) { b := aghalg.NewRingBuffer[int](5) - for i := 0; i < 10; i++ { + for i := range 10 { b.Append(i) } assert.Equal(t, []int{5, 6, 7, 8, 9}, elements(b, b.Len(), false)) @@ -44,7 +44,7 @@ func TestNewRingBuffer(t *testing.T) { t.Run("zero", func(t *testing.T) { b := aghalg.NewRingBuffer[int](0) - for i := 0; i < 10; i++ { + for i := range 10 { b.Append(i) bufLen := b.Len() assert.EqualValues(t, 0, bufLen) @@ -55,7 +55,7 @@ func TestNewRingBuffer(t *testing.T) { t.Run("single", func(t *testing.T) { b := aghalg.NewRingBuffer[int](1) - for i := 0; i < 10; i++ { + for i := range 10 { b.Append(i) bufLen := b.Len() assert.EqualValues(t, 1, bufLen) @@ -94,7 +94,7 @@ func TestRingBuffer_Range(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - for i := 0; i < tc.count; i++ { + for i := range tc.count { b.Append(i) } diff --git a/internal/aghalg/sortedmap_test.go b/internal/aghalg/sortedmap_test.go index 46128ed0..6e563802 100644 --- a/internal/aghalg/sortedmap_test.go +++ b/internal/aghalg/sortedmap_test.go @@ -11,7 +11,7 @@ func TestNewSortedMap(t *testing.T) { var m SortedMap[string, int] letters := []string{} - for i := 0; i < 10; i++ { + for i := range 10 { r := string('a' + rune(i)) letters = append(letters, r) } diff --git a/internal/aghos/filewalker.go b/internal/aghos/filewalker.go index 30c2d718..23296539 100644 --- a/internal/aghos/filewalker.go +++ b/internal/aghos/filewalker.go @@ -97,6 +97,8 @@ func (fw FileWalker) Walk(fsys fs.FS, initial ...string) (ok bool, err error) { var filename string defer func() { err = errors.Annotate(err, "checking %q: %w", filename) }() + // TODO(e.burkov): Redo this loop, as it modifies the very same slice it + // iterates over. for i := 0; i < len(src); i++ { var patterns []string var cont bool diff --git a/internal/aghos/os.go b/internal/aghos/os.go index c357d11d..e04055e4 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -159,21 +159,11 @@ func NotifyReconfigureSignal(c chan<- os.Signal) { notifyReconfigureSignal(c) } -// NotifyShutdownSignal notifies c on receiving shutdown signals. -func NotifyShutdownSignal(c chan<- os.Signal) { - notifyShutdownSignal(c) -} - // IsReconfigureSignal returns true if sig is a reconfigure signal. func IsReconfigureSignal(sig os.Signal) (ok bool) { return isReconfigureSignal(sig) } -// IsShutdownSignal returns true if sig is a shutdown signal. -func IsShutdownSignal(sig os.Signal) (ok bool) { - return isShutdownSignal(sig) -} - // SendShutdownSignal sends the shutdown signal to the channel. func SendShutdownSignal(c chan<- os.Signal) { sendShutdownSignal(c) diff --git a/internal/aghos/os_unix.go b/internal/aghos/os_unix.go index f52fab02..f2cc4fef 100644 --- a/internal/aghos/os_unix.go +++ b/internal/aghos/os_unix.go @@ -13,26 +13,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) { signal.Notify(c, unix.SIGHUP) } -func notifyShutdownSignal(c chan<- os.Signal) { - signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) -} - func isReconfigureSignal(sig os.Signal) (ok bool) { return sig == unix.SIGHUP } -func isShutdownSignal(sig os.Signal) (ok bool) { - switch sig { - case - unix.SIGINT, - unix.SIGQUIT, - unix.SIGTERM: - return true - default: - return false - } -} - func sendShutdownSignal(_ chan<- os.Signal) { // On Unix we are already notified by the system. } diff --git a/internal/aghos/os_windows.go b/internal/aghos/os_windows.go index 2c2620eb..b9bf8a4c 100644 --- a/internal/aghos/os_windows.go +++ b/internal/aghos/os_windows.go @@ -5,7 +5,6 @@ package aghos import ( "os" "os/signal" - "syscall" "golang.org/x/sys/windows" ) @@ -43,25 +42,10 @@ func notifyReconfigureSignal(c chan<- os.Signal) { signal.Notify(c, windows.SIGHUP) } -func notifyShutdownSignal(c chan<- os.Signal) { - // syscall.SIGTERM is processed automatically. See go doc os/signal, - // section Windows. - signal.Notify(c, os.Interrupt) -} - func isReconfigureSignal(sig os.Signal) (ok bool) { return sig == windows.SIGHUP } -func isShutdownSignal(sig os.Signal) (ok bool) { - switch sig { - case os.Interrupt, syscall.SIGTERM: - return true - default: - return false - } -} - func sendShutdownSignal(c chan<- os.Signal) { c <- os.Interrupt } diff --git a/internal/aghrenameio/renameio_test.go b/internal/aghrenameio/renameio_test.go index 2aa75b34..2fdc2cdb 100644 --- a/internal/aghrenameio/renameio_test.go +++ b/internal/aghrenameio/renameio_test.go @@ -78,7 +78,6 @@ func TestWithDeferredCleanup(t *testing.T) { }} for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/internal/client/addrproc_test.go b/internal/client/addrproc_test.go index c6b38657..f0d0a8f7 100644 --- a/internal/client/addrproc_test.go +++ b/internal/client/addrproc_test.go @@ -91,8 +91,6 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) { }} for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -186,8 +184,6 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) { }} for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { t.Parallel() diff --git a/internal/client/client.go b/internal/client/client.go index d0a75045..d3ead923 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -7,6 +7,7 @@ package client import ( "encoding" "fmt" + "net/netip" "github.com/AdguardTeam/AdGuardHome/internal/whois" ) @@ -56,6 +57,9 @@ func (cs Source) MarshalText() (text []byte, err error) { // Runtime is a client information from different sources. type Runtime struct { + // ip is an IP address of a client. + ip netip.Addr + // whois is the filtered WHOIS information of a client. whois *whois.Info @@ -80,6 +84,15 @@ type Runtime struct { hostsFile []string } +// NewRuntime constructs a new runtime client. ip must be valid IP address. +// +// TODO(s.chzhen): Validate IP address. +func NewRuntime(ip netip.Addr) (r *Runtime) { + return &Runtime{ + ip: ip, + } +} + // Info returns a client information from the highest-priority source. func (r *Runtime) Info() (cs Source, host string) { info := []string{} @@ -133,8 +146,8 @@ func (r *Runtime) SetWHOIS(info *whois.Info) { r.whois = info } -// Unset clears a cs information. -func (r *Runtime) Unset(cs Source) { +// unset clears a cs information. +func (r *Runtime) unset(cs Source) { switch cs { case SourceWHOIS: r.whois = nil @@ -149,11 +162,16 @@ func (r *Runtime) Unset(cs Source) { } } -// IsEmpty returns true if there is no information from any source. -func (r *Runtime) IsEmpty() (ok bool) { +// isEmpty returns true if there is no information from any source. +func (r *Runtime) isEmpty() (ok bool) { return r.whois == nil && r.arp == nil && r.rdns == nil && r.dhcp == nil && r.hostsFile == nil } + +// Addr returns an IP address of the client. +func (r *Runtime) Addr() (ip netip.Addr) { + return r.ip +} diff --git a/internal/client/index.go b/internal/client/index.go index c6a17cb3..63ae690e 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -4,8 +4,12 @@ import ( "fmt" "net" "net/netip" + "slices" + "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" + "github.com/AdguardTeam/golibs/errors" + "golang.org/x/exp/maps" ) // macKey contains MAC as byte array of 6, 8, or 20 bytes. @@ -28,6 +32,9 @@ func macToKey(mac net.HardwareAddr) (key macKey) { // Index stores all information about persistent clients. type Index struct { + // nameToUID maps client name to UID. + nameToUID map[string]UID + // clientIDToUID maps client ID to UID. clientIDToUID map[string]UID @@ -47,6 +54,7 @@ type Index struct { // NewIndex initializes the new instance of client index. func NewIndex() (ci *Index) { return &Index{ + nameToUID: map[string]UID{}, clientIDToUID: map[string]UID{}, ipToUID: map[netip.Addr]UID{}, subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), @@ -62,6 +70,8 @@ func (ci *Index) Add(c *Persistent) { panic("client must contain uid") } + ci.nameToUID[c.Name] = c.UID + for _, id := range c.ClientIDs { ci.clientIDToUID[id] = c.UID } @@ -82,15 +92,30 @@ func (ci *Index) Add(c *Persistent) { ci.uidToClient[c.UID] = c } +// ClashesUID returns existing persistent client with the same UID as c. Note +// that this is only possible when configuration contains duplicate fields. +func (ci *Index) ClashesUID(c *Persistent) (err error) { + p, ok := ci.uidToClient[c.UID] + if ok { + return fmt.Errorf("another client %q uses the same uid", p.Name) + } + + return nil +} + // Clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. func (ci *Index) Clashes(c *Persistent) (err error) { + if p := ci.clashesName(c); p != nil { + return fmt.Errorf("another client uses the same name %q", p.Name) + } + for _, id := range c.ClientIDs { existing, ok := ci.clientIDToUID[id] if ok && existing != c.UID { p := ci.uidToClient[existing] - return fmt.Errorf("another client %q uses the same ID %q", p.Name, id) + return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id) } } @@ -112,6 +137,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) { return nil } +// clashesName returns existing persistent client with the same name as c or +// nil. c must be non-nil. +func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { + existing, ok := ci.FindByName(c.Name) + if !ok { + return nil + } + + if existing.UID != c.UID { + return existing + } + + return nil +} + // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { @@ -184,21 +224,33 @@ func (ci *Index) Find(id string) (c *Persistent, ok bool) { mac, err := net.ParseMAC(id) if err == nil { - return ci.findByMAC(mac) + return ci.FindByMAC(mac) } return nil, false } -// find finds persistent client by IP address. +// FindByName finds persistent client by name. +func (ci *Index) FindByName(name string) (c *Persistent, found bool) { + uid, found := ci.nameToUID[name] + if found { + return ci.uidToClient[uid], true + } + + return nil, false +} + +// findByIP finds persistent client by IP address. func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { uid, found := ci.ipToUID[ip] if found { return ci.uidToClient[uid], true } + ipWithoutZone := ip.WithZone("") ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) { - if pref.Contains(ip) { + // Remove zone before checking because prefixes strip zones. + if pref.Contains(ipWithoutZone) { uid, found = id, true return false @@ -214,8 +266,8 @@ func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { return nil, false } -// find finds persistent client by MAC. -func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { +// FindByMAC finds persistent client by MAC. +func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { k := macToKey(mac) uid, found := ci.macToUID[k] if found { @@ -225,9 +277,31 @@ func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { return nil, false } +// FindByIPWithoutZone finds a persistent client by IP address without zone. It +// strips the IPv6 zone index from the stored IP addresses before comparing, +// because querylog entries don't have it. See TODO on [querylog.logEntry.IP]. +// +// Note that multiple clients can have the same IP address with different zones. +// Therefore, the result of this method is indeterminate. +func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) { + if (ip == netip.Addr{}) { + return nil + } + + for addr, uid := range ci.ipToUID { + if addr.WithZone("") == ip { + return ci.uidToClient[uid] + } + } + + return nil +} + // Delete removes information about persistent client from the index. c must be // non-nil. func (ci *Index) Delete(c *Persistent) { + delete(ci.nameToUID, c.Name) + for _, id := range c.ClientIDs { delete(ci.clientIDToUID, id) } @@ -247,3 +321,48 @@ func (ci *Index) Delete(c *Persistent) { delete(ci.uidToClient, c.UID) } + +// Size returns the number of persistent clients. +func (ci *Index) Size() (n int) { + return len(ci.uidToClient) +} + +// Range calls f for each persistent client, unless cont is false. The order is +// undefined. +func (ci *Index) Range(f func(c *Persistent) (cont bool)) { + for _, c := range ci.uidToClient { + if !f(c) { + return + } + } +} + +// RangeByName is like [Index.Range] but sorts the persistent clients by name +// before iterating ensuring a predictable order. +func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) { + cs := maps.Values(ci.uidToClient) + slices.SortFunc(cs, func(a, b *Persistent) (n int) { + return strings.Compare(a.Name, b.Name) + }) + + for _, c := range cs { + if !f(c) { + break + } + } +} + +// CloseUpstreams closes upstream configurations of persistent clients. +func (ci *Index) CloseUpstreams() (err error) { + var errs []error + ci.RangeByName(func(c *Persistent) (cont bool) { + err = c.CloseUpstreams() + if err != nil { + errs = append(errs, err) + } + + return true + }) + + return errors.Join(errs...) +} diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index abf38710..38c0df15 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -22,7 +22,7 @@ func newIDIndex(m []*Persistent) (ci *Index) { return ci } -func TestClientIndex(t *testing.T) { +func TestClientIndex_Find(t *testing.T) { const ( cliIPNone = "1.2.3.4" cliIP1 = "1.1.1.1" @@ -35,26 +35,49 @@ func TestClientIndex(t *testing.T) { cliID = "client-id" cliMAC = "11:11:11:11:11:11" + + linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0" + linkLocalSubnet = "fe80::/16" ) - clients := []*Persistent{{ - Name: "client1", - IPs: []netip.Addr{ - netip.MustParseAddr(cliIP1), - netip.MustParseAddr(cliIPv6), - }, - }, { - Name: "client2", - IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, - Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, - }, { - Name: "client_with_mac", - MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, - }, { - Name: "client_with_id", - ClientIDs: []string{cliID}, - }} + var ( + clientWithBothFams = &Persistent{ + Name: "client1", + IPs: []netip.Addr{ + netip.MustParseAddr(cliIP1), + netip.MustParseAddr(cliIPv6), + }, + } + clientWithSubnet = &Persistent{ + Name: "client2", + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + } + + clientWithMAC = &Persistent{ + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + } + + clientWithID = &Persistent{ + Name: "client_with_id", + ClientIDs: []string{cliID}, + } + + clientLinkLocal = &Persistent{ + Name: "client_link_local", + Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)}, + } + ) + + clients := []*Persistent{ + clientWithBothFams, + clientWithSubnet, + clientWithMAC, + clientWithID, + clientLinkLocal, + } ci := newIDIndex(clients) testCases := []struct { @@ -64,19 +87,23 @@ func TestClientIndex(t *testing.T) { }{{ name: "ipv4_ipv6", ids: []string{cliIP1, cliIPv6}, - want: clients[0], + want: clientWithBothFams, }, { name: "ipv4_subnet", ids: []string{cliIP2, cliSubnetIP}, - want: clients[1], + want: clientWithSubnet, }, { name: "mac", ids: []string{cliMAC}, - want: clients[2], + want: clientWithMAC, }, { name: "client_id", ids: []string{cliID}, - want: clients[3], + want: clientWithID, + }, { + name: "client_link_local_subnet", + ids: []string{linkLocalIP}, + want: clientLinkLocal, }} for _, tc := range testCases { @@ -221,3 +248,103 @@ func TestMACToKey(t *testing.T) { _ = macToKey(mac) }) } + +func TestIndex_FindByIPWithoutZone(t *testing.T) { + var ( + ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1") + ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2") + ) + + var ( + clientNoZone = &Persistent{ + Name: "client", + IPs: []netip.Addr{ip}, + } + + clientWithZone = &Persistent{ + Name: "client_with_zone", + IPs: []netip.Addr{ipWithZone}, + } + ) + + ci := newIDIndex([]*Persistent{ + clientNoZone, + clientWithZone, + }) + + testCases := []struct { + ip netip.Addr + want *Persistent + name string + }{{ + name: "without_zone", + ip: ip, + want: clientNoZone, + }, { + name: "with_zone", + ip: ipWithZone, + want: clientWithZone, + }, { + name: "zero_address", + ip: netip.Addr{}, + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := ci.FindByIPWithoutZone(tc.ip.WithZone("")) + require.Equal(t, tc.want, c) + }) + } +} + +func TestClientIndex_RangeByName(t *testing.T) { + sortedClients := []*Persistent{{ + Name: "clientA", + ClientIDs: []string{"A"}, + }, { + Name: "clientB", + ClientIDs: []string{"B"}, + }, { + Name: "clientC", + ClientIDs: []string{"C"}, + }, { + Name: "clientD", + ClientIDs: []string{"D"}, + }, { + Name: "clientE", + ClientIDs: []string{"E"}, + }} + + testCases := []struct { + name string + want []*Persistent + }{{ + name: "basic", + want: sortedClients, + }, { + name: "nil", + want: nil, + }, { + name: "one_element", + want: sortedClients[:1], + }, { + name: "two_elements", + want: sortedClients[:2], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ci := newIDIndex(tc.want) + + var got []*Persistent + ci.RangeByName(func(c *Persistent) (cont bool) { + got = append(got, c) + + return true + }) + + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 06e346f4..317dc72b 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -64,9 +64,7 @@ type Persistent struct { // upstream must be used. UpstreamConfig *proxy.CustomUpstreamConfig - // TODO(d.kolyshev): Make SafeSearchConf a pointer. - SafeSearchConf filtering.SafeSearchConfig - SafeSearch filtering.SafeSearch + SafeSearch filtering.SafeSearch // BlockedServices is the configuration of blocked services of a client. BlockedServices *filtering.BlockedServices @@ -95,6 +93,9 @@ type Persistent struct { UseOwnBlockedServices bool IgnoreQueryLog bool IgnoreStatistics bool + + // TODO(d.kolyshev): Make SafeSearchConf a pointer. + SafeSearchConf filtering.SafeSearchConfig } // SetTags sets the tags if they are known, otherwise logs an unknown tag. diff --git a/internal/client/runtimeindex.go b/internal/client/runtimeindex.go new file mode 100644 index 00000000..300fdca0 --- /dev/null +++ b/internal/client/runtimeindex.go @@ -0,0 +1,63 @@ +package client + +import "net/netip" + +// RuntimeIndex stores information about runtime clients. +type RuntimeIndex struct { + // index maps IP address to runtime client. + index map[netip.Addr]*Runtime +} + +// NewRuntimeIndex returns initialized runtime index. +func NewRuntimeIndex() (ri *RuntimeIndex) { + return &RuntimeIndex{ + index: map[netip.Addr]*Runtime{}, + } +} + +// Client returns the saved runtime client by ip. If no such client exists, +// returns nil. +func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) { + return ri.index[ip] +} + +// Add saves the runtime client in the index. IP address of a client must be +// unique. See [Runtime.Client]. rc must not be nil. +func (ri *RuntimeIndex) Add(rc *Runtime) { + ip := rc.Addr() + ri.index[ip] = rc +} + +// Size returns the number of the runtime clients. +func (ri *RuntimeIndex) Size() (n int) { + return len(ri.index) +} + +// Range calls f for each runtime client in an undefined order. +func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) { + for _, rc := range ri.index { + if !f(rc) { + return + } + } +} + +// Delete removes the runtime client by ip. +func (ri *RuntimeIndex) Delete(ip netip.Addr) { + delete(ri.index, ip) +} + +// DeleteBySource removes all runtime clients that have information only from +// the specified source and returns the number of removed clients. +func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) { + for ip, rc := range ri.index { + rc.unset(src) + + if rc.isEmpty() { + delete(ri.index, ip) + n++ + } + } + + return n +} diff --git a/internal/client/runtimeindex_test.go b/internal/client/runtimeindex_test.go new file mode 100644 index 00000000..66b975a0 --- /dev/null +++ b/internal/client/runtimeindex_test.go @@ -0,0 +1,85 @@ +package client_test + +import ( + "net/netip" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/stretchr/testify/assert" +) + +func TestRuntimeIndex(t *testing.T) { + const cliSrc = client.SourceARP + + var ( + ip1 = netip.MustParseAddr("1.1.1.1") + ip2 = netip.MustParseAddr("2.2.2.2") + ip3 = netip.MustParseAddr("3.3.3.3") + ) + + ri := client.NewRuntimeIndex() + currentSize := 0 + + testCases := []struct { + ip netip.Addr + name string + hosts []string + src client.Source + }{{ + src: cliSrc, + ip: ip1, + name: "1", + hosts: []string{"host1"}, + }, { + src: cliSrc, + ip: ip2, + name: "2", + hosts: []string{"host2"}, + }, { + src: cliSrc, + ip: ip3, + name: "3", + hosts: []string{"host3"}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rc := client.NewRuntime(tc.ip) + rc.SetInfo(tc.src, tc.hosts) + + ri.Add(rc) + currentSize++ + + got := ri.Client(tc.ip) + assert.Equal(t, rc, got) + }) + } + + t.Run("size", func(t *testing.T) { + assert.Equal(t, currentSize, ri.Size()) + }) + + t.Run("range", func(t *testing.T) { + s := 0 + + ri.Range(func(rc *client.Runtime) (cont bool) { + s++ + + return true + }) + + assert.Equal(t, currentSize, s) + }) + + t.Run("delete", func(t *testing.T) { + ri.Delete(ip1) + currentSize-- + + assert.Equal(t, currentSize, ri.Size()) + }) + + t.Run("delete_by_src", func(t *testing.T) { + assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc)) + assert.Equal(t, 0, ri.Size()) + }) +} diff --git a/internal/configmigrate/v15.go b/internal/configmigrate/v15.go index 85f6d14b..c99adcd3 100644 --- a/internal/configmigrate/v15.go +++ b/internal/configmigrate/v15.go @@ -1,5 +1,7 @@ package configmigrate +import "github.com/AdguardTeam/golibs/errors" + // migrateTo15 performs the following changes: // // # BEFORE: @@ -43,7 +45,7 @@ func migrateTo15(diskConf yobj) (err error) { } diskConf["querylog"] = qlog - return coalesceError( + return errors.Join( moveVal[bool](dns, qlog, "querylog_enabled", "enabled"), moveVal[bool](dns, qlog, "querylog_file_enabled", "file_enabled"), moveVal[any](dns, qlog, "querylog_interval", "interval"), diff --git a/internal/configmigrate/v24.go b/internal/configmigrate/v24.go index f9d781e5..104506dc 100644 --- a/internal/configmigrate/v24.go +++ b/internal/configmigrate/v24.go @@ -1,5 +1,7 @@ package configmigrate +import "github.com/AdguardTeam/golibs/errors" + // migrateTo24 performs the following changes: // // # BEFORE: @@ -28,7 +30,7 @@ func migrateTo24(diskConf yobj) (err error) { diskConf["schema_version"] = 24 logObj := yobj{} - err = coalesceError( + err = errors.Join( moveVal[string](diskConf, logObj, "log_file", "file"), moveVal[int](diskConf, logObj, "log_max_backups", "max_backups"), moveVal[int](diskConf, logObj, "log_max_size", "max_size"), diff --git a/internal/configmigrate/v26.go b/internal/configmigrate/v26.go index a19b9038..4d2c3975 100644 --- a/internal/configmigrate/v26.go +++ b/internal/configmigrate/v26.go @@ -1,5 +1,7 @@ package configmigrate +import "github.com/AdguardTeam/golibs/errors" + // migrateTo26 performs the following changes: // // # BEFORE: @@ -78,7 +80,7 @@ func migrateTo26(diskConf yobj) (err error) { } filteringObj := yobj{} - err = coalesceError( + err = errors.Join( moveSameVal[bool](dns, filteringObj, "filtering_enabled"), moveSameVal[int](dns, filteringObj, "filters_update_interval"), moveSameVal[bool](dns, filteringObj, "parental_enabled"), diff --git a/internal/configmigrate/v7.go b/internal/configmigrate/v7.go index 61ee1e26..b9339ace 100644 --- a/internal/configmigrate/v7.go +++ b/internal/configmigrate/v7.go @@ -1,5 +1,7 @@ package configmigrate +import "github.com/AdguardTeam/golibs/errors" + // migrateTo7 performs the following changes: // // # BEFORE: @@ -37,7 +39,7 @@ func migrateTo7(diskConf yobj) (err error) { } dhcpv4 := yobj{} - err = coalesceError( + err = errors.Join( moveSameVal[string](dhcp, dhcpv4, "gateway_ip"), moveSameVal[string](dhcp, dhcpv4, "subnet_mask"), moveSameVal[string](dhcp, dhcpv4, "range_start"), diff --git a/internal/configmigrate/yaml.go b/internal/configmigrate/yaml.go index c2e2ff08..52dc2704 100644 --- a/internal/configmigrate/yaml.go +++ b/internal/configmigrate/yaml.go @@ -50,19 +50,3 @@ func moveVal[T any](src, dst yobj, srcKey, dstKey string) (err error) { func moveSameVal[T any](src, dst yobj, key string) (err error) { return moveVal[T](src, dst, key, key) } - -// coalesceError returns the first non-nil error. It is named after function -// COALESCE in SQL. If all errors are nil, it returns nil. -// -// TODO(e.burkov): Replace with [errors.Join]. -// -// TODO(a.garipov): Think of ways to merge with [aghalg.Coalesce]. -func coalesceError(errors ...error) (res error) { - for _, err := range errors { - if err != nil { - return err - } - } - - return nil -} diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index c4d6c591..c6c6beab 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -156,7 +156,10 @@ func (a *accessManager) isBlockedIP(ip netip.Addr) (blocked bool, rule string) { } for _, ipnet := range ipnets { - if ipnet.Contains(ip) { + // Remove zone before checking because prefixes stip zones. + // + // TODO(d.kolyshev): Cover with tests. + if ipnet.Contains(ip.WithZone("")) { return blocked, ipnet.String() } } diff --git a/internal/dnsforward/beforerequest.go b/internal/dnsforward/beforerequest.go new file mode 100644 index 00000000..8a1b0272 --- /dev/null +++ b/internal/dnsforward/beforerequest.go @@ -0,0 +1,116 @@ +package dnsforward + +import ( + "encoding/binary" + "fmt" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// type check +var _ proxy.BeforeRequestHandler = (*Server)(nil) + +// HandleBefore is the handler that is called before any other processing, +// including logs. It performs access checks and puts the client ID, if there +// is one, into the server's cache. +// +// TODO(d.kolyshev): Extract to separate package. +func (s *Server) HandleBefore( + _ *proxy.Proxy, + pctx *proxy.DNSContext, +) (err error) { + clientID, err := s.clientIDFromDNSContext(pctx) + if err != nil { + return &proxy.BeforeRequestError{ + Err: fmt.Errorf("getting clientid: %w", err), + Response: s.NewMsgSERVFAIL(pctx.Req), + } + } + + blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID) + if blocked { + return s.preBlockedResponse(pctx) + } + + if len(pctx.Req.Question) == 1 { + q := pctx.Req.Question[0] + qt := q.Qtype + host := aghnet.NormalizeDomain(q.Name) + if s.access.isBlockedHost(host, qt) { + log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host) + + return s.preBlockedResponse(pctx) + } + } + + if clientID != "" { + key := [8]byte{} + binary.BigEndian.PutUint64(key[:], pctx.RequestID) + s.clientIDCache.Set(key[:], []byte(clientID)) + } + + return nil +} + +// clientIDFromDNSContext extracts the client's ID from the server name of the +// client's DoT or DoQ request or the path of the client's DoH. If the protocol +// is not one of these, clientID is an empty string and err is nil. +func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) { + proto := pctx.Proto + if proto == proxy.ProtoHTTPS { + clientID, err = clientIDFromDNSContextHTTPS(pctx) + if err != nil { + return "", fmt.Errorf("checking url: %w", err) + } else if clientID != "" { + return clientID, nil + } + + // Go on and check the domain name as well. + } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { + return "", nil + } + + hostSrvName := s.conf.ServerName + if hostSrvName == "" { + return "", nil + } + + cliSrvName, err := clientServerName(pctx, proto) + if err != nil { + return "", err + } + + clientID, err = clientIDFromClientServerName( + hostSrvName, + cliSrvName, + s.conf.StrictSNICheck, + ) + if err != nil { + return "", fmt.Errorf("clientid check: %w", err) + } + + return clientID, nil +} + +// errAccessBlocked is a sentinel error returned when a request is blocked by +// access settings. +var errAccessBlocked errors.Error = "blocked by access settings" + +// preBlockedResponse returns a protocol-appropriate response for a request that +// was blocked by access settings. +func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) { + if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt { + // Return nil so that dnsproxy drops the connection and thus + // prevent DNS amplification attacks. + return errAccessBlocked + } + + return &proxy.BeforeRequestError{ + Err: errAccessBlocked, + Response: s.makeResponseREFUSED(pctx.Req), + } +} diff --git a/internal/dnsforward/beforerequest_internal_test.go b/internal/dnsforward/beforerequest_internal_test.go new file mode 100644 index 00000000..7e0d6e9b --- /dev/null +++ b/internal/dnsforward/beforerequest_internal_test.go @@ -0,0 +1,299 @@ +package dnsforward + +import ( + "crypto/tls" + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + blockedHost = "blockedhost.org" + testFQDN = "example.org." + dnsClientTimeout = 200 * time.Millisecond +) + +func TestServer_HandleBefore_tls(t *testing.T) { + t.Parallel() + + const clientID = "client-1" + + testCases := []struct { + clientSrvName string + name string + host string + allowedClients []string + disallowedClients []string + blockedHosts []string + wantRCode int + }{{ + clientSrvName: tlsServerName, + name: "allow_all", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantRCode: dns.RcodeSuccess, + }, { + clientSrvName: "%" + "." + tlsServerName, + name: "invalid_client_id", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantRCode: dns.RcodeServerFailure, + }, { + clientSrvName: clientID + "." + tlsServerName, + name: "allowed_client_allowed", + host: testFQDN, + allowedClients: []string{clientID}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantRCode: dns.RcodeSuccess, + }, { + clientSrvName: "client-2." + tlsServerName, + name: "allowed_client_rejected", + host: testFQDN, + allowedClients: []string{clientID}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantRCode: dns.RcodeRefused, + }, { + clientSrvName: tlsServerName, + name: "disallowed_client_allowed", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{clientID}, + blockedHosts: []string{}, + wantRCode: dns.RcodeSuccess, + }, { + clientSrvName: clientID + "." + tlsServerName, + name: "disallowed_client_rejected", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{clientID}, + blockedHosts: []string{}, + wantRCode: dns.RcodeRefused, + }, { + clientSrvName: tlsServerName, + name: "blocked_hosts_allowed", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + wantRCode: dns.RcodeSuccess, + }, { + clientSrvName: tlsServerName, + name: "blocked_hosts_rejected", + host: dns.Fqdn(blockedHost), + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + wantRCode: dns.RcodeRefused, + }} + + localAns := []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: testFQDN, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 3600, + Rdlength: 4, + }, + A: net.IP{1, 2, 3, 4}, + }} + localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := (&dns.Msg{}).SetReply(req) + resp.Answer = localAns + + require.NoError(t, w.WriteMsg(resp)) + }) + localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s, _ := createTestTLS(t, TLSConfig{ + TLSListenAddrs: []*net.TCPAddr{{}}, + ServerName: tlsServerName, + }) + + s.conf.UpstreamDNS = []string{localUpsAddr} + + s.conf.AllowedClients = tc.allowedClients + s.conf.DisallowedClients = tc.disallowedClients + s.conf.BlockedHosts = tc.blockedHosts + + err := s.Prepare(&s.conf) + require.NoError(t, err) + + startDeferStop(t, s) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: tc.clientSrvName, + } + + client := &dns.Client{ + Net: "tcp-tls", + TLSConfig: tlsConfig, + Timeout: dnsClientTimeout, + } + + req := createTestMessage(tc.host) + addr := s.dnsProxy.Addr(proxy.ProtoTLS).String() + + reply, _, err := client.Exchange(req, addr) + require.NoError(t, err) + + assert.Equal(t, tc.wantRCode, reply.Rcode) + if tc.wantRCode == dns.RcodeSuccess { + assert.Equal(t, localAns, reply.Answer) + } else { + assert.Empty(t, reply.Answer) + } + }) + } +} + +func TestServer_HandleBefore_udp(t *testing.T) { + t.Parallel() + + const ( + clientIPv4 = "127.0.0.1" + clientIPv6 = "::1" + ) + + clientIPs := []string{clientIPv4, clientIPv6} + + testCases := []struct { + name string + host string + allowedClients []string + disallowedClients []string + blockedHosts []string + wantTimeout bool + }{{ + name: "allow_all", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantTimeout: false, + }, { + name: "allowed_client_allowed", + host: testFQDN, + allowedClients: clientIPs, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantTimeout: false, + }, { + name: "allowed_client_rejected", + host: testFQDN, + allowedClients: []string{"1:2:3::4"}, + disallowedClients: []string{}, + blockedHosts: []string{}, + wantTimeout: true, + }, { + name: "disallowed_client_allowed", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{"1:2:3::4"}, + blockedHosts: []string{}, + wantTimeout: false, + }, { + name: "disallowed_client_rejected", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: clientIPs, + blockedHosts: []string{}, + wantTimeout: true, + }, { + name: "blocked_hosts_allowed", + host: testFQDN, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + wantTimeout: false, + }, { + name: "blocked_hosts_rejected", + host: dns.Fqdn(blockedHost), + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + wantTimeout: true, + }} + + localAns := []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: testFQDN, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 3600, + Rdlength: 4, + }, + A: net.IP{1, 2, 3, 4}, + }} + localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { + resp := (&dns.Msg{}).SetReply(req) + resp.Answer = localAns + + require.NoError(t, w.WriteMsg(resp)) + }) + localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + Config: Config{ + AllowedClients: tc.allowedClients, + DisallowedClients: tc.disallowedClients, + BlockedHosts: tc.blockedHosts, + UpstreamDNS: []string{localUpsAddr}, + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + ServePlainDNS: true, + }) + + startDeferStop(t, s) + + client := &dns.Client{ + Net: "udp", + Timeout: dnsClientTimeout, + } + + req := createTestMessage(tc.host) + addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() + + reply, _, err := client.Exchange(req, addr) + if tc.wantTimeout { + wantErr := &net.OpError{} + require.ErrorAs(t, err, &wantErr) + assert.True(t, wantErr.Timeout()) + + assert.Nil(t, reply) + } else { + require.NoError(t, err) + require.NotNil(t, reply) + + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) + assert.Equal(t, localAns, reply.Answer) + } + }) + } +} diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 71ad2eb0..6cda328c 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -110,46 +110,6 @@ type quicConnection interface { ConnectionState() (cs quic.ConnectionState) } -// clientIDFromDNSContext extracts the client's ID from the server name of the -// client's DoT or DoQ request or the path of the client's DoH. If the protocol -// is not one of these, clientID is an empty string and err is nil. -func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) { - proto := pctx.Proto - if proto == proxy.ProtoHTTPS { - clientID, err = clientIDFromDNSContextHTTPS(pctx) - if err != nil { - return "", fmt.Errorf("checking url: %w", err) - } else if clientID != "" { - return clientID, nil - } - - // Go on and check the domain name as well. - } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { - return "", nil - } - - hostSrvName := s.conf.ServerName - if hostSrvName == "" { - return "", nil - } - - cliSrvName, err := clientServerName(pctx, proto) - if err != nil { - return "", err - } - - clientID, err = clientIDFromClientServerName( - hostSrvName, - cliSrvName, - s.conf.StrictSNICheck, - ) - if err != nil { - return "", fmt.Errorf("clientid check: %w", err) - } - - return clientID, nil -} - // clientServerName returns the TLS server name based on the protocol. For // DNS-over-HTTPS requests, it will return the hostname part of the Host header // if there is one. diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index db2d50af..4d2924ab 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -235,9 +235,18 @@ type DNSCryptConfig struct { // ServerConfig represents server configuration. // The zero ServerConfig is empty and ready for use. type ServerConfig struct { - UDPListenAddrs []*net.UDPAddr // UDP listen address - TCPListenAddrs []*net.TCPAddr // TCP listen address - UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config + // UDPListenAddrs is the list of addresses to listen for DNS-over-UDP. + UDPListenAddrs []*net.UDPAddr + + // TCPListenAddrs is the list of addresses to listen for DNS-over-TCP. + TCPListenAddrs []*net.TCPAddr + + // UpstreamConfig is the general configuration of upstream DNS servers. + UpstreamConfig *proxy.UpstreamConfig + + // PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers + // for private reverse DNS. + PrivateRDNSUpstreamConfig *proxy.UpstreamConfig // AddrProcConf defines the configuration for the client IP processor. // If nil, [client.EmptyAddrProc] is used. @@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) conf = &proxy.Config{ - HTTP3: srvConf.ServeHTTP3, - Ratelimit: int(srvConf.Ratelimit), - RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, - RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6, - RatelimitWhitelist: srvConf.RatelimitWhitelist, - RefuseAny: srvConf.RefuseAny, - TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes), - CacheMinTTL: srvConf.CacheMinTTL, - CacheMaxTTL: srvConf.CacheMaxTTL, - CacheOptimistic: srvConf.CacheOptimistic, - UpstreamConfig: srvConf.UpstreamConfig, - BeforeRequestHandler: s.beforeRequestHandler, - RequestHandler: s.handleDNSRequest, - HTTPSServerName: aghhttp.UserAgent(), - EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, - MaxGoroutines: srvConf.MaxGoroutines, - UseDNS64: srvConf.UseDNS64, - DNS64Prefs: srvConf.DNS64Prefixes, + HTTP3: srvConf.ServeHTTP3, + Ratelimit: int(srvConf.Ratelimit), + RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, + RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6, + RatelimitWhitelist: srvConf.RatelimitWhitelist, + RefuseAny: srvConf.RefuseAny, + TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes), + CacheMinTTL: srvConf.CacheMinTTL, + CacheMaxTTL: srvConf.CacheMaxTTL, + CacheOptimistic: srvConf.CacheOptimistic, + UpstreamConfig: srvConf.UpstreamConfig, + PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig, + BeforeRequestHandler: s, + RequestHandler: s.handleDNSRequest, + HTTPSServerName: aghhttp.UserAgent(), + EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, + MaxGoroutines: srvConf.MaxGoroutines, + UseDNS64: srvConf.UseDNS64, + DNS64Prefs: srvConf.DNS64Prefixes, + UsePrivateRDNS: srvConf.UsePrivateRDNS, + PrivateSubnets: s.privateNets, + MessageConstructor: s, } if srvConf.EDNSClientSubnet.UseCustom { @@ -452,12 +465,33 @@ func (s *Server) prepareIpsetListSettings() (err error) { } ipsets := stringutil.SplitTrimmed(string(data), "\n") + ipsets = stringutil.FilterOut(ipsets, IsCommentOrEmpty) log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn) return s.ipset.init(ipsets) } +// loadUpstreams parses upstream DNS servers from the configured file or from +// the configuration itself. +func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) { + if conf.UpstreamDNSFileName == "" { + return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil + } + + var data []byte + data, err = os.ReadFile(conf.UpstreamDNSFileName) + if err != nil { + return nil, fmt.Errorf("reading upstream from file: %w", err) + } + + upstreams = stringutil.SplitTrimmed(string(data), "\n") + + log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName) + + return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil +} + // collectListenAddr adds addrPort to addrs. It also adds its port to // unspecPorts if its address is unspecified. func collectListenAddr( @@ -529,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) { return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr()) } -// filterOut filters out all the upstreams that match um. It returns all the -// closing errors joined. +// filterOutAddrs filters out all the upstreams that match um. It returns all +// the closing errors joined. func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) { var errs []error delFunc := func(u upstream.Upstream) (ok bool) { diff --git a/internal/dnsforward/dns64_test.go b/internal/dnsforward/dns64_test.go index 49e1e4ce..18bc348f 100644 --- a/internal/dnsforward/dns64_test.go +++ b/internal/dnsforward/dns64_test.go @@ -3,7 +3,6 @@ package dnsforward import ( "net" "testing" - "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -11,6 +10,7 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns } func TestServer_HandleDNSRequest_dns64(t *testing.T) { + t.Parallel() + const ( ipv4Domain = "ipv4.only." ipv6Domain = "ipv6.only." @@ -252,33 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { require.Len(pt, m.Question, 1) require.Equal(pt, m.Question[0].Name, ptr64Domain) - resp := (&dns.Msg{ - Answer: []dns.RR{localRR}, - }).SetReply(m) + + resp := (&dns.Msg{}).SetReply(m) + resp.Answer = []dns.RR{localRR} require.NoError(t, w.WriteMsg(resp)) }) localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() client := &dns.Client{ - Net: "tcp", - Timeout: 1 * time.Second, + Net: string(proxy.ProtoTCP), + Timeout: testTimeout, } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() + upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { q := req.Question[0] - require.Contains(pt, tc.upsAns, q.Qtype) + require.Contains(pt, tc.upsAns, q.Qtype) answer := tc.upsAns[q.Qtype] - resp := (&dns.Msg{ - Answer: answer[sectionAnswer], - Ns: answer[sectionAuthority], - Extra: answer[sectionAdditional], - }).SetReply(req) + resp := (&dns.Msg{}).SetReply(req) + resp.Answer = answer[sectionAnswer] + resp.Ns = answer[sectionAuthority] + resp.Extra = answer[sectionAdditional] require.NoError(pt, w.WriteMsg(resp)) }) @@ -308,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) { req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) - resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String()) + resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String()) require.NoError(t, excErr) require.Equal(t, tc.wantAns, resp.Answer) }) } } + +func TestServer_dns64WithDisabledRDNS(t *testing.T) { + t.Parallel() + + // Shouldn't go to upstream at all. + panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { + panic("not implemented") + }) + upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String() + localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String() + + s := createTestServer(t, &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + UseDNS64: true, + Config: Config{ + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + UpstreamDNS: []string{upsAddr}, + }, + UsePrivateRDNS: false, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, + }) + startDeferStop(t, s) + + mappedIPv6 := net.ParseIP("64:ff9b::102:304") + arpa, err := netutil.IPToReversedAddr(mappedIPv6) + require.NoError(t, err) + + req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR) + + cli := &dns.Client{ + Net: string(proxy.ProtoTCP), + Timeout: testTimeout, + } + + resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String()) + require.NoError(t, err) + + assert.Equal(t, dns.RcodeNameError, resp.Rcode) +} diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index a1d1eede..fda29f0a 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -2,6 +2,7 @@ package dnsforward import ( + "cmp" "context" "fmt" "io" @@ -15,7 +16,6 @@ import ( "sync/atomic" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -135,12 +135,6 @@ type Server struct { // WHOIS, etc. addrProc client.AddressProcessor - // localResolvers is a DNS proxy instance used to resolve PTR records for - // addresses considered private as per the [privateNets]. - // - // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. - localResolvers *proxy.Proxy - // sysResolvers used to fetch system resolvers to use by default for private // PTR resolving. sysResolvers SystemResolvers @@ -158,12 +152,6 @@ type Server struct { // [upstream.Resolver] interface. bootResolvers []*upstream.UpstreamResolver - // recDetector is a cache for recursive requests. It is used to detect and - // prevent recursive requests only for private upstreams. - // - // See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135. - recDetector *recursionDetector - // dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major // part of DNS64 happens inside the [proxy] package, but there still are // some places where response mapping is needed (e.g. DHCP). @@ -212,14 +200,6 @@ type DNSCreateParams struct { LocalDomain string } -const ( - // recursionTTL is the time recursive request is cached for. - recursionTTL = 1 * time.Second - // cachedRecurrentReqNum is the maximum number of cached recurrent - // requests. - cachedRecurrentReqNum = 1000 -) - // NewServer creates a new instance of the dnsforward.Server // Note: this function must be called only once // @@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, - recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ EnableLRU: true, MaxCount: defaultClientIDCacheCount, @@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er s.serverLock.RLock() defer s.serverLock.RUnlock() + // TODO(e.burkov): Migrate to [netip.Addr] already. arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) if err != nil { return "", 0, fmt.Errorf("reversing ip: %w", err) @@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er } dctx := &proxy.DNSContext{ - Proto: "udp", - Req: req, + Proto: proxy.ProtoUDP, + Req: req, + IsPrivateClient: true, } - var resolver *proxy.Proxy var errMsg string if s.privateNets.Contains(ip) { if !s.conf.UsePrivateRDNS { return "", 0, nil } - resolver = s.localResolvers errMsg = "resolving a private address: %w" - s.recDetector.add(*req) + dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen()) } else { - resolver = s.internalProxy errMsg = "resolving an address: %w" } - if err = resolver.Resolve(dctx); err != nil { + if err = s.internalProxy.Resolve(dctx); err != nil { return "", 0, fmt.Errorf(errMsg, err) } @@ -473,103 +451,6 @@ func (s *Server) startLocked() error { return err } -// prepareLocalResolvers initializes the local upstreams configuration using -// boot as bootstrap. It assumes that s.serverLock is locked or s not running. -func (s *Server) prepareLocalResolvers( - boot upstream.Resolver, -) (uc *proxy.UpstreamConfig, err error) { - set, err := s.conf.ourAddrsSet() - if err != nil { - // Don't wrap the error because it's informative enough as is. - return nil, err - } - - resolvers := s.conf.LocalPTRResolvers - confNeedsFiltering := len(resolvers) > 0 - if confNeedsFiltering { - resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty) - } else { - sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has) - resolvers = make([]string, 0, len(sysResolvers)) - for _, r := range sysResolvers { - resolvers = append(resolvers, r.String()) - } - } - - log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers) - - uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{ - Bootstrap: boot, - Timeout: defaultLocalTimeout, - // TODO(e.burkov): Should we verify server's certificates? - PreferIPv6: s.conf.BootstrapPreferIPv6, - }) - if err != nil { - return nil, fmt.Errorf("preparing private upstreams: %w", err) - } - - if confNeedsFiltering { - err = filterOutAddrs(uc, set) - if err != nil { - return nil, fmt.Errorf("filtering private upstreams: %w", err) - } - } - - return uc, nil -} - -// LocalResolversError is an error type for errors during local resolvers setup. -// This is only needed to distinguish these errors from errors returned by -// creating the proxy. -type LocalResolversError struct { - Err error -} - -// type check -var _ error = (*LocalResolversError)(nil) - -// Error implements the error interface for *LocalResolversError. -func (err *LocalResolversError) Error() (s string) { - return fmt.Sprintf("creating local resolvers: %s", err.Err) -} - -// type check -var _ errors.Wrapper = (*LocalResolversError)(nil) - -// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError. -func (err *LocalResolversError) Unwrap() error { - return err.Err -} - -// setupLocalResolvers initializes and sets the resolvers for local addresses. -// It assumes s.serverLock is locked or s not running. It returns the upstream -// configuration used for private PTR resolving, or nil if it's disabled. Note, -// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. -func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) { - if !s.conf.UsePrivateRDNS { - // It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig]. - return nil, nil - } - - uc, err = s.prepareLocalResolvers(boot) - if err != nil { - // Don't wrap the error because it's informative enough as is. - return nil, err - } - - localResolvers, err := proxy.New(&proxy.Config{ - UpstreamConfig: uc, - }) - if err != nil { - return nil, &LocalResolversError{Err: err} - } - - s.localResolvers = localResolvers - - // TODO(e.burkov): Should we also consider the DNS64 usage? - return uc, nil -} - // Prepare initializes parameters of s using data from conf. conf must not be // nil. func (s *Server) Prepare(conf *ServerConfig) (err error) { @@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { s.initDefaultSettings() - boot, err := s.prepareInternalDNS() + err = s.prepareInternalDNS() if err != nil { // Don't wrap the error, because it's informative enough as is. return err @@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { return fmt.Errorf("preparing access: %w", err) } - // TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy. - proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot) - if err != nil { - return fmt.Errorf("setting up resolvers: %w", err) - } - proxyConfig.Fallbacks, err = s.setupFallbackDNS() if err != nil { return fmt.Errorf("setting up fallback dns servers: %w", err) @@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { s.dnsProxy = dnsProxy - s.recDetector.clear() - s.setupAddrProc() s.registerHandlers() @@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { return nil } -// prepareInternalDNS initializes the internal state of s before initializing -// the primary DNS proxy instance. It assumes s.serverLock is locked or the -// Server not running. -func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { - err = s.prepareIpsetListSettings() +// prepareUpstreamSettings sets upstream DNS server settings. +func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { + // Load upstreams either from the file, or from the settings + var upstreams []string + upstreams, err = s.conf.loadUpstreams() if err != nil { - return nil, fmt.Errorf("preparing ipset settings: %w", err) + return fmt.Errorf("loading upstreams: %w", err) } - s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ - Timeout: DefaultTimeout, + uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ + Bootstrap: boot, + Timeout: s.conf.UpstreamTimeout, HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), + PreferIPv6: s.conf.BootstrapPreferIPv6, + // Use a customized set of RootCAs, because Go's default mechanism of + // loading TLS roots does not always work properly on some routers so we're + // loading roots manually and pass it here. + // + // See [aghtls.SystemRootCAs]. + // + // TODO(a.garipov): Investigate if that's true. + RootCAs: s.conf.TLSv12Roots, + CipherSuites: s.conf.TLSCiphers, }) + if err != nil { + return fmt.Errorf("preparing upstream config: %w", err) + } + + s.conf.UpstreamConfig = uc + + return nil +} + +// PrivateRDNSError is returned when the private rDNS upstreams are +// invalid but enabled. +// +// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams +// configuration in proxy when the private rDNS function is enabled. In theory, +// proxy supports the case when no upstreams provided to resolve the private +// request, since it already supports this for DNS64-prefixed PTR requests. +type PrivateRDNSError struct { + err error +} + +// Error implements the [errors.Error] interface. +func (e *PrivateRDNSError) Error() (s string) { + return e.err.Error() +} + +func (e *PrivateRDNSError) Unwrap() (err error) { + return e.err +} + +// prepareLocalResolvers initializes the private RDNS upstream configuration +// according to the server's settings. It assumes s.serverLock is locked or the +// Server not running. +func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) { + if !s.conf.UsePrivateRDNS { + return nil, nil + } + + var ownAddrs addrPortSet + ownAddrs, err = s.conf.ourAddrsSet() if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err } + opts := &upstream.Options{ + Bootstrap: s.bootstrap, + Timeout: defaultLocalTimeout, + // TODO(e.burkov): Should we verify server's certificates? + PreferIPv6: s.conf.BootstrapPreferIPv6, + } + + addrs := s.conf.LocalPTRResolvers + uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts) + if err != nil { + return nil, fmt.Errorf("preparing resolvers: %w", err) + } + + return uc, nil +} + +// prepareInternalDNS initializes the internal state of s before initializing +// the primary DNS proxy instance. It assumes s.serverLock is locked or the +// Server not running. +func (s *Server) prepareInternalDNS() (err error) { + err = s.prepareIpsetListSettings() + if err != nil { + return fmt.Errorf("preparing ipset settings: %w", err) + } + + bootOpts := &upstream.Options{ + Timeout: DefaultTimeout, + HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), + } + + s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + err = s.prepareUpstreamSettings(s.bootstrap) if err != nil { // Don't wrap the error, because it's informative enough as is. - return s.bootstrap, err + return err + } + + s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers() + if err != nil { + return err } err = s.prepareInternalProxy() if err != nil { - return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) + return fmt.Errorf("preparing internal proxy: %w", err) } - return s.bootstrap, nil + return nil } // setupFallbackDNS initializes the fallback DNS servers. @@ -743,10 +707,16 @@ func validateBlockingMode( func (s *Server) prepareInternalProxy() (err error) { srvConf := s.conf conf := &proxy.Config{ - CacheEnabled: true, - CacheSizeBytes: 4096, - UpstreamConfig: srvConf.UpstreamConfig, - MaxGoroutines: s.conf.MaxGoroutines, + CacheEnabled: true, + CacheSizeBytes: 4096, + PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig, + UpstreamConfig: srvConf.UpstreamConfig, + MaxGoroutines: srvConf.MaxGoroutines, + UseDNS64: srvConf.UseDNS64, + DNS64Prefs: srvConf.DNS64Prefixes, + UsePrivateRDNS: srvConf.UsePrivateRDNS, + PrivateSubnets: s.privateNets, + MessageConstructor: s, } err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) @@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) { } } - logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s") - if s.localResolvers != nil { - logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s") - } - for _, b := range s.bootResolvers { logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) } @@ -908,5 +873,5 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, blocked = true } - return blocked, aghalg.Coalesce(rule, clientID) + return blocked, cmp.Or(rule, clientID) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index b490b38d..9e4942cc 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1,7 +1,7 @@ package dnsforward import ( - "context" + "cmp" "crypto/ecdsa" "crypto/rand" "crypto/rsa" @@ -21,7 +21,6 @@ import ( "testing/fstest" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -190,7 +189,7 @@ func newGoogleUpstream() (u upstream.Upstream) { return &aghtest.UpstreamMock{ OnAddress: func() (addr string) { return "google.upstream.example" }, OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( + return cmp.Or( aghtest.MatchedResponse(req, dns.TypeA, googleDomainName, "8.8.8.8"), new(dns.Msg).SetRcode(req, dns.RcodeNameError), ), nil @@ -253,7 +252,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { wg := &sync.WaitGroup{} - for i := 0; i < testMessagesCount; i++ { + for range testMessagesCount { msg := createGoogleATestMessage() wg.Add(1) @@ -276,7 +275,7 @@ func sendTestMessagesAsync(t *testing.T, conn *dns.Conn) { func sendTestMessages(t *testing.T, conn *dns.Conn) { t.Helper() - for i := 0; i < testMessagesCount; i++ { + for i := range testMessagesCount { req := createGoogleATestMessage() err := conn.WriteMsg(req) assert.NoErrorf(t, err, "cannot write message #%d: %s", i, err) @@ -491,19 +490,10 @@ func TestServerRace(t *testing.T) { } func TestSafeSearch(t *testing.T) { - resolver := &aghtest.Resolver{ - OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { - ip4, ip6 := aghtest.HostToIPs(host) - - return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil - }, - } - safeSearchConf := filtering.SafeSearchConfig{ - Enabled: true, - Google: true, - Yandex: true, - CustomResolver: resolver, + Enabled: true, + Google: true, + Yandex: true, } filterConf := &filtering.Config{ @@ -540,7 +530,6 @@ func TestSafeSearch(t *testing.T) { client := &dns.Client{} yandexIP := netip.AddrFrom4([4]byte{213, 180, 193, 56}) - googleIP, _ := aghtest.HostToIPs("forcesafesearch.google.com") testCases := []struct { host string @@ -564,19 +553,19 @@ func TestSafeSearch(t *testing.T) { wantCNAME: "", }, { host: "www.google.com.", - want: googleIP, + want: netip.Addr{}, wantCNAME: "forcesafesearch.google.com.", }, { host: "www.google.com.af.", - want: googleIP, + want: netip.Addr{}, wantCNAME: "forcesafesearch.google.com.", }, { host: "www.google.be.", - want: googleIP, + want: netip.Addr{}, wantCNAME: "forcesafesearch.google.com.", }, { host: "www.google.by.", - want: googleIP, + want: netip.Addr{}, wantCNAME: "forcesafesearch.google.com.", }} @@ -593,12 +582,15 @@ func TestSafeSearch(t *testing.T) { cname := testutil.RequireTypeAssert[*dns.CNAME](t, reply.Answer[0]) assert.Equal(t, tc.wantCNAME, cname.Target) + + a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[1]) + assert.NotEmpty(t, a.A) } else { require.Len(t, reply.Answer, 1) - } - a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[len(reply.Answer)-1]) - assert.Equal(t, net.IP(tc.want.AsSlice()), a.A) + a := testutil.RequireTypeAssert[*dns.A](t, reply.Answer[0]) + assert.Equal(t, net.IP(tc.want.AsSlice()), a.A) + } }) } } @@ -691,7 +683,7 @@ func TestServerCustomClientUpstream(t *testing.T) { ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { atomic.AddUint32(&upsCalledCounter, 1) - return aghalg.Coalesce( + return cmp.Or( aghtest.MatchedResponse(req, dns.TypeA, "host", "192.168.0.1"), new(dns.Msg).SetRcode(req, dns.RcodeNameError), ), nil @@ -1152,7 +1144,7 @@ func TestRewrite(t *testing.T) { })) ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) { - return aghalg.Coalesce( + return cmp.Or( aghtest.MatchedResponse(req, dns.TypeA, "example.org", "4.3.2.1"), new(dns.Msg).SetRcode(req, dns.RcodeNameError), ), nil @@ -1481,7 +1473,7 @@ func TestServer_Exchange(t *testing.T) { require.NoError(t, err) extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - resp := aghalg.Coalesce( + resp := cmp.Or( aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)), doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))), new(dns.Msg).SetRcode(req, dns.RcodeNameError), @@ -1495,7 +1487,7 @@ func TestServer_Exchange(t *testing.T) { require.NoError(t, err) locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - resp := aghalg.Coalesce( + resp := cmp.Or( aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)), new(dns.Msg).SetRcode(req, dns.RcodeNameError), ) diff --git a/internal/dnsforward/dnsrewrite.go b/internal/dnsforward/dnsrewrite.go index 8b9a0fb1..7d9fde72 100644 --- a/internal/dnsforward/dnsrewrite.go +++ b/internal/dnsforward/dnsrewrite.go @@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite( res *filtering.Result, pctx *proxy.DNSContext, ) (err error) { - resp := s.makeResponse(req) + resp := s.replyCompressed(req) dnsrr := res.DNSRewriteResult if dnsrr == nil { return errors.Error("no dns rewrite rule content") diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 46599c11..f6cd319d 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -1,57 +1,17 @@ package dnsforward import ( - "encoding/binary" "fmt" "net" "slices" "strings" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) -// beforeRequestHandler is the handler that is called before any other -// processing, including logs. It performs access checks and puts the client -// ID, if there is one, into the server's cache. -func (s *Server) beforeRequestHandler( - _ *proxy.Proxy, - pctx *proxy.DNSContext, -) (reply bool, err error) { - clientID, err := s.clientIDFromDNSContext(pctx) - if err != nil { - return false, fmt.Errorf("getting clientid: %w", err) - } - - blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID) - if blocked { - return s.preBlockedResponse(pctx) - } - - if len(pctx.Req.Question) == 1 { - q := pctx.Req.Question[0] - qt := q.Qtype - host := aghnet.NormalizeDomain(q.Name) - if s.access.isBlockedHost(host, qt) { - log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host) - - return s.preBlockedResponse(pctx) - } - } - - if clientID != "" { - key := [8]byte{} - binary.BigEndian.PutUint64(key[:], pctx.RequestID) - s.clientIDCache.Set(key[:], []byte(clientID)) - } - - return true, nil -} - // clientRequestFilteringSettings looks up client filtering settings using the // client's IP address and ID, if any, from dctx. func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { @@ -71,6 +31,7 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err req := pctx.Req q := req.Question[0] host := strings.TrimSuffix(q.Name, ".") + resVal, err := s.dnsFilter.CheckHost(host, q.Qtype, dctx.setts) if err != nil { return nil, fmt.Errorf("checking host %q: %w", host, err) @@ -79,22 +40,15 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err // TODO(a.garipov): Make CheckHost return a pointer. res = &resVal switch { - case res.IsFiltered: - log.Debug( - "dnsforward: host %q is filtered, reason: %q; rule: %q", - host, - res.Reason, - res.Rules[0].Text, - ) - pctx.Res = s.genDNSFilterMessage(pctx, res) - case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && - res.CanonName != "" && - len(res.IPList) == 0: + case isRewrittenCNAME(res): // Resolve the new canonical name, not the original host name. The // original question is readded in processFilteringAfterResponse. dctx.origQuestion = q req.Question[0].Name = dns.Fqdn(res.CanonName) - case res.Reason == filtering.Rewritten: + case res.IsFiltered: + log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason) + pctx.Res = s.genDNSFilterMessage(pctx, res) + case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch): pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName) case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts): if err = s.filterDNSRewrite(req, res, pctx); err != nil { @@ -105,6 +59,17 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err return res, err } +// isRewrittenCNAME returns true if the request considered to be rewritten with +// CNAME and has no resolved IPs. +func isRewrittenCNAME(res *filtering.Result) (ok bool) { + return res.Reason.In( + filtering.Rewritten, + filtering.RewrittenRule, + filtering.FilteredSafeSearch) && + res.CanonName != "" && + len(res.IPList) == 0 +} + // checkHostRules checks the host against filters. It is safe for concurrent // use. func (s *Server) checkHostRules( diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 2d446ac9..76f88edc 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -1,6 +1,7 @@ package dnsforward import ( + "cmp" "encoding/json" "fmt" "io" @@ -261,55 +262,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) { } } -// checkBootstrap returns an error if any bootstrap address is invalid. -func (req *jsonDNSConfig) checkBootstrap() (err error) { - if req.Bootstraps == nil { - return nil - } - - var b string - defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }() - - for _, b = range *req.Bootstraps { - if b == "" { - return errors.Error("empty") - } - - var resolver *upstream.UpstreamResolver - if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil { - // Don't wrap the error because it's informative enough as is. - return err - } - - if err = resolver.Close(); err != nil { - return fmt.Errorf("closing %s: %w", b, err) - } - } - - return nil -} - -// checkFallbacks returns an error if any fallback address is invalid. -func (req *jsonDNSConfig) checkFallbacks() (err error) { - if req.Fallbacks == nil { - return nil - } - - _, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{}) - if err != nil { - return fmt.Errorf("fallback servers: %w", err) - } - - return nil -} - // validate returns an error if any field of req is invalid. // // TODO(s.chzhen): Parse, don't validate. -func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { +func (req *jsonDNSConfig) validate( + ownAddrs addrPortSet, + sysResolvers SystemResolvers, + privateNets netutil.SubnetSet, +) (err error) { defer func() { err = errors.Annotate(err, "validating dns config: %w") }() - err = req.validateUpstreamDNSServers(privateNets) + err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets) if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -342,20 +305,77 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { return nil } +// checkBootstrap returns an error if any bootstrap address is invalid. +func (req *jsonDNSConfig) checkBootstrap() (err error) { + if req.Bootstraps == nil { + return nil + } + + var b string + defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }() + + for _, b = range *req.Bootstraps { + if b == "" { + return errors.Error("empty") + } + + var resolver *upstream.UpstreamResolver + if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil { + // Don't wrap the error because it's informative enough as is. + return err + } + + if err = resolver.Close(); err != nil { + return fmt.Errorf("closing %s: %w", b, err) + } + } + + return nil +} + +// checkPrivateRDNS returns an error if the configuration of the private RDNS is +// not valid. +func (req *jsonDNSConfig) checkPrivateRDNS( + ownAddrs addrPortSet, + sysResolvers SystemResolvers, + privateNets netutil.SubnetSet, +) (err error) { + if (req.UsePrivateRDNS == nil || !*req.UsePrivateRDNS) && req.LocalPTRUpstreams == nil { + return nil + } + + addrs := cmp.Or(req.LocalPTRUpstreams, &[]string{}) + + uc, err := newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, &upstream.Options{}) + err = errors.WithDeferred(err, uc.Close()) + if err != nil { + return fmt.Errorf("private upstream servers: %w", err) + } + + return nil +} + // validateUpstreamDNSServers returns an error if any field of req is invalid. -func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { +func (req *jsonDNSConfig) validateUpstreamDNSServers( + ownAddrs addrPortSet, + sysResolvers SystemResolvers, + privateNets netutil.SubnetSet, +) (err error) { + var uc *proxy.UpstreamConfig + opts := &upstream.Options{} + if req.Upstreams != nil { - _, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{}) + uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts) + err = errors.WithDeferred(err, uc.Close()) if err != nil { return fmt.Errorf("upstream servers: %w", err) } } - if req.LocalPTRUpstreams != nil { - err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) - if err != nil { - return fmt.Errorf("private upstream servers: %w", err) - } + err = req.checkPrivateRDNS(ownAddrs, sysResolvers, privateNets) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err } err = req.checkBootstrap() @@ -364,10 +384,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS return err } - err = req.checkFallbacks() - if err != nil { - // Don't wrap the error since it's informative enough as is. - return err + if req.Fallbacks != nil { + uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts) + err = errors.WithDeferred(err, uc.Close()) + if err != nil { + return fmt.Errorf("fallback servers: %w", err) + } } return nil @@ -436,7 +458,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { return } - err = req.validate(s.privateNets) + // TODO(e.burkov): Consider prebuilding this set on startup. + ourAddrs, err := s.conf.ourAddrsSet() + if err != nil { + // TODO(e.burkov): Put into openapi. + aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err) + + return + } + + err = req.validate(ourAddrs, s.sysResolvers, s.privateNets) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -587,7 +618,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { } var boots []*upstream.UpstreamResolver - opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) + opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index b0145b23..56daa4bf 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { wantSet: "", }, { name: "local_ptr_upstreams_bad", - wantSet: `validating dns config: ` + - `private upstream servers: checking domain-specific upstreams: ` + - `bad arpa domain name "non.arpa.": not a reversed ip network`, + wantSet: `validating dns config: private upstream servers: ` + + `bad arpa domain name "non.arpa": not a reversed ip network`, }, { name: "local_ptr_upstreams_null", wantSet: "", @@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) { } } -func TestValidateUpstreamsPrivate(t *testing.T) { - ss := netutil.SubnetSetFunc(netutil.IsLocallyServed) - - testCases := []struct { - name string - wantErr string - u string - }{{ - name: "success_address", - wantErr: ``, - u: "[/1.0.0.127.in-addr.arpa/]#", - }, { - name: "success_subnet", - wantErr: ``, - u: "[/127.in-addr.arpa/]#", - }, { - name: "not_arpa_subnet", - wantErr: `checking domain-specific upstreams: ` + - `bad arpa domain name "hello.world.": not a reversed ip network`, - u: "[/hello.world/]#", - }, { - name: "non-private_arpa_address", - wantErr: `checking domain-specific upstreams: ` + - `arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`, - u: "[/1.2.3.4.in-addr.arpa/]#", - }, { - name: "non-private_arpa_subnet", - wantErr: `checking domain-specific upstreams: ` + - `arpa domain "128.in-addr.arpa." should point to a locally-served network`, - u: "[/128.in-addr.arpa/]#", - }, { - name: "several_bad", - wantErr: `checking domain-specific upstreams: ` + - `arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" + - `bad arpa domain name "non.arpa.": not a reversed ip network`, - u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#", - }, { - name: "partial_good", - wantErr: "", - u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#", - }} - - for _, tc := range testCases { - set := []string{"192.168.0.1", tc.u} - - t.Run(tc.name, func(t *testing.T) { - err := ValidateUpstreamsPrivate(set, ss) - testutil.AssertErrorMsg(t, tc.wantErr, err) - }) - } -} - func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { t.Helper() diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index 5068cd9a..f645ab90 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -11,17 +11,21 @@ import ( "github.com/miekg/dns" ) -// makeResponse creates a DNS response by req and sets necessary flags. It also -// guarantees that req.Question will be not empty. -func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) { - resp = &dns.Msg{ - MsgHdr: dns.MsgHdr{ - RecursionAvailable: true, - }, - Compress: true, - } +// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor] +// template. Also extract all the methods to a separate entity. - resp.SetReply(req) +// reply creates a DNS response for req. +func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) { + resp = (&dns.Msg{}).SetRcode(req, code) + resp.RecursionAvailable = true + + return resp +} + +// replyCompressed creates a DNS response for req and sets the compress flag. +func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) { + resp = s.reply(req, dns.RcodeSuccess) + resp.Compress = true return resp } @@ -48,10 +52,10 @@ func (s *Server) genDNSFilterMessage( ) (resp *dns.Msg) { req := dctx.Req qt := req.Question[0].Qtype - if qt != dns.TypeA && qt != dns.TypeAAAA { + if qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeHTTPS { m, _, _ := s.dnsFilter.BlockingMode() if m == filtering.BlockingModeNullIP { - return s.makeResponse(req) + return s.replyCompressed(req) } return s.newMsgNODATA(req) @@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage( // getCNAMEWithIPs generates a filtered response to req for with CNAME record // and provided ips. func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { - resp = s.makeResponse(req) + resp = s.replyCompressed(req) originalName := req.Question[0].Name @@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M case filtering.BlockingModeNullIP: return s.makeResponseNullIP(req) case filtering.BlockingModeNXDOMAIN: - return s.genNXDomain(req) + return s.NewMsgNXDOMAIN(req) case filtering.BlockingModeREFUSED: return s.makeResponseREFUSED(req) default: log.Error("dnsforward: invalid blocking mode %q", mode) - return s.makeResponse(req) + return s.replyCompressed(req) } } @@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP( // genDNSFilterMessage. log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) - return s.makeResponse(req) + return s.replyCompressed(req) } } -func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeServerFailure) - resp.RecursionAvailable = true - return &resp -} - func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { - resp := s.makeResponse(request) + resp := s.replyCompressed(request) resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) return resp } func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { - resp := s.makeResponse(request) + resp := s.replyCompressed(request) resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) return resp } @@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M // Go on and return an empty response. } - resp = s.makeResponse(req) + resp = s.replyCompressed(req) resp.Answer = ans return resp @@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) { case dns.TypeAAAA: resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) default: - resp = s.makeResponse(req) + resp = s.replyCompressed(req) } return resp @@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo if newAddr == "" { log.Info("dnsforward: block host is not specified") - return s.genServerFailure(request) + return s.NewMsgSERVFAIL(request) } ip, err := netip.ParseAddr(newAddr) @@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo if prx == nil { log.Debug("dnsforward: %s", srvClosedErr) - return s.genServerFailure(request) + return s.NewMsgSERVFAIL(request) } err = prx.Resolve(newContext) if err != nil { log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) - return s.genServerFailure(request) + return s.NewMsgSERVFAIL(request) } - resp := s.makeResponse(request) + resp := s.replyCompressed(request) if newContext.Res != nil { for _, answer := range newContext.Res.Answer { answer.Header().Name = request.Question[0].Name @@ -342,48 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo return resp } -// preBlockedResponse returns a protocol-appropriate response for a request that -// was blocked by access settings. -func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) { - if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt { - // Return nil so that dnsproxy drops the connection and thus - // prevent DNS amplification attacks. - return false, nil - } - - pctx.Res = s.makeResponseREFUSED(pctx.Req) - - // Return true so that dnsproxy responds with the REFUSED message. - return true, nil -} - // Create REFUSED DNS response -func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeRefused) - resp.RecursionAvailable = true - return &resp +func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg { + return s.reply(req, dns.RcodeRefused) } // newMsgNODATA returns a properly initialized NODATA response. // // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) { - resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess) - resp.RecursionAvailable = true + resp = s.reply(req, dns.RcodeSuccess) resp.Ns = s.genSOA(req) return resp } -func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg { - resp := dns.Msg{} - resp.SetRcode(request, dns.RcodeNameError) - resp.RecursionAvailable = true - resp.Ns = s.genSOA(request) - return &resp -} - func (s *Server) genSOA(request *dns.Msg) []dns.RR { zone := "" if len(request.Question) > 0 { @@ -415,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR { if len(zone) > 0 && zone[0] != '.' { soa.Mbox += zone } + return []dns.RR{&soa} } + +// type check +var _ proxy.MessageConstructor = (*Server)(nil) + +// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for +// *Server. +func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) { + resp = s.reply(req, dns.RcodeNameError) + resp.Ns = s.genSOA(req) + + return resp +} + +// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for +// *Server. +func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) { + return s.reply(req, dns.RcodeServerFailure) +} + +// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for +// *Server. +func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) { + resp = s.reply(req, dns.RcodeNotImplemented) + + // Most of the Internet and especially the inner core has an MTU of at least + // 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet + // is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)). + // + // See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17. + const maxUDPPayload = 1452 + + // NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so + // explicitly set it. + resp.SetEdns0(maxUDPPayload, false) + + return resp +} diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 8cfe923a..8c66ccf9 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -1,20 +1,17 @@ package dnsforward import ( + "cmp" "encoding/binary" "net" "net/netip" - "strconv" "strings" "time" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" ) @@ -34,11 +31,6 @@ type dnsContext struct { // response is modified by filters. origResp *dns.Msg - // unreversedReqIP stores an IP address obtained from a PTR request if it - // was parsed successfully and belongs to one of the locally served IP - // ranges. - unreversedReqIP netip.Addr - // err is the error returned from a processing function. err error @@ -63,10 +55,6 @@ type dnsContext struct { // responseAD shows if the response had the AD bit set. responseAD bool - // isLocalClient shows if client's IP address is from locally served - // network. - isLocalClient bool - // isDHCPHost is true if the request for a local domain name and the DHCP is // available for this request. isDHCPHost bool @@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error // (*proxy.Proxy).handleDNSRequest method performs it before calling the // appropriate handler. mods := []modProcessFunc{ - s.processRecursion, s.processInitial, s.processDDRQuery, - s.processDetermineLocal, s.processDHCPHosts, - s.processRestrictLocal, s.processDHCPAddrs, s.processFilteringBeforeRequest, - s.processLocalPTR, s.processUpstream, s.processFilteringAfterResponse, s.ipset.process, @@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error return nil } -// processRecursion checks the incoming request and halts its handling by -// answering NXDOMAIN if s has tried to resolve it recently. -func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing recursion") - defer log.Debug("dnsforward: finished processing recursion") - - pctx := dctx.proxyCtx - - if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) { - log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name) - pctx.Res = s.genNXDomain(pctx.Req) - - return resultCodeFinish - } - - return resultCodeSuccess -} - // mozillaFQDN is the domain used to signal the Firefox browser to not use its // own DoH server. // @@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { } if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { - pctx.Res = s.genNXDomain(pctx.Req) + pctx.Res = s.NewMsgNXDOMAIN(pctx.Req) return resultCodeFinish } if q.Name == healthcheckFQDN { // Generate a NODATA negative response to make nslookup exit with 0. - pctx.Res = s.makeResponse(pctx.Req) + pctx.Res = s.replyCompressed(pctx.Req) return resultCodeFinish } @@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) { // // [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { - resp = s.makeResponse(req) + resp = s.replyCompressed(req) if req.Question[0].Qtype != dns.TypeSVCB { return resp } @@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { return resp } -// processDetermineLocal determines if the client's IP address is from locally -// served network and saves the result into the context. -func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing local detection") - defer log.Debug("dnsforward: finished processing local detection") - - rc = resultCodeSuccess - - dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr()) - - return rc -} - // processDHCPHosts respond to A requests if the target hostname is known to // the server. It responds with a mapped IP address if the DNS64 is enabled and // the request is for AAAA. @@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - if !dctx.isLocalClient { + if !pctx.IsPrivateClient { log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) - pctx.Res = s.genNXDomain(req) + pctx.Res = s.NewMsgNXDOMAIN(req) // Do not even put into query log. return resultCodeFinish @@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) - resp := s.makeResponse(req) + resp := s.replyCompressed(req) switch q.Qtype { case dns.TypeA: a := &dns.A{ @@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } -// indexFirstV4Label returns the index at which the reversed IPv4 address -// starts, assuming the domain is pre-validated ARPA domain having in-addr and -// arpa labels removed. -func indexFirstV4Label(domain string) (idx int) { - idx = len(domain) - for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ { - curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1 - _, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8) - if parseErr != nil { - return idx - } - - idx = curIdx - } - - return idx -} - -// indexFirstV6Label returns the index at which the reversed IPv6 address -// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa -// labels removed. -func indexFirstV6Label(domain string) (idx int) { - idx = len(domain) - for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ { - curIdx := idx - len("a.") - if curIdx > 1 && domain[curIdx-1] != '.' { - return idx - } - - nibble := domain[curIdx] - if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') { - return idx - } - - idx = curIdx - } - - return idx -} - -// extractARPASubnet tries to convert a reversed ARPA address being a part of -// domain to an IP network. domain must be an FQDN. -// -// TODO(e.burkov): Move to golibs. -func extractARPASubnet(domain string) (pref netip.Prefix, err error) { - err = netutil.ValidateDomainName(strings.TrimSuffix(domain, ".")) - if err != nil { - // Don't wrap the error since it's informative enough as is. - return netip.Prefix{}, err - } - - const ( - v4Suffix = "in-addr.arpa." - v6Suffix = "ip6.arpa." - ) - - domain = strings.ToLower(domain) - - var idx int - switch { - case strings.HasSuffix(domain, v4Suffix): - idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)]) - case strings.HasSuffix(domain, v6Suffix): - idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)]) - default: - return netip.Prefix{}, &netutil.AddrError{ - Err: netutil.ErrNotAReversedSubnet, - Kind: netutil.AddrKindARPA, - Addr: domain, - } - } - - return netutil.PrefixFromReversedAddr(domain[idx:]) -} - -// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses -// in locally served network from external clients. -func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing local restriction") - defer log.Debug("dnsforward: finished processing local restriction") - - pctx := dctx.proxyCtx - req := pctx.Req - q := req.Question[0] - if q.Qtype != dns.TypePTR { - // No need for restriction. - return resultCodeSuccess - } - - subnet, err := extractARPASubnet(q.Name) - if err != nil { - if errors.Is(err, netutil.ErrNotAReversedSubnet) { - log.Debug("dnsforward: request is not for arpa domain") - - return resultCodeSuccess - } - - log.Debug("dnsforward: parsing reversed addr: %s", err) - - return resultCodeError - } - - // Restrict an access to local addresses for external clients. We also - // assume that all the DHCP leases we give are locally served or at least - // shouldn't be accessible externally. - subnetAddr := subnet.Addr() - if !s.privateNets.Contains(subnetAddr) { - return resultCodeSuccess - } - - log.Debug("dnsforward: addr %s is from locally served network", subnetAddr) - - if !dctx.isLocalClient { - log.Debug("dnsforward: %q requests an internal ip", pctx.Addr) - pctx.Res = s.genNXDomain(req) - - // Do not even put into query log. - return resultCodeFinish - } - - // Do not perform unreversing ever again. - dctx.unreversedReqIP = subnetAddr - - // There is no need to filter request from external addresses since this - // code is only executed when the request is for locally served ARPA - // hostname so disable redundant filters. - dctx.setts.ParentalEnabled = false - dctx.setts.SafeBrowsingEnabled = false - dctx.setts.SafeSearchEnabled = false - dctx.setts.ServicesRules = nil - - // Nothing to restrict. - return resultCodeSuccess -} - // processDHCPAddrs responds to PTR requests if the target IP is leased by the // DHCP server. func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { @@ -562,23 +380,27 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - ipAddr := dctx.unreversedReqIP - if ipAddr == (netip.Addr{}) { + req := pctx.Req + q := req.Question[0] + pref := pctx.RequestedPrivateRDNS + // TODO(e.burkov): Consider answering authoritatively for SOA and NS + // queries. + if pref == (netip.Prefix{}) || q.Qtype != dns.TypePTR { return resultCodeSuccess } - host := s.dhcpServer.HostByIP(ipAddr) + addr := pref.Addr() + host := s.dhcpServer.HostByIP(addr) if host == "" { return resultCodeSuccess } - log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host) + log.Debug("dnsforward: dhcp client %s is %q", addr, host) - req := pctx.Req - resp := s.makeResponse(req) + resp := s.replyCompressed(req) ptr := &dns.PTR{ Hdr: dns.RR_Header{ - Name: req.Question[0].Name, + Name: q.Name, Rrtype: dns.TypePTR, // TODO(e.burkov): Use [dhcpsvc.Lease.Expiry]. See // https://github.com/AdguardTeam/AdGuardHome/issues/3932. @@ -593,62 +415,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } -// processLocalPTR responds to PTR requests if the target IP is detected to be -// inside the local network and the query was not answered from DHCP. -func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing local ptr") - defer log.Debug("dnsforward: finished processing local ptr") - - pctx := dctx.proxyCtx - if pctx.Res != nil { - return resultCodeSuccess - } - - ip := dctx.unreversedReqIP - if ip == (netip.Addr{}) { - return resultCodeSuccess - } - - s.serverLock.RLock() - defer s.serverLock.RUnlock() - - if s.conf.UsePrivateRDNS { - s.recDetector.add(*pctx.Req) - if err := s.localResolvers.Resolve(pctx); err != nil { - log.Debug("dnsforward: resolving private address: %s", err) - - // Generate the server failure if the private upstream configuration - // is empty. - // - // This is a crutch, see TODO at [Server.localResolvers]. - if errors.Is(err, upstream.ErrNoUpstreams) { - pctx.Res = s.genServerFailure(pctx.Req) - - // Do not even put into query log. - return resultCodeFinish - } - - dctx.err = err - - return resultCodeError - } - } - - if pctx.Res == nil { - pctx.Res = s.genNXDomain(pctx.Req) - - // Do not even put into query log. - return resultCodeFinish - } - - return resultCodeSuccess -} - // Apply filtering logic func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { log.Debug("dnsforward: started processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req") + if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) { + // There is no need to filter request for locally served ARPA hostname + // so disable redundant filters. + dctx.setts.ParentalEnabled = false + dctx.setts.SafeBrowsingEnabled = false + dctx.setts.SafeSearchEnabled = false + dctx.setts.ServicesRules = nil + } + if dctx.proxyCtx.Res != nil { // Go on since the response is already set. return resultCodeSuccess @@ -695,7 +475,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { // local domain name if there is one. name := req.Question[0].Name log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) - pctx.Res = s.genNXDomain(req) + pctx.Res = s.NewMsgNXDOMAIN(req) return resultCodeFinish } @@ -712,21 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { return resultCodeError } - if err := prx.Resolve(pctx); err != nil { - if errors.Is(err, upstream.ErrNoUpstreams) { - // Do not even put into querylog. Currently this happens either - // when the private resolvers enabled and the request is DNS64 PTR, - // or when the client isn't considered local by prx. - // - // TODO(e.burkov): Make proxy detect local client the same way as - // AGH does. - pctx.Res = s.genNXDomain(req) - - return resultCodeFinish - } - - dctx.err = err - + if dctx.err = prx.Resolve(pctx); dctx.err != nil { return resultCodeError } @@ -810,7 +576,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { } // Use the ClientID first, since it has a higher priority. - id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String()) + id := cmp.Or(clientID, pctx.Addr.Addr().String()) upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) if err != nil { log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) @@ -835,7 +601,8 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) return resultCodeSuccess case filtering.Rewritten, - filtering.RewrittenRule: + filtering.RewrittenRule, + filtering.FilteredSafeSearch: if dctx.origQuestion.Name == "" { // origQuestion is set in case we get only CNAME without IP from @@ -845,11 +612,10 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) pctx := dctx.proxyCtx pctx.Req.Question[0], pctx.Res.Question[0] = dctx.origQuestion, dctx.origQuestion - if len(pctx.Res.Answer) > 0 { - rr := s.genAnswerCNAME(pctx.Req, res.CanonName) - answer := append([]dns.RR{rr}, pctx.Res.Answer...) - pctx.Res.Answer = answer - } + + rr := s.genAnswerCNAME(pctx.Req, res.CanonName) + answer := append([]dns.RR{rr}, pctx.Res.Answer...) + pctx.Res.Answer = answer return resultCodeSuccess default: diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 5dc4e21b..e47027a5 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -1,14 +1,15 @@ package dnsforward import ( + "cmp" "net" "net/netip" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" @@ -70,8 +71,6 @@ func TestServer_ProcessInitial(t *testing.T) { }} for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -171,8 +170,6 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) { }} for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -379,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) { return f } -func TestServer_ProcessDetermineLocal(t *testing.T) { - s := &Server{ - privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), - } - - testCases := []struct { - want assert.BoolAssertionFunc - name string - cliAddr netip.AddrPort - }{{ - want: assert.True, - name: "local", - cliAddr: netip.MustParseAddrPort("192.168.0.1:1"), - }, { - want: assert.False, - name: "external", - cliAddr: netip.MustParseAddrPort("250.249.0.1:1"), - }, { - want: assert.False, - name: "invalid", - cliAddr: netip.AddrPort{}, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - proxyCtx := &proxy.DNSContext{ - Addr: tc.cliAddr, - } - dctx := &dnsContext{ - proxyCtx: proxyCtx, - } - s.processDetermineLocal(dctx) - - tc.want(t, dctx.isLocalClient) - }) - } -} - func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { const ( localDomainSuffix = "lan" @@ -486,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { dctx := &dnsContext{ proxyCtx: &proxy.DNSContext{ - Req: req, + Req: req, + IsPrivateClient: tc.isLocalCli, }, - isLocalClient: tc.isLocalCli, } res := s.processDHCPHosts(dctx) @@ -621,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { dctx := &dnsContext{ proxyCtx: &proxy.DNSContext{ - Req: req, + Req: req, + IsPrivateClient: true, }, - isLocalClient: true, } t.Run(tc.name, func(t *testing.T) { @@ -658,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { } } -func TestServer_ProcessRestrictLocal(t *testing.T) { +// TODO(e.burkov): Rewrite this test to use the whole server instead of just +// testing the [handleDNSRequest] method. See comment on +// "from_external_for_local" test case. +func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) { + intAddr := netip.MustParseAddr("192.168.1.1") + intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice()) + require.NoError(t, err) + + extAddr := netip.MustParseAddr("254.253.252.1") + extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice()) + require.NoError(t, err) + const ( - extPTRQuestion = "251.252.253.254.in-addr.arpa." - extPTRAnswer = "host1.example.net." - intPTRQuestion = "1.1.168.192.in-addr.arpa." - intPTRAnswer = "some.local-client." + extPTRAnswer = "host1.example.net." + intPTRAnswer = "some.local-client." ) localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - resp := aghalg.Coalesce( + resp := cmp.Or( aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), + (&dns.Msg{}).SetRcode(req, dns.RcodeNameError), ) require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) @@ -696,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) { startDeferStop(t, s) testCases := []struct { - name string - want string - question net.IP - cliAddr netip.AddrPort - wantLen int + name string + question string + wantErr error + wantAns []dns.RR + isPrivate bool }{{ - name: "from_local_to_external", - want: "host1.example.net.", - question: net.IP{254, 253, 252, 251}, - cliAddr: netip.MustParseAddrPort("192.168.10.10:1"), - wantLen: 1, + name: "from_local_for_external", + question: extPTRQuestion, + wantErr: nil, + wantAns: []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(extPTRQuestion), + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 60, + Rdlength: uint16(len(extPTRAnswer) + 1), + }, + Ptr: dns.Fqdn(extPTRAnswer), + }}, + isPrivate: true, }, { - name: "from_external_for_local", - want: "", - question: net.IP{192, 168, 1, 1}, - cliAddr: netip.MustParseAddrPort("254.253.252.251:1"), - wantLen: 0, + // In theory this case is not reproducible because [proxy.Proxy] should + // respond to such queries with NXDOMAIN before they reach + // [Server.handleDNSRequest]. + name: "from_external_for_local", + question: intPTRQuestion, + wantErr: upstream.ErrNoUpstreams, + wantAns: nil, + isPrivate: false, }, { name: "from_local_for_local", - want: "some.local-client.", - question: net.IP{192, 168, 1, 1}, - cliAddr: netip.MustParseAddrPort("192.168.1.2:1"), - wantLen: 1, + question: intPTRQuestion, + wantErr: nil, + wantAns: []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(intPTRQuestion), + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 60, + Rdlength: uint16(len(intPTRAnswer) + 1), + }, + Ptr: dns.Fqdn(intPTRAnswer), + }}, + isPrivate: true, }, { name: "from_external_for_external", - want: "host1.example.net.", - question: net.IP{254, 253, 252, 251}, - cliAddr: netip.MustParseAddrPort("254.253.252.255:1"), - wantLen: 1, + question: extPTRQuestion, + wantErr: nil, + wantAns: []dns.RR{&dns.PTR{ + Hdr: dns.RR_Header{ + Name: dns.Fqdn(extPTRQuestion), + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: 60, + Rdlength: uint16(len(extPTRAnswer) + 1), + }, + Ptr: dns.Fqdn(extPTRAnswer), + }}, + isPrivate: false, }} for _, tc := range testCases { - reqAddr, err := dns.ReverseAddr(tc.question.String()) - require.NoError(t, err) - req := createTestMessageWithType(reqAddr, dns.TypePTR) + pref, extErr := netutil.ExtractReversedAddr(tc.question) + require.NoError(t, extErr) + req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR) pctx := &proxy.DNSContext{ - Proto: proxy.ProtoTCP, - Req: req, - Addr: tc.cliAddr, + Req: req, + IsPrivateClient: tc.isPrivate, + } + // TODO(e.burkov): Configure the subnet set properly. + if netutil.IsLocallyServed(pref.Addr()) { + pctx.RequestedPrivateRDNS = pref } t.Run(tc.name, func(t *testing.T) { - err = s.handleDNSRequest(nil, pctx) - require.NoError(t, err) - require.NotNil(t, pctx.Res) - require.Len(t, pctx.Res.Answer, tc.wantLen) + err = s.handleDNSRequest(s.dnsProxy, pctx) + require.ErrorIs(t, err, tc.wantErr) - if tc.wantLen > 0 { - assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) - } + require.NotNil(t, pctx.Res) + assert.Equal(t, tc.wantAns, pctx.Res.Answer) }) } } -func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { +func TestServer_ProcessUpstream_localPTR(t *testing.T) { const locDomain = "some.local." const reqAddr = "1.1.168.192.in-addr.arpa." localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { - resp := aghalg.Coalesce( + resp := cmp.Or( aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), - new(dns.Msg).SetRcode(req, dns.RcodeNameError), + (&dns.Msg{}).SetRcode(req, dns.RcodeNameError), ) require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) }) localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() - s := createTestServer( - t, - &filtering.Config{ - BlockingMode: filtering.BlockingModeDefault, - }, - ServerConfig{ - UDPListenAddrs: []*net.UDPAddr{{}}, - TCPListenAddrs: []*net.TCPAddr{{}}, - Config: Config{ - UpstreamMode: UpstreamModeLoadBalance, - EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, - }, - UsePrivateRDNS: true, - LocalPTRResolvers: []string{localUpsAddr}, - ServePlainDNS: true, - }, - ) - - var proxyCtx *proxy.DNSContext - var dnsCtx *dnsContext - setup := func(use bool) { - proxyCtx = &proxy.DNSContext{ - Addr: testClientAddrPort, - Req: createTestMessageWithType(reqAddr, dns.TypePTR), + newPrxCtx := func() (prxCtx *proxy.DNSContext) { + return &proxy.DNSContext{ + Addr: testClientAddrPort, + Req: createTestMessageWithType(reqAddr, dns.TypePTR), + IsPrivateClient: true, + RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"), } - dnsCtx = &dnsContext{ - proxyCtx: proxyCtx, - unreversedReqIP: netip.MustParseAddr("192.168.1.1"), - } - s.conf.UsePrivateRDNS = use } t.Run("enabled", func(t *testing.T) { - setup(true) + s := createTestServer( + t, + &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, + ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + Config: Config{ + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + UsePrivateRDNS: true, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, + }, + ) + pctx := newPrxCtx() - rc := s.processLocalPTR(dnsCtx) + rc := s.processUpstream(&dnsContext{proxyCtx: pctx}) require.Equal(t, resultCodeSuccess, rc) - require.NotEmpty(t, proxyCtx.Res.Answer) + require.NotEmpty(t, pctx.Res.Answer) + ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0]) - assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) + assert.Equal(t, locDomain, ptr.Ptr) }) t.Run("disabled", func(t *testing.T) { - setup(false) + s := createTestServer( + t, + &filtering.Config{ + BlockingMode: filtering.BlockingModeDefault, + }, + ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + Config: Config{ + UpstreamMode: UpstreamModeLoadBalance, + EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, + }, + UsePrivateRDNS: false, + LocalPTRResolvers: []string{localUpsAddr}, + ServePlainDNS: true, + }, + ) + pctx := newPrxCtx() - rc := s.processLocalPTR(dnsCtx) - require.Equal(t, resultCodeFinish, rc) - require.Empty(t, proxyCtx.Res.Answer) + rc := s.processUpstream(&dnsContext{proxyCtx: pctx}) + require.Equal(t, resultCodeError, rc) + require.Empty(t, pctx.Res.Answer) }) } @@ -830,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) { assert.Empty(t, ipStringFromAddr(nil)) }) } - -// TODO(e.burkov): Add fuzzing when moving to golibs. -func TestExtractARPASubnet(t *testing.T) { - const ( - v4Suf = `in-addr.arpa.` - v4Part = `2.1.` + v4Suf - v4Whole = `4.3.` + v4Part - - v6Suf = `ip6.arpa.` - v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf - v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part - ) - - v4Pref := netip.MustParsePrefix("1.2.3.4/32") - v4PrefPart := netip.MustParsePrefix("1.2.0.0/16") - v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128") - v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64") - - testCases := []struct { - want netip.Prefix - name string - domain string - wantErr string - }{{ - want: netip.Prefix{}, - name: "not_an_arpa", - domain: "some.domain.name.", - wantErr: `bad arpa domain name "some.domain.name.": ` + - `not a reversed ip network`, - }, { - want: netip.Prefix{}, - name: "bad_domain_name", - domain: "abc.123.", - wantErr: `bad domain name "abc.123": ` + - `bad top-level domain name label "123": all octets are numeric`, - }, { - want: v4Pref, - name: "whole_v4", - domain: v4Whole, - wantErr: "", - }, { - want: v4PrefPart, - name: "partial_v4", - domain: v4Part, - wantErr: "", - }, { - want: v4Pref, - name: "whole_v4_within_domain", - domain: "a." + v4Whole, - wantErr: "", - }, { - want: v4Pref, - name: "whole_v4_additional_label", - domain: "5." + v4Whole, - wantErr: "", - }, { - want: v4PrefPart, - name: "partial_v4_within_domain", - domain: "a." + v4Part, - wantErr: "", - }, { - want: v4PrefPart, - name: "overflow_v4", - domain: "256." + v4Part, - wantErr: "", - }, { - want: v4PrefPart, - name: "overflow_v4_within_domain", - domain: "a.256." + v4Part, - wantErr: "", - }, { - want: netip.Prefix{}, - name: "empty_v4", - domain: v4Suf, - wantErr: `bad arpa domain name "in-addr.arpa": ` + - `not a reversed ip network`, - }, { - want: netip.Prefix{}, - name: "empty_v4_within_domain", - domain: "a." + v4Suf, - wantErr: `bad arpa domain name "in-addr.arpa": ` + - `not a reversed ip network`, - }, { - want: v6Pref, - name: "whole_v6", - domain: v6Whole, - wantErr: "", - }, { - want: v6PrefPart, - name: "partial_v6", - domain: v6Part, - }, { - want: v6Pref, - name: "whole_v6_within_domain", - domain: "g." + v6Whole, - wantErr: "", - }, { - want: v6Pref, - name: "whole_v6_additional_label", - domain: "1." + v6Whole, - wantErr: "", - }, { - want: v6PrefPart, - name: "partial_v6_within_domain", - domain: "label." + v6Part, - wantErr: "", - }, { - want: netip.Prefix{}, - name: "empty_v6", - domain: v6Suf, - wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`, - }, { - want: netip.Prefix{}, - name: "empty_v6_within_domain", - domain: "g." + v6Suf, - wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - subnet, err := extractARPASubnet(tc.domain) - testutil.AssertErrorMsg(t, tc.wantErr, err) - assert.Equal(t, tc.want, subnet) - }) - } -} diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 220f151c..ffcbc6ef 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -29,7 +29,13 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) { log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr) - ids := []string{ipStr, dctx.clientID} + ids := []string{ipStr} + if dctx.clientID != "" { + // Use the ClientID first because it has a higher priority. Filters + // have the same priority, see applyAdditionalFiltering. + ids = []string{dctx.clientID, ipStr} + } + qt, cl := q.Qtype, q.Qclass // Synchronize access to s.queryLog and s.stats so they won't be suddenly @@ -124,7 +130,7 @@ func (s *Server) logQuery(dctx *dnsContext, ip net.IP, processingTime time.Durat s.queryLog.Add(p) } -// updatesStats writes the request into statistics. +// updateStats writes the request data into statistics. func (s *Server) updateStats(dctx *dnsContext, clientIP string, processingTime time.Duration) { pctx := dctx.proxyCtx diff --git a/internal/dnsforward/upstreams.go b/internal/dnsforward/upstreams.go index c7f6677e..0754daae 100644 --- a/internal/dnsforward/upstreams.go +++ b/internal/dnsforward/upstreams.go @@ -2,90 +2,77 @@ package dnsforward import ( "fmt" - "net/netip" - "os" "slices" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" - "golang.org/x/exp/maps" ) -// loadUpstreams parses upstream DNS servers from the configured file or from -// the configuration itself. -func (s *Server) loadUpstreams() (upstreams []string, err error) { - if s.conf.UpstreamDNSFileName == "" { - return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil +// newBootstrap returns a bootstrap resolver based on the configuration of s. +// boots are the upstream resolvers that should be closed after use. r is the +// actual bootstrap resolver, which may include the system hosts. +// +// TODO(e.burkov): This function currently returns a resolver and a slice of +// the upstream resolvers, which are essentially the same. boots are returned +// for being able to close them afterwards, but it introduces an implicit +// contract that r could only be used before that. Anyway, this code should +// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver] +// and be used here. +func newBootstrap( + addrs []string, + etcHosts upstream.Resolver, + opts *upstream.Options, +) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) { + if len(addrs) == 0 { + addrs = defaultBootstrap } - var data []byte - data, err = os.ReadFile(s.conf.UpstreamDNSFileName) + boots, err = aghnet.ParseBootstraps(addrs, opts) if err != nil { - return nil, fmt.Errorf("reading upstream from file: %w", err) + // Don't wrap the error, since it's informative enough as is. + return nil, nil, err } - upstreams = stringutil.SplitTrimmed(string(data), "\n") + var parallel upstream.ParallelResolver + for _, b := range boots { + parallel = append(parallel, upstream.NewCachingResolver(b)) + } - log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName) + if etcHosts != nil { + r = upstream.ConsequentResolver{etcHosts, parallel} + } else { + r = parallel + } - return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil + return r, boots, nil } -// prepareUpstreamSettings sets upstream DNS server settings. -func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { - // Load upstreams either from the file, or from the settings - var upstreams []string - upstreams, err = s.loadUpstreams() - if err != nil { - return fmt.Errorf("loading upstreams: %w", err) - } - - s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{ - Bootstrap: boot, - Timeout: s.conf.UpstreamTimeout, - HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), - PreferIPv6: s.conf.BootstrapPreferIPv6, - // Use a customized set of RootCAs, because Go's default mechanism of - // loading TLS roots does not always work properly on some routers so we're - // loading roots manually and pass it here. - // - // See [aghtls.SystemRootCAs]. - // - // TODO(a.garipov): Investigate if that's true. - RootCAs: s.conf.TLSv12Roots, - CipherSuites: s.conf.TLSCiphers, - }) - if err != nil { - return fmt.Errorf("preparing upstream config: %w", err) - } - - return nil -} - -// prepareUpstreamConfig returns the upstream configuration based on upstreams -// and configuration of s. -func (s *Server) prepareUpstreamConfig( +// newUpstreamConfig returns the upstream configuration based on upstreams. If +// upstreams slice specifies no default upstreams, defaultUpstreams are used to +// create upstreams with no domain specifications. opts are used when creating +// upstream configuration. +func newUpstreamConfig( upstreams []string, defaultUpstreams []string, opts *upstream.Options, ) (uc *proxy.UpstreamConfig, err error) { uc, err = proxy.ParseUpstreamsConfig(upstreams, opts) if err != nil { - return nil, fmt.Errorf("parsing upstream config: %w", err) + return uc, fmt.Errorf("parsing upstreams: %w", err) } - if len(uc.Upstreams) == 0 && defaultUpstreams != nil { + if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 { log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams) + var defaultUpstreamConfig *proxy.UpstreamConfig defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts) if err != nil { - return nil, fmt.Errorf("parsing default upstreams: %w", err) + return uc, fmt.Errorf("parsing default upstreams: %w", err) } uc.Upstreams = defaultUpstreamConfig.Upstreams @@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig( return uc, nil } +// newPrivateConfig creates an upstream configuration for resolving PTR records +// for local addresses. The configuration is built either from the provided +// addresses or from the system resolvers. unwanted filters the resulting +// upstream configuration. +func newPrivateConfig( + addrs []string, + unwanted addrPortSet, + sysResolvers SystemResolvers, + privateNets netutil.SubnetSet, + opts *upstream.Options, +) (uc *proxy.UpstreamConfig, err error) { + confNeedsFiltering := len(addrs) > 0 + if confNeedsFiltering { + addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty) + } else { + sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has) + addrs = make([]string, 0, len(sysResolvers)) + for _, r := range sysResolvers { + addrs = append(addrs, r.String()) + } + } + + log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs) + + uc, err = proxy.ParseUpstreamsConfig(addrs, opts) + if err != nil { + return uc, fmt.Errorf("preparing private upstreams: %w", err) + } + + if !confNeedsFiltering { + return uc, nil + } + + err = filterOutAddrs(uc, unwanted) + if err != nil { + return uc, fmt.Errorf("filtering private upstreams: %w", err) + } + + // Prevalidate the config to catch the exact error before creating proxy. + // See TODO on [PrivateRDNSError]. + err = proxy.ValidatePrivateConfig(uc, privateNets) + if err != nil { + return uc, &PrivateRDNSError{err: err} + } + + return uc, nil +} + // UpstreamHTTPVersions returns the HTTP versions for upstream configuration // depending on configuration. func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { @@ -130,85 +165,9 @@ func setProxyUpstreamMode( return nil } -// createBootstrap returns a bootstrap resolver based on the configuration of s. -// boots are the upstream resolvers that should be closed after use. r is the -// actual bootstrap resolver, which may include the system hosts. -// -// TODO(e.burkov): This function currently returns a resolver and a slice of -// the upstream resolvers, which are essentially the same. boots are returned -// for being able to close them afterwards, but it introduces an implicit -// contract that r could only be used before that. Anyway, this code should -// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver] -// and be used here. -func (s *Server) createBootstrap( - addrs []string, - opts *upstream.Options, -) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) { - if len(addrs) == 0 { - addrs = defaultBootstrap - } - - boots, err = aghnet.ParseBootstraps(addrs, opts) - if err != nil { - // Don't wrap the error, since it's informative enough as is. - return nil, nil, err - } - - var parallel upstream.ParallelResolver - for _, b := range boots { - parallel = append(parallel, upstream.NewCachingResolver(b)) - } - - if s.etcHosts != nil { - r = upstream.ConsequentResolver{s.etcHosts, parallel} - } else { - r = parallel - } - - return r, boots, nil -} - // IsCommentOrEmpty returns true if s starts with a "#" character or is empty. // This function is useful for filtering out non-upstream lines from upstream // configs. func IsCommentOrEmpty(s string) (ok bool) { return len(s) == 0 || s[0] == '#' } - -// ValidateUpstreamsPrivate validates each upstream and returns an error if any -// upstream is invalid or if there are no default upstreams specified. It also -// checks each domain of domain-specific upstreams for being ARPA pointing to -// a locally-served network. privateNets must not be nil. -func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { - conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{}) - if err != nil { - return fmt.Errorf("creating config: %w", err) - } - - if conf == nil { - return nil - } - - keys := maps.Keys(conf.DomainReservedUpstreams) - slices.Sort(keys) - - var errs []error - for _, domain := range keys { - var subnet netip.Prefix - subnet, err = extractARPASubnet(domain) - if err != nil { - errs = append(errs, err) - - continue - } - - if !privateNets.Contains(subnet.Addr()) { - errs = append( - errs, - fmt.Errorf("arpa domain %q should point to a locally-served network", domain), - ) - } - } - - return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w") -} diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 4ea57eea..55404b74 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -559,6 +559,8 @@ type Result struct { Reason Reason `json:",omitempty"` // IsFiltered is true if the request is filtered. + // + // TODO(d.kolyshev): Get rid of this flag. IsFiltered bool `json:",omitempty"` } diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 83018ab3..db625903 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -200,7 +200,7 @@ func TestParallelSB(t *testing.T) { t.Cleanup(d.Close) t.Run("group", func(t *testing.T) { - for i := 0; i < 100; i++ { + for i := range 100 { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Parallel() d.checkMatch(t, sbBlocked, setts) @@ -670,7 +670,7 @@ func BenchmarkSafeBrowsing(b *testing.B) { }, nil) b.Cleanup(d.Close) - for n := 0; n < b.N; n++ { + for range b.N { res, err := d.CheckHost(sbBlocked, dns.TypeA, setts) require.NoError(b, err) diff --git a/internal/filtering/idgenerator_internal_test.go b/internal/filtering/idgenerator_internal_test.go index 28dc5dea..57af4ad1 100644 --- a/internal/filtering/idgenerator_internal_test.go +++ b/internal/filtering/idgenerator_internal_test.go @@ -63,8 +63,6 @@ func TestIDGenerator_Fix(t *testing.T) { }} for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { g := newIDGenerator(1) g.fix(tc.in) diff --git a/internal/filtering/rulelist/engine_test.go b/internal/filtering/rulelist/engine_test.go index 81ab8bf8..9eeda15b 100644 --- a/internal/filtering/rulelist/engine_test.go +++ b/internal/filtering/rulelist/engine_test.go @@ -1,7 +1,6 @@ package rulelist_test import ( - "context" "net/http" "testing" @@ -28,14 +27,12 @@ func TestEngine_Refresh(t *testing.T) { require.NotNil(t, eng) testutil.CleanupAndRequireSuccess(t, eng.Close) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - t.Cleanup(cancel) - buf := make([]byte, rulelist.DefaultRuleBufSize) cli := &http.Client{ Timeout: testTimeout, } + ctx := testutil.ContextWithTimeout(t, testTimeout) err := eng.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) require.NoError(t, err) diff --git a/internal/filtering/rulelist/filter_test.go b/internal/filtering/rulelist/filter_test.go index 05c1274c..21709583 100644 --- a/internal/filtering/rulelist/filter_test.go +++ b/internal/filtering/rulelist/filter_test.go @@ -1,7 +1,6 @@ package rulelist_test import ( - "context" "net/http" "net/url" "os" @@ -67,14 +66,12 @@ func TestFilter_Refresh(t *testing.T) { require.NotNil(t, f) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - t.Cleanup(cancel) - buf := make([]byte, rulelist.DefaultRuleBufSize) cli := &http.Client{ Timeout: testTimeout, } + ctx := testutil.ContextWithTimeout(t, testTimeout) res, err := f.Refresh(ctx, buf, cli, cacheDir, rulelist.DefaultMaxRuleListSize) require.NoError(t, err) diff --git a/internal/filtering/rulelist/parser_test.go b/internal/filtering/rulelist/parser_test.go index 45a8e465..f29c6288 100644 --- a/internal/filtering/rulelist/parser_test.go +++ b/internal/filtering/rulelist/parser_test.go @@ -132,7 +132,6 @@ func TestParser_Parse(t *testing.T) { }} for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() @@ -216,7 +215,7 @@ func BenchmarkParser_Parse(b *testing.B) { b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { resSink, errSink = p.Parse(dst, src, buf) dst.Reset() } diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index 003b9ee1..39c05140 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -1,7 +1,5 @@ package filtering -import "github.com/miekg/dns" - // SafeSearch interface describes a service for search engines hosts rewrites. type SafeSearch interface { // CheckHost checks host with safe search filter. CheckHost must be safe @@ -16,9 +14,6 @@ type SafeSearch interface { // SafeSearchConfig is a struct with safe search related settings. type SafeSearchConfig struct { - // CustomResolver is the resolver used by safe search. - CustomResolver Resolver `yaml:"-" json:"-"` - // Enabled indicates if safe search is enabled entirely. Enabled bool `yaml:"enabled" json:"enabled"` @@ -40,13 +35,7 @@ func (d *DNSFilter) checkSafeSearch( qtype uint16, setts *Settings, ) (res Result, err error) { - if !setts.ProtectionEnabled || - !setts.SafeSearchEnabled || - (qtype != dns.TypeA && qtype != dns.TypeAAAA) { - return Result{}, nil - } - - if d.safeSearch == nil { + if d.safeSearch == nil || !setts.ProtectionEnabled || !setts.SafeSearchEnabled { return Result{}, nil } diff --git a/internal/filtering/safesearch/safesearch.go b/internal/filtering/safesearch/safesearch.go index 50c0d187..7ea1e3ad 100644 --- a/internal/filtering/safesearch/safesearch.go +++ b/internal/filtering/safesearch/safesearch.go @@ -3,11 +3,9 @@ package safesearch import ( "bytes" - "context" "encoding/binary" "encoding/gob" "fmt" - "net" "net/netip" "strings" "sync" @@ -67,7 +65,6 @@ type Default struct { engine *urlfilter.DNSEngine cache cache.Cache - resolver filtering.Resolver logPrefix string cacheTTL time.Duration } @@ -80,11 +77,6 @@ func NewDefault( cacheSize uint, cacheTTL time.Duration, ) (ss *Default, err error) { - var resolver filtering.Resolver = net.DefaultResolver - if conf.CustomResolver != nil { - resolver = conf.CustomResolver - } - ss = &Default{ mu: &sync.RWMutex{}, @@ -92,7 +84,6 @@ func NewDefault( EnableLRU: true, MaxSize: cacheSize, }), - resolver: resolver, // Use %s, because the client safe-search names already contain double // quotes. logPrefix: fmt.Sprintf("safesearch %s: ", name), @@ -170,8 +161,11 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res ss.log(log.DEBUG, "lookup for %q finished in %s", host, time.Since(start)) }() - if qtype != dns.TypeA && qtype != dns.TypeAAAA { - return filtering.Result{}, fmt.Errorf("unsupported question type %s", dns.Type(qtype)) + switch qtype { + case dns.TypeA, dns.TypeAAAA, dns.TypeHTTPS: + // Go on. + default: + return filtering.Result{}, nil } // Check cache. Return cached result if it was found @@ -195,6 +189,9 @@ func (ss *Default) CheckHost(host string, qtype rules.RRType) (res filtering.Res } res = *fltRes + + // TODO(a.garipov): Consider switch back to resolving CNAME records IPs and + // saving results to cache. ss.setCacheResult(host, qtype, res) return res, nil @@ -223,20 +220,13 @@ func (ss *Default) searchHost(host string, qtype rules.RRType) (res *rules.DNSRe } // newResult creates Result object from rewrite rule. qtype must be either -// [dns.TypeA] or [dns.TypeAAAA]. If err is nil, res is never nil, so that the -// empty result is converted into a NODATA response. -// -// TODO(a.garipov): Use the main rewrite result mechanism used in -// [dnsforward.Server.filterDNSRequest]. Now we resolve IPs for CNAME to save -// them in the safe search cache. +// [dns.TypeA] or [dns.TypeAAAA], or [dns.TypeHTTPS]. If err is nil, res is +// never nil, so that the empty result is converted into a NODATA response. func (ss *Default) newResult( rewrite *rules.DNSRewrite, qtype rules.RRType, ) (res *filtering.Result, err error) { res = &filtering.Result{ - Rules: []*filtering.ResultRule{{ - FilterListID: rulelist.URLFilterIDSafeSearch, - }}, Reason: filtering.FilteredSafeSearch, IsFiltered: true, } @@ -247,69 +237,19 @@ func (ss *Default) newResult( return nil, fmt.Errorf("expected ip rewrite value, got %T(%[1]v)", rewrite.Value) } - res.Rules[0].IP = ip + res.Rules = []*filtering.ResultRule{{ + FilterListID: rulelist.URLFilterIDSafeSearch, + IP: ip, + }} return res, nil } - host := rewrite.NewCNAME - if host == "" { - return res, nil - } - - res.CanonName = host - - ss.log(log.DEBUG, "resolving %q", host) - - ips, err := ss.resolver.LookupIP(context.Background(), qtypeToProto(qtype), host) - if err != nil { - return nil, fmt.Errorf("resolving cname: %w", err) - } - - ss.log(log.DEBUG, "resolved %s", ips) - - for _, ip := range ips { - // TODO(a.garipov): Remove this filtering once the resolver we use - // actually learns about network. - addr := fitToProto(ip, qtype) - if addr == (netip.Addr{}) { - continue - } - - // TODO(e.burkov): Rules[0]? - res.Rules[0].IP = addr - } + res.CanonName = rewrite.NewCNAME return res, nil } -// qtypeToProto returns "ip4" for [dns.TypeA] and "ip6" for [dns.TypeAAAA]. -// It panics for other types. -func qtypeToProto(qtype rules.RRType) (proto string) { - switch qtype { - case dns.TypeA: - return "ip4" - case dns.TypeAAAA: - return "ip6" - default: - panic(fmt.Errorf("safesearch: unsupported question type %s", dns.Type(qtype))) - } -} - -// fitToProto returns a non-nil IP address if ip is the correct protocol version -// for qtype. qtype is expected to be either [dns.TypeA] or [dns.TypeAAAA]. -func fitToProto(ip net.IP, qtype rules.RRType) (res netip.Addr) { - if ip4 := ip.To4(); qtype == dns.TypeA { - if ip4 != nil { - return netip.AddrFrom4([4]byte(ip4)) - } - } else if ip = ip.To16(); ip != nil && qtype == dns.TypeAAAA { - return netip.AddrFrom16([16]byte(ip)) - } - - return netip.Addr{} -} - // setCacheResult stores data in cache for host. qtype is expected to be either // [dns.TypeA] or [dns.TypeAAAA]. func (ss *Default) setCacheResult(host string, qtype rules.RRType, res filtering.Result) { diff --git a/internal/filtering/safesearch/safesearch_internal_test.go b/internal/filtering/safesearch/safesearch_internal_test.go index ae4e380d..c6790dc3 100644 --- a/internal/filtering/safesearch/safesearch_internal_test.go +++ b/internal/filtering/safesearch/safesearch_internal_test.go @@ -1,13 +1,10 @@ package safesearch import ( - "context" - "net" "net/netip" "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -79,47 +76,6 @@ func TestSafeSearchCacheYandex(t *testing.T) { assert.Equal(t, cachedValue.Rules[0].IP, yandexIP) } -func TestSafeSearchCacheGoogle(t *testing.T) { - const domain = "www.google.ru" - - ss := newForTest(t, filtering.SafeSearchConfig{Enabled: false}) - - res, err := ss.CheckHost(domain, testQType) - require.NoError(t, err) - - assert.False(t, res.IsFiltered) - assert.Empty(t, res.Rules) - - resolver := &aghtest.Resolver{ - OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { - ip4, ip6 := aghtest.HostToIPs(host) - - return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil - }, - } - - ss = newForTest(t, defaultSafeSearchConf) - ss.resolver = resolver - - // Lookup for safesearch domain. - rewrite := ss.searchHost(domain, testQType) - - wantIP, _ := aghtest.HostToIPs(rewrite.NewCNAME) - - res, err = ss.CheckHost(domain, testQType) - require.NoError(t, err) - require.Len(t, res.Rules, 1) - - assert.Equal(t, wantIP, res.Rules[0].IP) - - // Check cache. - cachedValue, isFound := ss.getCachedResult(domain, testQType) - require.True(t, isFound) - require.Len(t, cachedValue.Rules, 1) - - assert.Equal(t, wantIP, cachedValue.Rules[0].IP) -} - const googleHost = "www.google.com" var dnsRewriteSink *rules.DNSRewrite @@ -127,7 +83,7 @@ var dnsRewriteSink *rules.DNSRewrite func BenchmarkSafeSearch(b *testing.B) { ss := newForTest(b, defaultSafeSearchConf) - for n := 0; n < b.N; n++ { + for range b.N { dnsRewriteSink = ss.searchHost(googleHost, testQType) } diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index 127b2ae1..9526c791 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" @@ -31,8 +30,6 @@ const ( // testConf is the default safe search configuration for tests. var testConf = filtering.SafeSearchConfig{ - CustomResolver: nil, - Enabled: true, Bing: true, @@ -52,61 +49,60 @@ func TestDefault_CheckHost_yandex(t *testing.T) { ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) require.NoError(t, err) - // Check host for each domain. - for _, host := range []string{ + hosts := []string{ "yandex.ru", "yAndeX.ru", "YANdex.COM", "yandex.by", "yandex.kz", "www.yandex.com", - } { - var res filtering.Result - res, err = ss.CheckHost(host, testQType) - require.NoError(t, err) - - assert.True(t, res.IsFiltered) - - require.Len(t, res.Rules, 1) - - assert.Equal(t, yandexIP, res.Rules[0].IP) - assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) } -} -func TestDefault_CheckHost_yandexAAAA(t *testing.T) { - conf := testConf - ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) - require.NoError(t, err) + testCases := []struct { + want netip.Addr + name string + qt uint16 + }{{ + want: yandexIP, + name: "a", + qt: dns.TypeA, + }, { + want: netip.Addr{}, + name: "aaaa", + qt: dns.TypeAAAA, + }, { + want: netip.Addr{}, + name: "https", + qt: dns.TypeHTTPS, + }} - res, err := ss.CheckHost("www.yandex.ru", dns.TypeAAAA) - require.NoError(t, err) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, host := range hosts { + // Check host for each domain. + var res filtering.Result + res, err = ss.CheckHost(host, tc.qt) + require.NoError(t, err) - assert.True(t, res.IsFiltered) + assert.True(t, res.IsFiltered) + assert.Equal(t, filtering.FilteredSafeSearch, res.Reason) - // TODO(a.garipov): Currently, the safe-search filter returns a single rule - // with a nil IP address. This isn't really necessary and should be changed - // once the TODO in [safesearch.Default.newResult] is resolved. - require.Len(t, res.Rules, 1) + if tc.want == (netip.Addr{}) { + assert.Empty(t, res.Rules) + } else { + require.Len(t, res.Rules, 1) - assert.Empty(t, res.Rules[0].IP) - assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) + rule := res.Rules[0] + assert.Equal(t, tc.want, rule.IP) + assert.Equal(t, rulelist.URLFilterIDSafeSearch, rule.FilterListID) + } + } + }) + } } func TestDefault_CheckHost_google(t *testing.T) { - resolver := &aghtest.Resolver{ - OnLookupIP: func(_ context.Context, _, host string) (ips []net.IP, err error) { - ip4, ip6 := aghtest.HostToIPs(host) - - return []net.IP{ip4.AsSlice(), ip6.AsSlice()}, nil - }, - } - - wantIP, _ := aghtest.HostToIPs("forcesafesearch.google.com") - - conf := testConf - conf.CustomResolver = resolver - ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) + ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) require.NoError(t, err) // Check host for each domain. @@ -125,11 +121,9 @@ func TestDefault_CheckHost_google(t *testing.T) { require.NoError(t, err) assert.True(t, res.IsFiltered) - - require.Len(t, res.Rules, 1) - - assert.Equal(t, wantIP, res.Rules[0].IP) - assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) + assert.Equal(t, filtering.FilteredSafeSearch, res.Reason) + assert.Equal(t, "forcesafesearch.google.com", res.CanonName) + assert.Empty(t, res.Rules) }) } } @@ -154,17 +148,7 @@ func (r *testResolver) LookupIP( } func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { - conf := testConf - conf.CustomResolver = &testResolver{ - OnLookupIP: func(_ context.Context, network, host string) (ips []net.IP, err error) { - assert.Equal(t, "ip6", network) - assert.Equal(t, "safe.duckduckgo.com", host) - - return nil, nil - }, - } - - ss, err := safesearch.NewDefault(conf, "", testCacheSize, testCacheTTL) + ss, err := safesearch.NewDefault(testConf, "", testCacheSize, testCacheTTL) require.NoError(t, err) // The DuckDuckGo safe-search addresses are resolved through CNAMEs, but @@ -174,14 +158,9 @@ func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { require.NoError(t, err) assert.True(t, res.IsFiltered) - - // TODO(a.garipov): Currently, the safe-search filter returns a single rule - // with a nil IP address. This isn't really necessary and should be changed - // once the TODO in [safesearch.Default.newResult] is resolved. - require.Len(t, res.Rules, 1) - - assert.Empty(t, res.Rules[0].IP) - assert.Equal(t, rulelist.URLFilterIDSafeSearch, res.Rules[0].FilterListID) + assert.Equal(t, filtering.FilteredSafeSearch, res.Reason) + assert.Equal(t, "safe.duckduckgo.com", res.CanonName) + assert.Empty(t, res.Rules) } func TestDefault_Update(t *testing.T) { diff --git a/internal/home/clients.go b/internal/home/clients.go index 2d5b1231..4f3870ec 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -24,7 +24,6 @@ import ( "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" - "golang.org/x/exp/maps" ) // DHCP is an interface for accessing DHCP lease data the [clientsContainer] @@ -46,22 +45,20 @@ type DHCP interface { // clientsContainer is the storage of all runtime and persistent clients. type clientsContainer struct { - // TODO(a.garipov): Perhaps use a number of separate indices for different - // types (string, netip.Addr, and so on). - list map[string]*client.Persistent // name -> client - + // clientIndex stores information about persistent clients. clientIndex *client.Index - // ipToRC maps IP addresses to runtime client information. - ipToRC map[netip.Addr]*client.Runtime + // runtimeIndex stores information about runtime clients. + runtimeIndex *client.RuntimeIndex allTags *container.MapSet[string] // dhcp is the DHCP service implementation. dhcp DHCP - // dnsServer is used for checking clients IP status access list status - dnsServer *dnsforward.Server + // clientChecker checks if a client is blocked by the current access + // settings. + clientChecker BlockedClientChecker // etcHosts contains list of rewrite rules taken from the operating system's // hosts database. @@ -90,6 +87,12 @@ type clientsContainer struct { testing bool } +// BlockedClientChecker checks if a client is blocked by the current access +// settings. +type BlockedClientChecker interface { + IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) +} + // Init initializes clients container // dhcpServer: optional // Note: this function must be called only once @@ -100,12 +103,12 @@ func (clients *clientsContainer) Init( arpDB arpdb.Interface, filteringConf *filtering.Config, ) (err error) { - if clients.list != nil { - log.Fatal("clients.list != nil") + // TODO(s.chzhen): Refactor it. + if clients.clientIndex != nil { + return errors.Error("clients container already initialized") } - clients.list = map[string]*client.Persistent{} - clients.ipToRC = map[netip.Addr]*client.Runtime{} + clients.runtimeIndex = client.NewRuntimeIndex() clients.clientIndex = client.NewIndex() @@ -248,8 +251,6 @@ func (o *clientObject) toPersistent( } if o.SafeSearchConf.Enabled { - o.SafeSearchConf.CustomResolver = safeSearchResolver{} - err = cli.SetSafeSearch( o.SafeSearchConf, filteringConf.SafeSearchCacheSize, @@ -285,9 +286,17 @@ func (clients *clientsContainer) addFromConfig( return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) } - _, err = clients.add(cli) + // TODO(s.chzhen): Consider moving to the client index constructor. + err = clients.clientIndex.ClashesUID(cli) if err != nil { - log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err) + return fmt.Errorf("adding client %s at index %d: %w", cli.Name, i, err) + } + + err = clients.add(cli) + if err != nil { + // TODO(s.chzhen): Return an error instead of logging if more + // stringent requirements are implemented. + log.Error("clients: adding client %s at index %d: %s", cli.Name, i, err) } } @@ -300,9 +309,9 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { clients.lock.Lock() defer clients.lock.Unlock() - objs = make([]*clientObject, 0, len(clients.list)) - for _, cli := range clients.list { - o := &clientObject{ + objs = make([]*clientObject, 0, clients.clientIndex.Size()) + clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) { + objs = append(objs, &clientObject{ Name: cli.Name, BlockedServices: cli.BlockedServices.Clone(), @@ -323,10 +332,10 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { IgnoreStatistics: cli.IgnoreStatistics, UpstreamsCacheEnabled: cli.UpstreamsCacheEnabled, UpstreamsCacheSize: cli.UpstreamsCacheSize, - } + }) - objs = append(objs, o) - } + return true + }) // Maps aren't guaranteed to iterate in the same order each time, so the // above loop can generate different orderings when writing to the config @@ -363,8 +372,8 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) return client.SourcePersistent } - rc, ok := clients.ipToRC[ip] - if ok { + rc := clients.runtimeIndex.Client(ip) + if rc != nil { src, _ = rc.Info() } @@ -406,23 +415,26 @@ func (clients *clientsContainer) clientOrArtificial( id string, ) (c *querylog.Client, art bool) { defer func() { - c.Disallowed, c.DisallowedRule = clients.dnsServer.IsBlockedClient(ip, id) + c.Disallowed, c.DisallowedRule = clients.clientChecker.IsBlockedClient(ip, id) if c.WHOIS == nil { c.WHOIS = &whois.Info{} } }() cli, ok := clients.find(id) - if ok { + if !ok { + cli = clients.clientIndex.FindByIPWithoutZone(ip) + } + + if cli != nil { return &querylog.Client{ Name: cli.Name, IgnoreQueryLog: cli.IgnoreQueryLog, }, false } - var rc *client.Runtime - rc, ok = clients.findRuntimeClient(ip) - if ok { + rc := clients.findRuntimeClient(ip) + if rc != nil { _, host := rc.Info() return &querylog.Client{ @@ -542,47 +554,38 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, return nil, false } - for _, c = range clients.list { - _, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr]) - if found { - return c, true - } - } - - return nil, false + return clients.clientIndex.FindByMAC(foundMAC) } // runtimeClient returns a runtime client from internal index. Note that it // doesn't include DHCP clients. -func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { +func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) { if ip == (netip.Addr{}) { - return nil, false + return nil } clients.lock.Lock() defer clients.lock.Unlock() - rc, ok = clients.ipToRC[ip] - - return rc, ok + return clients.runtimeIndex.Client(ip) } // findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime, ok bool) { - rc, ok = clients.runtimeClient(ip) +func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { + rc = clients.runtimeClient(ip) host := clients.dhcp.HostByIP(ip) if host != "" { - if !ok { - rc = &client.Runtime{} + if rc == nil { + rc = client.NewRuntime(ip) } rc.SetInfo(client.SourceDHCP, []string{host}) - return rc, true + return rc } - return rc, ok + return rc } // check validates the client. It also sorts the client tags. @@ -615,43 +618,32 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) { return nil } -// add adds a new client object. ok is false if such client already exists or -// if an error occurred. -func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) { +// add adds a persistent client or returns an error. +func (clients *clientsContainer) add(c *client.Persistent) (err error) { err = clients.check(c) if err != nil { - return false, err + // Don't wrap the error since it's informative enough as is. + return err } clients.lock.Lock() defer clients.lock.Unlock() - // check Name index - _, ok = clients.list[c.Name] - if ok { - return false, nil - } - - // check ID index err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. - return false, err + return err } clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size()) - return true, nil + return nil } // addLocked c to the indexes. clients.lock is expected to be locked. func (clients *clientsContainer) addLocked(c *client.Persistent) { - // update Name index - clients.list[c.Name] = c - - // update ID index clients.clientIndex.Add(c) } @@ -660,8 +652,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - var c *client.Persistent - c, ok = clients.list[name] + c, ok := clients.clientIndex.FindByName(name) if !ok { return false } @@ -678,9 +669,6 @@ func (clients *clientsContainer) removeLocked(c *client.Persistent) { log.Error("client container: removing client %s: %s", c.Name, err) } - // Update the name index. - delete(clients.list, c.Name) - // Update the ID index. clients.clientIndex.Delete(c) } @@ -696,22 +684,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) clients.lock.Lock() defer clients.lock.Unlock() - // Check the name index. - if prev.Name != c.Name { - _, ok := clients.list[c.Name] - if ok { - return errors.Error("client already exists") - } - } - - if c.EqualIDs(prev) { - clients.removeLocked(prev) - clients.addLocked(c) - - return nil - } - - // Check the ID index. err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. @@ -734,12 +706,12 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { return } - rc, ok := clients.ipToRC[ip] - if !ok { + rc := clients.runtimeIndex.Client(ip) + if rc == nil { // Create a RuntimeClient implicitly so that we don't do this check // again. - rc = &client.Runtime{} - clients.ipToRC[ip] = rc + rc = client.NewRuntime(ip) + clients.runtimeIndex.Add(rc) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } else { @@ -798,61 +770,54 @@ func (clients *clientsContainer) addHostLocked( host string, src client.Source, ) (ok bool) { - rc, ok := clients.ipToRC[ip] - if !ok { + rc := clients.runtimeIndex.Client(ip) + if rc == nil { if src < client.SourceDHCP { if clients.dhcp.HostByIP(ip) != "" { return false } } - rc = &client.Runtime{} - clients.ipToRC[ip] = rc + rc = client.NewRuntime(ip) + clients.runtimeIndex.Add(rc) } rc.SetInfo(src, []string{host}) - log.Debug("clients: adding client info %s -> %q %q [%d]", ip, src, host, len(clients.ipToRC)) + log.Debug( + "clients: adding client info %s -> %q %q [%d]", + ip, + src, + host, + clients.runtimeIndex.Size(), + ) return true } -// rmHostsBySrc removes all entries that match the specified source. -func (clients *clientsContainer) rmHostsBySrc(src client.Source) { - n := 0 - for ip, rc := range clients.ipToRC { - rc.Unset(src) - if rc.IsEmpty() { - delete(clients.ipToRC, ip) - n++ - } - } - - log.Debug("clients: removed %d client aliases", n) -} - // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceHostsFile) + deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile) + log.Debug("clients: removed %d client aliases from system hosts file", deleted) - n := 0 + added := 0 hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { - n++ + added++ } return true }) - log.Debug("clients: added %d client aliases from system hosts file", n) + log.Debug("clients: added %d client aliases from system hosts file", added) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a @@ -876,7 +841,8 @@ func (clients *clientsContainer) addFromSystemARP() { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHostsBySrc(client.SourceARP) + deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP) + log.Debug("clients: removed %d client aliases from arp neighborhood", deleted) added := 0 for _, n := range ns { @@ -891,18 +857,5 @@ func (clients *clientsContainer) addFromSystemARP() { // close gracefully closes all the client-specific upstream configurations of // the persistent clients. func (clients *clientsContainer) close() (err error) { - persistent := maps.Values(clients.list) - slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) { - return strings.Compare(a.Name, b.Name) - }) - - var errs []error - - for _, cli := range persistent { - if err = cli.CloseUpstreams(); err != nil { - errs = append(errs, err) - } - } - - return errors.Join(errs...) + return clients.clientIndex.CloseUpstreams() } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 4f9cb946..d371df7b 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -41,7 +41,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { } dhcp := &testDHCP{ - OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, + OnLeases: func() (leases []*dhcpsvc.Lease) { return nil }, OnHostBy: func(ip netip.Addr) (host string) { return "" }, OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil }, } @@ -72,23 +72,19 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli1IP, cliIPv6}, } - ok, err := clients.add(c) + err := clients.add(c) require.NoError(t, err) - assert.True(t, ok) - c = &client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{cli2IP}, } - ok, err = clients.add(c) + err = clients.add(c) require.NoError(t, err) - assert.True(t, ok) - - c, ok = clients.find(cli1) + c, ok := clients.find(cli1) require.True(t, ok) assert.Equal(t, "client1", c.Name) @@ -111,22 +107,20 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_name", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) - require.NoError(t, err) - assert.False(t, ok) + require.Error(t, err) }) t.Run("add_fail_ip", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), }) require.Error(t, err) - assert.False(t, ok) }) t.Run("update_fail_ip", func(t *testing.T) { @@ -145,12 +139,13 @@ func TestClients(t *testing.T) { cliNewIP = netip.MustParseAddr(cliNew) ) - prev, ok := clients.list["client1"] + prev, ok := clients.clientIndex.FindByName("client1") require.True(t, ok) + require.NotNil(t, prev) err := clients.update(prev, &client.Persistent{ Name: "client1", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -160,12 +155,13 @@ func TestClients(t *testing.T) { assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent) - prev, ok = clients.list["client1"] + prev, ok = clients.clientIndex.FindByName("client1") require.True(t, ok) + require.NotNil(t, prev) err = clients.update(prev, &client.Persistent{ Name: "client1-renamed", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -177,7 +173,7 @@ func TestClients(t *testing.T) { assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) - nilCli, ok := clients.list["client1"] + nilCli, ok := clients.clientIndex.FindByName("client1") require.False(t, ok) assert.Nil(t, nilCli) @@ -244,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -256,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -265,16 +261,15 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) - assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.ipToRC[ip] + rc := clients.runtimeIndex.Client(ip) require.Nil(t, rc) assert.True(t, clients.remove("client1")) @@ -288,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -296,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) { MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, }) require.NoError(t, err) - assert.True(t, ok) // Now add an auto-client with the same IP. - ok = clients.addHost(ip, "test", client.SourceRDNS) + ok := clients.addHost(ip, "test", client.SourceRDNS) assert.True(t, ok) }) @@ -339,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - ok, err := clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) - assert.True(t, ok) // Add a new client with the IP from the first client's IP range. - ok, err = clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) - assert.True(t, ok) }) } @@ -362,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, @@ -372,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) { }, }) require.NoError(t, err) - assert.True(t, ok) upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) assert.Nil(t, upsConf) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index b2270416..40a91f86 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -96,22 +96,26 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http clients.lock.Lock() defer clients.lock.Unlock() - for _, c := range clients.list { + clients.clientIndex.Range(func(c *client.Persistent) (cont bool) { cj := clientToJSON(c) data.Clients = append(data.Clients, cj) - } - for ip, rc := range clients.ipToRC { + return true + }) + + clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ WHOIS: whoisOrEmpty(rc), Name: host, Source: src, - IP: ip, + IP: rc.Addr(), } data.RuntimeClients = append(data.RuntimeClients, cj) - } + + return true + }) for _, l := range clients.dhcp.Leases() { cj := runtimeClientJSON{ @@ -332,20 +336,16 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - ok, err := clients.add(c) + err = clients.add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - if !ok { - aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists") - - return + if !clients.testing { + onConfigModified() } - - onConfigModified() } // handleDelClient is the handler for POST /control/clients/delete HTTP API. @@ -370,7 +370,9 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - onConfigModified() + if !clients.testing { + onConfigModified() + } } // updateJSON contains the name and data of the updated persistent client. @@ -404,7 +406,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht clients.lock.Lock() defer clients.lock.Unlock() - prev, ok = clients.list[dj.Name] + prev, ok = clients.clientIndex.FindByName(dj.Name) }() if !ok { @@ -427,14 +429,16 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - onConfigModified() + if !clients.testing { + onConfigModified() + } } // handleFindClient is the handler for GET /control/clients/find HTTP API. func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() data := []map[string]*clientJSON{} - for i := 0; i < len(q); i++ { + for i := range len(q) { idStr := q.Get(fmt.Sprintf("ip%d", i)) if idStr == "" { break @@ -447,7 +451,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http cj = clients.findRuntime(ip, idStr) } else { cj = clientToJSON(c) - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule } @@ -463,14 +467,14 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { - rc, ok := clients.findRuntimeClient(ip) - if !ok { + rc := clients.findRuntimeClient(ip) + if rc == nil { // It is still possible that the IP used to be in the runtime clients // list, but then the server was reloaded. So, check the DNS server's // blocked IP list. // // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj = &clientJSON{ IDs: []string{idStr}, Disallowed: &disallowed, @@ -488,7 +492,7 @@ func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *c WHOIS: whoisOrEmpty(rc), } - disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + disallowed, rule := clients.clientChecker.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule return cj diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go new file mode 100644 index 00000000..dc1aa87d --- /dev/null +++ b/internal/home/clientshttp_internal_test.go @@ -0,0 +1,399 @@ +package home + +import ( + "bytes" + "cmp" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "slices" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + testClientIP1 = "1.1.1.1" + testClientIP2 = "2.2.2.2" +) + +// testBlockedClientChecker is a mock implementation of the +// [BlockedClientChecker] interface. +type testBlockedClientChecker struct { + onIsBlockedClient func(ip netip.Addr, clientiD string) (blocked bool, rule string) +} + +// type check +var _ BlockedClientChecker = (*testBlockedClientChecker)(nil) + +// IsBlockedClient implements the [BlockedClientChecker] interface for +// *testBlockedClientChecker. +func (c *testBlockedClientChecker) IsBlockedClient( + ip netip.Addr, + clientID string, +) (blocked bool, rule string) { + return c.onIsBlockedClient(ip, clientID) +} + +// newPersistentClient is a helper function that returns a persistent client +// with the specified name and newly generated UID. +func newPersistentClient(name string) (c *client.Persistent) { + return &client.Persistent{ + Name: name, + UID: client.MustNewUID(), + BlockedServices: &filtering.BlockedServices{ + Schedule: &schedule.Weekly{}, + }, + } +} + +// newPersistentClientWithIDs is a helper function that returns a persistent +// client with the specified name and ids. +func newPersistentClientWithIDs(tb testing.TB, name string, ids []string) (c *client.Persistent) { + tb.Helper() + + c = newPersistentClient(name) + err := c.SetIDs(ids) + require.NoError(tb, err) + + return c +} + +// assertClients is a helper function that compares lists of persistent clients. +func assertClients(tb testing.TB, want, got []*client.Persistent) { + tb.Helper() + + require.Len(tb, got, len(want)) + + sortFunc := func(a, b *client.Persistent) (n int) { + return cmp.Compare(a.Name, b.Name) + } + + slices.SortFunc(want, sortFunc) + slices.SortFunc(got, sortFunc) + + slices.CompareFunc(want, got, func(a, b *client.Persistent) (n int) { + assert.True(tb, a.EqualIDs(b), "%q doesn't have the same ids as %q", a.Name, b.Name) + + return 0 + }) +} + +// assertPersistentClients is a helper function that uses HTTP API to check +// whether want persistent clients are the same as the persistent clients stored +// in the clients container. +func assertPersistentClients(tb testing.TB, clients *clientsContainer, want []*client.Persistent) { + tb.Helper() + + rw := httptest.NewRecorder() + clients.handleGetClients(rw, &http.Request{}) + + body, err := io.ReadAll(rw.Body) + require.NoError(tb, err) + + clientList := &clientListJSON{} + err = json.Unmarshal(body, clientList) + require.NoError(tb, err) + + var got []*client.Persistent + for _, cj := range clientList.Clients { + var c *client.Persistent + c, err = clients.jsonToClient(*cj, nil) + require.NoError(tb, err) + + got = append(got, c) + } + + assertClients(tb, want, got) +} + +// assertPersistentClientsData is a helper function that checks whether want +// persistent clients are the same as the persistent clients stored in data. +func assertPersistentClientsData( + tb testing.TB, + clients *clientsContainer, + data []map[string]*clientJSON, + want []*client.Persistent, +) { + tb.Helper() + + var got []*client.Persistent + for _, cm := range data { + for _, cj := range cm { + var c *client.Persistent + c, err := clients.jsonToClient(*cj, nil) + require.NoError(tb, err) + + got = append(got, c) + } + } + + assertClients(tb, want, got) +} + +func TestClientsContainer_HandleAddClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + + clientEmptyID := newPersistentClient("empty_client_id") + clientEmptyID.ClientIDs = []string{""} + + testCases := []struct { + name string + client *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "add_one", + client: clientOne, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne}, + }, { + name: "add_two", + client: clientTwo, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }, { + name: "duplicate_client", + client: clientTwo, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }, { + name: "empty_client_id", + client: clientEmptyID, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cj := clientToJSON(tc.client) + + body, err := json.Marshal(cj) + require.NoError(t, err) + + r, err := http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleAddClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleDelClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + err = clients.add(clientTwo) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) + + testCases := []struct { + name string + client *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "remove_one", + client: clientOne, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "duplicate_client", + client: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "empty_client_name", + client: newPersistentClient(""), + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientTwo}, + }, { + name: "remove_two", + client: clientTwo, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cj := clientToJSON(tc.client) + + var body []byte + body, err = json.Marshal(cj) + require.NoError(t, err) + + var r *http.Request + r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleDelClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleUpdateClient(t *testing.T) { + clients := newClientsContainer(t) + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne}) + + clientModified := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + + clientEmptyID := newPersistentClient("empty_client_id") + clientEmptyID.ClientIDs = []string{""} + + testCases := []struct { + name string + clientName string + modified *client.Persistent + wantCode int + wantClient []*client.Persistent + }{{ + name: "update_one", + clientName: clientOne.Name, + modified: clientModified, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "empty_name", + clientName: "", + modified: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "client_not_found", + clientName: "client_not_found", + modified: clientOne, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "empty_client_id", + clientName: clientModified.Name, + modified: clientEmptyID, + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }, { + name: "no_ids", + clientName: clientModified.Name, + modified: newPersistentClient("no_ids"), + wantCode: http.StatusBadRequest, + wantClient: []*client.Persistent{clientModified}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + uj := updateJSON{ + Name: tc.clientName, + Data: *clientToJSON(tc.modified), + } + + var body []byte + body, err = json.Marshal(uj) + require.NoError(t, err) + + var r *http.Request + r, err = http.NewRequest(http.MethodPost, "", bytes.NewReader(body)) + require.NoError(t, err) + + rw := httptest.NewRecorder() + clients.handleUpdateClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + assertPersistentClients(t, clients, tc.wantClient) + }) + } +} + +func TestClientsContainer_HandleFindClient(t *testing.T) { + clients := newClientsContainer(t) + clients.clientChecker = &testBlockedClientChecker{ + onIsBlockedClient: func(ip netip.Addr, clientID string) (ok bool, rule string) { + return false, "" + }, + } + + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) + err := clients.add(clientOne) + require.NoError(t, err) + + clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) + err = clients.add(clientTwo) + require.NoError(t, err) + + assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) + + testCases := []struct { + name string + query url.Values + wantCode int + wantClient []*client.Persistent + }{{ + name: "single", + query: url.Values{ + "ip0": []string{testClientIP1}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne}, + }, { + name: "multiple", + query: url.Values{ + "ip0": []string{testClientIP1}, + "ip1": []string{testClientIP2}, + }, + wantCode: http.StatusOK, + wantClient: []*client.Persistent{clientOne, clientTwo}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var r *http.Request + r, err = http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + r.URL.RawQuery = tc.query.Encode() + rw := httptest.NewRecorder() + clients.handleFindClient(rw, r) + require.NoError(t, err) + require.Equal(t, tc.wantCode, rw.Code) + + var body []byte + body, err = io.ReadAll(rw.Body) + require.NoError(t, err) + + clientData := []map[string]*clientJSON{} + err = json.Unmarshal(body, &clientData) + require.NoError(t, err) + + assertPersistentClientsData(t, clients, clientData, tc.wantClient) + }) + } +} diff --git a/internal/home/config.go b/internal/home/config.go index b9200239..f20211d6 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -203,15 +203,24 @@ type dnsConfig struct { // resolver should be used. PrivateNets []netutil.Prefix `yaml:"private_networks"` - // UsePrivateRDNS defines if the PTR requests for unknown addresses from - // locally-served networks should be resolved via private PTR resolvers. + // UsePrivateRDNS enables resolving requests containing a private IP address + // using private reverse DNS resolvers. See PrivateRDNSResolvers. + // + // TODO(e.burkov): Rename in YAML. UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"` - // LocalPTRResolvers is the slice of addresses to be used as upstreams - // for PTR queries for locally-served networks. - LocalPTRResolvers []string `yaml:"local_ptr_upstreams"` + // PrivateRDNSResolvers is the slice of addresses to be used as upstreams + // for private requests. It's only used for PTR, SOA, and NS queries, + // containing an ARPA subdomain, came from the the client with private + // address. The address considered private according to PrivateNets. + // + // If empty, the OS-provided resolvers are used for private requests. + PrivateRDNSResolvers []string `yaml:"local_ptr_upstreams"` - // UseDNS64 defines if DNS64 should be used for incoming requests. + // UseDNS64 defines if DNS64 should be used for incoming requests. Requests + // of type PTR for addresses within the configured prefixes will be resolved + // via [PrivateRDNSResolvers], so those should be valid and UsePrivateRDNS + // be set to true. UseDNS64 bool `yaml:"use_dns64"` // DNS64Prefixes is the list of NAT64 prefixes to be used for DNS64. @@ -658,7 +667,7 @@ func (c *configuration) write() (err error) { dns := &config.DNS dns.Config = c - dns.LocalPTRResolvers = s.LocalPTRResolvers() + dns.PrivateRDNSResolvers = s.LocalPTRResolvers() addrProcConf := s.AddrProcConfig() config.Clients.Sources.RDNS = addrProcConf.UseRDNS diff --git a/internal/home/dns.go b/internal/home/dns.go index 8cab8156..d64effd5 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -1,7 +1,6 @@ package home import ( - "context" "fmt" "net" "net/netip" @@ -18,7 +17,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" @@ -150,21 +148,19 @@ func initDNSServer( return fmt.Errorf("dnsforward.NewServer: %w", err) } - Context.clients.dnsServer = Context.dnsServer + Context.clients.clientChecker = Context.dnsServer dnsConf, err := newServerConfig(&config.DNS, config.Clients.Sources, tlsConf, httpReg) if err != nil { return fmt.Errorf("newServerConfig: %w", err) } + // Try to prepare the server with disabled private RDNS resolution if it + // failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams]. err = Context.dnsServer.Prepare(dnsConf) + if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) { + log.Info("WARNING: %s; trying to disable private RDNS resolution", err) - // TODO(e.burkov): Recreate the server with private RDNS disabled. This - // should go away once the private RDNS resolution is moved to the proxy. - var locResErr *dnsforward.LocalResolversError - if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) { - log.Info("WARNING: no local resolvers configured while private RDNS " + - "resolution enabled, trying to disable") dnsConf.UsePrivateRDNS = false err = Context.dnsServer.Prepare(dnsConf) } @@ -245,7 +241,7 @@ func newServerConfig( TLSv12Roots: Context.tlsRoots, ConfigModified: onConfigModified, HTTPRegister: httpReg, - LocalPTRResolvers: dnsConf.LocalPTRResolvers, + LocalPTRResolvers: dnsConf.PrivateRDNSResolvers, UseDNS64: dnsConf.UseDNS64, DNS64Prefixes: dnsConf.DNS64Prefixes, UsePrivateRDNS: dnsConf.UsePrivateRDNS, @@ -531,36 +527,6 @@ func closeDNSServer() { log.Debug("all dns modules are closed") } -// safeSearchResolver is a [filtering.Resolver] implementation used for safe -// search. -type safeSearchResolver struct{} - -// type check -var _ filtering.Resolver = safeSearchResolver{} - -// LookupIP implements [filtering.Resolver] interface for safeSearchResolver. -// It returns the slice of net.Addr with IPv4 and IPv6 instances. -func (r safeSearchResolver) LookupIP( - ctx context.Context, - network string, - host string, -) (ips []net.IP, err error) { - addrs, err := Context.dnsServer.Resolve(ctx, network, host) - if err != nil { - return nil, err - } - - if len(addrs) == 0 { - return nil, fmt.Errorf("couldn't lookup host: %s", host) - } - - for _, a := range addrs { - ips = append(ips, a.AsSlice()) - } - - return ips, nil -} - // checkStatsAndQuerylogDirs checks and returns directory paths to store // statistics and query log. func checkStatsAndQuerylogDirs( diff --git a/internal/home/home.go b/internal/home/home.go index 8060467a..f7db7da0 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -439,7 +439,6 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) { conf.ParentalBlockHost = host } - conf.SafeSearchConf.CustomResolver = safeSearchResolver{} conf.SafeSearch, err = safesearch.NewDefault( conf.SafeSearchConf, "default", diff --git a/internal/home/log.go b/internal/home/log.go index c0c79fd5..efc90d3f 100644 --- a/internal/home/log.go +++ b/internal/home/log.go @@ -1,13 +1,13 @@ package home import ( + "cmp" "fmt" "path/filepath" "runtime" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" "gopkg.in/natefinch/lumberjack.v2" "gopkg.in/yaml.v3" ) @@ -76,8 +76,7 @@ func getLogSettings(opts options) (ls *logSettings) { ls.Verbose = true } - // TODO(a.garipov): Use cmp.Or in Go 1.22. - ls.File = stringutil.Coalesce(opts.logFile, ls.File) + ls.File = cmp.Or(opts.logFile, ls.File) if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" { // When running as a Windows service, use eventlog by default if diff --git a/internal/home/service.go b/internal/home/service.go index cb8f9b8b..30bef2a7 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -306,7 +306,7 @@ func handleServiceStatusCommand(s service.Service) { } } -// handleServiceStatusCommand handles service "install" command +// handleServiceInstallCommand handles service "install" command. func handleServiceInstallCommand(s service.Service) { err := svcAction(s, "install") if err != nil { @@ -340,7 +340,7 @@ AdGuard Home is now available at the following addresses:`) } } -// handleServiceStatusCommand handles service "uninstall" command +// handleServiceUninstallCommand handles service "uninstall" command. func handleServiceUninstallCommand(s service.Service) { if aghos.IsOpenWrt() { // On OpenWrt it is important to run disable command first @@ -649,11 +649,6 @@ status() { // freeBSDScript is the source of the daemon script for FreeBSD. Keep as close // as possible to the https://github.com/kardianos/service/blob/18c957a3dc1120a2efe77beb401d476bade9e577/service_freebsd.go#L204. -// -// TODO(a.garipov): Don't use .WorkingDirectory here. There are currently no -// guarantees that it will actually be the required directory. -// -// See https://github.com/AdguardTeam/AdGuardHome/issues/2614. const freeBSDScript = `#!/bin/sh # PROVIDE: {{.Name}} # REQUIRE: networking @@ -667,7 +662,9 @@ name="{{.Name}}" pidfile_child="/var/run/${name}.pid" pidfile="/var/run/${name}_daemon.pid" command="/usr/sbin/daemon" -command_args="-P ${pidfile} -p ${pidfile_child} -T ${name} -r {{.WorkingDirectory}}/{{.Name}}" +daemon_args="-P ${pidfile} -p ${pidfile_child} -r -t ${name}" +command_args="${daemon_args} {{.Path}}{{range .Arguments}} {{.}}{{end}}" + run_rc_command "$1" ` diff --git a/internal/home/service_openbsd.go b/internal/home/service_openbsd.go index 4f94f0b4..56f5c428 100644 --- a/internal/home/service_openbsd.go +++ b/internal/home/service_openbsd.go @@ -3,6 +3,7 @@ package home import ( + "cmp" "fmt" "os" "os/signal" @@ -14,7 +15,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" - "github.com/AdguardTeam/golibs/stringutil" "github.com/kardianos/service" ) @@ -76,7 +76,7 @@ func (*openbsdRunComService) Platform() (p string) { // String implements service.Service interface for *openbsdRunComService. func (s *openbsdRunComService) String() string { - return stringutil.Coalesce(s.cfg.DisplayName, s.cfg.Name) + return cmp.Or(s.cfg.DisplayName, s.cfg.Name) } // getBool returns the value of the given name from kv, assuming the value is a diff --git a/internal/ipset/ipset_linux_internal_test.go b/internal/ipset/ipset_linux_internal_test.go index f22d93c1..4d727ee7 100644 --- a/internal/ipset/ipset_linux_internal_test.go +++ b/internal/ipset/ipset_linux_internal_test.go @@ -147,7 +147,7 @@ func BenchmarkManager_LookupHost(b *testing.B) { b.Run("long", func(b *testing.B) { const name = "a.very.long.domain.name.inside.the.domain.example.com" - for i := 0; i < b.N; i++ { + for range b.N { ipsetPropsSink = m.lookupHost(name) } @@ -156,7 +156,7 @@ func BenchmarkManager_LookupHost(b *testing.B) { b.Run("short", func(b *testing.B) { const name = "example.net" - for i := 0; i < b.N; i++ { + for range b.N { ipsetPropsSink = m.lookupHost(name) } diff --git a/internal/next/cmd/signal.go b/internal/next/cmd/signal.go index b3bae338..a9f8543f 100644 --- a/internal/next/cmd/signal.go +++ b/internal/next/cmd/signal.go @@ -8,6 +8,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/osutil" "github.com/google/renameio/v2/maybe" ) @@ -38,7 +39,7 @@ func (h *signalHandler) handle() { if aghos.IsReconfigureSignal(sig) { h.reconfigure() - } else if aghos.IsShutdownSignal(sig) { + } else if osutil.IsShutdownSignal(sig) { status := h.shutdown() h.removePID() @@ -122,7 +123,8 @@ func newSignalHandler( services: svcs, } - aghos.NotifyShutdownSignal(h.signal) + notifier := osutil.DefaultSignalNotifier{} + osutil.NotifyShutdownSignal(notifier, h.signal) aghos.NotifyReconfigureSignal(h.signal) return h diff --git a/internal/next/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go index 48f49b8d..2a46d956 100644 --- a/internal/next/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -1,7 +1,6 @@ package dnssvc_test import ( - "context" "net/netip" "testing" "time" @@ -94,10 +93,8 @@ func TestService(t *testing.T) { }}, } - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() - cli := &dns.Client{} + ctx := testutil.ContextWithTimeout(t, testTimeout) var resp *dns.Msg require.Eventually(t, func() (ok bool) { @@ -110,10 +107,8 @@ func TestService(t *testing.T) { assert.NotNil(t, resp) }) - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - defer cancel() + err = svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) - err = svc.Shutdown(ctx) require.NoError(t, err) err = upstreamSrv.Shutdown() diff --git a/internal/next/websvc/websvc_test.go b/internal/next/websvc/websvc_test.go index 6a2505d5..cb4c6bc9 100644 --- a/internal/next/websvc/websvc_test.go +++ b/internal/next/websvc/websvc_test.go @@ -109,12 +109,8 @@ func newTestServer( err = svc.Start() require.NoError(t, err) - t.Cleanup(func() { - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) - t.Cleanup(cancel) - - err = svc.Shutdown(ctx) - require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return svc.Shutdown(testutil.ContextWithTimeout(t, testTimeout)) }) c = svc.Config() diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index 4fc1d244..1f907e3d 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -303,7 +303,7 @@ func BenchmarkAnonymizeIP(b *testing.B) { b.Run(bc.name, func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for range b.N { AnonymizeIP(bc.ip) } @@ -313,7 +313,7 @@ func BenchmarkAnonymizeIP(b *testing.B) { b.Run(bc.name+"_slow", func(b *testing.B) { b.ReportAllocs() - for i := 0; i < b.N; i++ { + for range b.N { anonymizeIPSlow(bc.ip) } diff --git a/internal/querylog/entry.go b/internal/querylog/entry.go index c3c800ed..ed3319b0 100644 --- a/internal/querylog/entry.go +++ b/internal/querylog/entry.go @@ -31,6 +31,7 @@ type logEntry struct { Answer []byte `json:",omitempty"` OrigAnswer []byte `json:",omitempty"` + // TODO(s.chzhen): Use netip.Addr. IP net.IP `json:"IP"` Result filtering.Result diff --git a/internal/querylog/qlog_test.go b/internal/querylog/qlog_test.go index 0b2a476b..57d8b68d 100644 --- a/internal/querylog/qlog_test.go +++ b/internal/querylog/qlog_test.go @@ -143,13 +143,13 @@ func TestQueryLogOffsetLimit(t *testing.T) { secondPageDomain = "second.example.org" ) // Add entries to the log. - for i := 0; i < entNum; i++ { + for range entNum { addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to the first file. require.NoError(t, l.flushLogBuffer()) // Add more to the in-memory part of log. - for i := 0; i < entNum; i++ { + for range entNum { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } @@ -215,7 +215,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) { const entNum = 10 // Add entries to the log. - for i := 0; i < entNum; i++ { + for range entNum { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to disk. diff --git a/internal/querylog/qlogfile_test.go b/internal/querylog/qlogfile_test.go index f91d3911..8462e950 100644 --- a/internal/querylog/qlogfile_test.go +++ b/internal/querylog/qlogfile_test.go @@ -37,7 +37,7 @@ func prepareTestFile(t *testing.T, dir string, linesNum int) (name string) { var lineIP uint32 lineTime := time.Date(2020, 2, 18, 19, 36, 35, 920973000, time.UTC) - for i := 0; i < linesNum; i++ { + for range linesNum { lineIP++ lineTime = lineTime.Add(time.Second) diff --git a/internal/stats/stats_internal_test.go b/internal/stats/stats_internal_test.go index 9081dd21..3423c7ad 100644 --- a/internal/stats/stats_internal_test.go +++ b/internal/stats/stats_internal_test.go @@ -68,13 +68,13 @@ func TestStats_races(t *testing.T) { startWG, finWG := &sync.WaitGroup{}, &sync.WaitGroup{} waitCh := make(chan unit) - for i := 0; i < writersNum; i++ { + for i := range writersNum { startWG.Add(1) finWG.Add(1) go writeFunc(startWG, finWG, waitCh, i) } - for i := 0; i < readersNum; i++ { + for range readersNum { startWG.Add(1) finWG.Add(1) go readFunc(startWG, finWG, waitCh) @@ -111,7 +111,7 @@ func TestStatsCtx_FillCollectedStats_daily(t *testing.T) { dailyData := []*unitDB{} - for i := 0; i < daysCount*24; i++ { + for i := range daysCount * 24 { n := uint64(i) nResult := make([]uint64, resultLast) nResult[RFiltered] = n diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index f04bdf11..2f7c526a 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -195,7 +195,7 @@ func TestLargeNumbers(t *testing.T) { for h := 0; h < hoursNum; h++ { atomic.AddUint32(&curHour, 1) - for i := 0; i < cliNumPerHour; i++ { + for i := range cliNumPerHour { ip := net.IP{127, 0, byte((i & 0xff00) >> 8), byte(i & 0xff)} e := &stats.Entry{ Domain: fmt.Sprintf("domain%d.hour%d", i, h), diff --git a/internal/stats/unit.go b/internal/stats/unit.go index e43152a4..621f1cda 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -525,9 +525,8 @@ func (s *StatsCtx) fillCollectedStatsDaily( hours := countHours(curHour, days) units = units[len(units)-hours:] - for i := 0; i < len(units); i++ { + for i, u := range units { day := i / 24 - u := units[i] data.DNSQueries[day] += u.NTotal data.BlockedFiltering[day] += u.NResult[RFiltered] diff --git a/internal/tools/go.mod b/internal/tools/go.mod index a6e26c14..a64d048d 100644 --- a/internal/tools/go.mod +++ b/internal/tools/go.mod @@ -1,6 +1,6 @@ module github.com/AdguardTeam/AdGuardHome/internal/tools -go 1.22.2 +go 1.22.3 require ( github.com/fzipp/gocyclo v0.6.0 diff --git a/internal/whois/whois.go b/internal/whois/whois.go index 37f1dec8..10f0609b 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -3,6 +3,7 @@ package whois import ( "bytes" + "cmp" "context" "fmt" "io" @@ -17,7 +18,6 @@ import ( "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" - "github.com/AdguardTeam/golibs/stringutil" "github.com/bluele/gcache" ) @@ -174,7 +174,7 @@ func whoisParse(data []byte, maxLen int) (info map[string]string) { val = trimValue(val, maxLen) case "descr", "netname": key = "orgname" - val = stringutil.Coalesce(orgname, val) + val = cmp.Or(orgname, val) orgname = val case "whois": key = "whois" @@ -232,7 +232,7 @@ func (w *Default) queryAll(ctx context.Context, target string) (info map[string] server := net.JoinHostPort(w.serverAddr, w.portStr) var data []byte - for i := 0; i < w.maxRedirects; i++ { + for range w.maxRedirects { data, err = w.query(ctx, target, server) if err != nil { // Don't wrap the error since it's informative enough as is. diff --git a/scripts/translations/download.go b/scripts/translations/download.go index d83f0bac..a7efc420 100644 --- a/scripts/translations/download.go +++ b/scripts/translations/download.go @@ -48,7 +48,7 @@ func (c *twoskyClient) download() (err error) { failed := &sync.Map{} uriCh := make(chan *url.URL, len(c.langs)) - for i := 0; i < numWorker; i++ { + for range numWorker { wg.Add(1) go downloadWorker(wg, failed, client, uriCh) } diff --git a/scripts/translations/main.go b/scripts/translations/main.go index e03dcb10..c5b1ef1e 100644 --- a/scripts/translations/main.go +++ b/scripts/translations/main.go @@ -5,6 +5,7 @@ package main import ( "bufio" "bytes" + "cmp" "encoding/json" "fmt" "net/url" @@ -204,19 +205,13 @@ type twoskyClient struct { func (t *twoskyConfig) toClient() (cli *twoskyClient, err error) { defer func() { err = errors.Annotate(err, "filling config: %w") }() - uriStr := os.Getenv("TWOSKY_URI") - if uriStr == "" { - uriStr = twoskyURI - } + uriStr := cmp.Or(os.Getenv("TWOSKY_URI"), twoskyURI) uri, err := url.Parse(uriStr) if err != nil { return nil, err } - projectID := os.Getenv("TWOSKY_PROJECT_ID") - if projectID == "" { - projectID = defaultProjectID - } + projectID := cmp.Or(os.Getenv("TWOSKY_PROJECT_ID"), defaultProjectID) baseLang := t.BaseLangcode uLangStr := os.Getenv("UPLOAD_LANGUAGE")