tsdns: initial implementation of a Tailscale DNS resolver (#396)
Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
parent
5e1ee4be53
commit
511840b1f6
31
ipn/local.go
31
ipn/local.go
|
@ -26,6 +26,7 @@ import (
|
|||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tsdns"
|
||||
)
|
||||
|
||||
// LocalBackend is the glue between the major pieces of the Tailscale
|
||||
|
@ -311,6 +312,7 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||
|
||||
b.send(Notify{NetMap: newSt.NetMap})
|
||||
b.updateFilter(newSt.NetMap)
|
||||
b.updateDNSMap(newSt.NetMap)
|
||||
if disableDERP {
|
||||
b.e.SetDERPMap(nil)
|
||||
} else {
|
||||
|
@ -427,6 +429,27 @@ func (b *LocalBackend) updateFilter(netMap *controlclient.NetworkMap) {
|
|||
b.e.SetFilter(filter.New(netMap.PacketFilter, localNets, b.e.GetFilter(), b.logf))
|
||||
}
|
||||
|
||||
// updateDNSMap updates the domain map in the DNS resolver in wgengine
|
||||
// based on the given netMap and user preferences.
|
||||
func (b *LocalBackend) updateDNSMap(netMap *controlclient.NetworkMap) {
|
||||
if netMap == nil {
|
||||
return
|
||||
}
|
||||
dnsMap := &tsdns.Map{DomainToIP: make(map[string]netaddr.IP)}
|
||||
for _, peer := range netMap.Peers {
|
||||
if len(peer.Addresses) == 0 {
|
||||
continue
|
||||
}
|
||||
domain := peer.Hostinfo.Hostname
|
||||
// Like PeerStatus.SimpleHostName()
|
||||
domain = strings.TrimSuffix(domain, ".local")
|
||||
domain = strings.TrimSuffix(domain, ".localdomain")
|
||||
domain = domain + ".ipn.dev"
|
||||
dnsMap.DomainToIP[domain] = netaddr.IPFrom16(peer.Addresses[0].IP.Addr)
|
||||
}
|
||||
b.e.SetDNSMap(dnsMap)
|
||||
}
|
||||
|
||||
// readPoller is a goroutine that receives service lists from
|
||||
// b.portpoll and propagates them into the controlclient's HostInfo.
|
||||
func (b *LocalBackend) readPoller() {
|
||||
|
@ -667,6 +690,7 @@ func (b *LocalBackend) SetPrefs(new *Prefs) {
|
|||
}
|
||||
|
||||
b.updateFilter(b.netMapCache)
|
||||
b.updateDNSMap(b.netMapCache)
|
||||
|
||||
if old.WantRunning != new.WantRunning {
|
||||
b.stateMachine()
|
||||
|
@ -799,6 +823,13 @@ func routerConfig(cfg *wgcfg.Config, prefs *Prefs, dnsDomains []string) *router.
|
|||
rs.Routes = append(rs.Routes, wgCIDRToNetaddr(peer.AllowedIPs)...)
|
||||
}
|
||||
|
||||
// The Tailscale DNS IP.
|
||||
// TODO(dmytro): make this configurable.
|
||||
rs.Routes = append(rs.Routes, netaddr.IPPrefix{
|
||||
IP: netaddr.IPv4(100, 100, 100, 100),
|
||||
Bits: 32,
|
||||
})
|
||||
|
||||
return rs
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@
|
|||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -137,7 +136,7 @@ func maybeHexdump(flag RunFlags, b []byte) string {
|
|||
var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
|
||||
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
|
||||
|
||||
func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacket, r Response, why string) {
|
||||
func (f *Filter) logRateLimit(runflags RunFlags, q *packet.ParsedPacket, r Response, why string) {
|
||||
var verdict string
|
||||
|
||||
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
|
||||
|
@ -151,36 +150,33 @@ func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.ParsedPacke
|
|||
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
|
||||
// since it causes an allocation.
|
||||
if verdict != "" {
|
||||
var qs string
|
||||
if q == nil {
|
||||
qs = fmt.Sprintf("(%d bytes)", len(b))
|
||||
} else {
|
||||
qs = q.String()
|
||||
}
|
||||
f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b))
|
||||
b := q.Buffer()
|
||||
f.logf("%s: %s %d %s\n%s", verdict, q.String(), len(b), why, maybeHexdump(runflags, b))
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filter) RunIn(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
r := f.pre(b, q, rf)
|
||||
// RunIn determines whether this node is allowed to receive q from a Tailscale peer.
|
||||
func (f *Filter) RunIn(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
r := f.pre(q, rf)
|
||||
if r == Accept || r == Drop {
|
||||
// already logged
|
||||
return r
|
||||
}
|
||||
|
||||
r, why := f.runIn(q)
|
||||
f.logRateLimit(rf, b, q, r, why)
|
||||
f.logRateLimit(rf, q, r, why)
|
||||
return r
|
||||
}
|
||||
|
||||
func (f *Filter) RunOut(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
r := f.pre(b, q, rf)
|
||||
// RunOut determines whether this node is allowed to send q to a Tailscale peer.
|
||||
func (f *Filter) RunOut(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
r := f.pre(q, rf)
|
||||
if r == Drop || r == Accept {
|
||||
// already logged
|
||||
return r
|
||||
}
|
||||
r, why := f.runOut(q)
|
||||
f.logRateLimit(rf, b, q, r, why)
|
||||
f.logRateLimit(rf, q, r, why)
|
||||
return r
|
||||
}
|
||||
|
||||
|
@ -251,29 +247,28 @@ func (f *Filter) runOut(q *packet.ParsedPacket) (r Response, why string) {
|
|||
return Accept, "ok out"
|
||||
}
|
||||
|
||||
func (f *Filter) pre(b []byte, q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
if len(b) == 0 {
|
||||
func (f *Filter) pre(q *packet.ParsedPacket, rf RunFlags) Response {
|
||||
if len(q.Buffer()) == 0 {
|
||||
// wireguard keepalive packet, always permit.
|
||||
return Accept
|
||||
}
|
||||
if len(b) < 20 {
|
||||
f.logRateLimit(rf, b, nil, Drop, "too short")
|
||||
if len(q.Buffer()) < 20 {
|
||||
f.logRateLimit(rf, q, Drop, "too short")
|
||||
return Drop
|
||||
}
|
||||
q.Decode(b)
|
||||
|
||||
switch q.IPProto {
|
||||
case packet.Unknown:
|
||||
// Unknown packets are dangerous; always drop them.
|
||||
f.logRateLimit(rf, b, q, Drop, "unknown")
|
||||
f.logRateLimit(rf, q, Drop, "unknown")
|
||||
return Drop
|
||||
case packet.IPv6:
|
||||
f.logRateLimit(rf, b, q, Drop, "ipv6")
|
||||
f.logRateLimit(rf, q, Drop, "ipv6")
|
||||
return Drop
|
||||
case packet.Fragment:
|
||||
// Fragments after the first always need to be passed through.
|
||||
// Very small fragments are considered Junk by ParsedPacket.
|
||||
f.logRateLimit(rf, b, q, Accept, "fragment")
|
||||
f.logRateLimit(rf, q, Accept, "fragment")
|
||||
return Accept
|
||||
}
|
||||
|
||||
|
|
|
@ -144,11 +144,12 @@ func TestNoAllocs(t *testing.T) {
|
|||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := int(testing.AllocsPerRun(1000, func() {
|
||||
var q ParsedPacket
|
||||
q := &ParsedPacket{}
|
||||
q.Decode(test.packet)
|
||||
if test.in {
|
||||
acl.RunIn(test.packet, &q, 0)
|
||||
acl.RunIn(q, 0)
|
||||
} else {
|
||||
acl.RunOut(test.packet, &q, 0)
|
||||
acl.RunOut(q, 0)
|
||||
}
|
||||
}))
|
||||
|
||||
|
@ -187,12 +188,13 @@ func BenchmarkFilter(b *testing.B) {
|
|||
for _, bench := range benches {
|
||||
b.Run(bench.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var q ParsedPacket
|
||||
q := &ParsedPacket{}
|
||||
q.Decode(bench.packet)
|
||||
// This branch seems to have no measurable impact on performance.
|
||||
if bench.in {
|
||||
acl.RunIn(bench.packet, &q, 0)
|
||||
acl.RunIn(q, 0)
|
||||
} else {
|
||||
acl.RunOut(bench.packet, &q, 0)
|
||||
acl.RunOut(q, 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -215,7 +217,9 @@ func TestPreFilter(t *testing.T) {
|
|||
}
|
||||
f := NewAllowNone(t.Logf)
|
||||
for _, testPacket := range packets {
|
||||
got := f.pre([]byte(testPacket.b), &ParsedPacket{}, LogDrops|LogAccepts)
|
||||
p := &ParsedPacket{}
|
||||
p.Decode(testPacket.b)
|
||||
got := f.pre(p, LogDrops|LogAccepts)
|
||||
if got != testPacket.want {
|
||||
t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b))
|
||||
}
|
||||
|
|
|
@ -102,7 +102,7 @@ func ipChecksum(b []byte) uint16 {
|
|||
// It extracts only the subprotocol id, IP addresses, and (if any) ports,
|
||||
// and shouldn't need any memory allocation.
|
||||
func (q *ParsedPacket) Decode(b []byte) {
|
||||
q.b = nil
|
||||
q.b = b
|
||||
|
||||
if len(b) < ipHeaderLength {
|
||||
q.IPProto = Unknown
|
||||
|
@ -170,7 +170,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||
}
|
||||
q.SrcPort = 0
|
||||
q.DstPort = 0
|
||||
q.b = b
|
||||
q.dataofs = q.subofs + icmpHeaderLength
|
||||
return
|
||||
case TCP:
|
||||
|
@ -181,7 +180,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||
q.SrcPort = get16(sub[0:2])
|
||||
q.DstPort = get16(sub[2:4])
|
||||
q.TCPFlags = sub[13] & 0x3F
|
||||
q.b = b
|
||||
headerLength := (sub[12] & 0xF0) >> 2
|
||||
q.dataofs = q.subofs + int(headerLength)
|
||||
return
|
||||
|
@ -192,7 +190,6 @@ func (q *ParsedPacket) Decode(b []byte) {
|
|||
}
|
||||
q.SrcPort = get16(sub[0:2])
|
||||
q.DstPort = get16(sub[2:4])
|
||||
q.b = b
|
||||
q.dataofs = q.subofs + udpHeaderLength
|
||||
return
|
||||
default:
|
||||
|
@ -244,6 +241,11 @@ func (q *ParsedPacket) UDPHeader() UDPHeader {
|
|||
}
|
||||
}
|
||||
|
||||
// Buffer returns the entire packet buffer.
|
||||
func (q *ParsedPacket) Buffer() []byte {
|
||||
return q.b
|
||||
}
|
||||
|
||||
// Sub returns the IP subprotocol section.
|
||||
func (q *ParsedPacket) Sub(begin, n int) []byte {
|
||||
return q.b[q.subofs+begin : q.subofs+begin+n]
|
||||
|
|
|
@ -90,6 +90,7 @@ var ipv6PacketBuffer = []byte{
|
|||
}
|
||||
|
||||
var ipv6PacketDecode = ParsedPacket{
|
||||
b: ipv6PacketBuffer,
|
||||
IPProto: IPv6,
|
||||
}
|
||||
|
||||
|
@ -100,6 +101,7 @@ var unknownPacketBuffer = []byte{
|
|||
}
|
||||
|
||||
var unknownPacketDecode = ParsedPacket{
|
||||
b: unknownPacketBuffer,
|
||||
IPProto: Unknown,
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package tsdns provides a Resolver struct capable of resolving
|
||||
// domains on a Tailscale network.
|
||||
package tsdns
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
dns "golang.org/x/net/dns/dnsmessage"
|
||||
"inet.af/netaddr"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
// defaultTTL is the TTL in seconds of all responses from Resolver.
|
||||
const defaultTTL = 600
|
||||
|
||||
var (
|
||||
errMapNotSet = errors.New("domain map not set")
|
||||
errNoSuchDomain = errors.New("domain does not exist")
|
||||
errNotImplemented = errors.New("query type not implemented")
|
||||
errNotOurName = errors.New("not an *.ipn.dev domain")
|
||||
errNotQuery = errors.New("not a DNS query")
|
||||
)
|
||||
|
||||
var (
|
||||
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
|
||||
defaultPort = uint16(53)
|
||||
)
|
||||
|
||||
// Map is all the data Resolver needs to resolve DNS queries.
|
||||
type Map struct {
|
||||
// DomainToIP is a mapping of Tailscale domains to their IP addresses.
|
||||
// For example, monitoring.ipn.dev -> 100.64.0.1.
|
||||
DomainToIP map[string]netaddr.IP
|
||||
}
|
||||
|
||||
// Resolver is a DNS resolver for domain names of the form *.ipn.dev
|
||||
// It is intended
|
||||
type Resolver struct {
|
||||
logf logger.Logf
|
||||
|
||||
// ip is the IP on which the resolver is listening.
|
||||
ip packet.IP
|
||||
// port is the port on which the resolver is listening.
|
||||
port uint16
|
||||
|
||||
// mu guards the following fields from being updated while used.
|
||||
mu sync.Mutex
|
||||
// dnsMap is the map most recently received from the control server.
|
||||
dnsMap *Map
|
||||
}
|
||||
|
||||
// NewResolver constructs a resolver with default parameters.
|
||||
func NewResolver(logf logger.Logf) *Resolver {
|
||||
r := &Resolver{
|
||||
logf: logf,
|
||||
ip: defaultIP,
|
||||
port: defaultPort,
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// AcceptsPacket determines if the given packet is
|
||||
// directed to this resolver (by ip and port).
|
||||
// We also require that UDP be used to simplify things for now.
|
||||
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool {
|
||||
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP
|
||||
}
|
||||
|
||||
// SetMap sets the resolver's DNS map.
|
||||
func (r *Resolver) SetMap(m *Map) {
|
||||
r.mu.Lock()
|
||||
r.dnsMap = m
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// Resolve maps a given domain name to the IP address of the host that owns it.
|
||||
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
|
||||
// If not a subdomain of ipn.dev, then we must refuse this query.
|
||||
// We do this before checking the map to distinguish beween nonexistent domains
|
||||
// and misdirected queries.
|
||||
if !strings.HasSuffix(domain, ".ipn.dev") {
|
||||
return netaddr.IP{}, dns.RCodeRefused, errNotOurName
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
if r.dnsMap == nil {
|
||||
r.mu.Unlock()
|
||||
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
|
||||
}
|
||||
addr, found := r.dnsMap.DomainToIP[domain]
|
||||
r.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain
|
||||
}
|
||||
return addr, dns.RCodeSuccess, nil
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Header dns.Header
|
||||
ResourceHeader dns.ResourceHeader
|
||||
Question dns.Question
|
||||
// TODO(dmytro): support IPv6.
|
||||
IP netaddr.IP
|
||||
}
|
||||
|
||||
// parseQuery parses the query in given packet into a response struct.
|
||||
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error {
|
||||
var parser dns.Parser
|
||||
var err error
|
||||
|
||||
resp.Header, err = parser.Start(query.Payload())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Header.Response {
|
||||
return errNotQuery
|
||||
}
|
||||
|
||||
resp.Question, err = parser.Question()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeResponse resolves the question stored in resp and sets the answer fields.
|
||||
func (r *Resolver) makeResponse(resp *response) error {
|
||||
var err error
|
||||
|
||||
name := resp.Question.Name.String()
|
||||
if len(name) > 0 {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
|
||||
if resp.Question.Type == dns.TypeA {
|
||||
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
|
||||
resp.IP, resp.Header.RCode, err = r.Resolve(name)
|
||||
} else {
|
||||
resp.Header.RCode = dns.RCodeNotImplemented
|
||||
err = errNotImplemented
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// marshalAnswer serializes the answer record into an active builder.
|
||||
func marshalAnswer(resp *response, builder *dns.Builder) error {
|
||||
var answer dns.AResource
|
||||
|
||||
err := builder.StartAnswers()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
answerHeader := dns.ResourceHeader{
|
||||
Name: resp.Question.Name,
|
||||
Type: dns.TypeA,
|
||||
Class: dns.ClassINET,
|
||||
TTL: defaultTTL,
|
||||
}
|
||||
ip := resp.IP.As16()
|
||||
copy(answer.A[:], ip[12:])
|
||||
return builder.AResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalResponse serializes the DNS response into an active builder.
|
||||
func marshalResponse(resp *response, builder *dns.Builder) ([]byte, error) {
|
||||
resp.Header.Response = true
|
||||
resp.Header.Authoritative = true
|
||||
if resp.Header.RecursionDesired {
|
||||
resp.Header.RecursionAvailable = true
|
||||
}
|
||||
|
||||
err := builder.StartQuestions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = builder.Question(resp.Question)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Header.RCode == dns.RCodeSuccess {
|
||||
err = marshalAnswer(resp, builder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return builder.Finish()
|
||||
}
|
||||
|
||||
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
|
||||
udpHeader := query.UDPHeader()
|
||||
udpHeader.ToResponse()
|
||||
offset := udpHeader.Len()
|
||||
|
||||
// dns.Builder appends to the passed buffer (without reallocation when possible),
|
||||
// so we pass in a zero-length slice starting at the point it should start writing.
|
||||
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
|
||||
|
||||
// rbuf is the response slice with the correct length starting at offset.
|
||||
rbuf, err := marshalResponse(resp, &builder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
end := offset + len(rbuf)
|
||||
err = udpHeader.Marshal(buf[:end])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf[:end], nil
|
||||
}
|
||||
|
||||
// Respond writes a response to query into buf and returns buf trimmed to the response length.
|
||||
// It is assumed that r.AcceptsPacket(query) is true.
|
||||
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
|
||||
var resp response
|
||||
var err error
|
||||
|
||||
// 0. Verify that contract is upheld.
|
||||
if !r.AcceptsPacket(query) {
|
||||
r.logf("[unexpected] tsdns: Respond called on query not for this resolver")
|
||||
resp.Header.RCode = dns.RCodeServerFailure
|
||||
return marshalResponsePacket(query, &resp, buf)
|
||||
}
|
||||
// A DNS response is at least as long as the query
|
||||
if len(buf) < len(query.Buffer()) {
|
||||
r.logf("[unexpected] tsdns: response buffer is too small")
|
||||
resp.Header.RCode = dns.RCodeServerFailure
|
||||
return marshalResponsePacket(query, &resp, buf)
|
||||
}
|
||||
|
||||
// 1. Parse query packet.
|
||||
err = r.parseQuery(query, &resp)
|
||||
// We will not return this error: it is the sender's fault.
|
||||
if err != nil {
|
||||
r.logf("tsdns: error during query parsing: %v", err)
|
||||
resp.Header.RCode = dns.RCodeFormatError
|
||||
return marshalResponsePacket(query, &resp, buf)
|
||||
}
|
||||
|
||||
// 2. Service the query.
|
||||
err = r.makeResponse(&resp)
|
||||
// We will not return this error: it is the sender's fault.
|
||||
if err != nil {
|
||||
r.logf("tsdns: error during name resolution: %v", err)
|
||||
return marshalResponsePacket(query, &resp, buf)
|
||||
}
|
||||
// For now, we require IPv4 in all cases.
|
||||
// If we somehow came up with a non-IPv4 address, it's our fault.
|
||||
if !resp.IP.Is4() {
|
||||
resp.Header.RCode = dns.RCodeServerFailure
|
||||
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
|
||||
}
|
||||
|
||||
// 3. Serialize the response.
|
||||
return marshalResponsePacket(query, &resp, buf)
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/tailscale/wireguard-go/device"
|
||||
|
@ -19,10 +20,12 @@ import (
|
|||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
const (
|
||||
readMaxSize = device.MaxMessageSize
|
||||
readOffset = device.MessageTransportHeaderSize
|
||||
)
|
||||
const maxBufferSize = device.MaxMessageSize
|
||||
|
||||
// PacketStartOffset is the minimal amount of leading space that must exist
|
||||
// before &packet[offset] in a packet passed to Read, Write, or InjectInboundDirect.
|
||||
// This is necessary to avoid reallocation in wireguard-go internals.
|
||||
const PacketStartOffset = device.MessageTransportHeaderSize
|
||||
|
||||
// MaxPacketSize is the maximum size (in bytes)
|
||||
// of a packet that can be injected into a tstun.TUN.
|
||||
|
@ -35,7 +38,15 @@ var (
|
|||
ErrFiltered = errors.New("packet dropped by filter")
|
||||
)
|
||||
|
||||
var errPacketTooBig = errors.New("packet too big")
|
||||
var (
|
||||
errPacketTooBig = errors.New("packet too big")
|
||||
errOffsetTooBig = errors.New("offset larger than buffer length")
|
||||
errOffsetTooSmall = errors.New("offset smaller than PacketStartOffset")
|
||||
)
|
||||
|
||||
// FilterFunc is a packet-filtering function with access to the TUN device.
|
||||
// It must not hold onto the packet struct, as its backing storage will be reused.
|
||||
type FilterFunc func(*packet.ParsedPacket, *TUN) filter.Response
|
||||
|
||||
// TUN wraps a tun.Device from wireguard-go,
|
||||
// augmenting it with filtering and packet injection.
|
||||
|
@ -47,10 +58,14 @@ type TUN struct {
|
|||
tdev tun.Device
|
||||
|
||||
// buffer stores the oldest unconsumed packet from tdev.
|
||||
// It is made a static buffer in order to avoid graticious allocation.
|
||||
buffer [readMaxSize]byte
|
||||
// It is made a static buffer in order to avoid allocations.
|
||||
buffer [maxBufferSize]byte
|
||||
// bufferConsumed synchronizes access to buffer (shared by Read and poll).
|
||||
bufferConsumed chan struct{}
|
||||
// parsedPacketPool holds a pool of ParsedPacket structs for use in filtering.
|
||||
// This is needed because escape analysis cannot see that parsed packets
|
||||
// do not escape through {Pre,Post}Filter{In,Out}.
|
||||
parsedPacketPool sync.Pool // of *packet.ParsedPacket
|
||||
|
||||
// closed signals poll (by closing) when the device is closed.
|
||||
closed chan struct{}
|
||||
|
@ -73,8 +88,19 @@ type TUN struct {
|
|||
// filterFlags control the verbosity of logging packet drops/accepts.
|
||||
filterFlags filter.RunFlags
|
||||
|
||||
// insecure disables all filtering when set. This is useful in tests.
|
||||
insecure bool
|
||||
// PreFilterIn is the inbound filter function that runs before the main filter
|
||||
// and therefore sees the packets that may be later dropped by it.
|
||||
PreFilterIn FilterFunc
|
||||
// PostFilterIn is the inbound filter function that runs after the main filter.
|
||||
PostFilterIn FilterFunc
|
||||
// PreFilterOut is the outbound filter function that runs before the main filter
|
||||
// and therefore sees the packets that may be later dropped by it.
|
||||
PreFilterOut FilterFunc
|
||||
// PostFilterOut is the outbound filter function that runs after the main filter.
|
||||
PostFilterOut FilterFunc
|
||||
|
||||
// disableFilter disables all filtering when set. This should only be used in tests.
|
||||
disableFilter bool
|
||||
}
|
||||
|
||||
func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
||||
|
@ -87,8 +113,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
|
|||
closed: make(chan struct{}),
|
||||
errors: make(chan error),
|
||||
outbound: make(chan []byte),
|
||||
filterFlags: filter.LogAccepts | filter.LogDrops,
|
||||
// TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets.
|
||||
filterFlags: filter.LogAccepts | filter.LogDrops,
|
||||
}
|
||||
|
||||
tun.parsedPacketPool.New = func() interface{} {
|
||||
return new(packet.ParsedPacket)
|
||||
}
|
||||
|
||||
go tun.poll()
|
||||
// The buffer starts out consumed.
|
||||
tun.bufferConsumed <- struct{}{}
|
||||
|
@ -140,10 +172,10 @@ func (t *TUN) poll() {
|
|||
// continue
|
||||
}
|
||||
|
||||
// Read may use memory in t.buffer before readOffset for mandatory headers.
|
||||
// Read may use memory in t.buffer before PacketStartOffset for mandatory headers.
|
||||
// This is the rationale behind the tun.TUN.{Read,Write} interfaces
|
||||
// and the reason t.buffer has size MaxMessageSize and not MaxContentSize.
|
||||
n, err := t.tdev.Read(t.buffer[:], readOffset)
|
||||
n, err := t.tdev.Read(t.buffer[:], PacketStartOffset)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-t.closed:
|
||||
|
@ -165,26 +197,41 @@ func (t *TUN) poll() {
|
|||
select {
|
||||
case <-t.closed:
|
||||
return
|
||||
case t.outbound <- t.buffer[readOffset : readOffset+n]:
|
||||
case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TUN) filterOut(buf []byte) filter.Response {
|
||||
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
|
||||
defer t.parsedPacketPool.Put(p)
|
||||
p.Decode(buf)
|
||||
|
||||
if t.PreFilterOut != nil {
|
||||
if t.PreFilterOut(p, t) == filter.Drop {
|
||||
return filter.Drop
|
||||
}
|
||||
}
|
||||
|
||||
filt, _ := t.filter.Load().(*filter.Filter)
|
||||
|
||||
if filt == nil {
|
||||
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
|
||||
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
var p packet.ParsedPacket
|
||||
if filt.RunOut(buf, &p, t.filterFlags) == filter.Accept {
|
||||
return filter.Accept
|
||||
if filt.RunOut(p, t.filterFlags) != filter.Accept {
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
return filter.Drop
|
||||
if t.PostFilterOut != nil {
|
||||
if t.PostFilterOut(p, t) == filter.Drop {
|
||||
return filter.Drop
|
||||
}
|
||||
}
|
||||
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
||||
|
@ -200,12 +247,16 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
|||
// t.buffer has a fixed location in memory,
|
||||
// so this is the easiest way to tell when it has been consumed.
|
||||
// &packet[0] can be used because empty packets do not reach t.outbound.
|
||||
if &packet[0] == &t.buffer[readOffset] {
|
||||
if &packet[0] == &t.buffer[PacketStartOffset] {
|
||||
t.bufferConsumed <- struct{}{}
|
||||
} else {
|
||||
// If the packet is not from t.buffer, then it is an injected packet.
|
||||
// In this case, we return eary to bypass filtering
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
if !t.insecure {
|
||||
if !t.disableFilter {
|
||||
response := t.filterOut(buf[offset : offset+n])
|
||||
if response != filter.Accept {
|
||||
// Wireguard considers read errors fatal; pretend nothing was read
|
||||
|
@ -217,35 +268,38 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
|
|||
}
|
||||
|
||||
func (t *TUN) filterIn(buf []byte) filter.Response {
|
||||
p := t.parsedPacketPool.Get().(*packet.ParsedPacket)
|
||||
defer t.parsedPacketPool.Put(p)
|
||||
p.Decode(buf)
|
||||
|
||||
if t.PreFilterIn != nil {
|
||||
if t.PreFilterIn(p, t) == filter.Drop {
|
||||
return filter.Drop
|
||||
}
|
||||
}
|
||||
|
||||
filt, _ := t.filter.Load().(*filter.Filter)
|
||||
|
||||
if filt == nil {
|
||||
t.logf("Warning: you forgot to use SetFilter()! Packet dropped.")
|
||||
t.logf("tstun: warning: you forgot to use SetFilter()! Packet dropped.")
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
var p packet.ParsedPacket
|
||||
if filt.RunIn(buf, &p, t.filterFlags) == filter.Accept {
|
||||
// Only in fake mode, answer any incoming pings.
|
||||
if p.IsEchoRequest() {
|
||||
ft, ok := t.tdev.(*fakeTUN)
|
||||
if ok {
|
||||
header := p.ICMPHeader()
|
||||
header.ToResponse()
|
||||
packet := packet.Generate(&header, p.Payload())
|
||||
ft.Write(packet, 0)
|
||||
// We already handled it, stop.
|
||||
return filter.Drop
|
||||
}
|
||||
}
|
||||
return filter.Accept
|
||||
if filt.RunIn(p, t.filterFlags) != filter.Accept {
|
||||
return filter.Drop
|
||||
}
|
||||
|
||||
return filter.Drop
|
||||
if t.PostFilterIn != nil {
|
||||
if t.PostFilterIn(p, t) == filter.Drop {
|
||||
return filter.Drop
|
||||
}
|
||||
}
|
||||
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
func (t *TUN) Write(buf []byte, offset int) (int, error) {
|
||||
if !t.insecure {
|
||||
if !t.disableFilter {
|
||||
response := t.filterIn(buf[offset:])
|
||||
if response != filter.Accept {
|
||||
return 0, ErrFiltered
|
||||
|
@ -264,24 +318,53 @@ func (t *TUN) SetFilter(filt *filter.Filter) {
|
|||
t.filter.Store(filt)
|
||||
}
|
||||
|
||||
// InjectInbound makes the TUN device behave as if a packet
|
||||
// InjectInboundDirect makes the TUN device behave as if a packet
|
||||
// with the given contents was received from the network.
|
||||
// It blocks and does not take ownership of the packet.
|
||||
// Injecting an empty packet is a no-op.
|
||||
func (t *TUN) InjectInbound(packet []byte) error {
|
||||
// The injected packet will not pass through inbound filters.
|
||||
//
|
||||
// The packet contents are to start at &buf[offset].
|
||||
// offset must be greater or equal to PacketStartOffset.
|
||||
// The space before &buf[offset] will be used by Wireguard.
|
||||
func (t *TUN) InjectInboundDirect(buf []byte, offset int) error {
|
||||
if len(buf) > MaxPacketSize {
|
||||
return errPacketTooBig
|
||||
}
|
||||
if len(buf) < offset {
|
||||
return errOffsetTooBig
|
||||
}
|
||||
if offset < PacketStartOffset {
|
||||
return errOffsetTooSmall
|
||||
}
|
||||
|
||||
// Write to the underlying device to skip filters.
|
||||
_, err := t.tdev.Write(buf, offset)
|
||||
return err
|
||||
}
|
||||
|
||||
// InjectInboundCopy takes a packet without leading space,
|
||||
// reallocates it to conform to the InjectInbondDirect interface
|
||||
// and calls InjectInboundDirect on it. Injecting a nil packet is a no-op.
|
||||
func (t *TUN) InjectInboundCopy(packet []byte) error {
|
||||
// We duplicate this check from InjectInboundDirect here
|
||||
// to avoid wasting an allocation on an oversized packet.
|
||||
if len(packet) > MaxPacketSize {
|
||||
return errPacketTooBig
|
||||
}
|
||||
if len(packet) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := t.Write(packet, 0)
|
||||
return err
|
||||
|
||||
buf := make([]byte, PacketStartOffset+len(packet))
|
||||
copy(buf[PacketStartOffset:], packet)
|
||||
|
||||
return t.InjectInboundDirect(buf, PacketStartOffset)
|
||||
}
|
||||
|
||||
// InjectOutbound makes the TUN device behave as if a packet
|
||||
// with the given contents was sent to the network.
|
||||
// It does not block, but takes ownership of the packet.
|
||||
// The injected packet will not pass through outbound filters.
|
||||
// Injecting an empty packet is a no-op.
|
||||
func (t *TUN) InjectOutbound(packet []byte) error {
|
||||
if len(packet) > MaxPacketSize {
|
||||
|
|
|
@ -58,7 +58,7 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *TUN) {
|
|||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.insecure = true
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return chtun, tun
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *TUN) {
|
|||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
tun.insecure = true
|
||||
tun.disableFilter = true
|
||||
}
|
||||
return ftun.(*fakeTUN), tun
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ func TestWriteAndInject(t *testing.T) {
|
|||
for _, packet := range injected {
|
||||
go func(packet string) {
|
||||
payload := []byte(packet)
|
||||
err := tun.InjectInbound(payload)
|
||||
err := tun.InjectInboundCopy(payload)
|
||||
if err != nil {
|
||||
t.Errorf("%s: error: %v", packet, err)
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ import (
|
|||
"tailscale.com/wgengine/monitor"
|
||||
"tailscale.com/wgengine/packet"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tsdns"
|
||||
"tailscale.com/wgengine/tstun"
|
||||
)
|
||||
|
||||
|
@ -54,6 +55,7 @@ type userspaceEngine struct {
|
|||
tundev *tstun.TUN
|
||||
wgdev *device.Device
|
||||
router router.Router
|
||||
resolver *tsdns.Resolver
|
||||
magicConn *magicsock.Conn
|
||||
linkMon *monitor.Mon
|
||||
|
||||
|
@ -73,6 +75,28 @@ type userspaceEngine struct {
|
|||
// Lock ordering: wgLock, then mu.
|
||||
}
|
||||
|
||||
// RouterGen is the signature for a function that creates a
|
||||
// router.Router.
|
||||
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
|
||||
|
||||
type EngineConfig struct {
|
||||
// Logf is the logging function used by the engine.
|
||||
Logf logger.Logf
|
||||
// TUN is the tun device used by the engine.
|
||||
TUN tun.Device
|
||||
// RouterGen is the function used to instantiate the router.
|
||||
RouterGen RouterGen
|
||||
// ListenPort is the port on which the engine will listen.
|
||||
ListenPort uint16
|
||||
// EchoRespondToAll determines whether ICMP Echo requests incoming from Tailscale peers
|
||||
// will be intercepted and responded to, regardless of the source host.
|
||||
EchoRespondToAll bool
|
||||
// UseTailscaleDNS determines whether DNS requests for names of the form *.ipn.dev
|
||||
// directed to the designated Taislcale DNS address (see wgengine/tsdns)
|
||||
// will be intercepted and resolved by a tsdns.Resolver.
|
||||
UseTailscaleDNS bool
|
||||
}
|
||||
|
||||
type Loggify struct {
|
||||
f logger.Logf
|
||||
}
|
||||
|
@ -84,8 +108,14 @@ func (l *Loggify) Write(b []byte) (int, error) {
|
|||
|
||||
func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16) (Engine, error) {
|
||||
logf("Starting userspace wireguard engine (FAKE tuntap device).")
|
||||
tundev := tstun.WrapTUN(logf, tstun.NewFakeTUN())
|
||||
return NewUserspaceEngineAdvanced(logf, tundev, router.NewFake, listenPort)
|
||||
conf := EngineConfig{
|
||||
Logf: logf,
|
||||
TUN: tstun.NewFakeTUN(),
|
||||
RouterGen: router.NewFake,
|
||||
ListenPort: listenPort,
|
||||
EchoRespondToAll: true,
|
||||
}
|
||||
return NewUserspaceEngineAdvanced(conf)
|
||||
}
|
||||
|
||||
// NewUserspaceEngine creates the named tun device and returns a
|
||||
|
@ -104,38 +134,53 @@ func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16) (En
|
|||
return nil, err
|
||||
}
|
||||
logf("CreateTUN ok.")
|
||||
tundev := tstun.WrapTUN(logf, tun)
|
||||
|
||||
e, err := NewUserspaceEngineAdvanced(logf, tundev, router.New, listenPort)
|
||||
conf := EngineConfig{
|
||||
Logf: logf,
|
||||
TUN: tun,
|
||||
RouterGen: router.New,
|
||||
ListenPort: listenPort,
|
||||
// TODO(dmytro): plumb this down.
|
||||
UseTailscaleDNS: true,
|
||||
}
|
||||
|
||||
e, err := NewUserspaceEngineAdvanced(conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e, err
|
||||
}
|
||||
|
||||
// RouterGen is the signature for a function that creates a
|
||||
// router.Router.
|
||||
type RouterGen func(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (router.Router, error)
|
||||
|
||||
// NewUserspaceEngineAdvanced is like NewUserspaceEngine but takes a pre-created TUN device and allows specifing
|
||||
// a custom router constructor and listening port.
|
||||
func NewUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (Engine, error) {
|
||||
return newUserspaceEngineAdvanced(logf, tundev, routerGen, listenPort)
|
||||
// NewUserspaceEngineAdvanced is like NewUserspaceEngine
|
||||
// but provides control over all config fields.
|
||||
func NewUserspaceEngineAdvanced(conf EngineConfig) (Engine, error) {
|
||||
return newUserspaceEngineAdvanced(conf)
|
||||
}
|
||||
|
||||
func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen RouterGen, listenPort uint16) (_ Engine, reterr error) {
|
||||
func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
|
||||
logf := conf.Logf
|
||||
|
||||
e := &userspaceEngine{
|
||||
logf: logf,
|
||||
reqCh: make(chan struct{}, 1),
|
||||
waitCh: make(chan struct{}),
|
||||
tundev: tundev,
|
||||
pingers: make(map[wgcfg.Key]*pinger),
|
||||
logf: logf,
|
||||
reqCh: make(chan struct{}, 1),
|
||||
waitCh: make(chan struct{}),
|
||||
tundev: tstun.WrapTUN(logf, conf.TUN),
|
||||
resolver: tsdns.NewResolver(logf),
|
||||
pingers: make(map[wgcfg.Key]*pinger),
|
||||
}
|
||||
e.linkState, _ = getLinkState()
|
||||
|
||||
// Respond to all pings only in fake mode.
|
||||
if conf.EchoRespondToAll {
|
||||
e.tundev.PostFilterIn = echoRespondToAll
|
||||
}
|
||||
if conf.UseTailscaleDNS {
|
||||
e.tundev.PreFilterOut = e.handleDNS
|
||||
}
|
||||
|
||||
mon, err := monitor.New(logf, func() { e.LinkChange(false) })
|
||||
if err != nil {
|
||||
tundev.Close()
|
||||
e.tundev.Close()
|
||||
return nil, err
|
||||
}
|
||||
e.linkMon = mon
|
||||
|
@ -149,12 +194,12 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||
}
|
||||
magicsockOpts := magicsock.Options{
|
||||
Logf: logf,
|
||||
Port: listenPort,
|
||||
Port: conf.ListenPort,
|
||||
EndpointsFunc: endpointsFn,
|
||||
}
|
||||
e.magicConn, err = magicsock.NewConn(magicsockOpts)
|
||||
if err != nil {
|
||||
tundev.Close()
|
||||
e.tundev.Close()
|
||||
return nil, fmt.Errorf("wgengine: %v", err)
|
||||
}
|
||||
|
||||
|
@ -211,7 +256,7 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||
|
||||
// Pass the underlying tun.(*NativeDevice) to the router:
|
||||
// routers do not Read or Write, but do access native interfaces.
|
||||
e.router, err = routerGen(logf, e.wgdev, e.tundev.Unwrap())
|
||||
e.router, err = conf.RouterGen(logf, e.wgdev, e.tundev.Unwrap())
|
||||
if err != nil {
|
||||
e.magicConn.Close()
|
||||
return nil, err
|
||||
|
@ -256,6 +301,37 @@ func newUserspaceEngineAdvanced(logf logger.Logf, tundev *tstun.TUN, routerGen R
|
|||
return e, nil
|
||||
}
|
||||
|
||||
// echoRespondToAll is an inbound post-filter responding to all echo requests.
|
||||
func echoRespondToAll(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
|
||||
if p.IsEchoRequest() {
|
||||
header := p.ICMPHeader()
|
||||
header.ToResponse()
|
||||
packet := packet.Generate(&header, p.Payload())
|
||||
t.InjectOutbound(packet)
|
||||
// We already handled it, stop.
|
||||
return filter.Drop
|
||||
}
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
// handleDNS is an outbound pre-filter resolving Tailscale domains.
|
||||
func (e *userspaceEngine) handleDNS(p *packet.ParsedPacket, t *tstun.TUN) filter.Response {
|
||||
if e.resolver.AcceptsPacket(p) {
|
||||
// TODO(dmytro): avoid this allocation without having tsdns know tstun quirks.
|
||||
buf := make([]byte, tstun.MaxPacketSize)
|
||||
offset := tstun.PacketStartOffset
|
||||
response, err := e.resolver.Respond(p, buf[offset:])
|
||||
if err != nil {
|
||||
e.logf("DNS resolver error: %v", err)
|
||||
} else {
|
||||
t.InjectInboundDirect(buf[:offset+len(response)], offset)
|
||||
}
|
||||
// We already handled it, stop.
|
||||
return filter.Drop
|
||||
}
|
||||
return filter.Accept
|
||||
}
|
||||
|
||||
// pinger sends ping packets for a few seconds.
|
||||
//
|
||||
// These generated packets are used to ensure we trigger the spray logic in
|
||||
|
@ -447,6 +523,10 @@ func (e *userspaceEngine) SetFilter(filt *filter.Filter) {
|
|||
e.tundev.SetFilter(filt)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) SetDNSMap(dm *tsdns.Map) {
|
||||
e.resolver.SetMap(dm)
|
||||
}
|
||||
|
||||
func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tsdns"
|
||||
)
|
||||
|
||||
// NewWatchdog wraps an Engine and makes sure that all methods complete
|
||||
|
@ -74,6 +75,9 @@ func (e *watchdogEngine) GetFilter() *filter.Filter {
|
|||
func (e *watchdogEngine) SetFilter(filt *filter.Filter) {
|
||||
e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) })
|
||||
}
|
||||
func (e *watchdogEngine) SetDNSMap(dm *tsdns.Map) {
|
||||
e.watchdog("SetDNSMap", func() { e.wrap.SetDNSMap(dm) })
|
||||
}
|
||||
func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) {
|
||||
e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) })
|
||||
}
|
||||
|
|
|
@ -10,9 +10,6 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tstun"
|
||||
)
|
||||
|
||||
func TestWatchdog(t *testing.T) {
|
||||
|
@ -20,8 +17,7 @@ func TestWatchdog(t *testing.T) {
|
|||
|
||||
t.Run("default watchdog does not fire", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
|
||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
|
||||
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -37,8 +33,7 @@ func TestWatchdog(t *testing.T) {
|
|||
|
||||
t.Run("watchdog fires on blocked getStatus", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tun := tstun.WrapTUN(t.Logf, tstun.NewFakeTUN())
|
||||
e, err := NewUserspaceEngineAdvanced(t.Logf, tun, router.NewFake, 0)
|
||||
e, err := NewFakeUserspaceEngine(t.Logf, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/router"
|
||||
"tailscale.com/wgengine/tsdns"
|
||||
)
|
||||
|
||||
// ByteCount is the number of bytes that have been sent or received.
|
||||
|
@ -65,6 +66,9 @@ type Engine interface {
|
|||
// SetFilter updates the packet filter.
|
||||
SetFilter(*filter.Filter)
|
||||
|
||||
// SetDNSMap updates the DNS map.
|
||||
SetDNSMap(*tsdns.Map)
|
||||
|
||||
// SetStatusCallback sets the function to call when the
|
||||
// WireGuard status changes.
|
||||
SetStatusCallback(StatusCallback)
|
||||
|
|
Loading…
Reference in New Issue