net/dns/resolver: unexport Packet, only use it internally.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2021-03-31 23:06:47 -07:00
parent 5fb9e00ecf
commit f185d62dc8
4 changed files with 41 additions and 55 deletions

View File

@ -105,7 +105,7 @@ type forwarder struct {
logf logger.Logf
// responses is a channel by which responses are returned.
responses chan Packet
responses chan packet
// closed signals all goroutines to stop.
closed chan struct{}
// wg signals when all goroutines have stopped.
@ -126,7 +126,7 @@ func init() {
rand.Seed(time.Now().UnixNano())
}
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
return &forwarder{
logf: logger.WithPrefix(logf, "forward: "),
responses: responses,
@ -218,14 +218,11 @@ func (f *forwarder) recv(conn *fwdConn) {
f.mu.Unlock()
packet := Packet{
Payload: out,
Addr: record.src,
}
pkt := packet{out, record.src}
select {
case <-f.closed:
return
case f.responses <- packet:
case f.responses <- pkt:
// continue
}
}
@ -258,8 +255,8 @@ func (f *forwarder) cleanMap() {
}
// forward forwards the query to all upstream nameservers and returns the first response.
func (f *forwarder) forward(query Packet) error {
txid := getTxID(query.Payload)
func (f *forwarder) forward(query packet) error {
txid := getTxID(query.bs)
f.mu.Lock()
@ -269,14 +266,14 @@ func (f *forwarder) forward(query Packet) error {
return errNoUpstreams
}
f.txMap[txid] = forwardingRecord{
src: query.Addr,
src: query.addr,
createdAt: time.Now(),
}
f.mu.Unlock()
for _, upstream := range upstreams {
f.send(query.Payload, upstream)
f.send(query.bs, upstream)
}
return nil

View File

@ -44,14 +44,9 @@ var (
errNotOurName = errors.New("not a Tailscale DNS name")
)
// Packet represents a DNS payload together with the address of its origin.
type Packet struct {
// Payload is the application layer DNS payload.
// Resolver assumes ownership of the request payload when it is enqueued
// and cedes ownership of the response payload when it is returned from NextResponse.
Payload []byte
// Addr is the source address for a request and the destination address for a response.
Addr netaddr.IPPort
type packet struct {
bs []byte
addr netaddr.IPPort // src for a request, dst for a response
}
// Resolver is a DNS resolver for nodes on the Tailscale network,
@ -66,9 +61,9 @@ type Resolver struct {
forwarder *forwarder
// queue is a buffered channel holding DNS requests queued for resolution.
queue chan Packet
queue chan packet
// responses is an unbuffered channel to which responses are returned.
responses chan Packet
responses chan packet
// errors is an unbuffered channel to which errors are returned.
errors chan error
// closed signals all goroutines to stop.
@ -88,8 +83,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
r := &Resolver{
logf: logger.WithPrefix(logf, "dns: "),
linkMon: linkMon,
queue: make(chan Packet, queueSize),
responses: make(chan Packet),
queue: make(chan packet, queueSize),
responses: make(chan packet),
errors: make(chan error),
closed: make(chan struct{}),
}
@ -153,11 +148,11 @@ func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
// EnqueueRequest places the given DNS request in the resolver's queue.
// It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueueRequest(request Packet) error {
func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error {
select {
case <-r.closed:
return ErrClosed
case r.queue <- request:
case r.queue <- packet{bs, from}:
return nil
default:
return errFullQueue
@ -166,14 +161,14 @@ func (r *Resolver) EnqueueRequest(request Packet) error {
// NextResponse returns a DNS response to a previously enqueued request.
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextResponse() (Packet, error) {
func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error) {
select {
case <-r.closed:
return Packet{}, ErrClosed
return nil, netaddr.IPPort{}, ErrClosed
case resp := <-r.responses:
return resp, nil
return resp.bs, resp.addr, nil
case err := <-r.errors:
return Packet{}, err
return nil, netaddr.IPPort{}, err
}
}
@ -266,19 +261,19 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
func (r *Resolver) poll() {
defer r.wg.Done()
var packet Packet
var pkt packet
for {
select {
case <-r.closed:
return
case packet = <-r.queue:
case pkt = <-r.queue:
// continue
}
out, err := r.respond(packet.Payload)
out, err := r.respond(pkt.bs)
if err == errNotOurName {
err = r.forwarder.forward(packet)
err = r.forwarder.forward(pkt)
if err == nil {
// forward will send response into r.responses, nothing to do.
continue
@ -293,11 +288,11 @@ func (r *Resolver) poll() {
// continue
}
} else {
packet.Payload = out
pkt.bs = out
select {
case <-r.closed:
return
case r.responses <- packet:
case r.responses <- pkt:
// continue
}
}

View File

@ -109,10 +109,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
request := Packet{Payload: query}
r.EnqueueRequest(request)
resp, err := r.NextResponse()
return resp.Payload, err
r.EnqueueRequest(query, netaddr.IPPort{})
payload, _, err := r.NextResponse()
return payload, err
}
func mustIP(str string) netaddr.IP {
@ -418,21 +417,20 @@ func TestDelegateCollision(t *testing.T) {
// packets will have the same dns txid.
for _, p := range packets {
payload := dnspacket(p.qname, p.qtype)
req := Packet{Payload: payload, Addr: p.addr}
err := r.EnqueueRequest(req)
err := r.EnqueueRequest(payload, p.addr)
if err != nil {
t.Error(err)
}
}
// Despite the txid collision, the answer(s) should still match the query.
resp, err := r.NextResponse()
resp, addr, err := r.NextResponse()
if err != nil {
t.Error(err)
}
var p dns.Parser
_, err = p.Start(resp.Payload)
_, err = p.Start(resp)
if err != nil {
t.Error(err)
}
@ -456,8 +454,8 @@ func TestDelegateCollision(t *testing.T) {
}
for _, p := range packets {
if p.qtype == wantType && p.addr != resp.Addr {
t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
if p.qtype == wantType && p.addr != addr {
t.Errorf("addr = %v; want %v", addr, p.addr)
}
}
}

View File

@ -433,11 +433,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper)
// handleDNS is an outbound pre-filter resolving Tailscale domains.
func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP {
request := resolver.Packet{
Payload: append([]byte(nil), p.Payload()...),
Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port},
}
err := e.resolver.EnqueueRequest(request)
err := e.resolver.EnqueueRequest(append([]byte(nil), p.Payload()...), p.Src)
if err != nil {
e.logf("dns: enqueue: %v", err)
}
@ -449,7 +445,7 @@ func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.R
// pollResolver reads responses from the DNS resolver and injects them inbound.
func (e *userspaceEngine) pollResolver() {
for {
resp, err := e.resolver.NextResponse()
bs, to, err := e.resolver.NextResponse()
if err == resolver.ErrClosed {
return
}
@ -461,17 +457,17 @@ func (e *userspaceEngine) pollResolver() {
h := packet.UDP4Header{
IP4Header: packet.IP4Header{
Src: magicDNSIP,
Dst: resp.Addr.IP,
Dst: to.IP,
},
SrcPort: magicDNSPort,
DstPort: resp.Addr.Port,
DstPort: to.Port,
}
hlen := h.Len()
// TODO(dmytro): avoid this allocation without importing tstun quirks into dns.
const offset = tstun.PacketStartOffset
buf := make([]byte, offset+hlen+len(resp.Payload))
copy(buf[offset+hlen:], resp.Payload)
buf := make([]byte, offset+hlen+len(bs))
copy(buf[offset+hlen:], bs)
h.Marshal(buf[offset:])
e.tundev.InjectInboundDirect(buf, offset)