all: client runtime storage

This commit is contained in:
Stanislav Chzhen 2024-08-05 15:12:05 +03:00
parent b6ed769652
commit 6489996878
7 changed files with 155 additions and 83 deletions

View File

@ -8,6 +8,7 @@ import (
"encoding" "encoding"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/AdGuardHome/internal/whois"
) )
@ -175,3 +176,15 @@ func (r *Runtime) isEmpty() (ok bool) {
func (r *Runtime) Addr() (ip netip.Addr) { func (r *Runtime) Addr() (ip netip.Addr) {
return r.ip return r.ip
} }
// Clone returns a deep copy of the runtime client.
func (r *Runtime) Clone() (c *Runtime) {
return &Runtime{
ip: r.ip,
whois: r.whois.Clone(),
arp: slices.Clone(r.arp),
rdns: slices.Clone(r.rdns),
dhcp: slices.Clone(r.dhcp),
hostsFile: slices.Clone(r.hostsFile),
}
}

View File

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/container"
@ -31,8 +32,6 @@ type Storage struct {
index *index index *index
// runtimeIndex contains information about runtime clients. // runtimeIndex contains information about runtime clients.
//
// TODO(s.chzhen): Use it.
runtimeIndex *RuntimeIndex runtimeIndex *RuntimeIndex
} }
@ -236,20 +235,68 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
return s.runtimeIndex.Client(ip) return s.runtimeIndex.Client(ip)
} }
// AddRuntime saves the runtime client information in the storage. IP address // UpdateRuntime updates the stored runtime client with information from rc. If
// of a client must be unique. rc must not be nil. // no such client exists, saves the copy of rc in storage. rc must not be nil.
// func (s *Storage) UpdateRuntime(rc *Runtime) {
// TODO(s.chzhen): Use it.
func (s *Storage) AddRuntime(rc *Runtime) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.runtimeIndex.Add(rc) s.updateRuntimeLocked(rc)
}
// updateRuntimeLocked updates the stored runtime client with information from
// rc. rc must not be nil. Storage.mu is expected to be locked.
func (s *Storage) updateRuntimeLocked(rc *Runtime) {
stored := s.runtimeIndex.Client(rc.ip)
if stored == nil {
s.runtimeIndex.Add(rc.Clone())
return
}
if rc.whois != nil {
stored.whois = rc.whois.Clone()
}
if rc.arp != nil {
stored.arp = slices.Clone(rc.arp)
}
if rc.rdns != nil {
stored.rdns = slices.Clone(rc.rdns)
}
if rc.dhcp != nil {
stored.dhcp = slices.Clone(rc.dhcp)
}
if rc.hostsFile != nil {
stored.hostsFile = slices.Clone(rc.hostsFile)
}
}
// BatchUpdateBySource updates the stored runtime clients information from the
// specified source.
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) {
s.mu.Lock()
defer s.mu.Unlock()
for _, rc := range s.runtimeIndex.index {
rc.unset(src)
}
for _, rc := range rcs {
s.updateRuntimeLocked(rc)
}
for ip, rc := range s.runtimeIndex.index {
if rc.isEmpty() {
delete(s.runtimeIndex.index, ip)
}
}
} }
// SizeRuntime returns the number of the runtime clients. // SizeRuntime returns the number of the runtime clients.
//
// TODO(s.chzhen): Use it.
func (s *Storage) SizeRuntime() (n int) { func (s *Storage) SizeRuntime() (n int) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -258,8 +305,6 @@ func (s *Storage) SizeRuntime() (n int) {
} }
// RangeRuntime calls f for each runtime client in an undefined order. // RangeRuntime calls f for each runtime client in an undefined order.
//
// TODO(s.chzhen): Use it.
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -267,16 +312,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
s.runtimeIndex.Range(f) s.runtimeIndex.Range(f)
} }
// DeleteRuntime removes the runtime client by ip.
//
// TODO(s.chzhen): Use it.
func (s *Storage) DeleteRuntime(ip netip.Addr) {
s.mu.Lock()
defer s.mu.Unlock()
s.runtimeIndex.Delete(ip)
}
// DeleteBySource removes all runtime clients that have information only from // DeleteBySource removes all runtime clients that have information only from
// the specified source and returns the number of removed clients. // the specified source and returns the number of removed clients.
// //

