diff --git a/net/socks5/tssocks/tssocks.go b/net/socks5/tssocks/tssocks.go index d8edb8f77..589364fa7 100644 --- a/net/socks5/tssocks/tssocks.go +++ b/net/socks5/tssocks/tssocks.go @@ -27,34 +27,53 @@ import ( // // If ns is non-nil, it is used for dialing when needed. func NewServer(logf logger.Logf, e wgengine.Engine, ns *netstack.Impl) *socks5.Server { - srv := &socks5.Server{ - Logf: logf, + d := &dialer{ns: ns} + e.AddNetworkMapCallback(d.onNewNetmap) + return &socks5.Server{ + Logf: logf, + Dialer: d.DialContext, } - var ( - mu sync.Mutex // guards the following field - dns netstack.DNSMap - ) - e.AddNetworkMapCallback(func(nm *netmap.NetworkMap) { - mu.Lock() - defer mu.Unlock() - dns = netstack.DNSMapFromNetworkMap(nm) - }) - useNetstackForIP := func(ip netaddr.IP) bool { - // TODO(bradfitz): this isn't exactly right. - // We should also support subnets when the - // prefs are configured as such. - return tsaddr.IsTailscaleIP(ip) - } - srv.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { - ipp, err := dns.Resolve(ctx, addr) - if err != nil { - return nil, err - } - if ns != nil && useNetstackForIP(ipp.IP()) { - return ns.DialContextTCP(ctx, addr) - } - var d net.Dialer - return d.DialContext(ctx, network, ipp.String()) - } - return srv +} + +// dialer is the Tailscale SOCKS5 dialer. +type dialer struct { + ns *netstack.Impl + + mu sync.Mutex + dns netstack.DNSMap +} + +func (d *dialer) onNewNetmap(nm *netmap.NetworkMap) { + d.mu.Lock() + defer d.mu.Unlock() + d.dns = netstack.DNSMapFromNetworkMap(nm) +} + +func (d *dialer) resolve(ctx context.Context, addr string) (netaddr.IPPort, error) { + d.mu.Lock() + dns := d.dns + d.mu.Unlock() + return dns.Resolve(ctx, addr) +} + +func (d *dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + ipp, err := d.resolve(ctx, addr) + if err != nil { + return nil, err + } + if d.ns != nil && d.useNetstackForIP(ipp.IP()) { + return d.ns.DialContextTCP(ctx, ipp.String()) + } + var stdDialer net.Dialer + return stdDialer.DialContext(ctx, network, ipp.String()) +} + +func (d *dialer) useNetstackForIP(ip netaddr.IP) bool { + if d.ns == nil { + return false + } + // TODO(bradfitz): this isn't exactly right. + // We should also support subnets when the + // prefs are configured as such. + return tsaddr.IsTailscaleIP(ip) }