Pull request 2193: AGDNS-1982 Upd proxy

Closes #6854.Updates #6875.

Squashed commit of the following:

commit b98adbc0cc
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 10 19:21:44 2024 +0300

    dnsforward: upd proxy, imp code, docs

commit 4de1eb2bca
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 10 16:09:58 2024 +0300

    WIP

commit afa9d61e8d
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 9 19:24:09 2024 +0300

    all: log changes

commit c8340676a4
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 9 19:06:10 2024 +0300

    dnsforward: move code

commit 08bb7d43d2
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 9 18:09:46 2024 +0300

    dnsforward: imp code

commit b27547ec80
Merge: b7efca788 6f36ebc06
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 9 17:33:19 2024 +0300

    Merge branch 'master' into AGDNS-1982-upd-proxy

commit b7efca788b
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Tue Apr 9 17:27:14 2024 +0300

    all: upd proxy finally

commit 3e16fa87be
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Fri Apr 5 18:20:13 2024 +0300

    dnsforward: upd proxy

commit f3cdfc8633
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 4 20:37:32 2024 +0300

    all: upd proxy, golibs

commit a79298d6d0
Merge: 9feeba5c7 fd25dcacb
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 4 20:34:01 2024 +0300

    Merge branch 'master' into AGDNS-1982-upd-proxy

commit 9feeba5c7f
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 4 20:25:57 2024 +0300

    all: imp code, docs

commit 6c68d463db
Merge: d8108e651 ee619b2db
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Thu Apr 4 18:46:11 2024 +0300

    Merge branch 'master' into AGDNS-1982-upd-proxy

commit d8108e6516
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 3 19:25:27 2024 +0300

    all: imp code

commit 2046156580
Author: Eugene Burkov <E.Burkov@AdGuard.COM>
Date:   Wed Apr 3 17:10:33 2024 +0300

    all: remove private rdns logic
This commit is contained in:
Eugene Burkov 2024-04-11 14:03:37 +03:00
parent 6f36ebc06c
commit ff7c715c5f
19 changed files with 765 additions and 1365 deletions

View File

@ -27,7 +27,17 @@ NOTE: Add new changes BELOW THIS COMMENT.
- Support for comments in the ipset file ([#5345]). - Support for comments in the ipset file ([#5345]).
### Fixed
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
incorrectly considered invalid when specified for private RDNS upstream
servers ([#6854]).
- Unspecified IP addresses aren't checked when using "Fastest IP address" mode
([#6875]).
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345 [#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
[#6875]: https://github.com/AdguardTeam/AdGuardHome/issues/6875
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.

10
go.mod
View File

@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
go 1.22.2 go 1.22.2
require ( require (
github.com/AdguardTeam/dnsproxy v0.67.0 github.com/AdguardTeam/dnsproxy v0.69.1
github.com/AdguardTeam/golibs v0.21.0 github.com/AdguardTeam/golibs v0.23.0
github.com/AdguardTeam/urlfilter v0.18.0 github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnscrypt/v2 v2.2.7
@ -28,12 +28,12 @@ require (
// own code for that. Perhaps, use gopacket. // own code for that. Perhaps, use gopacket.
github.com/mdlayher/raw v0.1.0 github.com/mdlayher/raw v0.1.0
github.com/miekg/dns v1.1.58 github.com/miekg/dns v1.1.58
github.com/quic-go/quic-go v0.41.0 github.com/quic-go/quic-go v0.42.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.9.0
github.com/ti-mo/netfilter v0.5.1 github.com/ti-mo/netfilter v0.5.1
go.etcd.io/bbolt v1.3.9 go.etcd.io/bbolt v1.3.9
golang.org/x/crypto v0.21.0 golang.org/x/crypto v0.21.0
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8
golang.org/x/net v0.23.0 golang.org/x/net v0.23.0
golang.org/x/sys v0.18.0 golang.org/x/sys v0.18.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1

26
go.sum
View File

@ -1,7 +1,7 @@
github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw= github.com/AdguardTeam/dnsproxy v0.69.1 h1:KiLkKUSrvHeUO/YEf4Bbo/5zyFRIvQstjL7W9G/24pk=
github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls= github.com/AdguardTeam/dnsproxy v0.69.1/go.mod h1:atO3WeeuyepyhjSt6hC+MF7/IN7TZHfG3/ZwhImHzYs=
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= github.com/AdguardTeam/golibs v0.23.0 h1:PHz/QhJhLmoaOokkqrPFUgu9Hw4iVAqLtBP0O3g1D3Q=
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/AdguardTeam/golibs v0.23.0/go.mod h1:/xZCf6gZZzz7k1qaoJmI+hhxN98kHFr7LJ22j1nLH0c=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s= github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@ -101,19 +101,19 @@ github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k= github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA= github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4= github.com/shirou/gopsutil/v3 v3.23.7 h1:C+fHO8hfIppoJ1WdsVm1RoI0RwXoNdfTK7yWXV0wVj4=
github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4= github.com/shirou/gopsutil/v3 v3.23.7/go.mod h1:c4gnmoRC0hQuaLqvxnx1//VXQ0Ms/X9UnJF8pddY5z4=
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU= github.com/ti-mo/netfilter v0.2.0/go.mod h1:8GbBGsY/8fxtyIdfwy29JiluNcPK4K7wIT+x42ipqUU=
github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8= github.com/ti-mo/netfilter v0.5.1 h1:cqamEd1c1zmpfpqvInLOro0Znq/RAfw2QL5wL2rAR/8=
github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw= github.com/ti-mo/netfilter v0.5.1/go.mod h1:h9UPQ3ZrTZGBitay+LETMxZvNgWGK/efTUcqES2YiLw=
@ -133,8 +133,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8 h1:aAcj0Da7eBAtrTp03QXWvm88pSyOt+UgdZw2BFZ+lEw=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/exp v0.0.0-20240325151524-a685a6edb6d8/go.mod h1:CQ1k9gNrJ50XIzaKCRR2hssIjF07kZFEiieALBM/ARQ=
golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic=
@ -168,6 +168,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw=

View File

@ -0,0 +1,113 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(e.burkov): Write tests.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// errAccessBlocked is a sentinel error returned when a request is blocked by
// access settings.
var errAccessBlocked errors.Error = "blocked by access settings"
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return errAccessBlocked
}
return &proxy.BeforeRequestError{
Err: errAccessBlocked,
Response: s.makeResponseREFUSED(pctx.Req),
}
}

View File

@ -110,46 +110,6 @@ type quicConnection interface {
ConnectionState() (cs quic.ConnectionState) ConnectionState() (cs quic.ConnectionState)
} }
// clientIDFromDNSContext extracts the client's ID from the server name of the
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
// is not one of these, clientID is an empty string and err is nil.
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
proto := pctx.Proto
if proto == proxy.ProtoHTTPS {
clientID, err = clientIDFromDNSContextHTTPS(pctx)
if err != nil {
return "", fmt.Errorf("checking url: %w", err)
} else if clientID != "" {
return clientID, nil
}
// Go on and check the domain name as well.
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
return "", nil
}
hostSrvName := s.conf.ServerName
if hostSrvName == "" {
return "", nil
}
cliSrvName, err := clientServerName(pctx, proto)
if err != nil {
return "", err
}
clientID, err = clientIDFromClientServerName(
hostSrvName,
cliSrvName,
s.conf.StrictSNICheck,
)
if err != nil {
return "", fmt.Errorf("clientid check: %w", err)
}
return clientID, nil
}
// clientServerName returns the TLS server name based on the protocol. For // clientServerName returns the TLS server name based on the protocol. For
// DNS-over-HTTPS requests, it will return the hostname part of the Host header // DNS-over-HTTPS requests, it will return the hostname part of the Host header
// if there is one. // if there is one.

View File

@ -235,9 +235,18 @@ type DNSCryptConfig struct {
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use. // The zero ServerConfig is empty and ready for use.
type ServerConfig struct { type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address // UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
TCPListenAddrs []*net.TCPAddr // TCP listen address UDPListenAddrs []*net.UDPAddr
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
TCPListenAddrs []*net.TCPAddr
// UpstreamConfig is the general configuration of upstream DNS servers.
UpstreamConfig *proxy.UpstreamConfig
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
// for private reverse DNS.
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
// AddrProcConf defines the configuration for the client IP processor. // AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used. // If nil, [client.EmptyAddrProc] is used.
@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{ conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3, HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit), Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6, RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist, RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny, RefuseAny: srvConf.RefuseAny,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes), TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL, CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL, CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic, CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig, UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
RequestHandler: s.handleDNSRequest, BeforeRequestHandler: s,
HTTPSServerName: aghhttp.UserAgent(), RequestHandler: s.handleDNSRequest,
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, HTTPSServerName: aghhttp.UserAgent(),
MaxGoroutines: srvConf.MaxGoroutines, EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
UseDNS64: srvConf.UseDNS64, MaxGoroutines: srvConf.MaxGoroutines,
DNS64Prefs: srvConf.DNS64Prefixes, UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {
@ -459,6 +472,26 @@ func (s *Server) prepareIpsetListSettings() (err error) {
return s.ipset.init(ipsets) return s.ipset.init(ipsets)
} }
// loadUpstreams parses upstream DNS servers from the configured file or from
// the configuration itself.
func (conf *ServerConfig) loadUpstreams() (upstreams []string, err error) {
if conf.UpstreamDNSFileName == "" {
return stringutil.FilterOut(conf.UpstreamDNS, IsCommentOrEmpty), nil
}
var data []byte
data, err = os.ReadFile(conf.UpstreamDNSFileName)
if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err)
}
upstreams = stringutil.SplitTrimmed(string(data), "\n")
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), conf.UpstreamDNSFileName)
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil
}
// collectListenAddr adds addrPort to addrs. It also adds its port to // collectListenAddr adds addrPort to addrs. It also adds its port to
// unspecPorts if its address is unspecified. // unspecPorts if its address is unspecified.
func collectListenAddr( func collectListenAddr(
@ -530,8 +563,8 @@ func (m *combinedAddrPortSet) Has(addrPort netip.AddrPort) (ok bool) {
return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr()) return m.ports.Has(addrPort.Port()) && m.addrs.Has(addrPort.Addr())
} }
// filterOut filters out all the upstreams that match um. It returns all the // filterOutAddrs filters out all the upstreams that match um. It returns all
// closing errors joined. // the closing errors joined.
func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) { func filterOutAddrs(upsConf *proxy.UpstreamConfig, set addrPortSet) (err error) {
var errs []error var errs []error
delFunc := func(u upstream.Upstream) (ok bool) { delFunc := func(u upstream.Upstream) (ok bool) {

View File

@ -3,7 +3,6 @@ package dnsforward
import ( import (
"net" "net"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -64,6 +64,8 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
} }
func TestServer_HandleDNSRequest_dns64(t *testing.T) { func TestServer_HandleDNSRequest_dns64(t *testing.T) {
t.Parallel()
const ( const (
ipv4Domain = "ipv4.only." ipv4Domain = "ipv4.only."
ipv6Domain = "ipv6.only." ipv6Domain = "ipv6.only."
@ -252,32 +254,33 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1) require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain) require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR}, resp := (&dns.Msg{}).SetReply(m)
}).SetReply(m) resp.Answer = []dns.RR{localRR}
require.NoError(t, w.WriteMsg(resp)) require.NoError(t, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{ client := &dns.Client{
Net: "tcp", Net: string(proxy.ProtoTCP),
Timeout: 1 * time.Second, Timeout: testTimeout,
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel()
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0] q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype] answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{ resp := (&dns.Msg{}).SetReply(req)
Answer: answer[sectionAnswer], resp.Answer = answer[sectionAnswer]
Ns: answer[sectionAuthority], resp.Ns = answer[sectionAuthority]
Extra: answer[sectionAdditional], resp.Extra = answer[sectionAdditional]
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp)) require.NoError(pt, w.WriteMsg(resp))
}) })
@ -307,10 +310,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String()) resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, excErr) require.NoError(t, excErr)
require.Equal(t, tc.wantAns, resp.Answer) require.Equal(t, tc.wantAns, resp.Answer)
}) })
} }
} }
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
t.Parallel()
// Shouldn't go to upstream at all.
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
panic("not implemented")
})
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
cli := &dns.Client{
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
}

View File

@ -135,12 +135,6 @@ type Server struct {
// WHOIS, etc. // WHOIS, etc.
addrProc client.AddressProcessor addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
// sysResolvers used to fetch system resolvers to use by default for private // sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving. // PTR resolving.
sysResolvers SystemResolvers sysResolvers SystemResolvers
@ -158,12 +152,6 @@ type Server struct {
// [upstream.Resolver] interface. // [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major // dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
// part of DNS64 happens inside the [proxy] package, but there still are // part of DNS64 happens inside the [proxy] package, but there still are
// some places where response mapping is needed (e.g. DHCP). // some places where response mapping is needed (e.g. DHCP).
@ -212,14 +200,6 @@ type DNSCreateParams struct {
LocalDomain string LocalDomain string
} }
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
// //
@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{ clientIDCache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxCount: defaultClientIDCacheCount, MaxCount: defaultClientIDCacheCount,
@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
// TODO(e.burkov): Migrate to [netip.Addr] already.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", 0, fmt.Errorf("reversing ip: %w", err) return "", 0, fmt.Errorf("reversing ip: %w", err)
@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
} }
dctx := &proxy.DNSContext{ dctx := &proxy.DNSContext{
Proto: "udp", Proto: proxy.ProtoUDP,
Req: req, Req: req,
IsPrivateClient: true,
} }
var resolver *proxy.Proxy
var errMsg string var errMsg string
if s.privateNets.Contains(ip) { if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", 0, nil return "", 0, nil
} }
resolver = s.localResolvers
errMsg = "resolving a private address: %w" errMsg = "resolving a private address: %w"
s.recDetector.add(*req) dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
} else { } else {
resolver = s.internalProxy
errMsg = "resolving an address: %w" errMsg = "resolving an address: %w"
} }
if err = resolver.Resolve(dctx); err != nil { if err = s.internalProxy.Resolve(dctx); err != nil {
return "", 0, fmt.Errorf(errMsg, err) return "", 0, fmt.Errorf(errMsg, err)
} }
@ -473,103 +451,6 @@ func (s *Server) startLocked() error {
return err return err
} }
// prepareLocalResolvers initializes the local upstreams configuration using
// boot as bootstrap. It assumes that s.serverLock is locked or s not running.
func (s *Server) prepareLocalResolvers(
boot upstream.Resolver,
) (uc *proxy.UpstreamConfig, err error) {
set, err := s.conf.ourAddrsSet()
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
resolvers := s.conf.LocalPTRResolvers
confNeedsFiltering := len(resolvers) > 0
if confNeedsFiltering {
resolvers = stringutil.FilterOut(resolvers, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(s.sysResolvers.Addrs()), set.Has)
resolvers = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
resolvers = append(resolvers, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", resolvers)
uc, err = s.prepareUpstreamConfig(resolvers, nil, &upstream.Options{
Bootstrap: boot,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
})
if err != nil {
return nil, fmt.Errorf("preparing private upstreams: %w", err)
}
if confNeedsFiltering {
err = filterOutAddrs(uc, set)
if err != nil {
return nil, fmt.Errorf("filtering private upstreams: %w", err)
}
}
return uc, nil
}
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
// nil. // nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) { func (s *Server) Prepare(conf *ServerConfig) (err error) {
@ -586,7 +467,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings() s.initDefaultSettings()
boot, err := s.prepareInternalDNS() err = s.prepareInternalDNS()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
@ -608,12 +489,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err) return fmt.Errorf("preparing access: %w", err)
} }
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
proxyConfig.Fallbacks, err = s.setupFallbackDNS() proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil { if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err) return fmt.Errorf("setting up fallback dns servers: %w", err)
@ -626,8 +501,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.dnsProxy = dnsProxy s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc() s.setupAddrProc()
s.registerHandlers() s.registerHandlers()
@ -635,36 +508,127 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return nil return nil
} }
// prepareInternalDNS initializes the internal state of s before initializing // prepareUpstreamSettings sets upstream DNS server settings.
// the primary DNS proxy instance. It assumes s.serverLock is locked or the func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) {
// Server not running. // Load upstreams either from the file, or from the settings
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { var upstreams []string
err = s.prepareIpsetListSettings() upstreams, err = s.conf.loadUpstreams()
if err != nil { if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err) return fmt.Errorf("loading upstreams: %w", err)
} }
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ uc, err := newUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Timeout: DefaultTimeout, Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams), HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
}) })
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
s.conf.UpstreamConfig = uc
return nil
}
// PrivateRDNSError is returned when the private rDNS upstreams are
// invalid but enabled.
//
// TODO(e.burkov): Consider allowing to use incomplete private rDNS upstreams
// configuration in proxy when the private rDNS function is enabled. In theory,
// proxy supports the case when no upstreams provided to resolve the private
// request, since it already supports this for DNS64-prefixed PTR requests.
type PrivateRDNSError struct {
err error
}
// Error implements the [errors.Error] interface.
func (e *PrivateRDNSError) Error() (s string) {
return e.err.Error()
}
func (e *PrivateRDNSError) Unwrap() (err error) {
return e.err
}
// prepareLocalResolvers initializes the private RDNS upstream configuration
// according to the server's settings. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
return nil, nil
}
var ownAddrs addrPortSet
ownAddrs, err = s.conf.ourAddrsSet()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return nil, err
} }
opts := &upstream.Options{
Bootstrap: s.bootstrap,
Timeout: defaultLocalTimeout,
// TODO(e.burkov): Should we verify server's certificates?
PreferIPv6: s.conf.BootstrapPreferIPv6,
}
addrs := s.conf.LocalPTRResolvers
uc, err = newPrivateConfig(addrs, ownAddrs, s.sysResolvers, s.privateNets, opts)
if err != nil {
return nil, fmt.Errorf("preparing resolvers: %w", err)
}
return uc, nil
}
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (err error) {
err = s.prepareIpsetListSettings()
if err != nil {
return fmt.Errorf("preparing ipset settings: %w", err)
}
bootOpts := &upstream.Options{
Timeout: DefaultTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
}
s.bootstrap, s.bootResolvers, err = newBootstrap(s.conf.BootstrapDNS, s.etcHosts, bootOpts)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
err = s.prepareUpstreamSettings(s.bootstrap) err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err return err
}
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers()
if err != nil {
return err
} }
err = s.prepareInternalProxy() err = s.prepareInternalProxy()
if err != nil { if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) return fmt.Errorf("preparing internal proxy: %w", err)
} }
return s.bootstrap, nil return nil
} }
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
@ -743,10 +707,16 @@ func validateBlockingMode(
func (s *Server) prepareInternalProxy() (err error) { func (s *Server) prepareInternalProxy() (err error) {
srvConf := s.conf srvConf := s.conf
conf := &proxy.Config{ conf := &proxy.Config{
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
UpstreamConfig: srvConf.UpstreamConfig, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
MaxGoroutines: s.conf.MaxGoroutines, UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
MessageConstructor: s,
} }
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
@ -782,11 +752,6 @@ func (s *Server) stopLocked() (err error) {
} }
} }
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers { for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
} }

