diff --git a/go.mod b/go.mod index a24211a9..d7f9b34f 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/AdguardTeam/dnsproxy v0.46.2 - github.com/AdguardTeam/golibs v0.11.0 + github.com/AdguardTeam/golibs v0.11.2 github.com/AdguardTeam/urlfilter v0.16.0 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.5 diff --git a/go.sum b/go.sum index cc16cac8..0e866628 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/AdguardTeam/dnsproxy v0.46.2 h1:ZUKM713Ts5meYQqk6cJkUBMCFSWqFPXTgjXkN github.com/AdguardTeam/dnsproxy v0.46.2/go.mod h1:PAmRzFqls0E92XTglyY2ESAqMAzZJhHKErG1ZpRnpjA= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= -github.com/AdguardTeam/golibs v0.11.0 h1:fWp5bRLL7N806HWeNiRM7vHJH+wwWQ3Z6kpGPeu2onM= -github.com/AdguardTeam/golibs v0.11.0/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= +github.com/AdguardTeam/golibs v0.11.2 h1:JbQB1Dg2JWStXgHh1QqBbOLWnP4t9oDjppoBH6TVXSE= +github.com/AdguardTeam/golibs v0.11.2/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.16.0 h1:IO29m+ZyQuuOnPLTzHuXj35V1DZOp1Dcryl576P2syg= github.com/AdguardTeam/urlfilter v0.16.0/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI= diff --git a/internal/aghtest/aghtest.go b/internal/aghtest/aghtest.go index 878ef178..31274718 100644 --- a/internal/aghtest/aghtest.go +++ b/internal/aghtest/aghtest.go @@ -10,6 +10,8 @@ import ( ) // DiscardLogOutput runs tests with discarded logger output. +// +// TODO(a.garipov): Replace with testutil. func DiscardLogOutput(m *testing.M) { // TODO(e.burkov): Refactor code and tests to not use the global mutable // logger. diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index e535dbc3..ecf41709 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -3,6 +3,7 @@ package dnsforward import ( "encoding/binary" "net" + "net/netip" "strings" "time" @@ -13,11 +14,8 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/miekg/dns" - "golang.org/x/exp/slices" ) -//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap]. - // To transfer information between modules type dnsContext struct { proxyCtx *proxy.DNSContext @@ -197,7 +195,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) { s.tableHostToIP = t } -func (s *Server) setTableIPToHost(t *netutil.IPMap) { +func (s *Server) setTableIPToHost(t ipToHostTable) { s.tableIPToHostLock.Lock() defer s.tableIPToHostLock.Unlock() @@ -205,52 +203,54 @@ func (s *Server) setTableIPToHost(t *netutil.IPMap) { } func (s *Server) onDHCPLeaseChanged(flags int) { - var err error - - add := true switch flags { case dhcpd.LeaseChangedAdded, dhcpd.LeaseChangedAddedStatic, dhcpd.LeaseChangedRemovedStatic: // Go on. case dhcpd.LeaseChangedRemovedAll: - add = false + s.setTableHostToIP(nil) + s.setTableIPToHost(nil) + + return default: return } - var hostToIP hostToIPTable - var ipToHost *netutil.IPMap - if add { - ll := s.dhcpServer.Leases(dhcpd.LeasesAll) + ll := s.dhcpServer.Leases(dhcpd.LeasesAll) + hostToIP := make(hostToIPTable, len(ll)) + ipToHost := make(ipToHostTable, len(ll)) - hostToIP = make(hostToIPTable, len(ll)) - ipToHost = netutil.NewIPMap(len(ll)) + for _, l := range ll { + // TODO(a.garipov): Remove this after we're finished with the client + // hostname validations in the DHCP server code. + err := netutil.ValidateDomainName(l.Hostname) + if err != nil { + log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err) - for _, l := range ll { - // TODO(a.garipov): Remove this after we're finished with the client - // hostname validations in the DHCP server code. - err = netutil.ValidateDomainName(l.Hostname) - if err != nil { - log.Debug( - "dns: skipping invalid hostname %q from dhcp: %s", - l.Hostname, - err, - ) - } - - lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix) - ip := slices.Clone(l.IP) - - ipToHost.Set(ip, lowhost) - hostToIP[lowhost] = ip + continue } - log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len()) + lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix) + + // Assume that we only process IPv4 now. + // + // TODO(a.garipov): Remove once we switch to netip.Addr more fully. + ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4) + if err != nil { + log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err) + + continue + } + + ipToHost[ip] = lowhost + hostToIP[lowhost] = ip } s.setTableHostToIP(hostToIP) s.setTableIPToHost(ipToHost) + + log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost)) } // processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB @@ -365,24 +365,13 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { // dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of // address since the data inside the internal table may be changed while request // processing. It's safe for concurrent use. -func (s *Server) dhcpHostToIP(host string) (ip net.IP, ok bool) { +func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) { s.tableHostToIPLock.Lock() defer s.tableHostToIPLock.Unlock() - if s.tableHostToIP == nil { - return nil, false - } + ip, ok = s.tableHostToIP[host] - var ipFromTable net.IP - ipFromTable, ok = s.tableHostToIP[host] - if !ok { - return nil, false - } - - ip = make(net.IP, len(ipFromTable)) - copy(ip, ipFromTable) - - return ip, true + return ip, ok } // processDHCPHosts respond to A requests if the target hostname is known to @@ -399,7 +388,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { } if !dctx.isLocalClient { - log.Debug("dns: %q requests for dhcp host %q", pctx.Addr, reqHost) + log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost) pctx.Res = s.genNXDomain(req) // Do not even put into query log. @@ -410,18 +399,18 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { if !ok { // Go on and process them with filters, including dnsrewrite ones, and // possibly route them to a domain-specific upstream. - log.Debug("dns: no dhcp record for %q", reqHost) + log.Debug("dnsforward: no dhcp record for %q", reqHost) return resultCodeSuccess } - log.Debug("dns: dhcp record for %q is %s", reqHost, ip) + log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip) resp := s.makeResponse(req) if q.Qtype == dns.TypeA { a := &dns.A{ Hdr: s.hdr(req, dns.TypeA), - A: ip, + A: ip.AsSlice(), } resp.Answer = append(resp.Answer, a) } @@ -443,7 +432,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { ip, err := netutil.IPFromReversedAddr(q.Name) if err != nil { - log.Debug("dns: parsing reversed addr: %s", err) + log.Debug("dnsforward: parsing reversed addr: %s", err) // DNS-Based Service Discovery uses PTR records having not an ARPA // format of the domain name in question. Those shouldn't be @@ -451,12 +440,12 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { // RFC 2782. name := strings.TrimSuffix(q.Name, ".") if err = netutil.ValidateSRVDomainName(name); err != nil { - log.Debug("dns: validating service domain: %s", err) + log.Debug("dnsforward: validating service domain: %s", err) return resultCodeError } - log.Debug("dns: request is for a service domain") + log.Debug("dnsforward: request is for a service domain") return resultCodeSuccess } @@ -465,13 +454,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { // assume that all the DHCP leases we give are locally-served or at least // don't need to be accessible externally. if !s.privateNets.Contains(ip) { - log.Debug("dns: addr %s is not from locally-served network", ip) + log.Debug("dnsforward: addr %s is not from locally-served network", ip) return resultCodeSuccess } if !dctx.isLocalClient { - log.Debug("dns: %q requests an internal ip", pctx.Addr) + log.Debug("dnsforward: %q requests an internal ip", pctx.Addr) pctx.Res = s.genNXDomain(req) // Do not even put into query log. @@ -495,27 +484,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) { // ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for // concurrent use. -func (s *Server) ipToDHCPHost(ip net.IP) (host string, ok bool) { +func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) { s.tableIPToHostLock.Lock() defer s.tableIPToHostLock.Unlock() - if s.tableIPToHost == nil { - return "", false - } + host, ok = s.tableIPToHost[ip] - var v any - v, ok = s.tableIPToHost.Get(ip) - if !ok { - return "", false - } - - if host, ok = v.(string); !ok { - log.Error("dns: bad type %T in tableIPToHost for %s", v, ip) - - return "", false - } - - return host, true + return host, ok } // processDHCPAddrs responds to PTR requests if the target IP is leased by the @@ -531,12 +506,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - host, ok := s.ipToDHCPHost(ip) + // TODO(a.garipov): Remove once we switch to netip.Addr more fully. + ipAddr, err := netutil.IPToAddrNoMapped(ip) + if err != nil { + log.Debug("dnsforward: bad reverse ip %v from dhcp: %s", ip, err) + + return resultCodeSuccess + } + + host, ok := s.ipToDHCPHost(ipAddr) if !ok { return resultCodeSuccess } - log.Debug("dns: dhcp reverse record for %s is %q", ip, host) + log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host) req := pctx.Req resp := s.makeResponse(req) @@ -641,7 +624,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { // // TODO(a.garipov): Route such queries to a custom upstream for the // local domain name if there is one. - log.Debug("dns: dhcp client hostname %q was not filtered", reqHost) + log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost) pctx.Res = s.genNXDomain(req) return resultCodeFinish @@ -714,13 +697,13 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr)) upsConf, err := customUpsByClient(id) if err != nil { - log.Error("dns: getting custom upstreams for client %s: %s", id, err) + log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) return } if upsConf != nil { - log.Debug("dns: using custom upstreams for client %s", id) + log.Debug("dnsforward: using custom upstreams for client %s", id) } pctx.CustomUpstreamConfig = upsConf diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index da7c8ae6..915455d2 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -2,6 +2,7 @@ package dnsforward import ( "net" + "net/netip" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" @@ -9,6 +10,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -230,12 +232,11 @@ func TestServer_ProcessDetermineLocal(t *testing.T) { } func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { - knownIP := net.IP{1, 2, 3, 4} - + knownIP := netip.MustParseAddr("1.2.3.4") testCases := []struct { name string host string - wantIP net.IP + wantIP netip.Addr wantRes resultCode isLocalCli bool }{{ @@ -247,19 +248,19 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { }, { name: "local_client_unknown_host", host: "wronghost.lan", - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeSuccess, isLocalCli: true, }, { name: "external_client_known_host", host: "example.lan", - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeFinish, isLocalCli: false, }, { name: "external_client_unknown_host", host: "wronghost.lan", - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeFinish, isLocalCli: false, }} @@ -304,7 +305,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { return } - if tc.wantIP == nil { + if tc.wantIP == (netip.Addr{}) { assert.Nil(t, pctx.Res) } else { require.NotNil(t, pctx.Res) @@ -312,7 +313,12 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { ans := pctx.Res.Answer require.Len(t, ans, 1) - assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A) + a := testutil.RequireTypeAssert[*dns.A](t, ans[0]) + + ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4) + require.NoError(t, err) + + assert.Equal(t, tc.wantIP, ip) } }) } @@ -324,26 +330,26 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { examplelan = "example." + defaultLocalDomainSuffix ) - knownIP := net.IP{1, 2, 3, 4} + knownIP := netip.MustParseAddr("1.2.3.4") testCases := []struct { name string host string suffix string - wantIP net.IP + wantIP netip.Addr wantRes resultCode qtyp uint16 }{{ name: "success_external", host: examplecom, suffix: defaultLocalDomainSuffix, - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { name: "success_external_non_a", host: examplecom, suffix: defaultLocalDomainSuffix, - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeCNAME, }, { @@ -357,14 +363,14 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { name: "success_internal_unknown", host: "example-new.lan", suffix: defaultLocalDomainSuffix, - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeA, }, { name: "success_internal_aaaa", host: examplelan, suffix: defaultLocalDomainSuffix, - wantIP: nil, + wantIP: netip.Addr{}, wantRes: resultCodeSuccess, qtyp: dns.TypeAAAA, }, { @@ -423,7 +429,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { ans := pctx.Res.Answer require.Len(t, ans, 0) - } else if tc.wantIP == nil { + } else if tc.wantIP == (netip.Addr{}) { assert.Nil(t, pctx.Res) } else { require.NotNil(t, pctx.Res) @@ -431,7 +437,12 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { ans := pctx.Res.Answer require.Len(t, ans, 1) - assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A) + a := testutil.RequireTypeAssert[*dns.A](t, ans[0]) + + ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4) + require.NoError(t, err) + + assert.Equal(t, tc.wantIP, ip) } }) } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index f6507b03..6d107153 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "net/netip" "runtime" "strings" "sync" @@ -26,8 +27,6 @@ import ( "github.com/miekg/dns" ) -//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap]. - // DefaultTimeout is the default upstream timeout const DefaultTimeout = 10 * time.Second @@ -46,8 +45,13 @@ var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"} var webRegistered bool -// hostToIPTable is an alias for the type of Server.tableHostToIP. -type hostToIPTable = map[string]net.IP +// hostToIPTable is a convenient type alias for tables of host names to an IP +// address. +type hostToIPTable = map[string]netip.Addr + +// ipToHostTable is a convenient type alias for tables of IP addresses to their +// host names. For example, for use with PTR queries. +type ipToHostTable = map[netip.Addr]string // Server is the main way to start a DNS server. // @@ -84,8 +88,7 @@ type Server struct { tableHostToIP hostToIPTable tableHostToIPLock sync.Mutex - // TODO(e.burkov): Use map[netip.Addr]struct{} instead. - tableIPToHost *netutil.IPMap + tableIPToHost ipToHostTable tableIPToHostLock sync.Mutex // clientIDCache is a temporary storage for ClientIDs that were extracted diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 7d1ae199..06dc7260 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -33,7 +33,7 @@ import ( ) func TestMain(m *testing.M) { - aghtest.DiscardLogOutput(m) + testutil.DiscardLogOutput(m) } const ( @@ -1061,11 +1061,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { require.Len(t, resp.Answer, 1) - assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) - assert.Equal(t, "34.12.168.192.in-addr.arpa.", resp.Answer[0].Header().Name) + ans := resp.Answer[0] + assert.Equal(t, dns.TypePTR, ans.Header().Rrtype) + assert.Equal(t, "34.12.168.192.in-addr.arpa.", ans.Header().Name) + + ptr := testutil.RequireTypeAssert[*dns.PTR](t, ans) - ptr, ok := resp.Answer[0].(*dns.PTR) - require.True(t, ok) assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr) }