diff --git a/internal/home/client.go b/internal/home/client.go index a3054656..64e9b677 100644 --- a/internal/home/client.go +++ b/internal/home/client.go @@ -3,13 +3,20 @@ package home import ( "encoding" "fmt" + "net" + "net/netip" + "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" "github.com/google/uuid" + "golang.org/x/exp/slices" ) // UID is the type for the unique IDs of persistent clients. @@ -56,10 +63,15 @@ type persistentClient struct { Name string - IDs []string Tags []string Upstreams []string + IPs []netip.Addr + // TODO(s.chzhen): Use netutil.Prefix. + Subnets []netip.Prefix + MACs []net.HardwareAddr + ClientIDs []string + // UID is the unique identifier of the persistent client. UID UID @@ -75,17 +87,149 @@ type persistentClient struct { IgnoreStatistics bool } -// ShallowClone returns a deep copy of the client, except upstreamConfig, +// setTags sets the tags if they are known, otherwise logs an unknown tag. +func (c *persistentClient) setTags(tags []string, known *stringutil.Set) { + for _, t := range tags { + if !known.Has(t) { + log.Info("skipping unknown tag %q", t) + + continue + } + + c.Tags = append(c.Tags, t) + } + + slices.Sort(c.Tags) +} + +// setIDs parses a list of strings into typed fields and returns an error if +// there is one. +func (c *persistentClient) setIDs(ids []string) (err error) { + for _, id := range ids { + err = c.setID(id) + if err != nil { + return err + } + } + + slices.SortFunc(c.IPs, netip.Addr.Compare) + + // TODO(s.chzhen): Use netip.PrefixCompare in Go 1.23. + slices.SortFunc(c.Subnets, subnetCompare) + slices.SortFunc(c.MACs, slices.Compare[net.HardwareAddr]) + slices.Sort(c.ClientIDs) + + return nil +} + +// subnetCompare is a comparison function for the two subnets. It returns -1 if +// x sorts before y, 1 if x sorts after y, and 0 if their relative sorting +// position is the same. +func subnetCompare(x, y netip.Prefix) (cmp int) { + if x == y { + return 0 + } + + xAddr, xBits := x.Addr(), x.Bits() + yAddr, yBits := y.Addr(), y.Bits() + if xBits == yBits { + return xAddr.Compare(yAddr) + } + + if xBits > yBits { + return -1 + } else { + return 1 + } +} + +// setID parses id into typed field if there is no error. +func (c *persistentClient) setID(id string) (err error) { + if id == "" { + return errors.Error("clientid is empty") + } + + var ip netip.Addr + if ip, err = netip.ParseAddr(id); err == nil { + c.IPs = append(c.IPs, ip) + + return nil + } + + var subnet netip.Prefix + if subnet, err = netip.ParsePrefix(id); err == nil { + c.Subnets = append(c.Subnets, subnet) + + return nil + } + + var mac net.HardwareAddr + if mac, err = net.ParseMAC(id); err == nil { + c.MACs = append(c.MACs, mac) + + return nil + } + + err = dnsforward.ValidateClientID(id) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } + + c.ClientIDs = append(c.ClientIDs, strings.ToLower(id)) + + return nil +} + +// ids returns a list of client ids containing at least one element. +func (c *persistentClient) ids() (ids []string) { + ids = make([]string, 0, c.idsLen()) + + for _, ip := range c.IPs { + ids = append(ids, ip.String()) + } + + for _, subnet := range c.Subnets { + ids = append(ids, subnet.String()) + } + + for _, mac := range c.MACs { + ids = append(ids, mac.String()) + } + + return append(ids, c.ClientIDs...) +} + +// idsLen returns a length of client ids. +func (c *persistentClient) idsLen() (n int) { + return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) +} + +// equalIDs returns true if the ids of the current and previous clients are the +// same. +func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) { + return slices.Equal(c.IPs, prev.IPs) && + slices.Equal(c.Subnets, prev.Subnets) && + slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) && + slices.Equal(c.ClientIDs, prev.ClientIDs) +} + +// shallowClone returns a deep copy of the client, except upstreamConfig, // safeSearchConf, SafeSearch fields, because it's difficult to copy them. -func (c *persistentClient) ShallowClone() (sh *persistentClient) { - clone := *c +func (c *persistentClient) shallowClone() (clone *persistentClient) { + clone = &persistentClient{} + *clone = *c clone.BlockedServices = c.BlockedServices.Clone() - clone.IDs = stringutil.CloneSlice(c.IDs) - clone.Tags = stringutil.CloneSlice(c.Tags) - clone.Upstreams = stringutil.CloneSlice(c.Upstreams) + clone.Tags = slices.Clone(c.Tags) + clone.Upstreams = slices.Clone(c.Upstreams) - return &clone + clone.IPs = slices.Clone(c.IPs) + clone.Subnets = slices.Clone(c.Subnets) + clone.MACs = slices.Clone(c.MACs) + clone.ClientIDs = slices.Clone(c.ClientIDs) + + return clone } // closeUpstreams closes the client-specific upstream config of c if any. diff --git a/internal/home/client_internal_test.go b/internal/home/client_internal_test.go new file mode 100644 index 00000000..c360cb19 --- /dev/null +++ b/internal/home/client_internal_test.go @@ -0,0 +1,124 @@ +package home + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPersistentClient_EqualIDs(t *testing.T) { + const ( + ip = "0.0.0.0" + ip1 = "1.1.1.1" + ip2 = "2.2.2.2" + + cidr = "0.0.0.0/0" + cidr1 = "1.1.1.1/11" + cidr2 = "2.2.2.2/22" + + mac = "00-00-00-00-00-00" + mac1 = "11-11-11-11-11-11" + mac2 = "22-22-22-22-22-22" + + cli = "client0" + cli1 = "client1" + cli2 = "client2" + ) + + testCases := []struct { + name string + ids []string + prevIDs []string + want assert.BoolAssertionFunc + }{{ + name: "single_ip", + ids: []string{ip1}, + prevIDs: []string{ip1}, + want: assert.True, + }, { + name: "single_ip_not_equal", + ids: []string{ip1}, + prevIDs: []string{ip2}, + want: assert.False, + }, { + name: "ips_not_equal", + ids: []string{ip1, ip2}, + prevIDs: []string{ip1, ip}, + want: assert.False, + }, { + name: "ips_mixed_equal", + ids: []string{ip1, ip2}, + prevIDs: []string{ip2, ip1}, + want: assert.True, + }, { + name: "single_subnet", + ids: []string{cidr1}, + prevIDs: []string{cidr1}, + want: assert.True, + }, { + name: "subnets_not_equal", + ids: []string{ip1, ip2, cidr1, cidr2}, + prevIDs: []string{ip1, ip2, cidr1, cidr}, + want: assert.False, + }, { + name: "subnets_mixed_equal", + ids: []string{ip1, ip2, cidr1, cidr2}, + prevIDs: []string{cidr2, cidr1, ip2, ip1}, + want: assert.True, + }, { + name: "single_mac", + ids: []string{mac1}, + prevIDs: []string{mac1}, + want: assert.True, + }, { + name: "single_mac_not_equal", + ids: []string{mac1}, + prevIDs: []string{mac2}, + want: assert.False, + }, { + name: "macs_not_equal", + ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2}, + prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac}, + want: assert.False, + }, { + name: "macs_mixed_equal", + ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2}, + prevIDs: []string{mac2, mac1, cidr2, cidr1, ip2, ip1}, + want: assert.True, + }, { + name: "single_client_id", + ids: []string{cli1}, + prevIDs: []string{cli1}, + want: assert.True, + }, { + name: "single_client_id_not_equal", + ids: []string{cli1}, + prevIDs: []string{cli2}, + want: assert.False, + }, { + name: "client_ids_not_equal", + ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2}, + prevIDs: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli}, + want: assert.False, + }, { + name: "client_ids_mixed_equal", + ids: []string{ip1, ip2, cidr1, cidr2, mac1, mac2, cli1, cli2}, + prevIDs: []string{cli2, cli1, mac2, mac1, cidr2, cidr1, ip2, ip1}, + want: assert.True, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c := &persistentClient{} + err := c.setIDs(tc.ids) + require.NoError(t, err) + + prev := &persistentClient{} + err = prev.setIDs(tc.prevIDs) + require.NoError(t, err) + + tc.want(t, c.equalIDs(prev)) + }) + } +} diff --git a/internal/home/clients.go b/internal/home/clients.go index e7a8fb1c..07256fea 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -1,7 +1,6 @@ package home import ( - "bytes" "fmt" "net" "net/netip" @@ -218,7 +217,6 @@ func (o *clientObject) toPersistent( cli = &persistentClient{ Name: o.Name, - IDs: o.IDs, Upstreams: o.Upstreams, UID: o.UID, @@ -235,6 +233,11 @@ func (o *clientObject) toPersistent( UpstreamsCacheSize: o.UpstreamsCacheSize, } + err = cli.setIDs(o.IDs) + if err != nil { + return nil, fmt.Errorf("parsing ids: %w", err) + } + if (cli.UID == UID{}) { cli.UID, err = NewUID() if err != nil { @@ -262,15 +265,7 @@ func (o *clientObject) toPersistent( cli.BlockedServices = o.BlockedServices.Clone() - for _, t := range o.Tags { - if allTags.Has(t) { - cli.Tags = append(cli.Tags, t) - } else { - log.Info("skipping unknown tag %q", t) - } - } - - slices.Sort(cli.Tags) + cli.setTags(o.Tags, allTags) return cli, nil } @@ -310,7 +305,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { BlockedServices: cli.BlockedServices.Clone(), - IDs: stringutil.CloneSlice(cli.IDs), + IDs: cli.ids(), Tags: stringutil.CloneSlice(cli.Tags), Upstreams: stringutil.CloneSlice(cli.Upstreams), @@ -449,7 +444,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) return nil, false } - return c.ShallowClone(), true + return c.shallowClone(), true } // shouldCountClient is a wrapper around [clientsContainer.find] to make it a @@ -534,13 +529,7 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok } for _, c = range clients.list { - for _, id := range c.IDs { - var subnet netip.Prefix - subnet, err = netip.ParsePrefix(id) - if err != nil { - continue - } - + for _, subnet := range c.Subnets { if subnet.Contains(ip) { return c, true } @@ -560,15 +549,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, o } for _, c = range clients.list { - for _, id := range c.IDs { - mac, err := net.ParseMAC(id) - if err != nil { - continue - } - - if bytes.Equal(mac, foundMAC) { - return c, true - } + _, found := slices.BinarySearchFunc(c.MACs, foundMAC, slices.Compare[net.HardwareAddr]) + if found { + return c, true } } @@ -608,35 +591,26 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru return rc, ok } -// check validates the client. +// check validates the client. It also sorts the client tags. func (clients *clientsContainer) check(c *persistentClient) (err error) { switch { case c == nil: return errors.Error("client is nil") case c.Name == "": return errors.Error("invalid name") - case len(c.IDs) == 0: + case c.idsLen() == 0: return errors.Error("id required") default: // Go on. } - for i, id := range c.IDs { - var norm string - norm, err = normalizeClientIdentifier(id) - if err != nil { - return fmt.Errorf("client at index %d: %w", i, err) - } - - c.IDs[i] = norm - } - for _, t := range c.Tags { if !clients.allTags.Has(t) { return fmt.Errorf("invalid tag: %q", t) } } + // TODO(s.chzhen): Move to the constructor. slices.Sort(c.Tags) err = dnsforward.ValidateUpstreams(c.Upstreams) @@ -647,35 +621,6 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) { return nil } -// normalizeClientIdentifier returns a normalized version of idStr. If idStr -// cannot be normalized, it returns an error. -func normalizeClientIdentifier(idStr string) (norm string, err error) { - if idStr == "" { - return "", errors.Error("clientid is empty") - } - - var ip netip.Addr - if ip, err = netip.ParseAddr(idStr); err == nil { - return ip.String(), nil - } - - var subnet netip.Prefix - if subnet, err = netip.ParsePrefix(idStr); err == nil { - return subnet.String(), nil - } - - var mac net.HardwareAddr - if mac, err = net.ParseMAC(idStr); err == nil { - return mac.String(), nil - } - - if err = dnsforward.ValidateClientID(idStr); err == nil { - return strings.ToLower(idStr), nil - } - - return "", fmt.Errorf("bad client identifier %q", idStr) -} - // add adds a new client object. ok is false if such client already exists or // if an error occurred. func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { @@ -694,7 +639,8 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { } // check ID index - for _, id := range c.IDs { + ids := c.ids() + for _, id := range ids { var c2 *persistentClient c2, ok = clients.idIndex[id] if ok { @@ -704,7 +650,7 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list)) return true, nil } @@ -715,7 +661,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) { clients.list[c.Name] = c // update ID index - for _, id := range c.IDs { + for _, id := range c.ids() { clients.idIndex[id] = c } } @@ -747,7 +693,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) { delete(clients.list, c.Name) // Update the ID index. - for _, id := range c.IDs { + for _, id := range c.ids() { delete(clients.idIndex, id) } } @@ -771,13 +717,18 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { } } + if c.equalIDs(prev) { + clients.removeLocked(prev) + clients.addLocked(c) + + return nil + } + // Check the ID index. - if !slices.Equal(prev.IDs, c.IDs) { - for _, id := range c.IDs { - existing, ok := clients.idIndex[id] - if ok && existing != prev { - return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) - } + for _, id := range c.ids() { + existing, ok := clients.idIndex[id] + if ok && existing != prev { + return fmt.Errorf("id %q is used by client with name %q", id, existing.Name) } } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 7acf2ac9..07332ecf 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -62,11 +62,13 @@ func TestClients(t *testing.T) { cli1IP = netip.MustParseAddr(cli1) cli2IP = netip.MustParseAddr(cli2) + + cliIPv6 = netip.MustParseAddr("1:2:3::4") ) c := &persistentClient{ - IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", + IPs: []netip.Addr{cli1IP, cliIPv6}, } ok, err := clients.add(c) @@ -75,8 +77,8 @@ func TestClients(t *testing.T) { assert.True(t, ok) c = &persistentClient{ - IDs: []string{cli2}, Name: "client2", + IPs: []netip.Addr{cli2IP}, } ok, err = clients.add(c) @@ -108,8 +110,8 @@ func TestClients(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) { ok, err := clients.add(&persistentClient{ - IDs: []string{"1.2.3.5"}, Name: "client1", + IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) require.NoError(t, err) assert.False(t, ok) @@ -117,7 +119,6 @@ func TestClients(t *testing.T) { t.Run("add_fail_ip", func(t *testing.T) { ok, err := clients.add(&persistentClient{ - IDs: []string{"2.2.2.2"}, Name: "client3", }) require.Error(t, err) @@ -126,7 +127,6 @@ func TestClients(t *testing.T) { t.Run("update_fail_ip", func(t *testing.T) { err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{ - IDs: []string{"2.2.2.2"}, Name: "client1", }) assert.Error(t, err) @@ -144,8 +144,8 @@ func TestClients(t *testing.T) { require.True(t, ok) err := clients.update(prev, &persistentClient{ - IDs: []string{cliNew}, Name: "client1", + IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -158,8 +158,8 @@ func TestClients(t *testing.T) { require.True(t, ok) err = clients.update(prev, &persistentClient{ - IDs: []string{cliNew}, Name: "client1-renamed", + IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) require.NoError(t, err) @@ -175,9 +175,9 @@ func TestClients(t *testing.T) { assert.Nil(t, nilCli) - require.Len(t, c.IDs, 1) + require.Len(t, c.ids(), 1) - assert.Equal(t, cliNew, c.IDs[0]) + assert.Equal(t, cliNewIP, c.IPs[0]) }) t.Run("del_success", func(t *testing.T) { @@ -259,8 +259,8 @@ func TestClientsWHOIS(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") ok, err := clients.add(&persistentClient{ - IDs: []string{"1.1.1.2"}, Name: "client1", + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) assert.True(t, ok) @@ -281,8 +281,10 @@ func TestClientsAddExisting(t *testing.T) { // Add a client. ok, err := clients.add(&persistentClient{ - IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, - Name: "client1", + Name: "client1", + IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, + Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, + MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, }) require.NoError(t, err) assert.True(t, ok) @@ -329,16 +331,16 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the same IP as for a client with MAC. ok, err := clients.add(&persistentClient{ - IDs: []string{ip.String()}, Name: "client2", + IPs: []netip.Addr{ip}, }) require.NoError(t, err) assert.True(t, ok) // Add a new client with the IP from the first client's IP range. ok, err = clients.add(&persistentClient{ - IDs: []string{"2.2.2.2"}, Name: "client3", + IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) assert.True(t, ok) @@ -350,8 +352,8 @@ func TestClientsCustomUpstream(t *testing.T) { // Add client with upstreams. ok, err := clients.add(&persistentClient{ - IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", + IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, Upstreams: []string{ "1.1.1.1", "[/example.org/]8.8.8.8", diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index bab70235..3f2918ca 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -195,9 +195,14 @@ func (clients *clientsContainer) jsonToClient( return nil, err } + err = c.setIDs(cj.IDs) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return nil, err + } + c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled) c.Name = cj.Name - c.IDs = cj.IDs c.Tags = cj.Tags c.Upstreams = cj.Upstreams c.UseOwnSettings = !cj.UseGlobalSettings @@ -286,7 +291,7 @@ func clientToJSON(c *persistentClient) (cj *clientJSON) { return &clientJSON{ Name: c.Name, - IDs: c.IDs, + IDs: c.ids(), Tags: c.Tags, UseGlobalSettings: !c.UseOwnSettings, FilteringEnabled: c.FilteringEnabled,