dnsforward: imp code

This commit is contained in:
Dimitry Kolyshev 2023-12-19 12:08:05 +02:00
parent 58b53ccd97
commit a8c3299504
2 changed files with 37 additions and 18 deletions

View File

@ -90,7 +90,8 @@ type Config struct {
FallbackDNS []string `yaml:"fallback_dns"` FallbackDNS []string `yaml:"fallback_dns"`
// UpstreamMode determines the logic through which upstreams will be used. // UpstreamMode determines the logic through which upstreams will be used.
// See UpstreamModeType* constants. // Should be [UpstreamModeTypeParallel], [UpstreamModeTypeLoadBalance], or
// [UpstreamModeTypeFastestAddr].
UpstreamMode string `yaml:"upstream_mode"` UpstreamMode string `yaml:"upstream_mode"`
// FastestTimeout replaces the default timeout for dialing IP addresses // FastestTimeout replaces the default timeout for dialing IP addresses

View File

@ -146,7 +146,8 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers) localPTRUpstreams := stringutil.CloneSliceOrEmpty(s.conf.LocalPTRResolvers)
var upstreamMode string var upstreamMode string
// TODO(d.kolyshev): Set 'load_balance' type string instead of nil. // TODO(d.kolyshev): Use 'load_balance' on frontend instead of nil as a
// default value.
switch s.conf.UpstreamMode { switch s.conf.UpstreamMode {
case UpstreamModeTypeParallel: case UpstreamModeTypeParallel:
upstreamMode = "parallel" upstreamMode = "parallel"
@ -224,18 +225,24 @@ func (req *jsonDNSConfig) checkBlockingMode() (err error) {
return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6) return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6)
} }
// checkUpstreamsMode returns an error if the upstream mode is invalid. // checkUpstreamMode returns an error if the upstream mode is invalid.
func (req *jsonDNSConfig) checkUpstreamsMode() (err error) { func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
if req.UpstreamMode == nil { if req.UpstreamMode == nil {
return nil return nil
} }
mode := *req.UpstreamMode switch um := *req.UpstreamMode; um {
if ok := slices.Contains([]string{"", "fastest_addr", "parallel"}, mode); !ok { case "":
return fmt.Errorf("upstream_mode: incorrect value %q", mode) return nil
case "parallel":
return nil
case "fastest_addr":
return nil
case "load_balance":
return nil
default:
return fmt.Errorf("upstream_mode: incorrect value %q", um)
} }
return nil
} }
// checkBootstrap returns an error if any bootstrap address is invalid. // checkBootstrap returns an error if any bootstrap address is invalid.
@ -299,7 +306,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return err return err
} }
err = req.checkUpstreamsMode() err = req.checkUpstreamMode()
if err != nil { if err != nil {
// Don't wrap the error since it's informative enough as is. // Don't wrap the error since it's informative enough as is.
return err return err
@ -448,14 +455,7 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
} }
if dc.UpstreamMode != nil { if dc.UpstreamMode != nil {
switch *dc.UpstreamMode { s.conf.UpstreamMode = mustParseUpstreamMode(*dc.UpstreamMode)
case "parallel":
s.conf.UpstreamMode = UpstreamModeTypeParallel
case "fastest_addr":
s.conf.UpstreamMode = UpstreamModeTypeFastestAddr
default:
s.conf.UpstreamMode = UpstreamModeTypeLoadBalance
}
} else { } else {
s.conf.UpstreamMode = UpstreamModeTypeLoadBalance s.conf.UpstreamMode = UpstreamModeTypeLoadBalance
} }
@ -470,6 +470,24 @@ func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) {
return s.setConfigRestartable(dc) return s.setConfigRestartable(dc)
} }
// mustParseUpstreamMode returns an upstream mode parsed from string. Panics in
// case of invalid value.
func mustParseUpstreamMode(mode string) (um string) {
switch mode {
case "":
return UpstreamModeTypeLoadBalance
case "load_balance":
return UpstreamModeTypeLoadBalance
case "parallel":
return UpstreamModeTypeParallel
case "fastest_addr":
return UpstreamModeTypeFastestAddr
default:
// Should never happen, since the value should be validated.
panic(fmt.Errorf("unexpected upstream mode: %q", mode))
}
}
// setIfNotNil sets the value pointed at by currentPtr to the value pointed at // setIfNotNil sets the value pointed at by currentPtr to the value pointed at
// by newPtr if newPtr is not nil. currentPtr must not be nil. // by newPtr if newPtr is not nil. currentPtr must not be nil.
func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) { func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) {