332 lines
7.6 KiB
Go
332 lines
7.6 KiB
Go
package client
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"slices"
|
|
"sync"
|
|
|
|
"github.com/AdguardTeam/golibs/container"
|
|
"github.com/AdguardTeam/golibs/errors"
|
|
"github.com/AdguardTeam/golibs/log"
|
|
)
|
|
|
|
// Config is the client storage configuration structure.
|
|
//
|
|
// TODO(s.chzhen): Expand.
|
|
type Config struct {
|
|
// AllowedTags is a list of all allowed client tags.
|
|
AllowedTags []string
|
|
}
|
|
|
|
// Storage contains information about persistent and runtime clients.
|
|
type Storage struct {
|
|
// allowedTags is a set of all allowed tags.
|
|
allowedTags *container.MapSet[string]
|
|
|
|
// mu protects indexes of persistent and runtime clients.
|
|
mu *sync.Mutex
|
|
|
|
// index contains information about persistent clients.
|
|
index *index
|
|
|
|
// runtimeIndex contains information about runtime clients.
|
|
runtimeIndex *RuntimeIndex
|
|
}
|
|
|
|
// NewStorage returns initialized client storage. conf must not be nil.
|
|
func NewStorage(conf *Config) (s *Storage) {
|
|
allowedTags := container.NewMapSet(conf.AllowedTags...)
|
|
|
|
return &Storage{
|
|
allowedTags: allowedTags,
|
|
mu: &sync.Mutex{},
|
|
index: newIndex(),
|
|
runtimeIndex: NewRuntimeIndex(),
|
|
}
|
|
}
|
|
|
|
// Add stores persistent client information or returns an error.
|
|
func (s *Storage) Add(p *Persistent) (err error) {
|
|
defer func() { err = errors.Annotate(err, "adding client: %w") }()
|
|
|
|
err = p.validate(s.allowedTags)
|
|
if err != nil {
|
|
// Don't wrap the error since there is already an annotation deferred.
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
err = s.index.clashesUID(p)
|
|
if err != nil {
|
|
// Don't wrap the error since there is already an annotation deferred.
|
|
return err
|
|
}
|
|
|
|
err = s.index.clashes(p)
|
|
if err != nil {
|
|
// Don't wrap the error since there is already an annotation deferred.
|
|
return err
|
|
}
|
|
|
|
s.index.add(p)
|
|
|
|
log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size())
|
|
|
|
return nil
|
|
}
|
|
|
|
// FindByName finds persistent client by name. And returns its shallow copy.
|
|
func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p, ok = s.index.findByName(name)
|
|
if ok {
|
|
return p.ShallowClone(), ok
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// Find finds persistent client by string representation of the client ID, IP
|
|
// address, or MAC. And returns its shallow copy.
|
|
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p, ok = s.index.find(id)
|
|
if ok {
|
|
return p.ShallowClone(), ok
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// FindLoose is like [Storage.Find] but it also tries to find a persistent
|
|
// client by IP address without zone. It strips the IPv6 zone index from the
|
|
// stored IP addresses before comparing, because querylog entries don't have it.
|
|
// See TODO on [querylog.logEntry.IP].
|
|
//
|
|
// Note that multiple clients can have the same IP address with different zones.
|
|
// Therefore, the result of this method is indeterminate.
|
|
func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p, ok = s.index.find(id)
|
|
if ok {
|
|
return p.ShallowClone(), ok
|
|
}
|
|
|
|
p = s.index.findByIPWithoutZone(ip)
|
|
if p != nil {
|
|
return p.ShallowClone(), true
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// FindByMAC finds persistent client by MAC and returns its shallow copy.
|
|
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p, ok = s.index.findByMAC(mac)
|
|
if ok {
|
|
return p.ShallowClone(), ok
|
|
}
|
|
|
|
return nil, false
|
|
}
|
|
|
|
// RemoveByName removes persistent client information. ok is false if no such
|
|
// client exists by that name.
|
|
func (s *Storage) RemoveByName(name string) (ok bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
p, ok := s.index.findByName(name)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
if err := p.CloseUpstreams(); err != nil {
|
|
log.Error("client storage: removing client %q: %s", p.Name, err)
|
|
}
|
|
|
|
s.index.remove(p)
|
|
|
|
return true
|
|
}
|
|
|
|
// Update finds the stored persistent client by its name and updates its
|
|
// information from p.
|
|
func (s *Storage) Update(name string, p *Persistent) (err error) {
|
|
defer func() { err = errors.Annotate(err, "updating client: %w") }()
|
|
|
|
err = p.validate(s.allowedTags)
|
|
if err != nil {
|
|
// Don't wrap the error since there is already an annotation deferred.
|
|
return err
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
stored, ok := s.index.findByName(name)
|
|
if !ok {
|
|
return fmt.Errorf("client %q is not found", name)
|
|
}
|
|
|
|
// Client p has a newly generated UID, so replace it with the stored one.
|
|
//
|
|
// TODO(s.chzhen): Remove when frontend starts handling UIDs.
|
|
p.UID = stored.UID
|
|
|
|
err = s.index.clashes(p)
|
|
if err != nil {
|
|
// Don't wrap the error since there is already an annotation deferred.
|
|
return err
|
|
}
|
|
|
|
s.index.remove(stored)
|
|
s.index.add(p)
|
|
|
|
return nil
|
|
}
|
|
|
|
// RangeByName calls f for each persistent client sorted by name, unless cont is
|
|
// false.
|
|
func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.index.rangeByName(f)
|
|
}
|
|
|
|
// Size returns the number of persistent clients.
|
|
func (s *Storage) Size() (n int) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.index.size()
|
|
}
|
|
|
|
// CloseUpstreams closes upstream configurations of persistent clients.
|
|
func (s *Storage) CloseUpstreams() (err error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.index.closeUpstreams()
|
|
}
|
|
|
|
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
|
// client exists, returns nil.
|
|
//
|
|
// TODO(s.chzhen): Use it.
|
|
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.runtimeIndex.Client(ip)
|
|
}
|
|
|
|
// UpdateRuntime updates the stored runtime client with information from rc. If
|
|
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
|
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.updateRuntimeLocked(rc)
|
|
}
|
|
|
|
// updateRuntimeLocked updates the stored runtime client with information from
|
|
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
|
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
|
|
stored := s.runtimeIndex.Client(rc.ip)
|
|
if stored == nil {
|
|
s.runtimeIndex.Add(rc.Clone())
|
|
|
|
return true
|
|
}
|
|
|
|
if rc.whois != nil {
|
|
stored.whois = rc.whois.Clone()
|
|
}
|
|
|
|
if rc.arp != nil {
|
|
stored.arp = slices.Clone(rc.arp)
|
|
}
|
|
|
|
if rc.rdns != nil {
|
|
stored.rdns = slices.Clone(rc.rdns)
|
|
}
|
|
|
|
if rc.dhcp != nil {
|
|
stored.dhcp = slices.Clone(rc.dhcp)
|
|
}
|
|
|
|
if rc.hostsFile != nil {
|
|
stored.hostsFile = slices.Clone(rc.hostsFile)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// BatchUpdateBySource updates the stored runtime clients information from the
|
|
// specified source and returns the number of added and removed clients.
|
|
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
for _, rc := range s.runtimeIndex.index {
|
|
rc.unset(src)
|
|
}
|
|
|
|
for _, rc := range rcs {
|
|
if s.updateRuntimeLocked(rc) {
|
|
added++
|
|
}
|
|
}
|
|
|
|
for ip, rc := range s.runtimeIndex.index {
|
|
if rc.isEmpty() {
|
|
delete(s.runtimeIndex.index, ip)
|
|
removed++
|
|
}
|
|
}
|
|
|
|
return added, removed
|
|
}
|
|
|
|
// SizeRuntime returns the number of the runtime clients.
|
|
func (s *Storage) SizeRuntime() (n int) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.runtimeIndex.Size()
|
|
}
|
|
|
|
// RangeRuntime calls f for each runtime client in an undefined order.
|
|
func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.runtimeIndex.Range(f)
|
|
}
|
|
|
|
// DeleteBySource removes all runtime clients that have information only from
|
|
// the specified source and returns the number of removed clients.
|
|
//
|
|
// TODO(s.chzhen): Use it.
|
|
func (s *Storage) DeleteBySource(src Source) (n int) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
return s.runtimeIndex.DeleteBySource(src)
|
|
}
|