tsnet: be stricter about arguments to Server.Listen

Fixes #6201

Change-Id: I14b2b8ce9bee838344a3fad4f305c78ab775f72e
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-11 17:55:14 -08:00 committed by Brad Fitzpatrick
parent 08e110ebc5
commit cbc89830c4
2 changed files with 46 additions and 10 deletions

View File

@ -13,6 +13,7 @@ import (
"fmt"
"io"
"log"
"math"
"net"
"net/http"
"net/netip"
@ -38,6 +39,7 @@ import (
"tailscale.com/net/tsdial"
"tailscale.com/smallzstd"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/wgengine"
"tailscale.com/wgengine/monitor"
"tailscale.com/wgengine/netstack"
@ -423,7 +425,7 @@ func (s *Server) printAuthURLLoop() {
func (s *Server) forwardTCP(c net.Conn, port uint16) {
s.mu.Lock()
ln, ok := s.listeners[listenKey{"tcp", "", fmt.Sprint(port)}]
ln, ok := s.listeners[listenKey{"tcp", "", port}]
s.mu.Unlock()
if !ok {
c.Close()
@ -500,16 +502,24 @@ func (s *Server) APIClient() (*tailscale.Client, error) {
// Listen announces only on the Tailscale network.
// It will start the server if it has not been started yet.
func (s *Server) Listen(network, addr string) (net.Listener, error) {
host, port, err := net.SplitHostPort(addr)
switch network {
case "", "tcp", "tcp4", "tcp6":
default:
return nil, errors.New("unsupported network type")
}
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("tsnet: %w", err)
}
port, err := net.LookupPort(network, portStr)
if err != nil || port < 0 || port > math.MaxUint16 {
return nil, fmt.Errorf("invalid port: %w", err)
}
if err := s.Start(); err != nil {
return nil, err
}
key := listenKey{network, host, port}
key := listenKey{network, host, uint16(port)}
ln := &listener{
s: s,
key: key,
@ -518,14 +528,11 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
conn: make(chan net.Conn),
}
s.mu.Lock()
if s.listeners == nil {
s.listeners = map[listenKey]*listener{}
}
if _, ok := s.listeners[key]; ok {
s.mu.Unlock()
return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr)
}
s.listeners[key] = ln
mak.Set(&s.listeners, key, ln)
s.mu.Unlock()
return ln, nil
}
@ -533,7 +540,7 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
type listenKey struct {
network string
host string
port string
port uint16
}
type listener struct {

View File

@ -4,7 +4,10 @@
package tsnet
import "testing"
import (
"errors"
"testing"
)
// TestListener_Server ensures that the listener type always keeps the Server
// method, which is used by some external applications to identify a tsnet.Listener
@ -16,3 +19,29 @@ func TestListener_Server(t *testing.T) {
t.Errorf("listener.Server() returned %v, want %v", ln.Server(), s)
}
}
func TestListenerPort(t *testing.T) {
errNone := errors.New("sentinel start error")
tests := []struct {
network string
addr string
wantErr bool
}{
{"tcp", ":80", false},
{"foo", ":80", true},
{"tcp", ":http", false}, // built-in name to Go; doesn't require cgo, /etc/services
{"tcp", ":https", false}, // built-in name to Go; doesn't require cgo, /etc/services
{"tcp", ":gibberishsdlkfj", true},
{"tcp", ":%!d(string=80)", true}, // issue 6201
}
for _, tt := range tests {
s := &Server{}
s.initOnce.Do(func() { s.initErr = errNone })
_, err := s.Listen(tt.network, tt.addr)
gotErr := err != nil && err != errNone
if gotErr != tt.wantErr {
t.Errorf("Listen(%q, %q) error = %v, want %v", tt.network, tt.addr, gotErr, tt.wantErr)
}
}
}