// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package zstdframe import ( "math/bits" "strconv" "sync" "github.com/klauspost/compress/zstd" "tailscale.com/util/must" ) // Option is an option that can be passed to [AppendEncode] or [AppendDecode]. type Option interface{ isOption() } type encoderLevel int // Constants that implement [Option] and can be passed to [AppendEncode]. const ( FastestCompression = encoderLevel(zstd.SpeedFastest) DefaultCompression = encoderLevel(zstd.SpeedDefault) BetterCompression = encoderLevel(zstd.SpeedBetterCompression) BestCompression = encoderLevel(zstd.SpeedBestCompression) ) func (encoderLevel) isOption() {} // EncoderLevel specifies the compression level when encoding. // // This exists for compatibility with [zstd.EncoderLevel] values. // Most usages should directly use one of the following constants: // - [FastestCompression] // - [DefaultCompression] // - [BetterCompression] // - [BestCompression] // // By default, [DefaultCompression] is chosen. // This option is ignored when decoding. func EncoderLevel(level zstd.EncoderLevel) Option { return encoderLevel(level) } type withChecksum bool func (withChecksum) isOption() {} // WithChecksum specifies whether to produce a checksum when encoding, // or whether to verify the checksum when decoding. // By default, checksums are produced and verified. func WithChecksum(check bool) Option { return withChecksum(check) } type maxDecodedSize uint64 func (maxDecodedSize) isOption() {} type maxDecodedSizeLog2 uint8 // uint8 avoids allocation when storing into interface func (maxDecodedSizeLog2) isOption() {} // MaxDecodedSize specifies the maximum decoded size and // is used to protect against hostile content. // By default, there is no limit. // This option is ignored when encoding. func MaxDecodedSize(maxSize uint64) Option { if bits.OnesCount64(maxSize) == 1 { return maxDecodedSizeLog2(log2(maxSize)) } return maxDecodedSize(maxSize) } type maxWindowSizeLog2 uint8 // uint8 avoids allocation when storing into interface func (maxWindowSizeLog2) isOption() {} // MaxWindowSize specifies the maximum window size, which must be a power-of-two // and be in the range of [[zstd.MinWindowSize], [zstd.MaxWindowSize]]. // // The compression or decompression algorithm will use a LZ77 rolling window // no larger than the specified size. The compression ratio will be // adversely affected, but memory requirements will be lower. // When decompressing, an error is reported if a LZ77 back reference exceeds // the specified maximum window size. // // For decompression, [MaxDecodedSize] is generally more useful. func MaxWindowSize(maxSize uint64) Option { switch { case maxSize < zstd.MinWindowSize: panic("maximum window size cannot be less than " + strconv.FormatUint(zstd.MinWindowSize, 10)) case bits.OnesCount64(maxSize) != 1: panic("maximum window size must be a power-of-two") case maxSize > zstd.MaxWindowSize: panic("maximum window size cannot be greater than " + strconv.FormatUint(zstd.MaxWindowSize, 10)) default: return maxWindowSizeLog2(log2(maxSize)) } } type lowMemory bool func (lowMemory) isOption() {} // LowMemory specifies that the encoder and decoder should aim to use // lower amounts of memory at the cost of speed. // By default, more memory used for better speed. func LowMemory(low bool) Option { return lowMemory(low) } var encoderPools sync.Map // map[encoderOptions]*sync.Pool -> *zstd.Encoder type encoderOptions struct { level zstd.EncoderLevel maxWindowLog2 uint8 checksum bool lowMemory bool } type encoder struct { pool *sync.Pool *zstd.Encoder } func getEncoder(opts ...Option) encoder { eopts := encoderOptions{level: zstd.SpeedDefault, checksum: true} for _, opt := range opts { switch opt := opt.(type) { case encoderLevel: eopts.level = zstd.EncoderLevel(opt) case maxWindowSizeLog2: eopts.maxWindowLog2 = uint8(opt) case withChecksum: eopts.checksum = bool(opt) case lowMemory: eopts.lowMemory = bool(opt) } } vpool, ok := encoderPools.Load(eopts) if !ok { vpool, _ = encoderPools.LoadOrStore(eopts, new(sync.Pool)) } pool := vpool.(*sync.Pool) enc, _ := pool.Get().(*zstd.Encoder) if enc == nil { var noopts int zopts := [...]zstd.EOption{ // Set concurrency=1 to ensure synchronous operation. zstd.WithEncoderConcurrency(1), // In stateless compression, the data is already in a single buffer, // so we might as well encode it as a single segment, // which ensures that the Frame_Content_Size is always populated, // informing decoders up-front the expected decompressed size. zstd.WithSingleSegment(true), // Ensure strict compliance with RFC 8878, section 3.1., // where zstandard "is made up of one or more frames". zstd.WithZeroFrames(true), zstd.WithEncoderLevel(eopts.level), zstd.WithEncoderCRC(eopts.checksum), zstd.WithLowerEncoderMem(eopts.lowMemory), nil, // reserved for zstd.WithWindowSize } if eopts.maxWindowLog2 > 0 { zopts[len(zopts)-noopts-1] = zstd.WithWindowSize(1 << eopts.maxWindowLog2) } else { noopts++ } enc = must.Get(zstd.NewWriter(nil, zopts[:len(zopts)-noopts]...)) } return encoder{pool, enc} } func putEncoder(e encoder) { e.pool.Put(e.Encoder) } var decoderPools sync.Map // map[decoderOptions]*sync.Pool -> *zstd.Decoder type decoderOptions struct { maxSizeLog2 uint8 maxWindowLog2 uint8 checksum bool lowMemory bool } type decoder struct { pool *sync.Pool *zstd.Decoder maxSize uint64 } func getDecoder(opts ...Option) decoder { maxSize := uint64(1 << 63) dopts := decoderOptions{maxSizeLog2: 63, checksum: true} for _, opt := range opts { switch opt := opt.(type) { case maxDecodedSizeLog2: maxSize = 1 << uint8(opt) dopts.maxSizeLog2 = uint8(opt) case maxDecodedSize: maxSize = uint64(opt) dopts.maxSizeLog2 = uint8(log2(maxSize)) case maxWindowSizeLog2: dopts.maxWindowLog2 = uint8(opt) case withChecksum: dopts.checksum = bool(opt) case lowMemory: dopts.lowMemory = bool(opt) } } vpool, ok := decoderPools.Load(dopts) if !ok { vpool, _ = decoderPools.LoadOrStore(dopts, new(sync.Pool)) } pool := vpool.(*sync.Pool) dec, _ := pool.Get().(*zstd.Decoder) if dec == nil { var noopts int zopts := [...]zstd.DOption{ // Set concurrency=1 to ensure synchronous operation. zstd.WithDecoderConcurrency(1), zstd.WithDecoderMaxMemory(1 << min(max(10, dopts.maxSizeLog2), 63)), zstd.IgnoreChecksum(!dopts.checksum), zstd.WithDecoderLowmem(dopts.lowMemory), nil, // reserved for zstd.WithDecoderMaxWindow } if dopts.maxWindowLog2 > 0 { zopts[len(zopts)-noopts-1] = zstd.WithDecoderMaxWindow(1 << dopts.maxWindowLog2) } else { noopts++ } dec = must.Get(zstd.NewReader(nil, zopts[:len(zopts)-noopts]...)) } return decoder{pool, dec, maxSize} } func putDecoder(d decoder) { d.pool.Put(d.Decoder) } func (d decoder) DecodeAll(src, dst []byte) ([]byte, error) { // We only configure DecodeAll to enforce MaxDecodedSize by powers-of-two. // Perform a more fine grain check based on the exact value. dst2, err := d.Decoder.DecodeAll(src, dst) if err == nil && uint64(len(dst2)-len(dst)) > d.maxSize { err = zstd.ErrDecoderSizeExceeded } return dst2, err } // log2 computes log2 of x rounded up to the nearest integer. func log2(x uint64) int { return 64 - bits.LeadingZeros64(x-1) }