From 2c7efa46099d9b8ffe297ce247aff0aa8f45dff7 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Tue, 18 Jun 2024 20:14:34 +0300 Subject: [PATCH] client: add tests --- internal/client/storage.go | 30 +++++++++++++++--- internal/client/storage_test.go | 55 +++++++++++++++++++++++++++++++-- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/internal/client/storage.go b/internal/client/storage.go index 23ab770b..df8ad2f4 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -55,12 +55,17 @@ func (s *Storage) Add(p *Persistent) (err error) { } // Find finds persistent client by string representation of the client ID, IP -// address, or MAC. +// address, or MAC. And returns it shallow copy. func (s *Storage) Find(id string) (p *Persistent, ok bool) { s.mu.Lock() defer s.mu.Unlock() - return s.index.Find(id) + p, ok = s.index.Find(id) + if ok { + return p.ShallowClone(), ok + } + + return nil, false } // FindLoose is like [Storage.Find] but it also tries to find a persistent @@ -76,12 +81,12 @@ func (s *Storage) FindLoose(ip netip.Addr, id string) (p *Persistent, ok bool) { p, ok = s.index.Find(id) if ok { - return p, ok + return p.ShallowClone(), ok } p = s.index.FindByIPWithoutZone(ip) if p != nil { - return p, true + return p.ShallowClone(), true } return nil, false @@ -133,6 +138,23 @@ func (s *Storage) Update(name string, n *Persistent) (err error) { return nil } +// RangeByName calls f for each persistent client sorted by name, unless cont is +// false. +func (s *Storage) RangeByName(f func(c *Persistent) (cont bool)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.index.RangeByName(f) +} + +// CloseUpstreams closes upstream configurations of persistent clients. +func (s *Storage) CloseUpstreams() (err error) { + s.mu.Lock() + defer s.mu.Unlock() + + return s.index.CloseUpstreams() +} + // ClientRuntime returns the saved runtime client by ip. If no such client // exists, returns nil. func (s *Storage) ClientRuntime(ip netip.Addr) (rc *Runtime) { diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index fde65cc4..b00b1f69 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -294,7 +294,7 @@ func TestStorage_FindLoose(t *testing.T) { } ) - ci := newStorage( + s := newStorage( t, []*client.Persistent{ clientNoZone, @@ -326,7 +326,7 @@ func TestStorage_FindLoose(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - c, ok := ci.FindLoose(tc.ip.WithZone(""), nonExistingClientID) + c, ok := s.FindLoose(tc.ip.WithZone(""), nonExistingClientID) assert.Equal(t, tc.wantCli, c) tc.want(t, ok) }) @@ -416,3 +416,54 @@ func TestStorage_Update(t *testing.T) { }) } } + +func TestStorage_RangeByName(t *testing.T) { + sortedClients := []*client.Persistent{{ + Name: "clientA", + ClientIDs: []string{"A"}, + }, { + Name: "clientB", + ClientIDs: []string{"B"}, + }, { + Name: "clientC", + ClientIDs: []string{"C"}, + }, { + Name: "clientD", + ClientIDs: []string{"D"}, + }, { + Name: "clientE", + ClientIDs: []string{"E"}, + }} + + testCases := []struct { + name string + want []*client.Persistent + }{{ + name: "basic", + want: sortedClients, + }, { + name: "nil", + want: nil, + }, { + name: "one_element", + want: sortedClients[:1], + }, { + name: "two_elements", + want: sortedClients[:2], + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := newStorage(t, tc.want) + + var got []*client.Persistent + s.RangeByName(func(c *client.Persistent) (cont bool) { + got = append(got, c) + + return true + }) + + assert.Equal(t, tc.want, got) + }) + } +}