Pull request: 5035-dhcp-hosts-netip-addr
Updates #5035. Squashed commit of the following: commit 3a272842f738da322abb2bc5306aed94da79304b Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Oct 26 20:34:49 2022 +0300 dnsforward: imp docs, tests commit b442ca9b57d730be3af14c68759c706f1742e4c4 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Oct 26 19:51:21 2022 +0300 dnsforward: imp code, tests commit 8fca6de93edb8cfdb0ff5a940d08f8700e12a423 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Oct 26 16:38:27 2022 +0300 dnsforward: mv dhcp hosts to netip.Addr
This commit is contained in:
parent
bf10f157ab
commit
8a935d4ffb
2
go.mod
2
go.mod
|
@ -4,7 +4,7 @@ go 1.18
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.46.2
|
github.com/AdguardTeam/dnsproxy v0.46.2
|
||||||
github.com/AdguardTeam/golibs v0.11.0
|
github.com/AdguardTeam/golibs v0.11.2
|
||||||
github.com/AdguardTeam/urlfilter v0.16.0
|
github.com/AdguardTeam/urlfilter v0.16.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
github.com/ameshkov/dnscrypt/v2 v2.2.5
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -2,8 +2,8 @@ github.com/AdguardTeam/dnsproxy v0.46.2 h1:ZUKM713Ts5meYQqk6cJkUBMCFSWqFPXTgjXkN
|
||||||
github.com/AdguardTeam/dnsproxy v0.46.2/go.mod h1:PAmRzFqls0E92XTglyY2ESAqMAzZJhHKErG1ZpRnpjA=
|
github.com/AdguardTeam/dnsproxy v0.46.2/go.mod h1:PAmRzFqls0E92XTglyY2ESAqMAzZJhHKErG1ZpRnpjA=
|
||||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
|
||||||
github.com/AdguardTeam/golibs v0.11.0 h1:fWp5bRLL7N806HWeNiRM7vHJH+wwWQ3Z6kpGPeu2onM=
|
github.com/AdguardTeam/golibs v0.11.2 h1:JbQB1Dg2JWStXgHh1QqBbOLWnP4t9oDjppoBH6TVXSE=
|
||||||
github.com/AdguardTeam/golibs v0.11.0/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
github.com/AdguardTeam/golibs v0.11.2/go.mod h1:87bN2x4VsTritptE3XZg9l8T6gznWsIxHBcQ1DeRIXA=
|
||||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.0 h1:IO29m+ZyQuuOnPLTzHuXj35V1DZOp1Dcryl576P2syg=
|
github.com/AdguardTeam/urlfilter v0.16.0 h1:IO29m+ZyQuuOnPLTzHuXj35V1DZOp1Dcryl576P2syg=
|
||||||
github.com/AdguardTeam/urlfilter v0.16.0/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
github.com/AdguardTeam/urlfilter v0.16.0/go.mod h1:46YZDOV1+qtdRDuhZKVPSSp7JWWes0KayqHrKAFBdEI=
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// DiscardLogOutput runs tests with discarded logger output.
|
// DiscardLogOutput runs tests with discarded logger output.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Replace with testutil.
|
||||||
func DiscardLogOutput(m *testing.M) {
|
func DiscardLogOutput(m *testing.M) {
|
||||||
// TODO(e.burkov): Refactor code and tests to not use the global mutable
|
// TODO(e.burkov): Refactor code and tests to not use the global mutable
|
||||||
// logger.
|
// logger.
|
||||||
|
|
|
@ -3,6 +3,7 @@ package dnsforward
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -13,11 +14,8 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap].
|
|
||||||
|
|
||||||
// To transfer information between modules
|
// To transfer information between modules
|
||||||
type dnsContext struct {
|
type dnsContext struct {
|
||||||
proxyCtx *proxy.DNSContext
|
proxyCtx *proxy.DNSContext
|
||||||
|
@ -197,7 +195,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||||
s.tableHostToIP = t
|
s.tableHostToIP = t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) setTableIPToHost(t *netutil.IPMap) {
|
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||||
s.tableIPToHostLock.Lock()
|
s.tableIPToHostLock.Lock()
|
||||||
defer s.tableIPToHostLock.Unlock()
|
defer s.tableIPToHostLock.Unlock()
|
||||||
|
|
||||||
|
@ -205,52 +203,54 @@ func (s *Server) setTableIPToHost(t *netutil.IPMap) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) onDHCPLeaseChanged(flags int) {
|
func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||||
var err error
|
|
||||||
|
|
||||||
add := true
|
|
||||||
switch flags {
|
switch flags {
|
||||||
case dhcpd.LeaseChangedAdded,
|
case dhcpd.LeaseChangedAdded,
|
||||||
dhcpd.LeaseChangedAddedStatic,
|
dhcpd.LeaseChangedAddedStatic,
|
||||||
dhcpd.LeaseChangedRemovedStatic:
|
dhcpd.LeaseChangedRemovedStatic:
|
||||||
// Go on.
|
// Go on.
|
||||||
case dhcpd.LeaseChangedRemovedAll:
|
case dhcpd.LeaseChangedRemovedAll:
|
||||||
add = false
|
s.setTableHostToIP(nil)
|
||||||
|
s.setTableIPToHost(nil)
|
||||||
|
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var hostToIP hostToIPTable
|
|
||||||
var ipToHost *netutil.IPMap
|
|
||||||
if add {
|
|
||||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||||
|
hostToIP := make(hostToIPTable, len(ll))
|
||||||
hostToIP = make(hostToIPTable, len(ll))
|
ipToHost := make(ipToHostTable, len(ll))
|
||||||
ipToHost = netutil.NewIPMap(len(ll))
|
|
||||||
|
|
||||||
for _, l := range ll {
|
for _, l := range ll {
|
||||||
// TODO(a.garipov): Remove this after we're finished with the client
|
// TODO(a.garipov): Remove this after we're finished with the client
|
||||||
// hostname validations in the DHCP server code.
|
// hostname validations in the DHCP server code.
|
||||||
err = netutil.ValidateDomainName(l.Hostname)
|
err := netutil.ValidateDomainName(l.Hostname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug(
|
log.Debug("dnsforward: skipping invalid hostname %q from dhcp: %s", l.Hostname, err)
|
||||||
"dns: skipping invalid hostname %q from dhcp: %s",
|
|
||||||
l.Hostname,
|
continue
|
||||||
err,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
lowhost := strings.ToLower(l.Hostname + "." + s.localDomainSuffix)
|
||||||
ip := slices.Clone(l.IP)
|
|
||||||
|
|
||||||
ipToHost.Set(ip, lowhost)
|
// Assume that we only process IPv4 now.
|
||||||
hostToIP[lowhost] = ip
|
//
|
||||||
|
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||||
|
ip, err := netutil.IPToAddr(l.IP, netutil.AddrFamilyIPv4)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("dnsforward: skipping invalid ip %v from dhcp: %s", l.IP, err)
|
||||||
|
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
ipToHost[ip] = lowhost
|
||||||
|
hostToIP[lowhost] = ip
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setTableHostToIP(hostToIP)
|
s.setTableHostToIP(hostToIP)
|
||||||
s.setTableIPToHost(ipToHost)
|
s.setTableIPToHost(ipToHost)
|
||||||
|
|
||||||
|
log.Debug("dnsforward: added %d a and ptr entries from dhcp", len(ipToHost))
|
||||||
}
|
}
|
||||||
|
|
||||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||||
|
@ -365,24 +365,13 @@ func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of
|
// dhcpHostToIP tries to get an IP leased by DHCP and returns the copy of
|
||||||
// address since the data inside the internal table may be changed while request
|
// address since the data inside the internal table may be changed while request
|
||||||
// processing. It's safe for concurrent use.
|
// processing. It's safe for concurrent use.
|
||||||
func (s *Server) dhcpHostToIP(host string) (ip net.IP, ok bool) {
|
func (s *Server) dhcpHostToIP(host string) (ip netip.Addr, ok bool) {
|
||||||
s.tableHostToIPLock.Lock()
|
s.tableHostToIPLock.Lock()
|
||||||
defer s.tableHostToIPLock.Unlock()
|
defer s.tableHostToIPLock.Unlock()
|
||||||
|
|
||||||
if s.tableHostToIP == nil {
|
ip, ok = s.tableHostToIP[host]
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
var ipFromTable net.IP
|
return ip, ok
|
||||||
ipFromTable, ok = s.tableHostToIP[host]
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
ip = make(net.IP, len(ipFromTable))
|
|
||||||
copy(ip, ipFromTable)
|
|
||||||
|
|
||||||
return ip, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processDHCPHosts respond to A requests if the target hostname is known to
|
// processDHCPHosts respond to A requests if the target hostname is known to
|
||||||
|
@ -399,7 +388,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !dctx.isLocalClient {
|
if !dctx.isLocalClient {
|
||||||
log.Debug("dns: %q requests for dhcp host %q", pctx.Addr, reqHost)
|
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, reqHost)
|
||||||
pctx.Res = s.genNXDomain(req)
|
pctx.Res = s.genNXDomain(req)
|
||||||
|
|
||||||
// Do not even put into query log.
|
// Do not even put into query log.
|
||||||
|
@ -410,18 +399,18 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||||
if !ok {
|
if !ok {
|
||||||
// Go on and process them with filters, including dnsrewrite ones, and
|
// Go on and process them with filters, including dnsrewrite ones, and
|
||||||
// possibly route them to a domain-specific upstream.
|
// possibly route them to a domain-specific upstream.
|
||||||
log.Debug("dns: no dhcp record for %q", reqHost)
|
log.Debug("dnsforward: no dhcp record for %q", reqHost)
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("dns: dhcp record for %q is %s", reqHost, ip)
|
log.Debug("dnsforward: dhcp record for %q is %s", reqHost, ip)
|
||||||
|
|
||||||
resp := s.makeResponse(req)
|
resp := s.makeResponse(req)
|
||||||
if q.Qtype == dns.TypeA {
|
if q.Qtype == dns.TypeA {
|
||||||
a := &dns.A{
|
a := &dns.A{
|
||||||
Hdr: s.hdr(req, dns.TypeA),
|
Hdr: s.hdr(req, dns.TypeA),
|
||||||
A: ip,
|
A: ip.AsSlice(),
|
||||||
}
|
}
|
||||||
resp.Answer = append(resp.Answer, a)
|
resp.Answer = append(resp.Answer, a)
|
||||||
}
|
}
|
||||||
|
@ -443,7 +432,7 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
|
|
||||||
ip, err := netutil.IPFromReversedAddr(q.Name)
|
ip, err := netutil.IPFromReversedAddr(q.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debug("dns: parsing reversed addr: %s", err)
|
log.Debug("dnsforward: parsing reversed addr: %s", err)
|
||||||
|
|
||||||
// DNS-Based Service Discovery uses PTR records having not an ARPA
|
// DNS-Based Service Discovery uses PTR records having not an ARPA
|
||||||
// format of the domain name in question. Those shouldn't be
|
// format of the domain name in question. Those shouldn't be
|
||||||
|
@ -451,12 +440,12 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
// RFC 2782.
|
// RFC 2782.
|
||||||
name := strings.TrimSuffix(q.Name, ".")
|
name := strings.TrimSuffix(q.Name, ".")
|
||||||
if err = netutil.ValidateSRVDomainName(name); err != nil {
|
if err = netutil.ValidateSRVDomainName(name); err != nil {
|
||||||
log.Debug("dns: validating service domain: %s", err)
|
log.Debug("dnsforward: validating service domain: %s", err)
|
||||||
|
|
||||||
return resultCodeError
|
return resultCodeError
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("dns: request is for a service domain")
|
log.Debug("dnsforward: request is for a service domain")
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
@ -465,13 +454,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
// assume that all the DHCP leases we give are locally-served or at least
|
// assume that all the DHCP leases we give are locally-served or at least
|
||||||
// don't need to be accessible externally.
|
// don't need to be accessible externally.
|
||||||
if !s.privateNets.Contains(ip) {
|
if !s.privateNets.Contains(ip) {
|
||||||
log.Debug("dns: addr %s is not from locally-served network", ip)
|
log.Debug("dnsforward: addr %s is not from locally-served network", ip)
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
if !dctx.isLocalClient {
|
if !dctx.isLocalClient {
|
||||||
log.Debug("dns: %q requests an internal ip", pctx.Addr)
|
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
|
||||||
pctx.Res = s.genNXDomain(req)
|
pctx.Res = s.genNXDomain(req)
|
||||||
|
|
||||||
// Do not even put into query log.
|
// Do not even put into query log.
|
||||||
|
@ -495,27 +484,13 @@ func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
|
||||||
|
|
||||||
// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for
|
// ipToDHCPHost tries to get a hostname leased by DHCP. It's safe for
|
||||||
// concurrent use.
|
// concurrent use.
|
||||||
func (s *Server) ipToDHCPHost(ip net.IP) (host string, ok bool) {
|
func (s *Server) ipToDHCPHost(ip netip.Addr) (host string, ok bool) {
|
||||||
s.tableIPToHostLock.Lock()
|
s.tableIPToHostLock.Lock()
|
||||||
defer s.tableIPToHostLock.Unlock()
|
defer s.tableIPToHostLock.Unlock()
|
||||||
|
|
||||||
if s.tableIPToHost == nil {
|
host, ok = s.tableIPToHost[ip]
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
var v any
|
return host, ok
|
||||||
v, ok = s.tableIPToHost.Get(ip)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
if host, ok = v.(string); !ok {
|
|
||||||
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
|
||||||
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
|
|
||||||
return host, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||||
|
@ -531,12 +506,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
host, ok := s.ipToDHCPHost(ip)
|
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
|
||||||
|
ipAddr, err := netutil.IPToAddrNoMapped(ip)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug("dnsforward: bad reverse ip %v from dhcp: %s", ip, err)
|
||||||
|
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
host, ok := s.ipToDHCPHost(ipAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("dns: dhcp reverse record for %s is %q", ip, host)
|
log.Debug("dnsforward: dhcp reverse record for %s is %q", ip, host)
|
||||||
|
|
||||||
req := pctx.Req
|
req := pctx.Req
|
||||||
resp := s.makeResponse(req)
|
resp := s.makeResponse(req)
|
||||||
|
@ -641,7 +624,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Route such queries to a custom upstream for the
|
// TODO(a.garipov): Route such queries to a custom upstream for the
|
||||||
// local domain name if there is one.
|
// local domain name if there is one.
|
||||||
log.Debug("dns: dhcp client hostname %q was not filtered", reqHost)
|
log.Debug("dnsforward: dhcp client hostname %q was not filtered", reqHost)
|
||||||
pctx.Res = s.genNXDomain(req)
|
pctx.Res = s.genNXDomain(req)
|
||||||
|
|
||||||
return resultCodeFinish
|
return resultCodeFinish
|
||||||
|
@ -714,13 +697,13 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||||
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
|
id := stringutil.Coalesce(clientID, ipStringFromAddr(pctx.Addr))
|
||||||
upsConf, err := customUpsByClient(id)
|
upsConf, err := customUpsByClient(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("dns: getting custom upstreams for client %s: %s", id, err)
|
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if upsConf != nil {
|
if upsConf != nil {
|
||||||
log.Debug("dns: using custom upstreams for client %s", id)
|
log.Debug("dnsforward: using custom upstreams for client %s", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pctx.CustomUpstreamConfig = upsConf
|
pctx.CustomUpstreamConfig = upsConf
|
||||||
|
|
|
@ -2,6 +2,7 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -230,12 +232,11 @@ func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
knownIP := net.IP{1, 2, 3, 4}
|
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
wantIP net.IP
|
wantIP netip.Addr
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
isLocalCli bool
|
isLocalCli bool
|
||||||
}{{
|
}{{
|
||||||
|
@ -247,19 +248,19 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
}, {
|
}, {
|
||||||
name: "local_client_unknown_host",
|
name: "local_client_unknown_host",
|
||||||
host: "wronghost.lan",
|
host: "wronghost.lan",
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
isLocalCli: true,
|
isLocalCli: true,
|
||||||
}, {
|
}, {
|
||||||
name: "external_client_known_host",
|
name: "external_client_known_host",
|
||||||
host: "example.lan",
|
host: "example.lan",
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
isLocalCli: false,
|
isLocalCli: false,
|
||||||
}, {
|
}, {
|
||||||
name: "external_client_unknown_host",
|
name: "external_client_unknown_host",
|
||||||
host: "wronghost.lan",
|
host: "wronghost.lan",
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
isLocalCli: false,
|
isLocalCli: false,
|
||||||
}}
|
}}
|
||||||
|
@ -304,7 +305,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.wantIP == nil {
|
if tc.wantIP == (netip.Addr{}) {
|
||||||
assert.Nil(t, pctx.Res)
|
assert.Nil(t, pctx.Res)
|
||||||
} else {
|
} else {
|
||||||
require.NotNil(t, pctx.Res)
|
require.NotNil(t, pctx.Res)
|
||||||
|
@ -312,7 +313,12 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||||
ans := pctx.Res.Answer
|
ans := pctx.Res.Answer
|
||||||
require.Len(t, ans, 1)
|
require.Len(t, ans, 1)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A)
|
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||||
|
|
||||||
|
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantIP, ip)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -324,26 +330,26 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
examplelan = "example." + defaultLocalDomainSuffix
|
examplelan = "example." + defaultLocalDomainSuffix
|
||||||
)
|
)
|
||||||
|
|
||||||
knownIP := net.IP{1, 2, 3, 4}
|
knownIP := netip.MustParseAddr("1.2.3.4")
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
host string
|
host string
|
||||||
suffix string
|
suffix string
|
||||||
wantIP net.IP
|
wantIP netip.Addr
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
qtyp uint16
|
qtyp uint16
|
||||||
}{{
|
}{{
|
||||||
name: "success_external",
|
name: "success_external",
|
||||||
host: examplecom,
|
host: examplecom,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}, {
|
}, {
|
||||||
name: "success_external_non_a",
|
name: "success_external_non_a",
|
||||||
host: examplecom,
|
host: examplecom,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeCNAME,
|
qtyp: dns.TypeCNAME,
|
||||||
}, {
|
}, {
|
||||||
|
@ -357,14 +363,14 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
name: "success_internal_unknown",
|
name: "success_internal_unknown",
|
||||||
host: "example-new.lan",
|
host: "example-new.lan",
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeA,
|
qtyp: dns.TypeA,
|
||||||
}, {
|
}, {
|
||||||
name: "success_internal_aaaa",
|
name: "success_internal_aaaa",
|
||||||
host: examplelan,
|
host: examplelan,
|
||||||
suffix: defaultLocalDomainSuffix,
|
suffix: defaultLocalDomainSuffix,
|
||||||
wantIP: nil,
|
wantIP: netip.Addr{},
|
||||||
wantRes: resultCodeSuccess,
|
wantRes: resultCodeSuccess,
|
||||||
qtyp: dns.TypeAAAA,
|
qtyp: dns.TypeAAAA,
|
||||||
}, {
|
}, {
|
||||||
|
@ -423,7 +429,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
|
|
||||||
ans := pctx.Res.Answer
|
ans := pctx.Res.Answer
|
||||||
require.Len(t, ans, 0)
|
require.Len(t, ans, 0)
|
||||||
} else if tc.wantIP == nil {
|
} else if tc.wantIP == (netip.Addr{}) {
|
||||||
assert.Nil(t, pctx.Res)
|
assert.Nil(t, pctx.Res)
|
||||||
} else {
|
} else {
|
||||||
require.NotNil(t, pctx.Res)
|
require.NotNil(t, pctx.Res)
|
||||||
|
@ -431,7 +437,12 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||||
ans := pctx.Res.Answer
|
ans := pctx.Res.Answer
|
||||||
require.Len(t, ans, 1)
|
require.Len(t, ans, 1)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantIP, ans[0].(*dns.A).A)
|
a := testutil.RequireTypeAssert[*dns.A](t, ans[0])
|
||||||
|
|
||||||
|
ip, err := netutil.IPToAddr(a.A, netutil.AddrFamilyIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.wantIP, ip)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -26,8 +27,6 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
//lint:file-ignore SA1019 TODO(a.garipov): Replace [*netutil.IPMap].
|
|
||||||
|
|
||||||
// DefaultTimeout is the default upstream timeout
|
// DefaultTimeout is the default upstream timeout
|
||||||
const DefaultTimeout = 10 * time.Second
|
const DefaultTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
@ -46,8 +45,13 @@ var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"}
|
||||||
|
|
||||||
var webRegistered bool
|
var webRegistered bool
|
||||||
|
|
||||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
// hostToIPTable is a convenient type alias for tables of host names to an IP
|
||||||
type hostToIPTable = map[string]net.IP
|
// address.
|
||||||
|
type hostToIPTable = map[string]netip.Addr
|
||||||
|
|
||||||
|
// ipToHostTable is a convenient type alias for tables of IP addresses to their
|
||||||
|
// host names. For example, for use with PTR queries.
|
||||||
|
type ipToHostTable = map[netip.Addr]string
|
||||||
|
|
||||||
// Server is the main way to start a DNS server.
|
// Server is the main way to start a DNS server.
|
||||||
//
|
//
|
||||||
|
@ -84,8 +88,7 @@ type Server struct {
|
||||||
tableHostToIP hostToIPTable
|
tableHostToIP hostToIPTable
|
||||||
tableHostToIPLock sync.Mutex
|
tableHostToIPLock sync.Mutex
|
||||||
|
|
||||||
// TODO(e.burkov): Use map[netip.Addr]struct{} instead.
|
tableIPToHost ipToHostTable
|
||||||
tableIPToHost *netutil.IPMap
|
|
||||||
tableIPToHostLock sync.Mutex
|
tableIPToHostLock sync.Mutex
|
||||||
|
|
||||||
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
// clientIDCache is a temporary storage for ClientIDs that were extracted
|
||||||
|
|
|
@ -33,7 +33,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
aghtest.DiscardLogOutput(m)
|
testutil.DiscardLogOutput(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -1061,11 +1061,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||||
|
|
||||||
require.Len(t, resp.Answer, 1)
|
require.Len(t, resp.Answer, 1)
|
||||||
|
|
||||||
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
|
ans := resp.Answer[0]
|
||||||
assert.Equal(t, "34.12.168.192.in-addr.arpa.", resp.Answer[0].Header().Name)
|
assert.Equal(t, dns.TypePTR, ans.Header().Rrtype)
|
||||||
|
assert.Equal(t, "34.12.168.192.in-addr.arpa.", ans.Header().Name)
|
||||||
|
|
||||||
|
ptr := testutil.RequireTypeAssert[*dns.PTR](t, ans)
|
||||||
|
|
||||||
ptr, ok := resp.Answer[0].(*dns.PTR)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
assert.Equal(t, dns.Fqdn("myhost."+localDomain), ptr.Ptr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue