all: client runtime storage
This commit is contained in:
parent
b6ed769652
commit
6489996878
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
)
|
||||
|
@ -175,3 +176,15 @@ func (r *Runtime) isEmpty() (ok bool) {
|
|||
func (r *Runtime) Addr() (ip netip.Addr) {
|
||||
return r.ip
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the runtime client.
|
||||
func (r *Runtime) Clone() (c *Runtime) {
|
||||
return &Runtime{
|
||||
ip: r.ip,
|
||||
whois: r.whois.Clone(),
|
||||
arp: slices.Clone(r.arp),
|
||||
rdns: slices.Clone(r.rdns),
|
||||
dhcp: slices.Clone(r.dhcp),
|
||||
hostsFile: slices.Clone(r.hostsFile),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
|
@ -31,8 +32,6 @@ type Storage struct {
|
|||
index *index
|
||||
|
||||
// runtimeIndex contains information about runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
runtimeIndex *RuntimeIndex
|
||||
}
|
||||
|
||||
|
@ -236,20 +235,68 @@ func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
|||
return s.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// AddRuntime saves the runtime client information in the storage. IP address
|
||||
// of a client must be unique. rc must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) AddRuntime(rc *Runtime) {
|
||||
// UpdateRuntime updates the stored runtime client with information from rc. If
|
||||
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
||||
func (s *Storage) UpdateRuntime(rc *Runtime) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Add(rc)
|
||||
s.updateRuntimeLocked(rc)
|
||||
}
|
||||
|
||||
// updateRuntimeLocked updates the stored runtime client with information from
|
||||
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
||||
func (s *Storage) updateRuntimeLocked(rc *Runtime) {
|
||||
stored := s.runtimeIndex.Client(rc.ip)
|
||||
if stored == nil {
|
||||
s.runtimeIndex.Add(rc.Clone())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if rc.whois != nil {
|
||||
stored.whois = rc.whois.Clone()
|
||||
}
|
||||
|
||||
if rc.arp != nil {
|
||||
stored.arp = slices.Clone(rc.arp)
|
||||
}
|
||||
|
||||
if rc.rdns != nil {
|
||||
stored.rdns = slices.Clone(rc.rdns)
|
||||
}
|
||||
|
||||
if rc.dhcp != nil {
|
||||
stored.dhcp = slices.Clone(rc.dhcp)
|
||||
}
|
||||
|
||||
if rc.hostsFile != nil {
|
||||
stored.hostsFile = slices.Clone(rc.hostsFile)
|
||||
}
|
||||
}
|
||||
|
||||
// BatchUpdateBySource updates the stored runtime clients information from the
|
||||
// specified source.
|
||||
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, rc := range s.runtimeIndex.index {
|
||||
rc.unset(src)
|
||||
}
|
||||
|
||||
for _, rc := range rcs {
|
||||
s.updateRuntimeLocked(rc)
|
||||
}
|
||||
|
||||
for ip, rc := range s.runtimeIndex.index {
|
||||
if rc.isEmpty() {
|
||||
delete(s.runtimeIndex.index, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SizeRuntime returns the number of the runtime clients.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) SizeRuntime() (n int) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
@ -258,8 +305,6 @@ func (s *Storage) SizeRuntime() (n int) {
|
|||
}
|
||||
|
||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
@ -267,16 +312,6 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
|||
s.runtimeIndex.Range(f)
|
||||
}
|
||||
|
||||
// DeleteRuntime removes the runtime client by ip.
|
||||
//
|
||||
// TODO(s.chzhen): Use it.
|
||||
func (s *Storage) DeleteRuntime(ip netip.Addr) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.runtimeIndex.Delete(ip)
|
||||
}
|
||||
|
||||
// DeleteBySource removes all runtime clients that have information only from
|
||||
// the specified source and returns the number of removed clients.
|
||||
//
|
||||
|
|
|
@ -479,3 +479,43 @@ func TestStorage_RangeByName(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_UpdateRuntime(t *testing.T) {
|
||||
const (
|
||||
addedARP = "added_arp"
|
||||
|
||||
updatedARP = "updated_arp"
|
||||
updatedHostsFile = "updated_hosts"
|
||||
)
|
||||
|
||||
ip := netip.MustParseAddr("1.1.1.1")
|
||||
|
||||
added := client.NewRuntime(ip)
|
||||
added.SetInfo(client.SourceARP, []string{addedARP})
|
||||
|
||||
updated := client.NewRuntime(ip)
|
||||
updated.SetInfo(client.SourceARP, []string{updatedARP})
|
||||
updated.SetInfo(client.SourceHostsFile, []string{updatedHostsFile})
|
||||
|
||||
s := newStorage(t, nil)
|
||||
|
||||
s.UpdateRuntime(added)
|
||||
got := s.ClientRuntime(ip)
|
||||
source, host := got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, addedARP, host)
|
||||
|
||||
s.UpdateRuntime(updated)
|
||||
got = s.ClientRuntime(ip)
|
||||
source, host = got.Info()
|
||||
assert.Equal(t, client.SourceHostsFile, source)
|
||||
assert.Equal(t, updatedHostsFile, host)
|
||||
|
||||
n := s.DeleteBySource(client.SourceHostsFile)
|
||||
require.Equal(t, 0, n)
|
||||
|
||||
got = s.ClientRuntime(ip)
|
||||
source, host = got.Info()
|
||||
assert.Equal(t, client.SourceARP, source)
|
||||
assert.Equal(t, updatedARP, host)
|
||||
}
|
||||
|
|
|
@ -47,9 +47,6 @@ type clientsContainer struct {
|
|||
// storage stores information about persistent clients.
|
||||
storage *client.Storage
|
||||
|
||||
// runtimeIndex stores information about runtime clients.
|
||||
runtimeIndex *client.RuntimeIndex
|
||||
|
||||
// dhcp is the DHCP service implementation.
|
||||
dhcp DHCP
|
||||
|
||||
|
@ -105,8 +102,6 @@ func (clients *clientsContainer) Init(
|
|||
return errors.Error("clients container already initialized")
|
||||
}
|
||||
|
||||
clients.runtimeIndex = client.NewRuntimeIndex()
|
||||
|
||||
clients.storage = client.NewStorage(&client.Config{
|
||||
AllowedTags: clientTags,
|
||||
})
|
||||
|
@ -358,7 +353,7 @@ func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source)
|
|||
return client.SourcePersistent
|
||||
}
|
||||
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
if rc != nil {
|
||||
src, _ = rc.Info()
|
||||
}
|
||||
|
@ -539,22 +534,9 @@ func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent,
|
|||
return clients.storage.FindByMAC(foundMAC)
|
||||
}
|
||||
|
||||
// runtimeClient returns a runtime client from internal index. Note that it
|
||||
// doesn't include DHCP clients.
|
||||
func (clients *clientsContainer) runtimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
if ip == (netip.Addr{}) {
|
||||
return nil
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
return clients.runtimeIndex.Client(ip)
|
||||
}
|
||||
|
||||
// findRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
||||
rc = clients.runtimeClient(ip)
|
||||
rc = clients.storage.ClientRuntime(ip)
|
||||
host := clients.dhcp.HostByIP(ip)
|
||||
|
||||
if host != "" {
|
||||
|
@ -580,20 +562,11 @@ func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
|||
return
|
||||
}
|
||||
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
// Create a RuntimeClient implicitly so that we don't do this check
|
||||
// again.
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
rc := client.NewRuntime(ip)
|
||||
rc.SetWHOIS(wi)
|
||||
clients.storage.UpdateRuntime(rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
} else {
|
||||
host, _ := rc.Info()
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", host, wi)
|
||||
}
|
||||
|
||||
rc.SetWHOIS(wi)
|
||||
}
|
||||
|
||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
||||
|
@ -644,26 +617,20 @@ func (clients *clientsContainer) addHostLocked(
|
|||
host string,
|
||||
src client.Source,
|
||||
) (ok bool) {
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
if rc == nil {
|
||||
if src < client.SourceDHCP {
|
||||
if clients.dhcp.HostByIP(ip) != "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
rc = client.NewRuntime(ip)
|
||||
clients.runtimeIndex.Add(rc)
|
||||
}
|
||||
|
||||
rc := client.NewRuntime(ip)
|
||||
rc.SetInfo(src, []string{host})
|
||||
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
|
||||
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
|
||||
}
|
||||
|
||||
clients.storage.UpdateRuntime(rc)
|
||||
|
||||
log.Debug(
|
||||
"clients: adding client info %s -> %q %q [%d]",
|
||||
ip,
|
||||
src,
|
||||
host,
|
||||
clients.runtimeIndex.Size(),
|
||||
clients.storage.SizeRuntime(),
|
||||
)
|
||||
|
||||
return true
|
||||
|
@ -675,22 +642,24 @@ func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorag
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceHostsFile)
|
||||
log.Debug("clients: removed %d client aliases from system hosts file", deleted)
|
||||
|
||||
added := 0
|
||||
rcs := []*client.Runtime{}
|
||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||
// Only the first name of the first record is considered a canonical
|
||||
// hostname for the IP address.
|
||||
//
|
||||
// TODO(e.burkov): Consider using all the names from all the records.
|
||||
if clients.addHostLocked(addr, names[0], client.SourceHostsFile) {
|
||||
rc := client.NewRuntime(addr)
|
||||
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
|
||||
|
||||
added++
|
||||
}
|
||||
rcs = append(rcs, rc)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts file", added)
|
||||
}
|
||||
|
||||
|
@ -715,15 +684,17 @@ func (clients *clientsContainer) addFromSystemARP() {
|
|||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
deleted := clients.runtimeIndex.DeleteBySource(client.SourceARP)
|
||||
log.Debug("clients: removed %d client aliases from arp neighborhood", deleted)
|
||||
|
||||
added := 0
|
||||
rcs := []*client.Runtime{}
|
||||
for _, n := range ns {
|
||||
if clients.addHostLocked(n.IP, n.Name, client.SourceARP) {
|
||||
rc := client.NewRuntime(n.IP)
|
||||
rc.SetInfo(client.SourceARP, []string{n.Name})
|
||||
|
||||
added++
|
||||
rcs = append(rcs, rc)
|
||||
}
|
||||
}
|
||||
|
||||
clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
|
||||
|
||||
log.Debug("clients: added %d client aliases from arp neighborhood", added)
|
||||
}
|
||||
|
|
|
@ -240,7 +240,7 @@ func TestClientsWHOIS(t *testing.T) {
|
|||
t.Run("new_client", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("1.1.1.255")
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
|
@ -252,7 +252,7 @@ func TestClientsWHOIS(t *testing.T) {
|
|||
assert.True(t, ok)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, whois, rc.WHOIS())
|
||||
|
@ -269,7 +269,7 @@ func TestClientsWHOIS(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
clients.setWHOISInfo(ip, whois)
|
||||
rc := clients.runtimeIndex.Client(ip)
|
||||
rc := clients.storage.ClientRuntime(ip)
|
||||
require.Nil(t, rc)
|
||||
|
||||
assert.True(t, clients.storage.RemoveByName("client1"))
|
||||
|
|
|
@ -103,7 +103,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
|||
return true
|
||||
})
|
||||
|
||||
clients.runtimeIndex.Range(func(rc *client.Runtime) (cont bool) {
|
||||
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
cj := runtimeClientJSON{
|
||||
WHOIS: whoisOrEmpty(rc),
|
||||
|
|
|
@ -354,6 +354,19 @@ type Info struct {
|
|||
Orgname string `json:"orgname,omitempty"`
|
||||
}
|
||||
|
||||
// Clone returns a deep copy of the WHOIS info.
|
||||
func (i *Info) Clone() (c *Info) {
|
||||
if i == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &Info{
|
||||
City: i.City,
|
||||
Country: i.Country,
|
||||
Orgname: i.Orgname,
|
||||
}
|
||||
}
|
||||
|
||||
// cacheItem represents an item that we will store in the cache.
|
||||
type cacheItem struct {
|
||||
// expiry is the time when cacheItem will expire.
|
||||
|
|
Loading…
Reference in New Issue