View File

@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
res *filtering.Result, res *filtering.Result,
pctx *proxy.DNSContext, pctx *proxy.DNSContext,
) (err error) { ) (err error) {
resp := s.makeResponse(req) resp := s.replyCompressed(req)
dnsrr := res.DNSRewriteResult dnsrr := res.DNSRewriteResult
if dnsrr == nil { if dnsrr == nil {
return errors.Error("no dns rewrite rule content") return errors.Error("no dns rewrite rule content")

View File

@ -1,57 +1,17 @@
package dnsforward package dnsforward
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// beforeRequestHandler is the handler that is called before any other
// processing, including logs. It performs access checks and puts the client
// ID, if there is one, into the server's cache.
func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return true, nil
}
// clientRequestFilteringSettings looks up client filtering settings using the // clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx. // client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {

View File

@ -261,55 +261,17 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
} }
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid. // validate returns an error if any field of req is invalid.
// //
// TODO(s.chzhen): Parse, don't validate. // TODO(s.chzhen): Parse, don't validate.
func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validate(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
defer func() { err = errors.Annotate(err, "validating dns config: %w") }() defer func() { err = errors.Annotate(err, "validating dns config: %w") }()
err = req.validateUpstreamDNSServers(privateNets) err = req.validateUpstreamDNSServers(ownAddrs, sysResolvers, privateNets)
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
@ -342,17 +304,54 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil return nil
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid. // validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validateUpstreamDNSServers(
ownAddrs addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
) (err error) {
var uc *proxy.UpstreamConfig
opts := &upstream.Options{}
if req.Upstreams != nil { if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{}) uc, err = proxy.ParseUpstreamsConfig(*req.Upstreams, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil { if err != nil {
return fmt.Errorf("upstream servers: %w", err) return fmt.Errorf("upstream servers: %w", err)
} }
} }
if req.LocalPTRUpstreams != nil { if addrs := req.LocalPTRUpstreams; addrs != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) uc, err = newPrivateConfig(*addrs, ownAddrs, sysResolvers, privateNets, opts)
err = errors.WithDeferred(err, uc.Close())
if err != nil { if err != nil {
return fmt.Errorf("private upstream servers: %w", err) return fmt.Errorf("private upstream servers: %w", err)
} }
@ -364,10 +363,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err return err
} }
err = req.checkFallbacks() if req.Fallbacks != nil {
if err != nil { uc, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, opts)
// Don't wrap the error since it's informative enough as is. err = errors.WithDeferred(err, uc.Close())
return err if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
} }
return nil return nil
@ -436,7 +437,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return return
} }
err = req.validate(s.privateNets) // TODO(e.burkov): Consider prebuilding this set on startup.
ourAddrs, err := s.conf.ourAddrsSet()
if err != nil {
// TODO(e.burkov): !! Put into openapi
aghhttp.Error(r, w, http.StatusInternalServerError, "getting our addresses: %s", err)
return
}
err = req.validate(ourAddrs, s.sysResolvers, s.privateNets)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -587,7 +597,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
} }
var boots []*upstream.UpstreamResolver var boots []*upstream.UpstreamResolver
opts.Bootstrap, boots, err = s.createBootstrap(req.BootstrapDNS, opts) opts.Bootstrap, boots, err = newBootstrap(req.BootstrapDNS, s.etcHosts, opts)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err) aghhttp.Error(r, w, http.StatusBadRequest, "Failed to parse bootstrap servers: %s", err)

