diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 67aad623..a0b10520 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -43,6 +43,9 @@ type HostsContainer struct { // engine serves rulesStrg. engine *urlfilter.DNSEngine + // done is the channel to sign closing the container. + done chan struct{} + // updates is the channel for receiving updated hosts. updates chan *netutil.IPMap // last is the set of hosts that was cached within last detected change. @@ -57,12 +60,12 @@ type HostsContainer struct { patterns []string } -// errNoPaths is returned when there are no paths to watch passed to the -// HostsContainer. -const errNoPaths errors.Error = "hosts paths are empty" +// ErrNoHostsPaths is returned when there are no valid paths to watch passed to +// the HostsContainer. +const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided" // NewHostsContainer creates a container of hosts, that watches the paths with -// w. paths shouldn't be empty and each of them should locate either a file or +// w. paths shouldn't be empty and each of paths should locate either a file or // a directory in fsys. fsys and w must be non-nil. func NewHostsContainer( fsys fs.FS, @@ -72,16 +75,20 @@ func NewHostsContainer( defer func() { err = errors.Annotate(err, "%s: %w", hostsContainerPref) }() if len(paths) == 0 { - return nil, errNoPaths + return nil, ErrNoHostsPaths } - patterns, err := pathsToPatterns(fsys, paths) + var patterns []string + patterns, err = pathsToPatterns(fsys, paths) if err != nil { return nil, err + } else if len(patterns) == 0 { + return nil, ErrNoHostsPaths } hc = &HostsContainer{ engLock: &sync.RWMutex{}, + done: make(chan struct{}, 1), updates: make(chan *netutil.IPMap, 1), last: &netutil.IPMap{}, fsys: fsys, @@ -97,16 +104,13 @@ func NewHostsContainer( } for _, p := range paths { - err = w.Add(p) - if err == nil { - continue - } else if errors.Is(err, fs.ErrNotExist) { + if err = w.Add(p); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return nil, fmt.Errorf("adding path: %w", err) + } + log.Debug("%s: file %q expected to exist but doesn't", hostsContainerPref, p) - - continue } - - return nil, fmt.Errorf("adding path: %w", err) } go hc.handleEvents() @@ -130,16 +134,17 @@ func (hc *HostsContainer) MatchRequest( hc.engLock.RLock() defer hc.engLock.RUnlock() - res, ok = hc.engine.MatchRequest(req) - - return res, ok + return hc.engine.MatchRequest(req) } -// Close implements the io.Closer interface for *HostsContainer. +// Close implements the io.Closer interface for *HostsContainer. Close must +// only be called once. The returned err is always nil. func (hc *HostsContainer) Close() (err error) { log.Debug("%s: closing", hostsContainerPref) - return errors.Annotate(hc.w.Close(), "%s: closing: %w", hostsContainerPref) + close(hc.done) + + return nil } // Upd returns the channel into which the updates are sent. The receivable @@ -152,8 +157,14 @@ func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) { for i, p := range paths { var fi fs.FileInfo - if fi, err = fs.Stat(fsys, p); err != nil { - return nil, fmt.Errorf("%q at index %d: %w", p, i, err) + fi, err = fs.Stat(fsys, p) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + continue + } + + // Don't put a filename here since it's already added by fs.Stat. + return nil, fmt.Errorf("path at index %d: %w", i, err) } if fi.IsDir() { @@ -173,9 +184,21 @@ func (hc *HostsContainer) handleEvents() { defer close(hc.updates) - for range hc.w.Events() { - if err := hc.refresh(); err != nil { - log.Error("%s: %s", hostsContainerPref, err) + ok, eventsCh := true, hc.w.Events() + for ok { + select { + case _, ok = <-eventsCh: + if !ok { + log.Debug("%s: watcher closed the events channel", hostsContainerPref) + + continue + } + + if err := hc.refresh(); err != nil { + log.Error("%s: %s", hostsContainerPref, err) + } + case _, ok = <-hc.done: + // Go on. } } } @@ -373,6 +396,8 @@ func (hp *hostsParser) newStrg() (s *filterlist.RuleStorage, err error) { // refresh gets the data from specified files and propagates the updates if // needed. +// +// TODO(e.burkov): Accept a parameter to specify the files to refresh. func (hc *HostsContainer) refresh() (err error) { log.Debug("%s: refreshing", hostsContainerPref) diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index b30790a3..5a08b24f 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -24,14 +24,6 @@ const ( sp = " " ) -const closeCalled errors.Error = "close method called" - -// fsWatcherOnCloseStub is a stub implementation of the Close method of -// aghos.FSWatcher. -func fsWatcherOnCloseStub() (err error) { - return closeCalled -} - func TestNewHostsContainer(t *testing.T) { const dirname = "dir" const filename = "file1" @@ -43,30 +35,25 @@ func TestNewHostsContainer(t *testing.T) { } testCases := []struct { - name string - paths []string - wantErr error - wantPatterns []string + wantErr error + name string + paths []string }{{ - name: "one_file", - paths: []string{p}, - wantErr: nil, - wantPatterns: []string{p}, + wantErr: nil, + name: "one_file", + paths: []string{p}, }, { - name: "no_files", - paths: []string{}, - wantErr: errNoPaths, - wantPatterns: nil, + wantErr: ErrNoHostsPaths, + name: "no_files", + paths: []string{}, }, { - name: "non-existent_file", - paths: []string{path.Join(dirname, filename+"2")}, - wantErr: fs.ErrNotExist, - wantPatterns: nil, + wantErr: ErrNoHostsPaths, + name: "non-existent_file", + paths: []string{path.Join(dirname, filename+"2")}, }, { - name: "whole_dir", - paths: []string{dirname}, - wantErr: nil, - wantPatterns: []string{path.Join(dirname, "*")}, + wantErr: nil, + name: "whole_dir", + paths: []string{dirname}, }} for _, tc := range testCases { @@ -88,7 +75,7 @@ func TestNewHostsContainer(t *testing.T) { hc, err := NewHostsContainer(testFS, &aghtest.FSWatcher{ OnEvents: onEvents, OnAdd: onAdd, - OnClose: fsWatcherOnCloseStub, + OnClose: func() (err error) { panic("not implemented") }, }, tc.paths...) if tc.wantErr != nil { require.ErrorIs(t, err, tc.wantErr) @@ -99,13 +86,8 @@ func TestNewHostsContainer(t *testing.T) { } require.NoError(t, err) - t.Cleanup(func() { - require.ErrorIs(t, hc.Close(), closeCalled) - }) - require.NotNil(t, hc) - assert.Equal(t, tc.wantPatterns, hc.patterns) assert.NotNil(t, <-hc.Upd()) eventsCh <- struct{}{} @@ -178,12 +160,11 @@ func TestHostsContainer_Refresh(t *testing.T) { return nil }, - OnClose: fsWatcherOnCloseStub, + OnClose: func() (err error) { panic("not implemented") }, } hc, err := NewHostsContainer(testFS, w, dirname) require.NoError(t, err) - t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) }) checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) { upd, ok := <-hc.Upd() @@ -257,10 +238,9 @@ func TestHostsContainer_MatchRequest(t *testing.T) { return nil }, - OnClose: fsWatcherOnCloseStub, + OnClose: func() (err error) { panic("not implemented") }, }, filename) require.NoError(t, err) - t.Cleanup(func() { require.ErrorIs(t, hc.Close(), closeCalled) }) testCase := []struct { name string @@ -398,7 +378,7 @@ func TestHostsContainer_PathsToPatterns(t *testing.T) { paths: []string{fp1, path.Join(dir0, dir1)}, }, { name: "non-existing", - wantErr: fs.ErrNotExist, + wantErr: nil, want: nil, paths: []string{path.Join(dir0, "file_3")}, }} @@ -417,6 +397,19 @@ func TestHostsContainer_PathsToPatterns(t *testing.T) { assert.Equal(t, tc.want, patterns) }) } + + t.Run("bad_file", func(t *testing.T) { + const errStat errors.Error = "bad file" + + badFS := &aghtest.StatFS{ + OnStat: func(name string) (fs.FileInfo, error) { + return nil, errStat + }, + } + + _, err := pathsToPatterns(badFS, []string{""}) + assert.ErrorIs(t, err, errStat) + }) } func TestUniqueRules_AddPair(t *testing.T) { diff --git a/internal/aghos/fswatcher.go b/internal/aghos/fswatcher.go index a113610f..8f5d1d60 100644 --- a/internal/aghos/fswatcher.go +++ b/internal/aghos/fswatcher.go @@ -82,6 +82,8 @@ func (w *osWatcher) Events() (e <-chan event) { } // Add implements the FSWatcher interface for *osWatcher. +// +// TODO(e.burkov): Make it accept non-existing files to detect it's creating. func (w *osWatcher) Add(name string) (err error) { defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }() diff --git a/internal/aghtest/testfs.go b/internal/aghtest/testfs.go new file mode 100644 index 00000000..88203fec --- /dev/null +++ b/internal/aghtest/testfs.go @@ -0,0 +1,46 @@ +package aghtest + +import "io/fs" + +// type check +var _ fs.FS = &FS{} + +// FS is a mock fs.FS implementation to use in tests. +type FS struct { + OnOpen func(name string) (fs.File, error) +} + +// Open implements the fs.FS interface for *FS. +func (fsys *FS) Open(name string) (fs.File, error) { + return fsys.OnOpen(name) +} + +// type check +var _ fs.StatFS = &StatFS{} + +// StatFS is a mock fs.StatFS implementation to use in tests. +type StatFS struct { + // FS is embedded here to avoid implementing all it's methods. + FS + OnStat func(name string) (fs.FileInfo, error) +} + +// Stat implements the fs.StatFS interface for *StatFS. +func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { + return fsys.OnStat(name) +} + +// type check +var _ fs.GlobFS = &GlobFS{} + +// GlobFS is a mock fs.GlobFS implementation to use in tests. +type GlobFS struct { + // FS is embedded here to avoid implementing all it's methods. + FS + OnGlob func(pattern string) ([]string, error) +} + +// Glob implements the fs.GlobFS interface for *GlobFS. +func (fsys *GlobFS) Glob(pattern string) ([]string, error) { + return fsys.OnGlob(pattern) +} diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 701d1c5d..6f0ccc70 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1077,8 +1077,6 @@ func TestPTRResponseFromHosts(t *testing.T) { `)}, } - const closeCalled errors.Error = "close method called" - var eventsCalledCounter uint32 hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{ OnEvents: func() (e <-chan struct{}) { @@ -1091,13 +1089,11 @@ func TestPTRResponseFromHosts(t *testing.T) { return nil }, - OnClose: func() (err error) { return closeCalled }, + OnClose: func() (err error) { panic("not implemented") }, }, hostsFilename) require.NoError(t, err) t.Cleanup(func() { assert.Equal(t, uint32(1), atomic.LoadUint32(&eventsCalledCounter)) - - require.ErrorIs(t, hc.Close(), closeCalled) }) flt := filtering.New(&filtering.Config{ diff --git a/internal/home/home.go b/internal/home/home.go index c5a6a5e6..1148ed6a 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -59,7 +59,10 @@ type homeContext struct { // etcHosts is an IP-hostname pairs set taken from system configuration // (e.g. /etc/hosts) files. etcHosts *aghnet.HostsContainer - updater *updater.Updater + // hostsWatcher is the watcher to detect changes in the hosts files. + hostsWatcher aghos.FSWatcher + + updater *updater.Updater subnetDetector *aghnet.SubnetDetector @@ -232,6 +235,33 @@ func configureOS(conf *configuration) (err error) { return nil } +// setupHostsContainer initializes the structures to keep up-to-date the hosts +// provided by the OS. +func setupHostsContainer() (err error) { + Context.hostsWatcher, err = aghos.NewOSWritesWatcher() + if err != nil { + return fmt.Errorf("initing hosts watcher: %w", err) + } + + Context.etcHosts, err = aghnet.NewHostsContainer( + aghos.RootDirFS(), + Context.hostsWatcher, + aghnet.DefaultHostsPaths()..., + ) + if err != nil { + cerr := Context.hostsWatcher.Close() + if errors.Is(err, aghnet.ErrNoHostsPaths) && cerr == nil { + log.Info("warning: initing hosts container: %s", err) + + return nil + } + + return errors.WithDeferred(fmt.Errorf("initing hosts container: %w", err), cerr) + } + + return nil +} + func setupConfig(args options) (err error) { config.DHCP.WorkDir = Context.workDir config.DHCP.HTTPRegister = httpRegister @@ -259,19 +289,8 @@ func setupConfig(args options) (err error) { }) if !args.noEtcHosts { - var osWritesWatcher aghos.FSWatcher - osWritesWatcher, err = aghos.NewOSWritesWatcher() - if err != nil { - return fmt.Errorf("initing os watcher: %w", err) - } - - Context.etcHosts, err = aghnet.NewHostsContainer( - aghos.RootDirFS(), - osWritesWatcher, - aghnet.DefaultHostsPaths()..., - ) - if err != nil { - return fmt.Errorf("initing hosts container: %w", err) + if err = setupHostsContainer(); err != nil { + return err } } Context.clients.Init(config.Clients, Context.dhcpServer, Context.etcHosts) @@ -661,8 +680,15 @@ func cleanup(ctx context.Context) { } if Context.etcHosts != nil { + // Currently Context.hostsWatcher is only used in Context.etcHosts and + // needs closing only in case of the successful initialization of + // Context.etcHosts. + if err = Context.hostsWatcher.Close(); err != nil { + log.Error("closing hosts watcher: %s", err) + } + if err = Context.etcHosts.Close(); err != nil { - log.Error("stopping hosts container: %s", err) + log.Error("closing hosts container: %s", err) } }