AdGuardHome/internal/next/dnssvc/dnssvc_test.go

126 lines
2.9 KiB
Go

package dnssvc_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMain(m *testing.M) {
testutil.DiscardLogOutput(m)
}
// testTimeout is the common timeout for tests.
const testTimeout = 1 * time.Second
func TestService(t *testing.T) {
const (
listenAddr = "127.0.0.1:0"
bootstrapAddr = "127.0.0.1:0"
upstreamAddr = "upstream.example"
)
upstreamErrCh := make(chan error, 1)
upstreamStartedCh := make(chan struct{})
upstreamSrv := &dns.Server{
Addr: bootstrapAddr,
Net: "udp",
Handler: dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
pt := testutil.PanicT{}
resp := (&dns.Msg{}).SetReply(req)
resp.Answer = append(resp.Answer, &dns.A{
Hdr: dns.RR_Header{},
A: netip.MustParseAddrPort(bootstrapAddr).Addr().AsSlice(),
})
writeErr := w.WriteMsg(resp)
require.NoError(pt, writeErr)
}),
NotifyStartedFunc: func() { close(upstreamStartedCh) },
}
go func() {
listenErr := upstreamSrv.ListenAndServe()
if listenErr != nil {
// Log these immediately to see what happens.
t.Logf("upstream listen error: %s", listenErr)
}
upstreamErrCh <- listenErr
}()
_, _ = testutil.RequireReceive(t, upstreamStartedCh, testTimeout)
c := &dnssvc.Config{
Addresses: []netip.AddrPort{netip.MustParseAddrPort(listenAddr)},
BootstrapServers: []string{upstreamSrv.PacketConn.LocalAddr().String()},
UpstreamServers: []string{upstreamAddr},
DNS64Prefixes: nil,
UpstreamTimeout: testTimeout,
BootstrapPreferIPv6: false,
UseDNS64: false,
}
svc, err := dnssvc.New(c)
require.NoError(t, err)
err = svc.Start()
require.NoError(t, err)
gotConf := svc.Config()
require.NotNil(t, gotConf)
require.Len(t, gotConf.Addresses, 1)
addr := gotConf.Addresses[0]
t.Run("dns", func(t *testing.T) {
req := &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: dns.Id(),
RecursionDesired: true,
},
Question: []dns.Question{{
Name: "example.com.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
cli := &dns.Client{}
var resp *dns.Msg
require.Eventually(t, func() (ok bool) {
var excErr error
resp, _, excErr = cli.ExchangeContext(ctx, req, addr.String())
return excErr == nil
}, testTimeout, testTimeout/10)
assert.NotNil(t, resp)
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
err = svc.Shutdown(ctx)
require.NoError(t, err)
err = upstreamSrv.Shutdown()
require.NoError(t, err)
err, ok := testutil.RequireReceive(t, upstreamErrCh, testTimeout)
require.True(t, ok)
require.NoError(t, err)
}