package client import ( "fmt" "net" "net/netip" "slices" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/golibs/errors" "golang.org/x/exp/maps" ) // macKey contains MAC as byte array of 6, 8, or 20 bytes. type macKey any // macToKey converts mac into key of type macKey, which is used as the key of // the [clientIndex.macToUID]. mac must be valid MAC address. func macToKey(mac net.HardwareAddr) (key macKey) { switch len(mac) { case 6: return [6]byte(mac) case 8: return [8]byte(mac) case 20: return [20]byte(mac) default: panic(fmt.Errorf("invalid mac address %#v", mac)) } } // Index stores all information about persistent clients. type Index struct { // nameToUID maps client name to UID. nameToUID map[string]UID // clientIDToUID maps client ID to UID. clientIDToUID map[string]UID // ipToUID maps IP address to UID. ipToUID map[netip.Addr]UID // macToUID maps MAC address to UID. macToUID map[macKey]UID // uidToClient maps UID to the persistent client. uidToClient map[UID]*Persistent // subnetToUID maps subnet to UID. subnetToUID aghalg.SortedMap[netip.Prefix, UID] } // NewIndex initializes the new instance of client index. func NewIndex() (ci *Index) { return &Index{ nameToUID: map[string]UID{}, clientIDToUID: map[string]UID{}, ipToUID: map[netip.Addr]UID{}, subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), macToUID: map[macKey]UID{}, uidToClient: map[UID]*Persistent{}, } } // Add stores information about a persistent client in the index. c must be // non-nil and contain UID. func (ci *Index) Add(c *Persistent) { if (c.UID == UID{}) { panic("client must contain uid") } ci.nameToUID[c.Name] = c.UID for _, id := range c.ClientIDs { ci.clientIDToUID[id] = c.UID } for _, ip := range c.IPs { ci.ipToUID[ip] = c.UID } for _, pref := range c.Subnets { ci.subnetToUID.Set(pref, c.UID) } for _, mac := range c.MACs { k := macToKey(mac) ci.macToUID[k] = c.UID } ci.uidToClient[c.UID] = c } // ClashesUID returns existing persistent client with the same UID as c. Note // that this is only possible when configuration contains duplicate fields. func (ci *Index) ClashesUID(c *Persistent) (err error) { p, ok := ci.uidToClient[c.UID] if ok { return fmt.Errorf("another client %q uses the same uid", p.Name) } return nil } // Clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. func (ci *Index) Clashes(c *Persistent) (err error) { if p := ci.clashesName(c); p != nil { return fmt.Errorf("another client uses the same name %q", p.Name) } for _, id := range c.ClientIDs { existing, ok := ci.clientIDToUID[id] if ok && existing != c.UID { p := ci.uidToClient[existing] return fmt.Errorf("another client %q uses the same ClientID %q", p.Name, id) } } p, ip := ci.clashesIP(c) if p != nil { return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip) } p, s := ci.clashesSubnet(c) if p != nil { return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s) } p, mac := ci.clashesMAC(c) if p != nil { return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac) } return nil } // clashesName returns existing persistent client with the same name as c or // nil. c must be non-nil. func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { existing, ok := ci.FindByName(c.Name) if !ok { return nil } if existing.UID != c.UID { return existing } return nil } // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { for _, ip := range c.IPs { existing, ok := ci.ipToUID[ip] if ok && existing != c.UID { return ci.uidToClient[existing], ip } } return nil, netip.Addr{} } // clashesSubnet returns a previous client with the same subnet as c. c must be // non-nil. func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) { for _, s = range c.Subnets { var existing UID var ok bool ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) { if s == p { existing = uid ok = true return false } return true }) if ok && existing != c.UID { return ci.uidToClient[existing], s } } return nil, netip.Prefix{} } // clashesMAC returns a previous client with the same MAC address as c. c must // be non-nil. func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) { for _, mac = range c.MACs { k := macToKey(mac) existing, ok := ci.macToUID[k] if ok && existing != c.UID { return ci.uidToClient[existing], mac } } return nil, nil } // Find finds persistent client by string representation of the client ID, IP // address, or MAC. func (ci *Index) Find(id string) (c *Persistent, ok bool) { uid, found := ci.clientIDToUID[id] if found { return ci.uidToClient[uid], true } ip, err := netip.ParseAddr(id) if err == nil { // MAC addresses can be successfully parsed as IP addresses. c, found = ci.findByIP(ip) if found { return c, true } } mac, err := net.ParseMAC(id) if err == nil { return ci.FindByMAC(mac) } return nil, false } // FindByName finds persistent client by name. func (ci *Index) FindByName(name string) (c *Persistent, found bool) { uid, found := ci.nameToUID[name] if found { return ci.uidToClient[uid], true } return nil, false } // findByIP finds persistent client by IP address. func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { uid, found := ci.ipToUID[ip] if found { return ci.uidToClient[uid], true } ipWithoutZone := ip.WithZone("") ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) { // Remove zone before checking because prefixes strip zones. if pref.Contains(ipWithoutZone) { uid, found = id, true return false } return true }) if found { return ci.uidToClient[uid], true } return nil, false } // FindByMAC finds persistent client by MAC. func (ci *Index) FindByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { k := macToKey(mac) uid, found := ci.macToUID[k] if found { return ci.uidToClient[uid], true } return nil, false } // FindByIPWithoutZone finds a persistent client by IP address without zone. It // strips the IPv6 zone index from the stored IP addresses before comparing, // because querylog entries don't have it. See TODO on [querylog.logEntry.IP]. // // Note that multiple clients can have the same IP address with different zones. // Therefore, the result of this method is indeterminate. func (ci *Index) FindByIPWithoutZone(ip netip.Addr) (c *Persistent) { if (ip == netip.Addr{}) { return nil } for addr, uid := range ci.ipToUID { if addr.WithZone("") == ip { return ci.uidToClient[uid] } } return nil } // Delete removes information about persistent client from the index. c must be // non-nil. func (ci *Index) Delete(c *Persistent) { delete(ci.nameToUID, c.Name) for _, id := range c.ClientIDs { delete(ci.clientIDToUID, id) } for _, ip := range c.IPs { delete(ci.ipToUID, ip) } for _, pref := range c.Subnets { ci.subnetToUID.Del(pref) } for _, mac := range c.MACs { k := macToKey(mac) delete(ci.macToUID, k) } delete(ci.uidToClient, c.UID) } // Size returns the number of persistent clients. func (ci *Index) Size() (n int) { return len(ci.uidToClient) } // Range calls f for each persistent client, unless cont is false. The order is // undefined. func (ci *Index) Range(f func(c *Persistent) (cont bool)) { for _, c := range ci.uidToClient { if !f(c) { return } } } // RangeByName is like [Index.Range] but sorts the persistent clients by name // before iterating ensuring a predictable order. func (ci *Index) RangeByName(f func(c *Persistent) (cont bool)) { cs := maps.Values(ci.uidToClient) slices.SortFunc(cs, func(a, b *Persistent) (n int) { return strings.Compare(a.Name, b.Name) }) for _, c := range cs { if !f(c) { break } } } // CloseUpstreams closes upstream configurations of persistent clients. func (ci *Index) CloseUpstreams() (err error) { var errs []error ci.RangeByName(func(c *Persistent) (cont bool) { err = c.CloseUpstreams() if err != nil { errs = append(errs, err) } return true }) return errors.Join(errs...) }