diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 50a1bd009..a3fdec84a 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -105,7 +105,7 @@ type forwarder struct { logf logger.Logf // responses is a channel by which responses are returned. - responses chan Packet + responses chan packet // closed signals all goroutines to stop. closed chan struct{} // wg signals when all goroutines have stopped. @@ -126,7 +126,7 @@ func init() { rand.Seed(time.Now().UnixNano()) } -func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { +func newForwarder(logf logger.Logf, responses chan packet) *forwarder { return &forwarder{ logf: logger.WithPrefix(logf, "forward: "), responses: responses, @@ -218,14 +218,11 @@ func (f *forwarder) recv(conn *fwdConn) { f.mu.Unlock() - packet := Packet{ - Payload: out, - Addr: record.src, - } + pkt := packet{out, record.src} select { case <-f.closed: return - case f.responses <- packet: + case f.responses <- pkt: // continue } } @@ -258,8 +255,8 @@ func (f *forwarder) cleanMap() { } // forward forwards the query to all upstream nameservers and returns the first response. -func (f *forwarder) forward(query Packet) error { - txid := getTxID(query.Payload) +func (f *forwarder) forward(query packet) error { + txid := getTxID(query.bs) f.mu.Lock() @@ -269,14 +266,14 @@ func (f *forwarder) forward(query Packet) error { return errNoUpstreams } f.txMap[txid] = forwardingRecord{ - src: query.Addr, + src: query.addr, createdAt: time.Now(), } f.mu.Unlock() for _, upstream := range upstreams { - f.send(query.Payload, upstream) + f.send(query.bs, upstream) } return nil diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 09214b43b..4834a3582 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -44,14 +44,9 @@ var ( errNotOurName = errors.New("not a Tailscale DNS name") ) -// Packet represents a DNS payload together with the address of its origin. -type Packet struct { - // Payload is the application layer DNS payload. - // Resolver assumes ownership of the request payload when it is enqueued - // and cedes ownership of the response payload when it is returned from NextResponse. - Payload []byte - // Addr is the source address for a request and the destination address for a response. - Addr netaddr.IPPort +type packet struct { + bs []byte + addr netaddr.IPPort // src for a request, dst for a response } // Resolver is a DNS resolver for nodes on the Tailscale network, @@ -66,9 +61,9 @@ type Resolver struct { forwarder *forwarder // queue is a buffered channel holding DNS requests queued for resolution. - queue chan Packet + queue chan packet // responses is an unbuffered channel to which responses are returned. - responses chan Packet + responses chan packet // errors is an unbuffered channel to which errors are returned. errors chan error // closed signals all goroutines to stop. @@ -88,8 +83,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) { r := &Resolver{ logf: logger.WithPrefix(logf, "dns: "), linkMon: linkMon, - queue: make(chan Packet, queueSize), - responses: make(chan Packet), + queue: make(chan packet, queueSize), + responses: make(chan packet), errors: make(chan error), closed: make(chan struct{}), } @@ -153,11 +148,11 @@ func (r *Resolver) SetUpstreams(upstreams []net.Addr) { // EnqueueRequest places the given DNS request in the resolver's queue. // It takes ownership of the payload and does not block. // If the queue is full, the request will be dropped and an error will be returned. -func (r *Resolver) EnqueueRequest(request Packet) error { +func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error { select { case <-r.closed: return ErrClosed - case r.queue <- request: + case r.queue <- packet{bs, from}: return nil default: return errFullQueue @@ -166,14 +161,14 @@ func (r *Resolver) EnqueueRequest(request Packet) error { // NextResponse returns a DNS response to a previously enqueued request. // It blocks until a response is available and gives up ownership of the response payload. -func (r *Resolver) NextResponse() (Packet, error) { +func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error) { select { case <-r.closed: - return Packet{}, ErrClosed + return nil, netaddr.IPPort{}, ErrClosed case resp := <-r.responses: - return resp, nil + return resp.bs, resp.addr, nil case err := <-r.errors: - return Packet{}, err + return nil, netaddr.IPPort{}, err } } @@ -266,19 +261,19 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) { func (r *Resolver) poll() { defer r.wg.Done() - var packet Packet + var pkt packet for { select { case <-r.closed: return - case packet = <-r.queue: + case pkt = <-r.queue: // continue } - out, err := r.respond(packet.Payload) + out, err := r.respond(pkt.bs) if err == errNotOurName { - err = r.forwarder.forward(packet) + err = r.forwarder.forward(pkt) if err == nil { // forward will send response into r.responses, nothing to do. continue @@ -293,11 +288,11 @@ func (r *Resolver) poll() { // continue } } else { - packet.Payload = out + pkt.bs = out select { case <-r.closed: return - case r.responses <- packet: + case r.responses <- pkt: // continue } } diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 9ee7d11ac..3bf9bea6f 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -109,10 +109,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) { } func syncRespond(r *Resolver, query []byte) ([]byte, error) { - request := Packet{Payload: query} - r.EnqueueRequest(request) - resp, err := r.NextResponse() - return resp.Payload, err + r.EnqueueRequest(query, netaddr.IPPort{}) + payload, _, err := r.NextResponse() + return payload, err } func mustIP(str string) netaddr.IP { @@ -418,21 +417,20 @@ func TestDelegateCollision(t *testing.T) { // packets will have the same dns txid. for _, p := range packets { payload := dnspacket(p.qname, p.qtype) - req := Packet{Payload: payload, Addr: p.addr} - err := r.EnqueueRequest(req) + err := r.EnqueueRequest(payload, p.addr) if err != nil { t.Error(err) } } // Despite the txid collision, the answer(s) should still match the query. - resp, err := r.NextResponse() + resp, addr, err := r.NextResponse() if err != nil { t.Error(err) } var p dns.Parser - _, err = p.Start(resp.Payload) + _, err = p.Start(resp) if err != nil { t.Error(err) } @@ -456,8 +454,8 @@ func TestDelegateCollision(t *testing.T) { } for _, p := range packets { - if p.qtype == wantType && p.addr != resp.Addr { - t.Errorf("addr = %v; want %v", resp.Addr, p.addr) + if p.qtype == wantType && p.addr != addr { + t.Errorf("addr = %v; want %v", addr, p.addr) } } } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 28f79501b..a8d270c0b 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -433,11 +433,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) // handleDNS is an outbound pre-filter resolving Tailscale domains. func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response { if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP { - request := resolver.Packet{ - Payload: append([]byte(nil), p.Payload()...), - Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port}, - } - err := e.resolver.EnqueueRequest(request) + err := e.resolver.EnqueueRequest(append([]byte(nil), p.Payload()...), p.Src) if err != nil { e.logf("dns: enqueue: %v", err) } @@ -449,7 +445,7 @@ func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.R // pollResolver reads responses from the DNS resolver and injects them inbound. func (e *userspaceEngine) pollResolver() { for { - resp, err := e.resolver.NextResponse() + bs, to, err := e.resolver.NextResponse() if err == resolver.ErrClosed { return } @@ -461,17 +457,17 @@ func (e *userspaceEngine) pollResolver() { h := packet.UDP4Header{ IP4Header: packet.IP4Header{ Src: magicDNSIP, - Dst: resp.Addr.IP, + Dst: to.IP, }, SrcPort: magicDNSPort, - DstPort: resp.Addr.Port, + DstPort: to.Port, } hlen := h.Len() // TODO(dmytro): avoid this allocation without importing tstun quirks into dns. const offset = tstun.PacketStartOffset - buf := make([]byte, offset+hlen+len(resp.Payload)) - copy(buf[offset+hlen:], resp.Payload) + buf := make([]byte, offset+hlen+len(bs)) + copy(buf[offset+hlen:], bs) h.Marshal(buf[offset:]) e.tundev.InjectInboundDirect(buf, offset)