diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a3456a3..367a9940 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,11 +10,13 @@ and this project adheres to ## [Unreleased] ### Added +- Client ID support for DNS-over-HTTPS, DNS-over-QUIC, and DNS-over-TLS + ([#1383]). - `$dnsrewrite` modifier for filters ([#2102]). - The host checking API and the query logs API can now return multiple matched rules ([#2102]). @@ -27,6 +29,7 @@ and this project adheres to - HTTP API request body size limit ([#2305]). [#1361]: https://github.com/AdguardTeam/AdGuardHome/issues/1361 +[#1383]: https://github.com/AdguardTeam/AdGuardHome/issues/1383 [#2102]: https://github.com/AdguardTeam/AdGuardHome/issues/2102 [#2302]: https://github.com/AdguardTeam/AdGuardHome/issues/2302 [#2304]: https://github.com/AdguardTeam/AdGuardHome/issues/2304 @@ -35,6 +38,7 @@ and this project adheres to ### Changed +- `workDir` now supports symlinks. - Stopped mounting together the directories `/opt/adguardhome/conf` and `/opt/adguardhome/work` in our Docker images ([#2589]). - When `dns.bogus_nxdomain` option is used, the server will now transform diff --git a/HACKING.md b/HACKING.md index dc291bb7..0a4d102e 100644 --- a/HACKING.md +++ b/HACKING.md @@ -62,9 +62,12 @@ The rules are mostly sorted in the alphabetical order. * Don't use underscores in file and package names, unless they're build tags or for tests. This is to prevent accidental build errors with weird tags. - * Don't write code with more than four (**4**) levels of indentation. Just - like [Linus said], plus an additional level for an occasional error check or - struct initialization. + * Don't write non-test code with more than four (**4**) levels of indentation. + Just like [Linus said], plus an additional level for an occasional error + check or struct initialization. + + The exception proving the rule is the table-driven test code, where an + additional level of indentation is allowed. * Eschew external dependencies, including transitive, unless absolutely necessary. diff --git a/Makefile b/Makefile index d116d4e4..3d8625e6 100644 --- a/Makefile +++ b/Makefile @@ -80,7 +80,10 @@ go-lint: ; $(ENV) "$(SHELL)" ./scripts/make/go-lint.sh go-test: ; $(ENV) "$(SHELL)" ./scripts/make/go-test.sh go-tools: ; $(ENV) "$(SHELL)" ./scripts/make/go-tools.sh +go-check: go-tools go-lint go-test + openapi-lint: ; cd ./openapi/ && $(YARN) test +openapi-show: ; cd ./openapi/ && $(YARN) start # TODO(a.garipov): Remove the legacy targets once the build # infrastructure stops using them. diff --git a/README.md b/README.md index e694452f..6b60d888 100644 --- a/README.md +++ b/README.md @@ -87,12 +87,21 @@ If you're running **Linux**, there's a secure and easy way to install AdGuard Ho ### Guides -* [FAQ](https://github.com/AdguardTeam/AdGuardHome/wiki/FAQ) -* [Configuration](https://github.com/AdguardTeam/AdGuardHome/wiki/Configuration) -* [AdGuard Home as a DNS-over-HTTPS or DNS-over-TLS server](https://github.com/AdguardTeam/AdGuardHome/wiki/Encryption) -* [How to install and run AdGuard Home on Raspberry Pi](https://github.com/AdguardTeam/AdGuardHome/wiki/Raspberry-Pi) -* [How to install and run AdGuard Home on a Virtual Private Server](https://github.com/AdguardTeam/AdGuardHome/wiki/VPS) -* [How to write your own hosts blocklists properly](https://github.com/AdguardTeam/AdGuardHome/wiki/Hosts-Blocklists) +* [Getting Started](https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started) + * [FAQ](https://github.com/AdguardTeam/AdGuardHome/wiki/FAQ) + * [How to Write Hosts Blocklists](https://github.com/AdguardTeam/AdGuardHome/wiki/Hosts-Blocklists) + * [Comparing AdGuard Home to Other Solutions](https://github.com/AdguardTeam/AdGuardHome/wiki/Comparison) +* Configuring AdGuard + * [Configuration](https://github.com/AdguardTeam/AdGuardHome/wiki/Configuration) + * [Configuring AdGuard Home Clients](https://github.com/AdguardTeam/AdGuardHome/wiki/Clients) + * [AdGuard Home as a DoH, DoT, or DoQ Server](https://github.com/AdguardTeam/AdGuardHome/wiki/Encryption) + * [AdGuard Home as a DNSCrypt Server](https://github.com/AdguardTeam/AdGuardHome/wiki/DNSCrypt) + * [AdGuard Home as a DHCP Server](https://github.com/AdguardTeam/AdGuardHome/wiki/DHCP) +* Installing AdGuard Home + * [Docker](https://github.com/AdguardTeam/AdGuardHome/wiki/Docker) + * [How to Install and Run AdGuard Home on a Raspberry Pi](https://github.com/AdguardTeam/AdGuardHome/wiki/Raspberry-Pi) + * [How to Install and Run AdGuard Home on a Virtual Private Server](https://github.com/AdguardTeam/AdGuardHome/wiki/VPS) +* [Verifying Releases](https://github.com/AdguardTeam/AdGuardHome/wiki/Verify-Releases) ### API diff --git a/client/package-lock.json b/client/package-lock.json index f1172e00..56bf51c4 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -3066,12 +3066,6 @@ "pkg-up": "^2.0.0" } }, - "caniuse-lite": { - "version": "1.0.30001062", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001062.tgz", - "integrity": "sha512-ei9ZqeOnN7edDrb24QfJ0OZicpEbsWxv7WusOiQGz/f2SfvBgHHbOEwBJ8HKGVSyx8Z6ndPjxzR6m0NQq+0bfw==", - "dev": true - }, "postcss": { "version": "7.0.30", "resolved": "https://registry.npmjs.org/postcss/-/postcss-7.0.30.tgz", @@ -3928,9 +3922,9 @@ } }, "caniuse-lite": { - "version": "1.0.30001059", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001059.tgz", - "integrity": "sha512-oOrc+jPJWooKIA0IrNZ5sYlsXc7NP7KLhNWrSGEJhnfSzDvDJ0zd3i6HXsslExY9bbu+x0FQ5C61LcqmPt7bOQ==", + "version": "1.0.30001165", + "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001165.tgz", + "integrity": "sha512-8cEsSMwXfx7lWSUMA2s08z9dIgsnR5NAqjXP23stdsU3AUWkCr/rr4s4OFtHXn5XXr6+7kam3QFVoYyXNPdJPA==", "dev": true }, "capture-exit": { diff --git a/client/src/__locales/en.json b/client/src/__locales/en.json index afc2f105..a13f6cb4 100644 --- a/client/src/__locales/en.json +++ b/client/src/__locales/en.json @@ -32,6 +32,7 @@ "form_error_ip_format": "Invalid IP format", "form_error_mac_format": "Invalid MAC format", "form_error_client_id_format": "Invalid client ID format", + "form_error_server_name": "Invalid server name", "form_error_positive": "Must be greater than 0", "form_error_negative": "Must be equal to 0 or greater", "range_end_error": "Must be greater than range start", @@ -250,8 +251,12 @@ "dns_over_https": "DNS-over-HTTPS", "dns_over_tls": "DNS-over-TLS", "dns_over_quic": "DNS-over-QUIC", + "client_id": "Client ID", + "client_id_placeholder": "Enter client ID", + "client_id_desc": "Different clients can be identified by a special client ID. Here you can learn more about how to identify clients.", "download_mobileconfig_doh": "Download .mobileconfig for DNS-over-HTTPS", "download_mobileconfig_dot": "Download .mobileconfig for DNS-over-TLS", + "download_mobileconfig": "Download configuration file", "plain_dns": "Plain DNS", "form_enter_rate_limit": "Enter rate limit", "rate_limit": "Rate limit", @@ -331,7 +336,7 @@ "encryption_config_saved": "Encryption config saved", "encryption_server": "Server name", "encryption_server_enter": "Enter your domain name", - "encryption_server_desc": "In order to use HTTPS, you need to enter the server name that matches your SSL certificate.", + "encryption_server_desc": "In order to use HTTPS, you need to enter the server name that matches your SSL certificate or wildcard certificate. If the field is not set, it will accept TLS connections for any domain.", "encryption_redirect": "Redirect to HTTPS automatically", "encryption_redirect_desc": "If checked, AdGuard Home will automatically redirect you from HTTP to HTTPS addresses.", "encryption_https": "HTTPS port", @@ -387,7 +392,7 @@ "client_edit": "Edit Client", "client_identifier": "Identifier", "ip_address": "IP address", - "client_identifier_desc": "Clients can be identified by the IP address, CIDR, MAC address. Please note that using MAC as identifier is possible only if AdGuard Home is also a <0>DHCP server", + "client_identifier_desc": "Clients can be identified by the IP address, CIDR, MAC address or a special client ID (can be used for DoT/DoH/DoQ). <0>Here you can learn more about how to identify clients.", "form_enter_ip": "Enter IP", "form_enter_mac": "Enter MAC", "form_enter_id": "Enter identifier", @@ -431,6 +436,7 @@ "setup_dns_privacy_other_3": "<0>dnscrypt-proxy supports <1>DNS-over-HTTPS.", "setup_dns_privacy_other_4": "<0>Mozilla Firefox supports <1>DNS-over-HTTPS.", "setup_dns_privacy_other_5": "You will find more implementations <0>here and <1>here.", + "setup_dns_privacy_ioc_mac": "iOS and macOS configuration", "setup_dns_notice": "In order to use <1>DNS-over-HTTPS or <1>DNS-over-TLS, you need to <0>configure Encryption in AdGuard Home settings.", "rewrite_added": "DNS rewrite for \"{{key}}\" successfully added", "rewrite_deleted": "DNS rewrite for \"{{key}}\" successfully deleted", diff --git a/client/src/actions/queryLogs.js b/client/src/actions/queryLogs.js index 1c653fb3..de8af4a3 100644 --- a/client/src/actions/queryLogs.js +++ b/client/src/actions/queryLogs.js @@ -8,7 +8,7 @@ import { import { addErrorToast, addSuccessToast } from './toasts'; const enrichWithClientInfo = async (logs) => { - const clientsParams = getParamsForClientsSearch(logs, 'client'); + const clientsParams = getParamsForClientsSearch(logs, 'client', 'client_id'); if (Object.keys(clientsParams).length > 0) { const clients = await apiClient.findClients(clientsParams); diff --git a/client/src/components/App/index.css b/client/src/components/App/index.css index 0832e790..2b1ee76d 100644 --- a/client/src/components/App/index.css +++ b/client/src/components/App/index.css @@ -81,3 +81,7 @@ body { .ReactModal__Body--open { overflow: hidden; } + +a.btn-success.disabled { + color: #fff; +} diff --git a/client/src/components/Dashboard/Clients.js b/client/src/components/Dashboard/Clients.js index ccac7baf..46edc46f 100644 --- a/client/src/components/Dashboard/Clients.js +++ b/client/src/components/Dashboard/Clients.js @@ -9,7 +9,7 @@ import Card from '../ui/Card'; import Cell from '../ui/Cell'; import { getPercent, sortIp } from '../../helpers/helpers'; -import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants'; +import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants'; import { toggleClientBlock } from '../../actions/access'; import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell'; import { getStats } from '../../actions/stats'; @@ -35,6 +35,10 @@ const CountCell = (row) => { }; const renderBlockingButton = (ip, disallowed, disallowed_rule) => { + if (R_CLIENT_ID.test(ip)) { + return null; + } + const dispatch = useDispatch(); const { t } = useTranslation(); const processingSet = useSelector((state) => state.access.processingSet); @@ -59,17 +63,19 @@ const renderBlockingButton = (ip, disallowed, disallowed_rule) => { const text = disallowed ? BLOCK_ACTIONS.UNBLOCK : BLOCK_ACTIONS.BLOCK; const isNotInAllowedList = disallowed && disallowed_rule === ''; - return
- -
; + > + {text} + + + ); }; const ClientCell = (row) => { @@ -90,13 +96,14 @@ const Clients = ({ const { t } = useTranslation(); const topClients = useSelector((state) => state.stats.topClients, shallowEqual); - return - + ({ @@ -107,7 +114,7 @@ const Clients = ({ }))} columns={[ { - Header: 'IP', + Header: client_table_header, accessor: 'ip', sortMethod: sortIp, Cell: ClientCell, @@ -134,8 +141,9 @@ const Clients = ({ return disallowed ? { className: 'logs__row--red' } : {}; }} - /> - ; + /> + + ); }; Clients.propTypes = { diff --git a/client/src/components/Logs/Cells/ClientCell.js b/client/src/components/Logs/Cells/ClientCell.js index 56a3440c..863cd726 100644 --- a/client/src/components/Logs/Cells/ClientCell.js +++ b/client/src/components/Logs/Cells/ClientCell.js @@ -16,6 +16,7 @@ import { updateLogs } from '../../../actions/queryLogs'; const ClientCell = ({ client, + client_id, domain, info, info: { @@ -33,12 +34,14 @@ const ClientCell = ({ const autoClient = autoClients.find((autoClient) => autoClient.name === client); const source = autoClient?.source; const whoisAvailable = whois_info && Object.keys(whois_info).length > 0; + const clientName = name || client_id; + const clientInfo = { ...info, name: clientName }; const id = nanoid(); const data = { address: client, - name, + name: clientName, country: whois_info?.country, city: whois_info?.city, network: whois_info?.orgname, @@ -99,13 +102,20 @@ const ClientCell = ({ if (options.length === 0) { return null; } - return <>{options.map(({ name, onClick, disabled }) => )}; + return ( + <> + {options.map(({ name, onClick, disabled }) => ( + + ))} + + ); }; const content = getOptions(BUTTON_OPTIONS); @@ -125,45 +135,70 @@ const ClientCell = ({ 'button-action__container--detailed': isDetailed, }); - return
- - {content && } -
; + > + {t(buttonType)} + + {content && ( + + )} + + ); }; - return
- -
-
- {renderFormattedClientCell(client, info, isDetailed, true)} + return ( +
+ +
+
+ {renderFormattedClientCell(client, clientInfo, isDetailed, true)} +
+ {isDetailed && clientName && !whoisAvailable && ( +
+ {clientName} +
+ )}
- {isDetailed && name && !whoisAvailable - &&
{name}
} + {renderBlockingButton(isFiltered, domain)}
- {renderBlockingButton(isFiltered, domain)} -
; + ); }; ClientCell.propTypes = { client: propTypes.string.isRequired, + client_id: propTypes.string, domain: propTypes.string.isRequired, info: propTypes.oneOfType([ propTypes.string, diff --git a/client/src/components/Logs/Cells/index.js b/client/src/components/Logs/Cells/index.js index bc1fc2ed..5ec64f1f 100644 --- a/client/src/components/Logs/Cells/index.js +++ b/client/src/components/Logs/Cells/index.js @@ -70,6 +70,7 @@ const Row = memo(({ upstream, type, client_proto, + client_id, rules, originalResponse, status, @@ -176,7 +177,7 @@ const Row = memo(({ response_code: status, client_details: 'title', ip_address: client, - name: info?.name, + name: info?.name || client_id, country, city, network, @@ -233,6 +234,7 @@ Row.propTypes = { upstream: propTypes.string.isRequired, type: propTypes.string.isRequired, client_proto: propTypes.string.isRequired, + client_id: propTypes.string, rules: propTypes.arrayOf(propTypes.shape({ text: propTypes.string.isRequired, filter_list_id: propTypes.number.isRequired, diff --git a/client/src/components/Settings/Clients/Form.js b/client/src/components/Settings/Clients/Form.js index 20f1f828..631272d8 100644 --- a/client/src/components/Settings/Clients/Form.js +++ b/client/src/components/Settings/Clients/Form.js @@ -282,7 +282,7 @@ let Form = (props) => {
+ link , ]} diff --git a/client/src/components/Settings/Encryption/CertificateStatus.js b/client/src/components/Settings/Encryption/CertificateStatus.js index a93389f5..911d1d36 100644 --- a/client/src/components/Settings/Encryption/CertificateStatus.js +++ b/client/src/components/Settings/Encryption/CertificateStatus.js @@ -50,7 +50,7 @@ const CertificateStatus = ({ {dnsNames && (
  • encryption_hostnames:  - {dnsNames} + {dnsNames.join(', ')}
  • )} @@ -65,7 +65,7 @@ CertificateStatus.propTypes = { subject: PropTypes.string, issuer: PropTypes.string, notAfter: PropTypes.string, - dnsNames: PropTypes.string, + dnsNames: PropTypes.arrayOf(PropTypes.string), }; export default withTranslation()(CertificateStatus); diff --git a/client/src/components/Settings/Encryption/Form.js b/client/src/components/Settings/Encryption/Form.js index 1619ea3e..58cd181e 100644 --- a/client/src/components/Settings/Encryption/Form.js +++ b/client/src/components/Settings/Encryption/Form.js @@ -12,7 +12,7 @@ import { toNumber, } from '../../../helpers/form'; import { - validateIsSafePort, validatePort, validatePortQuic, validatePortTLS, + validateServerName, validateIsSafePort, validatePort, validatePortQuic, validatePortTLS, } from '../../../helpers/validators'; import i18n from '../../../i18n'; import KeyStatus from './KeyStatus'; @@ -127,6 +127,7 @@ let Form = (props) => { placeholder={t('encryption_server_enter')} onChange={handleChange} disabled={!isEnabled} + validate={validateServerName} />
    encryption_server_desc @@ -413,7 +414,7 @@ Form.propTypes = { valid_key: PropTypes.bool, valid_cert: PropTypes.bool, valid_pair: PropTypes.bool, - dns_names: PropTypes.string, + dns_names: PropTypes.arrayOf(PropTypes.string), key_type: PropTypes.string, issuer: PropTypes.string, subject: PropTypes.string, diff --git a/client/src/components/ui/Guide.js b/client/src/components/ui/Guide/Guide.js similarity index 71% rename from client/src/components/ui/Guide.js rename to client/src/components/ui/Guide/Guide.js index a06af83b..cf1af0b0 100644 --- a/client/src/components/ui/Guide.js +++ b/client/src/components/ui/Guide/Guide.js @@ -3,27 +3,12 @@ import PropTypes from 'prop-types'; import { Trans, useTranslation } from 'react-i18next'; import i18next from 'i18next'; import { useSelector } from 'react-redux'; -import Tabs from './Tabs'; -import Icons from './Icons'; -import { getPathWithQueryString } from '../../helpers/helpers'; -const MOBILE_CONFIG_LINKS = { - DOT: '/apple/dot.mobileconfig', - DOH: '/apple/doh.mobileconfig', -}; -const renderMobileconfigInfo = ({ label, components, server_name }) =>
  • - {label} - -
  • ; +import { MOBILE_CONFIG_LINKS } from '../../../helpers/constants'; + +import Tabs from '../Tabs'; +import Icons from '../Icons'; +import MobileConfigForm from './MobileConfigForm'; const renderLi = ({ label, components }) =>
  • { @@ -41,49 +26,8 @@ const renderLi = ({ label, components }) =>
  • ; -const getDnsPrivacyList = (server_name) => { - const iosList = [ - { - label: 'setup_dns_privacy_ios_2', - components: [ - { - key: 0, - href: 'https://adguard.com/adguard-ios/overview.html', - }, - text, - ], - }, - { - label: 'setup_dns_privacy_ios_1', - components: [ - { - key: 0, - href: 'https://itunes.apple.com/app/id1452162351', - }, - text, - { - key: 2, - href: 'https://dnscrypt.info/stamps', - }, - - ], - }]; - /* Insert second element if can generate .mobileconfig links */ - if (server_name) { - iosList.splice(1, 0, { - label: 'setup_dns_privacy_4', - components: { - highlight: , - }, - renderComponent: ({ label, components }) => renderMobileconfigInfo({ - label, - components, - server_name, - }), - }); - } - - return [{ +const getDnsPrivacyList = () => [ + { title: 'Android', list: [ { @@ -113,7 +57,32 @@ const getDnsPrivacyList = (server_name) => { }, { title: 'iOS', - list: iosList, + list: [ + { + label: 'setup_dns_privacy_ios_2', + components: [ + { + key: 0, + href: 'https://adguard.com/adguard-ios/overview.html', + }, + text, + ], + }, + { + label: 'setup_dns_privacy_ios_1', + components: [ + { + key: 0, + href: 'https://itunes.apple.com/app/id1452162351', + }, + text, + { + key: 2, + href: 'https://dnscrypt.info/stamps', + }, + ], + }, + ], }, { title: 'setup_dns_privacy_other_title', @@ -166,20 +135,20 @@ const getDnsPrivacyList = (server_name) => { }, ], }, - ]; -}; +]; -const renderDnsPrivacyList = ({ title, list }) =>
    - {title} -
      {list.map( - ({ - label, - components, - renderComponent = renderLi, - }) => renderComponent({ label, components }), - )} -
    -
    ; +const renderDnsPrivacyList = ({ title, list }) => ( +
    + + {title} + +
      + {list.map(({ label, components, renderComponent = renderLi }) => ( + renderComponent({ label, components }) + ))} +
    +
    +); const getTabs = ({ tlsAddress, @@ -267,8 +236,8 @@ const getTabs = ({
    )} - {showDnsPrivacyNotice - ?
    + {showDnsPrivacyNotice ? ( +
    - : <> + ) : ( + <>
    text

    ]}> setup_dns_privacy_3
    - {getDnsPrivacyList(server_name).map(renderDnsPrivacyList)} - } + {getDnsPrivacyList().map(renderDnsPrivacyList)} +
    + + + setup_dns_privacy_ioc_mac + + +
    +
    + }}> + setup_dns_privacy_4 + +
    + + + )}
    ; }, }, }); -const renderContent = ({ title, list, getTitle }) =>
    -
    {i18next.t(title)}
    -
    - {getTitle?.()} - {list - &&
      {list.map((item) =>
    1. - {item} -
    2. )} -
    } +const renderContent = ({ title, list, getTitle }) => ( +
    +
    + {i18next.t(title)} +
    +
    + {getTitle?.()} + {list && ( +
      + {list.map((item) => ( +
    1. + {item} +
    2. + ))} +
    + )} +
    -
    ; +); const Guide = ({ dnsAddresses }) => { const { t } = useTranslation(); - const server_name = useSelector((state) => state.encryption.server_name); + const server_name = useSelector((state) => state.encryption?.server_name); const tlsAddress = dnsAddresses?.filter((item) => item.includes('tls://')) ?? ''; const httpsAddress = dnsAddresses?.filter((item) => item.includes('https://')) ?? ''; const showDnsPrivacyNotice = httpsAddress.length < 1 && tlsAddress.length < 1; @@ -332,9 +330,14 @@ const Guide = ({ dnsAddresses }) => { return (
    + + {activeTab} + - {activeTab}
    ); }; @@ -364,6 +367,4 @@ renderLi.propTypes = { components: PropTypes.string, }; -renderMobileconfigInfo.propTypes = renderLi.propTypes; - export default Guide; diff --git a/client/src/components/ui/Guide/MobileConfigForm.js b/client/src/components/ui/Guide/MobileConfigForm.js new file mode 100644 index 00000000..e1726d99 --- /dev/null +++ b/client/src/components/ui/Guide/MobileConfigForm.js @@ -0,0 +1,131 @@ +import React from 'react'; +import PropTypes from 'prop-types'; +import { Trans } from 'react-i18next'; +import { useSelector } from 'react-redux'; +import { Field, reduxForm } from 'redux-form'; +import i18next from 'i18next'; +import cn from 'classnames'; + +import { getPathWithQueryString } from '../../../helpers/helpers'; +import { FORM_NAME, MOBILE_CONFIG_LINKS } from '../../../helpers/constants'; +import { + renderInputField, + renderSelectField, +} from '../../../helpers/form'; +import { + validateClientId, + validateServerName, +} from '../../../helpers/validators'; + +const getDownloadLink = (host, clientId, protocol, invalid) => { + if (!host || invalid) { + return ( + + ); + } + + const linkParams = { host }; + + if (clientId) { + linkParams.client_id = clientId; + } + + return ( + + download_mobileconfig + + ); +}; + +const MobileConfigForm = ({ invalid }) => { + const formValues = useSelector((state) => state.form[FORM_NAME.MOBILE_CONFIG]?.values); + + if (!formValues) { + return null; + } + + const { host, clientId, protocol } = formValues; + + const githubLink = ( + + text + + ); + + return ( +
    e.preventDefault()}> +
    +
    + + +
    +
    + +
    + + client_id_desc + +
    + +
    +
    + + + + + +
    +
    + + {getDownloadLink(host, clientId, protocol, invalid)} +
    + ); +}; + +MobileConfigForm.propTypes = { + invalid: PropTypes.bool.isRequired, +}; + +export default reduxForm({ form: FORM_NAME.MOBILE_CONFIG })(MobileConfigForm); diff --git a/client/src/components/ui/Guide/index.js b/client/src/components/ui/Guide/index.js new file mode 100644 index 00000000..ee660aeb --- /dev/null +++ b/client/src/components/ui/Guide/index.js @@ -0,0 +1 @@ +export { default } from './Guide'; diff --git a/client/src/helpers/constants.js b/client/src/helpers/constants.js index 76be76a6..e05ee9dc 100644 --- a/client/src/helpers/constants.js +++ b/client/src/helpers/constants.js @@ -13,6 +13,8 @@ export const R_MAC = /^((([a-fA-F0-9][a-fA-F0-9]+[-]){5}|([a-fA-F0-9][a-fA-F0-9] export const R_CIDR_IPV6 = /^s*((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]d|1dd|[1-9]?d)(.(25[0-5]|2[0-4]d|1dd|[1-9]?d)){3}))|:)))(%.+)?s*(\/(12[0-8]|1[0-1][0-9]|[1-9][0-9]|[0-9]))$/; +export const R_DOMAIN = /^[a-zA-Z0-9][a-zA-Z0-9-]{1,61}[a-zA-Z0-9]\.[a-zA-Z]{2,}$/; + export const R_PATH_LAST_PART = /\/[^/]*$/; // eslint-disable-next-line no-control-regex @@ -21,6 +23,8 @@ export const R_UNIX_ABSOLUTE_PATH = /^(\/[^/\x00]+)+$/; // eslint-disable-next-line no-control-regex export const R_WIN_ABSOLUTE_PATH = /^([a-zA-Z]:)?(\\|\/)(?:[^\\/:*?"<>|\x00]+\\)*[^\\/:*?"<>|\x00]*$/; +export const R_CLIENT_ID = /^[a-z0-9-]{1,64}$/; + export const HTML_PAGES = { INSTALL: '/install.html', LOGIN: '/login.html', @@ -514,6 +518,7 @@ export const FORM_NAME = { INSTALL: 'install', LOGIN: 'login', CACHE: 'cache', + MOBILE_CONFIG: 'mobileConfig', ...DHCP_FORM_NAMES, }; @@ -574,6 +579,7 @@ export const TOAST_TIMEOUTS = { export const ADDRESS_TYPES = { IP: 'IP', CIDR: 'CIDR', + CLIENT_ID: 'CLIENT_ID', UNKNOWN: 'UNKNOWN', }; @@ -585,3 +591,8 @@ export const CACHE_CONFIG_FIELDS = { export const isFirefox = navigator.userAgent.indexOf('Firefox') !== -1; export const COMMENT_LINE_DEFAULT_TOKEN = '#'; + +export const MOBILE_CONFIG_LINKS = { + DOT: '/apple/dot.mobileconfig', + DOH: '/apple/doh.mobileconfig', +}; diff --git a/client/src/helpers/helpers.js b/client/src/helpers/helpers.js index 82f30245..2acf5315 100644 --- a/client/src/helpers/helpers.js +++ b/client/src/helpers/helpers.js @@ -4,7 +4,6 @@ import dateFormat from 'date-fns/format'; import round from 'lodash/round'; import axios from 'axios'; import i18n from 'i18next'; -import uniqBy from 'lodash/uniqBy'; import ipaddr from 'ipaddr.js'; import queryString from 'query-string'; import React from 'react'; @@ -22,6 +21,7 @@ import { DHCP_VALUES_PLACEHOLDERS, FILTERED, FILTERED_STATUS, + R_CLIENT_ID, SERVICES_ID_NAME_MAP, STANDARD_DNS_PORT, STANDARD_HTTPS_PORT, @@ -62,6 +62,7 @@ export const normalizeLogs = (logs) => logs.map((log) => { answer_dnssec, client, client_proto, + client_id, elapsedMs, question, reason, @@ -99,6 +100,7 @@ export const normalizeLogs = (logs) => logs.map((log) => { reason, client, client_proto, + client_id, /* TODO 'filterId' and 'rule' are deprecated, will be removed in 0.106 */ filterId, rule, @@ -414,14 +416,21 @@ export const getPathWithQueryString = (path, params) => { return `${path}?${searchParams.toString()}`; }; -export const getParamsForClientsSearch = (data, param) => { - const uniqueClients = uniqBy(data, param); - return uniqueClients - .reduce((acc, item, idx) => { - const key = `ip${idx}`; - acc[key] = item[param]; - return acc; - }, {}); +export const getParamsForClientsSearch = (data, param, additionalParam) => { + const clients = new Set(); + data.forEach((e) => { + clients.add(e[param]); + if (e[additionalParam]) { + clients.add(e[additionalParam]); + } + }); + const params = {}; + const ids = Array.from(clients.values()); + ids.forEach((id, i) => { + params[`ip${i}`] = id; + }); + + return params; }; /** @@ -534,7 +543,7 @@ export const isIpInCidr = (ip, cidr) => { /** * * @param ipOrCidr - * @returns {'IP' | 'CIDR' | 'UNKNOWN'} + * @returns {'IP' | 'CIDR' | 'CLIENT_ID' | 'UNKNOWN'} * */ export const findAddressType = (address) => { @@ -547,6 +556,9 @@ export const findAddressType = (address) => { if (cidrMaybe && ipaddr.parseCIDR(address)) { return ADDRESS_TYPES.CIDR; } + if (R_CLIENT_ID.test(address)) { + return ADDRESS_TYPES.CLIENT_ID; + } return ADDRESS_TYPES.UNKNOWN; } catch (e) { @@ -567,20 +579,31 @@ export const separateIpsAndCidrs = (ids) => ids.reduce((acc, curr) => { if (addressType === ADDRESS_TYPES.CIDR) { acc.cidrs.push(curr); } + if (addressType === ADDRESS_TYPES.CLIENT_ID) { + acc.clientIds.push(curr); + } return acc; -}, { ips: [], cidrs: [] }); +}, { ips: [], cidrs: [], clientIds: [] }); export const countClientsStatistics = (ids, autoClients) => { - const { ips, cidrs } = separateIpsAndCidrs(ids); + const { ips, cidrs, clientIds } = separateIpsAndCidrs(ids); const ipsCount = ips.reduce((acc, curr) => { const count = autoClients[curr] || 0; return acc + count; }, 0); + const clientIdsCount = clientIds.reduce((acc, curr) => { + const count = autoClients[curr] || 0; + return acc + count; + }, 0); + const cidrsCount = Object.entries(autoClients) .reduce((acc, curr) => { const [id, count] = curr; + if (!ipaddr.isValid(id)) { + return false; + } if (cidrs.some((cidr) => isIpInCidr(id, cidr))) { // eslint-disable-next-line no-param-reassign acc += count; @@ -588,7 +611,7 @@ export const countClientsStatistics = (ids, autoClients) => { return acc; }, 0); - return ipsCount + cidrsCount; + return ipsCount + cidrsCount + clientIdsCount; }; /** diff --git a/client/src/helpers/validators.js b/client/src/helpers/validators.js index 2632df3f..c26bbb8e 100644 --- a/client/src/helpers/validators.js +++ b/client/src/helpers/validators.js @@ -9,6 +9,8 @@ import { R_URL_REQUIRES_PROTOCOL, STANDARD_WEB_PORT, UNSAFE_PORTS, + R_CLIENT_ID, + R_DOMAIN, } from './constants'; import { getLastIpv4Octet, isValidAbsolutePath } from './form'; @@ -71,12 +73,28 @@ export const validateClientId = (value) => { || R_MAC.test(formattedValue) || R_CIDR.test(formattedValue) || R_CIDR_IPV6.test(formattedValue) + || R_CLIENT_ID.test(formattedValue) )) { return 'form_error_client_id_format'; } return undefined; }; +/** + * @param value {string} + * @returns {undefined|string} + */ +export const validateServerName = (value) => { + if (!value) { + return undefined; + } + const formattedValue = value ? value.trim() : value; + if (formattedValue && !R_DOMAIN.test(formattedValue)) { + return 'form_error_server_name'; + } + return undefined; +}; + /** * @param value {string} * @returns {undefined|string} diff --git a/go.mod b/go.mod index f22141c2..a4f86658 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.14 require ( - github.com/AdguardTeam/dnsproxy v0.33.7 + github.com/AdguardTeam/dnsproxy v0.33.9 github.com/AdguardTeam/golibs v0.4.4 github.com/AdguardTeam/urlfilter v0.14.2 github.com/NYTimes/gziphandler v1.1.1 @@ -17,6 +17,7 @@ require ( github.com/insomniacslk/dhcp v0.0.0-20201112113307-4de412bc85d8 github.com/kardianos/service v1.2.0 github.com/karrick/godirwalk v1.16.1 // indirect + github.com/lucas-clemente/quic-go v0.19.3 github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 github.com/mdlayher/raw v0.0.0-20191009151244-50f2db8cc065 github.com/miekg/dns v1.1.35 diff --git a/go.sum b/go.sum index 1c488f95..b298768b 100644 --- a/go.sum +++ b/go.sum @@ -18,8 +18,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4= dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU= git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg= -github.com/AdguardTeam/dnsproxy v0.33.7 h1:DXsLTJoBSUejB2ZqVHyMG0/kXD8PzuVPbLCsGKBdaDc= -github.com/AdguardTeam/dnsproxy v0.33.7/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs= +github.com/AdguardTeam/dnsproxy v0.33.9 h1:HUwywkhUV/M73E7qWcBAF+SdsNq742s82Lvox4pr/tM= +github.com/AdguardTeam/dnsproxy v0.33.9/go.mod h1:dkI9VWh43XlOzF2XogDm1EmoVl7PANOR4isQV6X9LZs= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o= github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index ee3f6aa0..748b14b2 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -24,11 +24,12 @@ type FilteringConfig struct { // Callbacks for other modules // -- - // Filtering callback function - FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` + // FilterHandler is an optional additional filtering callback. + FilterHandler func(clientAddr net.IP, clientID string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` // GetCustomUpstreamByClient - a callback function that returns upstreams configuration // based on the client IP address. Returns nil if there are no custom upstreams for the client + // // TODO(e.burkov): Replace argument type with net.IP. GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` @@ -109,6 +110,10 @@ type TLSConfig struct { CertificateChainData []byte `yaml:"-" json:"-"` PrivateKeyData []byte `yaml:"-" json:"-"` + // ServerName is the hostname of the server. Currently, it is only + // being used for client ID checking. + ServerName string `yaml:"-" json:"-"` + cert tls.Certificate // DNS names from certificate (SAN) or CN value from Subject dnsNames []string diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 10a965e2..f8e7bff0 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -1,7 +1,10 @@ package dnsforward import ( + "crypto/tls" + "fmt" "net" + "path" "strings" "time" @@ -10,37 +13,64 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/util" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" + "github.com/lucas-clemente/quic-go" "github.com/miekg/dns" ) // To transfer information between modules type dnsContext struct { - srv *Server - proxyCtx *proxy.DNSContext - setts *dnsfilter.RequestFilteringSettings // filtering settings for this client - startTime time.Time - result *dnsfilter.Result - origResp *dns.Msg // response received from upstream servers. Set when response is modified by filtering - origQuestion dns.Question // question received from client. Set when Rewrites are used. - err error // error returned from the module - protectionEnabled bool // filtering is enabled, dnsfilter object is ready - responseFromUpstream bool // response is received from upstream servers - origReqDNSSEC bool // DNSSEC flag in the original request from user + srv *Server + proxyCtx *proxy.DNSContext + // setts are the filtering settings for the client. + setts *dnsfilter.RequestFilteringSettings + startTime time.Time + result *dnsfilter.Result + // origResp is the response received from upstream. It is set when the + // response is modified by filters. + origResp *dns.Msg + // err is the error returned from a processing function. + err error + // clientID is the clientID from DOH, DOQ, or DOT, if provided. + clientID string + // origQuestion is the question received from the client. It is set + // when the request is modified by rewrites. + origQuestion dns.Question + // protectionEnabled shows if the filtering is enabled, and if the + // server's DNS filter is ready. + protectionEnabled bool + // responseFromUpstream shows if the response is received from the + // upstream servers. + responseFromUpstream bool + // origReqDNSSEC shows if the DNSSEC flag in the original request from + // the client is set. + origReqDNSSEC bool } +// resultCode is the result of a request processing function. +type resultCode int + const ( - resultDone = iota // module has completed its job, continue - resultFinish // module has completed its job, exit normally - resultError // an error occurred, exit with an error + // resultCodeSuccess is returned when a handler performed successfully, + // and the next handler must be called. + resultCodeSuccess resultCode = iota + // resultCodeFinish is returned when a handler performed successfully, + // and the processing of the request must be stopped. + resultCodeFinish + // resultCodeError is returned when a handler failed, and the processing + // of the request must be stopped. + resultCodeError ) // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { - ctx := &dnsContext{srv: s, proxyCtx: d} - ctx.result = &dnsfilter.Result{} - ctx.startTime = time.Now() + ctx := &dnsContext{ + srv: s, + proxyCtx: d, + result: &dnsfilter.Result{}, + startTime: time.Now(), + } - type modProcessFunc func(ctx *dnsContext) int + type modProcessFunc func(ctx *dnsContext) (rc resultCode) // Since (*dnsforward.Server).handleDNSRequest(...) is used as // proxy.(Config).RequestHandler, there is no need for additional index @@ -51,6 +81,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { processInitial, processInternalHosts, processInternalIPAddrs, + processClientID, processFilteringBeforeRequest, processUpstream, processDNSSECAfterResponse, @@ -61,13 +92,13 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { for _, process := range mods { r := process(ctx) switch r { - case resultDone: + case resultCodeSuccess: // continue: call the next filter - case resultFinish: + case resultCodeFinish: return nil - case resultError: + case resultCodeError: return ctx.err } } @@ -79,12 +110,12 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { } // Perform initial checks; process WHOIS & rDNS -func processInitial(ctx *dnsContext) int { +func processInitial(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx if s.conf.AAAADisabled && d.Req.Question[0].Qtype == dns.TypeAAAA { _ = proxy.CheckDisabledAAAARequest(d, true) - return resultFinish + return resultCodeFinish } if s.conf.OnDNSRequest != nil { @@ -96,10 +127,10 @@ func processInitial(ctx *dnsContext) int { if (d.Req.Question[0].Qtype == dns.TypeA || d.Req.Question[0].Qtype == dns.TypeAAAA) && d.Req.Question[0].Name == "use-application-dns.net." { d.Res = s.genNXDomain(d.Req) - return resultFinish + return resultCodeFinish } - return resultDone + return resultCodeSuccess } // Return TRUE if host names doesn't contain disallowed characters @@ -157,29 +188,29 @@ func (s *Server) onDHCPLeaseChanged(flags int) { } // Respond to A requests if the target host name is associated with a lease from our DHCP server -func processInternalHosts(ctx *dnsContext) int { +func processInternalHosts(ctx *dnsContext) (rc resultCode) { s := ctx.srv req := ctx.proxyCtx.Req if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) { - return resultDone + return resultCodeSuccess } host := req.Question[0].Name host = strings.ToLower(host) if !strings.HasSuffix(host, ".lan.") { - return resultDone + return resultCodeSuccess } host = strings.TrimSuffix(host, ".lan.") s.tableHostToIPLock.Lock() if s.tableHostToIP == nil { s.tableHostToIPLock.Unlock() - return resultDone + return resultCodeSuccess } ip, ok := s.tableHostToIP[host] s.tableHostToIPLock.Unlock() if !ok { - return resultDone + return resultCodeSuccess } log.Debug("DNS: internal record: %s -> %s", req.Question[0].Name, ip) @@ -200,15 +231,163 @@ func processInternalHosts(ctx *dnsContext) int { } ctx.proxyCtx.Res = resp - return resultDone + return resultCodeSuccess +} + +const maxDomainPartLen = 64 + +// ValidateClientID returns an error if clientID is not a valid client ID. +func ValidateClientID(clientID string) (err error) { + if len(clientID) > maxDomainPartLen { + return fmt.Errorf("client id %q is too long, max: %d", clientID, maxDomainPartLen) + } + + for i, r := range clientID { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' { + continue + } + + return fmt.Errorf("invalid char %q at index %d in client id %q", r, i, clientID) + } + + return nil +} + +// clientIDFromClientServerName extracts and validates a client ID. hostSrvName +// is the server name of the host. cliSrvName is the server name as sent by the +// client. +func clientIDFromClientServerName(hostSrvName, cliSrvName string) (clientID string, err error) { + if hostSrvName == cliSrvName { + return "", nil + } + + if !strings.HasSuffix(cliSrvName, hostSrvName) { + return "", fmt.Errorf("client server name %q doesn't match host server name %q", cliSrvName, hostSrvName) + } + + clientID = cliSrvName[:len(cliSrvName)-len(hostSrvName)-1] + err = ValidateClientID(clientID) + if err != nil { + return "", fmt.Errorf("invalid client id: %w", err) + } + + return clientID, nil +} + +// processClientIDHTTPS extracts the client's ID from the path of the +// client's DNS-over-HTTPS request. +func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) { + pctx := ctx.proxyCtx + r := pctx.HTTPRequest + if r == nil { + ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto) + + return resultCodeError + } + + origPath := r.URL.Path + parts := strings.Split(path.Clean(origPath), "/") + if parts[0] == "" { + parts = parts[1:] + } + + if len(parts) == 0 || parts[0] != "dns-query" { + ctx.err = fmt.Errorf("client id check: invalid path %q", origPath) + + return resultCodeError + } + + clientID := "" + switch len(parts) { + case 1: + // Just /dns-query, no client ID. + return resultCodeSuccess + case 2: + clientID = parts[1] + default: + ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath) + + return resultCodeError + } + + err := ValidateClientID(clientID) + if err != nil { + ctx.err = fmt.Errorf("client id check: invalid client id: %w", err) + + return resultCodeError + } + + ctx.clientID = clientID + + return resultCodeSuccess +} + +// tlsConn is a narrow interface for *tls.Conn to simplify testing. +type tlsConn interface { + ConnectionState() (cs tls.ConnectionState) +} + +// quicSession is a narrow interface for quic.Session to simplify testing. +type quicSession interface { + ConnectionState() (cs quic.ConnectionState) +} + +// processClientID extracts the client's ID from the server name of the client's +// DOT or DOQ request or the path of the client's DOH. +func processClientID(ctx *dnsContext) (rc resultCode) { + pctx := ctx.proxyCtx + proto := pctx.Proto + if proto == proxy.ProtoHTTPS { + return processClientIDHTTPS(ctx) + } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { + return resultCodeSuccess + } + + hostSrvName := ctx.srv.conf.TLSConfig.ServerName + if hostSrvName == "" { + return resultCodeSuccess + } + + cliSrvName := "" + if proto == proxy.ProtoTLS { + conn := pctx.Conn + tc, ok := conn.(tlsConn) + if !ok { + ctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn) + + return resultCodeError + } + + cliSrvName = tc.ConnectionState().ServerName + } else if proto == proxy.ProtoQUIC { + qs, ok := pctx.QUICSession.(quicSession) + if !ok { + ctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession) + + return resultCodeError + } + + cliSrvName = qs.ConnectionState().ServerName + } + + clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName) + if err != nil { + ctx.err = fmt.Errorf("client id check: %w", err) + + return resultCodeError + } + + ctx.clientID = clientID + + return resultCodeSuccess } // Respond to PTR requests if the target IP address is leased by our DHCP server -func processInternalIPAddrs(ctx *dnsContext) int { +func processInternalIPAddrs(ctx *dnsContext) (rc resultCode) { s := ctx.srv req := ctx.proxyCtx.Req if req.Question[0].Qtype != dns.TypePTR { - return resultDone + return resultCodeSuccess } arpa := req.Question[0].Name @@ -216,18 +395,18 @@ func processInternalIPAddrs(ctx *dnsContext) int { arpa = strings.ToLower(arpa) ip := util.DNSUnreverseAddr(arpa) if ip == nil { - return resultDone + return resultCodeSuccess } s.tablePTRLock.Lock() if s.tablePTR == nil { s.tablePTRLock.Unlock() - return resultDone + return resultCodeSuccess } host, ok := s.tablePTR[ip.String()] s.tablePTRLock.Unlock() if !ok { - return resultDone + return resultCodeSuccess } log.Debug("DNS: reverse-lookup: %s -> %s", arpa, host) @@ -243,16 +422,16 @@ func processInternalIPAddrs(ctx *dnsContext) int { ptr.Ptr = host + "." resp.Answer = append(resp.Answer, ptr) ctx.proxyCtx.Res = resp - return resultDone + return resultCodeSuccess } // Apply filtering logic -func processFilteringBeforeRequest(ctx *dnsContext) int { +func processFilteringBeforeRequest(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx if d.Res != nil { - return resultDone // response is already set - nothing to do + return resultCodeSuccess // response is already set - nothing to do } s.RLock() @@ -266,24 +445,24 @@ func processFilteringBeforeRequest(ctx *dnsContext) int { var err error ctx.protectionEnabled = s.conf.ProtectionEnabled && s.dnsFilter != nil if ctx.protectionEnabled { - ctx.setts = s.getClientRequestFilteringSettings(d) + ctx.setts = s.getClientRequestFilteringSettings(ctx) ctx.result, err = s.filterDNSRequest(ctx) } s.RUnlock() if err != nil { ctx.err = err - return resultError + return resultCodeError } - return resultDone + return resultCodeSuccess } // processUpstream passes request to upstream servers and handles the response. -func processUpstream(ctx *dnsContext) int { +func processUpstream(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx if d.Res != nil { - return resultDone // response is already set - nothing to do + return resultCodeSuccess // response is already set - nothing to do } if d.Addr != nil && s.conf.GetCustomUpstreamByClient != nil { @@ -311,26 +490,26 @@ func processUpstream(ctx *dnsContext) int { err := s.dnsProxy.Resolve(d) if err != nil { ctx.err = err - return resultError + return resultCodeError } ctx.responseFromUpstream = true - return resultDone + return resultCodeSuccess } // Process DNSSEC after response from upstream server -func processDNSSECAfterResponse(ctx *dnsContext) int { +func processDNSSECAfterResponse(ctx *dnsContext) (rc resultCode) { d := ctx.proxyCtx if !ctx.responseFromUpstream || // don't process response if it's not from upstream servers !ctx.srv.conf.EnableDNSSEC { - return resultDone + return resultCodeSuccess } if !ctx.origReqDNSSEC { optResp := d.Res.IsEdns0() if optResp != nil && !optResp.Do() { - return resultDone + return resultCodeSuccess } // Remove RRSIG records from response @@ -361,11 +540,11 @@ func processDNSSECAfterResponse(ctx *dnsContext) int { d.Res.Ns = answers } - return resultDone + return resultCodeSuccess } // Apply filtering logic after we have received response from upstream servers -func processFilteringAfterResponse(ctx *dnsContext) int { +func processFilteringAfterResponse(ctx *dnsContext) (rc resultCode) { s := ctx.srv d := ctx.proxyCtx res := ctx.result @@ -402,7 +581,7 @@ func processFilteringAfterResponse(ctx *dnsContext) int { ctx.result, err = s.filterDNSResponse(ctx) if err != nil { ctx.err = err - return resultError + return resultCodeError } if ctx.result != nil { ctx.origResp = origResp2 // matched by response @@ -411,5 +590,5 @@ func processFilteringAfterResponse(ctx *dnsContext) int { } } - return resultDone + return resultCodeSuccess } diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go new file mode 100644 index 00000000..bd0ef4ab --- /dev/null +++ b/internal/dnsforward/dns_test.go @@ -0,0 +1,235 @@ +package dnsforward + +import ( + "crypto/tls" + "net" + "net/http" + "net/url" + "testing" + + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/lucas-clemente/quic-go" + "github.com/stretchr/testify/assert" +) + +// testTLSConn is a tlsConn for tests. +type testTLSConn struct { + // Conn is embedded here simply to make testTLSConn a net.Conn without + // acctually implementing all methods. + net.Conn + + serverName string +} + +// ConnectionState implements the tlsConn interface for testTLSConn. +func (c testTLSConn) ConnectionState() (cs tls.ConnectionState) { + cs.ServerName = c.serverName + + return cs +} + +// testQUICSession is a quicSession for tests. +type testQUICSession struct { + // Session is embedded here simply to make testQUICSession + // a quic.Session without acctually implementing all methods. + quic.Session + + serverName string +} + +// ConnectionState implements the quicSession interface for testQUICSession. +func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) { + cs.ServerName = c.serverName + + return cs +} + +func TestProcessClientID(t *testing.T) { + testCases := []struct { + name string + proto string + hostSrvName string + cliSrvName string + wantClientID string + wantErrMsg string + wantRes resultCode + }{{ + name: "udp", + proto: proxy.ProtoUDP, + hostSrvName: "", + cliSrvName: "", + wantClientID: "", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "tls_no_client_id", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "example.com", + wantClientID: "", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "tls_client_id", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "cli.example.com", + wantClientID: "cli", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "tls_client_id_hostname_error", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "cli.example.net", + wantClientID: "", + wantErrMsg: `client id check: client server name "cli.example.net" doesn't match host server name "example.com"`, + wantRes: resultCodeError, + }, { + name: "tls_invalid_client_id", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "!!!.example.com", + wantClientID: "", + wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`, + wantRes: resultCodeError, + }, { + name: "tls_client_id_too_long", + proto: proxy.ProtoTLS, + hostSrvName: "example.com", + cliSrvName: "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789.example.com", + wantClientID: "", + wantErrMsg: `client id check: invalid client id: client id "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789" is too long, max: 64`, + wantRes: resultCodeError, + }, { + name: "quic_client_id", + proto: proxy.ProtoQUIC, + hostSrvName: "example.com", + cliSrvName: "cli.example.com", + wantClientID: "cli", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + srv := &Server{ + conf: ServerConfig{ + TLSConfig: TLSConfig{ServerName: tc.hostSrvName}, + }, + } + + var conn net.Conn + if tc.proto == proxy.ProtoTLS { + conn = testTLSConn{ + serverName: tc.cliSrvName, + } + } + + var qs quic.Session + if tc.proto == proxy.ProtoQUIC { + qs = testQUICSession{ + serverName: tc.cliSrvName, + } + } + + dctx := &dnsContext{ + srv: srv, + proxyCtx: &proxy.DNSContext{ + Proto: tc.proto, + Conn: conn, + QUICSession: qs, + }, + } + + res := processClientID(dctx) + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantClientID, dctx.clientID) + + if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) { + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) + } else { + assert.Nil(t, dctx.err) + } + }) + } +} + +func TestProcessClientID_https(t *testing.T) { + testCases := []struct { + name string + path string + wantClientID string + wantErrMsg string + wantRes resultCode + }{{ + name: "no_client_id", + path: "/dns-query", + wantClientID: "", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "no_client_id_slash", + path: "/dns-query/", + wantClientID: "", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "client_id", + path: "/dns-query/cli", + wantClientID: "cli", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "client_id_slash", + path: "/dns-query/cli/", + wantClientID: "cli", + wantErrMsg: "", + wantRes: resultCodeSuccess, + }, { + name: "bad_url", + path: "/foo", + wantClientID: "", + wantErrMsg: `client id check: invalid path "/foo"`, + wantRes: resultCodeError, + }, { + name: "extra", + path: "/dns-query/cli/foo", + wantClientID: "", + wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`, + wantRes: resultCodeError, + }, { + name: "invalid_client_id", + path: "/dns-query/!!!", + wantClientID: "", + wantErrMsg: `client id check: invalid client id: invalid char '!' at index 0 in client id "!!!"`, + wantRes: resultCodeError, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := &http.Request{ + URL: &url.URL{ + Path: tc.path, + }, + } + + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Proto: proxy.ProtoHTTPS, + HTTPRequest: r, + }, + } + + res := processClientID(dctx) + assert.Equal(t, tc.wantRes, res) + assert.Equal(t, tc.wantClientID, dctx.clientID) + + if tc.wantErrMsg != "" && assert.NotNil(t, dctx.err) { + assert.Equal(t, tc.wantErrMsg, dctx.err.Error()) + } else { + assert.Nil(t, dctx.err) + } + }) + } +} diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index ec56c9e8..17f4b872 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) { func TestClientRulesForCNAMEMatching(t *testing.T) { s := createTestServer(t) testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} - s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) { + s.conf.FilterHandler = func(_ net.IP, _ string, settings *dnsfilter.RequestFilteringSettings) { settings.FilteringEnabled = false } err := s.startWithUpstream(testUpstm) @@ -1033,8 +1033,7 @@ func TestMatchDNSName(t *testing.T) { assert.False(t, matchDNSName(dnsNames, "*.host2")) } -type testDHCP struct { -} +type testDHCP struct{} func (d *testDHCP) Leases(flags int) []dhcpd.Lease { l := dhcpd.Lease{} diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 4d319288..f3d9f758 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -30,14 +30,15 @@ func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool return true, nil } -// getClientRequestFilteringSettings lookups client filtering settings -// using the client's IP address from the DNSContext -func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilter.RequestFilteringSettings { +// getClientRequestFilteringSettings looks up client filtering settings using +// the client's IP address and ID, if any, from ctx. +func (s *Server) getClientRequestFilteringSettings(ctx *dnsContext) *dnsfilter.RequestFilteringSettings { setts := s.dnsFilter.GetConfig() setts.FilteringEnabled = true if s.conf.FilterHandler != nil { - s.conf.FilterHandler(IPFromAddr(d.Addr), &setts) + s.conf.FilterHandler(IPFromAddr(ctx.proxyCtx.Addr), ctx.clientID, &setts) } + return &setts } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 3a6d6578..e96c6c3f 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -529,5 +529,5 @@ func (s *Server) registerHandlers() { s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList) s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet) - s.conf.HTTPRegister("", "/dns-query", s.handleDOH) + s.conf.HTTPRegister("", "/dns-query/", s.handleDOH) } diff --git a/internal/dnsforward/ipset.go b/internal/dnsforward/ipset.go index 74a8ecd8..7421edda 100644 --- a/internal/dnsforward/ipset.go +++ b/internal/dnsforward/ipset.go @@ -99,12 +99,12 @@ func (c *ipsetCtx) getIP(rr dns.RR) net.IP { } // Add IP addresses of the specified in configuration domain names to an ipset list -func (c *ipsetCtx) process(ctx *dnsContext) int { +func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) { req := ctx.proxyCtx.Req if !(req.Question[0].Qtype == dns.TypeA || req.Question[0].Qtype == dns.TypeAAAA) || !ctx.responseFromUpstream { - return resultDone + return resultCodeSuccess } host := req.Question[0].Name @@ -112,7 +112,7 @@ func (c *ipsetCtx) process(ctx *dnsContext) int { host = strings.ToLower(host) ipsetNames, found := c.ipsetList[host] if !found { - return resultDone + return resultCodeSuccess } log.Debug("IPSET: found ipsets %v for host %s", ipsetNames, host) @@ -138,5 +138,5 @@ func (c *ipsetCtx) process(ctx *dnsContext) int { } } - return resultDone + return resultCodeSuccess } diff --git a/internal/dnsforward/ipset_test.go b/internal/dnsforward/ipset_test.go index 41be83d2..f08e4d08 100644 --- a/internal/dnsforward/ipset_test.go +++ b/internal/dnsforward/ipset_test.go @@ -37,5 +37,5 @@ func TestIPSET(t *testing.T) { }, }, } - assert.Equal(t, resultDone, c.process(ctx)) + assert.Equal(t, resultCodeSuccess, c.process(ctx)) } diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index be45b0f9..29e6bb86 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -1,7 +1,6 @@ package dnsforward import ( - "net" "strings" "time" @@ -13,13 +12,13 @@ import ( ) // Write Stats data and logs -func processQueryLogsAndStats(ctx *dnsContext) int { +func processQueryLogsAndStats(ctx *dnsContext) (rc resultCode) { elapsed := time.Since(ctx.startTime) s := ctx.srv - d := ctx.proxyCtx + pctx := ctx.proxyCtx shouldLog := true - msg := d.Req + msg := pctx.Req // don't log ANY request if refuseAny is enabled if len(msg.Question) >= 1 && msg.Question[0].Qtype == dns.TypeANY && s.conf.RefuseAny { @@ -32,65 +31,67 @@ func processQueryLogsAndStats(ctx *dnsContext) int { if shouldLog && s.queryLog != nil { p := querylog.AddParams{ Question: msg, - Answer: d.Res, + Answer: pctx.Res, OrigAnswer: ctx.origResp, Result: ctx.result, Elapsed: elapsed, - ClientIP: IPFromAddr(d.Addr), + ClientIP: IPFromAddr(pctx.Addr), + ClientID: ctx.clientID, } - switch d.Proto { + switch pctx.Proto { case proxy.ProtoHTTPS: p.ClientProto = querylog.ClientProtoDOH case proxy.ProtoQUIC: p.ClientProto = querylog.ClientProtoDOQ case proxy.ProtoTLS: p.ClientProto = querylog.ClientProtoDOT + case proxy.ProtoDNSCrypt: + p.ClientProto = querylog.ClientProtoDNSCrypt default: - // Consider this a plain DNS-over-UDP or DNS-over-TCL + // Consider this a plain DNS-over-UDP or DNS-over-TCP // request. } - if d.Upstream != nil { - p.Upstream = d.Upstream.Address() + if pctx.Upstream != nil { + p.Upstream = pctx.Upstream.Address() } + s.queryLog.Add(p) } - s.updateStats(d, elapsed, *ctx.result) + s.updateStats(ctx, elapsed, *ctx.result) s.RUnlock() - return resultDone + return resultCodeSuccess } -func (s *Server) updateStats(d *proxy.DNSContext, elapsed time.Duration, res dnsfilter.Result) { +func (s *Server) updateStats(ctx *dnsContext, elapsed time.Duration, res dnsfilter.Result) { if s.stats == nil { return } + pctx := ctx.proxyCtx e := stats.Entry{} - e.Domain = strings.ToLower(d.Req.Question[0].Name) + e.Domain = strings.ToLower(pctx.Req.Question[0].Name) e.Domain = e.Domain[:len(e.Domain)-1] // remove last "." - switch addr := d.Addr.(type) { - case *net.UDPAddr: - e.Client = addr.IP - case *net.TCPAddr: - e.Client = addr.IP + + if clientID := ctx.clientID; clientID != "" { + e.Client = clientID + } else if ip := IPFromAddr(pctx.Addr); ip != nil { + e.Client = ip.String() } + e.Time = uint32(elapsed / 1000) e.Result = stats.RNotFiltered switch res.Reason { - case dnsfilter.FilteredSafeBrowsing: e.Result = stats.RSafeBrowsing - case dnsfilter.FilteredParental: e.Result = stats.RParental - case dnsfilter.FilteredSafeSearch: e.Result = stats.RSafeSearch - case dnsfilter.FilteredBlockList: fallthrough case dnsfilter.FilteredInvalid: diff --git a/internal/dnsforward/stats_test.go b/internal/dnsforward/stats_test.go new file mode 100644 index 00000000..3b5981bb --- /dev/null +++ b/internal/dnsforward/stats_test.go @@ -0,0 +1,198 @@ +package dnsforward + +import ( + "net" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" + "github.com/AdguardTeam/AdGuardHome/internal/querylog" + "github.com/AdguardTeam/AdGuardHome/internal/stats" + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +// testQueryLog is a simple querylog.QueryLog implementation for tests. +type testQueryLog struct { + // QueryLog is embedded here simply to make testQueryLog + // a querylog.QueryLog without acctually implementing all methods. + querylog.QueryLog + + lastParams querylog.AddParams +} + +// Add implements the querylog.QueryLog interface for *testQueryLog. +func (l *testQueryLog) Add(p querylog.AddParams) { + l.lastParams = p +} + +// testStats is a simple stats.Stats implementation for tests. +type testStats struct { + // Stats is embedded here simply to make testStats a stats.Stats without + // acctually implementing all methods. + stats.Stats + + lastEntry stats.Entry +} + +// Update implements the stats.Stats interface for *testStats. +func (l *testStats) Update(e stats.Entry) { + l.lastEntry = e +} + +func TestProcessQueryLogsAndStats(t *testing.T) { + testCases := []struct { + name string + proto string + addr net.Addr + clientID string + wantLogProto querylog.ClientProto + wantStatClient string + wantCode resultCode + reason dnsfilter.Reason + wantStatResult stats.Result + }{{ + name: "success_udp", + proto: proxy.ProtoUDP, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: "", + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_tls_client_id", + proto: proxy.ProtoTLS, + addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "cli42", + wantLogProto: querylog.ClientProtoDOT, + wantStatClient: "cli42", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_tls", + proto: proxy.ProtoTLS, + addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: querylog.ClientProtoDOT, + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_quic", + proto: proxy.ProtoQUIC, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: querylog.ClientProtoDOQ, + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_https", + proto: proxy.ProtoHTTPS, + addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: querylog.ClientProtoDOH, + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_dnscrypt", + proto: proxy.ProtoDNSCrypt, + addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: querylog.ClientProtoDNSCrypt, + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.NotFilteredNotFound, + wantStatResult: stats.RNotFiltered, + }, { + name: "success_udp_filtered", + proto: proxy.ProtoUDP, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: "", + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.FilteredBlockList, + wantStatResult: stats.RFiltered, + }, { + name: "success_udp_sb", + proto: proxy.ProtoUDP, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: "", + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.FilteredSafeBrowsing, + wantStatResult: stats.RSafeBrowsing, + }, { + name: "success_udp_ss", + proto: proxy.ProtoUDP, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: "", + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.FilteredSafeSearch, + wantStatResult: stats.RSafeSearch, + }, { + name: "success_udp_pc", + proto: proxy.ProtoUDP, + addr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 1234}, + clientID: "", + wantLogProto: "", + wantStatClient: "1.2.3.4", + wantCode: resultCodeSuccess, + reason: dnsfilter.FilteredParental, + wantStatResult: stats.RParental, + }} + + ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{}) + assert.Nil(t, err) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := &dns.Msg{ + Question: []dns.Question{{ + Name: "example.com.", + }}, + } + pctx := &proxy.DNSContext{ + Proto: tc.proto, + Req: req, + Res: &dns.Msg{}, + Addr: tc.addr, + Upstream: ups, + } + + ql := &testQueryLog{} + st := &testStats{} + dctx := &dnsContext{ + srv: &Server{ + queryLog: ql, + stats: st, + }, + proxyCtx: pctx, + startTime: time.Now(), + result: &dnsfilter.Result{ + Reason: tc.reason, + }, + clientID: tc.clientID, + } + + code := processQueryLogsAndStats(dctx) + assert.Equal(t, tc.wantCode, code) + assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto) + assert.Equal(t, tc.wantStatClient, st.lastEntry.Client) + assert.Equal(t, tc.wantStatResult, st.lastEntry.Result) + }) + } +} diff --git a/internal/home/clients.go b/internal/home/clients.go index b631f63e..c3eb366f 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -11,23 +11,21 @@ import ( "sync" "time" - "github.com/AdguardTeam/dnsproxy/proxy" - + "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/util" + "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" ) -const ( - clientsUpdatePeriod = 10 * time.Minute -) +const clientsUpdatePeriod = 10 * time.Minute var webHandlersRegistered = false -// Client information +// Client contains information about persistent clients. type Client struct { IDs []string Tags []string @@ -52,14 +50,13 @@ type Client struct { type clientSource uint -// Client sources +// Client sources. The order determines the priority. const ( - // Priority: etc/hosts > DHCP > ARP > rDNS > WHOIS - ClientSourceWHOIS clientSource = iota // from WHOIS - ClientSourceRDNS // from rDNS - ClientSourceDHCP // from DHCP - ClientSourceARP // from 'arp -a' - ClientSourceHostsFile // from /etc/hosts + ClientSourceWHOIS clientSource = iota + ClientSourceRDNS + ClientSourceDHCP + ClientSourceARP + ClientSourceHostsFile ) // ClientHost information @@ -70,12 +67,12 @@ type ClientHost struct { } type clientsContainer struct { - list map[string]*Client // name -> client - idIndex map[string]*Client // IP -> client - // TODO(e.burkov): Think of a way to not require string conversion for - // IP addresses. - ipHost map[string]*ClientHost // IP -> Hostname - lock sync.Mutex + // TODO(a.garipov): Perhaps use a number of separate indices for + // different types (string, net.IP, and so on). + list map[string]*Client // name -> client + idIndex map[string]*Client // ID -> client + ipHost map[string]*ClientHost // IP -> Hostname + lock sync.Mutex allTags map[string]bool @@ -158,7 +155,7 @@ func (clients *clientsContainer) tagKnown(tag string) bool { func (clients *clientsContainer) addFromConfig(objects []clientObject) { for _, cy := range objects { - cli := Client{ + cli := &Client{ Name: cy.Name, IDs: cy.IDs, UseOwnSettings: !cy.UseGlobalSettings, @@ -174,7 +171,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { for _, s := range cy.BlockedServices { if !dnsfilter.BlockedSvcKnown(s) { - log.Debug("Clients: skipping unknown blocked-service %q", s) + log.Debug("clients: skipping unknown blocked-service %q", s) continue } cli.BlockedServices = append(cli.BlockedServices, s) @@ -182,7 +179,7 @@ func (clients *clientsContainer) addFromConfig(objects []clientObject) { for _, t := range cy.Tags { if !clients.tagKnown(t) { - log.Debug("Clients: skipping unknown tag %q", t) + log.Debug("clients: skipping unknown tag %q", t) continue } cli.Tags = append(cli.Tags, t) @@ -210,10 +207,10 @@ func (clients *clientsContainer) WriteDiskConfig(objects *[]clientObject) { UseGlobalBlockedServices: !cli.UseOwnBlockedServices, } - cy.Tags = stringArrayDup(cli.Tags) - cy.IDs = stringArrayDup(cli.IDs) - cy.BlockedServices = stringArrayDup(cli.BlockedServices) - cy.Upstreams = stringArrayDup(cli.Upstreams) + cy.Tags = copyStrings(cli.Tags) + cy.IDs = copyStrings(cli.IDs) + cy.BlockedServices = copyStrings(cli.BlockedServices) + cy.Upstreams = copyStrings(cli.Upstreams) *objects = append(*objects, cy) } @@ -240,45 +237,44 @@ func (clients *clientsContainer) onHostsChanged() { clients.addFromHostsFile() } -// Exists checks if client with this IP already exists -func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool { +// Exists checks if client with this ID already exists. +func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - _, ok := clients.findByIP(ip) + _, ok = clients.findLocked(id) if ok { return true } - ch, ok := clients.ipHost[ip.String()] + var ch *ClientHost + ch, ok = clients.ipHost[id] if !ok { return false } - if source > ch.Source { - return false // we're going to overwrite this client's info with a stronger source - } - return true + + // Return false if the new source has higher priority. + return source <= ch.Source } -func stringArrayDup(a []string) []string { - a2 := make([]string, len(a)) - copy(a2, a) - return a2 +func copyStrings(a []string) (b []string) { + return append(b, a...) } -// Find searches for a client by IP -func (clients *clientsContainer) Find(ip net.IP) (Client, bool) { +// Find searches for a client by its ID. +func (clients *clientsContainer) Find(id string) (c *Client, ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findByIP(ip) + c, ok = clients.findLocked(id) if !ok { - return Client{}, false + return nil, false } - c.IDs = stringArrayDup(c.IDs) - c.Tags = stringArrayDup(c.Tags) - c.BlockedServices = stringArrayDup(c.BlockedServices) - c.Upstreams = stringArrayDup(c.Upstreams) + + c.IDs = copyStrings(c.IDs) + c.Tags = copyStrings(c.Tags) + c.BlockedServices = copyStrings(c.BlockedServices) + c.Upstreams = copyStrings(c.Upstreams) return c, true } @@ -289,7 +285,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.findByIP(net.ParseIP(ip)) + c, ok := clients.findLocked(ip) if !ok { return nil } @@ -308,15 +304,16 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig return c.upstreamConfig } -// Find searches for a client by IP (and does not lock anything) -func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) { - if ip == nil { - return Client{}, false +// findLocked searches for a client by its ID. For internal use only. +func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { + c, ok = clients.idIndex[id] + if ok { + return c, true } - c, ok := clients.idIndex[ip.String()] - if ok { - return *c, true + ip := net.ParseIP(id) + if ip == nil { + return nil, false } for _, c = range clients.list { @@ -325,88 +322,96 @@ func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) { if err != nil { continue } + if ipnet.Contains(ip) { - return *c, true + return c, true } } } if clients.dhcpServer == nil { - return Client{}, false + return nil, false } + macFound := clients.dhcpServer.FindMACbyIP(ip) if macFound == nil { - return Client{}, false + return nil, false } + for _, c = range clients.list { for _, id := range c.IDs { hwAddr, err := net.ParseMAC(id) if err != nil { continue } + if bytes.Equal(hwAddr, macFound) { - return *c, true + return c, true } } } - return Client{}, false + return nil, false } // FindAutoClient - search for an auto-client by IP -func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) { - if ip == nil { +func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) { + ipAddr := net.ParseIP(ip) + if ipAddr == nil { return ClientHost{}, false } clients.lock.Lock() defer clients.lock.Unlock() - ch, ok := clients.ipHost[ip.String()] + ch, ok := clients.ipHost[ip] if ok { return *ch, true } return ClientHost{}, false } -// Check if Client object's fields are correct -func (clients *clientsContainer) check(c *Client) error { - if len(c.Name) == 0 { - return fmt.Errorf("invalid Name") - } - - if len(c.IDs) == 0 { - return fmt.Errorf("id required") +// check validates the client. +func (clients *clientsContainer) check(c *Client) (err error) { + switch { + case c == nil: + return agherr.Error("client is nil") + case c.Name == "": + return agherr.Error("invalid name") + case len(c.IDs) == 0: + return agherr.Error("id required") + default: + // Go on. } for i, id := range c.IDs { - ip := net.ParseIP(id) - if ip != nil { - c.IDs[i] = ip.String() // normalize IP address - continue + // Normalize structured data. + var ip net.IP + var ipnet *net.IPNet + var mac net.HardwareAddr + if ip = net.ParseIP(id); ip != nil { + c.IDs[i] = ip.String() + } else if ip, ipnet, err = net.ParseCIDR(id); err == nil { + ipnet.IP = ip + c.IDs[i] = ipnet.String() + } else if mac, err = net.ParseMAC(id); err == nil { + c.IDs[i] = mac.String() + } else if err = dnsforward.ValidateClientID(id); err == nil { + c.IDs[i] = id + } else { + return fmt.Errorf("invalid client id at index %d: %q", i, id) } - - _, _, err := net.ParseCIDR(id) - if err == nil { - continue - } - - _, err = net.ParseMAC(id) - if err == nil { - continue - } - - return fmt.Errorf("invalid ID: %s", id) } for _, t := range c.Tags { if !clients.tagKnown(t) { - return fmt.Errorf("invalid tag: %s", t) + return fmt.Errorf("invalid tag: %q", t) } } + sort.Strings(c.Tags) - err := dnsforward.ValidateUpstreams(c.Upstreams) + err = dnsforward.ValidateUpstreams(c.Upstreams) if err != nil { return fmt.Errorf("invalid upstream servers: %w", err) } @@ -414,49 +419,52 @@ func (clients *clientsContainer) check(c *Client) error { return nil } -// Add a new client object -// Return true: success; false: client exists. -func (clients *clientsContainer) Add(c Client) (bool, error) { - e := clients.check(&c) - if e != nil { - return false, e +// Add adds a new client object. ok is false if such client already exists or +// if an error occurred. +func (clients *clientsContainer) Add(c *Client) (ok bool, err error) { + err = clients.check(c) + if err != nil { + return false, err } clients.lock.Lock() defer clients.lock.Unlock() // check Name index - _, ok := clients.list[c.Name] + _, ok = clients.list[c.Name] if ok { return false, nil } // check ID index for _, id := range c.IDs { - c2, ok := clients.idIndex[id] + var c2 *Client + c2, ok = clients.idIndex[id] if ok { - return false, fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name) + return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name) } } // update Name index - clients.list[c.Name] = &c + clients.list[c.Name] = c // update ID index for _, id := range c.IDs { - clients.idIndex[id] = &c + clients.idIndex[id] = c } - log.Debug("Clients: added %q: ID:%v [%d]", c.Name, c.IDs, len(clients.list)) + log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs, len(clients.list)) + return true, nil } -// Del removes a client -func (clients *clientsContainer) Del(name string) bool { +// Del removes a client. ok is false if there is no such client. +func (clients *clientsContainer) Del(name string) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - c, ok := clients.list[name] + var c *Client + c, ok = clients.list[name] if !ok { return false } @@ -468,25 +476,28 @@ func (clients *clientsContainer) Del(name string) bool { for _, id := range c.IDs { delete(clients.idIndex, id) } + return true } -// Return TRUE if arrays are equal -func arraysEqual(a, b []string) bool { +// equalStringSlices returns true if the slices are equal. +func equalStringSlices(a, b []string) (ok bool) { if len(a) != len(b) { return false } - for i := 0; i != len(a); i++ { + + for i := range a { if a[i] != b[i] { return false } } + return true } -// Update a client -func (clients *clientsContainer) Update(name string, c Client) error { - err := clients.check(&c) +// Update updates a client by its name. +func (clients *clientsContainer) Update(name string, c *Client) (err error) { + err = clients.check(c) if err != nil { return err } @@ -494,66 +505,69 @@ func (clients *clientsContainer) Update(name string, c Client) error { clients.lock.Lock() defer clients.lock.Unlock() - old, ok := clients.list[name] + prev, ok := clients.list[name] if !ok { - return fmt.Errorf("client not found") + return agherr.Error("client not found") } // check Name index - if old.Name != c.Name { + if prev.Name != c.Name { _, ok = clients.list[c.Name] if ok { - return fmt.Errorf("client already exists") + return agherr.Error("client already exists") } } // check IP index - if !arraysEqual(old.IDs, c.IDs) { + if !equalStringSlices(prev.IDs, c.IDs) { for _, id := range c.IDs { c2, ok := clients.idIndex[id] - if ok && c2 != old { - return fmt.Errorf("another client uses the same ID (%s): %s", id, c2.Name) + if ok && c2 != prev { + return fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name) } } // update ID index - for _, id := range old.IDs { + for _, id := range prev.IDs { delete(clients.idIndex, id) } for _, id := range c.IDs { - clients.idIndex[id] = old + clients.idIndex[id] = prev } } // update Name index - if old.Name != c.Name { - delete(clients.list, old.Name) - clients.list[c.Name] = old + if prev.Name != c.Name { + delete(clients.list, prev.Name) + clients.list[c.Name] = prev } // update upstreams cache c.upstreamConfig = nil - *old = c + *prev = *c + return nil } -// SetWhoisInfo - associate WHOIS information with a client -func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) { +// SetWhoisInfo sets the WHOIS information for a client. +// +// TODO(a.garipov): Perhaps replace [][]string with map[string]string. +func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { clients.lock.Lock() defer clients.lock.Unlock() - _, ok := clients.findByIP(ip) + _, ok := clients.findLocked(ip) if ok { - log.Debug("Clients: client for %s is already created, ignore WHOIS info", ip) + log.Debug("clients: client for %s is already created, ignore whois info", ip) return } - ipStr := ip.String() - ch, ok := clients.ipHost[ipStr] + ch, ok := clients.ipHost[ip] if ok { ch.WhoisInfo = info - log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo) + log.Debug("clients: set whois info for auto-client %s: %q", ch.Host, info) + return } @@ -562,32 +576,34 @@ func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) { Source: ClientSourceWHOIS, } ch.WhoisInfo = info - clients.ipHost[ipStr] = ch - log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo) + clients.ipHost[ip] = ch + log.Debug("clients: set whois info for auto-client with IP %s: %q", ip, info) } -// AddHost adds new IP -> Host pair -// Use priority of the source (etc/hosts > ARP > rDNS) -// so we overwrite existing entries with an equal or higher priority -func (clients *clientsContainer) AddHost(ip, host string, source clientSource) (bool, error) { +// AddHost adds a new IP-hostname pairing. The priorities of the sources is +// taken into account. ok is true if the pairing was added. +func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) { clients.lock.Lock() - b := clients.addHost(ip, host, source) + ok = clients.addHostLocked(ip, host, src) clients.lock.Unlock() - return b, nil + + return ok, nil } -func (clients *clientsContainer) addHost(ip, host string, source clientSource) (addedNew bool) { - ch, ok := clients.ipHost[ip] +// addHostLocked adds a new IP-hostname pairing. For internal use only. +func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) { + var ch *ClientHost + ch, ok = clients.ipHost[ip] if ok { - if ch.Source > source { + if ch.Source > src { return false } - ch.Source = source + ch.Source = src } else { ch = &ClientHost{ Host: host, - Source: source, + Source: src, } clients.ipHost[ip] = ch @@ -598,11 +614,11 @@ func (clients *clientsContainer) addHost(ip, host string, source clientSource) ( return true } -// Remove all entries that match the specified source -func (clients *clientsContainer) rmHosts(source clientSource) { +// rmHostsBySrc removes all entries that match the specified source. +func (clients *clientsContainer) rmHostsBySrc(src clientSource) { n := 0 for k, v := range clients.ipHost { - if v.Source == source { + if v.Source == src { delete(clients.ipHost, k) n++ } @@ -611,19 +627,20 @@ func (clients *clientsContainer) rmHosts(source clientSource) { log.Debug("clients: removed %d client aliases", n) } -// addFromHostsFile fills the clients hosts list from the system's hosts files. +// addFromHostsFile fills the client-hostname pairing index from the system's +// hosts files. func (clients *clientsContainer) addFromHostsFile() { hosts := clients.autoHosts.List() clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHosts(ClientSourceHostsFile) + clients.rmHostsBySrc(ClientSourceHostsFile) n := 0 for ip, names := range hosts { for _, name := range names { - ok := clients.addHost(ip, name, ClientSourceHostsFile) + ok := clients.addHostLocked(ip, name, ClientSourceHostsFile) if ok { n++ } @@ -633,31 +650,31 @@ func (clients *clientsContainer) addFromHostsFile() { log.Debug("Clients: added %d client aliases from system hosts-file", n) } -// Add IP -> Host pairs from the system's `arp -a` command output -// The command's output is: -// HOST (IP) at MAC on IFACE +// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a +// command. func (clients *clientsContainer) addFromSystemARP() { if runtime.GOOS == "windows" { return } cmd := exec.Command("arp", "-a") - log.Tracef("executing %s %v", cmd.Path, cmd.Args) + log.Tracef("executing %q %q", cmd.Path, cmd.Args) data, err := cmd.Output() if err != nil || cmd.ProcessState.ExitCode() != 0 { - log.Debug("command %s has failed: %v code:%d", + log.Debug("command %q has failed: %q code:%d", cmd.Path, err, cmd.ProcessState.ExitCode()) return } clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHosts(ClientSourceARP) + + clients.rmHostsBySrc(ClientSourceARP) n := 0 + // TODO(a.garipov): Rewrite to use bufio.Scanner. lines := strings.Split(string(data), "\n") for _, ln := range lines { - open := strings.Index(ln, " (") close := strings.Index(ln, ") ") if open == -1 || close == -1 || open >= close { @@ -670,16 +687,17 @@ func (clients *clientsContainer) addFromSystemARP() { continue } - ok := clients.addHost(ip, host, ClientSourceARP) + ok := clients.addHostLocked(ip, host, ClientSourceARP) if ok { n++ } } - log.Debug("Clients: added %d client aliases from 'arp -a' command output", n) + log.Debug("clients: added %d client aliases from 'arp -a' command output", n) } -// Add clients from DHCP that have non-empty Hostname property +// addFromDHCP adds the clients that have a non-empty hostname from the DHCP +// server. func (clients *clientsContainer) addFromDHCP() { if clients.dhcpServer == nil { return @@ -688,18 +706,20 @@ func (clients *clientsContainer) addFromDHCP() { clients.lock.Lock() defer clients.lock.Unlock() - clients.rmHosts(ClientSourceDHCP) + clients.rmHostsBySrc(ClientSourceDHCP) leases := clients.dhcpServer.Leases(dhcpd.LeasesAll) n := 0 for _, l := range leases { - if len(l.Hostname) == 0 { + if l.Hostname == "" { continue } - ok := clients.addHost(l.IP.String(), l.Hostname, ClientSourceDHCP) + + ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP) if ok { n++ } } - log.Debug("Clients: added %d client aliases from DHCP", n) + + log.Debug("clients: added %d client aliases from dhcp", n) } diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 94ff8009..a098bf4c 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -18,65 +18,65 @@ func TestClients(t *testing.T) { clients.Init(nil, nil, nil) t.Run("add_success", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", } - b, err := clients.Add(c) - assert.True(t, b) + ok, err := clients.Add(c) + assert.True(t, ok) assert.Nil(t, err) - c = Client{ + c = &Client{ IDs: []string{"2.2.2.2"}, Name: "client2", } - b, err = clients.Add(c) - assert.True(t, b) + ok, err = clients.Add(c) + assert.True(t, ok) assert.Nil(t, err) - c, b = clients.Find(net.IPv4(1, 1, 1, 1)) - assert.True(t, b) - assert.Equal(t, c.Name, "client1") + c, ok = clients.Find("1.1.1.1") + assert.True(t, ok) + assert.Equal(t, "client1", c.Name) - c, b = clients.Find(net.ParseIP("1:2:3::4")) - assert.True(t, b) - assert.Equal(t, c.Name, "client1") + c, ok = clients.Find("1:2:3::4") + assert.True(t, ok) + assert.Equal(t, "client1", c.Name) - c, b = clients.Find(net.IPv4(2, 2, 2, 2)) - assert.True(t, b) - assert.Equal(t, c.Name, "client2") + c, ok = clients.Find("2.2.2.2") + assert.True(t, ok) + assert.Equal(t, "client2", c.Name) - assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile)) + assert.True(t, !clients.Exists("1.2.3.4", ClientSourceHostsFile)) + assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"1.2.3.5"}, Name: "client1", } - b, err := clients.Add(c) - assert.False(t, b) + ok, err := clients.Add(c) + assert.False(t, ok) assert.Nil(t, err) }) t.Run("add_fail_ip", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"2.2.2.2"}, Name: "client3", } - b, err := clients.Add(c) - assert.False(t, b) + ok, err := clients.Add(c) + assert.False(t, ok) assert.NotNil(t, err) }) t.Run("update_fail_name", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"1.2.3.0"}, Name: "client3", } @@ -84,7 +84,7 @@ func TestClients(t *testing.T) { err := clients.Update("client3", c) assert.NotNil(t, err) - c = Client{ + c = &Client{ IDs: []string{"1.2.3.0"}, Name: "client2", } @@ -94,7 +94,7 @@ func TestClients(t *testing.T) { }) t.Run("update_fail_ip", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"2.2.2.2"}, Name: "client1", } @@ -104,7 +104,7 @@ func TestClients(t *testing.T) { }) t.Run("update_success", func(t *testing.T) { - c := Client{ + c := &Client{ IDs: []string{"1.1.1.2"}, Name: "client1", } @@ -112,10 +112,10 @@ func TestClients(t *testing.T) { err := clients.Update("client1", c) assert.Nil(t, err) - assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) - assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) + assert.True(t, !clients.Exists("1.1.1.1", ClientSourceHostsFile)) + assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) - c = Client{ + c = &Client{ IDs: []string{"1.1.1.2"}, Name: "client1-renamed", UseOwnSettings: true, @@ -124,77 +124,89 @@ func TestClients(t *testing.T) { err = clients.Update("client1", c) assert.Nil(t, err) - c, b := clients.Find(net.IPv4(1, 1, 1, 2)) - assert.True(t, b) + c, ok := clients.Find("1.1.1.2") + assert.True(t, ok) assert.Equal(t, "client1-renamed", c.Name) - assert.Equal(t, "1.1.1.2", c.IDs[0]) assert.True(t, c.UseOwnSettings) assert.Nil(t, clients.list["client1"]) + if assert.Len(t, c.IDs, 1) { + assert.Equal(t, "1.1.1.2", c.IDs[0]) + } }) t.Run("del_success", func(t *testing.T) { - b := clients.Del("client1-renamed") - assert.True(t, b) - assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile)) + ok := clients.Del("client1-renamed") + assert.True(t, ok) + assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { - b := clients.Del("client3") - assert.False(t, b) + ok := clients.Del("client3") + assert.False(t, ok) }) t.Run("addhost_success", func(t *testing.T) { - b, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) - assert.True(t, b) + ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP) + assert.True(t, ok) assert.Nil(t, err) - b, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) - assert.True(t, b) + ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP) + assert.True(t, ok) assert.Nil(t, err) - b, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) - assert.True(t, b) + ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile) + assert.True(t, ok) assert.Nil(t, err) - assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile)) + assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) }) t.Run("addhost_fail", func(t *testing.T) { - b, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) - assert.False(t, b) + ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS) + assert.False(t, ok) assert.Nil(t, err) }) } func TestClientsWhois(t *testing.T) { - var c Client + var c *Client clients := clientsContainer{} clients.testing = true clients.Init(nil, nil, nil) whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} // set whois info on new client - clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois) - assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1]) + clients.SetWhoisInfo("1.1.1.255", whois) + if assert.NotNil(t, clients.ipHost["1.1.1.255"]) { + h := clients.ipHost["1.1.1.255"] + if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) { + assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) + } + } // set whois info on existing auto-client _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) - clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois) - assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1]) + clients.SetWhoisInfo("1.1.1.1", whois) + if assert.NotNil(t, clients.ipHost["1.1.1.1"]) { + h := clients.ipHost["1.1.1.1"] + if assert.Len(t, h.WhoisInfo, 2) && assert.Len(t, h.WhoisInfo[0], 2) { + assert.Equal(t, "orgname-val", h.WhoisInfo[0][1]) + } + } // Check that we cannot set whois info on a manually-added client - c = Client{ + c = &Client{ IDs: []string{"1.1.1.2"}, Name: "client1", } _, _ = clients.Add(c) - clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois) + clients.SetWhoisInfo("1.1.1.2", whois) assert.Nil(t, clients.ipHost["1.1.1.2"]) _ = clients.Del("client1") } func TestClientsAddExisting(t *testing.T) { - var c Client + var c *Client clients := clientsContainer{} clients.testing = true clients.Init(nil, nil, nil) @@ -204,7 +216,7 @@ func TestClientsAddExisting(t *testing.T) { testIP := "1.2.3.4" // add a client - c = Client{ + c = &Client{ IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"}, Name: "client1", } @@ -233,7 +245,7 @@ func TestClientsAddExisting(t *testing.T) { assert.Nil(t, err) // add a new client with the same IP as for a client with MAC - c = Client{ + c = &Client{ IDs: []string{testIP}, Name: "client2", } @@ -242,7 +254,7 @@ func TestClientsAddExisting(t *testing.T) { assert.Nil(t, err) // add a new client with the IP from the client1's IP range - c = Client{ + c = &Client{ IDs: []string{"2.2.2.2"}, Name: "client3", } @@ -258,7 +270,7 @@ func TestClientsCustomUpstream(t *testing.T) { clients.Init(nil, nil, nil) // add client with upstreams - client := Client{ + c := &Client{ IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", Upstreams: []string{ @@ -266,7 +278,7 @@ func TestClientsCustomUpstream(t *testing.T) { "[/example.org/]8.8.8.8", }, } - ok, err := clients.Add(client) + ok, err := clients.Add(c) assert.Nil(t, err) assert.True(t, ok) @@ -275,6 +287,6 @@ func TestClientsCustomUpstream(t *testing.T) { config = clients.FindUpstreams("1.1.1.1") assert.NotNil(t, config) - assert.Len(t, config.Upstreams, 1) - assert.Len(t, config.DomainReservedUpstreams, 1) + assert.Equal(t, 1, len(config.Upstreams)) + assert.Equal(t, 1, len(config.DomainReservedUpstreams)) } diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index e39f4767..42d7fa20 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -158,7 +158,7 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. } c := jsonToClient(cj) - ok, err := clients.Add(*c) + ok, err := clients.Add(c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -216,7 +216,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht } c := jsonToClient(dj.Data) - err = clients.Update(dj.Name, *c) + err = clients.Update(dj.Name, c) if err != nil { httpError(w, http.StatusBadRequest, "%s", err) return @@ -229,28 +229,28 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() data := []map[string]clientJSON{} - for i := 0; ; i++ { - ipStr := q.Get(fmt.Sprintf("ip%d", i)) - ip := net.ParseIP(ipStr) - if ip == nil { + for i := 0; i < len(q); i++ { + idStr := q.Get(fmt.Sprintf("ip%d", i)) + if idStr == "" { break } - c, ok := clients.Find(ip) + ip := net.ParseIP(idStr) + c, ok := clients.Find(idStr) var cj clientJSON if !ok { var found bool - cj, found = clients.findTemporary(ip) + cj, found = clients.findTemporary(ip, idStr) if !found { continue } } else { - cj = clientToJSON(&c) + cj = clientToJSON(c) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) } data = append(data, map[string]clientJSON{ - ipStr: cj, + idStr: cj, }) } @@ -263,10 +263,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // findTemporary looks up the IP in temporary storages, like autohosts or // blocklists. -func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found bool) { - ipStr := ip.String() - ch, ok := clients.FindAutoClient(ip) - if !ok { +func (clients *clientsContainer) findTemporary(ip net.IP, idStr string) (cj clientJSON, found bool) { + ch, ok := clients.FindAutoClient(idStr) + if !ok && ip != nil { // It is still possible that the IP used to be in the runtime // clients list, but then the server was reloaded. So, check // the DNS server's blocked IP list. @@ -278,7 +277,7 @@ func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found } cj = clientJSON{ - IDs: []string{ipStr}, + IDs: []string{idStr}, Disallowed: disallowed, DisallowedRule: rule, } @@ -286,8 +285,10 @@ func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found return cj, true } - cj = clientHostToJSON(ipStr, ch) - cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) + cj = clientHostToJSON(idStr, ch) + if ip != nil { + cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) + } return cj, true } diff --git a/internal/home/config.go b/internal/home/config.go index afe96812..65a9401c 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -1,6 +1,7 @@ package home import ( + "errors" "io/ioutil" "net" "os" @@ -188,7 +189,7 @@ func initConfig() { func (c *configuration) getConfigFilename() string { configFile, err := filepath.EvalSymlinks(Context.configFilename) if err != nil { - if !os.IsNotExist(err) { + if !errors.Is(err, os.ErrNotExist) { log.Error("unexpected error while config file path evaluation: %s", err) } configFile = Context.configFilename diff --git a/internal/home/dns.go b/internal/home/dns.go index 82b844a4..3640fd03 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -3,8 +3,10 @@ package home import ( "fmt" "net" + "net/url" "os" "path/filepath" + "strconv" "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" @@ -58,7 +60,7 @@ func initDNSServer() error { if config.DNS.BindHost.IsUnspecified() { bindhost = net.IPv4(127, 0, 0, 1) } - filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) + filterConf.ResolverAddress = net.JoinHostPort(bindhost.String(), strconv.Itoa(config.DNS.Port)) filterConf.AutoHosts = &Context.autoHosts filterConf.ConfigModified = onConfigModified filterConf.HTTPRegister = httpRegister @@ -126,6 +128,7 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { Context.tls.WriteDiskConfig(&tlsConf) if tlsConf.Enabled { newconfig.TLSConfig = tlsConf.TLSConfig + newconfig.TLSConfig.ServerName = tlsConf.ServerName if tlsConf.PortDNSOverTLS != 0 { newconfig.TLSListenAddr = &net.TCPAddr{ @@ -207,36 +210,42 @@ type dnsEncryption struct { quic string } -func getDNSEncryption() dnsEncryption { - dnsEncryption := dnsEncryption{} - +func getDNSEncryption() (de dnsEncryption) { tlsConf := tlsConfigSettings{} Context.tls.WriteDiskConfig(&tlsConf) if tlsConf.Enabled && len(tlsConf.ServerName) != 0 { - + hostname := tlsConf.ServerName if tlsConf.PortHTTPS != 0 { - addr := tlsConf.ServerName + addr := hostname if tlsConf.PortHTTPS != 443 { - addr = fmt.Sprintf("%s:%d", addr, tlsConf.PortHTTPS) + addr = net.JoinHostPort(addr, strconv.Itoa(tlsConf.PortHTTPS)) } - addr = fmt.Sprintf("https://%s/dns-query", addr) - dnsEncryption.https = addr + + de.https = (&url.URL{ + Scheme: "https", + Host: addr, + Path: "/dns-query", + }).String() } if tlsConf.PortDNSOverTLS != 0 { - addr := fmt.Sprintf("tls://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverTLS) - dnsEncryption.tls = addr + de.tls = (&url.URL{ + Scheme: "tls", + Host: net.JoinHostPort(hostname, strconv.Itoa(tlsConf.PortDNSOverTLS)), + }).String() } if tlsConf.PortDNSOverQUIC != 0 { - addr := fmt.Sprintf("quic://%s:%d", tlsConf.ServerName, tlsConf.PortDNSOverQUIC) - dnsEncryption.quic = addr + de.quic = (&url.URL{ + Scheme: "quic", + Host: net.JoinHostPort(hostname, strconv.Itoa(int(tlsConf.PortDNSOverQUIC))), + }).String() } } - return dnsEncryption + return de } // Get the list of DNS addresses the server is listening on @@ -273,21 +282,26 @@ func getDNSAddresses() []string { return dnsAddresses } -// If a client has his own settings, apply them -func applyAdditionalFiltering(clientAddr net.IP, setts *dnsfilter.RequestFilteringSettings) { +// applyAdditionalFiltering adds additional client information and settings if +// the client has them. +func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) { Context.dnsFilter.ApplyBlockedServices(setts, nil, true) if clientAddr == nil { return } + setts.ClientIP = clientAddr - c, ok := Context.clients.Find(clientAddr) + c, ok := Context.clients.Find(clientID) if !ok { - return + c, ok = Context.clients.Find(clientAddr.String()) + if !ok { + return + } } - log.Debug("Using settings for client %s with IP %s", c.Name, clientAddr) + log.Debug("using settings for client %s with ip %s and id %q", c.Name, clientAddr, clientID) if c.UseOwnBlockedServices { Context.dnsFilter.ApplyBlockedServices(setts, c.BlockedServices, false) diff --git a/internal/home/home.go b/internal/home/home.go index 3c068912..1b6312c8 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -5,6 +5,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io/ioutil" "net" @@ -434,6 +435,10 @@ func initWorkingDir(args options) { } else { Context.workDir = filepath.Dir(execPath) } + + if workDir, err := filepath.EvalSymlinks(Context.workDir); err == nil { + Context.workDir = workDir + } } // configureLogger configures logger level and output @@ -624,7 +629,7 @@ func detectFirstRun() bool { configfile = filepath.Join(Context.workDir, Context.configFilename) } _, err := os.Stat(configfile) - return os.IsNotExist(err) + return errors.Is(err, os.ErrNotExist) } // Connect to a remote server resolving hostname using our own DNS server diff --git a/internal/home/mobileconfig.go b/internal/home/mobileconfig.go index 2f06329e..3953e2e6 100644 --- a/internal/home/mobileconfig.go +++ b/internal/home/mobileconfig.go @@ -4,7 +4,10 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" + "path" + "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/golibs/log" uuid "github.com/satori/go.uuid" "howett.net/plist" @@ -14,6 +17,7 @@ type dnsSettings struct { DNSProtocol string ServerURL string `plist:",omitempty"` ServerName string `plist:",omitempty"` + clientID string } type payloadContent struct { @@ -23,19 +27,19 @@ type payloadContent struct { PayloadIdentifier string PayloadType string PayloadUUID string - PayloadVersion int DNSSettings dnsSettings + PayloadVersion int } type mobileConfig struct { - PayloadContent []payloadContent PayloadDescription string PayloadDisplayName string PayloadIdentifier string - PayloadRemovalDisallowed bool PayloadType string PayloadUUID string + PayloadContent []payloadContent PayloadVersion int + PayloadRemovalDisallowed bool } func genUUIDv4() string { @@ -48,22 +52,35 @@ const ( ) func getMobileConfig(d dnsSettings) ([]byte, error) { - var name string + var dspName string switch d.DNSProtocol { case dnsProtoHTTPS: - name = fmt.Sprintf("%s DoH", d.ServerName) - d.ServerURL = fmt.Sprintf("https://%s/dns-query", d.ServerName) + dspName = fmt.Sprintf("%s DoH", d.ServerName) + + u := &url.URL{ + Scheme: "https", + Host: d.ServerName, + Path: "/dns-query", + } + if d.clientID != "" { + u.Path = path.Join(u.Path, d.clientID) + } + + d.ServerURL = u.String() case dnsProtoTLS: - name = fmt.Sprintf("%s DoT", d.ServerName) + dspName = fmt.Sprintf("%s DoT", d.ServerName) + if d.clientID != "" { + d.ServerName = d.clientID + "." + d.ServerName + } default: return nil, fmt.Errorf("bad dns protocol %q", d.DNSProtocol) } data := mobileConfig{ PayloadContent: []payloadContent{{ - Name: name, + Name: dspName, PayloadDescription: "Configures device to use AdGuard Home", - PayloadDisplayName: name, + PayloadDisplayName: dspName, PayloadIdentifier: fmt.Sprintf("com.apple.dnsSettings.managed.%s", genUUIDv4()), PayloadType: "com.apple.dnsSettings.managed", PayloadUUID: genUUIDv4(), @@ -71,7 +88,7 @@ func getMobileConfig(d dnsSettings) ([]byte, error) { DNSSettings: d, }}, PayloadDescription: "Adds AdGuard Home to Big Sur and iOS 14 or newer systems", - PayloadDisplayName: name, + PayloadDisplayName: dspName, PayloadIdentifier: genUUIDv4(), PayloadRemovalDisallowed: false, PayloadType: "Configuration", @@ -83,7 +100,10 @@ func getMobileConfig(d dnsSettings) ([]byte, error) { } func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { - host := r.URL.Query().Get("host") + var err error + + q := r.URL.Query() + host := q.Get("host") if host == "" { host = Context.tls.conf.ServerName } @@ -92,7 +112,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { w.WriteHeader(http.StatusInternalServerError) const msg = "no host in query parameters and no server_name" - err := json.NewEncoder(w).Encode(&jsonError{ + err = json.NewEncoder(w).Encode(&jsonError{ Message: msg, }) if err != nil { @@ -102,9 +122,25 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { return } + clientID := q.Get("client_id") + err = dnsforward.ValidateClientID(clientID) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + + err = json.NewEncoder(w).Encode(&jsonError{ + Message: err.Error(), + }) + if err != nil { + log.Debug("writing 400 json response: %s", err) + } + + return + } + d := dnsSettings{ DNSProtocol: dnsp, ServerName: host, + clientID: clientID, } mobileconfig, err := getMobileConfig(d) @@ -115,6 +151,7 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) { } w.Header().Set("Content-Type", "application/xml") + _, _ = w.Write(mobileconfig) } diff --git a/internal/home/mobileconfig_test.go b/internal/home/mobileconfig_test.go index 1025fe93..9dcafc97 100644 --- a/internal/home/mobileconfig_test.go +++ b/internal/home/mobileconfig_test.go @@ -73,6 +73,27 @@ func TestHandleMobileConfigDOH(t *testing.T) { handleMobileConfigDOH(w, r) assert.Equal(t, http.StatusInternalServerError, w.Code) }) + + t.Run("client_id", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/doh.mobileconfig?host=example.org&client_id=cli42", nil) + assert.Nil(t, err) + + w := httptest.NewRecorder() + + handleMobileConfigDOH(w, r) + assert.Equal(t, http.StatusOK, w.Code) + + var mc mobileConfig + _, err = plist.Unmarshal(w.Body.Bytes(), &mc) + assert.Nil(t, err) + + if assert.Len(t, mc.PayloadContent, 1) { + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoH", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "example.org", mc.PayloadContent[0].DNSSettings.ServerName) + assert.Equal(t, "https://example.org/dns-query/cli42", mc.PayloadContent[0].DNSSettings.ServerURL) + } + }) } func TestHandleMobileConfigDOT(t *testing.T) { @@ -137,4 +158,24 @@ func TestHandleMobileConfigDOT(t *testing.T) { handleMobileConfigDOT(w, r) assert.Equal(t, http.StatusInternalServerError, w.Code) }) + + t.Run("client_id", func(t *testing.T) { + r, err := http.NewRequest(http.MethodGet, "https://example.com:12345/apple/dot.mobileconfig?host=example.org&client_id=cli42", nil) + assert.Nil(t, err) + + w := httptest.NewRecorder() + + handleMobileConfigDOT(w, r) + assert.Equal(t, http.StatusOK, w.Code) + + var mc mobileConfig + _, err = plist.Unmarshal(w.Body.Bytes(), &mc) + assert.Nil(t, err) + + if assert.Len(t, mc.PayloadContent, 1) { + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].Name) + assert.Equal(t, "example.org DoT", mc.PayloadContent[0].PayloadDisplayName) + assert.Equal(t, "cli42.example.org", mc.PayloadContent[0].DNSSettings.ServerName) + } + }) } diff --git a/internal/home/rdns.go b/internal/home/rdns.go index c71f3822..dad75e44 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -57,7 +57,8 @@ func (r *RDNS) Begin(ip net.IP) { binary.BigEndian.PutUint64(expire, now+ttl) _ = r.ipAddrs.Set(ip, expire) - if r.clients.Exists(ip, ClientSourceRDNS) { + id := ip.String() + if r.clients.Exists(id, ClientSourceRDNS) { return } diff --git a/internal/home/whois.go b/internal/home/whois.go index 1d849673..26c674dc 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghio" "github.com/AdguardTeam/AdGuardHome/internal/util" - "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" ) @@ -25,14 +24,16 @@ const ( // Whois - module context type Whois struct { - clients *clientsContainer - ipChan chan net.IP - timeoutMsec uint + clients *clientsContainer + ipChan chan net.IP // Contains IP addresses of clients // An active IP address is resolved once again after it expires. // If IP address couldn't be resolved, it stays here for some time to prevent further attempts to resolve the same IP. ipAddrs cache.Cache + + // TODO(a.garipov): Rewrite to use time.Duration. Like, seriously, why? + timeoutMsec uint } // initWhois creates the Whois module context. @@ -244,6 +245,7 @@ func (w *Whois) workerLoop() { continue } - w.clients.SetWhoisInfo(ip, info) + id := ip.String() + w.clients.SetWhoisInfo(id, info) } } diff --git a/internal/querylog/decode.go b/internal/querylog/decode.go index ad0948ba..3e9a5f33 100644 --- a/internal/querylog/decode.go +++ b/internal/querylog/decode.go @@ -17,6 +17,16 @@ import ( type logEntryHandler (func(t json.Token, ent *logEntry) error) var logEntryHandlers = map[string]logEntryHandler{ + "CID": func(t json.Token, ent *logEntry) error { + v, ok := t.(string) + if !ok { + return nil + } + + ent.ClientID = v + + return nil + }, "IP": func(t json.Token, ent *logEntry) error { v, ok := t.(string) if !ok { diff --git a/internal/querylog/decode_test.go b/internal/querylog/decode_test.go index 40052fea..58ca5851 100644 --- a/internal/querylog/decode_test.go +++ b/internal/querylog/decode_test.go @@ -25,6 +25,7 @@ func TestDecodeLogEntry(t *testing.T) { t.Run("success", func(t *testing.T) { const ansStr = `Qz+BgAABAAEAAAAAAmFuBnlhbmRleAJydQAAAQABwAwAAQABAAAACgAEAAAAAA==` const data = `{"IP":"127.0.0.1",` + + `"CID":"cli42",` + `"T":"2020-11-25T18:55:56.519796+03:00",` + `"QH":"an.yandex.ru",` + `"QT":"A",` + @@ -52,6 +53,7 @@ func TestDecodeLogEntry(t *testing.T) { QHost: "an.yandex.ru", QType: "A", QClass: "IN", + ClientID: "cli42", ClientProto: "", Answer: ans, Result: dnsfilter.Result{ diff --git a/internal/querylog/json.go b/internal/querylog/json.go index 152f6ce2..7bfd98c4 100644 --- a/internal/querylog/json.go +++ b/internal/querylog/json.go @@ -79,6 +79,10 @@ func (l *queryLog) logEntryToJSONEntry(entry *logEntry) (jsonEntry jobject) { }, } + if entry.ClientID != "" { + jsonEntry["client_id"] = entry.ClientID + } + if msg != nil { jsonEntry["status"] = dns.RcodeToString[msg.Rcode] diff --git a/internal/querylog/qlog.go b/internal/querylog/qlog.go index 30054a92..41ce9823 100644 --- a/internal/querylog/qlog.go +++ b/internal/querylog/qlog.go @@ -2,6 +2,7 @@ package querylog import ( + "errors" "fmt" "net" "os" @@ -37,10 +38,11 @@ type ClientProto string // Client protocol names. const ( - ClientProtoDOH ClientProto = "doh" - ClientProtoDOQ ClientProto = "doq" - ClientProtoDOT ClientProto = "dot" - ClientProtoPlain ClientProto = "" + ClientProtoDOH ClientProto = "doh" + ClientProtoDOQ ClientProto = "doq" + ClientProtoDOT ClientProto = "dot" + ClientProtoDNSCrypt ClientProto = "dnscrypt" + ClientProtoPlain ClientProto = "" ) // NewClientProto validates that the client protocol name is valid and returns @@ -68,6 +70,7 @@ type logEntry struct { QType string `json:"QT"` QClass string `json:"QC"` + ClientID string `json:"CID,omitempty"` ClientProto ClientProto `json:"CP"` Answer []byte `json:",omitempty"` // sometimes empty answers happen like binerdunt.top or rev2.globalrootservers.net @@ -119,14 +122,15 @@ func (l *queryLog) clear() { l.flushPending = false l.bufferLock.Unlock() - err := os.Remove(l.logFile + ".1") - if err != nil && !os.IsNotExist(err) { - log.Error("file remove: %s: %s", l.logFile+".1", err) + oldLogFile := l.logFile + ".1" + err := os.Remove(oldLogFile) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Error("removing old log file %q: %s", oldLogFile, err) } err = os.Remove(l.logFile) - if err != nil && !os.IsNotExist(err) { - log.Error("file remove: %s: %s", l.logFile, err) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Error("removing log file %q: %s", l.logFile, err) } log.Debug("Query log: cleared") @@ -154,6 +158,7 @@ func (l *queryLog) Add(params AddParams) { Result: *params.Result, Elapsed: params.Elapsed, Upstream: params.Upstream, + ClientID: params.ClientID, ClientProto: params.ClientProto, } q := params.Question.Question[0] diff --git a/internal/querylog/qlogfile.go b/internal/querylog/qlogfile.go index 69a42ed2..3aa56f6f 100644 --- a/internal/querylog/qlogfile.go +++ b/internal/querylog/qlogfile.go @@ -251,7 +251,7 @@ func (q *QLogFile) readNextLine(position int64) (string, int64, error) { // the goal is to read a chunk of file that includes the line with the specified position. func (q *QLogFile) initBuffer(position int64) error { q.bufferStart = int64(0) - if (position - bufferSize) > 0 { + if position > bufferSize { q.bufferStart = position - bufferSize } @@ -264,12 +264,10 @@ func (q *QLogFile) initBuffer(position int64) error { if q.buffer == nil { q.buffer = make([]byte, bufferSize) } - q.bufferLen, err = q.file.Read(q.buffer) - if err != nil { - return err - } - return nil + q.bufferLen, err = q.file.Read(q.buffer) + + return err } // readProbeLine reads a line that includes the specified position @@ -280,7 +278,7 @@ func (q *QLogFile) readProbeLine(position int64) (string, int64, int64, error) { // In order to do this, we'll define the boundaries seekPosition := int64(0) relativePos := position // position relative to the buffer we're going to read - if (position - maxEntrySize) > 0 { + if position > maxEntrySize { seekPosition = position - maxEntrySize relativePos = maxEntrySize } diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index 6a6e0a6c..98b8959d 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -46,6 +46,7 @@ type AddParams struct { OrigAnswer *dns.Msg // The response from an upstream server (optional) Result *dnsfilter.Result // Filtering result (optional) Elapsed time.Duration // Time spent for processing the request + ClientID string ClientIP net.IP Upstream string // Upstream server URL ClientProto ClientProto diff --git a/internal/querylog/querylogfile.go b/internal/querylog/querylogfile.go index c6d48235..a0fd165b 100644 --- a/internal/querylog/querylogfile.go +++ b/internal/querylog/querylogfile.go @@ -3,6 +3,7 @@ package querylog import ( "bytes" "encoding/json" + "errors" "os" "time" @@ -87,18 +88,19 @@ func (l *queryLog) rotate() error { from := l.logFile to := l.logFile + ".1" - if _, err := os.Stat(from); os.IsNotExist(err) { - // do nothing, file doesn't exist - return nil - } - err := os.Rename(from, to) if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + log.Error("querylog: failed to rename file: %s", err) + return err } log.Debug("querylog: renamed %s -> %s", from, to) + return nil } diff --git a/internal/querylog/searchcriteria.go b/internal/querylog/searchcriteria.go index 1c2b26e3..4db39bf9 100644 --- a/internal/querylog/searchcriteria.go +++ b/internal/querylog/searchcriteria.go @@ -9,8 +9,13 @@ import ( type criteriaType int const ( - ctDomainOrClient criteriaType = iota // domain name or client IP address - ctFilteringStatus // filtering status + // ctDomainOrClient is for searching by the domain name, the client's IP + // address, or the clinet's ID. + ctDomainOrClient criteriaType = iota + // ctFilteringStatus is for searching by the filtering status. + // + // See (*searchCriteria).ctFilteringStatusCase for details. + ctFilteringStatus ) const ( @@ -38,9 +43,9 @@ var filteringStatusValues = []string{ // searchCriteria - every search request may contain a list of different search criteria // we use each of them to match the query type searchCriteria struct { + value string // search criteria value criteriaType criteriaType // type of the criteria strict bool // should we strictly match (equality) or not (indexOf) - value string // search criteria value } // quickMatch - quickly checks if the log entry matches this search criteria @@ -51,7 +56,8 @@ func (c *searchCriteria) quickMatch(line string) bool { switch c.criteriaType { case ctDomainOrClient: return c.quickMatchJSONValue(line, "QH") || - c.quickMatchJSONValue(line, "IP") + c.quickMatchJSONValue(line, "IP") || + c.quickMatchJSONValue(line, "ID") default: return true } @@ -89,13 +95,14 @@ func (c *searchCriteria) match(entry *logEntry) bool { } func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool { + clientID := strings.ToLower(entry.ClientID) qhost := strings.ToLower(entry.QHost) searchVal := strings.ToLower(c.value) - if c.strict && qhost == searchVal { + if c.strict && (qhost == searchVal || clientID == searchVal) { return true } - if !c.strict && strings.Contains(qhost, searchVal) { + if !c.strict && (strings.Contains(qhost, searchVal) || strings.Contains(clientID, searchVal)) { return true } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 1addbebd..7ed1d320 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -76,10 +76,14 @@ const ( rLast ) -// Entry - data to add +// Entry is a statistics data entry. type Entry struct { + // Clients is the client's primary ID. + // + // TODO(a.garipov): Make this a {net.IP, string} enum? + Client string + Domain string - Client net.IP Result Result Time uint32 // processing time (msec) } diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index b4be4db0..c4fbe191 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -39,13 +39,13 @@ func TestStats(t *testing.T) { e := Entry{} e.Domain = "domain" - e.Client = net.IP{127, 0, 0, 1} + e.Client = "127.0.0.1" e.Result = RFiltered e.Time = 123456 s.Update(e) e.Domain = "domain" - e.Client = net.IP{127, 0, 0, 1} + e.Client = "127.0.0.1" e.Result = RNotFiltered e.Time = 123456 s.Update(e) @@ -113,9 +113,10 @@ func TestLargeNumbers(t *testing.T) { } for i := 0; i != n; i++ { e.Domain = fmt.Sprintf("domain%d", i) - e.Client = net.IP{127, 0, 0, 1} - e.Client[2] = byte((i & 0xff00) >> 8) - e.Client[3] = byte(i & 0xff) + ip := net.IP{127, 0, 0, 1} + ip[2] = byte((i & 0xff00) >> 8) + ip[3] = byte(i & 0xff) + e.Client = ip.String() e.Result = RNotFiltered e.Time = 123456 s.Update(e) diff --git a/internal/stats/unit.go b/internal/stats/unit.go index 962fe85b..6f31cd5e 100644 --- a/internal/stats/unit.go +++ b/internal/stats/unit.go @@ -223,6 +223,7 @@ func (s *statsCtx) periodicFlush() { s.unitLock.Lock() ptr := s.unit s.unitLock.Unlock() + if ptr == nil { break } @@ -230,6 +231,7 @@ func (s *statsCtx) periodicFlush() { id := s.conf.UnitID() if ptr.id == id { time.Sleep(time.Second) + continue } @@ -243,6 +245,7 @@ func (s *statsCtx) periodicFlush() { if tx == nil { continue } + ok1 := s.flushUnitToDB(tx, u.id, udb) ok2 := s.deleteUnit(tx, id-s.conf.limit) if ok1 || ok2 { @@ -251,6 +254,7 @@ func (s *statsCtx) periodicFlush() { _ = tx.Rollback() } } + log.Tracef("periodicFlush() exited") } @@ -265,7 +269,7 @@ func (s *statsCtx) deleteUnit(tx *bolt.Tx, id uint32) bool { return true } -func convertMapToArray(m map[string]uint64, max int) []countPair { +func convertMapToSlice(m map[string]uint64, max int) []countPair { a := []countPair{} for k, v := range m { pair := countPair{} @@ -283,7 +287,7 @@ func convertMapToArray(m map[string]uint64, max int) []countPair { return a[:max] } -func convertArrayToMap(a []countPair) map[string]uint64 { +func convertSliceToMap(a []countPair) map[string]uint64 { m := map[string]uint64{} for _, it := range a { m[it.Name] = it.Count @@ -301,9 +305,9 @@ func serialize(u *unit) *unitDB { udb.TimeAvg = uint32(u.timeSum / u.nTotal) } - udb.Domains = convertMapToArray(u.domains, maxDomains) - udb.BlockedDomains = convertMapToArray(u.blockedDomains, maxDomains) - udb.Clients = convertMapToArray(u.clients, maxClients) + udb.Domains = convertMapToSlice(u.domains, maxDomains) + udb.BlockedDomains = convertMapToSlice(u.blockedDomains, maxDomains) + udb.Clients = convertMapToSlice(u.clients, maxClients) return &udb } @@ -319,9 +323,9 @@ func deserialize(u *unit, udb *unitDB) { u.nResult[i] = udb.NResult[i] } - u.domains = convertArrayToMap(udb.Domains) - u.blockedDomains = convertArrayToMap(udb.BlockedDomains) - u.clients = convertArrayToMap(udb.Clients) + u.domains = convertSliceToMap(udb.Domains) + u.blockedDomains = convertSliceToMap(udb.BlockedDomains) + u.clients = convertSliceToMap(udb.Clients) u.timeSum = uint64(udb.TimeAvg) * u.nTotal } @@ -372,7 +376,7 @@ func (s *statsCtx) loadUnitFromDB(tx *bolt.Tx, id uint32) *unitDB { return &udb } -func convertTopArray(a []countPair) []map[string]uint64 { +func convertTopSlice(a []countPair) []map[string]uint64 { m := []map[string]uint64{} for _, it := range a { ent := map[string]uint64{} @@ -461,13 +465,20 @@ func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) { func (s *statsCtx) Update(e Entry) { if e.Result == 0 || e.Result >= rLast || - len(e.Domain) == 0 || - !(len(e.Client) == 4 || len(e.Client) == 16) { + e.Domain == "" || + e.Client == "" { return } - client := s.getClientIP(e.Client) + + clientID := e.Client + if ip := net.ParseIP(clientID); ip != nil { + ip = s.getClientIP(ip) + clientID = ip.String() + } s.unitLock.Lock() + defer s.unitLock.Unlock() + u := s.unit u.nResult[e.Result]++ @@ -478,10 +489,9 @@ func (s *statsCtx) Update(e Entry) { u.blockedDomains[e.Domain]++ } - u.clients[client.String()]++ + u.clients[clientID]++ u.timeSum += uint64(e.Time) u.nTotal++ - s.unitLock.Unlock() } func (s *statsCtx) loadUnits(limit uint32) ([]*unitDB, uint32) { @@ -594,8 +604,8 @@ func (s *statsCtx) getData() (statsResponse, bool) { m[it.Name] += it.Count } } - a2 := convertMapToArray(m, max) - return convertTopArray(a2) + a2 := convertMapToSlice(m, max) + return convertTopSlice(a2) } dnsQueries := statsCollector(func(u *unitDB) (num uint64) { return u.NTotal }) @@ -661,7 +671,7 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP { m[it.Name] += it.Count } } - a := convertMapToArray(m, int(maxCount)) + a := convertMapToSlice(m, int(maxCount)) d := []net.IP{} for _, it := range a { d = append(d, net.ParseIP(it.Name)) diff --git a/openapi/CHANGELOG.md b/openapi/CHANGELOG.md index 225d2ead..fab1099e 100644 --- a/openapi/CHANGELOG.md +++ b/openapi/CHANGELOG.md @@ -4,6 +4,11 @@ ## v0.105: API changes +### New `"dnscrypt"` `"client_proto"` value in `GET /querylog` response + +* The field `"client_proto"` can now have the value `"dnscrypt"` when the + request was sent over a DNSCrypt connection. + ### New `"reason"` in `GET /filtering/check_host` and `GET /querylog` * The new `RewriteRule` reason is added to `GET /filtering/check_host` and diff --git a/openapi/openapi.yaml b/openapi/openapi.yaml index 490fd65d..366222af 100644 --- a/openapi/openapi.yaml +++ b/openapi/openapi.yaml @@ -794,11 +794,17 @@ 'tags': - 'clients' 'operationId': 'clientsFind' - 'summary': 'Get information about selected clients by their IP address' + 'summary': > + Get information about clients by their IP addresses or client IDs. 'parameters': - 'name': 'ip0' 'in': 'query' - 'description': 'Filter by IP address' + 'description': > + Filter by IP address or client IDs. Parameters with names `ip1`, + `ip2`, and so on are also accepted and interpreted as "ip0 OR ip1 OR + ip2". + + TODO(a.garipov): Replace with a better query API. 'schema': 'type': 'string' 'responses': @@ -1109,6 +1115,13 @@ 'name': 'host' 'schema': 'type': 'string' + - 'description': > + Client ID. + 'example': 'client-1' + 'in': 'query' + 'name': 'client_id' + 'schema': + 'type': 'string' 'responses': '200': 'description': 'DNS over HTTPS plist file.' @@ -1136,6 +1149,13 @@ 'name': 'host' 'schema': 'type': 'string' + - 'description': > + Client ID. + 'example': 'client-1' + 'in': 'query' + 'name': 'client_id' + 'schema': + 'type': 'string' 'responses': '200': 'description': 'DNS over TLS plist file' @@ -1781,13 +1801,21 @@ 'answer_dnssec': 'type': 'boolean' 'client': - 'type': 'string' + 'description': > + The client's IP address. 'example': '192.168.0.1' + 'type': 'string' + 'client_id': + 'description': > + The client ID, if provided in DOH, DOQ, or DOT. + 'example': 'cli123' + 'type': 'string' 'client_proto': 'enum': - 'dot' - 'doh' - 'doq' + - 'dnscrypt' - '' 'elapsedMs': 'type': 'string' @@ -2094,7 +2122,7 @@ 'type': 'string' 'Client': 'type': 'object' - 'description': 'Client information' + 'description': 'Client information.' 'properties': 'name': 'type': 'string' @@ -2102,7 +2130,7 @@ 'example': 'localhost' 'ids': 'type': 'array' - 'description': 'IP, CIDR or MAC address' + 'description': 'IP, CIDR, MAC, or client ID.' 'items': 'type': 'string' 'use_global_settings': @@ -2157,9 +2185,38 @@ 'type': 'string' 'ClientsFindResponse': 'type': 'array' - 'description': 'Response to clients find operation' + 'description': 'Client search results.' 'items': '$ref': '#/components/schemas/ClientsFindEntry' + 'example': + - 'cli42': + 'name': 'Client 42' + 'ids': ['cli42'] + 'use_global_settings': true + 'filtering_enabled': true + 'parental_enabled': true + 'safebrowsing_enabled': true + 'safesearch_enabled': true + 'use_global_blocked_services': true + 'blocked_services': null + 'upstreams': null + 'whois_info': null + 'disallowed': false + 'disallowed_rule': '' + - '1.2.3.4': + 'name': 'Client 1-2-3-4' + 'ids': ['1.2.3.4'] + 'use_global_settings': true + 'filtering_enabled': true + 'parental_enabled': true + 'safebrowsing_enabled': true + 'safesearch_enabled': true + 'use_global_blocked_services': true + 'blocked_services': null + 'upstreams': null + 'whois_info': null + 'disallowed': false + 'disallowed_rule': '' 'AccessListResponse': '$ref': '#/components/schemas/AccessList' 'AccessSetRequest': @@ -2187,10 +2244,9 @@ 'type': 'object' 'additionalProperties': '$ref': '#/components/schemas/ClientFindSubEntry' - 'example': - '1.2.3.4': 'test' 'ClientFindSubEntry': 'type': 'object' + 'description': 'Client information.' 'properties': 'name': 'type': 'string' @@ -2198,7 +2254,7 @@ 'example': 'localhost' 'ids': 'type': 'array' - 'description': 'IP, CIDR or MAC address' + 'description': 'IP, CIDR, MAC, or client ID.' 'items': 'type': 'string' 'use_global_settings': diff --git a/scripts/make/go-build.sh b/scripts/make/go-build.sh index f9702c35..f8520f21 100644 --- a/scripts/make/go-build.sh +++ b/scripts/make/go-build.sh @@ -54,7 +54,7 @@ esac # TODO(a.garipov): Additional validation? version="$VERSION" -# Set the linker flags accordingly: set the realease channel and the +# Set the linker flags accordingly: set the release channel and the # current version as well as goarm and gomips variable values, if the # variables are set and are not empty. readonly version_pkg='github.com/AdguardTeam/AdGuardHome/internal/version'