diff --git a/internal/client/client.go b/internal/client/client.go index d3ead923..780415e6 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -8,6 +8,7 @@ import ( "encoding" "fmt" "net/netip" + "slices" "github.com/AdguardTeam/AdGuardHome/internal/whois" ) @@ -175,3 +176,15 @@ func (r *Runtime) isEmpty() (ok bool) { func (r *Runtime) Addr() (ip netip.Addr) { return r.ip } + +// Clone returns a deep copy of the runtime client. +func (r *Runtime) Clone() (c *Runtime) { + return &Runtime{ + ip: r.ip, + whois: r.whois.Clone(), + arp: slices.Clone(r.arp), + rdns: slices.Clone(r.rdns), + dhcp: slices.Clone(r.dhcp), + hostsFile: slices.Clone(r.hostsFile), + } +} diff --git a/internal/client/storage.go b/internal/client/storage.go index 2053bdf9..66601812 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/netip" + "slices" "sync" "github.com/AdguardTeam/golibs/container" @@ -31,8 +32,6 @@ type Storage struct { index *index // runtimeIndex contains information about runtime clients. - // - // TODO(s.chzhen): Use it. runtimeIndex *RuntimeIndex } @@ -236,20 +235,68 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { return s.runtimeIndex.Client(ip) } -// AddRuntime saves the runtime client information in the storage. IP address -// of a client must be unique. rc must not be nil. -// -// TODO(s.chzhen): Use it. -func (s *Storage) AddRuntime(rc *Runtime) { +// UpdateRuntime updates the stored runtime client with information from rc. If +// no such client exists, saves the copy of rc in storage. rc must not be nil. +func (s *Storage) UpdateRuntime(rc *Runtime) { s.mu.Lock() defer s.mu.Unlock() - s.runtimeIndex.Add(rc) + s.updateRuntimeLocked(rc) +} + +// updateRuntimeLocked updates the stored runtime client with information from +// rc. rc must not be nil. Storage.mu is expected to be locked. +func (s *Storage) updateRuntimeLocked(rc *Runtime) { + stored := s.runtimeIndex.Client(rc.ip) + if stored == nil { + s.runtimeIndex.Add(rc.Clone()) + + return + } + + if rc.whois != nil { + stored.whois = rc.whois.Clone() + } + + if rc.arp != nil { + stored.arp = slices.Clone(rc.arp) + } + + if rc.rdns != nil { + stored.rdns = slices.Clone(rc.rdns) + } + + if rc.dhcp != nil { + stored.dhcp = slices.Clone(rc.dhcp) + } + + if rc.hostsFile != nil { + stored.hostsFile = slices.Clone(rc.hostsFile) + } +} + +// BatchUpdateBySource updates the stored runtime clients information from the +// specified source. +func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) { + s.mu.Lock() + defer s.mu.Unlock() + + for _, rc := range s.runtimeIndex.index { + rc.unset(src) + } + + for _, rc := range rcs { + s.updateRuntimeLocked(rc) + } + + for ip, rc := range s.runtimeIndex.index { + if rc.isEmpty() { + delete(s.runtimeIndex.index, ip) + } + } } // SizeRuntime returns the number of the runtime clients. -// -// TODO(s.chzhen): Use it. func (s *Storage) SizeRuntime() (n int) { s.mu.Lock() defer s.mu.Unlock() @@ -258,8 +305,6 @@ func (s *Storage) SizeRuntime() (n int) { } // RangeRuntime calls f for each runtime client in an undefined order. -// -// TODO(s.chzhen): Use it. func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() @@ -267,16 +312,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.runtimeIndex.Range(f) } -// DeleteRuntime removes the runtime client by ip. -// -// TODO(s.chzhen): Use it. -func (s *Storage) DeleteRuntime(ip netip.Addr) { - s.mu.Lock() - defer s.mu.Unlock() - - s.runtimeIndex.Delete(ip) -} - // DeleteBySource removes all runtime clients that have information only from // the specified source and returns the number of removed clients. // diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index abfc6d62..3e966575 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -479,3 +479,43 @@ func TestStorage_RangeByName(t *testing.T) { }) } } + +func TestStorage_UpdateRuntime(t *testing.T) { + const ( + addedARP = "added_arp" + + updatedARP = "updated_arp" + updatedHostsFile = "updated_hosts" + ) + + ip := netip.MustParseAddr("1.1.1.1") + + added := client.NewRuntime(ip) + added.SetInfo(client.SourceARP, []string{addedARP}) + + updated := client.NewRuntime(ip) + updated.SetInfo(client.SourceARP, []string{updatedARP}) + updated.SetInfo(client.SourceHostsFile, []string{updatedHostsFile}) + + s := newStorage(t, nil) + + s.UpdateRuntime(added) + got := s.ClientRuntime(ip) + source, host := got.Info() + assert.Equal(t, client.SourceARP, source) + assert.Equal(t, addedARP, host) + + s.UpdateRuntime(updated) + got = s.ClientRuntime(ip) + source, host = got.Info() + assert.Equal(t, client.SourceHostsFile, source) + assert.Equal(t, updatedHostsFile, host) + + n := s.DeleteBySource(client.SourceHostsFile) + require.Equal(t, 0, n) + + got = s.ClientRuntime(ip) + source, host = got.Info() + assert.Equal(t, client.SourceARP, source) + assert.Equal(t, updatedARP, host) +} diff --git a/internal/home/clients.go b/internal/home/clients.go index aee32f92..a94d5ec1 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -47,9 +47,6 @@ type clientsContainer struct { // storage stores information about persistent clients. storage *client.Storage - // runtimeIndex stores information about runtime clients. - runtimeIndex *client.RuntimeIndex - // dhcp is the DHCP service implementation. dhcp DHCP @@ -105,8 +102,6 @@ func (clients *clientsContainer) Init( return errors.Error("clients container already initialized") } - clients.runtimeIndex = client.NewRuntimeIndex() - clients.storage = client.NewStorage(&client.Config{ AllowedTags: clientTags, }) @@ -358,7 +353,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) return client.SourcePersistent } - rc := clients.runtimeIndex.Client(ip) + rc := clients.storage.ClientRuntime(ip) if rc != nil { src, _ = rc.Info() } @@ -539,22 +534,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, return clients.storage.FindByMAC(foundMAC) } -// runtimeClient returns a runtime client from internal index. Note that it -// doesn't include DHCP clients. -func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) { - if ip == (netip.Addr{}) { - return nil - } - - clients.lock.Lock() - defer clients.lock.Unlock() - - return clients.runtimeIndex.Client(ip) -} - // findRuntimeClient finds a runtime client by their IP. func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { - rc = clients.runtimeClient(ip) + rc = clients.storage.ClientRuntime(ip) host := clients.dhcp.HostByIP(ip) if host != "" { @@ -580,20 +562,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) { return } - rc := clients.runtimeIndex.Client(ip) - if rc == nil { - // Create a RuntimeClient implicitly so that we don't do this check - // again. - rc = client.NewRuntime(ip) - clients.runtimeIndex.Add(rc) - - log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) - } else { - host, _ := rc.Info() - log.Debug("clients: set whois info for runtime client %s: %+v", host, wi) - } - + rc := client.NewRuntime(ip) rc.SetWHOIS(wi) + clients.storage.UpdateRuntime(rc) + + log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } // addHost adds a new IP-hostname pairing. The priorities of the sources are @@ -644,26 +617,20 @@ func (clients *clientsContainer) addHostLocked( host string, src client.Source, ) (ok bool) { - rc := clients.runtimeIndex.Client(ip) - if rc == nil { - if src < client.SourceDHCP { - if clients.dhcp.HostByIP(ip) != "" { - return false - } - } - - rc = client.NewRuntime(ip) - clients.runtimeIndex.Add(rc) + rc := client.NewRuntime(ip) + rc.SetInfo(src, []string{host}) + if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" { + rc.SetInfo(client.SourceDHCP, []string{dhcpHost}) } - rc.SetInfo(src, []string{host}) + clients.storage.UpdateRuntime(rc) log.Debug( "clients: adding client info %s -> %q %q [%d]", ip, src, host, - clients.runtimeIndex.Size(), + clients.storage.SizeRuntime(), ) return true @@ -675,22 +642,24 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag clients.lock.Lock() defer clients.lock.Unlock() - deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile) - log.Debug("clients: removed %d client aliases from system hosts file", deleted) - added := 0 + rcs := []*client.Runtime{} hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. - if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { - added++ - } + rc := client.NewRuntime(addr) + rc.SetInfo(client.SourceHostsFile, []string{names[0]}) + + added++ + rcs = append(rcs, rc) return true }) + clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs) + log.Debug("clients: added %d client aliases from system hosts file", added) } @@ -715,16 +684,18 @@ func (clients *clientsContainer) addFromSystemARP() { clients.lock.Lock() defer clients.lock.Unlock() - deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP) - log.Debug("clients: removed %d client aliases from arp neighborhood", deleted) - added := 0 + rcs := []*client.Runtime{} for _, n := range ns { - if clients.addHostLocked(n.IP, n.Name, client.SourceARP) { - added++ - } + rc := client.NewRuntime(n.IP) + rc.SetInfo(client.SourceARP, []string{n.Name}) + + added++ + rcs = append(rcs, rc) } + clients.storage.BatchUpdateBySource(client.SourceARP, rcs) + log.Debug("clients: added %d client aliases from arp neighborhood", added) } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 2c90a1e0..f47676f0 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") clients.setWHOISInfo(ip, whois) - rc := clients.runtimeIndex.Client(ip) + rc := clients.storage.ClientRuntime(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) { assert.True(t, ok) clients.setWHOISInfo(ip, whois) - rc := clients.runtimeIndex.Client(ip) + rc := clients.storage.ClientRuntime(ip) require.NotNil(t, rc) assert.Equal(t, whois, rc.WHOIS()) @@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) { require.NoError(t, err) clients.setWHOISInfo(ip, whois) - rc := clients.runtimeIndex.Client(ip) + rc := clients.storage.ClientRuntime(ip) require.Nil(t, rc) assert.True(t, clients.storage.RemoveByName("client1")) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index aae8c34a..eba22ddb 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) - clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) { + clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() cj := runtimeClientJSON{ WHOIS: whoisOrEmpty(rc), diff --git a/internal/whois/whois.go b/internal/whois/whois.go index 10f0609b..3f48894e 100644 --- a/internal/whois/whois.go +++ b/internal/whois/whois.go @@ -354,6 +354,19 @@ type Info struct { Orgname string `json:"orgname,omitempty"` } +// Clone returns a deep copy of the WHOIS info. +func (i *Info) Clone() (c *Info) { + if i == nil { + return nil + } + + return &Info{ + City: i.City, + Country: i.Country, + Orgname: i.Orgname, + } +} + // cacheItem represents an item that we will store in the cache. type cacheItem struct { // expiry is the time when cacheItem will expire.