From d4bfe34ba77b9061c0ce97d62ede45817580bc98 Mon Sep 17 00:00:00 2001 From: Joe Tsai Date: Thu, 21 Mar 2024 11:39:20 -0700 Subject: [PATCH] util/zstdframe: add package for stateless zstd compression (#11481) The Go zstd package is not friendly for stateless zstd compression. Passing around multiple zstd.Encoder just for stateless compression is a waste of memory since the memory is never freed and seldom used if no compression operations are happening. For performance, we pool the relevant Encoder/Decoder with the specific options set. Functionally, this package is a wrapper over the Go zstd package with a more ergonomic API for stateless operations. This package can be used to cleanup various pre-existing zstd.Encoder pools or one-off handlers spread throughout our codebases. Performance: BenchmarkEncode/Best 1690 610926 ns/op 25.78 MB/s 1 B/op 0 allocs/op zstd_test.go:137: memory: 50.336 MiB zstd_test.go:138: ratio: 3.269x BenchmarkEncode/Better 10000 100939 ns/op 156.04 MB/s 0 B/op 0 allocs/op zstd_test.go:137: memory: 20.399 MiB zstd_test.go:138: ratio: 3.131x BenchmarkEncode/Default 15775 74976 ns/op 210.08 MB/s 105 B/op 0 allocs/op zstd_test.go:137: memory: 1.586 MiB zstd_test.go:138: ratio: 3.064x BenchmarkEncode/Fastest 23222 53977 ns/op 291.81 MB/s 26 B/op 0 allocs/op zstd_test.go:137: memory: 599.458 KiB zstd_test.go:138: ratio: 2.898x BenchmarkEncode/FastestLowMemory 23361 50789 ns/op 310.13 MB/s 15 B/op 0 allocs/op zstd_test.go:137: memory: 334.458 KiB zstd_test.go:138: ratio: 2.898x BenchmarkEncode/FastestNoChecksum 23086 50253 ns/op 313.44 MB/s 26 B/op 0 allocs/op zstd_test.go:137: memory: 599.458 KiB zstd_test.go:138: ratio: 2.900x BenchmarkDecode/Checksum 70794 17082 ns/op 300.96 MB/s 4 B/op 0 allocs/op zstd_test.go:163: memory: 316.438 KiB BenchmarkDecode/NoChecksum 74935 15990 ns/op 321.51 MB/s 4 B/op 0 allocs/op zstd_test.go:163: memory: 316.438 KiB BenchmarkDecode/LowMemory 71043 16739 ns/op 307.13 MB/s 0 B/op 0 allocs/op zstd_test.go:163: memory: 79.347 KiB We can see that the options are taking effect where compression ratio improves with higher levels and compression speed diminishes. We can also see that LowMemory takes effect where the pooled coder object references less memory than other cases. We can see that the pooling is taking effect as there are 0 amortized allocations. Additional performance: BenchmarkEncodeParallel/zstd-24 1857 619264 ns/op 1796 B/op 49 allocs/op BenchmarkEncodeParallel/zstdframe-24 1954 532023 ns/op 4293 B/op 49 allocs/op BenchmarkDecodeParallel/zstd-24 5288 197281 ns/op 2516 B/op 49 allocs/op BenchmarkDecodeParallel/zstdframe-24 6441 196254 ns/op 2513 B/op 49 allocs/op In concurrent usage, handling the pooling in this package has a marginal benefit over the zstd package, which relies on a Go channel as the pooling mechanism. In particular, coders can be freed by the GC when not in use. Coders can be shared throughout the program if they use this package instead of multiple independent pools doing the same thing. The allocations are unrelated to pooling as they're caused by the spawning of goroutines. Updates #cleanup Updates tailscale/corp#18514 Updates tailscale/corp#17653 Updates tailscale/corp#18005 Signed-off-by: Joe Tsai --- util/zstdframe/options.go | 183 +++++++++++++++++++++++++++++++ util/zstdframe/zstd.go | 127 ++++++++++++++++++++++ util/zstdframe/zstd_test.go | 209 ++++++++++++++++++++++++++++++++++++ 3 files changed, 519 insertions(+) create mode 100644 util/zstdframe/options.go create mode 100644 util/zstdframe/zstd.go create mode 100644 util/zstdframe/zstd_test.go diff --git a/util/zstdframe/options.go b/util/zstdframe/options.go new file mode 100644 index 000000000..0a8665c84 --- /dev/null +++ b/util/zstdframe/options.go @@ -0,0 +1,183 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package zstdframe + +import ( + "math/bits" + "sync" + + "github.com/klauspost/compress/zstd" + "tailscale.com/util/must" +) + +// Option is an option that can be passed to [AppendEncode] or [AppendDecode]. +type Option interface{ isOption() } + +type encoderLevel int + +// Constants that implement [Option] and can be passed to [AppendEncode]. +const ( + FastestCompression = encoderLevel(zstd.SpeedFastest) + DefaultCompression = encoderLevel(zstd.SpeedDefault) + BetterCompression = encoderLevel(zstd.SpeedBetterCompression) + BestCompression = encoderLevel(zstd.SpeedBestCompression) +) + +func (encoderLevel) isOption() {} + +// EncoderLevel specifies the compression level when encoding. +// +// This exists for compatibility with [zstd.EncoderLevel] values. +// Most usages should directly use one of the following constants: +// - [FastestCompression] +// - [DefaultCompression] +// - [BetterCompression] +// - [BestCompression] +// +// By default, [DefaultCompression] is chosen. +// This option is ignored when decoding. +func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) } + +type withChecksum bool + +func (withChecksum) isOption() {} + +// WithChecksum specifies whether to produce a checksum when encoding, +// or whether to verify the checksum when decoding. +// By default, checksums are produced and verified. +func WithChecksum(check bool) Option { return withChecksum(check) } + +type maxDecodedSize uint64 + +func (maxDecodedSize) isOption() {} + +// MaxDecodedSize specifies the maximum decoded size and +// is used to protect against hostile content. +// By default, there is no limit. +// This option is ignored when encoding. +func MaxDecodedSize(maxSize uint64) Option { + return maxDecodedSize(maxSize) +} + +type lowMemory bool + +func (lowMemory) isOption() {} + +// LowMemory specifies that the encoder and decoder should aim to use +// lower amounts of memory at the cost of speed. +// By default, more memory used for better speed. +func LowMemory(low bool) Option { return lowMemory(low) } + +var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder + +type encoderOptions struct { + level zstd.EncoderLevel + checksum bool + lowMemory bool +} + +type encoder struct { + pool *sync.Pool + *zstd.Encoder +} + +func getEncoder(opts ...Option) encoder { + eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true} + for _, opt := range opts { + switch opt := opt.(type) { + case encoderLevel: + eopts.level = zstd.EncoderLevel(opt) + case withChecksum: + eopts.checksum = bool(opt) + case lowMemory: + eopts.lowMemory = bool(opt) + } + } + + vpool, ok := encoderPools.Load(eopts) + if !ok { + vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool)) + } + pool := vpool.(*sync.Pool) + enc, _ := pool.Get().(*zstd.Encoder) + if enc == nil { + enc = must.Get(zstd.NewWriter(nil, + // Set concurrency=1 to ensure synchronous operation. + zstd.WithEncoderConcurrency(1), + // In stateless compression, the data is already in a single buffer, + // so we might as well encode it as a single segment, + // which ensures that the Frame_Content_Size is always populated, + // informing decoders up-front the expected decompressed size. + zstd.WithSingleSegment(true), + // Ensure strict compliance with RFC 8878, section 3.1., + // where zstandard "is made up of one or more frames". + zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(eopts.level), + zstd.WithEncoderCRC(eopts.checksum), + zstd.WithLowerEncoderMem(eopts.lowMemory))) + } + return encoder{pool, enc} +} + +func putEncoder(e encoder) { e.pool.Put(e.Encoder) } + +var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder + +type decoderOptions struct { + maxSizeLog2 int + checksum bool + lowMemory bool +} + +type decoder struct { + pool *sync.Pool + *zstd.Decoder + + maxSize uint64 +} + +func getDecoder(opts ...Option) decoder { + maxSize := uint64(1 << 63) + dopts := decoderOptions{maxSizeLog2: 63, checksum: true} + for _, opt := range opts { + switch opt := opt.(type) { + case maxDecodedSize: + maxSize = uint64(opt) + dopts.maxSizeLog2 = 64 - bits.LeadingZeros64(maxSize-1) + dopts.maxSizeLog2 = min(max(10, dopts.maxSizeLog2), 63) + case withChecksum: + dopts.checksum = bool(opt) + case lowMemory: + dopts.lowMemory = bool(opt) + } + } + + vpool, ok := decoderPools.Load(dopts) + if !ok { + vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool)) + } + pool := vpool.(*sync.Pool) + dec, _ := pool.Get().(*zstd.Decoder) + if dec == nil { + dec = must.Get(zstd.NewReader(nil, + // Set concurrency=1 to ensure synchronous operation. + zstd.WithDecoderConcurrency(1), + zstd.WithDecoderMaxMemory(1< d.maxSize { + err = zstd.ErrDecoderSizeExceeded + } + return dst2, err +} diff --git a/util/zstdframe/zstd.go b/util/zstdframe/zstd.go new file mode 100644 index 000000000..b20798418 --- /dev/null +++ b/util/zstdframe/zstd.go @@ -0,0 +1,127 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package zstdframe provides functionality for encoding and decoding +// independently compressed zstandard frames. +package zstdframe + +import ( + "encoding/binary" + "io" + + "github.com/klauspost/compress/zstd" +) + +// The Go zstd API surface is not ergonomic: +// +// - Options are set via NewReader and NewWriter and immutable once set. +// +// - Stateless operations like EncodeAll and DecodeAll are methods on +// the Encoder and Decoder types, which implies that options cannot be +// changed without allocating an entirely new Encoder or Decoder. +// +// This is further strange as Encoder and Decoder types are either +// stateful or stateless objects depending on semantic context. +// +// - By default, the zstd package tries to be overly clever by spawning off +// multiple goroutines to do work, which can lead to both excessive fanout +// of resources and also subtle race conditions. Also, each Encoder/Decoder +// never relinquish resources, which makes it unsuitable for lower memory. +// We work around the zstd defaults by setting concurrency=1 on each coder +// and pool individual coders, allowing the Go GC to reclaim unused coders. +// +// See https://github.com/klauspost/compress/issues/264 +// See https://github.com/klauspost/compress/issues/479 +// +// - The EncodeAll and DecodeAll functions appends to a user-provided buffer, +// but uses a signature opposite of most append-like functions in Go, +// where the output buffer is the second argument, leading to footguns. +// The zstdframe package provides AppendEncode and AppendDecode functions +// that follows Go convention of the first argument being the output buffer +// similar to how the builtin append function operates. +// +// See https://github.com/klauspost/compress/issues/648 +// +// - The zstd package is oddly inconsistent about naming. For example, +// IgnoreChecksum vs WithEncoderCRC, or +// WithDecoderLowmem vs WithLowerEncoderMem. +// Most options have a WithDecoder or WithEncoder prefix, but some do not. +// +// The zstdframe package wraps the zstd package and presents a more ergonomic API +// by providing stateless functions that take in variadic options. +// Pooling of resources is handled by this package to avoid each caller +// redundantly performing the same pooling at different call sites. + +// TODO: Since compression is CPU bound, +// should we have a semaphore ensure at most one operation per CPU? + +// AppendEncode appends the zstandard encoded content of src to dst. +// It emits exactly one frame as a single segment. +func AppendEncode(dst, src []byte, opts ...Option) []byte { + enc := getEncoder(opts...) + defer putEncoder(enc) + return enc.EncodeAll(src, dst) +} + +// AppendDecode appends the zstandard decoded content of src to dst. +// The input may consist of zero or more frames. +// Any call that handles untrusted input should specify [MaxDecodedSize]. +func AppendDecode(dst, src []byte, opts ...Option) ([]byte, error) { + dec := getDecoder(opts...) + defer putDecoder(dec) + return dec.DecodeAll(src, dst) +} + +// NextSize parses the next frame (regardless of whether it is a +// data frame or a metadata frame) and returns the total size of the frame. +// The frame can be skipped by slicing n bytes from b (e.g., b[n:]). +// It report [io.ErrUnexpectedEOF] if the frame is incomplete. +func NextSize(b []byte) (n int, err error) { + // Parse the frame header (RFC 8878, section 3.1.1.). + var frame zstd.Header + if err := frame.Decode(b); err != nil { + return n, err + } + n += frame.HeaderSize + + if frame.Skippable { + // Handle skippable frame (RFC 8878, section 3.1.2.). + if len(b[n:]) < int(frame.SkippableSize) { + return n, io.ErrUnexpectedEOF + } + n += int(frame.SkippableSize) + } else { + // Handle one or more Data_Blocks (RFC 8878, section 3.1.1.2.). + for { + if len(b[n:]) < 3 { + return n, io.ErrUnexpectedEOF + } + blockHeader := binary.LittleEndian.Uint32(b[n-1:]) >> 8 // load uint24 + lastBlock := (blockHeader >> 0) & ((1 << 1) - 1) + blockType := (blockHeader >> 1) & ((1 << 2) - 1) + blockSize := (blockHeader >> 3) & ((1 << 21) - 1) + n += 3 + if blockType == 1 { + // For RLE_Block (RFC 8878, section 3.1.1.2.2.), + // the Block_Content is only a single byte. + blockSize = 1 + } + if len(b[n:]) < int(blockSize) { + return n, io.ErrUnexpectedEOF + } + n += int(blockSize) + if lastBlock != 0 { + break + } + } + + // Handle optional Content_Checksum (RFC 8878, section 3.1.1.). + if frame.HasCheckSum { + if len(b[n:]) < 4 { + return n, io.ErrUnexpectedEOF + } + n += 4 + } + } + return n, nil +} diff --git a/util/zstdframe/zstd_test.go b/util/zstdframe/zstd_test.go new file mode 100644 index 000000000..db7b7801f --- /dev/null +++ b/util/zstdframe/zstd_test.go @@ -0,0 +1,209 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package zstdframe + +import ( + "math/bits" + "math/rand/v2" + "os" + "runtime" + "strings" + "sync" + "testing" + + "github.com/klauspost/compress/zstd" + "tailscale.com/util/must" +) + +// Use the concatenation of all Go source files in zstdframe as testdata. +var src = func() (out []byte) { + for _, de := range must.Get(os.ReadDir(".")) { + if strings.HasSuffix(de.Name(), ".go") { + out = append(out, must.Get(os.ReadFile(de.Name()))...) + } + } + return out +}() +var dst []byte +var dsts [][]byte + +// zstdEnc is identical to getEncoder without options, +// except it relies on concurrency managed by the zstd package itself. +var zstdEnc = must.Get(zstd.NewWriter(nil, + zstd.WithEncoderConcurrency(runtime.NumCPU()), + zstd.WithSingleSegment(true), + zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(zstd.SpeedDefault), + zstd.WithEncoderCRC(true), + zstd.WithLowerEncoderMem(false))) + +// zstdDec is identical to getDecoder without options, +// except it relies on concurrency managed by the zstd package itself. +var zstdDec = must.Get(zstd.NewReader(nil, + zstd.WithDecoderConcurrency(runtime.NumCPU()), + zstd.WithDecoderMaxMemory(1<<63), + zstd.IgnoreChecksum(false), + zstd.WithDecoderLowmem(false))) + +var coders = []struct { + name string + appendEncode func([]byte, []byte) []byte + appendDecode func([]byte, []byte) ([]byte, error) +}{{ + name: "zstd", + appendEncode: func(dst, src []byte) []byte { return zstdEnc.EncodeAll(src, dst) }, + appendDecode: func(dst, src []byte) ([]byte, error) { return zstdDec.DecodeAll(src, dst) }, +}, { + name: "zstdframe", + appendEncode: func(dst, src []byte) []byte { return AppendEncode(dst, src) }, + appendDecode: func(dst, src []byte) ([]byte, error) { return AppendDecode(dst, src) }, +}} + +func TestDecodeMaxSize(t *testing.T) { + var enc, dec []byte + zeros := make([]byte, 1<<16, 2<<16) + check := func(encSize, maxDecSize int) { + var gotErr, wantErr error + enc = AppendEncode(enc[:0], zeros[:encSize]) + + // Directly calling zstd.Decoder.DecodeAll may not trigger size check + // since it only operates on closest power-of-two. + dec, gotErr = func() ([]byte, error) { + d := getDecoder(MaxDecodedSize(uint64(maxDecSize))) + defer putDecoder(d) + return d.Decoder.DecodeAll(enc, dec[:0]) // directly call zstd.Decoder.DecodeAll + }() + if encSize > 1<<(64-bits.LeadingZeros64(uint64(maxDecSize)-1)) { + wantErr = zstd.ErrDecoderSizeExceeded + } + if gotErr != wantErr { + t.Errorf("DecodeAll(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr) + } + + // Calling AppendDecode should perform the exact size check. + dec, gotErr = AppendDecode(dec[:0], enc, MaxDecodedSize(uint64(maxDecSize))) + if encSize > maxDecSize { + wantErr = zstd.ErrDecoderSizeExceeded + } + if gotErr != wantErr { + t.Errorf("AppendDecode(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr) + } + } + + rn := rand.New(rand.NewPCG(0, 0)) + for n := 1 << 10; n <= len(zeros); n <<= 1 { + nl := rn.IntN(n + 1) + check(nl, nl) + check(nl, nl-1) + check(nl, (n+nl)/2) + check(nl, n) + check((n+nl)/2, n) + check(n-1, n-1) + check(n-1, n) + check(n-1, n+1) + check(n, n-1) + check(n, n) + check(n, n+1) + check(n+1, n-1) + check(n+1, n) + check(n+1, n+1) + } +} + +func BenchmarkEncode(b *testing.B) { + options := []struct { + name string + opts []Option + }{ + {name: "Best", opts: []Option{BestCompression}}, + {name: "Better", opts: []Option{BetterCompression}}, + {name: "Default", opts: []Option{DefaultCompression}}, + {name: "Fastest", opts: []Option{FastestCompression}}, + {name: "FastestLowMemory", opts: []Option{FastestCompression, LowMemory(true)}}, + {name: "FastestNoChecksum", opts: []Option{FastestCompression, WithChecksum(false)}}, + } + for _, bb := range options { + b.Run(bb.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(src))) + for i := 0; i < b.N; i++ { + dst = AppendEncode(dst[:0], src, bb.opts...) + } + }) + if testing.Verbose() { + ratio := float64(len(src)) / float64(len(dst)) + b.Logf("ratio: %0.3fx", ratio) + } + } +} + +func BenchmarkDecode(b *testing.B) { + options := []struct { + name string + opts []Option + }{ + {name: "Checksum", opts: []Option{WithChecksum(true)}}, + {name: "NoChecksum", opts: []Option{WithChecksum(false)}}, + {name: "LowMemory", opts: []Option{LowMemory(true)}}, + } + src := AppendEncode(nil, src) + for _, bb := range options { + b.Run(bb.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(src))) + for i := 0; i < b.N; i++ { + dst = must.Get(AppendDecode(dst[:0], src, bb.opts...)) + } + }) + } +} + +func BenchmarkEncodeParallel(b *testing.B) { + numCPU := runtime.NumCPU() + for _, coder := range coders { + dsts = dsts[:0] + for i := 0; i < numCPU; i++ { + dsts = append(dsts, coder.appendEncode(nil, src)) + } + b.Run(coder.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var group sync.WaitGroup + for j := 0; j < numCPU; j++ { + group.Add(1) + go func(j int) { + defer group.Done() + dsts[j] = coder.appendEncode(dsts[j][:0], src) + }(j) + } + group.Wait() + } + }) + } +} + +func BenchmarkDecodeParallel(b *testing.B) { + numCPU := runtime.NumCPU() + for _, coder := range coders { + dsts = dsts[:0] + src := AppendEncode(nil, src) + for i := 0; i < numCPU; i++ { + dsts = append(dsts, must.Get(coder.appendDecode(nil, src))) + } + b.Run(coder.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var group sync.WaitGroup + for j := 0; j < numCPU; j++ { + group.Add(1) + go func(j int) { + defer group.Done() + dsts[j] = must.Get(coder.appendDecode(dsts[j][:0], src)) + }(j) + } + group.Wait() + } + }) + } +}