220 lines
5.6 KiB
Go
220 lines
5.6 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"expvar"
|
|
"log"
|
|
"math/rand/v2"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"tailscale.com/syncs"
|
|
"tailscale.com/util/mak"
|
|
"tailscale.com/util/slicesx"
|
|
)
|
|
|
|
const refreshTimeout = time.Minute
|
|
|
|
type dnsEntryMap struct {
|
|
IPs map[string][]net.IP
|
|
Percent map[string]float64 // "foo.com" => 0.5 for 50%
|
|
}
|
|
|
|
var (
|
|
dnsCache atomic.Pointer[dnsEntryMap]
|
|
dnsCacheBytes syncs.AtomicValue[[]byte] // of JSON
|
|
unpublishedDNSCache atomic.Pointer[dnsEntryMap]
|
|
bootstrapLookupMap syncs.Map[string, bool]
|
|
)
|
|
|
|
var (
|
|
bootstrapDNSRequests = expvar.NewInt("counter_bootstrap_dns_requests")
|
|
publishedDNSHits = expvar.NewInt("counter_bootstrap_dns_published_hits")
|
|
publishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_published_misses")
|
|
unpublishedDNSHits = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
|
|
unpublishedDNSMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
|
|
unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses")
|
|
)
|
|
|
|
func init() {
|
|
expvar.Publish("counter_bootstrap_dns_queried_domains", expvar.Func(func() any {
|
|
return bootstrapLookupMap.Len()
|
|
}))
|
|
}
|
|
|
|
func refreshBootstrapDNSLoop() {
|
|
if *bootstrapDNS == "" && *unpublishedDNS == "" {
|
|
return
|
|
}
|
|
for {
|
|
refreshBootstrapDNS()
|
|
refreshUnpublishedDNS()
|
|
time.Sleep(10 * time.Minute)
|
|
}
|
|
}
|
|
|
|
func refreshBootstrapDNS() {
|
|
if *bootstrapDNS == "" {
|
|
return
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
|
|
defer cancel()
|
|
dnsEntries := resolveList(ctx, *bootstrapDNS)
|
|
// Randomize the order of the IPs for each name to avoid the client biasing
|
|
// to IPv6
|
|
for _, vv := range dnsEntries.IPs {
|
|
slicesx.Shuffle(vv)
|
|
}
|
|
j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t")
|
|
if err != nil {
|
|
// leave the old values in place
|
|
return
|
|
}
|
|
|
|
dnsCache.Store(dnsEntries)
|
|
dnsCacheBytes.Store(j)
|
|
}
|
|
|
|
func refreshUnpublishedDNS() {
|
|
if *unpublishedDNS == "" {
|
|
return
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
|
|
defer cancel()
|
|
dnsEntries := resolveList(ctx, *unpublishedDNS)
|
|
unpublishedDNSCache.Store(dnsEntries)
|
|
}
|
|
|
|
// resolveList takes a comma-separated list of DNS names to resolve.
|
|
//
|
|
// If an entry contains a slash, it's two DNS names: the first is the one to
|
|
// resolve and the second is that of a TXT recording containing the rollout
|
|
// percentage in range "0".."100". If the TXT record doesn't exist or is
|
|
// malformed, the percentage is 0. If the TXT record is not provided (there's no
|
|
// slash), then the percentage is 100.
|
|
func resolveList(ctx context.Context, list string) *dnsEntryMap {
|
|
ents := strings.Split(list, ",")
|
|
|
|
ret := &dnsEntryMap{}
|
|
|
|
var r net.Resolver
|
|
for _, ent := range ents {
|
|
name, txtName, _ := strings.Cut(ent, "/")
|
|
addrs, err := r.LookupIP(ctx, "ip", name)
|
|
if err != nil {
|
|
log.Printf("bootstrap DNS lookup %q: %v", name, err)
|
|
continue
|
|
}
|
|
mak.Set(&ret.IPs, name, addrs)
|
|
|
|
if txtName == "" {
|
|
mak.Set(&ret.Percent, name, 1.0)
|
|
continue
|
|
}
|
|
vals, err := r.LookupTXT(ctx, txtName)
|
|
if err != nil {
|
|
log.Printf("bootstrap DNS lookup %q: %v", txtName, err)
|
|
continue
|
|
}
|
|
for _, v := range vals {
|
|
if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 {
|
|
mak.Set(&ret.Percent, name, float64(v)/100)
|
|
}
|
|
}
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
|
|
bootstrapDNSRequests.Add(1)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
// Bootstrap DNS requests occur cross-regions, and are randomized per
|
|
// request, so keeping a connection open is pointlessly expensive.
|
|
w.Header().Set("Connection", "close")
|
|
|
|
// Try answering a query from our hidden map first
|
|
if q := r.URL.Query().Get("q"); q != "" {
|
|
bootstrapLookupMap.Store(q, true)
|
|
if bootstrapLookupMap.Len() > 500 { // defensive
|
|
bootstrapLookupMap.Clear()
|
|
}
|
|
if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 {
|
|
unpublishedDNSHits.Add(1)
|
|
|
|
percent := m.Percent[q]
|
|
if remoteAddrMatchesPercent(r.RemoteAddr, percent) {
|
|
// Only return the specific query, not everything.
|
|
m := map[string][]net.IP{q: m.IPs[q]}
|
|
j, err := json.MarshalIndent(m, "", "\t")
|
|
if err == nil {
|
|
w.Write(j)
|
|
return
|
|
}
|
|
} else {
|
|
unpublishedDNSPercentMisses.Add(1)
|
|
}
|
|
}
|
|
|
|
// If we have a "q" query for a name in the published cache
|
|
// list, then track whether that's a hit/miss.
|
|
m := dnsCache.Load()
|
|
var inPub bool
|
|
var ips []net.IP
|
|
if m != nil {
|
|
ips, inPub = m.IPs[q]
|
|
}
|
|
if inPub {
|
|
if len(ips) > 0 {
|
|
publishedDNSHits.Add(1)
|
|
} else {
|
|
publishedDNSMisses.Add(1)
|
|
}
|
|
} else {
|
|
// If it wasn't in either cache, treat this as a query
|
|
// for the unpublished cache, and thus a cache miss.
|
|
unpublishedDNSMisses.Add(1)
|
|
}
|
|
}
|
|
|
|
// Fall back to returning the public set of cached DNS names
|
|
j := dnsCacheBytes.Load()
|
|
w.Write(j)
|
|
}
|
|
|
|
// percent is [0.0, 1.0].
|
|
func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool {
|
|
if percent == 0 {
|
|
return false
|
|
}
|
|
if percent == 1 {
|
|
return true
|
|
}
|
|
reqIPStr, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
reqIP, err := netip.ParseAddr(reqIPStr)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
if reqIP.IsLoopback() {
|
|
// For local testing.
|
|
return rand.Float64() < 0.5
|
|
}
|
|
reqIP16 := reqIP.As16()
|
|
rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:]))
|
|
rnd := rand.New(rndSrc)
|
|
return percent > rnd.Float64()
|
|
}
|