diff --git a/home/clients.go b/home/clients.go index 8fbf72a2..5fb1115b 100644 --- a/home/clients.go +++ b/home/clients.go @@ -37,6 +37,11 @@ type Client struct { BlockedServices []string Upstreams []string // list of upstream servers to be used for the client's requests + // Upstream objects: + // nil: not yet initialized + // not nil, but empty: initialized, no good upstreams + // not nil, not empty: Upstreams ready to be used + upstreamObjects []upstream.Upstream } type clientSource uint @@ -62,12 +67,7 @@ type clientsContainer struct { list map[string]*Client // name -> client idIndex map[string]*Client // IP -> client ipHost map[string]*ClientHost // IP -> Hostname - - // cache for Upstream instances that are used in the case - // when custom DNS servers are configured for a client - upstreamsCache map[string][]upstream.Upstream // name -> []Upstream - - lock sync.Mutex + lock sync.Mutex // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server @@ -84,7 +84,6 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) clients.ipHost = make(map[string]*ClientHost) - clients.upstreamsCache = make(map[string][]upstream.Upstream) clients.dhcpServer = dhcpServer clients.addFromConfig(objects) @@ -198,6 +197,12 @@ func (clients *clientsContainer) Find(ip string) (Client, bool) { return clients.findByIP(ip) } +func upstreamArrayCopy(a []upstream.Upstream) []upstream.Upstream { + a2 := make([]upstream.Upstream, len(a)) + copy(a2, a) + return a2 +} + // FindUpstreams looks for upstreams configured for the client // If no client found for this IP, or if no custom upstreams are configured, // this method returns nil @@ -210,31 +215,22 @@ func (clients *clientsContainer) FindUpstreams(ip string) []upstream.Upstream { return nil } - if len(c.Upstreams) == 0 { + if c.upstreamObjects == nil { + c.upstreamObjects = make([]upstream.Upstream, 0) + for _, us := range c.Upstreams { + u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) + if err != nil { + log.Error("upstream.AddressToUpstream: %s: %s", us, err) + continue + } + c.upstreamObjects = append(c.upstreamObjects, u) + } + } + + if len(c.upstreamObjects) == 0 { return nil } - - upstreams, ok := clients.upstreamsCache[c.Name] - if ok { - return upstreams - } - - for _, us := range c.Upstreams { - u, err := upstream.AddressToUpstream(us, upstream.Options{Timeout: dnsforward.DefaultTimeout}) - if err != nil { - log.Error("upstream.AddressToUpstream: %s: %s", us, err) - continue - } - upstreams = append(upstreams, u) - } - - if len(upstreams) == 0 { - clients.upstreamsCache[c.Name] = nil - } else { - clients.upstreamsCache[c.Name] = upstreams - } - - return upstreams + return upstreamArrayCopy(c.upstreamObjects) } // Find searches for a client by IP (and does not lock anything) @@ -390,9 +386,6 @@ func (clients *clientsContainer) Del(name string) bool { // update Name index delete(clients.list, name) - // update upstreams cache - delete(clients.upstreamsCache, name) - // update ID index for _, id := range c.IDs { delete(clients.idIndex, id) @@ -461,11 +454,7 @@ func (clients *clientsContainer) Update(name string, c Client) error { } // update upstreams cache - if old.Name != c.Name { - delete(clients.upstreamsCache, old.Name) - } else { - delete(clients.upstreamsCache, c.Name) - } + c.upstreamObjects = nil *old = c return nil