tsdns: initial implementation of a Tailscale DNS resolver (#396)

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
Dmytro Shynkevych 2020-06-08 18:19:26 -04:00 committed by GitHub
parent 5e1ee4be53
commit 511840b1f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 583 additions and 109 deletions

View File

@ -26,6 +26,7 @@ import (
"tailscale.com/wgengine"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
)
// LocalBackend is the glue between the major pieces of the Tailscale
@ -311,6 +312,7 @@ func (b *LocalBackend) Start(opts Options) error {
b.send(Notify{NetMap: newSt.NetMap})
b.updateFilter(newSt.NetMap)
b.updateDNSMap(newSt.NetMap)
if disableDERP {
b.e.SetDERPMap(nil)
} else {
@ -427,6 +429,27 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) {
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
}
// updateDNSMap updates the domain map in the DNS resolver in wgengine
// based on the given netMap and user preferences.
func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
if netMap == nil {
return
}
dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)}
for _, peer := range netMap.Peers {
if len(peer.Addresses) == 0 {
continue
}
domain := peer.Hostinfo.Hostname
// Like PeerStatus.SimpleHostName()
domain = strings.TrimSuffix(domain, ".local")
domain = strings.TrimSuffix(domain, ".localdomain")
domain = domain + ".ipn.dev"
dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr)
}
b.e.SetDNSMap(dnsMap)
}
// readPoller is a goroutine that receives service lists from
// b.portpoll and propagates them into the controlclient's HostInfo.
func (b *LocalBackend) readPoller() {
@ -667,6 +690,7 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
}
b.updateFilter(b.netMapCache)
b.updateDNSMap(b.netMapCache)
if old.WantRunning != new.WantRunning {
b.stateMachine()
@ -799,6 +823,13 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router.
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
}
// The Tailscale DNS IP.
// TODO(dmytro): make this configurable.
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
IP: netaddr.IPv4(100, 100, 100, 100),
Bits: 32,
})
return rs
}

View File

@ -6,7 +6,6 @@
package filter
import (
"fmt"
"sync"
"time"
@ -137,7 +136,7 @@ func maybeHexdump(flag RunFlags, b []byte) string {
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) {
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, r Response, why string) {
var verdict string
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
@ -151,36 +150,33 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacke
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
// since it causes an allocation.
if verdict != "" {
var qs string
if q == nil {
qs = fmt.Sprintf("(%d bytes)", len(b))
} else {
qs = q.String()
}
f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b))
b := q.Buffer()
f.logf("%s: %s %d %s\n%s", verdict, q.String(), len(b), why, maybeHexdump(runflags, b))
}
}
func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(b, q, rf)
// RunIn determines whether this node is allowed to receive q from a Tailscale peer.
func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(q, rf)
if r == Accept || r == Drop {
// already logged
return r
}
r, why := f.runIn(q)
f.logRateLimit(rf, b, q, r, why)
f.logRateLimit(rf, q, r, why)
return r
}
func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(b, q, rf)
// RunOut determines whether this node is allowed to send q to a Tailscale peer.
func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
r := f.pre(q, rf)
if r == Drop || r == Accept {
// already logged
return r
}
r, why := f.runOut(q)
f.logRateLimit(rf, b, q, r, why)
f.logRateLimit(rf, q, r, why)
return r
}
@ -251,29 +247,28 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
return Accept, "ok out"
}
func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
if len(b) == 0 {
func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags) Response {
if len(q.Buffer()) == 0 {
// wireguard keepalive packet, always permit.
return Accept
}
if len(b) < 20 {
f.logRateLimit(rf, b, nil, Drop, "too short")
if len(q.Buffer()) < 20 {
f.logRateLimit(rf, q, Drop, "too short")
return Drop
}
q.Decode(b)
switch q.IPProto {
case packet.Unknown:
// Unknown packets are dangerous; always drop them.
f.logRateLimit(rf, b, q, Drop, "unknown")
f.logRateLimit(rf, q, Drop, "unknown")
return Drop
case packet.IPv6:
f.logRateLimit(rf, b, q, Drop, "ipv6")
f.logRateLimit(rf, q, Drop, "ipv6")
return Drop
case packet.Fragment:
// Fragments after the first always need to be passed through.
// Very small fragments are considered Junk by ParsedPacket.
f.logRateLimit(rf, b, q, Accept, "fragment")
f.logRateLimit(rf, q, Accept, "fragment")
return Accept
}

View File

@ -144,11 +144,12 @@ func TestNoAllocs(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := int(testing.AllocsPerRun(1000, func() {
var q ParsedPacket
q := &ParsedPacket{}
q.Decode(test.packet)
if test.in {
acl.RunIn(test.packet, &q, 0)
acl.RunIn(q, 0)
} else {
acl.RunOut(test.packet, &q, 0)
acl.RunOut(q, 0)
}
}))
@ -187,12 +188,13 @@ func BenchmarkFilter(b *testing.B) {
for _, bench := range benches {
b.Run(bench.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
var q ParsedPacket
q := &ParsedPacket{}
q.Decode(bench.packet)
// This branch seems to have no measurable impact on performance.
if bench.in {
acl.RunIn(bench.packet, &q, 0)
acl.RunIn(q, 0)
} else {
acl.RunOut(bench.packet, &q, 0)
acl.RunOut(q, 0)
}
}
})
@ -215,7 +217,9 @@ func TestPreFilter(t *testing.T) {
}
f := NewAllowNone(t.Logf)
for _, testPacket := range packets {
got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts)
p := &ParsedPacket{}
p.Decode(testPacket.b)
got := f.pre(p, LogDrops|LogAccepts)
if got != testPacket.want {
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
}

View File

@ -102,7 +102,7 @@ func ipChecksum(b []byte) uint16 {
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
// and shouldn't need any memory allocation.
func (q *ParsedPacket) Decode(b []byte) {
q.b = nil
q.b = b
if len(b) < ipHeaderLength {
q.IPProto = Unknown
@ -170,7 +170,6 @@ func (q *ParsedPacket) Decode(b []byte) {
}
q.SrcPort = 0
q.DstPort = 0
q.b = b
q.dataofs = q.subofs + icmpHeaderLength
return
case TCP:
@ -181,7 +180,6 @@ func (q *ParsedPacket) Decode(b []byte) {
q.SrcPort = get16(sub[0:2])
q.DstPort = get16(sub[2:4])
q.TCPFlags = sub[13] & 0x3F
q.b = b
headerLength := (sub[12] & 0xF0) >> 2
q.dataofs = q.subofs + int(headerLength)
return
@ -192,7 +190,6 @@ func (q *ParsedPacket) Decode(b []byte) {
}
q.SrcPort = get16(sub[0:2])
q.DstPort = get16(sub[2:4])
q.b = b
q.dataofs = q.subofs + udpHeaderLength
return
default:
@ -244,6 +241,11 @@ func (q *ParsedPacket) UDPHeader() UDPHeader {
}
}
// Buffer returns the entire packet buffer.
func (q *ParsedPacket) Buffer() []byte {
return q.b
}
// Sub returns the IP subprotocol section.
func (q *ParsedPacket) Sub(begin, n int) []byte {
return q.b[q.subofs+begin : q.subofs+begin+n]

View File

@ -90,6 +90,7 @@ var ipv6PacketBuffer = []byte{
}
var ipv6PacketDecode = ParsedPacket{
b: ipv6PacketBuffer,
IPProto: IPv6,
}
@ -100,6 +101,7 @@ var unknownPacketBuffer = []byte{
}
var unknownPacketDecode = ParsedPacket{
b: unknownPacketBuffer,
IPProto: Unknown,
}

274
wgengine/tsdns/tsdns.go Normal file
View File

@ -0,0 +1,274 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tsdns provides a Resolver struct capable of resolving
// domains on a Tailscale network.
package tsdns
import (
"encoding/binary"
"errors"
"strings"
"sync"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/types/logger"
"tailscale.com/wgengine/packet"
)
// defaultTTL is the TTL in seconds of all responses from Resolver.
const defaultTTL = 600
var (
errMapNotSet = errors.New("domain map not set")
errNoSuchDomain = errors.New("domain does not exist")
errNotImplemented = errors.New("query type not implemented")
errNotOurName = errors.New("not an *.ipn.dev domain")
errNotQuery = errors.New("not a DNS query")
)
var (
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
defaultPort = uint16(53)
)
// Map is all the data Resolver needs to resolve DNS queries.
type Map struct {
// DomainToIP is a mapping of Tailscale domains to their IP addresses.
// For example, monitoring.ipn.dev -> 100.64.0.1.
DomainToIP map[string]netaddr.IP
}
// Resolver is a DNS resolver for domain names of the form *.ipn.dev
// It is intended
type Resolver struct {
logf logger.Logf
// ip is the IP on which the resolver is listening.
ip packet.IP
// port is the port on which the resolver is listening.
port uint16
// mu guards the following fields from being updated while used.
mu sync.Mutex
// dnsMap is the map most recently received from the control server.
dnsMap *Map
}
// NewResolver constructs a resolver with default parameters.
func NewResolver(logf logger.Logf) *Resolver {
r := &Resolver{
logf: logf,
ip: defaultIP,
port: defaultPort,
}
return r
}
// AcceptsPacket determines if the given packet is
// directed to this resolver (by ip and port).
// We also require that UDP be used to simplify things for now.
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool {
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP
}
// SetMap sets the resolver's DNS map.
func (r *Resolver) SetMap(m *Map) {
r.mu.Lock()
r.dnsMap = m
r.mu.Unlock()
}
// Resolve maps a given domain name to the IP address of the host that owns it.
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
// If not a subdomain of ipn.dev, then we must refuse this query.
// We do this before checking the map to distinguish beween nonexistent domains
// and misdirected queries.
if !strings.HasSuffix(domain, ".ipn.dev") {
return netaddr.IP{}, dns.RCodeRefused, errNotOurName
}
r.mu.Lock()
if r.dnsMap == nil {
r.mu.Unlock()
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
}
addr, found := r.dnsMap.DomainToIP[domain]
r.mu.Unlock()
if !found {
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain
}
return addr, dns.RCodeSuccess, nil
}
type response struct {
Header dns.Header
ResourceHeader dns.ResourceHeader
Question dns.Question
// TODO(dmytro): support IPv6.
IP netaddr.IP
}
// parseQuery parses the query in given packet into a response struct.
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error {
var parser dns.Parser
var err error
resp.Header, err = parser.Start(query.Payload())
if err != nil {
return err
}
if resp.Header.Response {
return errNotQuery
}
resp.Question, err = parser.Question()
if err != nil {
return err
}
return nil
}
// makeResponse resolves the question stored in resp and sets the answer fields.
func (r *Resolver) makeResponse(resp *response) error {
var err error
name := resp.Question.Name.String()
if len(name) > 0 {
name = name[:len(name)-1]
}
if resp.Question.Type == dns.TypeA {
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
resp.IP, resp.Header.RCode, err = r.Resolve(name)
} else {
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
return err
}
// marshalAnswer serializes the answer record into an active builder.
func marshalAnswer(resp *response, builder *dns.Builder) error {
var answer dns.AResource
err := builder.StartAnswers()
if err != nil {
return err
}
answerHeader := dns.ResourceHeader{
Name: resp.Question.Name,
Type: dns.TypeA,
Class: dns.ClassINET,
TTL: defaultTTL,
}
ip := resp.IP.As16()
copy(answer.A[:], ip[12:])
return builder.AResource(answerHeader, answer)
}
// marshalResponse serializes the DNS response into an active builder.
func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) {
resp.Header.Response = true
resp.Header.Authoritative = true
if resp.Header.RecursionDesired {
resp.Header.RecursionAvailable = true
}
err := builder.StartQuestions()
if err != nil {
return nil, err
}
err = builder.Question(resp.Question)
if err != nil {
return nil, err
}
if resp.Header.RCode == dns.RCodeSuccess {
err = marshalAnswer(resp, builder)
if err != nil {
return nil, err
}
}
return builder.Finish()
}
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
udpHeader := query.UDPHeader()
udpHeader.ToResponse()
offset := udpHeader.Len()
// dns.Builder appends to the passed buffer (without reallocation when possible),
// so we pass in a zero-length slice starting at the point it should start writing.
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
// rbuf is the response slice with the correct length starting at offset.
rbuf, err := marshalResponse(resp, &builder)
if err != nil {
return nil, err
}
end := offset + len(rbuf)
err = udpHeader.Marshal(buf[:end])
if err != nil {
return nil, err
}
return buf[:end], nil
}
// Respond writes a response to query into buf and returns buf trimmed to the response length.
// It is assumed that r.AcceptsPacket(query) is true.
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
var resp response
var err error
// 0. Verify that contract is upheld.
if !r.AcceptsPacket(query) {
r.logf("[unexpected] tsdns: Respond called on query not for this resolver")
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponsePacket(query, &resp, buf)
}
// A DNS response is at least as long as the query
if len(buf) < len(query.Buffer()) {
r.logf("[unexpected] tsdns: response buffer is too small")
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponsePacket(query, &resp, buf)
}
// 1. Parse query packet.
err = r.parseQuery(query, &resp)
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during query parsing: %v", err)
resp.Header.RCode = dns.RCodeFormatError
return marshalResponsePacket(query, &resp, buf)
}
// 2. Service the query.
err = r.makeResponse(&resp)
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during name resolution: %v", err)
return marshalResponsePacket(query, &resp, buf)
}
// For now, we require IPv4 in all cases.
// If we somehow came up with a non-IPv4 address, it's our fault.
if !resp.IP.Is4() {
resp.Header.RCode = dns.RCodeServerFailure
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
}
// 3. Serialize the response.
return marshalResponsePacket(query, &resp, buf)
}

View File

@ -10,6 +10,7 @@ import (
"errors"
"io"
"os"
"sync"
"sync/atomic"
"github.com/tailscale/wireguard-go/device"
@ -19,10 +20,12 @@ import (
"tailscale.com/wgengine/packet"
)
const (
readMaxSize = device.MaxMessageSize
readOffset = device.MessageTransportHeaderSize
)
const maxBufferSize = device.MaxMessageSize
// PacketStartOffset is the minimal amount of leading space that must exist
// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
// This is necessary to avoid reallocation in wireguard-go internals.
const PacketStartOffset = device.MessageTransportHeaderSize
// MaxPacketSize is the maximum size (in bytes)
// of a packet that can be injected into a tstun.TUN.
@ -35,7 +38,15 @@ var (
ErrFiltered = errors.New("packet dropped by filter")
)
var errPacketTooBig = errors.New("packet too big")
var (
errPacketTooBig = errors.New("packet too big")
errOffsetTooBig = errors.New("offset larger than buffer length")
errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
)
// FilterFunc is a packet-filtering function with access to the TUN device.
// It must not hold onto the packet struct, as its backing storage will be reused.
type FilterFunc func(*packet.ParsedPacket, *TUN) filter.Response
// TUN wraps a tun.Device from wireguard-go,
// augmenting it with filtering and packet injection.
@ -47,10 +58,14 @@ type TUN struct {
tdev tun.Device
// buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid graticious allocation.
buffer [readMaxSize]byte
// It is made a static buffer in order to avoid allocations.
buffer [maxBufferSize]byte
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
bufferConsumed chan struct{}
// parsedPacketPool holds a pool of ParsedPacket structs for use in filtering.
// This is needed because escape analysis cannot see that parsed packets
// do not escape through {Pre,Post}Filter{In,Out}.
parsedPacketPool sync.Pool // of *packet.ParsedPacket
// closed signals poll (by closing) when the device is closed.
closed chan struct{}
@ -73,8 +88,19 @@ type TUN struct {
// filterFlags control the verbosity of logging packet drops/accepts.
filterFlags filter.RunFlags
// insecure disables all filtering when set. This is useful in tests.
insecure bool
// PreFilterIn is the inbound filter function that runs before the main filter
// and therefore sees the packets that may be later dropped by it.
PreFilterIn FilterFunc
// PostFilterIn is the inbound filter function that runs after the main filter.
PostFilterIn FilterFunc
// PreFilterOut is the outbound filter function that runs before the main filter
// and therefore sees the packets that may be later dropped by it.
PreFilterOut FilterFunc
// PostFilterOut is the outbound filter function that runs after the main filter.
PostFilterOut FilterFunc
// disableFilter disables all filtering when set. This should only be used in tests.
disableFilter bool
}
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
@ -87,8 +113,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
closed: make(chan struct{}),
errors: make(chan error),
outbound: make(chan []byte),
filterFlags: filter.LogAccepts | filter.LogDrops,
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
filterFlags: filter.LogAccepts | filter.LogDrops,
}
tun.parsedPacketPool.New = func() interface{} {
return new(packet.ParsedPacket)
}
go tun.poll()
// The buffer starts out consumed.
tun.bufferConsumed <- struct{}{}
@ -140,10 +172,10 @@ func (t *TUN) poll() {
// continue
}
// Read may use memory in t.buffer before readOffset for mandatory headers.
// Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
n, err := t.tdev.Read(t.buffer[:], readOffset)
n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
if err != nil {
select {
case <-t.closed:
@ -165,26 +197,41 @@ func (t *TUN) poll() {
select {
case <-t.closed:
return
case t.outbound <- t.buffer[readOffset : readOffset+n]:
case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
// continue
}
}
}
func (t *TUN) filterOut(buf []byte) filter.Response {
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
defer t.parsedPacketPool.Put(p)
p.Decode(buf)
if t.PreFilterOut != nil {
if t.PreFilterOut(p, t) == filter.Drop {
return filter.Drop
}
}
filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop
}
var p packet.ParsedPacket
if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept {
return filter.Accept
if filt.RunOut(p, t.filterFlags) != filter.Accept {
return filter.Drop
}
return filter.Drop
if t.PostFilterOut != nil {
if t.PostFilterOut(p, t) == filter.Drop {
return filter.Drop
}
}
return filter.Accept
}
func (t *TUN) Read(buf []byte, offset int) (int, error) {
@ -200,12 +247,16 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
// t.buffer has a fixed location in memory,
// so this is the easiest way to tell when it has been consumed.
// &packet[0] can be used because empty packets do not reach t.outbound.
if &packet[0] == &t.buffer[readOffset] {
if &packet[0] == &t.buffer[PacketStartOffset] {
t.bufferConsumed <- struct{}{}
} else {
// If the packet is not from t.buffer, then it is an injected packet.
// In this case, we return eary to bypass filtering
return n, nil
}
}
if !t.insecure {
if !t.disableFilter {
response := t.filterOut(buf[offset : offset+n])
if response != filter.Accept {
// Wireguard considers read errors fatal; pretend nothing was read
@ -217,35 +268,38 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
}
func (t *TUN) filterIn(buf []byte) filter.Response {
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
defer t.parsedPacketPool.Put(p)
p.Decode(buf)
if t.PreFilterIn != nil {
if t.PreFilterIn(p, t) == filter.Drop {
return filter.Drop
}
}
filt, _ := t.filter.Load().(*filter.Filter)
if filt == nil {
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
return filter.Drop
}
var p packet.ParsedPacket
if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept {
// Only in fake mode, answer any incoming pings.
if p.IsEchoRequest() {
ft, ok := t.tdev.(*fakeTUN)
if ok {
header := p.ICMPHeader()
header.ToResponse()
packet := packet.Generate(&header, p.Payload())
ft.Write(packet, 0)
// We already handled it, stop.
return filter.Drop
}
}
return filter.Accept
if filt.RunIn(p, t.filterFlags) != filter.Accept {
return filter.Drop
}
return filter.Drop
if t.PostFilterIn != nil {
if t.PostFilterIn(p, t) == filter.Drop {
return filter.Drop
}
}
return filter.Accept
}
func (t *TUN) Write(buf []byte, offset int) (int, error) {
if !t.insecure {
if !t.disableFilter {
response := t.filterIn(buf[offset:])
if response != filter.Accept {
return 0, ErrFiltered
@ -264,24 +318,53 @@ func (t *TUN) SetFilter(filt *filter.Filter) {
t.filter.Store(filt)
}
// InjectInbound makes the TUN device behave as if a packet
// InjectInboundDirect makes the TUN device behave as if a packet
// with the given contents was received from the network.
// It blocks and does not take ownership of the packet.
// Injecting an empty packet is a no-op.
func (t *TUN) InjectInbound(packet []byte) error {
// The injected packet will not pass through inbound filters.
//
// The packet contents are to start at &buf[offset].
// offset must be greater or equal to PacketStartOffset.
// The space before &buf[offset] will be used by Wireguard.
func (t *TUN) InjectInboundDirect(buf []byte, offset int) error {
if len(buf) > MaxPacketSize {
return errPacketTooBig
}
if len(buf) < offset {
return errOffsetTooBig
}
if offset < PacketStartOffset {
return errOffsetTooSmall
}
// Write to the underlying device to skip filters.
_, err := t.tdev.Write(buf, offset)
return err
}
// InjectInboundCopy takes a packet without leading space,
// reallocates it to conform to the InjectInbondDirect interface
// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
func (t *TUN) InjectInboundCopy(packet []byte) error {
// We duplicate this check from InjectInboundDirect here
// to avoid wasting an allocation on an oversized packet.
if len(packet) > MaxPacketSize {
return errPacketTooBig
}
if len(packet) == 0 {
return nil
}
_, err := t.Write(packet, 0)
return err
buf := make([]byte, PacketStartOffset+len(packet))
copy(buf[PacketStartOffset:], packet)
return t.InjectInboundDirect(buf, PacketStartOffset)
}
// InjectOutbound makes the TUN device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.
// The injected packet will not pass through outbound filters.
// Injecting an empty packet is a no-op.
func (t *TUN) InjectOutbound(packet []byte) error {
if len(packet) > MaxPacketSize {

View File

@ -58,7 +58,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
if secure {
setfilter(logf, tun)
} else {
tun.insecure = true
tun.disableFilter = true
}
return chtun, tun
}
@ -69,7 +69,7 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
if secure {
setfilter(logf, tun)
} else {
tun.insecure = true
tun.disableFilter = true
}
return ftun.(*fakeTUN), tun
}
@ -151,7 +151,7 @@ func TestWriteAndInject(t *testing.T) {
for _, packet := range injected {
go func(packet string) {
payload := []byte(packet)
err := tun.InjectInbound(payload)
err := tun.InjectInboundCopy(payload)
if err != nil {
t.Errorf("%s: error: %v", packet, err)
}

View File

@ -34,6 +34,7 @@ import (
"tailscale.com/wgengine/monitor"
"tailscale.com/wgengine/packet"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
"tailscale.com/wgengine/tstun"
)
@ -54,6 +55,7 @@ type userspaceEngine struct {
tundev *tstun.TUN
wgdev *device.Device
router router.Router
resolver *tsdns.Resolver
magicConn *magicsock.Conn
linkMon *monitor.Mon
@ -73,6 +75,28 @@ type userspaceEngine struct {
// Lock ordering: wgLock, then mu.
}
// RouterGen is the signature for a function that creates a
// router.Router.
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
type EngineConfig struct {
// Logf is the logging function used by the engine.
Logf logger.Logf
// TUN is the tun device used by the engine.
TUN tun.Device
// RouterGen is the function used to instantiate the router.
RouterGen RouterGen
// ListenPort is the port on which the engine will listen.
ListenPort uint16
// EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers
// will be intercepted and responded to, regardless of the source host.
EchoRespondToAll bool
// UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev
// directed to the designated Taislcale DNS address (see wgengine/tsdns)
// will be intercepted and resolved by a tsdns.Resolver.
UseTailscaleDNS bool
}
type Loggify struct {
f logger.Logf
}
@ -84,8 +108,14 @@ func (l *Loggify) Write(b []byte) (int, error) {
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
logf("Starting userspace wireguard engine (FAKE tuntap device).")
tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN())
return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort)
conf := EngineConfig{
Logf: logf,
TUN: tstun.NewFakeTUN(),
RouterGen: router.NewFake,
ListenPort: listenPort,
EchoRespondToAll: true,
}
return NewUserspaceEngineAdvanced(conf)
}
// NewUserspaceEngine creates the named tun device and returns a
@ -104,38 +134,53 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En
return nil, err
}
logf("CreateTUN ok.")
tundev := tstun.WrapTUN(logf, tun)
e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort)
conf := EngineConfig{
Logf: logf,
TUN: tun,
RouterGen: router.New,
ListenPort: listenPort,
// TODO(dmytro): plumb this down.
UseTailscaleDNS: true,
}
e, err := NewUserspaceEngineAdvanced(conf)
if err != nil {
return nil, err
}
return e, err
}
// RouterGen is the signature for a function that creates a
// router.Router.
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
// NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing
// a custom router constructor and listening port.
func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) {
return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort)
// NewUserspaceEngineAdvanced is like NewUserspaceEngine
// but provides control over all config fields.
func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) {
return newUserspaceEngineAdvanced(conf)
}
func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) {
func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
logf := conf.Logf
e := &userspaceEngine{
logf: logf,
reqCh: make(chan struct{}, 1),
waitCh: make(chan struct{}),
tundev: tundev,
pingers: make(map[wgcfg.Key]*pinger),
logf: logf,
reqCh: make(chan struct{}, 1),
waitCh: make(chan struct{}),
tundev: tstun.WrapTUN(logf, conf.TUN),
resolver: tsdns.NewResolver(logf),
pingers: make(map[wgcfg.Key]*pinger),
}
e.linkState, _ = getLinkState()
// Respond to all pings only in fake mode.
if conf.EchoRespondToAll {
e.tundev.PostFilterIn = echoRespondToAll
}
if conf.UseTailscaleDNS {
e.tundev.PreFilterOut = e.handleDNS
}
mon, err := monitor.New(logf, func() { e.LinkChange(false) })
if err != nil {
tundev.Close()
e.tundev.Close()
return nil, err
}
e.linkMon = mon
@ -149,12 +194,12 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
}
magicsockOpts := magicsock.Options{
Logf: logf,
Port: listenPort,
Port: conf.ListenPort,
EndpointsFunc: endpointsFn,
}
e.magicConn, err = magicsock.NewConn(magicsockOpts)
if err != nil {
tundev.Close()
e.tundev.Close()
return nil, fmt.Errorf("wgengine: %v", err)
}
@ -211,7 +256,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
// Pass the underlying tun.(*NativeDevice) to the router:
// routers do not Read or Write, but do access native interfaces.
e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap())
e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap())
if err != nil {
e.magicConn.Close()
return nil, err
@ -256,6 +301,37 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
return e, nil
}
// echoRespondToAll is an inbound post-filter responding to all echo requests.
func echoRespondToAll(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
if p.IsEchoRequest() {
header := p.ICMPHeader()
header.ToResponse()
packet := packet.Generate(&header, p.Payload())
t.InjectOutbound(packet)
// We already handled it, stop.
return filter.Drop
}
return filter.Accept
}
// handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
if e.resolver.AcceptsPacket(p) {
// TODO(dmytro): avoid this allocation without having tsdns know tstun quirks.
buf := make([]byte, tstun.MaxPacketSize)
offset := tstun.PacketStartOffset
response, err := e.resolver.Respond(p, buf[offset:])
if err != nil {
e.logf("DNS resolver error: %v", err)
} else {
t.InjectInboundDirect(buf[:offset+len(response)], offset)
}
// We already handled it, stop.
return filter.Drop
}
return filter.Accept
}
// pinger sends ping packets for a few seconds.
//
// These generated packets are used to ensure we trigger the spray logic in
@ -447,6 +523,10 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
e.tundev.SetFilter(filt)
}
func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) {
e.resolver.SetMap(dm)
}
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
e.mu.Lock()
defer e.mu.Unlock()

View File

@ -15,6 +15,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
)
// NewWatchdog wraps an Engine and makes sure that all methods complete
@ -74,6 +75,9 @@ func (e *watchdogEngine) GetFilter() *filter.Filter {
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
}
func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) {
e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) })
}
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
}

View File

@ -10,9 +10,6 @@ import (
"strings"
"testing"
"time"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tstun"
)
func TestWatchdog(t *testing.T) {
@ -20,8 +17,7 @@ func TestWatchdog(t *testing.T) {
t.Run("default watchdog does not fire", func(t *testing.T) {
t.Parallel()
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
e, err := NewFakeUserspaceEngine(t.Logf, 0)
if err != nil {
t.Fatal(err)
}
@ -37,8 +33,7 @@ func TestWatchdog(t *testing.T) {
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
t.Parallel()
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
e, err := NewFakeUserspaceEngine(t.Logf, 0)
if err != nil {
t.Fatal(err)
}

View File

@ -13,6 +13,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/router"
"tailscale.com/wgengine/tsdns"
)
// ByteCount is the number of bytes that have been sent or received.
@ -65,6 +66,9 @@ type Engine interface {
// SetFilter updates the packet filter.
SetFilter(*filter.Filter)
// SetDNSMap updates the DNS map.
SetDNSMap(*tsdns.Map)
// SetStatusCallback sets the function to call when the
// WireGuard status changes.
SetStatusCallback(StatusCallback)