From 243123a404bb5279a27de18391fa58a9d3e6149b Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Thu, 15 Aug 2024 19:15:54 +0300 Subject: [PATCH] all: add tests --- internal/client/storage_test.go | 304 +++++++++++++++++++++++++++++--- internal/home/clients.go | 4 +- 2 files changed, 280 insertions(+), 28 deletions(-) diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 3e966575..4a534580 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,9 +26,19 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { require.NoError(tb, s.Add(c)) } + require.Equal(tb, len(m), s.Size()) + return s } +// newRuntimeClient is a helper function that returns a new runtime client. +func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) { + rc = client.NewRuntime(ip) + rc.SetInfo(source, []string{host}) + + return rc +} + // mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an // error. func mustParseMAC(s string) (mac net.HardwareAddr) { @@ -43,6 +54,9 @@ func TestStorage_Add(t *testing.T) { const ( existingName = "existing_name" existingClientID = "existing_client_id" + + allowedTag = "tag" + notAllowedTag = "not_allowed_tag" ) var ( @@ -60,7 +74,7 @@ func TestStorage_Add(t *testing.T) { } s := client.NewStorage(&client.Config{ - AllowedTags: nil, + AllowedTags: []string{allowedTag}, }) err := s.Add(existingClient) require.NoError(t, err) @@ -119,6 +133,15 @@ func TestStorage_Add(t *testing.T) { }, wantErrMsg: `adding client: another client "existing_name" ` + `uses the same ClientID "existing_client_id"`, + }, { + name: "not_allowed_tag", + cli: &client.Persistent{ + Name: "nont_allowed_tag", + Tags: []string{notAllowedTag}, + IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, + UID: client.MustNewUID(), + }, + wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, }} for _, tc := range testCases { @@ -341,6 +364,127 @@ func TestStorage_FindLoose(t *testing.T) { } } +func TestStorage_FindByName(t *testing.T) { + const ( + cliIP1 = "1.1.1.1" + cliIP2 = "2.2.2.2" + ) + + const ( + clientExistingName = "client_existing" + clientAnotherExistingName = "client_another_existing" + nonExistingClientName = "client_non_existing" + ) + + var ( + clientExisting = &client.Persistent{ + Name: clientExistingName, + IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, + } + + clientAnotherExisting = &client.Persistent{ + Name: clientAnotherExistingName, + IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, + } + ) + + clients := []*client.Persistent{ + clientExisting, + clientAnotherExisting, + } + s := newStorage(t, clients) + + testCases := []struct { + want *client.Persistent + name string + clientName string + }{{ + name: "existing", + clientName: clientExistingName, + want: clientExisting, + }, { + name: "another_existing", + clientName: clientAnotherExistingName, + want: clientAnotherExisting, + }, { + name: "non_existing", + clientName: nonExistingClientName, + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := s.FindByName(tc.clientName) + if tc.want == nil { + assert.False(t, ok) + + return + } + + assert.True(t, ok) + assert.Equal(t, tc.want, c) + }) + } +} + +func TestStorage_FindByMAC(t *testing.T) { + var ( + cliMAC = mustParseMAC("11:11:11:11:11:11") + cliAnotherMAC = mustParseMAC("22:22:22:22:22:22") + nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33") + ) + + var ( + clientExisting = &client.Persistent{ + Name: "client", + MACs: []net.HardwareAddr{cliMAC}, + } + + clientAnotherExisting = &client.Persistent{ + Name: "another_client", + MACs: []net.HardwareAddr{cliAnotherMAC}, + } + ) + + clients := []*client.Persistent{ + clientExisting, + clientAnotherExisting, + } + s := newStorage(t, clients) + + testCases := []struct { + want *client.Persistent + name string + clientMAC net.HardwareAddr + }{{ + name: "existing", + clientMAC: cliMAC, + want: clientExisting, + }, { + name: "another_existing", + clientMAC: cliAnotherMAC, + want: clientAnotherExisting, + }, { + name: "non_existing", + clientMAC: nonExistingClientMAC, + want: nil, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + c, ok := s.FindByMAC(tc.clientMAC) + if tc.want == nil { + assert.False(t, ok) + + return + } + + assert.True(t, ok) + assert.Equal(t, tc.want, c) + }) + } +} + func TestStorage_Update(t *testing.T) { const ( clientName = "client_name" @@ -482,40 +626,148 @@ func TestStorage_RangeByName(t *testing.T) { func TestStorage_UpdateRuntime(t *testing.T) { const ( - addedARP = "added_arp" + addedARP = "added_arp" + addedSecondARP = "added_arp" - updatedARP = "updated_arp" - updatedHostsFile = "updated_hosts" + updatedARP = "updated_arp" + + cliCity = "City" + cliCountry = "Country" + cliOrgname = "Orgname" ) - ip := netip.MustParseAddr("1.1.1.1") - - added := client.NewRuntime(ip) - added.SetInfo(client.SourceARP, []string{addedARP}) + var ( + ip = netip.MustParseAddr("1.1.1.1") + ip2 = netip.MustParseAddr("2.2.2.2") + ) updated := client.NewRuntime(ip) updated.SetInfo(client.SourceARP, []string{updatedARP}) - updated.SetInfo(client.SourceHostsFile, []string{updatedHostsFile}) - s := newStorage(t, nil) + info := &whois.Info{ + City: cliCity, + Country: cliCountry, + Orgname: cliOrgname, + } + updated.SetWHOIS(info) - s.UpdateRuntime(added) - got := s.ClientRuntime(ip) - source, host := got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, addedARP, host) + s := client.NewStorage(&client.Config{ + AllowedTags: nil, + }) - s.UpdateRuntime(updated) - got = s.ClientRuntime(ip) - source, host = got.Info() - assert.Equal(t, client.SourceHostsFile, source) - assert.Equal(t, updatedHostsFile, host) + t.Run("add_arp_client", func(t *testing.T) { + added := client.NewRuntime(ip) + added.SetInfo(client.SourceARP, []string{addedARP}) - n := s.DeleteBySource(client.SourceHostsFile) - require.Equal(t, 0, n) + s.UpdateRuntime(added) + require.Equal(t, 1, s.SizeRuntime()) - got = s.ClientRuntime(ip) - source, host = got.Info() - assert.Equal(t, client.SourceARP, source) - assert.Equal(t, updatedARP, host) + got := s.ClientRuntime(ip) + source, host := got.Info() + assert.Equal(t, client.SourceARP, source) + assert.Equal(t, addedARP, host) + }) + + t.Run("add_second_arp_client", func(t *testing.T) { + added := client.NewRuntime(ip2) + added.SetInfo(client.SourceARP, []string{addedSecondARP}) + + s.UpdateRuntime(added) + require.Equal(t, 2, s.SizeRuntime()) + + got := s.ClientRuntime(ip2) + source, host := got.Info() + assert.Equal(t, client.SourceARP, source) + assert.Equal(t, addedSecondARP, host) + }) + + t.Run("update_first_client", func(t *testing.T) { + s.UpdateRuntime(updated) + got := s.ClientRuntime(ip) + require.Equal(t, 2, s.SizeRuntime()) + + source, host := got.Info() + assert.Equal(t, client.SourceARP, source) + assert.Equal(t, updatedARP, host) + }) + + t.Run("remove_arp_info", func(t *testing.T) { + n := s.DeleteBySource(client.SourceARP) + require.Equal(t, 1, n) + require.Equal(t, 1, s.SizeRuntime()) + + got := s.ClientRuntime(ip) + source, _ := got.Info() + assert.Equal(t, client.SourceWHOIS, source) + assert.Equal(t, info, got.WHOIS()) + }) + + t.Run("remove_whois_info", func(t *testing.T) { + n := s.DeleteBySource(client.SourceWHOIS) + require.Equal(t, 1, n) + require.Equal(t, 0, s.SizeRuntime()) + }) +} + +func TestStorage_BatchUpdateBySource(t *testing.T) { + const ( + defSrc = client.SourceARP + + cliFirstHost1 = "host1" + cliFirstHost2 = "host2" + cliUpdatedHost3 = "host3" + cliUpdatedHost4 = "host4" + cliUpdatedHost5 = "host5" + ) + + var ( + cliFirstIP1 = netip.MustParseAddr("1.1.1.1") + cliFirstIP2 = netip.MustParseAddr("2.2.2.2") + cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3") + cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4") + cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5") + ) + + firstClients := []*client.Runtime{ + newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1), + newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2), + } + + updatedClients := []*client.Runtime{ + newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3), + newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4), + newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5), + } + + s := client.NewStorage(&client.Config{ + AllowedTags: nil, + }) + + t.Run("populate_storage_with_first_clients", func(t *testing.T) { + s.BatchUpdateBySource(defSrc, firstClients) + require.Equal(t, len(firstClients), s.SizeRuntime()) + + rc := s.ClientRuntime(cliFirstIP1) + src, host := rc.Info() + assert.Equal(t, defSrc, src) + assert.Equal(t, cliFirstHost1, host) + }) + + t.Run("update_storage", func(t *testing.T) { + s.BatchUpdateBySource(defSrc, updatedClients) + require.Equal(t, len(updatedClients), s.SizeRuntime()) + + rc := s.ClientRuntime(cliUpdatedIP3) + src, host := rc.Info() + assert.Equal(t, defSrc, src) + assert.Equal(t, cliUpdatedHost3, host) + + rc = s.ClientRuntime(cliFirstIP1) + assert.Nil(t, rc) + }) + + t.Run("remove_all", func(t *testing.T) { + s.BatchUpdateBySource(defSrc, []*client.Runtime{}) + require.Equal(t, 0, s.SizeRuntime()) + }) } diff --git a/internal/home/clients.go b/internal/home/clients.go index a94d5ec1..e64f5d15 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -643,7 +643,7 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag defer clients.lock.Unlock() added := 0 - rcs := []*client.Runtime{} + var rcs []*client.Runtime hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. @@ -685,7 +685,7 @@ func (clients *clientsContainer) addFromSystemARP() { defer clients.lock.Unlock() added := 0 - rcs := []*client.Runtime{} + var rcs []*client.Runtime for _, n := range ns { rc := client.NewRuntime(n.IP) rc.SetInfo(client.SourceARP, []string{n.Name})