stun: check high bits in Is, add tests

Also use new stun.TxID type in stunner.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2020-02-26 11:34:01 -08:00
parent 2489ea4268
commit 14abc82033
3 changed files with 32 additions and 8 deletions

View File

@ -218,9 +218,7 @@ func mappedAddress(b []byte) (addr []byte, port uint16, err error) {
// Is reports whether b is a STUN message.
func Is(b []byte) bool {
if len(b) < headerLen {
return false // every STUN message must have a 20-byte header
}
// TODO RFC5389 suggests checking the first 2 bits of the header are zero.
return string(b[4:8]) == magicCookie
return len(b) >= headerLen &&
b[0]&0b11000000 == 0 && // top two bits must be zero
string(b[4:8]) == magicCookie
}

View File

@ -166,3 +166,29 @@ func TestParseResponse(t *testing.T) {
})
}
}
func TestIs(t *testing.T) {
const magicCookie = "\x21\x12\xa4\x42"
tests := []struct {
in string
want bool
}{
{"", false},
{"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
{"\x00\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00foo", true},
// high bits set:
{"\xf0\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
{"\x40\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", false},
// first byte non-zero, but not high bits:
{"\x20\x00\x00\x00" + magicCookie + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", true},
}
for i, tt := range tests {
pkt := []byte(tt.in)
got := stun.Is(pkt)
if got != tt.want {
t.Errorf("%d. In(%q (%v)) = %v; want %v", i, pkt, pkt, got, tt.want)
}
}
}

View File

@ -40,7 +40,7 @@ type Stunner struct {
type session struct {
replied chan struct{} // closed when server responds
tIDs [][12]byte // transaction IDs sent to a server
tIDs []stun.TxID // transaction IDs sent to a server
}
// Receive delivers a STUN packet to the stunner.
@ -90,7 +90,7 @@ func (s *Stunner) Run(ctx context.Context) error {
}
for _, server := range s.Servers {
// Generate the transaction IDs for this session.
tIDs := make([][12]byte, len(retryDurations))
tIDs := make([]stun.TxID, len(retryDurations))
for i := range tIDs {
if _, err := rand.Read(tIDs[i][:]); err != nil {
return fmt.Errorf("stunner: rand failed: %v", err)
@ -147,7 +147,7 @@ func (s *Stunner) runServer(ctx context.Context, server string) {
}
}
func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error {
func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error {
host, port, err := net.SplitHostPort(server)
if err != nil {
return err