diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index 430d2f35..341f2fad 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -249,6 +249,39 @@ func TestBlockedRequest(t *testing.T) { } } +func TestServerCustomClientUpstream(t *testing.T) { + s := createTestServer(t) + err := s.Start() + if err != nil { + t.Fatalf("Failed to start server: %s", err) + } + s.conf.GetCustomUpstreamByClient = func(clientAddr string) *proxy.UpstreamConfig { + uc := &proxy.UpstreamConfig{} + u := &testUpstream{} + u.ipv4 = map[string][]net.IP{} + u.ipv4["host."] = []net.IP{net.ParseIP("192.168.0.1")} + uc.Upstreams = append(uc.Upstreams, u) + return uc + } + addr := s.dnsProxy.Addr(proxy.ProtoUDP) + + // Send test request + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "host.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + reply, err := dns.Exchange(&req, addr.String()) + + assert.Nil(t, err) + assert.Equal(t, dns.RcodeSuccess, reply.Rcode) + assert.NotNil(t, reply.Answer) + assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String()) + assert.Nil(t, s.Stop()) +} + // testUpstream is a mock of real upstream. // specify fields with necessary values to simulate real upstream behaviour type testUpstream struct { diff --git a/dnsforward/helpers.go b/dnsforward/helpers.go deleted file mode 100644 index e7212355..00000000 --- a/dnsforward/helpers.go +++ /dev/null @@ -1,14 +0,0 @@ -package dnsforward - -import "net" - -// GetIPString is a helper function that extracts IP address from net.Addr -func GetIPString(addr net.Addr) string { - switch addr := addr.(type) { - case *net.UDPAddr: - return addr.IP.String() - case *net.TCPAddr: - return addr.IP.String() - } - return "" -} diff --git a/dnsforward/util.go b/dnsforward/util.go index f5c62cb8..f159f43a 100644 --- a/dnsforward/util.go +++ b/dnsforward/util.go @@ -8,6 +8,17 @@ import ( "github.com/AdguardTeam/golibs/utils" ) +// GetIPString is a helper function that extracts IP address from net.Addr +func GetIPString(addr net.Addr) string { + switch addr := addr.(type) { + case *net.UDPAddr: + return addr.IP.String() + case *net.TCPAddr: + return addr.IP.String() + } + return "" +} + func stringArrayDup(a []string) []string { a2 := make([]string, len(a)) copy(a2, a)