Pull request: aghnet: imp host validation for system resolvers

Updates #3022.

Squashed commit of the following:

commit 2f63b4e1765d9c9bfeadafcfa42c9d8741b628e1
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 28 21:29:28 2021 +0300

    aghnet: fix doc

commit efdc1bb2c8959a9f888d558c32c415e6f3678b0c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 28 21:19:54 2021 +0300

    all: doc changes

commit 8154797095874771bcf04d109644e6ae33fcb470
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Apr 28 21:15:42 2021 +0300

    aghnet: imp host validation for system resolvers
This commit is contained in:
Ainar Garipov 2021-04-28 21:34:18 +03:00
parent 5b8081169e
commit f5adf15c8c
4 changed files with 63 additions and 14 deletions

View File

@ -17,6 +17,14 @@ and this project adheres to
## [v0.106.1] - 2021-05-17 (APPROX.) ## [v0.106.1] - 2021-05-17 (APPROX.)
--> -->
### Fixed
- Validation of IPv6 addresses with zones in system resolvers ([#3022]).
[#3022]: https://github.com/AdguardTeam/AdGuardHome/issues/3022
## [v0.106.0] - 2021-04-28 ## [v0.106.0] - 2021-04-28
### Added ### Added

View File

@ -26,11 +26,15 @@ type SystemResolvers interface {
} }
const ( const (
// fakeDialErr is an error which dialFunc is expected to return. // errBadAddrPassed is returned when dialFunc can't parse an IP address.
fakeDialErr agherr.Error = "this error signals the successful dialFunc work" errBadAddrPassed agherr.Error = "the passed string is not a valid IP address"
// badAddrPassedErr is returned when dialFunc can't parse an IP address. // errFakeDial is an error which dialFunc is expected to return.
badAddrPassedErr agherr.Error = "the passed string is not a valid IP address" errFakeDial agherr.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat agherr.Error = "unexpected host format"
) )
// refreshWithTicker refreshes the cache of sr after each tick form tickCh. // refreshWithTicker refreshes the cache of sr after each tick form tickCh.

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
@ -35,7 +36,7 @@ func (sr *systemResolvers) refresh() (err error) {
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc()) _, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{} dnserr := &net.DNSError{}
if errors.As(err, &dnserr) && dnserr.Err == fakeDialErr.Error() { if errors.As(err, &dnserr) && dnserr.Err == errFakeDial.Error() {
return nil return nil
} }
@ -58,19 +59,43 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
return s return s
} }
// validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) {
defer agherr.Annotate("parsing %q: %w", &err, host)
var ipStr string
parts := strings.Split(host, "%")
switch len(parts) {
case 1:
ipStr = host
case 2:
// Remove the zone and check the IP address part.
ipStr = parts[0]
default:
return errUnexpectedHostFormat
}
if net.ParseIP(ipStr) == nil {
return errBadAddrPassed
}
return nil
}
// dialFunc gets the resolver's address and puts it into internal cache. // dialFunc gets the resolver's address and puts it into internal cache.
func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) { func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net.Conn, err error) {
// Just validate the passed address is a valid IP. // Just validate the passed address is a valid IP.
var host string var host string
host, err = SplitHost(address) host, err = SplitHost(address)
if err != nil { if err != nil {
// TODO(e.burkov): Maybe use a structured badAddrPassedErr to // TODO(e.burkov): Maybe use a structured errBadAddrPassed to
// allow unwrapping of the real error. // allow unwrapping of the real error.
return nil, fmt.Errorf("%s: %w", err, badAddrPassedErr) return nil, fmt.Errorf("%s: %w", err, errBadAddrPassed)
} }
if net.ParseIP(host) == nil { err = validateDialedHost(host)
return nil, fmt.Errorf("parsing %q: %w", host, badAddrPassedErr) if err != nil {
return nil, fmt.Errorf("validating dialed host: %w", err)
} }
sr.addrsLock.Lock() sr.addrsLock.Lock()
@ -78,7 +103,7 @@ func (sr *systemResolvers) dialFunc(_ context.Context, _, address string) (_ net
sr.addrs.Add(host) sr.addrs.Add(host)
return nil, fakeDialErr return nil, errFakeDial
} }
func (sr *systemResolvers) Get() (rs []string) { func (sr *systemResolvers) Get() (rs []string) {

View File

@ -46,21 +46,33 @@ func TestSystemResolvers_DialFunc(t *testing.T) {
imp := createTestSystemResolversImp(t, 0, nil) imp := createTestSystemResolversImp(t, 0, nil)
testCases := []struct { testCases := []struct {
want error
name string name string
address string address string
want error
}{{ }{{
want: errFakeDial,
name: "valid", name: "valid",
address: "127.0.0.1", address: "127.0.0.1",
want: fakeDialErr,
}, { }, {
want: errFakeDial,
name: "valid_ipv6_port",
address: "[::1]:53",
}, {
want: errFakeDial,
name: "valid_ipv6_zone_port",
address: "[::1%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_split_host", name: "invalid_split_host",
address: "127.0.0.1::123", address: "127.0.0.1::123",
want: badAddrPassedErr,
}, { }, {
want: errUnexpectedHostFormat,
name: "invalid_ipv6_zone_port",
address: "[::1%%lo0]:53",
}, {
want: errBadAddrPassed,
name: "invalid_parse_ip", name: "invalid_parse_ip",
address: "not-ip", address: "not-ip",
want: badAddrPassedErr,
}} }}
for _, tc := range testCases { for _, tc := range testCases {