diff --git a/util/zstdframe/options.go b/util/zstdframe/options.go new file mode 100644 index 000000000..0a8665c84 --- /dev/null +++ b/util/zstdframe/options.go @@ -0,0 +1,183 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package zstdframe + +import ( + "math/bits" + "sync" + + "github.com/klauspost/compress/zstd" + "tailscale.com/util/must" +) + +// Option is an option that can be passed to [AppendEncode] or [AppendDecode]. +type Option interface{ isOption() } + +type encoderLevel int + +// Constants that implement [Option] and can be passed to [AppendEncode]. +const ( + FastestCompression = encoderLevel(zstd.SpeedFastest) + DefaultCompression = encoderLevel(zstd.SpeedDefault) + BetterCompression = encoderLevel(zstd.SpeedBetterCompression) + BestCompression = encoderLevel(zstd.SpeedBestCompression) +) + +func (encoderLevel) isOption() {} + +// EncoderLevel specifies the compression level when encoding. +// +// This exists for compatibility with [zstd.EncoderLevel] values. +// Most usages should directly use one of the following constants: +// - [FastestCompression] +// - [DefaultCompression] +// - [BetterCompression] +// - [BestCompression] +// +// By default, [DefaultCompression] is chosen. +// This option is ignored when decoding. +func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) } + +type withChecksum bool + +func (withChecksum) isOption() {} + +// WithChecksum specifies whether to produce a checksum when encoding, +// or whether to verify the checksum when decoding. +// By default, checksums are produced and verified. +func WithChecksum(check bool) Option { return withChecksum(check) } + +type maxDecodedSize uint64 + +func (maxDecodedSize) isOption() {} + +// MaxDecodedSize specifies the maximum decoded size and +// is used to protect against hostile content. +// By default, there is no limit. +// This option is ignored when encoding. +func MaxDecodedSize(maxSize uint64) Option { + return maxDecodedSize(maxSize) +} + +type lowMemory bool + +func (lowMemory) isOption() {} + +// LowMemory specifies that the encoder and decoder should aim to use +// lower amounts of memory at the cost of speed. +// By default, more memory used for better speed. +func LowMemory(low bool) Option { return lowMemory(low) } + +var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder + +type encoderOptions struct { + level zstd.EncoderLevel + checksum bool + lowMemory bool +} + +type encoder struct { + pool *sync.Pool + *zstd.Encoder +} + +func getEncoder(opts ...Option) encoder { + eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true} + for _, opt := range opts { + switch opt := opt.(type) { + case encoderLevel: + eopts.level = zstd.EncoderLevel(opt) + case withChecksum: + eopts.checksum = bool(opt) + case lowMemory: + eopts.lowMemory = bool(opt) + } + } + + vpool, ok := encoderPools.Load(eopts) + if !ok { + vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool)) + } + pool := vpool.(*sync.Pool) + enc, _ := pool.Get().(*zstd.Encoder) + if enc == nil { + enc = must.Get(zstd.NewWriter(nil, + // Set concurrency=1 to ensure synchronous operation. + zstd.WithEncoderConcurrency(1), + // In stateless compression, the data is already in a single buffer, + // so we might as well encode it as a single segment, + // which ensures that the Frame_Content_Size is always populated, + // informing decoders up-front the expected decompressed size. + zstd.WithSingleSegment(true), + // Ensure strict compliance with RFC 8878, section 3.1., + // where zstandard "is made up of one or more frames". + zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(eopts.level), + zstd.WithEncoderCRC(eopts.checksum), + zstd.WithLowerEncoderMem(eopts.lowMemory))) + } + return encoder{pool, enc} +} + +func putEncoder(e encoder) { e.pool.Put(e.Encoder) } + +var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder + +type decoderOptions struct { + maxSizeLog2 int + checksum bool + lowMemory bool +} + +type decoder struct { + pool *sync.Pool + *zstd.Decoder + + maxSize uint64 +} + +func getDecoder(opts ...Option) decoder { + maxSize := uint64(1 << 63) + dopts := decoderOptions{maxSizeLog2: 63, checksum: true} + for _, opt := range opts { + switch opt := opt.(type) { + case maxDecodedSize: + maxSize = uint64(opt) + dopts.maxSizeLog2 = 64 - bits.LeadingZeros64(maxSize-1) + dopts.maxSizeLog2 = min(max(10, dopts.maxSizeLog2), 63) + case withChecksum: + dopts.checksum = bool(opt) + case lowMemory: + dopts.lowMemory = bool(opt) + } + } + + vpool, ok := decoderPools.Load(dopts) + if !ok { + vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool)) + } + pool := vpool.(*sync.Pool) + dec, _ := pool.Get().(*zstd.Decoder) + if dec == nil { + dec = must.Get(zstd.NewReader(nil, + // Set concurrency=1 to ensure synchronous operation. + zstd.WithDecoderConcurrency(1), + zstd.WithDecoderMaxMemory(1< d.maxSize { + err = zstd.ErrDecoderSizeExceeded + } + return dst2, err +} diff --git a/util/zstdframe/zstd.go b/util/zstdframe/zstd.go new file mode 100644 index 000000000..b20798418 --- /dev/null +++ b/util/zstdframe/zstd.go @@ -0,0 +1,127 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package zstdframe provides functionality for encoding and decoding +// independently compressed zstandard frames. +package zstdframe + +import ( + "encoding/binary" + "io" + + "github.com/klauspost/compress/zstd" +) + +// The Go zstd API surface is not ergonomic: +// +// - Options are set via NewReader and NewWriter and immutable once set. +// +// - Stateless operations like EncodeAll and DecodeAll are methods on +// the Encoder and Decoder types, which implies that options cannot be +// changed without allocating an entirely new Encoder or Decoder. +// +// This is further strange as Encoder and Decoder types are either +// stateful or stateless objects depending on semantic context. +// +// - By default, the zstd package tries to be overly clever by spawning off +// multiple goroutines to do work, which can lead to both excessive fanout +// of resources and also subtle race conditions. Also, each Encoder/Decoder +// never relinquish resources, which makes it unsuitable for lower memory. +// We work around the zstd defaults by setting concurrency=1 on each coder +// and pool individual coders, allowing the Go GC to reclaim unused coders. +// +// See https://github.com/klauspost/compress/issues/264 +// See https://github.com/klauspost/compress/issues/479 +// +// - The EncodeAll and DecodeAll functions appends to a user-provided buffer, +// but uses a signature opposite of most append-like functions in Go, +// where the output buffer is the second argument, leading to footguns. +// The zstdframe package provides AppendEncode and AppendDecode functions +// that follows Go convention of the first argument being the output buffer +// similar to how the builtin append function operates. +// +// See https://github.com/klauspost/compress/issues/648 +// +// - The zstd package is oddly inconsistent about naming. For example, +// IgnoreChecksum vs WithEncoderCRC, or +// WithDecoderLowmem vs WithLowerEncoderMem. +// Most options have a WithDecoder or WithEncoder prefix, but some do not. +// +// The zstdframe package wraps the zstd package and presents a more ergonomic API +// by providing stateless functions that take in variadic options. +// Pooling of resources is handled by this package to avoid each caller +// redundantly performing the same pooling at different call sites. + +// TODO: Since compression is CPU bound, +// should we have a semaphore ensure at most one operation per CPU? + +// AppendEncode appends the zstandard encoded content of src to dst. +// It emits exactly one frame as a single segment. +func AppendEncode(dst, src []byte, opts ...Option) []byte { + enc := getEncoder(opts...) + defer putEncoder(enc) + return enc.EncodeAll(src, dst) +} + +// AppendDecode appends the zstandard decoded content of src to dst. +// The input may consist of zero or more frames. +// Any call that handles untrusted input should specify [MaxDecodedSize]. +func AppendDecode(dst, src []byte, opts ...Option) ([]byte, error) { + dec := getDecoder(opts...) + defer putDecoder(dec) + return dec.DecodeAll(src, dst) +} + +// NextSize parses the next frame (regardless of whether it is a +// data frame or a metadata frame) and returns the total size of the frame. +// The frame can be skipped by slicing n bytes from b (e.g., b[n:]). +// It report [io.ErrUnexpectedEOF] if the frame is incomplete. +func NextSize(b []byte) (n int, err error) { + // Parse the frame header (RFC 8878, section 3.1.1.). + var frame zstd.Header + if err := frame.Decode(b); err != nil { + return n, err + } + n += frame.HeaderSize + + if frame.Skippable { + // Handle skippable frame (RFC 8878, section 3.1.2.). + if len(b[n:]) < int(frame.SkippableSize) { + return n, io.ErrUnexpectedEOF + } + n += int(frame.SkippableSize) + } else { + // Handle one or more Data_Blocks (RFC 8878, section 3.1.1.2.). + for { + if len(b[n:]) < 3 { + return n, io.ErrUnexpectedEOF + } + blockHeader := binary.LittleEndian.Uint32(b[n-1:]) >> 8 // load uint24 + lastBlock := (blockHeader >> 0) & ((1 << 1) - 1) + blockType := (blockHeader >> 1) & ((1 << 2) - 1) + blockSize := (blockHeader >> 3) & ((1 << 21) - 1) + n += 3 + if blockType == 1 { + // For RLE_Block (RFC 8878, section 3.1.1.2.2.), + // the Block_Content is only a single byte. + blockSize = 1 + } + if len(b[n:]) < int(blockSize) { + return n, io.ErrUnexpectedEOF + } + n += int(blockSize) + if lastBlock != 0 { + break + } + } + + // Handle optional Content_Checksum (RFC 8878, section 3.1.1.). + if frame.HasCheckSum { + if len(b[n:]) < 4 { + return n, io.ErrUnexpectedEOF + } + n += 4 + } + } + return n, nil +} diff --git a/util/zstdframe/zstd_test.go b/util/zstdframe/zstd_test.go new file mode 100644 index 000000000..db7b7801f --- /dev/null +++ b/util/zstdframe/zstd_test.go @@ -0,0 +1,209 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package zstdframe + +import ( + "math/bits" + "math/rand/v2" + "os" + "runtime" + "strings" + "sync" + "testing" + + "github.com/klauspost/compress/zstd" + "tailscale.com/util/must" +) + +// Use the concatenation of all Go source files in zstdframe as testdata. +var src = func() (out []byte) { + for _, de := range must.Get(os.ReadDir(".")) { + if strings.HasSuffix(de.Name(), ".go") { + out = append(out, must.Get(os.ReadFile(de.Name()))...) + } + } + return out +}() +var dst []byte +var dsts [][]byte + +// zstdEnc is identical to getEncoder without options, +// except it relies on concurrency managed by the zstd package itself. +var zstdEnc = must.Get(zstd.NewWriter(nil, + zstd.WithEncoderConcurrency(runtime.NumCPU()), + zstd.WithSingleSegment(true), + zstd.WithZeroFrames(true), + zstd.WithEncoderLevel(zstd.SpeedDefault), + zstd.WithEncoderCRC(true), + zstd.WithLowerEncoderMem(false))) + +// zstdDec is identical to getDecoder without options, +// except it relies on concurrency managed by the zstd package itself. +var zstdDec = must.Get(zstd.NewReader(nil, + zstd.WithDecoderConcurrency(runtime.NumCPU()), + zstd.WithDecoderMaxMemory(1<<63), + zstd.IgnoreChecksum(false), + zstd.WithDecoderLowmem(false))) + +var coders = []struct { + name string + appendEncode func([]byte, []byte) []byte + appendDecode func([]byte, []byte) ([]byte, error) +}{{ + name: "zstd", + appendEncode: func(dst, src []byte) []byte { return zstdEnc.EncodeAll(src, dst) }, + appendDecode: func(dst, src []byte) ([]byte, error) { return zstdDec.DecodeAll(src, dst) }, +}, { + name: "zstdframe", + appendEncode: func(dst, src []byte) []byte { return AppendEncode(dst, src) }, + appendDecode: func(dst, src []byte) ([]byte, error) { return AppendDecode(dst, src) }, +}} + +func TestDecodeMaxSize(t *testing.T) { + var enc, dec []byte + zeros := make([]byte, 1<<16, 2<<16) + check := func(encSize, maxDecSize int) { + var gotErr, wantErr error + enc = AppendEncode(enc[:0], zeros[:encSize]) + + // Directly calling zstd.Decoder.DecodeAll may not trigger size check + // since it only operates on closest power-of-two. + dec, gotErr = func() ([]byte, error) { + d := getDecoder(MaxDecodedSize(uint64(maxDecSize))) + defer putDecoder(d) + return d.Decoder.DecodeAll(enc, dec[:0]) // directly call zstd.Decoder.DecodeAll + }() + if encSize > 1<<(64-bits.LeadingZeros64(uint64(maxDecSize)-1)) { + wantErr = zstd.ErrDecoderSizeExceeded + } + if gotErr != wantErr { + t.Errorf("DecodeAll(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr) + } + + // Calling AppendDecode should perform the exact size check. + dec, gotErr = AppendDecode(dec[:0], enc, MaxDecodedSize(uint64(maxDecSize))) + if encSize > maxDecSize { + wantErr = zstd.ErrDecoderSizeExceeded + } + if gotErr != wantErr { + t.Errorf("AppendDecode(AppendEncode(%d), %d) error = %v, want %v", encSize, maxDecSize, gotErr, wantErr) + } + } + + rn := rand.New(rand.NewPCG(0, 0)) + for n := 1 << 10; n <= len(zeros); n <<= 1 { + nl := rn.IntN(n + 1) + check(nl, nl) + check(nl, nl-1) + check(nl, (n+nl)/2) + check(nl, n) + check((n+nl)/2, n) + check(n-1, n-1) + check(n-1, n) + check(n-1, n+1) + check(n, n-1) + check(n, n) + check(n, n+1) + check(n+1, n-1) + check(n+1, n) + check(n+1, n+1) + } +} + +func BenchmarkEncode(b *testing.B) { + options := []struct { + name string + opts []Option + }{ + {name: "Best", opts: []Option{BestCompression}}, + {name: "Better", opts: []Option{BetterCompression}}, + {name: "Default", opts: []Option{DefaultCompression}}, + {name: "Fastest", opts: []Option{FastestCompression}}, + {name: "FastestLowMemory", opts: []Option{FastestCompression, LowMemory(true)}}, + {name: "FastestNoChecksum", opts: []Option{FastestCompression, WithChecksum(false)}}, + } + for _, bb := range options { + b.Run(bb.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(src))) + for i := 0; i < b.N; i++ { + dst = AppendEncode(dst[:0], src, bb.opts...) + } + }) + if testing.Verbose() { + ratio := float64(len(src)) / float64(len(dst)) + b.Logf("ratio: %0.3fx", ratio) + } + } +} + +func BenchmarkDecode(b *testing.B) { + options := []struct { + name string + opts []Option + }{ + {name: "Checksum", opts: []Option{WithChecksum(true)}}, + {name: "NoChecksum", opts: []Option{WithChecksum(false)}}, + {name: "LowMemory", opts: []Option{LowMemory(true)}}, + } + src := AppendEncode(nil, src) + for _, bb := range options { + b.Run(bb.name, func(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(src))) + for i := 0; i < b.N; i++ { + dst = must.Get(AppendDecode(dst[:0], src, bb.opts...)) + } + }) + } +} + +func BenchmarkEncodeParallel(b *testing.B) { + numCPU := runtime.NumCPU() + for _, coder := range coders { + dsts = dsts[:0] + for i := 0; i < numCPU; i++ { + dsts = append(dsts, coder.appendEncode(nil, src)) + } + b.Run(coder.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var group sync.WaitGroup + for j := 0; j < numCPU; j++ { + group.Add(1) + go func(j int) { + defer group.Done() + dsts[j] = coder.appendEncode(dsts[j][:0], src) + }(j) + } + group.Wait() + } + }) + } +} + +func BenchmarkDecodeParallel(b *testing.B) { + numCPU := runtime.NumCPU() + for _, coder := range coders { + dsts = dsts[:0] + src := AppendEncode(nil, src) + for i := 0; i < numCPU; i++ { + dsts = append(dsts, must.Get(coder.appendDecode(nil, src))) + } + b.Run(coder.name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + var group sync.WaitGroup + for j := 0; j < numCPU; j++ { + group.Add(1) + go func(j int) { + defer group.Done() + dsts[j] = must.Get(coder.appendDecode(dsts[j][:0], src)) + }(j) + } + group.Wait() + } + }) + } +}