net/dns/resolver: add Windows ExitDNS service support, using net package
Updates #1713 Updates #835 Change-Id: Ia71e96d0632c2d617b401695ad68301b07c1c2ec Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
cab5c46481
commit
cced414c7d
|
@ -1797,10 +1797,9 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) {
|
|||
})
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "linux", "freebsd", "openbsd", "illumos", "darwin":
|
||||
case "linux", "freebsd", "openbsd", "illumos", "darwin", "windows":
|
||||
// These are the platforms currently supported by
|
||||
// net/dns/resolver/tsdns.go:Resolver.HandleExitNodeDNSQuery.
|
||||
// TODO(bradfitz): add windows once it's done there.
|
||||
ret = append(ret, tailcfg.Service{
|
||||
Proto: tailcfg.PeerAPIDNS,
|
||||
Port: 1, // version
|
||||
|
|
|
@ -360,7 +360,8 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
|
|||
case "windows":
|
||||
// TODO: use DnsQueryEx and write to ch.
|
||||
// See https://docs.microsoft.com/en-us/windows/win32/api/windns/nf-windns-dnsqueryex.
|
||||
return nil, errors.New("TODO: windows exit node suport")
|
||||
// For now just use the net package:
|
||||
return handleExitNodeDNSQueryWithNetPkg(ctx, nil, resp)
|
||||
case "darwin":
|
||||
// /etc/resolv.conf is a lie and only says one upstream DNS
|
||||
// but for now that's probably good enough. Later we'll
|
||||
|
@ -404,6 +405,106 @@ func (r *Resolver) HandleExitNodeDNSQuery(ctx context.Context, q []byte, from ne
|
|||
}
|
||||
}
|
||||
|
||||
// handleExitNodeDNSQueryWithNetPkg takes a DNS query message in q and
|
||||
// return a reply (for the ExitDNS DoH service) using the net package's
|
||||
// native APIs. This is only used on Windows for now.
|
||||
//
|
||||
// If resolver is nil, the net.Resolver zero value is used.
|
||||
//
|
||||
// response contains the pre-serialized response, which notably
|
||||
// includes the original question and its header.
|
||||
func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, resolver *net.Resolver, resp *response) (res []byte, err error) {
|
||||
if resp.Question.Class != dns.ClassINET {
|
||||
return nil, errors.New("unsupported class")
|
||||
}
|
||||
|
||||
r := resolver
|
||||
if r == nil {
|
||||
r = new(net.Resolver)
|
||||
}
|
||||
name := resp.Question.Name.String()
|
||||
|
||||
handleError := func(err error) (res []byte, _ error) {
|
||||
if isGoNoSuchHostError(err) {
|
||||
resp.Header.RCode = dns.RCodeNameError
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
// TODO: map other errors to RCodeServerFailure?
|
||||
// Or I guess our caller should do that?
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.Header.RCode = dns.RCodeSuccess // unless changed below
|
||||
|
||||
switch resp.Question.Type {
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
network := "ip4"
|
||||
if resp.Question.Type == dns.TypeAAAA {
|
||||
network = "ip6"
|
||||
}
|
||||
ips, err := r.LookupIP(ctx, network, name)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
for _, stdIP := range ips {
|
||||
if ip, ok := netaddr.FromStdIP(stdIP); ok {
|
||||
resp.IPs = append(resp.IPs, ip)
|
||||
}
|
||||
}
|
||||
case dns.TypeTXT:
|
||||
strs, err := r.LookupTXT(ctx, name)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
resp.TXT = strs
|
||||
case dns.TypePTR:
|
||||
ipStr, ok := unARPA(name)
|
||||
if !ok {
|
||||
// TODO: is this RCodeFormatError?
|
||||
return nil, errors.New("bogus PTR name")
|
||||
}
|
||||
addrs, err := r.LookupAddr(ctx, ipStr)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
if len(addrs) > 0 {
|
||||
resp.Name, _ = dnsname.ToFQDN(addrs[0])
|
||||
}
|
||||
case dns.TypeCNAME:
|
||||
cname, err := r.LookupCNAME(ctx, name)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
resp.CNAME = cname
|
||||
case dns.TypeSRV:
|
||||
// Thanks, Go: "To accommodate services publishing SRV
|
||||
// records under non-standard names, if both service
|
||||
// and proto are empty strings, LookupSRV looks up
|
||||
// name directly."
|
||||
_, srvs, err := r.LookupSRV(ctx, "", "", name)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
resp.SRVs = srvs
|
||||
case dns.TypeNS:
|
||||
nss, err := r.LookupNS(ctx, name)
|
||||
if err != nil {
|
||||
return handleError(err)
|
||||
}
|
||||
resp.NSs = nss
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported record type %v", resp.Question.Type)
|
||||
}
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
|
||||
func isGoNoSuchHostError(err error) bool {
|
||||
if de, ok := err.(*net.DNSError); ok {
|
||||
return de.IsNotFound
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type resolvConfCache struct {
|
||||
mod time.Time
|
||||
size int64
|
||||
|
@ -604,10 +705,27 @@ func (r *Resolver) handleQuery(pkt packet) {
|
|||
type response struct {
|
||||
Header dns.Header
|
||||
Question dns.Question
|
||||
|
||||
// Name is the response to a PTR query.
|
||||
Name dnsname.FQDN
|
||||
// IP is the response to an A, AAAA, or ALL query.
|
||||
IP netaddr.IP
|
||||
|
||||
// IP and IPs are the responses to an A, AAAA, or ALL query.
|
||||
// Either/both/neither can be populated.
|
||||
IP netaddr.IP
|
||||
IPs []netaddr.IP
|
||||
|
||||
// TXT is the response to a TXT query.
|
||||
// Each one is its own RR with one string.
|
||||
TXT []string
|
||||
|
||||
// CNAME is the response to a CNAME query.
|
||||
CNAME string
|
||||
|
||||
// SRVs are the responses to a SRV query.
|
||||
SRVs []*net.SRV
|
||||
|
||||
// NSs are the responses to an NS query.
|
||||
NSs []*net.NS
|
||||
}
|
||||
|
||||
var dnsParserPool = &sync.Pool{
|
||||
|
@ -683,6 +801,16 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error
|
|||
return builder.AAAAResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
func marshalIP(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
|
||||
if ip.Is4() {
|
||||
return marshalARecord(name, ip, builder)
|
||||
}
|
||||
if ip.Is6() {
|
||||
return marshalAAAARecord(name, ip, builder)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// marshalPTRRecord serializes a PTR record into an active builder.
|
||||
// The caller may continue using the builder following the call.
|
||||
func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builder) error {
|
||||
|
@ -702,6 +830,83 @@ func marshalPTRRecord(queryName dns.Name, name dnsname.FQDN, builder *dns.Builde
|
|||
return builder.PTRResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
func marshalTXT(queryName dns.Name, txts []string, builder *dns.Builder) error {
|
||||
for _, txt := range txts {
|
||||
if err := builder.TXTResource(dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}, dns.TXTResource{
|
||||
TXT: []string{txt},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalCNAME(queryName dns.Name, cname string, builder *dns.Builder) error {
|
||||
if cname == "" {
|
||||
return nil
|
||||
}
|
||||
name, err := dns.NewName(cname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return builder.CNAMEResource(dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypeCNAME,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}, dns.CNAMEResource{
|
||||
CNAME: name,
|
||||
})
|
||||
}
|
||||
|
||||
func marshalNS(queryName dns.Name, nss []*net.NS, builder *dns.Builder) error {
|
||||
for _, ns := range nss {
|
||||
name, err := dns.NewName(ns.Host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = builder.NSResource(dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypeNS,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}, dns.NSResource{NS: name})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func marshalSRV(queryName dns.Name, srvs []*net.SRV, builder *dns.Builder) error {
|
||||
for _, s := range srvs {
|
||||
srvName, err := dns.NewName(s.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = builder.SRVResource(dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}, dns.SRVResource{
|
||||
Target: srvName,
|
||||
Priority: s.Priority,
|
||||
Port: s.Port,
|
||||
Weight: s.Weight,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// marshalResponse serializes the DNS response into a new buffer.
|
||||
func marshalResponse(resp *response) ([]byte, error) {
|
||||
resp.Header.Response = true
|
||||
|
@ -712,6 +917,14 @@ func marshalResponse(resp *response) ([]byte, error) {
|
|||
|
||||
builder := dns.NewBuilder(nil, resp.Header)
|
||||
|
||||
// TODO(bradfitz): I'm not sure why this wasn't enabled
|
||||
// before, but for now (2021-12-09) enable it at least when
|
||||
// there's more than 1 record (which was never the case
|
||||
// before), where it really helps.
|
||||
if len(resp.IPs) > 1 {
|
||||
builder.EnableCompression()
|
||||
}
|
||||
|
||||
isSuccess := resp.Header.RCode == dns.RCodeSuccess
|
||||
|
||||
if resp.Question.Type != 0 || isSuccess {
|
||||
|
@ -738,13 +951,24 @@ func marshalResponse(resp *response) ([]byte, error) {
|
|||
|
||||
switch resp.Question.Type {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
|
||||
if resp.IP.Is4() {
|
||||
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
|
||||
} else if resp.IP.Is6() {
|
||||
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
|
||||
if err := marshalIP(resp.Question.Name, resp.IP, &builder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, ip := range resp.IPs {
|
||||
if err := marshalIP(resp.Question.Name, ip, &builder); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case dns.TypePTR:
|
||||
err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
|
||||
case dns.TypeTXT:
|
||||
err = marshalTXT(resp.Question.Name, resp.TXT, &builder)
|
||||
case dns.TypeCNAME:
|
||||
err = marshalCNAME(resp.Question.Name, resp.CNAME, &builder)
|
||||
case dns.TypeSRV:
|
||||
err = marshalSRV(resp.Question.Name, resp.SRVs, &builder)
|
||||
case dns.TypeNS:
|
||||
err = marshalNS(resp.Question.Name, resp.NSs, &builder)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -926,6 +1150,37 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
|||
return marshalResponse(resp)
|
||||
}
|
||||
|
||||
// unARPA maps from "4.4.8.8.in-addr.arpa." to "8.8.4.4", etc.
|
||||
func unARPA(a string) (ipStr string, ok bool) {
|
||||
const suf4 = ".in-addr.arpa."
|
||||
if strings.HasSuffix(a, suf4) {
|
||||
s := strings.TrimSuffix(a, suf4)
|
||||
// Parse and reverse octets.
|
||||
ip, err := netaddr.ParseIP(s)
|
||||
if err != nil || !ip.Is4() {
|
||||
return "", false
|
||||
}
|
||||
a4 := ip.As4()
|
||||
return netaddr.IPv4(a4[3], a4[2], a4[1], a4[0]).String(), true
|
||||
}
|
||||
const suf6 = ".ip6.arpa."
|
||||
if len(a) == len("e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.") &&
|
||||
strings.HasSuffix(a, suf6) {
|
||||
var hx [32]byte
|
||||
var a16 [16]byte
|
||||
for i := range hx {
|
||||
hx[31-i] = a[i*2]
|
||||
if a[i*2+1] != '.' {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
hex.Decode(a16[:], hx[:])
|
||||
return netaddr.IPFrom16(a16).String(), true
|
||||
}
|
||||
return "", false
|
||||
|
||||
}
|
||||
|
||||
var (
|
||||
metricDNSQueryLocal = clientmetric.NewCounter("dns_query_local")
|
||||
metricDNSQueryErrorClosed = clientmetric.NewCounter("dns_query_local_error_closed")
|
||||
|
|
|
@ -6,6 +6,7 @@ package resolver
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
|
@ -179,6 +180,129 @@ var resolveToNXDOMAIN = dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg)
|
|||
w.WriteMsg(m)
|
||||
})
|
||||
|
||||
// weirdoGoCNAMEHandler returns a DNS handler that satisfies
|
||||
// Go's weird Resolver.LookupCNAME (read its godoc carefully!).
|
||||
//
|
||||
// This doesn't even return a CNAME record, because that's not
|
||||
// what Go looks for.
|
||||
func weirdoGoCNAMEHandler(target string) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
question := req.Question[0]
|
||||
|
||||
switch question.Qtype {
|
||||
case dns.TypeA:
|
||||
m.Answer = append(m.Answer, &dns.CNAME{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: target,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 600,
|
||||
},
|
||||
Target: target,
|
||||
})
|
||||
case dns.TypeAAAA:
|
||||
m.Answer = append(m.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: target,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 600,
|
||||
},
|
||||
AAAA: net.ParseIP("1::2"),
|
||||
})
|
||||
}
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
// dnsHandler returns a handler that replies with the answers/options
|
||||
// provided.
|
||||
//
|
||||
// Types supported: netaddr.IP.
|
||||
func dnsHandler(answers ...interface{}) dns.HandlerFunc {
|
||||
return func(w dns.ResponseWriter, req *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(req)
|
||||
if len(req.Question) != 1 {
|
||||
panic("not a single-question request")
|
||||
}
|
||||
m.RecursionAvailable = true // to stop net package's errLameReferral on empty replies
|
||||
|
||||
question := req.Question[0]
|
||||
for _, a := range answers {
|
||||
switch a := a.(type) {
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported dnsHandler arg %T", a))
|
||||
case netaddr.IP:
|
||||
ip := a
|
||||
if ip.Is4() {
|
||||
m.Answer = append(m.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
A: ip.IPAddr().IP,
|
||||
})
|
||||
} else if ip.Is6() {
|
||||
m.Answer = append(m.Answer, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeAAAA,
|
||||
Class: dns.ClassINET,
|
||||
},
|
||||
AAAA: ip.IPAddr().IP,
|
||||
})
|
||||
}
|
||||
case dns.PTR:
|
||||
ptr := a
|
||||
ptr.Hdr = dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
m.Answer = append(m.Answer, &ptr)
|
||||
case dns.CNAME:
|
||||
c := a
|
||||
c.Hdr = dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeCNAME,
|
||||
Class: dns.ClassINET,
|
||||
Ttl: 600,
|
||||
}
|
||||
m.Answer = append(m.Answer, &c)
|
||||
case dns.TXT:
|
||||
txt := a
|
||||
txt.Hdr = dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeTXT,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
m.Answer = append(m.Answer, &txt)
|
||||
case dns.SRV:
|
||||
srv := a
|
||||
srv.Hdr = dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeSRV,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
m.Answer = append(m.Answer, &srv)
|
||||
case dns.NS:
|
||||
rr := a
|
||||
rr.Hdr = dns.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: dns.TypeNS,
|
||||
Class: dns.ClassINET,
|
||||
}
|
||||
m.Answer = append(m.Answer, &rr)
|
||||
}
|
||||
}
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
}
|
||||
|
||||
func serveDNS(tb testing.TB, addr string, records ...interface{}) *dns.Server {
|
||||
if len(records)%2 != 0 {
|
||||
panic("must have an even number of record values")
|
||||
|
|
|
@ -6,16 +6,22 @@ package resolver
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
miekdns "github.com/miekg/dns"
|
||||
"golang.org/x/net/dns/dnsmessage"
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/net/tsdial"
|
||||
|
@ -43,6 +49,8 @@ var dnsCfg = Config{
|
|||
|
||||
const noEdns = 0
|
||||
|
||||
const dnsHeaderLen = 12
|
||||
|
||||
func dnspacket(domain dnsname.FQDN, tp dns.Type, ednsSize uint16) []byte {
|
||||
var dnsHeader dns.Header
|
||||
question := dns.Question{
|
||||
|
@ -1093,3 +1101,383 @@ func TestForwardLinkSelection(t *testing.T) {
|
|||
type linkSelFunc func(ip netaddr.IP) string
|
||||
|
||||
func (f linkSelFunc) PickLink(ip netaddr.IP) string { return f(ip) }
|
||||
|
||||
func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("skipping test on Windows; waiting for golang.org/issue/33097")
|
||||
}
|
||||
|
||||
records := []interface{}{
|
||||
"no-records.test.",
|
||||
dnsHandler(),
|
||||
|
||||
"one-a.test.",
|
||||
dnsHandler(netaddr.MustParseIP("1.2.3.4")),
|
||||
|
||||
"two-a.test.",
|
||||
dnsHandler(netaddr.MustParseIP("1.2.3.4"), netaddr.MustParseIP("5.6.7.8")),
|
||||
|
||||
"one-aaaa.test.",
|
||||
dnsHandler(netaddr.MustParseIP("1::2")),
|
||||
|
||||
"two-aaaa.test.",
|
||||
dnsHandler(netaddr.MustParseIP("1::2"), netaddr.MustParseIP("3::4")),
|
||||
|
||||
"nx-domain.test.",
|
||||
resolveToNXDOMAIN,
|
||||
|
||||
"4.3.2.1.in-addr.arpa.",
|
||||
dnsHandler(miekdns.PTR{Ptr: "foo.com."}),
|
||||
|
||||
"cname.test.",
|
||||
weirdoGoCNAMEHandler("the-target.foo."),
|
||||
|
||||
"txt.test.",
|
||||
dnsHandler(
|
||||
miekdns.TXT{Txt: []string{"txt1=one"}},
|
||||
miekdns.TXT{Txt: []string{"txt2=two"}},
|
||||
miekdns.TXT{Txt: []string{"txt3=three"}},
|
||||
),
|
||||
|
||||
"srv.test.",
|
||||
dnsHandler(
|
||||
miekdns.SRV{
|
||||
Priority: 1,
|
||||
Weight: 2,
|
||||
Port: 3,
|
||||
Target: "foo.com.",
|
||||
},
|
||||
miekdns.SRV{
|
||||
Priority: 4,
|
||||
Weight: 5,
|
||||
Port: 6,
|
||||
Target: "bar.com.",
|
||||
},
|
||||
),
|
||||
|
||||
"ns.test.",
|
||||
dnsHandler(miekdns.NS{Ns: "ns1.foo."}, miekdns.NS{Ns: "ns2.bar."}),
|
||||
}
|
||||
v4server := serveDNS(t, "127.0.0.1:0", records...)
|
||||
defer v4server.Shutdown()
|
||||
|
||||
// backendResolver is the resolver between
|
||||
// handleExitNodeDNSQueryWithNetPkg and its upstream resolver,
|
||||
// which in this test's case is the miekg/dns test DNS server
|
||||
// (v4server).
|
||||
backResolver := &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "udp", v4server.PacketConn.LocalAddr().String())
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("no_such_host", func(t *testing.T) {
|
||||
res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{
|
||||
Header: dnsmessage.Header{
|
||||
ID: 123,
|
||||
Response: true,
|
||||
OpCode: 0, // query
|
||||
},
|
||||
Question: dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName("nx-domain.test."),
|
||||
Type: dnsmessage.TypeA,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(res) < dnsHeaderLen {
|
||||
t.Fatal("short reply")
|
||||
}
|
||||
rcode := dns.RCode(res[3] & 0x0f)
|
||||
if rcode != dns.RCodeNameError {
|
||||
t.Errorf("RCode = %v; want dns.RCodeNameError", rcode)
|
||||
t.Logf("Response was: %q", res)
|
||||
}
|
||||
})
|
||||
|
||||
matchPacked := func(want string) func(t testing.TB, got []byte) {
|
||||
return func(t testing.TB, got []byte) {
|
||||
if string(got) == want {
|
||||
return
|
||||
}
|
||||
t.Errorf("unexpected reply.\n got: %q\nwant: %q\n", got, want)
|
||||
t.Errorf("\nin hex:\n got: % 2x\nwant: % 2x\n", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
Type dnsmessage.Type
|
||||
Name string
|
||||
Check func(t testing.TB, got []byte)
|
||||
}{
|
||||
{
|
||||
Type: dnsmessage.TypeA,
|
||||
Name: "one-a.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05one-a\x04test\x00\x00\x01\x00\x01\x05one-a\x04test\x00\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeA,
|
||||
Name: "two-a.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x05two-a\x04test\x00\x00\x01\x00\x01\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x01\x02\x03\x04\xc0\f\x00\x01\x00\x01\x00\x00\x02X\x00\x04\x05\x06\a\b"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Name: "one-aaaa.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\bone-aaaa\x04test\x00\x00\x1c\x00\x01\bone-aaaa\x04test\x00\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Name: "two-aaaa.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\btwo-aaaa\x04test\x00\x00\x1c\x00\x01\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc0\f\x00\x1c\x00\x01\x00\x00\x02X\x00\x10\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypePTR,
|
||||
Name: "4.3.2.1.in-addr.arpa.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x014\x013\x012\x011\ain-addr\x04arpa\x00\x00\f\x00\x01\x00\x00\x02X\x00\t\x03foo\x03com\x00"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeCNAME,
|
||||
Name: "cname.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x01\x00\x00\x00\x00\x05cname\x04test\x00\x00\x05\x00\x01\x05cname\x04test\x00\x00\x05\x00\x01\x00\x00\x02X\x00\x10\nthe-target\x03foo\x00"),
|
||||
},
|
||||
|
||||
// No records of various types
|
||||
{
|
||||
Type: dnsmessage.TypeA,
|
||||
Name: "no-records.test.",
|
||||
Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x01\x00\x01"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeAAAA,
|
||||
Name: "no-records.test.",
|
||||
Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x1c\x00\x01"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeCNAME,
|
||||
Name: "no-records.test.",
|
||||
Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00\x05\x00\x01"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeSRV,
|
||||
Name: "no-records.test.",
|
||||
Check: matchPacked("\x00{\x84\x03\x00\x01\x00\x00\x00\x00\x00\x00\nno-records\x04test\x00\x00!\x00\x01"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeTXT,
|
||||
Name: "txt.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x03\x00\x00\x00\x00\x03txt\x04test\x00\x00\x10\x00\x01\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt1=one\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\t\btxt2=two\x03txt\x04test\x00\x00\x10\x00\x01\x00\x00\x02X\x00\v\ntxt3=three"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeSRV,
|
||||
Name: "srv.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x03srv\x04test\x00\x00!\x00\x01\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x01\x00\x02\x00\x03\x03foo\x03com\x00\x03srv\x04test\x00\x00!\x00\x01\x00\x00\x02X\x00\x0f\x00\x04\x00\x05\x00\x06\x03bar\x03com\x00"),
|
||||
},
|
||||
{
|
||||
Type: dnsmessage.TypeNS,
|
||||
Name: "ns.test.",
|
||||
Check: matchPacked("\x00{\x84\x00\x00\x01\x00\x02\x00\x00\x00\x00\x02ns\x04test\x00\x00\x02\x00\x01\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns1\x03foo\x00\x02ns\x04test\x00\x00\x02\x00\x01\x00\x00\x02X\x00\t\x03ns2\x03bar\x00"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%v_%v", tt.Type, strings.Trim(tt.Name, ".")), func(t *testing.T) {
|
||||
got, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), backResolver, &response{
|
||||
Header: dnsmessage.Header{
|
||||
ID: 123,
|
||||
Response: true,
|
||||
OpCode: 0, // query
|
||||
},
|
||||
Question: dnsmessage.Question{
|
||||
Name: dnsmessage.MustNewName(tt.Name),
|
||||
Type: tt.Type,
|
||||
Class: dnsmessage.ClassINET,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(got) < dnsHeaderLen {
|
||||
t.Errorf("short record")
|
||||
}
|
||||
if tt.Check != nil {
|
||||
tt.Check(t, got)
|
||||
if t.Failed() {
|
||||
t.Errorf("Got: %q\nIn hex: % 02x", got, got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
wrapRes := newWrapResolver(backResolver)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("wrap_ip_a", func(t *testing.T) {
|
||||
ips, err := wrapRes.LookupIP(ctx, "ip", "two-a.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := ips, []net.IP{
|
||||
net.ParseIP("1.2.3.4").To4(),
|
||||
net.ParseIP("5.6.7.8").To4(),
|
||||
}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("LookupIP = %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrap_ip_aaaa", func(t *testing.T) {
|
||||
ips, err := wrapRes.LookupIP(ctx, "ip", "two-aaaa.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := ips, []net.IP{
|
||||
net.ParseIP("1::2"),
|
||||
net.ParseIP("3::4"),
|
||||
}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("LookupIP(v6) = %v; want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrap_ip_nx", func(t *testing.T) {
|
||||
ips, err := wrapRes.LookupIP(ctx, "ip", "nx-domain.test.")
|
||||
if !isGoNoSuchHostError(err) {
|
||||
t.Errorf("no NX domain = (%v, %v); want no host error", ips, err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrap_srv", func(t *testing.T) {
|
||||
_, srvs, err := wrapRes.LookupSRV(ctx, "", "", "srv.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := srvs, []*net.SRV{
|
||||
{
|
||||
Target: "foo.com.",
|
||||
Priority: 1,
|
||||
Weight: 2,
|
||||
Port: 3,
|
||||
},
|
||||
{
|
||||
Target: "bar.com.",
|
||||
Priority: 4,
|
||||
Weight: 5,
|
||||
Port: 6,
|
||||
},
|
||||
}; !reflect.DeepEqual(got, want) {
|
||||
jgot, _ := json.Marshal(got)
|
||||
jwant, _ := json.Marshal(want)
|
||||
t.Errorf("SRV = %s; want %s", jgot, jwant)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrap_txt", func(t *testing.T) {
|
||||
txts, err := wrapRes.LookupTXT(ctx, "txt.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := txts, []string{"txt1=one", "txt2=two", "txt3=three"}; !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("TXT = %q; want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wrap_ns", func(t *testing.T) {
|
||||
nss, err := wrapRes.LookupNS(ctx, "ns.test.")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := nss, []*net.NS{
|
||||
{Host: "ns1.foo."},
|
||||
{Host: "ns2.bar."},
|
||||
}; !reflect.DeepEqual(got, want) {
|
||||
jgot, _ := json.Marshal(got)
|
||||
jwant, _ := json.Marshal(want)
|
||||
t.Errorf("NS = %s; want %s", jgot, jwant)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// newWrapResolver returns a resolver that uses r (via handleExitNodeDNSQueryWithNetPkg)
|
||||
// to make DNS requests.
|
||||
func newWrapResolver(r *net.Resolver) *net.Resolver {
|
||||
if runtime.GOOS == "windows" {
|
||||
panic("doesn't work on Windows") // golang.org/issue/33097
|
||||
}
|
||||
return &net.Resolver{
|
||||
PreferGo: true,
|
||||
Dial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return &wrapResolverConn{ctx: ctx, r: r}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type wrapResolverConn struct {
|
||||
ctx context.Context
|
||||
r *net.Resolver
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
var _ net.PacketConn = (*wrapResolverConn)(nil)
|
||||
|
||||
func (*wrapResolverConn) Close() error { return nil }
|
||||
func (*wrapResolverConn) LocalAddr() net.Addr { return fakeAddr{} }
|
||||
func (*wrapResolverConn) RemoteAddr() net.Addr { return fakeAddr{} }
|
||||
func (*wrapResolverConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (*wrapResolverConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (*wrapResolverConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (a *wrapResolverConn) Read(p []byte) (n int, err error) {
|
||||
n, _, err = a.ReadFrom(p)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *wrapResolverConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
n, err = a.buf.Read(p)
|
||||
return n, fakeAddr{}, err
|
||||
}
|
||||
|
||||
func (a *wrapResolverConn) Write(packet []byte) (n int, err error) {
|
||||
return a.WriteTo(packet, fakeAddr{})
|
||||
}
|
||||
|
||||
func (a *wrapResolverConn) WriteTo(q []byte, _ net.Addr) (n int, err error) {
|
||||
resp := parseExitNodeQuery(q)
|
||||
if resp == nil {
|
||||
return 0, errors.New("bad query")
|
||||
}
|
||||
res, err := handleExitNodeDNSQueryWithNetPkg(context.Background(), a.r, resp)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
a.buf.Write(res)
|
||||
return len(q), nil
|
||||
}
|
||||
|
||||
type fakeAddr struct{}
|
||||
|
||||
func (fakeAddr) Network() string { return "unused" }
|
||||
func (fakeAddr) String() string { return "unused-todoAddr" }
|
||||
|
||||
func TestUnARPA(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{"", ""},
|
||||
{"bad", ""},
|
||||
{"4.4.8.8.in-addr.arpa.", "8.8.4.4"},
|
||||
{".in-addr.arpa.", ""},
|
||||
{"e.0.0.2.0.0.0.0.0.0.0.0.0.0.0.0.b.0.8.0.a.0.0.4.0.b.8.f.7.0.6.2.ip6.arpa.", "2607:f8b0:400a:80b::200e"},
|
||||
{".ip6.arpa.", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, ok := unARPA(tt.in)
|
||||
if ok != (got != "") {
|
||||
t.Errorf("inconsistent results for %q: (%q, %v)", tt.in, got, ok)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("unARPA(%q) = %q; want %q", tt.in, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue