diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 70342f0a2..7d615867c 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -35,7 +35,8 @@ import ( // 10: 2021-01-17: client understands MapResponse.PeerSeenChange // 11: 2021-03-03: client understands IPv6, multiple default routes, and goroutine dumping // 12: 2021-03-04: client understands PingRequest -const CurrentMapRequestVersion = 12 +// 13: 2021-03-19: client understands FilterRule.IPProto +const CurrentMapRequestVersion = 13 type StableID string @@ -693,6 +694,17 @@ type FilterRule struct { // DstPorts are the port ranges to allow once a source IP // matches (is in the CIDR described by SrcIPs & SrcBits). DstPorts []NetPortRange + + // IPProto are the IP protocol numbers to match. + // + // As a special case, nil or empty means TCP, UDP, and ICMP. + // + // Numbers outside the uint8 range (below 0 or above 255) are + // reserved for Tailscale's use. Unknown ones are ignored. + // + // Depending on the IPProto values, DstPorts may or may not be + // used. + IPProto []int `json:",omitempty"` } var FilterAllowAll = []FilterRule{ diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index cbb985114..ee11bd44d 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -182,6 +182,7 @@ func matchesFamily(ms matches, keep func(netaddr.IP) bool) matches { var ret matches for _, m := range ms { var retm Match + retm.IPProto = m.IPProto for _, src := range m.Srcs { if keep(src.IP) { retm.Srcs = append(retm.Srcs, src) diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 9f98761f6..3b0748a8c 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -7,6 +7,7 @@ package filter import ( "encoding/hex" "fmt" + "reflect" "strconv" "strings" "testing" @@ -16,19 +17,27 @@ import ( "inet.af/netaddr" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" "tailscale.com/types/logger" ) func newFilter(logf logger.Logf) *Filter { + m := func(srcs []netaddr.IPPrefix, dsts []NetPortRange) Match { + return Match{ + IPProto: defaultProtos, + Srcs: srcs, + Dsts: dsts, + } + } matches := []Match{ - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("1.2.3.4:22", "5.6.7.8:23-24")}, - {Srcs: nets("8.1.1.1", "8.2.2.2"), Dsts: netports("5.6.7.8:27-28")}, - {Srcs: nets("2.2.2.2"), Dsts: netports("8.1.1.1:22")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")}, - {Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")}, - {Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")}, - {Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22", "2001::2:22")}, - {Srcs: nets("::/0"), Dsts: netports("::/0:443")}, + m(nets("8.1.1.1", "8.2.2.2"), netports("1.2.3.4:22", "5.6.7.8:23-24")), + m(nets("8.1.1.1", "8.2.2.2"), netports("5.6.7.8:27-28")), + m(nets("2.2.2.2"), netports("8.1.1.1:22")), + m(nets("0.0.0.0/0"), netports("100.122.98.50:*")), + m(nets("0.0.0.0/0"), netports("0.0.0.0/0:443")), + m(nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), netports("1.2.3.4:999")), + m(nets("::1", "::2"), netports("2001::1:22", "2001::2:22")), + m(nets("::/0"), netports("::/0:443")), } // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, @@ -89,6 +98,9 @@ func TestFilter(t *testing.T) { // unexpected dst IP. {Drop, parsed(packet.TCP, "8.1.1.1", "16.32.48.64", 0, 443)}, {Drop, parsed(packet.TCP, "1::", "2602::1", 0, 443)}, + + // Don't allow protocols not specified by filter + {Drop, parsed(132 /* SCTP */, "8.1.1.1", "1.2.3.4", 999, 22)}, } for i, test := range tests { aclFunc := acl.runIn4 @@ -707,3 +719,91 @@ func netports(netPorts ...string) (ret []NetPortRange) { } return ret } + +func TestMatchesFromFilterRules(t *testing.T) { + tests := []struct { + name string + in []tailcfg.FilterRule + want []Match + }{ + { + name: "empty", + want: []Match{}, + }, + { + name: "implicit_protos", + in: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.1.1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "*", + Ports: tailcfg.PortRange{First: 22, Last: 22}, + }}, + }, + }, + want: []Match{ + { + IPProto: []packet.IPProto{ + packet.TCP, + packet.UDP, + packet.ICMPv4, + packet.ICMPv6, + }, + Dsts: []NetPortRange{ + { + Net: netaddr.MustParseIPPrefix("0.0.0.0/0"), + Ports: PortRange{22, 22}, + }, + { + Net: netaddr.MustParseIPPrefix("::0/0"), + Ports: PortRange{22, 22}, + }, + }, + Srcs: []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("100.64.1.1/32"), + }, + }, + }, + }, + { + name: "explicit_protos", + in: []tailcfg.FilterRule{ + { + IPProto: []int{int(packet.TCP)}, + SrcIPs: []string{"100.64.1.1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.2.0.0/16", + Ports: tailcfg.PortRange{First: 22, Last: 22}, + }}, + }, + }, + want: []Match{ + { + IPProto: []packet.IPProto{ + packet.TCP, + }, + Dsts: []NetPortRange{ + { + Net: netaddr.MustParseIPPrefix("1.2.0.0/16"), + Ports: PortRange{22, 22}, + }, + }, + Srcs: []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix("100.64.1.1/32"), + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := MatchesFromFilterRules(tt.in) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("wrong\n got: %v\nwant: %v\n", got, tt.want) + } + }) + } +} diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index c30c37552..799614104 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -47,11 +47,13 @@ func (npr NetPortRange) String() string { // Match matches packets from any IP address in Srcs to any ip:port in // Dsts. type Match struct { - Dsts []NetPortRange - Srcs []netaddr.IPPrefix + IPProto []packet.IPProto // required set (no default value at this layer) + Dsts []NetPortRange + Srcs []netaddr.IPPrefix } func (m Match) String() string { + // TODO(bradfitz): use strings.Builder, add String tests srcs := []string{} for _, src := range m.Srcs { srcs = append(srcs, src.String()) @@ -72,13 +74,16 @@ func (m Match) String() string { } else { ds = "[" + strings.Join(dsts, ",") + "]" } - return fmt.Sprintf("%v=>%v", ss, ds) + return fmt.Sprintf("%v%v=>%v", m.IPProto, ss, ds) } type matches []Match func (ms matches) match(q *packet.Parsed) bool { for _, m := range ms { + if !protoInList(q.IPProto, m.IPProto) { + continue + } if !ipInList(q.Src.IP, m.Srcs) { continue } @@ -117,3 +122,12 @@ func ipInList(ip netaddr.IP, netlist []netaddr.IPPrefix) bool { } return false } + +func protoInList(proto packet.IPProto, valid []packet.IPProto) bool { + for _, v := range valid { + if proto == v { + return true + } + } + return false +} diff --git a/wgengine/filter/match_clone.go b/wgengine/filter/match_clone.go index 571664bd5..b031e1e3a 100644 --- a/wgengine/filter/match_clone.go +++ b/wgengine/filter/match_clone.go @@ -8,6 +8,7 @@ package filter import ( "inet.af/netaddr" + "tailscale.com/net/packet" ) // Clone makes a deep copy of Match. @@ -18,6 +19,7 @@ func (src *Match) Clone() *Match { } dst := new(Match) *dst = *src + dst.IPProto = append(src.IPProto[:0:0], src.IPProto...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) return dst @@ -26,6 +28,7 @@ func (src *Match) Clone() *Match { // A compilation failure here means this code must be regenerated, with command: // tailscale.com/cmd/cloner -type Match var _MatchNeedsRegeneration = Match(struct { - Dsts []NetPortRange - Srcs []netaddr.IPPrefix + IPProto []packet.IPProto + Dsts []NetPortRange + Srcs []netaddr.IPPrefix }{}) diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index 2f20cdb61..26b7ed3da 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -9,9 +9,17 @@ import ( "strings" "inet.af/netaddr" + "tailscale.com/net/packet" "tailscale.com/tailcfg" ) +var defaultProtos = []packet.IPProto{ + packet.TCP, + packet.UDP, + packet.ICMPv4, + packet.ICMPv6, +} + // MatchesFromFilterRules converts tailcfg FilterRules into Matches. // If an error is returned, the Matches result is still valid, // containing the rules that were successfully converted. @@ -22,6 +30,17 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { for _, r := range pf { m := Match{} + if len(r.IPProto) == 0 { + m.IPProto = append([]packet.IPProto(nil), defaultProtos...) + } else { + m.IPProto = make([]packet.IPProto, 0, len(r.IPProto)) + for _, n := range r.IPProto { + if n >= 0 && n <= 0xff { + m.IPProto = append(m.IPProto, packet.IPProto(n)) + } + } + } + for i, s := range r.SrcIPs { var bits *int if len(r.SrcBits) > i { diff --git a/wgengine/tstun/tun_test.go b/wgengine/tstun/tun_test.go index 365d56b4d..555324d19 100644 --- a/wgengine/tstun/tun_test.go +++ b/wgengine/tstun/tun_test.go @@ -106,9 +106,13 @@ func netports(netPorts ...string) (ret []filter.NetPortRange) { } func setfilter(logf logger.Logf, tun *TUN) { + protos := []packet.IPProto{ + packet.TCP, + packet.UDP, + } matches := []filter.Match{ - {Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, - {Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, + {IPProto: protos, Srcs: nets("5.6.7.8"), Dsts: netports("1.2.3.4:89-90")}, + {IPProto: protos, Srcs: nets("1.2.3.4"), Dsts: netports("5.6.7.8:98")}, } var sb netaddr.IPSetBuilder sb.AddPrefix(netaddr.MustParseIPPrefix("1.2.0.0/16"))