Pull request: all: add idna handling, imp domain validation
Updates #2915. Squashed commit of the following: commit b907324426c87ee7334edbd61e43c44444ad27a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 7 16:26:41 2021 +0300 all: imp docs, upd commit c022f75cac006e077095cad283fea0a91d3a0eea Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Wed Apr 7 15:51:30 2021 +0300 all: add idna handling, imp domain validation
This commit is contained in:
parent
00a61fdea0
commit
c133b01ef7
4
go.mod
4
go.mod
|
@ -35,9 +35,9 @@ require (
|
|||
github.com/u-root/u-root v7.0.0+incompatible
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
|
||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44
|
||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d // indirect
|
||||
golang.org/x/text v0.3.5 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.0.0
|
||||
|
|
8
go.sum
8
go.sum
|
@ -516,8 +516,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||
golang.org/x/net v0.0.0-20201216054612-986b41b23924/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 h1:qWPm9rbaAMKs8Bq/9LRpbMqxWRVUAQwMI9fVrssnTfw=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4 h1:4nGaVu0QrbjT/AK2PRLuQfQuh6DJve+pELhqTdAj3x0=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
|
@ -582,8 +582,8 @@ golang.org/x/sys v0.0.0-20210110051926-789bb1bd4061/go.mod h1:h1NjWce9XRLGQEsW7w
|
|||
golang.org/x/sys v0.0.0-20210123111255-9b0068b26619/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210216163648-f7da38b97c65/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2 h1:46ULzRKLh1CwgRq2dC5SlBzEqqNCi8rreOZnNrbqcIY=
|
||||
golang.org/x/sys v0.0.0-20210309074719-68d13333faf2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44 h1:Bli41pIlzTzf3KEY06n+xnzK/BESIg2ze4Pgfh/aI8c=
|
||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d h1:SZxvLBoTP5yHO3Frd4z4vrF+DBX9vMVanchswa69toE=
|
||||
|
|
|
@ -3,8 +3,10 @@ package aghnet
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
|
||||
|
@ -21,3 +23,79 @@ func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
|
|||
return fmt.Errorf("bad len: %d", l)
|
||||
}
|
||||
}
|
||||
|
||||
// maxDomainLabelLen is the maximum allowed length of a domain name label
|
||||
// according to RFC 1035.
|
||||
const maxDomainLabelLen = 63
|
||||
|
||||
// maxDomainNameLen is the maximum allowed length of a full domain name
|
||||
// according to RFC 1035.
|
||||
//
|
||||
// See https://stackoverflow.com/a/32294443/1892060.
|
||||
const maxDomainNameLen = 253
|
||||
|
||||
const invalidCharMsg = "invalid char %q at index %d in %q"
|
||||
|
||||
// isValidHostFirstRune returns true if r is a valid first rune for a hostname
|
||||
// label.
|
||||
func isValidHostFirstRune(r rune) (ok bool) {
|
||||
return (r >= 'a' && r <= 'z') ||
|
||||
(r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9')
|
||||
}
|
||||
|
||||
// isValidHostRune returns true if r is a valid rune for a hostname label.
|
||||
func isValidHostRune(r rune) (ok bool) {
|
||||
return r == '-' || isValidHostFirstRune(r)
|
||||
}
|
||||
|
||||
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
||||
// a domain name.
|
||||
func ValidateDomainNameLabel(label string) (err error) {
|
||||
if len(label) > maxDomainLabelLen {
|
||||
return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen)
|
||||
} else if len(label) == 0 {
|
||||
return agherr.Error("label is empty")
|
||||
}
|
||||
|
||||
if r := label[0]; !isValidHostFirstRune(rune(r)) {
|
||||
return fmt.Errorf(invalidCharMsg, r, 0, label)
|
||||
}
|
||||
|
||||
for i, r := range label[1:] {
|
||||
if !isValidHostRune(r) {
|
||||
return fmt.Errorf(invalidCharMsg, r, i+1, label)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDomainName validates the domain name in accordance to RFC 952, RFC
|
||||
// 1035, and with RFC-1123's inclusion of digits at the start of the host. It
|
||||
// doesn't validate against two or more hyphens to allow punycode and
|
||||
// internationalized domains.
|
||||
//
|
||||
// TODO(a.garipov): After making sure that this works correctly, port this into
|
||||
// module golibs.
|
||||
func ValidateDomainName(name string) (err error) {
|
||||
name, err = idna.ToASCII(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
l := len(name)
|
||||
if l == 0 || l > maxDomainNameLen {
|
||||
return fmt.Errorf("%q is too long, max: %d", name, maxDomainNameLen)
|
||||
}
|
||||
|
||||
labels := strings.Split(name, ".")
|
||||
for i, l := range labels {
|
||||
err = ValidateDomainNameLabel(l)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid domain name label at index %d: %w", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package aghnet
|
|||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -50,6 +51,81 @@ func TestValidateHardwareAddress(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func repeatStr(b *strings.Builder, s string, n int) {
|
||||
for i := 0; i < n; i++ {
|
||||
_, _ = b.WriteString(s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDomainName(t *testing.T) {
|
||||
b := &strings.Builder{}
|
||||
repeatStr(b, "a", 255)
|
||||
longDomainName := b.String()
|
||||
|
||||
b.Reset()
|
||||
repeatStr(b, "a", 64)
|
||||
longLabel := b.String()
|
||||
|
||||
_, _ = b.WriteString(".com")
|
||||
longLabelDomainName := b.String()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantErrMsg string
|
||||
}{{
|
||||
name: "success",
|
||||
in: "example.com",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "success_idna",
|
||||
in: "пример.рф",
|
||||
wantErrMsg: "",
|
||||
}, {
|
||||
name: "bad_symbol",
|
||||
in: "!!!",
|
||||
wantErrMsg: `invalid domain name label at index 0: ` +
|
||||
`invalid char '!' at index 0 in "!!!"`,
|
||||
}, {
|
||||
name: "bad_length",
|
||||
in: longDomainName,
|
||||
wantErrMsg: `"` + longDomainName + `" is too long, max: 253`,
|
||||
}, {
|
||||
name: "bad_label_length",
|
||||
in: longLabelDomainName,
|
||||
wantErrMsg: `invalid domain name label at index 0: "` + longLabel +
|
||||
`" is too long, max: 63`,
|
||||
}, {
|
||||
name: "bad_label_empty",
|
||||
in: "example..com",
|
||||
wantErrMsg: `invalid domain name label at index 1: label is empty`,
|
||||
}, {
|
||||
name: "bad_label_first_symbol",
|
||||
in: "example.-aa.com",
|
||||
wantErrMsg: `invalid domain name label at index 1:` +
|
||||
` invalid char '-' at index 0 in "-aa"`,
|
||||
}, {
|
||||
name: "bad_label_symbol",
|
||||
in: "example.a!!!.com",
|
||||
wantErrMsg: `invalid domain name label at index 1:` +
|
||||
` invalid char '!' at index 1 in "a!!!"`,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateDomainName(tc.in)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
|
|
|
@ -6,33 +6,14 @@ import (
|
|||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
// maxDomainLabelLen is the maximum allowed length of a domain name label
|
||||
// according to RFC 1035.
|
||||
const maxDomainLabelLen = 63
|
||||
|
||||
// validateDomainNameLabel returns an error if label is not a valid label of
|
||||
// a domain name.
|
||||
func validateDomainNameLabel(label string) (err error) {
|
||||
if len(label) > maxDomainLabelLen {
|
||||
return fmt.Errorf("%q is too long, max: %d", label, maxDomainLabelLen)
|
||||
}
|
||||
|
||||
for i, r := range label {
|
||||
if (r < 'a' || r > 'z') && (r < '0' || r > '9') && r != '-' {
|
||||
return fmt.Errorf("invalid char %q at index %d in %q", r, i, label)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateClientID returns an error if clientID is not a valid client ID.
|
||||
func ValidateClientID(clientID string) (err error) {
|
||||
err = validateDomainNameLabel(clientID)
|
||||
err = aghnet.ValidateDomainNameLabel(clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid client id: %w", err)
|
||||
}
|
||||
|
|
|
@ -238,8 +238,8 @@ func TestProcessClientID_https(t *testing.T) {
|
|||
name: "invalid_client_id",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id: invalid char '!'` +
|
||||
` at index 0 in "!!!"`,
|
||||
wantErrMsg: `client id check: invalid client id: invalid char '!' ` +
|
||||
`at index 0 in "!!!"`,
|
||||
wantRes: resultCodeError,
|
||||
}}
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
|||
if p.AutohostTLD == "" {
|
||||
autohostSuffix = defaultAutohostSuffix
|
||||
} else {
|
||||
err = validateDomainNameLabel(p.AutohostTLD)
|
||||
err = aghnet.ValidateDomainNameLabel(p.AutohostTLD)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("autohost tld: %w", err)
|
||||
}
|
||||
|
|
|
@ -947,145 +947,6 @@ func publicKey(priv interface{}) interface{} {
|
|||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstream string
|
||||
valid bool
|
||||
wantDef bool
|
||||
}{{
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
valid: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
valid: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
valid: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
valid: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
valid: false,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
require.Equal(t, tc.valid, err == nil)
|
||||
if tc.valid {
|
||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
set []string
|
||||
wantNil bool
|
||||
}{{
|
||||
name: "empty",
|
||||
msg: "empty upstreams array should be valid",
|
||||
set: nil,
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "comment",
|
||||
msg: "comments should not be validated",
|
||||
set: []string{"# comment"},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
msg: "there is no default upstream",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
wantNil: false,
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"8.8.8.8",
|
||||
},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "invalid",
|
||||
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
wantNil: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreams(tc.set)
|
||||
|
||||
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPStringFromAddr(t *testing.T) {
|
||||
t.Run("not_nil", func(t *testing.T) {
|
||||
addr := net.UDPAddr{
|
||||
|
|
|
@ -8,10 +8,11 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
|
@ -302,7 +303,7 @@ type upstreamJSON struct {
|
|||
}
|
||||
|
||||
// ValidateUpstreams validates each upstream and returns an error if any upstream is invalid or if there are no default upstreams specified
|
||||
func ValidateUpstreams(upstreams []string) error {
|
||||
func ValidateUpstreams(upstreams []string) (err error) {
|
||||
// No need to validate comments
|
||||
upstreams = filterOutComments(upstreams)
|
||||
|
||||
|
@ -311,7 +312,7 @@ func ValidateUpstreams(upstreams []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
_, err := proxy.ParseUpstreamsConfig(
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
|
@ -345,56 +346,61 @@ func ValidateUpstreams(upstreams []string) error {
|
|||
var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}
|
||||
|
||||
func validateUpstream(u string) (bool, error) {
|
||||
// Check if user tries to specify upstream for domain
|
||||
u, defaultUpstream, err := separateUpstream(u)
|
||||
// Check if the user tries to specify upstream for domain.
|
||||
u, useDefault, err := separateUpstream(u)
|
||||
if err != nil {
|
||||
return defaultUpstream, err
|
||||
return useDefault, err
|
||||
}
|
||||
|
||||
// The special server address '#' means "use the default servers"
|
||||
if u == "#" && !defaultUpstream {
|
||||
return defaultUpstream, nil
|
||||
if u == "#" && !useDefault {
|
||||
return useDefault, nil
|
||||
}
|
||||
|
||||
// Check if the upstream has a valid protocol prefix
|
||||
for _, proto := range protocols {
|
||||
if strings.HasPrefix(u, proto) {
|
||||
return defaultUpstream, nil
|
||||
return useDefault, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Return error if the upstream contains '://' without any valid protocol
|
||||
if strings.Contains(u, "://") {
|
||||
return defaultUpstream, fmt.Errorf("wrong protocol")
|
||||
return useDefault, fmt.Errorf("wrong protocol")
|
||||
}
|
||||
|
||||
// Check if upstream is valid plain DNS
|
||||
return defaultUpstream, checkPlainDNS(u)
|
||||
return useDefault, checkPlainDNS(u)
|
||||
}
|
||||
|
||||
// separateUpstream returns upstream without specified domains and a bool flag that indicates if no domains were specified
|
||||
// error will be returned if upstream per domain specification is invalid
|
||||
func separateUpstream(upstream string) (string, bool, error) {
|
||||
defaultUpstream := true
|
||||
if strings.HasPrefix(upstream, "[/") {
|
||||
defaultUpstream = false
|
||||
// split domains and upstream string
|
||||
domainsAndUpstream := strings.Split(strings.TrimPrefix(upstream, "[/"), "/]")
|
||||
if len(domainsAndUpstream) != 2 {
|
||||
return "", defaultUpstream, fmt.Errorf("wrong dns upstream per domain specification: %s", upstream)
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// useDefault is true when a default upstream must be used.
|
||||
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
|
||||
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr)
|
||||
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, true, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
if len(parts) != 2 {
|
||||
return "", false, agherr.Error("duplicated separator")
|
||||
}
|
||||
|
||||
domains := parts[0]
|
||||
upstream = parts[1]
|
||||
for i, host := range strings.Split(domains, "/") {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// split domains list and validate each one
|
||||
for _, host := range strings.Split(domainsAndUpstream[0], "/") {
|
||||
if host != "" {
|
||||
if err := utils.IsValidHostname(host); err != nil {
|
||||
return "", defaultUpstream, err
|
||||
}
|
||||
}
|
||||
err = aghnet.ValidateDomainName(host)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
|
||||
}
|
||||
upstream = domainsAndUpstream[1]
|
||||
}
|
||||
return upstream, defaultUpstream, nil
|
||||
|
||||
return upstream, false, nil
|
||||
}
|
||||
|
||||
// checkPlainDNS checks if host is plain DNS
|
||||
|
@ -462,13 +468,13 @@ func checkDNS(input string, bootstrap []string) error {
|
|||
}
|
||||
|
||||
// separate upstream from domains list
|
||||
input, defaultUpstream, err := separateUpstream(input)
|
||||
input, useDefault, err := separateUpstream(input)
|
||||
if err != nil {
|
||||
return fmt.Errorf("wrong upstream format: %w", err)
|
||||
}
|
||||
|
||||
// No need to check this DNS server
|
||||
if !defaultUpstream {
|
||||
if !useDefault {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -213,3 +213,158 @@ func TestDNSForwardHTTTP_handleSetConfig(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(a.garipov): Rewrite to check the actual error messages.
|
||||
func TestValidateUpstream(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstream string
|
||||
valid bool
|
||||
wantDef bool
|
||||
}{{
|
||||
name: "invalid",
|
||||
upstream: "1.2.3.4.5",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "123.3.7m",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "htttps://google.com/dns-query",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "[/host.com]tls://dns.adguard.com",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "invalid",
|
||||
upstream: "[host.ru]#",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "https://dns.adguard.com/dns-query",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid_default",
|
||||
upstream: "sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: true,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host.com/]1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[//]tls://1.1.1.1",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/www.host.com/]#",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host.com/google.com/]8.8.8.8",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "valid",
|
||||
upstream: "[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "idna",
|
||||
upstream: "[/пример.рф/]8.8.8.8",
|
||||
valid: true,
|
||||
wantDef: false,
|
||||
}, {
|
||||
name: "bad_domain",
|
||||
upstream: "[/!/]8.8.8.8",
|
||||
valid: false,
|
||||
wantDef: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
defaultUpstream, err := validateUpstream(tc.upstream)
|
||||
require.Equal(t, tc.valid, err == nil)
|
||||
if tc.valid {
|
||||
assert.Equal(t, tc.wantDef, defaultUpstream)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateUpstreamsSet(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
set []string
|
||||
wantNil bool
|
||||
}{{
|
||||
name: "empty",
|
||||
msg: "empty upstreams array should be valid",
|
||||
set: nil,
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "comment",
|
||||
msg: "comments should not be validated",
|
||||
set: []string{"# comment"},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "valid_no_default",
|
||||
msg: "there is no default upstream",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
},
|
||||
wantNil: false,
|
||||
}, {
|
||||
name: "valid_with_default",
|
||||
msg: "upstreams set is valid, but doesn't pass through validation cause: %s",
|
||||
set: []string{
|
||||
"[/host.com/]1.1.1.1",
|
||||
"[//]tls://1.1.1.1",
|
||||
"[/www.host.com/]#",
|
||||
"[/host.com/google.com/]8.8.8.8",
|
||||
"[/host/]sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_JS3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczEuYWRndWFyZC5jb20",
|
||||
"8.8.8.8",
|
||||
},
|
||||
wantNil: true,
|
||||
}, {
|
||||
name: "invalid",
|
||||
msg: "there is an invalid upstream in set, but it pass through validation",
|
||||
set: []string{"dhcp://fake.dns"},
|
||||
wantNil: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := ValidateUpstreams(tc.set)
|
||||
|
||||
assert.Equalf(t, tc.wantNil, err == nil, tc.msg, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
)
|
||||
|
||||
// IPFromAddr gets IP address from addr.
|
||||
|
@ -58,9 +58,10 @@ func matchDomainWildcard(host, wildcard string) bool {
|
|||
|
||||
// Return TRUE if client's SNI value matches DNS names from certificate
|
||||
func matchDNSName(dnsNames []string, sni string) bool {
|
||||
if utils.IsValidHostname(sni) != nil {
|
||||
if aghnet.ValidateDomainName(sni) != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if findSorted(dnsNames, sni) != -1 {
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
|
@ -20,7 +21,6 @@ import (
|
|||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/utils"
|
||||
)
|
||||
|
||||
const clientsUpdatePeriod = 10 * time.Minute
|
||||
|
@ -751,7 +751,7 @@ func (clients *clientsContainer) addFromSystemARP() {
|
|||
|
||||
host := ln[:open]
|
||||
ip := ln[open+2 : close]
|
||||
if utils.IsValidHostname(host) != nil || net.ParseIP(ip) == nil {
|
||||
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
|
@ -123,18 +123,20 @@ func handleMobileConfig(w http.ResponseWriter, r *http.Request, dnsp string) {
|
|||
}
|
||||
|
||||
clientID := q.Get("client_id")
|
||||
err = dnsforward.ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
err = json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: err.Error(),
|
||||
})
|
||||
if clientID != "" {
|
||||
err = dnsforward.ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
log.Debug("writing 400 json response: %s", err)
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
return
|
||||
err = json.NewEncoder(w).Encode(&jsonError{
|
||||
Message: err.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
log.Debug("writing 400 json response: %s", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d := dnsSettings{
|
||||
|
|
Loading…
Reference in New Issue