From 5ec01bf3ce6c01841bfc6d17736b5b35df06d2a3 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 18 Jun 2024 13:44:12 -0700 Subject: [PATCH] wgengine/filter: support FilterRules matching on srcIP node caps [capver 100] See #12542 for background. Updates #12542 Change-Id: Ida312f700affc00d17681dc7551ee9672eeb1789 Signed-off-by: Brad Fitzpatrick --- ipn/ipnlocal/local.go | 33 +++++++- net/tstun/wrap_test.go | 2 +- tailcfg/tailcfg.go | 4 +- wgengine/filter/filter.go | 40 ++++++---- wgengine/filter/filter_test.go | 78 ++++++++++++++++--- wgengine/filter/filtertype/filtertype.go | 25 +++++- .../filter/filtertype/filtertype_clone.go | 2 + wgengine/filter/match.go | 47 +++++++++-- wgengine/filter/tailcfg.go | 37 +++++---- 9 files changed, 212 insertions(+), 56 deletions(-) diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 1e50f6264..0d4a87629 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -250,9 +250,9 @@ type LocalBackend struct { // delta node mutations as they come in (with mu held). The map values can // be given out to callers, but the map itself must not escape the LocalBackend. peers map[tailcfg.NodeID]tailcfg.NodeView - nodeByAddr map[netip.Addr]tailcfg.NodeID - nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil - activeLogin string // last logged LoginName from netMap + nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes) + nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil + activeLogin string // last logged LoginName from netMap engineStatus ipn.EngineStatus endpoints []tailcfg.Endpoint blocked bool @@ -2021,7 +2021,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf)) } else { b.logf("[v1] netmap packet filter: %v filters", len(packetFilter)) - b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf)) + b.setFilter(filter.New(packetFilter, b.srcIPHasCapForFilter, localNets, logNets, oldFilter, b.logf)) } // The filter for a jailed node is the exact same as a ShieldsUp filter. oldJailedFilter := b.e.GetJailedFilter() @@ -6839,3 +6839,28 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) { }() return nil } + +// srcIPHasCapForFilter is called by the packet filter when evaluating firewall +// rules that require a source IP to have a certain node capability. +// +// TODO(bradfitz): optimize this later if/when it matters. +func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool { + if cap == "" { + // Shouldn't happen, but just in case. + // But the empty cap also shouldn't be found in Node.CapMap. + return false + } + + b.mu.Lock() + defer b.mu.Unlock() + + nodeID, ok := b.nodeByAddr[srcIP] + if !ok { + return false + } + n, ok := b.peers[nodeID] + if !ok { + return false + } + return n.HasCap(cap) +} diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index d6287c652..fb0324989 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -168,7 +168,7 @@ func setfilter(logf logger.Logf, tun *Wrapper) { var sb netipx.IPSetBuilder sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16")) ipSet, _ := sb.IPSet() - tun.SetFilter(filter.New(matches, ipSet, ipSet, nil, logf)) + tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf)) } func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index a7eb68d37..8ab7baf3c 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -140,7 +140,8 @@ type CapabilityVersion int // - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers // - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information // - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT -const CurrentCapabilityVersion CapabilityVersion = 99 +// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542) +const CurrentCapabilityVersion CapabilityVersion = 100 type StableID string @@ -1480,6 +1481,7 @@ type FilterRule struct { // * the string "*" to match everything (both IPv4 & IPv6) // * a CIDR (e.g. "192.168.0.0/16") // * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") + // * a string "cap:" with NodeCapMap cap name SrcIPs []string // SrcBits is deprecated; it was the old way to specify a CIDR diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index d87a99393..56224ac5d 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -42,6 +42,10 @@ type Filter struct { logIPs4 func(netip.Addr) bool logIPs6 func(netip.Addr) bool + // srcIPHasCap optionally specifies a function that reports + // whether a given source IP address has a given capability. + srcIPHasCap CapTestFunc + // matches4 and matches6 are lists of match->action rules // applied to all packets arriving over tailscale // tunnels. Matches are checked in order, and processing stops @@ -157,12 +161,12 @@ func NewAllowAllForTest(logf logger.Logf) *Filter { sb.AddPrefix(any4) sb.AddPrefix(any6) ipSet, _ := sb.IPSet() - return New(ms, ipSet, ipSet, nil, logf) + return New(ms, nil, ipSet, ipSet, nil, logf) } // NewAllowNone returns a packet filter that rejects everything. func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter { - return New(nil, &netipx.IPSet{}, logIPs, nil, logf) + return New(nil, nil, &netipx.IPSet{}, logIPs, nil, logf) } // NewShieldsUpFilter returns a packet filter that rejects incoming connections. @@ -174,17 +178,20 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat if shareStateWith != nil && !shareStateWith.shieldsUp { shareStateWith = nil } - f := New(nil, localNets, logIPs, shareStateWith, logf) + f := New(nil, nil, localNets, logIPs, shareStateWith, logf) f.shieldsUp = true return f } -// New creates a new packet filter. The filter enforces that incoming -// packets must be destined to an IP in localNets, and must be allowed -// by matches. If shareStateWith is non-nil, the returned filter -// shares state with the previous one, to enable changing rules at -// runtime without breaking existing stateful flows. -func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { +// New creates a new packet filter. The filter enforces that incoming packets +// must be destined to an IP in localNets, and must be allowed by matches. +// The optional capTest func is used to evaluate a Match that uses capabilities. +// If nil, such matches will always fail. +// +// If shareStateWith is non-nil, the returned filter shares state with the +// previous one, to enable changing rules at runtime without breaking existing +// stateful flows. +func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter { var state *filterState if shareStateWith != nil { state = shareStateWith.state @@ -229,6 +236,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches { for _, m := range ms { var retm Match retm.IPProto = m.IPProto + retm.SrcCaps = m.SrcCaps for _, src := range m.Srcs { if keep(src.Addr()) { retm.Srcs = append(retm.Srcs, src) @@ -240,7 +248,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches { retm.Dsts = append(retm.Dsts, dst) } } - if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 { + if (len(retm.Srcs) > 0 || len(retm.SrcCaps) > 0) && len(retm.Dsts) > 0 { retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs)) ret = append(ret, retm) } @@ -462,7 +470,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { // related to an existing ICMP-Echo, TCP, or UDP // session. return Accept, "icmp response ok" - } else if f.matches4.matchIPsOnly(q) { + } else if f.matches4.matchIPsOnly(q, f.srcIPHasCap) { // If any port is open to an IP, allow ICMP to it. return Accept, "icmp ok" } @@ -478,7 +486,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { if !q.IsTCPSyn() { return Accept, "tcp non-syn" } - if f.matches4.match(q) { + if f.matches4.match(q, f.srcIPHasCap) { return Accept, "tcp ok" } case ipproto.UDP, ipproto.SCTP: @@ -491,7 +499,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { if ok { return Accept, "cached" } - if f.matches4.match(q) { + if f.matches4.match(q, f.srcIPHasCap) { return Accept, "ok" } case ipproto.TSMP: @@ -522,7 +530,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { // related to an existing ICMP-Echo, TCP, or UDP // session. return Accept, "icmp response ok" - } else if f.matches6.matchIPsOnly(q) { + } else if f.matches6.matchIPsOnly(q, f.srcIPHasCap) { // If any port is open to an IP, allow ICMP to it. return Accept, "icmp ok" } @@ -538,7 +546,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { if q.IPProto == ipproto.TCP && !q.IsTCPSyn() { return Accept, "tcp non-syn" } - if f.matches6.match(q) { + if f.matches6.match(q, f.srcIPHasCap) { return Accept, "tcp ok" } case ipproto.UDP, ipproto.SCTP: @@ -551,7 +559,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { if ok { return Accept, "cached" } - if f.matches6.match(q) { + if f.matches6.match(q, f.srcIPHasCap) { return Accept, "ok" } case ipproto.TSMP: diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index 082476371..f2796d71f 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -40,14 +40,31 @@ const ( testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy ) -func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match { - if protos == nil { +// m returnns a Match with the given srcs and dsts. +// +// opts can be ipproto.Proto values (if none, defaultProtos is used) +// or tailcfg.NodeCapability values. Other values panic. +func m(srcs []netip.Prefix, dsts []NetPortRange, opts ...any) Match { + var protos []ipproto.Proto + var caps []tailcfg.NodeCapability + for _, o := range opts { + switch o := o.(type) { + case ipproto.Proto: + protos = append(protos, o) + case tailcfg.NodeCapability: + caps = append(caps, o) + default: + panic(fmt.Sprintf("unknown option type %T", o)) + } + } + if len(protos) == 0 { protos = defaultProtos } return Match{ IPProto: views.SliceOf(protos), Srcs: srcs, SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)), + SrcCaps: caps, Dsts: dsts, } } @@ -65,6 +82,7 @@ func newFilter(logf logger.Logf) *Filter { m(nets("::/0"), netports("::/0:443")), m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto), m(nets("::/0"), netports("::/0:*"), testAllowedProto), + m(nil, netports("1.2.3.4:22"), tailcfg.NodeCapability("cap-hit-1234-ssh")), } // Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8, @@ -79,11 +97,17 @@ func newFilter(logf logger.Logf) *Filter { localNetsSet, _ := localNets.IPSet() logBSet, _ := logB.IPSet() - return New(matches, localNetsSet, logBSet, nil, logf) + return New(matches, nil, localNetsSet, logBSet, nil, logf) } func TestFilter(t *testing.T) { - acl := newFilter(t.Logf) + filt := newFilter(t.Logf) + + ipWithCap := netip.MustParseAddr("10.0.0.1") + ipWithoutCap := netip.MustParseAddr("10.0.0.2") + filt.srcIPHasCap = func(ip netip.Addr, cap tailcfg.NodeCapability) bool { + return cap == "cap-hit-1234-ssh" && ip == ipWithCap + } type InOut struct { want Response @@ -139,21 +163,27 @@ func TestFilter(t *testing.T) { {Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)}, {Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)}, {Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)}, + + // Test use of a node capability to grant access. + // 10.0.0.1 has the capability; 10.0.0.2 does not (see srcIPHasCap at top of func) + {Accept, parsed(ipproto.TCP, ipWithCap.String(), "1.2.3.4", 30000, 22)}, + {Drop, parsed(ipproto.TCP, ipWithoutCap.String(), "1.2.3.4", 30000, 22)}, } for i, test := range tests { - aclFunc := acl.runIn4 + aclFunc := filt.runIn4 if test.p.IPVersion == 6 { - aclFunc = acl.runIn6 + aclFunc = filt.runIn6 } if got, why := aclFunc(&test.p); test.want != got { t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p) + continue } if test.p.IPProto == ipproto.TCP { var got Response if test.p.IPVersion == 4 { - got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) + got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) } else { - got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) + got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port()) } if test.want != got { t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p) @@ -165,7 +195,7 @@ func TestFilter(t *testing.T) { } } // Update UDP state - _, _ = acl.runOut(&test.p) + _, _ = filt.runOut(&test.p) } } @@ -264,13 +294,16 @@ func TestParseIPSet(t *testing.T) { {"*", pfx("0.0.0.0/0", "::/0"), ""}, } for _, tt := range tests { - got, err := parseIPSet(tt.host) + got, gotCap, err := parseIPSet(tt.host) if err != nil { if err.Error() == tt.wantErr { continue } t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr) } + if gotCap != "" { + t.Errorf("parseIPSet(%q) cap: %q; want empty", tt.host, gotCap) + } compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b }) compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }) if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" { @@ -278,6 +311,27 @@ func TestParseIPSet(t *testing.T) { continue } } + + capTests := []struct { + in string + want tailcfg.NodeCapability + }{ + {"cap:foo", "foo"}, + {"cap:people-in-8.8.8.0/24", "people-in-8.8.8.0/24"}, // test precedence of "/" search + } + for _, tt := range capTests { + pfxes, gotCap, err := parseIPSet(tt.in) + if err != nil { + t.Errorf("parseIPSet(%q) error: %v; want no error", tt.in, err) + continue + } + if gotCap != tt.want { + t.Errorf("parseIPSet(%q) cap: %q; want %q", tt.in, gotCap, tt.want) + } + if len(pfxes) != 0 { + t.Errorf("parseIPSet(%q) pfxes: %v; want empty", tt.in, pfxes) + } + } } func BenchmarkFilter(b *testing.B) { @@ -904,7 +958,7 @@ func TestPeerCaps(t *testing.T) { if err != nil { t.Fatal(err) } - filt := New(mm, nil, nil, nil, t.Logf) + filt := New(mm, nil, nil, nil, nil, t.Logf) tests := []struct { name string src, dst string // IP @@ -1037,7 +1091,7 @@ func benchmarkFile(b *testing.B, file string, opt benchOpt) { logIPs.AddPrefix(tsaddr.CGNATRange()) logIPs.AddPrefix(tsaddr.TailscaleULARange()) - f := New(matches, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard) + f := New(matches, nil, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard) var srcIP, dstIP netip.Addr if opt.v4 { srcIP = netip.MustParseAddr("1.2.3.4") diff --git a/wgengine/filter/filtertype/filtertype.go b/wgengine/filter/filtertype/filtertype.go index 689a45e7c..212eda43f 100644 --- a/wgengine/filter/filtertype/filtertype.go +++ b/wgengine/filter/filtertype/filtertype.go @@ -66,11 +66,28 @@ type CapMatch struct { // Match matches packets from any IP address in Srcs to any ip:port in // Dsts. type Match struct { - IPProto views.Slice[ipproto.Proto] // required set (no default value at this layer) - Srcs []netip.Prefix + // IPProto is the set of IP protocol numbers for which this match applies. + // It is required. There is no default value at this layer. + // If empty, it doesn't match. + IPProto views.Slice[ipproto.Proto] + + // Srcs is the set of source IP prefixes for which this match applies. A + // Match can match by either its source IP address being in Srcs (which + // SrcsContains tests) or if the source IP is of a known peer self address + // that contains a NodeCapability listed in SrcCaps. + Srcs []netip.Prefix + // SrcsContains is an optimized function that reports whether Addr is in + // Srcs, using the best search method for the size and shape of Srcs. SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs - Dsts []NetPortRange // optional, if Srcs match - Caps []CapMatch // optional, if Srcs match + + // SrcCaps is an alternative way to match packets. If the peer's source IP + // has one of these capabilities, it's also permitted. The peers are only + // looked up by their self address (Node.Addresses) and not by subnet routes + // they advertise. + SrcCaps []tailcfg.NodeCapability + + Dsts []NetPortRange // optional, if source matches + Caps []CapMatch // optional, if source match } func (m Match) String() string { diff --git a/wgengine/filter/filtertype/filtertype_clone.go b/wgengine/filter/filtertype/filtertype_clone.go index 122f1bbe7..63709188e 100644 --- a/wgengine/filter/filtertype/filtertype_clone.go +++ b/wgengine/filter/filtertype/filtertype_clone.go @@ -23,6 +23,7 @@ func (src *Match) Clone() *Match { *dst = *src dst.IPProto = src.IPProto dst.Srcs = append(src.Srcs[:0:0], src.Srcs...) + dst.SrcCaps = append(src.SrcCaps[:0:0], src.SrcCaps...) dst.Dsts = append(src.Dsts[:0:0], src.Dsts...) if src.Caps != nil { dst.Caps = make([]CapMatch, len(src.Caps)) @@ -38,6 +39,7 @@ var _MatchCloneNeedsRegeneration = Match(struct { IPProto views.Slice[ipproto.Proto] Srcs []netip.Prefix SrcsContains func(netip.Addr) bool + SrcCaps []tailcfg.NodeCapability Dsts []NetPortRange Caps []CapMatch }{}) diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go index 4d93979ea..6292c4971 100644 --- a/wgengine/filter/match.go +++ b/wgengine/filter/match.go @@ -4,19 +4,23 @@ package filter import ( + "net/netip" + "tailscale.com/net/packet" + "tailscale.com/tailcfg" "tailscale.com/types/views" "tailscale.com/wgengine/filter/filtertype" ) type matches []filtertype.Match -func (ms matches) match(q *packet.Parsed) bool { - for _, m := range ms { +func (ms matches) match(q *packet.Parsed, hasCap CapTestFunc) bool { + for i := range ms { + m := &ms[i] if !views.SliceContains(m.IPProto, q.IPProto) { continue } - if !m.SrcsContains(q.Src.Addr()) { + if !srcMatches(m, q.Src.Addr(), hasCap) { continue } for _, dst := range m.Dsts { @@ -32,9 +36,33 @@ func (ms matches) match(q *packet.Parsed) bool { return false } -func (ms matches) matchIPsOnly(q *packet.Parsed) bool { +// srcMatches reports whether srcAddr matche the src requirements in m, either +// by Srcs (using SrcsContains), or by the node having a capability listed +// in SrcCaps using the provided hasCap function. +func srcMatches(m *filtertype.Match, srcAddr netip.Addr, hasCap CapTestFunc) bool { + if m.SrcsContains(srcAddr) { + return true + } + if hasCap != nil { + for _, c := range m.SrcCaps { + if hasCap(srcAddr, c) { + return true + } + } + } + return false +} + +// CapTestFunc is the function signature of a function that tests whether srcIP +// has a given capability. +// +// It it used in the fast path of evaluating filter rules so should be fast. +type CapTestFunc = func(srcIP netip.Addr, cap tailcfg.NodeCapability) bool + +func (ms matches) matchIPsOnly(q *packet.Parsed, hasCap CapTestFunc) bool { + srcAddr := q.Src.Addr() for _, m := range ms { - if !m.SrcsContains(q.Src.Addr()) { + if !m.SrcsContains(srcAddr) { continue } for _, dst := range m.Dsts { @@ -43,6 +71,15 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool { } } } + if hasCap != nil { + for _, m := range ms { + for _, c := range m.SrcCaps { + if hasCap(srcAddr, c) { + return true + } + } + } + } return false } diff --git a/wgengine/filter/tailcfg.go b/wgengine/filter/tailcfg.go index ab77ea315..ff81077f7 100644 --- a/wgengine/filter/tailcfg.go +++ b/wgengine/filter/tailcfg.go @@ -58,12 +58,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { } for _, s := range r.SrcIPs { - nets, err := parseIPSet(s) + nets, cap, err := parseIPSet(s) if err != nil && erracc == nil { erracc = err continue } m.Srcs = append(m.Srcs, nets...) + if cap != "" { + m.SrcCaps = append(m.SrcCaps, cap) + } } m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs)) @@ -71,11 +74,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) { if d.Bits != nil { return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version") } - nets, err := parseIPSet(d.IP) + nets, cap, err := parseIPSet(d.IP) if err != nil && erracc == nil { erracc = err continue } + if cap != "" { + erracc = fmt.Errorf("unexpected capability %q in DstPorts", cap) + continue + } for _, net := range nets { m.Dsts = append(m.Dsts, NetPortRange{ Net: net, @@ -120,48 +127,52 @@ var ( // - the string "*" to match everything (both IPv4 & IPv6) // - a CIDR (e.g. "192.168.0.0/16") // - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800") +// - "cap:" to match a peer node capability // // TODO(bradfitz): make this return an IPSet and plumb that all // around, and ultimately use a new version of IPSet.ContainsFunc like // Contains16Func that works in [16]byte address, so we we can match // at runtime without allocating? -func parseIPSet(arg string) ([]netip.Prefix, error) { +func parseIPSet(arg string) (prefixes []netip.Prefix, peerCap tailcfg.NodeCapability, err error) { if arg == "*" { // User explicitly requested wildcard. return []netip.Prefix{ netip.PrefixFrom(zeroIP4, 0), netip.PrefixFrom(zeroIP6, 0), - }, nil + }, "", nil + } + if cap, ok := strings.CutPrefix(arg, "cap:"); ok { + return nil, tailcfg.NodeCapability(cap), nil } if strings.Contains(arg, "/") { pfx, err := netip.ParsePrefix(arg) if err != nil { - return nil, err + return nil, "", err } if pfx != pfx.Masked() { - return nil, fmt.Errorf("%v contains non-network bits set", pfx) + return nil, "", fmt.Errorf("%v contains non-network bits set", pfx) } - return []netip.Prefix{pfx}, nil + return []netip.Prefix{pfx}, "", nil } if strings.Count(arg, "-") == 1 { ip1s, ip2s, _ := strings.Cut(arg, "-") ip1, err := netip.ParseAddr(ip1s) if err != nil { - return nil, err + return nil, "", err } ip2, err := netip.ParseAddr(ip2s) if err != nil { - return nil, err + return nil, "", err } r := netipx.IPRangeFrom(ip1, ip2) if !r.IsValid() { - return nil, fmt.Errorf("invalid IP range %q", arg) + return nil, "", fmt.Errorf("invalid IP range %q", arg) } - return r.Prefixes(), nil + return r.Prefixes(), "", nil } ip, err := netip.ParseAddr(arg) if err != nil { - return nil, fmt.Errorf("invalid IP address %q", arg) + return nil, "", fmt.Errorf("invalid IP address %q", arg) } - return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, nil + return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, "", nil }