tailscale/net/dns/manager_tcp_test.go

137 lines
3.1 KiB
Go

// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dns
import (
"encoding/binary"
"io"
"net"
"testing"
"github.com/google/go-cmp/cmp"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/tsdial"
"tailscale.com/util/dnsname"
)
func mkDNSRequest(domain dnsname.FQDN, tp dns.Type) []byte {
var dnsHeader dns.Header
question := dns.Question{
Name: dns.MustNewName(domain.WithTrailingDot()),
Type: tp,
Class: dns.ClassINET,
}
builder := dns.NewBuilder(nil, dnsHeader)
if err := builder.StartQuestions(); err != nil {
panic(err)
}
if err := builder.Question(question); err != nil {
panic(err)
}
if err := builder.StartAdditionals(); err != nil {
panic(err)
}
ednsHeader := dns.ResourceHeader{
Name: dns.MustNewName("."),
Type: dns.TypeOPT,
Class: dns.Class(4095),
}
if err := builder.OPTResource(ednsHeader, dns.OPTResource{}); err != nil {
panic(err)
}
payload, _ := builder.Finish()
return payload
}
func TestDNSOverTCP(t *testing.T) {
f := fakeOSConfigurator{
SplitDNS: true,
BaseConfig: OSConfig{
Nameservers: mustIPs("8.8.8.8"),
SearchDomains: fqdns("coffee.shop"),
},
}
m := NewManager(t.Logf, &f, nil, new(tsdial.Dialer), nil)
m.resolver.TestOnlySetHook(f.SetResolver)
m.Set(Config{
Hosts: hosts(
"dave.ts.com.", "1.2.3.4",
"bradfitz.ts.com.", "2.3.4.5"),
Routes: upstreams("ts.com", ""),
SearchDomains: fqdns("tailscale.com", "universe.tf"),
})
defer m.Down()
c, s := net.Pipe()
defer s.Close()
go m.HandleTCPConn(s, netaddr.IPPort{})
defer c.Close()
wantResults := map[dnsname.FQDN]string{
"dave.ts.com.": "1.2.3.4",
"bradfitz.ts.com.": "2.3.4.5",
}
for domain, _ := range wantResults {
b := mkDNSRequest(domain, dns.TypeA)
binary.Write(c, binary.BigEndian, uint16(len(b)))
c.Write(b)
}
results := map[dnsname.FQDN]string{}
for i := 0; i < len(wantResults); i++ {
var respLength uint16
if err := binary.Read(c, binary.BigEndian, &respLength); err != nil {
t.Fatalf("reading len: %v", err)
}
resp := make([]byte, int(respLength))
if _, err := io.ReadFull(c, resp); err != nil {
t.Fatalf("reading data: %v", err)
}
var parser dns.Parser
if _, err := parser.Start(resp); err != nil {
t.Errorf("parser.Start() failed: %v", err)
continue
}
q, err := parser.Question()
if err != nil {
t.Errorf("parser.Question(): %v", err)
continue
}
if err := parser.SkipAllQuestions(); err != nil {
t.Errorf("parser.SkipAllQuestions(): %v", err)
continue
}
ah, err := parser.AnswerHeader()
if err != nil {
t.Errorf("parser.AnswerHeader(): %v", err)
continue
}
if ah.Type != dns.TypeA {
t.Errorf("unexpected answer type: got %v, want %v", ah.Type, dns.TypeA)
continue
}
res, err := parser.AResource()
if err != nil {
t.Errorf("parser.AResource(): %v", err)
continue
}
results[dnsname.FQDN(q.Name.String())] = net.IP(res.A[:]).String()
}
c.Close()
if diff := cmp.Diff(wantResults, results); diff != "" {
t.Errorf("wrong results (-got+want)\n%s", diff)
}
}