Merge branch 'master' into 3745-cleanup-lists

This commit is contained in:
Eugene Burkov 2021-10-21 14:30:18 +03:00
commit 8388b3d5bc
23 changed files with 549 additions and 184 deletions

View File

@ -46,6 +46,10 @@ and this project adheres to
### Changed ### Changed
- `$dnsrewrite` rules and other DNS rewrites will now be applied even when the
protection is disabled ([#1558]).
- DHCP gateway address, subnet mask, IP address range, and leases validations
([#3529]).
- The `systemd` service script will now create the `/var/log` directory when it - The `systemd` service script will now create the `/var/log` directory when it
doesn't exist ([#3579]). doesn't exist ([#3579]).
- Items in allowed clients, disallowed clients, and blocked hosts lists are now - Items in allowed clients, disallowed clients, and blocked hosts lists are now
@ -114,6 +118,7 @@ In this release, the schema version has changed from 10 to 12.
### Fixed ### Fixed
- Incorrect assignment of explicitly configured DHCP options ([#3744]).
- Occasional panic during shutdown ([#3655]). - Occasional panic during shutdown ([#3655]).
- Addition of IPs into only one as opposed to all matching ipsets on Linux - Addition of IPs into only one as opposed to all matching ipsets on Linux
([#3638]). ([#3638]).
@ -152,6 +157,7 @@ In this release, the schema version has changed from 10 to 12.
- Go 1.15 support. - Go 1.15 support.
[#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381 [#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381
[#1558]: https://github.com/AdguardTeam/AdGuardHome/issues/1558
[#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691
[#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898 [#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898
[#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992 [#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992
@ -195,6 +201,7 @@ In this release, the schema version has changed from 10 to 12.
[#3450]: https://github.com/AdguardTeam/AdGuardHome/issues/3450 [#3450]: https://github.com/AdguardTeam/AdGuardHome/issues/3450
[#3457]: https://github.com/AdguardTeam/AdGuardHome/issues/3457 [#3457]: https://github.com/AdguardTeam/AdGuardHome/issues/3457
[#3506]: https://github.com/AdguardTeam/AdGuardHome/issues/3506 [#3506]: https://github.com/AdguardTeam/AdGuardHome/issues/3506
[#3529]: https://github.com/AdguardTeam/AdGuardHome/issues/3529
[#3538]: https://github.com/AdguardTeam/AdGuardHome/issues/3538 [#3538]: https://github.com/AdguardTeam/AdGuardHome/issues/3538
[#3551]: https://github.com/AdguardTeam/AdGuardHome/issues/3551 [#3551]: https://github.com/AdguardTeam/AdGuardHome/issues/3551
[#3564]: https://github.com/AdguardTeam/AdGuardHome/issues/3564 [#3564]: https://github.com/AdguardTeam/AdGuardHome/issues/3564
@ -204,6 +211,7 @@ In this release, the schema version has changed from 10 to 12.
[#3607]: https://github.com/AdguardTeam/AdGuardHome/issues/3607 [#3607]: https://github.com/AdguardTeam/AdGuardHome/issues/3607
[#3638]: https://github.com/AdguardTeam/AdGuardHome/issues/3638 [#3638]: https://github.com/AdguardTeam/AdGuardHome/issues/3638
[#3655]: https://github.com/AdguardTeam/AdGuardHome/issues/3655 [#3655]: https://github.com/AdguardTeam/AdGuardHome/issues/3655
[#3744]: https://github.com/AdguardTeam/AdGuardHome/issues/3744

View File

@ -37,6 +37,9 @@
"dhcp_ipv6_settings": "DHCP IPv6 Settings", "dhcp_ipv6_settings": "DHCP IPv6 Settings",
"form_error_required": "Required field", "form_error_required": "Required field",
"form_error_ip4_format": "Invalid IPv4 format", "form_error_ip4_format": "Invalid IPv4 format",
"form_error_ip4_range_start_format": "Invalid range start IPv4 format",
"form_error_ip4_range_end_format": "Invalid range end IPv4 format",
"form_error_ip4_gateway_format": "Invalid gateway IPv4 format",
"form_error_ip6_format": "Invalid IPv6 format", "form_error_ip6_format": "Invalid IPv6 format",
"form_error_ip_format": "Invalid IP format", "form_error_ip_format": "Invalid IP format",
"form_error_mac_format": "Invalid MAC format", "form_error_mac_format": "Invalid MAC format",
@ -45,7 +48,14 @@
"form_error_subnet": "Subnet \"{{cidr}}\" does not contain the IP address \"{{ip}}\"", "form_error_subnet": "Subnet \"{{cidr}}\" does not contain the IP address \"{{ip}}\"",
"form_error_positive": "Must be greater than 0", "form_error_positive": "Must be greater than 0",
"form_error_negative": "Must be equal to 0 or greater", "form_error_negative": "Must be equal to 0 or greater",
"range_end_error": "Must be greater than range start", "out_of_range_error": "Must be out of range \"{{start}}\"-\"{{end}}\"",
"in_range_error": "Must be in range \"{{start}}\"-\"{{end}}\"",
"lower_range_start_error": "Must be lower than range start",
"lower_range_end_error": "Must be lower than range end",
"greater_range_start_error": "Must be greater than range start",
"greater_range_end_error": "Must be greater than range end",
"subnet_error": "Addresses must be in one subnet",
"gateway_or_subnet_invalid": "Subnet mask invalid",
"dhcp_form_gateway_input": "Gateway IP", "dhcp_form_gateway_input": "Gateway IP",
"dhcp_form_subnet_input": "Subnet mask", "dhcp_form_subnet_input": "Subnet mask",
"dhcp_form_range_title": "Range of IP addresses", "dhcp_form_range_title": "Range of IP addresses",

View File

@ -13,6 +13,9 @@ import {
validateIpv4, validateIpv4,
validateRequiredValue, validateRequiredValue,
validateIpv4RangeEnd, validateIpv4RangeEnd,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
validateNotInRange,
} from '../../../helpers/validators'; } from '../../../helpers/validators';
const FormDHCPv4 = ({ const FormDHCPv4 = ({
@ -54,7 +57,11 @@ const FormDHCPv4 = ({
type="text" type="text"
className="form-control" className="form-control"
placeholder={t(ipv4placeholders.gateway_ip)} placeholder={t(ipv4placeholders.gateway_ip)}
validate={[validateIpv4, validateRequired]} validate={[
validateIpv4,
validateRequired,
validateNotInRange,
]}
disabled={!isInterfaceIncludesIpv4} disabled={!isInterfaceIncludesIpv4}
/> />
</div> </div>
@ -66,7 +73,11 @@ const FormDHCPv4 = ({
type="text" type="text"
className="form-control" className="form-control"
placeholder={t(ipv4placeholders.subnet_mask)} placeholder={t(ipv4placeholders.subnet_mask)}
validate={[validateIpv4, validateRequired]} validate={[
validateIpv4,
validateRequired,
validateGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4} disabled={!isInterfaceIncludesIpv4}
/> />
</div> </div>
@ -84,7 +95,11 @@ const FormDHCPv4 = ({
type="text" type="text"
className="form-control" className="form-control"
placeholder={t(ipv4placeholders.range_start)} placeholder={t(ipv4placeholders.range_start)}
validate={[validateIpv4]} validate={[
validateIpv4,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4} disabled={!isInterfaceIncludesIpv4}
/> />
</div> </div>
@ -95,7 +110,12 @@ const FormDHCPv4 = ({
type="text" type="text"
className="form-control" className="form-control"
placeholder={t(ipv4placeholders.range_end)} placeholder={t(ipv4placeholders.range_end)}
validate={[validateIpv4, validateIpv4RangeEnd]} validate={[
validateIpv4,
validateIpv4RangeEnd,
validateGatewaySubnetMask,
validateIpForGatewaySubnetMask,
]}
disabled={!isInterfaceIncludesIpv4} disabled={!isInterfaceIncludesIpv4}
/> />
</div> </div>

View File

@ -10,6 +10,7 @@ import {
validateMac, validateMac,
validateRequiredValue, validateRequiredValue,
validateIpv4InCidr, validateIpv4InCidr,
validateInRange,
} from '../../../../helpers/validators'; } from '../../../../helpers/validators';
import { FORM_NAME } from '../../../../helpers/constants'; import { FORM_NAME } from '../../../../helpers/constants';
import { toggleLeaseModal } from '../../../../actions'; import { toggleLeaseModal } from '../../../../actions';
@ -53,7 +54,12 @@ const Form = ({
type="text" type="text"
className="form-control" className="form-control"
placeholder={t('form_enter_subnet_ip', { cidr })} placeholder={t('form_enter_subnet_ip', { cidr })}
validate={[validateRequiredValue, validateIpv4, validateIpv4InCidr]} validate={[
validateRequiredValue,
validateIpv4,
validateIpv4InCidr,
validateInRange,
]}
/> />
</div> </div>
<div className="form__group"> <div className="form__group">

View File

@ -11,6 +11,8 @@ const Modal = ({
handleSubmit, handleSubmit,
processingAdding, processingAdding,
cidr, cidr,
rangeStart,
rangeEnd,
}) => { }) => {
const dispatch = useDispatch(); const dispatch = useDispatch();
@ -38,10 +40,14 @@ const Modal = ({
ip: '', ip: '',
hostname: '', hostname: '',
cidr, cidr,
rangeStart,
rangeEnd,
}} }}
onSubmit={handleSubmit} onSubmit={handleSubmit}
processingAdding={processingAdding} processingAdding={processingAdding}
cidr={cidr} cidr={cidr}
rangeStart={rangeStart}
rangeEnd={rangeEnd}
/> />
</div> </div>
</ReactModal> </ReactModal>
@ -53,6 +59,8 @@ Modal.propTypes = {
handleSubmit: PropTypes.func.isRequired, handleSubmit: PropTypes.func.isRequired,
processingAdding: PropTypes.bool.isRequired, processingAdding: PropTypes.bool.isRequired,
cidr: PropTypes.string.isRequired, cidr: PropTypes.string.isRequired,
rangeStart: PropTypes.string,
rangeEnd: PropTypes.string,
}; };
export default withTranslation()(Modal); export default withTranslation()(Modal);

View File

@ -22,6 +22,8 @@ const StaticLeases = ({
processingDeleting, processingDeleting,
staticLeases, staticLeases,
cidr, cidr,
rangeStart,
rangeEnd,
}) => { }) => {
const [t] = useTranslation(); const [t] = useTranslation();
const dispatch = useDispatch(); const dispatch = useDispatch();
@ -100,6 +102,8 @@ const StaticLeases = ({
handleSubmit={handleSubmit} handleSubmit={handleSubmit}
processingAdding={processingAdding} processingAdding={processingAdding}
cidr={cidr} cidr={cidr}
rangeStart={rangeStart}
rangeEnd={rangeEnd}
/> />
</> </>
); );
@ -111,6 +115,8 @@ StaticLeases.propTypes = {
processingAdding: PropTypes.bool.isRequired, processingAdding: PropTypes.bool.isRequired,
processingDeleting: PropTypes.bool.isRequired, processingDeleting: PropTypes.bool.isRequired,
cidr: PropTypes.string.isRequired, cidr: PropTypes.string.isRequired,
rangeStart: PropTypes.string,
rangeEnd: PropTypes.string,
}; };
cellWrap.propTypes = { cellWrap.propTypes = {

View File

@ -275,6 +275,8 @@ const Dhcp = () => {
processingAdding={processingAdding} processingAdding={processingAdding}
processingDeleting={processingDeleting} processingDeleting={processingDeleting}
cidr={cidr} cidr={cidr}
rangeStart={dhcp?.values?.v4?.range_start}
rangeEnd={dhcp?.values?.v4?.range_end}
/> />
<div className="btn-list mt-2"> <div className="btn-list mt-2">
<button <button

View File

@ -552,6 +552,20 @@ export const isIpInCidr = (ip, cidr) => {
} }
}; };
/**
*
* @param {string} subnetMask
* @returns {IPv4 | null}
*/
export const parseSubnetMask = (subnetMask) => {
try {
return ipaddr.parse(subnetMask).prefixLengthFromSubnetMask();
} catch (e) {
console.error(e);
return null;
}
};
/** /**
* *
* @param {string} subnetMask * @param {string} subnetMask

View File

@ -1,4 +1,5 @@
import i18next from 'i18next'; import i18next from 'i18next';
import { import {
MAX_PORT, MAX_PORT,
R_CIDR, R_CIDR,
@ -14,7 +15,7 @@ import {
R_DOMAIN, R_DOMAIN,
} from './constants'; } from './constants';
import { ip4ToInt, isValidAbsolutePath } from './form'; import { ip4ToInt, isValidAbsolutePath } from './form';
import { isIpInCidr } from './helpers'; import { isIpInCidr, parseSubnetMask } from './helpers';
// Validation functions // Validation functions
// https://redux-form.com/8.3.0/examples/fieldlevelvalidation/ // https://redux-form.com/8.3.0/examples/fieldlevelvalidation/
@ -44,7 +45,7 @@ export const validateIpv4RangeEnd = (_, allValues) => {
const { range_end, range_start } = allValues.v4; const { range_end, range_start } = allValues.v4;
if (ip4ToInt(range_end) <= ip4ToInt(range_start)) { if (ip4ToInt(range_end) <= ip4ToInt(range_start)) {
return 'range_end_error'; return 'greater_range_start_error';
} }
return undefined; return undefined;
@ -61,6 +62,114 @@ export const validateIpv4 = (value) => {
return undefined; return undefined;
}; };
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateNotInRange = (value, allValues) => {
const { range_start, range_end } = allValues.v4;
if (range_start && validateIpv4(range_start)) {
return 'form_error_ip4_range_start_format';
}
if (range_end && validateIpv4(range_end)) {
return 'form_error_ip4_range_end_format';
}
const isAboveMin = range_start && ip4ToInt(value) >= ip4ToInt(range_start);
const isBelowMax = range_end && ip4ToInt(value) <= ip4ToInt(range_end);
if (isAboveMin && isBelowMax) {
return i18next.t('out_of_range_error', {
start: range_start,
end: range_end,
});
}
if (!range_end && isAboveMin) {
return 'lower_range_start_error';
}
if (!range_start && isBelowMax) {
return 'greater_range_end_error';
}
return undefined;
};
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateInRange = (value, allValues) => {
const { rangeStart, rangeEnd } = allValues;
if (rangeStart && validateIpv4(rangeStart)) {
return 'form_error_ip4_range_start_format';
}
if (rangeEnd && validateIpv4(rangeEnd)) {
return 'form_error_ip4_range_end_format';
}
const isBelowMin = rangeStart && ip4ToInt(value) < ip4ToInt(rangeStart);
const isAboveMax = rangeEnd && ip4ToInt(value) > ip4ToInt(rangeEnd);
if (isAboveMax || isBelowMin) {
return i18next.t('in_range_error', {
start: rangeStart,
end: rangeEnd,
});
}
return undefined;
};
/**
* @returns {undefined|string}
* @param _
* @param allValues
*/
export const validateGatewaySubnetMask = (_, allValues) => {
if (!allValues || !allValues.v4 || !allValues.v4.subnet_mask || !allValues.v4.gateway_ip) {
return 'gateway_or_subnet_invalid';
}
const { subnet_mask, gateway_ip } = allValues.v4;
if (validateIpv4(gateway_ip)) {
return 'form_error_ip4_gateway_format';
}
return parseSubnetMask(subnet_mask) ? undefined : 'gateway_or_subnet_invalid';
};
/**
* @returns {undefined|string}
* @param value
* @param allValues
*/
export const validateIpForGatewaySubnetMask = (value, allValues) => {
if (!allValues || !allValues.v4 || !value) {
return undefined;
}
const {
gateway_ip, subnet_mask,
} = allValues.v4;
const subnetPrefix = parseSubnetMask(subnet_mask);
if (!isIpInCidr(value, `${gateway_ip}/${subnetPrefix}`)) {
return 'subnet_error';
}
return undefined;
};
/** /**
* @param value {string} * @param value {string}
* @returns {undefined|string} * @returns {undefined|string}

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.16
require ( require (
github.com/AdguardTeam/dnsproxy v0.39.8 github.com/AdguardTeam/dnsproxy v0.39.8
github.com/AdguardTeam/golibs v0.10.0 github.com/AdguardTeam/golibs v0.10.2
github.com/AdguardTeam/urlfilter v0.14.6 github.com/AdguardTeam/urlfilter v0.14.6
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.2 github.com/ameshkov/dnscrypt/v2 v2.2.2

4
go.sum
View File

@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.8/go.mod h1:eDpJKAdkHORRwAedjuERv+7SWlcz4c
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.9.2/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= github.com/AdguardTeam/golibs v0.9.2/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY=
github.com/AdguardTeam/golibs v0.10.0 h1:A7MXRfZ+ItpOyS9tWKtqrLj3vZtE9FJFC+dOVY/LcWs= github.com/AdguardTeam/golibs v0.10.2 h1:TAwnS4Y49sSUa4UX1yz/MWNGbIlXHqafrWr9MxdIh9A=
github.com/AdguardTeam/golibs v0.10.0/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.2/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo= github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo=
github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -138,6 +139,49 @@ func TestNormalizeLeases(t *testing.T) {
assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr) assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr)
} }
func TestV4Server_badRange(t *testing.T) {
testCases := []struct {
name string
gatewayIP net.IP
subnetMask net.IP
wantErrMsg string
}{{
name: "gateway_in_range",
gatewayIP: net.IP{192, 168, 10, 120},
subnetMask: net.IP{255, 255, 255, 0},
wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " +
"192.168.10.20-192.168.10.200",
}, {
name: "outside_range_start",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 240},
wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " +
"192.168.10.1/28",
}, {
name: "outside_range_end",
gatewayIP: net.IP{192, 168, 10, 1},
subnetMask: net.IP{255, 255, 255, 224},
wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " +
"192.168.10.1/27",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf := V4ServerConf{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 20},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: tc.gatewayIP,
SubnetMask: tc.subnetMask,
notify: testNotify,
}
_, err := v4Create(conf)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}
}
// cloneUDPAddr returns a deep copy of a. // cloneUDPAddr returns a deep copy of a.
func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) { func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) {
return &net.UDPAddr{ return &net.UDPAddr{

View File

@ -293,6 +293,8 @@ func (s *v4Server) addLease(l *Lease) (err error) {
offset, inOffset := r.offset(l.IP) offset, inOffset := r.offset(l.IP)
if l.IsStatic() { if l.IsStatic() {
// TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is
// disabled.
if sn := s.conf.subnet; !sn.Contains(l.IP) { if sn := s.conf.subnet; !sn.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP) return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP)
} }
@ -900,9 +902,10 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code))) resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code)))
} }
} }
// Update the value of Domain Name Server option separately from others // Update the value of Domain Name Server option separately from others if
// since its value is set after server's creating. // not assigned yet since its value is set after server's creating.
if requested.Has(dhcpv4.OptionDomainNameServer) { if requested.Has(dhcpv4.OptionDomainNameServer) &&
!resp.Options.Has(dhcpv4.OptionDomainNameServer) {
resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...)) resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...))
} }
@ -1124,6 +1127,29 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }
if s.conf.ipRange.contains(routerIP) {
return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v",
routerIP,
conf.RangeStart,
conf.RangeEnd,
)
}
if !s.conf.subnet.Contains(conf.RangeStart) {
return s, fmt.Errorf("dhcpv4: range start %v is outside network %v",
conf.RangeStart,
s.conf.subnet,
)
}
if !s.conf.subnet.Contains(conf.RangeEnd) {
return s, fmt.Errorf("dhcpv4: range end %v is outside network %v",
conf.RangeEnd,
s.conf.subnet,
)
}
// TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange.
s.leasedOffsets = newBitSet() s.leasedOffsets = newBitSet()
if conf.LeaseDuration == 0 { if conf.LeaseDuration == 0 {

View File

@ -5,8 +5,10 @@ package dhcpd
import ( import (
"net" "net"
"strings"
"testing" "testing"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/mdlayher/raw" "github.com/mdlayher/raw"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -16,17 +18,34 @@ import (
func notify4(flags uint32) { func notify4(flags uint32) {
} }
func TestV4_AddRemove_static(t *testing.T) { // defaultV4ServerConf returns the default configuration for *v4Server to use in
s, err := v4Create(V4ServerConf{ // tests.
func defaultV4ServerConf() (conf V4ServerConf) {
return V4ServerConf{
Enabled: true, Enabled: true,
RangeStart: net.IP{192, 168, 10, 100}, RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200}, RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1}, GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0}, SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4, notify: notify4,
}) }
}
// defaultSrv prepares the default DHCPServer to use in tests. The underlying
// type of s is *v4Server.
func defaultSrv(t *testing.T) (s DHCPServer) {
t.Helper()
var err error
s, err = v4Create(defaultV4ServerConf())
require.NoError(t, err) require.NoError(t, err)
return s
}
func TestV4_AddRemove_static(t *testing.T) {
s := defaultSrv(t)
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Empty(t, ls) assert.Empty(t, ls)
@ -37,7 +56,7 @@ func TestV4_AddRemove_static(t *testing.T) {
IP: net.IP{192, 168, 10, 150}, IP: net.IP{192, 168, 10, 150},
} }
err = s.AddStaticLease(l) err := s.AddStaticLease(l)
require.NoError(t, err) require.NoError(t, err)
err = s.AddStaticLease(l) err = s.AddStaticLease(l)
@ -65,15 +84,7 @@ func TestV4_AddRemove_static(t *testing.T) {
} }
func TestV4_AddReplace(t *testing.T) { func TestV4_AddReplace(t *testing.T) {
sIface, err := v4Create(V4ServerConf{ sIface := defaultSrv(t)
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
})
require.NoError(t, err)
s, ok := sIface.(*v4Server) s, ok := sIface.(*v4Server)
require.True(t, ok) require.True(t, ok)
@ -89,7 +100,7 @@ func TestV4_AddReplace(t *testing.T) {
}} }}
for i := range dynLeases { for i := range dynLeases {
err = s.addLease(&dynLeases[i]) err := s.addLease(&dynLeases[i])
require.NoError(t, err) require.NoError(t, err)
} }
@ -104,7 +115,7 @@ func TestV4_AddReplace(t *testing.T) {
}} }}
for _, l := range stLeases { for _, l := range stLeases {
err = s.AddStaticLease(l) err := s.AddStaticLease(l)
require.NoError(t, err) require.NoError(t, err)
} }
@ -118,18 +129,81 @@ func TestV4_AddReplace(t *testing.T) {
} }
} }
func TestV4StaticLease_Get(t *testing.T) { func TestV4Server_Process_optionsPriority(t *testing.T) {
var err error defaultIP := net.IP{192, 168, 1, 1}
sIface, err := v4Create(V4ServerConf{ knownIP := net.IP{1, 2, 3, 4}
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100}, // prepareSrv creates a *v4Server and sets the opt6IPs in the initial
RangeEnd: net.IP{192, 168, 10, 200}, // configuration of the server as the value for DHCP option 6.
GatewayIP: net.IP{192, 168, 10, 1}, prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) {
SubnetMask: net.IP{255, 255, 255, 0}, t.Helper()
notify: notify4,
}) conf := defaultV4ServerConf()
if len(opt6IPs) > 0 {
b := &strings.Builder{}
stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String())
for _, ip := range opt6IPs[1:] {
stringutil.WriteToBuilder(b, ",", ip.String())
}
conf.Options = []string{b.String()}
}
ss, err := v4Create(conf)
require.NoError(t, err) require.NoError(t, err)
var ok bool
s, ok = ss.(*v4Server)
require.True(t, ok)
s.conf.dnsIPAddrs = []net.IP{defaultIP}
return s
}
// checkResp creates a discovery message with DHCP option 6 requested amd
// asserts the response to contain wantIPs in this option.
checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) {
t.Helper()
mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}
req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions(
dhcpv4.OptionDomainNameServer,
))
require.NoError(t, err)
var resp *dhcpv4.DHCPv4
resp, err = dhcpv4.NewReplyFromRequest(req)
require.NoError(t, err)
res := s.process(req, resp)
require.Equal(t, 1, res)
o := resp.GetOneOption(dhcpv4.OptionDomainNameServer)
require.NotEmpty(t, o)
wantData := []byte{}
for _, ip := range wantIPs {
wantData = append(wantData, ip...)
}
assert.Equal(t, o, wantData)
}
t.Run("default", func(t *testing.T) {
s := prepareSrv(t, nil)
checkResp(t, s, []net.IP{defaultIP})
})
t.Run("explicitly_configured", func(t *testing.T) {
s := prepareSrv(t, []net.IP{knownIP, knownIP})
checkResp(t, s, []net.IP{knownIP, knownIP})
})
}
func TestV4StaticLease_Get(t *testing.T) {
sIface := defaultSrv(t)
s, ok := sIface.(*v4Server) s, ok := sIface.(*v4Server)
require.True(t, ok) require.True(t, ok)
@ -140,7 +214,7 @@ func TestV4StaticLease_Get(t *testing.T) {
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
IP: net.IP{192, 168, 10, 150}, IP: net.IP{192, 168, 10, 150},
} }
err = s.AddStaticLease(l) err := s.AddStaticLease(l)
require.NoError(t, err) require.NoError(t, err)
var req, resp *dhcpv4.DHCPv4 var req, resp *dhcpv4.DHCPv4
@ -208,19 +282,14 @@ func TestV4StaticLease_Get(t *testing.T) {
} }
func TestV4DynamicLease_Get(t *testing.T) { func TestV4DynamicLease_Get(t *testing.T) {
var err error conf := defaultV4ServerConf()
sIface, err := v4Create(V4ServerConf{ conf.Options = []string{
Enabled: true,
RangeStart: net.IP{192, 168, 10, 100},
RangeEnd: net.IP{192, 168, 10, 200},
GatewayIP: net.IP{192, 168, 10, 1},
SubnetMask: net.IP{255, 255, 255, 0},
notify: notify4,
Options: []string{
"81 hex 303132", "81 hex 303132",
"82 ip 1.2.3.4", "82 ip 1.2.3.4",
}, }
})
var err error
sIface, err := v4Create(conf)
require.NoError(t, err) require.NoError(t, err)
s, ok := sIface.(*v4Server) s, ok := sIface.(*v4Server)

View File

@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
s.processRestrictLocal, s.processRestrictLocal,
s.processInternalIPAddrs, s.processInternalIPAddrs,
s.processClientID, s.processClientID,
processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR, s.processLocalPTR,
s.processUpstream, s.processUpstream,
processDNSSECAfterResponse, processDNSSECAfterResponse,
@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) {
} }
// Apply filtering logic // Apply filtering logic
func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
s := ctx.srv if ctx.proxyCtx.Res != nil {
d := ctx.proxyCtx // Go on since the response is already set.
return resultCodeSuccess
if d.Res != nil {
return resultCodeSuccess // response is already set - nothing to do
} }
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil ctx.protectionEnabled = s.conf.ProtectionEnabled
if !ctx.protectionEnabled {
if s.dnsFilter == nil {
return resultCodeSuccess return resultCodeSuccess
} }
@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) {
} }
var err error var err error
ctx.result, err = s.filterDNSRequest(ctx) if ctx.result, err = s.filterDNSRequest(ctx); err != nil {
if err != nil {
ctx.err = err ctx.err = err
return resultCodeError return resultCodeError
@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) {
func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) {
s := ctx.srv s := ctx.srv
d := ctx.proxyCtx d := ctx.proxyCtx
res := ctx.result
var err error
switch res.Reason { switch res := ctx.result; res.Reason {
case filtering.Rewritten, case filtering.NotFilteredAllowList:
// Go on.
case
filtering.Rewritten,
filtering.RewrittenRule: filtering.RewrittenRule:
if len(ctx.origQuestion.Name) == 0 { if len(ctx.origQuestion.Name) == 0 {
// origQuestion is set in case we get only CNAME without IP from rewrites table // origQuestion is set in case we get only CNAME without IP from
// rewrites table.
break break
} }
d.Req.Question[0] = ctx.origQuestion d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion
d.Res.Question[0] = ctx.origQuestion if len(d.Res.Answer) > 0 {
answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...)
if len(d.Res.Answer) != 0 {
answer := []dns.RR{}
answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName))
answer = append(answer, d.Res.Answer...)
d.Res.Answer = answer d.Res.Answer = answer
} }
case filtering.NotFilteredAllowList:
// nothing
default: default:
if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for // Check the response only if the it's from an upstream. Don't check
!ctx.responseFromUpstream { // only check response if it's from an upstream server // the response if the protection is disabled since dnsrewrite rules
// aren't applied to it anyway.
if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil {
break break
} }
origResp2 := d.Res
ctx.result, err = s.filterDNSResponse(ctx) origResp := d.Res
result, err := s.filterDNSResponse(ctx)
if err != nil { if err != nil {
ctx.err = err ctx.err = err
return resultCodeError return resultCodeError
} }
if ctx.result != nil {
ctx.origResp = origResp2 // matched by response if result != nil {
} else { ctx.result = result
ctx.result = &filtering.Result{} ctx.origResp = origResp
} }
} }
if ctx.result == nil {
ctx.result = &filtering.Result{}
}
return resultCodeSuccess return resultCodeSuccess
} }

View File

@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) {
}}, }},
} }
f := filtering.New(c, nil) f := filtering.New(c, nil)
f.SetEnabled(true)
snd, err := aghnet.NewSubnetDetector() snd, err := aghnet.NewSubnetDetector()
require.NoError(t, err) require.NoError(t, err)
@ -945,9 +946,10 @@ func TestRewrite(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
subTestFunc := func(t *testing.T) {
req := createTestMessageWithType("test.com.", dns.TypeA) req := createTestMessageWithType("test.com.", dns.TypeA)
reply, err := dns.Exchange(req, addr.String()) reply, eerr := dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Len(t, reply.Answer, 1) require.Len(t, reply.Answer, 1)
@ -957,14 +959,14 @@ func TestRewrite(t *testing.T) {
assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA) req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
assert.Empty(t, reply.Answer) assert.Empty(t, reply.Answer)
req = createTestMessageWithType("alias.test.com.", dns.TypeA) req = createTestMessageWithType("alias.test.com.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Len(t, reply.Answer, 2) require.Len(t, reply.Answer, 2)
@ -972,8 +974,8 @@ func TestRewrite(t *testing.T) {
assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, eerr = dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
// The original question is restored. // The original question is restored.
require.Len(t, reply.Question, 1) require.Len(t, reply.Question, 1)
@ -986,6 +988,16 @@ func TestRewrite(t *testing.T) {
assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype)
} }
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func publicKey(priv interface{}) interface{} { func publicKey(priv interface{}) interface{} {
switch k := priv.(type) { switch k := priv.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) {
require.ErrorIs(t, hc.Close(), closeCalled) require.ErrorIs(t, hc.Close(), closeCalled)
}) })
c := filtering.Config{ flt := filtering.New(&filtering.Config{
EtcHosts: hc, EtcHosts: hc,
} }, nil)
flt.SetEnabled(true)
var snd *aghnet.SubnetDetector var snd *aghnet.SubnetDetector
snd, err = aghnet.NewSubnetDetector() snd, err = aghnet.NewSubnetDetector()
@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
var s *Server var s *Server
s, err = NewServer(DNSCreateParams{ s, err = NewServer(DNSCreateParams{
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: filtering.New(&c, nil), DNSFilter: flt,
SubnetDetector: snd, SubnetDetector: snd,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1112,25 +1125,24 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.UDPListenAddrs = []*net.UDPAddr{{}}
s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"} s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err = s.Prepare(nil) err = s.Prepare(nil)
require.NoError(t, err) require.NoError(t, err)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
s.Close() s.Close()
}) })
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP) addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR)
resp, err := dns.Exchange(req, addr.String()) resp, eerr := dns.Exchange(req, addr.String())
require.NoError(t, err) require.NoError(t, eerr)
require.Lenf(t, resp.Answer, 1, "%#v", resp) require.Len(t, resp.Answer, 1)
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
@ -1140,6 +1152,16 @@ func TestPTRResponseFromHosts(t *testing.T) {
assert.Equal(t, "host.", ptr.Ptr) assert.Equal(t, "host.", ptr.Ptr)
} }
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf.ProtectionEnabled = &val
s.setConfig(conf)
t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc)
}
}
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
// TODO(a.garipov): Consider moving away from the text-based error // TODO(a.garipov): Consider moving away from the text-based error
// checks and onto a more structured approach. // checks and onto a more structured approach.

View File

@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler(
// the client's IP address and ID, if any, from ctx. // the client's IP address and ID, if any, from ctx.
func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings {
setts := s.dnsFilter.GetConfig() setts := s.dnsFilter.GetConfig()
setts.ProtectionEnabled = ctx.protectionEnabled
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr) ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr)
s.conf.FilterHandler(ip, ctx.clientID, &setts) s.conf.FilterHandler(ip, ctx.clientID, &setts)
@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S
func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
d := ctx.proxyCtx d := ctx.proxyCtx
req := d.Req req := d.Req
host := strings.TrimSuffix(req.Question[0].Name, ".") q := req.Question[0]
res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts) host := strings.TrimSuffix(q.Name, ".")
if err != nil { res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts)
// Return immediately if there's an error switch {
return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err) case err != nil:
} else if res.IsFiltered { return nil, fmt.Errorf("failed to check host %q: %w", host, err)
log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text) case res.IsFiltered:
log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text)
d.Res = s.genDNSFilterMessage(d, &res) d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) &&
res.CanonName != "" && res.CanonName != "" &&
len(res.IPList) == 0 { len(res.IPList) == 0:
// Resolve the new canonical name, not the original host // Resolve the new canonical name, not the original host name. The
// name. The original question is readded in // original question is readded in processFilteringAfterResponse.
// processFilteringAfterResponse. ctx.origQuestion = q
ctx.origQuestion = req.Question[0]
req.Question[0].Name = dns.Fqdn(res.CanonName) req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 { case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0:
resp := s.makeResponse(req) resp := s.makeResponse(req)
for _, h := range res.ReverseHosts {
hdr := dns.RR_Header{ hdr := dns.RR_Header{
Name: req.Question[0].Name, Name: q.Name,
Rrtype: dns.TypePTR, Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL, Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET, Class: dns.ClassINET,
} }
for _, h := range res.ReverseHosts {
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: hdr, Hdr: hdr,
Ptr: h, Ptr: h,
@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
d.Res = resp d.Res = resp
} else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) { case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts):
resp := s.makeResponse(req) resp := s.makeResponse(req)
name := host name := host
@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
for _, ip := range res.IPList { for _, ip := range res.IPList {
if req.Question[0].Qtype == dns.TypeA { switch q.Qtype {
case dns.TypeA:
a := s.genAnswerA(req, ip.To4()) a := s.genAnswerA(req, ip.To4())
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)
} else if req.Question[0].Qtype == dns.TypeAAAA { case dns.TypeAAAA:
a := s.genAnswerAAAA(req, ip) a := s.genAnswerAAAA(req, ip)
a.Hdr.Name = dns.Fqdn(name) a.Hdr.Name = dns.Fqdn(name)
resp.Answer = append(resp.Answer, a) resp.Answer = append(resp.Answer, a)
@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) {
} }
d.Res = resp d.Res = resp
} else if res.Reason == filtering.RewrittenRule { case res.Reason == filtering.RewrittenRule:
err = s.filterDNSRewrite(req, res, d) if err = s.filterDNSRewrite(req, res, d); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) {
continue continue
} }
host = strings.TrimSuffix(host, ".")
res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -38,6 +38,7 @@ type Settings struct {
ServicesRules []ServiceEntry ServicesRules []ServiceEntry
ProtectionEnabled bool
FilteringEnabled bool FilteringEnabled bool
SafeSearchEnabled bool SafeSearchEnabled bool
SafeBrowsingEnabled bool SafeBrowsingEnabled bool
@ -221,12 +222,13 @@ func (r Reason) String() string {
} }
// In returns true if reasons include r. // In returns true if reasons include r.
func (r Reason) In(reasons ...Reason) bool { func (r Reason) In(reasons ...Reason) (ok bool) {
for _, reason := range reasons { for _, reason := range reasons {
if r == reason { if r == reason {
return true return true
} }
} }
return false return false
} }
@ -245,7 +247,7 @@ func (d *DNSFilter) GetConfig() (s Settings) {
defer d.confLock.RUnlock() defer d.confLock.RUnlock()
return Settings{ return Settings{
FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1, FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0,
SafeSearchEnabled: d.Config.SafeSearchEnabled, SafeSearchEnabled: d.Config.SafeSearchEnabled,
SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled,
ParentalEnabled: d.Config.ParentalEnabled, ParentalEnabled: d.Config.ParentalEnabled,
@ -421,15 +423,17 @@ func (d *DNSFilter) CheckHost(
// Sometimes clients try to resolve ".", which is a request to get root // Sometimes clients try to resolve ".", which is a request to get root
// servers. // servers.
if host == "" { if host == "" {
return Result{Reason: NotFilteredNotFound}, nil return Result{}, nil
} }
host = strings.ToLower(host) host = strings.ToLower(host)
if setts.FilteringEnabled {
res = d.processRewrites(host, qtype) res = d.processRewrites(host, qtype)
if res.Reason == Rewritten { if res.Reason == Rewritten {
return res, nil return res, nil
} }
}
for _, hc := range d.hostCheckers { for _, hc := range d.hostCheckers {
res, err = hc.check(host, qtype, setts) res, err = hc.check(host, qtype, setts)
@ -448,7 +452,7 @@ func (d *DNSFilter) CheckHost(
// matchSysHosts tries to match the host against the operating system's hosts // matchSysHosts tries to match the host against the operating system's hosts
// database. // database.
func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) { func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) {
if d.EtcHosts == nil { if !setts.FilteringEnabled || d.EtcHosts == nil {
return Result{}, nil return Result{}, nil
} }
@ -468,10 +472,8 @@ func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (r
var ips []net.IP var ips []net.IP
var revHosts []string var revHosts []string
for _, nr := range dnsr { for _, nr := range dnsr {
dr := nr.DNSRewrite if nr.DNSRewrite == nil {
if dr == nil {
continue continue
} }
@ -553,6 +555,10 @@ func matchBlockedServicesRules(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ProtectionEnabled {
return Result{}, nil
}
svcs := setts.ServicesRules svcs := setts.ServicesRules
if len(svcs) == 0 { if len(svcs) == 0 {
return Result{}, nil return Result{}, nil
@ -784,7 +790,7 @@ func (d *DNSFilter) matchHost(
// TODO(e.burkov): Inspect if the above is true. // TODO(e.burkov): Inspect if the above is true.
defer d.engineLock.RUnlock() defer d.engineLock.RUnlock()
if d.filteringEngineAllow != nil { if setts.ProtectionEnabled && d.filteringEngineAllow != nil {
dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq)
if ok { if ok {
return d.matchHostProcessAllowList(host, dnsres) return d.matchHostProcessAllowList(host, dnsres)
@ -810,6 +816,11 @@ func (d *DNSFilter) matchHost(
return Result{}, nil return Result{}, nil
} }
if !setts.ProtectionEnabled {
// Don't check non-dnsrewrite filtering results.
return Result{}, nil
}
res = d.matchHostProcessDNSResult(qtype, dnsres) res = d.matchHostProcessDNSResult(qtype, dnsres)
for _, r := range res.Rules { for _, r := range res.Rules {
log.Debug( log.Debug(

View File

@ -21,7 +21,9 @@ func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m) aghtest.DiscardLogOutput(m)
} }
var setts Settings var setts = Settings{
ProtectionEnabled: true,
}
// Helpers. // Helpers.
@ -39,9 +41,9 @@ func purgeCaches() {
func newForTest(c *Config, filters []Filter) *DNSFilter { func newForTest(c *Config, filters []Filter) *DNSFilter {
setts = Settings{ setts = Settings{
ProtectionEnabled: true,
FilteringEnabled: true, FilteringEnabled: true,
} }
setts.FilteringEnabled = true
if c != nil { if c != nil {
c.SafeBrowsingCacheSize = 10000 c.SafeBrowsingCacheSize = 10000
c.ParentalCacheSize = 10000 c.ParentalCacheSize = 10000
@ -797,7 +799,11 @@ func TestClientSettings(t *testing.T) {
makeTester := func(tc testCase, before bool) func(t *testing.T) { makeTester := func(tc testCase, before bool) func(t *testing.T) {
return func(t *testing.T) { return func(t *testing.T) {
r, _ := d.CheckHost(tc.host, dns.TypeA, &setts) t.Helper()
r, err := d.CheckHost(tc.host, dns.TypeA, &setts)
require.NoError(t, err)
if before { if before {
assert.True(t, r.IsFiltered) assert.True(t, r.IsFiltered)
assert.Equal(t, tc.wantReason, r.Reason) assert.Equal(t, tc.wantReason, r.Reason)
@ -808,7 +814,7 @@ func TestClientSettings(t *testing.T) {
} }
// Check behaviour without any per-client settings, then apply per-client // Check behaviour without any per-client settings, then apply per-client
// settings and check behaviour once again. // settings and check behavior once again.
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, makeTester(tc, tc.before)) t.Run(tc.name, makeTester(tc, tc.before))
} }

View File

@ -306,7 +306,7 @@ func (d *DNSFilter) checkSafeBrowsing(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.SafeBrowsingEnabled { if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled {
return Result{}, nil return Result{}, nil
} }
@ -339,7 +339,7 @@ func (d *DNSFilter) checkParental(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.ParentalEnabled { if !setts.ProtectionEnabled || !setts.ParentalEnabled {
return Result{}, nil return Result{}, nil
} }

View File

@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) {
d.SetParentalUpstream(ups) d.SetParentalUpstream(ups)
setts := &Settings{ setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
ParentalEnabled: true, ParentalEnabled: true,
} }
@ -135,35 +136,36 @@ func TestSBPC(t *testing.T) {
const hostname = "example.org" const hostname = "example.org"
setts := &Settings{ setts := &Settings{
ProtectionEnabled: true,
SafeBrowsingEnabled: true, SafeBrowsingEnabled: true,
ParentalEnabled: true, ParentalEnabled: true,
} }
testCases := []struct { testCases := []struct {
testCache cache.Cache
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
name string name string
block bool block bool
testFunc func(host string, _ uint16, _ *Settings) (res Result, err error)
testCache cache.Cache
}{{ }{{
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_no_block", name: "sb_no_block",
block: false, block: false,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, { }, {
testCache: gctx.safebrowsingCache,
testFunc: d.checkSafeBrowsing,
name: "sb_block", name: "sb_block",
block: true, block: true,
testFunc: d.checkSafeBrowsing,
testCache: gctx.safebrowsingCache,
}, { }, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_no_block", name: "pc_no_block",
block: false, block: false,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}, { }, {
testCache: gctx.parentalCache,
testFunc: d.checkParental,
name: "pc_block", name: "pc_block",
block: true, block: true,
testFunc: d.checkParental,
testCache: gctx.parentalCache,
}} }}
for _, tc := range testCases { for _, tc := range testCases {

View File

@ -74,7 +74,7 @@ func (d *DNSFilter) checkSafeSearch(
_ uint16, _ uint16,
setts *Settings, setts *Settings,
) (res Result, err error) { ) (res Result, err error) {
if !setts.SafeSearchEnabled { if !setts.ProtectionEnabled || !setts.SafeSearchEnabled {
return Result{}, nil return Result{}, nil
} }

View File

@ -404,6 +404,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) {
setts := Context.dnsFilter.GetConfig() setts := Context.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
setts.ProtectionEnabled = true
Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) Context.dnsFilter.ApplyBlockedServices(&setts, nil, true)
result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts)
if err != nil { if err != nil {