add magic param for specifying server ip
This commit is contained in:
parent
e6f8aeeebe
commit
c1c0c2972a
|
@ -10,6 +10,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/next/dnssvc"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
@ -597,13 +598,11 @@ func checkDNS(
|
||||||
|
|
||||||
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
log.Debug("dnsforward: checking if upstream %q works", upstreamAddr)
|
||||||
|
|
||||||
u, err := upstream.AddressToUpstream(upstreamAddr, &upstream.Options{
|
us, err := dnssvc.AddressesToUpstreams([]string{upstreamAddr}, bootstrap, timeout)
|
||||||
Bootstrap: bootstrap,
|
|
||||||
Timeout: timeout,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
return fmt.Errorf("failed to choose upstream for %q: %w", upstreamAddr, err)
|
||||||
}
|
}
|
||||||
|
u := us[0]
|
||||||
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
defer func() { err = errors.WithDeferred(err, u.Close()) }()
|
||||||
|
|
||||||
if err = healthCheck(u); err != nil {
|
if err = healthCheck(u); err != nil {
|
||||||
|
|
|
@ -9,6 +9,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -80,7 +82,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
if len(c.Upstreams) > 0 {
|
if len(c.Upstreams) > 0 {
|
||||||
upstreams = c.Upstreams
|
upstreams = c.Upstreams
|
||||||
} else {
|
} else {
|
||||||
upstreams, err = addressesToUpstreams(
|
upstreams, err = AddressesToUpstreams(
|
||||||
c.UpstreamServers,
|
c.UpstreamServers,
|
||||||
c.BootstrapServers,
|
c.BootstrapServers,
|
||||||
c.UpstreamTimeout,
|
c.UpstreamTimeout,
|
||||||
|
@ -108,20 +110,47 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
// AddressesToUpstreams is a wrapper around [upstream.AddressToUpstream]. It
|
||||||
// accepts a slice of addresses and other upstream parameters, and returns a
|
// accepts a slice of addresses and other upstream parameters, and returns a
|
||||||
// slice of upstreams.
|
// slice of upstreams.
|
||||||
func addressesToUpstreams(
|
func AddressesToUpstreams(
|
||||||
upsStrs []string,
|
upsStrs []string,
|
||||||
bootstraps []string,
|
bootstraps []string,
|
||||||
timeout time.Duration,
|
timeout time.Duration,
|
||||||
) (upstreams []upstream.Upstream, err error) {
|
) (upstreams []upstream.Upstream, err error) {
|
||||||
upstreams = make([]upstream.Upstream, len(upsStrs))
|
upstreams = make([]upstream.Upstream, len(upsStrs))
|
||||||
for i, upsStr := range upsStrs {
|
for i, upsStr := range upsStrs {
|
||||||
upstreams[i], err = upstream.AddressToUpstream(upsStr, &upstream.Options{
|
// Here we pre-parse the url to find any magic params.
|
||||||
|
opts := upstream.Options{
|
||||||
Bootstrap: bootstraps,
|
Bootstrap: bootstraps,
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
})
|
}
|
||||||
|
if strings.Contains(upsStr, "://") {
|
||||||
|
var uu *url.URL
|
||||||
|
uu, err = url.Parse(upsStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse %s: %w", upsStr, err)
|
||||||
|
}
|
||||||
|
queries := uu.Query()
|
||||||
|
if serverIPstrs, exists := queries["adguardhome_upstream_ip"]; exists {
|
||||||
|
serverIPs := make([]net.IP, len(serverIPstrs))
|
||||||
|
for i := range serverIPstrs {
|
||||||
|
if t := net.ParseIP(serverIPstrs[i]); t != nil {
|
||||||
|
serverIPs[i] = t
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("failed to parse upstream_ip %s", serverIPstrs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
opts.ServerIPAddrs = serverIPs
|
||||||
|
|
||||||
|
// Remove the magic param to avoid interference with the real server
|
||||||
|
queries.Del("adguardhome_upstream_ip")
|
||||||
|
uu.RawQuery = queries.Encode()
|
||||||
|
upsStr = uu.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
upstreams[i], err = upstream.AddressToUpstream(upsStr, &opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
return nil, fmt.Errorf("upstream at index %d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue