Pull request 2138: AG-27492-client-persistent-storage

Squashed commit of the following:

commit 37e33ec761cfa30164125af2c5bb40789412355e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 14 15:25:25 2024 +0300

    aghalg: imp code

commit 6b2f09a44298b474ec1bdf3d027fb4941d2f7bea
Merge: b8ea924aa 37736289e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 14 15:04:59 2024 +0300

    Merge branch 'master' into AG-27492-client-persistent-storage

commit b8ea924aa7ed4c052760a6068f945d83d184e7e3
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Feb 13 19:07:52 2024 +0300

    home: imp tests

commit aa6fec03b1a1ead96bc76919b7ad51ae19626633
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Feb 13 14:54:28 2024 +0300

    home: imp docs

commit 10637fdec47d0b035cf5c7949ddcd9ec564851a3
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Feb 8 20:16:11 2024 +0300

    all: imp code

commit b45c7d868ddb1be73e119b3260e2a866d57baa91
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Feb 7 19:15:11 2024 +0300

    aghalg: add tests

commit 7abe33dbaa7221ddbc8b7d802dbfa7f951d90cf8
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Feb 6 20:50:22 2024 +0300

    all: imp code, tests

commit 4a44e993c9bd393d2cb9853108eae1ad91e64402
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Feb 1 14:59:11 2024 +0300

    all: persistent client index

commit 66b16e216e03e9f3d5e69496a89b18a9d732b564
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Jan 31 15:06:05 2024 +0300

    aghalg: ordered map
This commit is contained in:
Stanislav Chzhen 2024-02-15 14:08:05 +03:00
parent 37736289e2
commit fede297942
8 changed files with 741 additions and 82 deletions

View File

@ -0,0 +1,86 @@
package aghalg
import (
"slices"
)
// SortedMap is a map that keeps elements in order with internal sorting
// function. Must be initialised by the [NewSortedMap].
type SortedMap[K comparable, V any] struct {
vals map[K]V
cmp func(a, b K) (res int)
keys []K
}
// NewSortedMap initializes the new instance of sorted map. cmp is a sort
// function to keep elements in order.
//
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
return SortedMap[K, V]{
vals: map[K]V{},
cmp: cmp,
}
}
// Set adds val with key to the sorted map. It panics if the m is nil.
func (m *SortedMap[K, V]) Set(key K, val V) {
m.vals[key] = val
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
if has {
m.keys[i] = key
} else {
m.keys = slices.Insert(m.keys, i, key)
}
}
// Get returns val by key from the sorted map.
func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) {
if m == nil {
return
}
val, ok = m.vals[key]
return val, ok
}
// Del removes the value by key from the sorted map.
func (m *SortedMap[K, V]) Del(key K) {
if m == nil {
return
}
if _, has := m.vals[key]; !has {
return
}
delete(m.vals, key)
i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp)
m.keys = slices.Delete(m.keys, i, i+1)
}
// Clear removes all elements from the sorted map.
func (m *SortedMap[K, V]) Clear() {
if m == nil {
return
}
m.keys = nil
clear(m.vals)
}
// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
// false it stops.
func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) {
if m == nil {
return
}
for _, k := range m.keys {
if !cb(k, m.vals[k]) {
return
}
}
}

View File

@ -0,0 +1,95 @@
package aghalg
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewSortedMap(t *testing.T) {
var m SortedMap[string, int]
letters := []string{}
for i := 0; i < 10; i++ {
r := string('a' + rune(i))
letters = append(letters, r)
}
t.Run("create_and_fill", func(t *testing.T) {
m = NewSortedMap[string, int](strings.Compare)
nums := []int{}
for i, r := range letters {
m.Set(r, i)
nums = append(nums, i)
}
gotLetters := []string{}
gotNums := []int{}
m.Range(func(k string, v int) bool {
gotLetters = append(gotLetters, k)
gotNums = append(gotNums, v)
return true
})
assert.Equal(t, letters, gotLetters)
assert.Equal(t, nums, gotNums)
n, ok := m.Get(letters[0])
assert.True(t, ok)
assert.Equal(t, nums[0], n)
})
t.Run("clear", func(t *testing.T) {
lastLetter := letters[len(letters)-1]
m.Del(lastLetter)
_, ok := m.Get(lastLetter)
assert.False(t, ok)
m.Clear()
gotLetters := []string{}
m.Range(func(k string, _ int) bool {
gotLetters = append(gotLetters, k)
return true
})
assert.Len(t, gotLetters, 0)
})
}
func TestNewSortedMap_nil(t *testing.T) {
const (
key = "key"
val = "val"
)
var m SortedMap[string, string]
assert.Panics(t, func() {
m.Set(key, val)
})
assert.NotPanics(t, func() {
_, ok := m.Get(key)
assert.False(t, ok)
})
assert.NotPanics(t, func() {
m.Range(func(_, _ string) (cont bool) {
return true
})
})
assert.NotPanics(t, func() {
m.Del(key)
})
assert.NotPanics(t, func() {
m.Clear()
})
}

View File

@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) {
return UID(uuidv7), err return UID(uuidv7), err
} }
// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
func MustNewUID() (uid UID) {
uid, err := NewUID()
if err != nil {
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
}
return uid
}
// type check // type check
var _ encoding.TextMarshaler = UID{} var _ encoding.TextMarshaler = UID{}

View File

