Merge branch 'master' into feature/update_locales

This commit is contained in:
ArtemBaskal 2020-09-14 20:22:55 +03:00
commit a5e189bdf3
22 changed files with 290 additions and 199 deletions

View File

@ -253,7 +253,7 @@
"rate_limit": "Rate limit",
"edns_enable": "Enable EDNS Client Subnet",
"edns_cs_desc": "If enabled, AdGuard Home will be sending clients' subnets to the DNS servers.",
"rate_limit_desc": "The number of requests per second that a single client is allowed to make (0: unlimited)",
"rate_limit_desc": "The number of requests per second that a single client is allowed to make (setting it to 0 means unlimited)",
"blocking_ipv4_desc": "IP address to be returned for a blocked A request",
"blocking_ipv6_desc": "IP address to be returned for a blocked AAAA request",
"blocking_mode_default": "Default: Respond with REFUSED when blocked by Adblock-style rule; respond with the IP address specified in the rule when blocked by /etc/hosts-style rule",
@ -564,10 +564,9 @@
"enter_cache_size": "Enter cache size",
"enter_cache_ttl_min_override": "Enter minimum TTL",
"enter_cache_ttl_max_override": "Enter maximum TTL",
"cache_ttl_min_override_desc": "Override TTL value (minimum) received from upstream server. This value can't larger than 3600 (1 hour)",
"cache_ttl_min_override_desc": "Override TTL value (minimum) received from upstream server",
"cache_ttl_max_override_desc": "Override TTL value (maximum) received from upstream server",
"min_exceeds_max_value": "Minimum value exceeds maximum value",
"value_not_larger_than": "Value can't be larger than {{maximum}}",
"ttl_cache_validation": "Minimum cache TTL value must be less than or equal to the maximum value",
"filter_category_general": "General",
"filter_category_security": "Security",
"filter_category_regional": "Regional",

View File

@ -26,3 +26,10 @@
left: -20px;
width: calc(100% + 20px);
}
@media (max-width: 1279.98px) {
.table__action {
position: absolute;
right: 0;
}
}

View File

