net/dns{., resolver}: time out DNS queries after 10 seconds (#4690)

Fixes https://github.com/tailscale/corp/issues/5198

The upstream forwarder will block indefinitely on `udpconn.ReadFrom` if no
reply is recieved, due to the lack of deadline on the connection object.

There still isn't a deadline on the connection object, but the automatic closing
of the context on deadline expiry will close the connection via `closeOnCtxDone`,
unblocking the read and resulting in a normal teardown.

Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom 2022-05-18 10:40:04 -07:00 committed by GitHub
parent ec4c49a338
commit acfe5bd33b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -331,6 +331,9 @@ func (m *Manager) NextPacket() ([]byte, error) {
return buf, nil
}
// Query executes a DNS query recieved from the given address. The query is
// provided in bs as a wire-encoded DNS query without any transport header.
// This method is called for requests arriving over UDP and TCP.
func (m *Manager) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) {
select {
case <-m.ctx.Done():
@ -460,7 +463,7 @@ func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netaddr.IPPort) {
responses: make(chan []byte),
readClosing: make(chan struct{}),
}
s.ctx, s.closeCtx = context.WithCancel(context.Background())
s.ctx, s.closeCtx = context.WithCancel(m.ctx)
go s.handleReads()
s.handleWrites()
}

View File

@ -256,6 +256,12 @@ func (r *Resolver) Close() {
r.forwarder.Close()
}
// dnsQueryTimeout is not intended to be user-visible (the users
// DNS resolver will retry well before that), just put an upper
// bound on per-query resource usage.
const dnsQueryTimeout = 10 * time.Second
func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([]byte, error) {
metricDNSQueryLocal.Add(1)
select {
@ -268,7 +274,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, from netaddr.IPPort) ([
out, err := r.respond(bs)
if err == errNotOurName {
responses := make(chan packet, 1)
ctx, cancel := context.WithCancel(ctx)
ctx, cancel := context.WithTimeout(ctx, dnsQueryTimeout)
defer close(responses)
defer cancel()
err = r.forwarder.forwardWithDestChan(ctx, packet{bs, from}, responses)