diff --git a/CHANGELOG.md b/CHANGELOG.md index 48aa24d7..d663d78e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to ### Added +- The ability to customize the set of networks considered private through the + new `private_networks` setting ([#3142]). - EDNS Client-Subnet information in the request details section of a query log record ([#3978]). - Support for hostnames for plain UDP upstream servers using the `udp://` scheme @@ -88,6 +90,7 @@ In this release, the schema version has changed from 12 to 13. [#1730]: https://github.com/AdguardTeam/AdGuardHome/issues/1730 [#2993]: https://github.com/AdguardTeam/AdGuardHome/issues/2993 [#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057 +[#3142]: https://github.com/AdguardTeam/AdGuardHome/issues/3142 [#3367]: https://github.com/AdguardTeam/AdGuardHome/issues/3367 [#3381]: https://github.com/AdguardTeam/AdGuardHome/issues/3381 [#3503]: https://github.com/AdguardTeam/AdGuardHome/issues/3503 diff --git a/go.mod b/go.mod index 1aedfc38..8cfd5475 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.17 require ( github.com/AdguardTeam/dnsproxy v0.41.4 - github.com/AdguardTeam/golibs v0.10.6 + github.com/AdguardTeam/golibs v0.10.8 github.com/AdguardTeam/urlfilter v0.15.2 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.3 diff --git a/go.sum b/go.sum index 716b6955..5e5ef775 100644 --- a/go.sum +++ b/go.sum @@ -12,8 +12,9 @@ github.com/AdguardTeam/dnsproxy v0.41.4/go.mod h1:GCdEbTw683vBqksJIccPSYzBg2yIFb github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= -github.com/AdguardTeam/golibs v0.10.6 h1:6UG6LxWFnG7TfjNzeApw+T68Kqqov0fcDYk9RjhTdhc= github.com/AdguardTeam/golibs v0.10.6/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= +github.com/AdguardTeam/golibs v0.10.8 h1:diU9gP9qG1qeLbAkzIwfUerpHSqzR6zaBgzvRMR/m6Q= +github.com/AdguardTeam/golibs v0.10.8/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.15.2 h1:LZGgrm4l4Ys9eAqB+UUmZfiC6vHlDlYFhx0WXqo6LtQ= github.com/AdguardTeam/urlfilter v0.15.2/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= diff --git a/internal/aghnet/subnetdetector.go b/internal/aghnet/subnetdetector.go deleted file mode 100644 index fd338aaa..00000000 --- a/internal/aghnet/subnetdetector.go +++ /dev/null @@ -1,158 +0,0 @@ -package aghnet - -import ( - "net" -) - -// SubnetDetector describes IP address properties. -type SubnetDetector struct { - // spNets is the collection of special-purpose address registries as defined - // by RFC 6890. - spNets []*net.IPNet - - // locServedNets is the collection of locally-served networks as defined by - // RFC 6303. - locServedNets []*net.IPNet -} - -// NewSubnetDetector returns a new IP detector. -// -// TODO(a.garipov): Decide whether an error is actually needed. -func NewSubnetDetector() (snd *SubnetDetector, err error) { - spNets := []string{ - // "This" network. - "0.0.0.0/8", - // Private-Use Networks. - "10.0.0.0/8", - // Shared Address Space. - "100.64.0.0/10", - // Loopback. - "127.0.0.0/8", - // Link Local. - "169.254.0.0/16", - // Private-Use Networks. - "172.16.0.0/12", - // IETF Protocol Assignments. - "192.0.0.0/24", - // DS-Lite. - "192.0.0.0/29", - // TEST-NET-1 - "192.0.2.0/24", - // 6to4 Relay Anycast. - "192.88.99.0/24", - // Private-Use Networks. - "192.168.0.0/16", - // Network Interconnect Device Benchmark Testing. - "198.18.0.0/15", - // TEST-NET-2. - "198.51.100.0/24", - // TEST-NET-3. - "203.0.113.0/24", - // Reserved for Future Use. - "240.0.0.0/4", - // Limited Broadcast. - "255.255.255.255/32", - - // Loopback. - "::1/128", - // Unspecified. - "::/128", - // IPv4-IPv6 Translation Address. - "64:ff9b::/96", - - // IPv4-Mapped Address. Since this network is used for mapping - // IPv4 addresses, we don't include it. - // "::ffff:0:0/96", - - // Discard-Only Prefix. - "100::/64", - // IETF Protocol Assignments. - "2001::/23", - // TEREDO. - "2001::/32", - // Benchmarking. - "2001:2::/48", - // Documentation. - "2001:db8::/32", - // ORCHID. - "2001:10::/28", - // 6to4. - "2002::/16", - // Unique-Local. - "fc00::/7", - // Linked-Scoped Unicast. - "fe80::/10", - } - - // TODO(e.burkov): It's a subslice of the slice above. Should be done - // smarter. - locServedNets := []string{ - // IPv4. - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "127.0.0.0/8", - "169.254.0.0/16", - "192.0.2.0/24", - "198.51.100.0/24", - "203.0.113.0/24", - "255.255.255.255/32", - // IPv6. - "::/128", - "::1/128", - "fe80::/10", - "2001:db8::/32", - "fd00::/8", - } - - snd = &SubnetDetector{ - spNets: make([]*net.IPNet, len(spNets)), - locServedNets: make([]*net.IPNet, len(locServedNets)), - } - for i, ipnetStr := range spNets { - var ipnet *net.IPNet - _, ipnet, err = net.ParseCIDR(ipnetStr) - if err != nil { - return nil, err - } - - snd.spNets[i] = ipnet - } - for i, ipnetStr := range locServedNets { - var ipnet *net.IPNet - _, ipnet, err = net.ParseCIDR(ipnetStr) - if err != nil { - return nil, err - } - - snd.locServedNets[i] = ipnet - } - - return snd, nil -} - -// anyNetContains ranges through the given ipnets slice searching for the one -// which contains the ip. For internal use only. -// -// TODO(e.burkov): Think about memoization. -func anyNetContains(ipnets *[]*net.IPNet, ip net.IP) (is bool) { - for _, ipnet := range *ipnets { - if ipnet.Contains(ip) { - return true - } - } - - return false -} - -// IsSpecialNetwork returns true if IP address is contained by any of -// special-purpose IP address registries. It's safe for concurrent use. -func (snd *SubnetDetector) IsSpecialNetwork(ip net.IP) (is bool) { - return anyNetContains(&snd.spNets, ip) -} - -// IsLocallyServedNetwork returns true if IP address is contained by any of -// locally-served IP address registries. It's safe for concurrent use. -func (snd *SubnetDetector) IsLocallyServedNetwork(ip net.IP) (is bool) { - return anyNetContains(&snd.locServedNets, ip) -} diff --git a/internal/aghnet/subnetdetector_test.go b/internal/aghnet/subnetdetector_test.go deleted file mode 100644 index f4b7678c..00000000 --- a/internal/aghnet/subnetdetector_test.go +++ /dev/null @@ -1,252 +0,0 @@ -package aghnet - -import ( - "net" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSubnetDetector_DetectSpecialNetwork(t *testing.T) { - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testCases := []struct { - name string - ip net.IP - want bool - }{{ - name: "not_specific", - ip: net.ParseIP("8.8.8.8"), - want: false, - }, { - name: "this_host_on_this_network", - ip: net.ParseIP("0.0.0.0"), - want: true, - }, { - name: "private-Use", - ip: net.ParseIP("10.0.0.0"), - want: true, - }, { - name: "shared_address_space", - ip: net.ParseIP("100.64.0.0"), - want: true, - }, { - name: "loopback", - ip: net.ParseIP("127.0.0.0"), - want: true, - }, { - name: "link_local", - ip: net.ParseIP("169.254.0.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("172.16.0.0"), - want: true, - }, { - name: "ietf_protocol_assignments", - ip: net.ParseIP("192.0.0.0"), - want: true, - }, { - name: "ds-lite", - ip: net.ParseIP("192.0.0.0"), - want: true, - }, { - name: "documentation_(test-net-1)", - ip: net.ParseIP("192.0.2.0"), - want: true, - }, { - name: "6to4_relay_anycast", - ip: net.ParseIP("192.88.99.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("192.168.0.0"), - want: true, - }, { - name: "benchmarking", - ip: net.ParseIP("198.18.0.0"), - want: true, - }, { - name: "documentation_(test-net-2)", - ip: net.ParseIP("198.51.100.0"), - want: true, - }, { - name: "documentation_(test-net-3)", - ip: net.ParseIP("203.0.113.0"), - want: true, - }, { - name: "reserved", - ip: net.ParseIP("240.0.0.0"), - want: true, - }, { - name: "limited_broadcast", - ip: net.ParseIP("255.255.255.255"), - want: true, - }, { - name: "loopback_address", - ip: net.ParseIP("::1"), - want: true, - }, { - name: "unspecified_address", - ip: net.ParseIP("::"), - want: true, - }, { - name: "ipv4-ipv6_translation", - ip: net.ParseIP("64:ff9b::"), - want: true, - }, { - name: "discard-only_address_block", - ip: net.ParseIP("100::"), - want: true, - }, { - name: "ietf_protocol_assignments", - ip: net.ParseIP("2001::"), - want: true, - }, { - name: "teredo", - ip: net.ParseIP("2001::"), - want: true, - }, { - name: "benchmarking", - ip: net.ParseIP("2001:2::"), - want: true, - }, { - name: "documentation", - ip: net.ParseIP("2001:db8::"), - want: true, - }, { - name: "orchid", - ip: net.ParseIP("2001:10::"), - want: true, - }, { - name: "6to4", - ip: net.ParseIP("2002::"), - want: true, - }, { - name: "unique-local", - ip: net.ParseIP("fc00::"), - want: true, - }, { - name: "linked-scoped_unicast", - ip: net.ParseIP("fe80::"), - want: true, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, snd.IsSpecialNetwork(tc.ip)) - }) - } -} - -func TestSubnetDetector_DetectLocallyServedNetwork(t *testing.T) { - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testCases := []struct { - name string - ip net.IP - want bool - }{{ - name: "not_specific", - ip: net.ParseIP("8.8.8.8"), - want: false, - }, { - name: "private-Use", - ip: net.ParseIP("10.0.0.0"), - want: true, - }, { - name: "loopback", - ip: net.ParseIP("127.0.0.0"), - want: true, - }, { - name: "link_local", - ip: net.ParseIP("169.254.0.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("172.16.0.0"), - want: true, - }, { - name: "documentation_(test-net-1)", - ip: net.ParseIP("192.0.2.0"), - want: true, - }, { - name: "private-use", - ip: net.ParseIP("192.168.0.0"), - want: true, - }, { - name: "documentation_(test-net-2)", - ip: net.ParseIP("198.51.100.0"), - want: true, - }, { - name: "documentation_(test-net-3)", - ip: net.ParseIP("203.0.113.0"), - want: true, - }, { - name: "limited_broadcast", - ip: net.ParseIP("255.255.255.255"), - want: true, - }, { - name: "loopback_address", - ip: net.ParseIP("::1"), - want: true, - }, { - name: "unspecified_address", - ip: net.ParseIP("::"), - want: true, - }, { - name: "documentation", - ip: net.ParseIP("2001:db8::"), - want: true, - }, { - name: "linked-scoped_unicast", - ip: net.ParseIP("fe80::"), - want: true, - }, { - name: "locally_assigned", - ip: net.ParseIP("fd00::1"), - want: true, - }, { - name: "not_locally_assigned", - ip: net.ParseIP("fc00::1"), - want: false, - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, snd.IsLocallyServedNetwork(tc.ip)) - }) - } -} - -func TestSubnetDetector_Detect_parallel(t *testing.T) { - t.Parallel() - - snd, err := NewSubnetDetector() - require.NoError(t, err) - - testFunc := func() { - for _, ip := range []net.IP{ - net.IPv4allrouter, - net.IPv4allsys, - net.IPv4bcast, - net.IPv4zero, - net.IPv6interfacelocalallnodes, - net.IPv6linklocalallnodes, - net.IPv6linklocalallrouters, - net.IPv6loopback, - net.IPv6unspecified, - } { - _ = snd.IsSpecialNetwork(ip) - _ = snd.IsLocallyServedNetwork(ip) - } - } - - const goroutinesNum = 50 - for i := 0; i < goroutinesNum; i++ { - go testFunc() - } -} diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index dd5a4dd5..63f694bd 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -252,7 +252,7 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { return rc } - dctx.isLocalClient = s.subnetDetector.IsLocallyServedNetwork(ip) + dctx.isLocalClient = s.privateNets.Contains(ip) return rc } @@ -374,7 +374,7 @@ func (s *Server) processRestrictLocal(ctx *dnsContext) (rc resultCode) { // Restrict an access to local addresses for external clients. We also // assume that all the DHCP leases we give are locally-served or at least // don't need to be inaccessible externally. - if !s.subnetDetector.IsLocallyServedNetwork(ip) { + if !s.privateNets.Contains(ip) { log.Debug("dns: addr %s is not from locally-served network", ip) return resultCodeSuccess @@ -481,7 +481,7 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { s.serverLock.RLock() defer s.serverLock.RUnlock() - if !s.subnetDetector.IsLocallyServedNetwork(ip) { + if !s.privateNets.Contains(ip) { return resultCodeSuccess } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 4fc87ccf..54104268 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -4,35 +4,41 @@ import ( "net" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestServer_ProcessDetermineLocal(t *testing.T) { - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) s := &Server{ - subnetDetector: snd, + privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), } testCases := []struct { + want assert.BoolAssertionFunc name string cliIP net.IP - want bool }{{ + want: assert.True, name: "local", cliIP: net.IP{192, 168, 0, 1}, - want: true, }, { + want: assert.False, name: "external", cliIP: net.IP{250, 249, 0, 1}, - want: false, + }, { + want: assert.False, + name: "invalid", + cliIP: net.IP{1, 2, 3, 4, 5}, + }, { + want: assert.False, + name: "nil", + cliIP: nil, }} for _, tc := range testCases { @@ -47,7 +53,7 @@ func TestServer_ProcessDetermineLocal(t *testing.T) { } s.processDetermineLocal(dctx) - assert.Equal(t, tc.want, dctx.isLocalClient) + tc.want(t, dctx.isLocalClient) }) } } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 2d32cfd2..48e344b3 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -74,7 +74,7 @@ type Server struct { localDomainSuffix string ipset ipsetCtx - subnetDetector *aghnet.SubnetDetector + privateNets netutil.SubnetSet localResolvers *proxy.Proxy sysResolvers aghnet.SystemResolvers recDetector *recursionDetector @@ -111,13 +111,13 @@ const defaultLocalDomainSuffix = ".lan." // DNSCreateParams are parameters to create a new server. type DNSCreateParams struct { - DNSFilter *filtering.DNSFilter - Stats stats.Stats - QueryLog querylog.QueryLog - DHCPServer dhcpd.ServerInterface - SubnetDetector *aghnet.SubnetDetector - Anonymizer *aghnet.IPMut - LocalDomain string + DNSFilter *filtering.DNSFilter + Stats stats.Stats + QueryLog querylog.QueryLog + DHCPServer dhcpd.ServerInterface + PrivateNets netutil.SubnetSet + Anonymizer *aghnet.IPMut + LocalDomain string } // domainNameToSuffix converts a domain name into a local domain suffix. @@ -161,7 +161,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { dnsFilter: p.DNSFilter, stats: p.Stats, queryLog: p.QueryLog, - subnetDetector: p.SubnetDetector, + privateNets: p.PrivateNets, localDomainSuffix: localDomainSuffix, recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ @@ -315,7 +315,7 @@ func (s *Server) Exchange(ip net.IP) (host string, err error) { } resolver := s.internalProxy - if s.subnetDetector.IsLocallyServedNetwork(ip) { + if s.privateNets.Contains(ip) { if !s.conf.UsePrivateRDNS { return "", nil } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index bc90d760..36761f41 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -24,6 +24,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/miekg/dns" @@ -69,14 +70,11 @@ func createTestServer( f := filtering.New(filterConf, filters) f.SetEnabled(true) - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - + var err error s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -770,16 +768,11 @@ func TestBlockedCustomIP(t *testing.T) { Data: []byte(rules), }} - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - f := filtering.New(&filtering.Config{}, filters) - var s *Server - s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -913,15 +906,10 @@ func TestRewrite(t *testing.T) { f := filtering.New(c, nil) f.SetEnabled(true) - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - - var s *Server - s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -1028,15 +1016,10 @@ func (d *testDHCP) Leases(flags dhcpd.GetLeasesFlags) (leases []*dhcpd.Lease) { func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {} func TestPTRResponseFromDHCPLeases(t *testing.T) { - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - - var s *Server - s, err = NewServer(DNSCreateParams{ - DNSFilter: filtering.New(&filtering.Config{}, nil), - DHCPServer: &testDHCP{}, - SubnetDetector: snd, + s, err := NewServer(DNSCreateParams{ + DNSFilter: filtering.New(&filtering.Config{}, nil), + DHCPServer: &testDHCP{}, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -1105,16 +1088,11 @@ func TestPTRResponseFromHosts(t *testing.T) { }, nil) flt.SetEnabled(true) - var snd *aghnet.SubnetDetector - snd, err = aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - var s *Server s, err = NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: flt, - SubnetDetector: snd, + DHCPServer: &testDHCP{}, + DNSFilter: flt, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) @@ -1227,9 +1205,7 @@ func TestServer_Exchange(t *testing.T) { srv.conf.ResolveClients = true srv.conf.UsePrivateRDNS = true - var err error - srv.subnetDetector, err = aghnet.NewSubnetDetector() - require.NoError(t, err) + srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) localIP := net.IP{192, 168, 1, 1} testCases := []struct { diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 84570bce..69dcf8f9 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -4,7 +4,6 @@ import ( "net" "testing" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" @@ -39,14 +38,10 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { f := filtering.New(&filtering.Config{}, filters) f.SetEnabled(true) - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) - require.NotNil(t, snd) - s, err := NewServer(DNSCreateParams{ - DHCPServer: &testDHCP{}, - DNSFilter: f, - SubnetDetector: snd, + DHCPServer: &testDHCP{}, + DNSFilter: f, + PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), }) require.NoError(t, err) diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index b7fb66a6..2b7cfd13 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -10,7 +10,6 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" @@ -167,7 +166,7 @@ func (req *dnsConfig) checkBootstrap() (err error) { } // validate returns an error if any field of req is invalid. -func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) { +func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) { if req.Upstreams != nil { err = ValidateUpstreams(*req.Upstreams) if err != nil { @@ -176,7 +175,7 @@ func (req *dnsConfig) validate(snd *aghnet.SubnetDetector) (err error) { } if req.LocalPTRUpstreams != nil { - err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, snd) + err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) if err != nil { return fmt.Errorf("validating private upstream servers: %w", err) } @@ -224,7 +223,7 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { return } - err = req.validate(s.subnetDetector) + err = req.validate(s.privateNets) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -350,17 +349,6 @@ func IsCommentOrEmpty(s string) (ok bool) { return len(s) == 0 || s[0] == '#' } -// LocalNetChecker is used to check if the IP address belongs to a local -// network. -type LocalNetChecker interface { - // IsLocallyServedNetwork returns true if ip is contained in any of address - // registries defined by RFC 6303. - IsLocallyServedNetwork(ip net.IP) (ok bool) -} - -// type check -var _ LocalNetChecker = (*aghnet.SubnetDetector)(nil) - // newUpstreamConfig validates upstreams and returns an appropriate upstream // configuration or nil if it can't be built. // @@ -422,8 +410,8 @@ func stringKeysSorted(m map[string][]upstream.Upstream) (sorted []string) { // ValidateUpstreamsPrivate validates each upstream and returns an error if any // upstream is invalid or if there are no default upstreams specified. It also // checks each domain of domain-specific upstreams for being ARPA pointing to -// a locally-served network. lnc must not be nil. -func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err error) { +// a locally-served network. privateNets must not be nil. +func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) { conf, err := newUpstreamConfig(upstreams) if err != nil { return err @@ -444,7 +432,7 @@ func ValidateUpstreamsPrivate(upstreams []string, lnc LocalNetChecker) (err erro continue } - if !lnc.IsLocallyServedNetwork(subnet.IP) { + if !privateNets.Contains(subnet.IP) { errs = append( errs, fmt.Errorf("arpa domain %q should point to a locally-served network", domain), diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 6e28ab41..f468f7ae 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -14,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -410,8 +411,7 @@ func TestValidateUpstreams(t *testing.T) { } func TestValidateUpstreamsPrivate(t *testing.T) { - snd, err := aghnet.NewSubnetDetector() - require.NoError(t, err) + ss := netutil.SubnetSetFunc(netutil.IsLocallyServed) testCases := []struct { name string @@ -452,7 +452,7 @@ func TestValidateUpstreamsPrivate(t *testing.T) { set := []string{"192.168.0.1", tc.u} t.Run(tc.name, func(t *testing.T) { - err = ValidateUpstreamsPrivate(set, snd) + err := ValidateUpstreamsPrivate(set, ss) testutil.AssertErrorMsg(t, tc.wantErr, err) }) } diff --git a/internal/home/config.go b/internal/home/config.go index 6ecb4ccb..c018f16e 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -126,6 +126,10 @@ type dnsConfig struct { // ResolveClients enables and disables resolving clients with RDNS. ResolveClients bool `yaml:"resolve_clients"` + // PrivateNets is the set of IP networks for which the private reverse DNS + // resolver should be used. + PrivateNets []string `yaml:"private_networks"` + // UsePrivateRDNS defines if the PTR requests for unknown addresses from // locally-served networks should be resolved via private PTR resolvers. UsePrivateRDNS bool `yaml:"use_private_ptr_resolvers"` diff --git a/internal/home/dns.go b/internal/home/dns.go index 4c27abd9..d676f6af 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -77,13 +77,36 @@ func initDNSServer() (err error) { filterConf.HTTPRegister = httpRegister Context.dnsFilter = filtering.New(&filterConf, nil) + var privateNets netutil.SubnetSet + switch len(config.DNS.PrivateNets) { + case 0: + // Use an optimized locally-served matcher. + privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) + case 1: + var n *net.IPNet + n, err = netutil.ParseSubnet(config.DNS.PrivateNets[0]) + if err != nil { + return fmt.Errorf("preparing the set of private subnets: %w", err) + } + + privateNets = n + default: + var nets []*net.IPNet + nets, err = netutil.ParseSubnets(config.DNS.PrivateNets...) + if err != nil { + return fmt.Errorf("preparing the set of private subnets: %w", err) + } + + privateNets = netutil.SliceSubnetSet(nets) + } + p := dnsforward.DNSCreateParams{ - DNSFilter: Context.dnsFilter, - Stats: Context.stats, - QueryLog: Context.queryLog, - SubnetDetector: Context.subnetDetector, - Anonymizer: anonymizer, - LocalDomain: config.DHCP.LocalDomainName, + DNSFilter: Context.dnsFilter, + Stats: Context.stats, + QueryLog: Context.queryLog, + PrivateNets: privateNets, + Anonymizer: anonymizer, + LocalDomain: config.DHCP.LocalDomainName, } if Context.dhcpServer != nil { p.DHCPServer = Context.dhcpServer @@ -133,7 +156,7 @@ func onDNSRequest(pctx *proxy.DNSContext) { if config.DNS.ResolveClients && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.subnetDetector.IsSpecialNetwork(ip) { + if !netutil.IsSpecialPurpose(ip) { Context.whois.Begin(ip) } } @@ -360,10 +383,14 @@ func startDNSServer() error { const topClientsNumber = 100 // the number of clients to get for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) { + if ip == nil { + continue + } + if config.DNS.ResolveClients && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if !Context.subnetDetector.IsSpecialNetwork(ip) { + if !netutil.IsSpecialPurpose(ip) { Context.whois.Begin(ip) } } diff --git a/internal/home/home.go b/internal/home/home.go index 114ba4b6..7096be78 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -66,8 +66,6 @@ type homeContext struct { updater *updater.Updater - subnetDetector *aghnet.SubnetDetector - // mux is our custom http.ServeMux. mux *http.ServeMux @@ -477,9 +475,6 @@ func run(args options, clientBuildFS fs.FS) { Context.web, err = initWeb(args, clientBuildFS) fatalOnError(err) - Context.subnetDetector, err = aghnet.NewSubnetDetector() - fatalOnError(err) - if !Context.firstRun { err = initDNSServer() fatalOnError(err)