// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package syncs import ( "sync" "golang.org/x/sys/cpu" ) // ShardedMap is a synchronized map[K]V, internally sharded by a user-defined // K-sharding function. // // The zero value is not safe for use; use NewShardedMap. type ShardedMap[K comparable, V any] struct { shardFunc func(K) int shards []mapShard[K, V] } type mapShard[K comparable, V any] struct { mu sync.Mutex m map[K]V _ cpu.CacheLinePad // avoid false sharing of neighboring shards' mutexes } // NewShardedMap returns a new ShardedMap with the given number of shards and // sharding function. // // The shard func must return a integer in the range [0, shards) purely // deterministically based on the provided K. func NewShardedMap[K comparable, V any](shards int, shard func(K) int) *ShardedMap[K, V] { m := &ShardedMap[K, V]{ shardFunc: shard, shards: make([]mapShard[K, V], shards), } for i := range m.shards { m.shards[i].m = make(map[K]V) } return m } func (m *ShardedMap[K, V]) shard(key K) *mapShard[K, V] { return &m.shards[m.shardFunc(key)] } // GetOk returns m[key] and whether it was present. func (m *ShardedMap[K, V]) GetOk(key K) (value V, ok bool) { shard := m.shard(key) shard.mu.Lock() defer shard.mu.Unlock() value, ok = shard.m[key] return } // Get returns m[key] or the zero value of V if key is not present. func (m *ShardedMap[K, V]) Get(key K) (value V) { value, _ = m.GetOk(key) return } // Mutate atomically mutates m[k] by calling mutator. // // The mutator function is called with the old value (or its zero value) and // whether it existed in the map and it returns the new value and whether it // should be set in the map (true) or deleted from the map (false). // // It returns the change in size of the map as a result of the mutation, one of // -1 (delete), 0 (change), or 1 (addition). func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { shard := m.shard(key) shard.mu.Lock() defer shard.mu.Unlock() oldV, oldOK := shard.m[key] newV, newOK := mutator(oldV, oldOK) if newOK { shard.m[key] = newV if oldOK { return 0 } return 1 } delete(shard.m, key) if oldOK { return -1 } return 0 } // Set sets m[key] = value. // // present in m). func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { shard := m.shard(key) shard.mu.Lock() defer shard.mu.Unlock() s0 := len(shard.m) shard.m[key] = value return len(shard.m) > s0 } // Delete removes key from m. // // It reports whether the map size shrunk (that is, whether key was present in // the map). func (m *ShardedMap[K, V]) Delete(key K) (shrunk bool) { shard := m.shard(key) shard.mu.Lock() defer shard.mu.Unlock() s0 := len(shard.m) delete(shard.m, key) return len(shard.m) < s0 } // Contains reports whether m contains key. func (m *ShardedMap[K, V]) Contains(key K) bool { shard := m.shard(key) shard.mu.Lock() defer shard.mu.Unlock() _, ok := shard.m[key] return ok } // Len returns the number of elements in m. // // It does so by locking shards one at a time, so it's not particularly cheap, // nor does it give a consistent snapshot of the map. It's mostly intended for // metrics or testing. func (m *ShardedMap[K, V]) Len() int { n := 0 for i := range m.shards { shard := &m.shards[i] shard.mu.Lock() n += len(shard.m) shard.mu.Unlock() } return n }