net/dns/resolver: remove Start method, fully spin up in New instead.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2021-03-31 22:32:07 -07:00
parent 075fb93e69
commit 5fb9e00ecf
3 changed files with 47 additions and 53 deletions

View File

@ -84,7 +84,7 @@ type Resolver struct {
// New returns a new resolver. // New returns a new resolver.
// linkMon optionally specifies a link monitor to use for socket rebinding. // linkMon optionally specifies a link monitor to use for socket rebinding.
func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver { func New(logf logger.Logf, linkMon *monitor.Mon) (*Resolver, error) {
r := &Resolver{ r := &Resolver{
logf: logger.WithPrefix(logf, "dns: "), logf: logger.WithPrefix(logf, "dns: "),
linkMon: linkMon, linkMon: linkMon,
@ -98,20 +98,14 @@ func New(logf logger.Logf, linkMon *monitor.Mon) *Resolver {
r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange) r.unregLinkMon = r.linkMon.RegisterChangeCallback(r.onLinkMonitorChange)
} }
return r if err := r.forwarder.Start(); err != nil {
} return nil, err
func (r *Resolver) Start() error {
if r.forwarder != nil {
if err := r.forwarder.Start(); err != nil {
return err
}
} }
r.wg.Add(1) r.wg.Add(1)
go r.poll() go r.poll()
return nil return r, nil
} }
// Close shuts down the resolver and ensures poll goroutines have exited. // Close shuts down the resolver and ensures poll goroutines have exited.

View File

@ -194,14 +194,14 @@ func TestRDNSNameToIPv6(t *testing.T) {
} }
func TestResolve(t *testing.T) { func TestResolve(t *testing.T) {
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
tests := []struct { tests := []struct {
name string name string
qname string qname string
@ -240,14 +240,14 @@ func TestResolve(t *testing.T) {
} }
func TestResolveReverse(t *testing.T) { func TestResolveReverse(t *testing.T) {
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
tests := []struct { tests := []struct {
name string name string
ip netaddr.IP ip netaddr.IP
@ -318,18 +318,18 @@ func TestDelegate(t *testing.T) {
return return
} }
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
if err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
r.SetMap(dnsMap) r.SetMap(dnsMap)
r.SetUpstreams([]net.Addr{ r.SetUpstreams([]net.Addr{
v4server.PacketConn.LocalAddr(), v4server.PacketConn.LocalAddr(),
v6server.PacketConn.LocalAddr(), v6server.PacketConn.LocalAddr(),
}) })
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err)
}
defer r.Close()
tests := []struct { tests := []struct {
title string title string
query []byte query []byte
@ -397,15 +397,15 @@ func TestDelegateCollision(t *testing.T) {
} }
defer server.Shutdown() defer server.Shutdown()
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
packets := []struct { packets := []struct {
qname string qname string
qtype dns.Type qtype dns.Type
@ -463,9 +463,8 @@ func TestDelegateCollision(t *testing.T) {
} }
func TestConcurrentSetMap(t *testing.T) { func TestConcurrentSetMap(t *testing.T) {
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
@ -499,14 +498,14 @@ func TestConcurrentSetUpstreams(t *testing.T) {
} }
defer server.Shutdown() defer server.Shutdown()
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
packet := dnspacket("test.site.", dns.TypeA) packet := dnspacket("test.site.", dns.TypeA)
// This is purely to ensure that delegation does not race with SetUpstreams. // This is purely to ensure that delegation does not race with SetUpstreams.
var wg sync.WaitGroup var wg sync.WaitGroup
@ -670,14 +669,14 @@ var emptyResponse = []byte{
} }
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
// One full packet and one error packet // One full packet and one error packet
tests := []struct { tests := []struct {
name string name string
@ -709,13 +708,12 @@ func TestFull(t *testing.T) {
} }
func TestAllocs(t *testing.T) { func TestAllocs(t *testing.T) {
r := New(t.Logf, nil) r, err := New(t.Logf, nil)
r.SetMap(dnsMap) if err != nil {
if err := r.Start(); err != nil {
t.Fatalf("start: %v", err) t.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
// It is seemingly pointless to test allocs in the delegate path, // It is seemingly pointless to test allocs in the delegate path,
// as dialer.Dial -> Read -> Write alone comprise 12 allocs. // as dialer.Dial -> Read -> Write alone comprise 12 allocs.
@ -778,15 +776,15 @@ func BenchmarkFull(b *testing.B) {
} }
defer server.Shutdown() defer server.Shutdown()
r := New(b.Logf, nil) r, err := New(b.Logf, nil)
r.SetMap(dnsMap) if err != nil {
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
if err := r.Start(); err != nil {
b.Fatalf("start: %v", err) b.Fatalf("start: %v", err)
} }
defer r.Close() defer r.Close()
r.SetMap(dnsMap)
r.SetUpstreams([]net.Addr{server.PacketConn.LocalAddr()})
tests := []struct { tests := []struct {
name string name string
request []byte request []byte

View File

@ -219,7 +219,11 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
e.linkMonOwned = true e.linkMonOwned = true
} }
e.resolver = resolver.New(logf, e.linkMon) var err error
e.resolver, err = resolver.New(logf, e.linkMon)
if err != nil {
return nil, err
}
logf("link state: %+v", e.linkMon.InterfaceState()) logf("link state: %+v", e.linkMon.InterfaceState())
@ -246,7 +250,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
NoteRecvActivity: e.noteReceiveActivity, NoteRecvActivity: e.noteReceiveActivity,
LinkMonitor: e.linkMon, LinkMonitor: e.linkMon,
} }
var err error
e.magicConn, err = magicsock.NewConn(magicsockOpts) e.magicConn, err = magicsock.NewConn(magicsockOpts)
if err != nil { if err != nil {
return nil, fmt.Errorf("wgengine: %v", err) return nil, fmt.Errorf("wgengine: %v", err)
@ -374,8 +378,6 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
e.logf("Starting magicsock...") e.logf("Starting magicsock...")
e.magicConn.Start() e.magicConn.Start()
e.logf("Starting resolver...")
e.resolver.Start()
go e.pollResolver() go e.pollResolver()
e.logf("Engine created.") e.logf("Engine created.")