From b4d3e2928b6b389ed3ac320143bf863f550ba237 Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Mon, 13 Mar 2023 20:32:32 -0700 Subject: [PATCH] tsnet: avoid deadlock on close tsnet.Server.Close was calling listener.Close with the server mutex held, but the listener close method tries to grab that mutex, resulting in a deadlock. Co-authored-by: David Crawshaw Signed-off-by: Maisem Ali --- tsnet/tsnet.go | 32 ++++++++++++++++++++++++-------- tsnet/tsnet_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 8 deletions(-) diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 792c60bae..0c514d0bf 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -118,6 +118,7 @@ type Server struct { mu sync.Mutex listeners map[listenKey]*listener dialer *tsdial.Dialer + closed bool } // Dial connects to the address on the tailnet. @@ -303,6 +304,11 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) { // // It must not be called before or concurrently with Start. func (s *Server) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return fmt.Errorf("tsnet: %w", net.ErrClosed) + } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() var wg sync.WaitGroup @@ -350,14 +356,12 @@ func (s *Server) Close() error { s.loopbackListener.Close() } - s.mu.Lock() - defer s.mu.Unlock() for _, ln := range s.listeners { - ln.Close() + ln.closeLocked() } - s.listeners = nil wg.Wait() + s.closed = true return nil } @@ -1017,10 +1021,11 @@ type listenKey struct { } type listener struct { - s *Server - keys []listenKey - addr string - conn chan net.Conn + s *Server + keys []listenKey + addr string + conn chan net.Conn + closed bool // guarded by s.mu } func (ln *listener) Accept() (net.Conn, error) { @@ -1032,15 +1037,26 @@ func (ln *listener) Accept() (net.Conn, error) { } func (ln *listener) Addr() net.Addr { return addr{ln} } + func (ln *listener) Close() error { ln.s.mu.Lock() defer ln.s.mu.Unlock() + return ln.closeLocked() +} + +// closeLocked closes the listener. +// It must be called with ln.s.mu held. +func (ln *listener) closeLocked() error { + if ln.closed { + return fmt.Errorf("tsnet: %w", net.ErrClosed) + } for _, key := range ln.keys { if v, ok := ln.s.listeners[key]; ok && v == ln { delete(ln.s.listeners, key) } } close(ln.conn) + ln.closed = true return nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index ab55b7b60..0dab542f0 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -9,6 +9,7 @@ import ( "flag" "fmt" "io" + "net" "net/http" "net/http/httptest" "net/netip" @@ -344,3 +345,26 @@ func TestTailscaleIPs(t *testing.T) { sIp4, upIp4, sIp6, upIp6) } } + +// TestListenerCleanup is a regression test to verify that s.Close doesn't +// deadlock if a listener is still open. +func TestListenerCleanup(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL := startControl(t) + s1, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8081") + if err != nil { + t.Fatal(err) + } + + if err := s1.Close(); err != nil { + t.Fatal(err) + } + + if err := ln.Close(); !errors.Is(err, net.ErrClosed) { + t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) + } +}