diff --git a/CHANGELOG.md b/CHANGELOG.md index 5607b6e7..2025e26f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to ### Added +- Blocking access using client IDs ([#2624], [#3162]). - `source` directives support in `/etc/network/interfaces` on Linux ([#3257]). - RFC 9000 support in DNS-over-QUIC. - Completely disabling statistics by setting the statistics interval to zero @@ -80,9 +81,11 @@ released by then. [#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439 [#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441 [#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443 +[#2624]: https://github.com/AdguardTeam/AdGuardHome/issues/2624 [#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763 [#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013 [#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136 +[#3162]: https://github.com/AdguardTeam/AdGuardHome/issues/3162 [#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166 [#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172 [#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184 diff --git a/HACKING.md b/HACKING.md index bde2ab90..4b46ebc0 100644 --- a/HACKING.md +++ b/HACKING.md @@ -159,8 +159,10 @@ attributes to make it work in Markdown renderers that strip "id". --> * Minimize scope of variables as much as possible. - * No shadowing, since it can often lead to subtle bugs, especially with - errors. + * No name shadowing, including of predeclared identifiers, since it can often + lead to subtle bugs, especially with errors. This rule does not apply to + struct fields, since they are always used together with the name of the + struct value, so there isn't any confusion. * Prefer constants to variables where possible. Avoid global variables. Use [constant errors] instead of `errors.New`. diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index 48cfe096..c8accead 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -426,9 +426,9 @@ "access_title": "Access settings", "access_desc": "Here you can configure access rules for the AdGuard Home DNS server.", "access_allowed_title": "Allowed clients", - "access_allowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will accept requests from these IP addresses only.", + "access_allowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will accept requests only from these clients.", "access_disallowed_title": "Disallowed clients", - "access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.", + "access_disallowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will drop requests from these clients. If allowed clients are configured, this field is ignored.", "access_blocked_title": "Disallowed domains", "access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.", "access_settings_saved": "Access settings successfully saved", diff --git a/client/src/components/Dashboard/Clients.js b/client/src/components/Dashboard/Clients.js index 46edc46f..cc11b915 100644 --- a/client/src/components/Dashboard/Clients.js +++ b/client/src/components/Dashboard/Clients.js @@ -9,7 +9,7 @@ import Card from '../ui/Card'; import Cell from '../ui/Cell'; import { getPercent, sortIp } from '../../helpers/helpers'; -import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants'; +import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants'; import { toggleClientBlock } from '../../actions/access'; import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell'; import { getStats } from '../../actions/stats'; @@ -35,10 +35,6 @@ const CountCell = (row) => { }; const renderBlockingButton = (ip, disallowed, disallowed_rule) => { - if (R_CLIENT_ID.test(ip)) { - return null; - } - const dispatch = useDispatch(); const { t } = useTranslation(); const processingSet = useSelector((state) => state.access.processingSet); diff --git a/go.mod b/go.mod index e19a7896..2336f16a 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.16 require ( - github.com/AdguardTeam/dnsproxy v0.37.7 + github.com/AdguardTeam/dnsproxy v0.38.0 github.com/AdguardTeam/golibs v0.8.0 github.com/AdguardTeam/urlfilter v0.14.6 github.com/NYTimes/gziphandler v1.1.1 diff --git a/go.sum b/go.sum index 0d1b378a..66a44363 100644 --- a/go.sum +++ b/go.sum @@ -9,8 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk= github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= -github.com/AdguardTeam/dnsproxy v0.37.7 h1:yp0vEVYobf/1l8iY7es9yMqguw8BUEeC74OGA4G2v2A= -github.com/AdguardTeam/dnsproxy v0.37.7/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M= +github.com/AdguardTeam/dnsproxy v0.38.0 h1:7GyyNJOieIVOgdnhu47exqWjHPQro7wQhqzvQjaZt6M= +github.com/AdguardTeam/dnsproxy v0.38.0/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= diff --git a/internal/aghnet/etchostscontainer.go b/internal/aghnet/etchostscontainer.go index 9b6e07ef..f95896ce 100644 --- a/internal/aghnet/etchostscontainer.go +++ b/internal/aghnet/etchostscontainer.go @@ -27,10 +27,9 @@ type EtcHostsContainer struct { lock sync.RWMutex // table is the host-to-IPs map. table map[string][]net.IP - // tableReverse is the IP-to-hosts map. - // - // TODO(a.garipov): Make better use of newtypes. Perhaps a custom map. - tableReverse map[string][]string + // tableReverse is the IP-to-hosts map. The type of the values in the + // map is []string. + tableReverse *IPMap hostsFn string // path to the main hosts-file hostsDirs []string // paths to OS-specific directories with hosts-files @@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) { var err error ehc.watcher, err = fsnotify.NewWatcher() if err != nil { - log.Error("etchostscontainer: %s", err) + log.Error("etchosts: %s", err) } } @@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP { copy(ipsCopy, ips) } - log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy) + log.Debug("etchosts: answer: %s -> %v", host, ipsCopy) return ipsCopy } @@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [ return nil } - ipReal := UnreverseAddr(addr) - if ipReal == nil { + ip := UnreverseAddr(addr) + if ip == nil { return nil } - ipStr := ipReal.String() - ehc.lock.RLock() defer ehc.lock.RUnlock() - hosts = ehc.tableReverse[ipStr] - - if len(hosts) == 0 { - return nil // not found + v, ok := ehc.tableReverse.Get(ip) + if !ok { + return nil } - log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts) + hosts, ok = v.([]string) + if !ok { + log.Error("etchosts: bad type %T in tableReverse for %s", v, ip) + + return nil + } else if len(hosts) == 0 { + return nil + } + + log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts) return hosts } -// List returns an IP-to-hostnames table. It is safe for concurrent use. -func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) { +// List returns an IP-to-hostnames table. The type of the values in the map is +// []string. It is safe for concurrent use. +func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) { ehc.lock.RLock() defer ehc.lock.RUnlock() - ipToHosts = make(map[string][]string, len(ehc.tableReverse)) - for k, v := range ehc.tableReverse { - ipToHosts[k] = v - } - - return ipToHosts + return ehc.tableReverse.ShallowClone() } // update table @@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string ok = true } if ok { - log.Debug("etchostscontainer: added %s -> %s", ipAddr, host) + log.Debug("etchosts: added %s -> %s", ipAddr, host) } } // updateTableRev updates the reverse address table. -func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) { - ipStr := ipAddr.String() - hosts, ok := tableRev[ipStr] +func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) { + v, ok := tableRev.Get(ip) if !ok { - tableRev[ipStr] = []string{newHost} - log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) + tableRev.Set(ip, []string{newHost}) + log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost) return } + hosts, _ := v.([]string) for _, host := range hosts { if host == newHost { return } } - tableRev[ipStr] = append(tableRev[ipStr], newHost) - log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost) + hosts = append(hosts, newHost) + tableRev.Set(ip, hosts) + + log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost) } // parseHostsLine parses hosts from the fields. @@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) { // line for one IP are supported. func (ehc *EtcHostsContainer) load( table map[string][]net.IP, - tableRev map[string][]string, + tableRev *IPMap, fn string, ) { f, err := os.Open(fn) if err != nil { - log.Error("etchostscontainer: %s", err) + log.Error("etchosts: %s", err) return } @@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load( defer func() { derr := f.Close() if derr != nil { - log.Error("etchostscontainer: closing file: %s", err) + log.Error("etchosts: closing file: %s", err) } }() - log.Debug("etchostscontainer: loading hosts from file %s", fn) + log.Debug("etchosts: loading hosts from file %s", fn) s := bufio.NewScanner(f) for s.Scan() { @@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load( err = s.Err() if err != nil { - log.Error("etchostscontainer: %s", err) + log.Error("etchosts: %s", err) } } @@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() { } if event.Op&fsnotify.Write == fsnotify.Write { - log.Debug("etchostscontainer: modified: %s", event.Name) + log.Debug("etchosts: modified: %s", event.Name) ehc.updateHosts() } @@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() { if !ok { return } - log.Error("etchostscontainer: %s", err) + log.Error("etchosts: %s", err) } } } @@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() { // updateHosts - loads system hosts func (ehc *EtcHostsContainer) updateHosts() { table := make(map[string][]net.IP) - tableRev := make(map[string][]string) + tableRev := NewIPMap(0) ehc.load(table, tableRev, ehc.hostsFn) @@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() { des, err := os.ReadDir(dir) if err != nil { if !errors.Is(err, os.ErrNotExist) { - log.Error("etchostscontainer: Opening directory: %q: %s", dir, err) + log.Error("etchosts: Opening directory: %q: %s", dir, err) } continue diff --git a/internal/aghnet/etchostscontainer_test.go b/internal/aghnet/etchostscontainer_test.go index 74e4f46f..b9e7d8da 100644 --- a/internal/aghnet/etchostscontainer_test.go +++ b/internal/aghnet/etchostscontainer_test.go @@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) { }) t.Run("hosts_file", func(t *testing.T) { - names, ok := ehc.List()["127.0.0.1"] + names, ok := ehc.List().Get(net.IP{127, 0, 0, 1}) require.True(t, ok) assert.Equal(t, []string{"host", "localhost"}, names) }) diff --git a/internal/aghnet/ipmap.go b/internal/aghnet/ipmap.go new file mode 100644 index 00000000..78b3025d --- /dev/null +++ b/internal/aghnet/ipmap.go @@ -0,0 +1,112 @@ +package aghnet + +import ( + "fmt" + "net" +) + +// ipArr is a representation of an IP address as an array of bytes. +type ipArr [16]byte + +// String implements the fmt.Stringer interface for ipArr. +func (a ipArr) String() (s string) { + return net.IP(a[:]).String() +} + +// IPMap is a map of IP addresses. +type IPMap struct { + m map[ipArr]interface{} +} + +// NewIPMap returns a new empty IP map using hint as a size hint for the +// underlying map. +func NewIPMap(hint int) (m *IPMap) { + return &IPMap{ + m: make(map[ipArr]interface{}, hint), + } +} + +// ipToArr converts a net.IP into an ipArr. +// +// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17. +func ipToArr(ip net.IP) (a ipArr) { + copy(a[:], ip.To16()) + + return a +} + +// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just +// like delete on an empty map doesn't. +func (m *IPMap) Del(ip net.IP) { + if m != nil { + delete(m.m, ipToArr(ip)) + } +} + +// Get returns the value from the map. Calling Get on a nil *IPMap returns nil +// and false, just like indexing on an empty map does. +func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) { + if m != nil { + v, ok = m.m[ipToArr(ip)] + + return v, ok + } + + return nil, false +} + +// Len returns the length of the map. A nil *IPMap has a length of zero, just +// like an empty map. +func (m *IPMap) Len() (n int) { + if m == nil { + return 0 + } + + return len(m.m) +} + +// Range calls f for each key and value present in the map in an undefined +// order. If cont is false, range stops the iteration. Calling Range on a nil +// *IPMap has no effect, just like ranging over a nil map. +func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) { + if m == nil { + return + } + + for k, v := range m.m { + if !f(net.IP(k[:]), v) { + break + } + } +} + +// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map +// does. +func (m *IPMap) Set(ip net.IP, v interface{}) { + m.m[ipToArr(ip)] = v +} + +// ShallowClone returns a shallow clone of the map. +func (m *IPMap) ShallowClone() (sclone *IPMap) { + if m == nil { + return nil + } + + sclone = NewIPMap(m.Len()) + m.Range(func(ip net.IP, v interface{}) (cont bool) { + sclone.Set(ip, v) + + return true + }) + + return sclone +} + +// String implements the fmt.Stringer interface for *IPMap. +func (m *IPMap) String() (s string) { + if m == nil { + return "" + } + + return fmt.Sprint(m.m) +} diff --git a/internal/aghnet/ipmap_test.go b/internal/aghnet/ipmap_test.go new file mode 100644 index 00000000..3d3e765d --- /dev/null +++ b/internal/aghnet/ipmap_test.go @@ -0,0 +1,142 @@ +package aghnet + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIPMap_allocs(t *testing.T) { + ip4 := net.IP{1, 2, 3, 4} + m := NewIPMap(0) + m.Set(ip4, 42) + + t.Run("get", func(t *testing.T) { + var v interface{} + var ok bool + allocs := testing.AllocsPerRun(100, func() { + v, ok = m.Get(ip4) + }) + + require.True(t, ok) + require.Equal(t, 42, v) + + assert.Equal(t, float64(0), allocs) + }) + + t.Run("len", func(t *testing.T) { + var n int + allocs := testing.AllocsPerRun(100, func() { + n = m.Len() + }) + + require.Equal(t, 1, n) + + assert.Equal(t, float64(0), allocs) + }) +} + +func TestIPMap(t *testing.T) { + ip4 := net.IP{1, 2, 3, 4} + ip6 := net.IP{ + 0x12, 0x34, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x56, 0x78, + } + + val := 42 + + t.Run("nil", func(t *testing.T) { + var m *IPMap + + assert.NotPanics(t, func() { + m.Del(ip4) + m.Del(ip6) + }) + + assert.NotPanics(t, func() { + v, ok := m.Get(ip4) + assert.Nil(t, v) + assert.False(t, ok) + + v, ok = m.Get(ip6) + assert.Nil(t, v) + assert.False(t, ok) + }) + + assert.NotPanics(t, func() { + assert.Equal(t, 0, m.Len()) + }) + + assert.NotPanics(t, func() { + n := 0 + m.Range(func(_ net.IP, _ interface{}) (cont bool) { + n++ + + return true + }) + + assert.Equal(t, 0, n) + }) + + assert.Panics(t, func() { + m.Set(ip4, val) + }) + + assert.Panics(t, func() { + m.Set(ip6, val) + }) + + assert.NotPanics(t, func() { + sclone := m.ShallowClone() + assert.Nil(t, sclone) + }) + }) + + testIPMap := func(t *testing.T, ip net.IP, s string) { + m := NewIPMap(0) + assert.Equal(t, 0, m.Len()) + + v, ok := m.Get(ip) + assert.Nil(t, v) + assert.False(t, ok) + + m.Set(ip, val) + v, ok = m.Get(ip) + assert.Equal(t, val, v) + assert.True(t, ok) + + n := 0 + m.Range(func(ipKey net.IP, v interface{}) (cont bool) { + assert.Equal(t, ip.To16(), ipKey) + assert.Equal(t, val, v) + + n++ + + return false + }) + assert.Equal(t, 1, n) + + sclone := m.ShallowClone() + assert.Equal(t, m, sclone) + + assert.Equal(t, s, m.String()) + + m.Del(ip) + v, ok = m.Get(ip) + assert.Nil(t, v) + assert.False(t, ok) + assert.Equal(t, 0, m.Len()) + } + + t.Run("ipv4", func(t *testing.T) { + testIPMap(t, ip4, "map[1.2.3.4:42]") + }) + + t.Run("ipv6", func(t *testing.T) { + testIPMap(t, ip6, "map[1234::5678:42]") + }) +} diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index 73b40bd7..c53d6e63 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -6,138 +6,163 @@ import ( "net" "net/http" "strings" - "sync" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" ) +// accessCtx controls IP and client blocking that takes place before all other +// processing. An accessCtx is safe for concurrent use. type accessCtx struct { - lock sync.Mutex + allowedIPs *aghnet.IPMap + blockedIPs *aghnet.IPMap - // allowedClients are the IP addresses of clients in the allowlist. - allowedClients *aghstrings.Set + allowedClientIDs *aghstrings.Set + blockedClientIDs *aghstrings.Set - // disallowedClients are the IP addresses of clients in the blocklist. - disallowedClients *aghstrings.Set + blockedHostsEng *urlfilter.DNSEngine - allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients - disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked - - blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked + // TODO(a.garipov): Create a type for a set of IP networks. + // aghnet.IPNetSet? + allowedNets []*net.IPNet + blockedNets []*net.IPNet } -func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) { - a = &accessCtx{ - allowedClients: aghstrings.NewSet(), - disallowedClients: aghstrings.NewSet(), - } +// unit is a convenient alias for struct{} +type unit = struct{} - err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients) - if err != nil { - return nil, fmt.Errorf("processing allowed clients: %w", err) - } +// processAccessClients is a helper for processing a list of client strings, +// which may be an IP address, a CIDR, or a ClientID. +func processAccessClients( + clientStrs []string, + ips *aghnet.IPMap, + nets *[]*net.IPNet, + clientIDs *aghstrings.Set, +) (err error) { + for i, s := range clientStrs { + if ip := net.ParseIP(s); ip != nil { + ips.Set(ip, unit{}) + } else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil { + ipnet.IP = cidrIP + *nets = append(*nets, ipnet) + } else { + idErr := ValidateClientID(s) + if idErr != nil { + return fmt.Errorf( + "value %q at index %d: bad ip, cidr, or clientid", + s, + i, + ) + } - err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients) - if err != nil { - return nil, fmt.Errorf("processing disallowed clients: %w", err) - } - - b := &strings.Builder{} - for _, s := range blockedHosts { - aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n") - } - - listArray := []filterlist.RuleList{} - list := &filterlist.StringRuleList{ - ID: int(0), - RulesText: b.String(), - IgnoreCosmetic: true, - } - listArray = append(listArray, list) - rulesStorage, err := filterlist.NewRuleStorage(listArray) - if err != nil { - return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err) - } - a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage) - - return a, nil -} - -// Split array of IP or CIDR into 2 containers for fast search -func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error { - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - dst.Add(s) - - continue + clientIDs.Add(s) } - - _, ipnet, err := net.ParseCIDR(s) - if err != nil { - return err - } - - *dstIPNet = append(*dstIPNet, *ipnet) } return nil } -// IsBlockedIP - return TRUE if this client should be blocked -// Returns the item from the "disallowedClients" list that lead to blocking IP. -// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty, -// but the ip does not belong to it. -func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) { - ipStr := ip.String() +// newAccessCtx creates a new accessCtx. +func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) { + a = &accessCtx{ + allowedIPs: aghnet.NewIPMap(0), + blockedIPs: aghnet.NewIPMap(0), - a.lock.Lock() - defer a.lock.Unlock() - - if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 { - if a.allowedClients.Has(ipStr) { - return false, "" - } - - if len(a.allowedClientsIPNet) != 0 { - for _, ipnet := range a.allowedClientsIPNet { - if ipnet.Contains(ip) { - return false, "" - } - } - } - - return true, "" + allowedClientIDs: aghstrings.NewSet(), + blockedClientIDs: aghstrings.NewSet(), } - if a.disallowedClients.Has(ipStr) { - return true, ipStr + err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs) + if err != nil { + return nil, fmt.Errorf("adding allowed: %w", err) } - if len(a.disallowedClientsIPNet) != 0 { - for _, ipnet := range a.disallowedClientsIPNet { - if ipnet.Contains(ip) { - return true, ipnet.String() - } - } + err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs) + if err != nil { + return nil, fmt.Errorf("adding blocked: %w", err) } - return false, "" + b := &strings.Builder{} + for _, h := range blockedHosts { + aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n") + } + + lists := []filterlist.RuleList{ + &filterlist.StringRuleList{ + ID: int(0), + RulesText: b.String(), + IgnoreCosmetic: true, + }, + } + + rulesStrg, err := filterlist.NewRuleStorage(lists) + if err != nil { + return nil, fmt.Errorf("adding blocked hosts: %w", err) + } + + a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg) + + return a, nil } -// IsBlockedDomain - return TRUE if this domain should be blocked -func (a *accessCtx) IsBlockedDomain(host string) (ok bool) { - a.lock.Lock() - defer a.lock.Unlock() +// allowlistMode returns true if this *accessCtx is in the allowlist mode. +func (a *accessCtx) allowlistMode() (ok bool) { + return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0 +} - _, ok = a.blockedHostsEngine.Match(strings.ToLower(host)) +// isBlockedClientID returns true if the ClientID should be blocked. +func (a *accessCtx) isBlockedClientID(id string) (ok bool) { + allowlistMode := a.allowlistMode() + if id == "" { + // In allowlist mode, consider requests without client IDs + // blocked by default. + return allowlistMode + } + + if allowlistMode { + return !a.allowedClientIDs.Has(id) + } + + return a.blockedClientIDs.Has(id) +} + +// isBlockedHost returns true if host should be blocked. +func (a *accessCtx) isBlockedHost(host string) (ok bool) { + _, ok = a.blockedHostsEng.Match(strings.ToLower(host)) return ok } +// isBlockedIP returns the status of the IP address blocking as well as the rule +// that blocked it. +func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) { + blocked = true + ips := a.blockedIPs + ipnets := a.blockedNets + + if a.allowlistMode() { + // Enable allowlist mode and use the allowlist sets. + blocked = false + ips = a.allowedIPs + ipnets = a.allowedNets + } + + if _, ok := ips.Get(ip); ok { + return blocked, ip.String() + } + + for _, ipnet := range ipnets { + if ipnet.Contains(ip) { + return blocked, ipnet.String() + } + } + + return !blocked, "" +} + type accessListJSON struct { AllowedClients []string `json:"allowed_clients"` DisallowedClients []string `json:"disallowed_clients"` @@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(j) if err != nil { - httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err) + httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err) + return } } -func checkIPCIDRArray(src []string) error { - for _, s := range src { - ip := net.ParseIP(s) - if ip != nil { - continue - } - - _, _, err := net.ParseCIDR(s) - if err != nil { - return err - } - } - - return nil -} - func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { - j := accessListJSON{} - err := json.NewDecoder(r.Body).Decode(&j) + list := accessListJSON{} + err := json.NewDecoder(r.Body).Decode(&list) if err != nil { - httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err) - return - } + httpError(r, w, http.StatusBadRequest, "decoding request: %s", err) - err = checkIPCIDRArray(j.AllowedClients) - if err == nil { - err = checkIPCIDRArray(j.DisallowedClients) - } - if err != nil { - httpError(r, w, http.StatusBadRequest, "%s", err) return } var a *accessCtx - a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts) + a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts) if err != nil { httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err) return } - defer log.Debug("Access: updated lists: %d, %d, %d", - len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts)) + defer log.Debug( + "access: updated lists: %d, %d, %d", + len(list.AllowedClients), + len(list.DisallowedClients), + len(list.BlockedHosts), + ) defer s.conf.ConfigModified() s.serverLock.Lock() defer s.serverLock.Unlock() - s.conf.AllowedClients = j.AllowedClients - s.conf.DisallowedClients = j.DisallowedClients - s.conf.BlockedHosts = j.BlockedHosts + s.conf.AllowedClients = list.AllowedClients + s.conf.DisallowedClients = list.DisallowedClients + s.conf.BlockedHosts = list.BlockedHosts s.access = a } diff --git a/internal/dnsforward/access_test.go b/internal/dnsforward/access_test.go index eec5c511..7f9c4e79 100644 --- a/internal/dnsforward/access_test.go +++ b/internal/dnsforward/access_test.go @@ -8,99 +8,23 @@ import ( "github.com/stretchr/testify/require" ) -func TestIsBlockedIP(t *testing.T) { - const ( - ip int = iota - cidr - ) +func TestIsBlockedClientID(t *testing.T) { + clientID := "client-1" + clients := []string{clientID} - rules := []string{ - ip: "1.1.1.1", - cidr: "2.2.0.0/16", - } + a, err := newAccessCtx(clients, nil, nil) + require.NoError(t, err) - testCases := []struct { - name string - allowed bool - ip net.IP - wantDis bool - wantRule string - }{{ - name: "allow_ip", - allowed: true, - ip: net.IPv4(1, 1, 1, 1), - wantDis: false, - wantRule: "", - }, { - name: "disallow_ip", - allowed: true, - ip: net.IPv4(1, 1, 1, 2), - wantDis: true, - wantRule: "", - }, { - name: "allow_cidr", - allowed: true, - ip: net.IPv4(2, 2, 1, 1), - wantDis: false, - wantRule: "", - }, { - name: "disallow_cidr", - allowed: true, - ip: net.IPv4(2, 3, 1, 1), - wantDis: true, - wantRule: "", - }, { - name: "allow_ip", - allowed: false, - ip: net.IPv4(1, 1, 1, 1), - wantDis: true, - wantRule: rules[ip], - }, { - name: "disallow_ip", - allowed: false, - ip: net.IPv4(1, 1, 1, 2), - wantDis: false, - wantRule: "", - }, { - name: "allow_cidr", - allowed: false, - ip: net.IPv4(2, 2, 1, 1), - wantDis: true, - wantRule: rules[cidr], - }, { - name: "disallow_cidr", - allowed: false, - ip: net.IPv4(2, 3, 1, 1), - wantDis: false, - wantRule: "", - }} + assert.False(t, a.isBlockedClientID(clientID)) - for _, tc := range testCases { - prefix := "allowed_" - if !tc.allowed { - prefix = "disallowed_" - } + a, err = newAccessCtx(nil, clients, nil) + require.NoError(t, err) - t.Run(prefix+tc.name, func(t *testing.T) { - allowedRules := rules - var disallowedRules []string - - if !tc.allowed { - allowedRules, disallowedRules = disallowedRules, allowedRules - } - - aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil) - require.NoError(t, err) - - disallowed, rule := aCtx.IsBlockedIP(tc.ip) - assert.Equal(t, tc.wantDis, disallowed) - assert.Equal(t, tc.wantRule, rule) - }) - } + assert.True(t, a.isBlockedClientID(clientID)) } -func TestIsBlockedDomain(t *testing.T) { - aCtx, err := newAccessCtx(nil, nil, []string{ +func TestIsBlockedHost(t *testing.T) { + a, err := newAccessCtx(nil, nil, []string{ "host1", "*.host.com", "||host3.com^", @@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) { require.NoError(t, err) testCases := []struct { - name string - domain string - want bool + name string + host string + want bool }{{ - name: "plain_match", - domain: "host1", - want: true, + name: "plain_match", + host: "host1", + want: true, }, { - name: "plain_mismatch", - domain: "host2", - want: false, + name: "plain_mismatch", + host: "host2", + want: false, }, { - name: "wildcard_type-1_match_short", - domain: "asdf.host.com", - want: true, + name: "subdomain_match_short", + host: "asdf.host.com", + want: true, }, { - name: "wildcard_type-1_match_long", - domain: "qwer.asdf.host.com", - want: true, + name: "subdomain_match_long", + host: "qwer.asdf.host.com", + want: true, }, { - name: "wildcard_type-1_mismatch_no-lead", - domain: "host.com", - want: false, + name: "subdomain_mismatch_no_lead", + host: "host.com", + want: false, }, { - name: "wildcard_type-1_mismatch_bad-asterisk", - domain: "asdf.zhost.com", - want: false, + name: "subdomain_mismatch_bad_asterisk", + host: "asdf.zhost.com", + want: false, }, { - name: "wildcard_type-2_match_simple", - domain: "host3.com", - want: true, + name: "rule_match_simple", + host: "host3.com", + want: true, }, { - name: "wildcard_type-2_match_complex", - domain: "asdf.host3.com", - want: true, + name: "rule_match_complex", + host: "asdf.host3.com", + want: true, }, { - name: "wildcard_type-2_mismatch", - domain: ".host3.com", - want: false, + name: "rule_mismatch", + host: ".host3.com", + want: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain)) + assert.Equal(t, tc.want, a.isBlockedHost(tc.host)) }) } } + +func TestIsBlockedIP(t *testing.T) { + clients := []string{ + "1.2.3.4", + "5.6.7.8/24", + } + + allowCtx, err := newAccessCtx(clients, nil, nil) + require.NoError(t, err) + + blockCtx, err := newAccessCtx(nil, clients, nil) + require.NoError(t, err) + + testCases := []struct { + name string + wantRule string + ip net.IP + wantBlocked bool + }{{ + name: "match_ip", + wantRule: "1.2.3.4", + ip: net.IP{1, 2, 3, 4}, + wantBlocked: true, + }, { + name: "match_cidr", + wantRule: "5.6.7.8/24", + ip: net.IP{5, 6, 7, 100}, + wantBlocked: true, + }, { + name: "no_match_ip", + wantRule: "", + ip: net.IP{9, 2, 3, 4}, + wantBlocked: false, + }, { + name: "no_match_cidr", + wantRule: "", + ip: net.IP{9, 6, 7, 100}, + wantBlocked: false, + }} + + t.Run("allow", func(t *testing.T) { + for _, tc := range testCases { + blocked, rule := allowCtx.isBlockedIP(tc.ip) + assert.Equal(t, !tc.wantBlocked, blocked) + assert.Equal(t, tc.wantRule, rule) + } + }) + + t.Run("block", func(t *testing.T) { + for _, tc := range testCases { + blocked, rule := blockCtx.isBlockedIP(tc.ip) + assert.Equal(t, tc.wantBlocked, blocked) + assert.Equal(t, tc.wantRule, rule) + } + }) +} diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 995be08e..01301611 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -2,6 +2,7 @@ package dnsforward import ( "crypto/tls" + "encoding/binary" "fmt" "path" "strings" @@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) ( return clientID, nil } -// processClientIDHTTPS extracts the client's ID from the path of the +// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the // client's DNS-over-HTTPS request. -func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { - pctx := ctx.proxyCtx +func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) { r := pctx.HTTPRequest if r == nil { - ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) - - return resultCodeError + return "", fmt.Errorf( + "proxy ctx http request of proto %s is nil", + pctx.Proto, + ) } origPath := r.URL.Path @@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { } if len(parts) == 0 || parts[0] != "dns-query" { - ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) - - return resultCodeError + return "", fmt.Errorf("client id check: invalid path %q", origPath) } - clientID := "" switch len(parts) { case 1: // Just /dns-query, no client ID. - return resultCodeSuccess + return "", nil case 2: clientID = parts[1] default: - ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) - - return resultCodeError + return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath) } - err := ValidateClientID(clientID) + err = ValidateClientID(clientID) if err != nil { - ctx.err = fmt.Errorf("client id check: %w", err) - - return resultCodeError + return "", fmt.Errorf("client id check: %w", err) } - ctx.clientID = clientID - - return resultCodeSuccess + return clientID, nil } // tlsConn is a narrow interface for *tls.Conn to simplify testing. @@ -108,53 +100,73 @@ type quicSession interface { ConnectionState() (cs quic.ConnectionState) } -// processClientID extracts the client's ID from the server name of the client's -// DoT or DoQ request or the path of the client's DoH. -func processClientID(dctx *dnsContext) (rc resultCode) { - pctx := dctx.proxyCtx +// clientIDFromDNSContext extracts the client's ID from the server name of the +// client's DoT or DoQ request or the path of the client's DoH. If the protocol +// is not one of these, clientID is an empty string and err is nil. +func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) { proto := pctx.Proto if proto == proxy.ProtoHTTPS { - return processClientIDHTTPS(dctx) + return clientIDFromDNSContextHTTPS(pctx) } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { - return resultCodeSuccess + return "", nil } - srvConf := dctx.srv.conf - hostSrvName := srvConf.TLSConfig.ServerName + hostSrvName := s.conf.ServerName if hostSrvName == "" { - return resultCodeSuccess + return "", nil } cliSrvName := "" - if proto == proxy.ProtoTLS { + switch proto { + case proxy.ProtoTLS: conn := pctx.Conn tc, ok := conn.(tlsConn) if !ok { - dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) - - return resultCodeError + return "", fmt.Errorf( + "proxy ctx conn of proto %s is %T, want *tls.Conn", + proto, + conn, + ) } cliSrvName = tc.ConnectionState().ServerName - } else if proto == proxy.ProtoQUIC { + case proxy.ProtoQUIC: qs, ok := pctx.QUICSession.(quicSession) if !ok { - dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) - - return resultCodeError + return "", fmt.Errorf( + "proxy ctx quic session of proto %s is %T, want quic.Session", + proto, + pctx.QUICSession, + ) } cliSrvName = qs.ConnectionState().TLS.ServerName } - clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck) + clientID, err = clientIDFromClientServerName( + hostSrvName, + cliSrvName, + s.conf.StrictSNICheck, + ) if err != nil { - dctx.err = fmt.Errorf("client id check: %w", err) - - return resultCodeError + return "", fmt.Errorf("client id check: %w", err) } - dctx.clientID = clientID + return clientID, nil +} + +// processClientID puts the clientID into the DNS context, if there is one. +func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) { + pctx := dctx.proxyCtx + + var key [8]byte + binary.BigEndian.PutUint64(key[:], pctx.RequestID) + clientIDData := s.clientIDCache.Get(key[:]) + if clientIDData == nil { + return resultCodeSuccess + } + + dctx.clientID = string(clientIDData) return resultCodeSuccess } diff --git a/internal/dnsforward/clientid_test.go b/internal/dnsforward/clientid_test.go index 9394a9ab..b4adf8de 100644 --- a/internal/dnsforward/clientid_test.go +++ b/internal/dnsforward/clientid_test.go @@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) { return cs } -func TestProcessClientID(t *testing.T) { +func TestServer_clientIDFromDNSContext(t *testing.T) { testCases := []struct { name string - proto string + proto proxy.Proto hostSrvName string cliSrvName string wantClientID string wantErrMsg string - wantRes resultCode strictSNI bool }{{ name: "udp", @@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "", wantClientID: "", wantErrMsg: "", - wantRes: resultCodeSuccess, strictSNI: false, }, { name: "tls_no_client_id", @@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "example.com", wantClientID: "", wantErrMsg: "", - wantRes: resultCodeSuccess, strictSNI: true, }, { name: "tls_no_client_server_name", @@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: `client id check: client server name "" ` + `doesn't match host server name "example.com"`, - wantRes: resultCodeError, strictSNI: true, }, { name: "tls_no_client_server_name_no_strict", @@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "", wantClientID: "", wantErrMsg: "", - wantRes: resultCodeSuccess, strictSNI: false, }, { name: "tls_client_id", @@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "cli.example.com", wantClientID: "cli", wantErrMsg: "", - wantRes: resultCodeSuccess, strictSNI: true, }, { name: "tls_client_id_hostname_error", @@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: `client id check: client server name "cli.example.net" ` + `doesn't match host server name "example.com"`, - wantRes: resultCodeError, strictSNI: true, }, { name: "tls_invalid_client_id", @@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) { wantClientID: "", wantErrMsg: `client id check: invalid client id "!!!": ` + `invalid char '!' at index 0`, - wantRes: resultCodeError, strictSNI: true, }, { name: "tls_client_id_too_long", @@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) { wantErrMsg: `client id check: invalid client id "abcdefghijklmno` + `pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` + `label is too long, max: 63`, - wantRes: resultCodeError, strictSNI: true, }, { name: "quic_client_id", @@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) { cliSrvName: "cli.example.com", wantClientID: "cli", wantErrMsg: "", - wantRes: resultCodeSuccess, strictSNI: true, }} @@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) { ServerName: tc.hostSrvName, StrictSNICheck: tc.strictSNI, } + srv := &Server{ conf: ServerConfig{TLSConfig: tlsConf}, } @@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) { } } - dctx := &dnsContext{ - srv: srv, - proxyCtx: &proxy.DNSContext{ - Proto: tc.proto, - Conn: conn, - QUICSession: qs, - }, + pctx := &proxy.DNSContext{ + Proto: tc.proto, + Conn: conn, + QUICSession: qs, } - res := processClientID(dctx) - assert.Equal(t, tc.wantRes, res) - assert.Equal(t, tc.wantClientID, dctx.clientID) + clientID, err := srv.clientIDFromDNSContext(pctx) + assert.Equal(t, tc.wantClientID, clientID) if tc.wantErrMsg == "" { - assert.NoError(t, dctx.err) + assert.NoError(t, err) } else { - require.Error(t, dctx.err) - assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) + require.Error(t, err) + + assert.Equal(t, tc.wantErrMsg, err.Error()) } }) } } -func TestProcessClientID_https(t *testing.T) { +func TestClientIDFromDNSContextHTTPS(t *testing.T) { testCases := []struct { name string path string wantClientID string wantErrMsg string - wantRes resultCode }{{ name: "no_client_id", path: "/dns-query", wantClientID: "", wantErrMsg: "", - wantRes: resultCodeSuccess, }, { name: "no_client_id_slash", path: "/dns-query/", wantClientID: "", wantErrMsg: "", - wantRes: resultCodeSuccess, }, { name: "client_id", path: "/dns-query/cli", wantClientID: "cli", wantErrMsg: "", - wantRes: resultCodeSuccess, }, { name: "client_id_slash", path: "/dns-query/cli/", wantClientID: "cli", wantErrMsg: "", - wantRes: resultCodeSuccess, }, { name: "bad_url", path: "/foo", wantClientID: "", wantErrMsg: `client id check: invalid path "/foo"`, - wantRes: resultCodeError, }, { name: "extra", path: "/dns-query/cli/foo", wantClientID: "", wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`, - wantRes: resultCodeError, }, { name: "invalid_client_id", path: "/dns-query/!!!", wantClientID: "", wantErrMsg: `client id check: invalid client id "!!!": ` + `invalid char '!' at index 0`, - wantRes: resultCodeError, }} for _, tc := range testCases { @@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) { }, } - dctx := &dnsContext{ - proxyCtx: &proxy.DNSContext{ - Proto: proxy.ProtoHTTPS, - HTTPRequest: r, - }, + pctx := &proxy.DNSContext{ + Proto: proxy.ProtoHTTPS, + HTTPRequest: r, } - res := processClientID(dctx) - assert.Equal(t, tc.wantRes, res) - assert.Equal(t, tc.wantClientID, dctx.clientID) + clientID, err := clientIDFromDNSContextHTTPS(pctx) + assert.Equal(t, tc.wantClientID, clientID) if tc.wantErrMsg == "" { - assert.NoError(t, dctx.err) + assert.NoError(t, err) } else { - require.Error(t, dctx.err) + require.Error(t, err) - assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) + assert.Equal(t, tc.wantErrMsg, err.Error()) } }) } diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 04c133bb..213f6221 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error { upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty) upstreamConfig, err := proxy.ParseUpstreamsConfig( upstreams, - upstream.Options{ + &upstream.Options{ Bootstrap: s.conf.BootstrapDNS, Timeout: s.conf.UpstreamTimeout, }, @@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error { if len(upstreamConfig.Upstreams) == 0 { log.Info("warning: no default upstream servers specified, using %v", defaultDNS) - var uc proxy.UpstreamConfig + var uc *proxy.UpstreamConfig uc, err = proxy.ParseUpstreamsConfig( defaultDNS, - upstream.Options{ + &upstream.Options{ Bootstrap: s.conf.BootstrapDNS, Timeout: s.conf.UpstreamTimeout, }, @@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error { upstreamConfig.Upstreams = uc.Upstreams } - s.conf.UpstreamConfig = &upstreamConfig + s.conf.UpstreamConfig = upstreamConfig + return nil } diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 3e487531..657afe6b 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { s.processInternalHosts, s.processRestrictLocal, s.processInternalIPAddrs, - processClientID, + s.processClientID, processFilteringBeforeRequest, s.processLocalPTR, s.processUpstream, @@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) { s.tableHostToIP = t } -func (s *Server) setTableIPToHost(t ipToHostTable) { +func (s *Server) setTableIPToHost(t *aghnet.IPMap) { s.tableIPToHostLock.Lock() defer s.tableIPToHostLock.Unlock() @@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) { } var hostToIP hostToIPTable - var ipToHost ipToHostTable + var ipToHost *aghnet.IPMap if add { - hostToIP = make(hostToIPTable) - ipToHost = make(ipToHostTable) - ll := s.dhcpServer.Leases(dhcpd.LeasesAll) + hostToIP = make(hostToIPTable, len(ll)) + ipToHost = aghnet.NewIPMap(len(ll)) + for _, l := range ll { // TODO(a.garipov): Remove this after we're finished // with the client hostname validations in the DHCP @@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) { lowhost := strings.ToLower(l.Hostname) - ipToHost[l.IP.String()] = lowhost + ipToHost.Set(l.IP, lowhost) ip := make(net.IP, 4) copy(ip, l.IP.To4()) hostToIP[lowhost] = ip } - log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost)) + log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len()) } s.setTableHostToIP(hostToIP) @@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) { return "", false } - host, ok = s.tableIPToHost[ip.String()] + var v interface{} + v, ok = s.tableIPToHost.Get(ip) + + var typOK bool + if host, typOK = v.(string); !typOK { + log.Error("dns: bad type %T in tableIPToHost for %s", v, ip) + + return "", false + } return host, ok } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 616ab4f8..a28142ba 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -18,6 +18,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" @@ -26,6 +27,11 @@ import ( // DefaultTimeout is the default upstream timeout const DefaultTimeout = 10 * time.Second +// defaultClientIDCacheCount is the default count of items in the LRU client ID +// cache. The assumption here is that there won't be more than this many +// requests between the BeforeRequestHandler stage and the actual processing. +const defaultClientIDCacheCount = 1024 + const ( safeBrowsingBlockHost = "standard-block.dns.adguard.com" parentalBlockHost = "family-block.dns.adguard.com" @@ -44,12 +50,6 @@ var webRegistered bool // hostToIPTable is an alias for the type of Server.tableHostToIP. type hostToIPTable = map[string]net.IP -// ipToHostTable is an alias for the type of Server.tableIPToHost. -// -// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other -// places? -type ipToHostTable = map[string]string - // Server is the main way to start a DNS server. // // Example: @@ -81,9 +81,13 @@ type Server struct { tableHostToIP hostToIPTable tableHostToIPLock sync.Mutex - tableIPToHost ipToHostTable + tableIPToHost *aghnet.IPMap tableIPToHostLock sync.Mutex + // clientIDCache is a temporary storage for clientIDs that were + // extracted during the BeforeRequestHandler stage. + clientIDCache cache.Cache + // DNS proxy instance for internal usage // We don't Start() it and so no listen port is required. internalProxy *proxy.Proxy @@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { subnetDetector: p.SubnetDetector, localDomainSuffix: localDomainSuffix, recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum), + clientIDCache: cache.New(cache.Config{ + EnableLRU: true, + MaxCount: defaultClientIDCacheCount, + }), } // TODO(e.burkov): Enable the refresher after the actual implementation @@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) { log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs) - var upsConfig proxy.UpstreamConfig - upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{ - Bootstrap: bootstraps, - Timeout: defaultLocalTimeout, - // TODO(e.burkov): Should we verify server's ceritificates? - }) + var upsConfig *proxy.UpstreamConfig + upsConfig, err = proxy.ParseUpstreamsConfig( + localAddrs, + &upstream.Options{ + Bootstrap: bootstraps, + Timeout: defaultLocalTimeout, + // TODO(e.burkov): Should we verify server's ceritificates? + }, + ) if err != nil { return fmt.Errorf("parsing upstreams: %w", err) } s.localResolvers = &proxy.Proxy{ Config: proxy.Config{ - UpstreamConfig: &upsConfig, + UpstreamConfig: upsConfig, }, } @@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -// IsBlockedIP - return TRUE if this client should be blocked -func (s *Server) IsBlockedIP(ip net.IP) (bool, string) { - if ip == nil { - return false, "" +// IsBlockedClient returns true if the client is blocked by the current access +// settings. +func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) { + s.serverLock.RLock() + defer s.serverLock.RUnlock() + + allowlistMode := s.access.allowlistMode() + blockedByIP, rule := s.access.isBlockedIP(ip) + blockedByClientID := s.access.isBlockedClientID(clientID) + + // Allow if at least one of the checks allows in allowlist mode, but + // block if at least one of the checks blocks in blocklist mode. + if allowlistMode && blockedByIP && blockedByClientID { + log.Debug("client %s (id %q) is not in access allowlist", ip, clientID) + + // Return now without substituting the empty rule for the + // clientID because the rule can't be empty here. + return true, rule + } else if !allowlistMode && (blockedByIP || blockedByClientID) { + log.Debug("client %s (id %q) is in access blocklist", ip, clientID) + + blocked = true } - return s.access.IsBlockedIP(ip) + if rule == "" { + rule = clientID + } + + return blocked, rule } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index a7ed4fc4..e42ad30c 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -257,19 +257,22 @@ func TestServer(t *testing.T) { testCases := []struct { name string - proto string + net string + proto proxy.Proto }{{ name: "message_over_udp", + net: "", proto: proxy.ProtoUDP, }, { name: "message_over_tcp", + net: "tcp", proto: proxy.ProtoTCP, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { addr := s.dnsProxy.Addr(tc.proto) - client := dns.Client{Net: tc.proto} + client := dns.Client{Net: tc.net} reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String()) require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) @@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) { // Message over UDP. req := createGoogleATestMessage() addr := s.dnsProxy.Addr(proxy.ProtoUDP) - client := dns.Client{Net: proxy.ProtoUDP} + client := &dns.Client{} reply, _, err := client.Exchange(req, addr.String()) require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err) @@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) { // Create a DNS-over-QUIC upstream. addr := s.dnsProxy.Addr(proxy.ProtoQUIC) - opts := upstream.Options{InsecureSkipVerify: true} + opts := &upstream.Options{InsecureSkipVerify: true} u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts) require.NoError(t, err) @@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) { // Message over UDP. addr := s.dnsProxy.Addr(proxy.ProtoUDP) - conn, err := dns.Dial(proxy.ProtoUDP, addr.String()) + conn, err := dns.Dial("udp", addr.String()) require.NoErrorf(t, err, "cannot connect to the proxy: %s", err) sendTestMessagesAsync(t, conn) @@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) { startDeferStop(t, s) addr := s.dnsProxy.Addr(proxy.ProtoUDP).String() - client := dns.Client{Net: proxy.ProtoUDP} + client := &dns.Client{} yandexIP := net.IP{213, 180, 193, 56} googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com") @@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) { // Send a DNS request without question. _, _, err := (&dns.Client{ - Net: proxy.ProtoUDP, Timeout: 500 * time.Millisecond, }).Exchange(&req, addr) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index baabded2..d7510eeb 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -1,6 +1,7 @@ package dnsforward import ( + "encoding/binary" "fmt" "strings" @@ -11,23 +12,39 @@ import ( "github.com/miekg/dns" ) -func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { - ip := aghnet.IPFromAddr(d.Addr) - disallowed, _ := s.access.IsBlockedIP(ip) - if disallowed { - log.Tracef("Client IP %s is blocked by settings", ip) +// beforeRequestHandler is the handler that is called before any other +// processing, including logs. It performs access checks and puts the client +// ID, if there is one, into the server's cache. +func (s *Server) beforeRequestHandler( + _ *proxy.Proxy, + pctx *proxy.DNSContext, +) (reply bool, err error) { + ip := aghnet.IPFromAddr(pctx.Addr) + clientID, err := s.clientIDFromDNSContext(pctx) + if err != nil { + return false, fmt.Errorf("getting clientid: %w", err) + } + + blocked, _ := s.IsBlockedClient(ip, clientID) + if blocked { return false, nil } - if len(d.Req.Question) == 1 { - host := strings.TrimSuffix(d.Req.Question[0].Name, ".") - if s.access.IsBlockedDomain(host) { - log.Tracef("domain %s is blocked by access settings", host) + if len(pctx.Req.Question) == 1 { + host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".") + if s.access.isBlockedHost(host) { + log.Debug("host %s is in access blocklist", host) return false, nil } } + if clientID != "" { + key := [8]byte{} + binary.BigEndian.PutUint64(key[:], pctx.RequestID) + s.clientIDCache.Set(key[:], []byte(clientID)) + } + return true, nil } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 4dc74f97..3f2ccf55 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) { return boot, fmt.Errorf("invalid bootstrap server address: empty") } - if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil { + if _, err := upstream.NewResolver(boot, nil); err != nil { return boot, fmt.Errorf("invalid bootstrap server address: %w", err) } } @@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) { _, err = proxy.ParseUpstreamsConfig( upstreams, - upstream.Options{ + &upstream.Options{ Bootstrap: []string{}, Timeout: DefaultTimeout, }, @@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun log.Debug("checking if dns server %q works...", input) var u upstream.Upstream - u, err = upstream.AddressToUpstream(input, upstream.Options{ + u, err = upstream.AddressToUpstream(input, &upstream.Options{ Bootstrap: bootstrap, Timeout: timeout, }) diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go index 2ae7913b..64171288 100644 --- a/internal/dnsforward/stats_test.go +++ b/internal/dnsforward/stats_test.go @@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) { func TestProcessQueryLogsAndStats(t *testing.T) { testCases := []struct { name string - proto string + proto proxy.Proto addr net.Addr clientID string wantLogProto querylog.ClientProto @@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) { wantStatResult: stats.RParental, }} - ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{}) + ups, err := upstream.AddressToUpstream("1.1.1.1", nil) require.Nil(t, err) for _, tc := range testCases { diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index f14d20d9..e2c86c5e 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error { var err error d.safeBrowsingServer = defaultSafebrowsingServer d.parentalServer = defaultParentalServer - opts := upstream.Options{ + opts := &upstream.Options{ Timeout: dnsTimeout, ServerIPAddrs: []net.IP{ {94, 140, 14, 15}, diff --git a/internal/home/clients.go b/internal/home/clients.go index 2358003e..c33edbab 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -78,10 +78,13 @@ type RuntimeClientWHOISInfo struct { type clientsContainer struct { // TODO(a.garipov): Perhaps use a number of separate indices for // different types (string, net.IP, and so on). - list map[string]*Client // name -> client - idIndex map[string]*Client // ID -> client - ipToRC map[string]*RuntimeClient // IP -> runtime client - lock sync.Mutex + list map[string]*Client // name -> client + idIndex map[string]*Client // ID -> client + + // ipToRC is the IP address to *RuntimeClient map. + ipToRC *aghnet.IPMap + + lock sync.Mutex allTags *aghstrings.Set @@ -109,7 +112,7 @@ func (clients *clientsContainer) Init( } clients.list = make(map[string]*Client) clients.idIndex = make(map[string]*Client) - clients.ipToRC = make(map[string]*RuntimeClient) + clients.ipToRC = aghnet.NewIPMap(0) clients.allTags = aghstrings.NewSet(clientTags...) @@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() { clients.addFromHostsFile() } -// Exists checks if client with this ID already exists. -func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) { +// Exists checks if client with this IP address already exists. +func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - _, ok = clients.findLocked(id) + _, ok = clients.findLocked(ip.String()) if ok { return true } - var rc *RuntimeClient - rc, ok = clients.ipToRC[id] + rc, ok := clients.findRuntimeClientLocked(ip) if !ok { return false } @@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, for _, id := range ids { var name string whois := &querylog.ClientWHOIS{} + ip := net.ParseIP(id) c, ok := clients.Find(id) if ok { name = c.Name - } else { - var rc RuntimeClient - rc, ok = clients.FindRuntimeClient(id) + } else if ip != nil { + var rc *RuntimeClient + rc, ok = clients.FindRuntimeClient(ip) if !ok { continue } @@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, whois = toQueryLogWHOIS(rc.WHOISInfo) } - ip := net.ParseIP(id) - disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip) + disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id) return &querylog.Client{ Name: name, @@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams( return c.upstreamConfig, nil } - var conf proxy.UpstreamConfig + var conf *proxy.UpstreamConfig conf, err = proxy.ParseUpstreamsConfig( upstreams, - upstream.Options{ + &upstream.Options{ Bootstrap: config.DNS.BootstrapDNS, Timeout: config.DNS.UpstreamTimeout.Duration, }, @@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams( return nil, err } - c.upstreamConfig = &conf + c.upstreamConfig = conf - return &conf, nil + return conf, nil } // findLocked searches for a client by its ID. For internal use only. @@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { return nil, false } +// findRuntimeClientLocked finds a runtime client by their IP address. For +// internal use only. +func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) { + var v interface{} + v, ok = clients.ipToRC.Get(ip) + if !ok { + return nil, false + } + + rc, ok = v.(*RuntimeClient) + if !ok { + log.Error("clients: bad type %T in ipToRC for %s", v, ip) + + return nil, false + } + + return rc, true +} + // FindRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) { - ipAddr := net.ParseIP(ip) - if ipAddr == nil { - return RuntimeClient{}, false +func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) { + if ip == nil { + return nil, false } clients.lock.Lock() defer clients.lock.Unlock() - rc, ok := clients.ipToRC[ip] - if ok { - return *rc, true - } - - return RuntimeClient{}, false + return clients.findRuntimeClientLocked(ip) } // check validates the client. @@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { } // SetWHOISInfo sets the WHOIS information for a client. -func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) { +func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) { clients.lock.Lock() defer clients.lock.Unlock() - _, ok := clients.findLocked(ip) + _, ok := clients.findLocked(ip.String()) if ok { log.Debug("clients: client for %s is already created, ignore whois info", ip) return } - rc, ok := clients.ipToRC[ip] + rc, ok := clients.findRuntimeClientLocked(ip) if ok { rc.WHOISInfo = wi log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi) @@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI } rc.WHOISInfo = wi - clients.ipToRC[ip] = rc + + clients.ipToRC.Set(ip, rc) log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } // AddHost adds a new IP-hostname pairing. The priorities of the sources is // taken into account. ok is true if the pairing was added. -func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) { +func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) { clients.lock.Lock() defer clients.lock.Unlock() @@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok } // addHostLocked adds a new IP-hostname pairing. For internal use only. -func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) { +func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) { var rc *RuntimeClient - rc, ok = clients.ipToRC[ip] + rc, ok = clients.findRuntimeClientLocked(ip) if ok { if rc.Source > src { return false @@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource WHOISInfo: &RuntimeClientWHOISInfo{}, } - clients.ipToRC[ip] = rc + clients.ipToRC.Set(ip, rc) } - log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC)) + log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len()) return true } @@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource // rmHostsBySrc removes all entries that match the specified source. func (clients *clientsContainer) rmHostsBySrc(src clientSource) { n := 0 - for k, v := range clients.ipToRC { - if v.Source == src { - delete(clients.ipToRC, k) + clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) { + rc, ok := v.(*RuntimeClient) + if !ok { + log.Error("clients: bad type %T in ipToRC for %s", v, ip) + + return true + } + + if rc.Source == src { + clients.ipToRC.Del(ip) n++ } - } + + return true + }) log.Debug("clients: removed %d client aliases", n) } @@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() { clients.rmHostsBySrc(ClientSourceHostsFile) n := 0 - for ip, names := range hosts { + hosts.Range(func(ip net.IP, v interface{}) (cont bool) { + names, ok := v.([]string) + if !ok { + log.Error("dns: bad type %T in ipToRC for %s", v, ip) + } + for _, name := range names { - ok := clients.addHostLocked(ip, name, ClientSourceHostsFile) + ok = clients.addHostLocked(ip, name, ClientSourceHostsFile) if ok { n++ } } - } - log.Debug("Clients: added %d client aliases from system hosts-file", n) + return true + }) + + log.Debug("clients: added %d client aliases from system hosts-file", n) } // addFromSystemARP adds the IP-hostname pairings from the output of the arp -a @@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() { // TODO(a.garipov): Rewrite to use bufio.Scanner. lines := strings.Split(string(data), "\n") for _, ln := range lines { - open := strings.Index(ln, " (") - close := strings.Index(ln, ") ") - if open == -1 || close == -1 || open >= close { + lparen := strings.Index(ln, " (") + rparen := strings.Index(ln, ") ") + if lparen == -1 || rparen == -1 || lparen >= rparen { continue } - host := ln[:open] - ip := ln[open+2 : close] - if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil { + host := ln[:lparen] + ipStr := ln[lparen+2 : rparen] + ip := net.ParseIP(ipStr) + if aghnet.ValidateDomainName(host) != nil || ip == nil { continue } @@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) { continue } - ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP) + ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP) if ok { n++ } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 18f4f662..392a7545 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -26,6 +26,7 @@ func TestClients(t *testing.T) { ok, err := clients.Add(c) require.NoError(t, err) + assert.True(t, ok) c = &Client{ @@ -35,23 +36,27 @@ func TestClients(t *testing.T) { ok, err = clients.Add(c) require.NoError(t, err) + assert.True(t, ok) c, ok = clients.Find("1.1.1.1") require.True(t, ok) + assert.Equal(t, "client1", c.Name) c, ok = clients.Find("1:2:3::4") require.True(t, ok) + assert.Equal(t, "client1", c.Name) c, ok = clients.Find("2.2.2.2") require.True(t, ok) + assert.Equal(t, "client2", c.Name) - assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { @@ -101,8 +106,8 @@ func TestClients(t *testing.T) { }) require.NoError(t, err) - assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) - assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) + assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) err = clients.Update("client1", &Client{ IDs: []string{"1.1.1.2"}, @@ -113,21 +118,25 @@ func TestClients(t *testing.T) { c, ok := clients.Find("1.1.1.2") require.True(t, ok) + assert.Equal(t, "client1-renamed", c.Name) assert.True(t, c.UseOwnSettings) nilCli, ok := clients.list["client1"] require.False(t, ok) + assert.Nil(t, nilCli) require.Len(t, c.IDs, 1) + assert.Equal(t, "1.1.1.2", c.IDs[0]) }) t.Run("del_success", func(t *testing.T) { ok := clients.Del("client1-renamed") require.True(t, ok) - assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) + + assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { @@ -136,37 +145,44 @@ func TestClients(t *testing.T) { }) t.Run("addhost_success", func(t *testing.T) { - ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) + ip := net.IP{1, 1, 1, 1} + + ok, err := clients.AddHost(ip, "host", ClientSourceARP) require.NoError(t, err) + assert.True(t, ok) - ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) + ok, err = clients.AddHost(ip, "host2", ClientSourceARP) require.NoError(t, err) + assert.True(t, ok) - ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) + ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile) require.NoError(t, err) + assert.True(t, ok) - assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.True(t, clients.Exists(ip, ClientSourceHostsFile)) }) t.Run("dhcp_replaces_arp", func(t *testing.T) { - ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP) + ip := net.IP{1, 2, 3, 4} + + ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP) require.NoError(t, err) + assert.True(t, ok) + assert.True(t, clients.Exists(ip, ClientSourceARP)) - assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP)) - - ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP) + ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP) require.NoError(t, err) - assert.True(t, ok) - assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP)) + assert.True(t, ok) + assert.True(t, clients.Exists(ip, ClientSourceDHCP)) }) t.Run("addhost_fail", func(t *testing.T) { - ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) + ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS) require.NoError(t, err) assert.False(t, ok) }) @@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) { } t.Run("new_client", func(t *testing.T) { - clients.SetWHOISInfo("1.1.1.255", whois) + ip := net.IP{1, 1, 1, 255} + clients.SetWHOISInfo(ip, whois) + v, _ := clients.ipToRC.Get(ip) + require.NotNil(t, v) - require.NotNil(t, clients.ipToRC["1.1.1.255"]) + rc, ok := v.(*RuntimeClient) + require.True(t, ok) + require.NotNil(t, rc) - h := clients.ipToRC["1.1.1.255"] - require.NotNil(t, h) - - assert.Equal(t, h.WHOISInfo, whois) + assert.Equal(t, rc.WHOISInfo, whois) }) t.Run("existing_auto-client", func(t *testing.T) { - ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) + ip := net.IP{1, 1, 1, 1} + ok, err := clients.AddHost(ip, "host", ClientSourceRDNS) require.NoError(t, err) + assert.True(t, ok) - clients.SetWHOISInfo("1.1.1.1", whois) + clients.SetWHOISInfo(ip, whois) + v, _ := clients.ipToRC.Get(ip) + require.NotNil(t, v) - require.NotNil(t, clients.ipToRC["1.1.1.1"]) - h := clients.ipToRC["1.1.1.1"] - require.NotNil(t, h) + rc, ok := v.(*RuntimeClient) + require.True(t, ok) + require.NotNil(t, rc) - assert.Equal(t, h.WHOISInfo, whois) + assert.Equal(t, rc.WHOISInfo, whois) }) t.Run("can't_set_manually-added", func(t *testing.T) { + ip := net.IP{1, 1, 1, 2} + ok, err := clients.Add(&Client{ IDs: []string{"1.1.1.2"}, Name: "client1", @@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) { require.NoError(t, err) assert.True(t, ok) - clients.SetWHOISInfo("1.1.1.2", whois) - require.Nil(t, clients.ipToRC["1.1.1.2"]) + clients.SetWHOISInfo(ip, whois) + v, _ := clients.ipToRC.Get(ip) + require.Nil(t, v) + assert.True(t, clients.Del("client1")) }) } @@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) { clients.Init(nil, nil, nil) t.Run("simple", func(t *testing.T) { + ip := net.IP{1, 1, 1, 1} + // Add a client. ok, err := clients.Add(&Client{ - IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, + IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, Name: "client1", }) require.NoError(t, err) assert.True(t, ok) // Now add an auto-client with the same IP. - ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS) + ok, err = clients.AddHost(ip, "test", ClientSourceRDNS) require.NoError(t, err) assert.True(t, ok) }) @@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) { t.Run("complicated", func(t *testing.T) { var err error - testIP := net.IP{1, 2, 3, 4} + ip := net.IP{1, 2, 3, 4} // First, init a DHCP server with a single static lease. config := dhcpd.ServerConfig{ @@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) { err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{ HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, - IP: testIP, + IP: ip, Hostname: "testhost", Expiry: time.Now().Add(time.Hour), }) @@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) { // Add a new client with the same IP as for a client with MAC. ok, err := clients.Add(&Client{ - IDs: []string{testIP.String()}, + IDs: []string{ip.String()}, Name: "client2", }) require.NoError(t, err) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 0ac15a4c..412ff002 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -5,6 +5,8 @@ import ( "fmt" "net" "net/http" + + "github.com/AdguardTeam/golibs/log" ) // clientJSON is a common structure used by several handlers to deal with @@ -44,13 +46,13 @@ type clientJSON struct { type runtimeClientJSON struct { WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` - IP string `json:"ip"` Name string `json:"name"` Source string `json:"source"` + IP net.IP `json:"ip"` } type clientListJSON struct { - Clients []clientJSON `json:"clients"` + Clients []*clientJSON `json:"clients"` RuntimeClients []runtimeClientJSON `json:"auto_clients"` Tags []string `json:"supported_tags"` } @@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http cj := clientToJSON(c) data.Clients = append(data.Clients, cj) } - for ip, rc := range clients.ipToRC { + + clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) { + rc, ok := v.(*RuntimeClient) + if !ok { + log.Error("dns: bad type %T in ipToRC for %s", v, ip) + + return true + } + cj := runtimeClientJSON{ - IP: ip, - Name: rc.Host, WHOISInfo: rc.WHOISInfo, + + Name: rc.Host, + IP: ip, } cj.Source = "etc/hosts" @@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http } data.RuntimeClients = append(data.RuntimeClients, cj) - } + + return true + }) data.Tags = clientTags @@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) { } // Convert Client object to JSON -func clientToJSON(c *Client) clientJSON { - cj := clientJSON{ +func clientToJSON(c *Client) (cj *clientJSON) { + return &clientJSON{ Name: c.Name, IDs: c.IDs, Tags: c.Tags, @@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON { Upstreams: c.Upstreams, } - - return cj -} - -// runtimeClientToJSON converts a RuntimeClient into a JSON struct. -func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) { - cj = clientJSON{ - Name: rc.Host, - IDs: []string{ip}, - WHOISInfo: rc.WHOISInfo, - } - - return cj } // Add a new client @@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht // Get the list of clients by IP address list func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - data := []map[string]clientJSON{} + data := []map[string]*clientJSON{} for i := 0; i < len(q); i++ { idStr := q.Get(fmt.Sprintf("ip%d", i)) if idStr == "" { @@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http ip := net.ParseIP(idStr) c, ok := clients.Find(idStr) - var cj clientJSON + var cj *clientJSON if !ok { - var found bool - cj, found = clients.findRuntime(ip, idStr) - if !found { - continue - } + cj = clients.findRuntime(ip, idStr) } else { cj = clientToJSON(c) - disallowed, rule := clients.dnsServer.IsBlockedIP(ip) + disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule } - data = append(data, map[string]clientJSON{ + data = append(data, map[string]*clientJSON{ idStr: cj, }) } @@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http } // findRuntime looks up the IP in runtime and temporary storages, like -// /etc/hosts tables, DHCP leases, or blocklists. -func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) { - if ip == nil { - return cj, false - } - - rc, ok := clients.FindRuntimeClient(idStr) +// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be +// non-nil. +func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) { + rc, ok := clients.FindRuntimeClient(ip) if !ok { // It is still possible that the IP used to be in the runtime // clients list, but then the server was reloaded. So, check // the DNS server's blocked IP list. // // See https://github.com/AdguardTeam/AdGuardHome/issues/2428. - disallowed, rule := clients.dnsServer.IsBlockedIP(ip) - if rule == "" { - return clientJSON{}, false - } - - cj = clientJSON{ + disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) + cj = &clientJSON{ IDs: []string{idStr}, Disallowed: &disallowed, DisallowedRule: &rule, WHOISInfo: &RuntimeClientWHOISInfo{}, } - return cj, true + return cj } - cj = runtimeClientToJSON(idStr, rc) - disallowed, rule := clients.dnsServer.IsBlockedIP(ip) + cj = &clientJSON{ + Name: rc.Host, + IDs: []string{idStr}, + WHOISInfo: rc.WHOISInfo, + } + + disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr) cj.Disallowed, cj.DisallowedRule = &disallowed, &rule - return cj, true + return cj } // RegisterClientsHandlers registers HTTP handlers diff --git a/internal/home/dns.go b/internal/home/dns.go index a61afc27..c531d1c2 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -105,8 +105,8 @@ func isRunning() bool { return Context.dnsServer != nil && Context.dnsServer.IsRunning() } -func onDNSRequest(d *proxy.DNSContext) { - ip := aghnet.IPFromAddr(d.Addr) +func onDNSRequest(pctx *proxy.DNSContext) { + ip := aghnet.IPFromAddr(pctx.Addr) if ip == nil { // This would be quite weird if we get here. return diff --git a/internal/home/home.go b/internal/home/home.go index 683c267c..a6ee025e 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port You have two options: 1. Run AdGuard Home with root privileges 2. On Linux you can grant the CAP_NET_BIND_SERVICE capability: -https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser` +https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser` log.Fatal(msg) } diff --git a/internal/home/rdns.go b/internal/home/rdns.go index d19af7c8..cba748af 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) { func (r *RDNS) Begin(ip net.IP) { r.ensurePrivateCache() - if r.isCached(ip) { - return - } - - id := ip.String() - if r.clients.Exists(id, ClientSourceRDNS) { + if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) { return } @@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() { // Don't handle any errors since AddHost doesn't return non-nil // errors for now. - _, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS) + _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) } } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 3655b6b4..a4a37c2c 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/dnsproxy/upstream" @@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) { clients: &clientsContainer{ list: map[string]*Client{}, idIndex: tc.cliIDIndex, - ipToRC: map[string]*RuntimeClient{}, + ipToRC: aghnet.NewIPMap(0), allTags: aghstrings.NewSet(), }, } @@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { cc := &clientsContainer{ list: map[string]*Client{}, idIndex: map[string]*Client{}, - ipToRC: map[string]*RuntimeClient{}, + ipToRC: aghnet.NewIPMap(0), allTags: aghstrings.NewSet(), } ch := make(chan net.IP) @@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { return } - assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS)) + assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS)) }) } } diff --git a/internal/home/whois.go b/internal/home/whois.go index b80dd794..040e7210 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() { continue } - id := ip.String() - w.clients.SetWHOISInfo(id, info) + w.clients.SetWHOISInfo(ip, info) } } diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 8997da28..ca4199f6 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { a := convertMapToSlice(m, int(maxCount)) d := []net.IP{} for _, it := range a { - d = append(d, net.ParseIP(it.Name)) + ip := net.ParseIP(it.Name) + if ip != nil { + d = append(d, ip) + } } return d } diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 4744f720..faf4c6b3 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,6 +4,11 @@ ## v0.107: API changes +### Client IDs in Access Settings + +* The `POST /control/access/set` HTTP API now accepts client IDs in + `"allowed_clients"` and `"disallowed_clients"` fields. + ### The new field `"unicode_name"` in `DNSQuestion` * The new optional field `"unicode_name"` is the Unicode representation of @@ -17,7 +22,7 @@ ### Disabling Statistics -* The API `POST /control/stats_config` HTTP API allows disabling statistics by +* The `POST /control/stats_config` HTTP API allows disabling statistics by setting `"interval"` to `0`. ### `POST /control/dhcp/reset_leases` diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 5410427b..06db336e 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -1957,10 +1957,7 @@ 'disallowed_rule': 'type': 'string' 'description': > - The rule due to which the client is disallowed. If disallowed is - set to true, and this string is empty, then the client IP is - disallowed by the "allowed IP list", that is it is not included in - the allowed list. + The rule due to which the client is allowed or blocked. 'name': 'description': > Persistent client's name or an empty string if this is a runtime @@ -2352,17 +2349,19 @@ 'description': 'Client and host access list' 'properties': 'allowed_clients': - 'description': 'Allowlist of clients.' + 'description': > + The allowlist of clients: IP addresses, CIDRs, or client IDs. 'items': 'type': 'string' 'type': 'array' 'disallowed_clients': - 'description': 'Blocklist of clients.' + 'description': > + The blocklist of clients: IP addresses, CIDRs, or client IDs. 'items': 'type': 'string' 'type': 'array' 'blocked_hosts': - 'description': 'Blocklist of hosts.' + 'description': 'The blocklist of hosts.' 'items': 'type': 'string' 'type': 'array'