add magic param for specifying server ip

This commit is contained in:
wesley800 2022-11-30 22:55:28 +08:00
parent e6f8aeeebe
commit c1c0c2972a
2 changed files with 37 additions and 9 deletions

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
@ -597,13 +598,11 @@ func checkDNS(
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
Bootstrap: bootstrap,
Timeout: timeout,
})
us, err := dnssvc.AddressesToUpstreams([]string{upstreamAddr}, bootstrap, timeout)
if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
}
u := us[0]
defer func() { err = errors.WithDeferred(err, u.Close()) }()
if err = healthCheck(u); err != nil {

View File

@ -9,6 +9,8 @@ import (
"fmt"
"net"
"net/netip"
"net/url"
"strings"
"sync/atomic"
"time"
@ -80,7 +82,7 @@ func New(c *Config) (svc *Service, err error) {
if len(c.Upstreams) > 0 {
upstreams = c.Upstreams
} else {
upstreams, err = addressesToUpstreams(
upstreams, err = AddressesToUpstreams(
c.UpstreamServers,
c.BootstrapServers,
c.UpstreamTimeout,
@ -108,20 +110,47 @@ func New(c *Config) (svc *Service, err error) {
return svc, nil
}
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// AddressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// accepts a slice of addresses and other upstream parameters, and returns a
// slice of upstreams.
func addressesToUpstreams(
func AddressesToUpstreams(
upsStrs []string,
bootstraps []string,
timeout time.Duration,
) (upstreams []upstream.Upstream, err error) {
upstreams = make([]upstream.Upstream, len(upsStrs))
for i, upsStr := range upsStrs {
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
// Here we pre-parse the url to find any magic params.
opts := upstream.Options{
Bootstrap: bootstraps,
Timeout: timeout,
})
}
if strings.Contains(upsStr, "://") {
var uu *url.URL
uu, err = url.Parse(upsStr)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", upsStr, err)
}
queries := uu.Query()
if serverIPstrs, exists := queries["adguardhome_upstream_ip"]; exists {
serverIPs := make([]net.IP, len(serverIPstrs))
for i := range serverIPstrs {
if t := net.ParseIP(serverIPstrs[i]); t != nil {
serverIPs[i] = t
} else {
return nil, fmt.Errorf("failed to parse upstream_ip %s", serverIPstrs[i])
}
}
opts.ServerIPAddrs = serverIPs
// Remove the magic param to avoid interference with the real server
queries.Del("adguardhome_upstream_ip")
uu.RawQuery = queries.Encode()
upsStr = uu.String()
}
}
upstreams[i], err = upstream.AddressToUpstream(upsStr, &opts)
if err != nil {
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
}