package client_test import ( "net" "net/netip" "runtime" "sync" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type testHostsContainer struct { onUpd func() (updates <-chan *hostsfile.DefaultStorage) } // type check var _ client.HostsContainer = (*testHostsContainer)(nil) func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { return c.onUpd() } // Interface stores and refreshes the network neighborhood reported by ARP // (Address Resolution Protocol). type Interface interface { // Refresh updates the stored data. It must be safe for concurrent use. Refresh() (err error) // Neighbors returnes the last set of data reported by ARP. Both the method // and it's result must be safe for concurrent use. Neighbors() (ns []arpdb.Neighbor) } type testARP struct { onRefresh func() (err error) onNeighbors func() (ns []arpdb.Neighbor) } func (c *testARP) Refresh() (err error) { return c.onRefresh() } func (c *testARP) Neighbors() (ns []arpdb.Neighbor) { return c.onNeighbors() } type testDHCP struct { OnLeases func() (leases []*dhcpsvc.Lease) OnHostBy func(ip netip.Addr) (host string) OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) } // Lease implements the [DHCP] interface for testDHCP. func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } // HostByIP implements the [DHCP] interface for testDHCP. func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } // MACByIP implements the [DHCP] interface for testDHCP. func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } // compareRuntimeInfo is a helper function that returns true if the runtime // client has provided info. func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) { s, h := rc.Info() if s != src { return false } else if h != host { return false } return true } func TestStorage_Add_hostsfile(t *testing.T) { var ( cliIP1 = netip.MustParseAddr("1.1.1.1") cliName1 = "client_one" cliIP2 = netip.MustParseAddr("2.2.2.2") cliName2 = "client_two" ) hostCh := make(chan *hostsfile.DefaultStorage) h := &testHostsContainer{ onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh }, } storage, err := client.NewStorage(&client.Config{ EtcHosts: h, }) require.NoError(t, err) storage.Start() t.Run("add_hosts", func(t *testing.T) { var s *hostsfile.DefaultStorage s, err = hostsfile.NewDefaultStorage() require.NoError(t, err) s.Add(&hostsfile.Record{ Addr: cliIP1, Names: []string{cliName1}, }) testutil.RequireSend(t, hostCh, s, testTimeout) require.Eventually(t, func() (ok bool) { cli1 := storage.ClientRuntime(cliIP1) if cli1 == nil { return false } assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1)) return true }, testTimeout, testTimeout/10) }) t.Run("update_hosts", func(t *testing.T) { var s *hostsfile.DefaultStorage s, err = hostsfile.NewDefaultStorage() require.NoError(t, err) s.Add(&hostsfile.Record{ Addr: cliIP2, Names: []string{cliName2}, }) testutil.RequireSend(t, hostCh, s, testTimeout) require.Eventually(t, func() (ok bool) { cli2 := storage.ClientRuntime(cliIP2) if cli2 == nil { return false } assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2)) cli1 := storage.ClientRuntime(cliIP1) require.Nil(t, cli1) return true }, testTimeout, testTimeout/10) }) } func TestStorage_Add_arp(t *testing.T) { var ( mu sync.Mutex neighbors []arpdb.Neighbor cliIP1 = netip.MustParseAddr("1.1.1.1") cliName1 = "client_one" cliIP2 = netip.MustParseAddr("2.2.2.2") cliName2 = "client_two" ) a := &testARP{ onRefresh: func() (err error) { return nil }, onNeighbors: func() (ns []arpdb.Neighbor) { mu.Lock() defer mu.Unlock() return neighbors }, } storage, err := client.NewStorage(&client.Config{ ARPDB: a, ARPClientsUpdatePeriod: testTimeout / 10, }) require.NoError(t, err) storage.Start() t.Run("add_hosts", func(t *testing.T) { func() { mu.Lock() defer mu.Unlock() neighbors = []arpdb.Neighbor{{ Name: cliName1, IP: cliIP1, }} }() require.Eventually(t, func() (ok bool) { cli1 := storage.ClientRuntime(cliIP1) if cli1 == nil { return false } assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1)) return true }, testTimeout, testTimeout/10) }) t.Run("update_hosts", func(t *testing.T) { func() { mu.Lock() defer mu.Unlock() neighbors = []arpdb.Neighbor{{ Name: cliName2, IP: cliIP2, }} }() require.Eventually(t, func() (ok bool) { cli2 := storage.ClientRuntime(cliIP2) if cli2 == nil { return false } assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2)) cli1 := storage.ClientRuntime(cliIP1) require.Nil(t, cli1) return true }, testTimeout, testTimeout/10) }) } func TestStorage_Add_whois(t *testing.T) { var ( cliIP1 = netip.MustParseAddr("1.1.1.1") cliIP2 = netip.MustParseAddr("2.2.2.2") cliName2 = "client_two" cliIP3 = netip.MustParseAddr("3.3.3.3") cliName3 = "client_three" ) storage, err := client.NewStorage(&client.Config{}) require.NoError(t, err) whois := &whois.Info{ Country: "AU", Orgname: "Example Org", } t.Run("new_client", func(t *testing.T) { storage.UpdateAddress(cliIP1, "", whois) cli1 := storage.ClientRuntime(cliIP1) require.NotNil(t, cli1) assert.Equal(t, whois, cli1.WHOIS()) }) t.Run("existing_runtime_client", func(t *testing.T) { storage.UpdateAddress(cliIP2, cliName2, nil) storage.UpdateAddress(cliIP2, "", whois) cli2 := storage.ClientRuntime(cliIP2) require.NotNil(t, cli2) assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2)) assert.Equal(t, whois, cli2.WHOIS()) }) t.Run("can't_set_persistent_client", func(t *testing.T) { err = storage.Add(&client.Persistent{ Name: cliName3, UID: client.MustNewUID(), IPs: []netip.Addr{cliIP3}, }) require.NoError(t, err) storage.UpdateAddress(cliIP3, "", whois) rc := storage.ClientRuntime(cliIP3) require.Nil(t, rc) }) } func TestClientsDHCP(t *testing.T) { var ( cliIP1 = netip.MustParseAddr("1.1.1.1") cliName1 = "client_one" prsCliIP = netip.MustParseAddr("4.3.2.1") prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") prsCliName = "persitent_client" ) ipToHost := map[netip.Addr]string{ cliIP1: cliName1, } ipToMAC := map[netip.Addr]net.HardwareAddr{ prsCliIP: prsCliMAC, } d := &testDHCP{ OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, OnHostBy: func(ip netip.Addr) (host string) { return ipToHost[ip] }, OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return ipToMAC[ip] }, } storage, err := client.NewStorage(&client.Config{ DHCP: d, }) require.NoError(t, err) t.Run("find_runtime", func(t *testing.T) { cli1 := storage.ClientRuntime(cliIP1) require.NotNil(t, cli1) assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1)) }) t.Run("find_persistent", func(t *testing.T) { err = storage.Add(&client.Persistent{ Name: prsCliName, UID: client.MustNewUID(), MACs: []net.HardwareAddr{prsCliMAC}, }) require.NoError(t, err) prsCli, ok := storage.Find(prsCliIP.String()) require.True(t, ok) assert.Equal(t, prsCliName, prsCli.Name) }) } func TestClientsAddExisting(t *testing.T) { // First, init a DHCP server with a single static lease. config := &dhcpd.ServerConfig{ Enabled: true, DataDir: t.TempDir(), Conf4: dhcpd.V4ServerConf{ Enabled: true, GatewayIP: netip.MustParseAddr("1.2.3.1"), SubnetMask: netip.MustParseAddr("255.255.255.0"), RangeStart: netip.MustParseAddr("1.2.3.2"), RangeEnd: netip.MustParseAddr("1.2.3.10"), }, } dhcpServer, err := dhcpd.Create(config) require.NoError(t, err) storage, err := client.NewStorage(&client.Config{ DHCP: dhcpServer, }) require.NoError(t, err) t.Run("simple", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. err = storage.Add(&client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, }) require.NoError(t, err) // Now add an auto-client with the same IP. storage.UpdateAddress(ip, "test", nil) rc := storage.ClientRuntime(ip) assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test")) }) t.Run("complicated", func(t *testing.T) { // TODO(a.garipov): Properly decouple the DHCP server from the client // storage. if runtime.GOOS == "windows" { t.Skip("skipping dhcp test on windows") } ip := netip.MustParseAddr("1.2.3.4") err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: ip, Hostname: "testhost", Expiry: time.Now().Add(time.Hour), }) require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. err = storage.Add(&client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) // Add a new client with the IP from the first client's IP range. err = storage.Add(&client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) }) } // newStorage is a helper function that returns a client storage filled with // persistent clients from the m. It also generates a UID for each client. func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) require.NoError(tb, err) for _, c := range m { c.UID = client.MustNewUID() require.NoError(tb, s.Add(c)) } require.Equal(tb, len(m), s.Size()) return s } // newRuntimeClient is a helper function that returns a new runtime client. func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) { rc = client.NewRuntime(ip) rc.SetInfo(source, []string{host}) return rc } // mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an // error. func mustParseMAC(s string) (mac net.HardwareAddr) { mac, err := net.ParseMAC(s) if err != nil { panic(err) } return mac } func TestStorage_Add(t *testing.T) { const ( existingName = "existing_name" existingClientID = "existing_client_id" allowedTag = "tag" notAllowedTag = "not_allowed_tag" ) var ( existingClientUID = client.MustNewUID() existingIP = netip.MustParseAddr("1.2.3.4") existingSubnet = netip.MustParsePrefix("1.2.3.0/24") ) existingClient := &client.Persistent{ Name: existingName, IPs: []netip.Addr{existingIP}, Subnets: []netip.Prefix{existingSubnet}, ClientIDs: []string{existingClientID}, UID: existingClientUID, } s, err := client.NewStorage(&client.Config{ AllowedTags: []string{allowedTag}, }) require.NoError(t, err) err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { name string cli *client.Persistent wantErrMsg string }{{ name: "basic", cli: &client.Persistent{ Name: "basic", IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, UID: client.MustNewUID(), }, wantErrMsg: "", }, { name: "duplicate_uid", cli: &client.Persistent{ Name: "no_uid", IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, UID: existingClientUID, }, wantErrMsg: `adding client: another client "existing_name" uses the same uid`, }, { name: "duplicate_name", cli: &client.Persistent{ Name: existingName, IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: another client uses the same name "existing_name"`, }, { name: "duplicate_ip", cli: &client.Persistent{ Name: "duplicate_ip", IPs: []netip.Addr{existingIP}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: another client "existing_name" uses the same IP "1.2.3.4"`, }, { name: "duplicate_subnet", cli: &client.Persistent{ Name: "duplicate_subnet", Subnets: []netip.Prefix{existingSubnet}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: another client "existing_name" ` + `uses the same subnet "1.2.3.0/24"`, }, { name: "duplicate_client_id", cli: &client.Persistent{ Name: "duplicate_client_id", ClientIDs: []string{existingClientID}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: another client "existing_name" ` + `uses the same ClientID "existing_client_id"`, }, { name: "not_allowed_tag", cli: &client.Persistent{ Name: "nont_allowed_tag", Tags: []string{notAllowedTag}, IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")}, UID: client.MustNewUID(), }, wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { err = s.Add(tc.cli) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } } func TestStorage_RemoveByName(t *testing.T) { const ( existingName = "existing_name" ) existingClient := &client.Persistent{ Name: existingName, IPs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, UID: client.MustNewUID(), } s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) require.NoError(t, err) err = s.Add(existingClient) require.NoError(t, err) testCases := []struct { want assert.BoolAssertionFunc name string cliName string }{{ name: "existing_client", cliName: existingName, want: assert.True, }, { name: "non_existing_client", cliName: "non_existing_client", want: assert.False, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tc.want(t, s.RemoveByName(tc.cliName)) }) } t.Run("duplicate_remove", func(t *testing.T) { s, err = client.NewStorage(&client.Config{ AllowedTags: nil, }) require.NoError(t, err) err = s.Add(existingClient) require.NoError(t, err) assert.True(t, s.RemoveByName(existingName)) assert.False(t, s.RemoveByName(existingName)) }) } func TestStorage_Find(t *testing.T) { const ( cliIPNone = "1.2.3.4" cliIP1 = "1.1.1.1" cliIP2 = "2.2.2.2" cliIPv6 = "1:2:3::4" cliSubnet = "2.2.2.0/24" cliSubnetIP = "2.2.2.222" cliID = "client-id" cliMAC = "11:11:11:11:11:11" linkLocalIP = "fe80::abcd:abcd:abcd:ab%eth0" linkLocalSubnet = "fe80::/16" ) var ( clientWithBothFams = &client.Persistent{ Name: "client1", IPs: []netip.Addr{ netip.MustParseAddr(cliIP1), netip.MustParseAddr(cliIPv6), }, } clientWithSubnet = &client.Persistent{ Name: "client2", IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)}, } clientWithMAC = &client.Persistent{ Name: "client_with_mac", MACs: []net.HardwareAddr{mustParseMAC(cliMAC)}, } clientWithID = &client.Persistent{ Name: "client_with_id", ClientIDs: []string{cliID}, } clientLinkLocal = &client.Persistent{ Name: "client_link_local", Subnets: []netip.Prefix{netip.MustParsePrefix(linkLocalSubnet)}, } ) clients := []*client.Persistent{ clientWithBothFams, clientWithSubnet, clientWithMAC, clientWithID, clientLinkLocal, } s := newStorage(t, clients) testCases := []struct { want *client.Persistent name string ids []string }{{ name: "ipv4_ipv6", ids: []string{cliIP1, cliIPv6}, want: clientWithBothFams, }, { name: "ipv4_subnet", ids: []string{cliIP2, cliSubnetIP}, want: clientWithSubnet, }, { name: "mac", ids: []string{cliMAC}, want: clientWithMAC, }, { name: "client_id", ids: []string{cliID}, want: clientWithID, }, { name: "client_link_local_subnet", ids: []string{linkLocalIP}, want: clientLinkLocal, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for _, id := range tc.ids { c, ok := s.Find(id) require.True(t, ok) assert.Equal(t, tc.want, c) } }) } t.Run("not_found", func(t *testing.T) { _, ok := s.Find(cliIPNone) assert.False(t, ok) }) } func TestStorage_FindLoose(t *testing.T) { const ( nonExistingClientID = "client_id" ) var ( ip = netip.MustParseAddr("fe80::a098:7654:32ef:ff1") ipWithZone = netip.MustParseAddr("fe80::1ff:fe23:4567:890a%eth2") ) var ( clientNoZone = &client.Persistent{ Name: "client", IPs: []netip.Addr{ip}, } clientWithZone = &client.Persistent{ Name: "client_with_zone", IPs: []netip.Addr{ipWithZone}, } ) s := newStorage( t, []*client.Persistent{ clientNoZone, clientWithZone, }, ) testCases := []struct { ip netip.Addr want assert.BoolAssertionFunc wantCli *client.Persistent name string }{{ name: "without_zone", ip: ip, wantCli: clientNoZone, want: assert.True, }, { name: "with_zone", ip: ipWithZone, wantCli: clientWithZone, want: assert.True, }, { name: "zero_address", ip: netip.Addr{}, wantCli: nil, want: assert.False, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID) assert.Equal(t, tc.wantCli, c) tc.want(t, ok) }) } } func TestStorage_FindByName(t *testing.T) { const ( cliIP1 = "1.1.1.1" cliIP2 = "2.2.2.2" ) const ( clientExistingName = "client_existing" clientAnotherExistingName = "client_another_existing" nonExistingClientName = "client_non_existing" ) var ( clientExisting = &client.Persistent{ Name: clientExistingName, IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, } clientAnotherExisting = &client.Persistent{ Name: clientAnotherExistingName, IPs: []netip.Addr{netip.MustParseAddr(cliIP2)}, } ) clients := []*client.Persistent{ clientExisting, clientAnotherExisting, } s := newStorage(t, clients) testCases := []struct { want *client.Persistent name string clientName string }{{ name: "existing", clientName: clientExistingName, want: clientExisting, }, { name: "another_existing", clientName: clientAnotherExistingName, want: clientAnotherExisting, }, { name: "non_existing", clientName: nonExistingClientName, want: nil, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c, ok := s.FindByName(tc.clientName) if tc.want == nil { assert.False(t, ok) return } assert.True(t, ok) assert.Equal(t, tc.want, c) }) } } func TestStorage_FindByMAC(t *testing.T) { var ( cliMAC = mustParseMAC("11:11:11:11:11:11") cliAnotherMAC = mustParseMAC("22:22:22:22:22:22") nonExistingClientMAC = mustParseMAC("33:33:33:33:33:33") ) var ( clientExisting = &client.Persistent{ Name: "client", MACs: []net.HardwareAddr{cliMAC}, } clientAnotherExisting = &client.Persistent{ Name: "another_client", MACs: []net.HardwareAddr{cliAnotherMAC}, } ) clients := []*client.Persistent{ clientExisting, clientAnotherExisting, } s := newStorage(t, clients) testCases := []struct { want *client.Persistent name string clientMAC net.HardwareAddr }{{ name: "existing", clientMAC: cliMAC, want: clientExisting, }, { name: "another_existing", clientMAC: cliAnotherMAC, want: clientAnotherExisting, }, { name: "non_existing", clientMAC: nonExistingClientMAC, want: nil, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c, ok := s.FindByMAC(tc.clientMAC) if tc.want == nil { assert.False(t, ok) return } assert.True(t, ok) assert.Equal(t, tc.want, c) }) } } func TestStorage_Update(t *testing.T) { const ( clientName = "client_name" obstructingName = "obstructing_name" obstructingClientID = "obstructing_client_id" ) var ( obstructingIP = netip.MustParseAddr("1.2.3.4") obstructingSubnet = netip.MustParsePrefix("1.2.3.0/24") ) obstructingClient := &client.Persistent{ Name: obstructingName, IPs: []netip.Addr{obstructingIP}, Subnets: []netip.Prefix{obstructingSubnet}, ClientIDs: []string{obstructingClientID}, } clientToUpdate := &client.Persistent{ Name: clientName, IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, } testCases := []struct { name string cli *client.Persistent wantErrMsg string }{{ name: "basic", cli: &client.Persistent{ Name: "basic", IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1")}, UID: client.MustNewUID(), }, wantErrMsg: "", }, { name: "duplicate_name", cli: &client.Persistent{ Name: obstructingName, IPs: []netip.Addr{netip.MustParseAddr("3.3.3.3")}, UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client uses the same name "obstructing_name"`, }, { name: "duplicate_ip", cli: &client.Persistent{ Name: "duplicate_ip", IPs: []netip.Addr{obstructingIP}, UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" uses the same IP "1.2.3.4"`, }, { name: "duplicate_subnet", cli: &client.Persistent{ Name: "duplicate_subnet", Subnets: []netip.Prefix{obstructingSubnet}, UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" ` + `uses the same subnet "1.2.3.0/24"`, }, { name: "duplicate_client_id", cli: &client.Persistent{ Name: "duplicate_client_id", ClientIDs: []string{obstructingClientID}, UID: client.MustNewUID(), }, wantErrMsg: `updating client: another client "obstructing_name" ` + `uses the same ClientID "obstructing_client_id"`, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := newStorage( t, []*client.Persistent{ clientToUpdate, obstructingClient, }, ) err := s.Update(clientName, tc.cli) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } } func TestStorage_RangeByName(t *testing.T) { sortedClients := []*client.Persistent{{ Name: "clientA", ClientIDs: []string{"A"}, }, { Name: "clientB", ClientIDs: []string{"B"}, }, { Name: "clientC", ClientIDs: []string{"C"}, }, { Name: "clientD", ClientIDs: []string{"D"}, }, { Name: "clientE", ClientIDs: []string{"E"}, }} testCases := []struct { name string want []*client.Persistent }{{ name: "basic", want: sortedClients, }, { name: "nil", want: nil, }, { name: "one_element", want: sortedClients[:1], }, { name: "two_elements", want: sortedClients[:2], }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := newStorage(t, tc.want) var got []*client.Persistent s.RangeByName(func(c *client.Persistent) (cont bool) { got = append(got, c) return true }) assert.Equal(t, tc.want, got) }) } } func TestStorage_UpdateRuntime(t *testing.T) { const ( addedARP = "added_arp" addedSecondARP = "added_arp" updatedARP = "updated_arp" cliCity = "City" cliCountry = "Country" cliOrgname = "Orgname" ) var ( ip = netip.MustParseAddr("1.1.1.1") ip2 = netip.MustParseAddr("2.2.2.2") ) updated := client.NewRuntime(ip) updated.SetInfo(client.SourceARP, []string{updatedARP}) info := &whois.Info{ City: cliCity, Country: cliCountry, Orgname: cliOrgname, } updated.SetWHOIS(info) s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) require.NoError(t, err) t.Run("add_arp_client", func(t *testing.T) { added := client.NewRuntime(ip) added.SetInfo(client.SourceARP, []string{addedARP}) require.True(t, s.UpdateRuntime(added)) require.Equal(t, 1, s.SizeRuntime()) got := s.ClientRuntime(ip) source, host := got.Info() assert.Equal(t, client.SourceARP, source) assert.Equal(t, addedARP, host) }) t.Run("add_second_arp_client", func(t *testing.T) { added := client.NewRuntime(ip2) added.SetInfo(client.SourceARP, []string{addedSecondARP}) require.True(t, s.UpdateRuntime(added)) require.Equal(t, 2, s.SizeRuntime()) got := s.ClientRuntime(ip2) source, host := got.Info() assert.Equal(t, client.SourceARP, source) assert.Equal(t, addedSecondARP, host) }) t.Run("update_first_client", func(t *testing.T) { require.False(t, s.UpdateRuntime(updated)) got := s.ClientRuntime(ip) require.Equal(t, 2, s.SizeRuntime()) source, host := got.Info() assert.Equal(t, client.SourceARP, source) assert.Equal(t, updatedARP, host) }) t.Run("remove_arp_info", func(t *testing.T) { n := s.DeleteBySource(client.SourceARP) require.Equal(t, 1, n) require.Equal(t, 1, s.SizeRuntime()) got := s.ClientRuntime(ip) source, _ := got.Info() assert.Equal(t, client.SourceWHOIS, source) assert.Equal(t, info, got.WHOIS()) }) t.Run("remove_whois_info", func(t *testing.T) { n := s.DeleteBySource(client.SourceWHOIS) require.Equal(t, 1, n) require.Equal(t, 0, s.SizeRuntime()) }) } func TestStorage_BatchUpdateBySource(t *testing.T) { const ( defSrc = client.SourceARP cliFirstHost1 = "host1" cliFirstHost2 = "host2" cliUpdatedHost3 = "host3" cliUpdatedHost4 = "host4" cliUpdatedHost5 = "host5" ) var ( cliFirstIP1 = netip.MustParseAddr("1.1.1.1") cliFirstIP2 = netip.MustParseAddr("2.2.2.2") cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3") cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4") cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5") ) firstClients := []*client.Runtime{ newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1), newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2), } updatedClients := []*client.Runtime{ newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3), newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4), newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5), } s, err := client.NewStorage(&client.Config{ AllowedTags: nil, }) require.NoError(t, err) t.Run("populate_storage_with_first_clients", func(t *testing.T) { added, removed := s.BatchUpdateBySource(defSrc, firstClients) require.Equal(t, len(firstClients), added) require.Equal(t, 0, removed) require.Equal(t, len(firstClients), s.SizeRuntime()) rc := s.ClientRuntime(cliFirstIP1) src, host := rc.Info() assert.Equal(t, defSrc, src) assert.Equal(t, cliFirstHost1, host) }) t.Run("update_storage", func(t *testing.T) { added, removed := s.BatchUpdateBySource(defSrc, updatedClients) require.Equal(t, len(updatedClients), added) require.Equal(t, len(firstClients), removed) require.Equal(t, len(updatedClients), s.SizeRuntime()) rc := s.ClientRuntime(cliUpdatedIP3) src, host := rc.Info() assert.Equal(t, defSrc, src) assert.Equal(t, cliUpdatedHost3, host) rc = s.ClientRuntime(cliFirstIP1) assert.Nil(t, rc) }) t.Run("remove_all", func(t *testing.T) { added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{}) require.Equal(t, 0, added) require.Equal(t, len(updatedClients), removed) require.Equal(t, 0, s.SizeRuntime()) }) }