Pull request 1883: 951-blocked-services-client-schedule

Updates #951.

Squashed commit of the following:

commit 94e4766932940a99c5265489bccb46d0ed6cec25
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 27 17:21:41 2023 +0300

    chlog: upd docs

commit b4022c33860c258bf29650413f0c972b849a1758
Merge: cfa24ff01 e7e638443
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 27 16:33:20 2023 +0300

    Merge branch 'master' into 951-blocked-services-client-schedule

commit cfa24ff0190b2bc12736700eeff815525fbaf5fe
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 27 15:04:10 2023 +0300

    chlog: imp docs

commit dad27590d5eefde82758d58fc06a20c139492db8
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jun 26 17:38:08 2023 +0300

    home: imp err msg

commit 7d9ba98c4477000fc2e0f06c3462fe9cd0c65293
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Jun 26 16:58:00 2023 +0300

    all: add tests

commit 8e952fc4e3b3d433b29efe47c88d6b7806e99ff8
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 23 16:36:10 2023 +0300

    schedule: add todo

commit 723573a98d5b930334a5fa125eb12593f4a2430d
Merge: 2151ab2a6 e54fc9b1e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 23 11:40:03 2023 +0300

    Merge branch 'master' into 951-blocked-services-client-schedule

commit 2151ab2a627b9833ba6cce9621f72b29d326da75
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Jun 23 11:37:49 2023 +0300

    all: add tests

commit 81ab341db3e4053f09b181d8111c0da197bdac05
Merge: aa7ae41a8 66345e855
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Jun 22 17:59:01 2023 +0300

    Merge branch 'master' into 951-blocked-services-client-schedule

commit aa7ae41a868045fe24e390b25f15551fd8821529
Merge: 304389a48 06d465b0d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Jun 21 17:10:11 2023 +0300

    Merge branch 'master' into 951-blocked-services-client-schedule

commit 304389a487f728e8ced293ea811a4e0026a37f0d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Jun 21 17:05:31 2023 +0300

    home: imp err msg

commit 29cfc7ae2a0bbd5ec3205eae3f6f810519787f26
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 20 20:42:59 2023 +0300

    all: imp err handling

commit 8543868eef6442fd30131d9567b66222999101e9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 20 18:21:50 2023 +0300

    all: upd chlog

commit c5b614d45e5cf25c30c52343f48139fb34c77539
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Jun 20 14:37:47 2023 +0300

    all: add blocked services schedule
This commit is contained in:
Stanislav Chzhen 2023-06-27 18:03:07 +03:00
parent e7e638443f
commit d88181343c
14 changed files with 418 additions and 71 deletions

View File

@ -27,9 +27,9 @@ NOTE: Add new changes BELOW THIS COMMENT.
- The new command-line flag `--web-addr` is the address to serve the web UI on, - The new command-line flag `--web-addr` is the address to serve the web UI on,
in the host:port format. in the host:port format.
- The ability to set inactivity periods for filtering blocked services in the - The ability to set inactivity periods for filtering blocked services, both
configuration file ([#951]). The UI changes are coming in the upcoming globally and per client, in the configuration file ([#951]). The UI changes
releases. are coming in the upcoming releases.
- The ability to edit rewrite rules via `PUT /control/rewrite/update` HTTP API - The ability to edit rewrite rules via `PUT /control/rewrite/update` HTTP API
and the Web UI ([#1577]). and the Web UI ([#1577]).
@ -37,8 +37,42 @@ NOTE: Add new changes BELOW THIS COMMENT.
#### Configuration Changes #### Configuration Changes
In this release, the schema version has changed from 20 to 21. In this release, the schema version has changed from 20 to 22.
- Property `clients.persistent.blocked_services`, which in schema versions 21
and earlier used to be a list containing ids of blocked services, is now an
object containing ids and schedule for blocked services:
```yaml
# BEFORE:
'clients':
'persistent':
- 'name': 'client-name'
'blocked_services':
- id_1
- id_2
# AFTER:
'clients':
'persistent':
- 'name': client-name
'blocked_services':
'ids':
- id_1
- id_2
'schedule':
'time_zone': 'Local'
'sun':
'start': '0s'
'end': '24h'
'mon':
'start': '1h'
'end': '23h'
```
To rollback this change, replace `clients.persistent.blocked_services` object
with the list of ids of blocked services and change the `schema_version` back
to `21`.
- Property `dns.blocked_services`, which in schema versions 20 and earlier used - Property `dns.blocked_services`, which in schema versions 20 and earlier used
to be a list containing ids of blocked services, is now an object containing to be a list containing ids of blocked services, is now an object containing
ids and schedule for blocked services: ids and schedule for blocked services:

View File

@ -2,6 +2,7 @@ package filtering
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time" "time"
@ -55,11 +56,29 @@ type BlockedServices struct {
IDs []string `yaml:"ids"` IDs []string `yaml:"ids"`
} }
// BlockedSvcKnown returns true if a blocked service ID is known. // Clone returns a deep copy of blocked services.
func BlockedSvcKnown(s string) (ok bool) { func (s *BlockedServices) Clone() (c *BlockedServices) {
_, ok = serviceRules[s] if s == nil {
return nil
}
return ok return &BlockedServices{
Schedule: s.Schedule.Clone(),
IDs: slices.Clone(s.IDs),
}
}
// Validate returns an error if blocked services contain unknown service ID. s
// must not be nil.
func (s *BlockedServices) Validate() (err error) {
for _, id := range s.IDs {
_, ok := serviceRules[id]
if !ok {
return fmt.Errorf("unknown blocked-service %q", id)
}
}
return nil
} }
// ApplyBlockedServices - set blocked services settings for this DNS request // ApplyBlockedServices - set blocked services settings for this DNS request

View File

@ -988,17 +988,11 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) {
} }
if d.BlockedServices != nil { if d.BlockedServices != nil {
bsvcs := []string{} err = d.BlockedServices.Validate()
for _, s := range d.BlockedServices.IDs {
if !BlockedSvcKnown(s) {
log.Debug("skipping unknown blocked-service %q", s)
continue if err != nil {
} return nil, fmt.Errorf("filtering: %w", err)
bsvcs = append(bsvcs, s)
} }
d.BlockedServices.IDs = bsvcs
} }
if blockFilters != nil { if blockFilters != nil {

View File

@ -23,12 +23,14 @@ type Client struct {
safeSearchConf filtering.SafeSearchConfig safeSearchConf filtering.SafeSearchConfig
SafeSearch filtering.SafeSearch SafeSearch filtering.SafeSearch
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices
Name string Name string
IDs []string IDs []string
Tags []string Tags []string
BlockedServices []string Upstreams []string
Upstreams []string
UseOwnSettings bool UseOwnSettings bool
FilteringEnabled bool FilteringEnabled bool
@ -44,9 +46,9 @@ type Client struct {
func (c *Client) ShallowClone() (sh *Client) { func (c *Client) ShallowClone() (sh *Client) {
clone := *c clone := *c
clone.BlockedServices = c.BlockedServices.Clone()
clone.IDs = stringutil.CloneSlice(c.IDs) clone.IDs = stringutil.CloneSlice(c.IDs)
clone.Tags = stringutil.CloneSlice(c.Tags) clone.Tags = stringutil.CloneSlice(c.Tags)
clone.BlockedServices = stringutil.CloneSlice(c.BlockedServices)
clone.Upstreams = stringutil.CloneSlice(c.Upstreams) clone.Upstreams = stringutil.CloneSlice(c.Upstreams)
return &clone return &clone

View File

@ -96,7 +96,7 @@ func (clients *clientsContainer) Init(
etcHosts *aghnet.HostsContainer, etcHosts *aghnet.HostsContainer,
arpdb aghnet.ARPDB, arpdb aghnet.ARPDB,
filteringConf *filtering.Config, filteringConf *filtering.Config,
) { ) (err error) {
if clients.list != nil { if clients.list != nil {
log.Fatal("clients.list != nil") log.Fatal("clients.list != nil")
} }
@ -110,13 +110,17 @@ func (clients *clientsContainer) Init(
clients.dhcpServer = dhcpServer clients.dhcpServer = dhcpServer
clients.etcHosts = etcHosts clients.etcHosts = etcHosts
clients.arpdb = arpdb clients.arpdb = arpdb
clients.addFromConfig(objects, filteringConf) err = clients.addFromConfig(objects, filteringConf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
if clients.testing { if clients.testing {
return return nil
} }
if clients.dhcpServer != nil { if clients.dhcpServer != nil {
@ -127,6 +131,8 @@ func (clients *clientsContainer) Init(
if clients.etcHosts != nil { if clients.etcHosts != nil {
go clients.handleHostsUpdates() go clients.handleHostsUpdates()
} }
return nil
} }
func (clients *clientsContainer) handleHostsUpdates() { func (clients *clientsContainer) handleHostsUpdates() {
@ -166,12 +172,14 @@ func (clients *clientsContainer) reloadARP() {
type clientObject struct { type clientObject struct {
SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"` SafeSearchConf filtering.SafeSearchConfig `yaml:"safe_search"`
// BlockedServices is the configuration of blocked services of a client.
BlockedServices *filtering.BlockedServices `yaml:"blocked_services"`
Name string `yaml:"name"` Name string `yaml:"name"`
Tags []string `yaml:"tags"` IDs []string `yaml:"ids"`
IDs []string `yaml:"ids"` Tags []string `yaml:"tags"`
BlockedServices []string `yaml:"blocked_services"` Upstreams []string `yaml:"upstreams"`
Upstreams []string `yaml:"upstreams"`
UseGlobalSettings bool `yaml:"use_global_settings"` UseGlobalSettings bool `yaml:"use_global_settings"`
FilteringEnabled bool `yaml:"filtering_enabled"` FilteringEnabled bool `yaml:"filtering_enabled"`
@ -185,7 +193,10 @@ type clientObject struct {
// addFromConfig initializes the clients container with objects from the // addFromConfig initializes the clients container with objects from the
// configuration file. // configuration file.
func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) { func (clients *clientsContainer) addFromConfig(
objects []*clientObject,
filteringConf *filtering.Config,
) (err error) {
for _, o := range objects { for _, o := range objects {
cli := &Client{ cli := &Client{
Name: o.Name, Name: o.Name,
@ -206,7 +217,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
if o.SafeSearchConf.Enabled { if o.SafeSearchConf.Enabled {
o.SafeSearchConf.CustomResolver = safeSearchResolver{} o.SafeSearchConf.CustomResolver = safeSearchResolver{}
err := cli.setSafeSearch( err = cli.setSafeSearch(
o.SafeSearchConf, o.SafeSearchConf,
filteringConf.SafeSearchCacheSize, filteringConf.SafeSearchCacheSize,
time.Minute*time.Duration(filteringConf.CacheTime), time.Minute*time.Duration(filteringConf.CacheTime),
@ -218,14 +229,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
} }
} }
for _, s := range o.BlockedServices { err = o.BlockedServices.Validate()
if filtering.BlockedSvcKnown(s) { if err != nil {
cli.BlockedServices = append(cli.BlockedServices, s) return fmt.Errorf("clients: init client blocked services %q: %w", cli.Name, err)
} else {
log.Info("clients: skipping unknown blocked service %q", s)
}
} }
cli.BlockedServices = o.BlockedServices.Clone()
for _, t := range o.Tags { for _, t := range o.Tags {
if clients.allTags.Has(t) { if clients.allTags.Has(t) {
cli.Tags = append(cli.Tags, t) cli.Tags = append(cli.Tags, t)
@ -236,11 +246,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin
slices.Sort(cli.Tags) slices.Sort(cli.Tags)
_, err := clients.Add(cli) _, err = clients.Add(cli)
if err != nil { if err != nil {
log.Error("clients: adding clients %s: %s", cli.Name, err) log.Error("clients: adding clients %s: %s", cli.Name, err)
} }
} }
return nil
} }
// forConfig returns all currently known persistent clients as objects for the // forConfig returns all currently known persistent clients as objects for the
@ -254,10 +266,11 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
o := &clientObject{ o := &clientObject{
Name: cli.Name, Name: cli.Name,
Tags: stringutil.CloneSlice(cli.Tags), BlockedServices: cli.BlockedServices.Clone(),
IDs: stringutil.CloneSlice(cli.IDs),
BlockedServices: stringutil.CloneSlice(cli.BlockedServices), IDs: stringutil.CloneSlice(cli.IDs),
Upstreams: stringutil.CloneSlice(cli.Upstreams), Tags: stringutil.CloneSlice(cli.Tags),
Upstreams: stringutil.CloneSlice(cli.Upstreams),
UseGlobalSettings: !cli.UseOwnSettings, UseGlobalSettings: !cli.UseOwnSettings,
FilteringEnabled: cli.FilteringEnabled, FilteringEnabled: cli.FilteringEnabled,

View File

@ -16,18 +16,19 @@ import (
// newClientsContainer is a helper that creates a new clients container for // newClientsContainer is a helper that creates a new clients container for
// tests. // tests.
func newClientsContainer() (c *clientsContainer) { func newClientsContainer(t *testing.T) (c *clientsContainer) {
c = &clientsContainer{ c = &clientsContainer{
testing: true, testing: true,
} }
c.Init(nil, nil, nil, nil, &filtering.Config{}) err := c.Init(nil, nil, nil, nil, &filtering.Config{})
require.NoError(t, err)
return c return c
} }
func TestClients(t *testing.T) { func TestClients(t *testing.T) {
clients := newClientsContainer() clients := newClientsContainer(t)
t.Run("add_success", func(t *testing.T) { t.Run("add_success", func(t *testing.T) {
var ( var (
@ -198,7 +199,7 @@ func TestClients(t *testing.T) {
} }
func TestClientsWHOIS(t *testing.T) { func TestClientsWHOIS(t *testing.T) {
clients := newClientsContainer() clients := newClientsContainer(t)
whois := &whois.Info{ whois := &whois.Info{
Country: "AU", Country: "AU",
Orgname: "Example Org", Orgname: "Example Org",
@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) {
} }
func TestClientsAddExisting(t *testing.T) { func TestClientsAddExisting(t *testing.T) {
clients := newClientsContainer() clients := newClientsContainer(t)
t.Run("simple", func(t *testing.T) { t.Run("simple", func(t *testing.T) {
ip := netip.MustParseAddr("1.1.1.1") ip := netip.MustParseAddr("1.1.1.1")
@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) {
} }
func TestClientsCustomUpstream(t *testing.T) { func TestClientsCustomUpstream(t *testing.T) {
clients := newClientsContainer() clients := newClientsContainer(t)
// Add client with upstreams. // Add client with upstreams.
ok, err := clients.Add(&Client{ ok, err := clients.Add(&Client{

View File

@ -123,10 +123,14 @@ func (clients *clientsContainer) jsonToClient(cj clientJSON, prev *Client) (c *C
Name: cj.Name, Name: cj.Name,
IDs: cj.IDs, BlockedServices: &filtering.BlockedServices{
Tags: cj.Tags, Schedule: prev.BlockedServices.Schedule.Clone(),
BlockedServices: cj.BlockedServices, IDs: cj.BlockedServices,
Upstreams: cj.Upstreams, },
IDs: cj.IDs,
Tags: cj.Tags,
Upstreams: cj.Upstreams,
UseOwnSettings: !cj.UseGlobalSettings, UseOwnSettings: !cj.UseGlobalSettings,
FilteringEnabled: cj.FilteringEnabled, FilteringEnabled: cj.FilteringEnabled,
@ -180,7 +184,8 @@ func clientToJSON(c *Client) (cj *clientJSON) {
SafeBrowsingEnabled: c.SafeBrowsingEnabled, SafeBrowsingEnabled: c.SafeBrowsingEnabled,
UseGlobalBlockedServices: !c.UseOwnBlockedServices, UseGlobalBlockedServices: !c.UseOwnBlockedServices,
BlockedServices: c.BlockedServices,
BlockedServices: c.BlockedServices.IDs,
Upstreams: c.Upstreams, Upstreams: c.Upstreams,

View File

@ -474,9 +474,11 @@ func applyAdditionalFiltering(clientIP net.IP, clientID string, setts *filtering
if c.UseOwnBlockedServices { if c.UseOwnBlockedServices {
// TODO(e.burkov): Get rid of this crutch. // TODO(e.burkov): Get rid of this crutch.
setts.ServicesRules = nil setts.ServicesRules = nil
svcs := c.BlockedServices svcs := c.BlockedServices.IDs
Context.filters.ApplyBlockedServicesList(setts, svcs) if !c.BlockedServices.Schedule.Contains(time.Now()) {
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs) Context.filters.ApplyBlockedServicesList(setts, svcs)
log.Debug("%s: services for client %q set: %s", pref, c.Name, svcs)
}
} }
setts.ClientName = c.Name setts.ClientName = c.Name

View File

@ -6,9 +6,87 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestApplyAdditionalFiltering(t *testing.T) {
var err error
Context.filters, err = filtering.New(&filtering.Config{
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
}, nil)
require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{
"default": {
UseOwnSettings: false,
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
FilteringEnabled: false,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
"custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: true,
ParentalEnabled: true,
},
"partial_custom_filtering": {
UseOwnSettings: true,
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
FilteringEnabled: true,
SafeBrowsingEnabled: false,
ParentalEnabled: false,
},
}
testCases := []struct {
name string
id string
FilteringEnabled assert.BoolAssertionFunc
SafeSearchEnabled assert.BoolAssertionFunc
SafeBrowsingEnabled assert.BoolAssertionFunc
ParentalEnabled assert.BoolAssertionFunc
}{{
name: "global_settings",
id: "default",
FilteringEnabled: assert.False,
SafeSearchEnabled: assert.False,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}, {
name: "custom_settings",
id: "custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.True,
ParentalEnabled: assert.True,
}, {
name: "partial",
id: "partial_custom_filtering",
FilteringEnabled: assert.True,
SafeSearchEnabled: assert.True,
SafeBrowsingEnabled: assert.False,
ParentalEnabled: assert.False,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
setts := &filtering.Settings{}
applyAdditionalFiltering(net.IP{1, 2, 3, 4}, tc.id, setts)
tc.FilteringEnabled(t, setts.FilteringEnabled)
tc.SafeSearchEnabled(t, setts.SafeSearchEnabled)
tc.SafeBrowsingEnabled(t, setts.SafeBrowsingEnabled)
tc.ParentalEnabled(t, setts.ParentalEnabled)
})
}
}
func TestApplyAdditionalFiltering_blockedServices(t *testing.T) { func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
filtering.InitModule() filtering.InitModule()
@ -29,43 +107,61 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
Context.clients.idIndex = map[string]*Client{ Context.clients.idIndex = map[string]*Client{
"client_1": { "default": {
UseOwnBlockedServices: false, UseOwnBlockedServices: false,
}, },
"client_2": { "no_services": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
},
UseOwnBlockedServices: true, UseOwnBlockedServices: true,
}, },
"client_3": { "services": {
BlockedServices: clientBlockedServices, BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true, UseOwnBlockedServices: true,
}, },
"client_4": { "invalid_services": {
BlockedServices: invalidBlockedServices, BlockedServices: &filtering.BlockedServices{
Schedule: schedule.EmptyWeekly(),
IDs: invalidBlockedServices,
},
UseOwnBlockedServices: true,
},
"allow_all": {
BlockedServices: &filtering.BlockedServices{
Schedule: schedule.FullWeekly(),
IDs: clientBlockedServices,
},
UseOwnBlockedServices: true, UseOwnBlockedServices: true,
}, },
} }
testCases := []struct { testCases := []struct {
name string name string
ip net.IP
id string id string
setts *filtering.Settings
wantLen int wantLen int
}{{ }{{
name: "global_settings", name: "global_settings",
id: "client_1", id: "default",
wantLen: len(globalBlockedServices), wantLen: len(globalBlockedServices),
}, { }, {
name: "custom_settings", name: "custom_settings",
id: "client_2", id: "no_services",
wantLen: 0, wantLen: 0,
}, { }, {
name: "custom_settings_block", name: "custom_settings_block",
id: "client_3", id: "services",
wantLen: len(clientBlockedServices), wantLen: len(clientBlockedServices),
}, { }, {
name: "custom_settings_invalid", name: "custom_settings_invalid",
id: "client_4", id: "invalid_services",
wantLen: 0,
}, {
name: "custom_settings_inactive_schedule",
id: "allow_all",
wantLen: 0, wantLen: 0,
}} }}

View File

@ -355,13 +355,17 @@ func initContextClients() (err error) {
arpdb = aghnet.NewARPDB() arpdb = aghnet.NewARPDB()
} }
Context.clients.Init( err = Context.clients.Init(
config.Clients.Persistent, config.Clients.Persistent,
Context.dhcpServer, Context.dhcpServer,
Context.etcHosts, Context.etcHosts,
arpdb, arpdb,
config.DNS.DnsfilterConf, config.DNS.DnsfilterConf,
) )
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
return nil return nil
} }

View File

@ -228,7 +228,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
w.Reset() w.Reset()
cc := newClientsContainer() cc := newClientsContainer(t)
ch := make(chan netip.Addr) ch := make(chan netip.Addr)
rdns := &RDNS{ rdns := &RDNS{
exchanger: &rDNSExchanger{ exchanger: &rDNSExchanger{

View File

@ -22,7 +22,7 @@ import (
) )
// currentSchemaVersion is the current schema version. // currentSchemaVersion is the current schema version.
const currentSchemaVersion = 21 const currentSchemaVersion = 22
// These aliases are provided for convenience. // These aliases are provided for convenience.
type ( type (
@ -95,6 +95,7 @@ func upgradeConfigSchema(oldVersion int, diskConf yobj) (err error) {
upgradeSchema18to19, upgradeSchema18to19,
upgradeSchema19to20, upgradeSchema19to20,
upgradeSchema20to21, upgradeSchema20to21,
upgradeSchema21to22,
} }
n := 0 n := 0
@ -1179,6 +1180,82 @@ func upgradeSchema20to21(diskConf yobj) (err error) {
return nil return nil
} }
// upgradeSchema21to22 performs the following changes:
//
// # BEFORE:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// - 'svc_name'
//
// # AFTER:
// 'persistent':
// - 'name': 'client_name'
// 'blocked_services':
// 'ids':
// - 'svc_name'
// 'schedule':
// 'time_zone': 'Local'
func upgradeSchema21to22(diskConf yobj) (err error) {
log.Println("Upgrade yaml: 21 to 22")
diskConf["schema_version"] = 22
const field = "blocked_services"
clientsVal, ok := diskConf["clients"]
if !ok {
return nil
}
clients, ok := clientsVal.(yobj)
if !ok {
return fmt.Errorf("unexpected type of clients: %T", clientsVal)
}
persistentVal, ok := clients["persistent"]
if !ok {
return nil
}
persistent, ok := persistentVal.([]any)
if !ok {
return fmt.Errorf("unexpected type of persistent clients: %T", persistentVal)
}
for i, val := range persistent {
var c yobj
c, ok = val.(yobj)
if !ok {
return fmt.Errorf("persistent client at index %d: unexpected type %T", i, val)
}
var blockedVal any
blockedVal, ok = c[field]
if !ok {
continue
}
var services yarr
services, ok = blockedVal.(yarr)
if !ok {
return fmt.Errorf(
"persistent client at index %d: unexpected type of blocked services: %T",
i,
blockedVal,
)
}
c[field] = yobj{
"ids": services,
"schedule": yobj{
"time_zone": "Local",
},
}
}
return nil
}
// TODO(a.garipov): Replace with log.Output when we port it to our logging // TODO(a.garipov): Replace with log.Output when we port it to our logging
// package. // package.
func funcName() string { func funcName() string {

View File

@ -1183,3 +1183,73 @@ func TestUpgradeSchema20to21(t *testing.T) {
}) })
} }
} }
func TestUpgradeSchema21to22(t *testing.T) {
const newSchemaVer = 22
testCases := []struct {
in yobj
want yobj
name string
}{{
in: yobj{
"clients": yobj{},
},
want: yobj{
"clients": yobj{},
"schema_version": newSchemaVer,
},
name: "nothing",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "no_services",
}, {
in: yobj{
"clients": yobj{
"persistent": []any{yobj{"name": "localhost", "blocked_services": yarr{"ok"}}},
},
},
want: yobj{
"clients": yobj{
"persistent": []any{yobj{
"name": "localhost",
"blocked_services": yobj{
"ids": yarr{"ok"},
"schedule": yobj{
"time_zone": "Local",
},
},
}},
},
"schema_version": newSchemaVer,
},
name: "services",
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := upgradeSchema21to22(tc.in)
require.NoError(t, err)
assert.Equal(t, tc.want, tc.in)
})
}
}

View File

@ -28,6 +28,36 @@ func EmptyWeekly() (w *Weekly) {
} }
} }
// FullWeekly creates full weekly schedule with local time zone.
//
// TODO(s.chzhen): Consider moving into tests.
func FullWeekly() (w *Weekly) {
fullDay := dayRange{start: 0, end: maxDayRange}
return &Weekly{
location: time.Local,
days: [7]dayRange{
time.Sunday: fullDay,
time.Monday: fullDay,
time.Tuesday: fullDay,
time.Wednesday: fullDay,
time.Thursday: fullDay,
time.Friday: fullDay,
time.Saturday: fullDay,
},
}
}
// Clone returns a deep copy of a weekly.
func (w *Weekly) Clone() (c *Weekly) {
// NOTE: Do not use time.LoadLocation, because the results will be
// different on time zone database update.
return &Weekly{
location: w.location,
days: w.days,
}
}
// Contains returns true if t is within the corresponding day range of the // Contains returns true if t is within the corresponding day range of the
// schedule in the schedule's time zone. // schedule in the schedule's time zone.
func (w *Weekly) Contains(t time.Time) (ok bool) { func (w *Weekly) Contains(t time.Time) (ok bool) {