View File

@ -479,3 +479,43 @@ func TestStorage_RangeByName(t *testing.T) {
}) })
} }
} }
func TestStorage_UpdateRuntime(t *testing.T) {
const (
addedARP = "added_arp"
updatedARP = "updated_arp"
updatedHostsFile = "updated_hosts"
)
ip := netip.MustParseAddr("1.1.1.1")
added := client.NewRuntime(ip)
added.SetInfo(client.SourceARP, []string{addedARP})
updated := client.NewRuntime(ip)
updated.SetInfo(client.SourceARP, []string{updatedARP})
updated.SetInfo(client.SourceHostsFile, []string{updatedHostsFile})
s := newStorage(t, nil)
s.UpdateRuntime(added)
got := s.ClientRuntime(ip)
source, host := got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, addedARP, host)
s.UpdateRuntime(updated)
got = s.ClientRuntime(ip)
source, host = got.Info()
assert.Equal(t, client.SourceHostsFile, source)
assert.Equal(t, updatedHostsFile, host)
n := s.DeleteBySource(client.SourceHostsFile)
require.Equal(t, 0, n)
got = s.ClientRuntime(ip)
source, host = got.Info()
assert.Equal(t, client.SourceARP, source)
assert.Equal(t, updatedARP, host)
}

View File

@ -47,9 +47,6 @@ type clientsContainer struct {
// storage stores information about persistent clients. // storage stores information about persistent clients.
storage *client.Storage storage *client.Storage
// runtimeIndex stores information about runtime clients.
runtimeIndex *client.RuntimeIndex
// dhcp is the DHCP service implementation. // dhcp is the DHCP service implementation.
dhcp DHCP dhcp DHCP
@ -105,8 +102,6 @@ func (clients *clientsContainer) Init(
return errors.Error("clients container already initialized") return errors.Error("clients container already initialized")
} }
clients.runtimeIndex = client.NewRuntimeIndex()
clients.storage = client.NewStorage(&client.Config{ clients.storage = client.NewStorage(&client.Config{
AllowedTags: clientTags, AllowedTags: clientTags,
}) })
@ -358,7 +353,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
return client.SourcePersistent return client.SourcePersistent
} }
rc := clients.runtimeIndex.Client(ip) rc := clients.storage.ClientRuntime(ip)
if rc != nil { if rc != nil {
src, _ = rc.Info() src, _ = rc.Info()
} }
@ -539,22 +534,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
return clients.storage.FindByMAC(foundMAC) return clients.storage.FindByMAC(foundMAC)
} }
// runtimeClient returns a runtime client from internal index. Note that it
// doesn't include DHCP clients.
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
if ip == (netip.Addr{}) {
return nil
}
clients.lock.Lock()
defer clients.lock.Unlock()
return clients.runtimeIndex.Client(ip)
}
// findRuntimeClient finds a runtime client by their IP. // findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) { func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
rc = clients.runtimeClient(ip) rc = clients.storage.ClientRuntime(ip)
host := clients.dhcp.HostByIP(ip) host := clients.dhcp.HostByIP(ip)
if host != "" { if host != "" {
@ -580,20 +562,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
return return
} }
rc := clients.runtimeIndex.Client(ip) rc := client.NewRuntime(ip)
if rc == nil { rc.SetWHOIS(wi)
// Create a RuntimeClient implicitly so that we don't do this check clients.storage.UpdateRuntime(rc)
// again.
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
} else {
host, _ := rc.Info()
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
}
rc.SetWHOIS(wi)
} }
// addHost adds a new IP-hostname pairing. The priorities of the sources are // addHost adds a new IP-hostname pairing. The priorities of the sources are
@ -644,26 +617,20 @@ func (clients *clientsContainer) addHostLocked(
host string, host string,
src client.Source, src client.Source,
) (ok bool) { ) (ok bool) {
rc := clients.runtimeIndex.Client(ip) rc := client.NewRuntime(ip)
if rc == nil {
if src < client.SourceDHCP {
if clients.dhcp.HostByIP(ip) != "" {
return false
}
}
rc = client.NewRuntime(ip)
clients.runtimeIndex.Add(rc)
}
rc.SetInfo(src, []string{host}) rc.SetInfo(src, []string{host})
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
}
clients.storage.UpdateRuntime(rc)
log.Debug( log.Debug(
"clients: adding client info %s -> %q %q [%d]", "clients: adding client info %s -> %q %q [%d]",
ip, ip,
src, src,
host, host,
clients.runtimeIndex.Size(), clients.storage.SizeRuntime(),
) )
return true return true
@ -675,22 +642,24 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
added := 0 added := 0
rcs := []*client.Runtime{}
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
// Only the first name of the first record is considered a canonical // Only the first name of the first record is considered a canonical
// hostname for the IP address. // hostname for the IP address.
// //
// TODO(e.burkov): Consider using all the names from all the records. // TODO(e.burkov): Consider using all the names from all the records.
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { rc := client.NewRuntime(addr)
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
added++ added++
} rcs = append(rcs, rc)
return true return true
}) })
clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
log.Debug("clients: added %d client aliases from system hosts file", added) log.Debug("clients: added %d client aliases from system hosts file", added)
} }
@ -715,15 +684,17 @@ func (clients *clientsContainer) addFromSystemARP() {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
added := 0 added := 0
rcs := []*client.Runtime{}
for _, n := range ns { for _, n := range ns {
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) { rc := client.NewRuntime(n.IP)
rc.SetInfo(client.SourceARP, []string{n.Name})
added++ added++
rcs = append(rcs, rc)
} }
}
clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
log.Debug("clients: added %d client aliases from arp neighborhood", added) log.Debug("clients: added %d client aliases from arp neighborhood", added)
} }

View File

@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
t.Run("new_client", func(t *testing.T) { t.Run("new_client", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.255") ip := netip.MustParseAddr("1.1.1.255")
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip) rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc) require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS()) assert.Equal(t, whois, rc.WHOIS())
@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
assert.True(t, ok) assert.True(t, ok)
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip) rc := clients.storage.ClientRuntime(ip)
require.NotNil(t, rc) require.NotNil(t, rc)
assert.Equal(t, whois, rc.WHOIS()) assert.Equal(t, whois, rc.WHOIS())
@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
clients.setWHOISInfo(ip, whois) clients.setWHOISInfo(ip, whois)
rc := clients.runtimeIndex.Client(ip) rc := clients.storage.ClientRuntime(ip)
require.Nil(t, rc) require.Nil(t, rc)
assert.True(t, clients.storage.RemoveByName("client1")) assert.True(t, clients.storage.RemoveByName("client1"))

View File

@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true return true
}) })
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) { clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info() src, host := rc.Info()
cj := runtimeClientJSON{ cj := runtimeClientJSON{
WHOIS: whoisOrEmpty(rc), WHOIS: whoisOrEmpty(rc),

View File

@ -354,6 +354,19 @@ type Info struct {
Orgname string `json:"orgname,omitempty"` Orgname string `json:"orgname,omitempty"`
} }
// Clone returns a deep copy of the WHOIS info.
func (i *Info) Clone() (c *Info) {
if i == nil {
return nil
}
return &Info{
City: i.City,
Country: i.Country,
Orgname: i.Orgname,
}
}
// cacheItem represents an item that we will store in the cache. // cacheItem represents an item that we will store in the cache.
type cacheItem struct { type cacheItem struct {
// expiry is the time when cacheItem will expire. // expiry is the time when cacheItem will expire.