diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 87ec9cb8..bc86721e 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -89,14 +89,14 @@ func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { // AddressProcessor is a fake [client.AddressProcessor] implementation for // tests. type AddressProcessor struct { - OnProcess func(ip netip.Addr) + OnProcess func(ctx context.Context, ip netip.Addr) OnClose func() (err error) } // Process implements the [client.AddressProcessor] interface for // *AddressProcessor. -func (p *AddressProcessor) Process(ip netip.Addr) { - p.OnProcess(ip) +func (p *AddressProcessor) Process(ctx context.Context, ip netip.Addr) { + p.OnProcess(ctx, ip) } // Close implements the [client.AddressProcessor] interface for @@ -107,13 +107,18 @@ func (p *AddressProcessor) Close() (err error) { // AddressUpdater is a fake [client.AddressUpdater] implementation for tests. type AddressUpdater struct { - OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info) + OnUpdateAddress func(ctx context.Context, ip netip.Addr, host string, info *whois.Info) } // UpdateAddress implements the [client.AddressUpdater] interface for // *AddressUpdater. -func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { - p.OnUpdateAddress(ip, host, info) +func (p *AddressUpdater) UpdateAddress( + ctx context.Context, + ip netip.Addr, + host string, + info *whois.Info, +) { + p.OnUpdateAddress(ctx, ip, host, info) } // Package dnsforward diff --git a/internal/client/addrproc.go b/internal/client/addrproc.go index 35293609..abd2ed69 100644 --- a/internal/client/addrproc.go +++ b/internal/client/addrproc.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" ) @@ -22,7 +21,7 @@ const ErrClosed errors.Error = "use of closed address processor" // AddressProcessor is the interface for types that can process clients. type AddressProcessor interface { - Process(ip netip.Addr) + Process(ctx context.Context, ip netip.Addr) Close() (err error) } @@ -33,7 +32,7 @@ type EmptyAddrProc struct{} var _ AddressProcessor = EmptyAddrProc{} // Process implements the [AddressProcessor] interface for EmptyAddrProc. -func (EmptyAddrProc) Process(_ netip.Addr) {} +func (EmptyAddrProc) Process(_ context.Context, _ netip.Addr) {} // Close implements the [AddressProcessor] interface for EmptyAddrProc. func (EmptyAddrProc) Close() (_ error) { return nil } @@ -90,12 +89,15 @@ type DefaultAddrProcConfig struct { type AddressUpdater interface { // UpdateAddress updates information about an IP address, setting host (if // not empty) and WHOIS information (if not nil). - UpdateAddress(ip netip.Addr, host string, info *whois.Info) + UpdateAddress(ctx context.Context, ip netip.Addr, host string, info *whois.Info) } // DefaultAddrProc processes incoming client addresses with rDNS and WHOIS, if // configured, and updates that information in a client storage. type DefaultAddrProc struct { + // logger is used to log the operation of address processor. + logger *slog.Logger + // clientIPsMu serializes closure of clientIPs and access to isClosed. clientIPsMu *sync.Mutex @@ -142,6 +144,7 @@ const ( // not be nil. func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { p = &DefaultAddrProc{ + logger: c.BaseLogger.With(slogutil.KeyPrefix, "addrproc"), clientIPsMu: &sync.Mutex{}, clientIPs: make(chan netip.Addr, defaultQueueSize), rdns: &rdns.Empty{}, @@ -164,10 +167,13 @@ func NewDefaultAddrProc(c *DefaultAddrProcConfig) (p *DefaultAddrProc) { p.whois = newWHOIS(c.BaseLogger.With(slogutil.KeyPrefix, "whois"), c.DialContext) } - go p.process(c.CatchPanics) + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + + go p.process(ctx, c.CatchPanics) for _, ip := range c.InitialAddresses { - p.Process(ip) + p.Process(ctx, ip) } return p @@ -210,7 +216,7 @@ func newWHOIS(logger *slog.Logger, dialFunc aghnet.DialContextFunc) (w whois.Int var _ AddressProcessor = (*DefaultAddrProc)(nil) // Process implements the [AddressProcessor] interface for *DefaultAddrProc. -func (p *DefaultAddrProc) Process(ip netip.Addr) { +func (p *DefaultAddrProc) Process(ctx context.Context, ip netip.Addr) { p.clientIPsMu.Lock() defer p.clientIPsMu.Unlock() @@ -222,38 +228,42 @@ func (p *DefaultAddrProc) Process(ip netip.Addr) { case p.clientIPs <- ip: // Go on. default: - log.Debug("clients: ip channel is full; len: %d", len(p.clientIPs)) + p.logger.DebugContext(ctx, "ip channel is full", "len", len(p.clientIPs)) } } // process processes the incoming client IP-address information. It is intended // to be used as a goroutine. Once clientIPs is closed, process exits. -func (p *DefaultAddrProc) process(catchPanics bool) { +func (p *DefaultAddrProc) process(ctx context.Context, catchPanics bool) { if catchPanics { - defer log.OnPanic("addrProcessor.process") + defer slogutil.RecoverAndLog(ctx, p.logger) } - log.Info("clients: processing addresses") - - ctx := context.TODO() + p.logger.InfoContext(ctx, "processing addresses") for ip := range p.clientIPs { host := p.processRDNS(ctx, ip) info := p.processWHOIS(ctx, ip) - p.addrUpdater.UpdateAddress(ip, host, info) + p.addrUpdater.UpdateAddress(ctx, ip, host, info) } - log.Info("clients: finished processing addresses") + p.logger.InfoContext(ctx, "finished processing addresses") } // processRDNS resolves the clients' IP addresses using reverse DNS. host is // empty if there were errors or if the information hasn't changed. func (p *DefaultAddrProc) processRDNS(ctx context.Context, ip netip.Addr) (host string) { start := time.Now() - log.Debug("clients: processing %s with rdns", ip) + p.logger.DebugContext(ctx, "processing rdns", "ip", ip) defer func() { - log.Debug("clients: finished processing %s with rdns in %s", ip, time.Since(start)) + p.logger.DebugContext( + ctx, + "finished processing rdns", + "ip", ip, + "host", host, + "elapsed", time.Since(start), + ) }() ok := p.shouldResolve(ip) @@ -280,9 +290,15 @@ func (p *DefaultAddrProc) shouldResolve(ip netip.Addr) (ok bool) { // hasn't changed. func (p *DefaultAddrProc) processWHOIS(ctx context.Context, ip netip.Addr) (info *whois.Info) { start := time.Now() - log.Debug("clients: processing %s with whois", ip) + p.logger.DebugContext(ctx, "processing whois", "ip", ip) defer func() { - log.Debug("clients: finished processing %s with whois in %s", ip, time.Since(start)) + p.logger.DebugContext( + ctx, + "finished processing whois", + "ip", ip, + "whois", info, + "elapsed", time.Since(start), + ) }() // TODO(s.chzhen): Move the timeout logic from WHOIS configuration to the diff --git a/internal/client/addrproc_test.go b/internal/client/addrproc_test.go index 3df3a5c7..265ebd0d 100644 --- a/internal/client/addrproc_test.go +++ b/internal/client/addrproc_test.go @@ -26,7 +26,8 @@ func TestEmptyAddrProc(t *testing.T) { p := client.EmptyAddrProc{} assert.NotPanics(t, func() { - p.Process(testIP) + ctx := testutil.ContextWithTimeout(t, testTimeout) + p.Process(ctx, testIP) }) assert.NotPanics(t, func() { @@ -120,7 +121,8 @@ func TestDefaultAddrProc_Process_rDNS(t *testing.T) { }) testutil.CleanupAndRequireSuccess(t, p.Close) - p.Process(tc.ip) + ctx := testutil.ContextWithTimeout(t, testTimeout) + p.Process(ctx, tc.ip) if !tc.wantUpd { return @@ -146,8 +148,8 @@ func newOnUpdateAddress( ips chan<- netip.Addr, hosts chan<- string, infos chan<- *whois.Info, -) (f func(ip netip.Addr, host string, info *whois.Info)) { - return func(ip netip.Addr, host string, info *whois.Info) { +) (f func(ctx context.Context, ip netip.Addr, host string, info *whois.Info)) { + return func(ctx context.Context, ip netip.Addr, host string, info *whois.Info) { if !want && (host != "" || info != nil) { panic(fmt.Errorf("got unexpected update for %v with %q and %v", ip, host, info)) } @@ -230,7 +232,8 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) { }) testutil.CleanupAndRequireSuccess(t, p.Close) - p.Process(testIP) + ctx := testutil.ContextWithTimeout(t, testTimeout) + p.Process(ctx, testIP) if !tc.wantUpd { return @@ -251,7 +254,9 @@ func TestDefaultAddrProc_Process_WHOIS(t *testing.T) { func TestDefaultAddrProc_Close(t *testing.T) { t.Parallel() - p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{}) + p := client.NewDefaultAddrProc(&client.DefaultAddrProcConfig{ + BaseLogger: slogutil.NewDiscardLogger(), + }) err := p.Close() assert.NoError(t, err) diff --git a/internal/client/persistent.go b/internal/client/persistent.go index 7a0339b0..1cea335b 100644 --- a/internal/client/persistent.go +++ b/internal/client/persistent.go @@ -1,8 +1,10 @@ package client import ( + "context" "encoding" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -12,7 +14,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/google/uuid" ) @@ -134,7 +136,7 @@ type Persistent struct { // validate returns an error if persistent client information contains errors. // allTags must be sorted. -func (c *Persistent) validate(allTags []string) (err error) { +func (c *Persistent) validate(ctx context.Context, l *slog.Logger, allTags []string) (err error) { switch { case c.Name == "": return errors.Error("empty name") @@ -151,7 +153,7 @@ func (c *Persistent) validate(allTags []string) (err error) { err = conf.Close() if err != nil { - log.Error("client: closing upstream config: %s", err) + l.ErrorContext(ctx, "client: closing upstream config", slogutil.KeyError, err) } for _, t := range c.Tags { diff --git a/internal/client/storage.go b/internal/client/storage.go index da6dda5c..c3820cd8 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -3,6 +3,7 @@ package client import ( "context" "fmt" + "log/slog" "net" "net/netip" "slices" @@ -15,6 +16,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) // allowedTags is the list of available client tags. @@ -83,6 +85,10 @@ type HostsContainer interface { // StorageConfig is the client storage configuration structure. type StorageConfig struct { + // Logger is used for logging the operation of the client storage. It must + // not be nil. + Logger *slog.Logger + // DHCP is used to match IPs against MACs of persistent clients and update // [SourceDHCP] runtime client information. It must not be nil. DHCP DHCP @@ -108,6 +114,10 @@ type StorageConfig struct { // Storage contains information about persistent and runtime clients. type Storage struct { + // logger is used for logging the operation of the client storage. It must + // not be nil. + logger *slog.Logger + // mu protects indexes of persistent and runtime clients. mu *sync.Mutex @@ -145,12 +155,12 @@ type Storage struct { } // NewStorage returns initialized client storage. conf must not be nil. -func NewStorage(conf *StorageConfig) (s *Storage, err error) { +func NewStorage(ctx context.Context, conf *StorageConfig) (s *Storage, err error) { tags := slices.Clone(allowedTags) slices.Sort(tags) s = &Storage{ - allowedTags: tags, + logger: conf.Logger, mu: &sync.Mutex{}, index: newIndex(), runtimeIndex: newRuntimeIndex(), @@ -158,18 +168,19 @@ func NewStorage(conf *StorageConfig) (s *Storage, err error) { etcHosts: conf.EtcHosts, arpDB: conf.ARPDB, done: make(chan struct{}), + allowedTags: tags, arpClientsUpdatePeriod: conf.ARPClientsUpdatePeriod, runtimeSourceDHCP: conf.RuntimeSourceDHCP, } for i, p := range conf.InitialClients { - err = s.Add(p) + err = s.Add(ctx, p) if err != nil { return nil, fmt.Errorf("adding client %q at index %d: %w", p.Name, i, err) } } - s.ReloadARP() + s.ReloadARP(ctx) return s, nil } @@ -177,9 +188,9 @@ func NewStorage(conf *StorageConfig) (s *Storage, err error) { // Start starts the goroutines for updating the runtime client information. // // TODO(s.chzhen): Pass context. -func (s *Storage) Start(_ context.Context) (err error) { - go s.periodicARPUpdate() - go s.handleHostsUpdates() +func (s *Storage) Start(ctx context.Context) (err error) { + go s.periodicARPUpdate(ctx) + go s.handleHostsUpdates(ctx) return nil } @@ -195,15 +206,15 @@ func (s *Storage) Shutdown(_ context.Context) (err error) { // periodicARPUpdate periodically reloads runtime clients from ARP. It is // intended to be used as a goroutine. -func (s *Storage) periodicARPUpdate() { - defer log.OnPanic("storage") +func (s *Storage) periodicARPUpdate(ctx context.Context) { + defer slogutil.RecoverAndLog(ctx, s.logger) t := time.NewTicker(s.arpClientsUpdatePeriod) for { select { case <-t.C: - s.ReloadARP() + s.ReloadARP(ctx) case <-s.done: return } @@ -211,28 +222,28 @@ func (s *Storage) periodicARPUpdate() { } // ReloadARP reloads runtime clients from ARP, if configured. -func (s *Storage) ReloadARP() { +func (s *Storage) ReloadARP(ctx context.Context) { if s.arpDB != nil { - s.addFromSystemARP() + s.addFromSystemARP(ctx) } } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a // command. -func (s *Storage) addFromSystemARP() { +func (s *Storage) addFromSystemARP(ctx context.Context) { s.mu.Lock() defer s.mu.Unlock() if err := s.arpDB.Refresh(); err != nil { s.arpDB = arpdb.Empty{} - log.Error("refreshing arp container: %s", err) + s.logger.ErrorContext(ctx, "refreshing arp container", slogutil.KeyError, err) return } ns := s.arpDB.Neighbors() if len(ns) == 0 { - log.Debug("refreshing arp container: the update is empty") + s.logger.DebugContext(ctx, "refreshing arp container: the update is empty") return } @@ -246,17 +257,22 @@ func (s *Storage) addFromSystemARP() { removed := s.runtimeIndex.removeEmpty() - log.Debug("storage: added %d, removed %d client aliases from arp neighborhood", len(ns), removed) + s.logger.DebugContext( + ctx, + "updating client aliases from arp neighborhood", + "added", len(ns), + "removed", removed, + ) } // handleHostsUpdates receives the updates from the hosts container and adds // them to the clients storage. It is intended to be used as a goroutine. -func (s *Storage) handleHostsUpdates() { +func (s *Storage) handleHostsUpdates(ctx context.Context) { if s.etcHosts == nil { return } - defer log.OnPanic("storage") + defer slogutil.RecoverAndLog(ctx, s.logger) for { select { @@ -265,7 +281,7 @@ func (s *Storage) handleHostsUpdates() { return } - s.addFromHostsFile(upd) + s.addFromHostsFile(ctx, upd) case <-s.done: return } @@ -274,7 +290,7 @@ func (s *Storage) handleHostsUpdates() { // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. -func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) { +func (s *Storage) addFromHostsFile(ctx context.Context, hosts *hostsfile.DefaultStorage) { s.mu.Lock() defer s.mu.Unlock() @@ -294,14 +310,19 @@ func (s *Storage) addFromHostsFile(hosts *hostsfile.DefaultStorage) { }) removed := s.runtimeIndex.removeEmpty() - log.Debug("storage: added %d, removed %d client aliases from system hosts file", added, removed) + s.logger.DebugContext( + ctx, + "updating client aliases from system hosts file", + "added", added, + "removed", removed, + ) } // type check var _ AddressUpdater = (*Storage)(nil) // UpdateAddress implements the [AddressUpdater] interface for *Storage -func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { +func (s *Storage) UpdateAddress(ctx context.Context, ip netip.Addr, host string, info *whois.Info) { // Common fast path optimization. if host == "" && info == nil { return @@ -315,12 +336,12 @@ func (s *Storage) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { } if info != nil { - s.setWHOISInfo(ip, info) + s.setWHOISInfo(ctx, ip, info) } } // UpdateDHCP updates [SourceDHCP] runtime client information. -func (s *Storage) UpdateDHCP() { +func (s *Storage) UpdateDHCP(ctx context.Context) { if s.dhcp == nil || !s.runtimeSourceDHCP { return } @@ -338,14 +359,23 @@ func (s *Storage) UpdateDHCP() { } removed := s.runtimeIndex.removeEmpty() - log.Debug("storage: added %d, removed %d client aliases from dhcp", added, removed) + s.logger.DebugContext( + ctx, + "updating client aliases from dhcp", + "added", added, + "removed", removed, + ) } // setWHOISInfo sets the WHOIS information for a runtime client. -func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) { +func (s *Storage) setWHOISInfo(ctx context.Context, ip netip.Addr, wi *whois.Info) { _, ok := s.index.findByIP(ip) if ok { - log.Debug("storage: client for %s is already created, ignore whois info", ip) + s.logger.DebugContext( + ctx, + "persistent client is already created, ignore whois info", + "ip", ip, + ) return } @@ -358,14 +388,14 @@ func (s *Storage) setWHOISInfo(ip netip.Addr, wi *whois.Info) { rc.setWHOIS(wi) - log.Debug("storage: set whois info for runtime client with ip %s: %+v", ip, wi) + s.logger.DebugContext(ctx, "set whois info for runtime client", "ip", ip, "whois", wi) } // Add stores persistent client information or returns an error. -func (s *Storage) Add(p *Persistent) (err error) { +func (s *Storage) Add(ctx context.Context, p *Persistent) (err error) { defer func() { err = errors.Annotate(err, "adding client: %w") }() - err = p.validate(s.allowedTags) + err = p.validate(ctx, s.logger, s.allowedTags) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err @@ -388,7 +418,13 @@ func (s *Storage) Add(p *Persistent) (err error) { s.index.add(p) - log.Debug("client storage: added %q: IDs: %q [%d]", p.Name, p.IDs(), s.index.size()) + s.logger.DebugContext( + ctx, + "client added", + "name", p.Name, + "ids", p.IDs(), + "clients_count", s.index.size(), + ) return nil } @@ -490,10 +526,10 @@ func (s *Storage) RemoveByName(name string) (ok bool) { // Update finds the stored persistent client by its name and updates its // information from p. -func (s *Storage) Update(name string, p *Persistent) (err error) { +func (s *Storage) Update(ctx context.Context, name string, p *Persistent) (err error) { defer func() { err = errors.Annotate(err, "updating client: %w") }() - err = p.validate(s.allowedTags) + err = p.validate(ctx, s.logger, s.allowedTags) if err != nil { // Don't wrap the error since there is already an annotation deferred. return err diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 60d766d0..a2101013 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -15,11 +15,25 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/golibs/hostsfile" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// newTestStorage is a helper function that returns initialized storage. +func newTestStorage(tb testing.TB) (s *client.Storage) { + tb.Helper() + + ctx := testutil.ContextWithTimeout(tb, testTimeout) + s, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + }) + require.NoError(tb, err) + + return s +} + // testHostsContainer is a mock implementation of the [client.HostsContainer] // interface. type testHostsContainer struct { @@ -110,7 +124,9 @@ func TestStorage_Add_hostsfile(t *testing.T) { onUpd: func() (updates <-chan *hostsfile.DefaultStorage) { return hostCh }, } - storage, err := client.NewStorage(&client.StorageConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), DHCP: client.EmptyDHCP{}, EtcHosts: h, ARPClientsUpdatePeriod: testTimeout / 10, @@ -198,7 +214,9 @@ func TestStorage_Add_arp(t *testing.T) { }, } - storage, err := client.NewStorage(&client.StorageConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), DHCP: client.EmptyDHCP{}, ARPDB: a, ARPClientsUpdatePeriod: testTimeout / 10, @@ -273,8 +291,10 @@ func TestStorage_Add_whois(t *testing.T) { cliName3 = "client_three" ) - storage, err := client.NewStorage(&client.StorageConfig{ - DHCP: client.EmptyDHCP{}, + ctx := testutil.ContextWithTimeout(t, testTimeout) + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + DHCP: client.EmptyDHCP{}, }) require.NoError(t, err) @@ -284,7 +304,7 @@ func TestStorage_Add_whois(t *testing.T) { } t.Run("new_client", func(t *testing.T) { - storage.UpdateAddress(cliIP1, "", whois) + storage.UpdateAddress(ctx, cliIP1, "", whois) cli1 := storage.ClientRuntime(cliIP1) require.NotNil(t, cli1) @@ -292,8 +312,8 @@ func TestStorage_Add_whois(t *testing.T) { }) t.Run("existing_runtime_client", func(t *testing.T) { - storage.UpdateAddress(cliIP2, cliName2, nil) - storage.UpdateAddress(cliIP2, "", whois) + storage.UpdateAddress(ctx, cliIP2, cliName2, nil) + storage.UpdateAddress(ctx, cliIP2, "", whois) cli2 := storage.ClientRuntime(cliIP2) require.NotNil(t, cli2) @@ -304,14 +324,14 @@ func TestStorage_Add_whois(t *testing.T) { }) t.Run("can't_set_persistent_client", func(t *testing.T) { - err = storage.Add(&client.Persistent{ + err = storage.Add(ctx, &client.Persistent{ Name: cliName3, UID: client.MustNewUID(), IPs: []netip.Addr{cliIP3}, }) require.NoError(t, err) - storage.UpdateAddress(cliIP3, "", whois) + storage.UpdateAddress(ctx, cliIP3, "", whois) rc := storage.ClientRuntime(cliIP3) require.Nil(t, rc) }) @@ -364,7 +384,9 @@ func TestClientsDHCP(t *testing.T) { }, } - storage, err := client.NewStorage(&client.StorageConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), DHCP: d, RuntimeSourceDHCP: true, }) @@ -378,7 +400,7 @@ func TestClientsDHCP(t *testing.T) { }) t.Run("find_persistent", func(t *testing.T) { - err = storage.Add(&client.Persistent{ + err = storage.Add(ctx, &client.Persistent{ Name: prsCliName, UID: client.MustNewUID(), MACs: []net.HardwareAddr{prsCliMAC}, @@ -393,7 +415,7 @@ func TestClientsDHCP(t *testing.T) { t.Run("leases", func(t *testing.T) { delete(ipToHost, cliIP1) - storage.UpdateDHCP() + storage.UpdateDHCP(ctx) cli1 := storage.ClientRuntime(cliIP1) require.Nil(t, cli1) @@ -421,16 +443,19 @@ func TestClientsDHCP(t *testing.T) { } func TestClientsAddExisting(t *testing.T) { + ctx := testutil.ContextWithTimeout(t, testTimeout) + t.Run("simple", func(t *testing.T) { - storage, err := client.NewStorage(&client.StorageConfig{ - DHCP: client.EmptyDHCP{}, + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + DHCP: client.EmptyDHCP{}, }) require.NoError(t, err) ip := netip.MustParseAddr("1.1.1.1") // Add a client. - err = storage.Add(&client.Persistent{ + err = storage.Add(ctx, &client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")}, @@ -440,7 +465,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Now add an auto-client with the same IP. - storage.UpdateAddress(ip, "test", nil) + storage.UpdateAddress(ctx, ip, "test", nil) rc := storage.ClientRuntime(ip) assert.True(t, compareRuntimeInfo(rc, client.SourceRDNS, "test")) }) @@ -468,8 +493,9 @@ func TestClientsAddExisting(t *testing.T) { dhcpServer, err := dhcpd.Create(config) require.NoError(t, err) - storage, err := client.NewStorage(&client.StorageConfig{ - DHCP: dhcpServer, + storage, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + DHCP: dhcpServer, }) require.NoError(t, err) @@ -484,7 +510,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the same IP as for a client with MAC. - err = storage.Add(&client.Persistent{ + err = storage.Add(ctx, &client.Persistent{ Name: "client2", UID: client.MustNewUID(), IPs: []netip.Addr{ip}, @@ -492,7 +518,7 @@ func TestClientsAddExisting(t *testing.T) { require.NoError(t, err) // Add a new client with the IP from the first client's IP range. - err = storage.Add(&client.Persistent{ + err = storage.Add(ctx, &client.Persistent{ Name: "client3", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")}, @@ -506,14 +532,16 @@ func TestClientsAddExisting(t *testing.T) { func newStorage(tb testing.TB, m []*client.Persistent) (s *client.Storage) { tb.Helper() - s, err := client.NewStorage(&client.StorageConfig{ - DHCP: client.EmptyDHCP{}, + ctx := testutil.ContextWithTimeout(tb, testTimeout) + s, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + DHCP: client.EmptyDHCP{}, }) require.NoError(tb, err) for _, c := range m { c.UID = client.MustNewUID() - require.NoError(tb, s.Add(c)) + require.NoError(tb, s.Add(ctx, c)) } require.Equal(tb, len(m), s.Size()) @@ -555,9 +583,8 @@ func TestStorage_Add(t *testing.T) { UID: existingClientUID, } - s, err := client.NewStorage(&client.StorageConfig{}) - require.NoError(t, err) - + ctx := testutil.ContextWithTimeout(t, testTimeout) + s := newTestStorage(t) tags := s.AllowedTags() require.NotZero(t, len(tags)) require.True(t, slices.IsSorted(tags)) @@ -568,7 +595,7 @@ func TestStorage_Add(t *testing.T) { _, ok = slices.BinarySearch(tags, notAllowedTag) require.False(t, ok) - err = s.Add(existingClient) + err := s.Add(ctx, existingClient) require.NoError(t, err) testCases := []struct { @@ -669,7 +696,7 @@ func TestStorage_Add(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - err = s.Add(tc.cli) + err = s.Add(ctx, tc.cli) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) @@ -687,10 +714,9 @@ func TestStorage_RemoveByName(t *testing.T) { UID: client.MustNewUID(), } - s, err := client.NewStorage(&client.StorageConfig{}) - require.NoError(t, err) - - err = s.Add(existingClient) + ctx := testutil.ContextWithTimeout(t, testTimeout) + s := newTestStorage(t) + err := s.Add(ctx, existingClient) require.NoError(t, err) testCases := []struct { @@ -714,10 +740,8 @@ func TestStorage_RemoveByName(t *testing.T) { } t.Run("duplicate_remove", func(t *testing.T) { - s, err = client.NewStorage(&client.StorageConfig{}) - require.NoError(t, err) - - err = s.Add(existingClient) + s = newTestStorage(t) + err = s.Add(ctx, existingClient) require.NoError(t, err) assert.True(t, s.RemoveByName(existingName)) @@ -1080,6 +1104,7 @@ func TestStorage_Update(t *testing.T) { `uses the same ClientID "obstructing_client_id"`, }} + ctx := testutil.ContextWithTimeout(t, testTimeout) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s := newStorage( @@ -1090,7 +1115,7 @@ func TestStorage_Update(t *testing.T) { }, ) - err := s.Update(clientName, tc.cli) + err := s.Update(ctx, clientName, tc.cli) testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 1a3dbc2e..66baf368 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -2,6 +2,7 @@ package dnsforward import ( "cmp" + "context" "encoding/binary" "net" "net/netip" @@ -203,7 +204,8 @@ func (s *Server) processClientIP(addr netip.Addr) { s.serverLock.RLock() defer s.serverLock.RUnlock() - s.addrProc.Process(addr) + // TODO(s.chzhen): Pass context. + s.addrProc.Process(context.TODO(), addr) } // processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 9781b3b0..2ba42fa2 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -2,6 +2,7 @@ package dnsforward import ( "cmp" + "context" "net" "net/netip" "testing" @@ -90,7 +91,7 @@ func TestServer_ProcessInitial(t *testing.T) { var gotAddr netip.Addr s.addrProc = &aghtest.AddressProcessor{ - OnProcess: func(ip netip.Addr) { gotAddr = ip }, + OnProcess: func(ctx context.Context, ip netip.Addr) { gotAddr = ip }, OnClose: func() (err error) { panic("not implemented") }, } diff --git a/internal/home/clients.go b/internal/home/clients.go index 59b9cbe3..4aff4b25 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -107,7 +107,8 @@ func (clients *clientsContainer) Init( hosts = etcHosts } - clients.storage, err = client.NewStorage(&client.StorageConfig{ + clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ + Logger: baseLogger.With(slogutil.KeyPrefix, "client_storage"), InitialClients: confClients, DHCP: dhcpServer, EtcHosts: hosts, @@ -417,7 +418,8 @@ func (clients *clientsContainer) UpstreamConfigByID( ) c.UpstreamConfig = conf - err = clients.storage.Update(c.Name, c) + // TODO(s.chzhen): Pass context. + err = clients.storage.Update(context.TODO(), c.Name, c) if err != nil { return nil, fmt.Errorf("setting upstream config: %w", err) } @@ -430,8 +432,13 @@ var _ client.AddressUpdater = (*clientsContainer)(nil) // UpdateAddress implements the [client.AddressUpdater] interface for // *clientsContainer -func (clients *clientsContainer) UpdateAddress(ip netip.Addr, host string, info *whois.Info) { - clients.storage.UpdateAddress(ip, host, info) +func (clients *clientsContainer) UpdateAddress( + ctx context.Context, + ip netip.Addr, + host string, + info *whois.Info, +) { + clients.storage.UpdateAddress(ctx, ip, host, info) } // close gracefully closes all the client-specific upstream configurations of diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 13aed3cd..ad44e2e3 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -40,9 +40,10 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { func TestClientsCustomUpstream(t *testing.T) { clients := newClientsContainer(t) + ctx := testutil.ContextWithTimeout(t, testTimeout) // Add client with upstreams. - err := clients.storage.Add(&client.Persistent{ + err := clients.storage.Add(ctx, &client.Persistent{ Name: "client1", UID: client.MustNewUID(), IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")}, diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 3cc5a0c6..6d06ae00 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -106,7 +106,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http return true }) - clients.storage.UpdateDHCP() + clients.storage.UpdateDHCP(r.Context()) clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) { src, host := rc.Info() @@ -341,7 +341,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - err = clients.storage.Add(c) + err = clients.storage.Add(r.Context(), c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) @@ -411,7 +411,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - err = clients.storage.Update(dj.Name, c) + err = clients.storage.Update(r.Context(), dj.Name, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) diff --git a/internal/home/clientshttp_internal_test.go b/internal/home/clientshttp_internal_test.go index c9be42b7..a10ca8d1 100644 --- a/internal/home/clientshttp_internal_test.go +++ b/internal/home/clientshttp_internal_test.go @@ -203,13 +203,14 @@ func TestClientsContainer_HandleAddClient(t *testing.T) { func TestClientsContainer_HandleDelClient(t *testing.T) { clients := newClientsContainer(t) + ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.storage.Add(clientTwo) + err = clients.storage.Add(ctx, clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) @@ -265,9 +266,10 @@ func TestClientsContainer_HandleDelClient(t *testing.T) { func TestClientsContainer_HandleUpdateClient(t *testing.T) { clients := newClientsContainer(t) + ctx := testutil.ContextWithTimeout(t, testTimeout) clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne}) @@ -348,12 +350,14 @@ func TestClientsContainer_HandleFindClient(t *testing.T) { }, } + ctx := testutil.ContextWithTimeout(t, testTimeout) + clientOne := newPersistentClientWithIDs(t, "client1", []string{testClientIP1}) - err := clients.storage.Add(clientOne) + err := clients.storage.Add(ctx, clientOne) require.NoError(t, err) clientTwo := newPersistentClientWithIDs(t, "client2", []string{testClientIP2}) - err = clients.storage.Add(clientTwo) + err = clients.storage.Add(ctx, clientTwo) require.NoError(t, err) assertPersistentClients(t, clients, []*client.Persistent{clientOne, clientTwo}) diff --git a/internal/home/dns_internal_test.go b/internal/home/dns_internal_test.go index d3712890..b04ee648 100644 --- a/internal/home/dns_internal_test.go +++ b/internal/home/dns_internal_test.go @@ -7,6 +7,8 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/schedule" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -18,12 +20,15 @@ var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4}) func newStorage(tb testing.TB, clients []*client.Persistent) (s *client.Storage) { tb.Helper() - s, err := client.NewStorage(&client.StorageConfig{}) + ctx := testutil.ContextWithTimeout(tb, testTimeout) + s, err := client.NewStorage(ctx, &client.StorageConfig{ + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(tb, err) for _, p := range clients { p.UID = client.MustNewUID() - require.NoError(tb, s.Add(p)) + require.NoError(tb, s.Add(ctx, p)) } return s diff --git a/internal/home/home.go b/internal/home/home.go index 0c53a878..b986e972 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -115,15 +115,16 @@ func Main(clientBuildFS fs.FS) { signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) go func() { + ctx := context.Background() for { sig := <-signals log.Info("Received signal %q", sig) switch sig { case syscall.SIGHUP: - Context.clients.storage.ReloadARP() + Context.clients.storage.ReloadARP(ctx) Context.tls.reload() default: - cleanup(context.Background()) + cleanup(ctx) cleanupAlways() close(done) }