diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 8879bb1a..d15538a8 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -10,6 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" + "github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" @@ -597,13 +598,11 @@ func checkDNS( log.Debug("dnsforward: checking if upstream %q works", upstreamAddr) - u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{ - Bootstrap: bootstrap, - Timeout: timeout, - }) + us, err := dnssvc.AddressesToUpstreams([]string{upstreamAddr}, bootstrap, timeout) if err != nil { return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err) } + u := us[0] defer func() { err = errors.WithDeferred(err, u.Close()) }() if err = healthCheck(u); err != nil { diff --git a/internal/next/dnssvc/dnssvc.go b/internal/next/dnssvc/dnssvc.go index f25fa294..772f0891 100644 --- a/internal/next/dnssvc/dnssvc.go +++ b/internal/next/dnssvc/dnssvc.go @@ -9,6 +9,8 @@ import ( "fmt" "net" "net/netip" + "net/url" + "strings" "sync/atomic" "time" @@ -80,7 +82,7 @@ func New(c *Config) (svc *Service, err error) { if len(c.Upstreams) > 0 { upstreams = c.Upstreams } else { - upstreams, err = addressesToUpstreams( + upstreams, err = AddressesToUpstreams( c.UpstreamServers, c.BootstrapServers, c.UpstreamTimeout, @@ -108,20 +110,47 @@ func New(c *Config) (svc *Service, err error) { return svc, nil } -// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It +// AddressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It // accepts a slice of addresses and other upstream parameters, and returns a // slice of upstreams. -func addressesToUpstreams( +func AddressesToUpstreams( upsStrs []string, bootstraps []string, timeout time.Duration, ) (upstreams []upstream.Upstream, err error) { upstreams = make([]upstream.Upstream, len(upsStrs)) for i, upsStr := range upsStrs { - upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{ + // Here we pre-parse the url to find any magic params. + opts := upstream.Options{ Bootstrap: bootstraps, Timeout: timeout, - }) + } + if strings.Contains(upsStr, "://") { + var uu *url.URL + uu, err = url.Parse(upsStr) + if err != nil { + return nil, fmt.Errorf("failed to parse %s: %w", upsStr, err) + } + queries := uu.Query() + if serverIPstrs, exists := queries["adguardhome_upstream_ip"]; exists { + serverIPs := make([]net.IP, len(serverIPstrs)) + for i := range serverIPstrs { + if t := net.ParseIP(serverIPstrs[i]); t != nil { + serverIPs[i] = t + } else { + return nil, fmt.Errorf("failed to parse upstream_ip %s", serverIPstrs[i]) + } + } + opts.ServerIPAddrs = serverIPs + + // Remove the magic param to avoid interference with the real server + queries.Del("adguardhome_upstream_ip") + uu.RawQuery = queries.Encode() + upsStr = uu.String() + } + } + + upstreams[i], err = upstream.AddressToUpstream(upsStr, &opts) if err != nil { return nil, fmt.Errorf("upstream at index %d: %w", i, err) }