add magic param for specifying server ip
This commit is contained in:
parent
e6f8aeeebe
commit
c1c0c2972a
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
|
@ -597,13 +598,11 @@ func checkDNS(
|
|||
|
||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||
|
||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
us, err := dnssvc.AddressesToUpstreams([]string{upstreamAddr}, bootstrap, timeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||
}
|
||||
u := us[0]
|
||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||
|
||||
if err = healthCheck(u); err != nil {
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
|
@ -80,7 +82,7 @@ func New(c *Config) (svc *Service, err error) {
|
|||
if len(c.Upstreams) > 0 {
|
||||
upstreams = c.Upstreams
|
||||
} else {
|
||||
upstreams, err = addressesToUpstreams(
|
||||
upstreams, err = AddressesToUpstreams(
|
||||
c.UpstreamServers,
|
||||
c.BootstrapServers,
|
||||
c.UpstreamTimeout,
|
||||
|
@ -108,20 +110,47 @@ func New(c *Config) (svc *Service, err error) {
|
|||
return svc, nil
|
||||
}
|
||||
|
||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// AddressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||
// slice of upstreams.
|
||||
func addressesToUpstreams(
|
||||
func AddressesToUpstreams(
|
||||
upsStrs []string,
|
||||
bootstraps []string,
|
||||
timeout time.Duration,
|
||||
) (upstreams []upstream.Upstream, err error) {
|
||||
upstreams = make([]upstream.Upstream, len(upsStrs))
|
||||
for i, upsStr := range upsStrs {
|
||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
|
||||
// Here we pre-parse the url to find any magic params.
|
||||
opts := upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: timeout,
|
||||
})
|
||||
}
|
||||
if strings.Contains(upsStr, "://") {
|
||||
var uu *url.URL
|
||||
uu, err = url.Parse(upsStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", upsStr, err)
|
||||
}
|
||||
queries := uu.Query()
|
||||
if serverIPstrs, exists := queries["adguardhome_upstream_ip"]; exists {
|
||||
serverIPs := make([]net.IP, len(serverIPstrs))
|
||||
for i := range serverIPstrs {
|
||||
if t := net.ParseIP(serverIPstrs[i]); t != nil {
|
||||
serverIPs[i] = t
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse upstream_ip %s", serverIPstrs[i])
|
||||
}
|
||||
}
|
||||
opts.ServerIPAddrs = serverIPs
|
||||
|
||||
// Remove the magic param to avoid interference with the real server
|
||||
queries.Del("adguardhome_upstream_ip")
|
||||
uu.RawQuery = queries.Encode()
|
||||
upsStr = uu.String()
|
||||
}
|
||||
}
|
||||
|
||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue