641 lines
19 KiB
Go
641 lines
19 KiB
Go
|
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
// Package recursive implements a simple recursive DNS resolver.
|
||
|
package recursive
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"net/netip"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/miekg/dns"
|
||
|
"golang.org/x/exp/constraints"
|
||
|
"golang.org/x/exp/slices"
|
||
|
"tailscale.com/envknob"
|
||
|
"tailscale.com/net/netns"
|
||
|
"tailscale.com/types/logger"
|
||
|
"tailscale.com/util/dnsname"
|
||
|
"tailscale.com/util/mak"
|
||
|
"tailscale.com/util/multierr"
|
||
|
"tailscale.com/util/slicesx"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
// maxDepth is how deep from the root nameservers we'll recurse when
|
||
|
// resolving; passing this limit will instead return an error.
|
||
|
//
|
||
|
// maxDepth must be at least 20 to resolve "console.aws.amazon.com",
|
||
|
// which is a domain with a moderately complicated DNS setup. The
|
||
|
// current value of 30 was chosen semi-arbitrarily to ensure that we
|
||
|
// have about 50% headroom.
|
||
|
maxDepth = 30
|
||
|
// numStartingServers is the number of root nameservers that we use as
|
||
|
// initial candidates for our recursion.
|
||
|
numStartingServers = 3
|
||
|
// udpQueryTimeout is the amount of time we wait for a UDP response
|
||
|
// from a nameserver before falling back to a TCP connection.
|
||
|
udpQueryTimeout = 5 * time.Second
|
||
|
|
||
|
// These constants aren't typed in the DNS package, so we create typed
|
||
|
// versions here to avoid having to do repeated type casts.
|
||
|
qtypeA dns.Type = dns.Type(dns.TypeA)
|
||
|
qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA)
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// ErrMaxDepth is returned when recursive resolving exceeds the maximum
|
||
|
// depth limit for this package.
|
||
|
ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth)
|
||
|
|
||
|
// ErrAuthoritativeNoResponses is the error returned when an
|
||
|
// authoritative nameserver indicates that there are no responses to
|
||
|
// the given query.
|
||
|
ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses")
|
||
|
|
||
|
// ErrNoResponses is returned when our resolution process completes
|
||
|
// with no valid responses from any nameserver, but no authoritative
|
||
|
// server explicitly returned NXDOMAIN.
|
||
|
ErrNoResponses = errors.New("no responses to query")
|
||
|
)
|
||
|
|
||
|
var rootServersV4 = []netip.Addr{
|
||
|
netip.MustParseAddr("198.41.0.4"), // a.root-servers.net
|
||
|
netip.MustParseAddr("199.9.14.201"), // b.root-servers.net
|
||
|
netip.MustParseAddr("192.33.4.12"), // c.root-servers.net
|
||
|
netip.MustParseAddr("199.7.91.13"), // d.root-servers.net
|
||
|
netip.MustParseAddr("192.203.230.10"), // e.root-servers.net
|
||
|
netip.MustParseAddr("192.5.5.241"), // f.root-servers.net
|
||
|
netip.MustParseAddr("192.112.36.4"), // g.root-servers.net
|
||
|
netip.MustParseAddr("198.97.190.53"), // h.root-servers.net
|
||
|
netip.MustParseAddr("192.36.148.17"), // i.root-servers.net
|
||
|
netip.MustParseAddr("192.58.128.30"), // j.root-servers.net
|
||
|
netip.MustParseAddr("193.0.14.129"), // k.root-servers.net
|
||
|
netip.MustParseAddr("199.7.83.42"), // l.root-servers.net
|
||
|
netip.MustParseAddr("202.12.27.33"), // m.root-servers.net
|
||
|
}
|
||
|
|
||
|
var rootServersV6 = []netip.Addr{
|
||
|
netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:200::b"), // b.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net
|
||
|
netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net
|
||
|
netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net
|
||
|
netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net
|
||
|
netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net
|
||
|
netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net
|
||
|
}
|
||
|
|
||
|
var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS")
|
||
|
|
||
|
// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records.
|
||
|
type Resolver struct {
|
||
|
// Dialer is used to create outbound connections. If nil, a zero
|
||
|
// net.Dialer will be used instead.
|
||
|
Dialer netns.Dialer
|
||
|
|
||
|
// Logf is the logging function to use; if none is specified, then logs
|
||
|
// will be dropped.
|
||
|
Logf logger.Logf
|
||
|
|
||
|
// NoIPv6, if set, will prevent this package from querying for AAAA
|
||
|
// records and will avoid contacting nameservers over IPv6.
|
||
|
NoIPv6 bool
|
||
|
|
||
|
// Test mocks
|
||
|
testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error)
|
||
|
testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error)
|
||
|
rootServers []netip.Addr
|
||
|
timeNow func() time.Time
|
||
|
|
||
|
// Caching
|
||
|
// NOTE(andrew): if we make resolution parallel, this needs a mutex
|
||
|
queryCache map[dnsQuery]dnsMsgWithExpiry
|
||
|
|
||
|
// Possible future additions:
|
||
|
// - Additional nameservers? From the system maybe?
|
||
|
// - NoIPv4 for IPv4
|
||
|
// - DNS-over-HTTPS or DNS-over-TLS support
|
||
|
}
|
||
|
|
||
|
// queryState stores all state during the course of a single query
|
||
|
type queryState struct {
|
||
|
// rootServers are the root nameservers to start from
|
||
|
rootServers []netip.Addr
|
||
|
|
||
|
// TODO: metrics?
|
||
|
}
|
||
|
|
||
|
type dnsQuery struct {
|
||
|
nameserver netip.Addr
|
||
|
name dnsname.FQDN
|
||
|
qtype dns.Type
|
||
|
}
|
||
|
|
||
|
func (q dnsQuery) String() string {
|
||
|
return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype)
|
||
|
}
|
||
|
|
||
|
type dnsMsgWithExpiry struct {
|
||
|
*dns.Msg
|
||
|
expiresAt time.Time
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) now() time.Time {
|
||
|
if r.timeNow != nil {
|
||
|
return r.timeNow()
|
||
|
}
|
||
|
return time.Now()
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) logf(format string, args ...any) {
|
||
|
if r.Logf == nil {
|
||
|
return
|
||
|
}
|
||
|
r.Logf(format, args...)
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) dlogf(format string, args ...any) {
|
||
|
if r.Logf == nil || !debug() {
|
||
|
return
|
||
|
}
|
||
|
r.Logf(format, args...)
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) depthlogf(depth int, format string, args ...any) {
|
||
|
if r.Logf == nil || !debug() {
|
||
|
return
|
||
|
}
|
||
|
prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth))
|
||
|
r.Logf(prefix+format, args...)
|
||
|
}
|
||
|
|
||
|
var defaultDialer net.Dialer
|
||
|
|
||
|
func (r *Resolver) dialer() netns.Dialer {
|
||
|
if r.Dialer != nil {
|
||
|
return r.Dialer
|
||
|
}
|
||
|
|
||
|
return &defaultDialer
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) newState() *queryState {
|
||
|
var rootServers []netip.Addr
|
||
|
if len(r.rootServers) > 0 {
|
||
|
rootServers = r.rootServers
|
||
|
} else {
|
||
|
// Select a random subset of root nameservers to start from, since if
|
||
|
// we don't get responses from those, something else has probably gone
|
||
|
// horribly wrong.
|
||
|
roots4 := slices.Clone(rootServersV4)
|
||
|
slicesx.Shuffle(roots4)
|
||
|
roots4 = roots4[:numStartingServers]
|
||
|
|
||
|
var roots6 []netip.Addr
|
||
|
if !r.NoIPv6 {
|
||
|
roots6 = slices.Clone(rootServersV6)
|
||
|
slicesx.Shuffle(roots6)
|
||
|
roots6 = roots6[:numStartingServers]
|
||
|
}
|
||
|
|
||
|
// Interleave the root servers so that we try to contact them over
|
||
|
// IPv4, then IPv6, IPv4, IPv6, etc.
|
||
|
rootServers = slicesx.Interleave(roots4, roots6)
|
||
|
}
|
||
|
|
||
|
return &queryState{
|
||
|
rootServers: rootServers,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Resolve will perform a recursive DNS resolution for the provided name,
|
||
|
// starting at a randomly-chosen root DNS server, and return the A and AAAA
|
||
|
// responses as a slice of netip.Addrs along with the minimum TTL for the
|
||
|
// returned records.
|
||
|
func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) {
|
||
|
dnsName, err := dnsname.ToFQDN(name)
|
||
|
if err != nil {
|
||
|
return nil, 0, err
|
||
|
}
|
||
|
|
||
|
qstate := r.newState()
|
||
|
|
||
|
r.logf("querying IPv4 addresses for: %q", name)
|
||
|
addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA)
|
||
|
|
||
|
var (
|
||
|
addrs6 []netip.Addr
|
||
|
minTTL6 time.Duration
|
||
|
err6 error
|
||
|
)
|
||
|
if !r.NoIPv6 {
|
||
|
r.logf("querying IPv6 addresses for: %q", name)
|
||
|
addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA)
|
||
|
}
|
||
|
|
||
|
if err4 != nil && err6 != nil {
|
||
|
if err4 == err6 {
|
||
|
return nil, 0, err4
|
||
|
}
|
||
|
|
||
|
return nil, 0, multierr.New(err4, err6)
|
||
|
}
|
||
|
if err4 != nil {
|
||
|
return addrs6, minTTL6, nil
|
||
|
} else if err6 != nil {
|
||
|
return addrs4, minTTL4, nil
|
||
|
}
|
||
|
|
||
|
minTTL = minTTL4
|
||
|
if minTTL6 < minTTL {
|
||
|
minTTL = minTTL6
|
||
|
}
|
||
|
|
||
|
addrs = append(addrs4, addrs6...)
|
||
|
if len(addrs) == 0 {
|
||
|
return nil, 0, ErrNoResponses
|
||
|
}
|
||
|
|
||
|
slicesx.Shuffle(addrs)
|
||
|
return addrs, minTTL, nil
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) resolveRecursiveFromRoot(
|
||
|
ctx context.Context,
|
||
|
qstate *queryState,
|
||
|
depth int,
|
||
|
name dnsname.FQDN, // what we're querying
|
||
|
qtype dns.Type,
|
||
|
) ([]netip.Addr, time.Duration, error) {
|
||
|
r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype)
|
||
|
|
||
|
var depthError bool
|
||
|
for _, server := range qstate.rootServers {
|
||
|
addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype)
|
||
|
if err == nil {
|
||
|
return addrs, minTTL, err
|
||
|
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
|
||
|
return nil, 0, ErrAuthoritativeNoResponses
|
||
|
} else if errors.Is(err, ErrMaxDepth) {
|
||
|
depthError = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if depthError {
|
||
|
return nil, 0, ErrMaxDepth
|
||
|
}
|
||
|
return nil, 0, ErrNoResponses
|
||
|
}
|
||
|
|
||
|
func (r *Resolver) resolveRecursive(
|
||
|
ctx context.Context,
|
||
|
qstate *queryState,
|
||
|
depth int,
|
||
|
name dnsname.FQDN, // what we're querying
|
||
|
nameserver netip.Addr,
|
||
|
qtype dns.Type,
|
||
|
) ([]netip.Addr, time.Duration, error) {
|
||
|
if depth == maxDepth {
|
||
|
r.depthlogf(depth, "not recursing past maximum depth")
|
||
|
return nil, 0, ErrMaxDepth
|
||
|
}
|
||
|
|
||
|
// Ask this nameserver for an answer.
|
||
|
resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype)
|
||
|
if err != nil {
|
||
|
return nil, 0, err
|
||
|
}
|
||
|
|
||
|
// If we get an actual answer from the nameserver, then return it.
|
||
|
var (
|
||
|
answers []netip.Addr
|
||
|
cnames []dnsname.FQDN
|
||
|
minTTL = 24 * 60 * 60 // 24 hours in seconds
|
||
|
)
|
||
|
for _, answer := range resp.Answer {
|
||
|
if crec, ok := answer.(*dns.CNAME); ok {
|
||
|
cnameFQDN, err := dnsname.ToFQDN(crec.Target)
|
||
|
if err != nil {
|
||
|
r.logf("bad CNAME %q returned: %v", crec.Target, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
cnames = append(cnames, cnameFQDN)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
addr := addrFromRecord(answer)
|
||
|
if !addr.IsValid() {
|
||
|
r.logf("[unexpected] invalid record in %T answer", answer)
|
||
|
} else if addr.Is4() && qtype != qtypeA {
|
||
|
r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype)
|
||
|
} else if addr.Is6() && qtype != qtypeAAAA {
|
||
|
r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype)
|
||
|
} else {
|
||
|
answers = append(answers, addr)
|
||
|
minTTL = min(minTTL, int(answer.Header().Ttl))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(answers) > 0 {
|
||
|
r.depthlogf(depth, "got answers for %q: %v", name, answers)
|
||
|
return answers, time.Duration(minTTL) * time.Second, nil
|
||
|
}
|
||
|
|
||
|
r.depthlogf(depth, "no answers for %q", name)
|
||
|
|
||
|
// If we have a non-zero number of CNAMEs, then try resolving those
|
||
|
// (from the root again) and return the first one that succeeds.
|
||
|
//
|
||
|
// TODO: return the union of all responses?
|
||
|
// TODO: parallelism?
|
||
|
if len(cnames) > 0 {
|
||
|
r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames)
|
||
|
}
|
||
|
var cnameDepthError bool
|
||
|
for _, cname := range cnames {
|
||
|
answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype)
|
||
|
if err == nil {
|
||
|
return answers, minTTL, nil
|
||
|
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
|
||
|
return nil, 0, ErrAuthoritativeNoResponses
|
||
|
} else if errors.Is(err, ErrMaxDepth) {
|
||
|
cnameDepthError = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If this is an authoritative response, then we know that continuing
|
||
|
// to look further is not going to result in any answers and we should
|
||
|
// bail out.
|
||
|
if resp.MsgHdr.Authoritative {
|
||
|
// If we failed to recurse into a CNAME due to a depth limit,
|
||
|
// propagate that here.
|
||
|
if cnameDepthError {
|
||
|
return nil, 0, ErrMaxDepth
|
||
|
}
|
||
|
|
||
|
r.depthlogf(depth, "got authoritative response with no answers; stopping")
|
||
|
return nil, 0, ErrAuthoritativeNoResponses
|
||
|
}
|
||
|
|
||
|
r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name)
|
||
|
|
||
|
// No CNAMEs and no answers; see if we got any AUTHORITY responses,
|
||
|
// which indicate which nameservers to query next.
|
||
|
var authorities []dnsname.FQDN
|
||
|
for _, rr := range resp.Ns {
|
||
|
ns, ok := rr.(*dns.NS)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
nsName, err := dnsname.ToFQDN(ns.Ns)
|
||
|
if err != nil {
|
||
|
r.logf("unexpected bad NS name %q: %v", ns.Ns, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
authorities = append(authorities, nsName)
|
||
|
}
|
||
|
|
||
|
// Also check for "glue" records, which are IP addresses provided by
|
||
|
// the DNS server for authority responses; these are required when the
|
||
|
// authority server is a subdomain of what's being resolved.
|
||
|
glueRecords := make(map[dnsname.FQDN][]netip.Addr)
|
||
|
for _, rr := range resp.Extra {
|
||
|
name, err := dnsname.ToFQDN(rr.Header().Name)
|
||
|
if err != nil {
|
||
|
r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if addr := addrFromRecord(rr); addr.IsValid() {
|
||
|
glueRecords[name] = append(glueRecords[name], addr)
|
||
|
} else {
|
||
|
r.logf("unexpected bad Extra %T addr", rr)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Try authorities with glue records first, to minimize the number of
|
||
|
// additional DNS queries that we need to make.
|
||
|
authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool {
|
||
|
return len(glueRecords[aa]) > 0
|
||
|
})
|
||
|
|
||
|
authorityDepthError := false
|
||
|
|
||
|
r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue)
|
||
|
for _, authority := range authoritiesGlue {
|
||
|
for _, nameserver := range glueRecords[authority] {
|
||
|
answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
|
||
|
if err == nil {
|
||
|
return answers, minTTL, nil
|
||
|
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
|
||
|
return nil, 0, ErrAuthoritativeNoResponses
|
||
|
} else if errors.Is(err, ErrMaxDepth) {
|
||
|
authorityDepthError = true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue)
|
||
|
for _, authority := range authoritiesNoGlue {
|
||
|
// First, resolve the IP for the authority server from the
|
||
|
// root, querying for both IPv4 and IPv6 addresses regardless
|
||
|
// of what the current question type is.
|
||
|
//
|
||
|
// TODO: check for infinite recursion; it'll get caught by our
|
||
|
// recursion depth, but we want to bail early.
|
||
|
for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} {
|
||
|
answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype)
|
||
|
if err != nil {
|
||
|
r.depthlogf(depth, "error querying authority %q: %v", authority, err)
|
||
|
continue
|
||
|
}
|
||
|
r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers)
|
||
|
|
||
|
// Now, query this authority for the final address.
|
||
|
for _, nameserver := range answers {
|
||
|
answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype)
|
||
|
if err == nil {
|
||
|
return answers, minTTL, nil
|
||
|
} else if errors.Is(err, ErrAuthoritativeNoResponses) {
|
||
|
return nil, 0, ErrAuthoritativeNoResponses
|
||
|
} else if errors.Is(err, ErrMaxDepth) {
|
||
|
authorityDepthError = true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if authorityDepthError {
|
||
|
return nil, 0, ErrMaxDepth
|
||
|
}
|
||
|
return nil, 0, ErrNoResponses
|
||
|
}
|
||
|
|
||
|
func min[T constraints.Ordered](a, b T) T {
|
||
|
if a < b {
|
||
|
return a
|
||
|
}
|
||
|
return b
|
||
|
}
|
||
|
|
||
|
// queryNameserver sends a query for "name" to the nameserver "nameserver" for
|
||
|
// records of type "qtype", trying both UDP and TCP connections as
|
||
|
// appropriate.
|
||
|
func (r *Resolver) queryNameserver(
|
||
|
ctx context.Context,
|
||
|
depth int,
|
||
|
name dnsname.FQDN, // what we're querying
|
||
|
nameserver netip.Addr, // destination of query
|
||
|
qtype dns.Type,
|
||
|
) (*dns.Msg, error) {
|
||
|
// TODO(andrew): we should QNAME minimisation here to avoid sending the
|
||
|
// full name to intermediate/root nameservers. See:
|
||
|
// https://www.rfc-editor.org/rfc/rfc7816
|
||
|
|
||
|
// Handle the case where UDP is blocked by adding an explicit timeout
|
||
|
// for the UDP portion of this query.
|
||
|
udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout)
|
||
|
defer udpCtxCancel()
|
||
|
|
||
|
msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype)
|
||
|
if err == nil {
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
|
||
|
if err2 == nil {
|
||
|
return msg, nil
|
||
|
}
|
||
|
|
||
|
return nil, multierr.New(err, err2)
|
||
|
}
|
||
|
|
||
|
// queryNameserverProto sends a query for "name" to the nameserver "nameserver"
|
||
|
// for records of type "qtype" over the provided protocol (either "udp"
|
||
|
// or "tcp"), and returns the DNS response or an error.
|
||
|
func (r *Resolver) queryNameserverProto(
|
||
|
ctx context.Context,
|
||
|
depth int,
|
||
|
name dnsname.FQDN, // what we're querying
|
||
|
nameserver netip.Addr, // destination of query
|
||
|
protocol string,
|
||
|
qtype dns.Type,
|
||
|
) (resp *dns.Msg, err error) {
|
||
|
if r.testQueryHook != nil {
|
||
|
return r.testQueryHook(name, nameserver, protocol, qtype)
|
||
|
}
|
||
|
|
||
|
now := r.now()
|
||
|
nameserverStr := nameserver.String()
|
||
|
|
||
|
cacheKey := dnsQuery{
|
||
|
nameserver: nameserver,
|
||
|
name: name,
|
||
|
qtype: qtype,
|
||
|
}
|
||
|
cacheEntry, ok := r.queryCache[cacheKey]
|
||
|
if ok && cacheEntry.expiresAt.Before(now) {
|
||
|
r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype)
|
||
|
return cacheEntry.Msg, nil
|
||
|
}
|
||
|
|
||
|
var network string
|
||
|
if nameserver.Is4() {
|
||
|
network = protocol + "4"
|
||
|
} else {
|
||
|
network = protocol + "6"
|
||
|
}
|
||
|
|
||
|
// Prepare a message asking for an appropriately-typed record
|
||
|
// for the name we're querying.
|
||
|
m := new(dns.Msg)
|
||
|
m.SetQuestion(name.WithTrailingDot(), uint16(qtype))
|
||
|
|
||
|
// Allow mocking out the network components with our exchange hook.
|
||
|
if r.testExchangeHook != nil {
|
||
|
resp, err = r.testExchangeHook(nameserver, network, m)
|
||
|
} else {
|
||
|
// Dial the current nameserver using our dialer.
|
||
|
var nconn net.Conn
|
||
|
nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53"))
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
var c dns.Client // TODO: share?
|
||
|
conn := &dns.Conn{
|
||
|
Conn: nconn,
|
||
|
UDPSize: c.UDPSize,
|
||
|
}
|
||
|
|
||
|
// Send the DNS request to the current nameserver.
|
||
|
//
|
||
|
// TODO(andrew): use ExchangeWithConnContext after this upstream PR is
|
||
|
// merged:
|
||
|
// https://github.com/miekg/dns/pull/1459
|
||
|
r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype)
|
||
|
resp, _, err = c.ExchangeWithConn(m, conn)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// If the message was truncated and we're using UDP, re-run with TCP.
|
||
|
if resp.MsgHdr.Truncated && protocol == "udp" {
|
||
|
r.depthlogf(depth, "response message truncated; re-running query with TCP")
|
||
|
resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Find minimum expiry for all records in this message.
|
||
|
var minTTL int
|
||
|
for _, rr := range resp.Answer {
|
||
|
minTTL = min(minTTL, int(rr.Header().Ttl))
|
||
|
}
|
||
|
for _, rr := range resp.Ns {
|
||
|
minTTL = min(minTTL, int(rr.Header().Ttl))
|
||
|
}
|
||
|
for _, rr := range resp.Extra {
|
||
|
minTTL = min(minTTL, int(rr.Header().Ttl))
|
||
|
}
|
||
|
|
||
|
mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{
|
||
|
Msg: resp,
|
||
|
expiresAt: now.Add(time.Duration(minTTL) * time.Second),
|
||
|
})
|
||
|
return resp, nil
|
||
|
}
|
||
|
|
||
|
func addrFromRecord(rr dns.RR) netip.Addr {
|
||
|
switch v := rr.(type) {
|
||
|
case *dns.A:
|
||
|
ip, ok := netip.AddrFromSlice(v.A)
|
||
|
if !ok || !ip.Is4() {
|
||
|
return netip.Addr{}
|
||
|
}
|
||
|
return ip
|
||
|
case *dns.AAAA:
|
||
|
ip, ok := netip.AddrFromSlice(v.AAAA)
|
||
|
if !ok || !ip.Is6() {
|
||
|
return netip.Addr{}
|
||
|
}
|
||
|
return ip
|
||
|
}
|
||
|
return netip.Addr{}
|
||
|
}
|