@ -0,0 +1,249 @@
package home
import (
"fmt"
"net"
"net/netip"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
)
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
type macKey any
// macToKey converts mac into key of type macKey, which is used as the key of
// the [clientIndex.macToUID]. mac must be valid MAC address.
func macToKey(mac net.HardwareAddr) (key macKey) {
switch len(mac) {
case 6:
return [6]byte(mac)
case 8:
return [8]byte(mac)
case 20:
return [20]byte(mac)
default:
panic(fmt.Errorf("invalid mac address %#v", mac))
}
}
// clientIndex stores all information about persistent clients.
type clientIndex struct {
// clientIDToUID maps client ID to UID.
clientIDToUID map[string]UID
// ipToUID maps IP address to UID.
ipToUID map[netip.Addr]UID
// macToUID maps MAC address to UID.
macToUID map[macKey]UID
// uidToClient maps UID to the persistent client.
uidToClient map[UID]*persistentClient
// subnetToUID maps subnet to UID.
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
}
// NewClientIndex initializes the new instance of client index.
func NewClientIndex() (ci *clientIndex) {
return &clientIndex{
clientIDToUID: map[string]UID{},
ipToUID: map[netip.Addr]UID{},
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
macToUID: map[macKey]UID{},
uidToClient: map[UID]*persistentClient{},
}
}
// add stores information about a persistent client in the index. c must be
// non-nil and contain UID.
func (ci *clientIndex) add(c *persistentClient) {
if (c.UID == UID{}) {
panic("client must contain uid")
}
for _, id := range c.ClientIDs {
ci.clientIDToUID[id] = c.UID
}
for _, ip := range c.IPs {
ci.ipToUID[ip] = c.UID
}
for _, pref := range c.Subnets {
ci.subnetToUID.Set(pref, c.UID)
}
for _, mac := range c.MACs {
k := macToKey(mac)
ci.macToUID[k] = c.UID
}
ci.uidToClient[c.UID] = c
}
// clashes returns an error if the index contains a different persistent client
// with at least a single identifier contained by c. c must be non-nil.
func (ci *clientIndex) clashes(c *persistentClient) (err error) {
for _, id := range c.ClientIDs {
existing, ok := ci.clientIDToUID[id]
if ok && existing != c.UID {
p := ci.uidToClient[existing]
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
}
}
p, ip := ci.clashesIP(c)
if p != nil {
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
}
p, s := ci.clashesSubnet(c)
if p != nil {
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
}
p, mac := ci.clashesMAC(c)
if p != nil {
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
}
return nil
}
// clashesIP returns a previous client with the same IP address as c. c must be
// non-nil.
func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) {
for _, ip := range c.IPs {
existing, ok := ci.ipToUID[ip]
if ok && existing != c.UID {
return ci.uidToClient[existing], ip
}
}
return nil, netip.Addr{}
}
// clashesSubnet returns a previous client with the same subnet as c. c must be
// non-nil.
func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) {
for _, s = range c.Subnets {
var existing UID
var ok bool
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
if s == p {
existing = uid
ok = true
return false
}
return true
})
if ok && existing != c.UID {
return ci.uidToClient[existing], s
}
}
return nil, netip.Prefix{}
}
// clashesMAC returns a previous client with the same MAC address as c. c must
// be non-nil.
func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) {
for _, mac = range c.MACs {
k := macToKey(mac)
existing, ok := ci.macToUID[k]
if ok && existing != c.UID {
return ci.uidToClient[existing], mac
}
}
return nil, nil
}
// find finds persistent client by string representation of the client ID, IP
// address, or MAC.
func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
uid, found := ci.clientIDToUID[id]
if found {
return ci.uidToClient[uid], true
}
ip, err := netip.ParseAddr(id)
if err == nil {
// MAC addresses can be successfully parsed as IP addresses.
c, found = ci.findByIP(ip)
if found {
return c, true
}
}
mac, err := net.ParseMAC(id)
if err == nil {
return ci.findByMAC(mac)
}
return nil, false
}
// find finds persistent client by IP address.
func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) {
uid, found := ci.ipToUID[ip]
if found {
return ci.uidToClient[uid], true
}
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
if pref.Contains(ip) {
uid, found = id, true
return false
}
return true
})
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// find finds persistent client by MAC.
func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) {
k := macToKey(mac)
uid, found := ci.macToUID[k]
if found {
return ci.uidToClient[uid], true
}
return nil, false
}
// del removes information about persistent client from the index. c must be
// non-nil.
func (ci *clientIndex) del(c *persistentClient) {
for _, id := range c.ClientIDs {
delete(ci.clientIDToUID, id)
}
for _, ip := range c.IPs {
delete(ci.ipToUID, ip)
}
for _, pref := range c.Subnets {
ci.subnetToUID.Del(pref)
}
for _, mac := range c.MACs {
k := macToKey(mac)
delete(ci.macToUID, k)
}
delete(ci.uidToClient, c.UID)
}

View File

@ -0,0 +1,210 @@
package home
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClientIndex(t *testing.T) {
const (
cliIPNone = "1.2.3.4"
cliIP1 = "1.1.1.1"
cliIP2 = "2.2.2.2"
cliIPv6 = "1:2:3::4"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*persistentClient{{
Name: "client1",
IPs: []netip.Addr{
netip.MustParseAddr(cliIP1),
netip.MustParseAddr(cliIPv6),
},
}, {
Name: "client2",
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
name string
ids []string
want *persistentClient
}{{
name: "ipv4_ipv6",
ids: []string{cliIP1, cliIPv6},
want: clients[0],
}, {
name: "ipv4_subnet",
ids: []string{cliIP2, cliSubnetIP},
want: clients[1],
}, {
name: "mac",
ids: []string{cliMAC},
want: clients[2],
}, {
name: "client_id",
ids: []string{cliID},
want: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
for _, id := range tc.ids {
c, ok := ci.find(id)
require.True(t, ok)
assert.Equal(t, tc.want, c)
}
})
}
t.Run("not_found", func(t *testing.T) {
_, ok := ci.find(cliIPNone)
assert.False(t, ok)
})
}
func TestClientIndex_Clashes(t *testing.T) {
const (
cliIP1 = "1.1.1.1"
cliSubnet = "2.2.2.0/24"
cliSubnetIP = "2.2.2.222"
cliID = "client-id"
cliMAC = "11:11:11:11:11:11"
)
clients := []*persistentClient{{
Name: "client_with_ip",
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
}, {
Name: "client_with_subnet",
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
}, {
Name: "client_with_mac",
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
}, {
Name: "client_with_id",
ClientIDs: []string{cliID},
}}
ci := newIDIndex(clients)
testCases := []struct {
name string
client *persistentClient
}{{
name: "ipv4",
client: clients[0],
}, {
name: "subnet",
client: clients[1],
}, {
name: "mac",
client: clients[2],
}, {
name: "client_id",
client: clients[3],
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
clone := tc.client.shallowClone()
clone.UID = MustNewUID()
err := ci.clashes(clone)
require.Error(t, err)
ci.del(tc.client)
err = ci.clashes(clone)
require.NoError(t, err)
})
}
}
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
// error.
func mustParseMAC(s string) (mac net.HardwareAddr) {
mac, err := net.ParseMAC(s)
if err != nil {
panic(err)
}
return mac
}
func TestMACToKey(t *testing.T) {
testCases := []struct {
name string
in string
want any
}{{
name: "column6",
in: "00:00:5e:00:53:01",
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
}, {
name: "column8",
in: "02:00:5e:10:00:00:00:01",
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
}, {
name: "column20",
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
}, {
name: "hyphen6",
in: "00-00-5e-00-53-01",
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
}, {
name: "hyphen8",
in: "02-00-5e-10-00-00-00-01",
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
}, {
name: "hyphen20",
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
}, {
name: "dot6",
in: "0000.5e00.5301",
want: [6]byte(mustParseMAC("0000.5e00.5301")),
}, {
name: "dot8",
in: "0200.5e10.0000.0001",
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
}, {
name: "dot20",
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mac := mustParseMAC(tc.in)
key := macToKey(mac)
assert.Equal(t, tc.want, key)
})
}
assert.Panics(t, func() {
mac := net.HardwareAddr([]byte{1, 2, 3})
_ = macToKey(mac)
})
}

View File

@ -47,8 +47,9 @@ type DHCP interface {
type clientsContainer struct { type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for different // TODO(a.garipov): Perhaps use a number of separate indices for different
// types (string, netip.Addr, and so on). // types (string, netip.Addr, and so on).
list map[string]*persistentClient // name -> client list map[string]*persistentClient // name -> client
idIndex map[string]*persistentClient // ID -> client
clientIndex *clientIndex
// ipToRC maps IP addresses to runtime client information. // ipToRC maps IP addresses to runtime client information.
ipToRC map[netip.Addr]*client.Runtime ipToRC map[netip.Addr]*client.Runtime
@ -103,9 +104,10 @@ func (clients *clientsContainer) Init(
} }
clients.list = map[string]*persistentClient{} clients.list = map[string]*persistentClient{}
clients.idIndex = map[string]*persistentClient{}
clients.ipToRC = map[netip.Addr]*client.Runtime{} clients.ipToRC = map[netip.Addr]*client.Runtime{}
clients.clientIndex = NewClientIndex()
clients.allTags = stringutil.NewSet(clientTags...) clients.allTags = stringutil.NewSet(clientTags...)
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready. // TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
@ -517,7 +519,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
// findLocked searches for a client by its ID. clients.lock is expected to be // findLocked searches for a client by its ID. clients.lock is expected to be
// locked. // locked.
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) { func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
c, ok = clients.idIndex[id] c, ok = clients.clientIndex.find(id)
if ok { if ok {
return c, true return c, true
} }
@ -527,14 +529,6 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok
return nil, false return nil, false
} }
for _, c = range clients.list {
for _, subnet := range c.Subnets {
if subnet.Contains(ip) {
return c, true
}
}
}
// TODO(e.burkov): Iterate through clients.list only once. // TODO(e.burkov): Iterate through clients.list only once.
return clients.findDHCP(ip) return clients.findDHCP(ip)
} }
@ -638,18 +632,15 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
} }
// check ID index // check ID index
ids := c.ids() err = clients.clientIndex.clashes(c)
for _, id := range ids { if err != nil {
var c2 *persistentClient // Don't wrap the error since it's informative enough as is.
c2, ok = clients.idIndex[id] return false, err
if ok {
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
}
} }
clients.addLocked(c) clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list)) log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list))
return true, nil return true, nil
} }
@ -660,9 +651,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) {
clients.list[c.Name] = c clients.list[c.Name] = c
// update ID index // update ID index
for _, id := range c.ids() { clients.clientIndex.add(c)
clients.idIndex[id] = c
}
} }
// remove removes a client. ok is false if there is no such client. // remove removes a client. ok is false if there is no such client.
@ -692,9 +681,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) {
delete(clients.list, c.Name) delete(clients.list, c.Name)
// Update the ID index. // Update the ID index.
for _, id := range c.ids() { clients.clientIndex.del(c)
delete(clients.idIndex, id)
}
} }
// update updates a client by its name. // update updates a client by its name.
@ -724,11 +711,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
} }
// Check the ID index. // Check the ID index.
for _, id := range c.ids() { err = clients.clientIndex.clashes(c)
existing, ok := clients.idIndex[id] if err != nil {
if ok && existing != prev { // Don't wrap the error since it's informative enough as is.
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) return err
}
} }
clients.removeLocked(prev) clients.removeLocked(prev)

