all: imp code

This commit is contained in:
Dimitry Kolyshev 2024-07-09 11:31:12 +03:00
parent 7e8d3b50e7
commit 4db7cdc0c4
5 changed files with 37 additions and 17 deletions

View File

@ -205,7 +205,10 @@ type DNSCreateParams struct {
PrivateNets netutil.SubnetSet PrivateNets netutil.SubnetSet
Anonymizer *aghnet.IPMut Anonymizer *aghnet.IPMut
EtcHosts *aghnet.HostsContainer EtcHosts *aghnet.HostsContainer
Logger *slog.Logger
// Logger is used as a base logger. It must not be nil.
Logger *slog.Logger
LocalDomain string LocalDomain string
} }
@ -236,18 +239,13 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
etcHosts = upstream.NewHostsResolver(p.EtcHosts) etcHosts = upstream.NewHostsResolver(p.EtcHosts)
} }
l := p.Logger
if l == nil {
l = slog.Default()
}
s = &Server{ s = &Server{
dnsFilter: p.DNSFilter, dnsFilter: p.DNSFilter,
dhcpServer: p.DHCPServer, dhcpServer: p.DHCPServer,
stats: p.Stats, stats: p.Stats,
queryLog: p.QueryLog, queryLog: p.QueryLog,
privateNets: p.PrivateNets, privateNets: p.PrivateNets,
logger: l.With(slogutil.KeyPrefix, "dnsforward"), logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"),
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,

View File

@ -28,6 +28,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil" "github.com/AdguardTeam/golibs/timeutil"
@ -99,6 +100,7 @@ func createTestServer(
DHCPServer: dhcp, DHCPServer: dhcp,
DNSFilter: f, DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -339,7 +341,10 @@ func TestServer_timeout(t *testing.T) {
ServePlainDNS: true, ServePlainDNS: true,
} }
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) s, err := NewServer(DNSCreateParams{
DNSFilter: createTestDNSFilter(t),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err) require.NoError(t, err)
err = s.Prepare(srvConf) err = s.Prepare(srvConf)
@ -349,7 +354,10 @@ func TestServer_timeout(t *testing.T) {
}) })
t.Run("default", func(t *testing.T) { t.Run("default", func(t *testing.T) {
s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) s, err := NewServer(DNSCreateParams{
DNSFilter: createTestDNSFilter(t),
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err) require.NoError(t, err)
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
@ -376,7 +384,9 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
ServePlainDNS: true, ServePlainDNS: true,
} }
s, err := NewServer(DNSCreateParams{}) s, err := NewServer(DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
})
require.NoError(t, err) require.NoError(t, err)
err = s.Prepare(srvConf) err = s.Prepare(srvConf)
@ -962,6 +972,7 @@ func TestBlockedCustomIP(t *testing.T) {
DHCPServer: dhcp, DHCPServer: dhcp,
DNSFilter: f, DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1127,6 +1138,7 @@ func TestRewrite(t *testing.T) {
DHCPServer: dhcp, DHCPServer: dhcp,
DNSFilter: f, DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1256,6 +1268,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
}, },
}, },
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
LocalDomain: localDomain, LocalDomain: localDomain,
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1341,6 +1354,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
DHCPServer: dhcp, DHCPServer: dhcp,
DNSFilter: flt, DNSFilter: flt,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -1392,24 +1406,29 @@ func TestNewServer(t *testing.T) {
in DNSCreateParams in DNSCreateParams
wantErrMsg string wantErrMsg string
}{{ }{{
name: "success", name: "success",
in: DNSCreateParams{}, in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
},
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "success_local_tld", name: "success_local_tld",
in: DNSCreateParams{ in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "mynet", LocalDomain: "mynet",
}, },
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "success_local_domain", name: "success_local_domain",
in: DNSCreateParams{ in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "my.local.net", LocalDomain: "my.local.net",
}, },
wantErrMsg: "", wantErrMsg: "",
}, { }, {
name: "bad_local_domain", name: "bad_local_domain",
in: DNSCreateParams{ in: DNSCreateParams{
Logger: slogutil.NewDiscardLogger(),
LocalDomain: "!!!", LocalDomain: "!!!",
}, },
wantErrMsg: `local domain: bad domain name "!!!": ` + wantErrMsg: `local domain: bad domain name "!!!": ` +

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -57,6 +58,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
}, },
DNSFilter: f, DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -229,6 +231,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
DHCPServer: &testDHCP{}, DHCPServer: &testDHCP{},
DNSFilter: f, DNSFilter: f,
PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
Logger: slogutil.NewDiscardLogger(),
}) })
require.NoError(t, err) require.NoError(t, err)

View File

@ -559,7 +559,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
fatalOnError(err) fatalOnError(err)
// TODO(a.garipov): Use slog everywhere. // TODO(a.garipov): Use slog everywhere.
l := initLogger(ls) slogLogger := newSlogLogger(ls)
// Print the first message after logger is configured. // Print the first message after logger is configured.
log.Info(version.Full()) log.Info(version.Full())
@ -568,7 +568,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
log.Info("AdGuard Home is running as a service") log.Info("AdGuard Home is running as a service")
} }
err = setupContext(opts, l) err = setupContext(opts, slogLogger)
fatalOnError(err) fatalOnError(err)
err = configureOS(config) err = configureOS(config)

View File

@ -18,8 +18,8 @@ import (
// for logger output. // for logger output.
const configSyslog = "syslog" const configSyslog = "syslog"
// initLogger returns new [*slog.Logger] configured with the given settings. // newSlogLogger returns new [*slog.Logger] configured with the given settings.
func initLogger(ls *logSettings) (l *slog.Logger) { func newSlogLogger(ls *logSettings) (l *slog.Logger) {
if !ls.Enabled { if !ls.Enabled {
return slogutil.NewDiscardLogger() return slogutil.NewDiscardLogger()
} }
@ -46,7 +46,7 @@ func configureLogger(ls *logSettings) (err error) {
// Write logs to stdout by default. // Write logs to stdout by default.
if ls.File == "" { if ls.File == "" {
return err return nil
} }
if ls.File == configSyslog { if ls.File == configSyslog {