Pull request 2273: AG-27492-client-storage-runtime-sources
Squashed commit of the following: commit3191224d6d
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 26 18:20:04 2024 +0300 client: imp tests commit6cc4ed53a2
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 26 18:04:36 2024 +0300 client: imp code commit79272b299a
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 26 16:10:06 2024 +0300 all: imp code commit0a001fffbe
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Sep 24 20:05:47 2024 +0300 all: imp tests commit80f7e98d30
Merge:df7492e9d
e338214ad
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Sep 24 19:10:13 2024 +0300 Merge branch 'master' into AG-27492-client-storage-runtime-sources commitdf7492e9de
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Sep 24 19:06:37 2024 +0300 all: imp code commit23896ae5a6
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 19 21:04:34 2024 +0300 client: fix typo commitba0ba2478c
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 19 21:02:13 2024 +0300 all: imp code commitf7315be742
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 12 14:35:38 2024 +0300 home: imp code commitf63d0e80fb
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 12 14:15:49 2024 +0300 all: imp code commit9feda414b6
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Sep 10 17:53:42 2024 +0300 all: imp code commitfafd7cbb52
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Sep 9 21:13:05 2024 +0300 all: imp code commit2d2b8e0216
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Sep 5 20:55:10 2024 +0300 client: add tests commit4d394e6f21
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 29 20:40:38 2024 +0300 all: client storage runtime sources
This commit is contained in:
parent
e338214ad5
commit
d40de33316
|
@ -119,8 +119,8 @@ func (r *Runtime) Info() (cs Source, host string) {
|
||||||
return cs, info[0]
|
return cs, info[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInfo sets a host as a client information from the cs.
|
// setInfo sets a host as a client information from the cs.
|
||||||
func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
func (r *Runtime) setInfo(cs Source, hosts []string) {
|
||||||
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
// TODO(s.chzhen): Use contract where hosts must contain non-empty host.
|
||||||
if len(hosts) == 1 && hosts[0] == "" {
|
if len(hosts) == 1 && hosts[0] == "" {
|
||||||
hosts = []string{}
|
hosts = []string{}
|
||||||
|
@ -138,13 +138,13 @@ func (r *Runtime) SetInfo(cs Source, hosts []string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WHOIS returns a WHOIS client information.
|
// WHOIS returns a copy of WHOIS client information.
|
||||||
func (r *Runtime) WHOIS() (info *whois.Info) {
|
func (r *Runtime) WHOIS() (info *whois.Info) {
|
||||||
return r.whois
|
return r.whois.Clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWHOIS sets a WHOIS client information. info must be non-nil.
|
// setWHOIS sets a WHOIS client information. info must be non-nil.
|
||||||
func (r *Runtime) SetWHOIS(info *whois.Info) {
|
func (r *Runtime) setWHOIS(info *whois.Info) {
|
||||||
r.whois = info
|
r.whois = info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -178,8 +178,8 @@ func (r *Runtime) Addr() (ip netip.Addr) {
|
||||||
return r.ip
|
return r.ip
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone returns a deep copy of the runtime client.
|
// clone returns a deep copy of the runtime client.
|
||||||
func (r *Runtime) Clone() (c *Runtime) {
|
func (r *Runtime) clone() (c *Runtime) {
|
||||||
return &Runtime{
|
return &Runtime{
|
||||||
ip: r.ip,
|
ip: r.ip,
|
||||||
whois: r.whois.Clone(),
|
whois: r.whois.Clone(),
|
||||||
|
|
|
@ -13,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
@ -136,7 +135,8 @@ type Persistent struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate returns an error if persistent client information contains errors.
|
// validate returns an error if persistent client information contains errors.
|
||||||
func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
// allTags must be sorted.
|
||||||
|
func (c *Persistent) validate(allTags []string) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case c.Name == "":
|
case c.Name == "":
|
||||||
return errors.Error("empty name")
|
return errors.Error("empty name")
|
||||||
|
@ -157,7 +157,8 @@ func (c *Persistent) validate(allTags *container.MapSet[string]) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range c.Tags {
|
for _, t := range c.Tags {
|
||||||
if !allTags.Has(t) {
|
_, ok := slices.BinarySearch(allTags, t)
|
||||||
|
if !ok {
|
||||||
return fmt.Errorf("invalid tag: %q", t)
|
return fmt.Errorf("invalid tag: %q", t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/container"
|
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -125,69 +122,3 @@ func TestPersistent_EqualIDs(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPersistent_Validate(t *testing.T) {
|
|
||||||
const (
|
|
||||||
allowedTag = "allowed_tag"
|
|
||||||
notAllowedTag = "not_allowed_tag"
|
|
||||||
)
|
|
||||||
|
|
||||||
allowedTags := container.NewMapSet(allowedTag)
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
cli *Persistent
|
|
||||||
wantErrMsg string
|
|
||||||
}{{
|
|
||||||
name: "success",
|
|
||||||
cli: &Persistent{
|
|
||||||
Name: "basic",
|
|
||||||
IPs: []netip.Addr{
|
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
|
||||||
},
|
|
||||||
UID: MustNewUID(),
|
|
||||||
},
|
|
||||||
wantErrMsg: "",
|
|
||||||
}, {
|
|
||||||
name: "empty_name",
|
|
||||||
cli: &Persistent{
|
|
||||||
Name: "",
|
|
||||||
},
|
|
||||||
wantErrMsg: "empty name",
|
|
||||||
}, {
|
|
||||||
name: "no_id",
|
|
||||||
cli: &Persistent{
|
|
||||||
Name: "no_id",
|
|
||||||
},
|
|
||||||
wantErrMsg: "id required",
|
|
||||||
}, {
|
|
||||||
name: "no_uid",
|
|
||||||
cli: &Persistent{
|
|
||||||
Name: "no_uid",
|
|
||||||
IPs: []netip.Addr{
|
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErrMsg: "uid required",
|
|
||||||
}, {
|
|
||||||
name: "not_allowed_tag",
|
|
||||||
cli: &Persistent{
|
|
||||||
Name: "basic",
|
|
||||||
IPs: []netip.Addr{
|
|
||||||
netip.MustParseAddr("1.2.3.4"),
|
|
||||||
},
|
|
||||||
UID: MustNewUID(),
|
|
||||||
Tags: []string{
|
|
||||||
notAllowedTag,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErrMsg: `invalid tag: "` + notAllowedTag + `"`,
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := tc.cli.validate(allowedTags)
|
|
||||||
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,39 +2,34 @@ package client
|
||||||
|
|
||||||
import "net/netip"
|
import "net/netip"
|
||||||
|
|
||||||
// RuntimeIndex stores information about runtime clients.
|
// runtimeIndex stores information about runtime clients.
|
||||||
type RuntimeIndex struct {
|
type runtimeIndex struct {
|
||||||
// index maps IP address to runtime client.
|
// index maps IP address to runtime client.
|
||||||
index map[netip.Addr]*Runtime
|
index map[netip.Addr]*Runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRuntimeIndex returns initialized runtime index.
|
// newRuntimeIndex returns initialized runtime index.
|
||||||
func NewRuntimeIndex() (ri *RuntimeIndex) {
|
func newRuntimeIndex() (ri *runtimeIndex) {
|
||||||
return &RuntimeIndex{
|
return &runtimeIndex{
|
||||||
index: map[netip.Addr]*Runtime{},
|
index: map[netip.Addr]*Runtime{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Client returns the saved runtime client by ip. If no such client exists,
|
// client returns the saved runtime client by ip. If no such client exists,
|
||||||
// returns nil.
|
// returns nil.
|
||||||
func (ri *RuntimeIndex) Client(ip netip.Addr) (rc *Runtime) {
|
func (ri *runtimeIndex) client(ip netip.Addr) (rc *Runtime) {
|
||||||
return ri.index[ip]
|
return ri.index[ip]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add saves the runtime client in the index. IP address of a client must be
|
// add saves the runtime client in the index. IP address of a client must be
|
||||||
// unique. See [Runtime.Client]. rc must not be nil.
|
// unique. See [Runtime.Client]. rc must not be nil.
|
||||||
func (ri *RuntimeIndex) Add(rc *Runtime) {
|
func (ri *runtimeIndex) add(rc *Runtime) {
|
||||||
ip := rc.Addr()
|
ip := rc.Addr()
|
||||||
ri.index[ip] = rc
|
ri.index[ip] = rc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the number of the runtime clients.
|
// rangeClients calls f for each runtime client in an undefined order.
|
||||||
func (ri *RuntimeIndex) Size() (n int) {
|
func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) {
|
||||||
return len(ri.index)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Range calls f for each runtime client in an undefined order.
|
|
||||||
func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
|
|
||||||
for _, rc := range ri.index {
|
for _, rc := range ri.index {
|
||||||
if !f(rc) {
|
if !f(rc) {
|
||||||
return
|
return
|
||||||
|
@ -42,17 +37,31 @@ func (ri *RuntimeIndex) Range(f func(rc *Runtime) (cont bool)) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes the runtime client by ip.
|
// setInfo sets the client information from cs for runtime client stored by ip.
|
||||||
func (ri *RuntimeIndex) Delete(ip netip.Addr) {
|
// If no such client exists, it creates one.
|
||||||
delete(ri.index, ip)
|
func (ri *runtimeIndex) setInfo(ip netip.Addr, cs Source, hosts []string) (rc *Runtime) {
|
||||||
|
rc = ri.index[ip]
|
||||||
|
if rc == nil {
|
||||||
|
rc = NewRuntime(ip)
|
||||||
|
ri.add(rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc.setInfo(cs, hosts)
|
||||||
|
|
||||||
|
return rc
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteBySource removes all runtime clients that have information only from
|
// clearSource removes information from the specified source from all clients.
|
||||||
// the specified source and returns the number of removed clients.
|
func (ri *runtimeIndex) clearSource(src Source) {
|
||||||
func (ri *RuntimeIndex) DeleteBySource(src Source) (n int) {
|
for _, rc := range ri.index {
|
||||||
for ip, rc := range ri.index {
|
|
||||||
rc.unset(src)
|
rc.unset(src)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeEmpty removes empty runtime clients and returns the number of removed
|
||||||
|
// clients.
|
||||||
|
func (ri *runtimeIndex) removeEmpty() (n int) {
|
||||||
|
for ip, rc := range ri.index {
|
||||||
if rc.isEmpty() {
|
if rc.isEmpty() {
|
||||||
delete(ri.index, ip)
|
delete(ri.index, ip)
|
||||||
n++
|
n++
|
||||||
|
|
|
@ -1,85 +0,0 @@
|
||||||
package client_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRuntimeIndex(t *testing.T) {
|
|
||||||
const cliSrc = client.SourceARP
|
|
||||||
|
|
||||||
var (
|
|
||||||
ip1 = netip.MustParseAddr("1.1.1.1")
|
|
||||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
|
||||||
ip3 = netip.MustParseAddr("3.3.3.3")
|
|
||||||
)
|
|
||||||
|
|
||||||
ri := client.NewRuntimeIndex()
|
|
||||||
currentSize := 0
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
ip netip.Addr
|
|
||||||
name string
|
|
||||||
hosts []string
|
|
||||||
src client.Source
|
|
||||||
}{{
|
|
||||||
src: cliSrc,
|
|
||||||
ip: ip1,
|
|
||||||
name: "1",
|
|
||||||
hosts: []string{"host1"},
|
|
||||||
}, {
|
|
||||||
src: cliSrc,
|
|
||||||
ip: ip2,
|
|
||||||
name: "2",
|
|
||||||
hosts: []string{"host2"},
|
|
||||||
}, {
|
|
||||||
src: cliSrc,
|
|
||||||
ip: ip3,
|
|
||||||
name: "3",
|
|
||||||
hosts: []string{"host3"},
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
rc := client.NewRuntime(tc.ip)
|
|
||||||
rc.SetInfo(tc.src, tc.hosts)
|
|
||||||
|
|
||||||
ri.Add(rc)
|
|
||||||
currentSize++
|
|
||||||
|
|
||||||
got := ri.Client(tc.ip)
|
|
||||||
assert.Equal(t, rc, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("size", func(t *testing.T) {
|
|
||||||
assert.Equal(t, currentSize, ri.Size())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("range", func(t *testing.T) {
|
|
||||||
s := 0
|
|
||||||
|
|
||||||
ri.Range(func(rc *client.Runtime) (cont bool) {
|
|
||||||
s++
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
assert.Equal(t, currentSize, s)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete", func(t *testing.T) {
|
|
||||||
ri.Delete(ip1)
|
|
||||||
currentSize--
|
|
||||||
|
|
||||||
assert.Equal(t, currentSize, ri.Size())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete_by_src", func(t *testing.T) {
|
|
||||||
assert.Equal(t, currentSize, ri.DeleteBySource(cliSrc))
|
|
||||||
assert.Equal(t, 0, ri.Size())
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,30 +1,113 @@
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/container"
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the client storage configuration structure.
|
// allowedTags is the list of available client tags.
|
||||||
//
|
var allowedTags = []string{
|
||||||
// TODO(s.chzhen): Expand.
|
"device_audio",
|
||||||
type Config struct {
|
"device_camera",
|
||||||
// AllowedTags is a list of all allowed client tags.
|
"device_gameconsole",
|
||||||
AllowedTags []string
|
"device_laptop",
|
||||||
|
"device_nas", // Network-attached Storage
|
||||||
|
"device_other",
|
||||||
|
"device_pc",
|
||||||
|
"device_phone",
|
||||||
|
"device_printer",
|
||||||
|
"device_securityalarm",
|
||||||
|
"device_tablet",
|
||||||
|
"device_tv",
|
||||||
|
|
||||||
|
"os_android",
|
||||||
|
"os_ios",
|
||||||
|
"os_linux",
|
||||||
|
"os_macos",
|
||||||
|
"os_other",
|
||||||
|
"os_windows",
|
||||||
|
|
||||||
|
"user_admin",
|
||||||
|
"user_child",
|
||||||
|
"user_regular",
|
||||||
|
}
|
||||||
|
|
||||||
|
// DHCP is an interface for accessing DHCP lease data the [Storage] needs.
|
||||||
|
type DHCP interface {
|
||||||
|
// Leases returns all the DHCP leases.
|
||||||
|
Leases() (leases []*dhcpsvc.Lease)
|
||||||
|
|
||||||
|
// HostByIP returns the hostname of the DHCP client with the given IP
|
||||||
|
// address. host will be empty if there is no such client, due to an
|
||||||
|
// assumption that a DHCP client must always have a hostname.
|
||||||
|
HostByIP(ip netip.Addr) (host string)
|
||||||
|
|
||||||
|
// MACByIP returns the MAC address for the given IP address leased. It
|
||||||
|
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||||
|
// client must always have a MAC address.
|
||||||
|
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmptyDHCP is the empty [DHCP] implementation that does nothing.
|
||||||
|
type EmptyDHCP struct{}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ DHCP = EmptyDHCP{}
|
||||||
|
|
||||||
|
// Leases implements the [DHCP] interface for emptyDHCP.
|
||||||
|
func (EmptyDHCP) Leases() (leases []*dhcpsvc.Lease) { return nil }
|
||||||
|
|
||||||
|
// HostByIP implements the [DHCP] interface for emptyDHCP.
|
||||||
|
func (EmptyDHCP) HostByIP(_ netip.Addr) (host string) { return "" }
|
||||||
|
|
||||||
|
// MACByIP implements the [DHCP] interface for emptyDHCP.
|
||||||
|
func (EmptyDHCP) MACByIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
|
||||||
|
|
||||||
|
// HostsContainer is an interface for receiving updates to the system hosts
|
||||||
|
// file.
|
||||||
|
type HostsContainer interface {
|
||||||
|
Upd() (updates <-chan *hostsfile.DefaultStorage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageConfig is the client storage configuration structure.
|
||||||
|
type StorageConfig struct {
|
||||||
|
// DHCP is used to match IPs against MACs of persistent clients and update
|
||||||
|
// [SourceDHCP] runtime client information. It must not be nil.
|
||||||
|
DHCP DHCP
|
||||||
|
|
||||||
|
// EtcHosts is used to update [SourceHostsFile] runtime client information.
|
||||||
|
EtcHosts HostsContainer
|
||||||
|
|
||||||
|
// ARPDB is used to update [SourceARP] runtime client information.
|
||||||
|
ARPDB arpdb.Interface
|
||||||
|
|
||||||
|
// InitialClients is a list of persistent clients parsed from the
|
||||||
|
// configuration file. Each client must not be nil.
|
||||||
|
InitialClients []*Persistent
|
||||||
|
|
||||||
|
// ARPClientsUpdatePeriod defines how often [SourceARP] runtime client
|
||||||
|
// information is updated.
|
||||||
|
ARPClientsUpdatePeriod time.Duration
|
||||||
|
|
||||||
|
// RuntimeSourceDHCP specifies whether to update [SourceDHCP] information
|
||||||
|
// of runtime clients.
|
||||||
|
RuntimeSourceDHCP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage contains information about persistent and runtime clients.
|
// Storage contains information about persistent and runtime clients.
|
||||||
type Storage struct {
|
type Storage struct {
|
||||||
// allowedTags is a set of all allowed tags.
|
|
||||||
allowedTags *container.MapSet[string]
|
|
||||||
|
|
||||||
// mu protects indexes of persistent and runtime clients.
|
// mu protects indexes of persistent and runtime clients.
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
|
|
||||||
|
@ -32,19 +115,250 @@ type Storage struct {
|
||||||
index *index
|
index *index
|
||||||
|
|
||||||
// runtimeIndex contains information about runtime clients.
|
// runtimeIndex contains information about runtime clients.
|
||||||
runtimeIndex *RuntimeIndex
|
runtimeIndex *runtimeIndex
|
||||||
|
|
||||||
|
// dhcp is used to update [SourceDHCP] runtime client information.
|
||||||
|
dhcp DHCP
|
||||||
|
|
||||||
|
// etcHosts is used to update [SourceHostsFile] runtime client information.
|
||||||
|
etcHosts HostsContainer
|
||||||
|
|
||||||
|
// arpDB is used to update [SourceARP] runtime client information.
|
||||||
|
arpDB arpdb.Interface
|
||||||
|
|
||||||
|
// done is the shutdown signaling channel.
|
||||||
|
done chan struct{}
|
||||||
|
|
||||||
|
// allowedTags is a sorted list of all allowed tags. It must not be
|
||||||
|
// modified after initialization.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use custom type.
|
||||||
|
allowedTags []string
|
||||||
|
|
||||||
|
// arpClientsUpdatePeriod defines how often [SourceARP] runtime client
|
||||||
|
// information is updated. It must be greater than zero.
|
||||||
|
arpClientsUpdatePeriod time.Duration
|
||||||
|
|
||||||
|
// runtimeSourceDHCP specifies whether to update [SourceDHCP] information
|
||||||
|
// of runtime clients.
|
||||||
|
runtimeSourceDHCP bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStorage returns initialized client storage. conf must not be nil.
|
// NewStorage returns initialized client storage. conf must not be nil.
|
||||||
func NewStorage(conf *Config) (s *Storage) {
|
func NewStorage(conf *StorageConfig) (s *Storage, err error) {
|
||||||
allowedTags := container.NewMapSet(conf.AllowedTags...)
|
tags := slices.Clone(allowedTags)
|
||||||
|
slices.Sort(tags)
|
||||||
|
|
||||||
return &Storage{
|
s = &Storage{
|
||||||
allowedTags: allowedTags,
|
allowedTags: tags,
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
index: newIndex(),
|
index: newIndex(),
|
||||||
runtimeIndex: NewRuntimeIndex(),
|
runtimeIndex: newRuntimeIndex(),
|
||||||
|
dhcp: conf.DHCP,
|
||||||
|
etcHosts: conf.EtcHosts,
|
||||||
|
arpDB: conf.ARPDB,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod,
|
||||||
|
runtimeSourceDHCP: conf.RuntimeSourceDHCP,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i, p := range conf.InitialClients {
|
||||||
|
err = s.Add(p)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.ReloadARP()
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the goroutines for updating the runtime client information.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Pass context.
|
||||||
|
func (s *Storage) Start(_ context.Context) (err error) {
|
||||||
|
go s.periodicARPUpdate()
|
||||||
|
go s.handleHostsUpdates()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully stops the client storage.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Pass context.
|
||||||
|
func (s *Storage) Shutdown(_ context.Context) (err error) {
|
||||||
|
close(s.done)
|
||||||
|
|
||||||
|
return s.closeUpstreams()
|
||||||
|
}
|
||||||
|
|
||||||
|
// periodicARPUpdate periodically reloads runtime clients from ARP. It is
|
||||||
|
// intended to be used as a goroutine.
|
||||||
|
func (s *Storage) periodicARPUpdate() {
|
||||||
|
defer log.OnPanic("storage")
|
||||||
|
|
||||||
|
t := time.NewTicker(s.arpClientsUpdatePeriod)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
s.ReloadARP()
|
||||||
|
case <-s.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReloadARP reloads runtime clients from ARP, if configured.
|
||||||
|
func (s *Storage) ReloadARP() {
|
||||||
|
if s.arpDB != nil {
|
||||||
|
s.addFromSystemARP()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||||
|
// command.
|
||||||
|
func (s *Storage) addFromSystemARP() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if err := s.arpDB.Refresh(); err != nil {
|
||||||
|
s.arpDB = arpdb.Empty{}
|
||||||
|
log.Error("refreshing arp container: %s", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := s.arpDB.Neighbors()
|
||||||
|
if len(ns) == 0 {
|
||||||
|
log.Debug("refreshing arp container: the update is empty")
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
src := SourceARP
|
||||||
|
s.runtimeIndex.clearSource(src)
|
||||||
|
|
||||||
|
for _, n := range ns {
|
||||||
|
s.runtimeIndex.setInfo(n.IP, src, []string{n.Name})
|
||||||
|
}
|
||||||
|
|
||||||
|
removed := s.runtimeIndex.removeEmpty()
|
||||||
|
|
||||||
|
log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHostsUpdates receives the updates from the hosts container and adds
|
||||||
|
// them to the clients storage. It is intended to be used as a goroutine.
|
||||||
|
func (s *Storage) handleHostsUpdates() {
|
||||||
|
if s.etcHosts == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer log.OnPanic("storage")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case upd, ok := <-s.etcHosts.Upd():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.addFromHostsFile(upd)
|
||||||
|
case <-s.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addFromHostsFile fills the client-hostname pairing index from the system's
|
||||||
|
// hosts files.
|
||||||
|
func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
src := SourceHostsFile
|
||||||
|
s.runtimeIndex.clearSource(src)
|
||||||
|
|
||||||
|
added := 0
|
||||||
|
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
||||||
|
// Only the first name of the first record is considered a canonical
|
||||||
|
// hostname for the IP address.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Consider using all the names from all the records.
|
||||||
|
s.runtimeIndex.setInfo(addr, src, []string{names[0]})
|
||||||
|
added++
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
removed := s.runtimeIndex.removeEmpty()
|
||||||
|
log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ AddressUpdater = (*Storage)(nil)
|
||||||
|
|
||||||
|
// UpdateAddress implements the [AddressUpdater] interface for *Storage
|
||||||
|
func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||||
|
// Common fast path optimization.
|
||||||
|
if host == "" && info == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if host != "" {
|
||||||
|
s.runtimeIndex.setInfo(ip, SourceRDNS, []string{host})
|
||||||
|
}
|
||||||
|
|
||||||
|
if info != nil {
|
||||||
|
s.setWHOISInfo(ip, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDHCP updates [SourceDHCP] runtime client information.
|
||||||
|
func (s *Storage) UpdateDHCP() {
|
||||||
|
if s.dhcp == nil || !s.runtimeSourceDHCP {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
src := SourceDHCP
|
||||||
|
s.runtimeIndex.clearSource(src)
|
||||||
|
|
||||||
|
added := 0
|
||||||
|
for _, l := range s.dhcp.Leases() {
|
||||||
|
s.runtimeIndex.setInfo(l.IP, src, []string{l.Hostname})
|
||||||
|
added++
|
||||||
|
}
|
||||||
|
|
||||||
|
removed := s.runtimeIndex.removeEmpty()
|
||||||
|
log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// setWHOISInfo sets the WHOIS information for a runtime client.
|
||||||
|
func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
||||||
|
_, ok := s.index.findByIP(ip)
|
||||||
|
if ok {
|
||||||
|
log.Debug("storage: client for %s is already created, ignore whois info", ip)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rc := s.runtimeIndex.client(ip)
|
||||||
|
if rc == nil {
|
||||||
|
rc = NewRuntime(ip)
|
||||||
|
s.runtimeIndex.add(rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
rc.setWHOIS(wi)
|
||||||
|
|
||||||
|
log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add stores persistent client information or returns an error.
|
// Add stores persistent client information or returns an error.
|
||||||
|
@ -94,6 +408,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) {
|
||||||
|
|
||||||
// Find finds persistent client by string representation of the client ID, IP
|
// Find finds persistent client by string representation of the client ID, IP
|
||||||
// address, or MAC. And returns its shallow copy.
|
// address, or MAC. And returns its shallow copy.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain
|
||||||
|
// the parsed IP address, if any.
|
||||||
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
@ -103,6 +420,16 @@ func (s *Storage) Find(id string) (p *Persistent, ok bool) {
|
||||||
return p.ShallowClone(), ok
|
return p.ShallowClone(), ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip, err := netip.ParseAddr(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
foundMAC := s.dhcp.MACByIP(ip)
|
||||||
|
if foundMAC != nil {
|
||||||
|
return s.FindByMAC(foundMAC)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,11 +457,9 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// FindByMAC finds persistent client by MAC and returns its shallow copy.
|
// FindByMAC finds persistent client by MAC and returns its shallow copy. s.mu
|
||||||
|
// is expected to be locked.
|
||||||
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
func (s *Storage) FindByMAC(mac net.HardwareAddr) (p *Persistent, ok bool) {
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
p, ok = s.index.findByMAC(mac)
|
p, ok = s.index.findByMAC(mac)
|
||||||
if ok {
|
if ok {
|
||||||
return p.ShallowClone(), ok
|
return p.ShallowClone(), ok
|
||||||
|
@ -216,8 +541,8 @@ func (s *Storage) Size() (n int) {
|
||||||
return s.index.size()
|
return s.index.size()
|
||||||
}
|
}
|
||||||
|
|
||||||
// CloseUpstreams closes upstream configurations of persistent clients.
|
// closeUpstreams closes upstream configurations of persistent clients.
|
||||||
func (s *Storage) CloseUpstreams() (err error) {
|
func (s *Storage) closeUpstreams() (err error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
@ -226,89 +551,27 @@ func (s *Storage) CloseUpstreams() (err error) {
|
||||||
|
|
||||||
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
// ClientRuntime returns a copy of the saved runtime client by ip. If no such
|
||||||
// client exists, returns nil.
|
// client exists, returns nil.
|
||||||
//
|
|
||||||
// TODO(s.chzhen): Use it.
|
|
||||||
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
return s.runtimeIndex.Client(ip)
|
rc = s.runtimeIndex.client(ip)
|
||||||
}
|
if rc != nil {
|
||||||
|
return rc.clone()
|
||||||
// UpdateRuntime updates the stored runtime client with information from rc. If
|
|
||||||
// no such client exists, saves the copy of rc in storage. rc must not be nil.
|
|
||||||
func (s *Storage) UpdateRuntime(rc *Runtime) (added bool) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
return s.updateRuntimeLocked(rc)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateRuntimeLocked updates the stored runtime client with information from
|
|
||||||
// rc. rc must not be nil. Storage.mu is expected to be locked.
|
|
||||||
func (s *Storage) updateRuntimeLocked(rc *Runtime) (added bool) {
|
|
||||||
stored := s.runtimeIndex.Client(rc.ip)
|
|
||||||
if stored == nil {
|
|
||||||
s.runtimeIndex.Add(rc.Clone())
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if rc.whois != nil {
|
if !s.runtimeSourceDHCP {
|
||||||
stored.whois = rc.whois.Clone()
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if rc.arp != nil {
|
host := s.dhcp.HostByIP(ip)
|
||||||
stored.arp = slices.Clone(rc.arp)
|
if host == "" {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if rc.rdns != nil {
|
rc = s.runtimeIndex.setInfo(ip, SourceDHCP, []string{host})
|
||||||
stored.rdns = slices.Clone(rc.rdns)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rc.dhcp != nil {
|
return rc.clone()
|
||||||
stored.dhcp = slices.Clone(rc.dhcp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rc.hostsFile != nil {
|
|
||||||
stored.hostsFile = slices.Clone(rc.hostsFile)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchUpdateBySource updates the stored runtime clients information from the
|
|
||||||
// specified source and returns the number of added and removed clients.
|
|
||||||
func (s *Storage) BatchUpdateBySource(src Source, rcs []*Runtime) (added, removed int) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
for _, rc := range s.runtimeIndex.index {
|
|
||||||
rc.unset(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rc := range rcs {
|
|
||||||
if s.updateRuntimeLocked(rc) {
|
|
||||||
added++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for ip, rc := range s.runtimeIndex.index {
|
|
||||||
if rc.isEmpty() {
|
|
||||||
delete(s.runtimeIndex.index, ip)
|
|
||||||
removed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return added, removed
|
|
||||||
}
|
|
||||||
|
|
||||||
// SizeRuntime returns the number of the runtime clients.
|
|
||||||
func (s *Storage) SizeRuntime() (n int) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
return s.runtimeIndex.Size()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RangeRuntime calls f for each runtime client in an undefined order.
|
// RangeRuntime calls f for each runtime client in an undefined order.
|
||||||
|
@ -316,16 +579,11 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
s.runtimeIndex.Range(f)
|
s.runtimeIndex.rangeClients(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteBySource removes all runtime clients that have information only from
|
// AllowedTags returns the list of available client tags. tags must not be
|
||||||
// the specified source and returns the number of removed clients.
|
// modified.
|
||||||
//
|
func (s *Storage) AllowedTags() (tags []string) {
|
||||||
// TODO(s.chzhen): Use it.
|
return s.allowedTags
|
||||||
func (s *Storage) DeleteBySource(src Source) (n int) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
return s.runtimeIndex.DeleteBySource(src)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,23 +3,513 @@ package client_test
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
"github.com/AdguardTeam/golibs/hostsfile"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testHostsContainer is a mock implementation of the [client.HostsContainer]
|
||||||
|
// interface.
|
||||||
|
type testHostsContainer struct {
|
||||||
|
onUpd func() (updates <-chan *hostsfile.DefaultStorage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ client.HostsContainer = (*testHostsContainer)(nil)
|
||||||
|
|
||||||
|
// Upd implements the [client.HostsContainer] interface for *testHostsContainer.
|
||||||
|
func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
|
||||||
|
return c.onUpd()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interface stores and refreshes the network neighborhood reported by ARP
|
||||||
|
// (Address Resolution Protocol).
|
||||||
|
type Interface interface {
|
||||||
|
// Refresh updates the stored data. It must be safe for concurrent use.
|
||||||
|
Refresh() (err error)
|
||||||
|
|
||||||
|
// Neighbors returnes the last set of data reported by ARP. Both the method
|
||||||
|
// and it's result must be safe for concurrent use.
|
||||||
|
Neighbors() (ns []arpdb.Neighbor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testARPDB is a mock implementation of the [arpdb.Interface].
|
||||||
|
type testARPDB struct {
|
||||||
|
onRefresh func() (err error)
|
||||||
|
onNeighbors func() (ns []arpdb.Neighbor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ arpdb.Interface = (*testARPDB)(nil)
|
||||||
|
|
||||||
|
// Refresh implements the [arpdb.Interface] interface for *testARP.
|
||||||
|
func (c *testARPDB) Refresh() (err error) {
|
||||||
|
return c.onRefresh()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Neighbors implements the [arpdb.Interface] interface for *testARP.
|
||||||
|
func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) {
|
||||||
|
return c.onNeighbors()
|
||||||
|
}
|
||||||
|
|
||||||
|
// testDHCP is a mock implementation of the [client.DHCP].
|
||||||
|
type testDHCP struct {
|
||||||
|
OnLeases func() (leases []*dhcpsvc.Lease)
|
||||||
|
OnHostBy func(ip netip.Addr) (host string)
|
||||||
|
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ client.DHCP = (*testDHCP)(nil)
|
||||||
|
|
||||||
|
// Lease implements the [client.DHCP] interface for *testDHCP.
|
||||||
|
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
|
||||||
|
|
||||||
|
// HostByIP implements the [client.DHCP] interface for *testDHCP.
|
||||||
|
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
|
||||||
|
|
||||||
|
// MACByIP implements the [client.DHCP] interface for *testDHCP.
|
||||||
|
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
|
||||||
|
|
||||||
|
// compareRuntimeInfo is a helper function that returns true if the runtime
|
||||||
|
// client has provided info.
|
||||||
|
func compareRuntimeInfo(rc *client.Runtime, src client.Source, host string) (ok bool) {
|
||||||
|
s, h := rc.Info()
|
||||||
|
if s != src {
|
||||||
|
return false
|
||||||
|
} else if h != host {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Add_hostsfile(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||||
|
cliName1 = "client_one"
|
||||||
|
|
||||||
|
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
cliName2 = "client_two"
|
||||||
|
)
|
||||||
|
|
||||||
|
hostCh := make(chan *hostsfile.DefaultStorage)
|
||||||
|
h := &testHostsContainer{
|
||||||
|
onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh },
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: client.EmptyDHCP{},
|
||||||
|
EtcHosts: h,
|
||||||
|
ARPClientsUpdatePeriod: testTimeout / 10,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||||
|
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add_hosts", func(t *testing.T) {
|
||||||
|
var s *hostsfile.DefaultStorage
|
||||||
|
s, err = hostsfile.NewDefaultStorage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.Add(&hostsfile.Record{
|
||||||
|
Addr: cliIP1,
|
||||||
|
Names: []string{cliName1},
|
||||||
|
})
|
||||||
|
|
||||||
|
testutil.RequireSend(t, hostCh, s, testTimeout)
|
||||||
|
|
||||||
|
require.Eventually(t, func() (ok bool) {
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
if cli1 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli1, client.SourceHostsFile, cliName1))
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, testTimeout, testTimeout/10)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update_hosts", func(t *testing.T) {
|
||||||
|
var s *hostsfile.DefaultStorage
|
||||||
|
s, err = hostsfile.NewDefaultStorage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s.Add(&hostsfile.Record{
|
||||||
|
Addr: cliIP2,
|
||||||
|
Names: []string{cliName2},
|
||||||
|
})
|
||||||
|
|
||||||
|
testutil.RequireSend(t, hostCh, s, testTimeout)
|
||||||
|
|
||||||
|
require.Eventually(t, func() (ok bool) {
|
||||||
|
cli2 := storage.ClientRuntime(cliIP2)
|
||||||
|
if cli2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli2, client.SourceHostsFile, cliName2))
|
||||||
|
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
require.Nil(t, cli1)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, testTimeout, testTimeout/10)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Add_arp(t *testing.T) {
|
||||||
|
var (
|
||||||
|
mu sync.Mutex
|
||||||
|
neighbors []arpdb.Neighbor
|
||||||
|
|
||||||
|
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||||
|
cliName1 = "client_one"
|
||||||
|
|
||||||
|
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
cliName2 = "client_two"
|
||||||
|
)
|
||||||
|
|
||||||
|
a := &testARPDB{
|
||||||
|
onRefresh: func() (err error) { return nil },
|
||||||
|
onNeighbors: func() (ns []arpdb.Neighbor) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
return neighbors
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: client.EmptyDHCP{},
|
||||||
|
ARPDB: a,
|
||||||
|
ARPClientsUpdatePeriod: testTimeout / 10,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = storage.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||||
|
return storage.Shutdown(testutil.ContextWithTimeout(t, testTimeout))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("add_hosts", func(t *testing.T) {
|
||||||
|
func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
neighbors = []arpdb.Neighbor{{
|
||||||
|
Name: cliName1,
|
||||||
|
IP: cliIP1,
|
||||||
|
}}
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Eventually(t, func() (ok bool) {
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
if cli1 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli1, client.SourceARP, cliName1))
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, testTimeout, testTimeout/10)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("update_hosts", func(t *testing.T) {
|
||||||
|
func() {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
neighbors = []arpdb.Neighbor{{
|
||||||
|
Name: cliName2,
|
||||||
|
IP: cliIP2,
|
||||||
|
}}
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.Eventually(t, func() (ok bool) {
|
||||||
|
cli2 := storage.ClientRuntime(cliIP2)
|
||||||
|
if cli2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli2, client.SourceARP, cliName2))
|
||||||
|
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
require.Nil(t, cli1)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, testTimeout, testTimeout/10)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStorage_Add_whois(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
|
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
cliName2 = "client_two"
|
||||||
|
|
||||||
|
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||||
|
cliName3 = "client_three"
|
||||||
|
)
|
||||||
|
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: client.EmptyDHCP{},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
whois := &whois.Info{
|
||||||
|
Country: "AU",
|
||||||
|
Orgname: "Example Org",
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("new_client", func(t *testing.T) {
|
||||||
|
storage.UpdateAddress(cliIP1, "", whois)
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
require.NotNil(t, cli1)
|
||||||
|
|
||||||
|
assert.Equal(t, whois, cli1.WHOIS())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("existing_runtime_client", func(t *testing.T) {
|
||||||
|
storage.UpdateAddress(cliIP2, cliName2, nil)
|
||||||
|
storage.UpdateAddress(cliIP2, "", whois)
|
||||||
|
|
||||||
|
cli2 := storage.ClientRuntime(cliIP2)
|
||||||
|
require.NotNil(t, cli2)
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli2, client.SourceRDNS, cliName2))
|
||||||
|
|
||||||
|
assert.Equal(t, whois, cli2.WHOIS())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can't_set_persistent_client", func(t *testing.T) {
|
||||||
|
err = storage.Add(&client.Persistent{
|
||||||
|
Name: cliName3,
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
IPs: []netip.Addr{cliIP3},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
storage.UpdateAddress(cliIP3, "", whois)
|
||||||
|
rc := storage.ClientRuntime(cliIP3)
|
||||||
|
require.Nil(t, rc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientsDHCP(t *testing.T) {
|
||||||
|
var (
|
||||||
|
cliIP1 = netip.MustParseAddr("1.1.1.1")
|
||||||
|
cliName1 = "one.dhcp"
|
||||||
|
|
||||||
|
cliIP2 = netip.MustParseAddr("2.2.2.2")
|
||||||
|
cliMAC2 = mustParseMAC("22:22:22:22:22:22")
|
||||||
|
cliName2 = "two.dhcp"
|
||||||
|
|
||||||
|
cliIP3 = netip.MustParseAddr("3.3.3.3")
|
||||||
|
cliMAC3 = mustParseMAC("33:33:33:33:33:33")
|
||||||
|
cliName3 = "three.dhcp"
|
||||||
|
|
||||||
|
prsCliIP = netip.MustParseAddr("4.3.2.1")
|
||||||
|
prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA")
|
||||||
|
prsCliName = "persistent.dhcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
ipToHost := map[netip.Addr]string{
|
||||||
|
cliIP1: cliName1,
|
||||||
|
}
|
||||||
|
ipToMAC := map[netip.Addr]net.HardwareAddr{
|
||||||
|
prsCliIP: prsCliMAC,
|
||||||
|
}
|
||||||
|
|
||||||
|
leases := []*dhcpsvc.Lease{{
|
||||||
|
IP: cliIP2,
|
||||||
|
Hostname: cliName2,
|
||||||
|
HWAddr: cliMAC2,
|
||||||
|
}, {
|
||||||
|
IP: cliIP3,
|
||||||
|
Hostname: cliName3,
|
||||||
|
HWAddr: cliMAC3,
|
||||||
|
}}
|
||||||
|
|
||||||
|
d := &testDHCP{
|
||||||
|
OnLeases: func() (ls []*dhcpsvc.Lease) {
|
||||||
|
return leases
|
||||||
|
},
|
||||||
|
OnHostBy: func(ip netip.Addr) (host string) {
|
||||||
|
return ipToHost[ip]
|
||||||
|
},
|
||||||
|
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) {
|
||||||
|
return ipToMAC[ip]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: d,
|
||||||
|
RuntimeSourceDHCP: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("find_runtime", func(t *testing.T) {
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
require.NotNil(t, cli1)
|
||||||
|
|
||||||
|
assert.True(t, compareRuntimeInfo(cli1, client.SourceDHCP, cliName1))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("find_persistent", func(t *testing.T) {
|
||||||
|
err = storage.Add(&client.Persistent{
|
||||||
|
Name: prsCliName,
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
MACs: []net.HardwareAddr{prsCliMAC},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
prsCli, ok := storage.Find(prsCliIP.String())
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
assert.Equal(t, prsCliName, prsCli.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("leases", func(t *testing.T) {
|
||||||
|
delete(ipToHost, cliIP1)
|
||||||
|
storage.UpdateDHCP()
|
||||||
|
|
||||||
|
cli1 := storage.ClientRuntime(cliIP1)
|
||||||
|
require.Nil(t, cli1)
|
||||||
|
|
||||||
|
for i, l := range leases {
|
||||||
|
cli := storage.ClientRuntime(l.IP)
|
||||||
|
require.NotNil(t, cli)
|
||||||
|
|
||||||
|
src, host := cli.Info()
|
||||||
|
assert.Equal(t, client.SourceDHCP, src)
|
||||||
|
assert.Equal(t, leases[i].Hostname, host)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("range", func(t *testing.T) {
|
||||||
|
s := 0
|
||||||
|
storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||||
|
s++
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, len(leases), s)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientsAddExisting(t *testing.T) {
|
||||||
|
t.Run("simple", func(t *testing.T) {
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: client.EmptyDHCP{},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
|
// Add a client.
|
||||||
|
err = storage.Add(&client.Persistent{
|
||||||
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||||
|
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||||
|
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now add an auto-client with the same IP.
|
||||||
|
storage.UpdateAddress(ip, "test", nil)
|
||||||
|
rc := storage.ClientRuntime(ip)
|
||||||
|
assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("complicated", func(t *testing.T) {
|
||||||
|
// TODO(a.garipov): Properly decouple the DHCP server from the client
|
||||||
|
// storage.
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
t.Skip("skipping dhcp test on windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
// First, init a DHCP server with a single static lease.
|
||||||
|
config := &dhcpd.ServerConfig{
|
||||||
|
Enabled: true,
|
||||||
|
DataDir: t.TempDir(),
|
||||||
|
Conf4: dhcpd.V4ServerConf{
|
||||||
|
Enabled: true,
|
||||||
|
GatewayIP: netip.MustParseAddr("1.2.3.1"),
|
||||||
|
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||||
|
RangeStart: netip.MustParseAddr("1.2.3.2"),
|
||||||
|
RangeEnd: netip.MustParseAddr("1.2.3.10"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
dhcpServer, err := dhcpd.Create(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
storage, err := client.NewStorage(&client.StorageConfig{
|
||||||
|
DHCP: dhcpServer,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ip := netip.MustParseAddr("1.2.3.4")
|
||||||
|
|
||||||
|
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
|
||||||
|
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||||
|
IP: ip,
|
||||||
|
Hostname: "testhost",
|
||||||
|
Expiry: time.Now().Add(time.Hour),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a new client with the same IP as for a client with MAC.
|
||||||
|
err = storage.Add(&client.Persistent{
|
||||||
|
Name: "client2",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
IPs: []netip.Addr{ip},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a new client with the IP from the first client's IP range.
|
||||||
|
err = storage.Add(&client.Persistent{
|
||||||
|
Name: "client3",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// newStorage is a helper function that returns a client storage filled with
|
// newStorage is a helper function that returns a client storage filled with
|
||||||
// persistent clients from the m. It also generates a UID for each client.
|
// persistent clients from the m. It also generates a UID for each client.
|
||||||
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
s = client.NewStorage(&client.Config{
|
s, err := client.NewStorage(&client.StorageConfig{
|
||||||
AllowedTags: nil,
|
DHCP: client.EmptyDHCP{},
|
||||||
})
|
})
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
for _, c := range m {
|
for _, c := range m {
|
||||||
c.UID = client.MustNewUID()
|
c.UID = client.MustNewUID()
|
||||||
|
@ -31,14 +521,6 @@ func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) {
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// newRuntimeClient is a helper function that returns a new runtime client.
|
|
||||||
func newRuntimeClient(ip netip.Addr, source client.Source, host string) (rc *client.Runtime) {
|
|
||||||
rc = client.NewRuntime(ip)
|
|
||||||
rc.SetInfo(source, []string{host})
|
|
||||||
|
|
||||||
return rc
|
|
||||||
}
|
|
||||||
|
|
||||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||||
// error.
|
// error.
|
||||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||||
|
@ -55,7 +537,7 @@ func TestStorage_Add(t *testing.T) {
|
||||||
existingName = "existing_name"
|
existingName = "existing_name"
|
||||||
existingClientID = "existing_client_id"
|
existingClientID = "existing_client_id"
|
||||||
|
|
||||||
allowedTag = "tag"
|
allowedTag = "user_admin"
|
||||||
notAllowedTag = "not_allowed_tag"
|
notAllowedTag = "not_allowed_tag"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -73,10 +555,20 @@ func TestStorage_Add(t *testing.T) {
|
||||||
UID: existingClientUID,
|
UID: existingClientUID,
|
||||||
}
|
}
|
||||||
|
|
||||||
s := client.NewStorage(&client.Config{
|
s, err := client.NewStorage(&client.StorageConfig{})
|
||||||
AllowedTags: []string{allowedTag},
|
require.NoError(t, err)
|
||||||
})
|
|
||||||
err := s.Add(existingClient)
|
tags := s.AllowedTags()
|
||||||
|
require.NotZero(t, len(tags))
|
||||||
|
require.True(t, slices.IsSorted(tags))
|
||||||
|
|
||||||
|
_, ok := slices.BinarySearch(tags, allowedTag)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
_, ok = slices.BinarySearch(tags, notAllowedTag)
|
||||||
|
require.False(t, ok)
|
||||||
|
|
||||||
|
err = s.Add(existingClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
@ -136,12 +628,43 @@ func TestStorage_Add(t *testing.T) {
|
||||||
}, {
|
}, {
|
||||||
name: "not_allowed_tag",
|
name: "not_allowed_tag",
|
||||||
cli: &client.Persistent{
|
cli: &client.Persistent{
|
||||||
Name: "nont_allowed_tag",
|
Name: "not_allowed_tag",
|
||||||
Tags: []string{notAllowedTag},
|
Tags: []string{notAllowedTag},
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
IPs: []netip.Addr{netip.MustParseAddr("4.4.4.4")},
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
},
|
},
|
||||||
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
wantErrMsg: `adding client: invalid tag: "not_allowed_tag"`,
|
||||||
|
}, {
|
||||||
|
name: "allowed_tag",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "allowed_tag",
|
||||||
|
Tags: []string{allowedTag},
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("5.5.5.5")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("6.6.6.6")},
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding client: empty name",
|
||||||
|
}, {
|
||||||
|
name: "no_id",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "no_id",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding client: id required",
|
||||||
|
}, {
|
||||||
|
name: "no_uid",
|
||||||
|
cli: &client.Persistent{
|
||||||
|
Name: "no_uid",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr("7.7.7.7")},
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding client: uid required",
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -164,10 +687,10 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||||
UID: client.MustNewUID(),
|
UID: client.MustNewUID(),
|
||||||
}
|
}
|
||||||
|
|
||||||
s := client.NewStorage(&client.Config{
|
s, err := client.NewStorage(&client.StorageConfig{})
|
||||||
AllowedTags: nil,
|
require.NoError(t, err)
|
||||||
})
|
|
||||||
err := s.Add(existingClient)
|
err = s.Add(existingClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
@ -191,9 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("duplicate_remove", func(t *testing.T) {
|
t.Run("duplicate_remove", func(t *testing.T) {
|
||||||
s = client.NewStorage(&client.Config{
|
s, err = client.NewStorage(&client.StorageConfig{})
|
||||||
AllowedTags: nil,
|
require.NoError(t, err)
|
||||||
})
|
|
||||||
err = s.Add(existingClient)
|
err = s.Add(existingClient)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -623,157 +1146,3 @@ func TestStorage_RangeByName(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStorage_UpdateRuntime(t *testing.T) {
|
|
||||||
const (
|
|
||||||
addedARP = "added_arp"
|
|
||||||
addedSecondARP = "added_arp"
|
|
||||||
|
|
||||||
updatedARP = "updated_arp"
|
|
||||||
|
|
||||||
cliCity = "City"
|
|
||||||
cliCountry = "Country"
|
|
||||||
cliOrgname = "Orgname"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ip = netip.MustParseAddr("1.1.1.1")
|
|
||||||
ip2 = netip.MustParseAddr("2.2.2.2")
|
|
||||||
)
|
|
||||||
|
|
||||||
updated := client.NewRuntime(ip)
|
|
||||||
updated.SetInfo(client.SourceARP, []string{updatedARP})
|
|
||||||
|
|
||||||
info := &whois.Info{
|
|
||||||
City: cliCity,
|
|
||||||
Country: cliCountry,
|
|
||||||
Orgname: cliOrgname,
|
|
||||||
}
|
|
||||||
updated.SetWHOIS(info)
|
|
||||||
|
|
||||||
s := client.NewStorage(&client.Config{
|
|
||||||
AllowedTags: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add_arp_client", func(t *testing.T) {
|
|
||||||
added := client.NewRuntime(ip)
|
|
||||||
added.SetInfo(client.SourceARP, []string{addedARP})
|
|
||||||
|
|
||||||
require.True(t, s.UpdateRuntime(added))
|
|
||||||
require.Equal(t, 1, s.SizeRuntime())
|
|
||||||
|
|
||||||
got := s.ClientRuntime(ip)
|
|
||||||
source, host := got.Info()
|
|
||||||
assert.Equal(t, client.SourceARP, source)
|
|
||||||
assert.Equal(t, addedARP, host)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add_second_arp_client", func(t *testing.T) {
|
|
||||||
added := client.NewRuntime(ip2)
|
|
||||||
added.SetInfo(client.SourceARP, []string{addedSecondARP})
|
|
||||||
|
|
||||||
require.True(t, s.UpdateRuntime(added))
|
|
||||||
require.Equal(t, 2, s.SizeRuntime())
|
|
||||||
|
|
||||||
got := s.ClientRuntime(ip2)
|
|
||||||
source, host := got.Info()
|
|
||||||
assert.Equal(t, client.SourceARP, source)
|
|
||||||
assert.Equal(t, addedSecondARP, host)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("update_first_client", func(t *testing.T) {
|
|
||||||
require.False(t, s.UpdateRuntime(updated))
|
|
||||||
got := s.ClientRuntime(ip)
|
|
||||||
require.Equal(t, 2, s.SizeRuntime())
|
|
||||||
|
|
||||||
source, host := got.Info()
|
|
||||||
assert.Equal(t, client.SourceARP, source)
|
|
||||||
assert.Equal(t, updatedARP, host)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("remove_arp_info", func(t *testing.T) {
|
|
||||||
n := s.DeleteBySource(client.SourceARP)
|
|
||||||
require.Equal(t, 1, n)
|
|
||||||
require.Equal(t, 1, s.SizeRuntime())
|
|
||||||
|
|
||||||
got := s.ClientRuntime(ip)
|
|
||||||
source, _ := got.Info()
|
|
||||||
assert.Equal(t, client.SourceWHOIS, source)
|
|
||||||
assert.Equal(t, info, got.WHOIS())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("remove_whois_info", func(t *testing.T) {
|
|
||||||
n := s.DeleteBySource(client.SourceWHOIS)
|
|
||||||
require.Equal(t, 1, n)
|
|
||||||
require.Equal(t, 0, s.SizeRuntime())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStorage_BatchUpdateBySource(t *testing.T) {
|
|
||||||
const (
|
|
||||||
defSrc = client.SourceARP
|
|
||||||
|
|
||||||
cliFirstHost1 = "host1"
|
|
||||||
cliFirstHost2 = "host2"
|
|
||||||
cliUpdatedHost3 = "host3"
|
|
||||||
cliUpdatedHost4 = "host4"
|
|
||||||
cliUpdatedHost5 = "host5"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
cliFirstIP1 = netip.MustParseAddr("1.1.1.1")
|
|
||||||
cliFirstIP2 = netip.MustParseAddr("2.2.2.2")
|
|
||||||
cliUpdatedIP3 = netip.MustParseAddr("3.3.3.3")
|
|
||||||
cliUpdatedIP4 = netip.MustParseAddr("4.4.4.4")
|
|
||||||
cliUpdatedIP5 = netip.MustParseAddr("5.5.5.5")
|
|
||||||
)
|
|
||||||
|
|
||||||
firstClients := []*client.Runtime{
|
|
||||||
newRuntimeClient(cliFirstIP1, defSrc, cliFirstHost1),
|
|
||||||
newRuntimeClient(cliFirstIP2, defSrc, cliFirstHost2),
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedClients := []*client.Runtime{
|
|
||||||
newRuntimeClient(cliUpdatedIP3, defSrc, cliUpdatedHost3),
|
|
||||||
newRuntimeClient(cliUpdatedIP4, defSrc, cliUpdatedHost4),
|
|
||||||
newRuntimeClient(cliUpdatedIP5, defSrc, cliUpdatedHost5),
|
|
||||||
}
|
|
||||||
|
|
||||||
s := client.NewStorage(&client.Config{
|
|
||||||
AllowedTags: nil,
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("populate_storage_with_first_clients", func(t *testing.T) {
|
|
||||||
added, removed := s.BatchUpdateBySource(defSrc, firstClients)
|
|
||||||
require.Equal(t, len(firstClients), added)
|
|
||||||
require.Equal(t, 0, removed)
|
|
||||||
require.Equal(t, len(firstClients), s.SizeRuntime())
|
|
||||||
|
|
||||||
rc := s.ClientRuntime(cliFirstIP1)
|
|
||||||
src, host := rc.Info()
|
|
||||||
assert.Equal(t, defSrc, src)
|
|
||||||
assert.Equal(t, cliFirstHost1, host)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("update_storage", func(t *testing.T) {
|
|
||||||
added, removed := s.BatchUpdateBySource(defSrc, updatedClients)
|
|
||||||
require.Equal(t, len(updatedClients), added)
|
|
||||||
require.Equal(t, len(firstClients), removed)
|
|
||||||
require.Equal(t, len(updatedClients), s.SizeRuntime())
|
|
||||||
|
|
||||||
rc := s.ClientRuntime(cliUpdatedIP3)
|
|
||||||
src, host := rc.Info()
|
|
||||||
assert.Equal(t, defSrc, src)
|
|
||||||
assert.Equal(t, cliUpdatedHost3, host)
|
|
||||||
|
|
||||||
rc = s.ClientRuntime(cliFirstIP1)
|
|
||||||
assert.Nil(t, rc)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("remove_all", func(t *testing.T) {
|
|
||||||
added, removed := s.BatchUpdateBySource(defSrc, []*client.Runtime{})
|
|
||||||
require.Equal(t, 0, added)
|
|
||||||
require.Equal(t, len(updatedClients), removed)
|
|
||||||
require.Equal(t, 0, s.SizeRuntime())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -11,7 +11,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||||
|
@ -20,47 +19,18 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/hostsfile"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
|
||||||
// needs.
|
|
||||||
type DHCP interface {
|
|
||||||
// Leases returns all the DHCP leases.
|
|
||||||
Leases() (leases []*dhcpsvc.Lease)
|
|
||||||
|
|
||||||
// HostByIP returns the hostname of the DHCP client with the given IP
|
|
||||||
// address. The address will be netip.Addr{} if there is no such client,
|
|
||||||
// due to an assumption that a DHCP client must always have a hostname.
|
|
||||||
HostByIP(ip netip.Addr) (host string)
|
|
||||||
|
|
||||||
// MACByIP returns the MAC address for the given IP address leased. It
|
|
||||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
|
||||||
// client must always have a MAC address.
|
|
||||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientsContainer is the storage of all runtime and persistent clients.
|
// clientsContainer is the storage of all runtime and persistent clients.
|
||||||
type clientsContainer struct {
|
type clientsContainer struct {
|
||||||
// storage stores information about persistent clients.
|
// storage stores information about persistent clients.
|
||||||
storage *client.Storage
|
storage *client.Storage
|
||||||
|
|
||||||
// dhcp is the DHCP service implementation.
|
|
||||||
dhcp DHCP
|
|
||||||
|
|
||||||
// clientChecker checks if a client is blocked by the current access
|
// clientChecker checks if a client is blocked by the current access
|
||||||
// settings.
|
// settings.
|
||||||
clientChecker BlockedClientChecker
|
clientChecker BlockedClientChecker
|
||||||
|
|
||||||
// etcHosts contains list of rewrite rules taken from the operating system's
|
|
||||||
// hosts database.
|
|
||||||
etcHosts *aghnet.HostsContainer
|
|
||||||
|
|
||||||
// arpDB stores the neighbors retrieved from ARP.
|
|
||||||
arpDB arpdb.Interface
|
|
||||||
|
|
||||||
// lock protects all fields.
|
// lock protects all fields.
|
||||||
//
|
//
|
||||||
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
||||||
|
@ -92,7 +62,7 @@ type BlockedClientChecker interface {
|
||||||
// Note: this function must be called only once
|
// Note: this function must be called only once
|
||||||
func (clients *clientsContainer) Init(
|
func (clients *clientsContainer) Init(
|
||||||
objects []*clientObject,
|
objects []*clientObject,
|
||||||
dhcpServer DHCP,
|
dhcpServer client.DHCP,
|
||||||
etcHosts *aghnet.HostsContainer,
|
etcHosts *aghnet.HostsContainer,
|
||||||
arpDB arpdb.Interface,
|
arpDB arpdb.Interface,
|
||||||
filteringConf *filtering.Config,
|
filteringConf *filtering.Config,
|
||||||
|
@ -102,26 +72,15 @@ func (clients *clientsContainer) Init(
|
||||||
return errors.Error("clients container already initialized")
|
return errors.Error("clients container already initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.storage = client.NewStorage(&client.Config{
|
confClients := make([]*client.Persistent, 0, len(objects))
|
||||||
AllowedTags: clientTags,
|
for i, o := range objects {
|
||||||
})
|
var p *client.Persistent
|
||||||
|
p, err = o.toPersistent(filteringConf)
|
||||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
|
||||||
clients.dhcp = dhcpServer
|
|
||||||
|
|
||||||
clients.etcHosts = etcHosts
|
|
||||||
clients.arpDB = arpDB
|
|
||||||
err = clients.addFromConfig(objects, filteringConf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
return fmt.Errorf("init persistent client at index %d: %w", i, err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
confClients = append(confClients, p)
|
||||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
|
||||||
|
|
||||||
if clients.testing {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile
|
// The clients.etcHosts may be nil even if config.Clients.Sources.HostsFile
|
||||||
|
@ -130,21 +89,26 @@ func (clients *clientsContainer) Init(
|
||||||
// TODO(e.burkov): The option should probably be returned, since hosts file
|
// TODO(e.burkov): The option should probably be returned, since hosts file
|
||||||
// currently used not only for clients' information enrichment, but also in
|
// currently used not only for clients' information enrichment, but also in
|
||||||
// the filtering module and upstream addresses resolution.
|
// the filtering module and upstream addresses resolution.
|
||||||
if config.Clients.Sources.HostsFile && clients.etcHosts != nil {
|
var hosts client.HostsContainer = etcHosts
|
||||||
go clients.handleHostsUpdates()
|
if !config.Clients.Sources.HostsFile {
|
||||||
|
hosts = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
clients.storage, err = client.NewStorage(&client.StorageConfig{
|
||||||
|
InitialClients: confClients,
|
||||||
|
DHCP: dhcpServer,
|
||||||
|
EtcHosts: hosts,
|
||||||
|
ARPDB: arpDB,
|
||||||
|
ARPClientsUpdatePeriod: arpClientsUpdatePeriod,
|
||||||
|
RuntimeSourceDHCP: config.Clients.Sources.DHCP,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("init client storage: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleHostsUpdates receives the updates from the hosts container and adds
|
|
||||||
// them to the clients container. It is intended to be used as a goroutine.
|
|
||||||
func (clients *clientsContainer) handleHostsUpdates() {
|
|
||||||
for upd := range clients.etcHosts.Upd() {
|
|
||||||
clients.addFromHostsFile(upd)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// webHandlersRegistered prevents a [clientsContainer] from registering its web
|
// webHandlersRegistered prevents a [clientsContainer] from registering its web
|
||||||
// handlers more than once.
|
// handlers more than once.
|
||||||
//
|
//
|
||||||
|
@ -152,7 +116,7 @@ func (clients *clientsContainer) handleHostsUpdates() {
|
||||||
var webHandlersRegistered = false
|
var webHandlersRegistered = false
|
||||||
|
|
||||||
// Start starts the clients container.
|
// Start starts the clients container.
|
||||||
func (clients *clientsContainer) Start() {
|
func (clients *clientsContainer) Start(ctx context.Context) (err error) {
|
||||||
if clients.testing {
|
if clients.testing {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -162,14 +126,7 @@ func (clients *clientsContainer) Start() {
|
||||||
clients.registerWebHandlers()
|
clients.registerWebHandlers()
|
||||||
}
|
}
|
||||||
|
|
||||||
go clients.periodicUpdate()
|
return clients.storage.Start(ctx)
|
||||||
}
|
|
||||||
|
|
||||||
// reloadARP reloads runtime clients from ARP, if configured.
|
|
||||||
func (clients *clientsContainer) reloadARP() {
|
|
||||||
if clients.arpDB != nil {
|
|
||||||
clients.addFromSystemARP()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientObject is the YAML representation of a persistent client.
|
// clientObject is the YAML representation of a persistent client.
|
||||||
|
@ -270,28 +227,6 @@ func (o *clientObject) toPersistent(
|
||||||
return cli, nil
|
return cli, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addFromConfig initializes the clients container with objects from the
|
|
||||||
// configuration file.
|
|
||||||
func (clients *clientsContainer) addFromConfig(
|
|
||||||
objects []*clientObject,
|
|
||||||
filteringConf *filtering.Config,
|
|
||||||
) (err error) {
|
|
||||||
for i, o := range objects {
|
|
||||||
var cli *client.Persistent
|
|
||||||
cli, err = o.toPersistent(filteringConf)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = clients.storage.Add(cli)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("adding client %q at index %d: %w", cli.Name, i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// forConfig returns all currently known persistent clients as objects for the
|
// forConfig returns all currently known persistent clients as objects for the
|
||||||
// configuration file.
|
// configuration file.
|
||||||
func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
|
@ -332,39 +267,6 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
// arpClientsUpdatePeriod defines how often ARP clients are updated.
|
// arpClientsUpdatePeriod defines how often ARP clients are updated.
|
||||||
const arpClientsUpdatePeriod = 10 * time.Minute
|
const arpClientsUpdatePeriod = 10 * time.Minute
|
||||||
|
|
||||||
func (clients *clientsContainer) periodicUpdate() {
|
|
||||||
defer log.OnPanic("clients container")
|
|
||||||
|
|
||||||
for {
|
|
||||||
clients.reloadARP()
|
|
||||||
time.Sleep(arpClientsUpdatePeriod)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientSource checks if client with this IP address already exists and returns
|
|
||||||
// the source which updated it last. It returns [client.SourceNone] if the
|
|
||||||
// client doesn't exist. Note that it is only used in tests.
|
|
||||||
func (clients *clientsContainer) clientSource(ip netip.Addr) (src client.Source) {
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
_, ok := clients.findLocked(ip.String())
|
|
||||||
if ok {
|
|
||||||
return client.SourcePersistent
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := clients.storage.ClientRuntime(ip)
|
|
||||||
if rc != nil {
|
|
||||||
src, _ = rc.Info()
|
|
||||||
}
|
|
||||||
|
|
||||||
if src < client.SourceDHCP && clients.dhcp.HostByIP(ip) != "" {
|
|
||||||
src = client.SourceDHCP
|
|
||||||
}
|
|
||||||
|
|
||||||
return src
|
|
||||||
}
|
|
||||||
|
|
||||||
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid
|
// findMultiple is a wrapper around [clientsContainer.find] to make it a valid
|
||||||
// client finder for the query log. c is never nil; if no information about the
|
// client finder for the query log. c is never nil; if no information about the
|
||||||
// client is found, it returns an artificial client record by only setting the
|
// client is found, it returns an artificial client record by only setting the
|
||||||
|
@ -410,7 +312,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||||
}, false
|
}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := clients.findRuntimeClient(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
if rc != nil {
|
if rc != nil {
|
||||||
_, host := rc.Info()
|
_, host := rc.Info()
|
||||||
|
|
||||||
|
@ -425,19 +327,6 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// find returns a shallow copy of the client if there is one found.
|
|
||||||
func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) {
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
c, ok = clients.findLocked(id)
|
|
||||||
if !ok {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
||||||
// valid client information finder for the statistics. If no information about
|
// valid client information finder for the statistics. If no information about
|
||||||
// the client is found, it returns true.
|
// the client is found, it returns true.
|
||||||
|
@ -446,7 +335,7 @@ func (clients *clientsContainer) shouldCountClient(ids []string) (y bool) {
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
client, ok := clients.findLocked(id)
|
client, ok := clients.storage.Find(id)
|
||||||
if ok {
|
if ok {
|
||||||
return !client.IgnoreStatistics
|
return !client.IgnoreStatistics
|
||||||
}
|
}
|
||||||
|
@ -468,7 +357,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
c, ok := clients.findLocked(id)
|
c, ok := clients.storage.Find(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
} else if c.UpstreamConfig != nil {
|
} else if c.UpstreamConfig != nil {
|
||||||
|
@ -506,198 +395,17 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findLocked searches for a client by its ID. clients.lock is expected to be
|
|
||||||
// locked.
|
|
||||||
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
|
|
||||||
c, ok = clients.storage.Find(id)
|
|
||||||
if ok {
|
|
||||||
return c, true
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(e.burkov): Iterate through clients.list only once.
|
|
||||||
return clients.findDHCP(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
|
||||||
// there is such client. clients.lock is expected to be locked.
|
|
||||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) {
|
|
||||||
foundMAC := clients.dhcp.MACByIP(ip)
|
|
||||||
if foundMAC == nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return clients.storage.FindByMAC(foundMAC)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRuntimeClient finds a runtime client by their IP.
|
|
||||||
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Runtime) {
|
|
||||||
rc = clients.storage.ClientRuntime(ip)
|
|
||||||
host := clients.dhcp.HostByIP(ip)
|
|
||||||
|
|
||||||
if host != "" {
|
|
||||||
if rc == nil {
|
|
||||||
rc = client.NewRuntime(ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
rc.SetInfo(client.SourceDHCP, []string{host})
|
|
||||||
|
|
||||||
return rc
|
|
||||||
}
|
|
||||||
|
|
||||||
return rc
|
|
||||||
}
|
|
||||||
|
|
||||||
// setWHOISInfo sets the WHOIS information for a client. clients.lock is
|
|
||||||
// expected to be locked.
|
|
||||||
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *whois.Info) {
|
|
||||||
_, ok := clients.findLocked(ip.String())
|
|
||||||
if ok {
|
|
||||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rc := client.NewRuntime(ip)
|
|
||||||
rc.SetWHOIS(wi)
|
|
||||||
clients.storage.UpdateRuntime(rc)
|
|
||||||
|
|
||||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addHost adds a new IP-hostname pairing. The priorities of the sources are
|
|
||||||
// taken into account. ok is true if the pairing was added.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Only used in internal tests. Consider removing.
|
|
||||||
func (clients *clientsContainer) addHost(
|
|
||||||
ip netip.Addr,
|
|
||||||
host string,
|
|
||||||
src client.Source,
|
|
||||||
) (ok bool) {
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
return clients.addHostLocked(ip, host, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
var _ client.AddressUpdater = (*clientsContainer)(nil)
|
||||||
|
|
||||||
// UpdateAddress implements the [client.AddressUpdater] interface for
|
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||||
// *clientsContainer
|
// *clientsContainer
|
||||||
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||||
// Common fast path optimization.
|
clients.storage.UpdateAddress(ip, host, info)
|
||||||
if host == "" && info == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
if host != "" {
|
|
||||||
ok := clients.addHostLocked(ip, host, client.SourceRDNS)
|
|
||||||
if !ok {
|
|
||||||
log.Debug("clients: host for client %q already set with higher priority source", ip)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if info != nil {
|
|
||||||
clients.setWHOISInfo(ip, info)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
|
|
||||||
// locked.
|
|
||||||
func (clients *clientsContainer) addHostLocked(
|
|
||||||
ip netip.Addr,
|
|
||||||
host string,
|
|
||||||
src client.Source,
|
|
||||||
) (ok bool) {
|
|
||||||
rc := client.NewRuntime(ip)
|
|
||||||
rc.SetInfo(src, []string{host})
|
|
||||||
|
|
||||||
if config.Clients.Sources.DHCP {
|
|
||||||
if dhcpHost := clients.dhcp.HostByIP(ip); dhcpHost != "" {
|
|
||||||
rc.SetInfo(client.SourceDHCP, []string{dhcpHost})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
clients.storage.UpdateRuntime(rc)
|
|
||||||
|
|
||||||
log.Debug(
|
|
||||||
"clients: adding client info %s -> %q %q [%d]",
|
|
||||||
ip,
|
|
||||||
src,
|
|
||||||
host,
|
|
||||||
clients.storage.SizeRuntime(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// addFromHostsFile fills the client-hostname pairing index from the system's
|
|
||||||
// hosts files.
|
|
||||||
func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) {
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
var rcs []*client.Runtime
|
|
||||||
hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) {
|
|
||||||
// Only the first name of the first record is considered a canonical
|
|
||||||
// hostname for the IP address.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Consider using all the names from all the records.
|
|
||||||
rc := client.NewRuntime(addr)
|
|
||||||
rc.SetInfo(client.SourceHostsFile, []string{names[0]})
|
|
||||||
|
|
||||||
rcs = append(rcs, rc)
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
added, removed := clients.storage.BatchUpdateBySource(client.SourceHostsFile, rcs)
|
|
||||||
log.Debug("clients: added %d, removed %d client aliases from system hosts file", added, removed)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
|
||||||
// command.
|
|
||||||
func (clients *clientsContainer) addFromSystemARP() {
|
|
||||||
if err := clients.arpDB.Refresh(); err != nil {
|
|
||||||
log.Error("refreshing arp container: %s", err)
|
|
||||||
|
|
||||||
clients.arpDB = arpdb.Empty{}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ns := clients.arpDB.Neighbors()
|
|
||||||
if len(ns) == 0 {
|
|
||||||
log.Debug("refreshing arp container: the update is empty")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
clients.lock.Lock()
|
|
||||||
defer clients.lock.Unlock()
|
|
||||||
|
|
||||||
var rcs []*client.Runtime
|
|
||||||
for _, n := range ns {
|
|
||||||
rc := client.NewRuntime(n.IP)
|
|
||||||
rc.SetInfo(client.SourceARP, []string{n.Name})
|
|
||||||
|
|
||||||
rcs = append(rcs, rc)
|
|
||||||
}
|
|
||||||
|
|
||||||
added, removed := clients.storage.BatchUpdateBySource(client.SourceARP, rcs)
|
|
||||||
log.Debug("clients: added %d, removed %d client aliases from arp neighborhood", added, removed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// close gracefully closes all the client-specific upstream configurations of
|
// close gracefully closes all the client-specific upstream configurations of
|
||||||
// the persistent clients.
|
// the persistent clients.
|
||||||
func (clients *clientsContainer) close() (err error) {
|
func (clients *clientsContainer) close(ctx context.Context) (err error) {
|
||||||
return clients.storage.CloseUpstreams()
|
return clients.storage.Shutdown(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,34 +3,14 @@ package home
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testDHCP struct {
|
|
||||||
OnLeases func() (leases []*dhcpsvc.Lease)
|
|
||||||
OnHostBy func(ip netip.Addr) (host string)
|
|
||||||
OnMACBy func(ip netip.Addr) (mac net.HardwareAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lease implements the [DHCP] interface for testDHCP.
|
|
||||||
func (t *testDHCP) Leases() (leases []*dhcpsvc.Lease) { return t.OnLeases() }
|
|
||||||
|
|
||||||
// HostByIP implements the [DHCP] interface for testDHCP.
|
|
||||||
func (t *testDHCP) HostByIP(ip netip.Addr) (host string) { return t.OnHostBy(ip) }
|
|
||||||
|
|
||||||
// MACByIP implements the [DHCP] interface for testDHCP.
|
|
||||||
func (t *testDHCP) MACByIP(ip netip.Addr) (mac net.HardwareAddr) { return t.OnMACBy(ip) }
|
|
||||||
|
|
||||||
// newClientsContainer is a helper that creates a new clients container for
|
// newClientsContainer is a helper that creates a new clients container for
|
||||||
// tests.
|
// tests.
|
||||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||||
|
@ -40,316 +20,11 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||||
testing: true,
|
testing: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
dhcp := &testDHCP{
|
require.NoError(t, c.Init(nil, client.EmptyDHCP{}, nil, nil, &filtering.Config{}))
|
||||||
OnLeases: func() (leases []*dhcpsvc.Lease) { return nil },
|
|
||||||
OnHostBy: func(ip netip.Addr) (host string) { return "" },
|
|
||||||
OnMACBy: func(ip netip.Addr) (mac net.HardwareAddr) { return nil },
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, c.Init(nil, dhcp, nil, nil, &filtering.Config{}))
|
|
||||||
|
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClients(t *testing.T) {
|
|
||||||
clients := newClientsContainer(t)
|
|
||||||
|
|
||||||
t.Run("add_success", func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
cliNone = "1.2.3.4"
|
|
||||||
cli1 = "1.1.1.1"
|
|
||||||
cli2 = "2.2.2.2"
|
|
||||||
|
|
||||||
cli1IP = netip.MustParseAddr(cli1)
|
|
||||||
cli2IP = netip.MustParseAddr(cli2)
|
|
||||||
|
|
||||||
cliIPv6 = netip.MustParseAddr("1:2:3::4")
|
|
||||||
)
|
|
||||||
|
|
||||||
c := &client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := clients.storage.Add(c)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c = &client.Persistent{
|
|
||||||
Name: "client2",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{cli2IP},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = clients.storage.Add(c)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c, ok := clients.find(cli1)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
|
||||||
|
|
||||||
c, ok = clients.find("1:2:3::4")
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
|
||||||
|
|
||||||
c, ok = clients.find(cli2)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, "client2", c.Name)
|
|
||||||
|
|
||||||
_, ok = clients.find(cliNone)
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, clients.clientSource(cli1IP), client.SourcePersistent)
|
|
||||||
assert.Equal(t, clients.clientSource(cli2IP), client.SourcePersistent)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add_fail_name", func(t *testing.T) {
|
|
||||||
err := clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("add_fail_ip", func(t *testing.T) {
|
|
||||||
err := clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client3",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
})
|
|
||||||
require.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("update_fail_ip", func(t *testing.T) {
|
|
||||||
err := clients.storage.Update("client1", &client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("update_success", func(t *testing.T) {
|
|
||||||
var (
|
|
||||||
cliOld = "1.1.1.1"
|
|
||||||
cliNew = "1.1.1.2"
|
|
||||||
|
|
||||||
cliNewIP = netip.MustParseAddr(cliNew)
|
|
||||||
)
|
|
||||||
|
|
||||||
prev, ok := clients.storage.FindByName("client1")
|
|
||||||
require.True(t, ok)
|
|
||||||
require.NotNil(t, prev)
|
|
||||||
|
|
||||||
err := clients.storage.Update("client1", &client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: prev.UID,
|
|
||||||
IPs: []netip.Addr{cliNewIP},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
_, ok = clients.find(cliOld)
|
|
||||||
assert.False(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, clients.clientSource(cliNewIP), client.SourcePersistent)
|
|
||||||
|
|
||||||
prev, ok = clients.storage.FindByName("client1")
|
|
||||||
require.True(t, ok)
|
|
||||||
require.NotNil(t, prev)
|
|
||||||
|
|
||||||
err = clients.storage.Update("client1", &client.Persistent{
|
|
||||||
Name: "client1-renamed",
|
|
||||||
UID: prev.UID,
|
|
||||||
IPs: []netip.Addr{cliNewIP},
|
|
||||||
UseOwnSettings: true,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
c, ok := clients.find(cliNew)
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, "client1-renamed", c.Name)
|
|
||||||
assert.True(t, c.UseOwnSettings)
|
|
||||||
|
|
||||||
nilCli, ok := clients.storage.FindByName("client1")
|
|
||||||
require.False(t, ok)
|
|
||||||
|
|
||||||
assert.Nil(t, nilCli)
|
|
||||||
|
|
||||||
require.Len(t, c.IDs(), 1)
|
|
||||||
|
|
||||||
assert.Equal(t, cliNewIP, c.IPs[0])
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("del_success", func(t *testing.T) {
|
|
||||||
ok := clients.storage.RemoveByName("client1-renamed")
|
|
||||||
require.True(t, ok)
|
|
||||||
|
|
||||||
_, ok = clients.find("1.1.1.2")
|
|
||||||
assert.False(t, ok)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("del_fail", func(t *testing.T) {
|
|
||||||
ok := clients.storage.RemoveByName("client3")
|
|
||||||
assert.False(t, ok)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
|
||||||
ok := clients.addHost(ip, "host", client.SourceARP)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
ok = clients.addHost(ip, "host2", client.SourceARP)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
ok = clients.addHost(ip, "host3", client.SourceHostsFile)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, clients.clientSource(ip), client.SourceHostsFile)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.2.3.4")
|
|
||||||
ok := clients.addHost(ip, "from_arp", client.SourceARP)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, clients.clientSource(ip), client.SourceARP)
|
|
||||||
|
|
||||||
ok = clients.addHost(ip, "from_dhcp", client.SourceDHCP)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, clients.clientSource(ip), client.SourceDHCP)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("addhost_priority", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
|
||||||
ok := clients.addHost(ip, "host1", client.SourceRDNS)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
assert.Equal(t, client.SourceHostsFile, clients.clientSource(ip))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientsWHOIS(t *testing.T) {
|
|
||||||
clients := newClientsContainer(t)
|
|
||||||
whois := &whois.Info{
|
|
||||||
Country: "AU",
|
|
||||||
Orgname: "Example Org",
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("new_client", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.255")
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
|
||||||
rc := clients.storage.ClientRuntime(ip)
|
|
||||||
require.NotNil(t, rc)
|
|
||||||
|
|
||||||
assert.Equal(t, whois, rc.WHOIS())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("existing_auto-client", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
|
||||||
ok := clients.addHost(ip, "host", client.SourceRDNS)
|
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
|
||||||
rc := clients.storage.ClientRuntime(ip)
|
|
||||||
require.NotNil(t, rc)
|
|
||||||
|
|
||||||
assert.Equal(t, whois, rc.WHOIS())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.2")
|
|
||||||
|
|
||||||
err := clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
clients.setWHOISInfo(ip, whois)
|
|
||||||
rc := clients.storage.ClientRuntime(ip)
|
|
||||||
require.Nil(t, rc)
|
|
||||||
|
|
||||||
assert.True(t, clients.storage.RemoveByName("client1"))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientsAddExisting(t *testing.T) {
|
|
||||||
clients := newClientsContainer(t)
|
|
||||||
|
|
||||||
t.Run("simple", func(t *testing.T) {
|
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
|
||||||
|
|
||||||
// Add a client.
|
|
||||||
err := clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client1",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
|
||||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
|
||||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Now add an auto-client with the same IP.
|
|
||||||
ok := clients.addHost(ip, "test", client.SourceRDNS)
|
|
||||||
assert.True(t, ok)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("complicated", func(t *testing.T) {
|
|
||||||
// TODO(a.garipov): Properly decouple the DHCP server from the client
|
|
||||||
// storage.
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
t.Skip("skipping dhcp test on windows")
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("1.2.3.4")
|
|
||||||
|
|
||||||
// First, init a DHCP server with a single static lease.
|
|
||||||
config := &dhcpd.ServerConfig{
|
|
||||||
Enabled: true,
|
|
||||||
DataDir: t.TempDir(),
|
|
||||||
Conf4: dhcpd.V4ServerConf{
|
|
||||||
Enabled: true,
|
|
||||||
GatewayIP: netip.MustParseAddr("1.2.3.1"),
|
|
||||||
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
|
||||||
RangeStart: netip.MustParseAddr("1.2.3.2"),
|
|
||||||
RangeEnd: netip.MustParseAddr("1.2.3.10"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
dhcpServer, err := dhcpd.Create(config)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
clients.dhcp = dhcpServer
|
|
||||||
|
|
||||||
err = dhcpServer.AddStaticLease(&dhcpsvc.Lease{
|
|
||||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
|
||||||
IP: ip,
|
|
||||||
Hostname: "testhost",
|
|
||||||
Expiry: time.Now().Add(time.Hour),
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Add a new client with the same IP as for a client with MAC.
|
|
||||||
err = clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client2",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{ip},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Add a new client with the IP from the first client's IP range.
|
|
||||||
err = clients.storage.Add(&client.Persistent{
|
|
||||||
Name: "client3",
|
|
||||||
UID: client.MustNewUID(),
|
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientsCustomUpstream(t *testing.T) {
|
func TestClientsCustomUpstream(t *testing.T) {
|
||||||
clients := newClientsContainer(t)
|
clients := newClientsContainer(t)
|
||||||
|
|
||||||
|
|
|
@ -103,6 +103,8 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
|
clients.storage.UpdateDHCP()
|
||||||
|
|
||||||
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||||
src, host := rc.Info()
|
src, host := rc.Info()
|
||||||
cj := runtimeClientJSON{
|
cj := runtimeClientJSON{
|
||||||
|
@ -117,20 +119,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
if config.Clients.Sources.DHCP {
|
data.Tags = clients.storage.AllowedTags()
|
||||||
for _, l := range clients.dhcp.Leases() {
|
|
||||||
cj := runtimeClientJSON{
|
|
||||||
Name: l.Hostname,
|
|
||||||
Source: client.SourceDHCP,
|
|
||||||
IP: l.IP,
|
|
||||||
WHOIS: &whois.Info{},
|
|
||||||
}
|
|
||||||
|
|
||||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
data.Tags = clientTags
|
|
||||||
|
|
||||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||||
}
|
}
|
||||||
|
@ -432,7 +421,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, _ := netip.ParseAddr(idStr)
|
ip, _ := netip.ParseAddr(idStr)
|
||||||
c, ok := clients.find(idStr)
|
c, ok := clients.storage.Find(idStr)
|
||||||
var cj *clientJSON
|
var cj *clientJSON
|
||||||
if !ok {
|
if !ok {
|
||||||
cj = clients.findRuntime(ip, idStr)
|
cj = clients.findRuntime(ip, idStr)
|
||||||
|
@ -454,7 +443,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||||
// non-nil.
|
// non-nil.
|
||||||
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) {
|
||||||
rc := clients.findRuntimeClient(ip)
|
rc := clients.storage.ClientRuntime(ip)
|
||||||
if rc == nil {
|
if rc == nil {
|
||||||
// It is still possible that the IP used to be in the runtime clients
|
// It is still possible that the IP used to be in the runtime clients
|
||||||
// list, but then the server was reloaded. So, check the DNS server's
|
// list, but then the server was reloaded. So, check the DNS server's
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
package home
|
|
||||||
|
|
||||||
var clientTags = []string{
|
|
||||||
"device_audio",
|
|
||||||
"device_camera",
|
|
||||||
"device_gameconsole",
|
|
||||||
"device_laptop",
|
|
||||||
"device_nas", // Network-attached Storage
|
|
||||||
"device_other",
|
|
||||||
"device_pc",
|
|
||||||
"device_phone",
|
|
||||||
"device_printer",
|
|
||||||
"device_securityalarm",
|
|
||||||
"device_tablet",
|
|
||||||
"device_tv",
|
|
||||||
|
|
||||||
"os_android",
|
|
||||||
"os_ios",
|
|
||||||
"os_linux",
|
|
||||||
"os_macos",
|
|
||||||
"os_other",
|
|
||||||
"os_windows",
|
|
||||||
|
|
||||||
"user_admin",
|
|
||||||
"user_child",
|
|
||||||
"user_regular",
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
package home
|
package home
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
|
@ -414,9 +415,9 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||||
|
|
||||||
setts.ClientIP = clientIP
|
setts.ClientIP = clientIP
|
||||||
|
|
||||||
c, ok := Context.clients.find(clientID)
|
c, ok := Context.clients.storage.Find(clientID)
|
||||||
if !ok {
|
if !ok {
|
||||||
c, ok = Context.clients.find(clientIP.String())
|
c, ok = Context.clients.storage.Find(clientIP.String())
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
log.Debug("%s: no clients with ip %s and clientid %q", pref, clientIP, clientID)
|
||||||
|
|
||||||
|
@ -459,11 +460,15 @@ func startDNSServer() error {
|
||||||
|
|
||||||
Context.filters.EnableFilters(false)
|
Context.filters.EnableFilters(false)
|
||||||
|
|
||||||
Context.clients.Start()
|
// TODO(s.chzhen): Pass context.
|
||||||
|
err := Context.clients.Start(context.TODO())
|
||||||
err := Context.dnsServer.Start()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't start forwarding DNS server: %w", err)
|
return fmt.Errorf("starting clients container: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = Context.dnsServer.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("starting dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
Context.filters.Start()
|
Context.filters.Start()
|
||||||
|
@ -500,7 +505,7 @@ func stopDNSServer() (err error) {
|
||||||
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Context.clients.close()
|
err = Context.clients.close(context.TODO())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("closing clients container: %w", err)
|
return fmt.Errorf("closing clients container: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,9 +18,8 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||||
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
|
func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) {
|
||||||
tb.Helper()
|
tb.Helper()
|
||||||
|
|
||||||
s = client.NewStorage(&client.Config{
|
s, err := client.NewStorage(&client.StorageConfig{})
|
||||||
AllowedTags: nil,
|
require.NoError(tb, err)
|
||||||
})
|
|
||||||
|
|
||||||
for _, p := range clients {
|
for _, p := range clients {
|
||||||
p.UID = client.MustNewUID()
|
p.UID = client.MustNewUID()
|
||||||
|
|
|
@ -119,7 +119,7 @@ func Main(clientBuildFS fs.FS) {
|
||||||
log.Info("Received signal %q", sig)
|
log.Info("Received signal %q", sig)
|
||||||
switch sig {
|
switch sig {
|
||||||
case syscall.SIGHUP:
|
case syscall.SIGHUP:
|
||||||
Context.clients.reloadARP()
|
Context.clients.storage.ReloadARP()
|
||||||
Context.tls.reload()
|
Context.tls.reload()
|
||||||
default:
|
default:
|
||||||
cleanup(context.Background())
|
cleanup(context.Background())
|
||||||
|
|
Loading…
Reference in New Issue