diff --git a/internal/client/runtimeindex.go b/internal/client/runtimeindex.go index 10ee0b47..fc994bfa 100644 --- a/internal/client/runtimeindex.go +++ b/internal/client/runtimeindex.go @@ -28,8 +28,8 @@ func (ri *runtimeIndex) add(rc *Runtime) { ri.index[ip] = rc } -// rangeF calls f for each runtime client in an undefined order. -func (ri *runtimeIndex) rangeF(f func(rc *Runtime) (cont bool)) { +// rangeClients calls f for each runtime client in an undefined order. +func (ri *runtimeIndex) rangeClients(f func(rc *Runtime) (cont bool)) { for _, rc := range ri.index { if !f(rc) { return diff --git a/internal/client/storage.go b/internal/client/storage.go index eb7d3625..18d1bdda 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -125,14 +125,6 @@ func NewStorage(conf *Config) (s *Storage, err error) { done: make(chan struct{}), } - // TODO(s.chzhen): Refactor it. - switch v := s.etcHosts.(type) { - case *aghnet.HostsContainer: - if v == nil { - s.etcHosts = nil - } - } - for i, p := range conf.InitialClients { err = s.Add(p) if err != nil { @@ -140,6 +132,12 @@ func NewStorage(conf *Config) (s *Storage, err error) { } } + if hc, ok := s.etcHosts.(*aghnet.HostsContainer); ok && hc == nil { + s.etcHosts = nil + } + + s.ReloadARP() + return s, nil } @@ -163,9 +161,6 @@ func (s *Storage) Shutdown(_ context.Context) (err error) { func (s *Storage) periodicARPUpdate() { defer log.OnPanic("storage") - // Initial ARP refresh. - s.ReloadARP() - t := time.NewTicker(s.arpClientsUpdatePeriod) for { @@ -376,6 +371,9 @@ func (s *Storage) FindByName(name string) (p *Persistent, ok bool) { // Find finds persistent client by string representation of the client ID, IP // address, or MAC. And returns its shallow copy. +// +// TODO(s.chzhen): Accept ClientIDData structure instead, which will contain +// the parsed IP address, if any. func (s *Storage) Find(id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() @@ -540,5 +538,5 @@ func (s *Storage) RangeRuntime(f func(rc *Runtime) (cont bool)) { s.mu.Lock() defer s.mu.Unlock() - s.runtimeIndex.rangeF(f) + s.runtimeIndex.rangeClients(f) } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 7f9ffe60..ad6b7f9b 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -44,22 +44,22 @@ type Interface interface { Neighbors() (ns []arpdb.Neighbor) } -// testARP is a mock implementation of the [arpdb.Interface]. -type testARP struct { +// testARPDB is a mock implementation of the [arpdb.Interface]. +type testARPDB struct { onRefresh func() (err error) onNeighbors func() (ns []arpdb.Neighbor) } // type check -var _ arpdb.Interface = (*testARP)(nil) +var _ arpdb.Interface = (*testARPDB)(nil) // Refresh implements the [arpdb.Interface] interface for *testARP. -func (c *testARP) Refresh() (err error) { +func (c *testARPDB) Refresh() (err error) { return c.onRefresh() } // Neighbors implements the [arpdb.Interface] interface for *testARP. -func (c *testARP) Neighbors() (ns []arpdb.Neighbor) { +func (c *testARPDB) Neighbors() (ns []arpdb.Neighbor) { return c.onNeighbors() } @@ -186,7 +186,7 @@ func TestStorage_Add_arp(t *testing.T) { cliName2 = "client_two" ) - a := &testARP{ + a := &testARPDB{ onRefresh: func() (err error) { return nil }, onNeighbors: func() (ns []arpdb.Neighbor) { mu.Lock() @@ -327,7 +327,7 @@ func TestClientsDHCP(t *testing.T) { prsCliIP = netip.MustParseAddr("4.3.2.1") prsCliMAC = mustParseMAC("AA:AA:AA:AA:AA:AA") - prsCliName = "persitent.dhcp" + prsCliName = "persistent.dhcp" ) ipToHost := map[netip.Addr]string{ diff --git a/internal/home/dns.go b/internal/home/dns.go index a5dc7cad..0159fe29 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -460,12 +460,12 @@ func startDNSServer() error { err := Context.clients.Start(context.TODO()) if err != nil { - return fmt.Errorf("couldn't start clients container: %w", err) + return fmt.Errorf("starting clients container: %w", err) } err = Context.dnsServer.Start() if err != nil { - return fmt.Errorf("couldn't start forwarding DNS server: %w", err) + return fmt.Errorf("starting dns server: %w", err) } Context.filters.Start()