315 lines
8.1 KiB
Go
315 lines
8.1 KiB
Go
|
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package dnscache
|
||
|
|
||
|
import (
|
||
|
"encoding/binary"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/golang/groupcache/lru"
|
||
|
"golang.org/x/net/dns/dnsmessage"
|
||
|
)
|
||
|
|
||
|
// MessageCache is a cache that works at the DNS message layer,
|
||
|
// with its cache keyed on a DNS wire-level question, and capable
|
||
|
// of replying to DNS messages.
|
||
|
//
|
||
|
// Its zero value is ready for use with a default cache size.
|
||
|
// Use SetMaxCacheSize to specify the cache size.
|
||
|
//
|
||
|
// It's safe for concurrent use.
|
||
|
type MessageCache struct {
|
||
|
// Clock is a clock, for testing.
|
||
|
// If nil, time.Now is used.
|
||
|
Clock func() time.Time
|
||
|
|
||
|
mu sync.Mutex
|
||
|
cacheSizeSet int // 0 means default
|
||
|
cache lru.Cache // msgQ => *msgCacheValue
|
||
|
}
|
||
|
|
||
|
func (c *MessageCache) now() time.Time {
|
||
|
if c.Clock != nil {
|
||
|
return c.Clock()
|
||
|
}
|
||
|
return time.Now()
|
||
|
}
|
||
|
|
||
|
// SetMaxCacheSize sets the maximum number of DNS cache entries that
|
||
|
// can be stored.
|
||
|
func (c *MessageCache) SetMaxCacheSize(n int) {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
c.cacheSizeSet = n
|
||
|
c.pruneLocked()
|
||
|
}
|
||
|
|
||
|
// Flush clears the cache.
|
||
|
func (c *MessageCache) Flush() {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
c.cache.Clear()
|
||
|
}
|
||
|
|
||
|
// pruneLocked prunes down the cache size to the configured (or
|
||
|
// default) max size.
|
||
|
func (c *MessageCache) pruneLocked() {
|
||
|
max := c.cacheSizeSet
|
||
|
if max == 0 {
|
||
|
max = 500
|
||
|
}
|
||
|
for c.cache.Len() > max {
|
||
|
c.cache.RemoveOldest()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// msgQ is the MessageCache cache key.
|
||
|
//
|
||
|
// It's basically a golang.org/x/net/dns/dnsmessage#Question but the
|
||
|
// Class is omitted (we only cache ClassINET) and we store a Go string
|
||
|
// instead of a 256 byte dnsmessage.Name array.
|
||
|
type msgQ struct {
|
||
|
Name string
|
||
|
Type dnsmessage.Type // A, AAAA, MX, etc
|
||
|
}
|
||
|
|
||
|
// A *msgCacheValue is the cached value for a msgQ (question) key.
|
||
|
//
|
||
|
// Despite using pointers for storage and methods, the value is
|
||
|
// immutable once placed in the cache.
|
||
|
type msgCacheValue struct {
|
||
|
Expires time.Time
|
||
|
|
||
|
// Answers are the minimum data to reconstruct a DNS response
|
||
|
// message. TTLs are added later when converting to a
|
||
|
// dnsmessage.Resource.
|
||
|
Answers []msgResource
|
||
|
}
|
||
|
|
||
|
type msgResource struct {
|
||
|
Name string
|
||
|
Type dnsmessage.Type // dnsmessage.UnknownResource.Type
|
||
|
Data []byte // dnsmessage.UnknownResource.Data
|
||
|
}
|
||
|
|
||
|
// ErrCacheMiss is a sentinel error returned by MessageCache.ReplyFromCache
|
||
|
// when the request can not be satisified from cache.
|
||
|
var ErrCacheMiss = errors.New("cache miss")
|
||
|
|
||
|
var parserPool = &sync.Pool{
|
||
|
New: func() interface{} { return new(dnsmessage.Parser) },
|
||
|
}
|
||
|
|
||
|
// ReplyFromCache writes a DNS reply to w for the provided DNS query message,
|
||
|
// which must begin with the two ID bytes of a DNS message.
|
||
|
//
|
||
|
// If there's a cache miss, the message is invalid or unexpected,
|
||
|
// ErrCacheMiss is returned. On cache hit, either nil or an error from
|
||
|
// a w.Write call is returned.
|
||
|
func (c *MessageCache) ReplyFromCache(w io.Writer, dnsQueryMessage []byte) error {
|
||
|
cacheKey, txID, ok := getDNSQueryCacheKey(dnsQueryMessage)
|
||
|
if !ok {
|
||
|
return ErrCacheMiss
|
||
|
}
|
||
|
now := c.now()
|
||
|
|
||
|
c.mu.Lock()
|
||
|
cacheEntI, _ := c.cache.Get(cacheKey)
|
||
|
v, ok := cacheEntI.(*msgCacheValue)
|
||
|
if ok && now.After(v.Expires) {
|
||
|
c.cache.Remove(cacheKey)
|
||
|
ok = false
|
||
|
}
|
||
|
c.mu.Unlock()
|
||
|
|
||
|
if !ok {
|
||
|
return ErrCacheMiss
|
||
|
}
|
||
|
|
||
|
ttl := uint32(v.Expires.Sub(now).Seconds())
|
||
|
|
||
|
packedRes, err := packDNSResponse(cacheKey, txID, ttl, v.Answers)
|
||
|
if err != nil {
|
||
|
return ErrCacheMiss
|
||
|
}
|
||
|
_, err = w.Write(packedRes)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
errNotCacheable = errors.New("question not cacheable")
|
||
|
)
|
||
|
|
||
|
// AddCacheEntry adds a cache entry to the cache.
|
||
|
// It returns an error if the entry could not be cached.
|
||
|
func (c *MessageCache) AddCacheEntry(qPacket, res []byte) error {
|
||
|
cacheKey, qID, ok := getDNSQueryCacheKey(qPacket)
|
||
|
if !ok {
|
||
|
return errNotCacheable
|
||
|
}
|
||
|
now := c.now()
|
||
|
v := &msgCacheValue{}
|
||
|
|
||
|
p := parserPool.Get().(*dnsmessage.Parser)
|
||
|
defer parserPool.Put(p)
|
||
|
|
||
|
resh, err := p.Start(res)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("reading header in response: %w", err)
|
||
|
}
|
||
|
if resh.ID != qID {
|
||
|
return fmt.Errorf("response ID doesn't match query ID")
|
||
|
}
|
||
|
q, err := p.Question()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("reading 1st question in response: %w", err)
|
||
|
}
|
||
|
if _, err := p.Question(); err != dnsmessage.ErrSectionDone {
|
||
|
if err == nil {
|
||
|
return errors.New("unexpected 2nd question in response")
|
||
|
}
|
||
|
return fmt.Errorf("after reading 1st question in response: %w", err)
|
||
|
}
|
||
|
if resName := asciiLowerName(q.Name).String(); resName != cacheKey.Name {
|
||
|
return fmt.Errorf("response question name %q != question name %q", resName, cacheKey.Name)
|
||
|
}
|
||
|
for {
|
||
|
rh, err := p.AnswerHeader()
|
||
|
if err == dnsmessage.ErrSectionDone {
|
||
|
break
|
||
|
}
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("reading answer: %w", err)
|
||
|
}
|
||
|
res, err := p.UnknownResource()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("reading resource: %w", err)
|
||
|
}
|
||
|
if rh.Class != dnsmessage.ClassINET {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
// Set the cache entry's expiration to the soonest
|
||
|
// we've seen. (They should all be the same, though)
|
||
|
expires := now.Add(time.Duration(rh.TTL) * time.Second)
|
||
|
if v.Expires.IsZero() || expires.Before(v.Expires) {
|
||
|
v.Expires = expires
|
||
|
}
|
||
|
v.Answers = append(v.Answers, msgResource{
|
||
|
Name: rh.Name.String(),
|
||
|
Type: rh.Type,
|
||
|
Data: res.Data, // doesn't alias; a copy from dnsmessage.unpackUnknownResource
|
||
|
})
|
||
|
}
|
||
|
c.addCacheValue(cacheKey, v)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *MessageCache) addCacheValue(cacheKey msgQ, v *msgCacheValue) {
|
||
|
c.mu.Lock()
|
||
|
defer c.mu.Unlock()
|
||
|
c.cache.Add(cacheKey, v)
|
||
|
c.pruneLocked()
|
||
|
}
|
||
|
|
||
|
func getDNSQueryCacheKey(msg []byte) (cacheKey msgQ, txID uint16, ok bool) {
|
||
|
p := parserPool.Get().(*dnsmessage.Parser)
|
||
|
defer parserPool.Put(p)
|
||
|
h, err := p.Start(msg)
|
||
|
const dnsHeaderSize = 12
|
||
|
if err != nil || h.OpCode != 0 || h.Response || h.Truncated ||
|
||
|
len(msg) < dnsHeaderSize { // p.Start checks this anyway, but to be explicit for slicing below
|
||
|
return cacheKey, 0, false
|
||
|
}
|
||
|
var (
|
||
|
numQ = binary.BigEndian.Uint16(msg[4:6])
|
||
|
numAns = binary.BigEndian.Uint16(msg[6:8])
|
||
|
numAuth = binary.BigEndian.Uint16(msg[8:10])
|
||
|
numAddn = binary.BigEndian.Uint16(msg[10:12])
|
||
|
)
|
||
|
_ = numAddn // ignore this for now; do client OSes send EDNS additional? assume so, ignore.
|
||
|
if !(numQ == 1 && numAns == 0 && numAuth == 0) {
|
||
|
// Something weird. We don't want to deal with it.
|
||
|
return cacheKey, 0, false
|
||
|
}
|
||
|
q, err := p.Question()
|
||
|
if err != nil {
|
||
|
// Already verified numQ == 1 so shouldn't happen, but:
|
||
|
return cacheKey, 0, false
|
||
|
}
|
||
|
if q.Class != dnsmessage.ClassINET {
|
||
|
// We only cache the Internet class.
|
||
|
return cacheKey, 0, false
|
||
|
}
|
||
|
return msgQ{Name: asciiLowerName(q.Name).String(), Type: q.Type}, h.ID, true
|
||
|
}
|
||
|
|
||
|
func asciiLowerName(n dnsmessage.Name) dnsmessage.Name {
|
||
|
nb := n.Data[:]
|
||
|
if int(n.Length) < len(n.Data) {
|
||
|
nb = nb[:n.Length]
|
||
|
}
|
||
|
for i, b := range nb {
|
||
|
if 'A' <= b && b <= 'Z' {
|
||
|
n.Data[i] += 0x20
|
||
|
}
|
||
|
}
|
||
|
return n
|
||
|
}
|
||
|
|
||
|
// packDNSResponse builds a DNS response for the given question and
|
||
|
// transaction ID. The response resource records will have have the
|
||
|
// same provided TTL.
|
||
|
func packDNSResponse(q msgQ, txID uint16, ttl uint32, answers []msgResource) ([]byte, error) {
|
||
|
var baseMem []byte // TODO: guess a max size based on looping over answers?
|
||
|
b := dnsmessage.NewBuilder(baseMem, dnsmessage.Header{
|
||
|
ID: txID,
|
||
|
Response: true,
|
||
|
OpCode: 0,
|
||
|
Authoritative: false,
|
||
|
Truncated: false,
|
||
|
RCode: dnsmessage.RCodeSuccess,
|
||
|
})
|
||
|
name, err := dnsmessage.NewName(q.Name)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := b.StartQuestions(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := b.Question(dnsmessage.Question{
|
||
|
Name: name,
|
||
|
Type: q.Type,
|
||
|
Class: dnsmessage.ClassINET,
|
||
|
}); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := b.StartAnswers(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
for _, r := range answers {
|
||
|
name, err := dnsmessage.NewName(r.Name)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if err := b.UnknownResource(dnsmessage.ResourceHeader{
|
||
|
Name: name,
|
||
|
Type: r.Type,
|
||
|
Class: dnsmessage.ClassINET,
|
||
|
TTL: ttl,
|
||
|
}, dnsmessage.UnknownResource{
|
||
|
Type: r.Type,
|
||
|
Data: r.Data,
|
||
|
}); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
return b.Finish()
|
||
|
}
|