home: imp docs

This commit is contained in:
Stanislav Chzhen 2024-01-15 15:50:28 +03:00
parent 4d4b359991
commit 512edaf2dc
3 changed files with 27 additions and 32 deletions

View File

@ -1,7 +1,6 @@
package home
import (
"bytes"
"encoding"
"fmt"
"net"
@ -117,10 +116,7 @@ func (c *persistentClient) setIDs(ids []string) (err error) {
// TODO(s.chzhen): Use netip.PrefixCompare in Go 1.23.
slices.SortFunc(c.Subnets, subnetCompare)
slices.SortFunc(c.MACs, func(a, b net.HardwareAddr) int {
return bytes.Compare(a, b)
})
slices.SortFunc(c.MACs, slices.Compare[net.HardwareAddr])
slices.Sort(c.ClientIDs)
return nil
@ -185,7 +181,7 @@ func (c *persistentClient) setID(id string) (err error) {
return nil
}
// ids returns a list of client ids.
// ids returns a list of client ids containing at least one element.
func (c *persistentClient) ids() (ids []string) {
ids = make([]string, 0, c.idsLen())

View File

@ -30,82 +30,82 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
name string
ids []string
prevIDs []string
want bool
want assert.BoolAssertionFunc
}{{
name: "single_ip",
ids: []string{ip1},
prevIDs: []string{ip1},
want: true,
want: assert.True,
}, {
name: "single_ip_not_equal",
ids: []string{ip1},
prevIDs: []string{ip2},
want: false,
want: assert.False,
}, {
name: "ips_not_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip1, ip},
want: false,
want: assert.False,
}, {
name: "ips_mixed_equal",
ids: []string{ip1, ip2},
prevIDs: []string{ip2, ip1},
want: true,
want: assert.True,
}, {
name: "single_subnet",
ids: []string{cidr1},
prevIDs: []string{cidr1},
want: true,
want: assert.True,
}, {
name: "subnets_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{ip1, ip2, cidr1, cidr},
want: false,
want: assert.False,
}, {
name: "subnets_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2},
prevIDs: []string{cidr2, cidr1, ip2, ip1},
want: true,
want: assert.True,
}, {
name: "single_mac",
ids: []string{mac1},
prevIDs: []string{mac1},
want: true,
want: assert.True,
}, {
name: "single_mac_not_equal",
ids: []string{mac1},
prevIDs: []string{mac2},
want: false,
want: assert.False,
}, {
name: "macs_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac},
want: false,
want: assert.False,
}, {
name: "macs_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2},
prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1},
want: true,
want: assert.True,
}, {
name: "single_client_id",
ids: []string{cli1},
prevIDs: []string{cli1},
want: true,
want: assert.True,
}, {
name: "single_client_id_not_equal",
ids: []string{cli1},
prevIDs: []string{cli2},
want: false,
want: assert.False,
}, {
name: "client_ids_not_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli},
want: false,
want: assert.False,
}, {
name: "client_ids_mixed_equal",
ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2},
prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1},
want: true,
want: assert.True,
}}
for _, tc := range testCases {
@ -118,8 +118,7 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
err = prev.setIDs(tc.prevIDs)
require.NoError(t, err)
equal := c.equalIDs(prev)
assert.Equal(t, tc.want, equal)
tc.want(t, c.equalIDs(prev))
})
}
}

View File

@ -1,7 +1,6 @@
package home
import (
"bytes"
"fmt"
"net"
"net/netip"
@ -550,10 +549,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, o
}
for _, c = range clients.list {
for _, mac := range c.MACs {
if bytes.Equal(mac, foundMAC) {
return c, true
}
_, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr])
if found {
return c, true
}
}
@ -593,7 +591,7 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru
return rc, ok
}
// check validates the client.
// check validates the client. It also sorts the client tags.
func (clients *clientsContainer) check(c *persistentClient) (err error) {
switch {
case c == nil:
@ -612,6 +610,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) {
}
}
// TODO(s.chzhen): Move to the constructor.
slices.Sort(c.Tags)
err = dnsforward.ValidateUpstreams(c.Upstreams)
@ -640,7 +639,8 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
}
// check ID index
for _, id := range c.ids() {
ids := c.ids()
for _, id := range ids {
var c2 *persistentClient
c2, ok = clients.idIndex[id]
if ok {
@ -650,7 +650,7 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
clients.addLocked(c)
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list))
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
return true, nil
}