View File

@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "", wantSet: "",
}, { }, {
name: "local_ptr_upstreams_bad", name: "local_ptr_upstreams_bad",
wantSet: `validating dns config: ` + wantSet: `validating dns config: private upstream servers: ` +
`private upstream servers: checking domain-specific upstreams: ` + `bad arpa domain name "non.arpa": not a reversed ip network`,
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, { }, {
name: "local_ptr_upstreams_null", name: "local_ptr_upstreams_null",
wantSet: "", wantSet: "",
@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
} }
} }
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper() t.Helper()

View File

@ -11,17 +11,21 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// makeResponse creates a DNS response by req and sets necessary flags. It also // TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// guarantees that req.Question will be not empty. // template. Also extract all the methods to a separate entity.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
resp.SetReply(req) // reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
// replyCompressed creates a DNS response for req and sets the compress flag.
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeSuccess)
resp.Compress = true
return resp return resp
} }
@ -51,7 +55,7 @@ func (s *Server) genDNSFilterMessage(
if qt != dns.TypeA && qt != dns.TypeAAAA { if qt != dns.TypeA && qt != dns.TypeAAAA {
m, _, _ := s.dnsFilter.BlockingMode() m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP { if m == filtering.BlockingModeNullIP {
return s.makeResponse(req) return s.replyCompressed(req)
} }
return s.newMsgNODATA(req) return s.newMsgNODATA(req)
@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
// getCNAMEWithIPs generates a filtered response to req for with CNAME record // getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips. // and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
originalName := req.Question[0].Name originalName := req.Question[0].Name
@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeNullIP: case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req) return s.makeResponseNullIP(req)
case filtering.BlockingModeNXDOMAIN: case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req) return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED: case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req) return s.makeResponseREFUSED(req)
default: default:
log.Error("dnsforward: invalid blocking mode %q", mode) log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
// genDNSFilterMessage. // genDNSFilterMessage.
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp return resp
} }
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp return resp
} }
@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// Go on and return an empty response. // Go on and return an empty response.
} }
resp = s.makeResponse(req) resp = s.replyCompressed(req)
resp.Answer = ans resp.Answer = ans
return resp return resp
@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
case dns.TypeAAAA: case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default: default:
resp = s.makeResponse(req) resp = s.replyCompressed(req)
} }
return resp return resp
@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if newAddr == "" { if newAddr == "" {
log.Info("dnsforward: block host is not specified") log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
ip, err := netip.ParseAddr(newAddr) ip, err := netip.ParseAddr(newAddr)
@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if prx == nil { if prx == nil {
log.Debug("dnsforward: %s", srvClosedErr) log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
err = prx.Resolve(newContext) err = prx.Resolve(newContext)
if err != nil { if err != nil {
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
resp := s.makeResponse(request) resp := s.replyCompressed(request)
if newContext.Res != nil { if newContext.Res != nil {
for _, answer := range newContext.Res.Answer { for _, answer := range newContext.Res.Answer {
answer.Header().Name = request.Question[0].Name answer.Header().Name = request.Question[0].Name
@ -342,47 +339,21 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
return resp return resp
} }
// preBlockedResponse returns a protocol-appropriate response for a request that
// was blocked by access settings.
func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err error) {
if pctx.Proto == proxy.ProtoUDP || pctx.Proto == proxy.ProtoDNSCrypt {
// Return nil so that dnsproxy drops the connection and thus
// prevent DNS amplification attacks.
return false, nil
}
pctx.Res = s.makeResponseREFUSED(pctx.Req)
return true, nil
}
// Create REFUSED DNS response // Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
resp := dns.Msg{} return s.reply(req, dns.RcodeRefused)
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
} }
// newMsgNODATA returns a properly initialized NODATA response. // newMsgNODATA returns a properly initialized NODATA response.
// //
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) { func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess) resp = s.reply(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(req) resp.Ns = s.genSOA(req)
return resp return resp
} }
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR { func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := "" zone := ""
if len(request.Question) > 0 { if len(request.Question) > 0 {
@ -414,5 +385,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
if len(zone) > 0 && zone[0] != '.' { if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone soa.Mbox += zone
} }
return []dns.RR{&soa} return []dns.RR{&soa}
} }
// type check
var _ proxy.MessageConstructor = (*Server)(nil)
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNameError)
resp.Ns = s.genSOA(req)
return resp
}
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return s.reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}

View File

