tailcfg: implement text encoding for ProtoPortRange
Updates tailscale/corp#15043 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
parent
96f01a73b1
commit
4abd470322
|
@ -13,6 +13,11 @@ import (
|
||||||
"tailscale.com/util/vizerror"
|
"tailscale.com/util/vizerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errEmptyProtocol = errors.New("empty protocol")
|
||||||
|
errEmptyString = errors.New("empty string")
|
||||||
|
)
|
||||||
|
|
||||||
// ProtoPortRange is used to encode "proto:port" format.
|
// ProtoPortRange is used to encode "proto:port" format.
|
||||||
// The following formats are supported:
|
// The following formats are supported:
|
||||||
//
|
//
|
||||||
|
@ -30,6 +35,28 @@ type ProtoPortRange struct {
|
||||||
Ports PortRange
|
Ports PortRange
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalText implements the encoding.TextUnmarshaler interface. See
|
||||||
|
// ProtoPortRange for the format.
|
||||||
|
func (ppr *ProtoPortRange) UnmarshalText(text []byte) error {
|
||||||
|
ppr2, err := parseProtoPortRange(string(text))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*ppr = *ppr2
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalText implements the encoding.TextMarshaler interface. See
|
||||||
|
// ProtoPortRange for the format.
|
||||||
|
func (ppr *ProtoPortRange) MarshalText() ([]byte, error) {
|
||||||
|
if ppr.Proto == 0 && ppr.Ports == (PortRange{}) {
|
||||||
|
return []byte{}, nil
|
||||||
|
}
|
||||||
|
return []byte(ppr.String()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the stringer interface. See ProtoPortRange for the
|
||||||
|
// format.
|
||||||
func (ppr ProtoPortRange) String() string {
|
func (ppr ProtoPortRange) String() string {
|
||||||
if ppr.Proto == 0 {
|
if ppr.Proto == 0 {
|
||||||
if ppr.Ports == PortRangeAny {
|
if ppr.Ports == PortRangeAny {
|
||||||
|
@ -69,7 +96,7 @@ func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) {
|
||||||
|
|
||||||
func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
|
func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
|
||||||
if ipProtoPort == "" {
|
if ipProtoPort == "" {
|
||||||
return nil, errors.New("empty string")
|
return nil, errEmptyString
|
||||||
}
|
}
|
||||||
if ipProtoPort == "*" {
|
if ipProtoPort == "*" {
|
||||||
return &ProtoPortRange{Ports: PortRangeAny}, nil
|
return &ProtoPortRange{Ports: PortRangeAny}, nil
|
||||||
|
@ -82,7 +109,7 @@ func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if protoStr == "" {
|
if protoStr == "" {
|
||||||
return nil, errors.New("empty protocol")
|
return nil, errEmptyProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
ppr := &ProtoPortRange{
|
ppr := &ProtoPortRange{
|
||||||
|
|
|
@ -4,12 +4,15 @@
|
||||||
package tailcfg
|
package tailcfg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"encoding"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"tailscale.com/types/ipproto"
|
"tailscale.com/types/ipproto"
|
||||||
|
"tailscale.com/util/vizerror"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var _ encoding.TextUnmarshaler = (*ProtoPortRange)(nil)
|
||||||
|
|
||||||
func TestProtoPortRangeParsing(t *testing.T) {
|
func TestProtoPortRangeParsing(t *testing.T) {
|
||||||
pr := func(s, e uint16) PortRange {
|
pr := func(s, e uint16) PortRange {
|
||||||
return PortRange{First: s, Last: e}
|
return PortRange{First: s, Last: e}
|
||||||
|
@ -26,30 +29,28 @@ func TestProtoPortRangeParsing(t *testing.T) {
|
||||||
{in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}},
|
{in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}},
|
||||||
{
|
{
|
||||||
in: "tcp:",
|
in: "tcp:",
|
||||||
err: errors.New(`invalid port list: ""`),
|
err: vizerror.Errorf("invalid port list: %#v", ""),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
in: ":80",
|
in: ":80",
|
||||||
err: errors.New(`empty protocol`),
|
err: errEmptyProtocol,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
in: "",
|
in: "",
|
||||||
err: errors.New(`empty string`),
|
err: errEmptyString,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.in, func(t *testing.T) {
|
t.Run(tc.in, func(t *testing.T) {
|
||||||
ppr, err := parseProtoPortRange(tc.in)
|
var ppr ProtoPortRange
|
||||||
if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr {
|
err := ppr.UnmarshalText([]byte(tc.in))
|
||||||
t.Fatalf("got err %v; want %v", err, tc.err)
|
if tc.err != err {
|
||||||
} else if gotErr {
|
if err == nil || tc.err.Error() != err.Error() {
|
||||||
if err.Error() != tc.err.Error() {
|
t.Fatalf("want err=%v, got %v", tc.err, err)
|
||||||
t.Fatalf("got err %q; want %q", err, tc.err)
|
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if *ppr != tc.out {
|
if ppr != tc.out {
|
||||||
t.Fatalf("got %v; want %v", ppr, tc.out)
|
t.Fatalf("got %v; want %v", ppr, tc.out)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -88,3 +89,43 @@ func TestProtoPortRangeString(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProtoPortRangeRoundTrip(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input ProtoPortRange
|
||||||
|
text string
|
||||||
|
}{
|
||||||
|
{ProtoPortRange{Ports: PortRangeAny}, "*"},
|
||||||
|
{ProtoPortRange{Ports: PortRange{23, 23}}, "23"},
|
||||||
|
{ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"},
|
||||||
|
{ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"},
|
||||||
|
{ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"},
|
||||||
|
{ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"},
|
||||||
|
{ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"},
|
||||||
|
{ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"},
|
||||||
|
{ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"},
|
||||||
|
{ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"},
|
||||||
|
{ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"},
|
||||||
|
{ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"},
|
||||||
|
{ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
out, err := tc.input.MarshalText()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("MarshalText for %v: %v", tc.input, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if got := string(out); got != tc.text {
|
||||||
|
t.Errorf("MarshalText for %#v: got %q, want %q", tc.input, got, tc.text)
|
||||||
|
}
|
||||||
|
var ppr ProtoPortRange
|
||||||
|
if err := ppr.UnmarshalText(out); err != nil {
|
||||||
|
t.Errorf("UnmarshalText for %q: err=%v", tc.text, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ppr != tc.input {
|
||||||
|
t.Errorf("round trip error for %q: got %v, want %#v", tc.text, ppr, tc.input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue