From 4f5131f423f0f16e2d71e97e7eb07b55bd4897e1 Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Thu, 2 Jun 2022 17:55:48 +0300 Subject: [PATCH] all: sync more --- CHANGELOG.md | 130 ++++----- internal/aghnet/subnetdetector.go | 156 ----------- internal/aghnet/subnetdetector_test.go | 252 ------------------ internal/aghtest/upstream.go | 83 ++---- internal/dhcpd/dhcpd.go | 53 ++-- internal/dhcpd/dhcpd_test.go | 22 +- internal/dhcpd/http.go | 15 +- internal/dnsforward/config.go | 5 +- internal/dnsforward/dns.go | 19 +- internal/dnsforward/dns_test.go | 26 +- internal/dnsforward/dnsforward.go | 20 +- internal/dnsforward/dnsforward_test.go | 112 ++++---- internal/dnsforward/filter.go | 41 +-- internal/dnsforward/filter_test.go | 154 +++++++++++ internal/dnsforward/http.go | 239 ++++++++++------- internal/dnsforward/http_test.go | 62 ++++- .../TestDNSForwardHTTP_handleSetConfig.json | 37 +++ internal/filtering/filtering.go | 21 +- internal/home/clients.go | 78 +++--- internal/home/clients_test.go | 26 +- internal/home/config.go | 45 ++-- internal/home/controlinstall.go | 16 ++ internal/home/dns.go | 64 ++++- internal/home/home.go | 23 +- internal/home/options.go | 12 +- internal/home/rdns_test.go | 2 +- internal/home/upgrade.go | 114 +++++++- internal/home/upgrade_test.go | 130 ++++++++- internal/home/web.go | 9 +- openapi/CHANGELOG.md | 4 +- openapi/openapi.yaml | 3 + 31 files changed, 1073 insertions(+), 900 deletions(-) delete mode 100644 internal/aghnet/subnetdetector.go delete mode 100644 internal/aghnet/subnetdetector_test.go create mode 100644 internal/dnsforward/filter_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 55a58c6f..0bdbd9e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,6 @@ and this project adheres to ### Security -- Enforced password strength policy ([#3503]). - Weaker cipher suites that use the CBC (cipher block chaining) mode of operation have been disabled ([#2993]). @@ -25,15 +24,65 @@ and this project adheres to - Support for Discovery of Designated Resolvers (DDR) according to the [RFC draft][ddr-draft-06] ([#4463]). +- `windows/arm64` support ([#3057]). + +### Deprecated + +- Go 1.17 support. v0.109.0 will require at least Go 1.18 to build. + +[#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 +[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 + +[ddr-draft-06]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html + + + + + + + +## [v0.107.7] - 2022-06-06 (APPROX.) + +See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. + +### Security + +- Go version was updated to prevent the possibility of exploiting the + [CVE-2022-29526], [CVE-2022-30634], [CVE-2022-30629], [CVE-2022-30580], and + [CVE-2022-29804] vulnerabilities. +- Enforced password strength policy ([#3503]). + +### Added + +- Support for the final DNS-over-QUIC standard, [RFC 9250][rfc-9250] ([#4592]). +- Support upstreams for subdomains of a domain only ([#4503]). - The ability to control each source of runtime clients separately via `clients.runtime_sources` configuration object ([#3020]). - The ability to customize the set of networks that are considered private through the new `dns.private_networks` property in the configuration file ([#3142]). -- `windows/arm64` support ([#3057]). +- EDNS Client-Subnet information in the request details section of a query log + record ([#3978]). +- Support for hostnames for plain UDP upstream servers using the `udp://` scheme + ([#4166]). +- Logs are now collected by default on FreeBSD and OpenBSD when AdGuard Home is + installed as a service ([#4213]). ### Changed +- On OpenBSD, the daemon script now uses the recommended `/bin/ksh` shell + instead of the `/bin/sh` one ([#4533]). To apply this change, backup your + data and run `AdGuardHome -s uninstall && AdGuardHome -s install`. +- The default DNS-over-QUIC port number is now `853` instead of `754` in + accordance with [RFC 9250][rfc-9250] ([#4276]). +- Reverse DNS now has a greater priority as the source of runtime clients' + information than ARP neighborhood. +- Improved detection of runtime clients through more resilient ARP processing + ([#3597]). +- The TTL of responses served from the optimistic cache is now lowered to 10 + seconds. - Domain-specific private reverse DNS upstream servers are now validated to allow only `*.in-addr.arpa` and `*.ip6.arpa` domains pointing to locally-served networks ([#3381]). **Note:** If you already have invalid @@ -41,8 +90,15 @@ and this project adheres to essentially had no effect. - Response filtering is now performed using the record types of the answer section of messages as opposed to the type of the question ([#4238]). +- Instead of adding the build time information, the build scripts now use the + standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date + of the commit from which the binary was built ([#4221]). This should simplify + reproducible builds for package maintainers and those who compile their own + AdGuard Home. - The property `local_domain_name` is now in the `dhcp` object in the configuration file to avoid confusion ([#3367]). +- The `dns.bogus_nxdomain` property in the configuration file now supports CIDR + notation alongside IP addresses ([#1730]). #### Configuration Changes @@ -96,71 +152,9 @@ In this release, the schema version has changed from 12 to 14. ### Deprecated -- The `--no-etc-hosts` option. Its' functionality is now controlled by +- The `--no-etc-hosts` option. Its functionality is now controlled by `clients.runtime_sources.hosts` configuration property. v0.109.0 will remove the flag completely. -- Go 1.17 support. v0.109.0 will require at least Go 1.18 to build. - -[#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 -[#3020]: https://github.com/AdguardTeam/AdGuardHome/issues/3020 -[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 -[#3142]: https://github.com/AdguardTeam/AdGuardHome/issues/3142 -[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 -[#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 -[#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 -[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238 - -[ddr-draft-06]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html - - - - - - - -## [v0.107.7] - 2022-06-03 (APPROX.) - -See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. - -### Security - -- Go version was updated to prevent the possibility of exploiting the - [CVE-2022-29526], [CVE-2022-30634], [CVE-2022-30629], [CVE-2022-30580], and - [CVE-2022-29804] vulnerabilities. - -### Added - -- Support for the final DNS-over-QUIC standard, [RFC 9250][rfc-9250] ([#4592]). -- Support upstreams for subdomains of a domain only ([#4503]). -- EDNS Client-Subnet information in the request details section of a query log - record ([#3978]). -- Support for hostnames for plain UDP upstream servers using the `udp://` scheme - ([#4166]). -- Logs are now collected by default on FreeBSD and OpenBSD when AdGuard Home is - installed as a service ([#4213]). - -### Changed - -- On OpenBSD, the daemon script now uses the recommended `/bin/ksh` shell - instead of the `/bin/sh` one ([#4533]). To apply this change, backup your - data and run `AdGuardHome -s uninstall && AdGuardHome -s install`. -- The default DNS-over-QUIC port number is now `853` instead of `754` in - accordance with [RFC 9250][rfc-9250] ([#4276]). -- Reverse DNS now has a greater priority as the source of runtime clients' - information than ARP neighborhood. -- Improved detection of runtime clients through more resilient ARP processing - ([#3597]). -- The TTL of responses served from the optimistic cache is now lowered to 10 - seconds. -- Instead of adding the build time information, the build scripts now use the - standardized environment variable [`SOURCE_DATE_EPOCH`][repr] to add the date - of the commit from which the binary was built ([#4221]). This should simplify - reproducible builds for package maintainers and those who compile their own - AdGuard Home. -- The `dns.bogus_nxdomain` property in the configuration file now supports CIDR - notation alongside IP addresses ([#1730]). ### Fixed @@ -172,12 +166,18 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7]. - ARP tables refreshing process causing excessive PTR requests ([#3157]). [#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730 +[#3020]: https://github.com/AdguardTeam/AdGuardHome/issues/3020 +[#3142]: https://github.com/AdguardTeam/AdGuardHome/issues/3142 [#3157]: https://github.com/AdguardTeam/AdGuardHome/issues/3157 +[#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 +[#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 +[#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 [#3597]: https://github.com/AdguardTeam/AdGuardHome/issues/3597 [#3978]: https://github.com/AdguardTeam/AdGuardHome/issues/3978 [#4166]: https://github.com/AdguardTeam/AdGuardHome/issues/4166 [#4213]: https://github.com/AdguardTeam/AdGuardHome/issues/4213 [#4221]: https://github.com/AdguardTeam/AdGuardHome/issues/4221 +[#4238]: https://github.com/AdguardTeam/AdGuardHome/issues/4238 [#4273]: https://github.com/AdguardTeam/AdGuardHome/issues/4273 [#4276]: https://github.com/AdguardTeam/AdGuardHome/issues/4276 [#4480]: https://github.com/AdguardTeam/AdGuardHome/issues/4480 diff --git a/internal/aghnet/subnetdetector.go b/internal/aghnet/subnetdetector.go deleted file mode 100644 index e353903b..00000000 --- a/internal/aghnet/subnetdetector.go +++ /dev/null @@ -1,156 +0,0 @@ -package aghnet - -import ( - "net" -) - -// SubnetDetector describes IP address properties. -type SubnetDetector struct { - // spNets is the slice of special-purpose address registries as defined - // by RFC-6890 (https://tools.ietf.org/html/rfc6890). - spNets []*net.IPNet - - // locServedNets is the slice of locally-served networks as defined by - // RFC-6303 (https://tools.ietf.org/html/rfc6303). - locServedNets []*net.IPNet -} - -// NewSubnetDetector returns a new IP detector. -func NewSubnetDetector() (snd *SubnetDetector, err error) { - spNets := []string{ - // "This" network. - "0.0.0.0/8", - // Private-Use Networks. - "10.0.0.0/8", - // Shared Address Space. - "100.64.0.0/10", - // Loopback. - "127.0.0.0/8", - // Link Local. - "169.254.0.0/16", - // Private-Use Networks. - "172.16.0.0/12", - // IETF Protocol Assignments. - "192.0.0.0/24", - // DS-Lite. - "192.0.0.0/29", - // TEST-NET-1 - "192.0.2.0/24", - // 6to4 Relay Anycast. - "192.88.99.0/24", - // Private-Use Networks. - "192.168.0.0/16", - // Network Interconnect Device Benchmark Testing. - "198.18.0.0/15", - // TEST-NET-2. - "198.51.100.0/24", - // TEST-NET-3. - "203.0.113.0/24", - // Reserved for Future Use. - "240.0.0.0/4", - // Limited Broadcast. - "255.255.255.255/32", - - // Loopback. - "::1/128", - // Unspecified. - "::/128", - // IPv4-IPv6 Translation Address. - "64:ff9b::/96", - - // IPv4-Mapped Address. Since this network is used for mapping - // IPv4 addresses, we don't include it. - // "::ffff:0:0/96", - - // Discard-Only Prefix. - "100::/64", - // IETF Protocol Assignments. - "2001::/23", - // TEREDO. - "2001::/32", - // Benchmarking. - "2001:2::/48", - // Documentation. - "2001:db8::/32", - // ORCHID. - "2001:10::/28", - // 6to4. - "2002::/16", - // Unique-Local. - "fc00::/7", - // Linked-Scoped Unicast. - "fe80::/10", - } - - // TODO(e.burkov): It's a subslice of the slice above. Should be done - // smarter. - locServedNets := []string{ - // IPv4. - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "127.0.0.0/8", - "169.254.0.0/16", - "192.0.2.0/24", - "198.51.100.0/24", - "203.0.113.0/24", - "255.255.255.255/32", - // IPv6. - "::/128", - "::1/128", - "fe80::/10", - "2001:db8::/32", - "fd00::/8", - } - - snd = &SubnetDetector{ - spNets: make([]*net.IPNet, len(spNets)), - locServedNets: make([]*net.IPNet, len(locServedNets)), - } - for i, ipnetStr := range spNets { - var ipnet *net.IPNet - _, ipnet, err = net.ParseCIDR(ipnetStr) - if err != nil { - return nil, err - } - - snd.spNets[i] = ipnet - } - for i, ipnetStr := range locServedNets { - var ipnet *net.IPNet - _, ipnet, err = net.ParseCIDR(ipnetStr) - if err != nil { - return nil, err - } - - snd.locServedNets[i] = ipnet - } - - return snd, nil -} - -// anyNetContains ranges through the given ipnets slice searching for the one -// which contains the ip. For internal use only. -// -// TODO(e.burkov): Think about memoization. -func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) { - for _, ipnet := range *ipnets { - if ipnet.Contains(ip) { - return true - } - } - - return false -} - -// IsSpecialNetwork returns true if IP address is contained by any of -// special-purpose IP address registries. It's safe for concurrent use. -func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) { - return anyNetContains(&snd.spNets, ip) -} - -// IsLocallyServedNetwork returns true if IP address is contained by any of -// locally-served IP address registries. It's safe for concurrent use. -func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) { - return anyNetContains(&snd.locServedNets, ip) -} diff --git a/internal/aghnet/subnetdetector_test.go b/internal/aghnet/subnetdetector_test.go deleted file mode 100644 index f4b7678c..00000000 --- a/internal/aghnet/subnetdetector_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package aghnet - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) { - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testCases := []struct { - name string - ip net.IP - want bool - }{{ - name: "not_specific", - ip: net.ParseIP("8.8.8.8"), - want: false, - }, { - name: "this_host_on_this_network", - ip: net.ParseIP("0.0.0.0"), - want: true, - }, { - name: "private-Use", - ip: net.ParseIP("10.0.0.0"), - want: true, - }, { - name: "shared_address_space", - ip: net.ParseIP("100.64.0.0"), - want: true, - }, { - name: "loopback", - ip: net.ParseIP("127.0.0.0"), - want: true, - }, { - name: "link_local", - ip: net.ParseIP("169.254.0.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("172.16.0.0"), - want: true, - }, { - name: "ietf_protocol_assignments", - ip: net.ParseIP("192.0.0.0"), - want: true, - }, { - name: "ds-lite", - ip: net.ParseIP("192.0.0.0"), - want: true, - }, { - name: "documentation_(test-net-1)", - ip: net.ParseIP("192.0.2.0"), - want: true, - }, { - name: "6to4_relay_anycast", - ip: net.ParseIP("192.88.99.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("192.168.0.0"), - want: true, - }, { - name: "benchmarking", - ip: net.ParseIP("198.18.0.0"), - want: true, - }, { - name: "documentation_(test-net-2)", - ip: net.ParseIP("198.51.100.0"), - want: true, - }, { - name: "documentation_(test-net-3)", - ip: net.ParseIP("203.0.113.0"), - want: true, - }, { - name: "reserved", - ip: net.ParseIP("240.0.0.0"), - want: true, - }, { - name: "limited_broadcast", - ip: net.ParseIP("255.255.255.255"), - want: true, - }, { - name: "loopback_address", - ip: net.ParseIP("::1"), - want: true, - }, { - name: "unspecified_address", - ip: net.ParseIP("::"), - want: true, - }, { - name: "ipv4-ipv6_translation", - ip: net.ParseIP("64:ff9b::"), - want: true, - }, { - name: "discard-only_address_block", - ip: net.ParseIP("100::"), - want: true, - }, { - name: "ietf_protocol_assignments", - ip: net.ParseIP("2001::"), - want: true, - }, { - name: "teredo", - ip: net.ParseIP("2001::"), - want: true, - }, { - name: "benchmarking", - ip: net.ParseIP("2001:2::"), - want: true, - }, { - name: "documentation", - ip: net.ParseIP("2001:db8::"), - want: true, - }, { - name: "orchid", - ip: net.ParseIP("2001:10::"), - want: true, - }, { - name: "6to4", - ip: net.ParseIP("2002::"), - want: true, - }, { - name: "unique-local", - ip: net.ParseIP("fc00::"), - want: true, - }, { - name: "linked-scoped_unicast", - ip: net.ParseIP("fe80::"), - want: true, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip)) - }) - } -} - -func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) { - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testCases := []struct { - name string - ip net.IP - want bool - }{{ - name: "not_specific", - ip: net.ParseIP("8.8.8.8"), - want: false, - }, { - name: "private-Use", - ip: net.ParseIP("10.0.0.0"), - want: true, - }, { - name: "loopback", - ip: net.ParseIP("127.0.0.0"), - want: true, - }, { - name: "link_local", - ip: net.ParseIP("169.254.0.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("172.16.0.0"), - want: true, - }, { - name: "documentation_(test-net-1)", - ip: net.ParseIP("192.0.2.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("192.168.0.0"), - want: true, - }, { - name: "documentation_(test-net-2)", - ip: net.ParseIP("198.51.100.0"), - want: true, - }, { - name: "documentation_(test-net-3)", - ip: net.ParseIP("203.0.113.0"), - want: true, - }, { - name: "limited_broadcast", - ip: net.ParseIP("255.255.255.255"), - want: true, - }, { - name: "loopback_address", - ip: net.ParseIP("::1"), - want: true, - }, { - name: "unspecified_address", - ip: net.ParseIP("::"), - want: true, - }, { - name: "documentation", - ip: net.ParseIP("2001:db8::"), - want: true, - }, { - name: "linked-scoped_unicast", - ip: net.ParseIP("fe80::"), - want: true, - }, { - name: "locally_assigned", - ip: net.ParseIP("fd00::1"), - want: true, - }, { - name: "not_locally_assigned", - ip: net.ParseIP("fc00::1"), - want: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip)) - }) - } -} - -func TestSubnetDetector_Detect_parallel(t *testing.T) { - t.Parallel() - - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testFunc := func() { - for _, ip := range []net.IP{ - net.IPv4allrouter, - net.IPv4allsys, - net.IPv4bcast, - net.IPv4zero, - net.IPv6interfacelocalallnodes, - net.IPv6linklocalallnodes, - net.IPv6linklocalallrouters, - net.IPv6loopback, - net.IPv6unspecified, - } { - _ = snd.IsSpecialNetwork(ip) - _ = snd.IsLocallyServedNetwork(ip) - } - } - - const goroutinesNum = 50 - for i := 0; i < goroutinesNum; i++ { - go testFunc() - } -} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index aa364310..95d8f5ad 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -11,10 +11,10 @@ import ( "github.com/miekg/dns" ) -// TestUpstream is a mock of real upstream. -type TestUpstream struct { +// Upstream is a mock implementation of upstream.Upstream. +type Upstream struct { // CName is a map of hostname to canonical name. - CName map[string]string + CName map[string][]string // IPv4 is a map of hostname to IPv4. IPv4 map[string][]net.IP // IPv6 is a map of hostname to IPv6. @@ -25,78 +25,45 @@ type TestUpstream struct { Addr string } -// Exchange implements upstream.Upstream interface for *TestUpstream. +// Exchange implements the upstream.Upstream interface for *Upstream. // // TODO(a.garipov): Split further into handlers. -func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { - resp = &dns.Msg{} - resp.SetReply(m) +func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { + resp = new(dns.Msg).SetReply(m) if len(m.Question) == 0 { return nil, fmt.Errorf("question should not be empty") } - name := m.Question[0].Name - - if cname, ok := u.CName[name]; ok { - ans := &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeCNAME, - }, + q := m.Question[0] + name := q.Name + for _, cname := range u.CName[name] { + resp.Answer = append(resp.Answer, &dns.CNAME{ + Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME}, Target: cname, - } - - resp.Answer = append(resp.Answer, ans) + }) } - rrType := m.Question[0].Qtype + qtype := q.Qtype hdr := dns.RR_Header{ Name: name, - Rrtype: rrType, + Rrtype: qtype, } - var names []string - var ips []net.IP - switch m.Question[0].Qtype { + switch qtype { case dns.TypeA: - ips = u.IPv4[name] + for _, ip := range u.IPv4[name] { + resp.Answer = append(resp.Answer, &dns.A{Hdr: hdr, A: ip}) + } case dns.TypeAAAA: - ips = u.IPv6[name] + for _, ip := range u.IPv6[name] { + resp.Answer = append(resp.Answer, &dns.AAAA{Hdr: hdr, AAAA: ip}) + } case dns.TypePTR: - names = u.Reverse[name] - } - - for _, ip := range ips { - var ans dns.RR - if rrType == dns.TypeA { - ans = &dns.A{ - Hdr: hdr, - A: ip, - } - - resp.Answer = append(resp.Answer, ans) - - continue + for _, name := range u.Reverse[name] { + resp.Answer = append(resp.Answer, &dns.PTR{Hdr: hdr, Ptr: name}) } - - ans = &dns.AAAA{ - Hdr: hdr, - AAAA: ip, - } - - resp.Answer = append(resp.Answer, ans) } - - for _, n := range names { - ans := &dns.PTR{ - Hdr: hdr, - Ptr: n, - } - - resp.Answer = append(resp.Answer, ans) - } - if len(resp.Answer) == 0 { resp.SetRcode(m, dns.RcodeNameError) } @@ -104,8 +71,8 @@ func (u *TestUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -// Address implements upstream.Upstream interface for *TestUpstream. -func (u *TestUpstream) Address() string { +// Address implements upstream.Upstream interface for *Upstream. +func (u *Upstream) Address() string { return u.Addr } diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 341aa3d6..55c56c18 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -119,23 +119,28 @@ func (l *Lease) UnmarshalJSON(data []byte) (err error) { return nil } -// ServerConfig - DHCP server configuration -// field ordering is important -- yaml fields will mirror ordering from here +// ServerConfig is the configuration for the DHCP server. The order of YAML +// fields is important, since the YAML configuration file follows it. type ServerConfig struct { - Enabled bool `yaml:"enabled"` - InterfaceName string `yaml:"interface_name"` - - Conf4 V4ServerConf `yaml:"dhcpv4"` - Conf6 V6ServerConf `yaml:"dhcpv6"` - - WorkDir string `yaml:"-"` - DBFilePath string `yaml:"-"` // path to DB file - // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` // Register an HTTP handler HTTPRegister func(string, string, func(http.ResponseWriter, *http.Request)) `yaml:"-"` + + Enabled bool `yaml:"enabled"` + InterfaceName string `yaml:"interface_name"` + + // LocalDomainName is the domain name used for DHCP hosts. For example, + // a DHCP client with the hostname "myhost" can be addressed as "myhost.lan" + // when LocalDomainName is "lan". + LocalDomainName string `yaml:"local_domain_name"` + + Conf4 V4ServerConf `yaml:"dhcpv4"` + Conf6 V6ServerConf `yaml:"dhcpv6"` + + WorkDir string `yaml:"-"` + DBFilePath string `yaml:"-"` } // OnLeaseChangedT is a callback for lease changes. @@ -156,7 +161,9 @@ type Server struct { srv4 DHCPServer srv6 DHCPServer - conf ServerConfig + // TODO(a.garipov): Either create a separate type for the internal config or + // just put the config values into Server. + conf *ServerConfig // Called when the leases DB is modified onLeaseChanged []OnLeaseChangedT @@ -181,14 +188,21 @@ type ServerInterface interface { } // Create - create object -func Create(conf ServerConfig) (s *Server, err error) { - s = &Server{} +func Create(conf *ServerConfig) (s *Server, err error) { + s = &Server{ + conf: &ServerConfig{ + ConfigModified: conf.ConfigModified, - s.conf.Enabled = conf.Enabled - s.conf.InterfaceName = conf.InterfaceName - s.conf.HTTPRegister = conf.HTTPRegister - s.conf.ConfigModified = conf.ConfigModified - s.conf.DBFilePath = filepath.Join(conf.WorkDir, dbFilename) + HTTPRegister: conf.HTTPRegister, + + Enabled: conf.Enabled, + InterfaceName: conf.InterfaceName, + + LocalDomainName: conf.LocalDomainName, + + DBFilePath: filepath.Join(conf.WorkDir, dbFilename), + }, + } if !webHandlersRegistered && s.conf.HTTPRegister != nil { if runtime.GOOS == "windows" { @@ -305,6 +319,7 @@ func (s *Server) notify(flags int) { func (s *Server) WriteDiskConfig(c *ServerConfig) { c.Enabled = s.conf.Enabled c.InterfaceName = s.conf.InterfaceName + c.LocalDomainName = s.conf.LocalDomainName s.srv4.WriteDiskConfig4(&c.Conf4) s.srv6.WriteDiskConfig6(&c.Conf6) } diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index b8fc5fa0..b704cbb4 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -27,7 +27,7 @@ func testNotify(flags uint32) { func TestDB(t *testing.T) { var err error s := Server{ - conf: ServerConfig{ + conf: &ServerConfig{ DBFilePath: dbFilename, }, } @@ -140,27 +140,27 @@ func TestNormalizeLeases(t *testing.T) { func TestV4Server_badRange(t *testing.T) { testCases := []struct { name string + wantErrMsg string gatewayIP net.IP subnetMask net.IP - wantErrMsg string }{{ - name: "gateway_in_range", - gatewayIP: net.IP{192, 168, 10, 120}, - subnetMask: net.IP{255, 255, 255, 0}, + name: "gateway_in_range", wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " + "192.168.10.20-192.168.10.200", + gatewayIP: net.IP{192, 168, 10, 120}, + subnetMask: net.IP{255, 255, 255, 0}, }, { - name: "outside_range_start", - gatewayIP: net.IP{192, 168, 10, 1}, - subnetMask: net.IP{255, 255, 255, 240}, + name: "outside_range_start", wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " + "192.168.10.1/28", - }, { - name: "outside_range_end", gatewayIP: net.IP{192, 168, 10, 1}, - subnetMask: net.IP{255, 255, 255, 224}, + subnetMask: net.IP{255, 255, 255, 240}, + }, { + name: "outside_range_end", wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " + "192.168.10.1/27", + gatewayIP: net.IP{192, 168, 10, 1}, + subnetMask: net.IP{255, 255, 255, 224}, }} for _, tc := range testCases { diff --git a/internal/dhcpd/http.go b/internal/dhcpd/http.go index 78016010..e340addb 100644 --- a/internal/dhcpd/http.go +++ b/internal/dhcpd/http.go @@ -575,12 +575,15 @@ func (s *Server) handleReset(w http.ResponseWriter, r *http.Request) { log.Error("dhcp: removing db: %s", err) } - oldconf := s.conf - s.conf = ServerConfig{ - WorkDir: oldconf.WorkDir, - HTTPRegister: oldconf.HTTPRegister, - ConfigModified: oldconf.ConfigModified, - DBFilePath: oldconf.DBFilePath, + s.conf = &ServerConfig{ + ConfigModified: s.conf.ConfigModified, + + HTTPRegister: s.conf.HTTPRegister, + + LocalDomainName: s.conf.LocalDomainName, + + WorkDir: s.conf.WorkDir, + DBFilePath: s.conf.DBFilePath, } v4conf := V4ServerConf{ diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 1617e339..b44eab1f 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -132,8 +132,9 @@ type FilteringConfig struct { // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS type TLSConfig struct { - TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"` - QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"` + TLSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"` + QUICListenAddrs []*net.UDPAddr `yaml:"-" json:"-"` + HTTPSListenAddrs []*net.TCPAddr `yaml:"-" json:"-"` // Reject connection if the client uses server name (in SNI) that doesn't match the certificate StrictSNICheck bool `yaml:"strict_sni_check" json:"-"` diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 5e0b8293..d423482a 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -214,9 +214,8 @@ func (s *Server) onDHCPLeaseChanged(flags int) { ipToHost = netutil.NewIPMap(len(ll)) for _, l := range ll { - // TODO(a.garipov): Remove this after we're finished - // with the client hostname validations in the DHCP - // server code. + // TODO(a.garipov): Remove this after we're finished with the client + // hostname validations in the DHCP server code. err = netutil.ValidateDomainName(l.Hostname) if err != nil { log.Debug( @@ -252,7 +251,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { return rc } - dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip) + dctx.isLocalClient = s.privateNets.Contains(ip) return rc } @@ -300,6 +299,8 @@ func (s *Server) processInternalHosts(dctx *dnsContext) (rc resultCode) { } reqHost := strings.ToLower(q.Name) + // TODO(a.garipov): Move everything related to DHCP local domain to the DHCP + // server. host := strings.TrimSuffix(reqHost, s.localDomainSuffix) if host == reqHost { return resultCodeSuccess @@ -372,7 +373,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { // Restrict an access to local addresses for external clients. We also // assume that all the DHCP leases we give are locally-served or at least // don't need to be inaccessible externally. - if !s.subnetDetector.IsLocallyServedNetwork(ip) { + if !s.privateNets.Contains(ip) { log.Debug("dns: addr %s is not from locally-served network", ip) return resultCodeSuccess @@ -479,7 +480,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { s.serverLock.RLock() defer s.serverLock.RUnlock() - if !s.subnetDetector.IsLocallyServedNetwork(ip) { + if !s.privateNets.Contains(ip) { return resultCodeSuccess } @@ -611,9 +612,9 @@ func (s *Server) processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) d.Res.Answer = answer } default: - // Check the response only if the it's from an upstream. Don't check - // the response if the protection is disabled since dnsrewrite rules - // aren't applied to it anyway. + // Check the response only if it's from an upstream. Don't check the + // response if the protection is disabled since dnsrewrite rules aren't + // applied to it anyway. if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { break } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index edf54f51..54104268 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -4,35 +4,41 @@ import ( "net" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestServer_ProcessDetermineLocal(t *testing.T) { - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) s := &Server{ - subnetDetector: snd, + privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), } testCases := []struct { + want assert.BoolAssertionFunc name string cliIP net.IP - want bool }{{ + want: assert.True, name: "local", cliIP: net.IP{192, 168, 0, 1}, - want: true, }, { + want: assert.False, name: "external", cliIP: net.IP{250, 249, 0, 1}, - want: false, + }, { + want: assert.False, + name: "invalid", + cliIP: net.IP{1, 2, 3, 4, 5}, + }, { + want: assert.False, + name: "nil", + cliIP: nil, }} for _, tc := range testCases { @@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) { } s.processDetermineLocal(dctx) - assert.Equal(t, tc.want, dctx.isLocalClient) + tc.want(t, dctx.isLocalClient) }) } } @@ -261,7 +267,7 @@ func TestServer_ProcessInternalHosts(t *testing.T) { } func TestServer_ProcessRestrictLocal(t *testing.T) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ Reverse: map[string][]string{ "251.252.253.254.in-addr.arpa.": {"host1.example.net."}, "1.1.168.192.in-addr.arpa.": {"some.local-client."}, @@ -339,7 +345,7 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { s := createTestServer(t, &filtering.Config{}, ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, - }, &aghtest.TestUpstream{ + }, &aghtest.Upstream{ Reverse: map[string][]string{ reqAddr: {locDomain}, }, diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 3016f133..c0cd0e55 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -74,7 +74,7 @@ type Server struct { localDomainSuffix string ipset ipsetCtx - subnetDetector *aghnet.SubnetDetector + privateNets netutil.SubnetSet localResolvers *proxy.Proxy sysResolvers aghnet.SystemResolvers recDetector *recursionDetector @@ -111,13 +111,13 @@ const defaultLocalDomainSuffix = ".lan." // DNSCreateParams are parameters to create a new server. type DNSCreateParams struct { - DNSFilter *filtering.DNSFilter - Stats stats.Stats - QueryLog querylog.QueryLog - DHCPServer dhcpd.ServerInterface - SubnetDetector *aghnet.SubnetDetector - Anonymizer *aghnet.IPMut - LocalDomain string + DNSFilter *filtering.DNSFilter + Stats stats.Stats + QueryLog querylog.QueryLog + DHCPServer dhcpd.ServerInterface + PrivateNets netutil.SubnetSet + Anonymizer *aghnet.IPMut + LocalDomain string } // domainNameToSuffix converts a domain name into a local domain suffix. @@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { dnsFilter: p.DNSFilter, stats: p.Stats, queryLog: p.QueryLog, - subnetDetector: p.SubnetDetector, + privateNets: p.PrivateNets, localDomainSuffix: localDomainSuffix, recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ @@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { } var resolver *proxy.Proxy - if s.subnetDetector.IsLocallyServedNetwork(ip) { + if s.privateNets.Contains(ip) { if !s.conf.UsePrivateRDNS { return "", nil } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index d0191c85..36761f41 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -24,6 +24,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" @@ -69,14 +70,11 @@ func createTestServer( f := filtering.New(filterConf, filters) f.SetEnabled(true) - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - + var err error s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -89,7 +87,7 @@ func createTestServer( defer s.serverLock.Unlock() if localUps != nil { - s.localResolvers.Config.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} + s.localResolvers.UpstreamConfig.Upstreams = []upstream.Upstream{localUps} s.conf.UsePrivateRDNS = true } @@ -247,7 +245,7 @@ func TestServer(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -316,7 +314,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -339,7 +337,7 @@ func TestDoTServer(t *testing.T) { TLSListenAddrs: []*net.TCPAddr{{}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -369,7 +367,7 @@ func TestDoQServer(t *testing.T) { QUICListenAddrs: []*net.UDPAddr{{IP: net.IP{127, 0, 0, 1}}}, }) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -413,7 +411,7 @@ func TestServerRace(t *testing.T) { } s := createTestServer(t, filterConf, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ IPv4: map[string][]net.IP{ "google-public-dns-a.google.com.": {{8, 8, 8, 8}}, }, @@ -552,7 +550,7 @@ func TestServerCustomClientUpstream(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.GetCustomUpstreamByClient = func(_ string) (conf *proxy.UpstreamConfig, err error) { - ups := &aghtest.TestUpstream{ + ups := &aghtest.Upstream{ IPv4: map[string][]net.IP{ "host.": {{192, 168, 0, 1}}, }, @@ -580,9 +578,9 @@ func TestServerCustomClientUpstream(t *testing.T) { } // testCNAMEs is a map of names and CNAMEs necessary for the TestUpstream work. -var testCNAMEs = map[string]string{ - "badhost.": "NULL.example.org.", - "whitelist.example.org.": "NULL.example.org.", +var testCNAMEs = map[string][]string{ + "badhost.": {"NULL.example.org."}, + "whitelist.example.org.": {"NULL.example.org."}, } // testIPv4 is a map of names and IPv4s necessary for the TestUpstream work. @@ -596,7 +594,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) { UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, }, nil) - testUpstm := &aghtest.TestUpstream{ + testUpstm := &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, IPv6: nil, @@ -630,7 +628,7 @@ func TestBlockCNAME(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -640,14 +638,17 @@ func TestBlockCNAME(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() testCases := []struct { + name string host string want bool }{{ + name: "block_request", host: "badhost.", // 'badhost' has a canonical name 'NULL.example.org' which is // blocked by filters: response is blocked. want: true, }, { + name: "allowed", host: "whitelist.example.org.", // 'whitelist.example.org' has a canonical name // 'NULL.example.org' which is blocked by filters @@ -655,6 +656,7 @@ func TestBlockCNAME(t *testing.T) { // response isn't blocked. want: false, }, { + name: "block_response", host: "example.org.", // 'example.org' has a canonical name 'cname1' with IP // 127.0.0.255 which is blocked by filters: response is blocked. @@ -662,9 +664,9 @@ func TestBlockCNAME(t *testing.T) { }} for _, tc := range testCases { - t.Run("block_cname_"+tc.host, func(t *testing.T) { - req := createTestMessage(tc.host) + req := createTestMessage(tc.host) + t.Run(tc.name, func(t *testing.T) { reply, err := dns.Exchange(req, addr) require.NoError(t, err) @@ -674,7 +676,7 @@ func TestBlockCNAME(t *testing.T) { ans := reply.Answer[0] a, ok := ans.(*dns.A) - require.Truef(t, ok, "got %T", ans) + require.True(t, ok) assert.True(t, a.A.IsUnspecified()) } @@ -695,7 +697,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) { } s := createTestServer(t, &filtering.Config{}, forwardConf, nil) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ + &aghtest.Upstream{ CName: testCNAMEs, IPv4: testIPv4, }, @@ -766,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - f := filtering.New(&filtering.Config{}, filters) - var s *Server - s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -909,15 +906,10 @@ func TestRewrite(t *testing.T) { f := filtering.New(c, nil) f.SetEnabled(true) - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - - var s *Server - s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -931,9 +923,9 @@ func TestRewrite(t *testing.T) { })) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ - &aghtest.TestUpstream{ - CName: map[string]string{ - "example.org": "somename", + &aghtest.Upstream{ + CName: map[string][]string{ + "example.org": {"somename"}, }, IPv4: map[string][]net.IP{ "example.org.": {{4, 3, 2, 1}}, @@ -1024,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) { func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - - var s *Server - s, err = NewServer(DNSCreateParams{ - DNSFilter: filtering.New(&filtering.Config{}, nil), - DHCPServer: &testDHCP{}, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DNSFilter: filtering.New(&filtering.Config{}, nil), + DHCPServer: &testDHCP{}, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -1101,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) { }, nil) flt.SetEnabled(true) - var snd *aghnet.SubnetDetector - snd, err = aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - var s *Server s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: flt, - SubnetDetector: snd, + DHCPServer: &testDHCP{}, + DNSFilter: flt, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -1193,12 +1175,12 @@ func TestNewServer(t *testing.T) { } func TestServer_Exchange(t *testing.T) { - extUpstream := &aghtest.TestUpstream{ + extUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, }, } - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "1.1.168.192.in-addr.arpa.": {"local.domain"}, "2.1.168.192.in-addr.arpa.": {}, @@ -1223,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) { srv.conf.ResolveClients = true srv.conf.UsePrivateRDNS = true - var err error - srv.subnetDetector, err = aghnet.NewSubnetDetector() - require.NoError(t, err) + srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) localIP := net.IP{192, 168, 1, 1} testCases := []struct { diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 471b463e..18f12797 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -116,7 +116,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { // checkHostRules checks the host against filters. It is safe for concurrent // use. -func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Settings) ( +func (s *Server) checkHostRules(host string, rrtype uint16, setts *filtering.Settings) ( r *filtering.Result, err error, ) { @@ -128,7 +128,7 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett } var res filtering.Result - res, err = s.dnsFilter.CheckHostRules(host, qtype, setts) + res, err = s.dnsFilter.CheckHostRules(host, rrtype, setts) if err != nil { return nil, err } @@ -136,33 +136,36 @@ func (s *Server) checkHostRules(host string, qtype uint16, setts *filtering.Sett return &res, err } -// If response contains CNAME, A or AAAA records, we apply filtering to each -// canonical host name or IP address. If this is a match, we set a new response -// in d.Res and return. -func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { +// filterDNSResponse checks each resource record of the response's answer +// section from ctx and returns a non-nil res if at least one of canonnical +// names or IP addresses in it matches the filtering rules. +func (s *Server) filterDNSResponse(ctx *dnsContext) (res *filtering.Result, err error) { d := ctx.proxyCtx + setts := ctx.setts + if !setts.FilteringEnabled { + return nil, nil + } + for _, a := range d.Res.Answer { host := "" - - switch v := a.(type) { + var rrtype uint16 + switch a := a.(type) { case *dns.CNAME: - log.Debug("DNSFwd: Checking CNAME %s for %s", v.Target, v.Hdr.Name) - host = strings.TrimSuffix(v.Target, ".") - + host = strings.TrimSuffix(a.Target, ".") + rrtype = dns.TypeCNAME case *dns.A: - host = v.A.String() - log.Debug("DNSFwd: Checking record A (%s) for %s", host, v.Hdr.Name) - + host = a.A.String() + rrtype = dns.TypeA case *dns.AAAA: - host = v.AAAA.String() - log.Debug("DNSFwd: Checking record AAAA (%s) for %s", host, v.Hdr.Name) - + host = a.AAAA.String() + rrtype = dns.TypeAAAA default: continue } - host = strings.TrimSuffix(host, ".") - res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) + log.Debug("dnsforward: checking %s %s for %s", dns.Type(rrtype), host, a.Header().Name) + + res, err = s.checkHostRules(host, rrtype, setts) if err != nil { return nil, err } else if res == nil { diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go new file mode 100644 index 00000000..69dcf8f9 --- /dev/null +++ b/internal/dnsforward/filter_test.go @@ -0,0 +1,154 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { + rules := ` +||blocked.domain^ +@@||allowed.domain^ +||cname.specific^$dnstype=~CNAME +||0.0.0.1^$dnstype=~A +||::1^$dnstype=~AAAA +` + + forwardConf := ServerConfig{ + UDPListenAddrs: []*net.UDPAddr{{}}, + TCPListenAddrs: []*net.TCPAddr{{}}, + FilteringConfig: FilteringConfig{ + ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, + }, + } + filters := []filtering.Filter{{ + ID: 0, Data: []byte(rules), + }} + + f := filtering.New(&filtering.Config{}, filters) + f.SetEnabled(true) + + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + }) + require.NoError(t, err) + + s.conf = forwardConf + err = s.Prepare(nil) + require.NoError(t, err) + + s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ + &aghtest.Upstream{ + CName: map[string][]string{ + "cname.exception.": {"cname.specific."}, + "should.block.": {"blocked.domain."}, + "allowed.first.": {"allowed.domain.", "blocked.domain."}, + "blocked.first.": {"blocked.domain.", "allowed.domain."}, + }, + IPv4: map[string][]net.IP{ + "a.exception.": {{0, 0, 0, 1}}, + }, + IPv6: map[string][]net.IP{ + "aaaa.exception.": {net.ParseIP("::1")}, + }, + }, + } + startDeferStop(t, s) + + testCases := []struct { + req *dns.Msg + name string + wantAns []dns.RR + }{{ + req: createTestMessage("cname.exception."), + name: "cname_exception", + wantAns: []dns.RR{&dns.CNAME{ + Hdr: dns.RR_Header{ + Name: "cname.exception.", + Rrtype: dns.TypeCNAME, + }, + Target: "cname.specific.", + }}, + }, { + req: createTestMessage("should.block."), + name: "blocked_by_cname", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "should.block.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }, { + req: createTestMessage("a.exception."), + name: "a_exception", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "a.exception.", + Rrtype: dns.TypeA, + }, + A: net.IP{0, 0, 0, 1}, + }}, + }, { + req: createTestMessageWithType("aaaa.exception.", dns.TypeAAAA), + name: "aaaa_exception", + wantAns: []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{ + Name: "aaaa.exception.", + Rrtype: dns.TypeAAAA, + }, + AAAA: net.ParseIP("::1"), + }}, + }, { + req: createTestMessage("allowed.first."), + name: "allowed_first", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "allowed.first.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }, { + req: createTestMessage("blocked.first."), + name: "blocked_first", + wantAns: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Name: "blocked.first.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + }, + A: netutil.IPv4Zero(), + }}, + }} + + for _, tc := range testCases { + dctx := &proxy.DNSContext{ + Proto: proxy.ProtoUDP, + Req: tc.req, + Addr: &net.UDPAddr{IP: net.IP{127, 0, 0, 1}, Port: 1}, + } + + t.Run(tc.name, func(t *testing.T) { + err = s.handleDNSRequest(nil, dctx) + require.NoError(t, err) + require.NotNil(t, dctx.Res) + + assert.Equal(t, tc.wantAns, dctx.Res.Answer) + }) + } +} diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 6c7f58e7..50ab9643 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "sort" "strings" "time" @@ -41,7 +42,7 @@ type dnsConfig struct { LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"` } -func (s *Server) getDNSConfig() dnsConfig { +func (s *Server) getDNSConfig() (c *dnsConfig) { s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -70,7 +71,7 @@ func (s *Server) getDNSConfig() dnsConfig { upstreamMode = "parallel" } - return dnsConfig{ + return &dnsConfig{ Upstreams: &upstreams, UpstreamsFile: &upstreamFile, Bootstraps: &bootstraps, @@ -106,7 +107,7 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { // since there is no need to omit it while decoding from JSON. DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"` }{ - dnsConfig: s.getDNSConfig(), + dnsConfig: *s.getDNSConfig(), DefautLocalPTRUpstreams: defLocalPTRUps, } @@ -138,39 +139,63 @@ func (req *dnsConfig) checkBlockingMode() bool { } func (req *dnsConfig) checkUpstreamsMode() bool { - if req.UpstreamMode == nil { - return true - } + valid := []string{"", "fastest_addr", "parallel"} - for _, valid := range []string{ - "", - "fastest_addr", - "parallel", - } { - if *req.UpstreamMode == valid { - return true - } - } - - return false + return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode) } -func (req *dnsConfig) checkBootstrap() (string, error) { +func (req *dnsConfig) checkBootstrap() (err error) { if req.Bootstraps == nil { - return "", nil + return nil } - for _, boot := range *req.Bootstraps { - if boot == "" { - return boot, fmt.Errorf("invalid bootstrap server address: empty") + var b string + defer func() { err = errors.Annotate(err, "checking bootstrap %s: invalid address: %w", b) }() + + for _, b = range *req.Bootstraps { + if b == "" { + return errors.Error("empty") } - if _, err := upstream.NewResolver(boot, nil); err != nil { - return boot, fmt.Errorf("invalid bootstrap server address: %w", err) + if _, err = upstream.NewResolver(b, nil); err != nil { + return err } } - return "", nil + return nil +} + +// validate returns an error if any field of req is invalid. +func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) { + if req.Upstreams != nil { + err = ValidateUpstreams(*req.Upstreams) + if err != nil { + return fmt.Errorf("validating upstream servers: %w", err) + } + } + + if req.LocalPTRUpstreams != nil { + err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) + if err != nil { + return fmt.Errorf("validating private upstream servers: %w", err) + } + } + + err = req.checkBootstrap() + if err != nil { + return err + } + + switch { + case !req.checkBlockingMode(): + return errors.Error("blocking_mode: incorrect value") + case !req.checkUpstreamsMode(): + return errors.Error("upstream_mode: incorrect value") + case !req.checkCacheTTL(): + return errors.Error("cache_ttl_min must be less or equal than cache_ttl_max") + default: + return nil + } } func (req *dnsConfig) checkCacheTTL() bool { @@ -190,69 +215,33 @@ func (req *dnsConfig) checkCacheTTL() bool { } func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { - req := dnsConfig{} - err := json.NewDecoder(r.Body).Decode(&req) + req := &dnsConfig{} + err := json.NewDecoder(r.Body).Decode(req) if err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err) + aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err) return } - if req.Upstreams != nil { - if err = ValidateUpstreams(*req.Upstreams); err != nil { - aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err) - - return - } - } - - var errBoot string - if errBoot, err = req.checkBootstrap(); err != nil { - aghhttp.Error( - r, - w, - http.StatusBadRequest, - "%s can not be used as bootstrap dns cause: %s", - errBoot, - err, - ) + err = req.validate(s.privateNets) + if err != nil { + aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - switch { - case !req.checkBlockingMode(): - aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value") - - return - case !req.checkUpstreamsMode(): - aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value") - - return - case !req.checkCacheTTL(): - aghhttp.Error( - r, - w, - http.StatusBadRequest, - "cache_ttl_min must be less or equal than cache_ttl_max", - ) - - return - default: - // Go on. - } - restart := s.setConfig(req) s.conf.ConfigModified() if restart { - if err = s.Reconfigure(nil); err != nil { + err = s.Reconfigure(nil) + if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) } } } -func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) { +func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) { if dc.Upstreams != nil { s.conf.UpstreamDNS = *dc.Upstreams restart = true @@ -273,9 +262,9 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) { restart = true } - if dc.RateLimit != nil { - restart = restart || s.conf.Ratelimit != *dc.RateLimit + if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit { s.conf.Ratelimit = *dc.RateLimit + restart = true } if dc.EDNSCSEnabled != nil { @@ -306,7 +295,7 @@ func (s *Server) setConfigRestartable(dc dnsConfig) (restart bool) { return restart } -func (s *Server) setConfig(dc dnsConfig) (restart bool) { +func (s *Server) setConfig(dc *dnsConfig) (restart bool) { s.serverLock.Lock() defer s.serverLock.Unlock() @@ -353,52 +342,106 @@ type upstreamJSON struct { PrivateUpstreams []string `json:"private_upstream"` } -// IsCommentOrEmpty returns true of the string starts with a "#" character or is -// an empty string. This function is useful for filtering out non-upstream -// lines from upstream configs. +// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. +// This function is useful for filtering out non-upstream lines from upstream +// configs. func IsCommentOrEmpty(s string) (ok bool) { return len(s) == 0 || s[0] == '#' } +// newUpstreamConfig validates upstreams and returns an appropriate upstream +// configuration or nil if it can't be built. +// +// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams +// slice already so that this function may be considered useless. +func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) { + // No need to validate comments and empty lines. + upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty) + if len(upstreams) == 0 { + // Consider this case valid since it means the default server should be + // used. + return nil, nil + } + + conf, err = proxy.ParseUpstreamsConfig( + upstreams, + &upstream.Options{Bootstrap: []string{}, Timeout: DefaultTimeout}, + ) + if err != nil { + return nil, err + } else if len(conf.Upstreams) == 0 { + return nil, errors.Error("no default upstreams specified") + } + + for _, u := range upstreams { + _, err = validateUpstream(u) + if err != nil { + return nil, err + } + } + + return conf, nil +} + // ValidateUpstreams validates each upstream and returns an error if any // upstream is invalid or if there are no default upstreams specified. // -// TODO(e.burkov): Move into aghnet or even into dnsproxy. +// TODO(e.burkov): Move into aghnet or even into dnsproxy. func ValidateUpstreams(upstreams []string) (err error) { - // No need to validate comments - upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty) + _, err = newUpstreamConfig(upstreams) - // Consider this case valid because defaultDNS will be used - if len(upstreams) == 0 { - return nil + return err +} + +// stringKeysSorted returns the sorted slice of string keys of m. +// +// TODO(e.burkov): Use generics in Go 1.18. Move into golibs. +func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) { + sorted = make([]string, 0, len(m)) + for s := range m { + sorted = append(sorted, s) } - _, err = proxy.ParseUpstreamsConfig( - upstreams, - &upstream.Options{ - Bootstrap: []string{}, - Timeout: DefaultTimeout, - }, - ) + sort.Strings(sorted) + + return sorted +} + +// ValidateUpstreamsPrivate validates each upstream and returns an error if any +// upstream is invalid or if there are no default upstreams specified. It also +// checks each domain of domain-specific upstreams for being ARPA pointing to +// a locally-served network. privateNets must not be nil. +func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { + conf, err := newUpstreamConfig(upstreams) if err != nil { return err } - var defaultUpstreamFound bool - for _, u := range upstreams { - var useDefault bool - useDefault, err = validateUpstream(u) + if conf == nil { + return nil + } + + var errs []error + + for _, domain := range stringKeysSorted(conf.DomainReservedUpstreams) { + var subnet *net.IPNet + subnet, err = netutil.SubnetFromReversedAddr(domain) if err != nil { - return err + errs = append(errs, err) + + continue } - if !defaultUpstreamFound { - defaultUpstreamFound = useDefault + if !privateNets.Contains(subnet.IP) { + errs = append( + errs, + fmt.Errorf("arpa domain %q should point to a locally-served network", domain), + ) } } - if !defaultUpstreamFound { - return fmt.Errorf("no default upstreams specified") + if len(errs) > 0 { + return errors.List("checking domain-specific upstreams", errs...) } return nil diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index a501f7dc..f468f7ae 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -184,12 +185,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { wantSet: "", }, { name: "upstream_dns_bad", - wantSet: `wrong upstreams specification: bad ipport address "!!!": address !!!: ` + - `missing port in address`, + wantSet: `validating upstream servers: bad ipport address "!!!": ` + + `address !!!: missing port in address`, }, { name: "bootstraps_bad", - wantSet: `a can not be used as bootstrap dns cause: ` + - `invalid bootstrap server address: ` + + wantSet: `checking bootstrap a: invalid address: ` + `Resolver a is not eligible to be a bootstrap DNS server`, }, { name: "cache_bad_ttl", @@ -200,6 +200,10 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { }, { name: "local_ptr_upstreams_good", wantSet: "", + }, { + name: "local_ptr_upstreams_bad", + wantSet: `validating private upstream servers: checking domain-specific upstreams: ` + + `bad arpa domain name "non.arpa": not a reversed ip network`, }, { name: "local_ptr_upstreams_null", wantSet: "", @@ -358,7 +362,7 @@ func TestValidateUpstream(t *testing.T) { } } -func TestValidateUpstreamsSet(t *testing.T) { +func TestValidateUpstreams(t *testing.T) { testCases := []struct { name string wantErr string @@ -405,3 +409,51 @@ func TestValidateUpstreamsSet(t *testing.T) { }) } } + +func TestValidateUpstreamsPrivate(t *testing.T) { + ss := netutil.SubnetSetFunc(netutil.IsLocallyServed) + + testCases := []struct { + name string + wantErr string + u string + }{{ + name: "success_address", + wantErr: ``, + u: "[/1.0.0.127.in-addr.arpa/]#", + }, { + name: "success_subnet", + wantErr: ``, + u: "[/127.in-addr.arpa/]#", + }, { + name: "not_arpa_subnet", + wantErr: `checking domain-specific upstreams: ` + + `bad arpa domain name "hello.world": not a reversed ip network`, + u: "[/hello.world/]#", + }, { + name: "non-private_arpa_address", + wantErr: `checking domain-specific upstreams: ` + + `arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`, + u: "[/1.2.3.4.in-addr.arpa/]#", + }, { + name: "non-private_arpa_subnet", + wantErr: `checking domain-specific upstreams: ` + + `arpa domain "128.in-addr.arpa." should point to a locally-served network`, + u: "[/128.in-addr.arpa/]#", + }, { + name: "several_bad", + wantErr: `checking domain-specific upstreams: 2 errors: ` + + `"arpa domain \"1.2.3.4.in-addr.arpa.\" should point to a locally-served network", ` + + `"bad arpa domain name \"non.arpa\": not a reversed ip network"`, + u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#", + }} + + for _, tc := range testCases { + set := []string{"192.168.0.1", tc.u} + + t.Run(tc.name, func(t *testing.T) { + err := ValidateUpstreamsPrivate(set, ss) + testutil.AssertErrorMsg(t, tc.wantErr, err) + }) + } +} diff --git a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json index b594029c..830bf491 100644 --- a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json +++ b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json @@ -520,6 +520,43 @@ ] } }, + "local_ptr_upstreams_bad": { + "req": { + "local_ptr_upstreams": [ + "123.123.123.123", + "[/non.arpa/]#" + ] + }, + "want": { + "upstream_dns": [ + "8.8.8.8:53", + "8.8.4.4:53" + ], + "upstream_dns_file": "", + "bootstrap_dns": [ + "9.9.9.10", + "149.112.112.10", + "2620:fe::10", + "2620:fe::fe:10" + ], + "protection_enabled": true, + "ratelimit": 0, + "blocking_mode": "", + "blocking_ipv4": "", + "blocking_ipv6": "", + "edns_cs_enabled": false, + "dnssec_enabled": false, + "disable_ipv6": false, + "upstream_mode": "", + "cache_size": 0, + "cache_ttl_min": 0, + "cache_ttl_max": 0, + "cache_optimistic": false, + "resolve_clients": false, + "use_private_ptr_resolvers": false, + "local_ptr_upstreams": [] + } + }, "local_ptr_upstreams_null": { "req": { "local_ptr_upstreams": null diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 4cb4305b..8af49b0c 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -420,14 +420,8 @@ func (r Reason) Matched() bool { } // CheckHostRules tries to match the host against filtering rules only. -func (d *DNSFilter) CheckHostRules(host string, qtype uint16, setts *Settings) (Result, error) { - if !setts.FilteringEnabled { - return Result{}, nil - } - - host = strings.ToLower(host) - - return d.matchHost(host, qtype, setts) +func (d *DNSFilter) CheckHostRules(host string, rrtype uint16, setts *Settings) (Result, error) { + return d.matchHost(strings.ToLower(host), rrtype, setts) } // CheckHost tries to match the host against filtering rules, then safebrowsing @@ -726,8 +720,7 @@ func hostRulesToRules(netRules []*rules.HostRule) (res []rules.Rule) { return res } -// matchHostProcessAllowList processes the allowlist logic of host -// matching. +// matchHostProcessAllowList processes the allowlist logic of host matching. func (d *DNSFilter) matchHostProcessAllowList( host string, dnsres *urlfilter.DNSResult, @@ -798,11 +791,11 @@ func (d *DNSFilter) matchHostProcessDNSResult( return Result{} } -// matchHost is a low-level way to check only if hostname is filtered by rules, +// matchHost is a low-level way to check only if host is filtered by rules, // skipping expensive safebrowsing and parental lookups. func (d *DNSFilter) matchHost( host string, - qtype uint16, + rrtype uint16, setts *Settings, ) (res Result, err error) { if !setts.FilteringEnabled { @@ -815,7 +808,7 @@ func (d *DNSFilter) matchHost( // TODO(e.burkov): Wait for urlfilter update to pass net.IP. ClientIP: setts.ClientIP.String(), ClientName: setts.ClientName, - DNSType: qtype, + DNSType: rrtype, } d.engineLock.RLock() @@ -855,7 +848,7 @@ func (d *DNSFilter) matchHost( return Result{}, nil } - res = d.matchHostProcessDNSResult(qtype, dnsres) + res = d.matchHostProcessDNSResult(rrtype, dnsres) for _, r := range res.Rules { log.Debug( "filtering: found rule %q for host %q, filter list id: %d", diff --git a/internal/home/clients.go b/internal/home/clients.go index aa5f4778..4ba6b884 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -4,8 +4,6 @@ import ( "bytes" "fmt" "net" - "os/exec" - "runtime" "sort" "strings" "sync" @@ -62,6 +60,16 @@ const ( ClientSourceHostsFile ) +// clientSourceConf is used to configure where the runtime clients will be +// obtained from. +type clientSourcesConf struct { + WHOIS bool `yaml:"whois"` + ARP bool `yaml:"arp"` + RDNS bool `yaml:"rdns"` + DHCP bool `yaml:"dhcp"` + HostsFile bool `yaml:"hosts"` +} + // RuntimeClient information type RuntimeClient struct { WHOISInfo *RuntimeClientWHOISInfo @@ -99,6 +107,9 @@ type clientsContainer struct { // hosts database. etcHosts *aghnet.HostsContainer + // arpdb stores the neighbors retrieved from ARP. + arpdb aghnet.ARPDB + testing bool // if TRUE, this object is used for internal tests } @@ -109,6 +120,7 @@ func (clients *clientsContainer) Init( objects []*clientObject, dhcpServer *dhcpd.Server, etcHosts *aghnet.HostsContainer, + arpdb aghnet.ARPDB, ) { if clients.list != nil { log.Fatal("clients.list != nil") @@ -121,6 +133,7 @@ func (clients *clientsContainer) Init( clients.dhcpServer = dhcpServer clients.etcHosts = etcHosts + clients.arpdb = arpdb clients.addFromConfig(objects) if clients.testing { @@ -132,14 +145,14 @@ func (clients *clientsContainer) Init( clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) } - go clients.handleHostsUpdates() + if clients.etcHosts != nil { + go clients.handleHostsUpdates() + } } func (clients *clientsContainer) handleHostsUpdates() { - if clients.etcHosts != nil { - for upd := range clients.etcHosts.Upd() { - clients.addFromHostsFile(upd) - } + for upd := range clients.etcHosts.Upd() { + clients.addFromHostsFile(upd) } } @@ -156,7 +169,9 @@ func (clients *clientsContainer) Start() { // Reload reloads runtime clients. func (clients *clientsContainer) Reload() { - clients.addFromSystemARP() + if clients.arpdb != nil { + clients.addFromSystemARP() + } } type clientObject struct { @@ -255,6 +270,8 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { } func (clients *clientsContainer) periodicUpdate() { + defer log.OnPanic("clients container") + for { clients.Reload() time.Sleep(clientsUpdatePeriod) @@ -733,6 +750,7 @@ func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clien return false } + rc.Host = host rc.Source = src } else { rc = &RuntimeClient{ @@ -805,16 +823,18 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a // command. func (clients *clientsContainer) addFromSystemARP() { - if runtime.GOOS == "windows" { + if err := clients.arpdb.Refresh(); err != nil { + log.Error("refreshing arp container: %s", err) + + clients.arpdb = aghnet.EmptyARPDB{} + return } - cmd := exec.Command("arp", "-a") - log.Tracef("executing %q %q", cmd.Path, cmd.Args) - data, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Debug("command %q has failed: %q code:%d", - cmd.Path, err, cmd.ProcessState.ExitCode()) + ns := clients.arpdb.Neighbors() + if len(ns) == 0 { + log.Debug("refreshing arp container: the update is empty") + return } @@ -823,36 +843,20 @@ func (clients *clientsContainer) addFromSystemARP() { clients.rmHostsBySrc(ClientSourceARP) - n := 0 - // TODO(a.garipov): Rewrite to use bufio.Scanner. - lines := strings.Split(string(data), "\n") - for _, ln := range lines { - lparen := strings.Index(ln, " (") - rparen := strings.Index(ln, ") ") - if lparen == -1 || rparen == -1 || lparen >= rparen { - continue - } - - host := ln[:lparen] - ipStr := ln[lparen+2 : rparen] - ip := net.ParseIP(ipStr) - if netutil.ValidateDomainName(host) != nil || ip == nil { - continue - } - - ok := clients.addHostLocked(ip, host, ClientSourceARP) - if ok { - n++ + added := 0 + for _, n := range ns { + if clients.addHostLocked(n.IP, n.Name, ClientSourceARP) { + added++ } } - log.Debug("clients: added %d client aliases from 'arp -a' command output", n) + log.Debug("clients: added %d client aliases from arp neighborhood", added) } // updateFromDHCP adds the clients that have a non-empty hostname from the DHCP // server. func (clients *clientsContainer) updateFromDHCP(add bool) { - if clients.dhcpServer == nil { + if clients.dhcpServer == nil || !config.Clients.Sources.DHCP { return } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 93bf3360..629d7c69 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -3,11 +3,12 @@ package home import ( "net" "os" + "runtime" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" - + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,7 +17,7 @@ func TestClients(t *testing.T) { clients := clientsContainer{} clients.testing = true - clients.Init(nil, nil, nil) + clients.Init(nil, nil, nil, nil) t.Run("add_success", func(t *testing.T) { c := &Client{ @@ -192,7 +193,7 @@ func TestClientsWHOIS(t *testing.T) { clients := clientsContainer{ testing: true, } - clients.Init(nil, nil, nil) + clients.Init(nil, nil, nil, nil) whois := &RuntimeClientWHOISInfo{ Country: "AU", Orgname: "Example Org", @@ -251,7 +252,7 @@ func TestClientsAddExisting(t *testing.T) { clients := clientsContainer{ testing: true, } - clients.Init(nil, nil, nil) + clients.Init(nil, nil, nil, nil) t.Run("simple", func(t *testing.T) { ip := net.IP{1, 1, 1, 1} @@ -271,12 +272,18 @@ func TestClientsAddExisting(t *testing.T) { }) t.Run("complicated", func(t *testing.T) { + // TODO(a.garipov): Properly decouple the DHCP server from the client + // storage. + if runtime.GOOS == "windows" { + t.Skip("skipping dhcp test on windows") + } + var err error ip := net.IP{1, 2, 3, 4} // First, init a DHCP server with a single static lease. - config := dhcpd.ServerConfig{ + config := &dhcpd.ServerConfig{ Enabled: true, DBFilePath: "leases.db", Conf4: dhcpd.V4ServerConf{ @@ -290,10 +297,9 @@ func TestClientsAddExisting(t *testing.T) { clients.dhcpServer, err = dhcpd.Create(config) require.NoError(t, err) - // TODO(e.burkov): leases.db isn't created on Windows so removing it - // causes an error. Split the test to make it run properly on different - // operating systems. - t.Cleanup(func() { _ = os.Remove("leases.db") }) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return os.Remove("leases.db") + }) err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, @@ -325,7 +331,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := clientsContainer{ testing: true, } - clients.Init(nil, nil, nil) + clients.Init(nil, nil, nil, nil) // Add client with upstreams. ok, err := clients.Add(&Client{ diff --git a/internal/home/config.go b/internal/home/config.go index 0bc1145f..720683a1 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -51,6 +51,13 @@ type osConfig struct { RlimitNoFile uint64 `yaml:"rlimit_nofile"` } +type clientsConfig struct { + // Sources defines the set of sources to fetch the runtime clients from. + Sources *clientSourcesConf `yaml:"runtime_sources"` + // Persistent are the configured clients. + Persistent []*clientObject `yaml:"persistent"` +} + // configuration is loaded from YAML // field ordering is important -- yaml fields will mirror ordering from here type configuration struct { @@ -83,12 +90,12 @@ type configuration struct { WhitelistFilters []filter `yaml:"whitelist_filters"` UserRules []string `yaml:"user_rules"` - DHCP dhcpd.ServerConfig `yaml:"dhcp"` + DHCP *dhcpd.ServerConfig `yaml:"dhcp"` // Clients contains the YAML representations of the persistent clients. // This field is only used for reading and writing persistent client data. // Keep this field sorted to ensure consistent ordering. - Clients []*clientObject `yaml:"clients"` + Clients *clientsConfig `yaml:"clients"` logSettings `yaml:",inline"` @@ -123,13 +130,9 @@ type dnsConfig struct { // UpstreamTimeout is the timeout for querying upstream servers. UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"` - // LocalDomainName is the domain name used for known internal hosts. - // For example, a machine called "myhost" can be addressed as - // "myhost.lan" when LocalDomainName is "lan". - LocalDomainName string `yaml:"local_domain_name"` - - // ResolveClients enables and disables resolving clients with RDNS. - ResolveClients bool `yaml:"resolve_clients"` + // PrivateNets is the set of IP networks for which the private reverse DNS + // resolver should be used. + PrivateNets []string `yaml:"private_networks"` // UsePrivateRDNS defines if the PTR requests for unknown addresses from // locally-served networks should be resolved via private PTR resolvers. @@ -199,8 +202,6 @@ var config = &configuration{ FilteringEnabled: true, // whether or not use filter lists FiltersUpdateIntervalHours: 24, UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout}, - LocalDomainName: "lan", - ResolveClients: true, UsePrivateRDNS: true, }, TLS: tlsConfigSettings{ @@ -208,6 +209,18 @@ var config = &configuration{ PortDNSOverTLS: defaultPortTLS, // needs to be passed through to dnsproxy PortDNSOverQUIC: defaultPortQUIC, }, + DHCP: &dhcpd.ServerConfig{ + LocalDomainName: "lan", + }, + Clients: &clientsConfig{ + Sources: &clientSourcesConf{ + WHOIS: true, + ARP: true, + RDNS: true, + DHCP: true, + HostsFile: true, + }, + }, logSettings: logSettings{ LogCompress: false, LogLocalTime: false, @@ -403,18 +416,16 @@ func (c *configuration) write() error { s.WriteDiskConfig(&c) dns := &config.DNS dns.FilteringConfig = c - dns.LocalPTRResolvers, - dns.ResolveClients, - dns.UsePrivateRDNS = s.RDNSSettings() + dns.LocalPTRResolvers, config.Clients.Sources.RDNS, dns.UsePrivateRDNS = s.RDNSSettings() } if Context.dhcpServer != nil { - c := dhcpd.ServerConfig{} - Context.dhcpServer.WriteDiskConfig(&c) + c := &dhcpd.ServerConfig{} + Context.dhcpServer.WriteDiskConfig(c) config.DHCP = c } - config.Clients = Context.clients.forConfig() + config.Clients.Persistent = Context.clients.forConfig() configFile := config.getConfigFilename() log.Debug("Writing YAML file: %s", configFile) diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 76b8e28d..0435651b 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -13,6 +13,7 @@ import ( "runtime" "strings" "time" + "unicode/utf8" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" @@ -359,6 +360,9 @@ func shutdownSrv(ctx context.Context, srv *http.Server) { } } +// PasswordMinRunes is the minimum length of user's password in runes. +const PasswordMinRunes = 8 + // Apply new configuration, start DNS server, restart Web server func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { req, restartHTTP, err := decodeApplyConfigReq(r.Body) @@ -368,6 +372,18 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { return } + if utf8.RuneCountInString(req.Password) < PasswordMinRunes { + aghhttp.Error( + r, + w, + http.StatusUnprocessableEntity, + "password must be at least %d symbols long", + PasswordMinRunes, + ) + + return + } + err = aghnet.CheckPort("udp", req.DNS.IP, req.DNS.Port) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/dns.go b/internal/home/dns.go index ecf7bf19..71ab6b9c 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -77,13 +77,36 @@ func initDNSServer() (err error) { filterConf.HTTPRegister = httpRegister Context.dnsFilter = filtering.New(&filterConf, nil) + var privateNets netutil.SubnetSet + switch len(config.DNS.PrivateNets) { + case 0: + // Use an optimized locally-served matcher. + privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) + case 1: + var n *net.IPNet + n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) + if err != nil { + return fmt.Errorf("preparing the set of private subnets: %w", err) + } + + privateNets = n + default: + var nets []*net.IPNet + nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) + if err != nil { + return fmt.Errorf("preparing the set of private subnets: %w", err) + } + + privateNets = netutil.SliceSubnetSet(nets) + } + p := dnsforward.DNSCreateParams{ - DNSFilter: Context.dnsFilter, - Stats: Context.stats, - QueryLog: Context.queryLog, - SubnetDetector: Context.subnetDetector, - Anonymizer: anonymizer, - LocalDomain: config.DNS.LocalDomainName, + DNSFilter: Context.dnsFilter, + Stats: Context.stats, + QueryLog: Context.queryLog, + PrivateNets: privateNets, + Anonymizer: anonymizer, + LocalDomain: config.DHCP.LocalDomainName, } if Context.dhcpServer != nil { p.DHCPServer = Context.dhcpServer @@ -112,8 +135,13 @@ func initDNSServer() (err error) { return fmt.Errorf("dnsServer.Prepare: %w", err) } - Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS) - Context.whois = initWHOIS(&Context.clients) + if config.Clients.Sources.RDNS { + Context.rdns = NewRDNS(Context.dnsServer, &Context.clients, config.DNS.UsePrivateRDNS) + } + + if config.Clients.Sources.WHOIS { + Context.whois = initWHOIS(&Context.clients) + } Context.filters.Init() return nil @@ -130,10 +158,11 @@ func onDNSRequest(pctx *proxy.DNSContext) { return } - if config.DNS.ResolveClients && !ip.IsLoopback() { + srcs := config.Clients.Sources + if srcs.RDNS && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.subnetDetector.IsSpecialNetwork(ip) { + if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) { Context.whois.Begin(ip) } } @@ -192,6 +221,10 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { newConf.TLSConfig = tlsConf.TLSConfig newConf.TLSConfig.ServerName = tlsConf.ServerName + if tlsConf.PortHTTPS != 0 { + newConf.HTTPSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortHTTPS) + } + if tlsConf.PortDNSOverTLS != 0 { newConf.TLSListenAddrs = ipsToTCPAddrs(hosts, tlsConf.PortDNSOverTLS) } @@ -217,7 +250,7 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { newConf.FilterHandler = applyAdditionalFiltering newConf.GetCustomUpstreamByClient = Context.clients.findUpstreams - newConf.ResolveClients = dnsConf.ResolveClients + newConf.ResolveClients = config.Clients.Sources.RDNS newConf.UsePrivateRDNS = dnsConf.UsePrivateRDNS newConf.LocalPTRResolvers = dnsConf.LocalPTRResolvers newConf.UpstreamTimeout = dnsConf.UpstreamTimeout.Duration @@ -365,10 +398,15 @@ func startDNSServer() error { const topClientsNumber = 100 // the number of clients to get for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { - if config.DNS.ResolveClients && !ip.IsLoopback() { + if ip == nil { + continue + } + + srcs := config.Clients.Sources + if srcs.RDNS && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.subnetDetector.IsSpecialNetwork(ip) { + if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) { Context.whois.Begin(ip) } } diff --git a/internal/home/home.go b/internal/home/home.go index b6d3c223..1b0c7122 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -65,8 +65,6 @@ type homeContext struct { updater *updater.Updater - subnetDetector *aghnet.SubnetDetector - // mux is our custom http.ServeMux. mux *http.ServeMux @@ -175,6 +173,11 @@ func setupContext(args options) { os.Exit(0) } + + if !args.noEtcHosts && config.Clients.Sources.HostsFile { + err = setupHostsContainer() + fatalOnError(err) + } } Context.mux = http.NewServeMux() @@ -182,7 +185,7 @@ func setupContext(args options) { // logIfUnsupported logs a formatted warning if the error is one of the // unsupported errors and returns nil. If err is nil, logIfUnsupported returns -// nil. Otherise, it returns err. +// nil. Otherwise, it returns err. func logIfUnsupported(msg string, err error) (outErr error) { if errors.As(err, new(*aghos.UnsupportedError)) { log.Debug(msg, err) @@ -287,13 +290,12 @@ func setupConfig(args options) (err error) { ConfName: config.getConfigFilename(), }) - if !args.noEtcHosts { - if err = setupHostsContainer(); err != nil { - return err - } + var arpdb aghnet.ARPDB + if config.Clients.Sources.ARP { + arpdb = aghnet.NewARPDB() } - Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) + Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb) if args.bindPort != 0 { uc := aghalg.UniqChecker{} @@ -466,9 +468,6 @@ func run(args options, clientBuildFS fs.FS) { Context.web, err = initWeb(args, clientBuildFS) fatalOnError(err) - Context.subnetDetector, err = aghnet.NewSubnetDetector() - fatalOnError(err) - if !Context.firstRun { err = initDNSServer() fatalOnError(err) @@ -527,8 +526,10 @@ func checkPermissions() { if err != nil { if errors.Is(err, os.ErrPermission) { log.Fatal(`Permission check failed. + AdGuard Home is not allowed to bind to privileged ports (for instance, port 53). Please note, that this is crucial for a server to be able to use privileged ports. + You have two options: 1. Run AdGuard Home with root privileges 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: diff --git a/internal/home/options.go b/internal/home/options.go index dc11ca35..6f5a4d8d 100644 --- a/internal/home/options.go +++ b/internal/home/options.go @@ -230,13 +230,19 @@ var helpArg = arg{ } var noEtcHostsArg = arg{ - description: "Do not use the OS-provided hosts.", + description: "Deprecated. Do not use the OS-provided hosts.", longName: "no-etc-hosts", shortName: "", updateWithValue: nil, updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil }, - effect: nil, - serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) }, + effect: func(_ options, _ string) (f effect, err error) { + log.Info( + "warning: --no-etc-hosts flag is deprecated and will be removed in the future versions", + ) + + return nil, nil + }, + serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) }, } var localFrontendArg = arg{ diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 202f9f5f..08f4f013 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -167,7 +167,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - locUpstream := &aghtest.TestUpstream{ + locUpstream := &aghtest.Upstream{ Reverse: map[string][]string{ "192.168.1.1": {"local.domain"}, "2a00:1450:400c:c06::93": {"ipv6.domain"}, diff --git a/internal/home/upgrade.go b/internal/home/upgrade.go index 89d41556..d9611dc9 100644 --- a/internal/home/upgrade.go +++ b/internal/home/upgrade.go @@ -21,9 +21,11 @@ import ( ) // currentSchemaVersion is the current schema version. -const currentSchemaVersion = 12 +const currentSchemaVersion = 14 // These aliases are provided for convenience. +// +// TODO(e.burkov): Remove any after updating to Go 1.18. type ( any = interface{} yarr = []any @@ -85,6 +87,8 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) { upgradeSchema9to10, upgradeSchema10to11, upgradeSchema11to12, + upgradeSchema12to13, + upgradeSchema13to14, } n := 0 @@ -690,6 +694,114 @@ func upgradeSchema11to12(diskConf yobj) (err error) { return nil } +// upgradeSchema12to13 performs the following changes: +// +// # BEFORE: +// 'dns': +// # … +// 'local_domain_name': 'lan' +// +// # AFTER: +// 'dhcp': +// # … +// 'local_domain_name': 'lan' +// +func upgradeSchema12to13(diskConf yobj) (err error) { + log.Printf("Upgrade yaml: 12 to 13") + diskConf["schema_version"] = 13 + + dnsVal, ok := diskConf["dns"] + if !ok { + return nil + } + + var dns yobj + dns, ok = dnsVal.(yobj) + if !ok { + return fmt.Errorf("unexpected type of dns: %T", dnsVal) + } + + dhcpVal, ok := diskConf["dhcp"] + if !ok { + return nil + } + + var dhcp yobj + dhcp, ok = dhcpVal.(yobj) + if !ok { + return fmt.Errorf("unexpected type of dhcp: %T", dhcpVal) + } + + const field = "local_domain_name" + + dhcp[field] = dns[field] + delete(dns, field) + + return nil +} + +// upgradeSchema13to14 performs the following changes: +// +// # BEFORE: +// 'clients': +// - 'name': 'client-name' +// # … +// +// # AFTER: +// 'clients': +// 'persistent': +// - 'name': 'client-name' +// # … +// 'runtime_sources': +// 'whois': true +// 'arp': true +// 'rdns': true +// 'dhcp': true +// 'hosts': true +// +func upgradeSchema13to14(diskConf yobj) (err error) { + log.Printf("Upgrade yaml: 13 to 14") + diskConf["schema_version"] = 14 + + clientsVal, ok := diskConf["clients"] + if !ok { + clientsVal = yarr{} + } + + var rdnsSrc bool + if dnsVal, dok := diskConf["dns"]; dok { + var dnsSettings yobj + dnsSettings, ok = dnsVal.(yobj) + if !ok { + return fmt.Errorf("unexpected type of dns: %T", dnsVal) + } + + var rdnsSrcVal any + rdnsSrcVal, ok = dnsSettings["resolve_clients"] + if ok { + rdnsSrc, ok = rdnsSrcVal.(bool) + if !ok { + return fmt.Errorf("unexpected type of resolve_clients: %T", rdnsSrcVal) + } + + delete(dnsSettings, "resolve_clients") + } + } + + diskConf["clients"] = yobj{ + "persistent": clientsVal, + "runtime_sources": &clientSourcesConf{ + WHOIS: true, + ARP: true, + RDNS: rdnsSrc, + DHCP: true, + HostsFile: true, + }, + } + + return nil +} + // TODO(a.garipov): Replace with log.Output when we port it to our logging // package. func funcName() string { diff --git a/internal/home/upgrade_test.go b/internal/home/upgrade_test.go index 171ce3b2..4c25cba3 100644 --- a/internal/home/upgrade_test.go +++ b/internal/home/upgrade_test.go @@ -55,7 +55,7 @@ func TestUpgradeSchema2to3(t *testing.T) { require.Len(t, v, 1) require.Equal(t, "8.8.8.8:53", v[0]) default: - t.Fatalf("wrong type for bootsrap dns: %T", v) + t.Fatalf("wrong type for bootstrap dns: %T", v) } excludedEntries := []string{"bootstrap_dns"} @@ -511,3 +511,131 @@ func TestUpgradeSchema11to12(t *testing.T) { assert.Equal(t, 90*24*time.Hour, ivlVal.Duration) }) } + +func TestUpgradeSchema12to13(t *testing.T) { + const newSchemaVer = 13 + + testCases := []struct { + in yobj + want yobj + name string + }{{ + in: yobj{}, + want: yobj{"schema_version": newSchemaVer}, + name: "no_dns", + }, { + in: yobj{"dns": yobj{}}, + want: yobj{ + "dns": yobj{}, + "schema_version": newSchemaVer, + }, + name: "no_dhcp", + }, { + in: yobj{ + "dns": yobj{ + "local_domain_name": "lan", + }, + "dhcp": yobj{}, + "schema_version": newSchemaVer - 1, + }, + want: yobj{ + "dns": yobj{}, + "dhcp": yobj{ + "local_domain_name": "lan", + }, + "schema_version": newSchemaVer, + }, + name: "good", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := upgradeSchema12to13(tc.in) + require.NoError(t, err) + + assert.Equal(t, tc.want, tc.in) + }) + } +} + +func TestUpgradeSchema13to14(t *testing.T) { + const newSchemaVer = 14 + + testClient := &clientObject{ + Name: "agh-client", + IDs: []string{"id1"}, + UseGlobalSettings: true, + } + + testCases := []struct { + in yobj + want yobj + name string + }{{ + in: yobj{}, + want: yobj{ + "schema_version": newSchemaVer, + // The clients field will be added anyway. + "clients": yobj{ + "persistent": yarr{}, + "runtime_sources": &clientSourcesConf{ + WHOIS: true, + ARP: true, + RDNS: false, + DHCP: true, + HostsFile: true, + }, + }, + }, + name: "no_clients", + }, { + in: yobj{ + "clients": []*clientObject{testClient}, + }, + want: yobj{ + "schema_version": newSchemaVer, + "clients": yobj{ + "persistent": []*clientObject{testClient}, + "runtime_sources": &clientSourcesConf{ + WHOIS: true, + ARP: true, + RDNS: false, + DHCP: true, + HostsFile: true, + }, + }, + }, + name: "no_dns", + }, { + in: yobj{ + "clients": []*clientObject{testClient}, + "dns": yobj{ + "resolve_clients": true, + }, + }, + want: yobj{ + "schema_version": newSchemaVer, + "clients": yobj{ + "persistent": []*clientObject{testClient}, + "runtime_sources": &clientSourcesConf{ + WHOIS: true, + ARP: true, + RDNS: true, + DHCP: true, + HostsFile: true, + }, + }, + "dns": yobj{}, + }, + name: "good", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := upgradeSchema13to14(tc.in) + require.NoError(t, err) + + assert.Equal(t, tc.want, tc.in) + }) + } +} diff --git a/internal/home/web.go b/internal/home/web.go index 54fb1324..60af60be 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -34,14 +34,13 @@ const ( ) type webConfig struct { + clientFS fs.FS + clientBetaFS fs.FS + BindHost net.IP BindPort int BetaBindPort int PortHTTPS int - firstRun bool - - clientFS fs.FS - clientBetaFS fs.FS // ReadTimeout is an option to pass to http.Server for setting an // appropriate field. @@ -54,6 +53,8 @@ type webConfig struct { // WriteTimeout is an option to pass to http.Server for setting an // appropriate field. WriteTimeout time.Duration + + firstRun bool } // HTTPSServer - HTTPS Server diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 67498be5..b7641352 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,14 +4,14 @@ ## v0.108.0: API changes +## v0.107.7: API changes + ### The new possible status code in `/install/configure` response. * The new status code `422 Unprocessable Entity` in the response for `POST /install/configure` which means that the specified password does not meet the strength requirements. -## v0.107.7: API changes - ### The new optional field `"ecs"` in `QueryLogItem` * The new optional field `"ecs"` in `GET /control/querylog` contains the IP diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index e4ddf6b1..8b21a01f 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1088,6 +1088,9 @@ 'description': > Failed to parse initial configuration or cannot listen to the specified addresses. + '422': + 'description': > + The specified password does not meet the strength requirements. '500': 'description': 'Cannot start the DNS server' '/login':