249 lines
6.2 KiB
Go
249 lines
6.2 KiB
Go
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package tsdns
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"sync"
|
||
|
"testing"
|
||
|
|
||
|
dns "golang.org/x/net/dns/dnsmessage"
|
||
|
"inet.af/netaddr"
|
||
|
"tailscale.com/wgengine/packet"
|
||
|
)
|
||
|
|
||
|
var dnsMap = &Map{
|
||
|
domainToIP: map[string]netaddr.IP{
|
||
|
"test1.ipn.dev": netaddr.IPv4(1, 2, 3, 4),
|
||
|
"test2.ipn.dev": netaddr.IPv4(5, 6, 7, 8),
|
||
|
},
|
||
|
}
|
||
|
|
||
|
func dnspacket(srcip, dstip packet.IP, domain string, tp dns.Type, response bool) *packet.ParsedPacket {
|
||
|
dnsHeader := dns.Header{Response: response}
|
||
|
question := dns.Question{
|
||
|
Name: dns.MustNewName(domain),
|
||
|
Type: tp,
|
||
|
Class: dns.ClassINET,
|
||
|
}
|
||
|
udpHeader := &packet.UDPHeader{
|
||
|
IPHeader: packet.IPHeader{
|
||
|
SrcIP: srcip,
|
||
|
DstIP: dstip,
|
||
|
IPProto: packet.UDP,
|
||
|
},
|
||
|
SrcPort: 1234,
|
||
|
DstPort: 53,
|
||
|
}
|
||
|
|
||
|
builder := dns.NewBuilder(nil, dnsHeader)
|
||
|
builder.StartQuestions()
|
||
|
builder.Question(question)
|
||
|
payload, _ := builder.Finish()
|
||
|
|
||
|
buf := packet.Generate(udpHeader, payload)
|
||
|
|
||
|
pp := new(packet.ParsedPacket)
|
||
|
pp.Decode(buf)
|
||
|
|
||
|
return pp
|
||
|
}
|
||
|
|
||
|
func TestAcceptsPacket(t *testing.T) {
|
||
|
r := NewResolver(t.Logf)
|
||
|
r.SetMap(dnsMap)
|
||
|
|
||
|
src := packet.IP(0x64656667) // 100.101.102.103
|
||
|
dst := packet.IP(0x64646464) // 100.100.100.100
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
request *packet.ParsedPacket
|
||
|
want bool
|
||
|
}{
|
||
|
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), true},
|
||
|
{"invalid", dnspacket(dst, src, "test1.ipn.dev.", dns.TypeA, false), false},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
accepts := r.AcceptsPacket(tt.request)
|
||
|
if accepts != tt.want {
|
||
|
t.Errorf("accepts = %v; want %v", accepts, tt.want)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestResolve(t *testing.T) {
|
||
|
r := NewResolver(t.Logf)
|
||
|
r.SetMap(dnsMap)
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
domain string
|
||
|
ip netaddr.IP
|
||
|
code dns.RCode
|
||
|
iserr bool
|
||
|
}{
|
||
|
{"valid", "test1.ipn.dev", netaddr.IPv4(1, 2, 3, 4), dns.RCodeSuccess, false},
|
||
|
{"nxdomain", "test3.ipn.dev", netaddr.IP{}, dns.RCodeNameError, true},
|
||
|
{"not our domain", "google.com", netaddr.IP{}, dns.RCodeRefused, true},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
ip, code, err := r.Resolve(tt.domain)
|
||
|
if err != nil && !tt.iserr {
|
||
|
t.Errorf("err = %v; want nil", err)
|
||
|
} else if err == nil && tt.iserr {
|
||
|
t.Errorf("err = nil; want non-nil")
|
||
|
}
|
||
|
if code != tt.code {
|
||
|
t.Errorf("code = %v; want %v", code, tt.code)
|
||
|
}
|
||
|
// Only check ip for non-err
|
||
|
if !tt.iserr && ip != tt.ip {
|
||
|
t.Errorf("ip = %v; want %v", ip, tt.ip)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestConcurrentSet(t *testing.T) {
|
||
|
r := NewResolver(t.Logf)
|
||
|
|
||
|
// This is purely to ensure that Resolve does not race with SetMap.
|
||
|
var wg sync.WaitGroup
|
||
|
wg.Add(2)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
r.SetMap(dnsMap)
|
||
|
}()
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
r.Resolve("test1.ipn.dev")
|
||
|
}()
|
||
|
wg.Wait()
|
||
|
}
|
||
|
|
||
|
var validResponse = []byte{
|
||
|
// IP header
|
||
|
0x45, 0x00, 0x00, 0x58, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x00,
|
||
|
// Source IP
|
||
|
0x64, 0x64, 0x64, 0x64,
|
||
|
// Destination IP
|
||
|
0x64, 0x65, 0x66, 0x67,
|
||
|
// UDP header
|
||
|
0x00, 0x35, 0x04, 0xd2, 0x00, 0x44, 0x53, 0xdd,
|
||
|
// DNS payload
|
||
|
0x00, 0x00, // transaction id: 0
|
||
|
0x84, 0x00, // flags: response, authoritative, no error
|
||
|
0x00, 0x01, // one question
|
||
|
0x00, 0x01, // one answer
|
||
|
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||
|
// Question:
|
||
|
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||
|
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||
|
// Answer:
|
||
|
0x05, 0x74, 0x65, 0x73, 0x74, 0x31, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||
|
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||
|
0x00, 0x00, 0x02, 0x58, // TTL: 600
|
||
|
0x00, 0x04, // length: 4 bytes
|
||
|
0x01, 0x02, 0x03, 0x04, // A: 1.2.3.4
|
||
|
}
|
||
|
|
||
|
var nxdomainResponse = []byte{
|
||
|
// IP header
|
||
|
0x45, 0x00, 0x00, 0x3b, 0xff, 0xff, 0x00, 0x00, 0x40, 0x11, 0xe7, 0x1d,
|
||
|
// Source IP
|
||
|
0x64, 0x64, 0x64, 0x64,
|
||
|
// Destination IP
|
||
|
0x64, 0x65, 0x66, 0x67,
|
||
|
// UDP header
|
||
|
0x00, 0x35, 0x04, 0xd2, 0x00, 0x27, 0x25, 0x33,
|
||
|
// DNS payload
|
||
|
0x00, 0x00, // transaction id: 0
|
||
|
0x84, 0x03, // flags: response, authoritative, error: nxdomain
|
||
|
0x00, 0x01, // one question
|
||
|
0x00, 0x00, // no answers
|
||
|
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
|
||
|
// Question:
|
||
|
0x05, 0x74, 0x65, 0x73, 0x74, 0x33, 0x03, 0x69, 0x70, 0x6e, 0x03, 0x64, 0x65, 0x76, 0x00, // name
|
||
|
0x00, 0x01, 0x00, 0x01, // type A, class IN
|
||
|
}
|
||
|
|
||
|
func TestFull(t *testing.T) {
|
||
|
r := NewResolver(t.Logf)
|
||
|
r.SetMap(dnsMap)
|
||
|
|
||
|
src := packet.IP(0x64656667) // 100.101.102.103
|
||
|
dst := packet.IP(0x64646464) // 100.100.100.100
|
||
|
// One full packet and one error packet
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
request *packet.ParsedPacket
|
||
|
response []byte
|
||
|
}{
|
||
|
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false), validResponse},
|
||
|
{"error", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false), nxdomainResponse},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
buf := make([]byte, 512)
|
||
|
response, err := r.Respond(tt.request, buf)
|
||
|
if err != nil {
|
||
|
t.Errorf("err = %v; want nil", err)
|
||
|
}
|
||
|
if !bytes.Equal(response, tt.response) {
|
||
|
t.Errorf("response = %x; want %x", response, tt.response)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAllocs(t *testing.T) {
|
||
|
r := NewResolver(t.Logf)
|
||
|
r.SetMap(dnsMap)
|
||
|
|
||
|
src := packet.IP(0x64656667) // 100.101.102.103
|
||
|
dst := packet.IP(0x64646464) // 100.100.100.100
|
||
|
query := dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false)
|
||
|
|
||
|
buf := make([]byte, 512)
|
||
|
allocs := testing.AllocsPerRun(100, func() {
|
||
|
r.Respond(query, buf)
|
||
|
})
|
||
|
|
||
|
if allocs > 0 {
|
||
|
t.Errorf("allocs = %v; want 0", allocs)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func BenchmarkFull(b *testing.B) {
|
||
|
r := NewResolver(b.Logf)
|
||
|
r.SetMap(dnsMap)
|
||
|
|
||
|
src := packet.IP(0x64656667) // 100.101.102.103
|
||
|
dst := packet.IP(0x64646464) // 100.100.100.100
|
||
|
// One full packet and one error packet
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
request *packet.ParsedPacket
|
||
|
}{
|
||
|
{"valid", dnspacket(src, dst, "test1.ipn.dev.", dns.TypeA, false)},
|
||
|
{"nxdomain", dnspacket(src, dst, "test3.ipn.dev.", dns.TypeA, false)},
|
||
|
}
|
||
|
|
||
|
buf := make([]byte, 512)
|
||
|
for _, tt := range tests {
|
||
|
b.Run(tt.name, func(b *testing.B) {
|
||
|
for i := 0; i < b.N; i++ {
|
||
|
r.Respond(tt.request, buf)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|