From 1fc43cb45efbd428abaae9eba030f9bea818dfe3 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Fri, 19 Apr 2024 19:19:48 +0300 Subject: [PATCH] all: add tests --- internal/client/index.go | 22 ++++++++++- internal/client/index_internal_test.go | 1 + internal/client/persistent.go | 4 +- internal/client/persistent_internal_test.go | 2 +- internal/home/clients.go | 42 +++++---------------- internal/home/clients_internal_test.go | 39 +++++++------------ internal/home/clientshttp.go | 8 +--- 7 files changed, 49 insertions(+), 69 deletions(-) diff --git a/internal/client/index.go b/internal/client/index.go index ed76e16f..732802ad 100644 --- a/internal/client/index.go +++ b/internal/client/index.go @@ -92,6 +92,10 @@ func (ci *Index) Add(c *Persistent) { // Clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. func (ci *Index) Clashes(c *Persistent) (err error) { + if p := ci.clashesName(c); p != nil { + return fmt.Errorf("another client uses the same name %q", p.Name) + } + for _, id := range c.ClientIDs { existing, ok := ci.clientIDToUID[id] if ok && existing != c.UID { @@ -119,6 +123,21 @@ func (ci *Index) Clashes(c *Persistent) (err error) { return nil } +// clashesName returns a previous client with the same name as c. c must be +// non-nil. +func (ci *Index) clashesName(c *Persistent) (existing *Persistent) { + existing, ok := ci.FindByName(c.Name) + if !ok { + return nil + } + + if existing.UID != c.UID { + return existing + } + + return nil +} + // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { @@ -272,7 +291,8 @@ func (ci *Index) Size() (n int) { return len(ci.uidToClient) } -// Range calls f for each persistent client. +// Range calls f for each persistent client, unless cont is false. The order is +// undefined. func (ci *Index) Range(f func(c *Persistent) (cont bool)) { for _, c := range ci.uidToClient { if !f(c) { diff --git a/internal/client/index_internal_test.go b/internal/client/index_internal_test.go index abf38710..bb485f2e 100644 --- a/internal/client/index_internal_test.go +++ b/internal/client/index_internal_test.go @@ -56,6 +56,7 @@ func TestClientIndex(t *testing.T) { }} ci := newIDIndex(clients) + require.Equal(t, len(clients), ci.Size()) testCases := []struct { want *Persistent diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 06e346f4..ec6259fd 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -229,9 +229,9 @@ func (c *Persistent) IDsLen() (n int) { return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) } -// EqualIDs returns true if the ids of the current and previous clients are the +// equalIDs returns true if the ids of the current and previous clients are the // same. -func (c *Persistent) EqualIDs(prev *Persistent) (equal bool) { +func (c *Persistent) equalIDs(prev *Persistent) (equal bool) { return slices.Equal(c.IPs, prev.IPs) && slices.Equal(c.Subnets, prev.Subnets) && slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) && diff --git a/internal/client/persistent_internal_test.go b/internal/client/persistent_internal_test.go index 76da1e4b..22f978ec 100644 --- a/internal/client/persistent_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -118,7 +118,7 @@ func TestPersistentClient_EqualIDs(t *testing.T) { err = prev.SetIDs(tc.prevIDs) require.NoError(t, err) - tc.want(t, c.EqualIDs(prev)) + tc.want(t, c.equalIDs(prev)) }) } } diff --git a/internal/home/clients.go b/internal/home/clients.go index 9212ab6d..c924d258 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -96,8 +96,9 @@ func (clients *clientsContainer) Init( arpDB arpdb.Interface, filteringConf *filtering.Config, ) (err error) { + // TODO(s.chzhen): Refactor it. if clients.clientIndex != nil { - log.Fatal("clients.list != nil") + return errors.Error("clients container already initialized") } clients.runtimeIndex = client.NewRuntimeIndex() @@ -280,7 +281,7 @@ func (clients *clientsContainer) addFromConfig( return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) } - _, err = clients.add(cli) + err = clients.add(cli) if err != nil { log.Error("clients: adding client at index %d %s: %s", i, cli.Name, err) } @@ -605,37 +606,28 @@ func (clients *clientsContainer) check(c *client.Persistent) (err error) { return nil } -// add adds a new client object. ok is false if such client already exists or -// if an error occurred. -func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) { +// add adds a persistent client or returns an error. +func (clients *clientsContainer) add(c *client.Persistent) (err error) { err = clients.check(c) if err != nil { - return false, err + // Don't wrap the error since it's informative enough as is. + return err } clients.lock.Lock() defer clients.lock.Unlock() - // check Name index - // - // TODO(s.chzhen): Use [client.Index.Clashes]. - _, ok = clients.clientIndex.FindByName(c.Name) - if ok { - return false, nil - } - - // check ID index err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. - return false, err + return err } clients.addLocked(c) log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), clients.clientIndex.Size()) - return true, nil + return nil } // addLocked c to the indexes. clients.lock is expected to be locked. @@ -680,22 +672,6 @@ func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) clients.lock.Lock() defer clients.lock.Unlock() - // Check the name index. - if prev.Name != c.Name { - _, ok := clients.clientIndex.FindByName(c.Name) - if ok { - return errors.Error("client already exists") - } - } - - if c.EqualIDs(prev) { - clients.removeLocked(prev) - clients.addLocked(c) - - return nil - } - - // Check the ID index. err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 9d041994..0eb53284 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -72,23 +72,19 @@ func TestClients(t *testing.T) { IPs: []netip.Addr{cli1IP, cliIPv6}, } - ok, err := clients.add(c) + err := clients.add(c) require.NoError(t, err) - assert.True(t, ok) - c = &client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{cli2IP}, } - ok, err = clients.add(c) + err = clients.add(c) require.NoError(t, err) - assert.True(t, ok) - - c, ok = clients.find(cli1) + c, ok := clients.find(cli1) require.True(t, ok) assert.Equal(t, "client1", c.Name) @@ -111,22 +107,20 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_name", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) - require.NoError(t, err) - assert.False(t, ok) + require.Error(t, err) }) t.Run("add_fail_ip", func(t *testing.T) { - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), }) require.Error(t, err) - assert.False(t, ok) }) t.Run("update_fail_ip", func(t *testing.T) { @@ -151,7 +145,7 @@ func TestClients(t *testing.T) { err := clients.update(prev, &client.Persistent{ Name: "client1", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -167,7 +161,7 @@ func TestClients(t *testing.T) { err = clients.update(prev, &client.Persistent{ Name: "client1-renamed", - UID: client.MustNewUID(), + UID: prev.UID, IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -267,13 +261,12 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) - assert.True(t, ok) clients.setWHOISInfo(ip, whois) rc := clients.runtimeIndex.Client(ip) @@ -290,7 +283,7 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -298,10 +291,9 @@ func TestClientsAddExisting(t *testing.T) { MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, }) require.NoError(t, err) - assert.True(t, ok) // Now add an auto-client with the same IP. - ok = clients.addHost(ip, "test", client.SourceRDNS) + ok := clients.addHost(ip, "test", client.SourceRDNS) assert.True(t, ok) }) @@ -341,22 +333,20 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - ok, err := clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) - assert.True(t, ok) // Add a new client with the IP from the first client's IP range. - ok, err = clients.add(&client.Persistent{ + err = clients.add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) - assert.True(t, ok) }) } @@ -364,7 +354,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - ok, err := clients.add(&client.Persistent{ + err := clients.add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, @@ -374,7 +364,6 @@ func TestClientsCustomUpstream(t *testing.T) { }, }) require.NoError(t, err) - assert.True(t, ok) upsConf, err := clients.UpstreamConfigByID("1.2.3.4", net.DefaultResolver) assert.Nil(t, upsConf) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 6fbe1418..5ad08e4c 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -336,19 +336,13 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - ok, err := clients.add(c) + err = clients.add(c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - if !ok { - aghhttp.Error(r, w, http.StatusBadRequest, "Client already exists") - - return - } - onConfigModified() }