tsnet: be stricter about arguments to Server.Listen
Fixes #6201 Change-Id: I14b2b8ce9bee838344a3fad4f305c78ab775f72e Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
08e110ebc5
commit
cbc89830c4
|
@ -13,6 +13,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
@ -38,6 +39,7 @@ import (
|
|||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/wgengine"
|
||||
"tailscale.com/wgengine/monitor"
|
||||
"tailscale.com/wgengine/netstack"
|
||||
|
@ -423,7 +425,7 @@ func (s *Server) printAuthURLLoop() {
|
|||
|
||||
func (s *Server) forwardTCP(c net.Conn, port uint16) {
|
||||
s.mu.Lock()
|
||||
ln, ok := s.listeners[listenKey{"tcp", "", fmt.Sprint(port)}]
|
||||
ln, ok := s.listeners[listenKey{"tcp", "", port}]
|
||||
s.mu.Unlock()
|
||||
if !ok {
|
||||
c.Close()
|
||||
|
@ -500,16 +502,24 @@ func (s *Server) APIClient() (*tailscale.Client, error) {
|
|||
// Listen announces only on the Tailscale network.
|
||||
// It will start the server if it has not been started yet.
|
||||
func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
switch network {
|
||||
case "", "tcp", "tcp4", "tcp6":
|
||||
default:
|
||||
return nil, errors.New("unsupported network type")
|
||||
}
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tsnet: %w", err)
|
||||
}
|
||||
|
||||
port, err := net.LookupPort(network, portStr)
|
||||
if err != nil || port < 0 || port > math.MaxUint16 {
|
||||
return nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := listenKey{network, host, port}
|
||||
key := listenKey{network, host, uint16(port)}
|
||||
ln := &listener{
|
||||
s: s,
|
||||
key: key,
|
||||
|
@ -518,14 +528,11 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
|||
conn: make(chan net.Conn),
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.listeners == nil {
|
||||
s.listeners = map[listenKey]*listener{}
|
||||
}
|
||||
if _, ok := s.listeners[key]; ok {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr)
|
||||
}
|
||||
s.listeners[key] = ln
|
||||
mak.Set(&s.listeners, key, ln)
|
||||
s.mu.Unlock()
|
||||
return ln, nil
|
||||
}
|
||||
|
@ -533,7 +540,7 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
|||
type listenKey struct {
|
||||
network string
|
||||
host string
|
||||
port string
|
||||
port uint16
|
||||
}
|
||||
|
||||
type listener struct {
|
||||
|
|
|
@ -4,7 +4,10 @@
|
|||
|
||||
package tsnet
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestListener_Server ensures that the listener type always keeps the Server
|
||||
// method, which is used by some external applications to identify a tsnet.Listener
|
||||
|
@ -16,3 +19,29 @@ func TestListener_Server(t *testing.T) {
|
|||
t.Errorf("listener.Server() returned %v, want %v", ln.Server(), s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerPort(t *testing.T) {
|
||||
errNone := errors.New("sentinel start error")
|
||||
|
||||
tests := []struct {
|
||||
network string
|
||||
addr string
|
||||
wantErr bool
|
||||
}{
|
||||
{"tcp", ":80", false},
|
||||
{"foo", ":80", true},
|
||||
{"tcp", ":http", false}, // built-in name to Go; doesn't require cgo, /etc/services
|
||||
{"tcp", ":https", false}, // built-in name to Go; doesn't require cgo, /etc/services
|
||||
{"tcp", ":gibberishsdlkfj", true},
|
||||
{"tcp", ":%!d(string=80)", true}, // issue 6201
|
||||
}
|
||||
for _, tt := range tests {
|
||||
s := &Server{}
|
||||
s.initOnce.Do(func() { s.initErr = errNone })
|
||||
_, err := s.Listen(tt.network, tt.addr)
|
||||
gotErr := err != nil && err != errNone
|
||||
if gotErr != tt.wantErr {
|
||||
t.Errorf("Listen(%q, %q) error = %v, want %v", tt.network, tt.addr, gotErr, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue