Merge branch 'master' into 4993-alt-svc

This commit is contained in:
Ainar Garipov 2022-10-10 17:59:12 +03:00
commit e4a42bf233
74 changed files with 2509 additions and 1022 deletions

View File

@ -1,7 +1,7 @@
'name': 'build' 'name': 'build'
'env': 'env':
'GO_VERSION': '1.18.6' 'GO_VERSION': '1.18.7'
'NODE_VERSION': '14' 'NODE_VERSION': '14'
'on': 'on':

View File

@ -1,7 +1,7 @@
'name': 'lint' 'name': 'lint'
'env': 'env':
'GO_VERSION': '1.18.6' 'GO_VERSION': '1.18.7'
'on': 'on':
'push': 'push':

View File

@ -18,18 +18,43 @@ and this project adheres to
## [v0.108.0] - TBA (APPROX.) ## [v0.108.0] - TBA (APPROX.)
--> -->
## Added
- The ability to put [ClientIDs][clientid] into DNS-over-HTTPS hostnames as
opposed to URL paths ([#3418]). Note that AdGuard Home checks the server name
only if the URL does not contain a ClientID.
[#3418]: https://github.com/AdguardTeam/AdGuardHome/issues/3418
[clientid]: https://github.com/AdguardTeam/AdGuardHome/wiki/Clients#clientid
<!-- <!--
## [v0.107.16] - 2022-11-02 (APPROX.) ## [v0.107.17] - 2022-11-02 (APPROX.)
See also the [v0.107.16 GitHub milestone][ms-v0.107.15]. See also the [v0.107.17 GitHub milestone][ms-v0.107.17].
[ms-v0.107.16]: https://github.com/AdguardTeam/AdGuardHome/milestone/52?closed=1 [ms-v0.107.17]: https://github.com/AdguardTeam/AdGuardHome/milestone/52?closed=1
--> -->
## [v0.107.16] - 2022-10-07
This is a security update. There is no GitHub milestone, since no GitHub issues
were resolved.
## Security
- Go version has been updated to prevent the possibility of exploiting the
CVE-2022-2879, CVE-2022-2880, and CVE-2022-41715 Go vulnerabilities fixed in
[Go 1.18.7][go-1.18.7].
[go-1.18.7]: https://groups.google.com/g/golang-announce/c/xtuG5faxtaU
## [v0.107.15] - 2022-10-03 ## [v0.107.15] - 2022-10-03
See also the [v0.107.15 GitHub milestone][ms-v0.107.15]. See also the [v0.107.15 GitHub milestone][ms-v0.107.15].
@ -52,7 +77,7 @@ experimental and may break or change in the future.
explicitly enabled by setting the new property `dns.serve_http3` in the explicitly enabled by setting the new property `dns.serve_http3` in the
configuration file to `true`. configuration file to `true`.
- DNS-over-HTTP upstreams can now upgrade to HTTP/3 if the new configuration - DNS-over-HTTP upstreams can now upgrade to HTTP/3 if the new configuration
file property `use_http3_upstreams` is set to `true`. file property `dns.use_http3_upstreams` is set to `true`.
- Upstreams with forced DNS-over-HTTP/3 and no fallback to prior HTTP versions - Upstreams with forced DNS-over-HTTP/3 and no fallback to prior HTTP versions
using the `h3://` scheme. using the `h3://` scheme.
@ -166,7 +191,7 @@ See also the [v0.107.12 GitHub milestone][ms-v0.107.12].
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
CVE-2022-27664 and CVE-2022-32190 Go vulnerabilities fixed in CVE-2022-27664 and CVE-2022-32190 Go vulnerabilities fixed in
[Go 1.18.6][go-1.18.6]. [Go 1.18.6][go-1.18.6].
@ -287,7 +312,7 @@ See also the [v0.107.9 GitHub milestone][ms-v0.107.9].
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
CVE-2022-32189 Go vulnerability fixed in [Go 1.18.5][go-1.18.5]. Go 1.17 CVE-2022-32189 Go vulnerability fixed in [Go 1.18.5][go-1.18.5]. Go 1.17
support has also been removed, as it has reached end of life and will not support has also been removed, as it has reached end of life and will not
receive security updates. receive security updates.
@ -330,7 +355,7 @@ See also the [v0.107.8 GitHub milestone][ms-v0.107.8].
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
CVE-2022-1705, CVE-2022-32148, CVE-2022-30631, and other Go vulnerabilities CVE-2022-1705, CVE-2022-32148, CVE-2022-30631, and other Go vulnerabilities
fixed in [Go 1.17.12][go-1.17.12]. fixed in [Go 1.17.12][go-1.17.12].
@ -366,7 +391,7 @@ See also the [v0.107.7 GitHub milestone][ms-v0.107.7].
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
[CVE-2022-29526], [CVE-2022-30634], [CVE-2022-30629], [CVE-2022-30580], and [CVE-2022-29526], [CVE-2022-30634], [CVE-2022-30629], [CVE-2022-30580], and
[CVE-2022-29804] Go vulnerabilities. [CVE-2022-29804] Go vulnerabilities.
- Enforced password strength policy ([#3503]). - Enforced password strength policy ([#3503]).
@ -523,7 +548,7 @@ See also the [v0.107.6 GitHub milestone][ms-v0.107.6].
### Security ### Security
- `User-Agent` HTTP header removed from outgoing DNS-over-HTTPS requests. - `User-Agent` HTTP header removed from outgoing DNS-over-HTTPS requests.
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
[CVE-2022-24675], [CVE-2022-27536], and [CVE-2022-28327] Go vulnerabilities. [CVE-2022-24675], [CVE-2022-27536], and [CVE-2022-28327] Go vulnerabilities.
### Added ### Added
@ -578,7 +603,7 @@ were resolved.
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
[CVE-2022-24921] Go vulnerability. [CVE-2022-24921] Go vulnerability.
[CVE-2022-24921]: https://www.cvedetails.com/cve/CVE-2022-24921 [CVE-2022-24921]: https://www.cvedetails.com/cve/CVE-2022-24921
@ -591,7 +616,7 @@ See also the [v0.107.4 GitHub milestone][ms-v0.107.4].
### Security ### Security
- Go version was updated to prevent the possibility of exploiting the - Go version has been updated to prevent the possibility of exploiting the
[CVE-2022-23806], [CVE-2022-23772], and [CVE-2022-23773] Go vulnerabilities. [CVE-2022-23806], [CVE-2022-23772], and [CVE-2022-23773] Go vulnerabilities.
### Fixed ### Fixed
@ -1328,11 +1353,12 @@ See also the [v0.104.2 GitHub milestone][ms-v0.104.2].
<!-- <!--
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.16...HEAD [Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.17...HEAD
[v0.107.16]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.15...v0.107.15 [v0.107.17]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.16...v0.107.17
--> -->
[Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.15...HEAD [Unreleased]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.16...HEAD
[v0.107.16]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.15...v0.107.16
[v0.107.15]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.14...v0.107.15 [v0.107.15]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.14...v0.107.15
[v0.107.14]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.13...v0.107.14 [v0.107.14]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.13...v0.107.14
[v0.107.13]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.12...v0.107.13 [v0.107.13]: https://github.com/AdguardTeam/AdGuardHome/compare/v0.107.12...v0.107.13

View File

@ -34,7 +34,7 @@ YARN_INSTALL_FLAGS = $(YARN_FLAGS) --network-timeout 120000 --silent\
--ignore-engines --ignore-optional --ignore-platform\ --ignore-engines --ignore-optional --ignore-platform\
--ignore-scripts --ignore-scripts
V1API = 0 NEXTAPI = 0
# Macros for the build-release target. If FRONTEND_PREBUILT is 0, the # Macros for the build-release target. If FRONTEND_PREBUILT is 0, the
# default, the macro $(BUILD_RELEASE_DEPS_$(FRONTEND_PREBUILT)) expands # default, the macro $(BUILD_RELEASE_DEPS_$(FRONTEND_PREBUILT)) expands
@ -63,7 +63,7 @@ ENV = env\
PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\ PATH="$${PWD}/bin:$$( "$(GO.MACRO)" env GOPATH )/bin:$${PATH}"\
RACE='$(RACE)'\ RACE='$(RACE)'\
SIGN='$(SIGN)'\ SIGN='$(SIGN)'\
V1API='$(V1API)'\ NEXTAPI='$(NEXTAPI)'\
VERBOSE='$(VERBOSE)'\ VERBOSE='$(VERBOSE)'\
VERSION='$(VERSION)'\ VERSION='$(VERSION)'\

View File

@ -7,7 +7,7 @@
# Make sure to sync any changes with the branch overrides below. # Make sure to sync any changes with the branch overrides below.
'variables': 'variables':
'channel': 'edge' 'channel': 'edge'
'dockerGo': 'adguard/golang-ubuntu:5.1' 'dockerGo': 'adguard/golang-ubuntu:5.2'
'stages': 'stages':
- 'Build frontend': - 'Build frontend':
@ -322,7 +322,7 @@
# need to build a few of these. # need to build a few of these.
'variables': 'variables':
'channel': 'beta' 'channel': 'beta'
'dockerGo': 'adguard/golang-ubuntu:5.1' 'dockerGo': 'adguard/golang-ubuntu:5.2'
# release-vX.Y.Z branches are the branches from which the actual final release # release-vX.Y.Z branches are the branches from which the actual final release
# is built. # is built.
- '^release-v[0-9]+\.[0-9]+\.[0-9]+': - '^release-v[0-9]+\.[0-9]+\.[0-9]+':
@ -337,4 +337,4 @@
# are the ones that actually get released. # are the ones that actually get released.
'variables': 'variables':
'channel': 'release' 'channel': 'release'
'dockerGo': 'adguard/golang-ubuntu:5.1' 'dockerGo': 'adguard/golang-ubuntu:5.2'

View File

@ -5,7 +5,7 @@
'key': 'AHBRTSPECS' 'key': 'AHBRTSPECS'
'name': 'AdGuard Home - Build and run tests' 'name': 'AdGuard Home - Build and run tests'
'variables': 'variables':
'dockerGo': 'adguard/golang-ubuntu:5.1' 'dockerGo': 'adguard/golang-ubuntu:5.2'
'stages': 'stages':
- 'Tests': - 'Tests':

View File

@ -215,6 +215,7 @@
"example_upstream_udp": "regular DNS (over UDP, hostname);", "example_upstream_udp": "regular DNS (over UDP, hostname);",
"example_upstream_dot": "encrypted <0>DNS-over-TLS</0>;", "example_upstream_dot": "encrypted <0>DNS-over-TLS</0>;",
"example_upstream_doh": "encrypted <0>DNS-over-HTTPS</0>;", "example_upstream_doh": "encrypted <0>DNS-over-HTTPS</0>;",
"example_upstream_doh3": "encrypted DNS-over-HTTPS with forced <0>HTTP/3</0> and no fallback to HTTP/2 or below;",
"example_upstream_doq": "encrypted <0>DNS-over-QUIC</0>;", "example_upstream_doq": "encrypted <0>DNS-over-QUIC</0>;",
"example_upstream_sdns": "<0>DNS Stamps</0> for <1>DNSCrypt</1> or <2>DNS-over-HTTPS</2> resolvers;", "example_upstream_sdns": "<0>DNS Stamps</0> for <1>DNSCrypt</1> or <2>DNS-over-HTTPS</2> resolvers;",
"example_upstream_tcp": "regular DNS (over TCP);", "example_upstream_tcp": "regular DNS (over TCP);",
@ -605,7 +606,7 @@
"blocklist": "Blocklist", "blocklist": "Blocklist",
"milliseconds_abbreviation": "ms", "milliseconds_abbreviation": "ms",
"cache_size": "Cache size", "cache_size": "Cache size",
"cache_size_desc": "DNS cache size (in bytes).", "cache_size_desc": "DNS cache size (in bytes). To disable caching, leave empty.",
"cache_ttl_min_override": "Override minimum TTL", "cache_ttl_min_override": "Override minimum TTL",
"cache_ttl_max_override": "Override maximum TTL", "cache_ttl_max_override": "Override maximum TTL",
"enter_cache_size": "Enter cache size (bytes)", "enter_cache_size": "Enter cache size (bytes)",

View File

@ -121,7 +121,7 @@ const ClientCell = ({
{options.map(({ name, onClick, disabled }) => ( {options.map(({ name, onClick, disabled }) => (
<button <button
key={name} key={name}
className="button-action--arrow-option px-4 py-2" className="button-action--arrow-option px-4 py-1"
onClick={onClick} onClick={onClick}
disabled={disabled} disabled={disabled}
> >

View File

@ -50,9 +50,30 @@
} }
@media (max-width: 1024px) { @media (max-width: 1024px) {
.grid .key-colon, .grid .title--border { .grid .title--border {
margin-bottom: 4px;
font-weight: 600; font-weight: 600;
} }
.grid .key-colon {
margin-right: 4px;
color: var(--gray-8);
}
.grid__row {
display: flex;
align-items: flex-start;
flex-wrap: wrap;
margin-bottom: 2px;
font-size: 14px;
word-break: break-all;
overflow: hidden;
}
.grid__row .filteringRules__filter,
.grid__row .filteringRules {
margin-bottom: 0;
}
} }
@media (max-width: 767.98px) { @media (max-width: 767.98px) {
@ -100,7 +121,7 @@
} }
.title--border { .title--border {
padding-top: 2rem; padding-top: 1rem;
} }
.title--border:before { .title--border:before {
@ -109,7 +130,7 @@
left: 0; left: 0;
border-top: 0.5px solid var(--gray-d8) !important; border-top: 0.5px solid var(--gray-d8) !important;
width: 100%; width: 100%;
margin-top: -1rem; margin-top: -0.5rem;
} }
.icon-cross { .icon-cross {

View File

@ -146,7 +146,7 @@ const Row = memo(({
type="button" type="button"
className={ className={
classNames( classNames(
'button-action--arrow-option', 'button-action--arrow-option mb-1',
{ 'bg--danger': !isBlocked }, { 'bg--danger': !isBlocked },
{ 'bg--green': isFiltered }, { 'bg--green': isFiltered },
)} )}
@ -158,13 +158,13 @@ const Row = memo(({
); );
const blockForClientButton = <button const blockForClientButton = <button
className='text-center font-weight-bold py-2 button-action--arrow-option' className='text-center font-weight-bold py-1 button-action--arrow-option'
onClick={onBlockingForClientClick}> onClick={onBlockingForClientClick}>
{t(blockingForClientKey)} {t(blockingForClientKey)}
</button>; </button>;
const blockClientButton = <button const blockClientButton = <button
className='text-center font-weight-bold py-2 button-action--arrow-option' className='text-center font-weight-bold py-1 button-action--arrow-option'
onClick={onBlockingClientClick} onClick={onBlockingClientClick}
disabled={processingSet || lastRuleInAllowlist}> disabled={processingSet || lastRuleInAllowlist}>
{t(blockingClientKey)} {t(blockingClientKey)}

View File

@ -312,8 +312,8 @@
border: 0; border: 0;
display: block; display: block;
width: 100%; width: 100%;
padding-top: 0.5rem; padding-top: 0.2rem;
padding-bottom: 0.5rem; padding-bottom: 0.2rem;
text-align: center; text-align: center;
font-weight: 700; font-weight: 700;
color: inherit; color: inherit;

View File

@ -47,17 +47,20 @@ const processContent = (data) => Object.entries(data)
keyClass = ''; keyClass = '';
} }
return isHidden ? null : <div key={key}> return isHidden ? null : (
<div <div className="grid__row" key={key}>
<div
className={classNames(`key__${key}`, keyClass, { className={classNames(`key__${key}`, keyClass, {
'font-weight-bold': isBoolean && value === true, 'font-weight-bold': isBoolean && value === true,
})}> })}
<Trans>{isButton ? value : key}</Trans> >
<Trans>{isButton ? value : key}</Trans>
</div>
<div className={`value__${key} text-pre text-truncate`}>
<Trans>{(isTitle || isButton || isBoolean) ? '' : value || '—'}</Trans>
</div>
</div> </div>
<div className={`value__${key} text-pre text-truncate`}> );
<Trans>{(isTitle || isButton || isBoolean) ? '' : value || '—'}</Trans>
</div>
</div>;
}); });
const Logs = () => { const Logs = () => {

View File

@ -57,6 +57,22 @@ const Examples = (props) => (
example_upstream_doh example_upstream_doh
</Trans> </Trans>
</li> </li>
<li>
<code>h3://unfiltered.adguard-dns.com/dns-query</code>: <Trans
components={[
<a
href="https://en.wikipedia.org/wiki/HTTP/3"
target="_blank"
rel="noopener noreferrer"
key="0"
>
HTTP/3
</a>,
]}
>
example_upstream_doh3
</Trans>
</li>
<li> <li>
<code>quic://unfiltered.adguard-dns.com</code>: <Trans <code>quic://unfiltered.adguard-dns.com</code>: <Trans
components={[ components={[

View File

@ -0,0 +1,33 @@
// Package aghchan contains channel utilities.
package aghchan
import (
"fmt"
"time"
)
// Receive returns an error if it cannot receive a value form c before timeout
// runs out.
func Receive[T any](c <-chan T, timeout time.Duration) (v T, ok bool, err error) {
var zero T
timeoutCh := time.After(timeout)
select {
case <-timeoutCh:
// TODO(a.garipov): Consider implementing [errors.Aser] for
// os.ErrTimeout.
return zero, false, fmt.Errorf("did not receive after %s", timeout)
case v, ok = <-c:
return v, ok, nil
}
}
// MustReceive panics if it cannot receive a value form c before timeout runs
// out.
func MustReceive[T any](c <-chan T, timeout time.Duration) (v T, ok bool) {
v, ok, err := Receive(c, timeout)
if err != nil {
panic(err)
}
return v, ok
}

View File

@ -62,9 +62,16 @@ func WriteTextPlainDeprecated(w http.ResponseWriter, r *http.Request) (isPlainTe
} }
// WriteJSONResponse sets the content-type header in w.Header() to // WriteJSONResponse sets the content-type header in w.Header() to
// "application/json", encodes resp to w, calls Error on any returned error, and // "application/json", writes a header with a "200 OK" status, encodes resp to
// returns it as well. // w, calls [Error] on any returned error, and returns it as well.
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) { func WriteJSONResponse(w http.ResponseWriter, r *http.Request, resp any) (err error) {
return WriteJSONResponseCode(w, r, http.StatusOK, resp)
}
// WriteJSONResponseCode is like [WriteJSONResponse] but adds the ability to
// redefine the status code.
func WriteJSONResponseCode(w http.ResponseWriter, r *http.Request, code int, resp any) (err error) {
w.WriteHeader(code)
w.Header().Set(HdrNameContentType, HdrValApplicationJSON) w.Header().Set(HdrNameContentType, HdrValApplicationJSON)
err = json.NewEncoder(w).Encode(resp) err = json.NewEncoder(w).Encode(resp)
if err != nil { if err != nil {

View File

@ -10,9 +10,9 @@ import (
"testing/fstest" "testing/fstest"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter"
@ -163,15 +163,9 @@ func TestHostsContainer_refresh(t *testing.T) {
checkRefresh := func(t *testing.T, want *HostsRecord) { checkRefresh := func(t *testing.T, want *HostsRecord) {
t.Helper() t.Helper()
var ok bool upd, ok := aghchan.MustReceive(hc.Upd(), 1*time.Second)
var upd *netutil.IPMap require.True(t, ok)
select { require.NotNil(t, upd)
case upd, ok = <-hc.Upd():
require.True(t, ok)
require.NotNil(t, upd)
case <-time.After(1 * time.Second):
t.Fatal("did not receive after 1s")
}
assert.Equal(t, 1, upd.Len()) assert.Equal(t, 1, upd.Len())

View File

@ -15,11 +15,11 @@ import (
// errFSOpen. // errFSOpen.
type errFS struct{} type errFS struct{}
// errFSOpen is returned from errGlobFS.Open. // errFSOpen is returned from errFS.Open.
const errFSOpen errors.Error = "test open error" const errFSOpen errors.Error = "test open error"
// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and // Open implements the fs.FS interface for *errFS. fsys is always nil and err
// err is always errFSOpen. // is always errFSOpen.
func (efs *errFS) Open(name string) (fsys fs.File, err error) { func (efs *errFS) Open(name string) (fsys fs.File, err error) {
return nil, errFSOpen return nil, errFSOpen
} }

View File

@ -175,11 +175,21 @@ func RootDirFS() (fsys fs.FS) {
return os.DirFS("") return os.DirFS("")
} }
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
func NotifyReconfigureSignal(c chan<- os.Signal) {
notifyReconfigureSignal(c)
}
// NotifyShutdownSignal notifies c on receiving shutdown signals. // NotifyShutdownSignal notifies c on receiving shutdown signals.
func NotifyShutdownSignal(c chan<- os.Signal) { func NotifyShutdownSignal(c chan<- os.Signal) {
notifyShutdownSignal(c) notifyShutdownSignal(c)
} }
// IsReconfigureSignal returns true if sig is a reconfigure signal.
func IsReconfigureSignal(sig os.Signal) (ok bool) {
return isReconfigureSignal(sig)
}
// IsShutdownSignal returns true if sig is a shutdown signal. // IsShutdownSignal returns true if sig is a shutdown signal.
func IsShutdownSignal(sig os.Signal) (ok bool) { func IsShutdownSignal(sig os.Signal) (ok bool) {
return isShutdownSignal(sig) return isShutdownSignal(sig)

View File

@ -9,10 +9,18 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) { func notifyShutdownSignal(c chan<- os.Signal) {
signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM) signal.Notify(c, unix.SIGINT, unix.SIGQUIT, unix.SIGTERM)
} }
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == unix.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) { func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig { switch sig {
case case

View File

@ -39,12 +39,20 @@ func isOpenWrt() (ok bool) {
return false return false
} }
func notifyReconfigureSignal(c chan<- os.Signal) {
signal.Notify(c, windows.SIGHUP)
}
func notifyShutdownSignal(c chan<- os.Signal) { func notifyShutdownSignal(c chan<- os.Signal) {
// syscall.SIGTERM is processed automatically. See go doc os/signal, // syscall.SIGTERM is processed automatically. See go doc os/signal,
// section Windows. // section Windows.
signal.Notify(c, os.Interrupt) signal.Notify(c, os.Interrupt)
} }
func isReconfigureSignal(sig os.Signal) (ok bool) {
return sig == windows.SIGHUP
}
func isShutdownSignal(sig os.Signal) (ok bool) { func isShutdownSignal(sig os.Signal) (ok bool) {
switch sig { switch sig {
case case

View File

@ -1,10 +1,12 @@
package aghtest package aghtest
import ( import (
"context"
"io/fs" "io/fs"
"net" "net"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -15,6 +17,8 @@ import (
// Standard Library // Standard Library
// Package fs
// type check // type check
var _ fs.FS = &FS{} var _ fs.FS = &FS{}
@ -58,6 +62,8 @@ func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) {
return fsys.OnStat(name) return fsys.OnStat(name)
} }
// Package net
// type check // type check
var _ net.Listener = (*Listener)(nil) var _ net.Listener = (*Listener)(nil)
@ -83,31 +89,9 @@ func (l *Listener) Close() (err error) {
return l.OnClose() return l.OnClose()
} }
// Module dnsproxy // Module adguard-home
// type check // Package aghos
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}
// Module AdGuardHome
// type check // type check
var _ aghos.FSWatcher = (*FSWatcher)(nil) var _ aghos.FSWatcher = (*FSWatcher)(nil)
@ -133,3 +117,59 @@ func (w *FSWatcher) Add(name string) (err error) {
func (w *FSWatcher) Close() (err error) { func (w *FSWatcher) Close() (err error) {
return w.OnClose() return w.OnClose()
} }
// Package agh
// type check
var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
// ServiceWithConfig is a mock [agh.ServiceWithConfig] implementation for tests.
type ServiceWithConfig[ConfigType any] struct {
OnStart func() (err error)
OnShutdown func(ctx context.Context) (err error)
OnConfig func() (c ConfigType)
}
// Start implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[_]) Start() (err error) {
return s.OnStart()
}
// Shutdown implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) {
return s.OnShutdown(ctx)
}
// Config implements the [agh.ServiceWithConfig] interface for
// *ServiceWithConfig.
func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
return s.OnConfig()
}
// Module dnsproxy
// Package upstream
// type check
var _ upstream.Upstream = (*UpstreamMock)(nil)
// UpstreamMock is a mock [upstream.Upstream] implementation for tests.
//
// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and
// rename it to just Upstream.
type UpstreamMock struct {
OnAddress func() (addr string)
OnExchange func(req *dns.Msg) (resp *dns.Msg, err error)
}
// Address implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Address() (addr string) {
return u.OnAddress()
}
// Exchange implements the [upstream.Upstream] interface for *UpstreamMock.
func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
return u.OnExchange(req)
}

View File

@ -1,9 +1,3 @@
package aghtest_test package aghtest_test
import ( // Put interface checks that cause import cycles here.
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
)
// type check
var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil)

View File

@ -78,18 +78,7 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status.Leases = s.Leases(LeasesDynamic) status.Leases = s.Leases(LeasesDynamic)
status.StaticLeases = s.Leases(LeasesStatic) status.StaticLeases = s.Leases(LeasesStatic)
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, status)
err := json.NewEncoder(w).Encode(status)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to marshal DHCP status json: %s",
err,
)
}
} }
func (s *server) enableDHCP(ifaceName string) (code int, err error) { func (s *server) enableDHCP(ifaceName string) (code int, err error) {
@ -246,22 +235,7 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
if conf.Enabled != aghalg.NBNull { s.setConfFromJSON(conf, srv4, srv6)
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
}
if conf.InterfaceName != "" {
s.conf.InterfaceName = conf.InterfaceName
}
if srv4 != nil {
s.srv4 = srv4
}
if srv6 != nil {
s.srv6 = srv6
}
s.conf.ConfigModified() s.conf.ConfigModified()
err = s.dbLoad() err = s.dbLoad()
@ -280,6 +254,26 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
} }
} }
// setConfFromJSON sets configuration parameters in s from the new configuration
// decoded from JSON.
func (s *server) setConfFromJSON(conf *dhcpServerConfigJSON, srv4, srv6 DHCPServer) {
if conf.Enabled != aghalg.NBNull {
s.conf.Enabled = conf.Enabled == aghalg.NBTrue
}
if conf.InterfaceName != "" {
s.conf.InterfaceName = conf.InterfaceName
}
if srv4 != nil {
s.srv4 = srv4
}
if srv6 != nil {
s.srv6 = srv6
}
}
type netInterfaceJSON struct { type netInterfaceJSON struct {
Name string `json:"name"` Name string `json:"name"`
HardwareAddr string `json:"hardware_address"` HardwareAddr string `json:"hardware_address"`

View File

@ -3,11 +3,10 @@
package dhcpd package dhcpd
import ( import (
"encoding/json"
"net/http" "net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/log"
) )
// jsonError is a generic JSON error response. // jsonError is a generic JSON error response.
@ -25,15 +24,9 @@ type jsonError struct {
// TODO(a.garipov): Either take the logger from the server after we've // TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server. // refactored logging or make this not a method of *Server.
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) { func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponseCode(w, r, http.StatusNotImplemented, &jsonError{
w.WriteHeader(http.StatusNotImplemented)
err := json.NewEncoder(w).Encode(&jsonError{
Message: aghos.Unsupported("dhcp").Error(), Message: aghos.Unsupported("dhcp").Error(),
}) })
if err != nil {
log.Debug("writing 501 json response: %s", err)
}
} }
// registerHandlers sets the handlers for DHCP HTTP API that always respond with // registerHandlers sets the handlers for DHCP HTTP API that always respond with

View File

@ -123,7 +123,14 @@ type quicConnection interface {
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) { func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto proto := pctx.Proto
if proto == proxy.ProtoHTTPS { if proto == proxy.ProtoHTTPS {
return clientIDFromDNSContextHTTPS(pctx) clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC { } else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil return "", nil
} }
@ -133,31 +140,9 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
return "", nil return "", nil
} }
cliSrvName := "" cliSrvName, err := clientServerName(pctx, proto)
switch proto { if err != nil {
case proxy.ProtoTLS: return "", err
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
return "", fmt.Errorf(
"proxy ctx conn of proto %s is %T, want *tls.Conn",
proto,
conn,
)
}
cliSrvName = tc.ConnectionState().ServerName
case proxy.ProtoQUIC:
conn, ok := pctx.QUICConnection.(quicConnection)
if !ok {
return "", fmt.Errorf(
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
proto,
pctx.QUICConnection,
)
}
cliSrvName = conn.ConnectionState().TLS.ServerName
} }
clientID, err = clientIDFromClientServerName( clientID, err = clientIDFromClientServerName(
@ -171,3 +156,35 @@ func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string
return clientID, nil return clientID, nil
} }
// clientServerName returns the TLS server name based on the protocol.
func clientServerName(pctx *proxy.DNSContext, proto proxy.Proto) (srvName string, err error) {
switch proto {
case proxy.ProtoHTTPS:
if connState := pctx.HTTPRequest.TLS; connState != nil {
srvName = pctx.HTTPRequest.TLS.ServerName
}
case proxy.ProtoQUIC:
qConn := pctx.QUICConnection
conn, ok := qConn.(quicConnection)
if !ok {
return "", fmt.Errorf(
"proxy ctx quic conn of proto %s is %T, want quic.Connection",
proto,
qConn,
)
}
srvName = conn.ConnectionState().TLS.ServerName
case proxy.ProtoTLS:
conn := pctx.Conn
tc, ok := conn.(tlsConn)
if !ok {
return "", fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
}
srvName = tc.ConnectionState().ServerName
}
return srvName, nil
}

View File

@ -160,6 +160,22 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
wantClientID: "insensitive", wantClientID: "insensitive",
wantErrMsg: ``, wantErrMsg: ``,
strictSNI: true, strictSNI: true,
}, {
name: "https_no_clientid",
proto: proxy.ProtoHTTPS,
hostSrvName: "example.com",
cliSrvName: "example.com",
wantClientID: "",
wantErrMsg: "",
strictSNI: true,
}, {
name: "https_clientid",
proto: proxy.ProtoHTTPS,
hostSrvName: "example.com",
cliSrvName: "cli.example.com",
wantClientID: "cli",
wantErrMsg: "",
strictSNI: true,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
@ -173,16 +189,32 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
conf: ServerConfig{TLSConfig: tlsConf}, conf: ServerConfig{TLSConfig: tlsConf},
} }
var conn net.Conn var (
if tc.proto == proxy.ProtoTLS { conn net.Conn
conn = testTLSConn{ qconn quic.Connection
httpReq *http.Request
)
switch tc.proto {
case proxy.ProtoHTTPS:
u := &url.URL{
Path: "/dns-query",
}
connState := &tls.ConnectionState{
ServerName: tc.cliSrvName,
}
httpReq = &http.Request{
URL: u,
TLS: connState,
}
case proxy.ProtoQUIC:
qconn = testQUICConnection{
serverName: tc.cliSrvName, serverName: tc.cliSrvName,
} }
} case proxy.ProtoTLS:
conn = testTLSConn{
var qconn quic.Connection
if tc.proto == proxy.ProtoQUIC {
qconn = testQUICConnection{
serverName: tc.cliSrvName, serverName: tc.cliSrvName,
} }
} }
@ -190,6 +222,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{
Proto: tc.proto, Proto: tc.proto,
Conn: conn, Conn: conn,
HTTPRequest: httpReq,
QUICConnection: qconn, QUICConnection: qconn,
} }
@ -205,56 +238,76 @@ func TestClientIDFromDNSContextHTTPS(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
path string path string
cliSrvName string
wantClientID string wantClientID string
wantErrMsg string wantErrMsg string
}{{ }{{
name: "no_clientid", name: "no_clientid",
path: "/dns-query", path: "/dns-query",
cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "no_clientid_slash", name: "no_clientid_slash",
path: "/dns-query/", path: "/dns-query/",
cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "clientid", name: "clientid",
path: "/dns-query/cli", path: "/dns-query/cli",
cliSrvName: "example.com",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "clientid_slash", name: "clientid_slash",
path: "/dns-query/cli/", path: "/dns-query/cli/",
cliSrvName: "example.com",
wantClientID: "cli", wantClientID: "cli",
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "clientid_case", name: "clientid_case",
path: "/dns-query/InSeNsItIvE", path: "/dns-query/InSeNsItIvE",
cliSrvName: "example.com",
wantClientID: "insensitive", wantClientID: "insensitive",
wantErrMsg: ``, wantErrMsg: ``,
}, { }, {
name: "bad_url", name: "bad_url",
path: "/foo", path: "/foo",
cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: `clientid check: invalid path "/foo"`, wantErrMsg: `clientid check: invalid path "/foo"`,
}, { }, {
name: "extra", name: "extra",
path: "/dns-query/cli/foo", path: "/dns-query/cli/foo",
cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: `clientid check: invalid path "/dns-query/cli/foo": extra parts`, wantErrMsg: `clientid check: invalid path "/dns-query/cli/foo": extra parts`,
}, { }, {
name: "invalid_clientid", name: "invalid_clientid",
path: "/dns-query/!!!", path: "/dns-query/!!!",
cliSrvName: "example.com",
wantClientID: "", wantClientID: "",
wantErrMsg: `clientid check: invalid clientid "!!!": bad domain name label rune '!'`, wantErrMsg: `clientid check: invalid clientid "!!!": bad domain name label rune '!'`,
}, {
name: "both_ids",
path: "/dns-query/right",
cliSrvName: "wrong.example.com",
wantClientID: "right",
wantErrMsg: "",
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
connState := &tls.ConnectionState{
ServerName: tc.cliSrvName,
}
r := &http.Request{ r := &http.Request{
URL: &url.URL{ URL: &url.URL{
Path: tc.path, Path: tc.path,
}, },
TLS: connState,
} }
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{

View File

@ -453,13 +453,7 @@ func (d *DNSFilter) ApplyBlockedServices(setts *Settings, list []string) {
} }
func (d *DNSFilter) handleBlockedServicesAvailableServices(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleBlockedServicesAvailableServices(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, serviceIDs)
err := json.NewEncoder(w).Encode(serviceIDs)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding available services: %s", err)
return
}
} }
func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Request) {
@ -467,13 +461,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
list := d.Config.BlockedServices list := d.Config.BlockedServices
d.confLock.RUnlock() d.confLock.RUnlock()
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, list)
err := json.NewEncoder(w).Encode(list)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding services: %s", err)
return
}
} }
func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {

View File

@ -301,14 +301,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
return return
} }
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, resp)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
} }
type filterJSON struct { type filterJSON struct {
@ -361,17 +354,7 @@ func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request
resp.UserRules = d.UserRules resp.UserRules = d.UserRules
d.filtersMu.RUnlock() d.filtersMu.RUnlock()
jsonVal, err := json.Marshal(resp) _ = aghhttp.WriteJSONResponse(w, r, resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
}
} }
// Set filtering configuration // Set filtering configuration
@ -473,11 +456,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
} }
} }
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, resp)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "encoding response: %s", err)
}
} }
// RegisterFilteringHandlers - register handlers // RegisterFilteringHandlers - register handlers

View File

@ -240,13 +240,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
} }
d.confLock.Unlock() d.confLock.Unlock()
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, arr)
err := json.NewEncoder(w).Encode(arr)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
return
}
} }
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {

View File

@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -381,17 +380,13 @@ func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Req
} }
func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") resp := &struct {
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
}{ }{
Enabled: d.Config.SafeBrowsingEnabled, Enabled: d.Config.SafeBrowsingEnabled,
})
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
return
} }
_ = aghhttp.WriteJSONResponse(w, r, resp)
} }
func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
@ -405,13 +400,11 @@ func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request
} }
func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") resp := &struct {
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
}{ }{
Enabled: d.Config.ParentalEnabled, Enabled: d.Config.ParentalEnabled,
})
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
} }
_ = aghhttp.WriteJSONResponse(w, r, resp)
} }

View File

@ -5,7 +5,6 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"encoding/gob" "encoding/gob"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -146,21 +145,13 @@ func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Reque
} }
func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) { func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") resp := &struct {
err := json.NewEncoder(w).Encode(&struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
}{ }{
Enabled: d.Config.SafeSearchEnabled, Enabled: d.Config.SafeSearchEnabled,
})
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Unable to write response json: %s",
err,
)
} }
_ = aghhttp.WriteJSONResponse(w, r, resp)
} }
var safeSearchDomains = map[string]string{ var safeSearchDomains = map[string]string{

View File

@ -12,16 +12,11 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
func TestNewSessionToken(t *testing.T) { func TestNewSessionToken(t *testing.T) {
// Successful case. // Successful case.
token, err := newSessionToken() token, err := newSessionToken()

View File

@ -97,9 +97,15 @@ var Context homeContext
// Main is the entry point // Main is the entry point
func Main(clientBuildFS fs.FS) { func Main(clientBuildFS fs.FS) {
// config can be specified, which reads options from there, but other command line flags have to override config values initCmdLineOpts()
// therefore, we must do it manually instead of using a lib
args := loadOptions() // The configuration file path can be overridden, but other command-line
// options have to override config values. Therefore, do it manually
// instead of using package flag.
//
// TODO(a.garipov): The comment above is most likely false. Replace with
// package flag.
opts := loadCmdLineOpts()
Context.appSignalChannel = make(chan os.Signal) Context.appSignalChannel = make(chan os.Signal)
signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) signal.Notify(Context.appSignalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT)
@ -120,26 +126,18 @@ func Main(clientBuildFS fs.FS) {
} }
}() }()
if args.serviceControlAction != "" { if opts.serviceControlAction != "" {
handleServiceControlAction(args, clientBuildFS) handleServiceControlAction(opts, clientBuildFS)
return return
} }
// run the protection // run the protection
run(args, clientBuildFS) run(opts, clientBuildFS)
} }
func setupContext(args options) { func setupContext(opts options) {
Context.runningAsService = args.runningAsService setupContextFlags(opts)
Context.disableUpdate = args.disableUpdate ||
version.Channel() == version.ChannelDevelopment
Context.firstRun = detectFirstRun()
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
}
switch version.Channel() { switch version.Channel() {
case version.ChannelEdge, version.ChannelDevelopment: case version.ChannelEdge, version.ChannelDevelopment:
@ -174,13 +172,13 @@ func setupContext(args options) {
os.Exit(1) os.Exit(1)
} }
if args.checkConfig { if opts.checkConfig {
log.Info("configuration file is ok") log.Info("configuration file is ok")
os.Exit(0) os.Exit(0)
} }
if !args.noEtcHosts && config.Clients.Sources.HostsFile { if !opts.noEtcHosts && config.Clients.Sources.HostsFile {
err = setupHostsContainer() err = setupHostsContainer()
fatalOnError(err) fatalOnError(err)
} }
@ -189,6 +187,24 @@ func setupContext(args options) {
Context.mux = http.NewServeMux() Context.mux = http.NewServeMux()
} }
// setupContextFlags sets global flags and prints their status to the log.
func setupContextFlags(opts options) {
Context.firstRun = detectFirstRun()
if Context.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkPermissions()
}
Context.runningAsService = opts.runningAsService
// Don't print the runningAsService flag, since that has already been done
// in [run].
Context.disableUpdate = opts.disableUpdate || version.Channel() == version.ChannelDevelopment
if Context.disableUpdate {
log.Info("AdGuard Home updates are disabled")
}
}
// logIfUnsupported logs a formatted warning if the error is one of the // logIfUnsupported logs a formatted warning if the error is one of the
// unsupported errors and returns nil. If err is nil, logIfUnsupported returns // unsupported errors and returns nil. If err is nil, logIfUnsupported returns
// nil. Otherwise, it returns err. // nil. Otherwise, it returns err.
@ -270,7 +286,7 @@ func setupHostsContainer() (err error) {
return nil return nil
} }
func setupConfig(args options) (err error) { func setupConfig(opts options) (err error) {
config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts config.DNS.DnsfilterConf.EtcHosts = Context.etcHosts
config.DNS.DnsfilterConf.ConfigModified = onConfigModified config.DNS.DnsfilterConf.ConfigModified = onConfigModified
config.DNS.DnsfilterConf.HTTPRegister = httpRegister config.DNS.DnsfilterConf.HTTPRegister = httpRegister
@ -312,9 +328,9 @@ func setupConfig(args options) (err error) {
Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb) Context.clients.Init(config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb)
if args.bindPort != 0 { if opts.bindPort != 0 {
tcpPorts := aghalg.UniqChecker[tcpPort]{} tcpPorts := aghalg.UniqChecker[tcpPort]{}
addPorts(tcpPorts, tcpPort(args.bindPort), tcpPort(config.BetaBindPort)) addPorts(tcpPorts, tcpPort(opts.bindPort), tcpPort(config.BetaBindPort))
udpPorts := aghalg.UniqChecker[udpPort]{} udpPorts := aghalg.UniqChecker[udpPort]{}
addPorts(udpPorts, udpPort(config.DNS.Port)) addPorts(udpPorts, udpPort(config.DNS.Port))
@ -336,23 +352,23 @@ func setupConfig(args options) (err error) {
return fmt.Errorf("validating udp ports: %w", err) return fmt.Errorf("validating udp ports: %w", err)
} }
config.BindPort = args.bindPort config.BindPort = opts.bindPort
} }
// override bind host/port from the console // override bind host/port from the console
if args.bindHost != nil { if opts.bindHost != nil {
config.BindHost = args.bindHost config.BindHost = opts.bindHost
} }
if len(args.pidFile) != 0 && writePIDFile(args.pidFile) { if len(opts.pidFile) != 0 && writePIDFile(opts.pidFile) {
Context.pidFileName = args.pidFile Context.pidFileName = opts.pidFile
} }
return nil return nil
} }
func initWeb(args options, clientBuildFS fs.FS) (web *Web, err error) { func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
var clientFS, clientBetaFS fs.FS var clientFS, clientBetaFS fs.FS
if args.localFrontend { if opts.localFrontend {
log.Info("warning: using local frontend files") log.Info("warning: using local frontend files")
clientFS = os.DirFS("build/static") clientFS = os.DirFS("build/static")
@ -400,24 +416,24 @@ func fatalOnError(err error) {
} }
// run configures and starts AdGuard Home. // run configures and starts AdGuard Home.
func run(args options, clientBuildFS fs.FS) { func run(opts options, clientBuildFS fs.FS) {
// configure config filename // configure config filename
initConfigFilename(args) initConfigFilename(opts)
// configure working dir and config path // configure working dir and config path
initWorkingDir(args) initWorkingDir(opts)
// configure log level and output // configure log level and output
configureLogger(args) configureLogger(opts)
// Print the first message after logger is configured. // Print the first message after logger is configured.
log.Info(version.Full()) log.Info(version.Full())
log.Debug("current working directory is %s", Context.workDir) log.Debug("current working directory is %s", Context.workDir)
if args.runningAsService { if opts.runningAsService {
log.Info("AdGuard Home is running as a service") log.Info("AdGuard Home is running as a service")
} }
setupContext(args) setupContext(opts)
err := configureOS(config) err := configureOS(config)
fatalOnError(err) fatalOnError(err)
@ -427,7 +443,7 @@ func run(args options, clientBuildFS fs.FS) {
// but also avoid relying on automatic Go init() function // but also avoid relying on automatic Go init() function
filtering.InitModule() filtering.InitModule()
err = setupConfig(args) err = setupConfig(opts)
fatalOnError(err) fatalOnError(err)
if !Context.firstRun { if !Context.firstRun {
@ -456,7 +472,7 @@ func run(args options, clientBuildFS fs.FS) {
} }
sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") sessFilename := filepath.Join(Context.getDataDir(), "sessions.db")
GLMode = args.glinetMode GLMode = opts.glinetMode
var rateLimiter *authRateLimiter var rateLimiter *authRateLimiter
if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { if config.AuthAttempts > 0 && config.AuthBlockMin > 0 {
rateLimiter = newAuthRateLimiter( rateLimiter = newAuthRateLimiter(
@ -483,7 +499,7 @@ func run(args options, clientBuildFS fs.FS) {
log.Fatalf("Can't initialize TLS module") log.Fatalf("Can't initialize TLS module")
} }
Context.web, err = initWeb(args, clientBuildFS) Context.web, err = initWeb(opts, clientBuildFS)
fatalOnError(err) fatalOnError(err)
if !Context.firstRun { if !Context.firstRun {
@ -575,10 +591,10 @@ func writePIDFile(fn string) bool {
return true return true
} }
func initConfigFilename(args options) { func initConfigFilename(opts options) {
// config file path can be overridden by command-line arguments: // config file path can be overridden by command-line arguments:
if args.configFilename != "" { if opts.confFilename != "" {
Context.configFilename = args.configFilename Context.configFilename = opts.confFilename
} else { } else {
// Default config file name // Default config file name
Context.configFilename = "AdGuardHome.yaml" Context.configFilename = "AdGuardHome.yaml"
@ -587,15 +603,15 @@ func initConfigFilename(args options) {
// initWorkingDir initializes the workDir // initWorkingDir initializes the workDir
// if no command-line arguments specified, we use the directory where our binary file is located // if no command-line arguments specified, we use the directory where our binary file is located
func initWorkingDir(args options) { func initWorkingDir(opts options) {
execPath, err := os.Executable() execPath, err := os.Executable()
if err != nil { if err != nil {
panic(err) panic(err)
} }
if args.workDir != "" { if opts.workDir != "" {
// If there is a custom config file, use it's directory as our working dir // If there is a custom config file, use it's directory as our working dir
Context.workDir = args.workDir Context.workDir = opts.workDir
} else { } else {
Context.workDir = filepath.Dir(execPath) Context.workDir = filepath.Dir(execPath)
} }
@ -609,15 +625,15 @@ func initWorkingDir(args options) {
} }
// configureLogger configures logger level and output // configureLogger configures logger level and output
func configureLogger(args options) { func configureLogger(opts options) {
ls := getLogSettings() ls := getLogSettings()
// command-line arguments can override config settings // command-line arguments can override config settings
if args.verbose || config.Verbose { if opts.verbose || config.Verbose {
ls.Verbose = true ls.Verbose = true
} }
if args.logFile != "" { if opts.logFile != "" {
ls.File = args.logFile ls.File = opts.logFile
} else if config.File != "" { } else if config.File != "" {
ls.File = config.File ls.File = config.File
} }
@ -638,7 +654,7 @@ func configureLogger(args options) {
// happen pretty quickly. // happen pretty quickly.
log.SetFlags(log.LstdFlags | log.Lmicroseconds) log.SetFlags(log.LstdFlags | log.Lmicroseconds)
if args.runningAsService && ls.File == "" && runtime.GOOS == "windows" { if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
// When running as a Windows service, use eventlog by default if nothing // When running as a Windows service, use eventlog by default if nothing
// else is configured. Otherwise, we'll simply lose the log output. // else is configured. Otherwise, we'll simply lose the log output.
ls.File = configSyslog ls.File = configSyslog
@ -728,25 +744,29 @@ func exitWithError() {
os.Exit(64) os.Exit(64)
} }
// loadOptions reads command line arguments and initializes configuration // loadCmdLineOpts reads command line arguments and initializes configuration
func loadOptions() options { // from them. If there is an error or an effect, loadCmdLineOpts processes them
o, f, err := parse(os.Args[0], os.Args[1:]) // and exits.
func loadCmdLineOpts() (opts options) {
opts, eff, err := parseCmdOpts(os.Args[0], os.Args[1:])
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
_ = printHelp(os.Args[0]) printHelp(os.Args[0])
exitWithError() exitWithError()
} else if f != nil { }
err = f()
if eff != nil {
err = eff()
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
exitWithError() exitWithError()
} else {
os.Exit(0)
} }
os.Exit(0)
} }
return o return opts
} }
// printWebAddrs prints addresses built from proto, addr, and an appropriate // printWebAddrs prints addresses built from proto, addr, and an appropriate

View File

@ -0,0 +1,12 @@
package home
import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
initCmdLineOpts()
}

View File

@ -5,30 +5,60 @@ import (
"net" "net"
"os" "os"
"strconv" "strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/stringutil"
) )
// options passed from command-line arguments // TODO(a.garipov): Replace with package flag.
type options struct {
verbose bool // is verbose logging enabled
configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog
bindHost net.IP // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to
checkConfig bool // Check configuration and exit
disableUpdate bool // If set, don't check for updates
// service control action (see service.ControlAction array + "status" command) // options represents the command-line options.
type options struct {
// confFilename is the path to the configuration file.
confFilename string
// workDir is the path to the working directory where AdGuard Home stores
// filter data, the query log, and other data.
workDir string
// logFile is the path to the log file. If empty, AdGuard Home writes to
// stdout; if "syslog", to syslog.
logFile string
// pidFile is the file name for the file to which the PID is saved.
pidFile string
// serviceControlAction is the service action to perform. See
// [service.ControlAction] and [handleServiceControlAction].
serviceControlAction string serviceControlAction string
// runningAsService flag is set to true when options are passed from the service runner // bindHost is the address on which to serve the HTTP UI.
bindHost net.IP
// bindPort is the port on which to serve the HTTP UI.
bindPort int
// checkConfig is true if the current invocation is only required to check
// the configuration file and exit.
checkConfig bool
// disableUpdate, if set, makes AdGuard Home not check for updates.
disableUpdate bool
// verbose shows if verbose logging is enabled.
verbose bool
// runningAsService flag is set to true when options are passed from the
// service runner
//
// TODO(a.garipov): Perhaps this could be determined by a non-empty
// serviceControlAction?
runningAsService bool runningAsService bool
glinetMode bool // Activate GL-Inet compatibility mode // glinetMode shows if the GL-Inet compatibility mode is enabled.
glinetMode bool
// noEtcHosts flag should be provided when /etc/hosts file shouldn't be // noEtcHosts flag should be provided when /etc/hosts file shouldn't be
// used. // used.
@ -39,88 +69,85 @@ type options struct {
localFrontend bool localFrontend bool
} }
// functions used for their side-effects // initCmdLineOpts completes initialization of the global command-line option
type effect func() error // slice. It must only be called once.
func initCmdLineOpts() {
type arg struct { // The --help option cannot be put directly into cmdLineOpts, because that
description string // a short, English description of the argument // causes initialization cycle due to printHelp referencing cmdLineOpts.
longName string // the name of the argument used after '--' cmdLineOpts = append(cmdLineOpts, cmdLineOpt{
shortName string // the name of the argument used after '-' updateWithValue: nil,
updateNoValue: nil,
// only one of updateWithValue, updateNoValue, and effect should be present effect: func(o options, exec string) (effect, error) {
return func() error { printHelp(exec); exitWithError(); return nil }, nil
updateWithValue func(o options, v string) (options, error) // the mutator for arguments with parameters },
updateNoValue func(o options) (options, error) // the mutator for arguments without parameters serialize: func(o options) (val string, ok bool) { return "", false },
effect func(o options, exec string) (f effect, err error) // the side-effect closure generator description: "Print this help.",
longName: "help",
serialize func(o options) []string // the re-serialization function back to arguments (return nil for omit) shortName: "",
})
} }
// {type}SliceOrNil functions check their parameter of type {type} // effect is the type for functions used for their side-effects.
// against its zero value and return nil if the parameter value is type effect func() (err error)
// zero otherwise they return a string slice of the parameter
func ipSliceOrNil(ip net.IP) []string { // cmdLineOpt contains the data for a single command-line option. Only one of
if ip == nil { // updateWithValue, updateNoValue, and effect must be present.
return nil type cmdLineOpt struct {
} updateWithValue func(o options, v string) (updated options, err error)
updateNoValue func(o options) (updated options, err error)
effect func(o options, exec string) (eff effect, err error)
return []string{ip.String()} // serialize is a function that encodes the option into a slice of
// command-line arguments, if necessary. If ok is false, this option should
// be skipped.
serialize func(o options) (val string, ok bool)
description string
longName string
shortName string
} }
func stringSliceOrNil(s string) []string { // cmdLineOpts are all command-line options of AdGuard Home.
if s == "" { var cmdLineOpts = []cmdLineOpt{{
return nil updateWithValue: func(o options, v string) (options, error) {
} o.confFilename = v
return o, nil
},
updateNoValue: nil,
effect: nil,
serialize: func(o options) (val string, ok bool) {
return o.confFilename, o.confFilename != ""
},
description: "Path to the config file.",
longName: "config",
shortName: "c",
}, {
updateWithValue: func(o options, v string) (options, error) { o.workDir = v; return o, nil },
updateNoValue: nil,
effect: nil,
serialize: func(o options) (val string, ok bool) { return o.workDir, o.workDir != "" },
description: "Path to the working directory.",
longName: "work-dir",
shortName: "w",
}, {
updateWithValue: func(o options, v string) (options, error) {
o.bindHost = net.ParseIP(v)
return o, nil
},
updateNoValue: nil,
effect: nil,
serialize: func(o options) (val string, ok bool) {
if o.bindHost == nil {
return "", false
}
return []string{s} return o.bindHost.String(), true
} },
description: "Host address to bind HTTP server on.",
func intSliceOrNil(i int) []string { longName: "host",
if i == 0 { shortName: "h",
return nil }, {
} updateWithValue: func(o options, v string) (options, error) {
return []string{strconv.Itoa(i)}
}
func boolSliceOrNil(b bool) []string {
if b {
return []string{}
}
return nil
}
var args []arg
var configArg = arg{
"Path to the config file.",
"config", "c",
func(o options, v string) (options, error) { o.configFilename = v; return o, nil },
nil,
nil,
func(o options) []string { return stringSliceOrNil(o.configFilename) },
}
var workDirArg = arg{
"Path to the working directory.",
"work-dir", "w",
func(o options, v string) (options, error) { o.workDir = v; return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.workDir) },
}
var hostArg = arg{
"Host address to bind HTTP server on.",
"host", "h",
func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
func(o options) []string { return ipSliceOrNil(o.bindHost) },
}
var portArg = arg{
"Port to serve HTTP pages on.",
"port", "p",
func(o options, v string) (options, error) {
var err error var err error
var p int var p int
minPort, maxPort := 0, 1<<16-1 minPort, maxPort := 0, 1<<16-1
@ -131,108 +158,81 @@ var portArg = arg{
} else { } else {
o.bindPort = p o.bindPort = p
} }
return o, err
}, nil, nil,
func(o options) []string { return intSliceOrNil(o.bindPort) },
}
var serviceArg = arg{ return o, err
"Service control action: status, install, uninstall, start, stop, restart, reload (configuration).", },
"service", "s", updateNoValue: nil,
func(o options, v string) (options, error) { effect: nil,
serialize: func(o options) (val string, ok bool) {
if o.bindPort == 0 {
return "", false
}
return strconv.Itoa(o.bindPort), true
},
description: "Port to serve HTTP pages on.",
longName: "port",
shortName: "p",
}, {
updateWithValue: func(o options, v string) (options, error) {
o.serviceControlAction = v o.serviceControlAction = v
return o, nil return o, nil
}, nil, nil, },
func(o options) []string { return stringSliceOrNil(o.serviceControlAction) }, updateNoValue: nil,
} effect: nil,
serialize: func(o options) (val string, ok bool) {
var logfileArg = arg{ return o.serviceControlAction, o.serviceControlAction != ""
"Path to log file. If empty: write to stdout; if 'syslog': write to system log.", },
"logfile", "l", description: `Service control action: status, install (as a service), ` +
func(o options, v string) (options, error) { o.logFile = v; return o, nil }, nil, nil, `uninstall (as a service), start, stop, restart, reload (configuration).`,
func(o options) []string { return stringSliceOrNil(o.logFile) }, longName: "service",
} shortName: "s",
}, {
var pidfileArg = arg{ updateWithValue: func(o options, v string) (options, error) { o.logFile = v; return o, nil },
"Path to a file where PID is stored.", updateNoValue: nil,
"pidfile", "", effect: nil,
func(o options, v string) (options, error) { o.pidFile = v; return o, nil }, nil, nil, serialize: func(o options) (val string, ok bool) { return o.logFile, o.logFile != "" },
func(o options) []string { return stringSliceOrNil(o.pidFile) }, description: `Path to log file. If empty, write to stdout; ` +
} `if "syslog", write to system log.`,
longName: "logfile",
var checkConfigArg = arg{ shortName: "l",
"Check configuration and exit.", }, {
"check-config", "", updateWithValue: func(o options, v string) (options, error) { o.pidFile = v; return o, nil },
nil, func(o options) (options, error) { o.checkConfig = true; return o, nil }, nil, updateNoValue: nil,
func(o options) []string { return boolSliceOrNil(o.checkConfig) }, effect: nil,
} serialize: func(o options) (val string, ok bool) { return o.pidFile, o.pidFile != "" },
description: "Path to a file where PID is stored.",
var noCheckUpdateArg = arg{ longName: "pidfile",
"Don't check for updates.", shortName: "",
"no-check-update", "", }, {
nil, func(o options) (options, error) { o.disableUpdate = true; return o, nil }, nil, updateWithValue: nil,
func(o options) []string { return boolSliceOrNil(o.disableUpdate) }, updateNoValue: func(o options) (options, error) { o.checkConfig = true; return o, nil },
} effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.checkConfig },
var disableMemoryOptimizationArg = arg{ description: "Check configuration and exit.",
"Deprecated. Disable memory optimization.", longName: "check-config",
"no-mem-optimization", "", shortName: "",
nil, nil, func(_ options, _ string) (f effect, err error) { }, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.disableUpdate = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.disableUpdate },
description: "Don't check for updates.",
longName: "no-check-update",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: nil,
effect: func(_ options, _ string) (f effect, err error) {
log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated") log.Info("warning: using --no-mem-optimization flag has no effect and is deprecated")
return nil, nil return nil, nil
}, },
func(o options) []string { return nil }, serialize: func(o options) (val string, ok bool) { return "", false },
} description: "Deprecated. Disable memory optimization.",
longName: "no-mem-optimization",
var verboseArg = arg{ shortName: "",
"Enable verbose output.", }, {
"verbose", "v",
nil, func(o options) (options, error) { o.verbose = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.verbose) },
}
var glinetArg = arg{
"Run in GL-Inet compatibility mode.",
"glinet", "",
nil, func(o options) (options, error) { o.glinetMode = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.glinetMode) },
}
var versionArg = arg{
description: "Show the version and exit. Show more detailed version description with -v.",
longName: "version",
shortName: "",
updateWithValue: nil,
updateNoValue: nil,
effect: func(o options, exec string) (effect, error) {
return func() error {
if o.verbose {
fmt.Println(version.Verbose())
} else {
fmt.Println(version.Full())
}
os.Exit(0)
return nil
}, nil
},
serialize: func(o options) []string { return nil },
}
var helpArg = arg{
"Print this help.",
"help", "",
nil, nil, func(o options, exec string) (effect, error) {
return func() error { _ = printHelp(exec); os.Exit(64); return nil }, nil
},
func(o options) []string { return nil },
}
var noEtcHostsArg = arg{
description: "Deprecated. Do not use the OS-provided hosts.",
longName: "no-etc-hosts",
shortName: "",
updateWithValue: nil, updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil }, updateNoValue: func(o options) (options, error) { o.noEtcHosts = true; return o, nil },
effect: func(_ options, _ string) (f effect, err error) { effect: func(_ options, _ string) (f effect, err error) {
@ -242,146 +242,216 @@ var noEtcHostsArg = arg{
return nil, nil return nil, nil
}, },
serialize: func(o options) []string { return boolSliceOrNil(o.noEtcHosts) }, serialize: func(o options) (val string, ok bool) { return "", o.noEtcHosts },
} description: "Deprecated. Do not use the OS-provided hosts.",
longName: "no-etc-hosts",
var localFrontendArg = arg{ shortName: "",
description: "Use local frontend directories.", }, {
longName: "local-frontend",
shortName: "",
updateWithValue: nil, updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.localFrontend = true; return o, nil }, updateNoValue: func(o options) (options, error) { o.localFrontend = true; return o, nil },
effect: nil, effect: nil,
serialize: func(o options) []string { return boolSliceOrNil(o.localFrontend) }, serialize: func(o options) (val string, ok bool) { return "", o.localFrontend },
} description: "Use local frontend directories.",
longName: "local-frontend",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.verbose = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.verbose },
description: "Enable verbose output.",
longName: "verbose",
shortName: "v",
}, {
updateWithValue: nil,
updateNoValue: func(o options) (options, error) { o.glinetMode = true; return o, nil },
effect: nil,
serialize: func(o options) (val string, ok bool) { return "", o.glinetMode },
description: "Run in GL-Inet compatibility mode.",
longName: "glinet",
shortName: "",
}, {
updateWithValue: nil,
updateNoValue: nil,
effect: func(o options, exec string) (effect, error) {
return func() error {
if o.verbose {
fmt.Println(version.Verbose())
} else {
fmt.Println(version.Full())
}
func init() { os.Exit(0)
args = []arg{
configArg,
workDirArg,
hostArg,
portArg,
serviceArg,
logfileArg,
pidfileArg,
checkConfigArg,
noCheckUpdateArg,
disableMemoryOptimizationArg,
noEtcHostsArg,
localFrontendArg,
verboseArg,
glinetArg,
versionArg,
helpArg,
}
}
func getUsageLines(exec string, args []arg) []string { return nil
usage := []string{ }, nil
"Usage:", },
"", serialize: func(o options) (val string, ok bool) { return "", false },
fmt.Sprintf("%s [options]", exec), description: "Show the version and exit. Show more detailed version description with -v.",
"", longName: "version",
"Options:", shortName: "",
} }}
for _, arg := range args {
// printHelp prints the entire help message. It exits with an error code if
// there are any I/O errors.
func printHelp(exec string) {
b := &strings.Builder{}
stringutil.WriteToBuilder(
b,
"Usage:\n\n",
fmt.Sprintf("%s [options]\n\n", exec),
"Options:\n",
)
var err error
for _, opt := range cmdLineOpts {
val := "" val := ""
if arg.updateWithValue != nil { if opt.updateWithValue != nil {
val = " VALUE" val = " VALUE"
} }
if arg.shortName != "" {
usage = append(usage, fmt.Sprintf(" -%s, %-30s %s", longDesc := opt.longName + val
arg.shortName, if opt.shortName != "" {
"--"+arg.longName+val, _, err = fmt.Fprintf(b, " -%s, --%-28s %s\n", opt.shortName, longDesc, opt.description)
arg.description))
} else { } else {
usage = append(usage, fmt.Sprintf(" %-34s %s", _, err = fmt.Fprintf(b, " --%-32s %s\n", longDesc, opt.description)
"--"+arg.longName+val, }
arg.description))
if err != nil {
// The only error here can be from incorrect Fprintf usage, which is
// a programmer error.
panic(err)
} }
} }
return usage
_, err = fmt.Print(b)
if err != nil {
// Exit immediately, since not being able to print out a help message
// essentially means that the I/O is very broken at the moment.
exitWithError()
}
} }
func printHelp(exec string) error { // parseCmdOpts parses the command-line arguments into options and effects.
for _, line := range getUsageLines(exec, args) { func parseCmdOpts(cmdName string, args []string) (o options, eff effect, err error) {
_, err := fmt.Println(line) // Don't use range since the loop changes the loop variable.
argsLen := len(args)
for i := 0; i < len(args); i++ {
arg := args[i]
isKnown := false
for _, opt := range cmdLineOpts {
isKnown = argMatches(opt, arg)
if !isKnown {
continue
}
if opt.updateWithValue != nil {
i++
if i >= argsLen {
return o, eff, fmt.Errorf("got %s without argument", arg)
}
o, err = opt.updateWithValue(o, args[i])
} else {
o, eff, err = updateOptsNoValue(o, eff, opt, cmdName)
}
if err != nil {
return o, eff, fmt.Errorf("applying option %s: %w", arg, err)
}
break
}
if !isKnown {
return o, eff, fmt.Errorf("unknown option %s", arg)
}
}
return o, eff, err
}
// argMatches returns true if arg matches command-line option opt.
func argMatches(opt cmdLineOpt, arg string) (ok bool) {
if arg == "" || arg[0] != '-' {
return false
}
arg = arg[1:]
if arg == "" {
return false
}
return (opt.shortName != "" && arg == opt.shortName) ||
(arg[0] == '-' && arg[1:] == opt.longName)
}
// updateOptsNoValue sets values or effects from opt into o or prev.
func updateOptsNoValue(
o options,
prev effect,
opt cmdLineOpt,
cmdName string,
) (updated options, chained effect, err error) {
if opt.updateNoValue != nil {
o, err = opt.updateNoValue(o)
if err != nil {
return o, prev, err
}
return o, prev, nil
}
next, err := opt.effect(o, cmdName)
if err != nil {
return o, prev, err
}
chained = chainEffect(prev, next)
return o, chained, nil
}
// chainEffect chans the next effect after the prev one. If prev is nil, eff
// only calls next. If next is nil, eff is prev; if prev is nil, eff is next.
func chainEffect(prev, next effect) (eff effect) {
if prev == nil {
return next
} else if next == nil {
return prev
}
eff = func() (err error) {
err = prev()
if err != nil { if err != nil {
return err return err
} }
return next()
} }
return nil
return eff
} }
func argMatches(a arg, v string) bool { // optsToArgs converts command line options into a list of arguments.
return v == "--"+a.longName || (a.shortName != "" && v == "-"+a.shortName) func optsToArgs(o options) (args []string) {
} for _, opt := range cmdLineOpts {
val, ok := opt.serialize(o)
func parse(exec string, ss []string) (o options, f effect, err error) { if !ok {
for i := 0; i < len(ss); i++ { continue
v := ss[i]
knownParam := false
for _, arg := range args {
if argMatches(arg, v) {
if arg.updateWithValue != nil {
if i+1 >= len(ss) {
return o, f, fmt.Errorf("got %s without argument", v)
}
i++
o, err = arg.updateWithValue(o, ss[i])
if err != nil {
return
}
} else if arg.updateNoValue != nil {
o, err = arg.updateNoValue(o)
if err != nil {
return
}
} else if arg.effect != nil {
var eff effect
eff, err = arg.effect(o, exec)
if err != nil {
return
}
if eff != nil {
prevf := f
f = func() (ferr error) {
if prevf != nil {
ferr = prevf()
}
if ferr == nil {
ferr = eff()
}
return ferr
}
}
}
knownParam = true
break
}
} }
if !knownParam {
return o, f, fmt.Errorf("unknown option %v", v) if opt.shortName != "" {
args = append(args, "-"+opt.shortName)
} else {
args = append(args, "--"+opt.longName)
}
if val != "" {
args = append(args, val)
} }
} }
return return args
}
func shortestFlag(a arg) string {
if a.shortName != "" {
return "-" + a.shortName
}
return "--" + a.longName
}
func serialize(o options) []string {
ss := []string{}
for _, arg := range args {
s := arg.serialize(o)
if s != nil {
ss = append(ss, append([]string{shortestFlag(arg)}, s...)...)
}
}
return ss
} }

View File

@ -12,7 +12,7 @@ import (
func testParseOK(t *testing.T, ss ...string) options { func testParseOK(t *testing.T, ss ...string) options {
t.Helper() t.Helper()
o, _, err := parse("", ss) o, _, err := parseCmdOpts("", ss)
require.NoError(t, err) require.NoError(t, err)
return o return o
@ -21,7 +21,7 @@ func testParseOK(t *testing.T, ss ...string) options {
func testParseErr(t *testing.T, descr string, ss ...string) { func testParseErr(t *testing.T, descr string, ss ...string) {
t.Helper() t.Helper()
_, _, err := parse("", ss) _, _, err := parseCmdOpts("", ss)
require.Error(t, err) require.Error(t, err)
} }
@ -38,11 +38,11 @@ func TestParseVerbose(t *testing.T) {
} }
func TestParseConfigFilename(t *testing.T) { func TestParseConfigFilename(t *testing.T) {
assert.Equal(t, "", testParseOK(t).configFilename, "empty is no config filename") assert.Equal(t, "", testParseOK(t).confFilename, "empty is no config filename")
assert.Equal(t, "path", testParseOK(t, "-c", "path").configFilename, "-c is config filename") assert.Equal(t, "path", testParseOK(t, "-c", "path").confFilename, "-c is config filename")
testParseParamMissing(t, "-c") testParseParamMissing(t, "-c")
assert.Equal(t, "path", testParseOK(t, "--config", "path").configFilename, "--config is config filename") assert.Equal(t, "path", testParseOK(t, "--config", "path").confFilename, "--config is config filename")
testParseParamMissing(t, "--config") testParseParamMissing(t, "--config")
} }
@ -103,7 +103,7 @@ func TestParseDisableUpdate(t *testing.T) {
// TODO(e.burkov): Remove after v0.108.0. // TODO(e.burkov): Remove after v0.108.0.
func TestParseDisableMemoryOptimization(t *testing.T) { func TestParseDisableMemoryOptimization(t *testing.T) {
o, eff, err := parse("", []string{"--no-mem-optimization"}) o, eff, err := parseCmdOpts("", []string{"--no-mem-optimization"})
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, eff) assert.Nil(t, eff)
@ -130,73 +130,73 @@ func TestParseUnknown(t *testing.T) {
testParseErr(t, "unknown dash", "-") testParseErr(t, "unknown dash", "-")
} }
func TestSerialize(t *testing.T) { func TestOptsToArgs(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
args []string
opts options opts options
ss []string
}{{ }{{
name: "empty", name: "empty",
args: []string{},
opts: options{}, opts: options{},
ss: []string{},
}, { }, {
name: "config_filename", name: "config_filename",
opts: options{configFilename: "path"}, args: []string{"-c", "path"},
ss: []string{"-c", "path"}, opts: options{confFilename: "path"},
}, { }, {
name: "work_dir", name: "work_dir",
args: []string{"-w", "path"},
opts: options{workDir: "path"}, opts: options{workDir: "path"},
ss: []string{"-w", "path"},
}, { }, {
name: "bind_host", name: "bind_host",
args: []string{"-h", "1.2.3.4"},
opts: options{bindHost: net.IP{1, 2, 3, 4}}, opts: options{bindHost: net.IP{1, 2, 3, 4}},
ss: []string{"-h", "1.2.3.4"},
}, { }, {
name: "bind_port", name: "bind_port",
args: []string{"-p", "666"},
opts: options{bindPort: 666}, opts: options{bindPort: 666},
ss: []string{"-p", "666"},
}, { }, {
name: "log_file", name: "log_file",
args: []string{"-l", "path"},
opts: options{logFile: "path"}, opts: options{logFile: "path"},
ss: []string{"-l", "path"},
}, { }, {
name: "pid_file", name: "pid_file",
args: []string{"--pidfile", "path"},
opts: options{pidFile: "path"}, opts: options{pidFile: "path"},
ss: []string{"--pidfile", "path"},
}, { }, {
name: "disable_update", name: "disable_update",
args: []string{"--no-check-update"},
opts: options{disableUpdate: true}, opts: options{disableUpdate: true},
ss: []string{"--no-check-update"},
}, { }, {
name: "control_action", name: "control_action",
args: []string{"-s", "run"},
opts: options{serviceControlAction: "run"}, opts: options{serviceControlAction: "run"},
ss: []string{"-s", "run"},
}, { }, {
name: "glinet_mode", name: "glinet_mode",
args: []string{"--glinet"},
opts: options{glinetMode: true}, opts: options{glinetMode: true},
ss: []string{"--glinet"},
}, { }, {
name: "multiple", name: "multiple",
opts: options{ args: []string{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
},
ss: []string{
"-c", "config", "-c", "config",
"-w", "work", "-w", "work",
"-s", "run", "-s", "run",
"--pidfile", "pid", "--pidfile", "pid",
"--no-check-update", "--no-check-update",
}, },
opts: options{
serviceControlAction: "run",
confFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
},
}} }}
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
result := serialize(tc.opts) result := optsToArgs(tc.opts)
assert.ElementsMatch(t, tc.ss, result) assert.ElementsMatch(t, tc.args, result)
}) })
} }
} }

View File

@ -197,7 +197,7 @@ func handleServiceControlAction(opts options, clientBuildFS fs.FS) {
DisplayName: serviceDisplayName, DisplayName: serviceDisplayName,
Description: serviceDescription, Description: serviceDescription,
WorkingDirectory: pwd, WorkingDirectory: pwd,
Arguments: serialize(runOpts), Arguments: optsToArgs(runOpts),
} }
configureService(svcConfig) configureService(svcConfig)

63
internal/next/agh/agh.go Normal file
View File

@ -0,0 +1,63 @@
// Package agh contains common entities and interfaces of AdGuard Home.
package agh
import "context"
// Service is the interface for API servers.
//
// TODO(a.garipov): Consider adding a context to Start.
//
// TODO(a.garipov): Consider adding a Wait method or making an extension
// interface for that.
type Service interface {
// Start starts the service. It does not block.
Start() (err error)
// Shutdown gracefully stops the service. ctx is used to determine
// a timeout before trying to stop the service less gracefully.
Shutdown(ctx context.Context) (err error)
}
// type check
var _ Service = EmptyService{}
// EmptyService is a [Service] that does nothing.
//
// TODO(a.garipov): Remove if unnecessary.
type EmptyService struct{}
// Start implements the [Service] interface for EmptyService.
func (EmptyService) Start() (err error) { return nil }
// Shutdown implements the [Service] interface for EmptyService.
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }
// ServiceWithConfig is an extension of the [Service] interface for services
// that can return their configuration.
//
// TODO(a.garipov): Consider removing this generic interface if we figure out
// how to make it testable in a better way.
type ServiceWithConfig[ConfigType any] interface {
Service
Config() (c ConfigType)
}
// type check
var _ ServiceWithConfig[struct{}] = (*EmptyServiceWithConfig[struct{}])(nil)
// EmptyServiceWithConfig is a ServiceWithConfig that does nothing. Its Config
// method returns Conf.
//
// TODO(a.garipov): Remove if unnecessary.
type EmptyServiceWithConfig[ConfigType any] struct {
EmptyService
Conf ConfigType
}
// Config implements the [ServiceWithConfig] interface for
// *EmptyServiceWithConfig.
func (s *EmptyServiceWithConfig[ConfigType]) Config() (conf ConfigType) {
return s.Conf
}

View File

@ -8,39 +8,49 @@ import (
"context" "context"
"io/fs" "io/fs"
"math/rand" "math/rand"
"net/netip" "os"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
) )
// Main is the entry point of application. // Main is the entry point of application.
func Main(clientBuildFS fs.FS) { func Main(clientBuildFS fs.FS) {
// # Initial Configuration // Initial Configuration
start := time.Now() start := time.Now()
rand.Seed(start.UnixNano()) rand.Seed(start.UnixNano())
// TODO(a.garipov): Set up logging. // TODO(a.garipov): Set up logging.
// # Web Service log.Info("starting adguard home, version %s, pid %d", version.Version(), os.Getpid())
// Web Service
// TODO(a.garipov): Use in the Web service. // TODO(a.garipov): Use in the Web service.
_ = clientBuildFS _ = clientBuildFS
// TODO(a.garipov): Make configurable. // TODO(a.garipov): Set up configuration file name.
web := websvc.New(&websvc.Config{ const confFile = "AdGuardHome.1.yaml"
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:3001")},
Start: start,
Timeout: 60 * time.Second,
})
err := web.Start() confMgr, err := configmgr.New(confFile, start)
fatalOnError(err)
web := confMgr.Web()
err = web.Start()
fatalOnError(err)
dns := confMgr.DNS()
err = dns.Start()
fatalOnError(err) fatalOnError(err)
sigHdlr := newSignalHandler( sigHdlr := newSignalHandler(
confFile,
start,
web, web,
dns,
) )
go sigHdlr.handle() go sigHdlr.handle()

118
internal/next/cmd/signal.go Normal file
View File

@ -0,0 +1,118 @@
package cmd
import (
"os"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
"github.com/AdguardTeam/golibs/log"
)
// signalHandler processes incoming signals and shuts services down.
type signalHandler struct {
// signal is the channel to which OS signals are sent.
signal chan os.Signal
// confFile is the path to the configuration file.
confFile string
// start is the time at which AdGuard Home has been started.
start time.Time
// services are the services that are shut down before application exiting.
services []agh.Service
}
// handle processes OS signals.
func (h *signalHandler) handle() {
defer log.OnPanic("signalHandler.handle")
for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig)
if aghos.IsReconfigureSignal(sig) {
h.reconfigure()
} else if aghos.IsShutdownSignal(sig) {
status := h.shutdown()
log.Info("sighdlr: exiting with status %d", status)
os.Exit(status)
}
}
}
// reconfigure rereads the configuration file and updates and restarts services.
func (h *signalHandler) reconfigure() {
log.Info("sighdlr: reconfiguring adguard home")
status := h.shutdown()
if status != statusSuccess {
log.Info("sighdlr: reconfiruging: exiting with status %d", status)
os.Exit(status)
}
// TODO(a.garipov): This is a very rough way to do it. Some services can be
// reconfigured without the full shutdown, and the error handling is
// currently not the best.
confMgr, err := configmgr.New(h.confFile, h.start)
fatalOnError(err)
web := confMgr.Web()
err = web.Start()
fatalOnError(err)
dns := confMgr.DNS()
err = dns.Start()
fatalOnError(err)
h.services = []agh.Service{
dns,
web,
}
log.Info("sighdlr: successfully reconfigured adguard home")
}
// Exit status constants.
const (
statusSuccess = 0
statusError = 1
)
// shutdown gracefully shuts down all services.
func (h *signalHandler) shutdown() (status int) {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
status = statusSuccess
log.Info("sighdlr: shutting down services")
for i, service := range h.services {
err := service.Shutdown(ctx)
if err != nil {
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
status = statusError
}
}
return status
}
// newSignalHandler returns a new signalHandler that shuts down svcs.
func newSignalHandler(confFile string, start time.Time, svcs ...agh.Service) (h *signalHandler) {
h = &signalHandler{
signal: make(chan os.Signal, 1),
confFile: confFile,
start: start,
services: svcs,
}
aghos.NotifyShutdownSignal(h.signal)
aghos.NotifyReconfigureSignal(h.signal)
return h
}

View File

@ -0,0 +1,40 @@
package configmgr
import (
"net/netip"
"github.com/AdguardTeam/golibs/timeutil"
)
// Configuration Structures
// config is the top-level on-disk configuration structure.
type config struct {
DNS *dnsConfig `yaml:"dns"`
HTTP *httpConfig `yaml:"http"`
// TODO(a.garipov): Use.
SchemaVersion int `yaml:"schema_version"`
// TODO(a.garipov): Use.
DebugPprof bool `yaml:"debug_pprof"`
Verbose bool `yaml:"verbose"`
}
// dnsConfig is the on-disk DNS configuration.
//
// TODO(a.garipov): Validate.
type dnsConfig struct {
Addresses []netip.AddrPort `yaml:"addresses"`
BootstrapDNS []string `yaml:"bootstrap_dns"`
UpstreamDNS []string `yaml:"upstream_dns"`
UpstreamTimeout timeutil.Duration `yaml:"upstream_timeout"`
}
// httpConfig is the on-disk web API configuration.
//
// TODO(a.garipov): Validate.
type httpConfig struct {
Addresses []netip.AddrPort `yaml:"addresses"`
SecureAddresses []netip.AddrPort `yaml:"secure_addresses"`
Timeout timeutil.Duration `yaml:"timeout"`
ForceHTTPS bool `yaml:"force_https"`
}

View File

@ -0,0 +1,205 @@
// Package configmgr defines the AdGuard Home on-disk configuration entities and
// configuration manager.
package configmgr
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"gopkg.in/yaml.v3"
)
// Configuration Manager
// Manager handles full and partial changes in the configuration, persisting
// them to disk if necessary.
type Manager struct {
// updMu makes sure that at most one reconfiguration is performed at a time.
// updMu protects all fields below.
updMu *sync.RWMutex
// dns is the DNS service.
dns *dnssvc.Service
// Web is the Web API service.
web *websvc.Service
// current is the current configuration.
current *config
// fileName is the name of the configuration file.
fileName string
}
// New creates a new *Manager that persists changes to the file pointed to by
// fileName. It reads the configuration file and populates the service fields.
// start is the startup time of AdGuard Home.
func New(fileName string, start time.Time) (m *Manager, err error) {
defer func() { err = errors.Annotate(err, "reading config") }()
conf := &config{}
f, err := os.Open(fileName)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
err = yaml.NewDecoder(f).Decode(conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
// TODO(a.garipov): Move into a separate function and add other logging
// settings.
if conf.Verbose {
log.SetLevel(log.DEBUG)
}
// TODO(a.garipov): Validate the configuration structure. Return an error
// if it's incorrect.
m = &Manager{
updMu: &sync.RWMutex{},
current: conf,
fileName: fileName,
}
// TODO(a.garipov): Get the context with the timeout from the arguments?
const assemblyTimeout = 5 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), assemblyTimeout)
defer cancel()
err = m.assemble(ctx, conf, start)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
}
return m, nil
}
// assemble creates all services and puts them into the corresponding fields.
// The fields of conf must not be modified after calling assemble.
func (m *Manager) assemble(ctx context.Context, conf *config, start time.Time) (err error) {
dnsConf := &dnssvc.Config{
Addresses: conf.DNS.Addresses,
BootstrapServers: conf.DNS.BootstrapDNS,
UpstreamServers: conf.DNS.UpstreamDNS,
UpstreamTimeout: conf.DNS.UpstreamTimeout.Duration,
}
err = m.updateDNS(ctx, dnsConf)
if err != nil {
return fmt.Errorf("assembling dnssvc: %w", err)
}
webSvcConf := &websvc.Config{
ConfigManager: m,
// TODO(a.garipov): Fill from config file.
TLS: nil,
Start: start,
Addresses: conf.HTTP.Addresses,
SecureAddresses: conf.HTTP.SecureAddresses,
Timeout: conf.HTTP.Timeout.Duration,
ForceHTTPS: conf.HTTP.ForceHTTPS,
}
err = m.updateWeb(ctx, webSvcConf)
if err != nil {
return fmt.Errorf("assembling websvc: %w", err)
}
return nil
}
// DNS returns the current DNS service. It is safe for concurrent use.
func (m *Manager) DNS() (dns agh.ServiceWithConfig[*dnssvc.Config]) {
m.updMu.RLock()
defer m.updMu.RUnlock()
return m.dns
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *Manager. The
// fields of c must not be modified after calling UpdateDNS.
func (m *Manager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
m.updMu.Lock()
defer m.updMu.Unlock()
// TODO(a.garipov): Update and write the configuration file. Return an
// error if something went wrong.
err = m.updateDNS(ctx, c)
if err != nil {
return fmt.Errorf("reassembling dnssvc: %w", err)
}
return nil
}
// updateDNS recreates the DNS service. m.updMu is expected to be locked.
func (m *Manager) updateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
if prev := m.dns; prev != nil {
err = prev.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down dns svc: %w", err)
}
}
svc, err := dnssvc.New(c)
if err != nil {
return fmt.Errorf("creating dns svc: %w", err)
}
m.dns = svc
return nil
}
// Web returns the current web service. It is safe for concurrent use.
func (m *Manager) Web() (web agh.ServiceWithConfig[*websvc.Config]) {
m.updMu.RLock()
defer m.updMu.RUnlock()
return m.web
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *Manager. The
// fields of c must not be modified after calling UpdateWeb.
func (m *Manager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
m.updMu.Lock()
defer m.updMu.Unlock()
// TODO(a.garipov): Update and write the configuration file. Return an
// error if something went wrong.
err = m.updateWeb(ctx, c)
if err != nil {
return fmt.Errorf("reassembling websvc: %w", err)
}
return nil
}
// updateWeb recreates the web service. m.upd is expected to be locked.
func (m *Manager) updateWeb(ctx context.Context, c *websvc.Config) (err error) {
if prev := m.web; prev != nil {
err = prev.Shutdown(ctx)
if err != nil {
return fmt.Errorf("shutting down web svc: %w", err)
}
}
m.web = websvc.New(c)
return nil
}

View File

@ -9,9 +9,10 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sync/atomic"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
// TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes // TODO(a.garipov): Add a “dnsproxy proxy” package to shield us from changes
// and replacement of module dnsproxy. // and replacement of module dnsproxy.
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
@ -47,6 +48,14 @@ type Config struct {
// Service is the AdGuard Home DNS service. A nil *Service is a valid // Service is the AdGuard Home DNS service. A nil *Service is a valid
// [agh.Service] that does nothing. // [agh.Service] that does nothing.
type Service struct { type Service struct {
// running is an atomic boolean value. Keep it the first value in the
// struct to ensure atomic alignment. 0 means that the service is not
// running, 1 means that it is running.
//
// TODO(a.garipov): Use [atomic.Bool] in Go 1.19 or get rid of it
// completely.
running uint64
proxy *proxy.Proxy proxy *proxy.Proxy
bootstraps []string bootstraps []string
upstreams []string upstreams []string
@ -160,6 +169,17 @@ func (svc *Service) Start() (err error) {
return nil return nil
} }
defer func() {
// TODO(a.garipov): [proxy.Proxy.Start] doesn't actually have any way to
// tell when all servers are actually up, so at best this is merely an
// assumption.
if err != nil {
atomic.StoreUint64(&svc.running, 0)
} else {
atomic.StoreUint64(&svc.running, 1)
}
}()
return svc.proxy.Start() return svc.proxy.Start()
} }
@ -173,13 +193,27 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
return svc.proxy.Stop() return svc.proxy.Stop()
} }
// Config returns the current configuration of the web service. // Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) { func (svc *Service) Config() (c *Config) {
// TODO(a.garipov): Do we need to get the TCP addresses separately? // TODO(a.garipov): Do we need to get the TCP addresses separately?
udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs := make([]netip.AddrPort, len(udpAddrs)) var addrs []netip.AddrPort
for i, a := range udpAddrs { if atomic.LoadUint64(&svc.running) == 1 {
addrs[i] = a.(*net.UDPAddr).AddrPort() udpAddrs := svc.proxy.Addrs(proxy.ProtoUDP)
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.(*net.UDPAddr).AddrPort()
}
} else {
conf := svc.proxy.Config
udpAddrs := conf.UDPListenAddr
addrs = make([]netip.AddrPort, len(udpAddrs))
for i, a := range udpAddrs {
addrs[i] = a.AddrPort()
}
} }
c = &Config{ c = &Config{

View File

@ -7,7 +7,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/v1/dnssvc" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"

View File

@ -0,0 +1,84 @@
package websvc
import (
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
)
// DNS Settings Handlers
// ReqPatchSettingsDNS describes the request to the PATCH /api/v1/settings/dns
// HTTP API.
type ReqPatchSettingsDNS struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout JSONDuration `json:"upstream_timeout"`
}
// HTTPAPIDNSSettings are the DNS settings as used by the HTTP API. See the
// DnsSettings object in the OpenAPI specification.
type HTTPAPIDNSSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
BootstrapServers []string `json:"bootstrap_servers"`
UpstreamServers []string `json:"upstream_servers"`
UpstreamTimeout JSONDuration `json:"upstream_timeout"`
}
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
// API.
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsDNS{
Addresses: []netip.AddrPort{},
BootstrapServers: []string{},
UpstreamServers: []string{},
}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &dnssvc.Config{
Addresses: req.Addresses,
BootstrapServers: req.BootstrapServers,
UpstreamServers: req.UpstreamServers,
UpstreamTimeout: time.Duration(req.UpstreamTimeout),
}
ctx := r.Context()
err = svc.confMgr.UpdateDNS(ctx, newConf)
if err != nil {
writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", err))
return
}
newSvc := svc.confMgr.DNS()
err = newSvc.Start()
if err != nil {
writeJSONErrorResponse(w, r, fmt.Errorf("starting new service: %w", err))
return
}
writeJSONOKResponse(w, r, &HTTPAPIDNSSettings{
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,
UpstreamServers: newConf.UpstreamServers,
UpstreamTimeout: JSONDuration(newConf.UpstreamTimeout),
})
}

View File

@ -0,0 +1,69 @@
package websvc_test
import (
"context"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"sync/atomic"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandlePatchSettingsDNS(t *testing.T) {
wantDNS := &websvc.HTTPAPIDNSSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:53")},
BootstrapServers: []string{"1.0.0.1"},
UpstreamServers: []string{"1.1.1.1"},
UpstreamTimeout: websvc.JSONDuration(2 * time.Second),
}
// TODO(a.garipov): Use [atomic.Bool] in Go 1.19.
var numStarted uint64
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
return &aghtest.ServiceWithConfig[*dnssvc.Config]{
OnStart: func() (err error) {
atomic.AddUint64(&numStarted, 1)
return nil
},
OnShutdown: func(_ context.Context) (err error) { panic("not implemented") },
OnConfig: func() (c *dnssvc.Config) { panic("not implemented") },
}
}
confMgr.onUpdateDNS = func(ctx context.Context, c *dnssvc.Config) (err error) {
return nil
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsDNS,
}
req := jobj{
"addresses": wantDNS.Addresses,
"bootstrap_servers": wantDNS.BootstrapServers,
"upstream_servers": wantDNS.UpstreamServers,
"upstream_timeout": wantDNS.UpstreamTimeout,
}
respBody := httpPatch(t, u, req, http.StatusOK)
resp := &websvc.HTTPAPIDNSSettings{}
err := json.Unmarshal(respBody, resp)
require.NoError(t, err)
assert.Equal(t, uint64(1), numStarted)
assert.Equal(t, wantDNS, resp)
assert.Equal(t, wantDNS, resp)
}

View File

@ -0,0 +1,110 @@
package websvc
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/golibs/log"
)
// HTTP Settings Handlers
// ReqPatchSettingsHTTP describes the request to the PATCH /api/v1/settings/http
// HTTP API.
type ReqPatchSettingsHTTP struct {
// TODO(a.garipov): Add more as we go.
//
// TODO(a.garipov): Add wait time.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout JSONDuration `json:"timeout"`
}
// HTTPAPIHTTPSettings are the HTTP settings as used by the HTTP API. See the
// HttpSettings object in the OpenAPI specification.
type HTTPAPIHTTPSettings struct {
// TODO(a.garipov): Add more as we go.
Addresses []netip.AddrPort `json:"addresses"`
SecureAddresses []netip.AddrPort `json:"secure_addresses"`
Timeout JSONDuration `json:"timeout"`
ForceHTTPS bool `json:"force_https"`
}
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
// HTTP API.
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
req := &ReqPatchSettingsHTTP{}
// TODO(a.garipov): Validate nulls and proper JSON patch.
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
writeJSONErrorResponse(w, r, fmt.Errorf("decoding: %w", err))
return
}
newConf := &Config{
ConfigManager: svc.confMgr,
TLS: svc.tls,
Addresses: req.Addresses,
SecureAddresses: req.SecureAddresses,
Timeout: time.Duration(req.Timeout),
ForceHTTPS: svc.forceHTTPS,
}
writeJSONOKResponse(w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: JSONDuration(newConf.Timeout),
ForceHTTPS: newConf.ForceHTTPS,
})
cancelUpd := func() {}
updCtx := context.Background()
ctx := r.Context()
if deadline, ok := ctx.Deadline(); ok {
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
}
// Launch the new HTTP service in a separate goroutine to let this handler
// finish and thus, this server to shutdown.
go func() {
defer cancelUpd()
updErr := svc.confMgr.UpdateWeb(updCtx, newConf)
if updErr != nil {
writeJSONErrorResponse(w, r, fmt.Errorf("updating: %w", updErr))
return
}
// TODO(a.garipov): Consider better ways to do this.
const maxUpdDur = 10 * time.Second
updStart := time.Now()
var newSvc agh.ServiceWithConfig[*Config]
for newSvc = svc.confMgr.Web(); newSvc == svc; {
if time.Since(updStart) >= maxUpdDur {
log.Error("websvc: failed to update svc after %s", maxUpdDur)
return
}
log.Debug("websvc: waiting for new websvc to be configured")
time.Sleep(1 * time.Second)
}
updErr = newSvc.Start()
if updErr != nil {
log.Error("websvc: new svc failed to start with error: %s", updErr)
}
}()
}

View File

@ -0,0 +1,63 @@
package websvc_test
import (
"context"
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandlePatchSettingsHTTP(t *testing.T) {
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.1.1:443")},
Timeout: websvc.JSONDuration(10 * time.Second),
ForceHTTPS: false,
}
confMgr := newConfigManager()
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return websvc.New(&websvc.Config{
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: 5 * time.Second,
ForceHTTPS: true,
})
}
confMgr.onUpdateWeb = func(ctx context.Context, c *websvc.Config) (err error) {
return nil
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsHTTP,
}
req := jobj{
"addresses": wantWeb.Addresses,
"secure_addresses": wantWeb.SecureAddresses,
"timeout": wantWeb.Timeout,
"force_https": wantWeb.ForceHTTPS,
}
respBody := httpPatch(t, u, req, http.StatusOK)
resp := &websvc.HTTPAPIHTTPSettings{}
err := json.Unmarshal(respBody, resp)
require.NoError(t, err)
assert.Equal(t, wantWeb, resp)
}

View File

@ -0,0 +1,143 @@
package websvc
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/log"
)
// JSON Utilities
// nsecPerMsec is the number of nanoseconds in a millisecond.
const nsecPerMsec = float64(time.Millisecond / time.Nanosecond)
// JSONDuration is a time.Duration that can be decoded from JSON and encoded
// into JSON according to our API conventions.
type JSONDuration time.Duration
// type check
var _ json.Marshaler = JSONDuration(0)
// MarshalJSON implements the json.Marshaler interface for JSONDuration. err is
// always nil.
func (d JSONDuration) MarshalJSON() (b []byte, err error) {
msec := float64(time.Duration(d)) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', -1, 64)
return b, nil
}
// type check
var _ json.Unmarshaler = (*JSONDuration)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *JSONDuration.
func (d *JSONDuration) UnmarshalJSON(b []byte) (err error) {
if d == nil {
return fmt.Errorf("json duration is nil")
}
msec, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return fmt.Errorf("parsing json time: %w", err)
}
*d = JSONDuration(int64(msec * nsecPerMsec))
return nil
}
// JSONTime is a time.Time that can be decoded from JSON and encoded into JSON
// according to our API conventions.
type JSONTime time.Time
// type check
var _ json.Marshaler = JSONTime{}
// MarshalJSON implements the json.Marshaler interface for JSONTime. err is
// always nil.
func (t JSONTime) MarshalJSON() (b []byte, err error) {
msec := float64(time.Time(t).UnixNano()) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', -1, 64)
return b, nil
}
// type check
var _ json.Unmarshaler = (*JSONTime)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *JSONTime.
func (t *JSONTime) UnmarshalJSON(b []byte) (err error) {
if t == nil {
return fmt.Errorf("json time is nil")
}
msec, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return fmt.Errorf("parsing json time: %w", err)
}
*t = JSONTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
return nil
}
// writeJSONOKResponse writes headers with the code 200 OK, encodes v into w,
// and logs any errors it encounters. r is used to get additional information
// from the request.
func writeJSONOKResponse(w http.ResponseWriter, r *http.Request, v any) {
writeJSONResponse(w, r, v, http.StatusOK)
}
// writeJSONResponse writes headers with code, encodes v into w, and logs any
// errors it encounters. r is used to get additional information from the
// request.
func writeJSONResponse(w http.ResponseWriter, r *http.Request, v any, code int) {
// TODO(a.garipov): Put some of these to a middleware.
h := w.Header()
h.Set(aghhttp.HdrNameContentType, aghhttp.HdrValApplicationJSON)
h.Set(aghhttp.HdrNameServer, aghhttp.UserAgent())
w.WriteHeader(code)
err := json.NewEncoder(w).Encode(v)
if err != nil {
log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err)
}
}
// ErrorCode is the error code as used by the HTTP API. See the ErrorCode
// definition in the OpenAPI specification.
type ErrorCode string
// ErrorCode constants.
//
// TODO(a.garipov): Expand and document codes.
const (
// ErrorCodeTMP000 is the temporary error code used for all errors.
ErrorCodeTMP000 = ""
)
// HTTPAPIErrorResp is the error response as used by the HTTP API. See the
// BadRequestResp, InternalServerErrorResp, and similar objects in the OpenAPI
// specification.
type HTTPAPIErrorResp struct {
Code ErrorCode `json:"code"`
Msg string `json:"msg"`
}
// writeJSONErrorResponse encodes err as a JSON error into w, and logs any
// errors it encounters. r is used to get additional information from the
// request.
func writeJSONErrorResponse(w http.ResponseWriter, r *http.Request, err error) {
log.Error("websvc: %s %s: %s", r.Method, r.URL.Path, err)
writeJSONResponse(w, r, &HTTPAPIErrorResp{
Code: ErrorCodeTMP000,
Msg: err.Error(),
}, http.StatusUnprocessableEntity)
}

View File

@ -0,0 +1,114 @@
package websvc_test
import (
"encoding/json"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testJSONTime is the JSON time for tests.
var testJSONTime = websvc.JSONTime(time.Unix(1_234_567_890, 123_456_000).UTC())
// testJSONTimeStr is the string with the JSON encoding of testJSONTime.
const testJSONTimeStr = "1234567890123.456"
func TestJSONTime_MarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
in websvc.JSONTime
want []byte
}{{
name: "unix_zero",
wantErrMsg: "",
in: websvc.JSONTime(time.Unix(0, 0)),
want: []byte("0"),
}, {
name: "empty",
wantErrMsg: "",
in: websvc.JSONTime{},
want: []byte("-6795364578871.345"),
}, {
name: "time",
wantErrMsg: "",
in: testJSONTime,
want: []byte(testJSONTimeStr),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := tc.in.MarshalJSON()
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("json", func(t *testing.T) {
in := &struct {
A websvc.JSONTime
}{
A: testJSONTime,
}
got, err := json.Marshal(in)
require.NoError(t, err)
assert.Equal(t, []byte(`{"A":`+testJSONTimeStr+`}`), got)
})
}
func TestJSONTime_UnmarshalJSON(t *testing.T) {
testCases := []struct {
name string
wantErrMsg string
want websvc.JSONTime
data []byte
}{{
name: "time",
wantErrMsg: "",
want: testJSONTime,
data: []byte(testJSONTimeStr),
}, {
name: "bad",
wantErrMsg: `parsing json time: strconv.ParseFloat: parsing "{}": ` +
`invalid syntax`,
want: websvc.JSONTime{},
data: []byte(`{}`),
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var got websvc.JSONTime
err := got.UnmarshalJSON(tc.data)
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
assert.Equal(t, tc.want, got)
})
}
t.Run("nil", func(t *testing.T) {
err := (*websvc.JSONTime)(nil).UnmarshalJSON([]byte("0"))
require.Error(t, err)
msg := err.Error()
assert.Equal(t, "json time is nil", msg)
})
t.Run("json", func(t *testing.T) {
want := testJSONTime
var got struct {
A websvc.JSONTime
}
err := json.Unmarshal([]byte(`{"A":`+testJSONTimeStr+`}`), &got)
require.NoError(t, err)
assert.Equal(t, want, got.A)
})
}

View File

@ -0,0 +1,11 @@
package websvc
// Path constants
const (
PathHealthCheck = "/health-check"
PathV1SettingsAll = "/api/v1/settings/all"
PathV1SettingsDNS = "/api/v1/settings/dns"
PathV1SettingsHTTP = "/api/v1/settings/http"
PathV1SystemInfo = "/api/v1/system/info"
)

View File

@ -0,0 +1,42 @@
package websvc
import (
"net/http"
)
// All Settings Handlers
// RespGetV1SettingsAll describes the response of the GET /api/v1/settings/all
// HTTP API.
type RespGetV1SettingsAll struct {
// TODO(a.garipov): Add more as we go.
DNS *HTTPAPIDNSSettings `json:"dns"`
HTTP *HTTPAPIHTTPSettings `json:"http"`
}
// handleGetSettingsAll is the handler for the GET /api/v1/settings/all HTTP
// API.
func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request) {
dnsSvc := svc.confMgr.DNS()
dnsConf := dnsSvc.Config()
webSvc := svc.confMgr.Web()
httpConf := webSvc.Config()
// TODO(a.garipov): Add all currently supported parameters.
writeJSONOKResponse(w, r, &RespGetV1SettingsAll{
DNS: &HTTPAPIDNSSettings{
Addresses: dnsConf.Addresses,
BootstrapServers: dnsConf.BootstrapServers,
UpstreamServers: dnsConf.UpstreamServers,
UpstreamTimeout: JSONDuration(dnsConf.UpstreamTimeout),
},
HTTP: &HTTPAPIHTTPSettings{
Addresses: httpConf.Addresses,
SecureAddresses: httpConf.SecureAddresses,
Timeout: JSONDuration(httpConf.Timeout),
ForceHTTPS: httpConf.ForceHTTPS,
},
})
}

View File

@ -0,0 +1,75 @@
package websvc_test
import (
"crypto/tls"
"encoding/json"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestService_HandleGetSettingsAll(t *testing.T) {
// TODO(a.garipov): Add all currently supported parameters.
wantDNS := &websvc.HTTPAPIDNSSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:53")},
BootstrapServers: []string{"94.140.14.140", "94.140.14.141"},
UpstreamServers: []string{"94.140.14.14", "1.1.1.1"},
UpstreamTimeout: websvc.JSONDuration(1 * time.Second),
}
wantWeb := &websvc.HTTPAPIHTTPSettings{
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:80")},
SecureAddresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:443")},
Timeout: websvc.JSONDuration(5 * time.Second),
ForceHTTPS: true,
}
confMgr := newConfigManager()
confMgr.onDNS = func() (s agh.ServiceWithConfig[*dnssvc.Config]) {
c, err := dnssvc.New(&dnssvc.Config{
Addresses: wantDNS.Addresses,
UpstreamServers: wantDNS.UpstreamServers,
BootstrapServers: wantDNS.BootstrapServers,
UpstreamTimeout: time.Duration(wantDNS.UpstreamTimeout),
})
require.NoError(t, err)
return c
}
confMgr.onWeb = func() (s agh.ServiceWithConfig[*websvc.Config]) {
return websvc.New(&websvc.Config{
TLS: &tls.Config{
Certificates: []tls.Certificate{{}},
},
Addresses: wantWeb.Addresses,
SecureAddresses: wantWeb.SecureAddresses,
Timeout: time.Duration(wantWeb.Timeout),
ForceHTTPS: true,
})
}
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathV1SettingsAll,
}
body := httpGet(t, u, http.StatusOK)
resp := &websvc.RespGetV1SettingsAll{}
err := json.Unmarshal(body, resp)
require.NoError(t, err)
assert.Equal(t, wantDNS, resp.DNS)
assert.Equal(t, wantWeb, resp.HTTP)
}

View File

@ -16,20 +16,20 @@ type RespGetV1SystemInfo struct {
Channel string `json:"channel"` Channel string `json:"channel"`
OS string `json:"os"` OS string `json:"os"`
NewVersion string `json:"new_version,omitempty"` NewVersion string `json:"new_version,omitempty"`
Start jsonTime `json:"start"` Start JSONTime `json:"start"`
Version string `json:"version"` Version string `json:"version"`
} }
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP // handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
// API. // API.
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) { func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
writeJSONResponse(w, r, &RespGetV1SystemInfo{ writeJSONOKResponse(w, r, &RespGetV1SystemInfo{
Arch: runtime.GOARCH, Arch: runtime.GOARCH,
Channel: version.Channel(), Channel: version.Channel(),
OS: runtime.GOOS, OS: runtime.GOOS,
// TODO(a.garipov): Fill this when we have an updater. // TODO(a.garipov): Fill this when we have an updater.
NewVersion: "", NewVersion: "",
Start: jsonTime(svc.start), Start: JSONTime(svc.start),
Version: version.Version(), Version: version.Version(),
}) })
} }

View File

@ -8,16 +8,17 @@ import (
"testing" "testing"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc" "github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestService_handleGetV1SystemInfo(t *testing.T) { func TestService_handleGetV1SystemInfo(t *testing.T) {
_, addr := newTestServer(t) confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{ u := &url.URL{
Scheme: "http", Scheme: "http",
Host: addr, Host: addr.String(),
Path: websvc.PathV1SystemInfo, Path: websvc.PathV1SystemInfo,
} }

View File

@ -0,0 +1,31 @@
package websvc
import (
"net"
"sync"
)
// Wait Listener
// waitListener is a wrapper around a listener that also calls wg.Done() on the
// first call to Accept. It is useful in situations where it is important to
// catch the precise moment of the first call to Accept, for example when
// starting an HTTP server.
//
// TODO(a.garipov): Move to aghnet?
type waitListener struct {
net.Listener
firstAcceptWG *sync.WaitGroup
firstAcceptOnce sync.Once
}
// type check
var _ net.Listener = (*waitListener)(nil)
// Accept implements the [net.Listener] interface for *waitListener.
func (l *waitListener) Accept() (conn net.Conn, err error) {
l.firstAcceptOnce.Do(l.firstAcceptWG.Done)
return l.Listener.Accept()
}

View File

@ -0,0 +1,46 @@
package websvc
import (
"net"
"sync"
"sync/atomic"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghchan"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
)
func TestWaitListener_Accept(t *testing.T) {
// TODO(a.garipov): use atomic.Bool in Go 1.19.
var numAcceptCalls uint32
var l net.Listener = &aghtest.Listener{
OnAccept: func() (conn net.Conn, err error) {
atomic.AddUint32(&numAcceptCalls, 1)
return nil, nil
},
OnAddr: func() (addr net.Addr) { panic("not implemented") },
OnClose: func() (err error) { panic("not implemented") },
}
wg := &sync.WaitGroup{}
wg.Add(1)
done := make(chan struct{})
go aghchan.MustReceive(done, testTimeout)
go func() {
var wrapper net.Listener = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
_, _ = wrapper.Accept()
}()
wg.Wait()
close(done)
assert.Equal(t, uint32(1), atomic.LoadUint32(&numAcceptCalls))
}

View File

@ -1,4 +1,7 @@
// Package websvc contains the AdGuard Home web service. // Package websvc contains the AdGuard Home HTTP API service.
//
// NOTE: Packages other than cmd must not import this package, as it imports
// most other packages.
// //
// TODO(a.garipov): Add tests. // TODO(a.garipov): Add tests.
package websvc package websvc
@ -14,18 +17,35 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh" "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
httptreemux "github.com/dimfeld/httptreemux/v5" httptreemux "github.com/dimfeld/httptreemux/v5"
) )
// ConfigManager is the configuration manager interface.
type ConfigManager interface {
DNS() (svc agh.ServiceWithConfig[*dnssvc.Config])
Web() (svc agh.ServiceWithConfig[*Config])
UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error)
UpdateWeb(ctx context.Context, c *Config) (err error)
}
// Config is the AdGuard Home web service configuration structure. // Config is the AdGuard Home web service configuration structure.
type Config struct { type Config struct {
// ConfigManager is used to show information about services as well as
// dynamically reconfigure them.
ConfigManager ConfigManager
// TLS is the optional TLS configuration. If TLS is not nil, // TLS is the optional TLS configuration. If TLS is not nil,
// SecureAddresses must not be empty. // SecureAddresses must not be empty.
TLS *tls.Config TLS *tls.Config
// Start is the time of start of AdGuard Home.
Start time.Time
// Addresses are the addresses on which to serve the plain HTTP API. // Addresses are the addresses on which to serve the plain HTTP API.
Addresses []netip.AddrPort Addresses []netip.AddrPort
@ -33,40 +53,48 @@ type Config struct {
// SecureAddresses is not empty, TLS must not be nil. // SecureAddresses is not empty, TLS must not be nil.
SecureAddresses []netip.AddrPort SecureAddresses []netip.AddrPort
// Start is the time of start of AdGuard Home.
Start time.Time
// Timeout is the timeout for all server operations. // Timeout is the timeout for all server operations.
Timeout time.Duration Timeout time.Duration
// ForceHTTPS tells if all requests to Addresses should be redirected to a
// secure address instead.
//
// TODO(a.garipov): Use; define rules, which address to redirect to.
ForceHTTPS bool
} }
// Service is the AdGuard Home web service. A nil *Service is a valid // Service is the AdGuard Home web service. A nil *Service is a valid
// [agh.Service] that does nothing. // [agh.Service] that does nothing.
type Service struct { type Service struct {
tls *tls.Config confMgr ConfigManager
servers []*http.Server tls *tls.Config
start time.Time start time.Time
timeout time.Duration servers []*http.Server
timeout time.Duration
forceHTTPS bool
} }
// New returns a new properly initialized *Service. If c is nil, svc is a nil // New returns a new properly initialized *Service. If c is nil, svc is a nil
// *Service that does nothing. // *Service that does nothing. The fields of c must not be modified after
// calling New.
func New(c *Config) (svc *Service) { func New(c *Config) (svc *Service) {
if c == nil { if c == nil {
return nil return nil
} }
svc = &Service{ svc = &Service{
tls: c.TLS, confMgr: c.ConfigManager,
start: c.Start, tls: c.TLS,
timeout: c.Timeout, start: c.Start,
timeout: c.Timeout,
forceHTTPS: c.ForceHTTPS,
} }
mux := newMux(svc) mux := newMux(svc)
for _, a := range c.Addresses { for _, a := range c.Addresses {
addr := a.String() addr := a.String()
errLog := log.StdLog("websvc: http: "+addr, log.ERROR) errLog := log.StdLog("websvc: plain http: "+addr, log.ERROR)
svc.servers = append(svc.servers, &http.Server{ svc.servers = append(svc.servers, &http.Server{
Addr: addr, Addr: addr,
Handler: mux, Handler: mux,
@ -111,6 +139,21 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) {
method: http.MethodGet, method: http.MethodGet,
path: PathHealthCheck, path: PathHealthCheck,
isJSON: false, isJSON: false,
}, {
handler: svc.handleGetSettingsAll,
method: http.MethodGet,
path: PathV1SettingsAll,
isJSON: true,
}, {
handler: svc.handlePatchSettingsDNS,
method: http.MethodPatch,
path: PathV1SettingsDNS,
isJSON: true,
}, {
handler: svc.handlePatchSettingsHTTP,
method: http.MethodPatch,
path: PathV1SettingsHTTP,
isJSON: true,
}, { }, {
handler: svc.handleGetV1SystemInfo, handler: svc.handleGetV1SystemInfo,
method: http.MethodGet, method: http.MethodGet,
@ -119,29 +162,41 @@ func newMux(svc *Service) (mux *httptreemux.ContextMux) {
}} }}
for _, r := range routes { for _, r := range routes {
var h http.HandlerFunc
if r.isJSON { if r.isJSON {
// TODO(a.garipov): Consider using httptreemux's MiddlewareFunc. mux.Handle(r.method, r.path, jsonMw(r.handler))
h = jsonMw(r.handler)
} else { } else {
h = r.handler mux.Handle(r.method, r.path, r.handler)
} }
mux.Handle(r.method, r.path, h)
} }
return mux return mux
} }
// Addrs returns all addresses on which this server serves the HTTP API. Addrs // addrs returns all addresses on which this server serves the HTTP API. addrs
// must not be called until Start returns. // must not be called simultaneously with Start. If svc was initialized with
func (svc *Service) Addrs() (addrs []string) { // ":0" addresses, addrs will not return the actual bound ports until Start is
addrs = make([]string, 0, len(svc.servers)) // finished.
func (svc *Service) addrs() (addrs, secureAddrs []netip.AddrPort) {
for _, srv := range svc.servers { for _, srv := range svc.servers {
addrs = append(addrs, srv.Addr) addrPort, err := netip.ParseAddrPort(srv.Addr)
if err != nil {
// Technically shouldn't happen, since all servers must have a valid
// address.
panic(fmt.Errorf("websvc: server %q: bad address: %w", srv.Addr, err))
}
// srv.Serve will set TLSConfig to an almost empty value, so, instead of
// relying only on the nilness of TLSConfig, check the length of the
// certificates field as well.
if srv.TLSConfig == nil || len(srv.TLSConfig.Certificates) == 0 {
addrs = append(addrs, addrPort)
} else {
secureAddrs = append(secureAddrs, addrPort)
}
} }
return addrs return addrs, secureAddrs
} }
// handleGetHealthCheck is the handler for the GET /health-check HTTP API. // handleGetHealthCheck is the handler for the GET /health-check HTTP API.
@ -149,9 +204,6 @@ func (svc *Service) handleGetHealthCheck(w http.ResponseWriter, _ *http.Request)
_, _ = io.WriteString(w, "OK") _, _ = io.WriteString(w, "OK")
} }
// unit is a convenient alias for struct{}.
type unit = struct{}
// type check // type check
var _ agh.Service = (*Service)(nil) var _ agh.Service = (*Service)(nil)
@ -163,11 +215,9 @@ func (svc *Service) Start() (err error) {
return nil return nil
} }
srvs := svc.servers
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(len(srvs)) wg.Add(len(svc.servers))
for _, srv := range srvs { for _, srv := range svc.servers {
go serve(srv, wg) go serve(srv, wg)
} }
@ -181,11 +231,14 @@ func serve(srv *http.Server, wg *sync.WaitGroup) {
addr := srv.Addr addr := srv.Addr
defer log.OnPanic(addr) defer log.OnPanic(addr)
var proto string
var l net.Listener var l net.Listener
var err error var err error
if srv.TLSConfig == nil { if srv.TLSConfig == nil {
proto = "http"
l, err = net.Listen("tcp", addr) l, err = net.Listen("tcp", addr)
} else { } else {
proto = "https"
l, err = tls.Listen("tcp", addr, srv.TLSConfig) l, err = tls.Listen("tcp", addr, srv.TLSConfig)
} }
if err != nil { if err != nil {
@ -196,8 +249,12 @@ func serve(srv *http.Server, wg *sync.WaitGroup) {
// would mean that a random available port was automatically chosen. // would mean that a random available port was automatically chosen.
srv.Addr = l.Addr().String() srv.Addr = l.Addr().String()
log.Info("websvc: starting srv http://%s", srv.Addr) log.Info("websvc: starting srv %s://%s", proto, srv.Addr)
wg.Done()
l = &waitListener{
Listener: l,
firstAcceptWG: wg,
}
err = srv.Serve(l) err = srv.Serve(l)
if err != nil && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
@ -221,8 +278,28 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
} }
if len(errs) > 0 { if len(errs) > 0 {
return errors.List("shutting down") return errors.List("shutting down", errs...)
} }
return nil return nil
} }
// Config returns the current configuration of the web service. Config must not
// be called simultaneously with Start. If svc was initialized with ":0"
// addresses, addrs will not return the actual bound ports until Start is
// finished.
func (svc *Service) Config() (c *Config) {
c = &Config{
ConfigManager: svc.confMgr,
TLS: svc.tls,
// Leave Addresses and SecureAddresses empty and get the actual
// addresses that include the :0 ones later.
Start: svc.start,
Timeout: svc.timeout,
ForceHTTPS: svc.forceHTTPS,
}
c.Addresses, c.SecureAddresses = svc.addrs()
return c
}

View File

@ -0,0 +1,6 @@
package websvc
import "time"
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second

View File

@ -0,0 +1,188 @@
package websvc_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/AdGuardHome/internal/next/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
aghtest.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
// testStart is the server start value for tests.
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// type check
var _ websvc.ConfigManager = (*configManager)(nil)
// configManager is a [websvc.ConfigManager] for tests.
type configManager struct {
onDNS func() (svc agh.ServiceWithConfig[*dnssvc.Config])
onWeb func() (svc agh.ServiceWithConfig[*websvc.Config])
onUpdateDNS func(ctx context.Context, c *dnssvc.Config) (err error)
onUpdateWeb func(ctx context.Context, c *websvc.Config) (err error)
}
// DNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) DNS() (svc agh.ServiceWithConfig[*dnssvc.Config]) {
return m.onDNS()
}
// Web implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) Web() (svc agh.ServiceWithConfig[*websvc.Config]) {
return m.onWeb()
}
// UpdateDNS implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateDNS(ctx context.Context, c *dnssvc.Config) (err error) {
return m.onUpdateDNS(ctx, c)
}
// UpdateWeb implements the [websvc.ConfigManager] interface for *configManager.
func (m *configManager) UpdateWeb(ctx context.Context, c *websvc.Config) (err error) {
return m.onUpdateWeb(ctx, c)
}
// newConfigManager returns a *configManager all methods of which panic.
func newConfigManager() (m *configManager) {
return &configManager{
onDNS: func() (svc agh.ServiceWithConfig[*dnssvc.Config]) { panic("not implemented") },
onWeb: func() (svc agh.ServiceWithConfig[*websvc.Config]) { panic("not implemented") },
onUpdateDNS: func(_ context.Context, _ *dnssvc.Config) (err error) {
panic("not implemented")
},
onUpdateWeb: func(_ context.Context, _ *websvc.Config) (err error) {
panic("not implemented")
},
}
}
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(
t testing.TB,
confMgr websvc.ConfigManager,
) (svc *websvc.Service, addr netip.AddrPort) {
t.Helper()
c := &websvc.Config{
ConfigManager: confMgr,
TLS: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: nil,
Timeout: testTimeout,
Start: testStart,
ForceHTTPS: false,
}
svc = websvc.New(c)
err := svc.Start()
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
})
c = svc.Config()
require.NotNil(t, c)
require.Len(t, c.Addresses, 1)
return svc, c.Addresses[0]
}
// jobj is a utility alias for JSON objects.
type jobj map[string]any
// httpGet is a helper that performs an HTTP GET request and returns the body of
// the response as well as checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
t.Helper()
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
// httpPatch is a helper that performs an HTTP PATCH request with JSON-encoded
// reqBody as the request body and returns the body of the response as well as
// checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpPatch(t testing.TB, u *url.URL, reqBody any, wantCode int) (body []byte) {
t.Helper()
b, err := json.Marshal(reqBody)
require.NoErrorf(t, err, "marshaling reqBody")
req, err := http.NewRequest(http.MethodPatch, u.String(), bytes.NewReader(b))
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
func TestService_Start_getHealthCheck(t *testing.T) {
confMgr := newConfigManager()
_, addr := newTestServer(t, confMgr)
u := &url.URL{
Scheme: "http",
Host: addr.String(),
Path: websvc.PathHealthCheck,
}
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
}

View File

@ -1,7 +1,6 @@
package querylog package querylog
import ( import (
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@ -48,24 +47,7 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
// convert log entries to JSON // convert log entries to JSON
data := l.entriesToJSON(entries, oldest) data := l.entriesToJSON(entries, oldest)
jsonVal, err := json.Marshal(data) _ = aghhttp.WriteJSONResponse(w, r, data)
if err != nil {
aghhttp.Error(
r,
w,
http.StatusInternalServerError,
"Couldn't marshal data into json: %s",
err,
)
return
}
w.Header().Set("Content-Type", "application/json")
_, err = w.Write(jsonVal)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "Unable to write response json: %s", err)
}
} }
func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) { func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
@ -74,23 +56,13 @@ func (l *queryLog) handleQueryLogClear(_ http.ResponseWriter, _ *http.Request) {
// Get configuration // Get configuration
func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) { func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
resp := qlogConfig{} resp := qlogConfig{
resp.Enabled = l.conf.Enabled Enabled: l.conf.Enabled,
resp.Interval = l.conf.RotationIvl.Hours() / 24 Interval: l.conf.RotationIvl.Hours() / 24,
resp.AnonymizeClientIP = l.conf.AnonymizeClientIP AnonymizeClientIP: l.conf.AnonymizeClientIP,
jsonVal, err := json.Marshal(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
return
} }
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, resp)
_, err = w.Write(jsonVal)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "http write: %s", err)
}
} }
// AnonymizeIP masks ip to anonymize the client if the ip is a valid one. // AnonymizeIP masks ip to anonymize the client if the ip is a valid one.

View File

@ -55,12 +55,7 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
return return
} }
w.Header().Set("Content-Type", "application/json") _ = aghhttp.WriteJSONResponse(w, r, resp)
err := json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
}
} }
// configResp is the response to the GET /control/stats_info. // configResp is the response to the GET /control/stats_info.
@ -71,13 +66,7 @@ type configResp struct {
// handleStatsInfo handles requests to the GET /control/stats_info endpoint. // handleStatsInfo handles requests to the GET /control/stats_info endpoint.
func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) { func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
resp := configResp{IntervalDays: atomic.LoadUint32(&s.limitHours) / 24} resp := configResp{IntervalDays: atomic.LoadUint32(&s.limitHours) / 24}
_ = aghhttp.WriteJSONResponse(w, r, resp)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(resp)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "json encode: %s", err)
}
} }
// handleStatsConfig handles requests to the POST /control/stats_config // handleStatsConfig handles requests to the POST /control/stats_config

View File

@ -1,33 +0,0 @@
// Package agh contains common entities and interfaces of AdGuard Home.
//
// TODO(a.garipov): Move to the upper-level internal/.
package agh
import "context"
// Service is the interface for API servers.
//
// TODO(a.garipov): Consider adding a context to Start.
//
// TODO(a.garipov): Consider adding a Wait method or making an extension
// interface for that.
type Service interface {
// Start starts the service. It does not block.
Start() (err error)
// Shutdown gracefully stops the service. ctx is used to determine
// a timeout before trying to stop the service less gracefully.
Shutdown(ctx context.Context) (err error)
}
// type check
var _ Service = EmptyService{}
// EmptyService is a Service that does nothing.
type EmptyService struct{}
// Start implements the Service interface for EmptyService.
func (EmptyService) Start() (err error) { return nil }
// Shutdown implements the Service interface for EmptyService.
func (EmptyService) Shutdown(_ context.Context) (err error) { return nil }

View File

@ -1,70 +0,0 @@
package cmd
import (
"os"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/v1/agh"
"github.com/AdguardTeam/golibs/log"
)
// signalHandler processes incoming signals and shuts services down.
type signalHandler struct {
signal chan os.Signal
// services are the services that are shut down before application
// exiting.
services []agh.Service
}
// handle processes OS signals.
func (h *signalHandler) handle() {
defer log.OnPanic("signalHandler.handle")
for sig := range h.signal {
log.Info("sighdlr: received signal %q", sig)
if aghos.IsShutdownSignal(sig) {
h.shutdown()
}
}
}
// Exit status constants.
const (
statusSuccess = 0
statusError = 1
)
// shutdown gracefully shuts down all services.
func (h *signalHandler) shutdown() {
ctx, cancel := ctxWithDefaultTimeout()
defer cancel()
status := statusSuccess
log.Info("sighdlr: shutting down services")
for i, service := range h.services {
err := service.Shutdown(ctx)
if err != nil {
log.Error("sighdlr: shutting down service at index %d: %s", i, err)
status = statusError
}
}
log.Info("sighdlr: shutting down adguard home")
os.Exit(status)
}
// newSignalHandler returns a new signalHandler that shuts down svcs.
func newSignalHandler(svcs ...agh.Service) (h *signalHandler) {
h = &signalHandler{
signal: make(chan os.Signal, 1),
services: svcs,
}
aghos.NotifyShutdownSignal(h.signal)
return h
}

View File

@ -1,61 +0,0 @@
package websvc
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/AdguardTeam/golibs/log"
)
// JSON Utilities
// jsonTime is a time.Time that can be decoded from JSON and encoded into JSON
// according to our API conventions.
type jsonTime time.Time
// type check
var _ json.Marshaler = jsonTime{}
// nsecPerMsec is the number of nanoseconds in a millisecond.
const nsecPerMsec = float64(time.Millisecond / time.Nanosecond)
// MarshalJSON implements the json.Marshaler interface for jsonTime. err is
// always nil.
func (t jsonTime) MarshalJSON() (b []byte, err error) {
msec := float64(time.Time(t).UnixNano()) / nsecPerMsec
b = strconv.AppendFloat(nil, msec, 'f', 3, 64)
return b, nil
}
// type check
var _ json.Unmarshaler = (*jsonTime)(nil)
// UnmarshalJSON implements the json.Marshaler interface for *jsonTime.
func (t *jsonTime) UnmarshalJSON(b []byte) (err error) {
if t == nil {
return fmt.Errorf("json time is nil")
}
msec, err := strconv.ParseFloat(string(b), 64)
if err != nil {
return fmt.Errorf("parsing json time: %w", err)
}
*t = jsonTime(time.Unix(0, int64(msec*nsecPerMsec)).UTC())
return nil
}
// writeJSONResponse encodes v into w and logs any errors it encounters. r is
// used to get additional information from the request.
func writeJSONResponse(w io.Writer, r *http.Request, v any) {
err := json.NewEncoder(w).Encode(v)
if err != nil {
log.Error("websvc: writing resp to %s %s: %s", r.Method, r.URL.Path, err)
}
}

View File

@ -1,8 +0,0 @@
package websvc
// Path constants
const (
PathHealthCheck = "/health-check"
PathV1SystemInfo = "/api/v1/system/info"
)

View File

@ -1,93 +0,0 @@
package websvc_test
import (
"context"
"io"
"net/http"
"net/netip"
"net/url"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/v1/websvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const testTimeout = 1 * time.Second
// testStart is the server start value for tests.
var testStart = time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)
// newTestServer creates and starts a new web service instance as well as its
// sole address. It also registers a cleanup procedure, which shuts the
// instance down.
//
// TODO(a.garipov): Use svc or remove it.
func newTestServer(t testing.TB) (svc *websvc.Service, addr string) {
t.Helper()
c := &websvc.Config{
TLS: nil,
Addresses: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:0")},
SecureAddresses: nil,
Timeout: testTimeout,
Start: testStart,
}
svc = websvc.New(c)
err := svc.Start()
require.NoError(t, err)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
t.Cleanup(cancel)
err = svc.Shutdown(ctx)
require.NoError(t, err)
})
addrs := svc.Addrs()
require.Len(t, addrs, 1)
return svc, addrs[0]
}
// httpGet is a helper that performs an HTTP GET request and returns the body of
// the response as well as checks that the status code is correct.
//
// TODO(a.garipov): Add helpers for other methods.
func httpGet(t testing.TB, u *url.URL, wantCode int) (body []byte) {
t.Helper()
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
require.NoErrorf(t, err, "creating req")
httpCli := &http.Client{
Timeout: testTimeout,
}
resp, err := httpCli.Do(req)
require.NoErrorf(t, err, "performing req")
require.Equal(t, wantCode, resp.StatusCode)
testutil.CleanupAndRequireSuccess(t, resp.Body.Close)
body, err = io.ReadAll(resp.Body)
require.NoErrorf(t, err, "reading body")
return body
}
func TestService_Start_getHealthCheck(t *testing.T) {
_, addr := newTestServer(t)
u := &url.URL{
Scheme: "http",
Host: addr,
Path: websvc.PathHealthCheck,
}
body := httpGet(t, u, http.StatusOK)
assert.Equal(t, []byte("OK"), body)
}

View File

@ -63,14 +63,6 @@ func Version() (v string) {
return version return version
} }
// Constants defining the format of module information string.
const (
modInfoAtSep = "@"
modInfoDevSep = " "
modInfoSumLeft = " (sum: "
modInfoSumRight = ")"
)
// fmtModule returns formatted information about module. The result looks like: // fmtModule returns formatted information about module. The result looks like:
// //
// github.com/Username/module@v1.2.3 (sum: someHASHSUM=) // github.com/Username/module@v1.2.3 (sum: someHASHSUM=)
@ -87,14 +79,16 @@ func fmtModule(m *debug.Module) (formatted string) {
stringutil.WriteToBuilder(b, m.Path) stringutil.WriteToBuilder(b, m.Path)
if ver := m.Version; ver != "" { if ver := m.Version; ver != "" {
sep := modInfoAtSep sep := "@"
if ver == "(devel)" { if ver == "(devel)" {
sep = modInfoDevSep sep = " "
} }
stringutil.WriteToBuilder(b, sep, ver) stringutil.WriteToBuilder(b, sep, ver)
} }
if sum := m.Sum; sum != "" { if sum := m.Sum; sum != "" {
stringutil.WriteToBuilder(b, modInfoSumLeft, sum, modInfoSumRight) stringutil.WriteToBuilder(b, "(sum: ", sum, ")")
} }
return b.String() return b.String()

View File

@ -1,5 +1,5 @@
//go:build !v1 //go:build !next
// +build !v1 // +build !next
package main package main

View File

@ -1,12 +1,12 @@
//go:build v1 //go:build next
// +build v1 // +build next
package main package main
import ( import (
"embed" "embed"
"github.com/AdguardTeam/AdGuardHome/internal/v1/cmd" "github.com/AdguardTeam/AdGuardHome/internal/next/cmd"
) )
// Embed the prebuilt client here since we strive to keep .go files inside the // Embed the prebuilt client here since we strive to keep .go files inside the

View File

@ -2289,7 +2289,7 @@
'upstream_servers': 'upstream_servers':
- '1.1.1.1' - '1.1.1.1'
- '8.8.8.8' - '8.8.8.8'
'upstream_timeout': '1s' 'upstream_timeout': 1000
'required': 'required':
- 'addresses' - 'addresses'
- 'blocking_mode' - 'blocking_mode'
@ -2397,8 +2397,9 @@
'type': 'array' 'type': 'array'
'upstream_timeout': 'upstream_timeout':
'description': > 'description': >
Upstream request timeout, as a human readable duration. Upstream request timeout, in milliseconds.
'type': 'string' 'format': 'double'
'type': 'number'
'type': 'object' 'type': 'object'
'DnsType': 'DnsType':
@ -3505,14 +3506,16 @@
'addresses': 'addresses':
- '127.0.0.1:80' - '127.0.0.1:80'
- '192.168.1.1:80' - '192.168.1.1:80'
'force_https': true
'secure_addresses': 'secure_addresses':
- '127.0.0.1:443' - '127.0.0.1:443'
- '192.168.1.1:443' - '192.168.1.1:443'
'force_https': true 'timeout': 10000
'required': 'required':
- 'addresses' - 'addresses'
- 'secure_addresses'
- 'force_https' - 'force_https'
- 'secure_addresses'
- 'timeout'
'HttpSettingsPatch': 'HttpSettingsPatch':
'description': > 'description': >
@ -3539,6 +3542,11 @@
'items': 'items':
'type': 'string' 'type': 'string'
'type': 'array' 'type': 'array'
'timeout':
'description': >
HTTP request timeout, in milliseconds.
'format': 'double'
'type': 'number'
'type': 'object' 'type': 'object'
'InternalServerErrorResp': 'InternalServerErrorResp':

View File

@ -124,11 +124,11 @@ GO111MODULE='on'
export CGO_ENABLED GO111MODULE export CGO_ENABLED GO111MODULE
# Build the new binary if requested. # Build the new binary if requested.
if [ "${V1API:-0}" -eq '0' ] if [ "${NEXTAPI:-0}" -eq '0' ]
then then
tags_flags='--tags=' tags_flags='--tags='
else else
tags_flags='--tags=v1' tags_flags='--tags=next'
fi fi
readonly tags_flags readonly tags_flags

View File

@ -136,11 +136,11 @@ underscores() {
-e '_freebsd.go'\ -e '_freebsd.go'\
-e '_linux.go'\ -e '_linux.go'\
-e '_little.go'\ -e '_little.go'\
-e '_next.go'\
-e '_openbsd.go'\ -e '_openbsd.go'\
-e '_others.go'\ -e '_others.go'\
-e '_test.go'\ -e '_test.go'\
-e '_unix.go'\ -e '_unix.go'\
-e '_v1.go'\
-e '_windows.go' \ -e '_windows.go' \
-v\ -v\
| sed -e 's/./\t\0/' | sed -e 's/./\t\0/'
@ -223,13 +223,12 @@ govulncheck ./...
# Apply more lax standards to the code we haven't properly refactored yet. # Apply more lax standards to the code we haven't properly refactored yet.
gocyclo --over 17 ./internal/querylog/ gocyclo --over 17 ./internal/querylog/
gocyclo --over 15 ./internal/home/ ./internal/dhcpd gocyclo --over 13 ./internal/dhcpd ./internal/filtering/ ./internal/home/
gocyclo --over 13 ./internal/filtering/
# Apply stricter standards to new or somewhat refactored code. # Apply stricter standards to new or somewhat refactored code.
gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\ gocyclo --over 10 ./internal/aghio/ ./internal/aghnet/ ./internal/aghos/\
./internal/aghtest/ ./internal/dnsforward/ ./internal/stats/\ ./internal/aghtest/ ./internal/dnsforward/ ./internal/stats/\
./internal/tools/ ./internal/updater/ ./internal/v1/ ./internal/version/\ ./internal/tools/ ./internal/updater/ ./internal/next/ ./internal/version/\
./main.go ./main.go
ineffassign ./... ineffassign ./...