From dc7aa98b768bf82017aa5cc82a62dd4d685f811d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 9 Sep 2023 09:55:57 -0700 Subject: [PATCH] all: use set.Set consistently instead of map[T]struct{} I didn't clean up the more idiomatic map[T]bool with true values, at least yet. I just converted the relatively awkward struct{}-valued maps. Updates #cleanup Change-Id: I758abebd2bb1f64bc7a9d0f25c32298f4679c14f Signed-off-by: Brad Fitzpatrick --- health/health.go | 4 ++-- health/health_test.go | 4 +++- net/dns/nrpt_windows.go | 7 ++++--- net/tstun/wrap.go | 3 ++- tka/aum.go | 7 ++++--- tka/tka.go | 7 ++++--- util/set/set.go | 3 +++ util/winutil/svcdiag_windows.go | 7 ++++--- wgengine/magicsock/magicsock.go | 7 ++++--- wgengine/magicsock/magicsock_test.go | 5 +++-- wgengine/router/router_openbsd.go | 7 ++++--- wgengine/userspace.go | 5 +++-- 12 files changed, 40 insertions(+), 26 deletions(-) diff --git a/health/health.go b/health/health.go index 8b95c9194..e0881d810 100644 --- a/health/health.go +++ b/health/health.go @@ -27,7 +27,7 @@ var ( sysErr = map[Subsystem]error{} // error key => err (or nil for no error) watchers = set.HandleSet[func(Subsystem, error)]{} // opt func to run if error state changes - warnables = map[*Warnable]struct{}{} // set of warnables + warnables = set.Set[*Warnable]{} timer *time.Timer debugHandler = map[string]http.Handler{} @@ -84,7 +84,7 @@ func NewWarnable(opts ...WarnableOpt) *Warnable { } mu.Lock() defer mu.Unlock() - warnables[w] = struct{}{} + warnables.Add(w) return w } diff --git a/health/health_test.go b/health/health_test.go index 221a37dec..78d1422a2 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -8,6 +8,8 @@ import ( "fmt" "reflect" "testing" + + "tailscale.com/util/set" ) func TestAppendWarnableDebugFlags(t *testing.T) { @@ -35,5 +37,5 @@ func TestAppendWarnableDebugFlags(t *testing.T) { func resetWarnables() { mu.Lock() defer mu.Unlock() - warnables = make(map[*Warnable]struct{}) + warnables = set.Set[*Warnable]{} } diff --git a/net/dns/nrpt_windows.go b/net/dns/nrpt_windows.go index f81cdb42f..78a702616 100644 --- a/net/dns/nrpt_windows.go +++ b/net/dns/nrpt_windows.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/windows/registry" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/set" "tailscale.com/util/winutil" ) @@ -158,14 +159,14 @@ func (db *nrptRuleDatabase) detectWriteAsGP() { } // Add *all* rules from the GP subkey into a set. - gpSubkeyMap := make(map[string]struct{}, len(gpSubkeyNames)) + gpSubkeyMap := make(set.Set[string], len(gpSubkeyNames)) for _, gpSubkey := range gpSubkeyNames { - gpSubkeyMap[strings.ToUpper(gpSubkey)] = struct{}{} + gpSubkeyMap.Add(strings.ToUpper(gpSubkey)) } // Remove *our* rules from the set. for _, ourRuleID := range db.ruleIDs { - delete(gpSubkeyMap, strings.ToUpper(ourRuleID)) + gpSubkeyMap.Delete(strings.ToUpper(ourRuleID)) } // Any leftover rules do not belong to us. When group policy is being used diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 20f54744d..d2c5b32f6 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -35,6 +35,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" + "tailscale.com/util/set" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -589,7 +590,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config) *natV4Config { var ( rt table.RoutingTableBuilder dstMasqAddrs map[key.NodePublic]netip.Addr - listenAddrs map[netip.Addr]struct{} + listenAddrs set.Set[netip.Addr] ) // When using an exit node that requires masquerading, we need to diff --git a/tka/aum.go b/tka/aum.go index a4e2ff55f..d1f079398 100644 --- a/tka/aum.go +++ b/tka/aum.go @@ -14,6 +14,7 @@ import ( "github.com/fxamacker/cbor/v2" "golang.org/x/crypto/blake2s" "tailscale.com/types/tkatype" + "tailscale.com/util/set" ) // AUMHash represents the BLAKE2s digest of an Authority Update Message (AUM). @@ -326,7 +327,7 @@ func (a *AUM) Weight(state State) uint { // Despite the wire encoding being []byte, all KeyIDs are // 32 bytes. As such, we use that as the key for the map, // because map keys cannot be slices. - seenKeys := make(map[[32]byte]struct{}, 6) + seenKeys := make(set.Set[[32]byte], 6) for _, sig := range a.Signatures { if len(sig.KeyID) != 32 { panic("unexpected: keyIDs are 32 bytes") @@ -344,12 +345,12 @@ func (a *AUM) Weight(state State) uint { } panic(err) } - if _, seen := seenKeys[keyID]; seen { + if seenKeys.Contains(keyID) { continue } weight += key.Votes - seenKeys[keyID] = struct{}{} + seenKeys.Add(keyID) } return weight diff --git a/tka/tka.go b/tka/tka.go index 293ed2f67..61bee804b 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -14,6 +14,7 @@ import ( "github.com/fxamacker/cbor/v2" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/set" ) // Strict settings for the CBOR decoder. @@ -260,13 +261,13 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) var ( curs = topAUM state State - path = make(map[AUMHash]struct{}, 32) // 32 chosen arbitrarily. + path = make(set.Set[AUMHash], 32) // 32 chosen arbitrarily. ) for i := 0; true; i++ { if i > maxIter { return State{}, fmt.Errorf("iteration limit exceeded (%d)", maxIter) } - path[curs.Hash()] = struct{}{} + path.Add(curs.Hash()) // Checkpoints encapsulate the state at that point, dope. if curs.MessageKind == AUMCheckpoint { @@ -307,7 +308,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) // such, we use a custom advancer here. advancer := func(state State, candidates []AUM) (next *AUM, out State, err error) { for _, c := range candidates { - if _, inPath := path[c.Hash()]; inPath { + if path.Contains(c.Hash()) { if state, err = state.applyVerifiedAUM(c); err != nil { return nil, State{}, fmt.Errorf("advancing state: %v", err) } diff --git a/util/set/set.go b/util/set/set.go index 6adb5182f..e6f3ef1f0 100644 --- a/util/set/set.go +++ b/util/set/set.go @@ -10,6 +10,9 @@ type Set[T comparable] map[T]struct{} // Add adds e to the set. func (s Set[T]) Add(e T) { s[e] = struct{}{} } +// Delete removes e from the set. +func (s Set[T]) Delete(e T) { delete(s, e) } + // Contains reports whether s contains e. func (s Set[T]) Contains(e T) bool { _, ok := s[e] diff --git a/util/winutil/svcdiag_windows.go b/util/winutil/svcdiag_windows.go index ce8706a06..cd7c150aa 100644 --- a/util/winutil/svcdiag_windows.go +++ b/util/winutil/svcdiag_windows.go @@ -14,6 +14,7 @@ import ( "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" "tailscale.com/types/logger" + "tailscale.com/util/set" ) // LogSvcState obtains the state of the Windows service named rootSvcName and @@ -78,7 +79,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { } }() - seen := make(map[string]struct{}) + seen := set.Set[string]{} for err == nil && len(deps) > 0 { err = func() error { @@ -87,7 +88,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { deps = deps[:len(deps)-1] - seen[curSvc.Name] = struct{}{} + seen.Add(curSvc.Name) curCfg, err := curSvc.Config() if err != nil { @@ -97,7 +98,7 @@ func walkServices(rootSvcName string, callback walkSvcFunc) error { callback(curSvc, curCfg) for _, depName := range curCfg.Dependencies { - if _, ok := seen[depName]; ok { + if seen.Contains(depName) { continue } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 4acb8e09d..36373cb1f 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -54,6 +54,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" + "tailscale.com/util/set" "tailscale.com/util/uniq" "tailscale.com/wgengine/capture" ) @@ -229,7 +230,7 @@ type Conn struct { // WireGuard. These are not used to filter inbound or outbound // traffic at all, but only to track what state can be cleaned up // in other maps below that are keyed by peer public key. - peerSet map[key.NodePublic]struct{} + peerSet set.Set[key.NodePublic] // nodeOfDisco tracks the networkmap Node entity for each peer // discovery key. @@ -1708,7 +1709,7 @@ func (c *Conn) SetPrivateKey(privateKey key.NodePrivate) error { // then removes any state for old peers. // // The caller passes ownership of newPeers map to UpdatePeers. -func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { +func (c *Conn) UpdatePeers(newPeers set.Set[key.NodePublic]) { c.mu.Lock() defer c.mu.Unlock() @@ -1718,7 +1719,7 @@ func (c *Conn) UpdatePeers(newPeers map[key.NodePublic]struct{}) { // Clean up any key.NodePublic-keyed maps for peers that no longer // exist. for peer := range oldPeers { - if _, ok := newPeers[peer]; !ok { + if !newPeers.Contains(peer) { delete(c.derpRoute, peer) delete(c.peerLastDerp, peer) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index ac6a89aa1..15dad0372 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -58,6 +58,7 @@ import ( "tailscale.com/types/ptr" "tailscale.com/util/cibuild" "tailscale.com/util/racebuild" + "tailscale.com/util/set" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg/nmcfg" @@ -306,9 +307,9 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM for i, m := range ms { nm := buildNetmapLocked(i) m.conn.SetNetworkMap(nm) - peerSet := make(map[key.NodePublic]struct{}, len(nm.Peers)) + peerSet := make(set.Set[key.NodePublic], len(nm.Peers)) for _, peer := range nm.Peers { - peerSet[peer.Key()] = struct{}{} + peerSet.Add(peer.Key()) } m.conn.UpdatePeers(peerSet) wg, err := nmcfg.WGCfg(nm, logf, netmap.AllowSingleHosts, "") diff --git a/wgengine/router/router_openbsd.go b/wgengine/router/router_openbsd.go index c23d37e47..b85992779 100644 --- a/wgengine/router/router_openbsd.go +++ b/wgengine/router/router_openbsd.go @@ -14,6 +14,7 @@ import ( "go4.org/netipx" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/set" ) // For now this router only supports the WireGuard userspace implementation. @@ -26,7 +27,7 @@ type openbsdRouter struct { tunname string local4 netip.Prefix local6 netip.Prefix - routes map[netip.Prefix]struct{} + routes set.Set[netip.Prefix] } func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor) (Router, error) { @@ -173,9 +174,9 @@ func (r *openbsdRouter) Set(cfg *Config) error { } } - newRoutes := make(map[netip.Prefix]struct{}) + newRoutes := set.Set[netip.Prefix]{} for _, route := range cfg.Routes { - newRoutes[route] = struct{}{} + newRoutes.Add(route) } for route := range r.routes { if _, keep := newRoutes[route]; !keep { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 0fcf01324..967c6ea3d 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -44,6 +44,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/deephash" "tailscale.com/util/mak" + "tailscale.com/util/set" "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" @@ -782,12 +783,12 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.tundev.SetWGConfig(cfg) e.lastDNSConfig = dnsCfg - peerSet := make(map[key.NodePublic]struct{}, len(cfg.Peers)) + peerSet := make(set.Set[key.NodePublic], len(cfg.Peers)) e.mu.Lock() e.peerSequence = e.peerSequence[:0] for _, p := range cfg.Peers { e.peerSequence = append(e.peerSequence, p.PublicKey) - peerSet[p.PublicKey] = struct{}{} + peerSet.Add(p.PublicKey) } nm := e.netMap e.mu.Unlock()