View File

@ -68,6 +68,7 @@ func TestClients(t *testing.T) {
c := &persistentClient{ c := &persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{cli1IP, cliIPv6}, IPs: []netip.Addr{cli1IP, cliIPv6},
} }
@ -78,6 +79,7 @@ func TestClients(t *testing.T) {
c = &persistentClient{ c = &persistentClient{
Name: "client2", Name: "client2",
UID: MustNewUID(),
IPs: []netip.Addr{cli2IP}, IPs: []netip.Addr{cli2IP},
} }
@ -111,6 +113,7 @@ func TestClients(t *testing.T) {
t.Run("add_fail_name", func(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) {
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -120,6 +123,7 @@ func TestClients(t *testing.T) {
t.Run("add_fail_ip", func(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) {
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client3", Name: "client3",
UID: MustNewUID(),
}) })
require.Error(t, err) require.Error(t, err)
assert.False(t, ok) assert.False(t, ok)
@ -128,6 +132,7 @@ func TestClients(t *testing.T) {
t.Run("update_fail_ip", func(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) {
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{ err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
}) })
assert.Error(t, err) assert.Error(t, err)
}) })
@ -145,6 +150,7 @@ func TestClients(t *testing.T) {
err := clients.update(prev, &persistentClient{ err := clients.update(prev, &persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{cliNewIP}, IPs: []netip.Addr{cliNewIP},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -159,6 +165,7 @@ func TestClients(t *testing.T) {
err = clients.update(prev, &persistentClient{ err = clients.update(prev, &persistentClient{
Name: "client1-renamed", Name: "client1-renamed",
UID: MustNewUID(),
IPs: []netip.Addr{cliNewIP}, IPs: []netip.Addr{cliNewIP},
UseOwnSettings: true, UseOwnSettings: true,
}) })
@ -260,6 +267,7 @@ func TestClientsWHOIS(t *testing.T) {
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -282,6 +290,7 @@ func TestClientsAddExisting(t *testing.T) {
// Add a client. // Add a client.
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
@ -332,6 +341,7 @@ func TestClientsAddExisting(t *testing.T) {
// Add a new client with the same IP as for a client with MAC. // Add a new client with the same IP as for a client with MAC.
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client2", Name: "client2",
UID: MustNewUID(),
IPs: []netip.Addr{ip}, IPs: []netip.Addr{ip},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -340,6 +350,7 @@ func TestClientsAddExisting(t *testing.T) {
// Add a new client with the IP from the first client's IP range. // Add a new client with the IP from the first client's IP range.
ok, err = clients.add(&persistentClient{ ok, err = clients.add(&persistentClient{
Name: "client3", Name: "client3",
UID: MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
}) })
require.NoError(t, err) require.NoError(t, err)
@ -353,6 +364,7 @@ func TestClientsCustomUpstream(t *testing.T) {
// Add client with upstreams. // Add client with upstreams.
ok, err := clients.add(&persistentClient{ ok, err := clients.add(&persistentClient{
Name: "client1", Name: "client1",
UID: MustNewUID(),
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
Upstreams: []string{ Upstreams: []string{
"1.1.1.1", "1.1.1.1",

View File

@ -12,6 +12,19 @@ import (
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
// newIDIndex is a helper function that returns a client index filled with
// persistent clients from the m. It also generates a UID for each client.
func newIDIndex(m []*persistentClient) (ci *clientIndex) {
ci = NewClientIndex()
for _, c := range m {
c.UID = MustNewUID()
ci.add(c)
}
return ci
}
func TestApplyAdditionalFiltering(t *testing.T) { func TestApplyAdditionalFiltering(t *testing.T) {
var err error var err error
@ -22,29 +35,28 @@ func TestApplyAdditionalFiltering(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
Context.clients.idIndex = map[string]*persistentClient{ Context.clients.clientIndex = newIDIndex([]*persistentClient{{
"default": { ClientIDs: []string{"default"},
UseOwnSettings: false, UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false, FilteringEnabled: false,
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
ParentalEnabled: false, ParentalEnabled: false,
}, }, {
"custom_filtering": { ClientIDs: []string{"custom_filtering"},
UseOwnSettings: true, UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true, FilteringEnabled: true,
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
ParentalEnabled: true, ParentalEnabled: true,
}, }, {
"partial_custom_filtering": { ClientIDs: []string{"partial_custom_filtering"},
UseOwnSettings: true, UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true, FilteringEnabled: true,
SafeBrowsingEnabled: false, SafeBrowsingEnabled: false,
ParentalEnabled: false, ParentalEnabled: false,
}, }})
}
testCases := []struct { testCases := []struct {
name string name string
@ -108,38 +120,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
Context.clients.idIndex = map[string]*persistentClient{ Context.clients.clientIndex = newIDIndex([]*persistentClient{{
"default": { ClientIDs: []string{"default"},
UseOwnBlockedServices: false, UseOwnBlockedServices: false,
}, {
ClientIDs: []string{"no_services"},
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
}, },
"no_services": { UseOwnBlockedServices: true,
BlockedServices: &filtering.BlockedServices{ }, {
Schedule: schedule.EmptyWeekly(), ClientIDs: []string{"services"},
}, BlockedServices: &filtering.BlockedServices{
UseOwnBlockedServices: true, Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
}, },
"services": { UseOwnBlockedServices: true,
BlockedServices: &filtering.BlockedServices{ }, {
Schedule: schedule.EmptyWeekly(), ClientIDs: []string{"invalid_services"},
IDs: clientBlockedServices, BlockedServices: &filtering.BlockedServices{
}, Schedule: schedule.EmptyWeekly(),
UseOwnBlockedServices: true, IDs: invalidBlockedServices,
}, },
"invalid_services": { UseOwnBlockedServices: true,
BlockedServices: &filtering.BlockedServices{ }, {
Schedule: schedule.EmptyWeekly(), ClientIDs: []string{"allow_all"},
IDs: invalidBlockedServices, BlockedServices: &filtering.BlockedServices{
}, Schedule: schedule.FullWeekly(),
UseOwnBlockedServices: true, IDs: clientBlockedServices,
}, },
"allow_all": { UseOwnBlockedServices: true,
BlockedServices: &filtering.BlockedServices{ }})
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true,
},
}
testCases := []struct { testCases := []struct {
name string name string