diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 9c0964e93..e341186ec 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -449,7 +449,7 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) return } -func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { +func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { netMon, err := netmon.New(tb.Logf) if err != nil { tb.Fatal(err) @@ -463,8 +463,9 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa modify(fwd) } - rr := resolverAndDelay{ - name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + resolvers := make([]resolverAndDelay, len(ports)) + for i, port := range ports { + resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)} } rpkt := packet{ @@ -476,7 +477,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa rchan := make(chan packet, 1) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) tb.Cleanup(cancel) - err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) + err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...) select { case res := <-rchan: return res.bs, err @@ -485,8 +486,62 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa } } -func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { - resp, err := runTestQuery(tb, port, request, modify) +// makeTestRequest returns a new TypeA request for the given domain. +func makeTestRequest(tb testing.TB, domain string) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return request +} + +// makeTestResponse returns a new Type A response for the given domain, +// with the specified status code and zero or more addresses. +func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: code, + }) + builder.StartQuestions() + q := dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + } + builder.Question(q) + if len(addrs) > 0 { + builder.StartAnswers() + for _, addr := range addrs { + builder.AResource(dns.ResourceHeader{ + Name: q.Name, + Class: q.Class, + TTL: 120, + }, dns.AResource{ + A: addr.As4(), + }) + } + } + response, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return response +} + +func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte { + resp, err := runTestQuery(tb, request, modify, ports...) if err != nil { tb.Fatalf("error making request: %v", err) } @@ -515,7 +570,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -553,7 +608,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, nil, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -584,11 +639,11 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + resp := mustRunTestQuery(t, request, func(fwd *forwarder) { // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) - }) + }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) @@ -612,41 +667,10 @@ func TestForwarderTCPFallbackError(t *testing.T) { const domain = "error-response.tailscale.com." // Our response is a SERVFAIL - response := func() []byte { - name := dns.MustNewName(domain) - - builder := dns.NewBuilder(nil, dns.Header{ - Response: true, - RCode: dns.RCodeServerFailure, - }) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - }) - response, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return response - }() + response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -656,7 +680,7 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - resp, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, request, nil, port) if !sawRequest.Load() { t.Error("did not see DNS request") } @@ -673,6 +697,127 @@ func TestForwarderTCPFallbackError(t *testing.T) { } } +// Test to ensure that if we have more than one resolver, and at least one of them +// returns a successful response, we propagate it. +func TestForwarderWithManyResolvers(t *testing.T) { + enableDebug(t) + + const domain = "example.com." + request := makeTestRequest(t, domain) + + tests := []struct { + name string + responses [][]byte // upstream responses + wantResponses [][]byte // we should receive one of these from the forwarder + }{ + { + name: "Success", + responses: [][]byte{ // All upstream servers returned successful, but different, response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + wantResponses: [][]byte{ // We may forward whichever response is received first. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + }, + { + name: "ServFail", + responses: [][]byte{ // All upstream servers returned a SERVFAIL. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, + { + name: "ServFail+Success", + responses: [][]byte{ // All upstream servers fail except for one. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ // We should forward the successful response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "NXDomain", + responses: [][]byte{ // All upstream servers returned NXDOMAIN. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeNameError), + }, + }, + { + name: "NXDomain+Success", + responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "Refused", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "MixFail", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports := make([]uint16, len(tt.responses)) + for i := range tt.responses { + ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) + } + gotResponse, err := runTestQuery(t, request, nil, ports...) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool { + return slices.Equal(gotResponse, wantResponse) + }) + if !responseOk { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0]) + } + }) + } +} + // mdnsResponder at minimum has an expectation that NXDOMAIN must include the // question, otherwise it will penalize our server (#13511). func TestNXDOMAINIncludesQuestion(t *testing.T) { @@ -718,7 +863,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, port, request, nil) + res, err := runTestQuery(t, request, nil, port) if err != nil { t.Fatal(err) }