diff --git a/internal/client/storage.go b/internal/client/storage.go index 8684fc28..d1e306f9 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -9,7 +9,6 @@ import ( "sync" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" @@ -46,10 +45,14 @@ func (emptyDHCP) HostByIP(_ netip.Addr) (_ string) { return "" } func (emptyDHCP) MACByIP(_ netip.Addr) (_ net.HardwareAddr) { return nil } +type HostsContainer interface { + Upd() (updates <-chan *hostsfile.DefaultStorage) +} + // Config is the client storage configuration structure. type Config struct { DHCP DHCP - EtcHosts *aghnet.HostsContainer + EtcHosts HostsContainer ARPDB arpdb.Interface // AllowedTags is a list of all allowed client tags. @@ -74,7 +77,7 @@ type Storage struct { runtimeIndex *RuntimeIndex dhcp DHCP - etcHosts *aghnet.HostsContainer + etcHosts HostsContainer arpDB arpdb.Interface arpClientsUpdatePeriod time.Duration } @@ -159,6 +162,10 @@ func (s *Storage) addFromSystemARP() { // handleHostsUpdates receives the updates from the hosts container and adds // them to the clients storage. It is intended to be used as a goroutine. func (s *Storage) handleHostsUpdates() { + if s.etcHosts == nil { + return + } + defer log.OnPanic("storage") for upd := range s.etcHosts.Upd() { @@ -425,9 +432,9 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { return nil } + // TODO(s.chzhen): Update runtime index. rc = NewRuntime(ip) rc.SetInfo(SourceDHCP, []string{host}) - s.UpdateRuntime(rc) return rc } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 34a35636..1f2dd022 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -3,15 +3,422 @@ package client_test import ( "net" "net/netip" + "runtime" + "sync" "testing" + "time" + "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type testHostsContainer struct { + onUpd func() (updates <-chan *hostsfile.DefaultStorage) +} + +// type check +var _ client.HostsContainer = (*testHostsContainer)(nil) + +func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { + return c.onUpd() +} + +// Interface stores and refreshes the network neighborhood reported by ARP +// (Address Resolution Protocol). +type Interface interface { + // Refresh updates the stored data. It must be safe for concurrent use. + Refresh() (err error) + + // Neighbors returnes the last set of data reported by ARP. Both the method + // and it's result must be safe for concurrent use. + Neighbors() (ns []arpdb.Neighbor) +} + +type testARP struct { + onRefresh func() (err error) + + onNeighbors func() (ns []arpdb.Neighbor) +} + +func (c *testARP) Refresh() (err error) { + return c.onRefresh() +} + +func (c *testARP) Neighbors() (ns []arpdb.Neighbor) { + return c.onNeighbors() +} + +type testDHCP struct { + OnLeases func() (leases []*dhcpsvc.Lease) + OnHostBy func(ip netip.Addr) (host string) + OnMACBy func(ip netip.Addr) (mac net.HardwareAddr) +} + +// Lease implements the [DHCP] interface for testDHCP. +func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() } + +// HostByIP implements the [DHCP] interface for testDHCP. +func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) } + +// MACByIP implements the [DHCP] interface for testDHCP. +func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) } + +// compareRuntimeInfo is a helper function that returns true if the runtime +// client has provided info. +func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) { + s, h := rc.Info() + if s != src { + return false + } else if h != host { + return false + } + + return true +} + +func TestStorage_Add_hostsfile(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "client_one" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + ) + + hostCh := make(chan *hostsfile.DefaultStorage) + h := &testHostsContainer{ + onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh }, + } + + storage, err := client.NewStorage(&client.Config{ + EtcHosts: h, + }) + require.NoError(t, err) + + storage.Start() + + t.Run("add_hosts", func(t *testing.T) { + var s *hostsfile.DefaultStorage + s, err = hostsfile.NewDefaultStorage() + require.NoError(t, err) + + s.Add(&hostsfile.Record{ + Addr: cliIP1, + Names: []string{cliName1}, + }) + + testutil.RequireSend(t, hostCh, s, testTimeout) + + require.Eventually(t, func() (ok bool) { + cli1 := storage.ClientRuntime(cliIP1) + if cli1 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1)) + + return true + }, testTimeout, testTimeout/10) + }) + + t.Run("update_hosts", func(t *testing.T) { + var s *hostsfile.DefaultStorage + s, err = hostsfile.NewDefaultStorage() + require.NoError(t, err) + + s.Add(&hostsfile.Record{ + Addr: cliIP2, + Names: []string{cliName2}, + }) + + testutil.RequireSend(t, hostCh, s, testTimeout) + + require.Eventually(t, func() (ok bool) { + cli2 := storage.ClientRuntime(cliIP2) + if cli2 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2)) + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + return true + }, testTimeout, testTimeout/10) + }) +} + +func TestStorage_Add_arp(t *testing.T) { + var ( + mu sync.Mutex + neighbors []arpdb.Neighbor + + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "client_one" + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + ) + + a := &testARP{ + onRefresh: func() (err error) { return nil }, + onNeighbors: func() (ns []arpdb.Neighbor) { + mu.Lock() + defer mu.Unlock() + + return neighbors + }, + } + + storage, err := client.NewStorage(&client.Config{ + ARPDB: a, + ARPClientsUpdatePeriod: testTimeout / 10, + }) + require.NoError(t, err) + + storage.Start() + + t.Run("add_hosts", func(t *testing.T) { + func() { + mu.Lock() + defer mu.Unlock() + + neighbors = []arpdb.Neighbor{{ + Name: cliName1, + IP: cliIP1, + }} + }() + + require.Eventually(t, func() (ok bool) { + cli1 := storage.ClientRuntime(cliIP1) + if cli1 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1)) + + return true + }, testTimeout, testTimeout/10) + }) + + t.Run("update_hosts", func(t *testing.T) { + func() { + mu.Lock() + defer mu.Unlock() + + neighbors = []arpdb.Neighbor{{ + Name: cliName2, + IP: cliIP2, + }} + }() + + require.Eventually(t, func() (ok bool) { + cli2 := storage.ClientRuntime(cliIP2) + if cli2 == nil { + return false + } + + assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2)) + + cli1 := storage.ClientRuntime(cliIP1) + require.Nil(t, cli1) + + return true + }, testTimeout, testTimeout/10) + }) +} + +func TestStorage_Add_whois(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + + cliIP2 = netip.MustParseAddr("2.2.2.2") + cliName2 = "client_two" + + cliIP3 = netip.MustParseAddr("3.3.3.3") + cliName3 = "client_three" + ) + + storage, err := client.NewStorage(&client.Config{}) + require.NoError(t, err) + + whois := &whois.Info{ + Country: "AU", + Orgname: "Example Org", + } + + t.Run("new_client", func(t *testing.T) { + storage.UpdateAddress(cliIP1, "", whois) + cli1 := storage.ClientRuntime(cliIP1) + require.NotNil(t, cli1) + + assert.Equal(t, whois, cli1.WHOIS()) + }) + + t.Run("existing_runtime_client", func(t *testing.T) { + storage.UpdateAddress(cliIP2, cliName2, nil) + storage.UpdateAddress(cliIP2, "", whois) + + cli2 := storage.ClientRuntime(cliIP2) + require.NotNil(t, cli2) + + assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2)) + + assert.Equal(t, whois, cli2.WHOIS()) + }) + + t.Run("can't_set_persistent_client", func(t *testing.T) { + err = storage.Add(&client.Persistent{ + Name: cliName3, + UID: client.MustNewUID(), + IPs: []netip.Addr{cliIP3}, + }) + require.NoError(t, err) + + storage.UpdateAddress(cliIP3, "", whois) + rc := storage.ClientRuntime(cliIP3) + require.Nil(t, rc) + }) +} + +func TestClientsDHCP(t *testing.T) { + var ( + cliIP1 = netip.MustParseAddr("1.1.1.1") + cliName1 = "client_one" + + prsCliIP = netip.MustParseAddr("4.3.2.1") + prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") + prsCliName = "persitent_client" + ) + + ipToHost := map[netip.Addr]string{ + cliIP1: cliName1, + } + ipToMAC := map[netip.Addr]net.HardwareAddr{ + prsCliIP: prsCliMAC, + } + + d := &testDHCP{ + OnLeases: func() (leases []*dhcpsvc.Lease) { panic("not implemented") }, + OnHostBy: func(ip netip.Addr) (host string) { + return ipToHost[ip] + }, + OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { + return ipToMAC[ip] + }, + } + + storage, err := client.NewStorage(&client.Config{ + DHCP: d, + }) + require.NoError(t, err) + + t.Run("find_runtime", func(t *testing.T) { + cli1 := storage.ClientRuntime(cliIP1) + require.NotNil(t, cli1) + + assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1)) + }) + + t.Run("find_persistent", func(t *testing.T) { + err = storage.Add(&client.Persistent{ + Name: prsCliName, + UID: client.MustNewUID(), + MACs: []net.HardwareAddr{prsCliMAC}, + }) + require.NoError(t, err) + + prsCli, ok := storage.Find(prsCliIP.String()) + require.True(t, ok) + + assert.Equal(t, prsCliName, prsCli.Name) + }) +} + +func TestClientsAddExisting(t *testing.T) { + // First, init a DHCP server with a single static lease. + config := &dhcpd.ServerConfig{ + Enabled: true, + DataDir: t.TempDir(), + Conf4: dhcpd.V4ServerConf{ + Enabled: true, + GatewayIP: netip.MustParseAddr("1.2.3.1"), + SubnetMask: netip.MustParseAddr("255.255.255.0"), + RangeStart: netip.MustParseAddr("1.2.3.2"), + RangeEnd: netip.MustParseAddr("1.2.3.10"), + }, + } + + dhcpServer, err := dhcpd.Create(config) + require.NoError(t, err) + + storage, err := client.NewStorage(&client.Config{ + DHCP: dhcpServer, + }) + require.NoError(t, err) + + t.Run("simple", func(t *testing.T) { + ip := netip.MustParseAddr("1.1.1.1") + + // Add a client. + err = storage.Add(&client.Persistent{ + Name: "client1", + UID: client.MustNewUID(), + IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, + Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, + MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, + }) + require.NoError(t, err) + + // Now add an auto-client with the same IP. + storage.UpdateAddress(ip, "test", nil) + rc := storage.ClientRuntime(ip) + assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test")) + }) + + t.Run("complicated", func(t *testing.T) { + // TODO(a.garipov): Properly decouple the DHCP server from the client + // storage. + if runtime.GOOS == "windows" { + t.Skip("skipping dhcp test on windows") + } + + ip := netip.MustParseAddr("1.2.3.4") + + err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{ + HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, + IP: ip, + Hostname: "testhost", + Expiry: time.Now().Add(time.Hour), + }) + require.NoError(t, err) + + // Add a new client with the same IP as for a client with MAC. + err = storage.Add(&client.Persistent{ + Name: "client2", + UID: client.MustNewUID(), + IPs: []netip.Addr{ip}, + }) + require.NoError(t, err) + + // Add a new client with the IP from the first client's IP range. + err = storage.Add(&client.Persistent{ + Name: "client3", + UID: client.MustNewUID(), + IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, + }) + require.NoError(t, err) + }) +} + // newStorage is a helper function that returns a client storage filled with // persistent clients from the m. It also generates a UID for each client. func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {