net/dns/resolver: unexport Packet, only use it internally.
Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
parent
5fb9e00ecf
commit
f185d62dc8
|
@ -105,7 +105,7 @@ type forwarder struct {
|
|||
logf logger.Logf
|
||||
|
||||
// responses is a channel by which responses are returned.
|
||||
responses chan Packet
|
||||
responses chan packet
|
||||
// closed signals all goroutines to stop.
|
||||
closed chan struct{}
|
||||
// wg signals when all goroutines have stopped.
|
||||
|
@ -126,7 +126,7 @@ func init() {
|
|||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func newForwarder(logf logger.Logf, responses chan Packet) *forwarder {
|
||||
func newForwarder(logf logger.Logf, responses chan packet) *forwarder {
|
||||
return &forwarder{
|
||||
logf: logger.WithPrefix(logf, "forward: "),
|
||||
responses: responses,
|
||||
|
@ -218,14 +218,11 @@ func (f *forwarder) recv(conn *fwdConn) {
|
|||
|
||||
f.mu.Unlock()
|
||||
|
||||
packet := Packet{
|
||||
Payload: out,
|
||||
Addr: record.src,
|
||||
}
|
||||
pkt := packet{out, record.src}
|
||||
select {
|
||||
case <-f.closed:
|
||||
return
|
||||
case f.responses <- packet:
|
||||
case f.responses <- pkt:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
|
@ -258,8 +255,8 @@ func (f *forwarder) cleanMap() {
|
|||
}
|
||||
|
||||
// forward forwards the query to all upstream nameservers and returns the first response.
|
||||
func (f *forwarder) forward(query Packet) error {
|
||||
txid := getTxID(query.Payload)
|
||||
func (f *forwarder) forward(query packet) error {
|
||||
txid := getTxID(query.bs)
|
||||
|
||||
f.mu.Lock()
|
||||
|
||||
|
@ -269,14 +266,14 @@ func (f *forwarder) forward(query Packet) error {
|
|||
return errNoUpstreams
|
||||
}
|
||||
f.txMap[txid] = forwardingRecord{
|
||||
src: query.Addr,
|
||||
src: query.addr,
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
|
||||
f.mu.Unlock()
|
||||
|
||||
for _, upstream := range upstreams {
|
||||
f.send(query.Payload, upstream)
|
||||
f.send(query.bs, upstream)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -44,14 +44,9 @@ var (
|
|||
errNotOurName = errors.New("not a Tailscale DNS name")
|
||||
)
|
||||
|
||||
// Packet represents a DNS payload together with the address of its origin.
|
||||
type Packet struct {
|
||||
// Payload is the application layer DNS payload.
|
||||
// Resolver assumes ownership of the request payload when it is enqueued
|
||||
// and cedes ownership of the response payload when it is returned from NextResponse.
|
||||
Payload []byte
|
||||
// Addr is the source address for a request and the destination address for a response.
|
||||
Addr netaddr.IPPort
|
||||
type packet struct {
|
||||
bs []byte
|
||||
addr netaddr.IPPort // src for a request, dst for a response
|
||||
}
|
||||
|
||||
// Resolver is a DNS resolver for nodes on the Tailscale network,
|
||||
|
@ -66,9 +61,9 @@ type Resolver struct {
|
|||
forwarder *forwarder
|
||||
|
||||
// queue is a buffered channel holding DNS requests queued for resolution.
|
||||
queue chan Packet
|
||||
queue chan packet
|
||||
// responses is an unbuffered channel to which responses are returned.
|
||||
responses chan Packet
|
||||
responses chan packet
|
||||
// errors is an unbuffered channel to which errors are returned.
|
||||
errors chan error
|
||||
// closed signals all goroutines to stop.
|
||||
|
@ -88,8 +83,8 @@ func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
|
|||
r := &Resolver{
|
||||
logf: logger.WithPrefix(logf, "dns: "),
|
||||
linkMon: linkMon,
|
||||
queue: make(chan Packet, queueSize),
|
||||
responses: make(chan Packet),
|
||||
queue: make(chan packet, queueSize),
|
||||
responses: make(chan packet),
|
||||
errors: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
@ -153,11 +148,11 @@ func (r *Resolver) SetUpstreams(upstreams []net.Addr) {
|
|||
// EnqueueRequest places the given DNS request in the resolver's queue.
|
||||
// It takes ownership of the payload and does not block.
|
||||
// If the queue is full, the request will be dropped and an error will be returned.
|
||||
func (r *Resolver) EnqueueRequest(request Packet) error {
|
||||
func (r *Resolver) EnqueueRequest(bs []byte, from netaddr.IPPort) error {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return ErrClosed
|
||||
case r.queue <- request:
|
||||
case r.queue <- packet{bs, from}:
|
||||
return nil
|
||||
default:
|
||||
return errFullQueue
|
||||
|
@ -166,14 +161,14 @@ func (r *Resolver) EnqueueRequest(request Packet) error {
|
|||
|
||||
// NextResponse returns a DNS response to a previously enqueued request.
|
||||
// It blocks until a response is available and gives up ownership of the response payload.
|
||||
func (r *Resolver) NextResponse() (Packet, error) {
|
||||
func (r *Resolver) NextResponse() (packet []byte, to netaddr.IPPort, err error) {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return Packet{}, ErrClosed
|
||||
return nil, netaddr.IPPort{}, ErrClosed
|
||||
case resp := <-r.responses:
|
||||
return resp, nil
|
||||
return resp.bs, resp.addr, nil
|
||||
case err := <-r.errors:
|
||||
return Packet{}, err
|
||||
return nil, netaddr.IPPort{}, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -266,19 +261,19 @@ func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
|
|||
func (r *Resolver) poll() {
|
||||
defer r.wg.Done()
|
||||
|
||||
var packet Packet
|
||||
var pkt packet
|
||||
for {
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case packet = <-r.queue:
|
||||
case pkt = <-r.queue:
|
||||
// continue
|
||||
}
|
||||
|
||||
out, err := r.respond(packet.Payload)
|
||||
out, err := r.respond(pkt.bs)
|
||||
|
||||
if err == errNotOurName {
|
||||
err = r.forwarder.forward(packet)
|
||||
err = r.forwarder.forward(pkt)
|
||||
if err == nil {
|
||||
// forward will send response into r.responses, nothing to do.
|
||||
continue
|
||||
|
@ -293,11 +288,11 @@ func (r *Resolver) poll() {
|
|||
// continue
|
||||
}
|
||||
} else {
|
||||
packet.Payload = out
|
||||
pkt.bs = out
|
||||
select {
|
||||
case <-r.closed:
|
||||
return
|
||||
case r.responses <- packet:
|
||||
case r.responses <- pkt:
|
||||
// continue
|
||||
}
|
||||
}
|
||||
|
|
|
@ -109,10 +109,9 @@ func unpackResponse(payload []byte) (dnsResponse, error) {
|
|||
}
|
||||
|
||||
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
|
||||
request := Packet{Payload: query}
|
||||
r.EnqueueRequest(request)
|
||||
resp, err := r.NextResponse()
|
||||
return resp.Payload, err
|
||||
r.EnqueueRequest(query, netaddr.IPPort{})
|
||||
payload, _, err := r.NextResponse()
|
||||
return payload, err
|
||||
}
|
||||
|
||||
func mustIP(str string) netaddr.IP {
|
||||
|
@ -418,21 +417,20 @@ func TestDelegateCollision(t *testing.T) {
|
|||
// packets will have the same dns txid.
|
||||
for _, p := range packets {
|
||||
payload := dnspacket(p.qname, p.qtype)
|
||||
req := Packet{Payload: payload, Addr: p.addr}
|
||||
err := r.EnqueueRequest(req)
|
||||
err := r.EnqueueRequest(payload, p.addr)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Despite the txid collision, the answer(s) should still match the query.
|
||||
resp, err := r.NextResponse()
|
||||
resp, addr, err := r.NextResponse()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
var p dns.Parser
|
||||
_, err = p.Start(resp.Payload)
|
||||
_, err = p.Start(resp)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -456,8 +454,8 @@ func TestDelegateCollision(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, p := range packets {
|
||||
if p.qtype == wantType && p.addr != resp.Addr {
|
||||
t.Errorf("addr = %v; want %v", resp.Addr, p.addr)
|
||||
if p.qtype == wantType && p.addr != addr {
|
||||
t.Errorf("addr = %v; want %v", addr, p.addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -433,11 +433,7 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper)
|
|||
// handleDNS is an outbound pre-filter resolving Tailscale domains.
|
||||
func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.Response {
|
||||
if p.Dst.IP == magicDNSIP && p.Dst.Port == magicDNSPort && p.IPProto == ipproto.UDP {
|
||||
request := resolver.Packet{
|
||||
Payload: append([]byte(nil), p.Payload()...),
|
||||
Addr: netaddr.IPPort{IP: p.Src.IP, Port: p.Src.Port},
|
||||
}
|
||||
err := e.resolver.EnqueueRequest(request)
|
||||
err := e.resolver.EnqueueRequest(append([]byte(nil), p.Payload()...), p.Src)
|
||||
if err != nil {
|
||||
e.logf("dns: enqueue: %v", err)
|
||||
}
|
||||
|
@ -449,7 +445,7 @@ func (e *userspaceEngine) handleDNS(p *packet.Parsed, t *tstun.Wrapper) filter.R
|
|||
// pollResolver reads responses from the DNS resolver and injects them inbound.
|
||||
func (e *userspaceEngine) pollResolver() {
|
||||
for {
|
||||
resp, err := e.resolver.NextResponse()
|
||||
bs, to, err := e.resolver.NextResponse()
|
||||
if err == resolver.ErrClosed {
|
||||
return
|
||||
}
|
||||
|
@ -461,17 +457,17 @@ func (e *userspaceEngine) pollResolver() {
|
|||
h := packet.UDP4Header{
|
||||
IP4Header: packet.IP4Header{
|
||||
Src: magicDNSIP,
|
||||
Dst: resp.Addr.IP,
|
||||
Dst: to.IP,
|
||||
},
|
||||
SrcPort: magicDNSPort,
|
||||
DstPort: resp.Addr.Port,
|
||||
DstPort: to.Port,
|
||||
}
|
||||
hlen := h.Len()
|
||||
|
||||
// TODO(dmytro): avoid this allocation without importing tstun quirks into dns.
|
||||
const offset = tstun.PacketStartOffset
|
||||
buf := make([]byte, offset+hlen+len(resp.Payload))
|
||||
copy(buf[offset+hlen:], resp.Payload)
|
||||
buf := make([]byte, offset+hlen+len(bs))
|
||||
copy(buf[offset+hlen:], bs)
|
||||
h.Marshal(buf[offset:])
|
||||
|
||||
e.tundev.InjectInboundDirect(buf, offset)
|
||||
|
|
Loading…
Reference in New Issue