add magic param for specifying server ip

This commit is contained in:
wesley800 2022-11-30 22:55:28 +08:00
parent e6f8aeeebe
commit c1c0c2972a
2 changed files with 37 additions and 9 deletions

View File

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
@ -597,13 +598,11 @@ func checkDNS(
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr) log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{ us, err := dnssvc.AddressesToUpstreams([]string{upstreamAddr}, bootstrap, timeout)
Bootstrap: bootstrap,
Timeout: timeout,
})
if err != nil { if err != nil {
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err) return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
} }
u := us[0]
defer func() { err = errors.WithDeferred(err, u.Close()) }() defer func() { err = errors.WithDeferred(err, u.Close()) }()
if err = healthCheck(u); err != nil { if err = healthCheck(u); err != nil {

View File

@ -9,6 +9,8 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"net/url"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
@ -80,7 +82,7 @@ func New(c *Config) (svc *Service, err error) {
if len(c.Upstreams) > 0 { if len(c.Upstreams) > 0 {
upstreams = c.Upstreams upstreams = c.Upstreams
} else { } else {
upstreams, err = addressesToUpstreams( upstreams, err = AddressesToUpstreams(
c.UpstreamServers, c.UpstreamServers,
c.BootstrapServers, c.BootstrapServers,
c.UpstreamTimeout, c.UpstreamTimeout,
@ -108,20 +110,47 @@ func New(c *Config) (svc *Service, err error) {
return svc, nil return svc, nil
} }
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It // AddressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
// accepts a slice of addresses and other upstream parameters, and returns a // accepts a slice of addresses and other upstream parameters, and returns a
// slice of upstreams. // slice of upstreams.
func addressesToUpstreams( func AddressesToUpstreams(
upsStrs []string, upsStrs []string,
bootstraps []string, bootstraps []string,
timeout time.Duration, timeout time.Duration,
) (upstreams []upstream.Upstream, err error) { ) (upstreams []upstream.Upstream, err error) {
upstreams = make([]upstream.Upstream, len(upsStrs)) upstreams = make([]upstream.Upstream, len(upsStrs))
for i, upsStr := range upsStrs { for i, upsStr := range upsStrs {
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{ // Here we pre-parse the url to find any magic params.
opts := upstream.Options{
Bootstrap: bootstraps, Bootstrap: bootstraps,
Timeout: timeout, Timeout: timeout,
}) }
if strings.Contains(upsStr, "://") {
var uu *url.URL
uu, err = url.Parse(upsStr)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", upsStr, err)
}
queries := uu.Query()
if serverIPstrs, exists := queries["adguardhome_upstream_ip"]; exists {
serverIPs := make([]net.IP, len(serverIPstrs))
for i := range serverIPstrs {
if t := net.ParseIP(serverIPstrs[i]); t != nil {
serverIPs[i] = t
} else {
return nil, fmt.Errorf("failed to parse upstream_ip %s", serverIPstrs[i])
}
}
opts.ServerIPAddrs = serverIPs
// Remove the magic param to avoid interference with the real server
queries.Del("adguardhome_upstream_ip")
uu.RawQuery = queries.Encode()
upsStr = uu.String()
}
}
upstreams[i], err = upstream.AddressToUpstream(upsStr, &opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("upstream at index %d: %w", i, err) return nil, fmt.Errorf("upstream at index %d: %w", i, err)
} }