diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index d05a0d44..9ae6fc69 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -205,7 +205,10 @@ type DNSCreateParams struct { PrivateNets netutil.SubnetSet Anonymizer *aghnet.IPMut EtcHosts *aghnet.HostsContainer - Logger *slog.Logger + + // Logger is used as a base logger. It must not be nil. + Logger *slog.Logger + LocalDomain string } @@ -236,18 +239,13 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { etcHosts = upstream.NewHostsResolver(p.EtcHosts) } - l := p.Logger - if l == nil { - l = slog.Default() - } - s = &Server{ dnsFilter: p.DNSFilter, dhcpServer: p.DHCPServer, stats: p.Stats, queryLog: p.QueryLog, privateNets: p.PrivateNets, - logger: l.With(slogutil.KeyPrefix, "dnsforward"), + logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 9e4942cc..c326f8aa 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -28,6 +28,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" @@ -99,6 +100,7 @@ func createTestServer( DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -339,7 +341,10 @@ func TestServer_timeout(t *testing.T) { ServePlainDNS: true, } - s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) + s, err := NewServer(DNSCreateParams{ + DNSFilter: createTestDNSFilter(t), + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) err = s.Prepare(srvConf) @@ -349,7 +354,10 @@ func TestServer_timeout(t *testing.T) { }) t.Run("default", func(t *testing.T) { - s, err := NewServer(DNSCreateParams{DNSFilter: createTestDNSFilter(t)}) + s, err := NewServer(DNSCreateParams{ + DNSFilter: createTestDNSFilter(t), + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) s.conf.Config.UpstreamMode = UpstreamModeLoadBalance @@ -376,7 +384,9 @@ func TestServer_Prepare_fallbacks(t *testing.T) { ServePlainDNS: true, } - s, err := NewServer(DNSCreateParams{}) + s, err := NewServer(DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), + }) require.NoError(t, err) err = s.Prepare(srvConf) @@ -962,6 +972,7 @@ func TestBlockedCustomIP(t *testing.T) { DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1127,6 +1138,7 @@ func TestRewrite(t *testing.T) { DHCPServer: dhcp, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1256,6 +1268,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { }, }, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), LocalDomain: localDomain, }) require.NoError(t, err) @@ -1341,6 +1354,7 @@ func TestPTRResponseFromHosts(t *testing.T) { DHCPServer: dhcp, DNSFilter: flt, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -1392,24 +1406,29 @@ func TestNewServer(t *testing.T) { in DNSCreateParams wantErrMsg string }{{ - name: "success", - in: DNSCreateParams{}, + name: "success", + in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), + }, wantErrMsg: "", }, { name: "success_local_tld", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "mynet", }, wantErrMsg: "", }, { name: "success_local_domain", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "my.local.net", }, wantErrMsg: "", }, { name: "bad_local_domain", in: DNSCreateParams{ + Logger: slogutil.NewDiscardLogger(), LocalDomain: "!!!", }, wantErrMsg: `local domain: bad domain name "!!!": ` + diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 9e172a32..57d265f7 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -57,6 +58,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { }, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) @@ -229,6 +231,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { DHCPServer: &testDHCP{}, DNSFilter: f, PrivateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), + Logger: slogutil.NewDiscardLogger(), }) require.NoError(t, err) diff --git a/internal/home/home.go b/internal/home/home.go index 74ca6583..82a50083 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -559,7 +559,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { fatalOnError(err) // TODO(a.garipov): Use slog everywhere. - l := initLogger(ls) + slogLogger := newSlogLogger(ls) // Print the first message after logger is configured. log.Info(version.Full()) @@ -568,7 +568,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) { log.Info("AdGuard Home is running as a service") } - err = setupContext(opts, l) + err = setupContext(opts, slogLogger) fatalOnError(err) err = configureOS(config) diff --git a/internal/home/log.go b/internal/home/log.go index 13b3aa53..0b3a14a8 100644 --- a/internal/home/log.go +++ b/internal/home/log.go @@ -18,8 +18,8 @@ import ( // for logger output. const configSyslog = "syslog" -// initLogger returns new [*slog.Logger] configured with the given settings. -func initLogger(ls *logSettings) (l *slog.Logger) { +// newSlogLogger returns new [*slog.Logger] configured with the given settings. +func newSlogLogger(ls *logSettings) (l *slog.Logger) { if !ls.Enabled { return slogutil.NewDiscardLogger() } @@ -46,7 +46,7 @@ func configureLogger(ls *logSettings) (err error) { // Write logs to stdout by default. if ls.File == "" { - return err + return nil } if ls.File == configSyslog {