diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index b4804322..87ec9cb8 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -7,7 +7,6 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" @@ -94,9 +93,6 @@ type AddressProcessor struct { OnClose func() (err error) } -// type check -var _ client.AddressProcessor = (*AddressProcessor)(nil) - // Process implements the [client.AddressProcessor] interface for // *AddressProcessor. func (p *AddressProcessor) Process(ip netip.Addr) { @@ -114,9 +110,6 @@ type AddressUpdater struct { OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info) } -// type check -var _ client.AddressUpdater = (*AddressUpdater)(nil) - // UpdateAddress implements the [client.AddressUpdater] interface for // *AddressUpdater. func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go index c1a376ba..f0f55451 100644 --- a/internal/aghtest/interface_test.go +++ b/internal/aghtest/interface_test.go @@ -2,6 +2,7 @@ package aghtest_test import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" ) @@ -13,3 +14,13 @@ var _ filtering.Resolver = (*aghtest.Resolver)(nil) // type check var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil) + +// type check +// +// TODO(s.chzhen): It's here to avoid the import cycle. Remove it. +var _ client.AddressProcessor = (*aghtest.AddressProcessor)(nil) + +// type check +// +// TODO(s.chzhen): It's here to avoid the import cycle. Remove it. +var _ client.AddressUpdater = (*aghtest.AddressUpdater)(nil) diff --git a/internal/home/clientindex.go b/internal/client/index.go similarity index 77% rename from internal/home/clientindex.go rename to internal/client/index.go index 87d7406a..c6a17cb3 100644 --- a/internal/home/clientindex.go +++ b/internal/client/index.go @@ -1,4 +1,4 @@ -package home +package client import ( "fmt" @@ -26,8 +26,8 @@ func macToKey(mac net.HardwareAddr) (key macKey) { } } -// clientIndex stores all information about persistent clients. -type clientIndex struct { +// Index stores all information about persistent clients. +type Index struct { // clientIDToUID maps client ID to UID. clientIDToUID map[string]UID @@ -38,26 +38,26 @@ type clientIndex struct { macToUID map[macKey]UID // uidToClient maps UID to the persistent client. - uidToClient map[UID]*persistentClient + uidToClient map[UID]*Persistent // subnetToUID maps subnet to UID. subnetToUID aghalg.SortedMap[netip.Prefix, UID] } -// NewClientIndex initializes the new instance of client index. -func NewClientIndex() (ci *clientIndex) { - return &clientIndex{ +// NewIndex initializes the new instance of client index. +func NewIndex() (ci *Index) { + return &Index{ clientIDToUID: map[string]UID{}, ipToUID: map[netip.Addr]UID{}, subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare), macToUID: map[macKey]UID{}, - uidToClient: map[UID]*persistentClient{}, + uidToClient: map[UID]*Persistent{}, } } -// add stores information about a persistent client in the index. c must be +// Add stores information about a persistent client in the index. c must be // non-nil and contain UID. -func (ci *clientIndex) add(c *persistentClient) { +func (ci *Index) Add(c *Persistent) { if (c.UID == UID{}) { panic("client must contain uid") } @@ -82,9 +82,9 @@ func (ci *clientIndex) add(c *persistentClient) { ci.uidToClient[c.UID] = c } -// clashes returns an error if the index contains a different persistent client +// Clashes returns an error if the index contains a different persistent client // with at least a single identifier contained by c. c must be non-nil. -func (ci *clientIndex) clashes(c *persistentClient) (err error) { +func (ci *Index) Clashes(c *Persistent) (err error) { for _, id := range c.ClientIDs { existing, ok := ci.clientIDToUID[id] if ok && existing != c.UID { @@ -114,7 +114,7 @@ func (ci *clientIndex) clashes(c *persistentClient) (err error) { // clashesIP returns a previous client with the same IP address as c. c must be // non-nil. -func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) { +func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) { for _, ip := range c.IPs { existing, ok := ci.ipToUID[ip] if ok && existing != c.UID { @@ -127,7 +127,7 @@ func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip n // clashesSubnet returns a previous client with the same subnet as c. c must be // non-nil. -func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) { +func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) { for _, s = range c.Subnets { var existing UID var ok bool @@ -153,7 +153,7 @@ func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, // clashesMAC returns a previous client with the same MAC address as c. c must // be non-nil. -func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) { +func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) { for _, mac = range c.MACs { k := macToKey(mac) existing, ok := ci.macToUID[k] @@ -165,9 +165,9 @@ func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac return nil, nil } -// find finds persistent client by string representation of the client ID, IP +// Find finds persistent client by string representation of the client ID, IP // address, or MAC. -func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) { +func (ci *Index) Find(id string) (c *Persistent, ok bool) { uid, found := ci.clientIDToUID[id] if found { return ci.uidToClient[uid], true @@ -191,7 +191,7 @@ func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) { } // find finds persistent client by IP address. -func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) { +func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) { uid, found := ci.ipToUID[ip] if found { return ci.uidToClient[uid], true @@ -215,7 +215,7 @@ func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) } // find finds persistent client by MAC. -func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) { +func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) { k := macToKey(mac) uid, found := ci.macToUID[k] if found { @@ -225,9 +225,9 @@ func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, fou return nil, false } -// del removes information about persistent client from the index. c must be +// Delete removes information about persistent client from the index. c must be // non-nil. -func (ci *clientIndex) del(c *persistentClient) { +func (ci *Index) Delete(c *Persistent) { for _, id := range c.ClientIDs { delete(ci.clientIDToUID, id) } diff --git a/internal/home/clientindex_internal_test.go b/internal/client/index_internal_test.go similarity index 87% rename from internal/home/clientindex_internal_test.go rename to internal/client/index_internal_test.go index b89703db..abf38710 100644 --- a/internal/home/clientindex_internal_test.go +++ b/internal/client/index_internal_test.go @@ -1,4 +1,4 @@ -package home +package client import ( "net" @@ -9,6 +9,19 @@ import ( "github.com/stretchr/testify/require" ) +// newIDIndex is a helper function that returns a client index filled with +// persistent clients from the m. It also generates a UID for each client. +func newIDIndex(m []*Persistent) (ci *Index) { + ci = NewIndex() + + for _, c := range m { + c.UID = MustNewUID() + ci.Add(c) + } + + return ci +} + func TestClientIndex(t *testing.T) { const ( cliIPNone = "1.2.3.4" @@ -24,7 +37,7 @@ func TestClientIndex(t *testing.T) { cliMAC = "11:11:11:11:11:11" ) - clients := []*persistentClient{{ + clients := []*Persistent{{ Name: "client1", IPs: []netip.Addr{ netip.MustParseAddr(cliIP1), @@ -45,9 +58,9 @@ func TestClientIndex(t *testing.T) { ci := newIDIndex(clients) testCases := []struct { + want *Persistent name string ids []string - want *persistentClient }{{ name: "ipv4_ipv6", ids: []string{cliIP1, cliIPv6}, @@ -69,7 +82,7 @@ func TestClientIndex(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { for _, id := range tc.ids { - c, ok := ci.find(id) + c, ok := ci.Find(id) require.True(t, ok) assert.Equal(t, tc.want, c) @@ -78,7 +91,7 @@ func TestClientIndex(t *testing.T) { } t.Run("not_found", func(t *testing.T) { - _, ok := ci.find(cliIPNone) + _, ok := ci.Find(cliIPNone) assert.False(t, ok) }) } @@ -92,7 +105,7 @@ func TestClientIndex_Clashes(t *testing.T) { cliMAC = "11:11:11:11:11:11" ) - clients := []*persistentClient{{ + clients := []*Persistent{{ Name: "client_with_ip", IPs: []netip.Addr{netip.MustParseAddr(cliIP1)}, }, { @@ -109,8 +122,8 @@ func TestClientIndex_Clashes(t *testing.T) { ci := newIDIndex(clients) testCases := []struct { + client *Persistent name string - client *persistentClient }{{ name: "ipv4", client: clients[0], @@ -127,14 +140,14 @@ func TestClientIndex_Clashes(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - clone := tc.client.shallowClone() + clone := tc.client.ShallowClone() clone.UID = MustNewUID() - err := ci.clashes(clone) + err := ci.Clashes(clone) require.Error(t, err) - ci.del(tc.client) - err = ci.clashes(clone) + ci.Delete(tc.client) + err = ci.Clashes(clone) require.NoError(t, err) }) } @@ -153,9 +166,9 @@ func mustParseMAC(s string) (mac net.HardwareAddr) { func TestMACToKey(t *testing.T) { testCases := []struct { + want any name string in string - want any }{{ name: "column6", in: "00:00:5e:00:53:01", diff --git a/internal/home/client.go b/internal/client/persistent.go similarity index 73% rename from internal/home/client.go rename to internal/client/persistent.go index f1b4b68f..d70966d7 100644 --- a/internal/home/client.go +++ b/internal/client/persistent.go @@ -1,4 +1,4 @@ -package home +package client import ( "encoding" @@ -9,12 +9,12 @@ import ( "strings" "time" - "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/google/uuid" ) @@ -56,16 +56,16 @@ func (uid *UID) UnmarshalText(data []byte) error { return (*uuid.UUID)(uid).UnmarshalText(data) } -// persistentClient contains information about persistent clients. -type persistentClient struct { - // upstreamConfig is the custom upstream configuration for this client. If +// Persistent contains information about persistent clients. +type Persistent struct { + // UpstreamConfig is the custom upstream configuration for this client. If // it's nil, it has not been initialized yet. If it's non-nil and empty, // there are no valid upstreams. If it's non-nil and non-empty, these // upstream must be used. - upstreamConfig *proxy.CustomUpstreamConfig + UpstreamConfig *proxy.CustomUpstreamConfig - // TODO(d.kolyshev): Make safeSearchConf a pointer. - safeSearchConf filtering.SafeSearchConfig + // TODO(d.kolyshev): Make SafeSearchConf a pointer. + SafeSearchConf filtering.SafeSearchConfig SafeSearch filtering.SafeSearch // BlockedServices is the configuration of blocked services of a client. @@ -97,8 +97,8 @@ type persistentClient struct { IgnoreStatistics bool } -// setTags sets the tags if they are known, otherwise logs an unknown tag. -func (c *persistentClient) setTags(tags []string, known *stringutil.Set) { +// SetTags sets the tags if they are known, otherwise logs an unknown tag. +func (c *Persistent) SetTags(tags []string, known *stringutil.Set) { for _, t := range tags { if !known.Has(t) { log.Info("skipping unknown tag %q", t) @@ -112,9 +112,9 @@ func (c *persistentClient) setTags(tags []string, known *stringutil.Set) { slices.Sort(c.Tags) } -// setIDs parses a list of strings into typed fields and returns an error if +// SetIDs parses a list of strings into typed fields and returns an error if // there is one. -func (c *persistentClient) setIDs(ids []string) (err error) { +func (c *Persistent) SetIDs(ids []string) (err error) { for _, id := range ids { err = c.setID(id) if err != nil { @@ -154,7 +154,7 @@ func subnetCompare(x, y netip.Prefix) (cmp int) { } // setID parses id into typed field if there is no error. -func (c *persistentClient) setID(id string) (err error) { +func (c *Persistent) setID(id string) (err error) { if id == "" { return errors.Error("clientid is empty") } @@ -180,7 +180,7 @@ func (c *persistentClient) setID(id string) (err error) { return nil } - err = dnsforward.ValidateClientID(id) + err = ValidateClientID(id) if err != nil { // Don't wrap the error, because it's informative enough as is. return err @@ -191,9 +191,23 @@ func (c *persistentClient) setID(id string) (err error) { return nil } -// ids returns a list of client ids containing at least one element. -func (c *persistentClient) ids() (ids []string) { - ids = make([]string, 0, c.idsLen()) +// ValidateClientID returns an error if id is not a valid ClientID. +// +// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to +// avoid the import cycle. Remove it. +func ValidateClientID(id string) (err error) { + err = netutil.ValidateHostnameLabel(id) + if err != nil { + // Replace the domain name label wrapper with our own. + return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err)) + } + + return nil +} + +// IDs returns a list of client IDs containing at least one element. +func (c *Persistent) IDs() (ids []string) { + ids = make([]string, 0, c.IDsLen()) for _, ip := range c.IPs { ids = append(ids, ip.String()) @@ -210,24 +224,24 @@ func (c *persistentClient) ids() (ids []string) { return append(ids, c.ClientIDs...) } -// idsLen returns a length of client ids. -func (c *persistentClient) idsLen() (n int) { +// IDsLen returns a length of client ids. +func (c *Persistent) IDsLen() (n int) { return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs) } -// equalIDs returns true if the ids of the current and previous clients are the +// EqualIDs returns true if the ids of the current and previous clients are the // same. -func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) { +func (c *Persistent) EqualIDs(prev *Persistent) (equal bool) { return slices.Equal(c.IPs, prev.IPs) && slices.Equal(c.Subnets, prev.Subnets) && slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) && slices.Equal(c.ClientIDs, prev.ClientIDs) } -// shallowClone returns a deep copy of the client, except upstreamConfig, +// ShallowClone returns a deep copy of the client, except upstreamConfig, // safeSearchConf, SafeSearch fields, because it's difficult to copy them. -func (c *persistentClient) shallowClone() (clone *persistentClient) { - clone = &persistentClient{} +func (c *Persistent) ShallowClone() (clone *Persistent) { + clone = &Persistent{} *clone = *c clone.BlockedServices = c.BlockedServices.Clone() @@ -242,10 +256,10 @@ func (c *persistentClient) shallowClone() (clone *persistentClient) { return clone } -// closeUpstreams closes the client-specific upstream config of c if any. -func (c *persistentClient) closeUpstreams() (err error) { - if c.upstreamConfig != nil { - if err = c.upstreamConfig.Close(); err != nil { +// CloseUpstreams closes the client-specific upstream config of c if any. +func (c *Persistent) CloseUpstreams() (err error) { + if c.UpstreamConfig != nil { + if err = c.UpstreamConfig.Close(); err != nil { return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err) } } @@ -253,8 +267,8 @@ func (c *persistentClient) closeUpstreams() (err error) { return nil } -// setSafeSearch initializes and sets the safe search filter for this client. -func (c *persistentClient) setSafeSearch( +// SetSafeSearch initializes and sets the safe search filter for this client. +func (c *Persistent) SetSafeSearch( conf filtering.SafeSearchConfig, cacheSize uint, cacheTTL time.Duration, diff --git a/internal/home/client_internal_test.go b/internal/client/persistent_internal_test.go similarity index 94% rename from internal/home/client_internal_test.go rename to internal/client/persistent_internal_test.go index c360cb19..76da1e4b 100644 --- a/internal/home/client_internal_test.go +++ b/internal/client/persistent_internal_test.go @@ -1,4 +1,4 @@ -package home +package client import ( "testing" @@ -27,10 +27,10 @@ func TestPersistentClient_EqualIDs(t *testing.T) { ) testCases := []struct { + want assert.BoolAssertionFunc name string ids []string prevIDs []string - want assert.BoolAssertionFunc }{{ name: "single_ip", ids: []string{ip1}, @@ -110,15 +110,15 @@ func TestPersistentClient_EqualIDs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c := &persistentClient{} - err := c.setIDs(tc.ids) + c := &Persistent{} + err := c.SetIDs(tc.ids) require.NoError(t, err) - prev := &persistentClient{} - err = prev.setIDs(tc.prevIDs) + prev := &Persistent{} + err = prev.SetIDs(tc.prevIDs) require.NoError(t, err) - tc.want(t, c.equalIDs(prev)) + tc.want(t, c.EqualIDs(prev)) }) } } diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 0e39744c..71ad2eb0 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -14,6 +14,8 @@ import ( ) // ValidateClientID returns an error if id is not a valid ClientID. +// +// Keep in sync with [client.ValidateClientID]. func ValidateClientID(id string) (err error) { err = netutil.ValidateHostnameLabel(id) if err != nil { diff --git a/internal/home/clients.go b/internal/home/clients.go index f429e97e..fb627a2e 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -47,9 +47,9 @@ type DHCP interface { type clientsContainer struct { // TODO(a.garipov): Perhaps use a number of separate indices for different // types (string, netip.Addr, and so on). - list map[string]*persistentClient // name -> client + list map[string]*client.Persistent // name -> client - clientIndex *clientIndex + clientIndex *client.Index // ipToRC maps IP addresses to runtime client information. ipToRC map[netip.Addr]*client.Runtime @@ -103,10 +103,10 @@ func (clients *clientsContainer) Init( log.Fatal("clients.list != nil") } - clients.list = map[string]*persistentClient{} + clients.list = map[string]*client.Persistent{} clients.ipToRC = map[netip.Addr]*client.Runtime{} - clients.clientIndex = NewClientIndex() + clients.clientIndex = client.NewIndex() clients.allTags = stringutil.NewSet(clientTags...) @@ -190,7 +190,7 @@ type clientObject struct { Upstreams []string `yaml:"upstreams"` // UID is the unique identifier of the persistent client. - UID UID `yaml:"uid"` + UID client.UID `yaml:"uid"` // UpstreamsCacheSize is the DNS cache size (in bytes). // @@ -214,8 +214,8 @@ type clientObject struct { func (o *clientObject) toPersistent( filteringConf *filtering.Config, allTags *stringutil.Set, -) (cli *persistentClient, err error) { - cli = &persistentClient{ +) (cli *client.Persistent, err error) { + cli = &client.Persistent{ Name: o.Name, Upstreams: o.Upstreams, @@ -225,7 +225,7 @@ func (o *clientObject) toPersistent( UseOwnSettings: !o.UseGlobalSettings, FilteringEnabled: o.FilteringEnabled, ParentalEnabled: o.ParentalEnabled, - safeSearchConf: o.SafeSearchConf, + SafeSearchConf: o.SafeSearchConf, SafeBrowsingEnabled: o.SafeBrowsingEnabled, UseOwnBlockedServices: !o.UseGlobalBlockedServices, IgnoreQueryLog: o.IgnoreQueryLog, @@ -234,13 +234,13 @@ func (o *clientObject) toPersistent( UpstreamsCacheSize: o.UpstreamsCacheSize, } - err = cli.setIDs(o.IDs) + err = cli.SetIDs(o.IDs) if err != nil { return nil, fmt.Errorf("parsing ids: %w", err) } - if (cli.UID == UID{}) { - cli.UID, err = NewUID() + if (cli.UID == client.UID{}) { + cli.UID, err = client.NewUID() if err != nil { return nil, fmt.Errorf("generating uid: %w", err) } @@ -249,7 +249,7 @@ func (o *clientObject) toPersistent( if o.SafeSearchConf.Enabled { o.SafeSearchConf.CustomResolver = safeSearchResolver{} - err = cli.setSafeSearch( + err = cli.SetSafeSearch( o.SafeSearchConf, filteringConf.SafeSearchCacheSize, time.Minute*time.Duration(filteringConf.CacheTime), @@ -266,7 +266,7 @@ func (o *clientObject) toPersistent( cli.BlockedServices = o.BlockedServices.Clone() - cli.setTags(o.Tags, allTags) + cli.SetTags(o.Tags, allTags) return cli, nil } @@ -278,7 +278,7 @@ func (clients *clientsContainer) addFromConfig( filteringConf *filtering.Config, ) (err error) { for i, o := range objects { - var cli *persistentClient + var cli *client.Persistent cli, err = o.toPersistent(filteringConf, clients.allTags) if err != nil { return fmt.Errorf("clients: init persistent client at index %d: %w", i, err) @@ -306,7 +306,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { BlockedServices: cli.BlockedServices.Clone(), - IDs: cli.ids(), + IDs: cli.IDs(), Tags: stringutil.CloneSlice(cli.Tags), Upstreams: stringutil.CloneSlice(cli.Upstreams), @@ -315,7 +315,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) { UseGlobalSettings: !cli.UseOwnSettings, FilteringEnabled: cli.FilteringEnabled, ParentalEnabled: cli.ParentalEnabled, - SafeSearchConf: cli.safeSearchConf, + SafeSearchConf: cli.SafeSearchConf, SafeBrowsingEnabled: cli.SafeBrowsingEnabled, UseGlobalBlockedServices: !cli.UseOwnBlockedServices, IgnoreQueryLog: cli.IgnoreQueryLog, @@ -436,7 +436,7 @@ func (clients *clientsContainer) clientOrArtificial( } // find returns a shallow copy of the client if there is one found. -func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) { +func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -445,7 +445,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) return nil, false } - return c.shallowClone(), true + return c.ShallowClone(), true } // shouldCountClient is a wrapper around [clientsContainer.find] to make it a @@ -481,8 +481,8 @@ func (clients *clientsContainer) UpstreamConfigByID( c, ok := clients.findLocked(id) if !ok { return nil, nil - } else if c.upstreamConfig != nil { - return c.upstreamConfig, nil + } else if c.UpstreamConfig != nil { + return c.UpstreamConfig, nil } upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty) @@ -511,15 +511,15 @@ func (clients *clientsContainer) UpstreamConfigByID( int(c.UpstreamsCacheSize), config.DNS.EDNSClientSubnet.Enabled, ) - c.upstreamConfig = conf + c.UpstreamConfig = conf return conf, nil } // findLocked searches for a client by its ID. clients.lock is expected to be // locked. -func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) { - c, ok = clients.clientIndex.find(id) +func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) { + c, ok = clients.clientIndex.Find(id) if ok { return c, true } @@ -535,7 +535,7 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok // findDHCP searches for a client by its MAC, if the DHCP server is active and // there is such client. clients.lock is expected to be locked. -func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, ok bool) { +func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) { foundMAC := clients.dhcp.MACByIP(ip) if foundMAC == nil { return nil, false @@ -585,13 +585,13 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru } // check validates the client. It also sorts the client tags. -func (clients *clientsContainer) check(c *persistentClient) (err error) { +func (clients *clientsContainer) check(c *client.Persistent) (err error) { switch { case c == nil: return errors.Error("client is nil") case c.Name == "": return errors.Error("invalid name") - case c.idsLen() == 0: + case c.IDsLen() == 0: return errors.Error("id required") default: // Go on. @@ -616,7 +616,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) { // add adds a new client object. ok is false if such client already exists or // if an error occurred. -func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { +func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) { err = clients.check(c) if err != nil { return false, err @@ -632,7 +632,7 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { } // check ID index - err = clients.clientIndex.clashes(c) + err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. return false, err @@ -640,18 +640,18 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) { clients.addLocked(c) - log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list)) return true, nil } // addLocked c to the indexes. clients.lock is expected to be locked. -func (clients *clientsContainer) addLocked(c *persistentClient) { +func (clients *clientsContainer) addLocked(c *client.Persistent) { // update Name index clients.list[c.Name] = c // update ID index - clients.clientIndex.add(c) + clients.clientIndex.Add(c) } // remove removes a client. ok is false if there is no such client. @@ -659,7 +659,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - var c *persistentClient + var c *client.Persistent c, ok = clients.list[name] if !ok { return false @@ -672,8 +672,8 @@ func (clients *clientsContainer) remove(name string) (ok bool) { // removeLocked removes c from the indexes. clients.lock is expected to be // locked. -func (clients *clientsContainer) removeLocked(c *persistentClient) { - if err := c.closeUpstreams(); err != nil { +func (clients *clientsContainer) removeLocked(c *client.Persistent) { + if err := c.CloseUpstreams(); err != nil { log.Error("client container: removing client %s: %s", c.Name, err) } @@ -681,11 +681,11 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) { delete(clients.list, c.Name) // Update the ID index. - clients.clientIndex.del(c) + clients.clientIndex.Delete(c) } // update updates a client by its name. -func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { +func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) { err = clients.check(c) if err != nil { // Don't wrap the error since it's informative enough as is. @@ -703,7 +703,7 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { } } - if c.equalIDs(prev) { + if c.EqualIDs(prev) { clients.removeLocked(prev) clients.addLocked(c) @@ -711,7 +711,7 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) { } // Check the ID index. - err = clients.clientIndex.clashes(c) + err = clients.clientIndex.Clashes(c) if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -891,14 +891,14 @@ func (clients *clientsContainer) addFromSystemARP() { // the persistent clients. func (clients *clientsContainer) close() (err error) { persistent := maps.Values(clients.list) - slices.SortFunc(persistent, func(a, b *persistentClient) (res int) { + slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) { return strings.Compare(a.Name, b.Name) }) var errs []error for _, cli := range persistent { - if err = cli.closeUpstreams(); err != nil { + if err = cli.CloseUpstreams(); err != nil { errs = append(errs, err) } } diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 6b384f6b..4f9cb946 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -66,9 +66,9 @@ func TestClients(t *testing.T) { cliIPv6 = netip.MustParseAddr("1:2:3::4") ) - c := &persistentClient{ + c := &client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{cli1IP, cliIPv6}, } @@ -77,9 +77,9 @@ func TestClients(t *testing.T) { assert.True(t, ok) - c = &persistentClient{ + c = &client.Persistent{ Name: "client2", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{cli2IP}, } @@ -111,9 +111,9 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_name", func(t *testing.T) { - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")}, }) require.NoError(t, err) @@ -121,18 +121,18 @@ func TestClients(t *testing.T) { }) t.Run("add_fail_ip", func(t *testing.T) { - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client3", - UID: MustNewUID(), + UID: client.MustNewUID(), }) require.Error(t, err) assert.False(t, ok) }) t.Run("update_fail_ip", func(t *testing.T) { - err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{ + err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), }) assert.Error(t, err) }) @@ -148,9 +148,9 @@ func TestClients(t *testing.T) { prev, ok := clients.list["client1"] require.True(t, ok) - err := clients.update(prev, &persistentClient{ + err := clients.update(prev, &client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{cliNewIP}, }) require.NoError(t, err) @@ -163,9 +163,9 @@ func TestClients(t *testing.T) { prev, ok = clients.list["client1"] require.True(t, ok) - err = clients.update(prev, &persistentClient{ + err = clients.update(prev, &client.Persistent{ Name: "client1-renamed", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{cliNewIP}, UseOwnSettings: true, }) @@ -182,7 +182,7 @@ func TestClients(t *testing.T) { assert.Nil(t, nilCli) - require.Len(t, c.ids(), 1) + require.Len(t, c.IDs(), 1) assert.Equal(t, cliNewIP, c.IPs[0]) }) @@ -265,9 +265,9 @@ func TestClientsWHOIS(t *testing.T) { t.Run("can't_set_manually-added", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.2") - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")}, }) require.NoError(t, err) @@ -288,9 +288,9 @@ func TestClientsAddExisting(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") // Add a client. - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")}, MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}}, @@ -339,18 +339,18 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client2", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{ip}, }) require.NoError(t, err) assert.True(t, ok) // Add a new client with the IP from the first client's IP range. - ok, err = clients.add(&persistentClient{ + ok, err = clients.add(&client.Persistent{ Name: "client3", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, }) require.NoError(t, err) @@ -362,9 +362,9 @@ func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) // Add client with upstreams. - ok, err := clients.add(&persistentClient{ + ok, err := clients.add(&client.Persistent{ Name: "client1", - UID: MustNewUID(), + UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, Upstreams: []string{ "1.1.1.1", diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 3f2918ca..b2270416 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -131,9 +131,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http // initPrev initializes the persistent client with the default or previous // client properties. -func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err error) { +func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err error) { var ( - uid UID + uid client.UID ignoreQueryLog bool ignoreStatistics bool upsCacheEnabled bool @@ -166,14 +166,14 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e return nil, fmt.Errorf("invalid blocked services: %w", err) } - if (uid == UID{}) { - uid, err = NewUID() + if (uid == client.UID{}) { + uid, err = client.NewUID() if err != nil { return nil, fmt.Errorf("generating uid: %w", err) } } - return &persistentClient{ + return &client.Persistent{ BlockedServices: svcs, UID: uid, IgnoreQueryLog: ignoreQueryLog, @@ -187,21 +187,21 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e // errors. func (clients *clientsContainer) jsonToClient( cj clientJSON, - prev *persistentClient, -) (c *persistentClient, err error) { + prev *client.Persistent, +) (c *client.Persistent, err error) { c, err = initPrev(cj, prev) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err } - err = c.setIDs(cj.IDs) + err = c.SetIDs(cj.IDs) if err != nil { // Don't wrap the error since it's informative enough as is. return nil, err } - c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled) + c.SafeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled) c.Name = cj.Name c.Tags = cj.Tags c.Upstreams = cj.Upstreams @@ -211,9 +211,9 @@ func (clients *clientsContainer) jsonToClient( c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices - if c.safeSearchConf.Enabled { - err = c.setSafeSearch( - c.safeSearchConf, + if c.SafeSearchConf.Enabled { + err = c.SetSafeSearch( + c.SafeSearchConf, clients.safeSearchCacheSize, clients.safeSearchCacheTTL, ) @@ -258,7 +258,7 @@ func copySafeSearch( func copyBlockedServices( sch *schedule.Weekly, svcStrs []string, - prev *persistentClient, + prev *client.Persistent, ) (svcs *filtering.BlockedServices, err error) { var weekly *schedule.Weekly if sch != nil { @@ -283,15 +283,15 @@ func copyBlockedServices( } // clientToJSON converts persistent client object to JSON object. -func clientToJSON(c *persistentClient) (cj *clientJSON) { +func clientToJSON(c *client.Persistent) (cj *clientJSON) { // TODO(d.kolyshev): Remove after cleaning the deprecated // [clientJSON.SafeSearchEnabled] field. - cloneVal := c.safeSearchConf + cloneVal := c.SafeSearchConf safeSearchConf := &cloneVal return &clientJSON{ Name: c.Name, - IDs: c.ids(), + IDs: c.IDs(), Tags: c.Tags, UseGlobalSettings: !c.UseOwnSettings, FilteringEnabled: c.FilteringEnabled, @@ -397,7 +397,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - var prev *persistentClient + var prev *client.Persistent var ok bool func() { diff --git a/internal/home/dns.go b/internal/home/dns.go index 2b0267e8..a1f841b7 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -427,7 +427,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte } setts.FilteringEnabled = c.FilteringEnabled - setts.SafeSearchEnabled = c.safeSearchConf.Enabled + setts.SafeSearchEnabled = c.SafeSearchConf.Enabled setts.ClientSafeSearch = c.SafeSearch setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled setts.ParentalEnabled = c.ParentalEnabled diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index ff279752..8413e2a3 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -4,6 +4,7 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/stretchr/testify/assert" @@ -14,12 +15,12 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) // newIDIndex is a helper function that returns a client index filled with // persistent clients from the m. It also generates a UID for each client. -func newIDIndex(m []*persistentClient) (ci *clientIndex) { - ci = NewClientIndex() +func newIDIndex(m []*client.Persistent) (ci *client.Index) { + ci = client.NewIndex() for _, c := range m { - c.UID = MustNewUID() - ci.add(c) + c.UID = client.MustNewUID() + ci.Add(c) } return ci @@ -35,24 +36,24 @@ func TestApplyAdditionalFiltering(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ ClientIDs: []string{"default"}, UseOwnSettings: false, - safeSearchConf: filtering.SafeSearchConfig{Enabled: false}, + SafeSearchConf: filtering.SafeSearchConfig{Enabled: false}, FilteringEnabled: false, SafeBrowsingEnabled: false, ParentalEnabled: false, }, { ClientIDs: []string{"custom_filtering"}, UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, FilteringEnabled: true, SafeBrowsingEnabled: true, ParentalEnabled: true, }, { ClientIDs: []string{"partial_custom_filtering"}, UseOwnSettings: true, - safeSearchConf: filtering.SafeSearchConfig{Enabled: true}, + SafeSearchConf: filtering.SafeSearchConfig{Enabled: true}, FilteringEnabled: true, SafeBrowsingEnabled: false, ParentalEnabled: false, @@ -120,7 +121,7 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { }, nil) require.NoError(t, err) - Context.clients.clientIndex = newIDIndex([]*persistentClient{{ + Context.clients.clientIndex = newIDIndex([]*client.Persistent{{ ClientIDs: []string{"default"}, UseOwnBlockedServices: false, }, {