package dnsforward import ( "net/netip" "testing" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIsBlockedClientID(t *testing.T) { clientID := "client-1" clients := []string{clientID} a, err := newAccessCtx(clients, nil, nil) require.NoError(t, err) assert.False(t, a.isBlockedClientID(clientID)) a, err = newAccessCtx(nil, clients, nil) require.NoError(t, err) assert.True(t, a.isBlockedClientID(clientID)) } func TestIsBlockedHost(t *testing.T) { a, err := newAccessCtx(nil, nil, []string{ "host1", "*.host.com", "||host3.com^", "||*^$dnstype=HTTPS", }) require.NoError(t, err) testCases := []struct { want assert.BoolAssertionFunc name string host string qt rules.RRType }{{ want: assert.True, name: "plain_match", host: "host1", qt: dns.TypeA, }, { want: assert.False, name: "plain_mismatch", host: "host2", qt: dns.TypeA, }, { want: assert.True, name: "subdomain_match_short", host: "asdf.host.com", qt: dns.TypeA, }, { want: assert.True, name: "subdomain_match_long", host: "qwer.asdf.host.com", qt: dns.TypeA, }, { want: assert.False, name: "subdomain_mismatch_no_lead", host: "host.com", qt: dns.TypeA, }, { want: assert.False, name: "subdomain_mismatch_bad_asterisk", host: "asdf.zhost.com", qt: dns.TypeA, }, { want: assert.True, name: "rule_match_simple", host: "host3.com", qt: dns.TypeA, }, { want: assert.True, name: "rule_match_complex", host: "asdf.host3.com", qt: dns.TypeA, }, { want: assert.False, name: "rule_mismatch", host: ".host3.com", qt: dns.TypeA, }, { want: assert.True, name: "by_qtype", host: "site-with-https-record.example", qt: dns.TypeHTTPS, }, { want: assert.False, name: "by_qtype_other", host: "site-with-https-record.example", qt: dns.TypeA, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tc.want(t, a.isBlockedHost(tc.host, tc.qt)) }) } } func TestIsBlockedIP(t *testing.T) { clients := []string{ "1.2.3.4", "5.6.7.8/24", } allowCtx, err := newAccessCtx(clients, nil, nil) require.NoError(t, err) blockCtx, err := newAccessCtx(nil, clients, nil) require.NoError(t, err) testCases := []struct { ip netip.Addr name string wantRule string wantBlocked bool }{{ ip: netip.MustParseAddr("1.2.3.4"), name: "match_ip", wantRule: "1.2.3.4", wantBlocked: true, }, { ip: netip.MustParseAddr("5.6.7.100"), name: "match_cidr", wantRule: "5.6.7.8/24", wantBlocked: true, }, { ip: netip.MustParseAddr("9.2.3.4"), name: "no_match_ip", wantRule: "", wantBlocked: false, }, { ip: netip.MustParseAddr("9.6.7.100"), name: "no_match_cidr", wantRule: "", wantBlocked: false, }} t.Run("allow", func(t *testing.T) { for _, tc := range testCases { blocked, rule := allowCtx.isBlockedIP(tc.ip) assert.Equal(t, !tc.wantBlocked, blocked) assert.Equal(t, tc.wantRule, rule) } }) t.Run("block", func(t *testing.T) { for _, tc := range testCases { blocked, rule := blockCtx.isBlockedIP(tc.ip) assert.Equal(t, tc.wantBlocked, blocked) assert.Equal(t, tc.wantRule, rule) } }) }