diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b017cf0..aeb1ed2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,10 @@ and this project adheres to ### Changed +- `$dnsrewrite` rules and other DNS rewrites will now be applied even when the + protection is disabled ([#1558]). +- DHCP gateway address, subnet mask, IP address range, and leases validations + ([#3529]). - The `systemd` service script will now create the `/var/log` directory when it doesn't exist ([#3579]). - Items in allowed clients, disallowed clients, and blocked hosts lists are now @@ -114,6 +118,7 @@ In this release, the schema version has changed from 10 to 12. ### Fixed +- Incorrect assignment of explicitly configured DHCP options ([#3744]). - Occasional panic during shutdown ([#3655]). - Addition of IPs into only one as opposed to all matching ipsets on Linux ([#3638]). @@ -152,6 +157,7 @@ In this release, the schema version has changed from 10 to 12. - Go 1.15 support. [#1381]: https://github.com/AdguardTeam/AdGuardHome/issues/1381 +[#1558]: https://github.com/AdguardTeam/AdGuardHome/issues/1558 [#1691]: https://github.com/AdguardTeam/AdGuardHome/issues/1691 [#1898]: https://github.com/AdguardTeam/AdGuardHome/issues/1898 [#1992]: https://github.com/AdguardTeam/AdGuardHome/issues/1992 @@ -195,6 +201,7 @@ In this release, the schema version has changed from 10 to 12. [#3450]: https://github.com/AdguardTeam/AdGuardHome/issues/3450 [#3457]: https://github.com/AdguardTeam/AdGuardHome/issues/3457 [#3506]: https://github.com/AdguardTeam/AdGuardHome/issues/3506 +[#3529]: https://github.com/AdguardTeam/AdGuardHome/issues/3529 [#3538]: https://github.com/AdguardTeam/AdGuardHome/issues/3538 [#3551]: https://github.com/AdguardTeam/AdGuardHome/issues/3551 [#3564]: https://github.com/AdguardTeam/AdGuardHome/issues/3564 @@ -204,6 +211,7 @@ In this release, the schema version has changed from 10 to 12. [#3607]: https://github.com/AdguardTeam/AdGuardHome/issues/3607 [#3638]: https://github.com/AdguardTeam/AdGuardHome/issues/3638 [#3655]: https://github.com/AdguardTeam/AdGuardHome/issues/3655 +[#3744]: https://github.com/AdguardTeam/AdGuardHome/issues/3744 diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index f3d8abcd..6db195fe 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -37,6 +37,9 @@ "dhcp_ipv6_settings": "DHCP IPv6 Settings", "form_error_required": "Required field", "form_error_ip4_format": "Invalid IPv4 format", + "form_error_ip4_range_start_format": "Invalid range start IPv4 format", + "form_error_ip4_range_end_format": "Invalid range end IPv4 format", + "form_error_ip4_gateway_format": "Invalid gateway IPv4 format", "form_error_ip6_format": "Invalid IPv6 format", "form_error_ip_format": "Invalid IP format", "form_error_mac_format": "Invalid MAC format", @@ -45,7 +48,14 @@ "form_error_subnet": "Subnet \"{{cidr}}\" does not contain the IP address \"{{ip}}\"", "form_error_positive": "Must be greater than 0", "form_error_negative": "Must be equal to 0 or greater", - "range_end_error": "Must be greater than range start", + "out_of_range_error": "Must be out of range \"{{start}}\"-\"{{end}}\"", + "in_range_error": "Must be in range \"{{start}}\"-\"{{end}}\"", + "lower_range_start_error": "Must be lower than range start", + "lower_range_end_error": "Must be lower than range end", + "greater_range_start_error": "Must be greater than range start", + "greater_range_end_error": "Must be greater than range end", + "subnet_error": "Addresses must be in one subnet", + "gateway_or_subnet_invalid": "Subnet mask invalid", "dhcp_form_gateway_input": "Gateway IP", "dhcp_form_subnet_input": "Subnet mask", "dhcp_form_range_title": "Range of IP addresses", diff --git a/client/src/components/Settings/Dhcp/FormDHCPv4.js b/client/src/components/Settings/Dhcp/FormDHCPv4.js index 873e7696..cb371f9f 100644 --- a/client/src/components/Settings/Dhcp/FormDHCPv4.js +++ b/client/src/components/Settings/Dhcp/FormDHCPv4.js @@ -13,6 +13,9 @@ import { validateIpv4, validateRequiredValue, validateIpv4RangeEnd, + validateGatewaySubnetMask, + validateIpForGatewaySubnetMask, + validateNotInRange, } from '../../../helpers/validators'; const FormDHCPv4 = ({ @@ -54,7 +57,11 @@ const FormDHCPv4 = ({ type="text" className="form-control" placeholder={t(ipv4placeholders.gateway_ip)} - validate={[validateIpv4, validateRequired]} + validate={[ + validateIpv4, + validateRequired, + validateNotInRange, + ]} disabled={!isInterfaceIncludesIpv4} /> </div> @@ -66,7 +73,11 @@ const FormDHCPv4 = ({ type="text" className="form-control" placeholder={t(ipv4placeholders.subnet_mask)} - validate={[validateIpv4, validateRequired]} + validate={[ + validateIpv4, + validateRequired, + validateGatewaySubnetMask, + ]} disabled={!isInterfaceIncludesIpv4} /> </div> @@ -84,7 +95,11 @@ const FormDHCPv4 = ({ type="text" className="form-control" placeholder={t(ipv4placeholders.range_start)} - validate={[validateIpv4]} + validate={[ + validateIpv4, + validateGatewaySubnetMask, + validateIpForGatewaySubnetMask, + ]} disabled={!isInterfaceIncludesIpv4} /> </div> @@ -95,7 +110,12 @@ const FormDHCPv4 = ({ type="text" className="form-control" placeholder={t(ipv4placeholders.range_end)} - validate={[validateIpv4, validateIpv4RangeEnd]} + validate={[ + validateIpv4, + validateIpv4RangeEnd, + validateGatewaySubnetMask, + validateIpForGatewaySubnetMask, + ]} disabled={!isInterfaceIncludesIpv4} /> </div> diff --git a/client/src/components/Settings/Dhcp/StaticLeases/Form.js b/client/src/components/Settings/Dhcp/StaticLeases/Form.js index e857144d..7e9c641b 100644 --- a/client/src/components/Settings/Dhcp/StaticLeases/Form.js +++ b/client/src/components/Settings/Dhcp/StaticLeases/Form.js @@ -10,6 +10,7 @@ import { validateMac, validateRequiredValue, validateIpv4InCidr, + validateInRange, } from '../../../../helpers/validators'; import { FORM_NAME } from '../../../../helpers/constants'; import { toggleLeaseModal } from '../../../../actions'; @@ -53,7 +54,12 @@ const Form = ({ type="text" className="form-control" placeholder={t('form_enter_subnet_ip', { cidr })} - validate={[validateRequiredValue, validateIpv4, validateIpv4InCidr]} + validate={[ + validateRequiredValue, + validateIpv4, + validateIpv4InCidr, + validateInRange, + ]} /> </div> <div className="form__group"> diff --git a/client/src/components/Settings/Dhcp/StaticLeases/Modal.js b/client/src/components/Settings/Dhcp/StaticLeases/Modal.js index b65c298e..8ad0f009 100644 --- a/client/src/components/Settings/Dhcp/StaticLeases/Modal.js +++ b/client/src/components/Settings/Dhcp/StaticLeases/Modal.js @@ -11,6 +11,8 @@ const Modal = ({ handleSubmit, processingAdding, cidr, + rangeStart, + rangeEnd, }) => { const dispatch = useDispatch(); @@ -38,10 +40,14 @@ const Modal = ({ ip: '', hostname: '', cidr, + rangeStart, + rangeEnd, }} onSubmit={handleSubmit} processingAdding={processingAdding} cidr={cidr} + rangeStart={rangeStart} + rangeEnd={rangeEnd} /> </div> </ReactModal> @@ -53,6 +59,8 @@ Modal.propTypes = { handleSubmit: PropTypes.func.isRequired, processingAdding: PropTypes.bool.isRequired, cidr: PropTypes.string.isRequired, + rangeStart: PropTypes.string, + rangeEnd: PropTypes.string, }; export default withTranslation()(Modal); diff --git a/client/src/components/Settings/Dhcp/StaticLeases/index.js b/client/src/components/Settings/Dhcp/StaticLeases/index.js index 6e12f30e..bdd7cec5 100644 --- a/client/src/components/Settings/Dhcp/StaticLeases/index.js +++ b/client/src/components/Settings/Dhcp/StaticLeases/index.js @@ -22,6 +22,8 @@ const StaticLeases = ({ processingDeleting, staticLeases, cidr, + rangeStart, + rangeEnd, }) => { const [t] = useTranslation(); const dispatch = useDispatch(); @@ -100,6 +102,8 @@ const StaticLeases = ({ handleSubmit={handleSubmit} processingAdding={processingAdding} cidr={cidr} + rangeStart={rangeStart} + rangeEnd={rangeEnd} /> </> ); @@ -111,6 +115,8 @@ StaticLeases.propTypes = { processingAdding: PropTypes.bool.isRequired, processingDeleting: PropTypes.bool.isRequired, cidr: PropTypes.string.isRequired, + rangeStart: PropTypes.string, + rangeEnd: PropTypes.string, }; cellWrap.propTypes = { diff --git a/client/src/components/Settings/Dhcp/index.js b/client/src/components/Settings/Dhcp/index.js index 844e662e..6208d4a6 100644 --- a/client/src/components/Settings/Dhcp/index.js +++ b/client/src/components/Settings/Dhcp/index.js @@ -275,6 +275,8 @@ const Dhcp = () => { processingAdding={processingAdding} processingDeleting={processingDeleting} cidr={cidr} + rangeStart={dhcp?.values?.v4?.range_start} + rangeEnd={dhcp?.values?.v4?.range_end} /> <div className="btn-list mt-2"> <button diff --git a/client/src/helpers/helpers.js b/client/src/helpers/helpers.js index 19ed0b70..d6eaa061 100644 --- a/client/src/helpers/helpers.js +++ b/client/src/helpers/helpers.js @@ -552,6 +552,20 @@ export const isIpInCidr = (ip, cidr) => { } }; +/** + * + * @param {string} subnetMask + * @returns {IPv4 | null} + */ +export const parseSubnetMask = (subnetMask) => { + try { + return ipaddr.parse(subnetMask).prefixLengthFromSubnetMask(); + } catch (e) { + console.error(e); + return null; + } +}; + /** * * @param {string} subnetMask diff --git a/client/src/helpers/validators.js b/client/src/helpers/validators.js index 7075ca47..4a4125d2 100644 --- a/client/src/helpers/validators.js +++ b/client/src/helpers/validators.js @@ -1,4 +1,5 @@ import i18next from 'i18next'; + import { MAX_PORT, R_CIDR, @@ -14,7 +15,7 @@ import { R_DOMAIN, } from './constants'; import { ip4ToInt, isValidAbsolutePath } from './form'; -import { isIpInCidr } from './helpers'; +import { isIpInCidr, parseSubnetMask } from './helpers'; // Validation functions // https://redux-form.com/8.3.0/examples/fieldlevelvalidation/ @@ -44,7 +45,7 @@ export const validateIpv4RangeEnd = (_, allValues) => { const { range_end, range_start } = allValues.v4; if (ip4ToInt(range_end) <= ip4ToInt(range_start)) { - return 'range_end_error'; + return 'greater_range_start_error'; } return undefined; @@ -61,6 +62,114 @@ export const validateIpv4 = (value) => { return undefined; }; +/** + * @returns {undefined|string} + * @param _ + * @param allValues + */ +export const validateNotInRange = (value, allValues) => { + const { range_start, range_end } = allValues.v4; + + if (range_start && validateIpv4(range_start)) { + return 'form_error_ip4_range_start_format'; + } + + if (range_end && validateIpv4(range_end)) { + return 'form_error_ip4_range_end_format'; + } + + const isAboveMin = range_start && ip4ToInt(value) >= ip4ToInt(range_start); + const isBelowMax = range_end && ip4ToInt(value) <= ip4ToInt(range_end); + + if (isAboveMin && isBelowMax) { + return i18next.t('out_of_range_error', { + start: range_start, + end: range_end, + }); + } + + if (!range_end && isAboveMin) { + return 'lower_range_start_error'; + } + + if (!range_start && isBelowMax) { + return 'greater_range_end_error'; + } + + return undefined; +}; + +/** + * @returns {undefined|string} + * @param _ + * @param allValues + */ +export const validateInRange = (value, allValues) => { + const { rangeStart, rangeEnd } = allValues; + + if (rangeStart && validateIpv4(rangeStart)) { + return 'form_error_ip4_range_start_format'; + } + + if (rangeEnd && validateIpv4(rangeEnd)) { + return 'form_error_ip4_range_end_format'; + } + + const isBelowMin = rangeStart && ip4ToInt(value) < ip4ToInt(rangeStart); + const isAboveMax = rangeEnd && ip4ToInt(value) > ip4ToInt(rangeEnd); + + if (isAboveMax || isBelowMin) { + return i18next.t('in_range_error', { + start: rangeStart, + end: rangeEnd, + }); + } + + return undefined; +}; + +/** + * @returns {undefined|string} + * @param _ + * @param allValues + */ +export const validateGatewaySubnetMask = (_, allValues) => { + if (!allValues || !allValues.v4 || !allValues.v4.subnet_mask || !allValues.v4.gateway_ip) { + return 'gateway_or_subnet_invalid'; + } + + const { subnet_mask, gateway_ip } = allValues.v4; + + if (validateIpv4(gateway_ip)) { + return 'form_error_ip4_gateway_format'; + } + + return parseSubnetMask(subnet_mask) ? undefined : 'gateway_or_subnet_invalid'; +}; + +/** + * @returns {undefined|string} + * @param value + * @param allValues + */ +export const validateIpForGatewaySubnetMask = (value, allValues) => { + if (!allValues || !allValues.v4 || !value) { + return undefined; + } + + const { + gateway_ip, subnet_mask, + } = allValues.v4; + + const subnetPrefix = parseSubnetMask(subnet_mask); + + if (!isIpInCidr(value, `${gateway_ip}/${subnetPrefix}`)) { + return 'subnet_error'; + } + + return undefined; +}; + /** * @param value {string} * @returns {undefined|string} diff --git a/go.mod b/go.mod index 728a0d7f..85738a6e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.16 require ( github.com/AdguardTeam/dnsproxy v0.39.8 - github.com/AdguardTeam/golibs v0.10.0 + github.com/AdguardTeam/golibs v0.10.2 github.com/AdguardTeam/urlfilter v0.14.6 github.com/NYTimes/gziphandler v1.1.1 github.com/ameshkov/dnscrypt/v2 v2.2.2 diff --git a/go.sum b/go.sum index 6ec3396e..86c53f10 100644 --- a/go.sum +++ b/go.sum @@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.39.8/go.mod h1:eDpJKAdkHORRwAedjuERv+7SWlcz4c github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.9.2/go.mod h1:fCAMwPBJ8S7YMYbTWvYS+eeTLblP5E04IDtNAo7y7IY= -github.com/AdguardTeam/golibs v0.10.0 h1:A7MXRfZ+ItpOyS9tWKtqrLj3vZtE9FJFC+dOVY/LcWs= -github.com/AdguardTeam/golibs v0.10.0/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= +github.com/AdguardTeam/golibs v0.10.2 h1:TAwnS4Y49sSUa4UX1yz/MWNGbIlXHqafrWr9MxdIh9A= +github.com/AdguardTeam/golibs v0.10.2/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU= github.com/AdguardTeam/urlfilter v0.14.6 h1:emqoKZElooHACYehRBYENeKVN1a/rspxiqTIMYLuoIo= github.com/AdguardTeam/urlfilter v0.14.6/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U= diff --git a/internal/dhcpd/dhcpd_test.go b/internal/dhcpd/dhcpd_test.go index 21f63fd1..8609481b 100644 --- a/internal/dhcpd/dhcpd_test.go +++ b/internal/dhcpd/dhcpd_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -138,6 +139,49 @@ func TestNormalizeLeases(t *testing.T) { assert.Equal(t, leases[2].HWAddr, dynLeases[1].HWAddr) } +func TestV4Server_badRange(t *testing.T) { + testCases := []struct { + name string + gatewayIP net.IP + subnetMask net.IP + wantErrMsg string + }{{ + name: "gateway_in_range", + gatewayIP: net.IP{192, 168, 10, 120}, + subnetMask: net.IP{255, 255, 255, 0}, + wantErrMsg: "dhcpv4: gateway ip 192.168.10.120 in the ip range: " + + "192.168.10.20-192.168.10.200", + }, { + name: "outside_range_start", + gatewayIP: net.IP{192, 168, 10, 1}, + subnetMask: net.IP{255, 255, 255, 240}, + wantErrMsg: "dhcpv4: range start 192.168.10.20 is outside network " + + "192.168.10.1/28", + }, { + name: "outside_range_end", + gatewayIP: net.IP{192, 168, 10, 1}, + subnetMask: net.IP{255, 255, 255, 224}, + wantErrMsg: "dhcpv4: range end 192.168.10.200 is outside network " + + "192.168.10.1/27", + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + conf := V4ServerConf{ + Enabled: true, + RangeStart: net.IP{192, 168, 10, 20}, + RangeEnd: net.IP{192, 168, 10, 200}, + GatewayIP: tc.gatewayIP, + SubnetMask: tc.subnetMask, + notify: testNotify, + } + + _, err := v4Create(conf) + testutil.AssertErrorMsg(t, tc.wantErrMsg, err) + }) + } +} + // cloneUDPAddr returns a deep copy of a. func cloneUDPAddr(a *net.UDPAddr) (clone *net.UDPAddr) { return &net.UDPAddr{ diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index d7e4dfae..0107eddf 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -293,6 +293,8 @@ func (s *v4Server) addLease(l *Lease) (err error) { offset, inOffset := r.offset(l.IP) if l.IsStatic() { + // TODO(a.garipov, d.seregin): Subnet can be nil when dhcp server is + // disabled. if sn := s.conf.subnet; !sn.Contains(l.IP) { return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP) } @@ -900,9 +902,10 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int { resp.UpdateOption(dhcpv4.OptGeneric(code, configured.Get(code))) } } - // Update the value of Domain Name Server option separately from others - // since its value is set after server's creating. - if requested.Has(dhcpv4.OptionDomainNameServer) { + // Update the value of Domain Name Server option separately from others if + // not assigned yet since its value is set after server's creating. + if requested.Has(dhcpv4.OptionDomainNameServer) && + !resp.Options.Has(dhcpv4.OptionDomainNameServer) { resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...)) } @@ -1124,6 +1127,29 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { return s, fmt.Errorf("dhcpv4: %w", err) } + if s.conf.ipRange.contains(routerIP) { + return s, fmt.Errorf("dhcpv4: gateway ip %v in the ip range: %v-%v", + routerIP, + conf.RangeStart, + conf.RangeEnd, + ) + } + + if !s.conf.subnet.Contains(conf.RangeStart) { + return s, fmt.Errorf("dhcpv4: range start %v is outside network %v", + conf.RangeStart, + s.conf.subnet, + ) + } + + if !s.conf.subnet.Contains(conf.RangeEnd) { + return s, fmt.Errorf("dhcpv4: range end %v is outside network %v", + conf.RangeEnd, + s.conf.subnet, + ) + } + + // TODO(a.garipov, d.seregin): Check that every lease is inside the IPRange. s.leasedOffsets = newBitSet() if conf.LeaseDuration == 0 { diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index d20e715c..9bed3c60 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -5,8 +5,10 @@ package dhcpd import ( "net" + "strings" "testing" + "github.com/AdguardTeam/golibs/stringutil" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/mdlayher/raw" "github.com/stretchr/testify/assert" @@ -16,17 +18,34 @@ import ( func notify4(flags uint32) { } -func TestV4_AddRemove_static(t *testing.T) { - s, err := v4Create(V4ServerConf{ +// defaultV4ServerConf returns the default configuration for *v4Server to use in +// tests. +func defaultV4ServerConf() (conf V4ServerConf) { + return V4ServerConf{ Enabled: true, RangeStart: net.IP{192, 168, 10, 100}, RangeEnd: net.IP{192, 168, 10, 200}, GatewayIP: net.IP{192, 168, 10, 1}, SubnetMask: net.IP{255, 255, 255, 0}, notify: notify4, - }) + } +} + +// defaultSrv prepares the default DHCPServer to use in tests. The underlying +// type of s is *v4Server. +func defaultSrv(t *testing.T) (s DHCPServer) { + t.Helper() + + var err error + s, err = v4Create(defaultV4ServerConf()) require.NoError(t, err) + return s +} + +func TestV4_AddRemove_static(t *testing.T) { + s := defaultSrv(t) + ls := s.GetLeases(LeasesStatic) assert.Empty(t, ls) @@ -37,7 +56,7 @@ func TestV4_AddRemove_static(t *testing.T) { IP: net.IP{192, 168, 10, 150}, } - err = s.AddStaticLease(l) + err := s.AddStaticLease(l) require.NoError(t, err) err = s.AddStaticLease(l) @@ -65,15 +84,7 @@ func TestV4_AddRemove_static(t *testing.T) { } func TestV4_AddReplace(t *testing.T) { - sIface, err := v4Create(V4ServerConf{ - Enabled: true, - RangeStart: net.IP{192, 168, 10, 100}, - RangeEnd: net.IP{192, 168, 10, 200}, - GatewayIP: net.IP{192, 168, 10, 1}, - SubnetMask: net.IP{255, 255, 255, 0}, - notify: notify4, - }) - require.NoError(t, err) + sIface := defaultSrv(t) s, ok := sIface.(*v4Server) require.True(t, ok) @@ -89,7 +100,7 @@ func TestV4_AddReplace(t *testing.T) { }} for i := range dynLeases { - err = s.addLease(&dynLeases[i]) + err := s.addLease(&dynLeases[i]) require.NoError(t, err) } @@ -104,7 +115,7 @@ func TestV4_AddReplace(t *testing.T) { }} for _, l := range stLeases { - err = s.AddStaticLease(l) + err := s.AddStaticLease(l) require.NoError(t, err) } @@ -118,17 +129,80 @@ func TestV4_AddReplace(t *testing.T) { } } -func TestV4StaticLease_Get(t *testing.T) { - var err error - sIface, err := v4Create(V4ServerConf{ - Enabled: true, - RangeStart: net.IP{192, 168, 10, 100}, - RangeEnd: net.IP{192, 168, 10, 200}, - GatewayIP: net.IP{192, 168, 10, 1}, - SubnetMask: net.IP{255, 255, 255, 0}, - notify: notify4, +func TestV4Server_Process_optionsPriority(t *testing.T) { + defaultIP := net.IP{192, 168, 1, 1} + knownIP := net.IP{1, 2, 3, 4} + + // prepareSrv creates a *v4Server and sets the opt6IPs in the initial + // configuration of the server as the value for DHCP option 6. + prepareSrv := func(t *testing.T, opt6IPs []net.IP) (s *v4Server) { + t.Helper() + + conf := defaultV4ServerConf() + if len(opt6IPs) > 0 { + b := &strings.Builder{} + stringutil.WriteToBuilder(b, "6 ips ", opt6IPs[0].String()) + for _, ip := range opt6IPs[1:] { + stringutil.WriteToBuilder(b, ",", ip.String()) + } + conf.Options = []string{b.String()} + } + + ss, err := v4Create(conf) + require.NoError(t, err) + + var ok bool + s, ok = ss.(*v4Server) + require.True(t, ok) + + s.conf.dnsIPAddrs = []net.IP{defaultIP} + + return s + } + + // checkResp creates a discovery message with DHCP option 6 requested amd + // asserts the response to contain wantIPs in this option. + checkResp := func(t *testing.T, s *v4Server, wantIPs []net.IP) { + t.Helper() + + mac := net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA} + req, err := dhcpv4.NewDiscovery(mac, dhcpv4.WithRequestedOptions( + dhcpv4.OptionDomainNameServer, + )) + require.NoError(t, err) + + var resp *dhcpv4.DHCPv4 + resp, err = dhcpv4.NewReplyFromRequest(req) + require.NoError(t, err) + + res := s.process(req, resp) + require.Equal(t, 1, res) + + o := resp.GetOneOption(dhcpv4.OptionDomainNameServer) + require.NotEmpty(t, o) + + wantData := []byte{} + for _, ip := range wantIPs { + wantData = append(wantData, ip...) + } + assert.Equal(t, o, wantData) + } + + t.Run("default", func(t *testing.T) { + s := prepareSrv(t, nil) + + checkResp(t, s, []net.IP{defaultIP}) }) - require.NoError(t, err) + + t.Run("explicitly_configured", func(t *testing.T) { + s := prepareSrv(t, []net.IP{knownIP, knownIP}) + + checkResp(t, s, []net.IP{knownIP, knownIP}) + }) +} + +func TestV4StaticLease_Get(t *testing.T) { + sIface := defaultSrv(t) s, ok := sIface.(*v4Server) require.True(t, ok) @@ -140,7 +214,7 @@ func TestV4StaticLease_Get(t *testing.T) { HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}, IP: net.IP{192, 168, 10, 150}, } - err = s.AddStaticLease(l) + err := s.AddStaticLease(l) require.NoError(t, err) var req, resp *dhcpv4.DHCPv4 @@ -208,19 +282,14 @@ func TestV4StaticLease_Get(t *testing.T) { } func TestV4DynamicLease_Get(t *testing.T) { + conf := defaultV4ServerConf() + conf.Options = []string{ + "81 hex 303132", + "82 ip 1.2.3.4", + } + var err error - sIface, err := v4Create(V4ServerConf{ - Enabled: true, - RangeStart: net.IP{192, 168, 10, 100}, - RangeEnd: net.IP{192, 168, 10, 200}, - GatewayIP: net.IP{192, 168, 10, 1}, - SubnetMask: net.IP{255, 255, 255, 0}, - notify: notify4, - Options: []string{ - "81 hex 303132", - "82 ip 1.2.3.4", - }, - }) + sIface, err := v4Create(conf) require.NoError(t, err) s, ok := sIface.(*v4Server) diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 4f1b8d3a..a0aa4cb1 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -90,7 +90,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { s.processRestrictLocal, s.processInternalIPAddrs, s.processClientID, - processFilteringBeforeRequest, + s.processFilteringBeforeRequest, s.processLocalPTR, s.processUpstream, processDNSSECAfterResponse, @@ -468,19 +468,18 @@ func (s *Server) processLocalPTR(ctx *dnsContext) (rc resultCode) { } // Apply filtering logic -func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { - s := ctx.srv - d := ctx.proxyCtx - - if d.Res != nil { - return resultCodeSuccess // response is already set - nothing to do +func (s *Server) processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { + if ctx.proxyCtx.Res != nil { + // Go on since the response is already set. + return resultCodeSuccess } s.serverLock.RLock() defer s.serverLock.RUnlock() - ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil - if !ctx.protectionEnabled { + ctx.protectionEnabled = s.conf.ProtectionEnabled + + if s.dnsFilter == nil { return resultCodeSuccess } @@ -489,8 +488,7 @@ func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { } var err error - ctx.result, err = s.filterDNSRequest(ctx) - if err != nil { + if ctx.result, err = s.filterDNSRequest(ctx); err != nil { ctx.err = err return resultCodeError @@ -608,48 +606,50 @@ func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx - res := ctx.result - var err error - switch res.Reason { - case filtering.Rewritten, + switch res := ctx.result; res.Reason { + case filtering.NotFilteredAllowList: + // Go on. + case + filtering.Rewritten, filtering.RewrittenRule: if len(ctx.origQuestion.Name) == 0 { - // origQuestion is set in case we get only CNAME without IP from rewrites table + // origQuestion is set in case we get only CNAME without IP from + // rewrites table. break } - d.Req.Question[0] = ctx.origQuestion - d.Res.Question[0] = ctx.origQuestion - - if len(d.Res.Answer) != 0 { - answer := []dns.RR{} - answer = append(answer, s.genAnswerCNAME(d.Req, res.CanonName)) - answer = append(answer, d.Res.Answer...) + d.Req.Question[0], d.Res.Question[0] = ctx.origQuestion, ctx.origQuestion + if len(d.Res.Answer) > 0 { + answer := append([]dns.RR{s.genAnswerCNAME(d.Req, res.CanonName)}, d.Res.Answer...) d.Res.Answer = answer } - - case filtering.NotFilteredAllowList: - // nothing - default: - if !ctx.protectionEnabled || // filters are disabled: there's nothing to check for - !ctx.responseFromUpstream { // only check response if it's from an upstream server + // Check the response only if the it's from an upstream. Don't check + // the response if the protection is disabled since dnsrewrite rules + // aren't applied to it anyway. + if !ctx.protectionEnabled || !ctx.responseFromUpstream || s.dnsFilter == nil { break } - origResp2 := d.Res - ctx.result, err = s.filterDNSResponse(ctx) + + origResp := d.Res + result, err := s.filterDNSResponse(ctx) if err != nil { ctx.err = err + return resultCodeError } - if ctx.result != nil { - ctx.origResp = origResp2 // matched by response - } else { - ctx.result = &filtering.Result{} + + if result != nil { + ctx.result = result + ctx.origResp = origResp } } + if ctx.result == nil { + ctx.result = &filtering.Result{} + } + return resultCodeSuccess } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 3ab76a40..a3aa68c6 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -909,6 +909,7 @@ func TestRewrite(t *testing.T) { }}, } f := filtering.New(c, nil) + f.SetEnabled(true) snd, err := aghnet.NewSubnetDetector() require.NoError(t, err) @@ -945,45 +946,56 @@ func TestRewrite(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessageWithType("test.com.", dns.TypeA) - reply, err := dns.Exchange(req, addr.String()) - require.NoError(t, err) + subTestFunc := func(t *testing.T) { + req := createTestMessageWithType("test.com.", dns.TypeA) + reply, eerr := dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Len(t, reply.Answer, 1) + require.Len(t, reply.Answer, 1) - a, ok := reply.Answer[0].(*dns.A) - require.True(t, ok) + a, ok := reply.Answer[0].(*dns.A) + require.True(t, ok) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A)) - req = createTestMessageWithType("test.com.", dns.TypeAAAA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("test.com.", dns.TypeAAAA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - assert.Empty(t, reply.Answer) + assert.Empty(t, reply.Answer) - req = createTestMessageWithType("alias.test.com.", dns.TypeA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("alias.test.com.", dns.TypeA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Len(t, reply.Answer, 2) + require.Len(t, reply.Answer, 2) - assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) - assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) + assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) + assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A)) - req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) - reply, err = dns.Exchange(req, addr.String()) - require.NoError(t, err) + req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) + reply, eerr = dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - // The original question is restored. - require.Len(t, reply.Question, 1) + // The original question is restored. + require.Len(t, reply.Question, 1) - assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) + assert.Equal(t, "my.alias.example.org.", reply.Question[0].Name) - require.Len(t, reply.Answer, 2) + require.Len(t, reply.Answer, 2) - assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) - assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + assert.Equal(t, "example.org.", reply.Answer[0].(*dns.CNAME).Target) + assert.Equal(t, dns.TypeA, reply.Answer[1].Header().Rrtype) + } + + for _, protect := range []bool{true, false} { + val := protect + conf := s.getDNSConfig() + conf.ProtectionEnabled = &val + s.setConfig(conf) + + t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc) + } } func publicKey(priv interface{}) interface{} { @@ -1092,9 +1104,10 @@ func TestPTRResponseFromHosts(t *testing.T) { require.ErrorIs(t, hc.Close(), closeCalled) }) - c := filtering.Config{ + flt := filtering.New(&filtering.Config{ EtcHosts: hc, - } + }, nil) + flt.SetEnabled(true) var snd *aghnet.SubnetDetector snd, err = aghnet.NewSubnetDetector() @@ -1104,7 +1117,7 @@ func TestPTRResponseFromHosts(t *testing.T) { var s *Server s, err = NewServer(DNSCreateParams{ DHCPServer: &testDHCP{}, - DNSFilter: filtering.New(&c, nil), + DNSFilter: flt, SubnetDetector: snd, }) require.NoError(t, err) @@ -1112,32 +1125,41 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} - s.conf.FilteringConfig.ProtectionEnabled = true err = s.Prepare(nil) require.NoError(t, err) err = s.Start() require.NoError(t, err) - t.Cleanup(func() { s.Close() }) - addr := s.dnsProxy.Addr(proxy.ProtoUDP) - req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) + subTestFunc := func(t *testing.T) { + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + req := createTestMessageWithType("1.0.0.127.in-addr.arpa.", dns.TypePTR) - resp, err := dns.Exchange(req, addr.String()) - require.NoError(t, err) + resp, eerr := dns.Exchange(req, addr.String()) + require.NoError(t, eerr) - require.Lenf(t, resp.Answer, 1, "%#v", resp) + require.Len(t, resp.Answer, 1) - assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) - assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) + assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype) + assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name) - ptr, ok := resp.Answer[0].(*dns.PTR) - require.True(t, ok) - assert.Equal(t, "host.", ptr.Ptr) + ptr, ok := resp.Answer[0].(*dns.PTR) + require.True(t, ok) + assert.Equal(t, "host.", ptr.Ptr) + } + + for _, protect := range []bool{true, false} { + val := protect + conf := s.getDNSConfig() + conf.ProtectionEnabled = &val + s.setConfig(conf) + + t.Run(fmt.Sprintf("protection_is_%t", val), subTestFunc) + } } func TestNewServer(t *testing.T) { diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 5edca948..7300b43c 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -52,6 +52,7 @@ func (s *Server) beforeRequestHandler( // the client's IP address and ID, if any, from ctx. func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.Settings { setts := s.dnsFilter.GetConfig() + setts.ProtectionEnabled = ctx.protectionEnabled if s.conf.FilterHandler != nil { ip, _ := netutil.IPAndPortFromAddr(ctx.proxyCtx.Addr) s.conf.FilterHandler(ip, ctx.clientID, &setts) @@ -65,32 +66,31 @@ func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *filtering.S func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { d := ctx.proxyCtx req := d.Req - host := strings.TrimSuffix(req.Question[0].Name, ".") - res, err := s.dnsFilter.CheckHost(host, req.Question[0].Qtype, ctx.setts) - if err != nil { - // Return immediately if there's an error - return nil, fmt.Errorf("filtering failed to check host %q: %w", host, err) - } else if res.IsFiltered { - log.Tracef("Host %s is filtered, reason - %q, matched rule: %q", host, res.Reason, res.Rules[0].Text) + q := req.Question[0] + host := strings.TrimSuffix(q.Name, ".") + res, err := s.dnsFilter.CheckHost(host, q.Qtype, ctx.setts) + switch { + case err != nil: + return nil, fmt.Errorf("failed to check host %q: %w", host, err) + case res.IsFiltered: + log.Tracef("host %q is filtered, reason %q, rule: %q", host, res.Reason, res.Rules[0].Text) d.Res = s.genDNSFilterMessage(d, &res) - } else if res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && + case res.Reason.In(filtering.Rewritten, filtering.RewrittenRule) && res.CanonName != "" && - len(res.IPList) == 0 { - // Resolve the new canonical name, not the original host - // name. The original question is readded in - // processFilteringAfterResponse. - ctx.origQuestion = req.Question[0] + len(res.IPList) == 0: + // Resolve the new canonical name, not the original host name. The + // original question is readded in processFilteringAfterResponse. + ctx.origQuestion = q req.Question[0].Name = dns.Fqdn(res.CanonName) - } else if res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0 { + case res.Reason == filtering.RewrittenAutoHosts && len(res.ReverseHosts) != 0: resp := s.makeResponse(req) + hdr := dns.RR_Header{ + Name: q.Name, + Rrtype: dns.TypePTR, + Ttl: s.conf.BlockedResponseTTL, + Class: dns.ClassINET, + } for _, h := range res.ReverseHosts { - hdr := dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: dns.TypePTR, - Ttl: s.conf.BlockedResponseTTL, - Class: dns.ClassINET, - } - ptr := &dns.PTR{ Hdr: hdr, Ptr: h, @@ -100,7 +100,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } d.Res = resp - } else if res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts) { + case res.Reason.In(filtering.Rewritten, filtering.RewrittenAutoHosts): resp := s.makeResponse(req) name := host @@ -110,11 +110,12 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } for _, ip := range res.IPList { - if req.Question[0].Qtype == dns.TypeA { + switch q.Qtype { + case dns.TypeA: a := s.genAnswerA(req, ip.To4()) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) - } else if req.Question[0].Qtype == dns.TypeAAAA { + case dns.TypeAAAA: a := s.genAnswerAAAA(req, ip) a.Hdr.Name = dns.Fqdn(name) resp.Answer = append(resp.Answer, a) @@ -122,9 +123,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*filtering.Result, error) { } d.Res = resp - } else if res.Reason == filtering.RewrittenRule { - err = s.filterDNSRewrite(req, res, d) - if err != nil { + case res.Reason == filtering.RewrittenRule: + if err = s.filterDNSRewrite(req, res, d); err != nil { return nil, err } } @@ -179,6 +179,7 @@ func (s *Server) filterDNSResponse(ctx *dnsContext) (*filtering.Result, error) { continue } + host = strings.TrimSuffix(host, ".") res, err := s.checkHostRules(host, d.Req.Question[0].Qtype, ctx.setts) if err != nil { return nil, err diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index f7a8ebe4..87d0da15 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -38,6 +38,7 @@ type Settings struct { ServicesRules []ServiceEntry + ProtectionEnabled bool FilteringEnabled bool SafeSearchEnabled bool SafeBrowsingEnabled bool @@ -221,12 +222,13 @@ func (r Reason) String() string { } // In returns true if reasons include r. -func (r Reason) In(reasons ...Reason) bool { +func (r Reason) In(reasons ...Reason) (ok bool) { for _, reason := range reasons { if r == reason { return true } } + return false } @@ -245,7 +247,7 @@ func (d *DNSFilter) GetConfig() (s Settings) { defer d.confLock.RUnlock() return Settings{ - FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) == 1, + FilteringEnabled: atomic.LoadUint32(&d.Config.enabled) != 0, SafeSearchEnabled: d.Config.SafeSearchEnabled, SafeBrowsingEnabled: d.Config.SafeBrowsingEnabled, ParentalEnabled: d.Config.ParentalEnabled, @@ -421,14 +423,16 @@ func (d *DNSFilter) CheckHost( // Sometimes clients try to resolve ".", which is a request to get root // servers. if host == "" { - return Result{Reason: NotFilteredNotFound}, nil + return Result{}, nil } host = strings.ToLower(host) - res = d.processRewrites(host, qtype) - if res.Reason == Rewritten { - return res, nil + if setts.FilteringEnabled { + res = d.processRewrites(host, qtype) + if res.Reason == Rewritten { + return res, nil + } } for _, hc := range d.hostCheckers { @@ -448,7 +452,7 @@ func (d *DNSFilter) CheckHost( // matchSysHosts tries to match the host against the operating system's hosts // database. func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (res Result, err error) { - if d.EtcHosts == nil { + if !setts.FilteringEnabled || d.EtcHosts == nil { return Result{}, nil } @@ -468,10 +472,8 @@ func (d *DNSFilter) matchSysHosts(host string, qtype uint16, setts *Settings) (r var ips []net.IP var revHosts []string - for _, nr := range dnsr { - dr := nr.DNSRewrite - if dr == nil { + if nr.DNSRewrite == nil { continue } @@ -553,6 +555,10 @@ func matchBlockedServicesRules( _ uint16, setts *Settings, ) (res Result, err error) { + if !setts.ProtectionEnabled { + return Result{}, nil + } + svcs := setts.ServicesRules if len(svcs) == 0 { return Result{}, nil @@ -784,7 +790,7 @@ func (d *DNSFilter) matchHost( // TODO(e.burkov): Inspect if the above is true. defer d.engineLock.RUnlock() - if d.filteringEngineAllow != nil { + if setts.ProtectionEnabled && d.filteringEngineAllow != nil { dnsres, ok := d.filteringEngineAllow.MatchRequest(ureq) if ok { return d.matchHostProcessAllowList(host, dnsres) @@ -810,6 +816,11 @@ func (d *DNSFilter) matchHost( return Result{}, nil } + if !setts.ProtectionEnabled { + // Don't check non-dnsrewrite filtering results. + return Result{}, nil + } + res = d.matchHostProcessDNSResult(qtype, dnsres) for _, r := range res.Rules { log.Debug( diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 746e9ed0..b389b386 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -21,7 +21,9 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } -var setts Settings +var setts = Settings{ + ProtectionEnabled: true, +} // Helpers. @@ -39,9 +41,9 @@ func purgeCaches() { func newForTest(c *Config, filters []Filter) *DNSFilter { setts = Settings{ - FilteringEnabled: true, + ProtectionEnabled: true, + FilteringEnabled: true, } - setts.FilteringEnabled = true if c != nil { c.SafeBrowsingCacheSize = 10000 c.ParentalCacheSize = 10000 @@ -797,7 +799,11 @@ func TestClientSettings(t *testing.T) { makeTester := func(tc testCase, before bool) func(t *testing.T) { return func(t *testing.T) { - r, _ := d.CheckHost(tc.host, dns.TypeA, &setts) + t.Helper() + + r, err := d.CheckHost(tc.host, dns.TypeA, &setts) + require.NoError(t, err) + if before { assert.True(t, r.IsFiltered) assert.Equal(t, tc.wantReason, r.Reason) @@ -808,7 +814,7 @@ func TestClientSettings(t *testing.T) { } // Check behaviour without any per-client settings, then apply per-client - // settings and check behaviour once again. + // settings and check behavior once again. for _, tc := range testCases { t.Run(tc.name, makeTester(tc, tc.before)) } diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index d535a39d..ec626315 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -306,7 +306,7 @@ func (d *DNSFilter) checkSafeBrowsing( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.SafeBrowsingEnabled { + if !setts.ProtectionEnabled || !setts.SafeBrowsingEnabled { return Result{}, nil } @@ -339,7 +339,7 @@ func (d *DNSFilter) checkParental( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.ParentalEnabled { + if !setts.ProtectionEnabled || !setts.ParentalEnabled { return Result{}, nil } diff --git a/internal/filtering/safebrowsing_test.go b/internal/filtering/safebrowsing_test.go index d513c0b2..c88576f1 100644 --- a/internal/filtering/safebrowsing_test.go +++ b/internal/filtering/safebrowsing_test.go @@ -117,6 +117,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) { d.SetParentalUpstream(ups) setts := &Settings{ + ProtectionEnabled: true, SafeBrowsingEnabled: true, ParentalEnabled: true, } @@ -135,35 +136,36 @@ func TestSBPC(t *testing.T) { const hostname = "example.org" setts := &Settings{ + ProtectionEnabled: true, SafeBrowsingEnabled: true, ParentalEnabled: true, } testCases := []struct { + testCache cache.Cache + testFunc func(host string, _ uint16, _ *Settings) (res Result, err error) name string block bool - testFunc func(host string, _ uint16, _ *Settings) (res Result, err error) - testCache cache.Cache }{{ + testCache: gctx.safebrowsingCache, + testFunc: d.checkSafeBrowsing, name: "sb_no_block", block: false, - testFunc: d.checkSafeBrowsing, - testCache: gctx.safebrowsingCache, }, { + testCache: gctx.safebrowsingCache, + testFunc: d.checkSafeBrowsing, name: "sb_block", block: true, - testFunc: d.checkSafeBrowsing, - testCache: gctx.safebrowsingCache, }, { + testCache: gctx.parentalCache, + testFunc: d.checkParental, name: "pc_no_block", block: false, - testFunc: d.checkParental, - testCache: gctx.parentalCache, }, { + testCache: gctx.parentalCache, + testFunc: d.checkParental, name: "pc_block", block: true, - testFunc: d.checkParental, - testCache: gctx.parentalCache, }} for _, tc := range testCases { diff --git a/internal/filtering/safesearch.go b/internal/filtering/safesearch.go index db9e9ee9..ff89b950 100644 --- a/internal/filtering/safesearch.go +++ b/internal/filtering/safesearch.go @@ -74,7 +74,7 @@ func (d *DNSFilter) checkSafeSearch( _ uint16, setts *Settings, ) (res Result, err error) { - if !setts.SafeSearchEnabled { + if !setts.ProtectionEnabled || !setts.SafeSearchEnabled { return Result{}, nil } diff --git a/internal/home/controlfiltering.go b/internal/home/controlfiltering.go index 550b4b87..ee19fd16 100644 --- a/internal/home/controlfiltering.go +++ b/internal/home/controlfiltering.go @@ -404,6 +404,7 @@ func (f *Filtering) handleCheckHost(w http.ResponseWriter, r *http.Request) { setts := Context.dnsFilter.GetConfig() setts.FilteringEnabled = true + setts.ProtectionEnabled = true Context.dnsFilter.ApplyBlockedServices(&setts, nil, true) result, err := Context.dnsFilter.CheckHost(host, dns.TypeA, &setts) if err != nil {