Protect against users deleting the filter ID's in the config file.

Incidentally, it also simplifies upgrade schema from 0 to 1.
This commit is contained in:
Eugene Bujak 2018-11-27 21:25:03 +03:00
parent 6cb991fe7f
commit 701fd10c1c
4 changed files with 57 additions and 44 deletions

39
app.go
View File

@ -9,6 +9,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"time"
"github.com/gobuffalo/packr" "github.com/gobuffalo/packr"
"golang.org/x/crypto/ssh/terminal" "golang.org/x/crypto/ssh/terminal"
@ -135,6 +136,34 @@ func main() {
} }
} }
// Load filters from the disk
// And if any filter has zero ID, assign a new one
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we're operating on a copy
if filter.ID == 0 {
filter.ID = assignUniqueFilterID()
}
err := filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
// clear LastUpdated so it gets fetched right away
}
if len(filter.Contents) == 0 {
filter.LastUpdated = time.Time{}
}
}
// Update filters we've just loaded right away, don't wait for periodic update timer
go func() {
checkFiltersUpdates(false)
// Save the updated config
err := writeConfig()
if err != nil {
log.Fatal(err)
}
}()
// Eat all args so that coredns can start happily // Eat all args so that coredns can start happily
if len(os.Args) > 1 { if len(os.Args) > 1 {
os.Args = os.Args[:1] os.Args = os.Args[:1]
@ -146,16 +175,6 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
// Load filters from the disk
for i := range config.Filters {
filter := &config.Filters[i]
err = filter.load()
if err != nil {
// This is okay for the first start, the filter will be loaded later
log.Printf("Couldn't load filter %d contents due to %s", filter.ID, err)
}
}
address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort)) address := net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
runFiltersUpdatesTimer() runFiltersUpdatesTimer()

View File

@ -22,7 +22,7 @@ const (
) )
// Just a counter that we use for incrementing the filter ID // Just a counter that we use for incrementing the filter ID
var NextFilterId = time.Now().Unix() var nextFilterID int64 = time.Now().Unix()
// configuration is loaded from YAML // configuration is loaded from YAML
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
@ -74,7 +74,7 @@ type filter struct {
Name string `json:"name" yaml:"name"` Name string `json:"name" yaml:"name"`
RulesCount int `json:"rulesCount" yaml:"-"` RulesCount int `json:"rulesCount" yaml:"-"`
LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"` LastUpdated time.Time `json:"lastUpdated,omitempty" yaml:"last_updated,omitempty"`
ID int64 // auto-assigned when filter is added (see NextFilterId) ID int64 // auto-assigned when filter is added (see nextFilterID)
Contents []byte `json:"-" yaml:"-"` // not in yaml or json Contents []byte `json:"-" yaml:"-"` // not in yaml or json
} }
@ -165,12 +165,7 @@ func parseConfig() error {
config.Filters = config.Filters[:i] config.Filters = config.Filters[:i]
} }
// Set the next filter ID to max(filter.ID) + 1 updateUniqueFilterID(config.Filters)
for i := range config.Filters {
if NextFilterId < config.Filters[i].ID {
NextFilterId = config.Filters[i].ID + 1
}
}
return nil return nil
} }
@ -293,3 +288,18 @@ func generateCoreDNSConfigText() (string, error) {
configText = removeEmptyLines.ReplaceAllString(configText, "\n") configText = removeEmptyLines.ReplaceAllString(configText, "\n")
return configText, nil return configText, nil
} }
// Set the next filter ID to max(filter.ID) + 1
func updateUniqueFilterID(filters []filter) {
for _, filter := range filters {
if nextFilterID < filter.ID {
nextFilterID = filter.ID + 1
}
}
}
func assignUniqueFilterID() int64 {
value := nextFilterID
nextFilterID += 1
return value
}

View File

@ -343,9 +343,8 @@ func handleFilteringAddURL(w http.ResponseWriter, r *http.Request) {
} }
// Set necessary properties // Set necessary properties
filter.ID = NextFilterId filter.ID = assignUniqueFilterID()
filter.Enabled = true filter.Enabled = true
NextFilterId++
// Download the filter contents // Download the filter contents
ok, err := filter.update(true) ok, err := filter.update(true)
@ -550,6 +549,11 @@ func checkFiltersUpdates(force bool) int {
updateCount := 0 updateCount := 0
for i := range config.Filters { for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy filter := &config.Filters[i] // otherwise we will be operating on a copy
if filter.ID == 0 { // protect against users modifying the yaml and removing the ID
filter.ID = assignUniqueFilterID()
}
updated, err := filter.update(force) updated, err := filter.update(force)
if err != nil { if err != nil {
log.Printf("Failed to update filter %s: %s\n", filter.URL, err) log.Printf("Failed to update filter %s: %s\n", filter.URL, err)
@ -601,6 +605,9 @@ func parseFilterContents(contents []byte) (int, string) {
// If "force" is true -- does not check the filter's LastUpdated field // If "force" is true -- does not check the filter's LastUpdated field
// Call "save" to persist the filter contents // Call "save" to persist the filter contents
func (filter *filter) update(force bool) (bool, error) { func (filter *filter) update(force bool) (bool, error) {
if filter.ID == 0 { // protect against users deleting the ID
filter.ID = assignUniqueFilterID()
}
if !filter.Enabled { if !filter.Enabled {
return false, nil return false, nil
} }

View File

@ -87,30 +87,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
trace("Called") trace("Called")
// The first schema upgrade: // The first schema upgrade:
// Added "ID" field to "filter" -- we need to populate this field now // No more "dnsfilter.txt", filters are now kept in data/filters/
// Added "config.ourDataDir" -- where we will now store filters contents
for i := range config.Filters {
filter := &config.Filters[i] // otherwise we will be operating on a copy
// Set the filter ID
log.Printf("Seting ID=%d for filter %s", NextFilterId, filter.URL)
filter.ID = NextFilterId
NextFilterId++
// Forcibly update the filter
_, err := filter.update(true)
if err != nil {
log.Fatal(err)
}
// Saving it to the filters dir now
err = filter.save()
if err != nil {
log.Fatal(err)
}
}
// No more "dnsfilter.txt", filters are now loaded from config.ourDataDir/filters/
dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt") dnsFilterPath := filepath.Join(config.ourBinaryDir, "dnsfilter.txt")
_, err := os.Stat(dnsFilterPath) _, err := os.Stat(dnsFilterPath)
if !os.IsNotExist(err) { if !os.IsNotExist(err) {