diff --git a/internal/client/storage.go b/internal/client/storage.go index d6de62c3..46597cec 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -2,12 +2,14 @@ package client import ( "fmt" + "net/netip" "slices" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" ) // Storage contains information about persistent and runtime clients. @@ -17,6 +19,9 @@ type Storage struct { // index contains information about persistent clients. index *Index + + // runtimeIndex contains information about runtime clients. + runtimeIndex map[netip.Addr]*Runtime } // NewStorage returns initialized client storage. @@ -24,8 +29,9 @@ func NewStorage(clientTags []string) (s *Storage) { allTags := container.NewMapSet(clientTags...) return &Storage{ - allTags: allTags, - index: NewIndex(), + allTags: allTags, + index: NewIndex(), + runtimeIndex: map[netip.Addr]*Runtime{}, } } @@ -52,11 +58,16 @@ func (s *Storage) check(p *Persistent) (err error) { return errors.Error("id required") } - _, err = proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) + conf, err := proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{}) if err != nil { return fmt.Errorf("invalid upstream servers: %w", err) } + err = conf.Close() + if err != nil { + log.Error("client: closing upstream config: %s", err) + } + for _, t := range p.Tags { if !s.allTags.Has(t) { return fmt.Errorf("invalid tag: %q", t) @@ -104,3 +115,50 @@ func (s *Storage) Update(p, n *Persistent) (err error) { return nil } + +// ClientRuntime returns the saved runtime client by ip. If no such client +// exists, returns nil. +func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { + return s.runtimeIndex[ip] +} + +// AddRuntime saves the runtime client information in the storage. IP address +// of a client must be unique. rc must not be nil. +func (s *Storage) AddRuntime(rc *Runtime) { + ip := rc.Addr() + s.runtimeIndex[ip] = rc +} + +// SizeRuntime returns the number of the runtime clients. +func (s *Storage) SizeRuntime() (n int) { + return len(s.runtimeIndex) +} + +// RangeRuntime calls f for each runtime client in an undefined order. +func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { + for _, rc := range s.runtimeIndex { + if !f(rc) { + return + } + } +} + +// DeleteRuntime removes the runtime client by ip. +func (s *Storage) DeleteRuntime(ip netip.Addr) { + delete(s.runtimeIndex, ip) +} + +// DeleteBySource removes all runtime clients that have information only from +// the specified source and returns the number of removed clients. +func (s *Storage) DeleteBySource(src Source) (n int) { + for ip, rc := range s.runtimeIndex { + rc.unset(src) + + if rc.isEmpty() { + delete(s.runtimeIndex, ip) + n++ + } + } + + return n +}