filter: prevent escape of QDecode to the heap (#417)
Performance impact: name old time/op new time/op delta Filter/tcp_in-4 70.7ns ± 1% 30.9ns ± 1% -56.30% (p=0.008 n=5+5) Filter/tcp_out-4 58.6ns ± 0% 19.4ns ± 0% -66.87% (p=0.000 n=5+4) Filter/udp_in-4 96.8ns ± 2% 55.5ns ± 0% -42.64% (p=0.016 n=5+4) Filter/udp_out-4 120ns ± 1% 79ns ± 1% -33.87% (p=0.008 n=5+5) Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
parent
83b6b06cc4
commit
73c40c77b0
|
@ -138,16 +138,26 @@ var acceptBucket = rate.NewLimiter(rate.Every(10*time.Second), 3)
|
|||
var dropBucket = rate.NewLimiter(rate.Every(5*time.Second), 10)
|
||||
|
||||
func (f *Filter) logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) {
|
||||
var verdict string
|
||||
|
||||
if r == Drop && (runflags&LogDrops) != 0 && dropBucket.Allow() {
|
||||
verdict = "Drop"
|
||||
runflags &= HexdumpDrops
|
||||
} else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.Allow() {
|
||||
verdict = "Accept"
|
||||
runflags &= HexdumpAccepts
|
||||
}
|
||||
|
||||
// Note: it is crucial that q.String() be called only if {accept,drop}Bucket.Allow() passes,
|
||||
// since it causes an allocation.
|
||||
if verdict != "" {
|
||||
var qs string
|
||||
if q == nil {
|
||||
qs = fmt.Sprintf("(%d bytes)", len(b))
|
||||
} else {
|
||||
qs = q.String()
|
||||
}
|
||||
f.logf("Drop: %v %v %s\n%s", qs, len(b), why, maybeHexdump(runflags&HexdumpDrops, b))
|
||||
} else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.Allow() {
|
||||
f.logf("Accept: %v %v %s\n%s", q, len(b), why, maybeHexdump(runflags&HexdumpAccepts, b))
|
||||
f.logf("%s: %s %d %s\n%s", verdict, qs, len(b), why, maybeHexdump(runflags, b))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,7 +264,7 @@ func (f *Filter) pre(b []byte, q *packet.QDecode, rf RunFlags) Response {
|
|||
|
||||
if q.IPProto == packet.Junk {
|
||||
// Junk packets are dangerous; always drop them.
|
||||
f.logRateLimit(rf, b, q, Drop, "junk!")
|
||||
f.logRateLimit(rf, b, q, Drop, "junk")
|
||||
return Drop
|
||||
} else if q.IPProto == packet.Fragment {
|
||||
// Fragments after the first always need to be passed through.
|
||||
|
|
|
@ -7,9 +7,9 @@ package filter
|
|||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/wgengine/packet"
|
||||
)
|
||||
|
||||
|
@ -43,26 +43,29 @@ func netpr(ip IP, bits int, start, end uint16) []NetPortRange {
|
|||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
mm := Matches{
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{
|
||||
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
|
||||
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
|
||||
}},
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
|
||||
{Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
|
||||
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
|
||||
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
|
||||
{Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
|
||||
}
|
||||
var matches = Matches{
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: []NetPortRange{
|
||||
NetPortRange{Net{0x01020304, Netmask(32)}, PortRange{22, 22}},
|
||||
NetPortRange{Net{0x05060708, Netmask(32)}, PortRange{23, 24}},
|
||||
}},
|
||||
{Srcs: nets([]IP{0x08010101, 0x08020202}), Dsts: ippr(0x05060708, 27, 28)},
|
||||
{Srcs: nets([]IP{0x02020202}), Dsts: ippr(0x08010101, 22, 22)},
|
||||
{Srcs: []Net{NetAny}, Dsts: ippr(0x647a6232, 0, 65535)},
|
||||
{Srcs: []Net{NetAny}, Dsts: netpr(0, 0, 443, 443)},
|
||||
{Srcs: nets([]IP{0x99010101, 0x99010102, 0x99030303}), Dsts: ippr(0x01020304, 999, 999)},
|
||||
}
|
||||
|
||||
func newFilter(logf logger.Logf) *Filter {
|
||||
// Expects traffic to 100.122.98.50, 1.2.3.4, 5.6.7.8,
|
||||
// 102.102.102.102, 119.119.119.119, 8.1.0.0/16
|
||||
localNets := nets([]IP{0x647a6232, 0x01020304, 0x05060708, 0x66666666, 0x77777777})
|
||||
localNets = append(localNets, Net{IP(0x08010000), Netmask(16)})
|
||||
|
||||
acl := New(mm, localNets, nil, t.Logf)
|
||||
return New(matches, localNets, nil, logf)
|
||||
}
|
||||
|
||||
for _, ent := range []Matches{Matches{mm[0]}, mm} {
|
||||
func TestMarshal(t *testing.T) {
|
||||
for _, ent := range []Matches{Matches{matches[0]}, matches} {
|
||||
b, err := json.Marshal(ent)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal: %v", err)
|
||||
|
@ -73,7 +76,10 @@ func TestFilter(t *testing.T) {
|
|||
t.Fatalf("unmarshal: %v (%v)", err, string(b))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
// check packet filtering based on the table
|
||||
|
||||
type InOut struct {
|
||||
|
@ -116,6 +122,83 @@ func TestFilter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestNoAllocs(t *testing.T) {
|
||||
acl := newFilter(t.Logf)
|
||||
|
||||
tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in bool
|
||||
want int
|
||||
packet []byte
|
||||
}{
|
||||
{"tcp_in", true, 0, tcpPacket},
|
||||
{"tcp_out", false, 0, tcpPacket},
|
||||
{"udp_in", true, 0, udpPacket},
|
||||
// One alloc is inevitable (an lru cache update)
|
||||
{"udp_out", false, 1, udpPacket},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
got := int(testing.AllocsPerRun(1000, func() {
|
||||
var q QDecode
|
||||
if test.in {
|
||||
acl.RunIn(test.packet, &q, 0)
|
||||
} else {
|
||||
acl.RunOut(test.packet, &q, 0)
|
||||
}
|
||||
}))
|
||||
|
||||
if got > test.want {
|
||||
t.Errorf("got %d allocs per run; want at most %d", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilter(b *testing.B) {
|
||||
acl := newFilter(b.Logf)
|
||||
|
||||
tcpPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
udpPacket := rawpacket(UDP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
icmpPacket := rawpacket(ICMP, 0x08010101, 0x01020304, 0, 0, 0)
|
||||
|
||||
tcpSynPacket := rawpacket(TCP, 0x08010101, 0x01020304, 999, 22, 0)
|
||||
// TCP filtering is trivial (Accept) for non-SYN packets.
|
||||
tcpSynPacket[33] = packet.TCPSyn
|
||||
|
||||
benches := []struct {
|
||||
name string
|
||||
in bool
|
||||
packet []byte
|
||||
}{
|
||||
// Non-SYN TCP and ICMP have similar code paths in and out.
|
||||
{"icmp", true, icmpPacket},
|
||||
{"tcp", true, tcpPacket},
|
||||
{"tcp_syn_in", true, tcpSynPacket},
|
||||
{"tcp_syn_out", false, tcpSynPacket},
|
||||
{"udp_in", true, udpPacket},
|
||||
{"udp_out", false, udpPacket},
|
||||
}
|
||||
|
||||
for _, bench := range benches {
|
||||
b.Run(bench.name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var q QDecode
|
||||
// This branch seems to have no measurable impact on performance.
|
||||
if bench.in {
|
||||
acl.RunIn(bench.packet, &q, 0)
|
||||
} else {
|
||||
acl.RunOut(bench.packet, &q, 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreFilter(t *testing.T) {
|
||||
packets := []struct {
|
||||
desc string
|
||||
|
@ -124,11 +207,11 @@ func TestPreFilter(t *testing.T) {
|
|||
}{
|
||||
{"empty", Accept, []byte{}},
|
||||
{"short", Drop, []byte("short")},
|
||||
{"junk", Drop, rawpacket(Junk, 10)},
|
||||
{"fragment", Accept, rawpacket(Fragment, 40)},
|
||||
{"tcp", noVerdict, rawpacket(TCP, 200)},
|
||||
{"udp", noVerdict, rawpacket(UDP, 200)},
|
||||
{"icmp", noVerdict, rawpacket(ICMP, 200)},
|
||||
{"junk", Drop, rawdefault(Junk, 10)},
|
||||
{"fragment", Accept, rawdefault(Fragment, 40)},
|
||||
{"tcp", noVerdict, rawdefault(TCP, 200)},
|
||||
{"udp", noVerdict, rawdefault(UDP, 200)},
|
||||
{"icmp", noVerdict, rawdefault(ICMP, 200)},
|
||||
}
|
||||
f := NewAllowNone(t.Logf)
|
||||
for _, testPacket := range packets {
|
||||
|
@ -150,22 +233,38 @@ func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDec
|
|||
}
|
||||
}
|
||||
|
||||
func rawpacket(proto packet.IPProto, len uint16) []byte {
|
||||
bl := len
|
||||
if len < 24 {
|
||||
bl = 24
|
||||
// rawpacket generates a packet with given source and destination ports and IPs
|
||||
// and resizes the header to trimLength if it is nonzero.
|
||||
func rawpacket(proto packet.IPProto, src, dst packet.IP, sport, dport uint16, trimLength int) []byte {
|
||||
var headerLength int
|
||||
|
||||
switch proto {
|
||||
case ICMP:
|
||||
headerLength = 24
|
||||
case TCP:
|
||||
headerLength = 40
|
||||
case UDP:
|
||||
headerLength = 28
|
||||
default:
|
||||
headerLength = 24
|
||||
}
|
||||
if trimLength > headerLength {
|
||||
headerLength = trimLength
|
||||
}
|
||||
if trimLength == 0 {
|
||||
trimLength = headerLength
|
||||
}
|
||||
|
||||
bin := binary.BigEndian
|
||||
hdr := make([]byte, bl)
|
||||
hdr := make([]byte, headerLength)
|
||||
hdr[0] = 0x45
|
||||
bin.PutUint16(hdr[2:4], len)
|
||||
bin.PutUint16(hdr[2:4], uint16(trimLength))
|
||||
hdr[8] = 64
|
||||
ip := net.IPv4(8, 8, 8, 8).To4()
|
||||
copy(hdr[12:16], ip)
|
||||
copy(hdr[16:20], ip)
|
||||
bin.PutUint32(hdr[12:16], uint32(src))
|
||||
bin.PutUint32(hdr[16:20], uint32(dst))
|
||||
// ports
|
||||
bin.PutUint16(hdr[20:22], 53)
|
||||
bin.PutUint16(hdr[22:24], 53)
|
||||
bin.PutUint16(hdr[20:22], sport)
|
||||
bin.PutUint16(hdr[22:24], dport)
|
||||
|
||||
switch proto {
|
||||
case ICMP:
|
||||
|
@ -183,8 +282,15 @@ func rawpacket(proto packet.IPProto, len uint16) []byte {
|
|||
panic("unknown protocol")
|
||||
}
|
||||
|
||||
// Truncate the header if requested
|
||||
hdr = hdr[:len]
|
||||
// Trim the header if requested
|
||||
hdr = hdr[:trimLength]
|
||||
|
||||
return hdr
|
||||
}
|
||||
|
||||
// rawdefault calls rawpacket with default ports and IPs.
|
||||
func rawdefault(proto packet.IPProto, trimLength int) []byte {
|
||||
ip := IP(0x08080808) // 8.8.8.8
|
||||
port := uint16(53)
|
||||
return rawpacket(proto, ip, ip, port, port, trimLength)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue