diff --git a/go.mod b/go.mod index a522b6a1..992b4428 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome go 1.18 require ( - github.com/AdguardTeam/dnsproxy v0.45.3 + github.com/AdguardTeam/dnsproxy v0.46.1 github.com/AdguardTeam/golibs v0.10.9 github.com/AdguardTeam/urlfilter v0.16.0 github.com/NYTimes/gziphandler v1.1.1 @@ -18,7 +18,7 @@ require ( github.com/google/uuid v1.3.0 github.com/insomniacslk/dhcp v0.0.0-20220822114210-de18a9d48e84 github.com/kardianos/service v1.2.1 - github.com/lucas-clemente/quic-go v0.29.1 + github.com/lucas-clemente/quic-go v0.29.2 github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118 github.com/mdlayher/netlink v1.6.0 // TODO(a.garipov): This package is deprecated; find a new one or use @@ -49,8 +49,8 @@ require ( github.com/golang/mock v1.6.0 // indirect github.com/josharian/native v1.0.0 // indirect github.com/marten-seemann/qpack v0.2.1 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.2 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.0 // indirect + github.com/marten-seemann/qtls-go1-18 v0.1.3 // indirect + github.com/marten-seemann/qtls-go1-19 v0.1.1 // indirect github.com/mdlayher/packet v1.0.0 // indirect github.com/mdlayher/socket v0.2.3 // indirect github.com/nxadm/tail v1.4.8 // indirect diff --git a/go.sum b/go.sum index 529e49c1..f5eff3fa 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/AdguardTeam/dnsproxy v0.45.3 h1:lvJlifDIVjHFVkVcieBhXyQA357Wl+vmLxeDlaQ8DE8= -github.com/AdguardTeam/dnsproxy v0.45.3/go.mod h1:h+0r4GDvHHY2Wu6r7oqva+O37h00KofYysfzy1TEXFE= +github.com/AdguardTeam/dnsproxy v0.46.1 h1:ej9iRorG+vekaXGYB854waAiS+q8OfswYZ1MQRZolHk= +github.com/AdguardTeam/dnsproxy v0.46.1/go.mod h1:PAmRzFqls0E92XTglyY2ESAqMAzZJhHKErG1ZpRnpjA= github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4= github.com/AdguardTeam/golibs v0.10.4/go.mod h1:rSfQRGHIdgfxriDDNgNJ7HmE5zRoURq8R+VdR81Zuzw= github.com/AdguardTeam/golibs v0.10.9 h1:F9oP2da0dQ9RQDM1lGR7LxUTfUWu8hEFOs4icwAkKM0= @@ -90,14 +90,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lucas-clemente/quic-go v0.29.1 h1:Z+WMJ++qMLhvpFkRZA+jl3BTxUjm415YBmWanXB8zP0= -github.com/lucas-clemente/quic-go v0.29.1/go.mod h1:CTcNfLYJS2UuRNB+zcNlgvkjBhxX6Hm3WUxxAQx2mgE= +github.com/lucas-clemente/quic-go v0.29.2 h1:O8Mt0O6LpvEW+wfC40vZdcw0DngwYzoxq5xULZNzSI8= +github.com/lucas-clemente/quic-go v0.29.2/go.mod h1:g6/h9YMmLuU54tL1gW25uIi3VlBp3uv+sBihplIuskE= github.com/marten-seemann/qpack v0.2.1 h1:jvTsT/HpCn2UZJdP+UUB53FfUUgeOyG5K1ns0OJOGVs= github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= -github.com/marten-seemann/qtls-go1-18 v0.1.2 h1:JH6jmzbduz0ITVQ7ShevK10Av5+jBEKAHMntXmIV7kM= -github.com/marten-seemann/qtls-go1-18 v0.1.2/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= -github.com/marten-seemann/qtls-go1-19 v0.1.0 h1:rLFKD/9mp/uq1SYGYuVZhm83wkmU95pK5df3GufyYYU= -github.com/marten-seemann/qtls-go1-19 v0.1.0/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= +github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= +github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= +github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= +github.com/marten-seemann/qtls-go1-19 v0.1.1/go.mod h1:5HTDWtVudo/WFsHKRNuOhWlbdjrfs5JHrYb0wIJqGpI= github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7/go.mod h1:U6ZQobyTjI/tJyq2HG+i/dfSoFUt8/aZCM+GKtmFk/Y= github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118 h1:2oDp6OOhLxQ9JBoUuysVz9UZ9uI6oLUbvAZu0x8o+vE= github.com/mdlayher/ethernet v0.0.0-20220221185849-529eae5b6118/go.mod h1:ZFUnHIVchZ9lJoWoEGUg8Q3M4U8aNNWA3CVSUTkW4og= diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index ea919889..66f0211b 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -162,6 +162,7 @@ var _ upstream.Upstream = (*UpstreamMock)(nil) type UpstreamMock struct { OnAddress func() (addr string) OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) + OnClose func() (err error) } // Address implements the [upstream.Upstream] interface for *UpstreamMock. @@ -173,3 +174,8 @@ func (u *UpstreamMock) Address() (addr string) { func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { return u.OnExchange(req) } + +// Close implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Close() (err error) { + return u.OnClose() +} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index 699c14b9..77a2ae1d 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" "github.com/miekg/dns" "github.com/stretchr/testify/require" @@ -31,6 +32,8 @@ type Upstream struct { Addr string } +var _ upstream.Upstream = (*Upstream)(nil) + // RespondTo returns a response with answer if req has class cl, question type // qt, and target targ. func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) { @@ -68,7 +71,7 @@ func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) ( return resp } -// Exchange implements the upstream.Upstream interface for *Upstream. +// Exchange implements the [upstream.Upstream] interface for *Upstream. // // TODO(a.garipov): Split further into handlers. func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { @@ -114,11 +117,16 @@ func (u *Upstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { return resp, nil } -// Address implements upstream.Upstream interface for *Upstream. +// Address implements [upstream.Upstream] interface for *Upstream. func (u *Upstream) Address() string { return u.Addr } +// Close implements [upstream.Upstream] interface for *Upstream. +func (u *Upstream) Close() (err error) { + return nil +} + // NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that // supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is // true, hostname's actual hash is returned, blocking it. Otherwise, it returns diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 2455fff5..f31e28b4 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -518,7 +518,7 @@ func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) } // prepareInternalProxy initializes the DNS proxy that is used for internal DNS -// queries, such at client PTR resolving and updater hostname resolving. +// queries, such as public clients PTR resolving and updater hostname resolving. func (s *Server) prepareInternalProxy() (err error) { conf := &proxy.Config{ CacheEnabled: true, @@ -558,16 +558,37 @@ func (s *Server) Stop() error { return s.stopLocked() } -// stopLocked stops the DNS server without locking. For internal use only. -func (s *Server) stopLocked() error { +// stopLocked stops the DNS server without locking. For internal use only. +func (s *Server) stopLocked() (err error) { + var errs []error + if s.dnsProxy != nil { - err := s.dnsProxy.Stop() + err = s.dnsProxy.Stop() if err != nil { - return fmt.Errorf("could not stop the DNS server properly: %w", err) + errs = append(errs, fmt.Errorf("could not stop primary resolvers properly: %w", err)) } } - s.isRunning = false + if s.internalProxy != nil && s.internalProxy.UpstreamConfig != nil { + err = s.internalProxy.UpstreamConfig.Close() + if err != nil { + errs = append(errs, fmt.Errorf("could not stop internal resolvers properly: %w", err)) + } + } + + if s.localResolvers != nil && s.localResolvers.UpstreamConfig != nil { + err = s.localResolvers.UpstreamConfig.Close() + if err != nil { + errs = append(errs, fmt.Errorf("could not stop local resolvers properly: %w", err)) + } + } + + if len(errs) > 0 { + return errors.List("stopping DNS server", errs...) + } else { + s.isRunning = false + } + return nil } diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 91e31b70..97749c9c 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -603,6 +603,7 @@ func checkDNS( if err != nil { return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err) } + defer func() { err = errors.WithDeferred(err, u.Close()) }() if err = healthCheck(u); err != nil { err = fmt.Errorf("upstream %q fails to exchange: %w", upstreamAddr, err) diff --git a/internal/home/clients.go b/internal/home/clients.go index ee535ea9..ef1a11a7 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -21,6 +21,8 @@ import ( "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" ) const clientsUpdatePeriod = 10 * time.Minute @@ -50,6 +52,18 @@ type Client struct { UseOwnBlockedServices bool } +// closeUpstreams closes the client-specific upstream config of c if any. +func (c *Client) closeUpstreams() (err error) { + if c.upstreamConfig != nil { + err = c.upstreamConfig.Close() + if err != nil { + return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err) + } + } + + return nil +} + type clientSource uint // Client sources. The order determines the priority. @@ -651,6 +665,10 @@ func (clients *clientsContainer) Del(name string) (ok bool) { return false } + if err := c.closeUpstreams(); err != nil { + log.Error("client container: removing client %s: %s", name, err) + } + // update Name index delete(clients.list, name) @@ -709,7 +727,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { } } - // update ID index + // Update ID index. for _, id := range prev.IDs { delete(clients.idIndex, id) } @@ -718,14 +736,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { } } - // update Name index + // Update name index. if prev.Name != c.Name { delete(clients.list, prev.Name) clients.list[c.Name] = prev } - // update upstreams cache - c.upstreamConfig = nil + // Update upstreams cache. + err = c.closeUpstreams() + if err != nil { + return err + } *prev = *c @@ -910,3 +931,24 @@ func (clients *clientsContainer) updateFromDHCP(add bool) { log.Debug("clients: added %d client aliases from dhcp", n) } + +// Close gracefully closes all the client-specific upstream configurations of +// the persistent clients. +func (clients *clientsContainer) Close() (err error) { + persistent := maps.Values(clients.list) + slices.SortFunc(persistent, func(a, b *Client) (less bool) { return a.Name < b.Name }) + + var errs []error + + for _, cli := range persistent { + if err = cli.closeUpstreams(); err != nil { + errs = append(errs, err) + } + } + + if len(errs) > 0 { + return errors.List("closing client specific upstreams", errs...) + } + + return nil +} diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 59821d0a..4e4d22f2 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -179,6 +179,7 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. if !clients.Del(cj.Name) { aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") + return } diff --git a/internal/home/dns.go b/internal/home/dns.go index 6c0d6531..b5367b7d 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -431,17 +431,23 @@ func reconfigureDNSServer() (err error) { return nil } -func stopDNSServer() error { +func stopDNSServer() (err error) { if !isRunning() { return nil } - err := Context.dnsServer.Stop() + err = Context.dnsServer.Stop() if err != nil { - return fmt.Errorf("couldn't stop forwarding DNS server: %w", err) + return fmt.Errorf("stopping forwarding dns server: %w", err) + } + + err = Context.clients.Close() + if err != nil { + return fmt.Errorf("closing clients container: %w", err) } closeDNSServer() + return nil } diff --git a/internal/home/home.go b/internal/home/home.go index 5d3582e9..6dbe1600 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -122,7 +122,6 @@ func Main(clientBuildFS fs.FS) { case syscall.SIGHUP: Context.clients.Reload() Context.tls.reload() - default: cleanup(context.Background()) cleanupAlways() diff --git a/internal/next/dnssvc/dnssvc_test.go b/internal/next/dnssvc/dnssvc_test.go index 8205897c..814e3bad 100644 --- a/internal/next/dnssvc/dnssvc_test.go +++ b/internal/next/dnssvc/dnssvc_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/errors" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,6 +26,8 @@ func TestService(t *testing.T) { const ( bootstrapAddr = "bootstrap.example" upstreamAddr = "upstream.example" + + closeErr errors.Error = "closing failed" ) ups := &aghtest.UpstreamMock{ @@ -36,6 +39,9 @@ func TestService(t *testing.T) { return resp, nil }, + OnClose: func() (err error) { + return closeErr + }, } c := &dnssvc.Config{ @@ -85,5 +91,5 @@ func TestService(t *testing.T) { defer cancel() err = svc.Shutdown(ctx) - require.NoError(t, err) + require.ErrorIs(t, err, closeErr) }