Pull request: all: allow clientid in access settings
Updates #2624. Updates #3162. Squashed commit of the following: commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:41:33 2021 +0300 all: imp types, names commit ebd4ec26636853d0d58c4e331e6a78feede20813 Merge: 239eb72116e5e09c
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:14:33 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 239eb7215abc47e99a0300a0f4cf56002689b1a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:13:10 2021 +0300 all: fix client blocking check commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13 Merge: 9935f2a39d1656b5
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 13:12:28 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448 Author: Ildar Kamalov <ik@adguard.com> Date: Tue Jun 29 11:26:51 2021 +0300 client: show block button for client id commit ed786a6a74a081cd89e9d67df3537a4fadd54831 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:56:23 2021 +0300 client: imp i18n commit 4fed21c68473ad408960c08a7d87624cabce1911 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:34:09 2021 +0300 all: imp i18n, docs commit 55e65c0d6b939560c53dcb834a4557eb3853d194 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 13:34:01 2021 +0300 all: fix cache, imp code, docs, tests commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jun 24 19:27:12 2021 +0300 all: allow clientid in access settings
This commit is contained in:
parent
16e5e09c2e
commit
e08a64ebe4
CHANGELOG.mdHACKING.md
client/src
go.modgo.suminternal
aghnet
dnsforward
access.goaccess_test.goclientid.goclientid_test.goconfig.godns.godnsforward.godnsforward_test.gofilter.gohttp.gostats_test.go
filtering
home
stats
openapi
|
@ -15,6 +15,7 @@ and this project adheres to
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
|
- Blocking access using client IDs ([#2624], [#3162]).
|
||||||
- `source` directives support in `/etc/network/interfaces` on Linux ([#3257]).
|
- `source` directives support in `/etc/network/interfaces` on Linux ([#3257]).
|
||||||
- RFC 9000 support in DNS-over-QUIC.
|
- RFC 9000 support in DNS-over-QUIC.
|
||||||
- Completely disabling statistics by setting the statistics interval to zero
|
- Completely disabling statistics by setting the statistics interval to zero
|
||||||
|
@ -80,9 +81,11 @@ released by then.
|
||||||
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
|
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
|
||||||
[#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441
|
[#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441
|
||||||
[#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443
|
[#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443
|
||||||
|
[#2624]: https://github.com/AdguardTeam/AdGuardHome/issues/2624
|
||||||
[#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763
|
[#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763
|
||||||
[#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013
|
[#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013
|
||||||
[#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136
|
[#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136
|
||||||
|
[#3162]: https://github.com/AdguardTeam/AdGuardHome/issues/3162
|
||||||
[#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166
|
[#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166
|
||||||
[#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172
|
[#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172
|
||||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||||
|
|
|
@ -159,8 +159,10 @@ attributes to make it work in Markdown renderers that strip "id". -->
|
||||||
|
|
||||||
* Minimize scope of variables as much as possible.
|
* Minimize scope of variables as much as possible.
|
||||||
|
|
||||||
* No shadowing, since it can often lead to subtle bugs, especially with
|
* No name shadowing, including of predeclared identifiers, since it can often
|
||||||
errors.
|
lead to subtle bugs, especially with errors. This rule does not apply to
|
||||||
|
struct fields, since they are always used together with the name of the
|
||||||
|
struct value, so there isn't any confusion.
|
||||||
|
|
||||||
* Prefer constants to variables where possible. Avoid global variables. Use
|
* Prefer constants to variables where possible. Avoid global variables. Use
|
||||||
[constant errors] instead of `errors.New`.
|
[constant errors] instead of `errors.New`.
|
||||||
|
|
|
@ -426,9 +426,9 @@
|
||||||
"access_title": "Access settings",
|
"access_title": "Access settings",
|
||||||
"access_desc": "Here you can configure access rules for the AdGuard Home DNS server.",
|
"access_desc": "Here you can configure access rules for the AdGuard Home DNS server.",
|
||||||
"access_allowed_title": "Allowed clients",
|
"access_allowed_title": "Allowed clients",
|
||||||
"access_allowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will accept requests from these IP addresses only.",
|
"access_allowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will accept requests only from these clients.",
|
||||||
"access_disallowed_title": "Disallowed clients",
|
"access_disallowed_title": "Disallowed clients",
|
||||||
"access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.",
|
"access_disallowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will drop requests from these clients. If allowed clients are configured, this field is ignored.",
|
||||||
"access_blocked_title": "Disallowed domains",
|
"access_blocked_title": "Disallowed domains",
|
||||||
"access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.",
|
"access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.",
|
||||||
"access_settings_saved": "Access settings successfully saved",
|
"access_settings_saved": "Access settings successfully saved",
|
||||||
|
|
|
@ -9,7 +9,7 @@ import Card from '../ui/Card';
|
||||||
import Cell from '../ui/Cell';
|
import Cell from '../ui/Cell';
|
||||||
|
|
||||||
import { getPercent, sortIp } from '../../helpers/helpers';
|
import { getPercent, sortIp } from '../../helpers/helpers';
|
||||||
import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants';
|
import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants';
|
||||||
import { toggleClientBlock } from '../../actions/access';
|
import { toggleClientBlock } from '../../actions/access';
|
||||||
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
|
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
|
||||||
import { getStats } from '../../actions/stats';
|
import { getStats } from '../../actions/stats';
|
||||||
|
@ -35,10 +35,6 @@ const CountCell = (row) => {
|
||||||
};
|
};
|
||||||
|
|
||||||
const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
|
const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
|
||||||
if (R_CLIENT_ID.test(ip)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
const dispatch = useDispatch();
|
const dispatch = useDispatch();
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const processingSet = useSelector((state) => state.access.processingSet);
|
const processingSet = useSelector((state) => state.access.processingSet);
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
||||||
go 1.16
|
go 1.16
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.37.7
|
github.com/AdguardTeam/dnsproxy v0.38.0
|
||||||
github.com/AdguardTeam/golibs v0.8.0
|
github.com/AdguardTeam/golibs v0.8.0
|
||||||
github.com/AdguardTeam/urlfilter v0.14.6
|
github.com/AdguardTeam/urlfilter v0.14.6
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -9,8 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D
|
||||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
||||||
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk=
|
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk=
|
||||||
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
|
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
|
||||||
github.com/AdguardTeam/dnsproxy v0.37.7 h1:yp0vEVYobf/1l8iY7es9yMqguw8BUEeC74OGA4G2v2A=
|
github.com/AdguardTeam/dnsproxy v0.38.0 h1:7GyyNJOieIVOgdnhu47exqWjHPQro7wQhqzvQjaZt6M=
|
||||||
github.com/AdguardTeam/dnsproxy v0.37.7/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
|
github.com/AdguardTeam/dnsproxy v0.38.0/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
|
||||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||||
|
|
|
@ -27,10 +27,9 @@ type EtcHostsContainer struct {
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
// table is the host-to-IPs map.
|
// table is the host-to-IPs map.
|
||||||
table map[string][]net.IP
|
table map[string][]net.IP
|
||||||
// tableReverse is the IP-to-hosts map.
|
// tableReverse is the IP-to-hosts map. The type of the values in the
|
||||||
//
|
// map is []string.
|
||||||
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
|
tableReverse *IPMap
|
||||||
tableReverse map[string][]string
|
|
||||||
|
|
||||||
hostsFn string // path to the main hosts-file
|
hostsFn string // path to the main hosts-file
|
||||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||||
|
@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
||||||
var err error
|
var err error
|
||||||
ehc.watcher, err = fsnotify.NewWatcher()
|
ehc.watcher, err = fsnotify.NewWatcher()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("etchostscontainer: %s", err)
|
log.Error("etchosts: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
||||||
copy(ipsCopy, ips)
|
copy(ipsCopy, ips)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy)
|
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
||||||
return ipsCopy
|
return ipsCopy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipReal := UnreverseAddr(addr)
|
ip := UnreverseAddr(addr)
|
||||||
if ipReal == nil {
|
if ip == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ipStr := ipReal.String()
|
|
||||||
|
|
||||||
ehc.lock.RLock()
|
ehc.lock.RLock()
|
||||||
defer ehc.lock.RUnlock()
|
defer ehc.lock.RUnlock()
|
||||||
|
|
||||||
hosts = ehc.tableReverse[ipStr]
|
v, ok := ehc.tableReverse.Get(ip)
|
||||||
|
if !ok {
|
||||||
if len(hosts) == 0 {
|
return nil
|
||||||
return nil // not found
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts)
|
hosts, ok = v.([]string)
|
||||||
|
if !ok {
|
||||||
|
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
} else if len(hosts) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
||||||
|
|
||||||
return hosts
|
return hosts
|
||||||
}
|
}
|
||||||
|
|
||||||
// List returns an IP-to-hostnames table. It is safe for concurrent use.
|
// List returns an IP-to-hostnames table. The type of the values in the map is
|
||||||
func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) {
|
// []string. It is safe for concurrent use.
|
||||||
|
func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) {
|
||||||
ehc.lock.RLock()
|
ehc.lock.RLock()
|
||||||
defer ehc.lock.RUnlock()
|
defer ehc.lock.RUnlock()
|
||||||
|
|
||||||
ipToHosts = make(map[string][]string, len(ehc.tableReverse))
|
return ehc.tableReverse.ShallowClone()
|
||||||
for k, v := range ehc.tableReverse {
|
|
||||||
ipToHosts[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
return ipToHosts
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update table
|
// update table
|
||||||
|
@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string
|
||||||
ok = true
|
ok = true
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
log.Debug("etchostscontainer: added %s -> %s", ipAddr, host)
|
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateTableRev updates the reverse address table.
|
// updateTableRev updates the reverse address table.
|
||||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) {
|
func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) {
|
||||||
ipStr := ipAddr.String()
|
v, ok := tableRev.Get(ip)
|
||||||
hosts, ok := tableRev[ipStr]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tableRev[ipStr] = []string{newHost}
|
tableRev.Set(ip, []string{newHost})
|
||||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hosts, _ := v.([]string)
|
||||||
for _, host := range hosts {
|
for _, host := range hosts {
|
||||||
if host == newHost {
|
if host == newHost {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tableRev[ipStr] = append(tableRev[ipStr], newHost)
|
hosts = append(hosts, newHost)
|
||||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
tableRev.Set(ip, hosts)
|
||||||
|
|
||||||
|
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseHostsLine parses hosts from the fields.
|
// parseHostsLine parses hosts from the fields.
|
||||||
|
@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) {
|
||||||
// line for one IP are supported.
|
// line for one IP are supported.
|
||||||
func (ehc *EtcHostsContainer) load(
|
func (ehc *EtcHostsContainer) load(
|
||||||
table map[string][]net.IP,
|
table map[string][]net.IP,
|
||||||
tableRev map[string][]string,
|
tableRev *IPMap,
|
||||||
fn string,
|
fn string,
|
||||||
) {
|
) {
|
||||||
f, err := os.Open(fn)
|
f, err := os.Open(fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("etchostscontainer: %s", err)
|
log.Error("etchosts: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load(
|
||||||
defer func() {
|
defer func() {
|
||||||
derr := f.Close()
|
derr := f.Close()
|
||||||
if derr != nil {
|
if derr != nil {
|
||||||
log.Error("etchostscontainer: closing file: %s", err)
|
log.Error("etchosts: closing file: %s", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
log.Debug("etchosts: loading hosts from file %s", fn)
|
||||||
|
|
||||||
s := bufio.NewScanner(f)
|
s := bufio.NewScanner(f)
|
||||||
for s.Scan() {
|
for s.Scan() {
|
||||||
|
@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load(
|
||||||
|
|
||||||
err = s.Err()
|
err = s.Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("etchostscontainer: %s", err)
|
log.Error("etchosts: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||||
log.Debug("etchostscontainer: modified: %s", event.Name)
|
log.Debug("etchosts: modified: %s", event.Name)
|
||||||
ehc.updateHosts()
|
ehc.updateHosts()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Error("etchostscontainer: %s", err)
|
log.Error("etchosts: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||||
// updateHosts - loads system hosts
|
// updateHosts - loads system hosts
|
||||||
func (ehc *EtcHostsContainer) updateHosts() {
|
func (ehc *EtcHostsContainer) updateHosts() {
|
||||||
table := make(map[string][]net.IP)
|
table := make(map[string][]net.IP)
|
||||||
tableRev := make(map[string][]string)
|
tableRev := NewIPMap(0)
|
||||||
|
|
||||||
ehc.load(table, tableRev, ehc.hostsFn)
|
ehc.load(table, tableRev, ehc.hostsFn)
|
||||||
|
|
||||||
|
@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() {
|
||||||
des, err := os.ReadDir(dir)
|
des, err := os.ReadDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
log.Error("etchostscontainer: Opening directory: %q: %s", dir, err)
|
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("hosts_file", func(t *testing.T) {
|
t.Run("hosts_file", func(t *testing.T) {
|
||||||
names, ok := ehc.List()["127.0.0.1"]
|
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||||
})
|
})
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ipArr is a representation of an IP address as an array of bytes.
|
||||||
|
type ipArr [16]byte
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer interface for ipArr.
|
||||||
|
func (a ipArr) String() (s string) {
|
||||||
|
return net.IP(a[:]).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPMap is a map of IP addresses.
|
||||||
|
type IPMap struct {
|
||||||
|
m map[ipArr]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPMap returns a new empty IP map using hint as a size hint for the
|
||||||
|
// underlying map.
|
||||||
|
func NewIPMap(hint int) (m *IPMap) {
|
||||||
|
return &IPMap{
|
||||||
|
m: make(map[ipArr]interface{}, hint),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipToArr converts a net.IP into an ipArr.
|
||||||
|
//
|
||||||
|
// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17.
|
||||||
|
func ipToArr(ip net.IP) (a ipArr) {
|
||||||
|
copy(a[:], ip.To16())
|
||||||
|
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just
|
||||||
|
// like delete on an empty map doesn't.
|
||||||
|
func (m *IPMap) Del(ip net.IP) {
|
||||||
|
if m != nil {
|
||||||
|
delete(m.m, ipToArr(ip))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the value from the map. Calling Get on a nil *IPMap returns nil
|
||||||
|
// and false, just like indexing on an empty map does.
|
||||||
|
func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) {
|
||||||
|
if m != nil {
|
||||||
|
v, ok = m.m[ipToArr(ip)]
|
||||||
|
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the length of the map. A nil *IPMap has a length of zero, just
|
||||||
|
// like an empty map.
|
||||||
|
func (m *IPMap) Len() (n int) {
|
||||||
|
if m == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return len(m.m)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range calls f for each key and value present in the map in an undefined
|
||||||
|
// order. If cont is false, range stops the iteration. Calling Range on a nil
|
||||||
|
// *IPMap has no effect, just like ranging over a nil map.
|
||||||
|
func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range m.m {
|
||||||
|
if !f(net.IP(k[:]), v) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map
|
||||||
|
// does.
|
||||||
|
func (m *IPMap) Set(ip net.IP, v interface{}) {
|
||||||
|
m.m[ipToArr(ip)] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShallowClone returns a shallow clone of the map.
|
||||||
|
func (m *IPMap) ShallowClone() (sclone *IPMap) {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sclone = NewIPMap(m.Len())
|
||||||
|
m.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||||
|
sclone.Set(ip, v)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return sclone
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer interface for *IPMap.
|
||||||
|
func (m *IPMap) String() (s string) {
|
||||||
|
if m == nil {
|
||||||
|
return "<nil>"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprint(m.m)
|
||||||
|
}
|
|
@ -0,0 +1,142 @@
|
||||||
|
package aghnet
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIPMap_allocs(t *testing.T) {
|
||||||
|
ip4 := net.IP{1, 2, 3, 4}
|
||||||
|
m := NewIPMap(0)
|
||||||
|
m.Set(ip4, 42)
|
||||||
|
|
||||||
|
t.Run("get", func(t *testing.T) {
|
||||||
|
var v interface{}
|
||||||
|
var ok bool
|
||||||
|
allocs := testing.AllocsPerRun(100, func() {
|
||||||
|
v, ok = m.Get(ip4)
|
||||||
|
})
|
||||||
|
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 42, v)
|
||||||
|
|
||||||
|
assert.Equal(t, float64(0), allocs)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("len", func(t *testing.T) {
|
||||||
|
var n int
|
||||||
|
allocs := testing.AllocsPerRun(100, func() {
|
||||||
|
n = m.Len()
|
||||||
|
})
|
||||||
|
|
||||||
|
require.Equal(t, 1, n)
|
||||||
|
|
||||||
|
assert.Equal(t, float64(0), allocs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPMap(t *testing.T) {
|
||||||
|
ip4 := net.IP{1, 2, 3, 4}
|
||||||
|
ip6 := net.IP{
|
||||||
|
0x12, 0x34, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x00, 0x00,
|
||||||
|
0x00, 0x00, 0x56, 0x78,
|
||||||
|
}
|
||||||
|
|
||||||
|
val := 42
|
||||||
|
|
||||||
|
t.Run("nil", func(t *testing.T) {
|
||||||
|
var m *IPMap
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
m.Del(ip4)
|
||||||
|
m.Del(ip6)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
v, ok := m.Get(ip4)
|
||||||
|
assert.Nil(t, v)
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
v, ok = m.Get(ip6)
|
||||||
|
assert.Nil(t, v)
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
assert.Equal(t, 0, m.Len())
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
n := 0
|
||||||
|
m.Range(func(_ net.IP, _ interface{}) (cont bool) {
|
||||||
|
n++
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, 0, n)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
m.Set(ip4, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
m.Set(ip6, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
sclone := m.ShallowClone()
|
||||||
|
assert.Nil(t, sclone)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
testIPMap := func(t *testing.T, ip net.IP, s string) {
|
||||||
|
m := NewIPMap(0)
|
||||||
|
assert.Equal(t, 0, m.Len())
|
||||||
|
|
||||||
|
v, ok := m.Get(ip)
|
||||||
|
assert.Nil(t, v)
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
m.Set(ip, val)
|
||||||
|
v, ok = m.Get(ip)
|
||||||
|
assert.Equal(t, val, v)
|
||||||
|
assert.True(t, ok)
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
m.Range(func(ipKey net.IP, v interface{}) (cont bool) {
|
||||||
|
assert.Equal(t, ip.To16(), ipKey)
|
||||||
|
assert.Equal(t, val, v)
|
||||||
|
|
||||||
|
n++
|
||||||
|
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
assert.Equal(t, 1, n)
|
||||||
|
|
||||||
|
sclone := m.ShallowClone()
|
||||||
|
assert.Equal(t, m, sclone)
|
||||||
|
|
||||||
|
assert.Equal(t, s, m.String())
|
||||||
|
|
||||||
|
m.Del(ip)
|
||||||
|
v, ok = m.Get(ip)
|
||||||
|
assert.Nil(t, v)
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Equal(t, 0, m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ipv4", func(t *testing.T) {
|
||||||
|
testIPMap(t, ip4, "map[1.2.3.4:42]")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ipv6", func(t *testing.T) {
|
||||||
|
testIPMap(t, ip6, "map[1234::5678:42]")
|
||||||
|
})
|
||||||
|
}
|
|
@ -6,138 +6,163 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// accessCtx controls IP and client blocking that takes place before all other
|
||||||
|
// processing. An accessCtx is safe for concurrent use.
|
||||||
type accessCtx struct {
|
type accessCtx struct {
|
||||||
lock sync.Mutex
|
allowedIPs *aghnet.IPMap
|
||||||
|
blockedIPs *aghnet.IPMap
|
||||||
|
|
||||||
// allowedClients are the IP addresses of clients in the allowlist.
|
allowedClientIDs *aghstrings.Set
|
||||||
allowedClients *aghstrings.Set
|
blockedClientIDs *aghstrings.Set
|
||||||
|
|
||||||
// disallowedClients are the IP addresses of clients in the blocklist.
|
blockedHostsEng *urlfilter.DNSEngine
|
||||||
disallowedClients *aghstrings.Set
|
|
||||||
|
|
||||||
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||||
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
|
// aghnet.IPNetSet?
|
||||||
|
allowedNets []*net.IPNet
|
||||||
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
|
blockedNets []*net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
|
// unit is a convenient alias for struct{}
|
||||||
a = &accessCtx{
|
type unit = struct{}
|
||||||
allowedClients: aghstrings.NewSet(),
|
|
||||||
disallowedClients: aghstrings.NewSet(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
|
// processAccessClients is a helper for processing a list of client strings,
|
||||||
if err != nil {
|
// which may be an IP address, a CIDR, or a ClientID.
|
||||||
return nil, fmt.Errorf("processing allowed clients: %w", err)
|
func processAccessClients(
|
||||||
}
|
clientStrs []string,
|
||||||
|
ips *aghnet.IPMap,
|
||||||
|
nets *[]*net.IPNet,
|
||||||
|
clientIDs *aghstrings.Set,
|
||||||
|
) (err error) {
|
||||||
|
for i, s := range clientStrs {
|
||||||
|
if ip := net.ParseIP(s); ip != nil {
|
||||||
|
ips.Set(ip, unit{})
|
||||||
|
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||||
|
ipnet.IP = cidrIP
|
||||||
|
*nets = append(*nets, ipnet)
|
||||||
|
} else {
|
||||||
|
idErr := ValidateClientID(s)
|
||||||
|
if idErr != nil {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"value %q at index %d: bad ip, cidr, or clientid",
|
||||||
|
s,
|
||||||
|
i,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
|
clientIDs.Add(s)
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("processing disallowed clients: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b := &strings.Builder{}
|
|
||||||
for _, s := range blockedHosts {
|
|
||||||
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
listArray := []filterlist.RuleList{}
|
|
||||||
list := &filterlist.StringRuleList{
|
|
||||||
ID: int(0),
|
|
||||||
RulesText: b.String(),
|
|
||||||
IgnoreCosmetic: true,
|
|
||||||
}
|
|
||||||
listArray = append(listArray, list)
|
|
||||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
|
||||||
}
|
|
||||||
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
|
|
||||||
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split array of IP or CIDR into 2 containers for fast search
|
|
||||||
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
|
|
||||||
for _, s := range src {
|
|
||||||
ip := net.ParseIP(s)
|
|
||||||
if ip != nil {
|
|
||||||
dst.Add(s)
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ipnet, err := net.ParseCIDR(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
*dstIPNet = append(*dstIPNet, *ipnet)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlockedIP - return TRUE if this client should be blocked
|
// newAccessCtx creates a new accessCtx.
|
||||||
// Returns the item from the "disallowedClients" list that lead to blocking IP.
|
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||||
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
|
a = &accessCtx{
|
||||||
// but the ip does not belong to it.
|
allowedIPs: aghnet.NewIPMap(0),
|
||||||
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
|
blockedIPs: aghnet.NewIPMap(0),
|
||||||
ipStr := ip.String()
|
|
||||||
|
|
||||||
a.lock.Lock()
|
allowedClientIDs: aghstrings.NewSet(),
|
||||||
defer a.lock.Unlock()
|
blockedClientIDs: aghstrings.NewSet(),
|
||||||
|
|
||||||
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
|
|
||||||
if a.allowedClients.Has(ipStr) {
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(a.allowedClientsIPNet) != 0 {
|
|
||||||
for _, ipnet := range a.allowedClientsIPNet {
|
|
||||||
if ipnet.Contains(ip) {
|
|
||||||
return false, ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if a.disallowedClients.Has(ipStr) {
|
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
|
||||||
return true, ipStr
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("adding allowed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(a.disallowedClientsIPNet) != 0 {
|
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
|
||||||
for _, ipnet := range a.disallowedClientsIPNet {
|
if err != nil {
|
||||||
if ipnet.Contains(ip) {
|
return nil, fmt.Errorf("adding blocked: %w", err)
|
||||||
return true, ipnet.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, ""
|
b := &strings.Builder{}
|
||||||
|
for _, h := range blockedHosts {
|
||||||
|
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
lists := []filterlist.RuleList{
|
||||||
|
&filterlist.StringRuleList{
|
||||||
|
ID: int(0),
|
||||||
|
RulesText: b.String(),
|
||||||
|
IgnoreCosmetic: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rulesStrg, err := filterlist.NewRuleStorage(lists)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("adding blocked hosts: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
|
||||||
|
|
||||||
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlockedDomain - return TRUE if this domain should be blocked
|
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||||
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
|
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||||
a.lock.Lock()
|
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||||
defer a.lock.Unlock()
|
}
|
||||||
|
|
||||||
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
|
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||||
|
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||||
|
allowlistMode := a.allowlistMode()
|
||||||
|
if id == "" {
|
||||||
|
// In allowlist mode, consider requests without client IDs
|
||||||
|
// blocked by default.
|
||||||
|
return allowlistMode
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowlistMode {
|
||||||
|
return !a.allowedClientIDs.Has(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return a.blockedClientIDs.Has(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isBlockedHost returns true if host should be blocked.
|
||||||
|
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||||
|
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||||
|
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||||
|
// that blocked it.
|
||||||
|
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||||
|
blocked = true
|
||||||
|
ips := a.blockedIPs
|
||||||
|
ipnets := a.blockedNets
|
||||||
|
|
||||||
|
if a.allowlistMode() {
|
||||||
|
// Enable allowlist mode and use the allowlist sets.
|
||||||
|
blocked = false
|
||||||
|
ips = a.allowedIPs
|
||||||
|
ipnets = a.allowedNets
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := ips.Get(ip); ok {
|
||||||
|
return blocked, ip.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipnet := range ipnets {
|
||||||
|
if ipnet.Contains(ip) {
|
||||||
|
return blocked, ipnet.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return !blocked, ""
|
||||||
|
}
|
||||||
|
|
||||||
type accessListJSON struct {
|
type accessListJSON struct {
|
||||||
AllowedClients []string `json:"allowed_clients"`
|
AllowedClients []string `json:"allowed_clients"`
|
||||||
DisallowedClients []string `json:"disallowed_clients"`
|
DisallowedClients []string `json:"disallowed_clients"`
|
||||||
|
@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
err := json.NewEncoder(w).Encode(j)
|
err := json.NewEncoder(w).Encode(j)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkIPCIDRArray(src []string) error {
|
|
||||||
for _, s := range src {
|
|
||||||
ip := net.ParseIP(s)
|
|
||||||
if ip != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err := net.ParseCIDR(s)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||||
j := accessListJSON{}
|
list := accessListJSON{}
|
||||||
err := json.NewDecoder(r.Body).Decode(&j)
|
err := json.NewDecoder(r.Body).Decode(&list)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = checkIPCIDRArray(j.AllowedClients)
|
|
||||||
if err == nil {
|
|
||||||
err = checkIPCIDRArray(j.DisallowedClients)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var a *accessCtx
|
var a *accessCtx
|
||||||
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
|
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer log.Debug("Access: updated lists: %d, %d, %d",
|
defer log.Debug(
|
||||||
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
|
"access: updated lists: %d, %d, %d",
|
||||||
|
len(list.AllowedClients),
|
||||||
|
len(list.DisallowedClients),
|
||||||
|
len(list.BlockedHosts),
|
||||||
|
)
|
||||||
|
|
||||||
defer s.conf.ConfigModified()
|
defer s.conf.ConfigModified()
|
||||||
|
|
||||||
s.serverLock.Lock()
|
s.serverLock.Lock()
|
||||||
defer s.serverLock.Unlock()
|
defer s.serverLock.Unlock()
|
||||||
|
|
||||||
s.conf.AllowedClients = j.AllowedClients
|
s.conf.AllowedClients = list.AllowedClients
|
||||||
s.conf.DisallowedClients = j.DisallowedClients
|
s.conf.DisallowedClients = list.DisallowedClients
|
||||||
s.conf.BlockedHosts = j.BlockedHosts
|
s.conf.BlockedHosts = list.BlockedHosts
|
||||||
s.access = a
|
s.access = a
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,99 +8,23 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsBlockedIP(t *testing.T) {
|
func TestIsBlockedClientID(t *testing.T) {
|
||||||
const (
|
clientID := "client-1"
|
||||||
ip int = iota
|
clients := []string{clientID}
|
||||||
cidr
|
|
||||||
)
|
|
||||||
|
|
||||||
rules := []string{
|
a, err := newAccessCtx(clients, nil, nil)
|
||||||
ip: "1.1.1.1",
|
require.NoError(t, err)
|
||||||
cidr: "2.2.0.0/16",
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
assert.False(t, a.isBlockedClientID(clientID))
|
||||||
name string
|
|
||||||
allowed bool
|
|
||||||
ip net.IP
|
|
||||||
wantDis bool
|
|
||||||
wantRule string
|
|
||||||
}{{
|
|
||||||
name: "allow_ip",
|
|
||||||
allowed: true,
|
|
||||||
ip: net.IPv4(1, 1, 1, 1),
|
|
||||||
wantDis: false,
|
|
||||||
wantRule: "",
|
|
||||||
}, {
|
|
||||||
name: "disallow_ip",
|
|
||||||
allowed: true,
|
|
||||||
ip: net.IPv4(1, 1, 1, 2),
|
|
||||||
wantDis: true,
|
|
||||||
wantRule: "",
|
|
||||||
}, {
|
|
||||||
name: "allow_cidr",
|
|
||||||
allowed: true,
|
|
||||||
ip: net.IPv4(2, 2, 1, 1),
|
|
||||||
wantDis: false,
|
|
||||||
wantRule: "",
|
|
||||||
}, {
|
|
||||||
name: "disallow_cidr",
|
|
||||||
allowed: true,
|
|
||||||
ip: net.IPv4(2, 3, 1, 1),
|
|
||||||
wantDis: true,
|
|
||||||
wantRule: "",
|
|
||||||
}, {
|
|
||||||
name: "allow_ip",
|
|
||||||
allowed: false,
|
|
||||||
ip: net.IPv4(1, 1, 1, 1),
|
|
||||||
wantDis: true,
|
|
||||||
wantRule: rules[ip],
|
|
||||||
}, {
|
|
||||||
name: "disallow_ip",
|
|
||||||
allowed: false,
|
|
||||||
ip: net.IPv4(1, 1, 1, 2),
|
|
||||||
wantDis: false,
|
|
||||||
wantRule: "",
|
|
||||||
}, {
|
|
||||||
name: "allow_cidr",
|
|
||||||
allowed: false,
|
|
||||||
ip: net.IPv4(2, 2, 1, 1),
|
|
||||||
wantDis: true,
|
|
||||||
wantRule: rules[cidr],
|
|
||||||
}, {
|
|
||||||
name: "disallow_cidr",
|
|
||||||
allowed: false,
|
|
||||||
ip: net.IPv4(2, 3, 1, 1),
|
|
||||||
wantDis: false,
|
|
||||||
wantRule: "",
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
a, err = newAccessCtx(nil, clients, nil)
|
||||||
prefix := "allowed_"
|
require.NoError(t, err)
|
||||||
if !tc.allowed {
|
|
||||||
prefix = "disallowed_"
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run(prefix+tc.name, func(t *testing.T) {
|
assert.True(t, a.isBlockedClientID(clientID))
|
||||||
allowedRules := rules
|
|
||||||
var disallowedRules []string
|
|
||||||
|
|
||||||
if !tc.allowed {
|
|
||||||
allowedRules, disallowedRules = disallowedRules, allowedRules
|
|
||||||
}
|
|
||||||
|
|
||||||
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
|
||||||
assert.Equal(t, tc.wantDis, disallowed)
|
|
||||||
assert.Equal(t, tc.wantRule, rule)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIsBlockedDomain(t *testing.T) {
|
func TestIsBlockedHost(t *testing.T) {
|
||||||
aCtx, err := newAccessCtx(nil, nil, []string{
|
a, err := newAccessCtx(nil, nil, []string{
|
||||||
"host1",
|
"host1",
|
||||||
"*.host.com",
|
"*.host.com",
|
||||||
"||host3.com^",
|
"||host3.com^",
|
||||||
|
@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
domain string
|
host string
|
||||||
want bool
|
want bool
|
||||||
}{{
|
}{{
|
||||||
name: "plain_match",
|
name: "plain_match",
|
||||||
domain: "host1",
|
host: "host1",
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
name: "plain_mismatch",
|
name: "plain_mismatch",
|
||||||
domain: "host2",
|
host: "host2",
|
||||||
want: false,
|
want: false,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-1_match_short",
|
name: "subdomain_match_short",
|
||||||
domain: "asdf.host.com",
|
host: "asdf.host.com",
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-1_match_long",
|
name: "subdomain_match_long",
|
||||||
domain: "qwer.asdf.host.com",
|
host: "qwer.asdf.host.com",
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-1_mismatch_no-lead",
|
name: "subdomain_mismatch_no_lead",
|
||||||
domain: "host.com",
|
host: "host.com",
|
||||||
want: false,
|
want: false,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-1_mismatch_bad-asterisk",
|
name: "subdomain_mismatch_bad_asterisk",
|
||||||
domain: "asdf.zhost.com",
|
host: "asdf.zhost.com",
|
||||||
want: false,
|
want: false,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-2_match_simple",
|
name: "rule_match_simple",
|
||||||
domain: "host3.com",
|
host: "host3.com",
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-2_match_complex",
|
name: "rule_match_complex",
|
||||||
domain: "asdf.host3.com",
|
host: "asdf.host3.com",
|
||||||
want: true,
|
want: true,
|
||||||
}, {
|
}, {
|
||||||
name: "wildcard_type-2_mismatch",
|
name: "rule_mismatch",
|
||||||
domain: ".host3.com",
|
host: ".host3.com",
|
||||||
want: false,
|
want: false,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsBlockedIP(t *testing.T) {
|
||||||
|
clients := []string{
|
||||||
|
"1.2.3.4",
|
||||||
|
"5.6.7.8/24",
|
||||||
|
}
|
||||||
|
|
||||||
|
allowCtx, err := newAccessCtx(clients, nil, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
blockCtx, err := newAccessCtx(nil, clients, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
wantRule string
|
||||||
|
ip net.IP
|
||||||
|
wantBlocked bool
|
||||||
|
}{{
|
||||||
|
name: "match_ip",
|
||||||
|
wantRule: "1.2.3.4",
|
||||||
|
ip: net.IP{1, 2, 3, 4},
|
||||||
|
wantBlocked: true,
|
||||||
|
}, {
|
||||||
|
name: "match_cidr",
|
||||||
|
wantRule: "5.6.7.8/24",
|
||||||
|
ip: net.IP{5, 6, 7, 100},
|
||||||
|
wantBlocked: true,
|
||||||
|
}, {
|
||||||
|
name: "no_match_ip",
|
||||||
|
wantRule: "",
|
||||||
|
ip: net.IP{9, 2, 3, 4},
|
||||||
|
wantBlocked: false,
|
||||||
|
}, {
|
||||||
|
name: "no_match_cidr",
|
||||||
|
wantRule: "",
|
||||||
|
ip: net.IP{9, 6, 7, 100},
|
||||||
|
wantBlocked: false,
|
||||||
|
}}
|
||||||
|
|
||||||
|
t.Run("allow", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
blocked, rule := allowCtx.isBlockedIP(tc.ip)
|
||||||
|
assert.Equal(t, !tc.wantBlocked, blocked)
|
||||||
|
assert.Equal(t, tc.wantRule, rule)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("block", func(t *testing.T) {
|
||||||
|
for _, tc := range testCases {
|
||||||
|
blocked, rule := blockCtx.isBlockedIP(tc.ip)
|
||||||
|
assert.Equal(t, tc.wantBlocked, blocked)
|
||||||
|
assert.Equal(t, tc.wantRule, rule)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
|
||||||
return clientID, nil
|
return clientID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||||
// client's DNS-over-HTTPS request.
|
// client's DNS-over-HTTPS request.
|
||||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||||
pctx := ctx.proxyCtx
|
|
||||||
r := pctx.HTTPRequest
|
r := pctx.HTTPRequest
|
||||||
if r == nil {
|
if r == nil {
|
||||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
return "", fmt.Errorf(
|
||||||
|
"proxy ctx http request of proto %s is nil",
|
||||||
return resultCodeError
|
pctx.Proto,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
origPath := r.URL.Path
|
origPath := r.URL.Path
|
||||||
|
@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
return "", fmt.Errorf("client id check: invalid path %q", origPath)
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID := ""
|
|
||||||
switch len(parts) {
|
switch len(parts) {
|
||||||
case 1:
|
case 1:
|
||||||
// Just /dns-query, no client ID.
|
// Just /dns-query, no client ID.
|
||||||
return resultCodeSuccess
|
return "", nil
|
||||||
case 2:
|
case 2:
|
||||||
clientID = parts[1]
|
clientID = parts[1]
|
||||||
default:
|
default:
|
||||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err := ValidateClientID(clientID)
|
err = ValidateClientID(clientID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ctx.err = fmt.Errorf("client id check: %w", err)
|
return "", fmt.Errorf("client id check: %w", err)
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.clientID = clientID
|
return clientID, nil
|
||||||
|
|
||||||
return resultCodeSuccess
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||||
|
@ -108,53 +100,73 @@ type quicSession interface {
|
||||||
ConnectionState() (cs quic.ConnectionState)
|
ConnectionState() (cs quic.ConnectionState)
|
||||||
}
|
}
|
||||||
|
|
||||||
// processClientID extracts the client's ID from the server name of the client's
|
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||||
// DoT or DoQ request or the path of the client's DoH.
|
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||||
func processClientID(dctx *dnsContext) (rc resultCode) {
|
// is not one of these, clientID is an empty string and err is nil.
|
||||||
pctx := dctx.proxyCtx
|
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||||
proto := pctx.Proto
|
proto := pctx.Proto
|
||||||
if proto == proxy.ProtoHTTPS {
|
if proto == proxy.ProtoHTTPS {
|
||||||
return processClientIDHTTPS(dctx)
|
return clientIDFromDNSContextHTTPS(pctx)
|
||||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||||
return resultCodeSuccess
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
srvConf := dctx.srv.conf
|
hostSrvName := s.conf.ServerName
|
||||||
hostSrvName := srvConf.TLSConfig.ServerName
|
|
||||||
if hostSrvName == "" {
|
if hostSrvName == "" {
|
||||||
return resultCodeSuccess
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cliSrvName := ""
|
cliSrvName := ""
|
||||||
if proto == proxy.ProtoTLS {
|
switch proto {
|
||||||
|
case proxy.ProtoTLS:
|
||||||
conn := pctx.Conn
|
conn := pctx.Conn
|
||||||
tc, ok := conn.(tlsConn)
|
tc, ok := conn.(tlsConn)
|
||||||
if !ok {
|
if !ok {
|
||||||
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
return "", fmt.Errorf(
|
||||||
|
"proxy ctx conn of proto %s is %T, want *tls.Conn",
|
||||||
return resultCodeError
|
proto,
|
||||||
|
conn,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
cliSrvName = tc.ConnectionState().ServerName
|
cliSrvName = tc.ConnectionState().ServerName
|
||||||
} else if proto == proxy.ProtoQUIC {
|
case proxy.ProtoQUIC:
|
||||||
qs, ok := pctx.QUICSession.(quicSession)
|
qs, ok := pctx.QUICSession.(quicSession)
|
||||||
if !ok {
|
if !ok {
|
||||||
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
return "", fmt.Errorf(
|
||||||
|
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||||
return resultCodeError
|
proto,
|
||||||
|
pctx.QUICSession,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
clientID, err = clientIDFromClientServerName(
|
||||||
|
hostSrvName,
|
||||||
|
cliSrvName,
|
||||||
|
s.conf.StrictSNICheck,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
dctx.err = fmt.Errorf("client id check: %w", err)
|
return "", fmt.Errorf("client id check: %w", err)
|
||||||
|
|
||||||
return resultCodeError
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dctx.clientID = clientID
|
return clientID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// processClientID puts the clientID into the DNS context, if there is one.
|
||||||
|
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
|
||||||
|
pctx := dctx.proxyCtx
|
||||||
|
|
||||||
|
var key [8]byte
|
||||||
|
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||||
|
clientIDData := s.clientIDCache.Get(key[:])
|
||||||
|
if clientIDData == nil {
|
||||||
|
return resultCodeSuccess
|
||||||
|
}
|
||||||
|
|
||||||
|
dctx.clientID = string(clientIDData)
|
||||||
|
|
||||||
return resultCodeSuccess
|
return resultCodeSuccess
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||||
return cs
|
return cs
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessClientID(t *testing.T) {
|
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
proto string
|
proto proxy.Proto
|
||||||
hostSrvName string
|
hostSrvName string
|
||||||
cliSrvName string
|
cliSrvName string
|
||||||
wantClientID string
|
wantClientID string
|
||||||
wantErrMsg string
|
wantErrMsg string
|
||||||
wantRes resultCode
|
|
||||||
strictSNI bool
|
strictSNI bool
|
||||||
}{{
|
}{{
|
||||||
name: "udp",
|
name: "udp",
|
||||||
|
@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
cliSrvName: "",
|
cliSrvName: "",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
strictSNI: false,
|
strictSNI: false,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_no_client_id",
|
name: "tls_no_client_id",
|
||||||
|
@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
cliSrvName: "example.com",
|
cliSrvName: "example.com",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_no_client_server_name",
|
name: "tls_no_client_server_name",
|
||||||
|
@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: client server name "" ` +
|
wantErrMsg: `client id check: client server name "" ` +
|
||||||
`doesn't match host server name "example.com"`,
|
`doesn't match host server name "example.com"`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_no_client_server_name_no_strict",
|
name: "tls_no_client_server_name_no_strict",
|
||||||
|
@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
cliSrvName: "",
|
cliSrvName: "",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
strictSNI: false,
|
strictSNI: false,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id",
|
name: "tls_client_id",
|
||||||
|
@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
cliSrvName: "cli.example.com",
|
cliSrvName: "cli.example.com",
|
||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id_hostname_error",
|
name: "tls_client_id_hostname_error",
|
||||||
|
@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||||
`doesn't match host server name "example.com"`,
|
`doesn't match host server name "example.com"`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_invalid_client_id",
|
name: "tls_invalid_client_id",
|
||||||
|
@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||||
`invalid char '!' at index 0`,
|
`invalid char '!' at index 0`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "tls_client_id_too_long",
|
name: "tls_client_id_too_long",
|
||||||
|
@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||||
`label is too long, max: 63`,
|
`label is too long, max: 63`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}, {
|
}, {
|
||||||
name: "quic_client_id",
|
name: "quic_client_id",
|
||||||
|
@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
|
||||||
cliSrvName: "cli.example.com",
|
cliSrvName: "cli.example.com",
|
||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
strictSNI: true,
|
strictSNI: true,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
|
||||||
ServerName: tc.hostSrvName,
|
ServerName: tc.hostSrvName,
|
||||||
StrictSNICheck: tc.strictSNI,
|
StrictSNICheck: tc.strictSNI,
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
conf: ServerConfig{TLSConfig: tlsConf},
|
conf: ServerConfig{TLSConfig: tlsConf},
|
||||||
}
|
}
|
||||||
|
@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dctx := &dnsContext{
|
pctx := &proxy.DNSContext{
|
||||||
srv: srv,
|
Proto: tc.proto,
|
||||||
proxyCtx: &proxy.DNSContext{
|
Conn: conn,
|
||||||
Proto: tc.proto,
|
QUICSession: qs,
|
||||||
Conn: conn,
|
|
||||||
QUICSession: qs,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res := processClientID(dctx)
|
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||||
assert.Equal(t, tc.wantRes, res)
|
assert.Equal(t, tc.wantClientID, clientID)
|
||||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
|
||||||
|
|
||||||
if tc.wantErrMsg == "" {
|
if tc.wantErrMsg == "" {
|
||||||
assert.NoError(t, dctx.err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
require.Error(t, dctx.err)
|
require.Error(t, err)
|
||||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
|
||||||
|
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessClientID_https(t *testing.T) {
|
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
path string
|
path string
|
||||||
wantClientID string
|
wantClientID string
|
||||||
wantErrMsg string
|
wantErrMsg string
|
||||||
wantRes resultCode
|
|
||||||
}{{
|
}{{
|
||||||
name: "no_client_id",
|
name: "no_client_id",
|
||||||
path: "/dns-query",
|
path: "/dns-query",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
}, {
|
}, {
|
||||||
name: "no_client_id_slash",
|
name: "no_client_id_slash",
|
||||||
path: "/dns-query/",
|
path: "/dns-query/",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
}, {
|
}, {
|
||||||
name: "client_id",
|
name: "client_id",
|
||||||
path: "/dns-query/cli",
|
path: "/dns-query/cli",
|
||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
}, {
|
}, {
|
||||||
name: "client_id_slash",
|
name: "client_id_slash",
|
||||||
path: "/dns-query/cli/",
|
path: "/dns-query/cli/",
|
||||||
wantClientID: "cli",
|
wantClientID: "cli",
|
||||||
wantErrMsg: "",
|
wantErrMsg: "",
|
||||||
wantRes: resultCodeSuccess,
|
|
||||||
}, {
|
}, {
|
||||||
name: "bad_url",
|
name: "bad_url",
|
||||||
path: "/foo",
|
path: "/foo",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
}, {
|
}, {
|
||||||
name: "extra",
|
name: "extra",
|
||||||
path: "/dns-query/cli/foo",
|
path: "/dns-query/cli/foo",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
}, {
|
}, {
|
||||||
name: "invalid_client_id",
|
name: "invalid_client_id",
|
||||||
path: "/dns-query/!!!",
|
path: "/dns-query/!!!",
|
||||||
wantClientID: "",
|
wantClientID: "",
|
||||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||||
`invalid char '!' at index 0`,
|
`invalid char '!' at index 0`,
|
||||||
wantRes: resultCodeError,
|
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
dctx := &dnsContext{
|
pctx := &proxy.DNSContext{
|
||||||
proxyCtx: &proxy.DNSContext{
|
Proto: proxy.ProtoHTTPS,
|
||||||
Proto: proxy.ProtoHTTPS,
|
HTTPRequest: r,
|
||||||
HTTPRequest: r,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
res := processClientID(dctx)
|
clientID, err := clientIDFromDNSContextHTTPS(pctx)
|
||||||
assert.Equal(t, tc.wantRes, res)
|
assert.Equal(t, tc.wantClientID, clientID)
|
||||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
|
||||||
|
|
||||||
if tc.wantErrMsg == "" {
|
if tc.wantErrMsg == "" {
|
||||||
assert.NoError(t, dctx.err)
|
assert.NoError(t, err)
|
||||||
} else {
|
} else {
|
||||||
require.Error(t, dctx.err)
|
require.Error(t, err)
|
||||||
|
|
||||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||||
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
|
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
|
||||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||||
upstreams,
|
upstreams,
|
||||||
upstream.Options{
|
&upstream.Options{
|
||||||
Bootstrap: s.conf.BootstrapDNS,
|
Bootstrap: s.conf.BootstrapDNS,
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
},
|
},
|
||||||
|
@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||||
|
|
||||||
if len(upstreamConfig.Upstreams) == 0 {
|
if len(upstreamConfig.Upstreams) == 0 {
|
||||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||||
var uc proxy.UpstreamConfig
|
var uc *proxy.UpstreamConfig
|
||||||
uc, err = proxy.ParseUpstreamsConfig(
|
uc, err = proxy.ParseUpstreamsConfig(
|
||||||
defaultDNS,
|
defaultDNS,
|
||||||
upstream.Options{
|
&upstream.Options{
|
||||||
Bootstrap: s.conf.BootstrapDNS,
|
Bootstrap: s.conf.BootstrapDNS,
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
},
|
},
|
||||||
|
@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||||
upstreamConfig.Upstreams = uc.Upstreams
|
upstreamConfig.Upstreams = uc.Upstreams
|
||||||
}
|
}
|
||||||
|
|
||||||
s.conf.UpstreamConfig = &upstreamConfig
|
s.conf.UpstreamConfig = upstreamConfig
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||||
s.processInternalHosts,
|
s.processInternalHosts,
|
||||||
s.processRestrictLocal,
|
s.processRestrictLocal,
|
||||||
s.processInternalIPAddrs,
|
s.processInternalIPAddrs,
|
||||||
processClientID,
|
s.processClientID,
|
||||||
processFilteringBeforeRequest,
|
processFilteringBeforeRequest,
|
||||||
s.processLocalPTR,
|
s.processLocalPTR,
|
||||||
s.processUpstream,
|
s.processUpstream,
|
||||||
|
@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||||
s.tableHostToIP = t
|
s.tableHostToIP = t
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
|
||||||
s.tableIPToHostLock.Lock()
|
s.tableIPToHostLock.Lock()
|
||||||
defer s.tableIPToHostLock.Unlock()
|
defer s.tableIPToHostLock.Unlock()
|
||||||
|
|
||||||
|
@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var hostToIP hostToIPTable
|
var hostToIP hostToIPTable
|
||||||
var ipToHost ipToHostTable
|
var ipToHost *aghnet.IPMap
|
||||||
if add {
|
if add {
|
||||||
hostToIP = make(hostToIPTable)
|
|
||||||
ipToHost = make(ipToHostTable)
|
|
||||||
|
|
||||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||||
|
|
||||||
|
hostToIP = make(hostToIPTable, len(ll))
|
||||||
|
ipToHost = aghnet.NewIPMap(len(ll))
|
||||||
|
|
||||||
for _, l := range ll {
|
for _, l := range ll {
|
||||||
// TODO(a.garipov): Remove this after we're finished
|
// TODO(a.garipov): Remove this after we're finished
|
||||||
// with the client hostname validations in the DHCP
|
// with the client hostname validations in the DHCP
|
||||||
|
@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||||
|
|
||||||
lowhost := strings.ToLower(l.Hostname)
|
lowhost := strings.ToLower(l.Hostname)
|
||||||
|
|
||||||
ipToHost[l.IP.String()] = lowhost
|
ipToHost.Set(l.IP, lowhost)
|
||||||
|
|
||||||
ip := make(net.IP, 4)
|
ip := make(net.IP, 4)
|
||||||
copy(ip, l.IP.To4())
|
copy(ip, l.IP.To4())
|
||||||
hostToIP[lowhost] = ip
|
hostToIP[lowhost] = ip
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
|
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
||||||
}
|
}
|
||||||
|
|
||||||
s.setTableHostToIP(hostToIP)
|
s.setTableHostToIP(hostToIP)
|
||||||
|
@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
host, ok = s.tableIPToHost[ip.String()]
|
var v interface{}
|
||||||
|
v, ok = s.tableIPToHost.Get(ip)
|
||||||
|
|
||||||
|
var typOK bool
|
||||||
|
if host, typOK = v.(string); !typOK {
|
||||||
|
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
||||||
|
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
return host, ok
|
return host, ok
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -26,6 +27,11 @@ import (
|
||||||
// DefaultTimeout is the default upstream timeout
|
// DefaultTimeout is the default upstream timeout
|
||||||
const DefaultTimeout = 10 * time.Second
|
const DefaultTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// defaultClientIDCacheCount is the default count of items in the LRU client ID
|
||||||
|
// cache. The assumption here is that there won't be more than this many
|
||||||
|
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||||
|
const defaultClientIDCacheCount = 1024
|
||||||
|
|
||||||
const (
|
const (
|
||||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||||
parentalBlockHost = "family-block.dns.adguard.com"
|
parentalBlockHost = "family-block.dns.adguard.com"
|
||||||
|
@ -44,12 +50,6 @@ var webRegistered bool
|
||||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
||||||
type hostToIPTable = map[string]net.IP
|
type hostToIPTable = map[string]net.IP
|
||||||
|
|
||||||
// ipToHostTable is an alias for the type of Server.tableIPToHost.
|
|
||||||
//
|
|
||||||
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
|
|
||||||
// places?
|
|
||||||
type ipToHostTable = map[string]string
|
|
||||||
|
|
||||||
// Server is the main way to start a DNS server.
|
// Server is the main way to start a DNS server.
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
|
@ -81,9 +81,13 @@ type Server struct {
|
||||||
tableHostToIP hostToIPTable
|
tableHostToIP hostToIPTable
|
||||||
tableHostToIPLock sync.Mutex
|
tableHostToIPLock sync.Mutex
|
||||||
|
|
||||||
tableIPToHost ipToHostTable
|
tableIPToHost *aghnet.IPMap
|
||||||
tableIPToHostLock sync.Mutex
|
tableIPToHostLock sync.Mutex
|
||||||
|
|
||||||
|
// clientIDCache is a temporary storage for clientIDs that were
|
||||||
|
// extracted during the BeforeRequestHandler stage.
|
||||||
|
clientIDCache cache.Cache
|
||||||
|
|
||||||
// DNS proxy instance for internal usage
|
// DNS proxy instance for internal usage
|
||||||
// We don't Start() it and so no listen port is required.
|
// We don't Start() it and so no listen port is required.
|
||||||
internalProxy *proxy.Proxy
|
internalProxy *proxy.Proxy
|
||||||
|
@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||||
subnetDetector: p.SubnetDetector,
|
subnetDetector: p.SubnetDetector,
|
||||||
localDomainSuffix: localDomainSuffix,
|
localDomainSuffix: localDomainSuffix,
|
||||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||||
|
clientIDCache: cache.New(cache.Config{
|
||||||
|
EnableLRU: true,
|
||||||
|
MaxCount: defaultClientIDCacheCount,
|
||||||
|
}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||||
|
@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||||
|
|
||||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||||
|
|
||||||
var upsConfig proxy.UpstreamConfig
|
var upsConfig *proxy.UpstreamConfig
|
||||||
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
|
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||||
Bootstrap: bootstraps,
|
localAddrs,
|
||||||
Timeout: defaultLocalTimeout,
|
&upstream.Options{
|
||||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
Bootstrap: bootstraps,
|
||||||
})
|
Timeout: defaultLocalTimeout,
|
||||||
|
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||||
|
},
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parsing upstreams: %w", err)
|
return fmt.Errorf("parsing upstreams: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.localResolvers = &proxy.Proxy{
|
s.localResolvers = &proxy.Proxy{
|
||||||
Config: proxy.Config{
|
Config: proxy.Config{
|
||||||
UpstreamConfig: &upsConfig,
|
UpstreamConfig: upsConfig,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsBlockedIP - return TRUE if this client should be blocked
|
// IsBlockedClient returns true if the client is blocked by the current access
|
||||||
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
|
// settings.
|
||||||
if ip == nil {
|
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
||||||
return false, ""
|
s.serverLock.RLock()
|
||||||
|
defer s.serverLock.RUnlock()
|
||||||
|
|
||||||
|
allowlistMode := s.access.allowlistMode()
|
||||||
|
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||||
|
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||||
|
|
||||||
|
// Allow if at least one of the checks allows in allowlist mode, but
|
||||||
|
// block if at least one of the checks blocks in blocklist mode.
|
||||||
|
if allowlistMode && blockedByIP && blockedByClientID {
|
||||||
|
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||||
|
|
||||||
|
// Return now without substituting the empty rule for the
|
||||||
|
// clientID because the rule can't be empty here.
|
||||||
|
return true, rule
|
||||||
|
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||||
|
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||||
|
|
||||||
|
blocked = true
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.access.IsBlockedIP(ip)
|
if rule == "" {
|
||||||
|
rule = clientID
|
||||||
|
}
|
||||||
|
|
||||||
|
return blocked, rule
|
||||||
}
|
}
|
||||||
|
|
|
@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
proto string
|
net string
|
||||||
|
proto proxy.Proto
|
||||||
}{{
|
}{{
|
||||||
name: "message_over_udp",
|
name: "message_over_udp",
|
||||||
|
net: "",
|
||||||
proto: proxy.ProtoUDP,
|
proto: proxy.ProtoUDP,
|
||||||
}, {
|
}, {
|
||||||
name: "message_over_tcp",
|
name: "message_over_tcp",
|
||||||
|
net: "tcp",
|
||||||
proto: proxy.ProtoTCP,
|
proto: proxy.ProtoTCP,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
addr := s.dnsProxy.Addr(tc.proto)
|
addr := s.dnsProxy.Addr(tc.proto)
|
||||||
client := dns.Client{Net: tc.proto}
|
client := dns.Client{Net: tc.net}
|
||||||
|
|
||||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||||
|
@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||||
// Message over UDP.
|
// Message over UDP.
|
||||||
req := createGoogleATestMessage()
|
req := createGoogleATestMessage()
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
client := dns.Client{Net: proxy.ProtoUDP}
|
client := &dns.Client{}
|
||||||
|
|
||||||
reply, _, err := client.Exchange(req, addr.String())
|
reply, _, err := client.Exchange(req, addr.String())
|
||||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||||
|
@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
|
||||||
|
|
||||||
// Create a DNS-over-QUIC upstream.
|
// Create a DNS-over-QUIC upstream.
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||||
opts := upstream.Options{InsecureSkipVerify: true}
|
opts := &upstream.Options{InsecureSkipVerify: true}
|
||||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
|
||||||
|
|
||||||
// Message over UDP.
|
// Message over UDP.
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
conn, err := dns.Dial("udp", addr.String())
|
||||||
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
|
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
|
||||||
|
|
||||||
sendTestMessagesAsync(t, conn)
|
sendTestMessagesAsync(t, conn)
|
||||||
|
@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||||
client := dns.Client{Net: proxy.ProtoUDP}
|
client := &dns.Client{}
|
||||||
|
|
||||||
yandexIP := net.IP{213, 180, 193, 56}
|
yandexIP := net.IP{213, 180, 193, 56}
|
||||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||||
|
@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
|
|
||||||
// Send a DNS request without question.
|
// Send a DNS request without question.
|
||||||
_, _, err := (&dns.Client{
|
_, _, err := (&dns.Client{
|
||||||
Net: proxy.ProtoUDP,
|
|
||||||
Timeout: 500 * time.Millisecond,
|
Timeout: 500 * time.Millisecond,
|
||||||
}).Exchange(&req, addr)
|
}).Exchange(&req, addr)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -11,23 +12,39 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
// beforeRequestHandler is the handler that is called before any other
|
||||||
ip := aghnet.IPFromAddr(d.Addr)
|
// processing, including logs. It performs access checks and puts the client
|
||||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
// ID, if there is one, into the server's cache.
|
||||||
if disallowed {
|
func (s *Server) beforeRequestHandler(
|
||||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
_ *proxy.Proxy,
|
||||||
|
pctx *proxy.DNSContext,
|
||||||
|
) (reply bool, err error) {
|
||||||
|
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||||
|
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("getting clientid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blocked, _ := s.IsBlockedClient(ip, clientID)
|
||||||
|
if blocked {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.Req.Question) == 1 {
|
if len(pctx.Req.Question) == 1 {
|
||||||
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
|
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
|
||||||
if s.access.IsBlockedDomain(host) {
|
if s.access.isBlockedHost(host) {
|
||||||
log.Tracef("domain %s is blocked by access settings", host)
|
log.Debug("host %s is in access blocklist", host)
|
||||||
|
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if clientID != "" {
|
||||||
|
key := [8]byte{}
|
||||||
|
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||||
|
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||||
|
}
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
|
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||||
|
|
||||||
_, err = proxy.ParseUpstreamsConfig(
|
_, err = proxy.ParseUpstreamsConfig(
|
||||||
upstreams,
|
upstreams,
|
||||||
upstream.Options{
|
&upstream.Options{
|
||||||
Bootstrap: []string{},
|
Bootstrap: []string{},
|
||||||
Timeout: DefaultTimeout,
|
Timeout: DefaultTimeout,
|
||||||
},
|
},
|
||||||
|
@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||||
|
|
||||||
log.Debug("checking if dns server %q works...", input)
|
log.Debug("checking if dns server %q works...", input)
|
||||||
var u upstream.Upstream
|
var u upstream.Upstream
|
||||||
u, err = upstream.AddressToUpstream(input, upstream.Options{
|
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||||
Bootstrap: bootstrap,
|
Bootstrap: bootstrap,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
})
|
})
|
||||||
|
|
|
@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
|
||||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
proto string
|
proto proxy.Proto
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
clientID string
|
clientID string
|
||||||
wantLogProto querylog.ClientProto
|
wantLogProto querylog.ClientProto
|
||||||
|
@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||||
wantStatResult: stats.RParental,
|
wantStatResult: stats.RParental,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
|
|
@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error {
|
||||||
var err error
|
var err error
|
||||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||||
d.parentalServer = defaultParentalServer
|
d.parentalServer = defaultParentalServer
|
||||||
opts := upstream.Options{
|
opts := &upstream.Options{
|
||||||
Timeout: dnsTimeout,
|
Timeout: dnsTimeout,
|
||||||
ServerIPAddrs: []net.IP{
|
ServerIPAddrs: []net.IP{
|
||||||
{94, 140, 14, 15},
|
{94, 140, 14, 15},
|
||||||
|
|
|
@ -78,10 +78,13 @@ type RuntimeClientWHOISInfo struct {
|
||||||
type clientsContainer struct {
|
type clientsContainer struct {
|
||||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||||
// different types (string, net.IP, and so on).
|
// different types (string, net.IP, and so on).
|
||||||
list map[string]*Client // name -> client
|
list map[string]*Client // name -> client
|
||||||
idIndex map[string]*Client // ID -> client
|
idIndex map[string]*Client // ID -> client
|
||||||
ipToRC map[string]*RuntimeClient // IP -> runtime client
|
|
||||||
lock sync.Mutex
|
// ipToRC is the IP address to *RuntimeClient map.
|
||||||
|
ipToRC *aghnet.IPMap
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
|
||||||
allTags *aghstrings.Set
|
allTags *aghstrings.Set
|
||||||
|
|
||||||
|
@ -109,7 +112,7 @@ func (clients *clientsContainer) Init(
|
||||||
}
|
}
|
||||||
clients.list = make(map[string]*Client)
|
clients.list = make(map[string]*Client)
|
||||||
clients.idIndex = make(map[string]*Client)
|
clients.idIndex = make(map[string]*Client)
|
||||||
clients.ipToRC = make(map[string]*RuntimeClient)
|
clients.ipToRC = aghnet.NewIPMap(0)
|
||||||
|
|
||||||
clients.allTags = aghstrings.NewSet(clientTags...)
|
clients.allTags = aghstrings.NewSet(clientTags...)
|
||||||
|
|
||||||
|
@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() {
|
||||||
clients.addFromHostsFile()
|
clients.addFromHostsFile()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exists checks if client with this ID already exists.
|
// Exists checks if client with this IP address already exists.
|
||||||
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
|
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
_, ok = clients.findLocked(id)
|
_, ok = clients.findLocked(ip.String())
|
||||||
if ok {
|
if ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
var rc *RuntimeClient
|
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||||
rc, ok = clients.ipToRC[id]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
var name string
|
var name string
|
||||||
whois := &querylog.ClientWHOIS{}
|
whois := &querylog.ClientWHOIS{}
|
||||||
|
ip := net.ParseIP(id)
|
||||||
|
|
||||||
c, ok := clients.Find(id)
|
c, ok := clients.Find(id)
|
||||||
if ok {
|
if ok {
|
||||||
name = c.Name
|
name = c.Name
|
||||||
} else {
|
} else if ip != nil {
|
||||||
var rc RuntimeClient
|
var rc *RuntimeClient
|
||||||
rc, ok = clients.FindRuntimeClient(id)
|
rc, ok = clients.FindRuntimeClient(ip)
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||||
whois = toQueryLogWHOIS(rc.WHOISInfo)
|
whois = toQueryLogWHOIS(rc.WHOISInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP(id)
|
disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id)
|
||||||
disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip)
|
|
||||||
|
|
||||||
return &querylog.Client{
|
return &querylog.Client{
|
||||||
Name: name,
|
Name: name,
|
||||||
|
@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams(
|
||||||
return c.upstreamConfig, nil
|
return c.upstreamConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var conf proxy.UpstreamConfig
|
var conf *proxy.UpstreamConfig
|
||||||
conf, err = proxy.ParseUpstreamsConfig(
|
conf, err = proxy.ParseUpstreamsConfig(
|
||||||
upstreams,
|
upstreams,
|
||||||
upstream.Options{
|
&upstream.Options{
|
||||||
Bootstrap: config.DNS.BootstrapDNS,
|
Bootstrap: config.DNS.BootstrapDNS,
|
||||||
Timeout: config.DNS.UpstreamTimeout.Duration,
|
Timeout: config.DNS.UpstreamTimeout.Duration,
|
||||||
},
|
},
|
||||||
|
@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams(
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.upstreamConfig = &conf
|
c.upstreamConfig = conf
|
||||||
|
|
||||||
return &conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findLocked searches for a client by its ID. For internal use only.
|
// findLocked searches for a client by its ID. For internal use only.
|
||||||
|
@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||||
|
// internal use only.
|
||||||
|
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||||
|
var v interface{}
|
||||||
|
v, ok = clients.ipToRC.Get(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, ok = v.(*RuntimeClient)
|
||||||
|
if !ok {
|
||||||
|
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return rc, true
|
||||||
|
}
|
||||||
|
|
||||||
// FindRuntimeClient finds a runtime client by their IP.
|
// FindRuntimeClient finds a runtime client by their IP.
|
||||||
func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) {
|
func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||||
ipAddr := net.ParseIP(ip)
|
if ip == nil {
|
||||||
if ipAddr == nil {
|
return nil, false
|
||||||
return RuntimeClient{}, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
rc, ok := clients.ipToRC[ip]
|
return clients.findRuntimeClientLocked(ip)
|
||||||
if ok {
|
|
||||||
return *rc, true
|
|
||||||
}
|
|
||||||
|
|
||||||
return RuntimeClient{}, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// check validates the client.
|
// check validates the client.
|
||||||
|
@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetWHOISInfo sets the WHOIS information for a client.
|
// SetWHOISInfo sets the WHOIS information for a client.
|
||||||
func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) {
|
func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
_, ok := clients.findLocked(ip)
|
_, ok := clients.findLocked(ip.String())
|
||||||
if ok {
|
if ok {
|
||||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, ok := clients.ipToRC[ip]
|
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||||
if ok {
|
if ok {
|
||||||
rc.WHOISInfo = wi
|
rc.WHOISInfo = wi
|
||||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||||
|
@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI
|
||||||
}
|
}
|
||||||
|
|
||||||
rc.WHOISInfo = wi
|
rc.WHOISInfo = wi
|
||||||
clients.ipToRC[ip] = rc
|
|
||||||
|
clients.ipToRC.Set(ip, rc)
|
||||||
|
|
||||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||||
// taken into account. ok is true if the pairing was added.
|
// taken into account. ok is true if the pairing was added.
|
||||||
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
|
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok
|
||||||
}
|
}
|
||||||
|
|
||||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||||
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
|
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||||
var rc *RuntimeClient
|
var rc *RuntimeClient
|
||||||
rc, ok = clients.ipToRC[ip]
|
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||||
if ok {
|
if ok {
|
||||||
if rc.Source > src {
|
if rc.Source > src {
|
||||||
return false
|
return false
|
||||||
|
@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.ipToRC[ip] = rc
|
clients.ipToRC.Set(ip, rc)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC))
|
log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len())
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||||
// rmHostsBySrc removes all entries that match the specified source.
|
// rmHostsBySrc removes all entries that match the specified source.
|
||||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||||
n := 0
|
n := 0
|
||||||
for k, v := range clients.ipToRC {
|
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||||
if v.Source == src {
|
rc, ok := v.(*RuntimeClient)
|
||||||
delete(clients.ipToRC, k)
|
if !ok {
|
||||||
|
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if rc.Source == src {
|
||||||
|
clients.ipToRC.Del(ip)
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
log.Debug("clients: removed %d client aliases", n)
|
log.Debug("clients: removed %d client aliases", n)
|
||||||
}
|
}
|
||||||
|
@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() {
|
||||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||||
|
|
||||||
n := 0
|
n := 0
|
||||||
for ip, names := range hosts {
|
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||||
|
names, ok := v.([]string)
|
||||||
|
if !ok {
|
||||||
|
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||||
|
}
|
||||||
|
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||||
if ok {
|
if ok {
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Clients: added %d client aliases from system hosts-file", n)
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||||
|
@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||||
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
||||||
lines := strings.Split(string(data), "\n")
|
lines := strings.Split(string(data), "\n")
|
||||||
for _, ln := range lines {
|
for _, ln := range lines {
|
||||||
open := strings.Index(ln, " (")
|
lparen := strings.Index(ln, " (")
|
||||||
close := strings.Index(ln, ") ")
|
rparen := strings.Index(ln, ") ")
|
||||||
if open == -1 || close == -1 || open >= close {
|
if lparen == -1 || rparen == -1 || lparen >= rparen {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
host := ln[:open]
|
host := ln[:lparen]
|
||||||
ip := ln[open+2 : close]
|
ipStr := ln[lparen+2 : rparen]
|
||||||
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
|
ip := net.ParseIP(ipStr)
|
||||||
|
if aghnet.ValidateDomainName(host) != nil || ip == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||||
if ok {
|
if ok {
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
ok, err := clients.Add(c)
|
ok, err := clients.Add(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c = &Client{
|
c = &Client{
|
||||||
|
@ -35,23 +36,27 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
ok, err = clients.Add(c)
|
ok, err = clients.Add(c)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c, ok = clients.Find("1.1.1.1")
|
c, ok = clients.Find("1.1.1.1")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
|
||||||
c, ok = clients.Find("1:2:3::4")
|
c, ok = clients.Find("1:2:3::4")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client1", c.Name)
|
assert.Equal(t, "client1", c.Name)
|
||||||
|
|
||||||
c, ok = clients.Find("2.2.2.2")
|
c, ok = clients.Find("2.2.2.2")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client2", c.Name)
|
assert.Equal(t, "client2", c.Name)
|
||||||
|
|
||||||
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_name", func(t *testing.T) {
|
t.Run("add_fail_name", func(t *testing.T) {
|
||||||
|
@ -101,8 +106,8 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||||
|
|
||||||
err = clients.Update("client1", &Client{
|
err = clients.Update("client1", &Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
|
@ -113,21 +118,25 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
c, ok := clients.Find("1.1.1.2")
|
c, ok := clients.Find("1.1.1.2")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
assert.Equal(t, "client1-renamed", c.Name)
|
assert.Equal(t, "client1-renamed", c.Name)
|
||||||
assert.True(t, c.UseOwnSettings)
|
assert.True(t, c.UseOwnSettings)
|
||||||
|
|
||||||
nilCli, ok := clients.list["client1"]
|
nilCli, ok := clients.list["client1"]
|
||||||
require.False(t, ok)
|
require.False(t, ok)
|
||||||
|
|
||||||
assert.Nil(t, nilCli)
|
assert.Nil(t, nilCli)
|
||||||
|
|
||||||
require.Len(t, c.IDs, 1)
|
require.Len(t, c.IDs, 1)
|
||||||
|
|
||||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del_success", func(t *testing.T) {
|
t.Run("del_success", func(t *testing.T) {
|
||||||
ok := clients.Del("client1-renamed")
|
ok := clients.Del("client1-renamed")
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
|
||||||
|
assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("del_fail", func(t *testing.T) {
|
t.Run("del_fail", func(t *testing.T) {
|
||||||
|
@ -136,37 +145,44 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_success", func(t *testing.T) {
|
t.Run("addhost_success", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
ip := net.IP{1, 1, 1, 1}
|
||||||
|
|
||||||
|
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
assert.True(t, clients.Exists(ip, ClientSourceHostsFile))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP)
|
ip := net.IP{1, 2, 3, 4}
|
||||||
|
|
||||||
|
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
assert.True(t, clients.Exists(ip, ClientSourceARP))
|
||||||
|
|
||||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP))
|
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||||
|
|
||||||
ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
|
||||||
|
|
||||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP))
|
assert.True(t, ok)
|
||||||
|
assert.True(t, clients.Exists(ip, ClientSourceDHCP))
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("addhost_fail", func(t *testing.T) {
|
t.Run("addhost_fail", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
|
@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("new_client", func(t *testing.T) {
|
t.Run("new_client", func(t *testing.T) {
|
||||||
clients.SetWHOISInfo("1.1.1.255", whois)
|
ip := net.IP{1, 1, 1, 255}
|
||||||
|
clients.SetWHOISInfo(ip, whois)
|
||||||
|
v, _ := clients.ipToRC.Get(ip)
|
||||||
|
require.NotNil(t, v)
|
||||||
|
|
||||||
require.NotNil(t, clients.ipToRC["1.1.1.255"])
|
rc, ok := v.(*RuntimeClient)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
h := clients.ipToRC["1.1.1.255"]
|
assert.Equal(t, rc.WHOISInfo, whois)
|
||||||
require.NotNil(t, h)
|
|
||||||
|
|
||||||
assert.Equal(t, h.WHOISInfo, whois)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("existing_auto-client", func(t *testing.T) {
|
t.Run("existing_auto-client", func(t *testing.T) {
|
||||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
ip := net.IP{1, 1, 1, 1}
|
||||||
|
ok, err := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.SetWHOISInfo("1.1.1.1", whois)
|
clients.SetWHOISInfo(ip, whois)
|
||||||
|
v, _ := clients.ipToRC.Get(ip)
|
||||||
|
require.NotNil(t, v)
|
||||||
|
|
||||||
require.NotNil(t, clients.ipToRC["1.1.1.1"])
|
rc, ok := v.(*RuntimeClient)
|
||||||
h := clients.ipToRC["1.1.1.1"]
|
require.True(t, ok)
|
||||||
require.NotNil(t, h)
|
require.NotNil(t, rc)
|
||||||
|
|
||||||
assert.Equal(t, h.WHOISInfo, whois)
|
assert.Equal(t, rc.WHOISInfo, whois)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||||
|
ip := net.IP{1, 1, 1, 2}
|
||||||
|
|
||||||
ok, err := clients.Add(&Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{"1.1.1.2"},
|
IDs: []string{"1.1.1.2"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
clients.SetWHOISInfo("1.1.1.2", whois)
|
clients.SetWHOISInfo(ip, whois)
|
||||||
require.Nil(t, clients.ipToRC["1.1.1.2"])
|
v, _ := clients.ipToRC.Get(ip)
|
||||||
|
require.Nil(t, v)
|
||||||
|
|
||||||
assert.True(t, clients.Del("client1"))
|
assert.True(t, clients.Del("client1"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
clients.Init(nil, nil, nil)
|
clients.Init(nil, nil, nil)
|
||||||
|
|
||||||
t.Run("simple", func(t *testing.T) {
|
t.Run("simple", func(t *testing.T) {
|
||||||
|
ip := net.IP{1, 1, 1, 1}
|
||||||
|
|
||||||
// Add a client.
|
// Add a client.
|
||||||
ok, err := clients.Add(&Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Now add an auto-client with the same IP.
|
// Now add an auto-client with the same IP.
|
||||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
})
|
})
|
||||||
|
@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
t.Run("complicated", func(t *testing.T) {
|
t.Run("complicated", func(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
testIP := net.IP{1, 2, 3, 4}
|
ip := net.IP{1, 2, 3, 4}
|
||||||
|
|
||||||
// First, init a DHCP server with a single static lease.
|
// First, init a DHCP server with a single static lease.
|
||||||
config := dhcpd.ServerConfig{
|
config := dhcpd.ServerConfig{
|
||||||
|
@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
|
|
||||||
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||||
IP: testIP,
|
IP: ip,
|
||||||
Hostname: "testhost",
|
Hostname: "testhost",
|
||||||
Expiry: time.Now().Add(time.Hour),
|
Expiry: time.Now().Add(time.Hour),
|
||||||
})
|
})
|
||||||
|
@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
|
|
||||||
// Add a new client with the same IP as for a client with MAC.
|
// Add a new client with the same IP as for a client with MAC.
|
||||||
ok, err := clients.Add(&Client{
|
ok, err := clients.Add(&Client{
|
||||||
IDs: []string{testIP.String()},
|
IDs: []string{ip.String()},
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// clientJSON is a common structure used by several handlers to deal with
|
// clientJSON is a common structure used by several handlers to deal with
|
||||||
|
@ -44,13 +46,13 @@ type clientJSON struct {
|
||||||
type runtimeClientJSON struct {
|
type runtimeClientJSON struct {
|
||||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||||
|
|
||||||
IP string `json:"ip"`
|
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Source string `json:"source"`
|
Source string `json:"source"`
|
||||||
|
IP net.IP `json:"ip"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientListJSON struct {
|
type clientListJSON struct {
|
||||||
Clients []clientJSON `json:"clients"`
|
Clients []*clientJSON `json:"clients"`
|
||||||
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
|
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
|
||||||
Tags []string `json:"supported_tags"`
|
Tags []string `json:"supported_tags"`
|
||||||
}
|
}
|
||||||
|
@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||||
cj := clientToJSON(c)
|
cj := clientToJSON(c)
|
||||||
data.Clients = append(data.Clients, cj)
|
data.Clients = append(data.Clients, cj)
|
||||||
}
|
}
|
||||||
for ip, rc := range clients.ipToRC {
|
|
||||||
|
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||||
|
rc, ok := v.(*RuntimeClient)
|
||||||
|
if !ok {
|
||||||
|
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
cj := runtimeClientJSON{
|
cj := runtimeClientJSON{
|
||||||
IP: ip,
|
|
||||||
Name: rc.Host,
|
|
||||||
WHOISInfo: rc.WHOISInfo,
|
WHOISInfo: rc.WHOISInfo,
|
||||||
|
|
||||||
|
Name: rc.Host,
|
||||||
|
IP: ip,
|
||||||
}
|
}
|
||||||
|
|
||||||
cj.Source = "etc/hosts"
|
cj.Source = "etc/hosts"
|
||||||
|
@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||||
}
|
}
|
||||||
|
|
||||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||||
}
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
data.Tags = clientTags
|
data.Tags = clientTags
|
||||||
|
|
||||||
|
@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert Client object to JSON
|
// Convert Client object to JSON
|
||||||
func clientToJSON(c *Client) clientJSON {
|
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||||
cj := clientJSON{
|
return &clientJSON{
|
||||||
Name: c.Name,
|
Name: c.Name,
|
||||||
IDs: c.IDs,
|
IDs: c.IDs,
|
||||||
Tags: c.Tags,
|
Tags: c.Tags,
|
||||||
|
@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON {
|
||||||
|
|
||||||
Upstreams: c.Upstreams,
|
Upstreams: c.Upstreams,
|
||||||
}
|
}
|
||||||
|
|
||||||
return cj
|
|
||||||
}
|
|
||||||
|
|
||||||
// runtimeClientToJSON converts a RuntimeClient into a JSON struct.
|
|
||||||
func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) {
|
|
||||||
cj = clientJSON{
|
|
||||||
Name: rc.Host,
|
|
||||||
IDs: []string{ip},
|
|
||||||
WHOISInfo: rc.WHOISInfo,
|
|
||||||
}
|
|
||||||
|
|
||||||
return cj
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add a new client
|
// Add a new client
|
||||||
|
@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||||
// Get the list of clients by IP address list
|
// Get the list of clients by IP address list
|
||||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
data := []map[string]clientJSON{}
|
data := []map[string]*clientJSON{}
|
||||||
for i := 0; i < len(q); i++ {
|
for i := 0; i < len(q); i++ {
|
||||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||||
if idStr == "" {
|
if idStr == "" {
|
||||||
|
@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
|
|
||||||
ip := net.ParseIP(idStr)
|
ip := net.ParseIP(idStr)
|
||||||
c, ok := clients.Find(idStr)
|
c, ok := clients.Find(idStr)
|
||||||
var cj clientJSON
|
var cj *clientJSON
|
||||||
if !ok {
|
if !ok {
|
||||||
var found bool
|
cj = clients.findRuntime(ip, idStr)
|
||||||
cj, found = clients.findRuntime(ip, idStr)
|
|
||||||
if !found {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
cj = clientToJSON(c)
|
cj = clientToJSON(c)
|
||||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||||
}
|
}
|
||||||
|
|
||||||
data = append(data, map[string]clientJSON{
|
data = append(data, map[string]*clientJSON{
|
||||||
idStr: cj,
|
idStr: cj,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||||
}
|
}
|
||||||
|
|
||||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||||
// /etc/hosts tables, DHCP leases, or blocklists.
|
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) {
|
// non-nil.
|
||||||
if ip == nil {
|
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||||
return cj, false
|
rc, ok := clients.FindRuntimeClient(ip)
|
||||||
}
|
|
||||||
|
|
||||||
rc, ok := clients.FindRuntimeClient(idStr)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// It is still possible that the IP used to be in the runtime
|
// It is still possible that the IP used to be in the runtime
|
||||||
// clients list, but then the server was reloaded. So, check
|
// clients list, but then the server was reloaded. So, check
|
||||||
// the DNS server's blocked IP list.
|
// the DNS server's blocked IP list.
|
||||||
//
|
//
|
||||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||||
if rule == "" {
|
cj = &clientJSON{
|
||||||
return clientJSON{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
cj = clientJSON{
|
|
||||||
IDs: []string{idStr},
|
IDs: []string{idStr},
|
||||||
Disallowed: &disallowed,
|
Disallowed: &disallowed,
|
||||||
DisallowedRule: &rule,
|
DisallowedRule: &rule,
|
||||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||||
}
|
}
|
||||||
|
|
||||||
return cj, true
|
return cj
|
||||||
}
|
}
|
||||||
|
|
||||||
cj = runtimeClientToJSON(idStr, rc)
|
cj = &clientJSON{
|
||||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
Name: rc.Host,
|
||||||
|
IDs: []string{idStr},
|
||||||
|
WHOISInfo: rc.WHOISInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||||
|
|
||||||
return cj, true
|
return cj
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterClientsHandlers registers HTTP handlers
|
// RegisterClientsHandlers registers HTTP handlers
|
||||||
|
|
|
@ -105,8 +105,8 @@ func isRunning() bool {
|
||||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||||
}
|
}
|
||||||
|
|
||||||
func onDNSRequest(d *proxy.DNSContext) {
|
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||||
ip := aghnet.IPFromAddr(d.Addr)
|
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
// This would be quite weird if we get here.
|
// This would be quite weird if we get here.
|
||||||
return
|
return
|
||||||
|
|
|
@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port
|
||||||
You have two options:
|
You have two options:
|
||||||
1. Run AdGuard Home with root privileges
|
1. Run AdGuard Home with root privileges
|
||||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||||
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
|
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
|
||||||
|
|
||||||
log.Fatal(msg)
|
log.Fatal(msg)
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||||
func (r *RDNS) Begin(ip net.IP) {
|
func (r *RDNS) Begin(ip net.IP) {
|
||||||
r.ensurePrivateCache()
|
r.ensurePrivateCache()
|
||||||
|
|
||||||
if r.isCached(ip) {
|
if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) {
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
id := ip.String()
|
|
||||||
if r.clients.Exists(id, ClientSourceRDNS) {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() {
|
||||||
|
|
||||||
// Don't handle any errors since AddHost doesn't return non-nil
|
// Don't handle any errors since AddHost doesn't return non-nil
|
||||||
// errors for now.
|
// errors for now.
|
||||||
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
|
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
|
@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||||
clients: &clientsContainer{
|
clients: &clientsContainer{
|
||||||
list: map[string]*Client{},
|
list: map[string]*Client{},
|
||||||
idIndex: tc.cliIDIndex,
|
idIndex: tc.cliIDIndex,
|
||||||
ipToRC: map[string]*RuntimeClient{},
|
ipToRC: aghnet.NewIPMap(0),
|
||||||
allTags: aghstrings.NewSet(),
|
allTags: aghstrings.NewSet(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
cc := &clientsContainer{
|
cc := &clientsContainer{
|
||||||
list: map[string]*Client{},
|
list: map[string]*Client{},
|
||||||
idIndex: map[string]*Client{},
|
idIndex: map[string]*Client{},
|
||||||
ipToRC: map[string]*RuntimeClient{},
|
ipToRC: aghnet.NewIPMap(0),
|
||||||
allTags: aghstrings.NewSet(),
|
allTags: aghstrings.NewSet(),
|
||||||
}
|
}
|
||||||
ch := make(chan net.IP)
|
ch := make(chan net.IP)
|
||||||
|
@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS))
|
assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
id := ip.String()
|
w.clients.SetWHOISInfo(ip, info)
|
||||||
w.clients.SetWHOISInfo(id, info)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
|
||||||
a := convertMapToSlice(m, int(maxCount))
|
a := convertMapToSlice(m, int(maxCount))
|
||||||
d := []net.IP{}
|
d := []net.IP{}
|
||||||
for _, it := range a {
|
for _, it := range a {
|
||||||
d = append(d, net.ParseIP(it.Name))
|
ip := net.ParseIP(it.Name)
|
||||||
|
if ip != nil {
|
||||||
|
d = append(d, ip)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,11 @@
|
||||||
|
|
||||||
## v0.107: API changes
|
## v0.107: API changes
|
||||||
|
|
||||||
|
### Client IDs in Access Settings
|
||||||
|
|
||||||
|
* The `POST /control/access/set` HTTP API now accepts client IDs in
|
||||||
|
`"allowed_clients"` and `"disallowed_clients"` fields.
|
||||||
|
|
||||||
### The new field `"unicode_name"` in `DNSQuestion`
|
### The new field `"unicode_name"` in `DNSQuestion`
|
||||||
|
|
||||||
* The new optional field `"unicode_name"` is the Unicode representation of
|
* The new optional field `"unicode_name"` is the Unicode representation of
|
||||||
|
@ -17,7 +22,7 @@
|
||||||
|
|
||||||
### Disabling Statistics
|
### Disabling Statistics
|
||||||
|
|
||||||
* The API `POST /control/stats_config` HTTP API allows disabling statistics by
|
* The `POST /control/stats_config` HTTP API allows disabling statistics by
|
||||||
setting `"interval"` to `0`.
|
setting `"interval"` to `0`.
|
||||||
|
|
||||||
### `POST /control/dhcp/reset_leases`
|
### `POST /control/dhcp/reset_leases`
|
||||||
|
|
|
@ -1957,10 +1957,7 @@
|
||||||
'disallowed_rule':
|
'disallowed_rule':
|
||||||
'type': 'string'
|
'type': 'string'
|
||||||
'description': >
|
'description': >
|
||||||
The rule due to which the client is disallowed. If disallowed is
|
The rule due to which the client is allowed or blocked.
|
||||||
set to true, and this string is empty, then the client IP is
|
|
||||||
disallowed by the "allowed IP list", that is it is not included in
|
|
||||||
the allowed list.
|
|
||||||
'name':
|
'name':
|
||||||
'description': >
|
'description': >
|
||||||
Persistent client's name or an empty string if this is a runtime
|
Persistent client's name or an empty string if this is a runtime
|
||||||
|
@ -2352,17 +2349,19 @@
|
||||||
'description': 'Client and host access list'
|
'description': 'Client and host access list'
|
||||||
'properties':
|
'properties':
|
||||||
'allowed_clients':
|
'allowed_clients':
|
||||||
'description': 'Allowlist of clients.'
|
'description': >
|
||||||
|
The allowlist of clients: IP addresses, CIDRs, or client IDs.
|
||||||
'items':
|
'items':
|
||||||
'type': 'string'
|
'type': 'string'
|
||||||
'type': 'array'
|
'type': 'array'
|
||||||
'disallowed_clients':
|
'disallowed_clients':
|
||||||
'description': 'Blocklist of clients.'
|
'description': >
|
||||||
|
The blocklist of clients: IP addresses, CIDRs, or client IDs.
|
||||||
'items':
|
'items':
|
||||||
'type': 'string'
|
'type': 'string'
|
||||||
'type': 'array'
|
'type': 'array'
|
||||||
'blocked_hosts':
|
'blocked_hosts':
|
||||||
'description': 'Blocklist of hosts.'
|
'description': 'The blocklist of hosts.'
|
||||||
'items':
|
'items':
|
||||||
'type': 'string'
|
'type': 'string'
|
||||||
'type': 'array'
|
'type': 'array'
|
||||||
|
|
Loading…
Reference in New Issue