tailscale/net/dnscache/messagecache.go

315 lines
8.1 KiB
Go

// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package dnscache
import (
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"time"
"github.com/golang/groupcache/lru"
"golang.org/x/net/dns/dnsmessage"
)
// MessageCache is a cache that works at the DNS message layer,
// with its cache keyed on a DNS wire-level question, and capable
// of replying to DNS messages.
//
// Its zero value is ready for use with a default cache size.
// Use SetMaxCacheSize to specify the cache size.
//
// It's safe for concurrent use.
type MessageCache struct {
// Clock is a clock, for testing.
// If nil, time.Now is used.
Clock func() time.Time
mu sync.Mutex
cacheSizeSet int // 0 means default
cache lru.Cache // msgQ => *msgCacheValue
}
func (c *MessageCache) now() time.Time {
if c.Clock != nil {
return c.Clock()
}
return time.Now()
}
// SetMaxCacheSize sets the maximum number of DNS cache entries that
// can be stored.
func (c *MessageCache) SetMaxCacheSize(n int) {
c.mu.Lock()
defer c.mu.Unlock()
c.cacheSizeSet = n
c.pruneLocked()
}
// Flush clears the cache.
func (c *MessageCache) Flush() {
c.mu.Lock()
defer c.mu.Unlock()
c.cache.Clear()
}
// pruneLocked prunes down the cache size to the configured (or
// default) max size.
func (c *MessageCache) pruneLocked() {
max := c.cacheSizeSet
if max == 0 {
max = 500
}
for c.cache.Len() > max {
c.cache.RemoveOldest()
}
}
// msgQ is the MessageCache cache key.
//
// It's basically a golang.org/x/net/dns/dnsmessage#Question but the
// Class is omitted (we only cache ClassINET) and we store a Go string
// instead of a 256 byte dnsmessage.Name array.
type msgQ struct {
Name string
Type dnsmessage.Type // A, AAAA, MX, etc
}
// A *msgCacheValue is the cached value for a msgQ (question) key.
//
// Despite using pointers for storage and methods, the value is
// immutable once placed in the cache.
type msgCacheValue struct {
Expires time.Time
// Answers are the minimum data to reconstruct a DNS response
// message. TTLs are added later when converting to a
// dnsmessage.Resource.
Answers []msgResource
}
type msgResource struct {
Name string
Type dnsmessage.Type // dnsmessage.UnknownResource.Type
Data []byte // dnsmessage.UnknownResource.Data
}
// ErrCacheMiss is a sentinel error returned by MessageCache.ReplyFromCache
// when the request can not be satisified from cache.
var ErrCacheMiss = errors.New("cache miss")
var parserPool = &sync.Pool{
New: func() any { return new(dnsmessage.Parser) },
}
// ReplyFromCache writes a DNS reply to w for the provided DNS query message,
// which must begin with the two ID bytes of a DNS message.
//
// If there's a cache miss, the message is invalid or unexpected,
// ErrCacheMiss is returned. On cache hit, either nil or an error from
// a w.Write call is returned.
func (c *MessageCache) ReplyFromCache(w io.Writer, dnsQueryMessage []byte) error {
cacheKey, txID, ok := getDNSQueryCacheKey(dnsQueryMessage)
if !ok {
return ErrCacheMiss
}
now := c.now()
c.mu.Lock()
cacheEntI, _ := c.cache.Get(cacheKey)
v, ok := cacheEntI.(*msgCacheValue)
if ok && now.After(v.Expires) {
c.cache.Remove(cacheKey)
ok = false
}
c.mu.Unlock()
if !ok {
return ErrCacheMiss
}
ttl := uint32(v.Expires.Sub(now).Seconds())
packedRes, err := packDNSResponse(cacheKey, txID, ttl, v.Answers)
if err != nil {
return ErrCacheMiss
}
_, err = w.Write(packedRes)
return err
}
var (
errNotCacheable = errors.New("question not cacheable")
)
// AddCacheEntry adds a cache entry to the cache.
// It returns an error if the entry could not be cached.
func (c *MessageCache) AddCacheEntry(qPacket, res []byte) error {
cacheKey, qID, ok := getDNSQueryCacheKey(qPacket)
if !ok {
return errNotCacheable
}
now := c.now()
v := &msgCacheValue{}
p := parserPool.Get().(*dnsmessage.Parser)
defer parserPool.Put(p)
resh, err := p.Start(res)
if err != nil {
return fmt.Errorf("reading header in response: %w", err)
}
if resh.ID != qID {
return fmt.Errorf("response ID doesn't match query ID")
}
q, err := p.Question()
if err != nil {
return fmt.Errorf("reading 1st question in response: %w", err)
}
if _, err := p.Question(); err != dnsmessage.ErrSectionDone {
if err == nil {
return errors.New("unexpected 2nd question in response")
}
return fmt.Errorf("after reading 1st question in response: %w", err)
}
if resName := asciiLowerName(q.Name).String(); resName != cacheKey.Name {
return fmt.Errorf("response question name %q != question name %q", resName, cacheKey.Name)
}
for {
rh, err := p.AnswerHeader()
if err == dnsmessage.ErrSectionDone {
break
}
if err != nil {
return fmt.Errorf("reading answer: %w", err)
}
res, err := p.UnknownResource()
if err != nil {
return fmt.Errorf("reading resource: %w", err)
}
if rh.Class != dnsmessage.ClassINET {
continue
}
// Set the cache entry's expiration to the soonest
// we've seen. (They should all be the same, though)
expires := now.Add(time.Duration(rh.TTL) * time.Second)
if v.Expires.IsZero() || expires.Before(v.Expires) {
v.Expires = expires
}
v.Answers = append(v.Answers, msgResource{
Name: rh.Name.String(),
Type: rh.Type,
Data: res.Data, // doesn't alias; a copy from dnsmessage.unpackUnknownResource
})
}
c.addCacheValue(cacheKey, v)
return nil
}
func (c *MessageCache) addCacheValue(cacheKey msgQ, v *msgCacheValue) {
c.mu.Lock()
defer c.mu.Unlock()
c.cache.Add(cacheKey, v)
c.pruneLocked()
}
func getDNSQueryCacheKey(msg []byte) (cacheKey msgQ, txID uint16, ok bool) {
p := parserPool.Get().(*dnsmessage.Parser)
defer parserPool.Put(p)
h, err := p.Start(msg)
const dnsHeaderSize = 12
if err != nil || h.OpCode != 0 || h.Response || h.Truncated ||
len(msg) < dnsHeaderSize { // p.Start checks this anyway, but to be explicit for slicing below
return cacheKey, 0, false
}
var (
numQ = binary.BigEndian.Uint16(msg[4:6])
numAns = binary.BigEndian.Uint16(msg[6:8])
numAuth = binary.BigEndian.Uint16(msg[8:10])
numAddn = binary.BigEndian.Uint16(msg[10:12])
)
_ = numAddn // ignore this for now; do client OSes send EDNS additional? assume so, ignore.
if !(numQ == 1 && numAns == 0 && numAuth == 0) {
// Something weird. We don't want to deal with it.
return cacheKey, 0, false
}
q, err := p.Question()
if err != nil {
// Already verified numQ == 1 so shouldn't happen, but:
return cacheKey, 0, false
}
if q.Class != dnsmessage.ClassINET {
// We only cache the Internet class.
return cacheKey, 0, false
}
return msgQ{Name: asciiLowerName(q.Name).String(), Type: q.Type}, h.ID, true
}
func asciiLowerName(n dnsmessage.Name) dnsmessage.Name {
nb := n.Data[:]
if int(n.Length) < len(n.Data) {
nb = nb[:n.Length]
}
for i, b := range nb {
if 'A' <= b && b <= 'Z' {
n.Data[i] += 0x20
}
}
return n
}
// packDNSResponse builds a DNS response for the given question and
// transaction ID. The response resource records will have have the
// same provided TTL.
func packDNSResponse(q msgQ, txID uint16, ttl uint32, answers []msgResource) ([]byte, error) {
var baseMem []byte // TODO: guess a max size based on looping over answers?
b := dnsmessage.NewBuilder(baseMem, dnsmessage.Header{
ID: txID,
Response: true,
OpCode: 0,
Authoritative: false,
Truncated: false,
RCode: dnsmessage.RCodeSuccess,
})
name, err := dnsmessage.NewName(q.Name)
if err != nil {
return nil, err
}
if err := b.StartQuestions(); err != nil {
return nil, err
}
if err := b.Question(dnsmessage.Question{
Name: name,
Type: q.Type,
Class: dnsmessage.ClassINET,
}); err != nil {
return nil, err
}
if err := b.StartAnswers(); err != nil {
return nil, err
}
for _, r := range answers {
name, err := dnsmessage.NewName(r.Name)
if err != nil {
return nil, err
}
if err := b.UnknownResource(dnsmessage.ResourceHeader{
Name: name,
Type: r.Type,
Class: dnsmessage.ClassINET,
TTL: ttl,
}, dnsmessage.UnknownResource{
Type: r.Type,
Data: r.Data,
}); err != nil {
return nil, err
}
}
return b.Finish()
}