@ -1,20 +1,17 @@
package dnsforward package dnsforward
import ( import (
"cmp"
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
@ -34,11 +31,6 @@ type dnsContext struct {
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function. // err is the error returned from a processing function.
err error err error
@ -63,10 +55,6 @@ type dnsContext struct {
// responseAD shows if the response had the AD bit set. // responseAD shows if the response had the AD bit set.
responseAD bool responseAD bool
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is // isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request. // available for this request.
isDHCPHost bool isDHCPHost bool
@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// (*proxy.Proxy).handleDNSRequest method performs it before calling the // (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler. // appropriate handler.
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion,
s.processInitial, s.processInitial,
s.processDDRQuery, s.processDDRQuery,
s.processDetermineLocal,
s.processDHCPHosts, s.processDHCPHosts,
s.processRestrictLocal,
s.processDHCPAddrs, s.processDHCPAddrs,
s.processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream, s.processUpstream,
s.processFilteringAfterResponse, s.processFilteringAfterResponse,
s.ipset.process, s.ipset.process,
@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
return nil return nil
} }
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its // mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server. // own DoH server.
// //
@ -199,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
} }
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
if q.Name == healthcheckFQDN { if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0. // Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req) pctx.Res = s.replyCompressed(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
@ -272,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
// //
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. // [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
if req.Question[0].Qtype != dns.TypeSVCB { if req.Question[0].Qtype != dns.TypeSVCB {
return resp return resp
} }
@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
return resp return resp
} }
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
// processDHCPHosts respond to A requests if the target hostname is known to // processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and // the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA. // the request is for AAAA.
@ -370,9 +323,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
if !dctx.isLocalClient { if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
// Do not even put into query log. // Do not even put into query log.
return resultCodeFinish return resultCodeFinish
@ -389,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req) resp := s.replyCompressed(req)
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA:
a := &dns.A{ a := &dns.A{
@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if q.Qtype != dns.TypePTR {
// No need for restriction.
return resultCodeSuccess
}
subnet, err := extractARPASubnet(q.Name)
if err != nil {
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
return resultCodeSuccess
}
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
// hostname so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the // processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server. // DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@ -562,20 +380,21 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
ipAddr := dctx.unreversedReqIP pref := pctx.RequestedPrivateRDNS
if ipAddr == (netip.Addr{}) { if pref == (netip.Prefix{}) {
return resultCodeSuccess return resultCodeSuccess
} }
host := s.dhcpServer.HostByIP(ipAddr) addr := pref.Addr()
host := s.dhcpServer.HostByIP(addr)
if host == "" { if host == "" {
return resultCodeSuccess return resultCodeSuccess
} }
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host) log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req req := pctx.Req
resp := s.makeResponse(req) resp := s.replyCompressed(req)
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: req.Question[0].Name, Name: req.Question[0].Name,
@ -593,62 +412,20 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
log.Debug("dnsforward: resolving private address: %s", err)
// Generate the server failure if the private upstream configuration
// is empty.
//
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError
}
}
if pctx.Res == nil {
pctx.Res = s.genNXDomain(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
// Apply filtering logic // Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req") log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
// so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
}
if dctx.proxyCtx.Res != nil { if dctx.proxyCtx.Res != nil {
// Go on since the response is already set. // Go on since the response is already set.
return resultCodeSuccess return resultCodeSuccess
@ -695,7 +472,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// local domain name if there is one. // local domain name if there is one.
name := req.Question[0].Name name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
return resultCodeFinish return resultCodeFinish
} }
@ -712,21 +489,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeError return resultCodeError
} }
if err := prx.Resolve(pctx); err != nil { if dctx.err = prx.Resolve(pctx); dctx.err != nil {
if errors.Is(err, upstream.ErrNoUpstreams) {
// Do not even put into querylog. Currently this happens either
// when the private resolvers enabled and the request is DNS64 PTR,
// or when the client isn't considered local by prx.
//
// TODO(e.burkov): Make proxy detect local client the same way as
// AGH does.
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
}
dctx.err = err
return resultCodeError return resultCodeError
} }
@ -810,7 +573,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
} }
// Use the ClientID first, since it has a higher priority. // Use the ClientID first, since it has a higher priority.
id := stringutil.Coalesce(clientID, pctx.Addr.Addr().String()) id := cmp.Or(clientID, pctx.Addr.Addr().String())
upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap) upsConf, err := s.conf.ClientsContainer.UpstreamConfigByID(id, s.bootstrap)
if err != nil { if err != nil {
log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err) log.Error("dnsforward: getting custom upstreams for client %s: %s", id, err)

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@ -375,44 +376,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f return f
} }
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
}
s.processDetermineLocal(dctx)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const ( const (
localDomainSuffix = "lan" localDomainSuffix = "lan"
@ -482,9 +445,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: tc.isLocalCli,
}, },
isLocalClient: tc.isLocalCli,
} }
res := s.processDHCPHosts(dctx) res := s.processDHCPHosts(dctx)
@ -617,9 +580,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: true,
}, },
isLocalClient: true,
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -654,19 +617,28 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
} }
} }
func TestServer_ProcessRestrictLocal(t *testing.T) { // TODO(e.burkov): Rewrite this test to use the whole server instead of just
// testing the [handleDNSRequest] method. See comment on
// "from_external_for_local" test case.
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
require.NoError(t, err)
extAddr := netip.MustParseAddr("254.253.252.1")
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
require.NoError(t, err)
const ( const (
extPTRQuestion = "251.252.253.254.in-addr.arpa." extPTRAnswer = "host1.example.net."
extPTRAnswer = "host1.example.net." intPTRAnswer = "some.local-client."
intPTRQuestion = "1.1.168.192.in-addr.arpa."
intPTRAnswer = "some.local-client."
) )
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
@ -692,123 +664,165 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
startDeferStop(t, s) startDeferStop(t, s)
testCases := []struct { testCases := []struct {
name string name string
want string question string
question net.IP wantErr error
cliAddr netip.AddrPort wantAns []dns.RR
wantLen int isPrivate bool
}{{ }{{
name: "from_local_to_external", name: "from_local_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_local", // In theory this case is not reproducible because [proxy.Proxy] should
want: "", // respond to such queries with NXDOMAIN before they reach
question: net.IP{192, 168, 1, 1}, // [Server.handleDNSRequest].
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"), name: "from_external_for_local",
wantLen: 0, question: intPTRQuestion,
wantErr: upstream.ErrNoUpstreams,
wantAns: nil,
isPrivate: false,
}, { }, {
name: "from_local_for_local", name: "from_local_for_local",
want: "some.local-client.", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(intPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(intPTRAnswer) + 1),
},
Ptr: dns.Fqdn(intPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_external", name: "from_external_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
reqAddr, err := dns.ReverseAddr(tc.question.String()) pref, extErr := netutil.ExtractReversedAddr(tc.question)
require.NoError(t, err) require.NoError(t, extErr)
req := createTestMessageWithType(reqAddr, dns.TypePTR)
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP, Req: req,
Req: req, IsPrivateClient: tc.isPrivate,
Addr: tc.cliAddr, }
// TODO(e.burkov): Configure the subnet set properly.
if netutil.IsLocallyServed(pref.Addr()) {
pctx.RequestedPrivateRDNS = pref
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, pctx) err = s.handleDNSRequest(s.dnsProxy, pctx)
require.NoError(t, err) require.ErrorIs(t, err, tc.wantErr)
require.NotNil(t, pctx.Res)
require.Len(t, pctx.Res.Answer, tc.wantLen)
if tc.wantLen > 0 { require.NotNil(t, pctx.Res)
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, tc.wantAns, pctx.Res.Answer)
}
}) })
} }
} }
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { func TestServer_ProcessUpstream_localPTR(t *testing.T) {
const locDomain = "some.local." const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa." const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := cmp.Or( resp := cmp.Or(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer( newPrxCtx := func() (prxCtx *proxy.DNSContext) {
t, return &proxy.DNSContext{
&filtering.Config{ Addr: testClientAddrPort,
BlockingMode: filtering.BlockingModeDefault, Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}, IsPrivateClient: true,
ServerConfig{ RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
var proxyCtx *proxy.DNSContext
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
} }
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
} }
t.Run("enabled", func(t *testing.T) { t.Run("enabled", func(t *testing.T) {
setup(true) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc) require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer) require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, locDomain, ptr.Ptr)
}) })
t.Run("disabled", func(t *testing.T) { t.Run("disabled", func(t *testing.T) {
setup(false) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeFinish, rc) require.Equal(t, resultCodeError, rc)
require.Empty(t, proxyCtx.Res.Answer) require.Empty(t, pctx.Res.Answer)
}) })
} }
@ -826,129 +840,3 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil)) assert.Empty(t, ipStringFromAddr(nil))
}) })
} }
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@ -1,115 +0,0 @@
package dnsforward
import (
"bytes"
"encoding/binary"
"time"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests cache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: cache.New(cache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the real
// machine's endianness is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [netutil.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
log.Debug("writing message signature: %s", err)
}
return b.Bytes()
}

View File

@ -1,148 +0,0 @@
package dnsforward
import (
"encoding/binary"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 100
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
name string
msg dns.Msg
want int
}{{
name: "simple",
msg: dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
want: 1,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}
var sink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignature(msg)
}
assert.NotEmpty(b, sink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, sink)
})
}

View File

@ -2,90 +2,77 @@ package dnsforward
import ( import (
"fmt" "fmt"
"net/netip"
"os"
"slices" "slices"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
) )
// loadUpstreams parses upstream DNS servers from the configured file or from // newBootstrap returns a bootstrap resolver based on the configuration of s.
// the configuration itself. // boots are the upstream resolvers that should be closed after use. r is the
func (s *Server) loadUpstreams() (upstreams []string, err error) { // actual bootstrap resolver, which may include the system hosts.
if s.conf.UpstreamDNSFileName == "" { //
return stringutil.FilterOut(s.conf.UpstreamDNS, IsCommentOrEmpty), nil // TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func newBootstrap(
addrs []string,
etcHosts upstream.Resolver,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
} }
var data []byte boots, err = aghnet.ParseBootstraps(addrs, opts)
data, err = os.ReadFile(s.conf.UpstreamDNSFileName)
if err != nil { if err != nil {
return nil, fmt.Errorf("reading upstream from file: %w", err) // Don't wrap the error, since it's informative enough as is.
return nil, nil, err
} }
upstreams = stringutil.SplitTrimmed(string(data), "\n") var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
log.Debug("dnsforward: got %d upstreams in %q", len(upstreams), s.conf.UpstreamDNSFileName) if etcHosts != nil {
r = upstream.ConsequentResolver{etcHosts, parallel}
} else {
r = parallel
}
return stringutil.FilterOut(upstreams, IsCommentOrEmpty), nil return r, boots, nil
} }
// prepareUpstreamSettings sets upstream DNS server settings. // newUpstreamConfig returns the upstream configuration based on upstreams. If
func (s *Server) prepareUpstreamSettings(boot upstream.Resolver) (err error) { // upstreams slice specifies no default upstreams, defaultUpstreams are used to
// Load upstreams either from the file, or from the settings // create upstreams with no domain specifications. opts are used when creating
var upstreams []string // upstream configuration.
upstreams, err = s.loadUpstreams() func newUpstreamConfig(
if err != nil {
return fmt.Errorf("loading upstreams: %w", err)
}
s.conf.UpstreamConfig, err = s.prepareUpstreamConfig(upstreams, defaultDNS, &upstream.Options{
Bootstrap: boot,
Timeout: s.conf.UpstreamTimeout,
HTTPVersions: UpstreamHTTPVersions(s.conf.UseHTTP3Upstreams),
PreferIPv6: s.conf.BootstrapPreferIPv6,
// Use a customized set of RootCAs, because Go's default mechanism of
// loading TLS roots does not always work properly on some routers so we're
// loading roots manually and pass it here.
//
// See [aghtls.SystemRootCAs].
//
// TODO(a.garipov): Investigate if that's true.
RootCAs: s.conf.TLSv12Roots,
CipherSuites: s.conf.TLSCiphers,
})
if err != nil {
return fmt.Errorf("preparing upstream config: %w", err)
}
return nil
}
// prepareUpstreamConfig returns the upstream configuration based on upstreams
// and configuration of s.
func (s *Server) prepareUpstreamConfig(
upstreams []string, upstreams []string,
defaultUpstreams []string, defaultUpstreams []string,
opts *upstream.Options, opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) { ) (uc *proxy.UpstreamConfig, err error) {
uc, err = proxy.ParseUpstreamsConfig(upstreams, opts) uc, err = proxy.ParseUpstreamsConfig(upstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing upstream config: %w", err) return uc, fmt.Errorf("parsing upstreams: %w", err)
} }
if len(uc.Upstreams) == 0 && defaultUpstreams != nil { if len(uc.Upstreams) == 0 && len(defaultUpstreams) > 0 {
log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams) log.Info("dnsforward: warning: no default upstreams specified, using %v", defaultUpstreams)
var defaultUpstreamConfig *proxy.UpstreamConfig var defaultUpstreamConfig *proxy.UpstreamConfig
defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts) defaultUpstreamConfig, err = proxy.ParseUpstreamsConfig(defaultUpstreams, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("parsing default upstreams: %w", err) return uc, fmt.Errorf("parsing default upstreams: %w", err)
} }
uc.Upstreams = defaultUpstreamConfig.Upstreams uc.Upstreams = defaultUpstreamConfig.Upstreams
@ -94,6 +81,54 @@ func (s *Server) prepareUpstreamConfig(
return uc, nil return uc, nil
} }
// newPrivateConfig creates an upstream configuration for resolving PTR records
// for local addresses. The configuration is built either from the provided
// addresses or from the system resolvers. unwanted filters the resulting
// upstream configuration.
func newPrivateConfig(
addrs []string,
unwanted addrPortSet,
sysResolvers SystemResolvers,
privateNets netutil.SubnetSet,
opts *upstream.Options,
) (uc *proxy.UpstreamConfig, err error) {
confNeedsFiltering := len(addrs) > 0
if confNeedsFiltering {
addrs = stringutil.FilterOut(addrs, IsCommentOrEmpty)
} else {
sysResolvers := slices.DeleteFunc(slices.Clone(sysResolvers.Addrs()), unwanted.Has)
addrs = make([]string, 0, len(sysResolvers))
for _, r := range sysResolvers {
addrs = append(addrs, r.String())
}
}
log.Debug("dnsforward: upstreams to resolve ptr for local addresses: %v", addrs)
uc, err = proxy.ParseUpstreamsConfig(addrs, opts)
if err != nil {
return uc, fmt.Errorf("preparing private upstreams: %w", err)
}
if !confNeedsFiltering {
return uc, nil
}
err = filterOutAddrs(uc, unwanted)
if err != nil {
return uc, fmt.Errorf("filtering private upstreams: %w", err)
}
// Prevalidate the config to catch the exact error before creating proxy.
// See TODO on [PrivateRDNSError].
err = proxy.ValidatePrivateConfig(uc, privateNets)
if err != nil {
return uc, &PrivateRDNSError{err: err}
}
return uc, nil
}
// UpstreamHTTPVersions returns the HTTP versions for upstream configuration // UpstreamHTTPVersions returns the HTTP versions for upstream configuration
// depending on configuration. // depending on configuration.
func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) { func UpstreamHTTPVersions(http3 bool) (v []upstream.HTTPVersion) {
@ -130,85 +165,9 @@ func setProxyUpstreamMode(
return nil return nil
} }
// createBootstrap returns a bootstrap resolver based on the configuration of s.
// boots are the upstream resolvers that should be closed after use. r is the
// actual bootstrap resolver, which may include the system hosts.
//
// TODO(e.burkov): This function currently returns a resolver and a slice of
// the upstream resolvers, which are essentially the same. boots are returned
// for being able to close them afterwards, but it introduces an implicit
// contract that r could only be used before that. Anyway, this code should
// improve when the [proxy.UpstreamConfig] will become an [upstream.Resolver]
// and be used here.
func (s *Server) createBootstrap(
addrs []string,
opts *upstream.Options,
) (r upstream.Resolver, boots []*upstream.UpstreamResolver, err error) {
if len(addrs) == 0 {
addrs = defaultBootstrap
}
boots, err = aghnet.ParseBootstraps(addrs, opts)
if err != nil {
// Don't wrap the error, since it's informative enough as is.
return nil, nil, err
}
var parallel upstream.ParallelResolver
for _, b := range boots {
parallel = append(parallel, upstream.NewCachingResolver(b))
}
if s.etcHosts != nil {
r = upstream.ConsequentResolver{s.etcHosts, parallel}
} else {
r = parallel
}
return r, boots, nil
}
// IsCommentOrEmpty returns true if s starts with a "#" character or is empty. // IsCommentOrEmpty returns true if s starts with a "#" character or is empty.
// This function is useful for filtering out non-upstream lines from upstream // This function is useful for filtering out non-upstream lines from upstream
// configs. // configs.
func IsCommentOrEmpty(s string) (ok bool) { func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' return len(s) == 0 || s[0] == '#'
} }
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}

View File

@ -18,7 +18,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -157,14 +156,12 @@ func initDNSServer(
return fmt.Errorf("newServerConfig: %w", err) return fmt.Errorf("newServerConfig: %w", err)
} }
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [ErrBadPrivateRDNSUpstreams].
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
} }