wgengine/filter: support FilterRules matching on srcIP node caps [capver 100]
See #12542 for background. Updates #12542 Change-Id: Ida312f700affc00d17681dc7551ee9672eeb1789 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
07063bc5c7
commit
5ec01bf3ce
|
@ -250,9 +250,9 @@ type LocalBackend struct {
|
|||
// delta node mutations as they come in (with mu held). The map values can
|
||||
// be given out to callers, but the map itself must not escape the LocalBackend.
|
||||
peers map[tailcfg.NodeID]tailcfg.NodeView
|
||||
nodeByAddr map[netip.Addr]tailcfg.NodeID
|
||||
nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil
|
||||
activeLogin string // last logged LoginName from netMap
|
||||
nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes)
|
||||
nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil
|
||||
activeLogin string // last logged LoginName from netMap
|
||||
engineStatus ipn.EngineStatus
|
||||
endpoints []tailcfg.Endpoint
|
||||
blocked bool
|
||||
|
@ -2021,7 +2021,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P
|
|||
b.setFilter(filter.NewShieldsUpFilter(localNets, logNets, oldFilter, b.logf))
|
||||
} else {
|
||||
b.logf("[v1] netmap packet filter: %v filters", len(packetFilter))
|
||||
b.setFilter(filter.New(packetFilter, localNets, logNets, oldFilter, b.logf))
|
||||
b.setFilter(filter.New(packetFilter, b.srcIPHasCapForFilter, localNets, logNets, oldFilter, b.logf))
|
||||
}
|
||||
// The filter for a jailed node is the exact same as a ShieldsUp filter.
|
||||
oldJailedFilter := b.e.GetJailedFilter()
|
||||
|
@ -6839,3 +6839,28 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) {
|
|||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// srcIPHasCapForFilter is called by the packet filter when evaluating firewall
|
||||
// rules that require a source IP to have a certain node capability.
|
||||
//
|
||||
// TODO(bradfitz): optimize this later if/when it matters.
|
||||
func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool {
|
||||
if cap == "" {
|
||||
// Shouldn't happen, but just in case.
|
||||
// But the empty cap also shouldn't be found in Node.CapMap.
|
||||
return false
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
nodeID, ok := b.nodeByAddr[srcIP]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
n, ok := b.peers[nodeID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return n.HasCap(cap)
|
||||
}
|
||||
|
|
|
@ -168,7 +168,7 @@ func setfilter(logf logger.Logf, tun *Wrapper) {
|
|||
var sb netipx.IPSetBuilder
|
||||
sb.AddPrefix(netip.MustParsePrefix("1.2.0.0/16"))
|
||||
ipSet, _ := sb.IPSet()
|
||||
tun.SetFilter(filter.New(matches, ipSet, ipSet, nil, logf))
|
||||
tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf))
|
||||
}
|
||||
|
||||
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
||||
|
|
|
@ -140,7 +140,8 @@ type CapabilityVersion int
|
|||
// - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers
|
||||
// - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information
|
||||
// - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT
|
||||
const CurrentCapabilityVersion CapabilityVersion = 99
|
||||
// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542)
|
||||
const CurrentCapabilityVersion CapabilityVersion = 100
|
||||
|
||||
type StableID string
|
||||
|
||||
|
@ -1480,6 +1481,7 @@ type FilterRule struct {
|
|||
// * the string "*" to match everything (both IPv4 & IPv6)
|
||||
// * a CIDR (e.g. "192.168.0.0/16")
|
||||
// * a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||
// * a string "cap:<capability>" with NodeCapMap cap name
|
||||
SrcIPs []string
|
||||
|
||||
// SrcBits is deprecated; it was the old way to specify a CIDR
|
||||
|
|
|
@ -42,6 +42,10 @@ type Filter struct {
|
|||
logIPs4 func(netip.Addr) bool
|
||||
logIPs6 func(netip.Addr) bool
|
||||
|
||||
// srcIPHasCap optionally specifies a function that reports
|
||||
// whether a given source IP address has a given capability.
|
||||
srcIPHasCap CapTestFunc
|
||||
|
||||
// matches4 and matches6 are lists of match->action rules
|
||||
// applied to all packets arriving over tailscale
|
||||
// tunnels. Matches are checked in order, and processing stops
|
||||
|
@ -157,12 +161,12 @@ func NewAllowAllForTest(logf logger.Logf) *Filter {
|
|||
sb.AddPrefix(any4)
|
||||
sb.AddPrefix(any6)
|
||||
ipSet, _ := sb.IPSet()
|
||||
return New(ms, ipSet, ipSet, nil, logf)
|
||||
return New(ms, nil, ipSet, ipSet, nil, logf)
|
||||
}
|
||||
|
||||
// NewAllowNone returns a packet filter that rejects everything.
|
||||
func NewAllowNone(logf logger.Logf, logIPs *netipx.IPSet) *Filter {
|
||||
return New(nil, &netipx.IPSet{}, logIPs, nil, logf)
|
||||
return New(nil, nil, &netipx.IPSet{}, logIPs, nil, logf)
|
||||
}
|
||||
|
||||
// NewShieldsUpFilter returns a packet filter that rejects incoming connections.
|
||||
|
@ -174,17 +178,20 @@ func NewShieldsUpFilter(localNets *netipx.IPSet, logIPs *netipx.IPSet, shareStat
|
|||
if shareStateWith != nil && !shareStateWith.shieldsUp {
|
||||
shareStateWith = nil
|
||||
}
|
||||
f := New(nil, localNets, logIPs, shareStateWith, logf)
|
||||
f := New(nil, nil, localNets, logIPs, shareStateWith, logf)
|
||||
f.shieldsUp = true
|
||||
return f
|
||||
}
|
||||
|
||||
// New creates a new packet filter. The filter enforces that incoming
|
||||
// packets must be destined to an IP in localNets, and must be allowed
|
||||
// by matches. If shareStateWith is non-nil, the returned filter
|
||||
// shares state with the previous one, to enable changing rules at
|
||||
// runtime without breaking existing stateful flows.
|
||||
func New(matches []Match, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
// New creates a new packet filter. The filter enforces that incoming packets
|
||||
// must be destined to an IP in localNets, and must be allowed by matches.
|
||||
// The optional capTest func is used to evaluate a Match that uses capabilities.
|
||||
// If nil, such matches will always fail.
|
||||
//
|
||||
// If shareStateWith is non-nil, the returned filter shares state with the
|
||||
// previous one, to enable changing rules at runtime without breaking existing
|
||||
// stateful flows.
|
||||
func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, shareStateWith *Filter, logf logger.Logf) *Filter {
|
||||
var state *filterState
|
||||
if shareStateWith != nil {
|
||||
state = shareStateWith.state
|
||||
|
@ -229,6 +236,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
|
|||
for _, m := range ms {
|
||||
var retm Match
|
||||
retm.IPProto = m.IPProto
|
||||
retm.SrcCaps = m.SrcCaps
|
||||
for _, src := range m.Srcs {
|
||||
if keep(src.Addr()) {
|
||||
retm.Srcs = append(retm.Srcs, src)
|
||||
|
@ -240,7 +248,7 @@ func matchesFamily(ms matches, keep func(netip.Addr) bool) matches {
|
|||
retm.Dsts = append(retm.Dsts, dst)
|
||||
}
|
||||
}
|
||||
if len(retm.Srcs) > 0 && len(retm.Dsts) > 0 {
|
||||
if (len(retm.Srcs) > 0 || len(retm.SrcCaps) > 0) && len(retm.Dsts) > 0 {
|
||||
retm.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(retm.Srcs))
|
||||
ret = append(ret, retm)
|
||||
}
|
||||
|
@ -462,7 +470,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||
// related to an existing ICMP-Echo, TCP, or UDP
|
||||
// session.
|
||||
return Accept, "icmp response ok"
|
||||
} else if f.matches4.matchIPsOnly(q) {
|
||||
} else if f.matches4.matchIPsOnly(q, f.srcIPHasCap) {
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
|
@ -478,7 +486,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||
if !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if f.matches4.match(q) {
|
||||
if f.matches4.match(q, f.srcIPHasCap) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case ipproto.UDP, ipproto.SCTP:
|
||||
|
@ -491,7 +499,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) {
|
|||
if ok {
|
||||
return Accept, "cached"
|
||||
}
|
||||
if f.matches4.match(q) {
|
||||
if f.matches4.match(q, f.srcIPHasCap) {
|
||||
return Accept, "ok"
|
||||
}
|
||||
case ipproto.TSMP:
|
||||
|
@ -522,7 +530,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||
// related to an existing ICMP-Echo, TCP, or UDP
|
||||
// session.
|
||||
return Accept, "icmp response ok"
|
||||
} else if f.matches6.matchIPsOnly(q) {
|
||||
} else if f.matches6.matchIPsOnly(q, f.srcIPHasCap) {
|
||||
// If any port is open to an IP, allow ICMP to it.
|
||||
return Accept, "icmp ok"
|
||||
}
|
||||
|
@ -538,7 +546,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||
if q.IPProto == ipproto.TCP && !q.IsTCPSyn() {
|
||||
return Accept, "tcp non-syn"
|
||||
}
|
||||
if f.matches6.match(q) {
|
||||
if f.matches6.match(q, f.srcIPHasCap) {
|
||||
return Accept, "tcp ok"
|
||||
}
|
||||
case ipproto.UDP, ipproto.SCTP:
|
||||
|
@ -551,7 +559,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) {
|
|||
if ok {
|
||||
return Accept, "cached"
|
||||
}
|
||||
if f.matches6.match(q) {
|
||||
if f.matches6.match(q, f.srcIPHasCap) {
|
||||
return Accept, "ok"
|
||||
}
|
||||
case ipproto.TSMP:
|
||||
|
|
|
@ -40,14 +40,31 @@ const (
|
|||
testDeniedProto ipproto.Proto = 127 // CRUDP, appropriately cruddy
|
||||
)
|
||||
|
||||
func m(srcs []netip.Prefix, dsts []NetPortRange, protos ...ipproto.Proto) Match {
|
||||
if protos == nil {
|
||||
// m returnns a Match with the given srcs and dsts.
|
||||
//
|
||||
// opts can be ipproto.Proto values (if none, defaultProtos is used)
|
||||
// or tailcfg.NodeCapability values. Other values panic.
|
||||
func m(srcs []netip.Prefix, dsts []NetPortRange, opts ...any) Match {
|
||||
var protos []ipproto.Proto
|
||||
var caps []tailcfg.NodeCapability
|
||||
for _, o := range opts {
|
||||
switch o := o.(type) {
|
||||
case ipproto.Proto:
|
||||
protos = append(protos, o)
|
||||
case tailcfg.NodeCapability:
|
||||
caps = append(caps, o)
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown option type %T", o))
|
||||
}
|
||||
}
|
||||
if len(protos) == 0 {
|
||||
protos = defaultProtos
|
||||
}
|
||||
return Match{
|
||||
IPProto: views.SliceOf(protos),
|
||||
Srcs: srcs,
|
||||
SrcsContains: ipset.NewContainsIPFunc(views.SliceOf(srcs)),
|
||||
SrcCaps: caps,
|
||||
Dsts: dsts,
|
||||
}
|
||||
}
|
||||
|
@ -65,6 +82,7 @@ func newFilter(logf logger.Logf) *Filter {
|
|||
m(nets("::/0"), netports("::/0:443")),
|
||||
m(nets("0.0.0.0/0"), netports("0.0.0.0/0:*"), testAllowedProto),
|
||||
m(nets("::/0"), netports("::/0:*"), testAllowedProto),
|
||||
m(nil, netports("1.2.3.4:22"), tailcfg.NodeCapability("cap-hit-1234-ssh")),
|
||||
}
|
||||
|
||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||
|
@ -79,11 +97,17 @@ func newFilter(logf logger.Logf) *Filter {
|
|||
localNetsSet, _ := localNets.IPSet()
|
||||
logBSet, _ := logB.IPSet()
|
||||
|
||||
return New(matches, localNetsSet, logBSet, nil, logf)
|
||||
return New(matches, nil, localNetsSet, logBSet, nil, logf)
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
filt := newFilter(t.Logf)
|
||||
|
||||
ipWithCap := netip.MustParseAddr("10.0.0.1")
|
||||
ipWithoutCap := netip.MustParseAddr("10.0.0.2")
|
||||
filt.srcIPHasCap = func(ip netip.Addr, cap tailcfg.NodeCapability) bool {
|
||||
return cap == "cap-hit-1234-ssh" && ip == ipWithCap
|
||||
}
|
||||
|
||||
type InOut struct {
|
||||
want Response
|
||||
|
@ -139,21 +163,27 @@ func TestFilter(t *testing.T) {
|
|||
{Accept, parsed(testAllowedProto, "2001::1", "2001::2", 0, 0)},
|
||||
{Drop, parsed(testDeniedProto, "1.2.3.4", "5.6.7.8", 0, 0)},
|
||||
{Drop, parsed(testDeniedProto, "2001::1", "2001::2", 0, 0)},
|
||||
|
||||
// Test use of a node capability to grant access.
|
||||
// 10.0.0.1 has the capability; 10.0.0.2 does not (see srcIPHasCap at top of func)
|
||||
{Accept, parsed(ipproto.TCP, ipWithCap.String(), "1.2.3.4", 30000, 22)},
|
||||
{Drop, parsed(ipproto.TCP, ipWithoutCap.String(), "1.2.3.4", 30000, 22)},
|
||||
}
|
||||
for i, test := range tests {
|
||||
aclFunc := acl.runIn4
|
||||
aclFunc := filt.runIn4
|
||||
if test.p.IPVersion == 6 {
|
||||
aclFunc = acl.runIn6
|
||||
aclFunc = filt.runIn6
|
||||
}
|
||||
if got, why := aclFunc(&test.p); test.want != got {
|
||||
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
|
||||
continue
|
||||
}
|
||||
if test.p.IPProto == ipproto.TCP {
|
||||
var got Response
|
||||
if test.p.IPVersion == 4 {
|
||||
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||
} else {
|
||||
got = acl.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||
got = filt.CheckTCP(test.p.Src.Addr(), test.p.Dst.Addr(), test.p.Dst.Port())
|
||||
}
|
||||
if test.want != got {
|
||||
t.Errorf("#%d CheckTCP got=%v want=%v packet:%v", i, got, test.want, test.p)
|
||||
|
@ -165,7 +195,7 @@ func TestFilter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
// Update UDP state
|
||||
_, _ = acl.runOut(&test.p)
|
||||
_, _ = filt.runOut(&test.p)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -264,13 +294,16 @@ func TestParseIPSet(t *testing.T) {
|
|||
{"*", pfx("0.0.0.0/0", "::/0"), ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, err := parseIPSet(tt.host)
|
||||
got, gotCap, err := parseIPSet(tt.host)
|
||||
if err != nil {
|
||||
if err.Error() == tt.wantErr {
|
||||
continue
|
||||
}
|
||||
t.Errorf("parseIPSet(%q) error: %v; want error %q", tt.host, err, tt.wantErr)
|
||||
}
|
||||
if gotCap != "" {
|
||||
t.Errorf("parseIPSet(%q) cap: %q; want empty", tt.host, gotCap)
|
||||
}
|
||||
compareIP := cmp.Comparer(func(a, b netip.Addr) bool { return a == b })
|
||||
compareIPPrefix := cmp.Comparer(func(a, b netip.Prefix) bool { return a == b })
|
||||
if diff := cmp.Diff(got, tt.want, compareIP, compareIPPrefix); diff != "" {
|
||||
|
@ -278,6 +311,27 @@ func TestParseIPSet(t *testing.T) {
|
|||
continue
|
||||
}
|
||||
}
|
||||
|
||||
capTests := []struct {
|
||||
in string
|
||||
want tailcfg.NodeCapability
|
||||
}{
|
||||
{"cap:foo", "foo"},
|
||||
{"cap:people-in-8.8.8.0/24", "people-in-8.8.8.0/24"}, // test precedence of "/" search
|
||||
}
|
||||
for _, tt := range capTests {
|
||||
pfxes, gotCap, err := parseIPSet(tt.in)
|
||||
if err != nil {
|
||||
t.Errorf("parseIPSet(%q) error: %v; want no error", tt.in, err)
|
||||
continue
|
||||
}
|
||||
if gotCap != tt.want {
|
||||
t.Errorf("parseIPSet(%q) cap: %q; want %q", tt.in, gotCap, tt.want)
|
||||
}
|
||||
if len(pfxes) != 0 {
|
||||
t.Errorf("parseIPSet(%q) pfxes: %v; want empty", tt.in, pfxes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilter(b *testing.B) {
|
||||
|
@ -904,7 +958,7 @@ func TestPeerCaps(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
filt := New(mm, nil, nil, nil, t.Logf)
|
||||
filt := New(mm, nil, nil, nil, nil, t.Logf)
|
||||
tests := []struct {
|
||||
name string
|
||||
src, dst string // IP
|
||||
|
@ -1037,7 +1091,7 @@ func benchmarkFile(b *testing.B, file string, opt benchOpt) {
|
|||
logIPs.AddPrefix(tsaddr.CGNATRange())
|
||||
logIPs.AddPrefix(tsaddr.TailscaleULARange())
|
||||
|
||||
f := New(matches, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
||||
f := New(matches, nil, must.Get(localNets.IPSet()), must.Get(logIPs.IPSet()), nil, logger.Discard)
|
||||
var srcIP, dstIP netip.Addr
|
||||
if opt.v4 {
|
||||
srcIP = netip.MustParseAddr("1.2.3.4")
|
||||
|
|
|
@ -66,11 +66,28 @@ type CapMatch struct {
|
|||
// Match matches packets from any IP address in Srcs to any ip:port in
|
||||
// Dsts.
|
||||
type Match struct {
|
||||
IPProto views.Slice[ipproto.Proto] // required set (no default value at this layer)
|
||||
Srcs []netip.Prefix
|
||||
// IPProto is the set of IP protocol numbers for which this match applies.
|
||||
// It is required. There is no default value at this layer.
|
||||
// If empty, it doesn't match.
|
||||
IPProto views.Slice[ipproto.Proto]
|
||||
|
||||
// Srcs is the set of source IP prefixes for which this match applies. A
|
||||
// Match can match by either its source IP address being in Srcs (which
|
||||
// SrcsContains tests) or if the source IP is of a known peer self address
|
||||
// that contains a NodeCapability listed in SrcCaps.
|
||||
Srcs []netip.Prefix
|
||||
// SrcsContains is an optimized function that reports whether Addr is in
|
||||
// Srcs, using the best search method for the size and shape of Srcs.
|
||||
SrcsContains func(netip.Addr) bool `json:"-"` // report whether Addr is in Srcs
|
||||
Dsts []NetPortRange // optional, if Srcs match
|
||||
Caps []CapMatch // optional, if Srcs match
|
||||
|
||||
// SrcCaps is an alternative way to match packets. If the peer's source IP
|
||||
// has one of these capabilities, it's also permitted. The peers are only
|
||||
// looked up by their self address (Node.Addresses) and not by subnet routes
|
||||
// they advertise.
|
||||
SrcCaps []tailcfg.NodeCapability
|
||||
|
||||
Dsts []NetPortRange // optional, if source matches
|
||||
Caps []CapMatch // optional, if source match
|
||||
}
|
||||
|
||||
func (m Match) String() string {
|
||||
|
|
|
@ -23,6 +23,7 @@ func (src *Match) Clone() *Match {
|
|||
*dst = *src
|
||||
dst.IPProto = src.IPProto
|
||||
dst.Srcs = append(src.Srcs[:0:0], src.Srcs...)
|
||||
dst.SrcCaps = append(src.SrcCaps[:0:0], src.SrcCaps...)
|
||||
dst.Dsts = append(src.Dsts[:0:0], src.Dsts...)
|
||||
if src.Caps != nil {
|
||||
dst.Caps = make([]CapMatch, len(src.Caps))
|
||||
|
@ -38,6 +39,7 @@ var _MatchCloneNeedsRegeneration = Match(struct {
|
|||
IPProto views.Slice[ipproto.Proto]
|
||||
Srcs []netip.Prefix
|
||||
SrcsContains func(netip.Addr) bool
|
||||
SrcCaps []tailcfg.NodeCapability
|
||||
Dsts []NetPortRange
|
||||
Caps []CapMatch
|
||||
}{})
|
||||
|
|
|
@ -4,19 +4,23 @@
|
|||
package filter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/wgengine/filter/filtertype"
|
||||
)
|
||||
|
||||
type matches []filtertype.Match
|
||||
|
||||
func (ms matches) match(q *packet.Parsed) bool {
|
||||
for _, m := range ms {
|
||||
func (ms matches) match(q *packet.Parsed, hasCap CapTestFunc) bool {
|
||||
for i := range ms {
|
||||
m := &ms[i]
|
||||
if !views.SliceContains(m.IPProto, q.IPProto) {
|
||||
continue
|
||||
}
|
||||
if !m.SrcsContains(q.Src.Addr()) {
|
||||
if !srcMatches(m, q.Src.Addr(), hasCap) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.Dsts {
|
||||
|
@ -32,9 +36,33 @@ func (ms matches) match(q *packet.Parsed) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
|
||||
// srcMatches reports whether srcAddr matche the src requirements in m, either
|
||||
// by Srcs (using SrcsContains), or by the node having a capability listed
|
||||
// in SrcCaps using the provided hasCap function.
|
||||
func srcMatches(m *filtertype.Match, srcAddr netip.Addr, hasCap CapTestFunc) bool {
|
||||
if m.SrcsContains(srcAddr) {
|
||||
return true
|
||||
}
|
||||
if hasCap != nil {
|
||||
for _, c := range m.SrcCaps {
|
||||
if hasCap(srcAddr, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CapTestFunc is the function signature of a function that tests whether srcIP
|
||||
// has a given capability.
|
||||
//
|
||||
// It it used in the fast path of evaluating filter rules so should be fast.
|
||||
type CapTestFunc = func(srcIP netip.Addr, cap tailcfg.NodeCapability) bool
|
||||
|
||||
func (ms matches) matchIPsOnly(q *packet.Parsed, hasCap CapTestFunc) bool {
|
||||
srcAddr := q.Src.Addr()
|
||||
for _, m := range ms {
|
||||
if !m.SrcsContains(q.Src.Addr()) {
|
||||
if !m.SrcsContains(srcAddr) {
|
||||
continue
|
||||
}
|
||||
for _, dst := range m.Dsts {
|
||||
|
@ -43,6 +71,15 @@ func (ms matches) matchIPsOnly(q *packet.Parsed) bool {
|
|||
}
|
||||
}
|
||||
}
|
||||
if hasCap != nil {
|
||||
for _, m := range ms {
|
||||
for _, c := range m.SrcCaps {
|
||||
if hasCap(srcAddr, c) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
@ -58,12 +58,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
|||
}
|
||||
|
||||
for _, s := range r.SrcIPs {
|
||||
nets, err := parseIPSet(s)
|
||||
nets, cap, err := parseIPSet(s)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
m.Srcs = append(m.Srcs, nets...)
|
||||
if cap != "" {
|
||||
m.SrcCaps = append(m.SrcCaps, cap)
|
||||
}
|
||||
}
|
||||
m.SrcsContains = ipset.NewContainsIPFunc(views.SliceOf(m.Srcs))
|
||||
|
||||
|
@ -71,11 +74,15 @@ func MatchesFromFilterRules(pf []tailcfg.FilterRule) ([]Match, error) {
|
|||
if d.Bits != nil {
|
||||
return nil, fmt.Errorf("unexpected DstBits; control plane should not send this to this client version")
|
||||
}
|
||||
nets, err := parseIPSet(d.IP)
|
||||
nets, cap, err := parseIPSet(d.IP)
|
||||
if err != nil && erracc == nil {
|
||||
erracc = err
|
||||
continue
|
||||
}
|
||||
if cap != "" {
|
||||
erracc = fmt.Errorf("unexpected capability %q in DstPorts", cap)
|
||||
continue
|
||||
}
|
||||
for _, net := range nets {
|
||||
m.Dsts = append(m.Dsts, NetPortRange{
|
||||
Net: net,
|
||||
|
@ -120,48 +127,52 @@ var (
|
|||
// - the string "*" to match everything (both IPv4 & IPv6)
|
||||
// - a CIDR (e.g. "192.168.0.0/16")
|
||||
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
|
||||
// - "cap:<peer-node-capability>" to match a peer node capability
|
||||
//
|
||||
// TODO(bradfitz): make this return an IPSet and plumb that all
|
||||
// around, and ultimately use a new version of IPSet.ContainsFunc like
|
||||
// Contains16Func that works in [16]byte address, so we we can match
|
||||
// at runtime without allocating?
|
||||
func parseIPSet(arg string) ([]netip.Prefix, error) {
|
||||
func parseIPSet(arg string) (prefixes []netip.Prefix, peerCap tailcfg.NodeCapability, err error) {
|
||||
if arg == "*" {
|
||||
// User explicitly requested wildcard.
|
||||
return []netip.Prefix{
|
||||
netip.PrefixFrom(zeroIP4, 0),
|
||||
netip.PrefixFrom(zeroIP6, 0),
|
||||
}, nil
|
||||
}, "", nil
|
||||
}
|
||||
if cap, ok := strings.CutPrefix(arg, "cap:"); ok {
|
||||
return nil, tailcfg.NodeCapability(cap), nil
|
||||
}
|
||||
if strings.Contains(arg, "/") {
|
||||
pfx, err := netip.ParsePrefix(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
if pfx != pfx.Masked() {
|
||||
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
|
||||
return nil, "", fmt.Errorf("%v contains non-network bits set", pfx)
|
||||
}
|
||||
return []netip.Prefix{pfx}, nil
|
||||
return []netip.Prefix{pfx}, "", nil
|
||||
}
|
||||
if strings.Count(arg, "-") == 1 {
|
||||
ip1s, ip2s, _ := strings.Cut(arg, "-")
|
||||
ip1, err := netip.ParseAddr(ip1s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
ip2, err := netip.ParseAddr(ip2s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
r := netipx.IPRangeFrom(ip1, ip2)
|
||||
if !r.IsValid() {
|
||||
return nil, fmt.Errorf("invalid IP range %q", arg)
|
||||
return nil, "", fmt.Errorf("invalid IP range %q", arg)
|
||||
}
|
||||
return r.Prefixes(), nil
|
||||
return r.Prefixes(), "", nil
|
||||
}
|
||||
ip, err := netip.ParseAddr(arg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IP address %q", arg)
|
||||
return nil, "", fmt.Errorf("invalid IP address %q", arg)
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, nil
|
||||
return []netip.Prefix{netip.PrefixFrom(ip, ip.BitLen())}, "", nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue