all: persistent client storage
This commit is contained in:
parent
bcda80bee7
commit
bde3baa5da
|
@ -64,7 +64,7 @@ func NewIndex() (ci *Index) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add stores information about a persistent client in the index. c must be
|
// Add stores information about a persistent client in the index. c must be
|
||||||
// non-nil and contain UID.
|
// non-nil, have a UID, and contain at least one identifier.
|
||||||
func (ci *Index) Add(c *Persistent) {
|
func (ci *Index) Add(c *Persistent) {
|
||||||
if (c.UID == UID{}) {
|
if (c.UID == UID{}) {
|
||||||
panic("client must contain uid")
|
panic("client must contain uid")
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/container"
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Storage contains information about persistent and runtime clients.
|
||||||
|
type Storage struct {
|
||||||
|
// allTags is a set of all client tags.
|
||||||
|
allTags *container.MapSet[string]
|
||||||
|
|
||||||
|
// index contains information about persistent clients.
|
||||||
|
index *Index
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStorage returns initialized client storage.
|
||||||
|
func NewStorage(clientTags []string) (s *Storage) {
|
||||||
|
allTags := container.NewMapSet(clientTags...)
|
||||||
|
|
||||||
|
return &Storage{
|
||||||
|
allTags: allTags,
|
||||||
|
index: NewIndex(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add stores persistent client information or returns an error.
|
||||||
|
func (s *Storage) Add(p *Persistent) (err error) {
|
||||||
|
err = s.check(p)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("adding client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Add(p)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// check returns an error if persistent client information contains errors.
|
||||||
|
func (s *Storage) check(p *Persistent) (err error) {
|
||||||
|
switch {
|
||||||
|
case p == nil:
|
||||||
|
return errors.Error("client is nil")
|
||||||
|
case p.Name == "":
|
||||||
|
return errors.Error("empty name")
|
||||||
|
case p.IDsLen() == 0:
|
||||||
|
return errors.Error("id required")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = proxy.ParseUpstreamsConfig(p.Upstreams, &upstream.Options{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid upstream servers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range p.Tags {
|
||||||
|
if !s.allTags.Has(t) {
|
||||||
|
return fmt.Errorf("invalid tag: %q", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(s.chzhen): Move to the constructor.
|
||||||
|
slices.Sort(p.Tags)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveByName removes persistent client information. ok is false if no such
|
||||||
|
// client exists by that name.
|
||||||
|
func (s *Storage) RemoveByName(name string) (ok bool) {
|
||||||
|
p, ok := s.index.FindByName(name)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Delete(p)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update updates stored persistent client information p with new information n
|
||||||
|
// or returns an error. p and n must have the same UID.
|
||||||
|
func (s *Storage) Update(p, n *Persistent) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "updating client: %w") }()
|
||||||
|
|
||||||
|
err = s.check(n)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.index.Clashes(n)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.index.Delete(p)
|
||||||
|
s.index.Add(n)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -310,7 +309,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
objs = make([]*clientObject, 0, clients.clientIndex.Size())
|
||||||
clients.clientIndex.Range(func(cli *client.Persistent) (cont bool) {
|
clients.clientIndex.RangeByName(func(cli *client.Persistent) (cont bool) {
|
||||||
objs = append(objs, &clientObject{
|
objs = append(objs, &clientObject{
|
||||||
Name: cli.Name,
|
Name: cli.Name,
|
||||||
|
|
||||||
|
@ -337,14 +336,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
// Maps aren't guaranteed to iterate in the same order each time, so the
|
|
||||||
// above loop can generate different orderings when writing to the config
|
|
||||||
// file: this produces lots of diffs in config files, so sort objects by
|
|
||||||
// name before writing.
|
|
||||||
slices.SortStableFunc(objs, func(a, b *clientObject) (res int) {
|
|
||||||
return strings.Compare(a.Name, b.Name)
|
|
||||||
})
|
|
||||||
|
|
||||||
return objs
|
return objs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue