diff --git a/dnsfilter/dnsfilter.go b/dnsfilter/dnsfilter.go index 3e2b8f0c..4744ef5e 100644 --- a/dnsfilter/dnsfilter.go +++ b/dnsfilter/dnsfilter.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" @@ -53,6 +54,9 @@ type Config struct { // Per-client settings can override this configuration. BlockedServices []string `yaml:"blocked_services"` + // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files + AutoHosts *util.AutoHosts `yaml:"-"` + // Called when the configuration is changed by HTTP request ConfigModified func() `yaml:"-"` @@ -139,6 +143,9 @@ const ( // ReasonRewrite - rewrite rule was applied ReasonRewrite + + // RewriteEtcHosts - rewrite by /etc/hosts rule + RewriteEtcHosts ) var reasonNames = []string{ @@ -154,6 +161,7 @@ var reasonNames = []string{ "FilteredBlockedService", "Rewrite", + "RewriteEtcHosts", } func (r Reason) String() string { @@ -303,6 +311,15 @@ func (d *Dnsfilter) CheckHost(host string, qtype uint16, setts *RequestFiltering return result, nil } + if d.Config.AutoHosts != nil { + ips := d.Config.AutoHosts.Process(host) + if ips != nil { + result.Reason = RewriteEtcHosts + result.IPList = ips + return result, nil + } + } + // try filter lists first if setts.FilteringEnabled { result, err = d.matchHost(host, qtype, setts.ClientTags) diff --git a/dnsfilter/dnsfilter_test.go b/dnsfilter/dnsfilter_test.go index 963926a9..532e1c3c 100644 --- a/dnsfilter/dnsfilter_test.go +++ b/dnsfilter/dnsfilter_test.go @@ -3,6 +3,7 @@ package dnsfilter import ( "fmt" "net" + "os" "path" "runtime" "testing" @@ -621,6 +622,13 @@ func TestRewrites(t *testing.T) { assert.True(t, r.IPList[0].Equal(net.ParseIP("1.2.3.4"))) } +func prepareTestDir() string { + const dir = "./agh-test" + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, 0755) + return dir +} + // BENCHMARKS func BenchmarkSafeBrowsing(b *testing.B) { diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 646f61ff..9fe3b200 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -665,7 +665,11 @@ func processFilteringAfterResponse(ctx *dnsContext) int { res := ctx.result var err error - if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 { + switch res.Reason { + case dnsfilter.ReasonRewrite: + if len(res.CanonName) == 0 { + break + } d.Req.Question[0] = ctx.origQuestion d.Res.Question[0] = ctx.origQuestion @@ -676,7 +680,14 @@ func processFilteringAfterResponse(ctx *dnsContext) int { d.Res.Answer = answer } - } else if res.Reason != dnsfilter.NotFilteredWhiteList && ctx.protectionEnabled { + case dnsfilter.RewriteEtcHosts: + case dnsfilter.NotFilteredWhiteList: + // nothing + + default: + if !ctx.protectionEnabled { + break + } origResp2 := d.Res ctx.result, err = s.filterDNSResponse(ctx) if err != nil { @@ -845,7 +856,8 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) { // log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule) d.Res = s.genDNSFilterMessage(d, &res) - } else if res.Reason == dnsfilter.ReasonRewrite && len(res.IPList) != 0 { + } else if (res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts) && + len(res.IPList) != 0 { resp := s.makeResponse(req) name := host diff --git a/go.mod b/go.mod index 18e8de6e..1ed9bedc 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/AdguardTeam/urlfilter v0.9.1 github.com/NYTimes/gziphandler v1.1.1 github.com/etcd-io/bbolt v1.3.3 + github.com/fsnotify/fsnotify v1.4.9 github.com/go-test/deep v1.0.4 // indirect github.com/gobuffalo/packr v1.19.0 github.com/joomcode/errorx v1.0.0 diff --git a/go.sum b/go.sum index 837f305e..5861d329 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/etcd-io/bbolt v1.3.3 h1:gSJmxrs37LgTqR/oyJBWok6k6SvXEUerFTbltIhXkBM= github.com/etcd-io/bbolt v1.3.3/go.mod h1:ZF2nL25h33cCyBtcyWeZ2/I3HQOfTP+0PIEvHjkjCrw= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= @@ -181,6 +183,7 @@ golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191002091554-b397fe3ad8ed h1:5TJcLJn2a55mJjzYk0yOoqN8X1OdvBDUnaZaKKyQtkY= golang.org/x/sys v0.0.0-20191002091554-b397fe3ad8ed/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8 h1:JA8d3MPx/IToSyXZG/RhwYEtfrKO1Fxrqe8KrkiLXKM= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= diff --git a/home/clients.go b/home/clients.go index c7571a5f..b943d25d 100644 --- a/home/clients.go +++ b/home/clients.go @@ -3,9 +3,7 @@ package home import ( "bytes" "fmt" - "io/ioutil" "net" - "os" "os/exec" "runtime" "sort" @@ -16,6 +14,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/dhcpd" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" + "github.com/AdguardTeam/AdGuardHome/util" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/utils" @@ -79,12 +78,14 @@ type clientsContainer struct { // dhcpServer is used for looking up clients IP addresses by MAC addresses dhcpServer *dhcpd.Server + autoHosts *util.AutoHosts // get entries from system hosts-files + testing bool // if TRUE, this object is used for internal tests } // Init initializes clients container // Note: this function must be called only once -func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server) { +func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd.Server, autoHosts *util.AutoHosts) { if clients.list != nil { log.Fatal("clients.list != nil") } @@ -98,11 +99,13 @@ func (clients *clientsContainer) Init(objects []clientObject, dhcpServer *dhcpd. } clients.dhcpServer = dhcpServer + clients.autoHosts = autoHosts clients.addFromConfig(objects) if !clients.testing { clients.addFromDHCP() clients.dhcpServer.SetOnLeaseChanged(clients.onDHCPLeaseChanged) + clients.autoHosts.SetOnChanged(clients.onHostsChanged) } } @@ -120,7 +123,6 @@ func (clients *clientsContainer) Start() { // Reload - reload auto-clients func (clients *clientsContainer) Reload() { - clients.addFromHostsFile() clients.addFromSystemARP() } @@ -225,6 +227,10 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { } } +func (clients *clientsContainer) onHostsChanged() { + clients.addFromHostsFile() +} + // Exists checks if client with this IP already exists func (clients *clientsContainer) Exists(ip string, source clientSource) bool { clients.lock.Lock() @@ -605,46 +611,28 @@ func (clients *clientsContainer) rmHosts(source clientSource) int { return n } -// Parse system 'hosts' file and fill clients array +// Fill clients array from system hosts-file func (clients *clientsContainer) addFromHostsFile() { - hostsFn := "/etc/hosts" - if runtime.GOOS == "windows" { - hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") - } - - d, e := ioutil.ReadFile(hostsFn) - if e != nil { - log.Info("Can't read file %s: %v", hostsFn, e) - return - } + hosts := clients.autoHosts.List() clients.lock.Lock() defer clients.lock.Unlock() _ = clients.rmHosts(ClientSourceHostsFile) - lines := strings.Split(string(d), "\n") n := 0 - for _, ln := range lines { - ln = strings.TrimSpace(ln) - if len(ln) == 0 || ln[0] == '#' { - continue - } - - fields := strings.Fields(ln) - if len(fields) < 2 { - continue - } - - ok, e := clients.addHost(fields[0], fields[1], ClientSourceHostsFile) - if e != nil { - log.Tracef("%s", e) - } - if ok { - n++ + for ip, names := range hosts { + for _, name := range names { + ok, err := clients.addHost(ip, name.String(), ClientSourceHostsFile) + if err != nil { + log.Debug("Clients: %s", err) + } + if ok { + n++ + } } } - log.Debug("Clients: added %d client aliases from %s", n, hostsFn) + log.Debug("Clients: added %d client aliases from system hosts-file", n) } // Add IP -> Host pairs from the system's `arp -a` command output diff --git a/home/clients_test.go b/home/clients_test.go index 4468de35..50b96121 100644 --- a/home/clients_test.go +++ b/home/clients_test.go @@ -18,7 +18,7 @@ func TestClients(t *testing.T) { clients := clientsContainer{} clients.testing = true - clients.Init(nil, nil) + clients.Init(nil, nil, nil) // add c = Client{ @@ -156,7 +156,7 @@ func TestClientsWhois(t *testing.T) { var c Client clients := clientsContainer{} clients.testing = true - clients.Init(nil, nil) + clients.Init(nil, nil, nil) whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} // set whois info on new client @@ -183,7 +183,7 @@ func TestClientsAddExisting(t *testing.T) { var c Client clients := clientsContainer{} clients.testing = true - clients.Init(nil, nil) + clients.Init(nil, nil, nil) // some test variables mac, _ := net.ParseMAC("aa:aa:aa:aa:aa:aa") diff --git a/home/dns.go b/home/dns.go index c9cfb513..e19079ce 100644 --- a/home/dns.go +++ b/home/dns.go @@ -56,6 +56,7 @@ func initDNSServer() error { bindhost = "127.0.0.1" } filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) + filterConf.AutoHosts = &Context.autoHosts filterConf.ConfigModified = onConfigModified filterConf.HTTPRegister = httpRegister Context.dnsFilter = dnsfilter.New(&filterConf, nil) diff --git a/home/home.go b/home/home.go index 17d05d41..69447bc6 100644 --- a/home/home.go +++ b/home/home.go @@ -68,6 +68,7 @@ type homeContext struct { filters Filtering // DNS filtering module web *Web // Web (HTTP, HTTPS) module tls *TLSMod // TLS module + autoHosts util.AutoHosts // IP-hostname pairs taken from system configuration (e.g. /etc/hosts) files // Runtime properties // -- @@ -210,7 +211,8 @@ func run(args options) { if Context.dhcpServer == nil { os.Exit(1) } - Context.clients.Init(config.Clients, Context.dhcpServer) + Context.autoHosts.Init("") + Context.clients.Init(config.Clients, Context.dhcpServer, &Context.autoHosts) config.Clients = nil if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && @@ -270,6 +272,7 @@ func run(args options) { log.Fatalf("%s", err) } Context.tls.Start() + Context.autoHosts.Start() go func() { err := startDNSServer() @@ -434,6 +437,8 @@ func cleanup() { log.Error("Couldn't stop DHCP server: %s", err) } + Context.autoHosts.Close() + if Context.tls != nil { Context.tls.Close() Context.tls = nil diff --git a/home/service.go b/home/service.go index 2cb31ecb..05804d63 100644 --- a/home/service.go +++ b/home/service.go @@ -196,7 +196,7 @@ func handleServiceInstallCommand(s service.Service) { log.Fatal(err) } - if isOpenWrt() { + if util.IsOpenWrt() { // On OpenWrt it is important to run enable after the service installation // Otherwise, the service won't start on the system startup _, err := runInitdCommand("enable") @@ -223,7 +223,7 @@ Click on the link below and follow the Installation Wizard steps to finish setup // handleServiceStatusCommand handles service "uninstall" command func handleServiceUninstallCommand(s service.Service) { - if isOpenWrt() { + if util.IsOpenWrt() { // On OpenWrt it is important to run disable command first // as it will remove the symlink _, err := runInitdCommand("disable") @@ -270,7 +270,7 @@ func configureService(c *service.Config) { c.Option["SysvScript"] = sysvScript // On OpenWrt we're using a different type of sysvScript - if isOpenWrt() { + if util.IsOpenWrt() { c.Option["SysvScript"] = openWrtScript } } @@ -283,20 +283,6 @@ func runInitdCommand(action string) (int, error) { return code, err } -// isOpenWrt checks if OS is OpenWRT -func isOpenWrt() bool { - if runtime.GOOS != "linux" { - return false - } - - body, err := ioutil.ReadFile("/etc/os-release") - if err != nil { - return false - } - - return strings.Contains(string(body), "OpenWrt") -} - // Basically the same template as the one defined in github.com/kardianos/service // but with two additional keys - StandardOutPath and StandardErrorPath var launchdConfig = ` diff --git a/util/auto_hosts.go b/util/auto_hosts.go new file mode 100644 index 00000000..a2ea56ae --- /dev/null +++ b/util/auto_hosts.go @@ -0,0 +1,246 @@ +package util + +import ( + "bufio" + "io" + "io/ioutil" + "net" + "os" + "runtime" + "strings" + "sync" + + "github.com/AdguardTeam/golibs/log" + "github.com/fsnotify/fsnotify" +) + +type onChangedT func() + +// AutoHosts - automatic DNS records +type AutoHosts struct { + lock sync.Mutex // serialize access to table + table map[string][]net.IP // 'hostname -> IP' table + hostsFn string // path to the main hosts-file + hostsDirs []string // paths to OS-specific directories with hosts-files + watcher *fsnotify.Watcher // file and directory watcher object + updateChan chan bool // signal for 'update' goroutine + + onChanged onChangedT // notification to other modules +} + +// SetOnChanged - set callback function that will be called when the data is changed +func (a *AutoHosts) SetOnChanged(onChanged onChangedT) { + a.onChanged = onChanged +} + +// Notify other modules +func (a *AutoHosts) notify() { + if a.onChanged == nil { + return + } + a.onChanged() +} + +// Init - initialize +// hostsFn: Override default name for the hosts-file (optional) +func (a *AutoHosts) Init(hostsFn string) { + a.table = make(map[string][]net.IP) + a.updateChan = make(chan bool, 2) + + a.hostsFn = "/etc/hosts" + if runtime.GOOS == "windows" { + a.hostsFn = os.ExpandEnv("$SystemRoot\\system32\\drivers\\etc\\hosts") + } + if len(hostsFn) != 0 { + a.hostsFn = hostsFn + } + + if IsOpenWrt() { + a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c" + } + + var err error + a.watcher, err = fsnotify.NewWatcher() + if err != nil { + log.Error("AutoHosts: %s", err) + } +} + +// Start - start module +func (a *AutoHosts) Start() { + go a.update() + a.updateChan <- true + + go a.watcherLoop() + + err := a.watcher.Add(a.hostsFn) + if err != nil { + log.Error("AutoHosts: %s", err) + } + + for _, dir := range a.hostsDirs { + err = a.watcher.Add(dir) + if err != nil { + log.Error("AutoHosts: %s", err) + } + } +} + +// Close - close module +func (a *AutoHosts) Close() { + a.updateChan <- false + a.watcher.Close() +} + +// Read IP-hostname pairs from file +// Multiple hostnames per line (per one IP) is supported. +func (a *AutoHosts) load(table map[string][]net.IP, fn string) { + f, err := os.Open(fn) + if err != nil { + log.Error("AutoHosts: %s", err) + return + } + defer f.Close() + r := bufio.NewReader(f) + log.Debug("AutoHosts: loading hosts from file %s", fn) + + finish := false + for !finish { + line, err := r.ReadString('\n') + if err == io.EOF { + finish = true + } else if err != nil { + log.Error("AutoHosts: %s", err) + return + } + line = strings.TrimSpace(line) + + ip := SplitNext(&line, ' ') + ipAddr := net.ParseIP(ip) + if ipAddr == nil { + continue + } + for { + host := SplitNext(&line, ' ') + if len(host) == 0 { + break + } + ips, ok := table[host] + if ok { + for _, ip := range ips { + if ip.Equal(ipAddr) { + // IP already exists: don't add duplicates + ok = false + break + } + } + if !ok { + ips = append(ips, ipAddr) + table[host] = ips + } + } else { + table[host] = []net.IP{ipAddr} + ok = true + } + if ok { + log.Debug("AutoHosts: added %s -> %s", ip, host) + } + } + } +} + +// Receive notifications from fsnotify package +func (a *AutoHosts) watcherLoop() { + for { + select { + + case event, ok := <-a.watcher.Events: + if !ok { + return + } + + // skip duplicate events + repeat := true + for repeat { + select { + case _ = <-a.watcher.Events: + // skip this event + default: + repeat = false + } + } + + if event.Op&fsnotify.Write == fsnotify.Write { + log.Debug("AutoHosts: modified: %s", event.Name) + select { + case a.updateChan <- true: + // sent a signal to 'update' goroutine + default: + // queue is full + } + } + + case err, ok := <-a.watcher.Errors: + if !ok { + return + } + log.Error("AutoHosts: %s", err) + } + } +} + +// Read static hosts from system files +func (a *AutoHosts) update() { + for { + select { + case ok := <-a.updateChan: + if !ok { + return + } + + table := make(map[string][]net.IP) + + a.load(table, a.hostsFn) + + for _, dir := range a.hostsDirs { + fis, err := ioutil.ReadDir(dir) + if err != nil { + if !os.IsNotExist(err) { + log.Error("AutoHosts: Opening directory: %s: %s", dir, err) + } + continue + } + + for _, fi := range fis { + a.load(table, dir+"/"+fi.Name()) + } + } + + a.lock.Lock() + a.table = table + a.lock.Unlock() + a.notify() + } + } +} + +// Process - get the list of IP addresses for the hostname +func (a *AutoHosts) Process(host string) []net.IP { + a.lock.Lock() + ips, _ := a.table[host] + ipsCopy := make([]net.IP, len(ips)) + copy(ipsCopy, ips) + a.lock.Unlock() + return ipsCopy +} + +// List - get the hosts table. Thread-safe. +func (a *AutoHosts) List() map[string][]net.IP { + table := make(map[string][]net.IP) + a.lock.Lock() + for k, v := range a.table { + table[k] = v + } + a.lock.Unlock() + return table +} diff --git a/util/auto_hosts_test.go b/util/auto_hosts_test.go new file mode 100644 index 00000000..6243b5ca --- /dev/null +++ b/util/auto_hosts_test.go @@ -0,0 +1,54 @@ +package util + +import ( + "io/ioutil" + "net" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func prepareTestDir() string { + const dir = "./agh-test" + _ = os.RemoveAll(dir) + _ = os.MkdirAll(dir, 0755) + return dir +} + +func TestAutoHosts(t *testing.T) { + ah := AutoHosts{} + + dir := prepareTestDir() + defer func() { _ = os.RemoveAll(dir) }() + + f, _ := ioutil.TempFile(dir, "") + defer os.Remove(f.Name()) + defer f.Close() + + _, _ = f.WriteString(" 127.0.0.1 host localhost \n") + + ah.Init(f.Name()) + ah.Start() + // wait until we parse the file + time.Sleep(50 * time.Millisecond) + + ips := ah.Process("localhost") + assert.True(t, ips[0].Equal(net.ParseIP("127.0.0.1"))) + ips = ah.Process("newhost") + assert.True(t, len(ips) == 0) + + table := ah.List() + ips, _ = table["host"] + assert.True(t, ips[0].String() == "127.0.0.1") + + _, _ = f.WriteString("127.0.0.2 newhost\n") + // wait until fsnotify has triggerred and processed the file-modification event + time.Sleep(50 * time.Millisecond) + + ips = ah.Process("newhost") + assert.True(t, ips[0].Equal(net.ParseIP("127.0.0.2"))) + + ah.Close() +} diff --git a/util/helpers.go b/util/helpers.go index 45b3311e..27ac4d71 100644 --- a/util/helpers.go +++ b/util/helpers.go @@ -2,6 +2,7 @@ package util import ( "fmt" + "io/ioutil" "os" "os/exec" "path" @@ -37,6 +38,7 @@ func FuncName() string { } // SplitNext - split string by a byte and return the first chunk +// Skip empty chunks // Whitespace is trimmed func SplitNext(str *string, splitBy byte) string { i := strings.IndexByte(*str, splitBy) @@ -44,6 +46,14 @@ func SplitNext(str *string, splitBy byte) string { if i != -1 { s = (*str)[0:i] *str = (*str)[i+1:] + k := 0 + ch := rune(0) + for k, ch = range *str { + if byte(ch) != splitBy { + break + } + } + *str = (*str)[k:] } else { s = *str *str = "" @@ -58,3 +68,17 @@ func MinInt(a, b int) int { } return b } + +// IsOpenWrt checks if OS is OpenWRT +func IsOpenWrt() bool { + if runtime.GOOS != "linux" { + return false + } + + body, err := ioutil.ReadFile("/etc/os-release") + if err != nil { + return false + } + + return strings.Contains(string(body), "OpenWrt") +}