@ -8,10 +8,9 @@ import {
renderInputField,
toNumber,
} from '../../../helpers/form';
import { FORM_NAME } from '../../../helpers/constants';
import { FORM_NAME, UINT32_RANGE } from '../../../helpers/constants';
import {
validateIpv4,
validateIsPositiveValue,
validateRequiredValue,
validateIpv4RangeEnd,
} from '../../../helpers/validators';
@ -110,9 +109,10 @@ const FormDHCPv4 = ({
type="number"
className="form-control"
placeholder={t(ipv4placeholders.lease_duration)}
validate={[validateIsPositiveValue, validateRequired]}
validate={validateRequired}
normalize={toNumber}
min={0}
min={1}
max={UINT32_RANGE.MAX}
disabled={!isInterfaceIncludesIpv4}
/>
</div>

View File

@ -8,12 +8,8 @@ import {
renderInputField,
toNumber,
} from '../../../helpers/form';
import { FORM_NAME } from '../../../helpers/constants';
import {
validateIpv6,
validateIsPositiveValue,
validateRequiredValue,
} from '../../../helpers/validators';
import { FORM_NAME, UINT32_RANGE } from '../../../helpers/constants';
import { validateIpv6, validateRequiredValue } from '../../../helpers/validators';
const FormDHCPv6 = ({
handleSubmit,
@ -86,9 +82,10 @@ const FormDHCPv6 = ({
type="number"
className="form-control"
placeholder={t(ipv6placeholders.lease_duration)}
validate={[validateIsPositiveValue, validateRequired]}
validate={validateRequired}
normalizeOnBlur={toNumber}
min={0}
min={1}
max={UINT32_RANGE.MAX}
disabled={!isInterfaceIncludesIpv6}
/>
</div>

View File

@ -4,32 +4,30 @@ import { Field, reduxForm } from 'redux-form';
import { Trans, useTranslation } from 'react-i18next';
import { shallowEqual, useSelector } from 'react-redux';
import { renderInputField, toNumber } from '../../../../helpers/form';
import { validateBiggerOrEqualZeroValue, getMaxValueValidator, validateRequiredValue } from '../../../../helpers/validators';
import { FORM_NAME, SECONDS_IN_HOUR } from '../../../../helpers/constants';
import { validateRequiredValue } from '../../../../helpers/validators';
import { CACHE_CONFIG_FIELDS, FORM_NAME, UINT32_RANGE } from '../../../../helpers/constants';
const validateMaxValue3600 = getMaxValueValidator(SECONDS_IN_HOUR);
const getInputFields = ({ validateRequiredValue, validateMaxValue3600 }) => [{
name: 'cache_size',
title: 'cache_size',
description: 'cache_size_desc',
placeholder: 'enter_cache_size',
validate: validateRequiredValue,
},
{
name: 'cache_ttl_min',
title: 'cache_ttl_min_override',
description: 'cache_ttl_min_override_desc',
placeholder: 'enter_cache_ttl_min_override',
max: SECONDS_IN_HOUR,
validate: validateMaxValue3600,
},
{
name: 'cache_ttl_max',
title: 'cache_ttl_max_override',
description: 'cache_ttl_max_override_desc',
placeholder: 'enter_cache_ttl_max_override',
}];
const getInputFields = (validateRequiredValue) => [
{
name: CACHE_CONFIG_FIELDS.cache_size,
title: 'cache_size',
description: 'cache_size_desc',
placeholder: 'enter_cache_size',
validate: validateRequiredValue,
},
{
name: CACHE_CONFIG_FIELDS.cache_ttl_min,
title: 'cache_ttl_min_override',
description: 'cache_ttl_min_override_desc',
placeholder: 'enter_cache_ttl_min_override',
},
{
name: CACHE_CONFIG_FIELDS.cache_ttl_max,
title: 'cache_ttl_max_override',
description: 'cache_ttl_max_override_desc',
placeholder: 'enter_cache_ttl_max_override',
},
];
const Form = ({
handleSubmit, submitting, invalid,
@ -41,17 +39,16 @@ const Form = ({
cache_ttl_max, cache_ttl_min,
} = useSelector((state) => state.form[FORM_NAME.CACHE].values, shallowEqual);
const minExceedsMax = cache_ttl_min > cache_ttl_max;
const minExceedsMax = typeof cache_ttl_min === 'number'
&& typeof cache_ttl_max === 'number'
&& cache_ttl_min > cache_ttl_max;
const INPUTS_FIELDS = getInputFields({
validateRequiredValue,
validateMaxValue3600,
});
const INPUTS_FIELDS = getInputFields(validateRequiredValue);
return <form onSubmit={handleSubmit}>
<div className="row">
{INPUTS_FIELDS.map(({
name, title, description, placeholder, validate, max,
name, title, description, placeholder, validate, min = 0, max = UINT32_RANGE.MAX,
}) => <div className="col-12" key={name}>
<div className="col-12 col-md-7 p-0">
<div className="form__group form__group--settings">
@ -66,15 +63,15 @@ const Form = ({
disabled={processingSetConfig}
normalize={toNumber}
className="form-control"
validate={[validateBiggerOrEqualZeroValue].concat(validate || [])}
min={0}
validate={validate}
min={min}
max={max}
/>
</div>
</div>
</div>)}
{minExceedsMax
&& <span className="text-danger pl-3 pb-3">{t('min_exceeds_max_value')}</span>}
&& <span className="text-danger pl-3 pb-3">{t('ttl_cache_validation')}</span>}
</div>
<button
type="submit"

View File

@ -1,10 +1,12 @@
import React from 'react';
import { useTranslation } from 'react-i18next';
import { shallowEqual, useDispatch, useSelector } from 'react-redux';
import { change } from 'redux-form';
import Card from '../../../ui/Card';
import Form from './Form';
import { setDnsConfig } from '../../../../actions/dnsConfig';
import { selectCompletedFields } from '../../../../helpers/helpers';
import { CACHE_CONFIG_FIELDS, FORM_NAME } from '../../../../helpers/constants';
const CacheConfig = () => {
const { t } = useTranslation();
@ -15,6 +17,15 @@ const CacheConfig = () => {
const handleFormSubmit = (values) => {
const completedFields = selectCompletedFields(values);
Object.entries(completedFields).forEach(([k, v]) => {
if ((k === CACHE_CONFIG_FIELDS.cache_ttl_min
|| k === CACHE_CONFIG_FIELDS.cache_ttl_max)
&& v === 0) {
dispatch(change(FORM_NAME.CACHE, k, ''));
}
});
dispatch(setDnsConfig(completedFields));
};

View File

@ -10,12 +10,11 @@ import {
toNumber,
} from '../../../../helpers/form';
import {
validateBiggerOrEqualZeroValue,
validateIpv4,
validateIpv6,
validateRequiredValue,
} from '../../../../helpers/validators';
import { BLOCKING_MODES, FORM_NAME } from '../../../../helpers/constants';
import { BLOCKING_MODES, FORM_NAME, UINT32_RANGE } from '../../../../helpers/constants';
const checkboxes = [
{
@ -87,7 +86,9 @@ const Form = ({
className="form-control"
placeholder={t('form_enter_rate_limit')}
normalize={toNumber}
validate={[validateRequiredValue, validateBiggerOrEqualZeroValue]}
validate={validateRequiredValue}
min={UINT32_RANGE.MIN}
max={UINT32_RANGE.MAX}
/>
</div>
</div>

View File

@ -506,9 +506,12 @@ export const FORM_NAME = {
export const SMALL_SCREEN_SIZE = 767;
export const MEDIUM_SCREEN_SIZE = 1023;
export const SECONDS_IN_HOUR = 60 * 60;
export const SECONDS_IN_DAY = 60 * 60 * 24;
export const SECONDS_IN_DAY = SECONDS_IN_HOUR * 24;
export const UINT32_RANGE = {
MIN: 0,
MAX: 4294967295,
};
export const DHCP_VALUES_PLACEHOLDERS = {
ipv4: {
@ -559,3 +562,9 @@ export const ADDRESS_TYPES = {
CIDR: 'CIDR',
UNKNOWN: 'UNKNOWN',
};
export const CACHE_CONFIG_FIELDS = {
cache_size: 'cache_size',
cache_ttl_min: 'cache_ttl_min',
cache_ttl_max: 'cache_ttl_max',
};

View File

@ -1,4 +1,3 @@
import i18next from 'i18next';
import {
MAX_PORT,
R_CIDR,
@ -28,17 +27,6 @@ export const validateRequiredValue = (value) => {
return 'form_error_required';
};
/**
* @param maximum {number}
* @returns {(value:number) => undefined|string}
*/
export const getMaxValueValidator = (maximum) => (value) => {
if (value && value > maximum) {
return i18next.t('value_not_larger_than', { maximum });
}
return undefined;
};
/**
* @param value {string}
* @returns {undefined|string}
@ -122,17 +110,6 @@ export const validateMac = (value) => {
return undefined;
};
/**
* @param value {number}
* @returns {undefined|string}
*/
export const validateIsPositiveValue = (value) => {
if ((value || value === 0) && value <= 0) {
return 'form_error_positive';
}
return undefined;
};
/**
* @param value {number}
* @returns {boolean|*}

View File

@ -214,6 +214,9 @@ func (s *Server) initDefaultSettings() {
if s.conf.TCPListenAddr == nil {
s.conf.TCPListenAddr = defaultValues.TCPListenAddr
}
if len(s.conf.BlockedHosts) == 0 {
s.conf.BlockedHosts = defaultBlockedHosts
}
}
// prepareUpstreamSettings - prepares upstream DNS server settings

View File

@ -31,6 +31,9 @@ var defaultDNS = []string{
}
var defaultBootstrap = []string{"9.9.9.10", "149.112.112.10", "2620:fe::10", "2620:fe::fe:10"}
// Often requested by all kinds of DNS probes
var defaultBlockedHosts = []string{"version.bind", "id.server", "hostname.bind"}
var webRegistered bool
// Server is the main way to start a DNS server.

View File

@ -9,13 +9,17 @@ import (
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io/ioutil"
"math/big"
"net"
"os"
"sort"
"sync"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/util"
"github.com/AdguardTeam/AdGuardHome/dhcpd"
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
"github.com/AdguardTeam/dnsproxy/proxy"
@ -664,17 +668,17 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
func TestRewrite(t *testing.T) {
c := dnsfilter.Config{}
c.Rewrites = []dnsfilter.RewriteEntry{
dnsfilter.RewriteEntry{
{
Domain: "test.com",
Answer: "1.2.3.4",
Type: dns.TypeA,
},
dnsfilter.RewriteEntry{
{
Domain: "alias.test.com",
Answer: "test.com",
Type: dns.TypeCNAME,
},
dnsfilter.RewriteEntry{
{
Domain: "my.alias.example.org",
Answer: "example.org",
Type: dns.TypeCNAME,
@ -1066,7 +1070,7 @@ func (d *testDHCP) SetOnLeaseChanged(onLeaseChanged dhcpd.OnLeaseChangedT) {
return
}
func TestPTRResponse(t *testing.T) {
func TestPTRResponseFromDHCPLeases(t *testing.T) {
dhcp := &testDHCP{}
c := dnsfilter.Config{}
@ -1094,3 +1098,45 @@ func TestPTRResponse(t *testing.T) {
s.Close()
}
func TestPTRResponseFromHosts(t *testing.T) {
c := dnsfilter.Config{
AutoHosts: &util.AutoHosts{},
}
// Prepare test hosts file
hf, _ := ioutil.TempFile("", "")
defer func() { _ = os.Remove(hf.Name()) }()
defer hf.Close()
_, _ = hf.WriteString(" 127.0.0.1 host # comment \n")
_, _ = hf.WriteString(" ::1 localhost#comment \n")
// Init auto hosts
c.AutoHosts.Init(hf.Name())
defer c.AutoHosts.Close()
f := dnsfilter.New(&c, nil)
s := NewServer(DNSCreateParams{DNSFilter: f})
s.conf.UDPListenAddr = &net.UDPAddr{Port: 0}
s.conf.TCPListenAddr = &net.TCPAddr{Port: 0}
s.conf.UpstreamDNS = []string{"127.0.0.1:53"}
s.conf.FilteringConfig.ProtectionEnabled = true
err := s.Prepare(nil)
assert.True(t, err == nil)
assert.Nil(t, s.Start())
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessage("1.0.0.127.in-addr.arpa.")
req.Question[0].Qtype = dns.TypePTR
resp, err := dns.Exchange(req, addr.String())
assert.Nil(t, err)
assert.Equal(t, 1, len(resp.Answer))
assert.Equal(t, dns.TypePTR, resp.Answer[0].Header().Rrtype)
assert.Equal(t, "1.0.0.127.in-addr.arpa.", resp.Answer[0].Header().Name)
ptr := resp.Answer[0].(*dns.PTR)
assert.Equal(t, "host.", ptr.Ptr)
s.Close()
}

View File

@ -51,7 +51,7 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
return nil, errorx.Decorate(err, "dnsfilter failed to check host '%s'", host)
} else if res.IsFiltered {
// log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
log.Tracef("Host %s is filtered, reason - '%s', matched rule: '%s'", host, res.Reason, res.Rule)
d.Res = s.genDNSFilterMessage(d, &res)
} else if res.Reason == dnsfilter.ReasonRewrite && len(res.CanonName) != 0 && len(res.IPList) == 0 {
@ -59,6 +59,19 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
// resolve canonical name, not the original host name
d.Req.Question[0].Name = dns.Fqdn(res.CanonName)
} else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 {
resp := s.makeResponse(req)
ptr := &dns.PTR{}
ptr.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr.Ptr = res.ReverseHost
resp.Answer = append(resp.Answer, ptr)
d.Res = resp
} else if res.Reason == dnsfilter.ReasonRewrite || res.Reason == dnsfilter.RewriteEtcHosts {
resp := s.makeResponse(req)
@ -81,20 +94,6 @@ func (s *Server) filterDNSRequest(ctx *dnsContext) (*dnsfilter.Result, error) {
}
d.Res = resp
} else if res.Reason == dnsfilter.RewriteEtcHosts && len(res.ReverseHost) != 0 {
resp := s.makeResponse(req)
ptr := &dns.PTR{}
ptr.Hdr = dns.RR_Header{
Name: req.Question[0].Name,
Rrtype: dns.TypePTR,
Ttl: s.conf.BlockedResponseTTL,
Class: dns.ClassINET,
}
ptr.Ptr = res.ReverseHost
resp.Answer = append(resp.Answer, ptr)
d.Res = resp
}
return &res, err

2
go.mod
View File

@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
go 1.14
require (
github.com/AdguardTeam/dnsproxy v0.32.1
github.com/AdguardTeam/dnsproxy v0.32.5
github.com/AdguardTeam/golibs v0.4.2
github.com/AdguardTeam/urlfilter v0.12.2
github.com/NYTimes/gziphandler v1.1.1

4
go.sum
View File

@ -7,8 +7,8 @@ dmitri.shuralyov.com/html/belt v0.0.0-20180602232347-f7d459c86be0/go.mod h1:JLBr
dmitri.shuralyov.com/service/change v0.0.0-20181023043359-a85b471d5412/go.mod h1:a1inKt/atXimZ4Mv927x+r7UpyzRUf4emIoiiSC2TN4=
dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D6DFvNNtx+9ybjezNCa8XF0xaYcETyp6rHWU=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/AdguardTeam/dnsproxy v0.32.1 h1:UoiFt/aT8YCBFUGe7hG8ehLRXyvoIf22mOQqeIQxhWI=
github.com/AdguardTeam/dnsproxy v0.32.1/go.mod h1:ZLDrKIypYxBDz2N9FQHgeehuHrwTbuhZXdGwNySshbw=
github.com/AdguardTeam/dnsproxy v0.32.5 h1:UiExd/uHt2UOL4tYg1+WfXXUlkxmlpnMnQiTs63PekQ=
github.com/AdguardTeam/dnsproxy v0.32.5/go.mod h1:ZLDrKIypYxBDz2N9FQHgeehuHrwTbuhZXdGwNySshbw=
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2 h1:7M28oTZFoFwNmp8eGPb3ImmYbxGaJLyQXeIFVHjME0o=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=

View File

@ -161,6 +161,9 @@ func run(args options) {
// configure log level and output
configureLogger(args)
// Go memory hacks
memoryUsage(args)
// print the first message after logger is configured
log.Println(version())
log.Debug("Current working directory is %s", Context.workDir)

44
home/memory.go Normal file
View File

@ -0,0 +1,44 @@
package home
import (
"os"
"runtime/debug"
"time"
"github.com/AdguardTeam/golibs/log"
)
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
// See this for the details on the performance hits & gains:
// https://github.com/AdguardTeam/AdGuardHome/issues/2044#issuecomment-687042211
func memoryUsage(args options) {
if args.disableMemoryOptimization {
log.Info("Memory optimization is disabled")
return
}
// Makes Go allocate heap at a slower pace
// By default we keep it at 50%
debug.SetGCPercent(50)
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
// instead of MADV_FREE on Linux when returning memory to the
// kernel. This is less efficient, but causes RSS numbers to drop
// more quickly.
_ = os.Setenv("GODEBUG", "madvdontneed=1")
// periodically call "debug.FreeOSMemory" so
// that the OS could reclaim the free memory
go func() {
ticker := time.NewTicker(5 * time.Minute)
for {
select {
case t := <-ticker.C:
t.Second()
log.Debug("Free OS memory")
debug.FreeOSMemory()
}
}
}()
}

View File

@ -24,7 +24,11 @@ type options struct {
// runningAsService flag is set to true when options are passed from the service runner
runningAsService bool
glinetMode bool // Activate GL-Inet mode
// disableMemoryOptimization - disables memory optimization hacks
// see memoryUsage() function for the details
disableMemoryOptimization bool
glinetMode bool // Activate GL-Inet compatibility mode
}
// functions used for their side-effects
@ -151,6 +155,13 @@ var noCheckUpdateArg = arg{
func(o options) []string { return boolSliceOrNil(o.disableUpdate) },
}
var disableMemoryOptimizationArg = arg{
"Disable memory optimization",
"no-mem-optimization", "",
nil, func(o options) (options, error) { o.disableMemoryOptimization = true; return o, nil }, nil,
func(o options) []string { return boolSliceOrNil(o.disableMemoryOptimization) },
}
var verboseArg = arg{
"Enable verbose output",
"verbose", "v",
@ -194,6 +205,7 @@ func init() {
pidfileArg,
checkConfigArg,
noCheckUpdateArg,
disableMemoryOptimizationArg,
verboseArg,
glinetArg,
versionArg,

View File

@ -140,6 +140,15 @@ func TestParseDisableUpdate(t *testing.T) {
}
}
func TestParseDisableMemoryOptimization(t *testing.T) {
if testParseOk(t).disableMemoryOptimization {
t.Fatal("empty is not disable update")
}
if !testParseOk(t, "--no-mem-optimization").disableMemoryOptimization {
t.Fatal("--no-mem-optimization is disable update")
}
}
func TestParseService(t *testing.T) {
if testParseOk(t).serviceControlAction != "" {
t.Fatal("empty is no service command")
@ -226,12 +235,17 @@ func TestSerializeGLInet(t *testing.T) {
testSerialize(t, options{glinetMode: true}, "--glinet")
}
func TestSerializeDisableMemoryOptimization(t *testing.T) {
testSerialize(t, options{disableMemoryOptimization: true}, "--no-mem-optimization")
}
func TestSerializeMultiple(t *testing.T) {
testSerialize(t, options{
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update")
serviceControlAction: "run",
configFilename: "config",
workDir: "work",
pidFile: "pid",
disableUpdate: true,
disableMemoryOptimization: true,
}, "-c", "config", "-w", "work", "-s", "run", "--pidfile", "pid", "--no-check-update", "--no-mem-optimization")
}

31
main.go
View File

@ -4,10 +4,6 @@
package main
import (
"os"
"runtime/debug"
"time"
"github.com/AdguardTeam/AdGuardHome/home"
)
@ -21,32 +17,5 @@ var channel = "release"
var goarm = ""
func main() {
memoryUsage()
home.Main(version, channel, goarm)
}
// memoryUsage implements a couple of not really beautiful hacks which purpose is to
// make OS reclaim the memory freed by AdGuard Home as soon as possible.
func memoryUsage() {
debug.SetGCPercent(10)
// madvdontneed: setting madvdontneed=1 will use MADV_DONTNEED
// instead of MADV_FREE on Linux when returning memory to the
// kernel. This is less efficient, but causes RSS numbers to drop
// more quickly.
_ = os.Setenv("GODEBUG", "madvdontneed=1")
// periodically call "debug.FreeOSMemory" so
// that the OS could reclaim the free memory
go func() {
ticker := time.NewTicker(15 * time.Second)
for {
select {
case t := <-ticker.C:
t.Second()
debug.FreeOSMemory()
}
}
}()
}

View File

@ -10,9 +10,10 @@ import (
"strings"
"sync"
"github.com/miekg/dns"
"github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
)
type onChangedT func()
@ -62,6 +63,9 @@ func (a *AutoHosts) Init(hostsFn string) {
a.hostsDirs = append(a.hostsDirs, "/tmp/hosts") // OpenWRT: "/tmp/hosts/dhcp.cfg01411c"
}
// Load hosts initially
a.updateHosts()
var err error
a.watcher, err = fsnotify.NewWatcher()
if err != nil {
@ -102,6 +106,62 @@ func (a *AutoHosts) Close() {
}
}
// Process - get the list of IP addresses for the hostname
// Return nil if not found
func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
a.lock.Lock()
ips, _ := a.table[host]
if len(ips) != 0 {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
a.lock.Unlock()
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse - process PTR request
// Return "" if not found or an error occurred
func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
if qtype != dns.TypePTR {
return ""
}
ipReal := DNSUnreverseAddr(addr)
if ipReal == nil {
return "" // invalid IP in question
}
ipStr := ipReal.String()
a.lock.Lock()
host := a.tableReverse[ipStr]
a.lock.Unlock()
if len(host) == 0 {
return "" // not found
}
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host)
return host
}
// List - get "IP -> hostname" table. Thread-safe.
func (a *AutoHosts) List() map[string]string {
table := make(map[string]string)
a.lock.Lock()
for k, v := range a.tableReverse {
table[k] = v
}
a.lock.Unlock()
return table
}
// update table
func (a *AutoHosts) updateTable(table map[string][]net.IP, host string, ipAddr net.IP) {
ips, ok := table[host]
@ -275,59 +335,3 @@ func (a *AutoHosts) updateHosts() {
a.notify()
}
// Process - get the list of IP addresses for the hostname
// Return nil if not found
func (a *AutoHosts) Process(host string, qtype uint16) []net.IP {
if qtype == dns.TypePTR {
return nil
}
var ipsCopy []net.IP
a.lock.Lock()
ips, _ := a.table[host]
if len(ips) != 0 {
ipsCopy = make([]net.IP, len(ips))
copy(ipsCopy, ips)
}
a.lock.Unlock()
log.Debug("AutoHosts: answer: %s -> %v", host, ipsCopy)
return ipsCopy
}
// ProcessReverse - process PTR request
// Return "" if not found or an error occurred
func (a *AutoHosts) ProcessReverse(addr string, qtype uint16) string {
if qtype != dns.TypePTR {
return ""
}
ipReal := DNSUnreverseAddr(addr)
if ipReal == nil {
return "" // invalid IP in question
}
ipStr := ipReal.String()
a.lock.Lock()
host := a.tableReverse[ipStr]
a.lock.Unlock()
if len(host) == 0 {
return "" // not found
}
log.Debug("AutoHosts: reverse-lookup: %s -> %s", addr, host)
return host
}
// List - get "IP -> hostname" table. Thread-safe.
func (a *AutoHosts) List() map[string]string {
table := make(map[string]string)
a.lock.Lock()
for k, v := range a.tableReverse {
table[k] = v
}
a.lock.Unlock()
return table
}

View File

@ -34,9 +34,6 @@ func TestAutoHostsResolution(t *testing.T) {
ah.Init(f.Name())
// Update from the hosts file
ah.updateHosts()
// Existing host
ips := ah.Process("localhost", dns.TypeA)
assert.NotNil(t, ips)
@ -79,7 +76,6 @@ func TestAutoHostsFSNotify(t *testing.T) {
// Init
_, _ = f.WriteString(" 127.0.0.1 host localhost \n")
ah.Init(f.Name())
ah.updateHosts()
// Unknown host
ips := ah.Process("newhost", dns.TypeA)