238 lines
6.2 KiB
Go
238 lines
6.2 KiB
Go
package filtering_test
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
|
"github.com/AdguardTeam/golibs/testutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
|
type rewriteJSON struct {
|
|
Domain string `json:"domain"`
|
|
Answer string `json:"answer"`
|
|
}
|
|
|
|
type rewriteUpdateJSON struct {
|
|
Target rewriteJSON `json:"target"`
|
|
Update rewriteJSON `json:"update"`
|
|
}
|
|
|
|
const (
|
|
// testTimeout is the common timeout for tests.
|
|
testTimeout = 100 * time.Millisecond
|
|
|
|
listURL = "/control/rewrite/list"
|
|
addURL = "/control/rewrite/add"
|
|
deleteURL = "/control/rewrite/delete"
|
|
updateURL = "/control/rewrite/update"
|
|
|
|
decodeErrorMsg = "json.Decode: json: cannot unmarshal string into Go value of type" +
|
|
" filtering.rewriteEntryJSON\n"
|
|
)
|
|
|
|
func TestDNSFilter_handleRewriteHTTP(t *testing.T) {
|
|
confModCh := make(chan struct{})
|
|
reqCh := make(chan struct{})
|
|
testRewrites := []*rewriteJSON{
|
|
{Domain: "example.local", Answer: "example.rewrite"},
|
|
{Domain: "one.local", Answer: "one.rewrite"},
|
|
}
|
|
|
|
testRewritesJSON, mErr := json.Marshal(testRewrites)
|
|
require.NoError(t, mErr)
|
|
|
|
testCases := []struct {
|
|
reqData any
|
|
name string
|
|
url string
|
|
method string
|
|
wantList []*rewriteJSON
|
|
wantBody string
|
|
wantConfMod bool
|
|
wantStatus int
|
|
}{{
|
|
name: "list",
|
|
url: listURL,
|
|
method: http.MethodGet,
|
|
reqData: nil,
|
|
wantConfMod: false,
|
|
wantStatus: http.StatusOK,
|
|
wantBody: string(testRewritesJSON) + "\n",
|
|
wantList: testRewrites,
|
|
}, {
|
|
name: "add",
|
|
url: addURL,
|
|
method: http.MethodPost,
|
|
reqData: rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
|
|
wantConfMod: true,
|
|
wantStatus: http.StatusOK,
|
|
wantBody: "",
|
|
wantList: append(
|
|
testRewrites,
|
|
&rewriteJSON{Domain: "add.local", Answer: "add.rewrite"},
|
|
),
|
|
}, {
|
|
name: "add_error",
|
|
url: addURL,
|
|
method: http.MethodPost,
|
|
reqData: "invalid_json",
|
|
wantConfMod: false,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: decodeErrorMsg,
|
|
wantList: testRewrites,
|
|
}, {
|
|
name: "delete",
|
|
url: deleteURL,
|
|
method: http.MethodPost,
|
|
reqData: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
|
|
wantConfMod: true,
|
|
wantStatus: http.StatusOK,
|
|
wantBody: "",
|
|
wantList: []*rewriteJSON{{Domain: "example.local", Answer: "example.rewrite"}},
|
|
}, {
|
|
name: "delete_error",
|
|
url: deleteURL,
|
|
method: http.MethodPost,
|
|
reqData: "invalid_json",
|
|
wantConfMod: false,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: decodeErrorMsg,
|
|
wantList: testRewrites,
|
|
}, {
|
|
name: "update",
|
|
url: updateURL,
|
|
method: http.MethodPut,
|
|
reqData: rewriteUpdateJSON{
|
|
Target: rewriteJSON{Domain: "one.local", Answer: "one.rewrite"},
|
|
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
|
|
},
|
|
wantConfMod: true,
|
|
wantStatus: http.StatusOK,
|
|
wantBody: "",
|
|
wantList: []*rewriteJSON{
|
|
{Domain: "example.local", Answer: "example.rewrite"},
|
|
{Domain: "upd.local", Answer: "upd.rewrite"},
|
|
},
|
|
}, {
|
|
name: "update_error",
|
|
url: updateURL,
|
|
method: http.MethodPut,
|
|
reqData: "invalid_json",
|
|
wantConfMod: false,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: "json.Decode: json: cannot unmarshal string into Go value of type" +
|
|
" filtering.rewriteUpdateJSON\n",
|
|
wantList: testRewrites,
|
|
}, {
|
|
name: "update_error_target",
|
|
url: updateURL,
|
|
method: http.MethodPut,
|
|
reqData: rewriteUpdateJSON{
|
|
Target: rewriteJSON{Domain: "inv.local", Answer: "inv.rewrite"},
|
|
Update: rewriteJSON{Domain: "upd.local", Answer: "upd.rewrite"},
|
|
},
|
|
wantConfMod: false,
|
|
wantStatus: http.StatusBadRequest,
|
|
wantBody: "target rule not found\n",
|
|
wantList: testRewrites,
|
|
}}
|
|
|
|
for _, tc := range testCases {
|
|
onConfModified := func() {
|
|
if !tc.wantConfMod {
|
|
panic("config modified has been fired")
|
|
}
|
|
|
|
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
|
}
|
|
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
handlers := make(map[string]http.Handler)
|
|
|
|
d, err := filtering.New(&filtering.Config{
|
|
ConfigModified: onConfModified,
|
|
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
|
handlers[url] = handler
|
|
},
|
|
Rewrites: rewriteEntriesToLegacyRewrites(testRewrites),
|
|
}, nil)
|
|
require.NoError(t, err)
|
|
t.Cleanup(d.Close)
|
|
|
|
d.RegisterFilteringHandlers()
|
|
require.NotEmpty(t, handlers)
|
|
require.Contains(t, handlers, listURL)
|
|
require.Contains(t, handlers, tc.url)
|
|
|
|
var body io.Reader
|
|
if tc.reqData != nil {
|
|
data, rErr := json.Marshal(tc.reqData)
|
|
require.NoError(t, rErr)
|
|
|
|
body = bytes.NewReader(data)
|
|
}
|
|
|
|
r := httptest.NewRequest(tc.method, tc.url, body)
|
|
w := httptest.NewRecorder()
|
|
|
|
go func() {
|
|
handlers[tc.url].ServeHTTP(w, r)
|
|
|
|
testutil.RequireSend(testutil.PanicT{}, reqCh, struct{}{}, testTimeout)
|
|
}()
|
|
|
|
if tc.wantConfMod {
|
|
testutil.RequireReceive(t, confModCh, testTimeout)
|
|
}
|
|
|
|
testutil.RequireReceive(t, reqCh, testTimeout)
|
|
assert.Equal(t, tc.wantStatus, w.Code)
|
|
|
|
respBody, err := io.ReadAll(w.Body)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, []byte(tc.wantBody), respBody)
|
|
|
|
assertRewritesList(t, handlers[listURL], tc.wantList)
|
|
})
|
|
}
|
|
}
|
|
|
|
// assertRewritesList checks if rewrites list equals the list received from the
|
|
// handler by listURL.
|
|
func assertRewritesList(t *testing.T, handler http.Handler, wantList []*rewriteJSON) {
|
|
t.Helper()
|
|
|
|
r := httptest.NewRequest(http.MethodGet, listURL, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(w, r)
|
|
require.Equal(t, http.StatusOK, w.Code)
|
|
|
|
var actual []*rewriteJSON
|
|
err := json.NewDecoder(w.Body).Decode(&actual)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, wantList, actual)
|
|
}
|
|
|
|
// rewriteEntriesToLegacyRewrites gets legacy rewrites from json entries.
|
|
func rewriteEntriesToLegacyRewrites(entries []*rewriteJSON) (rw []*filtering.LegacyRewrite) {
|
|
for _, entry := range entries {
|
|
rw = append(rw, &filtering.LegacyRewrite{
|
|
Domain: entry.Domain,
|
|
Answer: entry.Answer,
|
|
})
|
|
}
|
|
|
|
return rw
|
|
}
|