diff --git a/internal/aghalg/sortedmap.go b/internal/aghalg/sortedmap.go new file mode 100644 index 00000000..e983c44d --- /dev/null +++ b/internal/aghalg/sortedmap.go @@ -0,0 +1,86 @@ +package aghalg + +import ( + "slices" +) + +// SortedMap is a map that keeps elements in order with internal sorting +// function. Must be initialised by the [NewSortedMap]. +type SortedMap[K comparable, V any] struct { + vals map[K]V + cmp func(a, b K) (res int) + keys []K +} + +// NewSortedMap initializes the new instance of sorted map. cmp is a sort +// function to keep elements in order. +// +// TODO(s.chzhen): Use cmp.Compare in Go 1.21. +func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] { + return SortedMap[K, V]{ + vals: map[K]V{}, + cmp: cmp, + } +} + +// Set adds val with key to the sorted map. It panics if the m is nil. +func (m *SortedMap[K, V]) Set(key K, val V) { + m.vals[key] = val + + i, has := slices.BinarySearchFunc(m.keys, key, m.cmp) + if has { + m.keys[i] = key + } else { + m.keys = slices.Insert(m.keys, i, key) + } +} + +// Get returns val by key from the sorted map. +func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) { + if m == nil { + return + } + + val, ok = m.vals[key] + + return val, ok +} + +// Del removes the value by key from the sorted map. +func (m *SortedMap[K, V]) Del(key K) { + if m == nil { + return + } + + if _, has := m.vals[key]; !has { + return + } + + delete(m.vals, key) + i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp) + m.keys = slices.Delete(m.keys, i, i+1) +} + +// Clear removes all elements from the sorted map. +func (m *SortedMap[K, V]) Clear() { + if m == nil { + return + } + + m.keys = nil + clear(m.vals) +} + +// Range calls cb for each element of the map, sorted by m.cmp. If cb returns +// false it stops. +func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) { + if m == nil { + return + } + + for _, k := range m.keys { + if !cb(k, m.vals[k]) { + return + } + } +} diff --git a/internal/aghalg/sortedmap_test.go b/internal/aghalg/sortedmap_test.go new file mode 100644 index 00000000..46128ed0 --- /dev/null +++ b/internal/aghalg/sortedmap_test.go @@ -0,0 +1,95 @@ +package aghalg + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewSortedMap(t *testing.T) { + var m SortedMap[string, int] + + letters := []string{} + for i := 0; i < 10; i++ { + r := string('a' + rune(i)) + letters = append(letters, r) + } + + t.Run("create_and_fill", func(t *testing.T) { + m = NewSortedMap[string, int](strings.Compare) + + nums := []int{} + for i, r := range letters { + m.Set(r, i) + nums = append(nums, i) + } + + gotLetters := []string{} + gotNums := []int{} + m.Range(func(k string, v int) bool { + gotLetters = append(gotLetters, k) + gotNums = append(gotNums, v) + + return true + }) + + assert.Equal(t, letters, gotLetters) + assert.Equal(t, nums, gotNums) + + n, ok := m.Get(letters[0]) + assert.True(t, ok) + assert.Equal(t, nums[0], n) + }) + + t.Run("clear", func(t *testing.T) { + lastLetter := letters[len(letters)-1] + m.Del(lastLetter) + + _, ok := m.Get(lastLetter) + assert.False(t, ok) + + m.Clear() + + gotLetters := []string{} + m.Range(func(k string, _ int) bool { + gotLetters = append(gotLetters, k) + + return true + }) + + assert.Len(t, gotLetters, 0) + }) +} + +func TestNewSortedMap_nil(t *testing.T) { + const ( + key = "key" + val = "val" + ) + + var m SortedMap[string, string] + + assert.Panics(t, func() { + m.Set(key, val) + }) + + assert.NotPanics(t, func() { + _, ok := m.Get(key) + assert.False(t, ok) + }) + + assert.NotPanics(t, func() { + m.Range(func(_, _ string) (cont bool) { + return true + }) + }) + + assert.NotPanics(t, func() { + m.Del(key) + }) + + assert.NotPanics(t, func() { + m.Clear() + }) +} diff --git a/internal/home/client.go b/internal/home/client.go index 742754c7..f1b4b68f 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) { return UID(uuidv7), err } +// MustNewUID is a wrapper around [NewUID] that panics if there is an error. +func MustNewUID() (uid UID) { + uid, err := NewUID() + if err != nil { + panic(fmt.Errorf("unexpected uuidv7 error: %w", err)) + } + + return uid +} + // type check var _ encoding.TextMarshaler = UID{} diff --git a/internal/home/clientindex.go b/internal/home/clientindex.go new file mode 100644 index 00000000..87d7406a --- /dev/null +++ b/internal/home/clientindex.go @@ -0,0 +1,249 @@ +package home + +import ( + "fmt" + "net" + "net/netip" + + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" +) + +// macKey contains MAC as byte array of 6, 8, or 20 bytes. +type macKey any + +// macToKey converts mac into key of type macKey, which is used as the key of +// the [clientIndex.macToUID]. mac must be valid MAC address. +func macToKey(mac net.HardwareAddr) (key macKey) { + switch len(mac) { + case 6: + return [6]byte(mac) + case 8: + return [8]byte(mac) + case 20: + return [20]byte(mac) + default: + panic(fmt.Errorf("invalid mac address %#v", mac)) + } +} + +// clientIndex stores all information about persistent clients. +type clientIndex struct { + // clientIDToUID maps client ID to UID. + clientIDToUID map[string]UID + + // ipToUID maps IP address to UID. + ipToUID map[netip.Addr]UID + + // macToUID maps MAC address to UID. + macToUID map[macKey]UID + + // uidToClient maps UID to the persistent client. + uidToClient map[UID]*persistentClient + + // subnetToUID maps subnet to UID. + subnetToUID aghalg.SortedMap[netip.Prefix, UID] +} + +// NewClientIndex initializes the new instance of client index. +func NewClientIndex() (ci *clientIndex) { + return &clientIndex{ + clientIDToUID: map[string]UID{}, + ipToUID: map[netip.Addr]UID{}, + subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), + macToUID: map[macKey]UID{}, + uidToClient: map[UID]*persistentClient{}, + } +} + +// add stores information about a persistent client in the index. c must be +// non-nil and contain UID. +func (ci *clientIndex) add(c *persistentClient) { + if (c.UID == UID{}) { + panic("client must contain uid") + } + + for _, id := range c.ClientIDs { + ci.clientIDToUID[id] = c.UID + } + + for _, ip := range c.IPs { + ci.ipToUID[ip] = c.UID + } + + for _, pref := range c.Subnets { + ci.subnetToUID.Set(pref, c.UID) + } + + for _, mac := range c.MACs { + k := macToKey(mac) + ci.macToUID[k] = c.UID + } + + ci.uidToClient[c.UID] = c +} + +// clashes returns an error if the index contains a different persistent client +// with at least a single identifier contained by c. c must be non-nil. +func (ci *clientIndex) clashes(c *persistentClient) (err error) { + for _, id := range c.ClientIDs { + existing, ok := ci.clientIDToUID[id] + if ok && existing != c.UID { + p := ci.uidToClient[existing] + + return fmt.Errorf("another client %q uses the same ID %q", p.Name, id) + } + } + + p, ip := ci.clashesIP(c) + if p != nil { + return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip) + } + + p, s := ci.clashesSubnet(c) + if p != nil { + return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s) + } + + p, mac := ci.clashesMAC(c) + if p != nil { + return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac) + } + + return nil +} + +// clashesIP returns a previous client with the same IP address as c. c must be +// non-nil. +func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) { + for _, ip := range c.IPs { + existing, ok := ci.ipToUID[ip] + if ok && existing != c.UID { + return ci.uidToClient[existing], ip + } + } + + return nil, netip.Addr{} +} + +// clashesSubnet returns a previous client with the same subnet as c. c must be +// non-nil. +func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) { + for _, s = range c.Subnets { + var existing UID + var ok bool + + ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) { + if s == p { + existing = uid + ok = true + + return false + } + + return true + }) + + if ok && existing != c.UID { + return ci.uidToClient[existing], s + } + } + + return nil, netip.Prefix{} +} + +// clashesMAC returns a previous client with the same MAC address as c. c must +// be non-nil. +func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) { + for _, mac = range c.MACs { + k := macToKey(mac) + existing, ok := ci.macToUID[k] + if ok && existing != c.UID { + return ci.uidToClient[existing], mac + } + } + + return nil, nil +} + +// find finds persistent client by string representation of the client ID, IP +// address, or MAC. +func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) { + uid, found := ci.clientIDToUID[id] + if found { + return ci.uidToClient[uid], true + } + + ip, err := netip.ParseAddr(id) + if err == nil { + // MAC addresses can be successfully parsed as IP addresses. + c, found = ci.findByIP(ip) + if found { + return c, true + } + } + + mac, err := net.ParseMAC(id) + if err == nil { + return ci.findByMAC(mac) + } + + return nil, false +} + +// find finds persistent client by IP address. +func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) { + uid, found := ci.ipToUID[ip] + if found { + return ci.uidToClient[uid], true + } + + ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) { + if pref.Contains(ip) { + uid, found = id, true + + return false + } + + return true + }) + + if found { + return ci.uidToClient[uid], true + } + + return nil, false +} + +// find finds persistent client by MAC. +func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) { + k := macToKey(mac) + uid, found := ci.macToUID[k] + if found { + return ci.uidToClient[uid], true + } + + return nil, false +} + +// del removes information about persistent client from the index. c must be +// non-nil. +func (ci *clientIndex) del(c *persistentClient) { + for _, id := range c.ClientIDs { + delete(ci.clientIDToUID, id) + } + + for _, ip := range c.IPs { + delete(ci.ipToUID, ip) + } + + for _, pref := range c.Subnets { + ci.subnetToUID.Del(pref) + } + + for _, mac := range c.MACs { + k := macToKey(mac) + delete(ci.macToUID, k) + } + + delete(ci.uidToClient, c.UID) +} diff --git a/internal/home/clientindex_internal_test.go b/internal/home/clientindex_internal_test.go new file mode 100644 index 00000000..b89703db --- /dev/null +++ b/internal/home/clientindex_internal_test.go @@ -0,0 +1,210 @@ +package home + +import ( + "net" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClientIndex(t *testing.T) { + const ( + cliIPNone = "1.2.3.4" + cliIP1 = "1.1.1.1" + cliIP2 = "2.2.2.2" + + cliIPv6 = "1:2:3::4" + + cliSubnet = "2.2.2.0/24" + cliSubnetIP = "2.2.2.222" + + cliID = "client-id" + cliMAC = "11:11:11:11:11:11" + ) + + clients := []*persistentClient{{ + Name: "client1", + IPs: []netip.Addr{ + netip.MustParseAddr(cliIP1), + netip.MustParseAddr(cliIPv6), + }, + }, { + Name: "client2", + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + }, { + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + }, { + Name: "client_with_id", + ClientIDs: []string{cliID}, + }} + + ci := newIDIndex(clients) + + testCases := []struct { + name string + ids []string + want *persistentClient + }{{ + name: "ipv4_ipv6", + ids: []string{cliIP1, cliIPv6}, + want: clients[0], + }, { + name: "ipv4_subnet", + ids: []string{cliIP2, cliSubnetIP}, + want: clients[1], + }, { + name: "mac", + ids: []string{cliMAC}, + want: clients[2], + }, { + name: "client_id", + ids: []string{cliID}, + want: clients[3], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, id := range tc.ids { + c, ok := ci.find(id) + require.True(t, ok) + + assert.Equal(t, tc.want, c) + } + }) + } + + t.Run("not_found", func(t *testing.T) { + _, ok := ci.find(cliIPNone) + assert.False(t, ok) + }) +} + +func TestClientIndex_Clashes(t *testing.T) { + const ( + cliIP1 = "1.1.1.1" + cliSubnet = "2.2.2.0/24" + cliSubnetIP = "2.2.2.222" + cliID = "client-id" + cliMAC = "11:11:11:11:11:11" + ) + + clients := []*persistentClient{{ + Name: "client_with_ip", + IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, + }, { + Name: "client_with_subnet", + Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, + }, { + Name: "client_with_mac", + MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, + }, { + Name: "client_with_id", + ClientIDs: []string{cliID}, + }} + + ci := newIDIndex(clients) + + testCases := []struct { + name string + client *persistentClient + }{{ + name: "ipv4", + client: clients[0], + }, { + name: "subnet", + client: clients[1], + }, { + name: "mac", + client: clients[2], + }, { + name: "client_id", + client: clients[3], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + clone := tc.client.shallowClone() + clone.UID = MustNewUID() + + err := ci.clashes(clone) + require.Error(t, err) + + ci.del(tc.client) + err = ci.clashes(clone) + require.NoError(t, err) + }) + } +} + +// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an +// error. +func mustParseMAC(s string) (mac net.HardwareAddr) { + mac, err := net.ParseMAC(s) + if err != nil { + panic(err) + } + + return mac +} + +func TestMACToKey(t *testing.T) { + testCases := []struct { + name string + in string + want any + }{{ + name: "column6", + in: "00:00:5e:00:53:01", + want: [6]byte(mustParseMAC("00:00:5e:00:53:01")), + }, { + name: "column8", + in: "02:00:5e:10:00:00:00:01", + want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")), + }, { + name: "column20", + in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01", + want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")), + }, { + name: "hyphen6", + in: "00-00-5e-00-53-01", + want: [6]byte(mustParseMAC("00-00-5e-00-53-01")), + }, { + name: "hyphen8", + in: "02-00-5e-10-00-00-00-01", + want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")), + }, { + name: "hyphen20", + in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01", + want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")), + }, { + name: "dot6", + in: "0000.5e00.5301", + want: [6]byte(mustParseMAC("0000.5e00.5301")), + }, { + name: "dot8", + in: "0200.5e10.0000.0001", + want: [8]byte(mustParseMAC("0200.5e10.0000.0001")), + }, { + name: "dot20", + in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001", + want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mac := mustParseMAC(tc.in) + + key := macToKey(mac) + assert.Equal(t, tc.want, key) + }) + } + + assert.Panics(t, func() { + mac := net.HardwareAddr([]byte{1, 2, 3}) + _ = macToKey(mac) + }) +} diff --git a/internal/home/clients.go b/internal/home/clients.go index 5aa2c81e..f429e97e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -47,8 +47,9 @@ type DHCP interface { type clientsContainer struct { // TODO(a.garipov): Perhaps use a number of separate indices for different // types (string, netip.Addr, and so on). - list map[string]*persistentClient // name -> client - idIndex map[string]*persistentClient // ID -> client + list map[string]*persistentClient // name -> client + + clientIndex *clientIndex // ipToRC maps IP addresses to runtime client information. ipToRC map[netip.Addr]*client.Runtime @@ -103,9 +104,10 @@ func (clients *clientsContainer) Init( } clients.list = map[string]*persistentClient{} - clients.idIndex = map[string]*persistentClient{} clients.ipToRC = map[netip.Addr]*client.Runtime{} + clients.clientIndex = NewClientIndex() + clients.allTags = stringutil.NewSet(clientTags...) // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. @@ -517,7 +519,7 @@ func (clients *clientsContainer) UpstreamConfigByID( // findLocked searches for a client by its ID. clients.lock is expected to be // locked. func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) { - c, ok = clients.idIndex[id] + c, ok = clients.clientIndex.find(id) if ok { return c, true } @@ -527,14 +529,6 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok return nil, false } - for _, c = range clients.list { - for _, subnet := range c.Subnets { - if subnet.Contains(ip) { - return c, true - } - } - } - // TODO(e.burkov): Iterate through clients.list only once. return clients.findDHCP(ip) } @@ -638,18 +632,15 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { } // check ID index - ids := c.ids() - for _, id := range ids { - var c2 *persistentClient - c2, ok = clients.idIndex[id] - if ok { - return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name) - } + err = clients.clientIndex.clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return false, err } clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list)) return true, nil } @@ -660,9 +651,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) { clients.list[c.Name] = c // update ID index - for _, id := range c.ids() { - clients.idIndex[id] = c - } + clients.clientIndex.add(c) } // remove removes a client. ok is false if there is no such client. @@ -692,9 +681,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) { delete(clients.list, c.Name) // Update the ID index. - for _, id := range c.ids() { - delete(clients.idIndex, id) - } + clients.clientIndex.del(c) } // update updates a client by its name. @@ -724,11 +711,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { } // Check the ID index. - for _, id := range c.ids() { - existing, ok := clients.idIndex[id] - if ok && existing != prev { - return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) - } + err = clients.clientIndex.clashes(c) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err } clients.removeLocked(prev) diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 07332ecf..6b384f6b 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -68,6 +68,7 @@ func TestClients(t *testing.T) { c := &persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{cli1IP, cliIPv6}, } @@ -78,6 +79,7 @@ func TestClients(t *testing.T) { c = &persistentClient{ Name: "client2", + UID: MustNewUID(), IPs: []netip.Addr{cli2IP}, } @@ -111,6 +113,7 @@ func TestClients(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) require.NoError(t, err) @@ -120,6 +123,7 @@ func TestClients(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client3", + UID: MustNewUID(), }) require.Error(t, err) assert.False(t, ok) @@ -128,6 +132,7 @@ func TestClients(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) { err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{ Name: "client1", + UID: MustNewUID(), }) assert.Error(t, err) }) @@ -145,6 +150,7 @@ func TestClients(t *testing.T) { err := clients.update(prev, &persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -159,6 +165,7 @@ func TestClients(t *testing.T) { err = clients.update(prev, &persistentClient{ Name: "client1-renamed", + UID: MustNewUID(), IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -260,6 +267,7 @@ func TestClientsWHOIS(t *testing.T) { ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) @@ -282,6 +290,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a client. ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, @@ -332,6 +341,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the same IP as for a client with MAC. ok, err := clients.add(&persistentClient{ Name: "client2", + UID: MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) @@ -340,6 +350,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the IP from the first client's IP range. ok, err = clients.add(&persistentClient{ Name: "client3", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) @@ -353,6 +364,7 @@ func TestClientsCustomUpstream(t *testing.T) { // Add client with upstreams. ok, err := clients.add(&persistentClient{ Name: "client1", + UID: MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, Upstreams: []string{ "1.1.1.1", diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index 820b22a6..ff279752 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -12,6 +12,19 @@ import ( var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) +// newIDIndex is a helper function that returns a client index filled with +// persistent clients from the m. It also generates a UID for each client. +func newIDIndex(m []*persistentClient) (ci *clientIndex) { + ci = NewClientIndex() + + for _, c := range m { + c.UID = MustNewUID() + ci.add(c) + } + + return ci +} + func TestApplyAdditionalFiltering(t *testing.T) { var err error @@ -22,29 +35,28 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.idIndex = map[string]*persistentClient{ - "default": { - UseOwnSettings: false, - safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, - FilteringEnabled: false, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, - "custom_filtering": { - UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: true, - ParentalEnabled: true, - }, - "partial_custom_filtering": { - UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, - FilteringEnabled: true, - SafeBrowsingEnabled: false, - ParentalEnabled: false, - }, - } + Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + ClientIDs: []string{"default"}, + UseOwnSettings: false, + safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, + FilteringEnabled: false, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }, { + ClientIDs: []string{"custom_filtering"}, + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: true, + ParentalEnabled: true, + }, { + ClientIDs: []string{"partial_custom_filtering"}, + UseOwnSettings: true, + safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + FilteringEnabled: true, + SafeBrowsingEnabled: false, + ParentalEnabled: false, + }}) testCases := []struct { name string @@ -108,38 +120,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.idIndex = map[string]*persistentClient{ - "default": { - UseOwnBlockedServices: false, + Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + ClientIDs: []string{"default"}, + UseOwnBlockedServices: false, + }, { + ClientIDs: []string{"no_services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), }, - "no_services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: clientBlockedServices, }, - "services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"invalid_services"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.EmptyWeekly(), + IDs: invalidBlockedServices, }, - "invalid_services": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.EmptyWeekly(), - IDs: invalidBlockedServices, - }, - UseOwnBlockedServices: true, + UseOwnBlockedServices: true, + }, { + ClientIDs: []string{"allow_all"}, + BlockedServices: &filtering.BlockedServices{ + Schedule: schedule.FullWeekly(), + IDs: clientBlockedServices, }, - "allow_all": { - BlockedServices: &filtering.BlockedServices{ - Schedule: schedule.FullWeekly(), - IDs: clientBlockedServices, - }, - UseOwnBlockedServices: true, - }, - } + UseOwnBlockedServices: true, + }}) testCases := []struct { name string