From a8c329950423b59098d1f2b16d1da7100dd54f8d Mon Sep 17 00:00:00 2001 From: Dimitry Kolyshev Date: Tue, 19 Dec 2023 12:08:05 +0200 Subject: [PATCH] dnsforward: imp code --- internal/dnsforward/config.go | 3 +- internal/dnsforward/http.go | 52 +++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 18 deletions(-) diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 7fa98e80..2a494e5f 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -90,7 +90,8 @@ type Config struct { FallbackDNS []string `yaml:"fallback_dns"` // UpstreamMode determines the logic through which upstreams will be used. - // See UpstreamModeType* constants. + // Should be [UpstreamModeTypeParallel], [UpstreamModeTypeLoadBalance], or + // [UpstreamModeTypeFastestAddr]. UpstreamMode string `yaml:"upstream_mode"` // FastestTimeout replaces the default timeout for dialing IP addresses diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 0746b572..096c3d73 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -146,7 +146,8 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers) var upstreamMode string - // TODO(d.kolyshev): Set 'load_balance' type string instead of nil. + // TODO(d.kolyshev): Use 'load_balance' on frontend instead of nil as a + // default value. switch s.conf.UpstreamMode { case UpstreamModeTypeParallel: upstreamMode = "parallel" @@ -224,18 +225,24 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) { return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6) } -// checkUpstreamsMode returns an error if the upstream mode is invalid. -func (req *jsonDNSConfig) checkUpstreamsMode() (err error) { +// checkUpstreamMode returns an error if the upstream mode is invalid. +func (req *jsonDNSConfig) checkUpstreamMode() (err error) { if req.UpstreamMode == nil { return nil } - mode := *req.UpstreamMode - if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok { - return fmt.Errorf("upstream_mode: incorrect value %q", mode) + switch um := *req.UpstreamMode; um { + case "": + return nil + case "parallel": + return nil + case "fastest_addr": + return nil + case "load_balance": + return nil + default: + return fmt.Errorf("upstream_mode: incorrect value %q", um) } - - return nil } // checkBootstrap returns an error if any bootstrap address is invalid. @@ -299,7 +306,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { return err } - err = req.checkUpstreamsMode() + err = req.checkUpstreamMode() if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -448,14 +455,7 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { } if dc.UpstreamMode != nil { - switch *dc.UpstreamMode { - case "parallel": - s.conf.UpstreamMode = UpstreamModeTypeParallel - case "fastest_addr": - s.conf.UpstreamMode = UpstreamModeTypeFastestAddr - default: - s.conf.UpstreamMode = UpstreamModeTypeLoadBalance - } + s.conf.UpstreamMode = mustParseUpstreamMode(*dc.UpstreamMode) } else { s.conf.UpstreamMode = UpstreamModeTypeLoadBalance } @@ -470,6 +470,24 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { return s.setConfigRestartable(dc) } +// mustParseUpstreamMode returns an upstream mode parsed from string. Panics in +// case of invalid value. +func mustParseUpstreamMode(mode string) (um string) { + switch mode { + case "": + return UpstreamModeTypeLoadBalance + case "load_balance": + return UpstreamModeTypeLoadBalance + case "parallel": + return UpstreamModeTypeParallel + case "fastest_addr": + return UpstreamModeTypeFastestAddr + default: + // Should never happen, since the value should be validated. + panic(fmt.Errorf("unexpected upstream mode: %q", mode)) + } +} + // setIfNotNil sets the value pointed at by currentPtr to the value pointed at // by newPtr if newPtr is not nil. currentPtr must not be nil. func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {