net/netns: thread logf into control functions

So that darwin can log there without panicking during tests.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2021-11-18 12:18:02 -08:00 committed by Josh Bleecher Snyder
parent 85184a58ed
commit 758c37b83d
16 changed files with 64 additions and 33 deletions

View File

@ -160,7 +160,7 @@ func NewDirect(opts Options) (*Direct, error) {
UseLastGood: true, UseLastGood: true,
LookupIPFallback: dnsfallback.Lookup, LookupIPFallback: dnsfallback.Lookup,
} }
dialer := netns.NewDialer() dialer := netns.NewDialer(opts.Logf)
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment tr.Proxy = tshttpproxy.ProxyFromEnvironment
tshttpproxy.SetTransportGetProxyConnectHeader(tr) tshttpproxy.SetTransportGetProxyConnectHeader(tr)

View File

@ -429,7 +429,7 @@ func (c *Client) dialURL(ctx context.Context) (net.Conn, error) {
return c.dialer(ctx, "tcp", net.JoinHostPort(host, urlPort(c.url))) return c.dialer(ctx, "tcp", net.JoinHostPort(host, urlPort(c.url)))
} }
hostOrIP := host hostOrIP := host
dialer := netns.NewDialer() dialer := netns.NewDialer(c.logf)
if c.DNSCache != nil { if c.DNSCache != nil {
ip, _, _, err := c.DNSCache.LookupIP(ctx, host) ip, _, _, err := c.DNSCache.LookupIP(ctx, host)
@ -519,7 +519,7 @@ func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tl
} }
func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) { func (c *Client) dialContext(ctx context.Context, proto, addr string) (net.Conn, error) {
return netns.NewDialer().DialContext(ctx, proto, addr) return netns.NewDialer(c.logf).DialContext(ctx, proto, addr)
} }
// shouldDialProto reports whether an explicitly provided IPv4 or IPv6 // shouldDialProto reports whether an explicitly provided IPv4 or IPv6

View File

@ -587,7 +587,7 @@ func newLogtailTransport(host string) *http.Transport {
// Log whenever we dial: // Log whenever we dial:
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
nd := netns.FromDialer(&net.Dialer{ nd := netns.FromDialer(log.Printf, &net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: netknob.PlatformTCPKeepAlive(), KeepAlive: netknob.PlatformTCPKeepAlive(),
}) })

View File

@ -342,7 +342,7 @@ func (f *forwarder) getKnownDoHClient(ip netaddr.IP) (urlBase string, c *http.Cl
if f.dohClient == nil { if f.dohClient == nil {
f.dohClient = map[string]*http.Client{} f.dohClient = map[string]*http.Client{}
} }
nsDialer := netns.NewDialer() nsDialer := netns.NewDialer(f.logf)
c = &http.Client{ c = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
IdleConnTimeout: dohTransportTimeout, IdleConnTimeout: dohTransportTimeout,

View File

@ -90,7 +90,7 @@ func Lookup(ctx context.Context, host string) ([]netaddr.IP, error) {
// serverName and serverIP of are, say, "derpN.tailscale.com". // serverName and serverIP of are, say, "derpN.tailscale.com".
// queryName is the name being sought (e.g. "controlplane.tailscale.com"), passed as hint. // queryName is the name being sought (e.g. "controlplane.tailscale.com"), passed as hint.
func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netaddr.IP, queryName string) (dnsMap, error) { func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netaddr.IP, queryName string) (dnsMap, error) {
dialer := netns.NewDialer() dialer := netns.NewDialer(log.Printf)
tr := http.DefaultTransport.(*http.Transport).Clone() tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment tr.Proxy = tshttpproxy.ProxyFromEnvironment
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {

View File

@ -807,7 +807,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
} }
// Create a UDP4 socket used for sending to our discovered IPv4 address. // Create a UDP4 socket used for sending to our discovered IPv4 address.
rs.pc4Hair, err = netns.Listener().ListenPacket(ctx, "udp4", ":0") rs.pc4Hair, err = netns.Listener(c.logf).ListenPacket(ctx, "udp4", ":0")
if err != nil { if err != nil {
c.logf("udp4: %v", err) c.logf("udp4: %v", err)
return nil, err return nil, err
@ -835,7 +835,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
if f := c.GetSTUNConn4; f != nil { if f := c.GetSTUNConn4; f != nil {
rs.pc4 = f() rs.pc4 = f()
} else { } else {
u4, err := netns.Listener().ListenPacket(ctx, "udp4", c.udpBindAddr()) u4, err := netns.Listener(c.logf).ListenPacket(ctx, "udp4", c.udpBindAddr())
if err != nil { if err != nil {
c.logf("udp4: %v", err) c.logf("udp4: %v", err)
return nil, err return nil, err
@ -848,7 +848,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap) (_ *Report,
if f := c.GetSTUNConn6; f != nil { if f := c.GetSTUNConn6; f != nil {
rs.pc6 = f() rs.pc6 = f()
} else { } else {
u6, err := netns.Listener().ListenPacket(ctx, "udp6", c.udpBindAddr()) u6, err := netns.Listener(c.logf).ListenPacket(ctx, "udp6", c.udpBindAddr())
if err != nil { if err != nil {
c.logf("udp6: %v", err) c.logf("udp6: %v", err)
} else { } else {

View File

@ -21,6 +21,7 @@ import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/net/netknob" "tailscale.com/net/netknob"
"tailscale.com/syncs" "tailscale.com/syncs"
"tailscale.com/types/logger"
) )
var disabled syncs.AtomicBool var disabled syncs.AtomicBool
@ -34,19 +35,19 @@ func SetEnabled(on bool) {
// Listener returns a new net.Listener with its Control hook func // Listener returns a new net.Listener with its Control hook func
// initialized as necessary to run in logical network namespace that // initialized as necessary to run in logical network namespace that
// doesn't route back into Tailscale. // doesn't route back into Tailscale.
func Listener() *net.ListenConfig { func Listener(logf logger.Logf) *net.ListenConfig {
if disabled.Get() { if disabled.Get() {
return new(net.ListenConfig) return new(net.ListenConfig)
} }
return &net.ListenConfig{Control: control} return &net.ListenConfig{Control: control(logf)}
} }
// NewDialer returns a new Dialer using a net.Dialer with its Control // NewDialer returns a new Dialer using a net.Dialer with its Control
// hook func initialized as necessary to run in a logical network // hook func initialized as necessary to run in a logical network
// namespace that doesn't route back into Tailscale. It also handles // namespace that doesn't route back into Tailscale. It also handles
// using a SOCKS if configured in the environment with ALL_PROXY. // using a SOCKS if configured in the environment with ALL_PROXY.
func NewDialer() Dialer { func NewDialer(logf logger.Logf) Dialer {
return FromDialer(&net.Dialer{ return FromDialer(logf, &net.Dialer{
KeepAlive: netknob.PlatformTCPKeepAlive(), KeepAlive: netknob.PlatformTCPKeepAlive(),
}) })
} }
@ -55,11 +56,11 @@ func NewDialer() Dialer {
// network namespace that doesn't route back into Tailscale. It also // network namespace that doesn't route back into Tailscale. It also
// handles using a SOCKS if configured in the environment with // handles using a SOCKS if configured in the environment with
// ALL_PROXY. // ALL_PROXY.
func FromDialer(d *net.Dialer) Dialer { func FromDialer(logf logger.Logf, d *net.Dialer) Dialer {
if disabled.Get() { if disabled.Get() {
return d return d
} }
d.Control = control d.Control = control(logf)
if wrapDialer != nil { if wrapDialer != nil {
return wrapDialer(d) return wrapDialer(d)
} }

View File

@ -11,6 +11,8 @@ import (
"fmt" "fmt"
"sync" "sync"
"syscall" "syscall"
"tailscale.com/types/logger"
) )
var ( var (
@ -44,11 +46,15 @@ func SetAndroidProtectFunc(f func(fd int) error) {
androidProtectFunc = f androidProtectFunc = f
} }
// control marks c as necessary to dial in a separate network namespace. func control(logger.Logf) func(network, address string, c syscall.RawConn) error {
return controlC
}
// controlC marks c as necessary to dial in a separate network namespace.
// //
// It's intentionally the same signature as net.Dialer.Control // It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control. // and net.ListenConfig.Control.
func control(network, address string, c syscall.RawConn) error { func controlC(network, address string, c syscall.RawConn) error {
var sockErr error var sockErr error
err := c.Control(func(fd uintptr) { err := c.Control(func(fd uintptr) {
androidProtectFuncMu.Lock() androidProtectFuncMu.Lock()

View File

@ -9,26 +9,32 @@ package netns
import ( import (
"fmt" "fmt"
"log"
"strings" "strings"
"syscall" "syscall"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/types/logger"
) )
// control marks c as necessary to dial in a separate network namespace. func control(logf logger.Logf) func(network, address string, c syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
return controlLogf(logf, network, address, c)
}
}
// controlLogf marks c as necessary to dial in a separate network namespace.
// //
// It's intentionally the same signature as net.Dialer.Control // It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control. // and net.ListenConfig.Control.
func control(network, address string, c syscall.RawConn) error { func controlLogf(logf logger.Logf, network, address string, c syscall.RawConn) error {
if strings.HasPrefix(address, "127.") || address == "::1" { if strings.HasPrefix(address, "127.") || address == "::1" {
// Don't bind to an interface for localhost connections. // Don't bind to an interface for localhost connections.
return nil return nil
} }
idx, err := interfaces.DefaultRouteInterfaceIndex() idx, err := interfaces.DefaultRouteInterfaceIndex()
if err != nil { if err != nil {
log.Printf("netns: DefaultRouteInterfaceIndex: %v", err) logf("[unexpected] netns: DefaultRouteInterfaceIndex: %v", err)
return nil return nil
} }
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6 v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
@ -47,7 +53,7 @@ func control(network, address string, c syscall.RawConn) error {
return fmt.Errorf("RawConn.Control on %T: %w", c, err) return fmt.Errorf("RawConn.Control on %T: %w", c, err)
} }
if sockErr != nil { if sockErr != nil {
log.Printf("netns: control(%q, %q), v6=%v, index=%v: %v", network, address, v6, idx, sockErr) logf("[unexpected] netns: control(%q, %q), v6=%v, index=%v: %v", network, address, v6, idx, sockErr)
} }
return sockErr return sockErr
} }

View File

@ -7,9 +7,17 @@
package netns package netns
import "syscall" import (
"syscall"
// control does nothing to c. "tailscale.com/types/logger"
func control(network, address string, c syscall.RawConn) error { )
func control(logger.Logf) func(network, address string, c syscall.RawConn) error {
return controlC
}
// controlC does nothing to c.
func controlC(network, address string, c syscall.RawConn) error {
return nil return nil
} }

View File

@ -17,6 +17,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/types/logger"
) )
// tailscaleBypassMark is the mark indicating that packets originating // tailscaleBypassMark is the mark indicating that packets originating
@ -82,11 +83,15 @@ func ignoreErrors() bool {
return false return false
} }
// control marks c as necessary to dial in a separate network namespace. func control(logger.Logf) func(network, address string, c syscall.RawConn) error {
return controlC
}
// controlC marks c as necessary to dial in a separate network namespace.
// //
// It's intentionally the same signature as net.Dialer.Control // It's intentionally the same signature as net.Dialer.Control
// and net.ListenConfig.Control. // and net.ListenConfig.Control.
func control(network, address string, c syscall.RawConn) error { func controlC(network, address string, c syscall.RawConn) error {
if isLocalhost(address) { if isLocalhost(address) {
// Don't bind to an interface for localhost connections. // Don't bind to an interface for localhost connections.
return nil return nil

View File

@ -25,7 +25,7 @@ func TestDial(t *testing.T) {
if !*extNetwork { if !*extNetwork {
t.Skip("skipping test without --use-external-network") t.Skip("skipping test without --use-external-network")
} }
d := NewDialer() d := NewDialer(t.Logf)
c, err := d.Dial("tcp", "google.com:80") c, err := d.Dial("tcp", "google.com:80")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -12,6 +12,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"tailscale.com/net/interfaces" "tailscale.com/net/interfaces"
"tailscale.com/types/logger"
"tailscale.com/util/endian" "tailscale.com/util/endian"
) )
@ -26,9 +27,13 @@ func interfaceIndex(iface *winipcfg.IPAdapterAddresses) uint32 {
return iface.IfIndex return iface.IfIndex
} }
// control binds c to the Windows interface that holds a default func control(logger.Logf) func(network, address string, c syscall.RawConn) error {
return controlC
}
// controlC binds c to the Windows interface that holds a default
// route, and is not the Tailscale WinTun interface. // route, and is not the Tailscale WinTun interface.
func control(network, address string, c syscall.RawConn) error { func controlC(network, address string, c syscall.RawConn) error {
if strings.HasPrefix(address, "127.") { if strings.HasPrefix(address, "127.") {
// Don't bind to an interface for localhost connections, // Don't bind to an interface for localhost connections,
// otherwise we get: // otherwise we get:

View File

@ -251,7 +251,7 @@ func (c *Client) listenPacket(ctx context.Context, network, addr string) (net.Pa
var lc net.ListenConfig var lc net.ListenConfig
return lc.ListenPacket(ctx, network, addr) return lc.ListenPacket(ctx, network, addr)
} }
return netns.Listener().ListenPacket(ctx, network, addr) return netns.Listener(c.logf).ListenPacket(ctx, network, addr)
} }
func (c *Client) invalidateMappingsLocked(releaseOld bool) { func (c *Client) invalidateMappingsLocked(releaseOld bool) {

View File

@ -219,7 +219,7 @@ func (c *Client) upnpHTTPClientLocked() *http.Client {
if c.uPnPHTTPClient == nil { if c.uPnPHTTPClient == nil {
c.uPnPHTTPClient = &http.Client{ c.uPnPHTTPClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: netns.NewDialer().DialContext, DialContext: netns.NewDialer(c.logf).DialContext,
IdleConnTimeout: 2 * time.Second, // LAN is cheap IdleConnTimeout: 2 * time.Second, // LAN is cheap
}, },
} }

View File

@ -2719,7 +2719,7 @@ func (c *Conn) listenPacket(network string, port uint16) (net.PacketConn, error)
if c.testOnlyPacketListener != nil { if c.testOnlyPacketListener != nil {
return c.testOnlyPacketListener.ListenPacket(ctx, network, addr) return c.testOnlyPacketListener.ListenPacket(ctx, network, addr)
} }
return netns.Listener().ListenPacket(ctx, network, addr) return netns.Listener(c.logf).ListenPacket(ctx, network, addr)
} }
// bindSocket initializes rucPtr if necessary and binds a UDP socket to it. // bindSocket initializes rucPtr if necessary and binds a UDP socket to it.