diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 65c9d3c4..15875ccc 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -198,7 +198,7 @@ func (hc *HostsContainer) Close() (err error) { } // Upd returns the channel into which the updates are sent. The receivable -// map's values are guaranteed to be of type of *stringutil.Set. +// map's values are guaranteed to be of type of *HostsRecord. func (hc *HostsContainer) Upd() (updates <-chan *netutil.IPMap) { return hc.updates } @@ -290,7 +290,7 @@ func (hp *hostsParser) parseFile(r io.Reader) (patterns []string, cont bool, err continue } - hp.addPairs(ip, hosts) + hp.addRecord(ip, hosts) } return nil, true, s.Err() @@ -335,39 +335,66 @@ func (hp *hostsParser) parseLine(line string) (ip net.IP, hosts []string) { return ip, hosts } -// addPair puts the pair of ip and host to the rules builder if needed. For -// each ip the first member of hosts will become the main one. -func (hp *hostsParser) addPairs(ip net.IP, hosts []string) { +// HostsRecord represents a single hosts file record. +type HostsRecord struct { + Aliases *stringutil.Set + Canonical string +} + +// Equal returns true if all fields of rec are equal to field in other or they +// both are nil. +func (rec *HostsRecord) Equal(other *HostsRecord) (ok bool) { + if rec == nil { + return other == nil + } + + return rec.Canonical == other.Canonical && rec.Aliases.Equal(other.Aliases) +} + +// addRecord puts the record for the IP address to the rules builder if needed. +// The first host is considered to be the canonical name for the IP address. +// hosts must have at least one name. +func (hp *hostsParser) addRecord(ip net.IP, hosts []string) { + line := strings.Join(append([]string{ip.String()}, hosts...), " ") + + var rec *HostsRecord v, ok := hp.table.Get(ip) if !ok { - // This ip is added at the first time. - v = stringutil.NewSet() - hp.table.Set(ip, v) + rec = &HostsRecord{ + Aliases: stringutil.NewSet(), + } + + rec.Canonical, hosts = hosts[0], hosts[1:] + hp.addRules(ip, rec.Canonical, line) + hp.table.Set(ip, rec) + } else { + rec, ok = v.(*HostsRecord) + if !ok { + log.Error("%s: adding pairs: unexpected type %T", hostsContainerPref, v) + + return + } } - var set *stringutil.Set - set, ok = v.(*stringutil.Set) - if !ok { - log.Debug("%s: adding pairs: unexpected value type %T", hostsContainerPref, v) - - return - } - - processed := strings.Join(append([]string{ip.String()}, hosts...), " ") - for _, h := range hosts { - if set.Has(h) { + for _, host := range hosts { + if rec.Canonical == host || rec.Aliases.Has(host) { continue } - set.Add(h) + rec.Aliases.Add(host) - rule, rulePtr := hp.writeRules(h, ip) - hp.translations[rule], hp.translations[rulePtr] = processed, processed - - log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, h) + hp.addRules(ip, host, line) } } +// addRules adds rules and rule translations for the line. +func (hp *hostsParser) addRules(ip net.IP, host, line string) { + rule, rulePtr := hp.writeRules(host, ip) + hp.translations[rule], hp.translations[rulePtr] = line, line + + log.Debug("%s: added ip-host pair %q-%q", hostsContainerPref, ip, host) +} + // writeRules writes the actual rule for the qtype and the PTR for the host-ip // pair into internal builders. func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) { @@ -417,6 +444,7 @@ func (hp *hostsParser) writeRules(host string, ip net.IP) (rule, rulePtr string) } // equalSet returns true if the internal hosts table just parsed equals target. +// target's values must be of type *HostsRecord. func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) { if target == nil { // hp.table shouldn't appear nil since it's initialized on each refresh. @@ -427,22 +455,35 @@ func (hp *hostsParser) equalSet(target *netutil.IPMap) (ok bool) { return false } - hp.table.Range(func(ip net.IP, b interface{}) (cont bool) { - // ok is set to true if the target doesn't contain ip or if the - // appropriate hosts set isn't equal to the checked one. - if a, hasIP := target.Get(ip); !hasIP { - ok = true - } else if hosts, aok := a.(*stringutil.Set); aok { - ok = !hosts.Equal(b.(*stringutil.Set)) + hp.table.Range(func(ip net.IP, recVal interface{}) (cont bool) { + var targetVal interface{} + targetVal, ok = target.Get(ip) + if !ok { + return false } - // Continue only if maps has no discrepancies. - return !ok + var rec *HostsRecord + rec, ok = recVal.(*HostsRecord) + if !ok { + log.Error("%s: comparing: unexpected type %T", hostsContainerPref, recVal) + + return false + } + + var targetRec *HostsRecord + targetRec, ok = targetVal.(*HostsRecord) + if !ok { + log.Error("%s: comparing: target: unexpected type %T", hostsContainerPref, targetVal) + + return false + } + + ok = rec.Equal(targetRec) + + return ok }) - // Return true if every value from the IP map has no discrepancies with the - // appropriate one from the target. - return !ok + return ok } // sendUpd tries to send the parsed data to the ch. diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 42202f43..019c713e 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter" @@ -159,31 +160,47 @@ func TestHostsContainer_refresh(t *testing.T) { require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) - checkRefresh := func(t *testing.T, wantHosts *stringutil.Set) { - upd, ok := <-hc.Upd() - require.True(t, ok) - require.NotNil(t, upd) + checkRefresh := func(t *testing.T, want *HostsRecord) { + t.Helper() + + var ok bool + var upd *netutil.IPMap + select { + case upd, ok = <-hc.Upd(): + require.True(t, ok) + require.NotNil(t, upd) + case <-time.After(1 * time.Second): + t.Fatal("did not receive after 1s") + } assert.Equal(t, 1, upd.Len()) v, ok := upd.Get(ip) require.True(t, ok) - var set *stringutil.Set - set, ok = v.(*stringutil.Set) - require.True(t, ok) + require.IsType(t, (*HostsRecord)(nil), v) - assert.True(t, set.Equal(wantHosts)) + rec, _ := v.(*HostsRecord) + require.NotNil(t, rec) + + assert.Truef(t, rec.Equal(want), "%+v != %+v", rec, want) } t.Run("initial_refresh", func(t *testing.T) { - checkRefresh(t, stringutil.NewSet("hostname")) + checkRefresh(t, &HostsRecord{ + Aliases: stringutil.NewSet(), + Canonical: "hostname", + }) }) t.Run("second_refresh", func(t *testing.T) { testFS["dir/file2"] = &fstest.MapFile{Data: []byte(ipStr + ` alias` + nl)} eventsCh <- event{} - checkRefresh(t, stringutil.NewSet("hostname", "alias")) + + checkRefresh(t, &HostsRecord{ + Aliases: stringutil.NewSet("alias"), + Canonical: "hostname", + }) }) t.Run("double_refresh", func(t *testing.T) { @@ -363,10 +380,15 @@ func TestHostsContainer(t *testing.T) { require.NoError(t, fstest.TestFS(testdata, "etc_hosts")) testCases := []struct { - want []*rules.DNSRewrite - name string req *urlfilter.DNSRequest + name string + want []*rules.DNSRewrite }{{ + req: &urlfilter.DNSRequest{ + Hostname: "simplehost", + DNSType: dns.TypeA, + }, + name: "simple", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, Value: net.IPv4(1, 0, 0, 1), @@ -376,27 +398,12 @@ func TestHostsContainer(t *testing.T) { Value: net.ParseIP("::1"), RRType: dns.TypeAAAA, }}, - name: "simple", - req: &urlfilter.DNSRequest{ - Hostname: "simplehost", - DNSType: dns.TypeA, - }, }, { - want: []*rules.DNSRewrite{{ - RCode: dns.RcodeSuccess, - Value: net.IPv4(1, 0, 0, 0), - RRType: dns.TypeA, - }, { - RCode: dns.RcodeSuccess, - Value: net.ParseIP("::"), - RRType: dns.TypeAAAA, - }}, - name: "hello_alias", req: &urlfilter.DNSRequest{ Hostname: "hello.world", DNSType: dns.TypeA, }, - }, { + name: "hello_alias", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, Value: net.IPv4(1, 0, 0, 0), @@ -406,26 +413,41 @@ func TestHostsContainer(t *testing.T) { Value: net.ParseIP("::"), RRType: dns.TypeAAAA, }}, - name: "other_line_alias", + }, { req: &urlfilter.DNSRequest{ Hostname: "hello.world.again", DNSType: dns.TypeA, }, + name: "other_line_alias", + want: []*rules.DNSRewrite{{ + RCode: dns.RcodeSuccess, + Value: net.IPv4(1, 0, 0, 0), + RRType: dns.TypeA, + }, { + RCode: dns.RcodeSuccess, + Value: net.ParseIP("::"), + RRType: dns.TypeAAAA, + }}, }, { - want: []*rules.DNSRewrite{}, - name: "hello_subdomain", req: &urlfilter.DNSRequest{ Hostname: "say.hello", DNSType: dns.TypeA, }, - }, { + name: "hello_subdomain", want: []*rules.DNSRewrite{}, - name: "hello_alias_subdomain", + }, { req: &urlfilter.DNSRequest{ Hostname: "say.hello.world", DNSType: dns.TypeA, }, + name: "hello_alias_subdomain", + want: []*rules.DNSRewrite{}, }, { + req: &urlfilter.DNSRequest{ + Hostname: "for.testing", + DNSType: dns.TypeA, + }, + name: "lots_of_aliases", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, RRType: dns.TypeA, @@ -435,37 +457,37 @@ func TestHostsContainer(t *testing.T) { RRType: dns.TypeAAAA, Value: net.ParseIP("::2"), }}, - name: "lots_of_aliases", - req: &urlfilter.DNSRequest{ - Hostname: "for.testing", - DNSType: dns.TypeA, - }, }, { + req: &urlfilter.DNSRequest{ + Hostname: "1.0.0.1.in-addr.arpa", + DNSType: dns.TypePTR, + }, + name: "reverse", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, RRType: dns.TypePTR, Value: "simplehost.", }}, - name: "reverse", - req: &urlfilter.DNSRequest{ - Hostname: "1.0.0.1.in-addr.arpa", - DNSType: dns.TypePTR, - }, }, { - want: []*rules.DNSRewrite{}, - name: "non-existing", req: &urlfilter.DNSRequest{ Hostname: "nonexisting", DNSType: dns.TypeA, }, + name: "non-existing", + want: []*rules.DNSRewrite{}, }, { - want: nil, - name: "bad_type", req: &urlfilter.DNSRequest{ Hostname: "1.0.0.1.in-addr.arpa", DNSType: dns.TypeSRV, }, + name: "bad_type", + want: nil, }, { + req: &urlfilter.DNSRequest{ + Hostname: "domain", + DNSType: dns.TypeA, + }, + name: "issue_4216_4_6", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, RRType: dns.TypeA, @@ -475,12 +497,12 @@ func TestHostsContainer(t *testing.T) { RRType: dns.TypeAAAA, Value: net.ParseIP("::42"), }}, - name: "issue_4216_4_6", + }, { req: &urlfilter.DNSRequest{ - Hostname: "domain", + Hostname: "domain4", DNSType: dns.TypeA, }, - }, { + name: "issue_4216_4", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, RRType: dns.TypeA, @@ -490,12 +512,12 @@ func TestHostsContainer(t *testing.T) { RRType: dns.TypeA, Value: net.IPv4(1, 3, 5, 7), }}, - name: "issue_4216_4", - req: &urlfilter.DNSRequest{ - Hostname: "domain4", - DNSType: dns.TypeA, - }, }, { + req: &urlfilter.DNSRequest{ + Hostname: "domain6", + DNSType: dns.TypeAAAA, + }, + name: "issue_4216_6", want: []*rules.DNSRewrite{{ RCode: dns.RcodeSuccess, RRType: dns.TypeAAAA, @@ -505,11 +527,6 @@ func TestHostsContainer(t *testing.T) { RRType: dns.TypeAAAA, Value: net.ParseIP("::31"), }}, - name: "issue_4216_6", - req: &urlfilter.DNSRequest{ - Hostname: "domain6", - DNSType: dns.TypeAAAA, - }, }} stubWatcher := aghtest.FSWatcher{ diff --git a/internal/aghnet/systemresolvers.go b/internal/aghnet/systemresolvers.go index 13fbeb32..5ca8e9be 100644 --- a/internal/aghnet/systemresolvers.go +++ b/internal/aghnet/systemresolvers.go @@ -19,7 +19,7 @@ type SystemResolvers interface { } // NewSystemResolvers returns a SystemResolvers with the cache refresh rate -// defined by refreshIvl. It disables auto-resfreshing if refreshIvl is 0. If +// defined by refreshIvl. It disables auto-refreshing if refreshIvl is 0. If // nil is passed for hostGenFunc, the default generator will be used. func NewSystemResolvers( hostGenFunc HostGenFunc, diff --git a/internal/home/clients.go b/internal/home/clients.go index 4ba6b884..74545a67 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -743,8 +743,7 @@ func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSourc // addHostLocked adds a new IP-hostname pairing. For internal use only. func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) { - var rc *RuntimeClient - rc, ok = clients.findRuntimeClientLocked(ip) + rc, ok := clients.findRuntimeClientLocked(ip) if ok { if rc.Source > src { return false @@ -799,25 +798,20 @@ func (clients *clientsContainer) addFromHostsFile(hosts *netutil.IPMap) { n := 0 hosts.Range(func(ip net.IP, v interface{}) (cont bool) { - hosts, ok := v.(*stringutil.Set) + rec, ok := v.(*aghnet.HostsRecord) if !ok { log.Error("dns: bad type %T in ipToRC for %s", v, ip) return true } - hosts.Range(func(name string) (cont bool) { - if clients.addHostLocked(ip, name, ClientSourceHostsFile) { - n++ - } - - return true - }) + clients.addHostLocked(ip, rec.Canonical, ClientSourceHostsFile) + n++ return true }) - log.Debug("clients: added %d client aliases from system hosts-file", n) + log.Debug("clients: added %d client aliases from system hosts file", n) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a