288 lines
7.0 KiB
Go
288 lines
7.0 KiB
Go
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package stunner
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math/rand"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"tailscale.com/net/dnscache"
|
|
"tailscale.com/stun"
|
|
)
|
|
|
|
// Stunner sends a STUN request to several servers and handles a response.
|
|
//
|
|
// It is designed to used on a connection owned by other code and so does
|
|
// not directly reference a net.Conn of any sort. Instead, the user should
|
|
// provide Send function to send packets, and call Receive when a new
|
|
// STUN response is received.
|
|
//
|
|
// In response, a Stunner will call Endpoint with any endpoints determined
|
|
// for the connection. (An endpoint may be reported multiple times if
|
|
// multiple servers are provided.)
|
|
type Stunner struct {
|
|
// Send sends a packet.
|
|
// It will typically be a PacketConn.WriteTo method value.
|
|
Send func([]byte, net.Addr) (int, error) // sends a packet
|
|
|
|
// Endpoint is called whenever a STUN response is received.
|
|
// The server is the STUN server that replied, endpoint is the ip:port
|
|
// from the STUN response, and d is the duration that the STUN request
|
|
// took on the wire (not including DNS lookup time.
|
|
Endpoint func(server, endpoint string, d time.Duration)
|
|
|
|
// onPacket is the internal version of Endpoint that does de-dup.
|
|
// It's set by Run.
|
|
onPacket func(server, endpoint string, d time.Duration)
|
|
|
|
Servers []string // STUN servers to contact
|
|
|
|
// DNSCache optionally specifies a DNSCache to use.
|
|
// If nil, a DNS cache is not used.
|
|
DNSCache *dnscache.Resolver
|
|
|
|
// Logf optionally specifies a log function. If nil, logging is disabled.
|
|
Logf func(format string, args ...interface{})
|
|
|
|
// OnlyIPv6 controls whether IPv6 is exclusively used.
|
|
// If false, only IPv4 is used. There is currently no mixed mode.
|
|
OnlyIPv6 bool
|
|
|
|
mu sync.Mutex
|
|
inFlight map[stun.TxID]request
|
|
}
|
|
|
|
func (s *Stunner) addTX(tx stun.TxID, server string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, dup := s.inFlight[tx]; dup {
|
|
panic("unexpected duplicate STUN TransactionID")
|
|
}
|
|
s.inFlight[tx] = request{sent: time.Now(), server: server}
|
|
}
|
|
|
|
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.inFlight == nil {
|
|
return request{}, false
|
|
}
|
|
r, ok := s.inFlight[tx]
|
|
if ok {
|
|
delete(s.inFlight, tx)
|
|
} else {
|
|
s.logf("stunner: got STUN packet for unknown TxID %x", tx)
|
|
}
|
|
return r, ok
|
|
}
|
|
|
|
type request struct {
|
|
sent time.Time
|
|
server string
|
|
}
|
|
|
|
func (s *Stunner) logf(format string, args ...interface{}) {
|
|
if s.Logf != nil {
|
|
s.Logf(format, args...)
|
|
}
|
|
}
|
|
|
|
// Receive delivers a STUN packet to the stunner.
|
|
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) {
|
|
if !stun.Is(p) {
|
|
s.logf("[unexpected] stunner: received non-STUN packet")
|
|
return
|
|
}
|
|
now := time.Now()
|
|
tx, addr, port, err := stun.ParseResponse(p)
|
|
if err != nil {
|
|
s.logf("stunner: received bad STUN response: %v", err)
|
|
return
|
|
}
|
|
r, ok := s.removeTX(tx)
|
|
if !ok {
|
|
return
|
|
}
|
|
d := now.Sub(r.sent)
|
|
|
|
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port))
|
|
s.onPacket(r.server, host, d)
|
|
}
|
|
|
|
func (s *Stunner) resolver() *net.Resolver {
|
|
return net.DefaultResolver
|
|
}
|
|
|
|
// cleanUpPostRun zeros out some fields, mostly for debugging (so
|
|
// things crash or race+fail if there's a sender still running.)
|
|
func (s *Stunner) cleanUpPostRun() {
|
|
s.mu.Lock()
|
|
s.inFlight = nil
|
|
s.mu.Unlock()
|
|
}
|
|
|
|
// Run starts a Stunner and blocks until all servers either respond
|
|
// or are tried multiple times and timeout.
|
|
// It can not be called concurrently with itself.
|
|
func (s *Stunner) Run(ctx context.Context) error {
|
|
for _, server := range s.Servers {
|
|
if _, _, err := net.SplitHostPort(server); err != nil {
|
|
return fmt.Errorf("Stunner.Run: invalid server %q (in Server list %q)", server, s.Servers)
|
|
}
|
|
}
|
|
if len(s.Servers) == 0 {
|
|
return errors.New("stunner: no Servers")
|
|
}
|
|
|
|
s.inFlight = make(map[stun.TxID]request)
|
|
defer s.cleanUpPostRun()
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
type sender struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
var (
|
|
needMu sync.Mutex
|
|
need = make(map[string]sender) // keyed by server; deleted when done
|
|
allDone = make(chan struct{}) // closed when need is empty
|
|
)
|
|
s.onPacket = func(server, endpoint string, d time.Duration) {
|
|
needMu.Lock()
|
|
defer needMu.Unlock()
|
|
sender, ok := need[server]
|
|
if !ok {
|
|
return
|
|
}
|
|
sender.cancel()
|
|
delete(need, server)
|
|
s.Endpoint(server, endpoint, d)
|
|
if len(need) == 0 {
|
|
close(allDone)
|
|
}
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
for _, server := range s.Servers {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
need[server] = sender{ctx, cancel}
|
|
}
|
|
for server, sender := range need {
|
|
wg.Add(1)
|
|
server, ctx := server, sender.ctx
|
|
go func() {
|
|
defer wg.Done()
|
|
s.sendPackets(ctx, server)
|
|
}()
|
|
}
|
|
var err error
|
|
select {
|
|
case <-ctx.Done():
|
|
err = ctx.Err()
|
|
case <-allDone:
|
|
cancel()
|
|
}
|
|
wg.Wait()
|
|
|
|
var missing []string
|
|
needMu.Lock()
|
|
for server := range need {
|
|
missing = append(missing, server)
|
|
}
|
|
needMu.Unlock()
|
|
|
|
if len(missing) == 0 || err == nil {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("got STUN error: %w; missing replies from: %v", err, strings.Join(missing, ", "))
|
|
}
|
|
|
|
func (s *Stunner) serverAddr(ctx context.Context, server string) (*net.UDPAddr, error) {
|
|
hostStr, portStr, err := net.SplitHostPort(server)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
addrPort, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("port: %v", err)
|
|
}
|
|
if addrPort == 0 {
|
|
addrPort = 3478
|
|
}
|
|
addr := &net.UDPAddr{Port: addrPort}
|
|
|
|
var ipAddrs []net.IPAddr
|
|
if s.DNSCache != nil {
|
|
ip, err := s.DNSCache.LookupIP(ctx, hostStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ipAddrs = []net.IPAddr{{IP: ip}}
|
|
} else {
|
|
ipAddrs, err = s.resolver().LookupIPAddr(ctx, hostStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("lookup ip addr (%q): %v", hostStr, err)
|
|
}
|
|
}
|
|
|
|
for _, ipAddr := range ipAddrs {
|
|
ip4 := ipAddr.IP.To4()
|
|
if ip4 != nil {
|
|
if s.OnlyIPv6 {
|
|
continue
|
|
}
|
|
addr.IP = ip4
|
|
break
|
|
} else if s.OnlyIPv6 {
|
|
addr.IP = ipAddr.IP
|
|
addr.Zone = ipAddr.Zone
|
|
}
|
|
}
|
|
if addr.IP == nil {
|
|
if s.OnlyIPv6 {
|
|
return nil, fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
return nil, fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs)
|
|
}
|
|
return addr, nil
|
|
}
|
|
|
|
func (s *Stunner) sendPackets(ctx context.Context, server string) error {
|
|
addr, err := s.serverAddr(ctx, server)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
const maxSend = 2
|
|
for i := 0; i < maxSend; i++ {
|
|
txID := stun.NewTxID()
|
|
req := stun.Request(txID)
|
|
s.addTX(txID, server)
|
|
_, err = s.Send(req, addr)
|
|
if err != nil {
|
|
return fmt.Errorf("send: %v", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
// Ignore error. The caller deals with handling contexts.
|
|
// We only use it to dermine when to stop spraying STUN packets.
|
|
return nil
|
|
case <-time.After(time.Millisecond * time.Duration(50+rand.Intn(200))):
|
|
}
|
|
}
|
|
return nil
|
|
}
|