all: sync with master
This commit is contained in:
parent
fbc0d981ba
commit
6f7bfd6c9c
|
@ -32,11 +32,15 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
- Statistics for 7 days displayed by day on the dashboard graph ([#6712]).
|
||||||
|
- Missing "served from cache" label on long DNS server strings ([#6740]).
|
||||||
- Incorrect tracking of the system hosts file's changes ([#6711]).
|
- Incorrect tracking of the system hosts file's changes ([#6711]).
|
||||||
|
|
||||||
[#5992]: https://github.com/AdguardTeam/AdGuardHome/issues/5992
|
[#5992]: https://github.com/AdguardTeam/AdGuardHome/issues/5992
|
||||||
[#6610]: https://github.com/AdguardTeam/AdGuardHome/issues/6610
|
[#6610]: https://github.com/AdguardTeam/AdGuardHome/issues/6610
|
||||||
[#6711]: https://github.com/AdguardTeam/AdGuardHome/issues/6711
|
[#6711]: https://github.com/AdguardTeam/AdGuardHome/issues/6711
|
||||||
|
[#6712]: https://github.com/AdguardTeam/AdGuardHome/issues/6712
|
||||||
|
[#6740]: https://github.com/AdguardTeam/AdGuardHome/issues/6740
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||||
|
|
|
@ -473,6 +473,9 @@ bug or implementing the feature.
|
||||||
[@kongfl888](https://github.com/kongfl888) (originally by
|
[@kongfl888](https://github.com/kongfl888) (originally by
|
||||||
[@rufengsuixing](https://github.com/rufengsuixing)).
|
[@rufengsuixing](https://github.com/rufengsuixing)).
|
||||||
|
|
||||||
|
* [AdGuardHome sync](https://github.com/bakito/adguardhome-sync) by
|
||||||
|
[@bakito](https://github.com/bakito).
|
||||||
|
|
||||||
* [Terminal-based, real-time traffic monitoring and statistics for your AdGuard Home
|
* [Terminal-based, real-time traffic monitoring and statistics for your AdGuard Home
|
||||||
instance](https://github.com/Lissy93/AdGuardian-Term) by
|
instance](https://github.com/Lissy93/AdGuardian-Term) by
|
||||||
[@Lissy93](https://github.com/Lissy93)
|
[@Lissy93](https://github.com/Lissy93)
|
||||||
|
|
|
@ -122,8 +122,6 @@
|
||||||
# from the release branch and are used to build the release candidate
|
# from the release branch and are used to build the release candidate
|
||||||
# images.
|
# images.
|
||||||
- '^rc-v[0-9]+\.[0-9]+\.[0-9]+':
|
- '^rc-v[0-9]+\.[0-9]+\.[0-9]+':
|
||||||
# Build betas on release branches manually.
|
|
||||||
'triggers': []
|
|
||||||
# Set the default release channel on the release branch to beta, as we
|
# Set the default release channel on the release branch to beta, as we
|
||||||
# may need to build a few of these.
|
# may need to build a few of these.
|
||||||
'variables':
|
'variables':
|
||||||
|
|
|
@ -678,7 +678,7 @@
|
||||||
"use_saved_key": "Use the previously saved key",
|
"use_saved_key": "Use the previously saved key",
|
||||||
"parental_control": "Parental Control",
|
"parental_control": "Parental Control",
|
||||||
"safe_browsing": "Safe Browsing",
|
"safe_browsing": "Safe Browsing",
|
||||||
"served_from_cache": "{{value}} <i>(served from cache)</i>",
|
"served_from_cache_label": "Served from cache",
|
||||||
"form_error_password_length": "Password must be {{min}} to {{max}} characters long",
|
"form_error_password_length": "Password must be {{min}} to {{max}} characters long",
|
||||||
"anonymizer_notification": "<0>Note:</0> IP anonymization is enabled. You can disable it in <1>General settings</1>.",
|
"anonymizer_notification": "<0>Note:</0> IP anonymization is enabled. You can disable it in <1>General settings</1>.",
|
||||||
"confirm_dns_cache_clear": "Are you sure you want to clear DNS cache?",
|
"confirm_dns_cache_clear": "Are you sure you want to clear DNS cache?",
|
||||||
|
|
|
@ -55,6 +55,12 @@ const Dashboard = ({
|
||||||
return t('stats_disabled_short');
|
return t('stats_disabled_short');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const msIn7Days = 604800000;
|
||||||
|
|
||||||
|
if (stats.timeUnits === TIME_UNITS.HOURS && stats.interval === msIn7Days) {
|
||||||
|
return t('for_last_days', { count: msToDays(stats.interval) });
|
||||||
|
}
|
||||||
|
|
||||||
return stats.timeUnits === TIME_UNITS.HOURS
|
return stats.timeUnits === TIME_UNITS.HOURS
|
||||||
? t('for_last_hours', { count: msToHours(stats.interval) })
|
? t('for_last_hours', { count: msToHours(stats.interval) })
|
||||||
: t('for_last_days', { count: msToDays(stats.interval) });
|
: t('for_last_days', { count: msToDays(stats.interval) });
|
||||||
|
|
|
@ -38,9 +38,6 @@ const ResponseCell = ({
|
||||||
|
|
||||||
const statusLabel = t(isBlockedByResponse ? 'blocked_by_cname_or_ip' : FILTERED_STATUS_TO_META_MAP[reason]?.LABEL || reason);
|
const statusLabel = t(isBlockedByResponse ? 'blocked_by_cname_or_ip' : FILTERED_STATUS_TO_META_MAP[reason]?.LABEL || reason);
|
||||||
const boldStatusLabel = <span className="font-weight-bold">{statusLabel}</span>;
|
const boldStatusLabel = <span className="font-weight-bold">{statusLabel}</span>;
|
||||||
const upstreamString = cached
|
|
||||||
? t('served_from_cache', { value: upstream, i: <i /> })
|
|
||||||
: upstream;
|
|
||||||
|
|
||||||
const renderResponses = (responseArr) => {
|
const renderResponses = (responseArr) => {
|
||||||
if (!responseArr || responseArr.length === 0) {
|
if (!responseArr || responseArr.length === 0) {
|
||||||
|
@ -58,7 +55,16 @@ const ResponseCell = ({
|
||||||
|
|
||||||
const COMMON_CONTENT = {
|
const COMMON_CONTENT = {
|
||||||
encryption_status: boldStatusLabel,
|
encryption_status: boldStatusLabel,
|
||||||
install_settings_dns: upstreamString,
|
install_settings_dns: upstream,
|
||||||
|
...(cached
|
||||||
|
&& {
|
||||||
|
served_from_cache_label: (
|
||||||
|
<svg className="icons icon--20 icon--green mb-1">
|
||||||
|
<use xlinkHref="#check" />
|
||||||
|
</svg>
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
elapsed: formattedElapsedMs,
|
elapsed: formattedElapsedMs,
|
||||||
response_code: status,
|
response_code: status,
|
||||||
...(service_name && services.allServices
|
...(service_name && services.allServices
|
||||||
|
|
|
@ -118,9 +118,6 @@ const Row = memo(({
|
||||||
|
|
||||||
const blockingForClientKey = isFiltered ? 'unblock_for_this_client_only' : 'block_for_this_client_only';
|
const blockingForClientKey = isFiltered ? 'unblock_for_this_client_only' : 'block_for_this_client_only';
|
||||||
const clientNameBlockingFor = getBlockingClientName(clients, client);
|
const clientNameBlockingFor = getBlockingClientName(clients, client);
|
||||||
const upstreamString = cached
|
|
||||||
? t('served_from_cache', { value: upstream, i: <i /> })
|
|
||||||
: upstream;
|
|
||||||
|
|
||||||
const onBlockingForClientClick = () => {
|
const onBlockingForClientClick = () => {
|
||||||
dispatch(toggleBlockingForClient(buttonType, domain, clientNameBlockingFor));
|
dispatch(toggleBlockingForClient(buttonType, domain, clientNameBlockingFor));
|
||||||
|
@ -192,7 +189,16 @@ const Row = memo(({
|
||||||
className="link--green">{sourceData.name}
|
className="link--green">{sourceData.name}
|
||||||
</a>,
|
</a>,
|
||||||
response_details: 'title',
|
response_details: 'title',
|
||||||
install_settings_dns: upstreamString,
|
install_settings_dns: upstream,
|
||||||
|
...(cached
|
||||||
|
&& {
|
||||||
|
served_from_cache_label: (
|
||||||
|
<svg className="icons icon--20 icon--green">
|
||||||
|
<use xlinkHref="#check" />
|
||||||
|
</svg>
|
||||||
|
),
|
||||||
|
}
|
||||||
|
),
|
||||||
elapsed: formattedElapsedMs,
|
elapsed: formattedElapsedMs,
|
||||||
...(rules.length > 0
|
...(rules.length > 0
|
||||||
&& { rule_label: getRulesToFilterList(rules, filters, whitelistFilters) }
|
&& { rule_label: getRulesToFilterList(rules, filters, whitelistFilters) }
|
||||||
|
|
|
@ -245,6 +245,10 @@ const Icons = () => (
|
||||||
<path fillRule="evenodd" clipRule="evenodd" d="M12 13.5C11.1716 13.5 10.5 12.8284 10.5 12C10.5 11.1716 11.1716 10.5 12 10.5C12.8284 10.5 13.5 11.1716 13.5 12C13.5 12.8284 12.8284 13.5 12 13.5Z" fill="currentColor" />
|
<path fillRule="evenodd" clipRule="evenodd" d="M12 13.5C11.1716 13.5 10.5 12.8284 10.5 12C10.5 11.1716 11.1716 10.5 12 10.5C12.8284 10.5 13.5 11.1716 13.5 12C13.5 12.8284 12.8284 13.5 12 13.5Z" fill="currentColor" />
|
||||||
<path fillRule="evenodd" clipRule="evenodd" d="M12 20C11.1716 20 10.5 19.3284 10.5 18.5C10.5 17.6716 11.1716 17 12 17C12.8284 17 13.5 17.6716 13.5 18.5C13.5 19.3284 12.8284 20 12 20Z" fill="currentColor" />
|
<path fillRule="evenodd" clipRule="evenodd" d="M12 20C11.1716 20 10.5 19.3284 10.5 18.5C10.5 17.6716 11.1716 17 12 17C12.8284 17 13.5 17.6716 13.5 18.5C13.5 19.3284 12.8284 20 12 20Z" fill="currentColor" />
|
||||||
</symbol>
|
</symbol>
|
||||||
|
|
||||||
|
<symbol id="check" fill="none" viewBox="0 0 24 24" stroke="currentColor" strokeWidth="2" strokeLinecap="round" strokeLinejoin="round">
|
||||||
|
<path d="M5 11.7665L10.5878 17L19 8" />
|
||||||
|
</symbol>
|
||||||
</svg>
|
</svg>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
||||||
go 1.21.8
|
go 1.21.8
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/AdguardTeam/dnsproxy v0.65.2
|
github.com/AdguardTeam/dnsproxy v0.66.0
|
||||||
github.com/AdguardTeam/golibs v0.20.1
|
github.com/AdguardTeam/golibs v0.20.1
|
||||||
github.com/AdguardTeam/urlfilter v0.18.0
|
github.com/AdguardTeam/urlfilter v0.18.0
|
||||||
github.com/NYTimes/gziphandler v1.1.1
|
github.com/NYTimes/gziphandler v1.1.1
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -1,5 +1,5 @@
|
||||||
github.com/AdguardTeam/dnsproxy v0.65.2 h1:D+BMw0Vu2lbQrYpoPctG2Xr+24KdfhgkzZb6QgPZheM=
|
github.com/AdguardTeam/dnsproxy v0.66.0 h1:RyUbyDxRSXBFjVG1l2/4HV3I98DtfIgpnZkgXkgHKnc=
|
||||||
github.com/AdguardTeam/dnsproxy v0.65.2/go.mod h1:8NQTTNZY+qR9O1Fzgz3WQv30knfSgms68SRlzSnX74A=
|
github.com/AdguardTeam/dnsproxy v0.66.0/go.mod h1:ZThEXbMUlP1RxfwtNW30ItPAHE6OF4YFygK8qjU/cvY=
|
||||||
github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY=
|
github.com/AdguardTeam/golibs v0.20.1 h1:ol8qLjWGZhU9paMMwN+OLWVTUigGsXa29iVTyd62VKY=
|
||||||
github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU=
|
github.com/AdguardTeam/golibs v0.20.1/go.mod h1:bgcMgRviCKyU6mkrX+RtT/OsKPFzyppelfRsksMG3KU=
|
||||||
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
|
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
|
||||||
|
|
|
@ -5,9 +5,9 @@ package aghalg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"golang.org/x/exp/constraints"
|
"golang.org/x/exp/constraints"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Coalesce returns the first non-zero value. It is named after function
|
// Coalesce returns the first non-zero value. It is named after function
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
package aghalg_test
|
package aghalg_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// elements is a helper function that returns n elements of the buffer.
|
// elements is a helper function that returns n elements of the buffer.
|
||||||
|
|
|
@ -0,0 +1,86 @@
|
||||||
|
package aghalg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SortedMap is a map that keeps elements in order with internal sorting
|
||||||
|
// function. Must be initialised by the [NewSortedMap].
|
||||||
|
type SortedMap[K comparable, V any] struct {
|
||||||
|
vals map[K]V
|
||||||
|
cmp func(a, b K) (res int)
|
||||||
|
keys []K
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSortedMap initializes the new instance of sorted map. cmp is a sort
|
||||||
|
// function to keep elements in order.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
|
||||||
|
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
|
||||||
|
return SortedMap[K, V]{
|
||||||
|
vals: map[K]V{},
|
||||||
|
cmp: cmp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set adds val with key to the sorted map. It panics if the m is nil.
|
||||||
|
func (m *SortedMap[K, V]) Set(key K, val V) {
|
||||||
|
m.vals[key] = val
|
||||||
|
|
||||||
|
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
|
||||||
|
if has {
|
||||||
|
m.keys[i] = key
|
||||||
|
} else {
|
||||||
|
m.keys = slices.Insert(m.keys, i, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns val by key from the sorted map.
|
||||||
|
func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok = m.vals[key]
|
||||||
|
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del removes the value by key from the sorted map.
|
||||||
|
func (m *SortedMap[K, V]) Del(key K) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, has := m.vals[key]; !has {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.vals, key)
|
||||||
|
i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp)
|
||||||
|
m.keys = slices.Delete(m.keys, i, i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all elements from the sorted map.
|
||||||
|
func (m *SortedMap[K, V]) Clear() {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.keys = nil
|
||||||
|
clear(m.vals)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
|
||||||
|
// false it stops.
|
||||||
|
func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range m.keys {
|
||||||
|
if !cb(k, m.vals[k]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
package aghalg
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewSortedMap(t *testing.T) {
|
||||||
|
var m SortedMap[string, int]
|
||||||
|
|
||||||
|
letters := []string{}
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
r := string('a' + rune(i))
|
||||||
|
letters = append(letters, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("create_and_fill", func(t *testing.T) {
|
||||||
|
m = NewSortedMap[string, int](strings.Compare)
|
||||||
|
|
||||||
|
nums := []int{}
|
||||||
|
for i, r := range letters {
|
||||||
|
m.Set(r, i)
|
||||||
|
nums = append(nums, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotLetters := []string{}
|
||||||
|
gotNums := []int{}
|
||||||
|
m.Range(func(k string, v int) bool {
|
||||||
|
gotLetters = append(gotLetters, k)
|
||||||
|
gotNums = append(gotNums, v)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Equal(t, letters, gotLetters)
|
||||||
|
assert.Equal(t, nums, gotNums)
|
||||||
|
|
||||||
|
n, ok := m.Get(letters[0])
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, nums[0], n)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("clear", func(t *testing.T) {
|
||||||
|
lastLetter := letters[len(letters)-1]
|
||||||
|
m.Del(lastLetter)
|
||||||
|
|
||||||
|
_, ok := m.Get(lastLetter)
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
m.Clear()
|
||||||
|
|
||||||
|
gotLetters := []string{}
|
||||||
|
m.Range(func(k string, _ int) bool {
|
||||||
|
gotLetters = append(gotLetters, k)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Len(t, gotLetters, 0)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSortedMap_nil(t *testing.T) {
|
||||||
|
const (
|
||||||
|
key = "key"
|
||||||
|
val = "val"
|
||||||
|
)
|
||||||
|
|
||||||
|
var m SortedMap[string, string]
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
m.Set(key, val)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
_, ok := m.Get(key)
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
m.Range(func(_, _ string) (cont bool) {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
m.Del(key)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
m.Clear()
|
||||||
|
})
|
||||||
|
}
|
|
@ -154,8 +154,8 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEvents concurrently handles the file system events. It closes the
|
// handleEvents concurrently handles the file system events. It closes the
|
||||||
// update channel of HostsContainer when finishes. It's used to be called
|
// update channel of HostsContainer when finishes. It is intended to be used as
|
||||||
// within a separate goroutine.
|
// a goroutine.
|
||||||
func (hc *HostsContainer) handleEvents() {
|
func (hc *HostsContainer) handleEvents() {
|
||||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
|
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
|
||||||
|
|
||||||
|
|
|
@ -67,6 +67,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: onEvents,
|
OnEvents: onEvents,
|
||||||
OnAdd: onAdd,
|
OnAdd: onAdd,
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
|
@ -93,6 +94,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
t.Run("nil_fs", func(t *testing.T) {
|
t.Run("nil_fs", func(t *testing.T) {
|
||||||
require.Panics(t, func() {
|
require.Panics(t, func() {
|
||||||
_, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{
|
_, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
// Those shouldn't panic.
|
// Those shouldn't panic.
|
||||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
OnAdd: func(name string) (err error) { return nil },
|
OnAdd: func(name string) (err error) { return nil },
|
||||||
|
@ -111,6 +113,7 @@ func TestNewHostsContainer(t *testing.T) {
|
||||||
const errOnAdd errors.Error = "error"
|
const errOnAdd errors.Error = "error"
|
||||||
|
|
||||||
errWatcher := &aghtest.FSWatcher{
|
errWatcher := &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
OnEvents: func() (e <-chan struct{}) { panic("not implemented") },
|
||||||
OnAdd: func(name string) (err error) { return errOnAdd },
|
OnAdd: func(name string) (err error) { return errOnAdd },
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
|
@ -155,6 +158,7 @@ func TestHostsContainer_refresh(t *testing.T) {
|
||||||
t.Cleanup(func() { close(eventsCh) })
|
t.Cleanup(func() { close(eventsCh) })
|
||||||
|
|
||||||
w := &aghtest.FSWatcher{
|
w := &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: func() (e <-chan event) { return eventsCh },
|
OnEvents: func() (e <-chan event) { return eventsCh },
|
||||||
OnAdd: func(name string) (err error) {
|
OnAdd: func(name string) (err error) {
|
||||||
assert.Equal(t, "dir", name)
|
assert.Equal(t, "dir", name)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
package aghnet
|
package aghnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/urlfilter"
|
"github.com/AdguardTeam/urlfilter"
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IgnoreEngine contains the list of rules for ignoring hostnames and matches
|
// IgnoreEngine contains the list of rules for ignoring hostnames and matches
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DialContextFunc is the semantic alias for dialing functions, such as
|
// DialContextFunc is the semantic alias for dialing functions, such as
|
||||||
|
@ -32,7 +33,7 @@ var (
|
||||||
netInterfaceAddrs = net.InterfaceAddrs
|
netInterfaceAddrs = net.InterfaceAddrs
|
||||||
|
|
||||||
// rootDirFS is the filesystem pointing to the root directory.
|
// rootDirFS is the filesystem pointing to the root directory.
|
||||||
rootDirFS = aghos.RootDirFS()
|
rootDirFS = osutil.RootDirFS()
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
|
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,31 +20,38 @@ type event = struct{}
|
||||||
// FSWatcher tracks all the fyle system events and notifies about those.
|
// FSWatcher tracks all the fyle system events and notifies about those.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
// TODO(e.burkov, a.garipov): Move into another package like aghfs.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Add tests.
|
||||||
type FSWatcher interface {
|
type FSWatcher interface {
|
||||||
|
// Start starts watching the added files.
|
||||||
|
Start() (err error)
|
||||||
|
|
||||||
|
// Close stops watching the files and closes an update channel.
|
||||||
io.Closer
|
io.Closer
|
||||||
|
|
||||||
// Events should return a read-only channel which notifies about events.
|
// Events returns the channel to notify about the file system events.
|
||||||
Events() (e <-chan event)
|
Events() (e <-chan event)
|
||||||
|
|
||||||
// Add should check if the file named name is accessible and starts tracking
|
// Add starts tracking the file. It returns an error if the file can't be
|
||||||
// it.
|
// tracked. It must not be called after Start.
|
||||||
Add(name string) (err error)
|
Add(name string) (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// osWatcher tracks the file system provided by the OS.
|
// osWatcher tracks the file system provided by the OS.
|
||||||
type osWatcher struct {
|
type osWatcher struct {
|
||||||
// w is the actual notifier that is handled by osWatcher.
|
// watcher is the actual notifier that is handled by osWatcher.
|
||||||
w *fsnotify.Watcher
|
watcher *fsnotify.Watcher
|
||||||
|
|
||||||
// events is the channel to notify.
|
// events is the channel to notify.
|
||||||
events chan event
|
events chan event
|
||||||
|
|
||||||
|
// files is the set of tracked files.
|
||||||
|
files *stringutil.Set
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
||||||
// osWatcherPref is a prefix for logging and wrapping errors in osWathcer's
|
// methods.
|
||||||
// methods.
|
const osWatcherPref = "os watcher"
|
||||||
osWatcherPref = "os watcher"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
|
// NewOSWritesWatcher creates FSWatcher that tracks the real file system of the
|
||||||
// OS and notifies only about writing events.
|
// OS and notifies only about writing events.
|
||||||
|
@ -55,25 +64,27 @@ func NewOSWritesWatcher() (w FSWatcher, err error) {
|
||||||
return nil, fmt.Errorf("creating watcher: %w", err)
|
return nil, fmt.Errorf("creating watcher: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fsw := &osWatcher{
|
return &osWatcher{
|
||||||
w: watcher,
|
watcher: watcher,
|
||||||
events: make(chan event, 1),
|
events: make(chan event, 1),
|
||||||
}
|
files: stringutil.NewSet(),
|
||||||
|
}, nil
|
||||||
go fsw.handleErrors()
|
|
||||||
go fsw.handleEvents()
|
|
||||||
|
|
||||||
return fsw, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleErrors handles accompanying errors. It used to be called in a separate
|
// type check
|
||||||
// goroutine.
|
var _ FSWatcher = (*osWatcher)(nil)
|
||||||
func (w *osWatcher) handleErrors() {
|
|
||||||
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
|
|
||||||
|
|
||||||
for err := range w.w.Errors {
|
// Start implements the FSWatcher interface for *osWatcher.
|
||||||
log.Error("%s: %s", osWatcherPref, err)
|
func (w *osWatcher) Start() (err error) {
|
||||||
}
|
go w.handleErrors()
|
||||||
|
go w.handleEvents()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the FSWatcher interface for *osWatcher.
|
||||||
|
func (w *osWatcher) Close() (err error) {
|
||||||
|
return w.watcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Events implements the FSWatcher interface for *osWatcher.
|
// Events implements the FSWatcher interface for *osWatcher.
|
||||||
|
@ -81,34 +92,42 @@ func (w *osWatcher) Events() (e <-chan event) {
|
||||||
return w.events
|
return w.events
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add implements the FSWatcher interface for *osWatcher.
|
// Add implements the [FSWatcher] interface for *osWatcher.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Make it accept non-existing files to detect it's creating.
|
// TODO(e.burkov): Make it accept non-existing files to detect it's creating.
|
||||||
func (w *osWatcher) Add(name string) (err error) {
|
func (w *osWatcher) Add(name string) (err error) {
|
||||||
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
defer func() { err = errors.Annotate(err, "%s: %w", osWatcherPref) }()
|
||||||
|
|
||||||
if _, err = fs.Stat(RootDirFS(), name); err != nil {
|
fi, err := fs.Stat(osutil.RootDirFS(), name)
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("checking file %q: %w", name, err)
|
return fmt.Errorf("checking file %q: %w", name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.w.Add(filepath.Join("/", name))
|
name = filepath.Join("/", name)
|
||||||
}
|
w.files.Add(name)
|
||||||
|
|
||||||
// Close implements the FSWatcher interface for *osWatcher.
|
// Watch the directory and filter the events by the file name, since the
|
||||||
func (w *osWatcher) Close() (err error) {
|
// common recomendation to the fsnotify package is to watch the directory
|
||||||
return w.w.Close()
|
// instead of the file itself.
|
||||||
|
//
|
||||||
|
// See https://pkg.go.dev/github.com/fsnotify/fsnotify@v1.7.0#readme-watching-a-file-doesn-t-work-well.
|
||||||
|
if !fi.IsDir() {
|
||||||
|
name = filepath.Dir(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return w.watcher.Add(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleEvents notifies about the received file system's event if needed. It
|
// handleEvents notifies about the received file system's event if needed. It
|
||||||
// used to be called in a separate goroutine.
|
// is intended to be used as a goroutine.
|
||||||
func (w *osWatcher) handleEvents() {
|
func (w *osWatcher) handleEvents() {
|
||||||
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
|
defer log.OnPanic(fmt.Sprintf("%s: handling events", osWatcherPref))
|
||||||
|
|
||||||
defer close(w.events)
|
defer close(w.events)
|
||||||
|
|
||||||
ch := w.w.Events
|
ch := w.watcher.Events
|
||||||
for e := range ch {
|
for e := range ch {
|
||||||
if e.Op&fsnotify.Write == 0 {
|
if e.Op&fsnotify.Write == 0 || !w.files.Has(e.Name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,3 +150,13 @@ func (w *osWatcher) handleEvents() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleErrors handles accompanying errors. It used to be called in a separate
|
||||||
|
// goroutine.
|
||||||
|
func (w *osWatcher) handleErrors() {
|
||||||
|
defer log.OnPanic(fmt.Sprintf("%s: handling errors", osWatcherPref))
|
||||||
|
|
||||||
|
for err := range w.watcher.Errors {
|
||||||
|
log.Error("%s: %s", osWatcherPref, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -7,17 +7,16 @@ import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path"
|
"path"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UnsupportedError is returned by functions and methods when a particular
|
// UnsupportedError is returned by functions and methods when a particular
|
||||||
|
@ -155,13 +154,6 @@ func IsOpenWrt() (ok bool) {
|
||||||
return isOpenWrt()
|
return isOpenWrt()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RootDirFS returns the [fs.FS] rooted at the operating system's root. On
|
|
||||||
// Windows it returns the fs.FS rooted at the volume of the system directory
|
|
||||||
// (usually, C:).
|
|
||||||
func RootDirFS() (fsys fs.FS) {
|
|
||||||
return rootDirFS()
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
|
// NotifyReconfigureSignal notifies c on receiving reconfigure signals.
|
||||||
func NotifyReconfigureSignal(c chan<- os.Signal) {
|
func NotifyReconfigureSignal(c chan<- os.Signal) {
|
||||||
notifyReconfigureSignal(c)
|
notifyReconfigureSignal(c)
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ func isOpenWrt() (ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
return nil, !stringutil.ContainsFold(string(data), osNameData), nil
|
||||||
}).Walk(RootDirFS(), etcReleasePattern)
|
}).Walk(osutil.RootDirFS(), etcReleasePattern)
|
||||||
|
|
||||||
return err == nil && ok
|
return err == nil && ok
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,17 +3,12 @@
|
||||||
package aghos
|
package aghos
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
)
|
)
|
||||||
|
|
||||||
func rootDirFS() (fsys fs.FS) {
|
|
||||||
return os.DirFS("/")
|
|
||||||
}
|
|
||||||
|
|
||||||
func notifyReconfigureSignal(c chan<- os.Signal) {
|
func notifyReconfigureSignal(c chan<- os.Signal) {
|
||||||
signal.Notify(c, unix.SIGHUP)
|
signal.Notify(c, unix.SIGHUP)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,29 +3,13 @@
|
||||||
package aghos
|
package aghos
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/fs"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func rootDirFS() (fsys fs.FS) {
|
|
||||||
// TODO(a.garipov): Use a better way if golang/go#44279 is ever resolved.
|
|
||||||
sysDir, err := windows.GetSystemDirectory()
|
|
||||||
if err != nil {
|
|
||||||
log.Error("aghos: getting root filesystem: %s; using C:", err)
|
|
||||||
|
|
||||||
// Assume that C: is the safe default.
|
|
||||||
return os.DirFS("C:")
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.DirFS(filepath.VolumeName(sysDir))
|
|
||||||
}
|
|
||||||
|
|
||||||
func setRlimit(val uint64) (err error) {
|
func setRlimit(val uint64) (err error) {
|
||||||
return Unsupported("setrlimit")
|
return Unsupported("setrlimit")
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,8 +9,13 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -71,3 +76,49 @@ func StartHTTPServer(t testing.TB, data []byte) (c *http.Client, u *url.URL) {
|
||||||
|
|
||||||
return srv.Client(), u
|
return srv.Client(), u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// testTimeout is a timeout for tests.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Move into agdctest.
|
||||||
|
const testTimeout = 1 * time.Second
|
||||||
|
|
||||||
|
// StartLocalhostUpstream is a test helper that starts a DNS server on
|
||||||
|
// localhost.
|
||||||
|
func StartLocalhostUpstream(t *testing.T, h dns.Handler) (addr *url.URL) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
startCh := make(chan netip.AddrPort)
|
||||||
|
defer close(startCh)
|
||||||
|
errCh := make(chan error)
|
||||||
|
|
||||||
|
srv := &dns.Server{
|
||||||
|
Addr: "127.0.0.1:0",
|
||||||
|
Net: string(proxy.ProtoTCP),
|
||||||
|
Handler: h,
|
||||||
|
ReadTimeout: testTimeout,
|
||||||
|
WriteTimeout: testTimeout,
|
||||||
|
}
|
||||||
|
srv.NotifyStartedFunc = func() {
|
||||||
|
addrPort := srv.Listener.Addr()
|
||||||
|
startCh <- netutil.NetAddrToAddrPort(addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() { errCh <- srv.ListenAndServe() }()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case addrPort := <-startCh:
|
||||||
|
addr = &url.URL{
|
||||||
|
Scheme: string(proxy.ProtoTCP),
|
||||||
|
Host: addrPort.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
testutil.CleanupAndRequireSuccess(t, func() (err error) { return <-errCh })
|
||||||
|
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
|
||||||
|
case err := <-errCh:
|
||||||
|
require.NoError(t, err)
|
||||||
|
case <-time.After(testTimeout):
|
||||||
|
require.FailNow(t, "timeout exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||||
|
@ -26,14 +25,25 @@ import (
|
||||||
|
|
||||||
// FSWatcher is a fake [aghos.FSWatcher] implementation for tests.
|
// FSWatcher is a fake [aghos.FSWatcher] implementation for tests.
|
||||||
type FSWatcher struct {
|
type FSWatcher struct {
|
||||||
|
OnStart func() (err error)
|
||||||
|
OnClose func() (err error)
|
||||||
OnEvents func() (e <-chan struct{})
|
OnEvents func() (e <-chan struct{})
|
||||||
OnAdd func(name string) (err error)
|
OnAdd func(name string) (err error)
|
||||||
OnClose func() (err error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
var _ aghos.FSWatcher = (*FSWatcher)(nil)
|
||||||
|
|
||||||
|
// Start implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||||
|
func (w *FSWatcher) Start() (err error) {
|
||||||
|
return w.OnStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||||
|
func (w *FSWatcher) Close() (err error) {
|
||||||
|
return w.OnClose()
|
||||||
|
}
|
||||||
|
|
||||||
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
// Events implements the [aghos.FSWatcher] interface for *FSWatcher.
|
||||||
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
func (w *FSWatcher) Events() (e <-chan struct{}) {
|
||||||
return w.OnEvents()
|
return w.OnEvents()
|
||||||
|
@ -44,11 +54,6 @@ func (w *FSWatcher) Add(name string) (err error) {
|
||||||
return w.OnAdd(name)
|
return w.OnAdd(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close implements the [aghos.FSWatcher] interface for *FSWatcher.
|
|
||||||
func (w *FSWatcher) Close() (err error) {
|
|
||||||
return w.OnClose()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Package agh
|
// Package agh
|
||||||
|
|
||||||
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
||||||
|
@ -88,9 +93,6 @@ type AddressProcessor struct {
|
||||||
OnClose func() (err error)
|
OnClose func() (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ client.AddressProcessor = (*AddressProcessor)(nil)
|
|
||||||
|
|
||||||
// Process implements the [client.AddressProcessor] interface for
|
// Process implements the [client.AddressProcessor] interface for
|
||||||
// *AddressProcessor.
|
// *AddressProcessor.
|
||||||
func (p *AddressProcessor) Process(ip netip.Addr) {
|
func (p *AddressProcessor) Process(ip netip.Addr) {
|
||||||
|
@ -108,9 +110,6 @@ type AddressUpdater struct {
|
||||||
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
|
OnUpdateAddress func(ip netip.Addr, host string, info *whois.Info)
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ client.AddressUpdater = (*AddressUpdater)(nil)
|
|
||||||
|
|
||||||
// UpdateAddress implements the [client.AddressUpdater] interface for
|
// UpdateAddress implements the [client.AddressUpdater] interface for
|
||||||
// *AddressUpdater.
|
// *AddressUpdater.
|
||||||
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
func (p *AddressUpdater) UpdateAddress(ip netip.Addr, host string, info *whois.Info) {
|
||||||
|
|
|
@ -2,6 +2,7 @@ package aghtest_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
)
|
)
|
||||||
|
@ -13,3 +14,13 @@ var _ filtering.Resolver = (*aghtest.Resolver)(nil)
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
|
var _ dnsforward.ClientsContainer = (*aghtest.ClientsContainer)(nil)
|
||||||
|
|
||||||
|
// type check
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
||||||
|
var _ client.AddressProcessor = (*aghtest.AddressProcessor)(nil)
|
||||||
|
|
||||||
|
// type check
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): It's here to avoid the import cycle. Remove it.
|
||||||
|
var _ client.AddressUpdater = (*aghtest.AddressUpdater)(nil)
|
||||||
|
|
|
@ -7,13 +7,14 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"golang.org/x/exp/slices"
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Variables and functions to substitute in tests.
|
// Variables and functions to substitute in tests.
|
||||||
|
@ -22,7 +23,7 @@ var (
|
||||||
aghosRunCommand = aghos.RunCommand
|
aghosRunCommand = aghos.RunCommand
|
||||||
|
|
||||||
// rootDirFS is the filesystem pointing to the root directory.
|
// rootDirFS is the filesystem pointing to the root directory.
|
||||||
rootDirFS = aghos.RootDirFS()
|
rootDirFS = osutil.RootDirFS()
|
||||||
)
|
)
|
||||||
|
|
||||||
// Interface stores and refreshes the network neighborhood reported by ARP
|
// Interface stores and refreshes the network neighborhood reported by ARP
|
||||||
|
|
|
@ -0,0 +1,249 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
|
)
|
||||||
|
|
||||||
|
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||||
|
type macKey any
|
||||||
|
|
||||||
|
// macToKey converts mac into key of type macKey, which is used as the key of
|
||||||
|
// the [clientIndex.macToUID]. mac must be valid MAC address.
|
||||||
|
func macToKey(mac net.HardwareAddr) (key macKey) {
|
||||||
|
switch len(mac) {
|
||||||
|
case 6:
|
||||||
|
return [6]byte(mac)
|
||||||
|
case 8:
|
||||||
|
return [8]byte(mac)
|
||||||
|
case 20:
|
||||||
|
return [20]byte(mac)
|
||||||
|
default:
|
||||||
|
panic(fmt.Errorf("invalid mac address %#v", mac))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Index stores all information about persistent clients.
|
||||||
|
type Index struct {
|
||||||
|
// clientIDToUID maps client ID to UID.
|
||||||
|
clientIDToUID map[string]UID
|
||||||
|
|
||||||
|
// ipToUID maps IP address to UID.
|
||||||
|
ipToUID map[netip.Addr]UID
|
||||||
|
|
||||||
|
// macToUID maps MAC address to UID.
|
||||||
|
macToUID map[macKey]UID
|
||||||
|
|
||||||
|
// uidToClient maps UID to the persistent client.
|
||||||
|
uidToClient map[UID]*Persistent
|
||||||
|
|
||||||
|
// subnetToUID maps subnet to UID.
|
||||||
|
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIndex initializes the new instance of client index.
|
||||||
|
func NewIndex() (ci *Index) {
|
||||||
|
return &Index{
|
||||||
|
clientIDToUID: map[string]UID{},
|
||||||
|
ipToUID: map[netip.Addr]UID{},
|
||||||
|
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||||
|
macToUID: map[macKey]UID{},
|
||||||
|
uidToClient: map[UID]*Persistent{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add stores information about a persistent client in the index. c must be
|
||||||
|
// non-nil and contain UID.
|
||||||
|
func (ci *Index) Add(c *Persistent) {
|
||||||
|
if (c.UID == UID{}) {
|
||||||
|
panic("client must contain uid")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range c.ClientIDs {
|
||||||
|
ci.clientIDToUID[id] = c.UID
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range c.IPs {
|
||||||
|
ci.ipToUID[ip] = c.UID
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pref := range c.Subnets {
|
||||||
|
ci.subnetToUID.Set(pref, c.UID)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, mac := range c.MACs {
|
||||||
|
k := macToKey(mac)
|
||||||
|
ci.macToUID[k] = c.UID
|
||||||
|
}
|
||||||
|
|
||||||
|
ci.uidToClient[c.UID] = c
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clashes returns an error if the index contains a different persistent client
|
||||||
|
// with at least a single identifier contained by c. c must be non-nil.
|
||||||
|
func (ci *Index) Clashes(c *Persistent) (err error) {
|
||||||
|
for _, id := range c.ClientIDs {
|
||||||
|
existing, ok := ci.clientIDToUID[id]
|
||||||
|
if ok && existing != c.UID {
|
||||||
|
p := ci.uidToClient[existing]
|
||||||
|
|
||||||
|
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p, ip := ci.clashesIP(c)
|
||||||
|
if p != nil {
|
||||||
|
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, s := ci.clashesSubnet(c)
|
||||||
|
if p != nil {
|
||||||
|
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
p, mac := ci.clashesMAC(c)
|
||||||
|
if p != nil {
|
||||||
|
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clashesIP returns a previous client with the same IP address as c. c must be
|
||||||
|
// non-nil.
|
||||||
|
func (ci *Index) clashesIP(c *Persistent) (p *Persistent, ip netip.Addr) {
|
||||||
|
for _, ip := range c.IPs {
|
||||||
|
existing, ok := ci.ipToUID[ip]
|
||||||
|
if ok && existing != c.UID {
|
||||||
|
return ci.uidToClient[existing], ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clashesSubnet returns a previous client with the same subnet as c. c must be
|
||||||
|
// non-nil.
|
||||||
|
func (ci *Index) clashesSubnet(c *Persistent) (p *Persistent, s netip.Prefix) {
|
||||||
|
for _, s = range c.Subnets {
|
||||||
|
var existing UID
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
|
||||||
|
if s == p {
|
||||||
|
existing = uid
|
||||||
|
ok = true
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if ok && existing != c.UID {
|
||||||
|
return ci.uidToClient[existing], s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, netip.Prefix{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clashesMAC returns a previous client with the same MAC address as c. c must
|
||||||
|
// be non-nil.
|
||||||
|
func (ci *Index) clashesMAC(c *Persistent) (p *Persistent, mac net.HardwareAddr) {
|
||||||
|
for _, mac = range c.MACs {
|
||||||
|
k := macToKey(mac)
|
||||||
|
existing, ok := ci.macToUID[k]
|
||||||
|
if ok && existing != c.UID {
|
||||||
|
return ci.uidToClient[existing], mac
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find finds persistent client by string representation of the client ID, IP
|
||||||
|
// address, or MAC.
|
||||||
|
func (ci *Index) Find(id string) (c *Persistent, ok bool) {
|
||||||
|
uid, found := ci.clientIDToUID[id]
|
||||||
|
if found {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := netip.ParseAddr(id)
|
||||||
|
if err == nil {
|
||||||
|
// MAC addresses can be successfully parsed as IP addresses.
|
||||||
|
c, found = ci.findByIP(ip)
|
||||||
|
if found {
|
||||||
|
return c, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mac, err := net.ParseMAC(id)
|
||||||
|
if err == nil {
|
||||||
|
return ci.findByMAC(mac)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// find finds persistent client by IP address.
|
||||||
|
func (ci *Index) findByIP(ip netip.Addr) (c *Persistent, found bool) {
|
||||||
|
uid, found := ci.ipToUID[ip]
|
||||||
|
if found {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
|
||||||
|
if pref.Contains(ip) {
|
||||||
|
uid, found = id, true
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if found {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// find finds persistent client by MAC.
|
||||||
|
func (ci *Index) findByMAC(mac net.HardwareAddr) (c *Persistent, found bool) {
|
||||||
|
k := macToKey(mac)
|
||||||
|
uid, found := ci.macToUID[k]
|
||||||
|
if found {
|
||||||
|
return ci.uidToClient[uid], true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes information about persistent client from the index. c must be
|
||||||
|
// non-nil.
|
||||||
|
func (ci *Index) Delete(c *Persistent) {
|
||||||
|
for _, id := range c.ClientIDs {
|
||||||
|
delete(ci.clientIDToUID, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range c.IPs {
|
||||||
|
delete(ci.ipToUID, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pref := range c.Subnets {
|
||||||
|
ci.subnetToUID.Del(pref)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, mac := range c.MACs {
|
||||||
|
k := macToKey(mac)
|
||||||
|
delete(ci.macToUID, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(ci.uidToClient, c.UID)
|
||||||
|
}
|
|
@ -0,0 +1,223 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newIDIndex is a helper function that returns a client index filled with
|
||||||
|
// persistent clients from the m. It also generates a UID for each client.
|
||||||
|
func newIDIndex(m []*Persistent) (ci *Index) {
|
||||||
|
ci = NewIndex()
|
||||||
|
|
||||||
|
for _, c := range m {
|
||||||
|
c.UID = MustNewUID()
|
||||||
|
ci.Add(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ci
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientIndex(t *testing.T) {
|
||||||
|
const (
|
||||||
|
cliIPNone = "1.2.3.4"
|
||||||
|
cliIP1 = "1.1.1.1"
|
||||||
|
cliIP2 = "2.2.2.2"
|
||||||
|
|
||||||
|
cliIPv6 = "1:2:3::4"
|
||||||
|
|
||||||
|
cliSubnet = "2.2.2.0/24"
|
||||||
|
cliSubnetIP = "2.2.2.222"
|
||||||
|
|
||||||
|
cliID = "client-id"
|
||||||
|
cliMAC = "11:11:11:11:11:11"
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*Persistent{{
|
||||||
|
Name: "client1",
|
||||||
|
IPs: []netip.Addr{
|
||||||
|
netip.MustParseAddr(cliIP1),
|
||||||
|
netip.MustParseAddr(cliIPv6),
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Name: "client2",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||||
|
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||||
|
}, {
|
||||||
|
Name: "client_with_mac",
|
||||||
|
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||||
|
}, {
|
||||||
|
Name: "client_with_id",
|
||||||
|
ClientIDs: []string{cliID},
|
||||||
|
}}
|
||||||
|
|
||||||
|
ci := newIDIndex(clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
want *Persistent
|
||||||
|
name string
|
||||||
|
ids []string
|
||||||
|
}{{
|
||||||
|
name: "ipv4_ipv6",
|
||||||
|
ids: []string{cliIP1, cliIPv6},
|
||||||
|
want: clients[0],
|
||||||
|
}, {
|
||||||
|
name: "ipv4_subnet",
|
||||||
|
ids: []string{cliIP2, cliSubnetIP},
|
||||||
|
want: clients[1],
|
||||||
|
}, {
|
||||||
|
name: "mac",
|
||||||
|
ids: []string{cliMAC},
|
||||||
|
want: clients[2],
|
||||||
|
}, {
|
||||||
|
name: "client_id",
|
||||||
|
ids: []string{cliID},
|
||||||
|
want: clients[3],
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
for _, id := range tc.ids {
|
||||||
|
c, ok := ci.Find(id)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
assert.Equal(t, tc.want, c)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("not_found", func(t *testing.T) {
|
||||||
|
_, ok := ci.Find(cliIPNone)
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientIndex_Clashes(t *testing.T) {
|
||||||
|
const (
|
||||||
|
cliIP1 = "1.1.1.1"
|
||||||
|
cliSubnet = "2.2.2.0/24"
|
||||||
|
cliSubnetIP = "2.2.2.222"
|
||||||
|
cliID = "client-id"
|
||||||
|
cliMAC = "11:11:11:11:11:11"
|
||||||
|
)
|
||||||
|
|
||||||
|
clients := []*Persistent{{
|
||||||
|
Name: "client_with_ip",
|
||||||
|
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||||
|
}, {
|
||||||
|
Name: "client_with_subnet",
|
||||||
|
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||||
|
}, {
|
||||||
|
Name: "client_with_mac",
|
||||||
|
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||||
|
}, {
|
||||||
|
Name: "client_with_id",
|
||||||
|
ClientIDs: []string{cliID},
|
||||||
|
}}
|
||||||
|
|
||||||
|
ci := newIDIndex(clients)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
client *Persistent
|
||||||
|
name string
|
||||||
|
}{{
|
||||||
|
name: "ipv4",
|
||||||
|
client: clients[0],
|
||||||
|
}, {
|
||||||
|
name: "subnet",
|
||||||
|
client: clients[1],
|
||||||
|
}, {
|
||||||
|
name: "mac",
|
||||||
|
client: clients[2],
|
||||||
|
}, {
|
||||||
|
name: "client_id",
|
||||||
|
client: clients[3],
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
clone := tc.client.ShallowClone()
|
||||||
|
clone.UID = MustNewUID()
|
||||||
|
|
||||||
|
err := ci.Clashes(clone)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
ci.Delete(tc.client)
|
||||||
|
err = ci.Clashes(clone)
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||||
|
// error.
|
||||||
|
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||||
|
mac, err := net.ParseMAC(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mac
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMACToKey(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
want any
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
}{{
|
||||||
|
name: "column6",
|
||||||
|
in: "00:00:5e:00:53:01",
|
||||||
|
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
|
||||||
|
}, {
|
||||||
|
name: "column8",
|
||||||
|
in: "02:00:5e:10:00:00:00:01",
|
||||||
|
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
|
||||||
|
}, {
|
||||||
|
name: "column20",
|
||||||
|
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
|
||||||
|
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
|
||||||
|
}, {
|
||||||
|
name: "hyphen6",
|
||||||
|
in: "00-00-5e-00-53-01",
|
||||||
|
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
|
||||||
|
}, {
|
||||||
|
name: "hyphen8",
|
||||||
|
in: "02-00-5e-10-00-00-00-01",
|
||||||
|
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
|
||||||
|
}, {
|
||||||
|
name: "hyphen20",
|
||||||
|
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
|
||||||
|
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
|
||||||
|
}, {
|
||||||
|
name: "dot6",
|
||||||
|
in: "0000.5e00.5301",
|
||||||
|
want: [6]byte(mustParseMAC("0000.5e00.5301")),
|
||||||
|
}, {
|
||||||
|
name: "dot8",
|
||||||
|
in: "0200.5e10.0000.0001",
|
||||||
|
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
|
||||||
|
}, {
|
||||||
|
name: "dot20",
|
||||||
|
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
|
||||||
|
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
mac := mustParseMAC(tc.in)
|
||||||
|
|
||||||
|
key := macToKey(mac)
|
||||||
|
assert.Equal(t, tc.want, key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
mac := net.HardwareAddr([]byte{1, 2, 3})
|
||||||
|
_ = macToKey(mac)
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,22 +1,22 @@
|
||||||
package home
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding"
|
"encoding"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering/safesearch"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// UID is the type for the unique IDs of persistent clients.
|
// UID is the type for the unique IDs of persistent clients.
|
||||||
|
@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) {
|
||||||
return UID(uuidv7), err
|
return UID(uuidv7), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
|
||||||
|
func MustNewUID() (uid UID) {
|
||||||
|
uid, err := NewUID()
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return uid
|
||||||
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
var _ encoding.TextMarshaler = UID{}
|
var _ encoding.TextMarshaler = UID{}
|
||||||
|
|
||||||
|
@ -46,16 +56,16 @@ func (uid *UID) UnmarshalText(data []byte) error {
|
||||||
return (*uuid.UUID)(uid).UnmarshalText(data)
|
return (*uuid.UUID)(uid).UnmarshalText(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persistentClient contains information about persistent clients.
|
// Persistent contains information about persistent clients.
|
||||||
type persistentClient struct {
|
type Persistent struct {
|
||||||
// upstreamConfig is the custom upstream configuration for this client. If
|
// UpstreamConfig is the custom upstream configuration for this client. If
|
||||||
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
// it's nil, it has not been initialized yet. If it's non-nil and empty,
|
||||||
// there are no valid upstreams. If it's non-nil and non-empty, these
|
// there are no valid upstreams. If it's non-nil and non-empty, these
|
||||||
// upstream must be used.
|
// upstream must be used.
|
||||||
upstreamConfig *proxy.CustomUpstreamConfig
|
UpstreamConfig *proxy.CustomUpstreamConfig
|
||||||
|
|
||||||
// TODO(d.kolyshev): Make safeSearchConf a pointer.
|
// TODO(d.kolyshev): Make SafeSearchConf a pointer.
|
||||||
safeSearchConf filtering.SafeSearchConfig
|
SafeSearchConf filtering.SafeSearchConfig
|
||||||
SafeSearch filtering.SafeSearch
|
SafeSearch filtering.SafeSearch
|
||||||
|
|
||||||
// BlockedServices is the configuration of blocked services of a client.
|
// BlockedServices is the configuration of blocked services of a client.
|
||||||
|
@ -87,8 +97,8 @@ type persistentClient struct {
|
||||||
IgnoreStatistics bool
|
IgnoreStatistics bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// setTags sets the tags if they are known, otherwise logs an unknown tag.
|
// SetTags sets the tags if they are known, otherwise logs an unknown tag.
|
||||||
func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
|
func (c *Persistent) SetTags(tags []string, known *stringutil.Set) {
|
||||||
for _, t := range tags {
|
for _, t := range tags {
|
||||||
if !known.Has(t) {
|
if !known.Has(t) {
|
||||||
log.Info("skipping unknown tag %q", t)
|
log.Info("skipping unknown tag %q", t)
|
||||||
|
@ -102,9 +112,9 @@ func (c *persistentClient) setTags(tags []string, known *stringutil.Set) {
|
||||||
slices.Sort(c.Tags)
|
slices.Sort(c.Tags)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setIDs parses a list of strings into typed fields and returns an error if
|
// SetIDs parses a list of strings into typed fields and returns an error if
|
||||||
// there is one.
|
// there is one.
|
||||||
func (c *persistentClient) setIDs(ids []string) (err error) {
|
func (c *Persistent) SetIDs(ids []string) (err error) {
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
err = c.setID(id)
|
err = c.setID(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -144,7 +154,7 @@ func subnetCompare(x, y netip.Prefix) (cmp int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// setID parses id into typed field if there is no error.
|
// setID parses id into typed field if there is no error.
|
||||||
func (c *persistentClient) setID(id string) (err error) {
|
func (c *Persistent) setID(id string) (err error) {
|
||||||
if id == "" {
|
if id == "" {
|
||||||
return errors.Error("clientid is empty")
|
return errors.Error("clientid is empty")
|
||||||
}
|
}
|
||||||
|
@ -170,7 +180,7 @@ func (c *persistentClient) setID(id string) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = dnsforward.ValidateClientID(id)
|
err = ValidateClientID(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
|
@ -181,9 +191,23 @@ func (c *persistentClient) setID(id string) (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ids returns a list of client ids containing at least one element.
|
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||||
func (c *persistentClient) ids() (ids []string) {
|
//
|
||||||
ids = make([]string, 0, c.idsLen())
|
// TODO(s.chzhen): It's an exact copy of the [dnsforward.ValidateClientID] to
|
||||||
|
// avoid the import cycle. Remove it.
|
||||||
|
func ValidateClientID(id string) (err error) {
|
||||||
|
err = netutil.ValidateHostnameLabel(id)
|
||||||
|
if err != nil {
|
||||||
|
// Replace the domain name label wrapper with our own.
|
||||||
|
return fmt.Errorf("invalid clientid %q: %w", id, errors.Unwrap(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDs returns a list of client IDs containing at least one element.
|
||||||
|
func (c *Persistent) IDs() (ids []string) {
|
||||||
|
ids = make([]string, 0, c.IDsLen())
|
||||||
|
|
||||||
for _, ip := range c.IPs {
|
for _, ip := range c.IPs {
|
||||||
ids = append(ids, ip.String())
|
ids = append(ids, ip.String())
|
||||||
|
@ -200,24 +224,24 @@ func (c *persistentClient) ids() (ids []string) {
|
||||||
return append(ids, c.ClientIDs...)
|
return append(ids, c.ClientIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// idsLen returns a length of client ids.
|
// IDsLen returns a length of client ids.
|
||||||
func (c *persistentClient) idsLen() (n int) {
|
func (c *Persistent) IDsLen() (n int) {
|
||||||
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
return len(c.IPs) + len(c.Subnets) + len(c.MACs) + len(c.ClientIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// equalIDs returns true if the ids of the current and previous clients are the
|
// EqualIDs returns true if the ids of the current and previous clients are the
|
||||||
// same.
|
// same.
|
||||||
func (c *persistentClient) equalIDs(prev *persistentClient) (equal bool) {
|
func (c *Persistent) EqualIDs(prev *Persistent) (equal bool) {
|
||||||
return slices.Equal(c.IPs, prev.IPs) &&
|
return slices.Equal(c.IPs, prev.IPs) &&
|
||||||
slices.Equal(c.Subnets, prev.Subnets) &&
|
slices.Equal(c.Subnets, prev.Subnets) &&
|
||||||
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
|
slices.EqualFunc(c.MACs, prev.MACs, slices.Equal[net.HardwareAddr]) &&
|
||||||
slices.Equal(c.ClientIDs, prev.ClientIDs)
|
slices.Equal(c.ClientIDs, prev.ClientIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
// shallowClone returns a deep copy of the client, except upstreamConfig,
|
// ShallowClone returns a deep copy of the client, except upstreamConfig,
|
||||||
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
|
// safeSearchConf, SafeSearch fields, because it's difficult to copy them.
|
||||||
func (c *persistentClient) shallowClone() (clone *persistentClient) {
|
func (c *Persistent) ShallowClone() (clone *Persistent) {
|
||||||
clone = &persistentClient{}
|
clone = &Persistent{}
|
||||||
*clone = *c
|
*clone = *c
|
||||||
|
|
||||||
clone.BlockedServices = c.BlockedServices.Clone()
|
clone.BlockedServices = c.BlockedServices.Clone()
|
||||||
|
@ -232,10 +256,10 @@ func (c *persistentClient) shallowClone() (clone *persistentClient) {
|
||||||
return clone
|
return clone
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeUpstreams closes the client-specific upstream config of c if any.
|
// CloseUpstreams closes the client-specific upstream config of c if any.
|
||||||
func (c *persistentClient) closeUpstreams() (err error) {
|
func (c *Persistent) CloseUpstreams() (err error) {
|
||||||
if c.upstreamConfig != nil {
|
if c.UpstreamConfig != nil {
|
||||||
if err = c.upstreamConfig.Close(); err != nil {
|
if err = c.UpstreamConfig.Close(); err != nil {
|
||||||
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
return fmt.Errorf("closing upstreams of client %q: %w", c.Name, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -243,8 +267,8 @@ func (c *persistentClient) closeUpstreams() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setSafeSearch initializes and sets the safe search filter for this client.
|
// SetSafeSearch initializes and sets the safe search filter for this client.
|
||||||
func (c *persistentClient) setSafeSearch(
|
func (c *Persistent) SetSafeSearch(
|
||||||
conf filtering.SafeSearchConfig,
|
conf filtering.SafeSearchConfig,
|
||||||
cacheSize uint,
|
cacheSize uint,
|
||||||
cacheTTL time.Duration,
|
cacheTTL time.Duration,
|
|
@ -1,4 +1,4 @@
|
||||||
package home
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -27,10 +27,10 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
|
||||||
)
|
)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
want assert.BoolAssertionFunc
|
||||||
name string
|
name string
|
||||||
ids []string
|
ids []string
|
||||||
prevIDs []string
|
prevIDs []string
|
||||||
want assert.BoolAssertionFunc
|
|
||||||
}{{
|
}{{
|
||||||
name: "single_ip",
|
name: "single_ip",
|
||||||
ids: []string{ip1},
|
ids: []string{ip1},
|
||||||
|
@ -110,15 +110,15 @@ func TestPersistentClient_EqualIDs(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
c := &persistentClient{}
|
c := &Persistent{}
|
||||||
err := c.setIDs(tc.ids)
|
err := c.SetIDs(tc.ids)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
prev := &persistentClient{}
|
prev := &Persistent{}
|
||||||
err = prev.setIDs(tc.prevIDs)
|
err = prev.SetIDs(tc.prevIDs)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tc.want(t, c.equalIDs(prev))
|
tc.want(t, c.EqualIDs(prev))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,7 +16,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/google/renameio/v2/maybe"
|
"github.com/google/renameio/v2/maybe"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -13,7 +14,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
|
@ -19,7 +20,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type v4ServerConfJSON struct {
|
type v4ServerConfJSON struct {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -20,7 +21,6 @@ import (
|
||||||
"github.com/go-ping/ping"
|
"github.com/go-ping/ping"
|
||||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// v4Server is a DHCPv4 server.
|
// v4Server is a DHCPv4 server.
|
||||||
|
|
|
@ -2,11 +2,11 @@ package dhcpsvc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is the configuration for the DHCP service.
|
// Config is the configuration for the DHCP service.
|
||||||
|
@ -19,6 +19,8 @@ type Config struct {
|
||||||
// clients' hostnames.
|
// clients' hostnames.
|
||||||
LocalDomainName string
|
LocalDomainName string
|
||||||
|
|
||||||
|
// TODO(e.burkov): Add DB path.
|
||||||
|
|
||||||
// ICMPTimeout is the timeout for checking another DHCP server's presence.
|
// ICMPTimeout is the timeout for checking another DHCP server's presence.
|
||||||
ICMPTimeout time.Duration
|
ICMPTimeout time.Duration
|
||||||
|
|
||||||
|
@ -68,12 +70,6 @@ func (conf *Config) Validate() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newMustErr returns an error that indicates that valName must be as must
|
|
||||||
// describes.
|
|
||||||
func newMustErr(valName, must string, val fmt.Stringer) (err error) {
|
|
||||||
return fmt.Errorf("%s %s must %s", valName, val, must)
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate returns an error in ic, if any.
|
// validate returns an error in ic, if any.
|
||||||
func (ic *InterfaceConfig) validate() (err error) {
|
func (ic *InterfaceConfig) validate() (err error) {
|
||||||
if ic == nil {
|
if ic == nil {
|
||||||
|
|
|
@ -7,48 +7,16 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Lease is a DHCP lease.
|
// Interface is a DHCP service.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Consider moving it to [agh], since it also may be needed in
|
// TODO(e.burkov): Separate HostByIP, MACByIP, IPByHost into a separate
|
||||||
// [websvc].
|
// interface. This is also applicable to Enabled method.
|
||||||
type Lease struct {
|
//
|
||||||
// IP is the IP address leased to the client.
|
// TODO(e.burkov): Reconsider the requirements for the leases validity.
|
||||||
IP netip.Addr
|
|
||||||
|
|
||||||
// Expiry is the expiration time of the lease.
|
|
||||||
Expiry time.Time
|
|
||||||
|
|
||||||
// Hostname of the client.
|
|
||||||
Hostname string
|
|
||||||
|
|
||||||
// HWAddr is the physical hardware address (MAC address).
|
|
||||||
HWAddr net.HardwareAddr
|
|
||||||
|
|
||||||
// IsStatic defines if the lease is static.
|
|
||||||
IsStatic bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone returns a deep copy of l.
|
|
||||||
func (l *Lease) Clone() (clone *Lease) {
|
|
||||||
if l == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Lease{
|
|
||||||
Expiry: l.Expiry,
|
|
||||||
Hostname: l.Hostname,
|
|
||||||
HWAddr: slices.Clone(l.HWAddr),
|
|
||||||
IP: l.IP,
|
|
||||||
IsStatic: l.IsStatic,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Interface interface {
|
type Interface interface {
|
||||||
agh.ServiceWithConfig[*Config]
|
agh.ServiceWithConfig[*Config]
|
||||||
|
|
||||||
|
@ -63,6 +31,8 @@ type Interface interface {
|
||||||
// MACByIP returns the MAC address for the given IP address leased. It
|
// MACByIP returns the MAC address for the given IP address leased. It
|
||||||
// returns nil if there is no such client, due to an assumption that a DHCP
|
// returns nil if there is no such client, due to an assumption that a DHCP
|
||||||
// client must always have a MAC address.
|
// client must always have a MAC address.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Think of a contract for the returned value.
|
||||||
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
MACByIP(ip netip.Addr) (mac net.HardwareAddr)
|
||||||
|
|
||||||
// IPByHost returns the IP address of the DHCP client with the given
|
// IPByHost returns the IP address of the DHCP client with the given
|
||||||
|
@ -71,26 +41,29 @@ type Interface interface {
|
||||||
// hostname, either set or generated.
|
// hostname, either set or generated.
|
||||||
IPByHost(host string) (ip netip.Addr)
|
IPByHost(host string) (ip netip.Addr)
|
||||||
|
|
||||||
// Leases returns all the active DHCP leases.
|
// Leases returns all the active DHCP leases. The returned slice should be
|
||||||
|
// a clone.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Consider implementing iterating methods with appropriate
|
// TODO(e.burkov): Consider implementing iterating methods with appropriate
|
||||||
// signatures instead of cloning the whole list.
|
// signatures instead of cloning the whole list.
|
||||||
Leases() (ls []*Lease)
|
Leases() (ls []*Lease)
|
||||||
|
|
||||||
// AddLease adds a new DHCP lease. It returns an error if the lease is
|
// AddLease adds a new DHCP lease. l must be valid. It returns an error if
|
||||||
// invalid or already exists.
|
// l already exists.
|
||||||
AddLease(l *Lease) (err error)
|
AddLease(l *Lease) (err error)
|
||||||
|
|
||||||
// UpdateStaticLease changes an existing DHCP lease. It returns an error if
|
// UpdateStaticLease replaces an existing static DHCP lease. l must be
|
||||||
// there is no lease with such hardware addressor if new values are invalid
|
// valid. It returns an error if the lease with the given hardware address
|
||||||
// or already exist.
|
// doesn't exist or if other values match another existing lease.
|
||||||
UpdateStaticLease(l *Lease) (err error)
|
UpdateStaticLease(l *Lease) (err error)
|
||||||
|
|
||||||
// RemoveLease removes an existing DHCP lease. It returns an error if there
|
// RemoveLease removes an existing DHCP lease. l must be valid. It returns
|
||||||
// is no lease equal to l.
|
// an error if there is no lease equal to l.
|
||||||
RemoveLease(l *Lease) (err error)
|
RemoveLease(l *Lease) (err error)
|
||||||
|
|
||||||
// Reset removes all the DHCP leases.
|
// Reset removes all the DHCP leases.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): If it's really needed?
|
||||||
Reset() (err error)
|
Reset() (err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package dhcpsvc
|
package dhcpsvc
|
||||||
|
|
||||||
import "github.com/AdguardTeam/golibs/errors"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// errNilConfig is returned when a nil config met.
|
// errNilConfig is returned when a nil config met.
|
||||||
|
@ -9,3 +13,9 @@ const (
|
||||||
// errNoInterfaces is returned when no interfaces found in configuration.
|
// errNoInterfaces is returned when no interfaces found in configuration.
|
||||||
errNoInterfaces errors.Error = "no interfaces specified"
|
errNoInterfaces errors.Error = "no interfaces specified"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// newMustErr returns an error that indicates that valName must be as must
|
||||||
|
// describes.
|
||||||
|
func newMustErr(valName, must string, val fmt.Stringer) (err error) {
|
||||||
|
return fmt.Errorf("%s %s must %s", valName, val, must)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,66 @@
|
||||||
|
package dhcpsvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// netInterface is a common part of any network interface within the DHCP
|
||||||
|
// server.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Add other methods as [DHCPServer] evolves.
|
||||||
|
type netInterface struct {
|
||||||
|
// name is the name of the network interface.
|
||||||
|
name string
|
||||||
|
|
||||||
|
// leases is a set of leases sorted by hardware address.
|
||||||
|
leases []*Lease
|
||||||
|
|
||||||
|
// leaseTTL is the default Time-To-Live value for leases.
|
||||||
|
leaseTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset clears all the slices in iface for reuse.
|
||||||
|
func (iface *netInterface) reset() {
|
||||||
|
iface.leases = iface.leases[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertLease inserts the given lease into iface. It returns an error if the
|
||||||
|
// lease can't be inserted.
|
||||||
|
func (iface *netInterface) insertLease(l *Lease) (err error) {
|
||||||
|
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||||
|
if found {
|
||||||
|
return fmt.Errorf("lease for mac %s already exists", l.HWAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
iface.leases = slices.Insert(iface.leases, i, l)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLease replaces an existing lease within iface with the given one. It
|
||||||
|
// returns an error if there is no lease with such hardware address.
|
||||||
|
func (iface *netInterface) updateLease(l *Lease) (prev *Lease, err error) {
|
||||||
|
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("no lease for mac %s", l.HWAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
prev, iface.leases[i] = iface.leases[i], l
|
||||||
|
|
||||||
|
return prev, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeLease removes an existing lease from iface. It returns an error if
|
||||||
|
// there is no lease equal to l.
|
||||||
|
func (iface *netInterface) removeLease(l *Lease) (err error) {
|
||||||
|
i, found := slices.BinarySearchFunc(iface.leases, l, compareLeaseMAC)
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("no lease for mac %s", l.HWAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
iface.leases = slices.Delete(iface.leases, i, i+1)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package dhcpsvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Lease is a DHCP lease.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Consider moving it to [agh], since it also may be needed in
|
||||||
|
// [websvc].
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Add validation method.
|
||||||
|
type Lease struct {
|
||||||
|
// IP is the IP address leased to the client.
|
||||||
|
IP netip.Addr
|
||||||
|
|
||||||
|
// Expiry is the expiration time of the lease.
|
||||||
|
Expiry time.Time
|
||||||
|
|
||||||
|
// Hostname of the client.
|
||||||
|
Hostname string
|
||||||
|
|
||||||
|
// HWAddr is the physical hardware address (MAC address).
|
||||||
|
HWAddr net.HardwareAddr
|
||||||
|
|
||||||
|
// IsStatic defines if the lease is static.
|
||||||
|
IsStatic bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a deep copy of l.
|
||||||
|
func (l *Lease) Clone() (clone *Lease) {
|
||||||
|
if l == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Lease{
|
||||||
|
Expiry: l.Expiry,
|
||||||
|
Hostname: l.Hostname,
|
||||||
|
HWAddr: slices.Clone(l.HWAddr),
|
||||||
|
IP: l.IP,
|
||||||
|
IsStatic: l.IsStatic,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// compareLeaseMAC compares two [Lease]s by hardware address.
|
||||||
|
func compareLeaseMAC(a, b *Lease) (res int) {
|
||||||
|
return bytes.Compare(a.HWAddr, b.HWAddr)
|
||||||
|
}
|
|
@ -0,0 +1,126 @@
|
||||||
|
package dhcpsvc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// leaseIndex is the set of leases indexed by their identifiers for quick
|
||||||
|
// lookup.
|
||||||
|
type leaseIndex struct {
|
||||||
|
// byAddr is a lookup shortcut for leases by their IP addresses.
|
||||||
|
byAddr map[netip.Addr]*Lease
|
||||||
|
|
||||||
|
// byName is a lookup shortcut for leases by their hostnames.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Use a slice of leases with the same hostname?
|
||||||
|
byName map[string]*Lease
|
||||||
|
}
|
||||||
|
|
||||||
|
// newLeaseIndex returns a new index for [Lease]s.
|
||||||
|
func newLeaseIndex() *leaseIndex {
|
||||||
|
return &leaseIndex{
|
||||||
|
byAddr: map[netip.Addr]*Lease{},
|
||||||
|
byName: map[string]*Lease{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// leaseByAddr returns a lease by its IP address.
|
||||||
|
func (idx *leaseIndex) leaseByAddr(addr netip.Addr) (l *Lease, ok bool) {
|
||||||
|
l, ok = idx.byAddr[addr]
|
||||||
|
|
||||||
|
return l, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// leaseByName returns a lease by its hostname.
|
||||||
|
func (idx *leaseIndex) leaseByName(name string) (l *Lease, ok bool) {
|
||||||
|
// TODO(e.burkov): Probably, use a case-insensitive comparison and store in
|
||||||
|
// slice. This would require a benchmark.
|
||||||
|
l, ok = idx.byName[strings.ToLower(name)]
|
||||||
|
|
||||||
|
return l, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear removes all leases from idx.
|
||||||
|
func (idx *leaseIndex) clear() {
|
||||||
|
clear(idx.byAddr)
|
||||||
|
clear(idx.byName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add adds l into idx and into iface. l must be valid, iface should be
|
||||||
|
// responsible for l's IP. It returns an error if l duplicates at least a
|
||||||
|
// single value of another lease.
|
||||||
|
func (idx *leaseIndex) add(l *Lease, iface *netInterface) (err error) {
|
||||||
|
loweredName := strings.ToLower(l.Hostname)
|
||||||
|
|
||||||
|
if _, ok := idx.byAddr[l.IP]; ok {
|
||||||
|
return fmt.Errorf("lease for ip %s already exists", l.IP)
|
||||||
|
} else if _, ok = idx.byName[loweredName]; ok {
|
||||||
|
return fmt.Errorf("lease for hostname %s already exists", l.Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = iface.insertLease(l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
idx.byAddr[l.IP] = l
|
||||||
|
idx.byName[loweredName] = l
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove removes l from idx and from iface. l must be valid, iface should
|
||||||
|
// contain the same lease or the lease itself. It returns an error if the lease
|
||||||
|
// not found.
|
||||||
|
func (idx *leaseIndex) remove(l *Lease, iface *netInterface) (err error) {
|
||||||
|
loweredName := strings.ToLower(l.Hostname)
|
||||||
|
|
||||||
|
if _, ok := idx.byAddr[l.IP]; !ok {
|
||||||
|
return fmt.Errorf("no lease for ip %s", l.IP)
|
||||||
|
} else if _, ok = idx.byName[loweredName]; !ok {
|
||||||
|
return fmt.Errorf("no lease for hostname %s", l.Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = iface.removeLease(l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(idx.byAddr, l.IP)
|
||||||
|
delete(idx.byName, loweredName)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update updates l in idx and in iface. l must be valid, iface should be
|
||||||
|
// responsible for l's IP. It returns an error if l duplicates at least a
|
||||||
|
// single value of another lease, except for the updated lease itself.
|
||||||
|
func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) {
|
||||||
|
loweredName := strings.ToLower(l.Hostname)
|
||||||
|
|
||||||
|
existing, ok := idx.byAddr[l.IP]
|
||||||
|
if ok && !slices.Equal(l.HWAddr, existing.HWAddr) {
|
||||||
|
return fmt.Errorf("lease for ip %s already exists", l.IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
existing, ok = idx.byName[loweredName]
|
||||||
|
if ok && !slices.Equal(l.HWAddr, existing.HWAddr) {
|
||||||
|
return fmt.Errorf("lease for hostname %s already exists", l.Hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
prev, err := iface.updateLease(l)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(idx.byAddr, prev.IP)
|
||||||
|
delete(idx.byName, strings.ToLower(prev.Hostname))
|
||||||
|
|
||||||
|
idx.byAddr[l.IP] = l
|
||||||
|
idx.byName[loweredName] = l
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -2,11 +2,15 @@ package dhcpsvc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
|
// DHCPServer is a DHCP server for both IPv4 and IPv6 address families.
|
||||||
|
@ -15,18 +19,21 @@ type DHCPServer struct {
|
||||||
// information about its clients.
|
// information about its clients.
|
||||||
enabled *atomic.Bool
|
enabled *atomic.Bool
|
||||||
|
|
||||||
// localTLD is the top-level domain name to use for resolving DHCP
|
// localTLD is the top-level domain name to use for resolving DHCP clients'
|
||||||
// clients' hostnames.
|
// hostnames.
|
||||||
localTLD string
|
localTLD string
|
||||||
|
|
||||||
|
// leasesMu protects the leases index as well as leases in the interfaces.
|
||||||
|
leasesMu *sync.RWMutex
|
||||||
|
|
||||||
|
// leases stores the DHCP leases for quick lookups.
|
||||||
|
leases *leaseIndex
|
||||||
|
|
||||||
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
|
// interfaces4 is the set of IPv4 interfaces sorted by interface name.
|
||||||
interfaces4 []*iface4
|
interfaces4 netInterfacesV4
|
||||||
|
|
||||||
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
|
// interfaces6 is the set of IPv6 interfaces sorted by interface name.
|
||||||
interfaces6 []*iface6
|
interfaces6 netInterfacesV6
|
||||||
|
|
||||||
// leases is the set of active DHCP leases.
|
|
||||||
leases []*Lease
|
|
||||||
|
|
||||||
// icmpTimeout is the timeout for checking another DHCP server's presence.
|
// icmpTimeout is the timeout for checking another DHCP server's presence.
|
||||||
icmpTimeout time.Duration
|
icmpTimeout time.Duration
|
||||||
|
@ -42,26 +49,27 @@ func New(conf *Config) (srv *DHCPServer, err error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
ifaces4 := make([]*iface4, len(conf.Interfaces))
|
// TODO(e.burkov): Add validations scoped to the network interfaces set.
|
||||||
ifaces6 := make([]*iface6, len(conf.Interfaces))
|
ifaces4 := make(netInterfacesV4, 0, len(conf.Interfaces))
|
||||||
|
ifaces6 := make(netInterfacesV6, 0, len(conf.Interfaces))
|
||||||
|
|
||||||
ifaceNames := maps.Keys(conf.Interfaces)
|
ifaceNames := maps.Keys(conf.Interfaces)
|
||||||
slices.Sort(ifaceNames)
|
slices.Sort(ifaceNames)
|
||||||
|
|
||||||
var i4 *iface4
|
var i4 *netInterfaceV4
|
||||||
var i6 *iface6
|
var i6 *netInterfaceV6
|
||||||
|
|
||||||
for _, ifaceName := range ifaceNames {
|
for _, ifaceName := range ifaceNames {
|
||||||
iface := conf.Interfaces[ifaceName]
|
iface := conf.Interfaces[ifaceName]
|
||||||
|
|
||||||
i4, err = newIface4(ifaceName, iface.IPv4)
|
i4, err = newNetInterfaceV4(ifaceName, iface.IPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("interface %q: ipv4: %w", ifaceName, err)
|
return nil, fmt.Errorf("interface %q: ipv4: %w", ifaceName, err)
|
||||||
} else if i4 != nil {
|
} else if i4 != nil {
|
||||||
ifaces4 = append(ifaces4, i4)
|
ifaces4 = append(ifaces4, i4)
|
||||||
}
|
}
|
||||||
|
|
||||||
i6 = newIface6(ifaceName, iface.IPv6)
|
i6 = newNetInterfaceV6(ifaceName, iface.IPv6)
|
||||||
if i6 != nil {
|
if i6 != nil {
|
||||||
ifaces6 = append(ifaces6, i6)
|
ifaces6 = append(ifaces6, i6)
|
||||||
}
|
}
|
||||||
|
@ -70,13 +78,19 @@ func New(conf *Config) (srv *DHCPServer, err error) {
|
||||||
enabled := &atomic.Bool{}
|
enabled := &atomic.Bool{}
|
||||||
enabled.Store(conf.Enabled)
|
enabled.Store(conf.Enabled)
|
||||||
|
|
||||||
return &DHCPServer{
|
srv = &DHCPServer{
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
|
localTLD: conf.LocalDomainName,
|
||||||
|
leasesMu: &sync.RWMutex{},
|
||||||
|
leases: newLeaseIndex(),
|
||||||
interfaces4: ifaces4,
|
interfaces4: ifaces4,
|
||||||
interfaces6: ifaces6,
|
interfaces6: ifaces6,
|
||||||
localTLD: conf.LocalDomainName,
|
|
||||||
icmpTimeout: conf.ICMPTimeout,
|
icmpTimeout: conf.ICMPTimeout,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
// TODO(e.burkov): Load leases.
|
||||||
|
|
||||||
|
return srv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// type check
|
// type check
|
||||||
|
@ -91,10 +105,140 @@ func (srv *DHCPServer) Enabled() (ok bool) {
|
||||||
|
|
||||||
// Leases implements the [Interface] interface for *DHCPServer.
|
// Leases implements the [Interface] interface for *DHCPServer.
|
||||||
func (srv *DHCPServer) Leases() (leases []*Lease) {
|
func (srv *DHCPServer) Leases() (leases []*Lease) {
|
||||||
leases = make([]*Lease, 0, len(srv.leases))
|
srv.leasesMu.RLock()
|
||||||
for _, lease := range srv.leases {
|
defer srv.leasesMu.RUnlock()
|
||||||
|
|
||||||
|
for _, iface := range srv.interfaces4 {
|
||||||
|
for _, lease := range iface.leases {
|
||||||
leases = append(leases, lease.Clone())
|
leases = append(leases, lease.Clone())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
for _, iface := range srv.interfaces6 {
|
||||||
|
for _, lease := range iface.leases {
|
||||||
|
leases = append(leases, lease.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return leases
|
return leases
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HostByIP implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) HostByIP(ip netip.Addr) (host string) {
|
||||||
|
srv.leasesMu.RLock()
|
||||||
|
defer srv.leasesMu.RUnlock()
|
||||||
|
|
||||||
|
if l, ok := srv.leases.leaseByAddr(ip); ok {
|
||||||
|
return l.Hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// MACByIP implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) MACByIP(ip netip.Addr) (mac net.HardwareAddr) {
|
||||||
|
srv.leasesMu.RLock()
|
||||||
|
defer srv.leasesMu.RUnlock()
|
||||||
|
|
||||||
|
if l, ok := srv.leases.leaseByAddr(ip); ok {
|
||||||
|
return l.HWAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPByHost implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) {
|
||||||
|
srv.leasesMu.RLock()
|
||||||
|
defer srv.leasesMu.RUnlock()
|
||||||
|
|
||||||
|
if l, ok := srv.leases.leaseByName(host); ok {
|
||||||
|
return l.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.Addr{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) Reset() (err error) {
|
||||||
|
srv.leasesMu.Lock()
|
||||||
|
defer srv.leasesMu.Unlock()
|
||||||
|
|
||||||
|
for _, iface := range srv.interfaces4 {
|
||||||
|
iface.reset()
|
||||||
|
}
|
||||||
|
for _, iface := range srv.interfaces6 {
|
||||||
|
iface.reset()
|
||||||
|
}
|
||||||
|
srv.leases.clear()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddLease implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) AddLease(l *Lease) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "adding lease: %w") }()
|
||||||
|
|
||||||
|
addr := l.IP
|
||||||
|
iface, err := srv.ifaceForAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.leasesMu.Lock()
|
||||||
|
defer srv.leasesMu.Unlock()
|
||||||
|
|
||||||
|
return srv.leases.add(l, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStaticLease implements the [Interface] interface for *DHCPServer.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Support moving leases between interfaces.
|
||||||
|
func (srv *DHCPServer) UpdateStaticLease(l *Lease) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "updating static lease: %w") }()
|
||||||
|
|
||||||
|
addr := l.IP
|
||||||
|
iface, err := srv.ifaceForAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.leasesMu.Lock()
|
||||||
|
defer srv.leasesMu.Unlock()
|
||||||
|
|
||||||
|
return srv.leases.update(l, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveLease implements the [Interface] interface for *DHCPServer.
|
||||||
|
func (srv *DHCPServer) RemoveLease(l *Lease) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "removing lease: %w") }()
|
||||||
|
|
||||||
|
addr := l.IP
|
||||||
|
iface, err := srv.ifaceForAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
// Don't wrap the error since there is already an annotation deferred.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.leasesMu.Lock()
|
||||||
|
defer srv.leasesMu.Unlock()
|
||||||
|
|
||||||
|
return srv.leases.remove(l, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ifaceForAddr returns the handled network interface for the given IP address,
|
||||||
|
// or an error if no such interface exists.
|
||||||
|
func (srv *DHCPServer) ifaceForAddr(addr netip.Addr) (iface *netInterface, err error) {
|
||||||
|
var ok bool
|
||||||
|
if addr.Is4() {
|
||||||
|
iface, ok = srv.interfaces4.find(addr)
|
||||||
|
} else {
|
||||||
|
iface, ok = srv.interfaces6.find(addr)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("no interface for ip %s", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return iface, nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,17 +1,67 @@
|
||||||
package dhcpsvc_test
|
package dhcpsvc_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testLocalTLD is a common local TLD for tests.
|
// testLocalTLD is a common local TLD for tests.
|
||||||
const testLocalTLD = "local"
|
const testLocalTLD = "local"
|
||||||
|
|
||||||
|
// testInterfaceConf is a common set of interface configurations for tests.
|
||||||
|
var testInterfaceConf = map[string]*dhcpsvc.InterfaceConfig{
|
||||||
|
"eth0": {
|
||||||
|
IPv4: &dhcpsvc.IPv4Config{
|
||||||
|
Enabled: true,
|
||||||
|
GatewayIP: netip.MustParseAddr("192.168.0.1"),
|
||||||
|
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||||
|
RangeStart: netip.MustParseAddr("192.168.0.2"),
|
||||||
|
RangeEnd: netip.MustParseAddr("192.168.0.254"),
|
||||||
|
LeaseDuration: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
IPv6: &dhcpsvc.IPv6Config{
|
||||||
|
Enabled: true,
|
||||||
|
RangeStart: netip.MustParseAddr("2001:db8::1"),
|
||||||
|
LeaseDuration: 1 * time.Hour,
|
||||||
|
RAAllowSLAAC: true,
|
||||||
|
RASLAACOnly: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"eth1": {
|
||||||
|
IPv4: &dhcpsvc.IPv4Config{
|
||||||
|
Enabled: true,
|
||||||
|
GatewayIP: netip.MustParseAddr("172.16.0.1"),
|
||||||
|
SubnetMask: netip.MustParseAddr("255.255.255.0"),
|
||||||
|
RangeStart: netip.MustParseAddr("172.16.0.2"),
|
||||||
|
RangeEnd: netip.MustParseAddr("172.16.0.255"),
|
||||||
|
LeaseDuration: 1 * time.Hour,
|
||||||
|
},
|
||||||
|
IPv6: &dhcpsvc.IPv6Config{
|
||||||
|
Enabled: true,
|
||||||
|
RangeStart: netip.MustParseAddr("2001:db9::1"),
|
||||||
|
LeaseDuration: 1 * time.Hour,
|
||||||
|
RAAllowSLAAC: true,
|
||||||
|
RASLAACOnly: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustParseMAC parses a hardware address from s and requires no errors.
|
||||||
|
func mustParseMAC(t require.TestingT, s string) (mac net.HardwareAddr) {
|
||||||
|
mac, err := net.ParseMAC(s)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return mac
|
||||||
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
validIPv4Conf := &dhcpsvc.IPv4Config{
|
validIPv4Conf := &dhcpsvc.IPv4Config{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
|
@ -113,3 +163,433 @@ func TestNew(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDHCPServer_AddLease(t *testing.T) {
|
||||||
|
srv, err := dhcpsvc.New(&dhcpsvc.Config{
|
||||||
|
Enabled: true,
|
||||||
|
LocalDomainName: testLocalTLD,
|
||||||
|
Interfaces: testInterfaceConf,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const (
|
||||||
|
host1 = "host1"
|
||||||
|
host2 = "host2"
|
||||||
|
host3 = "host3"
|
||||||
|
)
|
||||||
|
|
||||||
|
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||||
|
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||||
|
ip3 := netip.MustParseAddr("2001:db8::2")
|
||||||
|
|
||||||
|
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||||
|
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||||
|
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||||
|
|
||||||
|
require.NoError(t, srv.AddLease(&dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
IsStatic: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
lease *dhcpsvc.Lease
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "outside_range",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding lease: no interface for ip 1.2.3.4",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_ip",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding lease: lease for ip " + ip1.String() +
|
||||||
|
" already exists",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_hostname",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding lease: lease for hostname " + host1 +
|
||||||
|
" already exists",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_hostname_case",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: strings.ToUpper(host1),
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding lease: lease for hostname " +
|
||||||
|
strings.ToUpper(host1) + " already exists",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_mac",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "adding lease: lease for mac " + mac1.String() +
|
||||||
|
" already exists",
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "valid_v6",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac3,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.AddLease(tc.lease))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDHCPServer_index(t *testing.T) {
|
||||||
|
srv, err := dhcpsvc.New(&dhcpsvc.Config{
|
||||||
|
Enabled: true,
|
||||||
|
LocalDomainName: testLocalTLD,
|
||||||
|
Interfaces: testInterfaceConf,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const (
|
||||||
|
host1 = "host1"
|
||||||
|
host2 = "host2"
|
||||||
|
host3 = "host3"
|
||||||
|
host4 = "host4"
|
||||||
|
host5 = "host5"
|
||||||
|
)
|
||||||
|
|
||||||
|
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||||
|
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||||
|
ip3 := netip.MustParseAddr("172.16.0.3")
|
||||||
|
ip4 := netip.MustParseAddr("172.16.0.4")
|
||||||
|
|
||||||
|
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||||
|
mac2 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||||
|
mac3 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||||
|
|
||||||
|
leases := []*dhcpsvc.Lease{{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac2,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac3,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host4,
|
||||||
|
IP: ip4,
|
||||||
|
HWAddr: mac1,
|
||||||
|
IsStatic: true,
|
||||||
|
}}
|
||||||
|
for _, l := range leases {
|
||||||
|
require.NoError(t, srv.AddLease(l))
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ip_idx", func(t *testing.T) {
|
||||||
|
assert.Equal(t, ip1, srv.IPByHost(host1))
|
||||||
|
assert.Equal(t, ip2, srv.IPByHost(host2))
|
||||||
|
assert.Equal(t, ip3, srv.IPByHost(host3))
|
||||||
|
assert.Equal(t, ip4, srv.IPByHost(host4))
|
||||||
|
assert.Equal(t, netip.Addr{}, srv.IPByHost(host5))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("name_idx", func(t *testing.T) {
|
||||||
|
assert.Equal(t, host1, srv.HostByIP(ip1))
|
||||||
|
assert.Equal(t, host2, srv.HostByIP(ip2))
|
||||||
|
assert.Equal(t, host3, srv.HostByIP(ip3))
|
||||||
|
assert.Equal(t, host4, srv.HostByIP(ip4))
|
||||||
|
assert.Equal(t, "", srv.HostByIP(netip.Addr{}))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("mac_idx", func(t *testing.T) {
|
||||||
|
assert.Equal(t, mac1, srv.MACByIP(ip1))
|
||||||
|
assert.Equal(t, mac2, srv.MACByIP(ip2))
|
||||||
|
assert.Equal(t, mac3, srv.MACByIP(ip3))
|
||||||
|
assert.Equal(t, mac1, srv.MACByIP(ip4))
|
||||||
|
assert.Nil(t, srv.MACByIP(netip.Addr{}))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDHCPServer_UpdateStaticLease(t *testing.T) {
|
||||||
|
srv, err := dhcpsvc.New(&dhcpsvc.Config{
|
||||||
|
Enabled: true,
|
||||||
|
LocalDomainName: testLocalTLD,
|
||||||
|
Interfaces: testInterfaceConf,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const (
|
||||||
|
host1 = "host1"
|
||||||
|
host2 = "host2"
|
||||||
|
host3 = "host3"
|
||||||
|
host4 = "host4"
|
||||||
|
host5 = "host5"
|
||||||
|
host6 = "host6"
|
||||||
|
)
|
||||||
|
|
||||||
|
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||||
|
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||||
|
ip3 := netip.MustParseAddr("192.168.0.4")
|
||||||
|
ip4 := netip.MustParseAddr("2001:db8::2")
|
||||||
|
ip5 := netip.MustParseAddr("2001:db8::3")
|
||||||
|
|
||||||
|
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||||
|
mac2 := mustParseMAC(t, "01:02:03:04:05:07")
|
||||||
|
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||||
|
mac4 := mustParseMAC(t, "06:05:04:03:02:02")
|
||||||
|
|
||||||
|
leases := []*dhcpsvc.Lease{{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac2,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host4,
|
||||||
|
IP: ip4,
|
||||||
|
HWAddr: mac4,
|
||||||
|
IsStatic: true,
|
||||||
|
}}
|
||||||
|
for _, l := range leases {
|
||||||
|
require.NoError(t, srv.AddLease(l))
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
lease *dhcpsvc.Lease
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "outside_range",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: netip.MustParseAddr("1.2.3.4"),
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "updating static lease: no interface for ip 1.2.3.4",
|
||||||
|
}, {
|
||||||
|
name: "not_found",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac3,
|
||||||
|
},
|
||||||
|
wantErrMsg: "updating static lease: no lease for mac " + mac3.String(),
|
||||||
|
}, {
|
||||||
|
name: "duplicate_ip",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "updating static lease: lease for ip " + ip2.String() +
|
||||||
|
" already exists",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_hostname",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "updating static lease: lease for hostname " + host2 +
|
||||||
|
" already exists",
|
||||||
|
}, {
|
||||||
|
name: "duplicate_hostname_case",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: strings.ToUpper(host2),
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "updating static lease: lease for hostname " +
|
||||||
|
strings.ToUpper(host2) + " already exists",
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "valid_v6",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host6,
|
||||||
|
IP: ip5,
|
||||||
|
HWAddr: mac4,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.UpdateStaticLease(tc.lease))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDHCPServer_RemoveLease(t *testing.T) {
|
||||||
|
srv, err := dhcpsvc.New(&dhcpsvc.Config{
|
||||||
|
Enabled: true,
|
||||||
|
LocalDomainName: testLocalTLD,
|
||||||
|
Interfaces: testInterfaceConf,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const (
|
||||||
|
host1 = "host1"
|
||||||
|
host2 = "host2"
|
||||||
|
host3 = "host3"
|
||||||
|
)
|
||||||
|
|
||||||
|
ip1 := netip.MustParseAddr("192.168.0.2")
|
||||||
|
ip2 := netip.MustParseAddr("192.168.0.3")
|
||||||
|
ip3 := netip.MustParseAddr("2001:db8::2")
|
||||||
|
|
||||||
|
mac1 := mustParseMAC(t, "01:02:03:04:05:06")
|
||||||
|
mac2 := mustParseMAC(t, "02:03:04:05:06:07")
|
||||||
|
mac3 := mustParseMAC(t, "06:05:04:03:02:01")
|
||||||
|
|
||||||
|
leases := []*dhcpsvc.Lease{{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac3,
|
||||||
|
IsStatic: true,
|
||||||
|
}}
|
||||||
|
for _, l := range leases {
|
||||||
|
require.NoError(t, srv.AddLease(l))
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
lease *dhcpsvc.Lease
|
||||||
|
wantErrMsg string
|
||||||
|
}{{
|
||||||
|
name: "not_found_mac",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac2,
|
||||||
|
},
|
||||||
|
wantErrMsg: "removing lease: no lease for mac " + mac2.String(),
|
||||||
|
}, {
|
||||||
|
name: "not_found_ip",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip2,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "removing lease: no lease for ip " + ip2.String(),
|
||||||
|
}, {
|
||||||
|
name: "not_found_host",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host2,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "removing lease: no lease for hostname " + host2,
|
||||||
|
}, {
|
||||||
|
name: "valid",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host1,
|
||||||
|
IP: ip1,
|
||||||
|
HWAddr: mac1,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}, {
|
||||||
|
name: "valid_v6",
|
||||||
|
lease: &dhcpsvc.Lease{
|
||||||
|
Hostname: host3,
|
||||||
|
IP: ip3,
|
||||||
|
HWAddr: mac3,
|
||||||
|
},
|
||||||
|
wantErrMsg: "",
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testutil.AssertErrorMsg(t, tc.wantErrMsg, srv.RemoveLease(tc.lease))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Empty(t, srv.Leases())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDHCPServer_Reset(t *testing.T) {
|
||||||
|
srv, err := dhcpsvc.New(&dhcpsvc.Config{
|
||||||
|
Enabled: true,
|
||||||
|
LocalDomainName: testLocalTLD,
|
||||||
|
Interfaces: testInterfaceConf,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
leases := []*dhcpsvc.Lease{{
|
||||||
|
Hostname: "host1",
|
||||||
|
IP: netip.MustParseAddr("192.168.0.2"),
|
||||||
|
HWAddr: mustParseMAC(t, "01:02:03:04:05:06"),
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: "host2",
|
||||||
|
IP: netip.MustParseAddr("192.168.0.3"),
|
||||||
|
HWAddr: mustParseMAC(t, "06:05:04:03:02:01"),
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: "host3",
|
||||||
|
IP: netip.MustParseAddr("2001:db8::2"),
|
||||||
|
HWAddr: mustParseMAC(t, "02:03:04:05:06:07"),
|
||||||
|
IsStatic: true,
|
||||||
|
}, {
|
||||||
|
Hostname: "host4",
|
||||||
|
IP: netip.MustParseAddr("2001:db8::3"),
|
||||||
|
HWAddr: mustParseMAC(t, "06:05:04:03:02:02"),
|
||||||
|
IsStatic: true,
|
||||||
|
}}
|
||||||
|
|
||||||
|
for _, l := range leases {
|
||||||
|
require.NoError(t, srv.AddLease(l))
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Len(t, srv.Leases(), len(leases))
|
||||||
|
|
||||||
|
require.NoError(t, srv.Reset())
|
||||||
|
|
||||||
|
assert.Empty(t, srv.Leases())
|
||||||
|
}
|
||||||
|
|
|
@ -4,12 +4,12 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPv4Config is the interface-specific configuration for DHCPv4.
|
// IPv4Config is the interface-specific configuration for DHCPv4.
|
||||||
|
@ -64,69 +64,6 @@ func (conf *IPv4Config) validate() (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// iface4 is a DHCP interface for IPv4 address family.
|
|
||||||
type iface4 struct {
|
|
||||||
// gateway is the IP address of the network gateway.
|
|
||||||
gateway netip.Addr
|
|
||||||
|
|
||||||
// subnet is the network subnet.
|
|
||||||
subnet netip.Prefix
|
|
||||||
|
|
||||||
// addrSpace is the IPv4 address space allocated for leasing.
|
|
||||||
addrSpace ipRange
|
|
||||||
|
|
||||||
// name is the name of the interface.
|
|
||||||
name string
|
|
||||||
|
|
||||||
// implicitOpts are the options listed in Appendix A of RFC 2131 and
|
|
||||||
// initialized with default values. It must not have intersections with
|
|
||||||
// explicitOpts.
|
|
||||||
implicitOpts layers.DHCPOptions
|
|
||||||
|
|
||||||
// explicitOpts are the user-configured options. It must not have
|
|
||||||
// intersections with implicitOpts.
|
|
||||||
explicitOpts layers.DHCPOptions
|
|
||||||
|
|
||||||
// leaseTTL is the time-to-live of dynamic leases on this interface.
|
|
||||||
leaseTTL time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// newIface4 creates a new DHCP interface for IPv4 address family with the given
|
|
||||||
// configuration. It returns an error if the given configuration can't be used.
|
|
||||||
func newIface4(name string, conf *IPv4Config) (i *iface4, err error) {
|
|
||||||
if !conf.Enabled {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
|
|
||||||
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case !subnet.Contains(conf.RangeStart):
|
|
||||||
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
|
|
||||||
case !subnet.Contains(conf.RangeEnd):
|
|
||||||
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
|
|
||||||
}
|
|
||||||
|
|
||||||
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if addrSpace.contains(conf.GatewayIP) {
|
|
||||||
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
|
|
||||||
}
|
|
||||||
|
|
||||||
i = &iface4{
|
|
||||||
name: name,
|
|
||||||
gateway: conf.GatewayIP,
|
|
||||||
subnet: subnet,
|
|
||||||
addrSpace: addrSpace,
|
|
||||||
leaseTTL: conf.LeaseDuration,
|
|
||||||
}
|
|
||||||
i.implicitOpts, i.explicitOpts = conf.options()
|
|
||||||
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// options returns the implicit and explicit options for the interface. The two
|
// options returns the implicit and explicit options for the interface. The two
|
||||||
// lists are disjoint and the implicit options are initialized with default
|
// lists are disjoint and the implicit options are initialized with default
|
||||||
// values.
|
// values.
|
||||||
|
@ -318,3 +255,83 @@ func (conf *IPv4Config) options() (implicit, explicit layers.DHCPOptions) {
|
||||||
func compareV4OptionCodes(a, b layers.DHCPOption) (res int) {
|
func compareV4OptionCodes(a, b layers.DHCPOption) (res int) {
|
||||||
return int(a.Type) - int(b.Type)
|
return int(a.Type) - int(b.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// netInterfaceV4 is a DHCP interface for IPv4 address family.
|
||||||
|
type netInterfaceV4 struct {
|
||||||
|
// gateway is the IP address of the network gateway.
|
||||||
|
gateway netip.Addr
|
||||||
|
|
||||||
|
// subnet is the network subnet.
|
||||||
|
subnet netip.Prefix
|
||||||
|
|
||||||
|
// addrSpace is the IPv4 address space allocated for leasing.
|
||||||
|
addrSpace ipRange
|
||||||
|
|
||||||
|
// implicitOpts are the options listed in Appendix A of RFC 2131 and
|
||||||
|
// initialized with default values. It must not have intersections with
|
||||||
|
// explicitOpts.
|
||||||
|
implicitOpts layers.DHCPOptions
|
||||||
|
|
||||||
|
// explicitOpts are the user-configured options. It must not have
|
||||||
|
// intersections with implicitOpts.
|
||||||
|
explicitOpts layers.DHCPOptions
|
||||||
|
|
||||||
|
// netInterface is embedded here to provide some common network interface
|
||||||
|
// logic.
|
||||||
|
netInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNetInterfaceV4 creates a new DHCP interface for IPv4 address family with
|
||||||
|
// the given configuration. It returns an error if the given configuration
|
||||||
|
// can't be used.
|
||||||
|
func newNetInterfaceV4(name string, conf *IPv4Config) (i *netInterfaceV4, err error) {
|
||||||
|
if !conf.Enabled {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
maskLen, _ := net.IPMask(conf.SubnetMask.AsSlice()).Size()
|
||||||
|
subnet := netip.PrefixFrom(conf.GatewayIP, maskLen)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case !subnet.Contains(conf.RangeStart):
|
||||||
|
return nil, fmt.Errorf("range start %s is not within %s", conf.RangeStart, subnet)
|
||||||
|
case !subnet.Contains(conf.RangeEnd):
|
||||||
|
return nil, fmt.Errorf("range end %s is not within %s", conf.RangeEnd, subnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrSpace, err := newIPRange(conf.RangeStart, conf.RangeEnd)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else if addrSpace.contains(conf.GatewayIP) {
|
||||||
|
return nil, fmt.Errorf("gateway ip %s in the ip range %s", conf.GatewayIP, addrSpace)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = &netInterfaceV4{
|
||||||
|
gateway: conf.GatewayIP,
|
||||||
|
subnet: subnet,
|
||||||
|
addrSpace: addrSpace,
|
||||||
|
netInterface: netInterface{
|
||||||
|
name: name,
|
||||||
|
leaseTTL: conf.LeaseDuration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
i.implicitOpts, i.explicitOpts = conf.options()
|
||||||
|
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
|
||||||
|
type netInterfacesV4 []*netInterfaceV4
|
||||||
|
|
||||||
|
// find returns the first network interface within ifaces containing ip. It
|
||||||
|
// returns false if there is no such interface.
|
||||||
|
func (ifaces netInterfacesV4) find(ip netip.Addr) (iface4 *netInterface, ok bool) {
|
||||||
|
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV4) (contains bool) {
|
||||||
|
return iface.subnet.Contains(ip)
|
||||||
|
})
|
||||||
|
if i < 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ifaces[i].netInterface, true
|
||||||
|
}
|
||||||
|
|
|
@ -3,11 +3,12 @@ package dhcpsvc
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPv6Config is the interface-specific configuration for DHCPv6.
|
// IPv6Config is the interface-specific configuration for DHCPv6.
|
||||||
|
@ -52,57 +53,6 @@ func (conf *IPv6Config) validate() (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// iface6 is a DHCP interface for IPv6 address family.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Add options.
|
|
||||||
type iface6 struct {
|
|
||||||
// rangeStart is the first IP address in the range.
|
|
||||||
rangeStart netip.Addr
|
|
||||||
|
|
||||||
// name is the name of the interface.
|
|
||||||
name string
|
|
||||||
|
|
||||||
// implicitOpts are the DHCPv6 options listed in RFC 8415 (and others) and
|
|
||||||
// initialized with default values. It must not have intersections with
|
|
||||||
// explicitOpts.
|
|
||||||
implicitOpts layers.DHCPv6Options
|
|
||||||
|
|
||||||
// explicitOpts are the user-configured options. It must not have
|
|
||||||
// intersections with implicitOpts.
|
|
||||||
explicitOpts layers.DHCPv6Options
|
|
||||||
|
|
||||||
// leaseTTL is the time-to-live of dynamic leases on this interface.
|
|
||||||
leaseTTL time.Duration
|
|
||||||
|
|
||||||
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
|
|
||||||
// flags.
|
|
||||||
raSLAACOnly bool
|
|
||||||
|
|
||||||
// raAllowSLAAC defines if DHCP should send ICMPv6.RA packets with MO flags.
|
|
||||||
raAllowSLAAC bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// newIface6 creates a new DHCP interface for IPv6 address family with the given
|
|
||||||
// configuration.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Validate properly.
|
|
||||||
func newIface6(name string, conf *IPv6Config) (i *iface6) {
|
|
||||||
if !conf.Enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
i = &iface6{
|
|
||||||
name: name,
|
|
||||||
rangeStart: conf.RangeStart,
|
|
||||||
leaseTTL: conf.LeaseDuration,
|
|
||||||
raSLAACOnly: conf.RASLAACOnly,
|
|
||||||
raAllowSLAAC: conf.RAAllowSLAAC,
|
|
||||||
}
|
|
||||||
i.implicitOpts, i.explicitOpts = conf.options()
|
|
||||||
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
|
|
||||||
// options returns the implicit and explicit options for the interface. The two
|
// options returns the implicit and explicit options for the interface. The two
|
||||||
// lists are disjoint and the implicit options are initialized with default
|
// lists are disjoint and the implicit options are initialized with default
|
||||||
// values.
|
// values.
|
||||||
|
@ -133,3 +83,79 @@ func (conf *IPv6Config) options() (implicit, explicit layers.DHCPv6Options) {
|
||||||
func compareV6OptionCodes(a, b layers.DHCPv6Option) (res int) {
|
func compareV6OptionCodes(a, b layers.DHCPv6Option) (res int) {
|
||||||
return int(a.Code) - int(b.Code)
|
return int(a.Code) - int(b.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// netInterfaceV6 is a DHCP interface for IPv6 address family.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Add options.
|
||||||
|
type netInterfaceV6 struct {
|
||||||
|
// rangeStart is the first IP address in the range.
|
||||||
|
rangeStart netip.Addr
|
||||||
|
|
||||||
|
// implicitOpts are the DHCPv6 options listed in RFC 8415 (and others) and
|
||||||
|
// initialized with default values. It must not have intersections with
|
||||||
|
// explicitOpts.
|
||||||
|
implicitOpts layers.DHCPv6Options
|
||||||
|
|
||||||
|
// explicitOpts are the user-configured options. It must not have
|
||||||
|
// intersections with implicitOpts.
|
||||||
|
explicitOpts layers.DHCPv6Options
|
||||||
|
|
||||||
|
// netInterface is embedded here to provide some common network interface
|
||||||
|
// logic.
|
||||||
|
netInterface
|
||||||
|
|
||||||
|
// raSLAACOnly defines if DHCP should send ICMPv6.RA packets without MO
|
||||||
|
// flags.
|
||||||
|
raSLAACOnly bool
|
||||||
|
|
||||||
|
// raAllowSLAAC defines if DHCP should send ICMPv6.RA packets with MO flags.
|
||||||
|
raAllowSLAAC bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newNetInterfaceV6 creates a new DHCP interface for IPv6 address family with
|
||||||
|
// the given configuration.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Validate properly.
|
||||||
|
func newNetInterfaceV6(name string, conf *IPv6Config) (i *netInterfaceV6) {
|
||||||
|
if !conf.Enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
i = &netInterfaceV6{
|
||||||
|
rangeStart: conf.RangeStart,
|
||||||
|
netInterface: netInterface{
|
||||||
|
name: name,
|
||||||
|
leaseTTL: conf.LeaseDuration,
|
||||||
|
},
|
||||||
|
raSLAACOnly: conf.RASLAACOnly,
|
||||||
|
raAllowSLAAC: conf.RAAllowSLAAC,
|
||||||
|
}
|
||||||
|
i.implicitOpts, i.explicitOpts = conf.options()
|
||||||
|
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
// netInterfacesV4 is a slice of network interfaces of IPv4 address family.
|
||||||
|
type netInterfacesV6 []*netInterfaceV6
|
||||||
|
|
||||||
|
// find returns the first network interface within ifaces containing ip. It
|
||||||
|
// returns false if there is no such interface.
|
||||||
|
func (ifaces netInterfacesV6) find(ip netip.Addr) (iface6 *netInterface, ok bool) {
|
||||||
|
// prefLen is the length of prefix to match ip against.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): DHCPv6 inherits the weird behavior of legacy
|
||||||
|
// implementation where the allocated range constrained by the first address
|
||||||
|
// and the first address with last byte set to 0xff. Proper prefixes should
|
||||||
|
// be used instead.
|
||||||
|
const prefLen = netutil.IPv6BitLen - 8
|
||||||
|
|
||||||
|
i := slices.IndexFunc(ifaces, func(iface *netInterfaceV6) (contains bool) {
|
||||||
|
return !ip.Less(iface.rangeStart) &&
|
||||||
|
netip.PrefixFrom(iface.rangeStart, prefLen).Contains(ip)
|
||||||
|
})
|
||||||
|
if i < 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ifaces[i].netInterface, true
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateClientID returns an error if id is not a valid ClientID.
|
// ValidateClientID returns an error if id is not a valid ClientID.
|
||||||
|
//
|
||||||
|
// Keep in sync with [client.ValidateClientID].
|
||||||
func ValidateClientID(id string) (err error) {
|
func ValidateClientID(id string) (err error) {
|
||||||
err = netutil.ValidateHostnameLabel(id)
|
err = netutil.ValidateHostnameLabel(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -24,7 +25,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"github.com/ameshkov/dnscrypt/v2"
|
"github.com/ameshkov/dnscrypt/v2"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientsContainer provides information about preconfigured DNS clients.
|
// ClientsContainer provides information about preconfigured DNS clients.
|
||||||
|
@ -357,10 +357,6 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||||
conf.DNSCryptResolverCert = c.ResolverCert
|
conf.DNSCryptResolverCert = c.ResolverCert
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.UpstreamConfig == nil || len(conf.UpstreamConfig.Upstreams) == 0 {
|
|
||||||
return nil, errors.Error("no default upstream servers configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
conf, err = prepareCacheConfig(conf,
|
conf, err = prepareCacheConfig(conf,
|
||||||
srvConf.CacheSize,
|
srvConf.CacheSize,
|
||||||
srvConf.CacheMinTTL,
|
srvConf.CacheMinTTL,
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAnyNameMatches(t *testing.T) {
|
func TestAnyNameMatches(t *testing.T) {
|
||||||
|
|
|
@ -2,54 +2,56 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// upstreamConfigValidator parses the [*proxy.UpstreamConfig] and checks the
|
// upstreamConfigValidator parses each section of an upstream configuration into
|
||||||
// actual DNS availability of each upstream.
|
// a corresponding [*proxy.UpstreamConfig] and checks the actual DNS
|
||||||
|
// availability of each upstream.
|
||||||
type upstreamConfigValidator struct {
|
type upstreamConfigValidator struct {
|
||||||
// general is the general upstream configuration.
|
// generalUpstreamResults contains upstream results of a general section.
|
||||||
general []*upstreamResult
|
generalUpstreamResults map[string]*upstreamResult
|
||||||
|
|
||||||
// fallback is the fallback upstream configuration.
|
// fallbackUpstreamResults contains upstream results of a fallback section.
|
||||||
fallback []*upstreamResult
|
fallbackUpstreamResults map[string]*upstreamResult
|
||||||
|
|
||||||
// private is the private upstream configuration.
|
// privateUpstreamResults contains upstream results of a private section.
|
||||||
private []*upstreamResult
|
privateUpstreamResults map[string]*upstreamResult
|
||||||
|
|
||||||
|
// generalParseResults contains parsing results of a general section.
|
||||||
|
generalParseResults []*parseResult
|
||||||
|
|
||||||
|
// fallbackParseResults contains parsing results of a fallback section.
|
||||||
|
fallbackParseResults []*parseResult
|
||||||
|
|
||||||
|
// privateParseResults contains parsing results of a private section.
|
||||||
|
privateParseResults []*parseResult
|
||||||
}
|
}
|
||||||
|
|
||||||
// upstreamResult is a result of validation of an [upstream.Upstream] within an
|
// upstreamResult is a result of parsing of an [upstream.Upstream] within an
|
||||||
// [proxy.UpstreamConfig].
|
// [proxy.UpstreamConfig].
|
||||||
type upstreamResult struct {
|
type upstreamResult struct {
|
||||||
// server is the parsed upstream. It is nil when there was an error during
|
// server is the parsed upstream.
|
||||||
// parsing.
|
|
||||||
server upstream.Upstream
|
server upstream.Upstream
|
||||||
|
|
||||||
// err is the error either from parsing or from checking the upstream.
|
// err is the upstream check error.
|
||||||
err error
|
err error
|
||||||
|
|
||||||
// original is the piece of configuration that have either been turned to an
|
|
||||||
// upstream or caused an error.
|
|
||||||
original string
|
|
||||||
|
|
||||||
// isSpecific is true if the upstream is domain-specific.
|
// isSpecific is true if the upstream is domain-specific.
|
||||||
isSpecific bool
|
isSpecific bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// compare compares two [upstreamResult]s. It returns 0 if they are equal, -1
|
// parseResult contains a original piece of upstream configuration and a
|
||||||
// if ur should be sorted before other, and 1 otherwise.
|
// corresponding error.
|
||||||
//
|
type parseResult struct {
|
||||||
// TODO(e.burkov): Perhaps it makes sense to sort the results with errors near
|
err *proxy.ParseError
|
||||||
// the end.
|
original string
|
||||||
func (ur *upstreamResult) compare(other *upstreamResult) (res int) {
|
|
||||||
return strings.Compare(ur.original, other.original)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
// newUpstreamConfigValidator parses the upstream configuration and returns a
|
||||||
|
@ -61,97 +63,99 @@ func newUpstreamConfigValidator(
|
||||||
private []string,
|
private []string,
|
||||||
opts *upstream.Options,
|
opts *upstream.Options,
|
||||||
) (cv *upstreamConfigValidator) {
|
) (cv *upstreamConfigValidator) {
|
||||||
cv = &upstreamConfigValidator{}
|
cv = &upstreamConfigValidator{
|
||||||
|
generalUpstreamResults: map[string]*upstreamResult{},
|
||||||
|
fallbackUpstreamResults: map[string]*upstreamResult{},
|
||||||
|
privateUpstreamResults: map[string]*upstreamResult{},
|
||||||
|
}
|
||||||
|
|
||||||
for _, line := range general {
|
conf, err := proxy.ParseUpstreamsConfig(general, opts)
|
||||||
cv.general = cv.insertLineResults(cv.general, line, opts)
|
cv.generalParseResults = collectErrResults(general, err)
|
||||||
}
|
insertConfResults(conf, cv.generalUpstreamResults)
|
||||||
for _, line := range fallback {
|
|
||||||
cv.fallback = cv.insertLineResults(cv.fallback, line, opts)
|
conf, err = proxy.ParseUpstreamsConfig(fallback, opts)
|
||||||
}
|
cv.fallbackParseResults = collectErrResults(fallback, err)
|
||||||
for _, line := range private {
|
insertConfResults(conf, cv.fallbackUpstreamResults)
|
||||||
cv.private = cv.insertLineResults(cv.private, line, opts)
|
|
||||||
}
|
conf, err = proxy.ParseUpstreamsConfig(private, opts)
|
||||||
|
cv.privateParseResults = collectErrResults(private, err)
|
||||||
|
insertConfResults(conf, cv.privateUpstreamResults)
|
||||||
|
|
||||||
return cv
|
return cv
|
||||||
}
|
}
|
||||||
|
|
||||||
// insertLineResults parses line and inserts the result into s. It can insert
|
// collectErrResults parses err and returns parsing results containing the
|
||||||
// multiple results as well as none.
|
// original upstream configuration line and the corresponding error. err can be
|
||||||
func (cv *upstreamConfigValidator) insertLineResults(
|
// nil.
|
||||||
s []*upstreamResult,
|
func collectErrResults(lines []string, err error) (results []*parseResult) {
|
||||||
line string,
|
if err == nil {
|
||||||
opts *upstream.Options,
|
return nil
|
||||||
) (result []*upstreamResult) {
|
|
||||||
upstreams, isSpecific, err := splitUpstreamLine(line)
|
|
||||||
if err != nil {
|
|
||||||
return cv.insert(s, &upstreamResult{
|
|
||||||
err: err,
|
|
||||||
original: line,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, upstreamAddr := range upstreams {
|
// limit is a maximum length for upstream configuration lines.
|
||||||
var res *upstreamResult
|
const limit = 80
|
||||||
if upstreamAddr != "#" {
|
|
||||||
res = cv.parseUpstream(upstreamAddr, opts)
|
wrapper, ok := err.(errors.WrapperSlice)
|
||||||
} else if !isSpecific {
|
if !ok {
|
||||||
res = &upstreamResult{
|
log.Debug("dnsforward: configvalidator: unwrapping: %s", err)
|
||||||
err: errNotDomainSpecific,
|
|
||||||
original: upstreamAddr,
|
return nil
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
|
errs := wrapper.Unwrap()
|
||||||
|
results = make([]*parseResult, 0, len(errs))
|
||||||
|
for i, e := range errs {
|
||||||
|
var parseErr *proxy.ParseError
|
||||||
|
if !errors.As(e, &parseErr) {
|
||||||
|
log.Debug("dnsforward: configvalidator: inserting unexpected error %d: %s", i, err)
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
res.isSpecific = isSpecific
|
idx := parseErr.Idx
|
||||||
s = cv.insert(s, res)
|
line := []rune(lines[idx])
|
||||||
|
if len(line) > limit {
|
||||||
|
line = line[:limit]
|
||||||
|
line[limit-1] = '…'
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
results = append(results, &parseResult{
|
||||||
|
original: string(line),
|
||||||
|
err: parseErr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert inserts r into slice in a sorted order, except duplicates. slice must
|
// insertConfResults parses conf and inserts the upstream result into results.
|
||||||
// not be nil.
|
// It can insert multiple results as well as none.
|
||||||
func (cv *upstreamConfigValidator) insert(
|
func insertConfResults(conf *proxy.UpstreamConfig, results map[string]*upstreamResult) {
|
||||||
s []*upstreamResult,
|
insertListResults(conf.Upstreams, results, false)
|
||||||
r *upstreamResult,
|
|
||||||
) (result []*upstreamResult) {
|
|
||||||
i, has := slices.BinarySearchFunc(s, r, (*upstreamResult).compare)
|
|
||||||
if has {
|
|
||||||
log.Debug("dnsforward: duplicate configuration %q", r.original)
|
|
||||||
|
|
||||||
return s
|
for _, ups := range conf.DomainReservedUpstreams {
|
||||||
|
insertListResults(ups, results, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
return slices.Insert(s, i, r)
|
for _, ups := range conf.SpecifiedDomainUpstreams {
|
||||||
|
insertListResults(ups, results, true)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseUpstream parses addr and returns the result of parsing. It returns nil
|
// insertListResults constructs upstream results from the upstream list and
|
||||||
// if the specified server points at the default upstream server which is
|
// inserts them into results. It can insert multiple results as well as none.
|
||||||
// validated separately.
|
func insertListResults(ups []upstream.Upstream, results map[string]*upstreamResult, specific bool) {
|
||||||
func (cv *upstreamConfigValidator) parseUpstream(
|
for _, u := range ups {
|
||||||
addr string,
|
addr := u.Address()
|
||||||
opts *upstream.Options,
|
_, ok := results[addr]
|
||||||
) (r *upstreamResult) {
|
if ok {
|
||||||
// Check if the upstream has a valid protocol prefix.
|
continue
|
||||||
//
|
|
||||||
// TODO(e.burkov): Validate the domain name.
|
|
||||||
if proto, _, ok := strings.Cut(addr, "://"); ok {
|
|
||||||
if !slices.Contains(protocols, proto) {
|
|
||||||
return &upstreamResult{
|
|
||||||
err: fmt.Errorf("bad protocol %q", proto),
|
|
||||||
original: addr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ups, err := upstream.AddressToUpstream(addr, opts)
|
results[addr] = &upstreamResult{
|
||||||
|
server: u,
|
||||||
return &upstreamResult{
|
isSpecific: specific,
|
||||||
server: ups,
|
}
|
||||||
err: err,
|
|
||||||
original: addr,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,35 +191,30 @@ func (cv *upstreamConfigValidator) check() {
|
||||||
}
|
}
|
||||||
|
|
||||||
wg := &sync.WaitGroup{}
|
wg := &sync.WaitGroup{}
|
||||||
wg.Add(len(cv.general) + len(cv.fallback) + len(cv.private))
|
wg.Add(len(cv.generalUpstreamResults) +
|
||||||
|
len(cv.fallbackUpstreamResults) +
|
||||||
|
len(cv.privateUpstreamResults))
|
||||||
|
|
||||||
for _, res := range cv.general {
|
for _, res := range cv.generalUpstreamResults {
|
||||||
go cv.checkSrv(res, wg, commonChecker)
|
go checkSrv(res, wg, commonChecker)
|
||||||
}
|
}
|
||||||
for _, res := range cv.fallback {
|
for _, res := range cv.fallbackUpstreamResults {
|
||||||
go cv.checkSrv(res, wg, commonChecker)
|
go checkSrv(res, wg, commonChecker)
|
||||||
}
|
}
|
||||||
for _, res := range cv.private {
|
for _, res := range cv.privateUpstreamResults {
|
||||||
go cv.checkSrv(res, wg, arpaChecker)
|
go checkSrv(res, wg, arpaChecker)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkSrv runs hc on the server from res, if any, and stores any occurred
|
// checkSrv runs hc on the server from res, if any, and stores any occurred
|
||||||
// error in res. wg is always marked done in the end. It used to be called in
|
// error in res. wg is always marked done in the end. It is intended to be
|
||||||
// a separate goroutine.
|
// used as a goroutine.
|
||||||
func (cv *upstreamConfigValidator) checkSrv(
|
func checkSrv(res *upstreamResult, wg *sync.WaitGroup, hc *healthchecker) {
|
||||||
res *upstreamResult,
|
defer log.OnPanic(fmt.Sprintf("dnsforward: checking upstream %s", res.server.Address()))
|
||||||
wg *sync.WaitGroup,
|
|
||||||
hc *healthchecker,
|
|
||||||
) {
|
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
|
||||||
if res.server == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res.err = hc.check(res.server)
|
res.err = hc.check(res.server)
|
||||||
if res.err != nil && res.isSpecific {
|
if res.err != nil && res.isSpecific {
|
||||||
res.err = domainSpecificTestError{Err: res.err}
|
res.err = domainSpecificTestError{Err: res.err}
|
||||||
|
@ -225,65 +224,126 @@ func (cv *upstreamConfigValidator) checkSrv(
|
||||||
// close closes all the upstreams that were successfully parsed. It enriches
|
// close closes all the upstreams that were successfully parsed. It enriches
|
||||||
// the results with deferred closing errors.
|
// the results with deferred closing errors.
|
||||||
func (cv *upstreamConfigValidator) close() {
|
func (cv *upstreamConfigValidator) close() {
|
||||||
for _, slice := range [][]*upstreamResult{cv.general, cv.fallback, cv.private} {
|
all := []map[string]*upstreamResult{
|
||||||
for _, r := range slice {
|
cv.generalUpstreamResults,
|
||||||
if r.server != nil {
|
cv.fallbackUpstreamResults,
|
||||||
|
cv.privateUpstreamResults,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range all {
|
||||||
|
for _, r := range m {
|
||||||
r.err = errors.WithDeferred(r.err, r.server.Close())
|
r.err = errors.WithDeferred(r.err, r.server.Close())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sections of the upstream configuration according to the text label of the
|
||||||
|
// localization.
|
||||||
|
//
|
||||||
|
// Keep in sync with client/src/__locales/en.json.
|
||||||
|
//
|
||||||
|
// TODO(s.chzhen): Refactor.
|
||||||
|
const (
|
||||||
|
generalTextLabel = "upstream_dns"
|
||||||
|
fallbackTextLabel = "fallback_dns_title"
|
||||||
|
privateTextLabel = "local_ptr_title"
|
||||||
|
)
|
||||||
|
|
||||||
// status returns all the data collected during parsing, healthcheck, and
|
// status returns all the data collected during parsing, healthcheck, and
|
||||||
// closing of the upstreams. The returned map is keyed by the original upstream
|
// closing of the upstreams. The returned map is keyed by the original upstream
|
||||||
// configuration piece and contains the corresponding error or "OK" if there was
|
// configuration piece and contains the corresponding error or "OK" if there was
|
||||||
// no error.
|
// no error.
|
||||||
func (cv *upstreamConfigValidator) status() (results map[string]string) {
|
func (cv *upstreamConfigValidator) status() (results map[string]string) {
|
||||||
result := map[string]string{}
|
// Names of the upstream configuration sections for logging.
|
||||||
|
const (
|
||||||
|
generalSection = "general"
|
||||||
|
fallbackSection = "fallback"
|
||||||
|
privateSection = "private"
|
||||||
|
)
|
||||||
|
|
||||||
for _, res := range cv.general {
|
results = map[string]string{}
|
||||||
resultToStatus("general", res, result)
|
|
||||||
|
for original, res := range cv.generalUpstreamResults {
|
||||||
|
upstreamResultToStatus(generalSection, string(original), res, results)
|
||||||
}
|
}
|
||||||
for _, res := range cv.fallback {
|
for original, res := range cv.fallbackUpstreamResults {
|
||||||
resultToStatus("fallback", res, result)
|
upstreamResultToStatus(fallbackSection, string(original), res, results)
|
||||||
}
|
}
|
||||||
for _, res := range cv.private {
|
for original, res := range cv.privateUpstreamResults {
|
||||||
resultToStatus("private", res, result)
|
upstreamResultToStatus(privateSection, string(original), res, results)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
parseResultToStatus(generalTextLabel, generalSection, cv.generalParseResults, results)
|
||||||
|
parseResultToStatus(fallbackTextLabel, fallbackSection, cv.fallbackParseResults, results)
|
||||||
|
parseResultToStatus(privateTextLabel, privateSection, cv.privateParseResults, results)
|
||||||
|
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
// resultToStatus puts "OK" or an error message from res into resMap. section
|
// upstreamResultToStatus puts "OK" or an error message from res into resMap.
|
||||||
// is the name of the upstream configuration section, i.e. "general",
|
// section is the name of the upstream configuration section, i.e. "general",
|
||||||
// "fallback", or "private", and only used for logging.
|
// "fallback", or "private", and only used for logging.
|
||||||
//
|
//
|
||||||
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
|
// TODO(e.burkov): Currently, the HTTP handler expects that all the results are
|
||||||
// put together in a single map, which may lead to collisions, see AG-27539.
|
// put together in a single map, which may lead to collisions, see AG-27539.
|
||||||
// Improve the results compilation.
|
// Improve the results compilation.
|
||||||
func resultToStatus(section string, res *upstreamResult, resMap map[string]string) {
|
func upstreamResultToStatus(
|
||||||
|
section string,
|
||||||
|
original string,
|
||||||
|
res *upstreamResult,
|
||||||
|
resMap map[string]string,
|
||||||
|
) {
|
||||||
val := "OK"
|
val := "OK"
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
val = res.err.Error()
|
val = res.err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
prevVal := resMap[res.original]
|
prevVal := resMap[original]
|
||||||
switch prevVal {
|
switch prevVal {
|
||||||
case "":
|
case "":
|
||||||
resMap[res.original] = val
|
resMap[original] = val
|
||||||
case val:
|
case val:
|
||||||
log.Debug("dnsforward: duplicating %s config line %q", section, res.original)
|
log.Debug("dnsforward: duplicating %s config line %q", section, original)
|
||||||
default:
|
default:
|
||||||
log.Debug(
|
log.Debug(
|
||||||
"dnsforward: warning: %s config line %q (%v) had different result %v",
|
"dnsforward: warning: %s config line %q (%v) had different result %v",
|
||||||
section,
|
section,
|
||||||
val,
|
val,
|
||||||
res.original,
|
original,
|
||||||
prevVal,
|
prevVal,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseResultToStatus puts parsing error messages from results into resMap.
|
||||||
|
// section is the name of the upstream configuration section, i.e. "general",
|
||||||
|
// "fallback", or "private", and only used for logging.
|
||||||
|
//
|
||||||
|
// Parsing error message has the following format:
|
||||||
|
//
|
||||||
|
// sectionTextLabel line: parsing error
|
||||||
|
//
|
||||||
|
// Where sectionTextLabel is a section text label of a localization and line is
|
||||||
|
// a line number.
|
||||||
|
func parseResultToStatus(
|
||||||
|
textLabel string,
|
||||||
|
section string,
|
||||||
|
results []*parseResult,
|
||||||
|
resMap map[string]string,
|
||||||
|
) {
|
||||||
|
for _, res := range results {
|
||||||
|
original := res.original
|
||||||
|
_, ok := resMap[original]
|
||||||
|
if ok {
|
||||||
|
log.Debug("dnsforward: duplicating %s parsing error %q", section, original)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resMap[original] = fmt.Sprintf("%s %d: parsing error", textLabel, res.err.Idx+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
// domainSpecificTestError is a wrapper for errors returned by checkDNS to mark
|
||||||
// the tested upstream domain-specific and therefore consider its errors
|
// the tested upstream domain-specific and therefore consider its errors
|
||||||
// non-critical.
|
// non-critical.
|
||||||
|
@ -342,7 +402,7 @@ func (h *healthchecker) check(u upstream.Upstream) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
return fmt.Errorf("couldn't communicate with upstream: %w", err)
|
||||||
} else if h.ansEmpty && len(reply.Answer) > 0 {
|
} else if h.ansEmpty && len(reply.Answer) > 0 {
|
||||||
return errWrongResponse
|
return errors.Error("wrong response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
@ -101,21 +100,6 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||||
type answerMap = map[uint16][sectionsNum][]dns.RR
|
type answerMap = map[uint16][sectionsNum][]dns.RR
|
||||||
|
|
||||||
pt := testutil.PanicT{}
|
pt := testutil.PanicT{}
|
||||||
newUps := func(answers answerMap) (u upstream.Upstream) {
|
|
||||||
return aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
|
||||||
q := req.Question[0]
|
|
||||||
require.Contains(pt, answers, q.Qtype)
|
|
||||||
|
|
||||||
answer := answers[q.Qtype]
|
|
||||||
|
|
||||||
resp = (&dns.Msg{}).SetReply(req)
|
|
||||||
resp.Answer = answer[sectionAnswer]
|
|
||||||
resp.Ns = answer[sectionAuthority]
|
|
||||||
resp.Extra = answer[sectionAdditional]
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -265,13 +249,16 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||||
}}
|
}}
|
||||||
|
|
||||||
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
|
localRR := newRR(t, ptr64Domain, dns.TypePTR, 3600, pointedDomain)
|
||||||
localUps := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
|
||||||
require.Equal(pt, req.Question[0].Name, ptr64Domain)
|
require.Len(pt, m.Question, 1)
|
||||||
resp = (&dns.Msg{}).SetReply(req)
|
require.Equal(pt, m.Question[0].Name, ptr64Domain)
|
||||||
resp.Answer = []dns.RR{localRR}
|
resp := (&dns.Msg{
|
||||||
|
Answer: []dns.RR{localRR},
|
||||||
|
}).SetReply(m)
|
||||||
|
|
||||||
return resp, nil
|
require.NoError(t, w.WriteMsg(resp))
|
||||||
})
|
})
|
||||||
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||||
|
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Net: "tcp",
|
Net: "tcp",
|
||||||
|
@ -279,10 +266,28 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be reused
|
tc := tc
|
||||||
// right after stop, due to a data race in [proxy.Proxy.Init] method
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
// when setting an OOB size. As a temporary workaround, recreate the
|
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
// whole server for each test case.
|
q := req.Question[0]
|
||||||
|
require.Contains(pt, tc.upsAns, q.Qtype)
|
||||||
|
|
||||||
|
answer := tc.upsAns[q.Qtype]
|
||||||
|
|
||||||
|
resp := (&dns.Msg{
|
||||||
|
Answer: answer[sectionAnswer],
|
||||||
|
Ns: answer[sectionAuthority],
|
||||||
|
Extra: answer[sectionAdditional],
|
||||||
|
}).SetReply(req)
|
||||||
|
|
||||||
|
require.NoError(pt, w.WriteMsg(resp))
|
||||||
|
})
|
||||||
|
upsAddr := aghtest.StartLocalhostUpstream(t, upsHdlr).String()
|
||||||
|
|
||||||
|
// TODO(e.burkov): It seems [proxy.Proxy] isn't intended to be
|
||||||
|
// reused right after stop, due to a data race in [proxy.Proxy.Init]
|
||||||
|
// method when setting an OOB size. As a temporary workaround,
|
||||||
|
// recreate the whole server for each test case.
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, ServerConfig{
|
}, ServerConfig{
|
||||||
|
@ -292,12 +297,13 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
|
||||||
Config: Config{
|
Config: Config{
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
UpstreamDNS: []string{upsAddr},
|
||||||
},
|
},
|
||||||
|
UsePrivateRDNS: true,
|
||||||
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, localUps)
|
})
|
||||||
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newUps(tc.upsAns)}
|
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
|
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -30,7 +31,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
"github.com/AdguardTeam/golibs/netutil/sysresolv"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultTimeout is the default upstream timeout
|
// DefaultTimeout is the default upstream timeout
|
||||||
|
@ -464,7 +464,8 @@ func (s *Server) Start() error {
|
||||||
// startLocked starts the DNS server without locking. s.serverLock is expected
|
// startLocked starts the DNS server without locking. s.serverLock is expected
|
||||||
// to be locked.
|
// to be locked.
|
||||||
func (s *Server) startLocked() error {
|
func (s *Server) startLocked() error {
|
||||||
err := s.dnsProxy.Start()
|
// TODO(e.burkov): Use context properly.
|
||||||
|
err := s.dnsProxy.Start(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s.isRunning = true
|
s.isRunning = true
|
||||||
}
|
}
|
||||||
|
@ -518,34 +519,30 @@ func (s *Server) prepareLocalResolvers(
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupLocalResolvers initializes and sets the resolvers for local addresses.
|
// setupLocalResolvers initializes and sets the resolvers for local addresses.
|
||||||
// It assumes s.serverLock is locked or s not running.
|
// It assumes s.serverLock is locked or s not running. It returns the upstream
|
||||||
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (err error) {
|
// configuration used for private PTR resolving, or nil if it's disabled. Note,
|
||||||
uc, err := s.prepareLocalResolvers(boot)
|
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||||
|
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
|
||||||
|
if !s.conf.UsePrivateRDNS {
|
||||||
|
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
uc, err = s.prepareLocalResolvers(boot)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error because it's informative enough as is.
|
// Don't wrap the error because it's informative enough as is.
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.localResolvers = &proxy.Proxy{
|
s.localResolvers, err = proxy.New(&proxy.Config{
|
||||||
Config: proxy.Config{
|
|
||||||
UpstreamConfig: uc,
|
UpstreamConfig: uc,
|
||||||
},
|
})
|
||||||
}
|
|
||||||
|
|
||||||
err = s.localResolvers.Init()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("initializing proxy: %w", err)
|
return nil, fmt.Errorf("creating local resolvers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
// TODO(e.burkov): Should we also consider the DNS64 usage?
|
||||||
if s.conf.UsePrivateRDNS &&
|
return uc, nil
|
||||||
// Only set the upstream config if there are any upstreams. It's safe
|
|
||||||
// to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
|
|
||||||
len(uc.Upstreams)+len(uc.DomainReservedUpstreams)+len(uc.SpecifiedDomainUpstreams) > 0 {
|
|
||||||
s.dnsProxy.PrivateRDNSUpstreamConfig = uc
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare initializes parameters of s using data from conf. conf must not be
|
// Prepare initializes parameters of s using data from conf. conf must not be
|
||||||
|
@ -586,21 +583,22 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||||
return fmt.Errorf("preparing access: %w", err)
|
return fmt.Errorf("preparing access: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the proxy here because [setupLocalResolvers] sets its values.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
|
||||||
s.dnsProxy = &proxy.Proxy{Config: *proxyConfig}
|
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
|
||||||
|
|
||||||
err = s.setupLocalResolvers(boot)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting up resolvers: %w", err)
|
return fmt.Errorf("setting up resolvers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = s.setupFallbackDNS()
|
proxyConfig.Fallbacks, err = s.setupFallbackDNS()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
return fmt.Errorf("setting up fallback dns servers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.dnsProxy, err = proxy.New(proxyConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
s.recDetector.clear()
|
s.recDetector.clear()
|
||||||
|
|
||||||
s.setupAddrProc()
|
s.setupAddrProc()
|
||||||
|
@ -643,26 +641,25 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupFallbackDNS initializes the fallback DNS servers.
|
// setupFallbackDNS initializes the fallback DNS servers.
|
||||||
func (s *Server) setupFallbackDNS() (err error) {
|
func (s *Server) setupFallbackDNS() (uc *proxy.UpstreamConfig, err error) {
|
||||||
fallbacks := s.conf.FallbackDNS
|
fallbacks := s.conf.FallbackDNS
|
||||||
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
fallbacks = stringutil.FilterOut(fallbacks, IsCommentOrEmpty)
|
||||||
if len(fallbacks) == 0 {
|
if len(fallbacks) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
uc, err := proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
|
uc, err = proxy.ParseUpstreamsConfig(fallbacks, &upstream.Options{
|
||||||
// TODO(s.chzhen): Investigate if other options are needed.
|
// TODO(s.chzhen): Investigate if other options are needed.
|
||||||
Timeout: s.conf.UpstreamTimeout,
|
Timeout: s.conf.UpstreamTimeout,
|
||||||
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
PreferIPv6: s.conf.BootstrapPreferIPv6,
|
||||||
|
// TODO(e.burkov): Use bootstrap.
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Do not wrap the error because it's informative enough as is.
|
// Do not wrap the error because it's informative enough as is.
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.dnsProxy.Fallbacks = uc
|
return uc, nil
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupAddrProc initializes the address processor. It assumes s.serverLock is
|
// setupAddrProc initializes the address processor. It assumes s.serverLock is
|
||||||
|
@ -730,19 +727,9 @@ func (s *Server) prepareInternalProxy() (err error) {
|
||||||
return fmt.Errorf("invalid upstream mode: %w", err)
|
return fmt.Errorf("invalid upstream mode: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(a.garipov): Make a proper constructor for proxy.Proxy.
|
s.internalProxy, err = proxy.New(conf)
|
||||||
p := &proxy.Proxy{
|
|
||||||
Config: *conf,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = p.Init()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
|
|
||||||
s.internalProxy = p
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the DNS server.
|
// Stop stops the DNS server.
|
||||||
|
@ -761,14 +748,17 @@ func (s *Server) stopLocked() (err error) {
|
||||||
// [upstream.Upstream] implementations.
|
// [upstream.Upstream] implementations.
|
||||||
|
|
||||||
if s.dnsProxy != nil {
|
if s.dnsProxy != nil {
|
||||||
err = s.dnsProxy.Stop()
|
// TODO(e.burkov): Use context properly.
|
||||||
|
err = s.dnsProxy.Shutdown(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("dnsforward: closing primary resolvers: %s", err)
|
log.Error("dnsforward: closing primary resolvers: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
|
||||||
|
if s.localResolvers != nil {
|
||||||
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
|
||||||
|
}
|
||||||
|
|
||||||
for _, b := range s.bootResolvers {
|
for _, b := range s.bootResolvers {
|
||||||
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
|
||||||
|
|
|
@ -5,9 +5,11 @@ import (
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
@ -63,8 +65,7 @@ func startDeferStop(t *testing.T, s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
err := s.Start()
|
err := s.Start()
|
||||||
require.NoErrorf(t, err, "failed to start server: %s", err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +73,6 @@ func createTestServer(
|
||||||
t *testing.T,
|
t *testing.T,
|
||||||
filterConf *filtering.Config,
|
filterConf *filtering.Config,
|
||||||
forwardConf ServerConfig,
|
forwardConf ServerConfig,
|
||||||
localUps upstream.Upstream,
|
|
||||||
) (s *Server) {
|
) (s *Server) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
|
@ -82,7 +82,8 @@ func createTestServer(
|
||||||
@@||whitelist.example.org^
|
@@||whitelist.example.org^
|
||||||
||127.0.0.255`
|
||127.0.0.255`
|
||||||
filters := []filtering.Filter{{
|
filters := []filtering.Filter{{
|
||||||
ID: 0, Data: []byte(rules),
|
ID: 0,
|
||||||
|
Data: []byte(rules),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
f, err := filtering.New(filterConf, filters)
|
f, err := filtering.New(filterConf, filters)
|
||||||
|
@ -105,19 +106,6 @@ func createTestServer(
|
||||||
err = s.Prepare(&forwardConf)
|
err = s.Prepare(&forwardConf)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s.serverLock.Lock()
|
|
||||||
defer s.serverLock.Unlock()
|
|
||||||
|
|
||||||
// TODO(e.burkov): Try to move it higher.
|
|
||||||
if localUps != nil {
|
|
||||||
ups := []upstream.Upstream{localUps}
|
|
||||||
s.localResolvers.UpstreamConfig.Upstreams = ups
|
|
||||||
s.conf.UsePrivateRDNS = true
|
|
||||||
s.dnsProxy.PrivateRDNSUpstreamConfig = &proxy.UpstreamConfig{
|
|
||||||
Upstreams: ups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,7 +169,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte)
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
|
|
||||||
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem
|
||||||
s.conf.TLSConfig = tlsConf
|
s.conf.TLSConfig = tlsConf
|
||||||
|
@ -310,7 +298,7 @@ func TestServer(t *testing.T) {
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
@ -410,7 +398,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
@ -490,7 +478,7 @@ func TestServerRace(t *testing.T) {
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
|
@ -545,7 +533,7 @@ func TestSafeSearch(t *testing.T) {
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||||
|
@ -628,7 +616,7 @@ func TestInvalidRequest(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||||
|
@ -662,7 +650,7 @@ func TestBlockedRequest(t *testing.T) {
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
@ -698,7 +686,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
|
|
||||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||||
atomic.AddUint32(&upsCalledCounter, 1)
|
atomic.AddUint32(&upsCalledCounter, 1)
|
||||||
|
@ -773,7 +761,7 @@ func TestBlockCNAMEProtectionEnabled(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
testUpstm := &aghtest.Upstream{
|
testUpstm := &aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
IPv4: testIPv4,
|
IPv4: testIPv4,
|
||||||
|
@ -811,7 +799,7 @@ func TestBlockCNAME(t *testing.T) {
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.Upstream{
|
&aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
|
@ -886,7 +874,7 @@ func TestClientRulesForCNAMEMatching(t *testing.T) {
|
||||||
}
|
}
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||||
&aghtest.Upstream{
|
&aghtest.Upstream{
|
||||||
CName: testCNAMEs,
|
CName: testCNAMEs,
|
||||||
|
@ -933,7 +921,7 @@ func TestNullBlockedRequest(t *testing.T) {
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeNullIP,
|
BlockingMode: filtering.BlockingModeNullIP,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
@ -1054,7 +1042,7 @@ func TestBlockedByHosts(t *testing.T) {
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
ProtectionEnabled: true,
|
ProtectionEnabled: true,
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, forwardConf, nil)
|
}, forwardConf)
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
@ -1102,7 +1090,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) {
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||||
|
|
||||||
|
@ -1330,6 +1318,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||||
|
|
||||||
var eventsCalledCounter uint32
|
var eventsCalledCounter uint32
|
||||||
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: func() (e <-chan struct{}) {
|
OnEvents: func() (e <-chan struct{}) {
|
||||||
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))
|
||||||
|
|
||||||
|
@ -1481,6 +1470,8 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
onesIP = netip.MustParseAddr("1.1.1.1")
|
onesIP = netip.MustParseAddr("1.1.1.1")
|
||||||
twosIP = netip.MustParseAddr("2.2.2.2")
|
twosIP = netip.MustParseAddr("2.2.2.2")
|
||||||
localIP = netip.MustParseAddr("192.168.1.1")
|
localIP = netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
pt = testutil.PanicT{}
|
||||||
)
|
)
|
||||||
|
|
||||||
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
onesRevExtIPv4, err := netutil.IPToReversedAddr(onesIP.AsSlice())
|
||||||
|
@ -1489,72 +1480,73 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
twosRevExtIPv4, err := netutil.IPToReversedAddr(twosIP.AsSlice())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
extUpstream := &aghtest.UpstreamMock{
|
extUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
OnAddress: func() (addr string) { return "external.upstream.example" },
|
resp := aghalg.Coalesce(
|
||||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, dns.Fqdn(onesHost)),
|
||||||
return aghalg.Coalesce(
|
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, dns.Fqdn(twosHost))),
|
||||||
aghtest.MatchedResponse(req, dns.TypePTR, onesRevExtIPv4, onesHost),
|
|
||||||
doubleTTL(aghtest.MatchedResponse(req, dns.TypePTR, twosRevExtIPv4, twosHost)),
|
|
||||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||||
), nil
|
)
|
||||||
},
|
|
||||||
}
|
require.NoError(pt, w.WriteMsg(resp))
|
||||||
|
})
|
||||||
|
upsAddr := aghtest.StartLocalhostUpstream(t, extUpsHdlr).String()
|
||||||
|
|
||||||
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
revLocIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
locUpstream := &aghtest.UpstreamMock{
|
locUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
OnAddress: func() (addr string) { return "local.upstream.example" },
|
resp := aghalg.Coalesce(
|
||||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, dns.Fqdn(localDomainHost)),
|
||||||
return aghalg.Coalesce(
|
|
||||||
aghtest.MatchedResponse(req, dns.TypePTR, revLocIPv4, localDomainHost),
|
|
||||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||||
), nil
|
)
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
errUpstream := aghtest.NewErrorUpstream()
|
require.NoError(pt, w.WriteMsg(resp))
|
||||||
nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true)
|
|
||||||
refusingUpstream := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
|
||||||
return new(dns.Msg).SetRcode(req, dns.RcodeRefused), nil
|
|
||||||
})
|
})
|
||||||
zeroTTLUps := &aghtest.UpstreamMock{
|
|
||||||
OnAddress: func() (addr string) { return "zero.ttl.example" },
|
errUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
|
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeServerFailure)))
|
||||||
resp = new(dns.Msg).SetReply(req)
|
})
|
||||||
hdr := dns.RR_Header{
|
|
||||||
|
nonPtrHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
hash := sha256.Sum256([]byte("some-host"))
|
||||||
|
resp := (&dns.Msg{
|
||||||
|
Answer: []dns.RR{&dns.TXT{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: req.Question[0].Name,
|
||||||
|
Rrtype: dns.TypeTXT,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 60,
|
||||||
|
},
|
||||||
|
Txt: []string{hex.EncodeToString(hash[:])},
|
||||||
|
}},
|
||||||
|
}).SetReply(req)
|
||||||
|
|
||||||
|
require.NoError(pt, w.WriteMsg(resp))
|
||||||
|
})
|
||||||
|
refusingHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
require.NoError(pt, w.WriteMsg(new(dns.Msg).SetRcode(req, dns.RcodeRefused)))
|
||||||
|
})
|
||||||
|
|
||||||
|
zeroTTLHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
resp := (&dns.Msg{
|
||||||
|
Answer: []dns.RR{&dns.PTR{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
Name: req.Question[0].Name,
|
Name: req.Question[0].Name,
|
||||||
Rrtype: dns.TypePTR,
|
Rrtype: dns.TypePTR,
|
||||||
Class: dns.ClassINET,
|
Class: dns.ClassINET,
|
||||||
Ttl: 0,
|
Ttl: 0,
|
||||||
}
|
},
|
||||||
resp.Answer = []dns.RR{&dns.PTR{
|
Ptr: dns.Fqdn(localDomainHost),
|
||||||
Hdr: hdr,
|
}},
|
||||||
Ptr: localDomainHost,
|
}).SetReply(req)
|
||||||
}}
|
|
||||||
|
|
||||||
return resp, nil
|
require.NoError(pt, w.WriteMsg(resp))
|
||||||
},
|
})
|
||||||
}
|
|
||||||
|
|
||||||
srv := &Server{
|
|
||||||
recDetector: newRecursionDetector(0, 1),
|
|
||||||
internalProxy: &proxy.Proxy{
|
|
||||||
Config: proxy.Config{
|
|
||||||
UpstreamConfig: &proxy.UpstreamConfig{
|
|
||||||
Upstreams: []upstream.Upstream{extUpstream},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv.conf.UsePrivateRDNS = true
|
|
||||||
srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
|
||||||
require.NoError(t, srv.internalProxy.Init())
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
req netip.Addr
|
req netip.Addr
|
||||||
wantErr error
|
wantErr error
|
||||||
locUpstream upstream.Upstream
|
locUpstream dns.Handler
|
||||||
name string
|
name string
|
||||||
want string
|
want string
|
||||||
wantTTL time.Duration
|
wantTTL time.Duration
|
||||||
|
@ -1569,35 +1561,35 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
name: "local_good",
|
name: "local_good",
|
||||||
want: localDomainHost,
|
want: localDomainHost,
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
locUpstream: locUpstream,
|
locUpstream: locUpsHdlr,
|
||||||
req: localIP,
|
req: localIP,
|
||||||
wantTTL: defaultTTL,
|
wantTTL: defaultTTL,
|
||||||
}, {
|
}, {
|
||||||
name: "upstream_error",
|
name: "upstream_error",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: aghtest.ErrUpstream,
|
wantErr: ErrRDNSFailed,
|
||||||
locUpstream: errUpstream,
|
locUpstream: errUpsHdlr,
|
||||||
req: localIP,
|
req: localIP,
|
||||||
wantTTL: 0,
|
wantTTL: 0,
|
||||||
}, {
|
}, {
|
||||||
name: "empty_answer_error",
|
name: "empty_answer_error",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: ErrRDNSNoData,
|
wantErr: ErrRDNSNoData,
|
||||||
locUpstream: locUpstream,
|
locUpstream: locUpsHdlr,
|
||||||
req: netip.MustParseAddr("192.168.1.2"),
|
req: netip.MustParseAddr("192.168.1.2"),
|
||||||
wantTTL: 0,
|
wantTTL: 0,
|
||||||
}, {
|
}, {
|
||||||
name: "invalid_answer",
|
name: "invalid_answer",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: ErrRDNSNoData,
|
wantErr: ErrRDNSNoData,
|
||||||
locUpstream: nonPtrUpstream,
|
locUpstream: nonPtrHdlr,
|
||||||
req: localIP,
|
req: localIP,
|
||||||
wantTTL: 0,
|
wantTTL: 0,
|
||||||
}, {
|
}, {
|
||||||
name: "refused",
|
name: "refused",
|
||||||
want: "",
|
want: "",
|
||||||
wantErr: ErrRDNSFailed,
|
wantErr: ErrRDNSFailed,
|
||||||
locUpstream: refusingUpstream,
|
locUpstream: refusingHdlr,
|
||||||
req: localIP,
|
req: localIP,
|
||||||
wantTTL: 0,
|
wantTTL: 0,
|
||||||
}, {
|
}, {
|
||||||
|
@ -1611,23 +1603,28 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
name: "zero_ttl",
|
name: "zero_ttl",
|
||||||
want: localDomainHost,
|
want: localDomainHost,
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
locUpstream: zeroTTLUps,
|
locUpstream: zeroTTLHdlr,
|
||||||
req: localIP,
|
req: localIP,
|
||||||
wantTTL: 0,
|
wantTTL: 0,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
pcfg := proxy.Config{
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, tc.locUpstream).String()
|
||||||
UpstreamConfig: &proxy.UpstreamConfig{
|
|
||||||
Upstreams: []upstream.Upstream{tc.locUpstream},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv.localResolvers = &proxy.Proxy{
|
|
||||||
Config: pcfg,
|
|
||||||
}
|
|
||||||
require.NoError(t, srv.localResolvers.Init())
|
|
||||||
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
srv := createTestServer(t, &filtering.Config{
|
||||||
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
|
}, ServerConfig{
|
||||||
|
Config: Config{
|
||||||
|
UpstreamDNS: []string{upsAddr},
|
||||||
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
},
|
||||||
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
|
UsePrivateRDNS: true,
|
||||||
|
ServePlainDNS: true,
|
||||||
|
})
|
||||||
|
|
||||||
host, ttl, eerr := srv.Exchange(tc.req)
|
host, ttl, eerr := srv.Exchange(tc.req)
|
||||||
|
|
||||||
require.ErrorIs(t, eerr, tc.wantErr)
|
require.ErrorIs(t, eerr, tc.wantErr)
|
||||||
|
@ -1637,8 +1634,17 @@ func TestServer_Exchange(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("resolving_disabled", func(t *testing.T) {
|
t.Run("resolving_disabled", func(t *testing.T) {
|
||||||
srv.conf.UsePrivateRDNS = false
|
srv := createTestServer(t, &filtering.Config{
|
||||||
t.Cleanup(func() { srv.conf.UsePrivateRDNS = true })
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
|
}, ServerConfig{
|
||||||
|
Config: Config{
|
||||||
|
UpstreamDNS: []string{upsAddr},
|
||||||
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
},
|
||||||
|
LocalPTRResolvers: []string{},
|
||||||
|
ServePlainDNS: true,
|
||||||
|
})
|
||||||
|
|
||||||
host, _, eerr := srv.Exchange(localIP)
|
host, _, eerr := srv.Exchange(localIP)
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
|
|
||||||
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
makeQ := func(qtype rules.RRType) (req *dns.Msg) {
|
||||||
return &dns.Msg{
|
return &dns.Msg{
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
@ -12,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// beforeRequestHandler is the handler that is called before any other
|
// beforeRequestHandler is the handler that is called before any other
|
||||||
|
|
|
@ -6,16 +6,17 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// jsonDNSConfig is the JSON representation of the DNS server configuration.
|
// jsonDNSConfig is the JSON representation of the DNS server configuration.
|
||||||
|
@ -294,7 +295,7 @@ func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ValidateUpstreams(*req.Fallbacks)
|
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("fallback servers: %w", err)
|
return fmt.Errorf("fallback servers: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -344,7 +345,7 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
||||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||||
if req.Upstreams != nil {
|
if req.Upstreams != nil {
|
||||||
err = ValidateUpstreams(*req.Upstreams)
|
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("upstream servers: %w", err)
|
return fmt.Errorf("upstream servers: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -580,9 +581,6 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
req.Upstreams = stringutil.FilterOut(req.Upstreams, IsCommentOrEmpty)
|
|
||||||
req.FallbackDNS = stringutil.FilterOut(req.FallbackDNS, IsCommentOrEmpty)
|
|
||||||
req.PrivateUpstreams = stringutil.FilterOut(req.PrivateUpstreams, IsCommentOrEmpty)
|
|
||||||
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
req.BootstrapDNS = stringutil.FilterOut(req.BootstrapDNS, IsCommentOrEmpty)
|
||||||
|
|
||||||
opts := &upstream.Options{
|
opts := &upstream.Options{
|
||||||
|
|
|
@ -83,7 +83,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
s.sysResolvers = &emptySysResolvers{}
|
s.sysResolvers = &emptySysResolvers{}
|
||||||
|
|
||||||
require.NoError(t, s.Start())
|
require.NoError(t, s.Start())
|
||||||
|
@ -164,7 +164,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||||
ConfigModified: func() {},
|
ConfigModified: func() {},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}
|
}
|
||||||
s := createTestServer(t, filterConf, forwardConf, nil)
|
s := createTestServer(t, filterConf, forwardConf)
|
||||||
s.sysResolvers = &emptySysResolvers{}
|
s.sysResolvers = &emptySysResolvers{}
|
||||||
|
|
||||||
defaultConf := s.conf
|
defaultConf := s.conf
|
||||||
|
@ -223,8 +223,9 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||||
wantSet: "",
|
wantSet: "",
|
||||||
}, {
|
}, {
|
||||||
name: "upstream_dns_bad",
|
name: "upstream_dns_bad",
|
||||||
wantSet: `validating dns config: ` +
|
wantSet: `validating dns config: upstream servers: parsing error at index 0: ` +
|
||||||
`upstream servers: validating upstream "!!!": not an ip:port`,
|
`cannot prepare the upstream: invalid address !!!: bad hostname "!!!": ` +
|
||||||
|
`bad top-level domain name label "!!!": bad top-level domain name label rune '!'`,
|
||||||
}, {
|
}, {
|
||||||
name: "bootstraps_bad",
|
name: "bootstraps_bad",
|
||||||
wantSet: `validating dns config: checking bootstrap a: not a bootstrap: ParseAddr("a"): ` +
|
wantSet: `validating dns config: checking bootstrap a: not a bootstrap: ParseAddr("a"): ` +
|
||||||
|
@ -313,98 +314,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateUpstreams(t *testing.T) {
|
|
||||||
const sdnsStamp = `sdns://AQMAAAAAAAAAFDE3Ni4xMDMuMTMwLjEzMDo1NDQzINErR_J` +
|
|
||||||
`S3PLCu_iZEIbq95zkSV2LFsigxDIuUso_OQhzIjIuZG5zY3J5cHQuZGVmYXVsdC5uczE` +
|
|
||||||
`uYWRndWFyZC5jb20`
|
|
||||||
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
wantErr string
|
|
||||||
set []string
|
|
||||||
}{{
|
|
||||||
name: "empty",
|
|
||||||
wantErr: ``,
|
|
||||||
set: nil,
|
|
||||||
}, {
|
|
||||||
name: "comment",
|
|
||||||
wantErr: ``,
|
|
||||||
set: []string{"# comment"},
|
|
||||||
}, {
|
|
||||||
name: "no_default",
|
|
||||||
wantErr: `no default upstreams specified`,
|
|
||||||
set: []string{
|
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]" + sdnsStamp,
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
name: "with_default",
|
|
||||||
wantErr: ``,
|
|
||||||
set: []string{
|
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]" + sdnsStamp,
|
|
||||||
"8.8.8.8",
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
wantErr: `validating upstream "dhcp://fake.dns": bad protocol "dhcp"`,
|
|
||||||
set: []string{"dhcp://fake.dns"},
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
wantErr: `validating upstream "1.2.3.4.5": not an ip:port`,
|
|
||||||
set: []string{"1.2.3.4.5"},
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
wantErr: `validating upstream "123.3.7m": not an ip:port`,
|
|
||||||
set: []string{"123.3.7m"},
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
wantErr: `splitting upstream line "[/host.com]tls://dns.adguard.com": ` +
|
|
||||||
`missing separator`,
|
|
||||||
set: []string{"[/host.com]tls://dns.adguard.com"},
|
|
||||||
}, {
|
|
||||||
name: "invalid",
|
|
||||||
wantErr: `validating upstream "[host.ru]#": not an ip:port`,
|
|
||||||
set: []string{"[host.ru]#"},
|
|
||||||
}, {
|
|
||||||
name: "valid_default",
|
|
||||||
wantErr: ``,
|
|
||||||
set: []string{
|
|
||||||
"1.1.1.1",
|
|
||||||
"tls://1.1.1.1",
|
|
||||||
"https://dns.adguard.com/dns-query",
|
|
||||||
sdnsStamp,
|
|
||||||
"udp://dns.google",
|
|
||||||
"udp://8.8.8.8",
|
|
||||||
"[/host.com/]1.1.1.1",
|
|
||||||
"[//]tls://1.1.1.1",
|
|
||||||
"[/www.host.com/]#",
|
|
||||||
"[/host.com/google.com/]8.8.8.8",
|
|
||||||
"[/host/]" + sdnsStamp,
|
|
||||||
"[/пример.рф/]8.8.8.8",
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
name: "bad_domain",
|
|
||||||
wantErr: `splitting upstream line "[/!/]8.8.8.8": domain at index 0: ` +
|
|
||||||
`bad domain name "!": bad top-level domain name label "!": ` +
|
|
||||||
`bad top-level domain name label rune '!'`,
|
|
||||||
set: []string{"[/!/]8.8.8.8"},
|
|
||||||
}}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
err := ValidateUpstreams(tc.set)
|
|
||||||
testutil.AssertErrorMsg(t, tc.wantErr, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateUpstreamsPrivate(t *testing.T) {
|
func TestValidateUpstreamsPrivate(t *testing.T) {
|
||||||
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
|
||||||
|
|
||||||
|
@ -509,6 +418,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
&aghtest.FSWatcher{
|
&aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
OnAdd: func(_ string) (err error) { return nil },
|
OnAdd: func(_ string) (err error) { return nil },
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
|
@ -529,7 +439,7 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
srv.etcHosts = upstream.NewHostsResolver(hc)
|
srv.etcHosts = upstream.NewHostsResolver(hc)
|
||||||
startDeferStop(t, srv)
|
startDeferStop(t, srv)
|
||||||
|
|
||||||
|
|
|
@ -2,13 +2,13 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
// makeResponse creates a DNS response by req and sets necessary flags. It also
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/testutil"
|
"github.com/AdguardTeam/golibs/testutil"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
|
@ -87,7 +86,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||||
|
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, c, nil)
|
}, c)
|
||||||
|
|
||||||
var gotAddr netip.Addr
|
var gotAddr netip.Addr
|
||||||
s.addrProc = &aghtest.AddressProcessor{
|
s.addrProc = &aghtest.AddressProcessor{
|
||||||
|
@ -188,7 +187,7 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||||
|
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
}, c, nil)
|
}, c)
|
||||||
|
|
||||||
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
|
resp := newResp(dns.RcodeSuccess, tc.req, tc.respAns)
|
||||||
dctx := &dnsContext{
|
dctx := &dnsContext{
|
||||||
|
@ -248,9 +247,9 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host string
|
host string
|
||||||
want []*dns.SVCB
|
want []*dns.SVCB
|
||||||
wantRes resultCode
|
wantRes resultCode
|
||||||
portDoH int
|
addrsDoH []*net.TCPAddr
|
||||||
portDoT int
|
addrsDoT []*net.TCPAddr
|
||||||
portDoQ int
|
addrsDoQ []*net.UDPAddr
|
||||||
qtype uint16
|
qtype uint16
|
||||||
ddrEnabled bool
|
ddrEnabled bool
|
||||||
}{{
|
}{{
|
||||||
|
@ -259,14 +258,14 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: testQuestionTarget,
|
host: testQuestionTarget,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoH: 8043,
|
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||||
}, {
|
}, {
|
||||||
name: "pass_qtype",
|
name: "pass_qtype",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeA,
|
qtype: dns.TypeA,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoH: 8043,
|
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||||
}, {
|
}, {
|
||||||
name: "pass_disabled_tls",
|
name: "pass_disabled_tls",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
|
@ -279,7 +278,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: false,
|
ddrEnabled: false,
|
||||||
portDoH: 8043,
|
addrsDoH: []*net.TCPAddr{{Port: 8043}},
|
||||||
}, {
|
}, {
|
||||||
name: "dot",
|
name: "dot",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
|
@ -287,7 +286,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoT: 8043,
|
addrsDoT: []*net.TCPAddr{{Port: 8043}},
|
||||||
}, {
|
}, {
|
||||||
name: "doh",
|
name: "doh",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
|
@ -295,7 +294,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoH: 8044,
|
addrsDoH: []*net.TCPAddr{{Port: 8044}},
|
||||||
}, {
|
}, {
|
||||||
name: "doq",
|
name: "doq",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
|
@ -303,7 +302,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoQ: 8042,
|
addrsDoQ: []*net.UDPAddr{{Port: 8042}},
|
||||||
}, {
|
}, {
|
||||||
name: "dot_doh",
|
name: "dot_doh",
|
||||||
wantRes: resultCodeFinish,
|
wantRes: resultCodeFinish,
|
||||||
|
@ -311,13 +310,35 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||||
host: ddrHostFQDN,
|
host: ddrHostFQDN,
|
||||||
qtype: dns.TypeSVCB,
|
qtype: dns.TypeSVCB,
|
||||||
ddrEnabled: true,
|
ddrEnabled: true,
|
||||||
portDoT: 8043,
|
addrsDoT: []*net.TCPAddr{{Port: 8043}},
|
||||||
portDoH: 8044,
|
addrsDoH: []*net.TCPAddr{{Port: 8044}},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
_, certPem, keyPem := createServerTLSConfig(t)
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled)
|
s := createTestServer(t, &filtering.Config{
|
||||||
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
|
}, ServerConfig{
|
||||||
|
Config: Config{
|
||||||
|
HandleDDR: tc.ddrEnabled,
|
||||||
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
|
},
|
||||||
|
TLSConfig: TLSConfig{
|
||||||
|
ServerName: ddrTestDomainName,
|
||||||
|
CertificateChainData: certPem,
|
||||||
|
PrivateKeyData: keyPem,
|
||||||
|
TLSListenAddrs: tc.addrsDoT,
|
||||||
|
HTTPSListenAddrs: tc.addrsDoH,
|
||||||
|
QUICListenAddrs: tc.addrsDoQ,
|
||||||
|
},
|
||||||
|
ServePlainDNS: true,
|
||||||
|
})
|
||||||
|
// TODO(e.burkov): Generate a certificate actually containing the
|
||||||
|
// IP addresses.
|
||||||
|
s.conf.hasIPAddrs = true
|
||||||
|
|
||||||
req := createTestMessageWithType(tc.host, tc.qtype)
|
req := createTestMessageWithType(tc.host, tc.qtype)
|
||||||
|
|
||||||
|
@ -358,41 +379,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
s = &Server{
|
|
||||||
dnsFilter: createTestDNSFilter(t),
|
|
||||||
dnsProxy: &proxy.Proxy{
|
|
||||||
Config: proxy.Config{},
|
|
||||||
},
|
|
||||||
conf: ServerConfig{
|
|
||||||
Config: Config{
|
|
||||||
HandleDDR: ddrEnabled,
|
|
||||||
},
|
|
||||||
TLSConfig: TLSConfig{
|
|
||||||
ServerName: ddrTestDomainName,
|
|
||||||
},
|
|
||||||
ServePlainDNS: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
if portDoT > 0 {
|
|
||||||
s.dnsProxy.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}}
|
|
||||||
s.conf.hasIPAddrs = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if portDoQ > 0 {
|
|
||||||
s.dnsProxy.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}}
|
|
||||||
}
|
|
||||||
|
|
||||||
if portDoH > 0 {
|
|
||||||
s.conf.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
func TestServer_ProcessDetermineLocal(t *testing.T) {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
|
||||||
|
@ -680,13 +666,16 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||||
intPTRAnswer = "some.local-client."
|
intPTRAnswer = "some.local-client."
|
||||||
)
|
)
|
||||||
|
|
||||||
ups := aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
return aghalg.Coalesce(
|
resp := aghalg.Coalesce(
|
||||||
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
|
||||||
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
|
||||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||||
), nil
|
)
|
||||||
|
|
||||||
|
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||||
})
|
})
|
||||||
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||||
|
|
||||||
s := createTestServer(t, &filtering.Config{
|
s := createTestServer(t, &filtering.Config{
|
||||||
BlockingMode: filtering.BlockingModeDefault,
|
BlockingMode: filtering.BlockingModeDefault,
|
||||||
|
@ -696,12 +685,14 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
|
||||||
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
// TODO(s.chzhen): Add tests where EDNSClientSubnet.Enabled is true.
|
||||||
// Improve Config declaration for tests.
|
// Improve Config declaration for tests.
|
||||||
Config: Config{
|
Config: Config{
|
||||||
|
UpstreamDNS: []string{localUpsAddr},
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
|
UsePrivateRDNS: true,
|
||||||
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, ups)
|
})
|
||||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ups}
|
|
||||||
startDeferStop(t, s)
|
startDeferStop(t, s)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
|
@ -764,6 +755,16 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
const locDomain = "some.local."
|
const locDomain = "some.local."
|
||||||
const reqAddr = "1.1.168.192.in-addr.arpa."
|
const reqAddr = "1.1.168.192.in-addr.arpa."
|
||||||
|
|
||||||
|
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
|
||||||
|
resp := aghalg.Coalesce(
|
||||||
|
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
||||||
|
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
|
||||||
|
})
|
||||||
|
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
|
||||||
|
|
||||||
s := createTestServer(
|
s := createTestServer(
|
||||||
t,
|
t,
|
||||||
&filtering.Config{
|
&filtering.Config{
|
||||||
|
@ -776,14 +777,10 @@ func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) {
|
||||||
UpstreamMode: UpstreamModeLoadBalance,
|
UpstreamMode: UpstreamModeLoadBalance,
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
|
UsePrivateRDNS: true,
|
||||||
|
LocalPTRResolvers: []string{localUpsAddr},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
},
|
},
|
||||||
aghtest.NewUpstreamMock(func(req *dns.Msg) (resp *dns.Msg, err error) {
|
|
||||||
return aghalg.Coalesce(
|
|
||||||
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
|
|
||||||
new(dns.Msg).SetRcode(req, dns.RcodeNameError),
|
|
||||||
), nil
|
|
||||||
}),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var proxyCtx *proxy.DNSContext
|
var proxyCtx *proxy.DNSContext
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||||
},
|
},
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
}, nil)
|
})
|
||||||
|
|
||||||
req := &dns.Msg{
|
req := &dns.Msg{
|
||||||
Question: []dns.Question{{
|
Question: []dns.Question{{
|
||||||
|
|
|
@ -2,10 +2,9 @@ package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
@ -16,29 +15,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// errNotDomainSpecific is returned when the upstream should be
|
|
||||||
// domain-specific, but isn't.
|
|
||||||
errNotDomainSpecific errors.Error = "not a domain-specific upstream"
|
|
||||||
|
|
||||||
// errMissingSeparator is returned when the domain-specific part of the
|
|
||||||
// upstream configuration line isn't closed.
|
|
||||||
errMissingSeparator errors.Error = "missing separator"
|
|
||||||
|
|
||||||
// errDupSeparator is returned when the domain-specific part of the upstream
|
|
||||||
// configuration line contains more than one ending separator.
|
|
||||||
errDupSeparator errors.Error = "duplicated separator"
|
|
||||||
|
|
||||||
// errNoDefaultUpstreams is returned when there are no default upstreams
|
|
||||||
// specified in the upstream configuration.
|
|
||||||
errNoDefaultUpstreams errors.Error = "no default upstreams specified"
|
|
||||||
|
|
||||||
// errWrongResponse is returned when the checked upstream replies in an
|
|
||||||
// unexpected way.
|
|
||||||
errWrongResponse errors.Error = "wrong response"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// loadUpstreams parses upstream DNS servers from the configured file or from
|
// loadUpstreams parses upstream DNS servers from the configured file or from
|
||||||
|
@ -199,84 +175,12 @@ func IsCommentOrEmpty(s string) (ok bool) {
|
||||||
return len(s) == 0 || s[0] == '#'
|
return len(s) == 0 || s[0] == '#'
|
||||||
}
|
}
|
||||||
|
|
||||||
// newUpstreamConfig validates upstreams and returns an appropriate upstream
|
|
||||||
// configuration or nil if it can't be built.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Perhaps proxy.ParseUpstreamsConfig should validate upstreams
|
|
||||||
// slice already so that this function may be considered useless.
|
|
||||||
func newUpstreamConfig(upstreams []string) (conf *proxy.UpstreamConfig, err error) {
|
|
||||||
// No need to validate comments and empty lines.
|
|
||||||
upstreams = stringutil.FilterOut(upstreams, IsCommentOrEmpty)
|
|
||||||
if len(upstreams) == 0 {
|
|
||||||
// Consider this case valid since it means the default server should be
|
|
||||||
// used.
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err = validateUpstreamConfig(upstreams)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conf, err = proxy.ParseUpstreamsConfig(
|
|
||||||
upstreams,
|
|
||||||
&upstream.Options{
|
|
||||||
Bootstrap: net.DefaultResolver,
|
|
||||||
Timeout: DefaultTimeout,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return nil, err
|
|
||||||
} else if len(conf.Upstreams) == 0 {
|
|
||||||
return nil, errNoDefaultUpstreams
|
|
||||||
}
|
|
||||||
|
|
||||||
return conf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateUpstreamConfig validates each upstream from the upstream
|
|
||||||
// configuration and returns an error if any upstream is invalid.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
|
||||||
func validateUpstreamConfig(conf []string) (err error) {
|
|
||||||
for _, u := range conf {
|
|
||||||
var ups []string
|
|
||||||
var isSpecific bool
|
|
||||||
ups, isSpecific, err = splitUpstreamLine(u)
|
|
||||||
if err != nil {
|
|
||||||
// Don't wrap the error since it's informative enough as is.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, addr := range ups {
|
|
||||||
_, err = validateUpstream(addr, isSpecific)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("validating upstream %q: %w", addr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateUpstreams validates each upstream and returns an error if any
|
|
||||||
// upstream is invalid or if there are no default upstreams specified.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Merge with [upstreamConfigValidator] somehow.
|
|
||||||
func ValidateUpstreams(upstreams []string) (err error) {
|
|
||||||
_, err = newUpstreamConfig(upstreams)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
|
||||||
// upstream is invalid or if there are no default upstreams specified. It also
|
// upstream is invalid or if there are no default upstreams specified. It also
|
||||||
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
// checks each domain of domain-specific upstreams for being ARPA pointing to
|
||||||
// a locally-served network. privateNets must not be nil.
|
// a locally-served network. privateNets must not be nil.
|
||||||
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
|
||||||
conf, err := newUpstreamConfig(upstreams)
|
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("creating config: %w", err)
|
return fmt.Errorf("creating config: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -308,66 +212,3 @@ func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet)
|
||||||
|
|
||||||
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
|
||||||
}
|
}
|
||||||
|
|
||||||
// protocols are the supported URL schemes for upstreams.
|
|
||||||
var protocols = []string{"h3", "https", "quic", "sdns", "tcp", "tls", "udp"}
|
|
||||||
|
|
||||||
// validateUpstream returns an error if u alongside with domains is not a valid
|
|
||||||
// upstream configuration. useDefault is true if the upstream is
|
|
||||||
// domain-specific and is configured to point at the default upstream server
|
|
||||||
// which is validated separately. The upstream is considered domain-specific
|
|
||||||
// only if domains is at least not nil.
|
|
||||||
func validateUpstream(u string, isSpecific bool) (useDefault bool, err error) {
|
|
||||||
// The special server address '#' means that default server must be used.
|
|
||||||
if useDefault = u == "#" && isSpecific; useDefault {
|
|
||||||
return useDefault, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the upstream has a valid protocol prefix.
|
|
||||||
//
|
|
||||||
// TODO(e.burkov): Validate the domain name.
|
|
||||||
if proto, _, ok := strings.Cut(u, "://"); ok {
|
|
||||||
if !slices.Contains(protocols, proto) {
|
|
||||||
return false, fmt.Errorf("bad protocol %q", proto)
|
|
||||||
}
|
|
||||||
} else if _, err = netip.ParseAddr(u); err == nil {
|
|
||||||
return false, nil
|
|
||||||
} else if _, err = netip.ParseAddrPort(u); err == nil {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitUpstreamLine returns the upstreams and the specified domains. domains
|
|
||||||
// is nil when the upstream is not domains-specific. Otherwise it may also be
|
|
||||||
// empty.
|
|
||||||
func splitUpstreamLine(upstreamStr string) (upstreams []string, isSpecific bool, err error) {
|
|
||||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
|
||||||
return []string{upstreamStr}, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() { err = errors.Annotate(err, "splitting upstream line %q: %w", upstreamStr) }()
|
|
||||||
|
|
||||||
doms, ups, found := strings.Cut(upstreamStr[2:], "/]")
|
|
||||||
if !found {
|
|
||||||
return nil, false, errMissingSeparator
|
|
||||||
} else if strings.Contains(ups, "/]") {
|
|
||||||
return nil, false, errDupSeparator
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, host := range strings.Split(doms, "/") {
|
|
||||||
if host == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = netutil.ValidateDomainName(strings.TrimPrefix(host, "*."))
|
|
||||||
if err != nil {
|
|
||||||
return nil, false, fmt.Errorf("domain at index %d: %w", i, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
isSpecific = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Fields(ups), isSpecific, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -100,8 +100,7 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||||
name: "bad_specification",
|
name: "bad_specification",
|
||||||
general: []string{"[/domain.example/]/]1.2.3.4"},
|
general: []string{"[/domain.example/]/]1.2.3.4"},
|
||||||
want: map[string]string{
|
want: map[string]string{
|
||||||
"[/domain.example/]/]1.2.3.4": `splitting upstream line ` +
|
"[/domain.example/]/]1.2.3.4": generalTextLabel + " 1: parsing error",
|
||||||
`"[/domain.example/]/]1.2.3.4": duplicated separator`,
|
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
name: "all_different",
|
name: "all_different",
|
||||||
|
@ -120,23 +119,9 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||||
fallback: []string{"[/example/" + goodUps},
|
fallback: []string{"[/example/" + goodUps},
|
||||||
private: []string{"[/example//bad.123/]" + goodUps},
|
private: []string{"[/example//bad.123/]" + goodUps},
|
||||||
want: map[string]string{
|
want: map[string]string{
|
||||||
`[/example/]/]` + goodUps: `splitting upstream line ` +
|
"[/example/]/]" + goodUps: generalTextLabel + " 1: parsing error",
|
||||||
`"[/example/]/]` + goodUps + `": duplicated separator`,
|
"[/example/" + goodUps: fallbackTextLabel + " 1: parsing error",
|
||||||
`[/example/` + goodUps: `splitting upstream line ` +
|
"[/example//bad.123/]" + goodUps: privateTextLabel + " 1: parsing error",
|
||||||
`"[/example/` + goodUps + `": missing separator`,
|
|
||||||
`[/example//bad.123/]` + goodUps: `splitting upstream line ` +
|
|
||||||
`"[/example//bad.123/]` + goodUps + `": domain at index 2: ` +
|
|
||||||
`bad domain name "bad.123": ` +
|
|
||||||
`bad top-level domain name label "123": all octets are numeric`,
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
name: "non-specific_default",
|
|
||||||
general: []string{
|
|
||||||
"#",
|
|
||||||
"[/example/]#",
|
|
||||||
},
|
|
||||||
want: map[string]string{
|
|
||||||
"#": "not a domain-specific upstream",
|
|
||||||
},
|
},
|
||||||
}, {
|
}, {
|
||||||
name: "bad_proto",
|
name: "bad_proto",
|
||||||
|
@ -144,7 +129,15 @@ func TestUpstreamConfigValidator(t *testing.T) {
|
||||||
"bad://1.2.3.4",
|
"bad://1.2.3.4",
|
||||||
},
|
},
|
||||||
want: map[string]string{
|
want: map[string]string{
|
||||||
"bad://1.2.3.4": `bad protocol "bad"`,
|
"bad://1.2.3.4": generalTextLabel + " 1: parsing error",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "truncated_line",
|
||||||
|
general: []string{
|
||||||
|
"This is a very long line. It will cause a parsing error and will be truncated here.",
|
||||||
|
},
|
||||||
|
want: map[string]string{
|
||||||
|
"This is a very long line. It will cause a parsing error and will be truncated …": "upstream_dns 1: parsing error",
|
||||||
},
|
},
|
||||||
}}
|
}}
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,13 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// serviceRules maps a service ID to its filtering rules.
|
// serviceRules maps a service ID to its filtering rules.
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -15,7 +16,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// filterDir is the subdirectory of a data directory to store downloaded
|
// filterDir is the subdirectory of a data directory to store downloaded
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
@ -29,7 +30,6 @@ import (
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// The IDs of built-in filter lists.
|
// The IDs of built-in filter lists.
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -14,7 +15,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/net/publicsuffix"
|
"golang.org/x/net/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package hashprefix
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -12,7 +13,6 @@ import (
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
|
@ -40,6 +40,7 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
watcher := &aghtest.FSWatcher{
|
watcher := &aghtest.FSWatcher{
|
||||||
|
OnStart: func() (_ error) { panic("not implemented") },
|
||||||
OnEvents: func() (e <-chan struct{}) { return nil },
|
OnEvents: func() (e <-chan struct{}) { return nil },
|
||||||
OnAdd: func(name string) (err error) { return nil },
|
OnAdd: func(name string) (err error) { return nil },
|
||||||
OnClose: func() (err error) { return nil },
|
OnClose: func() (err error) { return nil },
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -15,7 +16,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// validateFilterURL validates the filter list URL or file name.
|
// validateFilterURL validates the filter list URL or file name.
|
||||||
|
|
|
@ -3,6 +3,7 @@ package rewrite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -12,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Storage is a storage for rewrite rules.
|
// Storage is a storage for rewrite rules.
|
||||||
|
|
|
@ -3,10 +3,10 @@ package filtering
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
// TODO(d.kolyshev): Use [rewrite.Item] instead.
|
||||||
|
|
|
@ -3,12 +3,12 @@ package filtering
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Legacy DNS rewrites
|
// Legacy DNS rewrites
|
||||||
|
|
|
@ -6,9 +6,9 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
"io"
|
"io"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Parser is a filtering-rule parser that collects data, such as the checksum
|
// Parser is a filtering-rule parser that collects data, such as the checksum
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -23,7 +24,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
"github.com/AdguardTeam/golibs/stringutil"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
// DHCP is an interface for accessing DHCP lease data the [clientsContainer]
|
||||||
|
@ -47,8 +47,9 @@ type DHCP interface {
|
||||||
type clientsContainer struct {
|
type clientsContainer struct {
|
||||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||||
// types (string, netip.Addr, and so on).
|
// types (string, netip.Addr, and so on).
|
||||||
list map[string]*persistentClient // name -> client
|
list map[string]*client.Persistent // name -> client
|
||||||
idIndex map[string]*persistentClient // ID -> client
|
|
||||||
|
clientIndex *client.Index
|
||||||
|
|
||||||
// ipToRC maps IP addresses to runtime client information.
|
// ipToRC maps IP addresses to runtime client information.
|
||||||
ipToRC map[netip.Addr]*client.Runtime
|
ipToRC map[netip.Addr]*client.Runtime
|
||||||
|
@ -102,10 +103,11 @@ func (clients *clientsContainer) Init(
|
||||||
log.Fatal("clients.list != nil")
|
log.Fatal("clients.list != nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.list = map[string]*persistentClient{}
|
clients.list = map[string]*client.Persistent{}
|
||||||
clients.idIndex = map[string]*persistentClient{}
|
|
||||||
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
||||||
|
|
||||||
|
clients.clientIndex = client.NewIndex()
|
||||||
|
|
||||||
clients.allTags = stringutil.NewSet(clientTags...)
|
clients.allTags = stringutil.NewSet(clientTags...)
|
||||||
|
|
||||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
||||||
|
@ -140,8 +142,7 @@ func (clients *clientsContainer) Init(
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleHostsUpdates receives the updates from the hosts container and adds
|
// handleHostsUpdates receives the updates from the hosts container and adds
|
||||||
// them to the clients container. It's used to be called in a separate
|
// them to the clients container. It is intended to be used as a goroutine.
|
||||||
// goroutine.
|
|
||||||
func (clients *clientsContainer) handleHostsUpdates() {
|
func (clients *clientsContainer) handleHostsUpdates() {
|
||||||
for upd := range clients.etcHosts.Upd() {
|
for upd := range clients.etcHosts.Upd() {
|
||||||
clients.addFromHostsFile(upd)
|
clients.addFromHostsFile(upd)
|
||||||
|
@ -189,7 +190,7 @@ type clientObject struct {
|
||||||
Upstreams []string `yaml:"upstreams"`
|
Upstreams []string `yaml:"upstreams"`
|
||||||
|
|
||||||
// UID is the unique identifier of the persistent client.
|
// UID is the unique identifier of the persistent client.
|
||||||
UID UID `yaml:"uid"`
|
UID client.UID `yaml:"uid"`
|
||||||
|
|
||||||
// UpstreamsCacheSize is the DNS cache size (in bytes).
|
// UpstreamsCacheSize is the DNS cache size (in bytes).
|
||||||
//
|
//
|
||||||
|
@ -213,8 +214,8 @@ type clientObject struct {
|
||||||
func (o *clientObject) toPersistent(
|
func (o *clientObject) toPersistent(
|
||||||
filteringConf *filtering.Config,
|
filteringConf *filtering.Config,
|
||||||
allTags *stringutil.Set,
|
allTags *stringutil.Set,
|
||||||
) (cli *persistentClient, err error) {
|
) (cli *client.Persistent, err error) {
|
||||||
cli = &persistentClient{
|
cli = &client.Persistent{
|
||||||
Name: o.Name,
|
Name: o.Name,
|
||||||
|
|
||||||
Upstreams: o.Upstreams,
|
Upstreams: o.Upstreams,
|
||||||
|
@ -224,7 +225,7 @@ func (o *clientObject) toPersistent(
|
||||||
UseOwnSettings: !o.UseGlobalSettings,
|
UseOwnSettings: !o.UseGlobalSettings,
|
||||||
FilteringEnabled: o.FilteringEnabled,
|
FilteringEnabled: o.FilteringEnabled,
|
||||||
ParentalEnabled: o.ParentalEnabled,
|
ParentalEnabled: o.ParentalEnabled,
|
||||||
safeSearchConf: o.SafeSearchConf,
|
SafeSearchConf: o.SafeSearchConf,
|
||||||
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
SafeBrowsingEnabled: o.SafeBrowsingEnabled,
|
||||||
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
UseOwnBlockedServices: !o.UseGlobalBlockedServices,
|
||||||
IgnoreQueryLog: o.IgnoreQueryLog,
|
IgnoreQueryLog: o.IgnoreQueryLog,
|
||||||
|
@ -233,13 +234,13 @@ func (o *clientObject) toPersistent(
|
||||||
UpstreamsCacheSize: o.UpstreamsCacheSize,
|
UpstreamsCacheSize: o.UpstreamsCacheSize,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cli.setIDs(o.IDs)
|
err = cli.SetIDs(o.IDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parsing ids: %w", err)
|
return nil, fmt.Errorf("parsing ids: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cli.UID == UID{}) {
|
if (cli.UID == client.UID{}) {
|
||||||
cli.UID, err = NewUID()
|
cli.UID, err = client.NewUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("generating uid: %w", err)
|
return nil, fmt.Errorf("generating uid: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -248,7 +249,7 @@ func (o *clientObject) toPersistent(
|
||||||
if o.SafeSearchConf.Enabled {
|
if o.SafeSearchConf.Enabled {
|
||||||
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
o.SafeSearchConf.CustomResolver = safeSearchResolver{}
|
||||||
|
|
||||||
err = cli.setSafeSearch(
|
err = cli.SetSafeSearch(
|
||||||
o.SafeSearchConf,
|
o.SafeSearchConf,
|
||||||
filteringConf.SafeSearchCacheSize,
|
filteringConf.SafeSearchCacheSize,
|
||||||
time.Minute*time.Duration(filteringConf.CacheTime),
|
time.Minute*time.Duration(filteringConf.CacheTime),
|
||||||
|
@ -265,7 +266,7 @@ func (o *clientObject) toPersistent(
|
||||||
|
|
||||||
cli.BlockedServices = o.BlockedServices.Clone()
|
cli.BlockedServices = o.BlockedServices.Clone()
|
||||||
|
|
||||||
cli.setTags(o.Tags, allTags)
|
cli.SetTags(o.Tags, allTags)
|
||||||
|
|
||||||
return cli, nil
|
return cli, nil
|
||||||
}
|
}
|
||||||
|
@ -277,7 +278,7 @@ func (clients *clientsContainer) addFromConfig(
|
||||||
filteringConf *filtering.Config,
|
filteringConf *filtering.Config,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
for i, o := range objects {
|
for i, o := range objects {
|
||||||
var cli *persistentClient
|
var cli *client.Persistent
|
||||||
cli, err = o.toPersistent(filteringConf, clients.allTags)
|
cli, err = o.toPersistent(filteringConf, clients.allTags)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
return fmt.Errorf("clients: init persistent client at index %d: %w", i, err)
|
||||||
|
@ -305,7 +306,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
|
|
||||||
BlockedServices: cli.BlockedServices.Clone(),
|
BlockedServices: cli.BlockedServices.Clone(),
|
||||||
|
|
||||||
IDs: cli.ids(),
|
IDs: cli.IDs(),
|
||||||
Tags: stringutil.CloneSlice(cli.Tags),
|
Tags: stringutil.CloneSlice(cli.Tags),
|
||||||
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
Upstreams: stringutil.CloneSlice(cli.Upstreams),
|
||||||
|
|
||||||
|
@ -314,7 +315,7 @@ func (clients *clientsContainer) forConfig() (objs []*clientObject) {
|
||||||
UseGlobalSettings: !cli.UseOwnSettings,
|
UseGlobalSettings: !cli.UseOwnSettings,
|
||||||
FilteringEnabled: cli.FilteringEnabled,
|
FilteringEnabled: cli.FilteringEnabled,
|
||||||
ParentalEnabled: cli.ParentalEnabled,
|
ParentalEnabled: cli.ParentalEnabled,
|
||||||
SafeSearchConf: cli.safeSearchConf,
|
SafeSearchConf: cli.SafeSearchConf,
|
||||||
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
SafeBrowsingEnabled: cli.SafeBrowsingEnabled,
|
||||||
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
UseGlobalBlockedServices: !cli.UseOwnBlockedServices,
|
||||||
IgnoreQueryLog: cli.IgnoreQueryLog,
|
IgnoreQueryLog: cli.IgnoreQueryLog,
|
||||||
|
@ -435,7 +436,7 @@ func (clients *clientsContainer) clientOrArtificial(
|
||||||
}
|
}
|
||||||
|
|
||||||
// find returns a shallow copy of the client if there is one found.
|
// find returns a shallow copy of the client if there is one found.
|
||||||
func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool) {
|
func (clients *clientsContainer) find(id string) (c *client.Persistent, ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
|
@ -444,7 +445,7 @@ func (clients *clientsContainer) find(id string) (c *persistentClient, ok bool)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.shallowClone(), true
|
return c.ShallowClone(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
// shouldCountClient is a wrapper around [clientsContainer.find] to make it a
|
||||||
|
@ -480,8 +481,8 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||||
c, ok := clients.findLocked(id)
|
c, ok := clients.findLocked(id)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
} else if c.upstreamConfig != nil {
|
} else if c.UpstreamConfig != nil {
|
||||||
return c.upstreamConfig, nil
|
return c.UpstreamConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
|
upstreams := stringutil.FilterOut(c.Upstreams, dnsforward.IsCommentOrEmpty)
|
||||||
|
@ -510,15 +511,15 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||||
int(c.UpstreamsCacheSize),
|
int(c.UpstreamsCacheSize),
|
||||||
config.DNS.EDNSClientSubnet.Enabled,
|
config.DNS.EDNSClientSubnet.Enabled,
|
||||||
)
|
)
|
||||||
c.upstreamConfig = conf
|
c.UpstreamConfig = conf
|
||||||
|
|
||||||
return conf, nil
|
return conf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// findLocked searches for a client by its ID. clients.lock is expected to be
|
// findLocked searches for a client by its ID. clients.lock is expected to be
|
||||||
// locked.
|
// locked.
|
||||||
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
|
func (clients *clientsContainer) findLocked(id string) (c *client.Persistent, ok bool) {
|
||||||
c, ok = clients.idIndex[id]
|
c, ok = clients.clientIndex.Find(id)
|
||||||
if ok {
|
if ok {
|
||||||
return c, true
|
return c, true
|
||||||
}
|
}
|
||||||
|
@ -528,21 +529,13 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c = range clients.list {
|
|
||||||
for _, subnet := range c.Subnets {
|
|
||||||
if subnet.Contains(ip) {
|
|
||||||
return c, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(e.burkov): Iterate through clients.list only once.
|
// TODO(e.burkov): Iterate through clients.list only once.
|
||||||
return clients.findDHCP(ip)
|
return clients.findDHCP(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
// findDHCP searches for a client by its MAC, if the DHCP server is active and
|
||||||
// there is such client. clients.lock is expected to be locked.
|
// there is such client. clients.lock is expected to be locked.
|
||||||
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *persistentClient, ok bool) {
|
func (clients *clientsContainer) findDHCP(ip netip.Addr) (c *client.Persistent, ok bool) {
|
||||||
foundMAC := clients.dhcp.MACByIP(ip)
|
foundMAC := clients.dhcp.MACByIP(ip)
|
||||||
if foundMAC == nil {
|
if foundMAC == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
|
@ -592,13 +585,13 @@ func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *client.Ru
|
||||||
}
|
}
|
||||||
|
|
||||||
// check validates the client. It also sorts the client tags.
|
// check validates the client. It also sorts the client tags.
|
||||||
func (clients *clientsContainer) check(c *persistentClient) (err error) {
|
func (clients *clientsContainer) check(c *client.Persistent) (err error) {
|
||||||
switch {
|
switch {
|
||||||
case c == nil:
|
case c == nil:
|
||||||
return errors.Error("client is nil")
|
return errors.Error("client is nil")
|
||||||
case c.Name == "":
|
case c.Name == "":
|
||||||
return errors.Error("invalid name")
|
return errors.Error("invalid name")
|
||||||
case c.idsLen() == 0:
|
case c.IDsLen() == 0:
|
||||||
return errors.Error("id required")
|
return errors.Error("id required")
|
||||||
default:
|
default:
|
||||||
// Go on.
|
// Go on.
|
||||||
|
@ -613,7 +606,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) {
|
||||||
// TODO(s.chzhen): Move to the constructor.
|
// TODO(s.chzhen): Move to the constructor.
|
||||||
slices.Sort(c.Tags)
|
slices.Sort(c.Tags)
|
||||||
|
|
||||||
err = dnsforward.ValidateUpstreams(c.Upstreams)
|
_, err = proxy.ParseUpstreamsConfig(c.Upstreams, &upstream.Options{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid upstream servers: %w", err)
|
return fmt.Errorf("invalid upstream servers: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -623,7 +616,7 @@ func (clients *clientsContainer) check(c *persistentClient) (err error) {
|
||||||
|
|
||||||
// add adds a new client object. ok is false if such client already exists or
|
// add adds a new client object. ok is false if such client already exists or
|
||||||
// if an error occurred.
|
// if an error occurred.
|
||||||
func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
|
func (clients *clientsContainer) add(c *client.Persistent) (ok bool, err error) {
|
||||||
err = clients.check(c)
|
err = clients.check(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
|
@ -639,31 +632,26 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// check ID index
|
// check ID index
|
||||||
ids := c.ids()
|
err = clients.clientIndex.Clashes(c)
|
||||||
for _, id := range ids {
|
if err != nil {
|
||||||
var c2 *persistentClient
|
// Don't wrap the error since it's informative enough as is.
|
||||||
c2, ok = clients.idIndex[id]
|
return false, err
|
||||||
if ok {
|
|
||||||
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.addLocked(c)
|
clients.addLocked(c)
|
||||||
|
|
||||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
|
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.IDs(), len(clients.list))
|
||||||
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addLocked c to the indexes. clients.lock is expected to be locked.
|
// addLocked c to the indexes. clients.lock is expected to be locked.
|
||||||
func (clients *clientsContainer) addLocked(c *persistentClient) {
|
func (clients *clientsContainer) addLocked(c *client.Persistent) {
|
||||||
// update Name index
|
// update Name index
|
||||||
clients.list[c.Name] = c
|
clients.list[c.Name] = c
|
||||||
|
|
||||||
// update ID index
|
// update ID index
|
||||||
for _, id := range c.ids() {
|
clients.clientIndex.Add(c)
|
||||||
clients.idIndex[id] = c
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove removes a client. ok is false if there is no such client.
|
// remove removes a client. ok is false if there is no such client.
|
||||||
|
@ -671,7 +659,7 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
|
||||||
clients.lock.Lock()
|
clients.lock.Lock()
|
||||||
defer clients.lock.Unlock()
|
defer clients.lock.Unlock()
|
||||||
|
|
||||||
var c *persistentClient
|
var c *client.Persistent
|
||||||
c, ok = clients.list[name]
|
c, ok = clients.list[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
|
@ -684,8 +672,8 @@ func (clients *clientsContainer) remove(name string) (ok bool) {
|
||||||
|
|
||||||
// removeLocked removes c from the indexes. clients.lock is expected to be
|
// removeLocked removes c from the indexes. clients.lock is expected to be
|
||||||
// locked.
|
// locked.
|
||||||
func (clients *clientsContainer) removeLocked(c *persistentClient) {
|
func (clients *clientsContainer) removeLocked(c *client.Persistent) {
|
||||||
if err := c.closeUpstreams(); err != nil {
|
if err := c.CloseUpstreams(); err != nil {
|
||||||
log.Error("client container: removing client %s: %s", c.Name, err)
|
log.Error("client container: removing client %s: %s", c.Name, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -693,13 +681,11 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) {
|
||||||
delete(clients.list, c.Name)
|
delete(clients.list, c.Name)
|
||||||
|
|
||||||
// Update the ID index.
|
// Update the ID index.
|
||||||
for _, id := range c.ids() {
|
clients.clientIndex.Delete(c)
|
||||||
delete(clients.idIndex, id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// update updates a client by its name.
|
// update updates a client by its name.
|
||||||
func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
|
func (clients *clientsContainer) update(prev, c *client.Persistent) (err error) {
|
||||||
err = clients.check(c)
|
err = clients.check(c)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since it's informative enough as is.
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
@ -717,7 +703,7 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.equalIDs(prev) {
|
if c.EqualIDs(prev) {
|
||||||
clients.removeLocked(prev)
|
clients.removeLocked(prev)
|
||||||
clients.addLocked(c)
|
clients.addLocked(c)
|
||||||
|
|
||||||
|
@ -725,11 +711,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the ID index.
|
// Check the ID index.
|
||||||
for _, id := range c.ids() {
|
err = clients.clientIndex.Clashes(c)
|
||||||
existing, ok := clients.idIndex[id]
|
if err != nil {
|
||||||
if ok && existing != prev {
|
// Don't wrap the error since it's informative enough as is.
|
||||||
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
|
return err
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
clients.removeLocked(prev)
|
clients.removeLocked(prev)
|
||||||
|
@ -906,14 +891,14 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||||
// the persistent clients.
|
// the persistent clients.
|
||||||
func (clients *clientsContainer) close() (err error) {
|
func (clients *clientsContainer) close() (err error) {
|
||||||
persistent := maps.Values(clients.list)
|
persistent := maps.Values(clients.list)
|
||||||
slices.SortFunc(persistent, func(a, b *persistentClient) (res int) {
|
slices.SortFunc(persistent, func(a, b *client.Persistent) (res int) {
|
||||||
return strings.Compare(a.Name, b.Name)
|
return strings.Compare(a.Name, b.Name)
|
||||||
})
|
})
|
||||||
|
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
for _, cli := range persistent {
|
for _, cli := range persistent {
|
||||||
if err = cli.closeUpstreams(); err != nil {
|
if err = cli.CloseUpstreams(); err != nil {
|
||||||
errs = append(errs, err)
|
errs = append(errs, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,9 @@ func TestClients(t *testing.T) {
|
||||||
cliIPv6 = netip.MustParseAddr("1:2:3::4")
|
cliIPv6 = netip.MustParseAddr("1:2:3::4")
|
||||||
)
|
)
|
||||||
|
|
||||||
c := &persistentClient{
|
c := &client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,8 +77,9 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
c = &persistentClient{
|
c = &client.Persistent{
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{cli2IP},
|
IPs: []netip.Addr{cli2IP},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,8 +111,9 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_name", func(t *testing.T) {
|
t.Run("add_fail_name", func(t *testing.T) {
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -118,16 +121,18 @@ func TestClients(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add_fail_ip", func(t *testing.T) {
|
t.Run("add_fail_ip", func(t *testing.T) {
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
})
|
})
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.False(t, ok)
|
assert.False(t, ok)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("update_fail_ip", func(t *testing.T) {
|
t.Run("update_fail_ip", func(t *testing.T) {
|
||||||
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
|
err := clients.update(&client.Persistent{Name: "client1"}, &client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
})
|
})
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
@ -143,8 +148,9 @@ func TestClients(t *testing.T) {
|
||||||
prev, ok := clients.list["client1"]
|
prev, ok := clients.list["client1"]
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
err := clients.update(prev, &persistentClient{
|
err := clients.update(prev, &client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{cliNewIP},
|
IPs: []netip.Addr{cliNewIP},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -157,8 +163,9 @@ func TestClients(t *testing.T) {
|
||||||
prev, ok = clients.list["client1"]
|
prev, ok = clients.list["client1"]
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
err = clients.update(prev, &persistentClient{
|
err = clients.update(prev, &client.Persistent{
|
||||||
Name: "client1-renamed",
|
Name: "client1-renamed",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{cliNewIP},
|
IPs: []netip.Addr{cliNewIP},
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
})
|
})
|
||||||
|
@ -175,7 +182,7 @@ func TestClients(t *testing.T) {
|
||||||
|
|
||||||
assert.Nil(t, nilCli)
|
assert.Nil(t, nilCli)
|
||||||
|
|
||||||
require.Len(t, c.ids(), 1)
|
require.Len(t, c.IDs(), 1)
|
||||||
|
|
||||||
assert.Equal(t, cliNewIP, c.IPs[0])
|
assert.Equal(t, cliNewIP, c.IPs[0])
|
||||||
})
|
})
|
||||||
|
@ -258,8 +265,9 @@ func TestClientsWHOIS(t *testing.T) {
|
||||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.2")
|
ip := netip.MustParseAddr("1.1.1.2")
|
||||||
|
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -280,8 +288,9 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("1.1.1.1")
|
ip := netip.MustParseAddr("1.1.1.1")
|
||||||
|
|
||||||
// Add a client.
|
// Add a client.
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||||
|
@ -330,16 +339,18 @@ func TestClientsAddExisting(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add a new client with the same IP as for a client with MAC.
|
// Add a new client with the same IP as for a client with MAC.
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client2",
|
Name: "client2",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{ip},
|
IPs: []netip.Addr{ip},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, ok)
|
assert.True(t, ok)
|
||||||
|
|
||||||
// Add a new client with the IP from the first client's IP range.
|
// Add a new client with the IP from the first client's IP range.
|
||||||
ok, err = clients.add(&persistentClient{
|
ok, err = clients.add(&client.Persistent{
|
||||||
Name: "client3",
|
Name: "client3",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -351,8 +362,9 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||||
clients := newClientsContainer(t)
|
clients := newClientsContainer(t)
|
||||||
|
|
||||||
// Add client with upstreams.
|
// Add client with upstreams.
|
||||||
ok, err := clients.add(&persistentClient{
|
ok, err := clients.add(&client.Persistent{
|
||||||
Name: "client1",
|
Name: "client1",
|
||||||
|
UID: client.MustNewUID(),
|
||||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||||
Upstreams: []string{
|
Upstreams: []string{
|
||||||
"1.1.1.1",
|
"1.1.1.1",
|
||||||
|
|
|
@ -131,9 +131,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||||
|
|
||||||
// initPrev initializes the persistent client with the default or previous
|
// initPrev initializes the persistent client with the default or previous
|
||||||
// client properties.
|
// client properties.
|
||||||
func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err error) {
|
func initPrev(cj clientJSON, prev *client.Persistent) (c *client.Persistent, err error) {
|
||||||
var (
|
var (
|
||||||
uid UID
|
uid client.UID
|
||||||
ignoreQueryLog bool
|
ignoreQueryLog bool
|
||||||
ignoreStatistics bool
|
ignoreStatistics bool
|
||||||
upsCacheEnabled bool
|
upsCacheEnabled bool
|
||||||
|
@ -166,14 +166,14 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e
|
||||||
return nil, fmt.Errorf("invalid blocked services: %w", err)
|
return nil, fmt.Errorf("invalid blocked services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (uid == UID{}) {
|
if (uid == client.UID{}) {
|
||||||
uid, err = NewUID()
|
uid, err = client.NewUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("generating uid: %w", err)
|
return nil, fmt.Errorf("generating uid: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &persistentClient{
|
return &client.Persistent{
|
||||||
BlockedServices: svcs,
|
BlockedServices: svcs,
|
||||||
UID: uid,
|
UID: uid,
|
||||||
IgnoreQueryLog: ignoreQueryLog,
|
IgnoreQueryLog: ignoreQueryLog,
|
||||||
|
@ -187,21 +187,21 @@ func initPrev(cj clientJSON, prev *persistentClient) (c *persistentClient, err e
|
||||||
// errors.
|
// errors.
|
||||||
func (clients *clientsContainer) jsonToClient(
|
func (clients *clientsContainer) jsonToClient(
|
||||||
cj clientJSON,
|
cj clientJSON,
|
||||||
prev *persistentClient,
|
prev *client.Persistent,
|
||||||
) (c *persistentClient, err error) {
|
) (c *client.Persistent, err error) {
|
||||||
c, err = initPrev(cj, prev)
|
c, err = initPrev(cj, prev)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since it's informative enough as is.
|
// Don't wrap the error since it's informative enough as is.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.setIDs(cj.IDs)
|
err = c.SetIDs(cj.IDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Don't wrap the error since it's informative enough as is.
|
// Don't wrap the error since it's informative enough as is.
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.safeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
|
c.SafeSearchConf = copySafeSearch(cj.SafeSearchConf, cj.SafeSearchEnabled)
|
||||||
c.Name = cj.Name
|
c.Name = cj.Name
|
||||||
c.Tags = cj.Tags
|
c.Tags = cj.Tags
|
||||||
c.Upstreams = cj.Upstreams
|
c.Upstreams = cj.Upstreams
|
||||||
|
@ -211,9 +211,9 @@ func (clients *clientsContainer) jsonToClient(
|
||||||
c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled
|
c.SafeBrowsingEnabled = cj.SafeBrowsingEnabled
|
||||||
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
|
c.UseOwnBlockedServices = !cj.UseGlobalBlockedServices
|
||||||
|
|
||||||
if c.safeSearchConf.Enabled {
|
if c.SafeSearchConf.Enabled {
|
||||||
err = c.setSafeSearch(
|
err = c.SetSafeSearch(
|
||||||
c.safeSearchConf,
|
c.SafeSearchConf,
|
||||||
clients.safeSearchCacheSize,
|
clients.safeSearchCacheSize,
|
||||||
clients.safeSearchCacheTTL,
|
clients.safeSearchCacheTTL,
|
||||||
)
|
)
|
||||||
|
@ -258,7 +258,7 @@ func copySafeSearch(
|
||||||
func copyBlockedServices(
|
func copyBlockedServices(
|
||||||
sch *schedule.Weekly,
|
sch *schedule.Weekly,
|
||||||
svcStrs []string,
|
svcStrs []string,
|
||||||
prev *persistentClient,
|
prev *client.Persistent,
|
||||||
) (svcs *filtering.BlockedServices, err error) {
|
) (svcs *filtering.BlockedServices, err error) {
|
||||||
var weekly *schedule.Weekly
|
var weekly *schedule.Weekly
|
||||||
if sch != nil {
|
if sch != nil {
|
||||||
|
@ -283,15 +283,15 @@ func copyBlockedServices(
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientToJSON converts persistent client object to JSON object.
|
// clientToJSON converts persistent client object to JSON object.
|
||||||
func clientToJSON(c *persistentClient) (cj *clientJSON) {
|
func clientToJSON(c *client.Persistent) (cj *clientJSON) {
|
||||||
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
// TODO(d.kolyshev): Remove after cleaning the deprecated
|
||||||
// [clientJSON.SafeSearchEnabled] field.
|
// [clientJSON.SafeSearchEnabled] field.
|
||||||
cloneVal := c.safeSearchConf
|
cloneVal := c.SafeSearchConf
|
||||||
safeSearchConf := &cloneVal
|
safeSearchConf := &cloneVal
|
||||||
|
|
||||||
return &clientJSON{
|
return &clientJSON{
|
||||||
Name: c.Name,
|
Name: c.Name,
|
||||||
IDs: c.ids(),
|
IDs: c.IDs(),
|
||||||
Tags: c.Tags,
|
Tags: c.Tags,
|
||||||
UseGlobalSettings: !c.UseOwnSettings,
|
UseGlobalSettings: !c.UseOwnSettings,
|
||||||
FilteringEnabled: c.FilteringEnabled,
|
FilteringEnabled: c.FilteringEnabled,
|
||||||
|
@ -397,7 +397,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var prev *persistentClient
|
var prev *client.Persistent
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
|
|
|
@ -232,6 +232,10 @@ type dnsConfig struct {
|
||||||
|
|
||||||
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
|
// ServePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||||
ServePlainDNS bool `yaml:"serve_plain_dns"`
|
ServePlainDNS bool `yaml:"serve_plain_dns"`
|
||||||
|
|
||||||
|
// HostsFileEnabled defines whether to use information from the system hosts
|
||||||
|
// file to resolve queries.
|
||||||
|
HostsFileEnabled bool `yaml:"hostsfile_enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsConfigSettings struct {
|
type tlsConfigSettings struct {
|
||||||
|
@ -259,6 +263,10 @@ type tlsConfigSettings struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type queryLogConfig struct {
|
type queryLogConfig struct {
|
||||||
|
// DirPath is the custom directory for logs. If it's empty the default
|
||||||
|
// directory will be used. See [homeContext.getDataDir].
|
||||||
|
DirPath string `yaml:"dir_path"`
|
||||||
|
|
||||||
// Ignored is the list of host names, which should not be written to log.
|
// Ignored is the list of host names, which should not be written to log.
|
||||||
// "." is considered to be the root domain.
|
// "." is considered to be the root domain.
|
||||||
Ignored []string `yaml:"ignored"`
|
Ignored []string `yaml:"ignored"`
|
||||||
|
@ -278,6 +286,10 @@ type queryLogConfig struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type statsConfig struct {
|
type statsConfig struct {
|
||||||
|
// DirPath is the custom directory for statistics. If it's empty the
|
||||||
|
// default directory is used. See [homeContext.getDataDir].
|
||||||
|
DirPath string `yaml:"dir_path"`
|
||||||
|
|
||||||
// Ignored is the list of host names, which should not be counted.
|
// Ignored is the list of host names, which should not be counted.
|
||||||
Ignored []string `yaml:"ignored"`
|
Ignored []string `yaml:"ignored"`
|
||||||
|
|
||||||
|
@ -344,6 +356,7 @@ var config = &configuration{
|
||||||
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
UpstreamTimeout: timeutil.Duration{Duration: dnsforward.DefaultTimeout},
|
||||||
UsePrivateRDNS: true,
|
UsePrivateRDNS: true,
|
||||||
ServePlainDNS: true,
|
ServePlainDNS: true,
|
||||||
|
HostsFileEnabled: true,
|
||||||
},
|
},
|
||||||
TLS: tlsConfigSettings{
|
TLS: tlsConfigSettings{
|
||||||
PortHTTPS: defaultPortHTTPS,
|
PortHTTPS: defaultPortHTTPS,
|
||||||
|
@ -443,20 +456,25 @@ var config = &configuration{
|
||||||
Theme: ThemeAuto,
|
Theme: ThemeAuto,
|
||||||
}
|
}
|
||||||
|
|
||||||
// getConfigFilename returns path to the current config file
|
// configFilePath returns the absolute path to the symlink-evaluated path to the
|
||||||
func (c *configuration) getConfigFilename() string {
|
// current config file.
|
||||||
configFile, err := filepath.EvalSymlinks(Context.configFilename)
|
func configFilePath() (confPath string) {
|
||||||
|
confPath, err := filepath.EvalSymlinks(Context.confFilePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !errors.Is(err, os.ErrNotExist) {
|
confPath = Context.confFilePath
|
||||||
log.Error("unexpected error while config file path evaluation: %s", err)
|
logFunc := log.Error
|
||||||
}
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
configFile = Context.configFilename
|
logFunc = log.Debug
|
||||||
}
|
|
||||||
if !filepath.IsAbs(configFile) {
|
|
||||||
configFile = filepath.Join(Context.workDir, configFile)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return configFile
|
logFunc("evaluating config path: %s; using %q", err, confPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !filepath.IsAbs(confPath) {
|
||||||
|
confPath = filepath.Join(Context.workDir, confPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return confPath
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateBindHosts returns error if any of binding hosts from configuration is
|
// validateBindHosts returns error if any of binding hosts from configuration is
|
||||||
|
@ -497,7 +515,10 @@ func parseConfig() (err error) {
|
||||||
// Don't wrap the error, because it's informative enough as is.
|
// Don't wrap the error, because it's informative enough as is.
|
||||||
return err
|
return err
|
||||||
} else if upgraded {
|
} else if upgraded {
|
||||||
err = maybe.WriteFile(config.getConfigFilename(), config.fileData, 0o644)
|
confPath := configFilePath()
|
||||||
|
log.Debug("writing config file %q after config upgrade", confPath)
|
||||||
|
|
||||||
|
err = maybe.WriteFile(confPath, config.fileData, 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing new config: %w", err)
|
return fmt.Errorf("writing new config: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -518,12 +539,8 @@ func parseConfig() (err error) {
|
||||||
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
|
config.DNS.UpstreamTimeout = timeutil.Duration{Duration: dnsforward.DefaultTimeout}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = setContextTLSCipherIDs()
|
// Do not wrap the error because it's informative enough as is.
|
||||||
if err != nil {
|
return setContextTLSCipherIDs()
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateConfig returns error if the configuration is invalid.
|
// validateConfig returns error if the configuration is invalid.
|
||||||
|
@ -587,11 +604,11 @@ func readConfigFile() (fileData []byte, err error) {
|
||||||
return config.fileData, nil
|
return config.fileData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
name := config.getConfigFilename()
|
confPath := configFilePath()
|
||||||
log.Debug("reading config file: %s", name)
|
log.Debug("reading config file %q", confPath)
|
||||||
|
|
||||||
// Do not wrap the error because it's informative enough as is.
|
// Do not wrap the error because it's informative enough as is.
|
||||||
return os.ReadFile(name)
|
return os.ReadFile(confPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
// Saves configuration to the YAML file and also saves the user filter contents to a file
|
||||||
|
@ -655,8 +672,8 @@ func (c *configuration) write() (err error) {
|
||||||
|
|
||||||
config.Clients.Persistent = Context.clients.forConfig()
|
config.Clients.Persistent = Context.clients.forConfig()
|
||||||
|
|
||||||
configFile := config.getConfigFilename()
|
confPath := configFilePath()
|
||||||
log.Debug("writing config file %q", configFile)
|
log.Debug("writing config file %q", confPath)
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
enc := yaml.NewEncoder(buf)
|
enc := yaml.NewEncoder(buf)
|
||||||
|
@ -667,7 +684,7 @@ func (c *configuration) write() (err error) {
|
||||||
return fmt.Errorf("generating config file: %w", err)
|
return fmt.Errorf("generating config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = maybe.WriteFile(configFile, buf.Bytes(), 0o644)
|
err = maybe.WriteFile(confPath, buf.Bytes(), 0o644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("writing config file: %w", err)
|
return fmt.Errorf("writing config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -144,10 +144,7 @@ func handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
// Make sure that we don't send negative numbers to the frontend,
|
// Make sure that we don't send negative numbers to the frontend,
|
||||||
// since enough time might have passed to make the difference less
|
// since enough time might have passed to make the difference less
|
||||||
// than zero.
|
// than zero.
|
||||||
protectionDisabledDuration = max(
|
protectionDisabledDuration = max(0, time.Until(*protectionDisabledUntil).Milliseconds())
|
||||||
0,
|
|
||||||
time.Until(*protectionDisabledUntil).Milliseconds(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = statusResponse{
|
resp = statusResponse{
|
||||||
|
|
|
@ -46,12 +46,15 @@ func onConfigModified() {
|
||||||
// server and initializes it at last. It also must not be called unless
|
// server and initializes it at last. It also must not be called unless
|
||||||
// [config] and [Context] are initialized.
|
// [config] and [Context] are initialized.
|
||||||
func initDNS() (err error) {
|
func initDNS() (err error) {
|
||||||
baseDir := Context.getDataDir()
|
|
||||||
|
|
||||||
anonymizer := config.anonymizer()
|
anonymizer := config.anonymizer()
|
||||||
|
|
||||||
|
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&Context, config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
statsConf := stats.Config{
|
statsConf := stats.Config{
|
||||||
Filename: filepath.Join(baseDir, "stats.db"),
|
Filename: filepath.Join(statsDir, "stats.db"),
|
||||||
Limit: config.Stats.Interval.Duration,
|
Limit: config.Stats.Interval.Duration,
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
|
@ -75,7 +78,7 @@ func initDNS() (err error) {
|
||||||
ConfigModified: onConfigModified,
|
ConfigModified: onConfigModified,
|
||||||
HTTPRegister: httpRegister,
|
HTTPRegister: httpRegister,
|
||||||
FindClient: Context.clients.findMultiple,
|
FindClient: Context.clients.findMultiple,
|
||||||
BaseDir: baseDir,
|
BaseDir: querylogDir,
|
||||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||||
RotationIvl: config.QueryLog.Interval.Duration,
|
RotationIvl: config.QueryLog.Interval.Duration,
|
||||||
MemSize: config.QueryLog.MemSize,
|
MemSize: config.QueryLog.MemSize,
|
||||||
|
@ -424,7 +427,7 @@ func applyAdditionalFiltering(clientIP netip.Addr, clientID string, setts *filte
|
||||||
}
|
}
|
||||||
|
|
||||||
setts.FilteringEnabled = c.FilteringEnabled
|
setts.FilteringEnabled = c.FilteringEnabled
|
||||||
setts.SafeSearchEnabled = c.safeSearchConf.Enabled
|
setts.SafeSearchEnabled = c.SafeSearchConf.Enabled
|
||||||
setts.ClientSafeSearch = c.SafeSearch
|
setts.ClientSafeSearch = c.SafeSearch
|
||||||
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
setts.SafeBrowsingEnabled = c.SafeBrowsingEnabled
|
||||||
setts.ParentalEnabled = c.ParentalEnabled
|
setts.ParentalEnabled = c.ParentalEnabled
|
||||||
|
@ -545,3 +548,50 @@ func (r safeSearchResolver) LookupIP(
|
||||||
|
|
||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkStatsAndQuerylogDirs checks and returns directory paths to store
|
||||||
|
// statistics and query log.
|
||||||
|
func checkStatsAndQuerylogDirs(
|
||||||
|
ctx *homeContext,
|
||||||
|
conf *configuration,
|
||||||
|
) (statsDir, querylogDir string, err error) {
|
||||||
|
baseDir := ctx.getDataDir()
|
||||||
|
|
||||||
|
statsDir = conf.Stats.DirPath
|
||||||
|
if statsDir == "" {
|
||||||
|
statsDir = baseDir
|
||||||
|
} else {
|
||||||
|
err = checkDir(statsDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("statistics: custom directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
querylogDir = conf.QueryLog.DirPath
|
||||||
|
if querylogDir == "" {
|
||||||
|
querylogDir = baseDir
|
||||||
|
} else {
|
||||||
|
err = checkDir(querylogDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("querylog: custom directory: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return statsDir, querylogDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkDir checks if the path is a directory. It's used to check for
|
||||||
|
// misconfiguration at startup.
|
||||||
|
func checkDir(path string) (err error) {
|
||||||
|
var fi os.FileInfo
|
||||||
|
if fi, err = os.Stat(path); err != nil {
|
||||||
|
// Don't wrap the error, since it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fi.IsDir() {
|
||||||
|
return fmt.Errorf("%q is not a directory", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
@ -12,6 +13,19 @@ import (
|
||||||
|
|
||||||
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||||
|
|
||||||
|
// newIDIndex is a helper function that returns a client index filled with
|
||||||
|
// persistent clients from the m. It also generates a UID for each client.
|
||||||
|
func newIDIndex(m []*client.Persistent) (ci *client.Index) {
|
||||||
|
ci = client.NewIndex()
|
||||||
|
|
||||||
|
for _, c := range m {
|
||||||
|
c.UID = client.MustNewUID()
|
||||||
|
ci.Add(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ci
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -22,29 +36,28 @@ func TestApplyAdditionalFiltering(t *testing.T) {
|
||||||
}, nil)
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
Context.clients.idIndex = map[string]*persistentClient{
|
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
|
||||||
"default": {
|
ClientIDs: []string{"default"},
|
||||||
UseOwnSettings: false,
|
UseOwnSettings: false,
|
||||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
SafeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||||
FilteringEnabled: false,
|
FilteringEnabled: false,
|
||||||
SafeBrowsingEnabled: false,
|
SafeBrowsingEnabled: false,
|
||||||
ParentalEnabled: false,
|
ParentalEnabled: false,
|
||||||
},
|
}, {
|
||||||
"custom_filtering": {
|
ClientIDs: []string{"custom_filtering"},
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
SafeBrowsingEnabled: true,
|
SafeBrowsingEnabled: true,
|
||||||
ParentalEnabled: true,
|
ParentalEnabled: true,
|
||||||
},
|
}, {
|
||||||
"partial_custom_filtering": {
|
ClientIDs: []string{"partial_custom_filtering"},
|
||||||
UseOwnSettings: true,
|
UseOwnSettings: true,
|
||||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
SafeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||||
FilteringEnabled: true,
|
FilteringEnabled: true,
|
||||||
SafeBrowsingEnabled: false,
|
SafeBrowsingEnabled: false,
|
||||||
ParentalEnabled: false,
|
ParentalEnabled: false,
|
||||||
},
|
}})
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -108,38 +121,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||||
}, nil)
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
Context.clients.idIndex = map[string]*persistentClient{
|
Context.clients.clientIndex = newIDIndex([]*client.Persistent{{
|
||||||
"default": {
|
ClientIDs: []string{"default"},
|
||||||
UseOwnBlockedServices: false,
|
UseOwnBlockedServices: false,
|
||||||
},
|
}, {
|
||||||
"no_services": {
|
ClientIDs: []string{"no_services"},
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.EmptyWeekly(),
|
Schedule: schedule.EmptyWeekly(),
|
||||||
},
|
},
|
||||||
UseOwnBlockedServices: true,
|
UseOwnBlockedServices: true,
|
||||||
},
|
}, {
|
||||||
"services": {
|
ClientIDs: []string{"services"},
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.EmptyWeekly(),
|
Schedule: schedule.EmptyWeekly(),
|
||||||
IDs: clientBlockedServices,
|
IDs: clientBlockedServices,
|
||||||
},
|
},
|
||||||
UseOwnBlockedServices: true,
|
UseOwnBlockedServices: true,
|
||||||
},
|
}, {
|
||||||
"invalid_services": {
|
ClientIDs: []string{"invalid_services"},
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.EmptyWeekly(),
|
Schedule: schedule.EmptyWeekly(),
|
||||||
IDs: invalidBlockedServices,
|
IDs: invalidBlockedServices,
|
||||||
},
|
},
|
||||||
UseOwnBlockedServices: true,
|
UseOwnBlockedServices: true,
|
||||||
},
|
}, {
|
||||||
"allow_all": {
|
ClientIDs: []string{"allow_all"},
|
||||||
BlockedServices: &filtering.BlockedServices{
|
BlockedServices: &filtering.BlockedServices{
|
||||||
Schedule: schedule.FullWeekly(),
|
Schedule: schedule.FullWeekly(),
|
||||||
IDs: clientBlockedServices,
|
IDs: clientBlockedServices,
|
||||||
},
|
},
|
||||||
UseOwnBlockedServices: true,
|
UseOwnBlockedServices: true,
|
||||||
},
|
}})
|
||||||
}
|
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -39,8 +40,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/netutil"
|
"github.com/AdguardTeam/golibs/netutil"
|
||||||
"github.com/AdguardTeam/golibs/osutil"
|
"github.com/AdguardTeam/golibs/osutil"
|
||||||
"github.com/AdguardTeam/golibs/stringutil"
|
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Global context
|
// Global context
|
||||||
|
@ -68,7 +67,10 @@ type homeContext struct {
|
||||||
// Runtime properties
|
// Runtime properties
|
||||||
// --
|
// --
|
||||||
|
|
||||||
configFilename string // Config filename (can be overridden via the command line arguments)
|
// confFilePath is the configuration file path as set by default or from the
|
||||||
|
// command-line options.
|
||||||
|
confFilePath string
|
||||||
|
|
||||||
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
workDir string // Location of our directory, used to protect against CWD being somewhere else
|
||||||
pidFileName string // PID file name. Empty if no PID file was created.
|
pidFileName string // PID file name. Empty if no PID file was created.
|
||||||
controlLock sync.Mutex
|
controlLock sync.Mutex
|
||||||
|
@ -250,7 +252,7 @@ func setupHostsContainer() (err error) {
|
||||||
return errors.Join(fmt.Errorf("initializing hosts container: %w", err), closeErr)
|
return errors.Join(fmt.Errorf("initializing hosts container: %w", err), closeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return hostsWatcher.Start()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupOpts sets up command-line options.
|
// setupOpts sets up command-line options.
|
||||||
|
@ -361,7 +363,7 @@ func setupDNSFilteringConf(conf *filtering.Config) (err error) {
|
||||||
|
|
||||||
conf.EtcHosts = Context.etcHosts
|
conf.EtcHosts = Context.etcHosts
|
||||||
// TODO(s.chzhen): Use empty interface.
|
// TODO(s.chzhen): Use empty interface.
|
||||||
if Context.etcHosts == nil {
|
if Context.etcHosts == nil || !config.DNS.HostsFileEnabled {
|
||||||
conf.EtcHosts = nil
|
conf.EtcHosts = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -575,6 +577,9 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||||
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
Path: path.Join("adguardhome", version.Channel(), "version.json"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
confPath := configFilePath()
|
||||||
|
log.Debug("using config path %q for updater", confPath)
|
||||||
|
|
||||||
upd := updater.NewUpdater(&updater.Config{
|
upd := updater.NewUpdater(&updater.Config{
|
||||||
Client: config.Filtering.HTTPClient,
|
Client: config.Filtering.HTTPClient,
|
||||||
Version: version.Version(),
|
Version: version.Version(),
|
||||||
|
@ -584,7 +589,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}) {
|
||||||
GOARM: version.GOARM(),
|
GOARM: version.GOARM(),
|
||||||
GOMIPS: version.GOMIPS(),
|
GOMIPS: version.GOMIPS(),
|
||||||
WorkDir: Context.workDir,
|
WorkDir: Context.workDir,
|
||||||
ConfName: config.getConfigFilename(),
|
ConfName: confPath,
|
||||||
ExecPath: execPath,
|
ExecPath: execPath,
|
||||||
VersionCheckURL: u.String(),
|
VersionCheckURL: u.String(),
|
||||||
})
|
})
|
||||||
|
@ -748,7 +753,16 @@ func writePIDFile(fn string) bool {
|
||||||
// initConfigFilename sets up context config file path. This file path can be
|
// initConfigFilename sets up context config file path. This file path can be
|
||||||
// overridden by command-line arguments, or is set to default.
|
// overridden by command-line arguments, or is set to default.
|
||||||
func initConfigFilename(opts options) {
|
func initConfigFilename(opts options) {
|
||||||
Context.configFilename = stringutil.Coalesce(opts.confFilename, "AdGuardHome.yaml")
|
confPath := opts.confFilename
|
||||||
|
if confPath == "" {
|
||||||
|
Context.confFilePath = "AdGuardHome.yaml"
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("config path overridden to %q from cmdline", confPath)
|
||||||
|
|
||||||
|
Context.confFilePath = confPath
|
||||||
}
|
}
|
||||||
|
|
||||||
// initWorkingDir initializes the workDir. If no command-line arguments are
|
// initWorkingDir initializes the workDir. If no command-line arguments are
|
||||||
|
@ -906,16 +920,23 @@ func printHTTPAddresses(proto string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -------------------
|
// detectFirstRun returns true if this is the first run of AdGuard Home.
|
||||||
// first run / install
|
func detectFirstRun() (ok bool) {
|
||||||
// -------------------
|
confPath := Context.confFilePath
|
||||||
func detectFirstRun() bool {
|
if !filepath.IsAbs(confPath) {
|
||||||
configfile := Context.configFilename
|
confPath = filepath.Join(Context.workDir, Context.confFilePath)
|
||||||
if !filepath.IsAbs(configfile) {
|
|
||||||
configfile = filepath.Join(Context.workDir, Context.configFilename)
|
|
||||||
}
|
}
|
||||||
_, err := os.Stat(configfile)
|
|
||||||
return errors.Is(err, os.ErrNotExist)
|
_, err := os.Stat(confPath)
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
} else if errors.Is(err, os.ErrNotExist) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Error("detecting first run: %s; considering first run", err)
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// jsonError is a generic JSON error response.
|
// jsonError is a generic JSON error response.
|
||||||
|
|
|
@ -75,6 +75,8 @@ func getLogSettings(opts options) (ls *logSettings) {
|
||||||
if opts.verbose {
|
if opts.verbose {
|
||||||
ls.Verbose = true
|
ls.Verbose = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(a.garipov): Use cmp.Or in Go 1.22.
|
||||||
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
|
ls.File = stringutil.Coalesce(opts.logFile, ls.File)
|
||||||
|
|
||||||
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
if opts.runningAsService && ls.File == "" && runtime.GOOS == "windows" {
|
||||||
|
|
|
@ -270,13 +270,15 @@ var cmdLineOpts = []cmdLineOpt{{
|
||||||
log.Info(
|
log.Info(
|
||||||
"warning: --no-etc-hosts flag is deprecated " +
|
"warning: --no-etc-hosts flag is deprecated " +
|
||||||
"and will be removed in the future versions; " +
|
"and will be removed in the future versions; " +
|
||||||
"set clients.runtime_sources.hosts in the configuration file to false instead",
|
"set clients.runtime_sources.hosts and dns.hostsfile_enabled " +
|
||||||
|
"in the configuration file to false instead",
|
||||||
)
|
)
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
},
|
},
|
||||||
serialize: func(o options) (val string, ok bool) { return "", o.noEtcHosts },
|
serialize: func(o options) (val string, ok bool) { return "", o.noEtcHosts },
|
||||||
description: "Deprecated: use clients.runtime_sources.hosts instead. Do not use the OS-provided hosts.",
|
description: "Deprecated: use clients.runtime_sources.hosts and dns.hostsfile_enabled " +
|
||||||
|
"instead. Do not use the OS-provided hosts.",
|
||||||
longName: "no-etc-hosts",
|
longName: "no-etc-hosts",
|
||||||
shortName: "",
|
shortName: "",
|
||||||
}, {
|
}, {
|
||||||
|
|
|
@ -227,12 +227,15 @@ func handleServiceControlAction(
|
||||||
runOpts := opts
|
runOpts := opts
|
||||||
runOpts.serviceControlAction = "run"
|
runOpts.serviceControlAction = "run"
|
||||||
|
|
||||||
|
args := optsToArgs(runOpts)
|
||||||
|
log.Debug("service: using args %q", args)
|
||||||
|
|
||||||
svcConfig := &service.Config{
|
svcConfig := &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: serviceDisplayName,
|
DisplayName: serviceDisplayName,
|
||||||
Description: serviceDescription,
|
Description: serviceDescription,
|
||||||
WorkingDirectory: pwd,
|
WorkingDirectory: pwd,
|
||||||
Arguments: optsToArgs(runOpts),
|
Arguments: args,
|
||||||
}
|
}
|
||||||
configureService(svcConfig)
|
configureService(svcConfig)
|
||||||
|
|
||||||
|
|
|
@ -8,13 +8,13 @@ import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
"github.com/AdguardTeam/AdGuardHome/internal/configmigrate"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
"github.com/AdguardTeam/AdGuardHome/internal/next/configmgr"
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
// options contains all command-line options for the AdGuardHome(.exe) binary.
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -20,7 +21,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/golibs/timeutil"
|
"github.com/AdguardTeam/golibs/timeutil"
|
||||||
"github.com/google/renameio/v2/maybe"
|
"github.com/google/renameio/v2/maybe"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -67,8 +67,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
svc.bootstrapResolvers = resolvers
|
svc.bootstrapResolvers = resolvers
|
||||||
svc.proxy = &proxy.Proxy{
|
svc.proxy, err = proxy.New(&proxy.Config{
|
||||||
Config: proxy.Config{
|
|
||||||
UDPListenAddr: udpAddrs(c.Addresses),
|
UDPListenAddr: udpAddrs(c.Addresses),
|
||||||
TCPListenAddr: tcpAddrs(c.Addresses),
|
TCPListenAddr: tcpAddrs(c.Addresses),
|
||||||
UpstreamConfig: &proxy.UpstreamConfig{
|
UpstreamConfig: &proxy.UpstreamConfig{
|
||||||
|
@ -76,10 +75,7 @@ func New(c *Config) (svc *Service, err error) {
|
||||||
},
|
},
|
||||||
UseDNS64: c.UseDNS64,
|
UseDNS64: c.UseDNS64,
|
||||||
DNS64Prefs: c.DNS64Prefixes,
|
DNS64Prefs: c.DNS64Prefixes,
|
||||||
},
|
})
|
||||||
}
|
|
||||||
|
|
||||||
err = svc.proxy.Init()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("proxy: %w", err)
|
return nil, fmt.Errorf("proxy: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -174,7 +170,7 @@ func (svc *Service) Start() (err error) {
|
||||||
svc.running.Store(err == nil)
|
svc.running.Store(err == nil)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return svc.proxy.Start()
|
return svc.proxy.Start(context.Background())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
// Shutdown implements the [agh.Service] interface for *Service. svc may be
|
||||||
|
@ -185,7 +181,7 @@ func (svc *Service) Shutdown(ctx context.Context) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
errs := []error{
|
errs := []error{
|
||||||
svc.proxy.Stop(),
|
svc.proxy.Shutdown(ctx),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, b := range svc.bootstrapResolvers {
|
for _, b := range svc.bootstrapResolvers {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package querylog
|
package querylog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
@ -9,7 +10,6 @@ import (
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/net/idna"
|
"golang.org/x/net/idna"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,11 @@ package querylog
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/errors"
|
"github.com/AdguardTeam/golibs/errors"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// client finds the client info, if any, by its ClientID and IP address,
|
// client finds the client info, if any, by its ClientID and IP address,
|
||||||
|
@ -49,8 +49,8 @@ func (l *queryLog) client(clientID, ip string, cache clientCache) (c *Client, er
|
||||||
// the total amount of records in the buffer at the moment of searching.
|
// the total amount of records in the buffer at the moment of searching.
|
||||||
// l.confMu is expected to be locked.
|
// l.confMu is expected to be locked.
|
||||||
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) {
|
func (l *queryLog) searchMemory(params *searchParams, cache clientCache) (entries []*logEntry, total int) {
|
||||||
// We use this configuration check because a buffer can contain a single log
|
// Check memory size, as the buffer can contain a single log record. See
|
||||||
// record. See [newQueryLog].
|
// [newQueryLog].
|
||||||
if l.conf.MemSize == 0 {
|
if l.conf.MemSize == 0 {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
@ -12,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"go.etcd.io/bbolt"
|
"go.etcd.io/bbolt"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -484,7 +484,7 @@ func (s *StatsCtx) fillCollectedStats(data *StatsResp, units []*unitDB, curID ui
|
||||||
data.TimeUnits = timeUnitsHours
|
data.TimeUnits = timeUnitsHours
|
||||||
|
|
||||||
daysCount := size / 24
|
daysCount := size / 24
|
||||||
if daysCount >= 7 {
|
if daysCount > 7 {
|
||||||
size = daysCount
|
size = daysCount
|
||||||
data.TimeUnits = timeUnitsDays
|
data.TimeUnits = timeUnitsDays
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||||
|
@ -12,7 +13,6 @@ import (
|
||||||
"github.com/AdguardTeam/golibs/ioutil"
|
"github.com/AdguardTeam/golibs/ioutil"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO(a.garipov): Make configurable.
|
// TODO(a.garipov): Make configurable.
|
||||||
|
|
|
@ -61,8 +61,6 @@ set -f -u
|
||||||
#
|
#
|
||||||
# TODO(a.garipov): Add golibs/log.
|
# TODO(a.garipov): Add golibs/log.
|
||||||
#
|
#
|
||||||
# TODO(a.garipov): Add "golang.org/x/exp/slices" back after a release.
|
|
||||||
#
|
|
||||||
# TODO(a.garipov): Add deprecated package golang.org/x/exp/maps once all
|
# TODO(a.garipov): Add deprecated package golang.org/x/exp/maps once all
|
||||||
# projects switch to Go 1.22.
|
# projects switch to Go 1.22.
|
||||||
blocklist_imports() {
|
blocklist_imports() {
|
||||||
|
@ -73,6 +71,7 @@ blocklist_imports() {
|
||||||
-e '[[:space:]]"reflect"$'\
|
-e '[[:space:]]"reflect"$'\
|
||||||
-e '[[:space:]]"sort"$'\
|
-e '[[:space:]]"sort"$'\
|
||||||
-e '[[:space:]]"unsafe"$'\
|
-e '[[:space:]]"unsafe"$'\
|
||||||
|
-e '[[:space:]]"golang.org/x/exp/slices"$'\
|
||||||
-e '[[:space:]]"golang.org/x/net/context"$'\
|
-e '[[:space:]]"golang.org/x/net/context"$'\
|
||||||
-n\
|
-n\
|
||||||
-- '*.go'\
|
-- '*.go'\
|
||||||
|
|
Loading…
Reference in New Issue