diff --git a/internal/aghchan/aghchan.go b/internal/aghchan/aghchan.go deleted file mode 100644 index 1da1790a..00000000 --- a/internal/aghchan/aghchan.go +++ /dev/null @@ -1,33 +0,0 @@ -// Package aghchan contains channel utilities. -package aghchan - -import ( - "fmt" - "time" -) - -// Receive returns an error if it cannot receive a value form c before timeout -// runs out. -func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) { - var zero T - timeoutCh := time.After(timeout) - select { - case <-timeoutCh: - // TODO(a.garipov): Consider implementing [errors.Aser] for - // os.ErrTimeout. - return zero, false, fmt.Errorf("did not receive after %s", timeout) - case v, ok = <-c: - return v, ok, nil - } -} - -// MustReceive panics if it cannot receive a value form c before timeout runs -// out. -func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) { - v, ok, err := Receive(c, timeout) - if err != nil { - panic(err) - } - - return v, ok -} diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index e27b115f..76b61b03 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -1,65 +1,23 @@ package aghnet import ( - "context" "fmt" "io" "io/fs" "net/netip" "path" - "strings" "sync/atomic" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" ) -// DefaultHostsPaths returns the slice of paths default for the operating system -// to files and directories which are containing the hosts database. The result -// is intended to be used within fs.FS so the initial slash is omitted. -func DefaultHostsPaths() (paths []string) { - return defaultHostsPaths() -} - -// MatchAddr returns the records for the IP address. -func (hc *HostsContainer) MatchAddr(ip netip.Addr) (recs []*hostsfile.Record) { - cur := hc.current.Load() - if cur == nil { - return nil - } - - return cur.addrs[ip] -} - -// MatchName returns the records for the hostname. -func (hc *HostsContainer) MatchName(name string) (recs []*hostsfile.Record) { - cur := hc.current.Load() - if cur != nil { - recs = cur.names[name] - } - - return recs -} - // hostsContainerPrefix is a prefix for logging and wrapping errors in // HostsContainer's methods. const hostsContainerPrefix = "hosts container" -// Hosts is a map of IP addresses to the records, as it primarily stored in the -// [HostsContainer]. It should not be accessed for writing since it may be read -// concurrently, users should clone it before modifying. -// -// The order of records for each address is preserved from original files, but -// the order of the addresses, being a map key, is not. -// -// TODO(e.burkov): Probably, this should be a sorted slice of records. -type Hosts map[netip.Addr][]*hostsfile.Record - // HostsContainer stores the relevant hosts database provided by the OS and // processes both A/AAAA and PTR DNS requests for those. type HostsContainer struct { @@ -67,10 +25,10 @@ type HostsContainer struct { done chan struct{} // updates is the channel for receiving updated hosts. - updates chan Hosts + updates chan *hostsfile.DefaultStorage // current is the last set of hosts parsed. - current atomic.Pointer[hostsIndex] + current atomic.Pointer[hostsfile.DefaultStorage] // fsys is the working file system to read hosts files from. fsys fs.FS @@ -111,7 +69,7 @@ func NewHostsContainer( hc = &HostsContainer{ done: make(chan struct{}, 1), - updates: make(chan Hosts, 1), + updates: make(chan *hostsfile.DefaultStorage, 1), fsys: fsys, watcher: w, patterns: patterns, @@ -152,11 +110,25 @@ func (hc *HostsContainer) Close() (err error) { return err } -// Upd returns the channel into which the updates are sent. -func (hc *HostsContainer) Upd() (updates <-chan Hosts) { +// Upd returns the channel into which the updates are sent. The updates +// themselves must not be modified. +func (hc *HostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { return hc.updates } +// type check +var _ hostsfile.Storage = (*HostsContainer)(nil) + +// ByAddr implements the [hostsfile.Storage] interface for *HostsContainer. +func (hc *HostsContainer) ByAddr(addr netip.Addr) (names []string) { + return hc.current.Load().ByAddr(addr) +} + +// ByName implements the [hostsfile.Storage] interface for *HostsContainer. +func (hc *HostsContainer) ByName(name string) (addrs []netip.Addr) { + return hc.current.Load().ByName(name) +} + // pathsToPatterns converts paths into patterns compatible with fs.Glob. func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) { for i, p := range paths { @@ -167,7 +139,7 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) continue } - // Don't put a filename here since it's already added by fs.Stat. + // Don't put a filename here since it's already added by [fs.Stat]. return nil, fmt.Errorf("path at index %d: %w", i, err) } @@ -209,7 +181,7 @@ func (hc *HostsContainer) handleEvents() { } // sendUpd tries to send the parsed data to the ch. -func (hc *HostsContainer) sendUpd(recs Hosts) { +func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) { log.Debug("%s: sending upd", hostsContainerPrefix) ch := hc.updates @@ -226,67 +198,6 @@ func (hc *HostsContainer) sendUpd(recs Hosts) { } } -// hostsIndex is a [hostsfile.Set] to enumerate all the records. -type hostsIndex struct { - // addrs maps IP addresses to the records. - addrs Hosts - - // names maps hostnames to the records. - names map[string][]*hostsfile.Record -} - -// walk is a file walking function for hostsIndex. -func (idx *hostsIndex) walk(r io.Reader) (patterns []string, cont bool, err error) { - return nil, true, hostsfile.Parse(idx, r, nil) -} - -// type check -var _ hostsfile.Set = (*hostsIndex)(nil) - -// Add implements the [hostsfile.Set] interface for *hostsIndex. -func (idx *hostsIndex) Add(rec *hostsfile.Record) { - idx.addrs[rec.Addr] = append(idx.addrs[rec.Addr], rec) - for _, name := range rec.Names { - idx.names[name] = append(idx.names[name], rec) - } -} - -// type check -var _ hostsfile.HandleSet = (*hostsIndex)(nil) - -// HandleInvalid implements the [hostsfile.HandleSet] interface for *hostsIndex. -func (idx *hostsIndex) HandleInvalid(src string, _ []byte, err error) { - lineErr := &hostsfile.LineError{} - if !errors.As(err, &lineErr) { - // Must not happen if idx passed to [hostsfile.Parse]. - return - } else if errors.Is(lineErr, hostsfile.ErrEmptyLine) { - // Ignore empty lines. - return - } - - log.Info("%s: warning: parsing %q: %s", hostsContainerPrefix, src, lineErr) -} - -// equalRecs is an equality function for [*hostsfile.Record]. -func equalRecs(a, b *hostsfile.Record) (ok bool) { - return a.Addr == b.Addr && a.Source == b.Source && slices.Equal(a.Names, b.Names) -} - -// equalRecSlices is an equality function for slices of [*hostsfile.Record]. -func equalRecSlices(a, b []*hostsfile.Record) (ok bool) { return slices.EqualFunc(a, b, equalRecs) } - -// Equal returns true if indexes are equal. -func (idx *hostsIndex) Equal(other *hostsIndex) (ok bool) { - if idx == nil { - return other == nil - } else if other == nil { - return false - } - - return maps.EqualFunc(idx.addrs, other.addrs, equalRecSlices) -} - // refresh gets the data from specified files and propagates the updates if // needed. // @@ -294,63 +205,22 @@ func (idx *hostsIndex) Equal(other *hostsIndex) (ok bool) { func (hc *HostsContainer) refresh() (err error) { log.Debug("%s: refreshing", hostsContainerPrefix) - var addrLen, nameLen int - last := hc.current.Load() - if last != nil { - addrLen, nameLen = len(last.addrs), len(last.names) - } - idx := &hostsIndex{ - addrs: make(Hosts, addrLen), - names: make(map[string][]*hostsfile.Record, nameLen), - } - - _, err = aghos.FileWalker(idx.walk).Walk(hc.fsys, hc.patterns...) + // The error is always nil here since no readers passed. + strg, _ := hostsfile.NewDefaultStorage() + _, err = aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { + // Don't wrap the error since it's already informative enough as is. + return nil, true, hostsfile.Parse(strg, r, nil) + }).Walk(hc.fsys, hc.patterns...) if err != nil { // Don't wrap the error since it's informative enough as is. return err } - // TODO(e.burkov): Serialize updates using time. - if !last.Equal(idx) { - hc.current.Store(idx) - hc.sendUpd(idx.addrs) + // TODO(e.burkov): Serialize updates using [time.Time]. + if !hc.current.Load().Equal(strg) { + hc.current.Store(strg) + hc.sendUpd(strg) } return nil } - -// type check -var _ upstream.Resolver = (*HostsContainer)(nil) - -// LookupNetIP implements the [upstream.Resolver] interface for *HostsContainer. -func (hc *HostsContainer) LookupNetIP( - ctx context.Context, - network string, - hostname string, -) (addrs []netip.Addr, err error) { - // TODO(e.burkov): Think of extracting this logic to a golibs function if - // needed anywhere else. - var isDesiredProto func(ip netip.Addr) (ok bool) - switch network { - case "ip4": - isDesiredProto = (netip.Addr).Is4 - case "ip6": - isDesiredProto = (netip.Addr).Is6 - case "ip": - isDesiredProto = func(ip netip.Addr) (ok bool) { return true } - default: - return nil, fmt.Errorf("unsupported network: %q", network) - } - - idx := hc.current.Load() - recs := idx.names[strings.ToLower(hostname)] - - addrs = make([]netip.Addr, 0, len(recs)) - for _, rec := range recs { - if isDesiredProto(rec.Addr) { - addrs = append(addrs, rec.Addr) - } - } - - return slices.Clip(addrs), nil -} diff --git a/internal/aghnet/hostscontainer_linux.go b/internal/aghnet/hostscontainer_linux.go deleted file mode 100644 index 290291e9..00000000 --- a/internal/aghnet/hostscontainer_linux.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build linux - -package aghnet - -import ( - "github.com/AdguardTeam/AdGuardHome/internal/aghos" -) - -func defaultHostsPaths() (paths []string) { - paths = []string{"etc/hosts"} - - if aghos.IsOpenWrt() { - paths = append(paths, "tmp/hosts") - } - - return paths -} diff --git a/internal/aghnet/hostscontainer_others.go b/internal/aghnet/hostscontainer_others.go deleted file mode 100644 index 61487dc4..00000000 --- a/internal/aghnet/hostscontainer_others.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !(windows || linux) - -package aghnet - -func defaultHostsPaths() (paths []string) { - return []string{"etc/hosts"} -} diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index ac30777f..813b369d 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -3,13 +3,11 @@ package aghnet_test import ( "net/netip" "path" - "path/filepath" "sync/atomic" "testing" "testing/fstest" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghchan" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" @@ -20,139 +18,6 @@ import ( "github.com/stretchr/testify/require" ) -// nl is a newline character. -const nl = "\n" - -// Variables mirroring the etc_hosts file from testdata. -var ( - addr1000 = netip.MustParseAddr("1.0.0.0") - addr1001 = netip.MustParseAddr("1.0.0.1") - addr1002 = netip.MustParseAddr("1.0.0.2") - addr1003 = netip.MustParseAddr("1.0.0.3") - addr1004 = netip.MustParseAddr("1.0.0.4") - addr1357 = netip.MustParseAddr("1.3.5.7") - addr4216 = netip.MustParseAddr("4.2.1.6") - addr7531 = netip.MustParseAddr("7.5.3.1") - - addr0 = netip.MustParseAddr("::") - addr1 = netip.MustParseAddr("::1") - addr2 = netip.MustParseAddr("::2") - addr3 = netip.MustParseAddr("::3") - addr4 = netip.MustParseAddr("::4") - addr42 = netip.MustParseAddr("::42") - addr13 = netip.MustParseAddr("::13") - addr31 = netip.MustParseAddr("::31") - - hostsSrc = "./" + filepath.Join("./testdata", "etc_hosts") - - testHosts = map[netip.Addr][]*hostsfile.Record{ - addr1000: {{ - Addr: addr1000, - Source: hostsSrc, - Names: []string{"hello", "hello.world"}, - }, { - Addr: addr1000, - Source: hostsSrc, - Names: []string{"hello.world.again"}, - }, { - Addr: addr1000, - Source: hostsSrc, - Names: []string{"hello.world"}, - }}, - addr1001: {{ - Addr: addr1001, - Source: hostsSrc, - Names: []string{"simplehost"}, - }, { - Addr: addr1001, - Source: hostsSrc, - Names: []string{"simplehost"}, - }}, - addr1002: {{ - Addr: addr1002, - Source: hostsSrc, - Names: []string{"a.whole", "lot.of", "aliases", "for.testing"}, - }}, - addr1003: {{ - Addr: addr1003, - Source: hostsSrc, - Names: []string{"*"}, - }}, - addr1004: {{ - Addr: addr1004, - Source: hostsSrc, - Names: []string{"*.com"}, - }}, - addr1357: {{ - Addr: addr1357, - Source: hostsSrc, - Names: []string{"domain4", "domain4.alias"}, - }}, - addr7531: {{ - Addr: addr7531, - Source: hostsSrc, - Names: []string{"domain4.alias", "domain4"}, - }}, - addr4216: {{ - Addr: addr4216, - Source: hostsSrc, - Names: []string{"domain", "domain.alias"}, - }}, - addr0: {{ - Addr: addr0, - Source: hostsSrc, - Names: []string{"hello", "hello.world"}, - }, { - Addr: addr0, - Source: hostsSrc, - Names: []string{"hello.world.again"}, - }, { - Addr: addr0, - Source: hostsSrc, - Names: []string{"hello.world"}, - }}, - addr1: {{ - Addr: addr1, - Source: hostsSrc, - Names: []string{"simplehost"}, - }, { - Addr: addr1, - Source: hostsSrc, - Names: []string{"simplehost"}, - }}, - addr2: {{ - Addr: addr2, - Source: hostsSrc, - Names: []string{"a.whole", "lot.of", "aliases", "for.testing"}, - }}, - addr3: {{ - Addr: addr3, - Source: hostsSrc, - Names: []string{"*"}, - }}, - addr4: {{ - Addr: addr4, - Source: hostsSrc, - Names: []string{"*.com"}, - }}, - addr42: {{ - Addr: addr42, - Source: hostsSrc, - Names: []string{"domain.alias", "domain"}, - }}, - addr13: {{ - Addr: addr13, - Source: hostsSrc, - Names: []string{"domain6", "domain6.alias"}, - }}, - addr31: {{ - Addr: addr31, - Source: hostsSrc, - Names: []string{"domain6.alias", "domain6"}, - }}, - } -) - func TestNewHostsContainer(t *testing.T) { const dirname = "dir" const filename = "file1" @@ -267,7 +132,21 @@ func TestHostsContainer_refresh(t *testing.T) { anotherIPStr := "1.2.3.4" anotherIP := netip.MustParseAddr(anotherIPStr) - testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)}} + r1 := &hostsfile.Record{ + Addr: ip, + Source: "file1", + Names: []string{"hostname"}, + } + r2 := &hostsfile.Record{ + Addr: anotherIP, + Source: "file2", + Names: []string{"alias"}, + } + + r1Data, _ := r1.MarshalText() + r2Data, _ := r2.MarshalText() + + testFS := fstest.MapFS{"dir/file1": &fstest.MapFile{Data: r1Data}} // event is a convenient alias for an empty struct{} to emit test events. type event = struct{} @@ -289,172 +168,47 @@ func TestHostsContainer_refresh(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) - checkRefresh := func(t *testing.T, want aghnet.Hosts) { - t.Helper() - - upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second) - require.True(t, ok) - - assert.Equal(t, want, upd) - } + strg, _ := hostsfile.NewDefaultStorage() + strg.Add(r1) t.Run("initial_refresh", func(t *testing.T) { - checkRefresh(t, aghnet.Hosts{ - ip: {{ - Addr: ip, - Source: "file1", - Names: []string{"hostname"}, - }}, - }) + upd, ok := testutil.RequireReceive(t, hc.Upd(), 1*time.Second) + require.True(t, ok) + + assert.True(t, strg.Equal(upd)) }) + strg.Add(r2) + t.Run("second_refresh", func(t *testing.T) { - testFS["dir/file2"] = &fstest.MapFile{Data: []byte(anotherIPStr + ` alias` + nl)} + testFS["dir/file2"] = &fstest.MapFile{Data: r2Data} eventsCh <- event{} - checkRefresh(t, aghnet.Hosts{ - ip: {{ - Addr: ip, - Source: "file1", - Names: []string{"hostname"}, - }}, - anotherIP: {{ - Addr: anotherIP, - Source: "file2", - Names: []string{"alias"}, - }}, - }) + upd, ok := testutil.RequireReceive(t, hc.Upd(), 1*time.Second) + require.True(t, ok) + + assert.True(t, strg.Equal(upd)) }) t.Run("double_refresh", func(t *testing.T) { // Make a change once. - testFS["dir/file1"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)} + testFS["dir/file1"] = &fstest.MapFile{Data: []byte(ipStr + " alias\n")} eventsCh <- event{} // Require the changes are written. - require.Eventually(t, func() bool { - ips := hc.MatchName("hostname") + current, ok := testutil.RequireReceive(t, hc.Upd(), 1*time.Second) + require.True(t, ok) - return len(ips) == 0 - }, 5*time.Second, time.Second/2) + require.Empty(t, current.ByName("hostname")) // Make a change again. - testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` hostname` + nl)} + testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + " hostname\n")} eventsCh <- event{} // Require the changes are written. - require.Eventually(t, func() bool { - ips := hc.MatchName("hostname") + current, ok = testutil.RequireReceive(t, hc.Upd(), 1*time.Second) + require.True(t, ok) - return len(ips) > 0 - }, 5*time.Second, time.Second/2) - - assert.Len(t, hc.Upd(), 1) + require.NotEmpty(t, current.ByName("hostname")) }) } - -func TestHostsContainer_MatchName(t *testing.T) { - require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) - - stubWatcher := aghtest.FSWatcher{ - OnEvents: func() (e <-chan struct{}) { return nil }, - OnAdd: func(name string) (err error) { return nil }, - OnClose: func() (err error) { return nil }, - } - - testCases := []struct { - req string - name string - want []*hostsfile.Record - }{{ - req: "simplehost", - name: "simple", - want: append(testHosts[addr1001], testHosts[addr1]...), - }, { - req: "hello.world", - name: "hello_alias", - want: []*hostsfile.Record{ - testHosts[addr1000][0], - testHosts[addr1000][2], - testHosts[addr0][0], - testHosts[addr0][2], - }, - }, { - req: "hello.world.again", - name: "other_line_alias", - want: []*hostsfile.Record{ - testHosts[addr1000][1], - testHosts[addr0][1], - }, - }, { - req: "say.hello", - name: "hello_subdomain", - want: nil, - }, { - req: "say.hello.world", - name: "hello_alias_subdomain", - want: nil, - }, { - req: "for.testing", - name: "lots_of_aliases", - want: append(testHosts[addr1002], testHosts[addr2]...), - }, { - req: "nonexistent.example", - name: "non-existing", - want: nil, - }, { - req: "domain", - name: "issue_4216_4_6", - want: append(testHosts[addr4216], testHosts[addr42]...), - }, { - req: "domain4", - name: "issue_4216_4", - want: append(testHosts[addr1357], testHosts[addr7531]...), - }, { - req: "domain6", - name: "issue_4216_6", - want: append(testHosts[addr13], testHosts[addr31]...), - }} - - hc, err := aghnet.NewHostsContainer(testdata, &stubWatcher, "etc_hosts") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, hc.Close) - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - recs := hc.MatchName(tc.req) - assert.Equal(t, tc.want, recs) - }) - } -} - -func TestHostsContainer_MatchAddr(t *testing.T) { - require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) - - stubWatcher := aghtest.FSWatcher{ - OnEvents: func() (e <-chan struct{}) { return nil }, - OnAdd: func(name string) (err error) { return nil }, - OnClose: func() (err error) { return nil }, - } - - hc, err := aghnet.NewHostsContainer(testdata, &stubWatcher, "etc_hosts") - require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, hc.Close) - - testCases := []struct { - req netip.Addr - name string - want []*hostsfile.Record - }{{ - req: netip.AddrFrom4([4]byte{1, 0, 0, 1}), - name: "reverse", - want: testHosts[addr1001], - }} - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - recs := hc.MatchAddr(tc.req) - assert.Equal(t, tc.want, recs) - }) - } -} diff --git a/internal/aghnet/hostscontainer_windows.go b/internal/aghnet/hostscontainer_windows.go deleted file mode 100644 index 7bbf7ac0..00000000 --- a/internal/aghnet/hostscontainer_windows.go +++ /dev/null @@ -1,32 +0,0 @@ -//go:build windows - -package aghnet - -import ( - "os" - "path" - "path/filepath" - "strings" - - "github.com/AdguardTeam/golibs/log" - "golang.org/x/sys/windows" -) - -func defaultHostsPaths() (paths []string) { - sysDir, err := windows.GetSystemDirectory() - if err != nil { - log.Error("aghnet: getting system directory: %s", err) - - return []string{} - } - - // Split all the elements of the path to join them afterwards. This is - // needed to make the Windows-specific path string returned by - // windows.GetSystemDirectory to be compatible with fs.FS. - pathElems := strings.Split(sysDir, string(os.PathSeparator)) - if len(pathElems) > 0 && pathElems[0] == filepath.VolumeName(sysDir) { - pathElems = pathElems[1:] - } - - return []string{path.Join(append(pathElems, "drivers/etc/hosts")...)} -} diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 2a77803b..06f65840 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -1,11 +1,9 @@ package aghnet_test import ( - "io/fs" "net" "net/netip" "net/url" - "os" "testing" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" @@ -18,9 +16,6 @@ func TestMain(m *testing.M) { testutil.DiscardLogOutput(m) } -// testdata is the filesystem containing data for testing the package. -var testdata fs.FS = os.DirFS("./testdata") - func TestParseAddrPort(t *testing.T) { const defaultPort = 1 diff --git a/internal/aghnet/testdata/etc_hosts b/internal/aghnet/testdata/etc_hosts deleted file mode 100644 index 7cd0c770..00000000 --- a/internal/aghnet/testdata/etc_hosts +++ /dev/null @@ -1,38 +0,0 @@ -# -# Test /etc/hosts file -# - -1.0.0.1 simplehost -1.0.0.0 hello hello.world - -# See https://github.com/AdguardTeam/AdGuardHome/issues/3846. -1.0.0.2 a.whole lot.of aliases for.testing - -# See https://github.com/AdguardTeam/AdGuardHome/issues/3946. -1.0.0.3 * -1.0.0.4 *.com - -# See https://github.com/AdguardTeam/AdGuardHome/issues/4079. -1.0.0.0 hello.world.again - -# Duplicates of a main host and an alias. -1.0.0.1 simplehost -1.0.0.0 hello.world - -# Same for IPv6. -::1 simplehost -:: hello hello.world -::2 a.whole lot.of aliases for.testing -::3 * -::4 *.com -:: hello.world.again -::1 simplehost -:: hello.world - -# See https://github.com/AdguardTeam/AdGuardHome/issues/4216. -4.2.1.6 domain domain.alias -::42 domain.alias domain -1.3.5.7 domain4 domain4.alias -7.5.3.1 domain4.alias domain4 -::13 domain6 domain6.alias -::31 domain6.alias domain6 diff --git a/internal/aghnet/testdata/ifaces b/internal/aghnet/testdata/ifaces deleted file mode 100644 index b98b0409..00000000 --- a/internal/aghnet/testdata/ifaces +++ /dev/null @@ -1 +0,0 @@ -iface sample_name inet static diff --git a/internal/aghnet/testdata/include-subsources b/internal/aghnet/testdata/include-subsources deleted file mode 100644 index 5391a5b3..00000000 --- a/internal/aghnet/testdata/include-subsources +++ /dev/null @@ -1,5 +0,0 @@ -# The "testdata" part is added here because the test is actually run from the -# parent directory. Real interface files usually contain only absolute paths. - -source ./testdata/ifaces -source ./testdata/* diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 8afbd3df..cdd2c240 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -142,7 +142,7 @@ type Server struct { // PTR resolving. sysResolvers SystemResolvers - // etcHosts contains the data from the system's hosts files. + // etcHosts contains the current data from the system's hosts files. etcHosts upstream.Resolver // bootstrap is the resolver for upstreams' hostnames. @@ -239,6 +239,11 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { p.Anonymizer = aghnet.NewIPMut(nil) } + var etcHosts upstream.Resolver + if p.EtcHosts != nil { + etcHosts = upstream.NewHostsResolver(p.EtcHosts) + } + s = &Server{ dnsFilter: p.DNSFilter, dhcpServer: p.DHCPServer, @@ -247,6 +252,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { privateNets: p.PrivateNets, // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), + etcHosts: etcHosts, recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), clientIDCache: cache.New(cache.Config{ EnableLRU: true, @@ -257,9 +263,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { ServePlainDNS: true, }, } - if p.EtcHosts != nil { - s.etcHosts = p.EtcHosts - } s.sysResolvers, err = sysresolv.NewSystemResolvers(nil, defaultPlainDNSPort) if err != nil { diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index b7696cdc..4b6987bb 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/httphdr" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" @@ -526,7 +527,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { }, ServePlainDNS: true, }, nil) - srv.etcHosts = hc + srv.etcHosts = upstream.NewHostsResolver(hc) startDeferStop(t, srv) testCases := []struct { diff --git a/internal/filtering/dnsrewrite.go b/internal/filtering/dnsrewrite.go index 3fd6e778..19b964a2 100644 --- a/internal/filtering/dnsrewrite.go +++ b/internal/filtering/dnsrewrite.go @@ -1,7 +1,6 @@ package filtering import ( - "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -95,39 +94,3 @@ func (d *DNSFilter) processDNSResultRewrites( return res } - -// appendRewriteResultFromHost appends the rewrite result from rec to vals and -// resRules. -func appendRewriteResultFromHost( - vals []rules.RRValue, - resRules []*ResultRule, - rec *hostsfile.Record, - qtype uint16, -) (updatedVals []rules.RRValue, updatedRules []*ResultRule) { - switch qtype { - case dns.TypeA: - if !rec.Addr.Is4() { - return vals, resRules - } - - vals = append(vals, rec.Addr) - case dns.TypeAAAA: - if !rec.Addr.Is6() { - return vals, resRules - } - - vals = append(vals, rec.Addr) - case dns.TypePTR: - for _, name := range rec.Names { - vals = append(vals, name) - } - } - - recText, _ := rec.MarshalText() - resRules = append(resRules, &ResultRule{ - FilterListID: SysHostsListID, - Text: string(recText), - }) - - return vals, resRules -} diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index 98853c95..2793697c 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -220,15 +220,19 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { addrv4 := netip.MustParseAddr("1.2.3.4") addrv6 := netip.MustParseAddr("::1") addrMapped := netip.MustParseAddr("::ffff:1.2.3.4") + addrv4Dup := netip.MustParseAddr("4.3.2.1") data := fmt.Sprintf( ""+ - "%s v4.host.example\n"+ - "%s v6.host.example\n"+ - "%s mapped.host.example\n", + "%[1]s v4.host.example\n"+ + "%[2]s v6.host.example\n"+ + "%[3]s mapped.host.example\n"+ + "%[4]s v4.host.with-dup\n"+ + "%[4]s v4.host.with-dup\n", addrv4, addrv6, addrMapped, + addrv4Dup, ) files := fstest.MapFS{ @@ -343,6 +347,15 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { dtyp: dns.TypeCNAME, wantRules: nil, wantResps: nil, + }, { + name: "v4_dup", + host: "v4.host.with-dup", + dtyp: dns.TypeA, + wantRules: []*ResultRule{{ + Text: "4.3.2.1 v4.host.with-dup", + FilterListID: SysHostsListID, + }}, + wantResps: []rules.RRValue{addrv4Dup}, }} for _, tc := range testCases { diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 703f6c71..d54b0d35 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -18,7 +18,6 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" @@ -100,7 +99,7 @@ type Config struct { // system configuration files (e.g. /etc/hosts). // // TODO(e.burkov): Move it to dnsforward entirely. - EtcHosts *aghnet.HostsContainer `yaml:"-"` + EtcHosts hostsfile.Storage `yaml:"-"` // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -482,15 +481,6 @@ func (d *DNSFilter) SetProtectionEnabled(status bool) { d.conf.ProtectionEnabled = status } -// EtcHostsRecords returns the hosts records for the hostname. -func (d *DNSFilter) EtcHostsRecords(hostname string) (recs []*hostsfile.Record) { - if d.conf.EtcHosts != nil { - return d.conf.EtcHosts.MatchName(hostname) - } - - return recs -} - // SetBlockingMode sets blocking mode properties. func (d *DNSFilter) SetBlockingMode(mode BlockingMode, bIPv4, bIPv6 netip.Addr) { d.confMu.Lock() @@ -637,39 +627,10 @@ func (d *DNSFilter) matchSysHosts( ) (res Result, err error) { // TODO(e.burkov): Where else is this checked? if !setts.FilteringEnabled || d.conf.EtcHosts == nil { - return res, nil - } - - var recs []*hostsfile.Record - switch qtype { - case dns.TypeA, dns.TypeAAAA: - recs = d.conf.EtcHosts.MatchName(host) - case dns.TypePTR: - var ip net.IP - ip, err = netutil.IPFromReversedAddr(host) - if err != nil { - log.Debug("filtering: failed to parse PTR record %q: %s", host, err) - - return res, nil - } - - addr, _ := netip.AddrFromSlice(ip) - recs = d.conf.EtcHosts.MatchAddr(addr) - default: - log.Debug("filtering: unsupported query type %s", dns.Type(qtype)) - } - - var vals []rules.RRValue - var resRules []*ResultRule - resRulesLen := 0 - for _, rec := range recs { - vals, resRules = appendRewriteResultFromHost(vals, resRules, rec, qtype) - if len(resRules) > resRulesLen { - resRulesLen = len(resRules) - log.Debug("filtering: matched %s in %q", host, rec.Source) - } + return Result{}, nil } + vals, rs := hostsRewrites(qtype, host, d.conf.EtcHosts) if len(vals) > 0 { res.DNSRewriteResult = &DNSRewriteResult{ Response: DNSRewriteResultResponse{ @@ -677,13 +638,64 @@ func (d *DNSFilter) matchSysHosts( }, RCode: dns.RcodeSuccess, } - res.Rules = resRules - res.Reason = RewrittenRule + res.Rules = rs + res.Reason = RewrittenAutoHosts } return res, nil } +// hostsRewrites returns values and rules matched by qt and host within hs. +func hostsRewrites( + qtype uint16, + host string, + hs hostsfile.Storage, +) (vals []rules.RRValue, rs []*ResultRule) { + var isValidProto func(netip.Addr) (ok bool) + switch qtype { + case dns.TypeA: + isValidProto = netip.Addr.Is4 + case dns.TypeAAAA: + isValidProto = netip.Addr.Is6 + case dns.TypePTR: + // TODO(e.burkov): Add some [netip]-aware alternative to [netutil]. + ip, err := netutil.IPFromReversedAddr(host) + if err != nil { + log.Debug("filtering: failed to parse PTR record %q: %s", host, err) + + return nil, nil + } + + addr, _ := netip.AddrFromSlice(ip) + + for _, name := range hs.ByAddr(addr) { + vals = append(vals, name) + rs = append(rs, &ResultRule{ + Text: fmt.Sprintf("%s %s", addr, name), + FilterListID: SysHostsListID, + }) + } + + return vals, rs + default: + log.Debug("filtering: unsupported qtype %d", qtype) + + return nil, nil + } + + for _, addr := range hs.ByName(host) { + if isValidProto(addr) { + vals = append(vals, addr) + rs = append(rs, &ResultRule{ + Text: fmt.Sprintf("%s %s", addr, host), + FilterListID: SysHostsListID, + }) + } + } + + return vals, rs +} + // processRewrites performs filtering based on the legacy rewrite records. // // Firstly, it finds CNAME rewrites for host. If the CNAME is the same as host, diff --git a/internal/home/clients.go b/internal/home/clients.go index 6d3a6d23..4d52f81b 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" "golang.org/x/exp/maps" @@ -139,6 +140,9 @@ func (clients *clientsContainer) Init( return nil } +// handleHostsUpdates receives the updates from the hosts container and adds +// them to the clients container. It's used to be called in a separate +// goroutine. func (clients *clientsContainer) handleHostsUpdates() { for upd := range clients.etcHosts.Upd() { clients.addFromHostsFile(upd) @@ -870,21 +874,24 @@ func (clients *clientsContainer) rmHostsBySrc(src client.Source) { // addFromHostsFile fills the client-hostname pairing index from the system's // hosts files. -func (clients *clientsContainer) addFromHostsFile(hosts aghnet.Hosts) { +func (clients *clientsContainer) addFromHostsFile(hosts *hostsfile.DefaultStorage) { clients.lock.Lock() defer clients.lock.Unlock() clients.rmHostsBySrc(client.SourceHostsFile) n := 0 - for addr, rec := range hosts { + hosts.RangeNames(func(addr netip.Addr, names []string) (cont bool) { // Only the first name of the first record is considered a canonical // hostname for the IP address. // // TODO(e.burkov): Consider using all the names from all the records. - clients.addHostLocked(addr, rec[0].Names[0], client.SourceHostsFile) - n++ - } + if clients.addHostLocked(addr, names[0], client.SourceHostsFile) { + n++ + } + + return true + }) log.Debug("clients: added %d client aliases from system hosts file", n) } diff --git a/internal/home/home.go b/internal/home/home.go index ab2d83a2..18d6a961 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -35,8 +35,10 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/stringutil" "golang.org/x/exp/slices" ) @@ -231,11 +233,12 @@ func setupHostsContainer() (err error) { return fmt.Errorf("initing hosts watcher: %w", err) } - Context.etcHosts, err = aghnet.NewHostsContainer( - aghos.RootDirFS(), - hostsWatcher, - aghnet.DefaultHostsPaths()..., - ) + paths, err := hostsfile.DefaultHostsPaths() + if err != nil { + return fmt.Errorf("getting default system hosts paths: %w", err) + } + + Context.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...) if err != nil { closeErr := hostsWatcher.Close() if errors.Is(err, aghnet.ErrNoHostsPaths) { diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 4ebcdcff..e52b0dfc 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -206,7 +206,6 @@ run_linter gocognit --over='11'\ run_linter gocognit --over='10'\ ./internal/aghalg/\ - ./internal/aghchan/\ ./internal/aghhttp/\ ./internal/aghrenameio/\ ./internal/aghtest/\ @@ -244,7 +243,6 @@ run_linter nilness ./... # TODO(a.garipov): Enable for all. run_linter fieldalignment \ ./internal/aghalg/\ - ./internal/aghchan/\ ./internal/aghhttp/\ ./internal/aghos/\ ./internal/aghrenameio/\