Pull request: 2508 ip conversion vol.2

Merge in DNS/adguard-home from 2508-ip-conversion-vol2 to master

Closes #2508.

Squashed commit of the following:

commit 5b9d33f9cd352756831f63e34c4aea48674628c1
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 20 17:15:17 2021 +0300

    util: replace net.IPNet with pointer

commit 680126de7d59464077f9edf1bbaa925dd3fcee19
Merge: d3ba6a6c 5a50efad
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Wed Jan 20 17:02:41 2021 +0300

    Merge branch 'master' into 2508-ip-conversion-vol2

commit d3ba6a6cdd01c0aa736418fdb86ed40120169fe9
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 19 18:29:54 2021 +0300

    all: remove last conversion

commit 88b63f11a6c3f8705d7fa0c448c50dd646cc9214
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Tue Jan 19 14:12:45 2021 +0300

    all: improve code quality

commit 71af60c70a0dbaf55e2221023d6d2e4993c9e9a7
Merge: 98af3784 9f75725d
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 17:13:27 2021 +0300

    Merge branch 'master' into 2508-ip-conversion-vol2

commit 98af3784ce44d0993d171653c13d6e83bb8d1e6a
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 16:32:53 2021 +0300

    all: log changes

commit e99595a172bae1e844019d344544be84ddd65e4e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Mon Jan 18 16:06:49 2021 +0300

    all: fix or remove remaining net.IP <-> string conversions

commit 7fd0634ce945f7e4c9b856684c5199f8a84a543e
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Fri Jan 15 15:36:17 2021 +0300

    all: remove redundant net.IP <-> string converions

commit 5df8af030421237d41b67ed659f83526cc258199
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Jan 14 16:35:25 2021 +0300

    stats: remove redundant net.IP <-> string conversion

commit fbe4e3fc015e6898063543a90c04401d76dbb18f
Author: Eugene Burkov <e.burkov@adguard.com>
Date:   Thu Jan 14 16:20:35 2021 +0300

    querylog: remove redundant net.IP <-> string conversion
This commit is contained in:
Eugene Burkov 2021-01-20 17:27:53 +03:00
parent 5a50efadb2
commit 7fab31beae
45 changed files with 324 additions and 302 deletions

View File

@ -66,6 +66,7 @@ and this project adheres to
### Fixed ### Fixed
- Unnecessary conversions from `string` to `net.IP`, and vice versa ([#2508]).
- Inability to set DNS cache TTL limits ([#2459]). - Inability to set DNS cache TTL limits ([#2459]).
- Possible freezes on slower machines ([#2225]). - Possible freezes on slower machines ([#2225]).
- A mitigation against records being shown in the wrong order on the query log - A mitigation against records being shown in the wrong order on the query log
@ -79,9 +80,13 @@ and this project adheres to
[#2345]: https://github.com/AdguardTeam/AdGuardHome/issues/2345 [#2345]: https://github.com/AdguardTeam/AdGuardHome/issues/2345
[#2355]: https://github.com/AdguardTeam/AdGuardHome/issues/2355 [#2355]: https://github.com/AdguardTeam/AdGuardHome/issues/2355
[#2459]: https://github.com/AdguardTeam/AdGuardHome/issues/2459 [#2459]: https://github.com/AdguardTeam/AdGuardHome/issues/2459
[#2508]: https://github.com/AdguardTeam/AdGuardHome/issues/2508
### Removed ### Removed
- The undocumented ability to use hostnames as any of `bind_host` values in
configuration. Documentation requires them to be valid IP addresses, and now
the implementation makes sure that that is the case ([#2508]).
- `Dockerfile` ([#2276]). Replaced with the script - `Dockerfile` ([#2276]). Replaced with the script
`scripts/make/build-docker.sh` which uses `scripts/make/Dockerfile`. `scripts/make/build-docker.sh` which uses `scripts/make/Dockerfile`.
- Support for pre-v0.99.3 format of query logs ([#2102]). - Support for pre-v0.99.3 format of query logs ([#2102]).

View File

@ -297,9 +297,6 @@ func parseOptionString(s string) (uint8, []byte) {
return 0, nil return 0, nil
} }
val = ip val = ip
if ip.To4() != nil {
val = ip.To4()
}
default: default:
return 0, nil return 0, nil

View File

@ -61,11 +61,11 @@ func TestDB(t *testing.T) {
ll := s.srv4.GetLeases(LeasesAll) ll := s.srv4.GetLeases(LeasesAll)
assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:bb", ll[0].HWAddr.String())
assert.Equal(t, "192.168.10.101", ll[0].IP.String()) assert.True(t, net.IP{192, 168, 10, 101}.Equal(ll[0].IP))
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix()) assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ll[1].HWAddr.String())
assert.Equal(t, "192.168.10.100", ll[1].IP.String()) assert.True(t, net.IP{192, 168, 10, 100}.Equal(ll[1].IP))
assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix()) assert.Equal(t, exp1.Unix(), ll[1].Expiry.Unix())
_ = os.Remove("leases.db") _ = os.Remove("leases.db")
@ -117,7 +117,7 @@ func TestOptions(t *testing.T) {
code, val = parseOptionString("123 ip 1.2.3.4") code, val = parseOptionString("123 ip 1.2.3.4")
assert.EqualValues(t, 123, code) assert.EqualValues(t, 123, code)
assert.Equal(t, "1.2.3.4", net.IP(string(val)).String()) assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(val)))
code, _ = parseOptionString("256 ip 1.1.1.1") code, _ = parseOptionString("256 ip 1.1.1.1")
assert.EqualValues(t, 0, code) assert.EqualValues(t, 0, code)

View File

@ -40,7 +40,7 @@ func v4JSONToServerConf(j v4ServerConfJSON) V4ServerConf {
} }
type v6ServerConfJSON struct { type v6ServerConfJSON struct {
RangeStart string `json:"range_start"` RangeStart net.IP `json:"range_start"`
LeaseDuration uint32 `json:"lease_duration"` LeaseDuration uint32 `json:"lease_duration"`
} }
@ -331,7 +331,7 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
result.V4.StaticIP.Error = err.Error() result.V4.StaticIP.Error = err.Error()
} else if !isStaticIP { } else if !isStaticIP {
result.V4.StaticIP.Static = "no" result.V4.StaticIP.Static = "no"
result.V4.StaticIP.IP = util.GetSubnet(interfaceName) result.V4.StaticIP.IP = util.GetSubnet(interfaceName).String()
} }
if found4 { if found4 {

View File

@ -79,7 +79,7 @@ type V6ServerConf struct {
// The first IP address for dynamic leases // The first IP address for dynamic leases
// The last allowed IP address ends with 0xff byte // The last allowed IP address ends with 0xff byte
RangeStart string `yaml:"range_start" json:"range_start"` RangeStart net.IP `yaml:"range_start"`
LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds LeaseDuration uint32 `yaml:"lease_duration" json:"lease_duration"` // in seconds

View File

@ -40,7 +40,7 @@ func TestV4StaticLeaseAddRemove(t *testing.T) {
// check // check
ls = s.GetLeases(LeasesStatic) ls = s.GetLeases(LeasesStatic)
assert.Len(t, ls, 1) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
@ -102,11 +102,11 @@ func TestV4StaticLeaseAddReplaceDynamic(t *testing.T) {
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Len(t, ls, 2) assert.Len(t, ls, 2)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "33:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix())
assert.Equal(t, "192.168.10.152", ls[1].IP.String()) assert.True(t, net.IP{192, 168, 10, 152}.Equal(ls[1].IP))
assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String()) assert.Equal(t, "22:aa:aa:aa:aa:aa", ls[1].HWAddr.String())
assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix()) assert.EqualValues(t, leaseExpireStatic, ls[1].Expiry.Unix())
} }
@ -139,10 +139,10 @@ func TestV4StaticLeaseGet(t *testing.T) {
// check "Offer" // check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String()) assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr))
assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
// "Request" // "Request"
@ -153,20 +153,20 @@ func TestV4StaticLeaseGet(t *testing.T) {
// check "Ack" // check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.150", resp.YourIPAddr.String()) assert.True(t, net.IP{192, 168, 10, 150}.Equal(resp.YourIPAddr))
assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS() dnsAddrs := resp.DNS()
assert.Len(t, dnsAddrs, 1) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease // check lease
ls := s.GetLeases(LeasesStatic) ls := s.GetLeases(LeasesStatic)
assert.Len(t, ls, 1) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.150", ls[0].IP.String()) assert.True(t, net.IP{192, 168, 10, 150}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
} }
@ -197,13 +197,13 @@ func TestV4DynamicLeaseGet(t *testing.T) {
// check "Offer" // check "Offer"
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType()) assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String()) assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr))
assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
assert.Equal(t, "1.2.3.4", net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)]).String()) assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
// "Request" // "Request"
req, _ = dhcpv4.NewRequestFromOffer(resp) req, _ = dhcpv4.NewRequestFromOffer(resp)
@ -213,20 +213,20 @@ func TestV4DynamicLeaseGet(t *testing.T) {
// check "Ack" // check "Ack"
assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType()) assert.Equal(t, dhcpv4.MessageTypeAck, resp.MessageType())
assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", resp.ClientHWAddr.String())
assert.Equal(t, "192.168.10.100", resp.YourIPAddr.String()) assert.True(t, net.IP{192, 168, 10, 100}.Equal(resp.YourIPAddr))
assert.Equal(t, "192.168.10.1", resp.Router()[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.Router()[0]))
assert.Equal(t, "192.168.10.1", resp.ServerIdentifier().String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(resp.ServerIdentifier()))
assert.Equal(t, "255.255.255.0", net.IP(resp.SubnetMask()).String()) assert.True(t, net.IP{255, 255, 255, 0}.Equal(net.IP(resp.SubnetMask())))
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
dnsAddrs := resp.DNS() dnsAddrs := resp.DNS()
assert.Len(t, dnsAddrs, 1) assert.Len(t, dnsAddrs, 1)
assert.Equal(t, "192.168.10.1", dnsAddrs[0].String()) assert.True(t, net.IP{192, 168, 10, 1}.Equal(dnsAddrs[0]))
// check lease // check lease
ls := s.GetLeases(LeasesDynamic) ls := s.GetLeases(LeasesDynamic)
assert.Len(t, ls, 1) assert.Len(t, ls, 1)
assert.Equal(t, "192.168.10.100", ls[0].IP.String()) assert.True(t, net.IP{192, 168, 10, 100}.Equal(ls[0].IP))
assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String()) assert.Equal(t, "aa:aa:aa:aa:aa:aa", ls[0].HWAddr.String())
start := net.IP{192, 168, 10, 100} start := net.IP{192, 168, 10, 100}

View File

@ -660,7 +660,7 @@ func v6Create(conf V6ServerConf) (DHCPServer, error) {
return s, nil return s, nil
} }
s.conf.ipStart = net.ParseIP(conf.RangeStart) s.conf.ipStart = conf.RangeStart
if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil { if s.conf.ipStart == nil || s.conf.ipStart.To16() == nil {
return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart) return s, fmt.Errorf("dhcpv6: invalid range-start IP: %s", conf.RangeStart)
} }

View File

@ -17,7 +17,7 @@ func notify6(flags uint32) {
func TestV6StaticLeaseAddRemove(t *testing.T) { func TestV6StaticLeaseAddRemove(t *testing.T) {
conf := V6ServerConf{ conf := V6ServerConf{
Enabled: true, Enabled: true,
RangeStart: "2001::1", RangeStart: net.ParseIP("2001::1"),
notify: notify6, notify: notify6,
} }
s, err := v6Create(conf) s, err := v6Create(conf)
@ -60,7 +60,7 @@ func TestV6StaticLeaseAddRemove(t *testing.T) {
func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) { func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
conf := V6ServerConf{ conf := V6ServerConf{
Enabled: true, Enabled: true,
RangeStart: "2001::1", RangeStart: net.ParseIP("2001::1"),
notify: notify6, notify: notify6,
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)
@ -109,7 +109,7 @@ func TestV6StaticLeaseAddReplaceDynamic(t *testing.T) {
func TestV6GetLease(t *testing.T) { func TestV6GetLease(t *testing.T) {
conf := V6ServerConf{ conf := V6ServerConf{
Enabled: true, Enabled: true,
RangeStart: "2001::1", RangeStart: net.ParseIP("2001::1"),
notify: notify6, notify: notify6,
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)
@ -169,7 +169,7 @@ func TestV6GetLease(t *testing.T) {
func TestV6GetDynamicLease(t *testing.T) { func TestV6GetDynamicLease(t *testing.T) {
conf := V6ServerConf{ conf := V6ServerConf{
Enabled: true, Enabled: true,
RangeStart: "2001::2", RangeStart: net.ParseIP("2001::2"),
notify: notify6, notify: notify6,
} }
sIface, err := v6Create(conf) sIface, err := v6Create(conf)

View File

@ -36,7 +36,7 @@ type RequestFilteringSettings struct {
ParentalEnabled bool ParentalEnabled bool
ClientName string ClientName string
ClientIP string ClientIP net.IP
ClientTags []string ClientTags []string
ServicesRules []ServiceEntry ServicesRules []ServiceEntry
@ -676,9 +676,10 @@ func (d *DNSFilter) matchHost(host string, qtype uint16, setts RequestFilteringS
ureq := urlfilter.DNSRequest{ ureq := urlfilter.DNSRequest{
Hostname: host, Hostname: host,
SortedClientTags: setts.ClientTags, SortedClientTags: setts.ClientTags,
ClientIP: setts.ClientIP, // TODO(e.burkov): Wait for urlfilter update to pass net.IP.
ClientName: setts.ClientName, ClientIP: setts.ClientIP.String(),
DNSType: qtype, ClientName: setts.ClientName,
DNSType: qtype,
} }
if d.filteringEngineAllow != nil { if d.filteringEngineAllow != nil {

View File

@ -117,19 +117,19 @@ func TestRewritesLevels(t *testing.T) {
r := d.processRewrites("host.com", dns.TypeA) r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.1.1.1", r.IPList[0].String()) assert.True(t, net.IP{1, 1, 1, 1}.Equal(r.IPList[0]))
// match L2 // match L2
r = d.processRewrites("sub.host.com", dns.TypeA) r = d.processRewrites("sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match L3 // match L3
r = d.processRewrites("my.sub.host.com", dns.TypeA) r = d.processRewrites("my.sub.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "3.3.3.3", r.IPList[0].String()) assert.True(t, net.IP{3, 3, 3, 3}.Equal(r.IPList[0]))
} }
func TestRewritesExceptionCNAME(t *testing.T) { func TestRewritesExceptionCNAME(t *testing.T) {
@ -145,7 +145,7 @@ func TestRewritesExceptionCNAME(t *testing.T) {
r := d.processRewrites("my.host.com", dns.TypeA) r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception // match sub-domain, but handle exception
r = d.processRewrites("sub.host.com", dns.TypeA) r = d.processRewrites("sub.host.com", dns.TypeA)
@ -165,7 +165,7 @@ func TestRewritesExceptionWC(t *testing.T) {
r := d.processRewrites("my.host.com", dns.TypeA) r := d.processRewrites("my.host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "2.2.2.2", r.IPList[0].String()) assert.True(t, net.IP{2, 2, 2, 2}.Equal(r.IPList[0]))
// match sub-domain, but handle exception // match sub-domain, but handle exception
r = d.processRewrites("my.sub.host.com", dns.TypeA) r = d.processRewrites("my.sub.host.com", dns.TypeA)
@ -188,7 +188,7 @@ func TestRewritesExceptionIP(t *testing.T) {
r := d.processRewrites("host.com", dns.TypeA) r := d.processRewrites("host.com", dns.TypeA)
assert.Equal(t, Rewritten, r.Reason) assert.Equal(t, Rewritten, r.Reason)
assert.Len(t, r.IPList, 1) assert.Len(t, r.IPList, 1)
assert.Equal(t, "1.2.3.4", r.IPList[0].String()) assert.True(t, net.IP{1, 2, 3, 4}.Equal(r.IPList[0]))
// match exception // match exception
r = d.processRewrites("host.com", dns.TypeAAAA) r = d.processRewrites("host.com", dns.TypeAAAA)

View File

@ -83,20 +83,21 @@ func processIPCIDRArray(dst *map[string]bool, dstIPNet *[]net.IPNet, src []strin
// Returns the item from the "disallowedClients" list that lead to blocking IP. // Returns the item from the "disallowedClients" list that lead to blocking IP.
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty, // If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
// but the ip does not belong to it. // but the ip does not belong to it.
func (a *accessCtx) IsBlockedIP(ip string) (bool, string) { func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
ipStr := ip.String()
a.lock.Lock() a.lock.Lock()
defer a.lock.Unlock() defer a.lock.Unlock()
if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 { if len(a.allowedClients) != 0 || len(a.allowedClientsIPNet) != 0 {
_, ok := a.allowedClients[ip] _, ok := a.allowedClients[ipStr]
if ok { if ok {
return false, "" return false, ""
} }
if len(a.allowedClientsIPNet) != 0 { if len(a.allowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.allowedClientsIPNet { for _, ipnet := range a.allowedClientsIPNet {
if ipnet.Contains(ipAddr) { if ipnet.Contains(ip) {
return false, "" return false, ""
} }
} }
@ -105,15 +106,14 @@ func (a *accessCtx) IsBlockedIP(ip string) (bool, string) {
return true, "" return true, ""
} }
_, ok := a.disallowedClients[ip] _, ok := a.disallowedClients[ipStr]
if ok { if ok {
return true, ip return true, ipStr
} }
if len(a.disallowedClientsIPNet) != 0 { if len(a.disallowedClientsIPNet) != 0 {
ipAddr := net.ParseIP(ip)
for _, ipnet := range a.disallowedClientsIPNet { for _, ipnet := range a.disallowedClientsIPNet {
if ipnet.Contains(ipAddr) { if ipnet.Contains(ip) {
return true, ipnet.String() return true, ipnet.String()
} }
} }

View File

@ -1,6 +1,7 @@
package dnsforward package dnsforward
import ( import (
"net"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -10,19 +11,19 @@ func TestIsBlockedIPAllowed(t *testing.T) {
a := &accessCtx{} a := &accessCtx{}
assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil)) assert.Nil(t, a.Init([]string{"1.1.1.1", "2.2.0.0/16"}, nil, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
} }
@ -31,19 +32,19 @@ func TestIsBlockedIPDisallowed(t *testing.T) {
a := &accessCtx{} a := &accessCtx{}
assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil)) assert.Nil(t, a.Init(nil, []string{"1.1.1.1", "2.2.0.0/16"}, nil))
disallowed, disallowedRule := a.IsBlockedIP("1.1.1.1") disallowed, disallowedRule := a.IsBlockedIP(net.IPv4(1, 1, 1, 1))
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Equal(t, "1.1.1.1", disallowedRule) assert.Equal(t, "1.1.1.1", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("1.1.1.2") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(1, 1, 1, 2))
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.2.1.1") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 2, 1, 1))
assert.True(t, disallowed) assert.True(t, disallowed)
assert.Equal(t, "2.2.0.0/16", disallowedRule) assert.Equal(t, "2.2.0.0/16", disallowedRule)
disallowed, disallowedRule = a.IsBlockedIP("2.3.1.1") disallowed, disallowedRule = a.IsBlockedIP(net.IPv4(2, 3, 1, 1))
assert.False(t, disallowed) assert.False(t, disallowed)
assert.Empty(t, disallowedRule) assert.Empty(t, disallowedRule)
} }

View File

@ -25,11 +25,11 @@ type FilteringConfig struct {
// -- // --
// Filtering callback function // Filtering callback function
FilterHandler func(clientAddr string, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"` FilterHandler func(clientAddr net.IP, settings *dnsfilter.RequestFilteringSettings) `yaml:"-"`
// GetCustomUpstreamByClient - a callback function that returns upstreams configuration // GetCustomUpstreamByClient - a callback function that returns upstreams configuration
// based on the client IP address. Returns nil if there are no custom upstreams for the client // based on the client IP address. Returns nil if there are no custom upstreams for the client
// TODO(e.burkov): replace argument type with net.IP. // TODO(e.burkov): Replace argument type with net.IP.
GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"` GetCustomUpstreamByClient func(clientAddr string) *proxy.UpstreamConfig `yaml:"-"`
// Protection configuration // Protection configuration

View File

@ -298,6 +298,6 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// IsBlockedIP - return TRUE if this client should be blocked // IsBlockedIP - return TRUE if this client should be blocked
func (s *Server) IsBlockedIP(ip string) (bool, string) { func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
return s.access.IsBlockedIP(ip) return s.access.IsBlockedIP(ip)
} }

View File

@ -322,7 +322,7 @@ func TestServerCustomClientUpstream(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, dns.RcodeSuccess, reply.Rcode) assert.Equal(t, dns.RcodeSuccess, reply.Rcode)
assert.NotNil(t, reply.Answer) assert.NotNil(t, reply.Answer)
assert.Equal(t, "192.168.0.1", reply.Answer[0].(*dns.A).A.String()) assert.True(t, net.IP{192, 168, 0, 1}.Equal(reply.Answer[0].(*dns.A).A))
assert.Nil(t, s.Stop()) assert.Nil(t, s.Stop())
} }
@ -473,7 +473,7 @@ func TestBlockCNAME(t *testing.T) {
func TestClientRulesForCNAMEMatching(t *testing.T) { func TestClientRulesForCNAMEMatching(t *testing.T) {
s := createTestServer(t) s := createTestServer(t)
testUpstm := &testUpstream{testCNAMEs, testIPv4, nil} testUpstm := &testUpstream{testCNAMEs, testIPv4, nil}
s.conf.FilterHandler = func(_ string, settings *dnsfilter.RequestFilteringSettings) { s.conf.FilterHandler = func(_ net.IP, settings *dnsfilter.RequestFilteringSettings) {
settings.FilteringEnabled = false settings.FilteringEnabled = false
} }
err := s.startWithUpstream(testUpstm) err := s.startWithUpstream(testUpstm)
@ -568,7 +568,7 @@ func TestBlockedCustomIP(t *testing.T) {
assert.Len(t, reply.Answer, 1) assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "0.0.0.1", a.A.String()) assert.True(t, net.IP{0, 0, 0, 1}.Equal(a.A))
req = createTestMessageWithType("null.example.org.", dns.TypeAAAA) req = createTestMessageWithType("null.example.org.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
@ -713,7 +713,7 @@ func TestRewrite(t *testing.T) {
assert.Len(t, reply.Answer, 1) assert.Len(t, reply.Answer, 1)
a, ok := reply.Answer[0].(*dns.A) a, ok := reply.Answer[0].(*dns.A)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, "1.2.3.4", a.A.String()) assert.True(t, net.IP{1, 2, 3, 4}.Equal(a.A))
req = createTestMessageWithType("test.com.", dns.TypeAAAA) req = createTestMessageWithType("test.com.", dns.TypeAAAA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())
@ -725,7 +725,7 @@ func TestRewrite(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, reply.Answer, 2) assert.Len(t, reply.Answer, 2)
assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target) assert.Equal(t, "test.com.", reply.Answer[0].(*dns.CNAME).Target)
assert.Equal(t, "1.2.3.4", reply.Answer[1].(*dns.A).A.String()) assert.True(t, net.IP{1, 2, 3, 4}.Equal(reply.Answer[1].(*dns.A).A))
req = createTestMessageWithType("my.alias.example.org.", dns.TypeA) req = createTestMessageWithType("my.alias.example.org.", dns.TypeA)
reply, err = dns.Exchange(req, addr.String()) reply, err = dns.Exchange(req, addr.String())

View File

@ -12,7 +12,7 @@ import (
) )
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) { func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
ip := IPStringFromAddr(d.Addr) ip := IPFromAddr(d.Addr)
disallowed, _ := s.access.IsBlockedIP(ip) disallowed, _ := s.access.IsBlockedIP(ip)
if disallowed { if disallowed {
log.Tracef("Client IP %s is blocked by settings", ip) log.Tracef("Client IP %s is blocked by settings", ip)
@ -36,8 +36,7 @@ func (s *Server) getClientRequestFilteringSettings(d *proxy.DNSContext) *dnsfilt
setts := s.dnsFilter.GetConfig() setts := s.dnsFilter.GetConfig()
setts.FilteringEnabled = true setts.FilteringEnabled = true
if s.conf.FilterHandler != nil { if s.conf.FilterHandler != nil {
clientAddr := IPStringFromAddr(d.Addr) s.conf.FilterHandler(IPFromAddr(d.Addr), &setts)
s.conf.FilterHandler(clientAddr, &setts)
} }
return &setts return &setts
} }

View File

@ -36,7 +36,7 @@ func processQueryLogsAndStats(ctx *dnsContext) int {
OrigAnswer: ctx.origResp, OrigAnswer: ctx.origResp,
Result: ctx.result, Result: ctx.result,
Elapsed: elapsed, Elapsed: elapsed,
ClientIP: ipFromAddr(d.Addr), ClientIP: IPFromAddr(d.Addr),
} }
switch d.Proto { switch d.Proto {

View File

@ -8,8 +8,8 @@ import (
"github.com/AdguardTeam/golibs/utils" "github.com/AdguardTeam/golibs/utils"
) )
// ipFromAddr gets IP address from addr. // IPFromAddr gets IP address from addr.
func ipFromAddr(addr net.Addr) (ip net.IP) { func IPFromAddr(addr net.Addr) (ip net.IP) {
switch addr := addr.(type) { switch addr := addr.(type) {
case *net.UDPAddr: case *net.UDPAddr:
return addr.IP return addr.IP
@ -22,8 +22,8 @@ func ipFromAddr(addr net.Addr) (ip net.IP) {
// IPStringFromAddr extracts IP address from net.Addr. // IPStringFromAddr extracts IP address from net.Addr.
// Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone: // Note: we can't use net.SplitHostPort(a.String()) because of IPv6 zone:
// https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261 // https://github.com/AdguardTeam/AdGuardHome/internal/issues/1261
func IPStringFromAddr(addr net.Addr) (ipstr string) { func IPStringFromAddr(addr net.Addr) (ipStr string) {
if ip := ipFromAddr(addr); ip != nil { if ip := IPFromAddr(addr); ip != nil {
return ip.String() return ip.String()
} }

View File

@ -70,10 +70,12 @@ type ClientHost struct {
} }
type clientsContainer struct { type clientsContainer struct {
list map[string]*Client // name -> client list map[string]*Client // name -> client
idIndex map[string]*Client // IP -> client idIndex map[string]*Client // IP -> client
ipHost map[string]*ClientHost // IP -> Hostname // TODO(e.burkov): Think of a way to not require string conversion for
lock sync.Mutex // IP addresses.
ipHost map[string]*ClientHost // IP -> Hostname
lock sync.Mutex
allTags map[string]bool allTags map[string]bool
@ -239,7 +241,7 @@ func (clients *clientsContainer) onHostsChanged() {
} }
// Exists checks if client with this IP already exists // Exists checks if client with this IP already exists
func (clients *clientsContainer) Exists(ip string, source clientSource) bool { func (clients *clientsContainer) Exists(ip net.IP, source clientSource) bool {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -248,7 +250,7 @@ func (clients *clientsContainer) Exists(ip string, source clientSource) bool {
return true return true
} }
ch, ok := clients.ipHost[ip] ch, ok := clients.ipHost[ip.String()]
if !ok { if !ok {
return false return false
} }
@ -265,7 +267,7 @@ func stringArrayDup(a []string) []string {
} }
// Find searches for a client by IP // Find searches for a client by IP
func (clients *clientsContainer) Find(ip string) (Client, bool) { func (clients *clientsContainer) Find(ip net.IP) (Client, bool) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -287,7 +289,7 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
c, ok := clients.findByIP(ip) c, ok := clients.findByIP(net.ParseIP(ip))
if !ok { if !ok {
return nil return nil
} }
@ -307,13 +309,12 @@ func (clients *clientsContainer) FindUpstreams(ip string) *proxy.UpstreamConfig
} }
// Find searches for a client by IP (and does not lock anything) // Find searches for a client by IP (and does not lock anything)
func (clients *clientsContainer) findByIP(ip string) (Client, bool) { func (clients *clientsContainer) findByIP(ip net.IP) (Client, bool) {
ipAddr := net.ParseIP(ip) if ip == nil {
if ipAddr == nil {
return Client{}, false return Client{}, false
} }
c, ok := clients.idIndex[ip] c, ok := clients.idIndex[ip.String()]
if ok { if ok {
return *c, true return *c, true
} }
@ -324,7 +325,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if err != nil { if err != nil {
continue continue
} }
if ipnet.Contains(ipAddr) { if ipnet.Contains(ip) {
return *c, true return *c, true
} }
} }
@ -333,7 +334,7 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
if clients.dhcpServer == nil { if clients.dhcpServer == nil {
return Client{}, false return Client{}, false
} }
macFound := clients.dhcpServer.FindMACbyIP(ipAddr) macFound := clients.dhcpServer.FindMACbyIP(ip)
if macFound == nil { if macFound == nil {
return Client{}, false return Client{}, false
} }
@ -353,16 +354,15 @@ func (clients *clientsContainer) findByIP(ip string) (Client, bool) {
} }
// FindAutoClient - search for an auto-client by IP // FindAutoClient - search for an auto-client by IP
func (clients *clientsContainer) FindAutoClient(ip string) (ClientHost, bool) { func (clients *clientsContainer) FindAutoClient(ip net.IP) (ClientHost, bool) {
ipAddr := net.ParseIP(ip) if ip == nil {
if ipAddr == nil {
return ClientHost{}, false return ClientHost{}, false
} }
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
ch, ok := clients.ipHost[ip] ch, ok := clients.ipHost[ip.String()]
if ok { if ok {
return *ch, true return *ch, true
} }
@ -539,7 +539,7 @@ func (clients *clientsContainer) Update(name string, c Client) error {
} }
// SetWhoisInfo - associate WHOIS information with a client // SetWhoisInfo - associate WHOIS information with a client
func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) { func (clients *clientsContainer) SetWhoisInfo(ip net.IP, info [][]string) {
clients.lock.Lock() clients.lock.Lock()
defer clients.lock.Unlock() defer clients.lock.Unlock()
@ -549,7 +549,8 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
return return
} }
ch, ok := clients.ipHost[ip] ipStr := ip.String()
ch, ok := clients.ipHost[ipStr]
if ok { if ok {
ch.WhoisInfo = info ch.WhoisInfo = info
log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo) log.Debug("Clients: set WHOIS info for auto-client %s: %v", ch.Host, ch.WhoisInfo)
@ -561,7 +562,7 @@ func (clients *clientsContainer) SetWhoisInfo(ip string, info [][]string) {
Source: ClientSourceWHOIS, Source: ClientSourceWHOIS,
} }
ch.WhoisInfo = info ch.WhoisInfo = info
clients.ipHost[ip] = ch clients.ipHost[ipStr] = ch
log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo) log.Debug("Clients: set WHOIS info for auto-client with IP %s: %v", ip, ch.WhoisInfo)
} }

View File

@ -36,21 +36,21 @@ func TestClients(t *testing.T) {
assert.True(t, b) assert.True(t, b)
assert.Nil(t, err) assert.Nil(t, err)
c, b = clients.Find("1.1.1.1") c, b = clients.Find(net.IPv4(1, 1, 1, 1))
assert.True(t, b) assert.True(t, b)
assert.Equal(t, c.Name, "client1") assert.Equal(t, c.Name, "client1")
c, b = clients.Find("1:2:3::4") c, b = clients.Find(net.ParseIP("1:2:3::4"))
assert.True(t, b) assert.True(t, b)
assert.Equal(t, c.Name, "client1") assert.Equal(t, c.Name, "client1")
c, b = clients.Find("2.2.2.2") c, b = clients.Find(net.IPv4(2, 2, 2, 2))
assert.True(t, b) assert.True(t, b)
assert.Equal(t, c.Name, "client2") assert.Equal(t, c.Name, "client2")
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile)) assert.False(t, clients.Exists(net.IPv4(1, 2, 3, 4), ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IPv4(2, 2, 2, 2), ClientSourceHostsFile))
}) })
t.Run("add_fail_name", func(t *testing.T) { t.Run("add_fail_name", func(t *testing.T) {
@ -112,8 +112,8 @@ func TestClients(t *testing.T) {
err := clients.Update("client1", c) err := clients.Update("client1", c)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
c = Client{ c = Client{
IDs: []string{"1.1.1.2"}, IDs: []string{"1.1.1.2"},
@ -124,7 +124,7 @@ func TestClients(t *testing.T) {
err = clients.Update("client1", c) err = clients.Update("client1", c)
assert.Nil(t, err) assert.Nil(t, err)
c, b := clients.Find("1.1.1.2") c, b := clients.Find(net.IPv4(1, 1, 1, 2))
assert.True(t, b) assert.True(t, b)
assert.Equal(t, "client1-renamed", c.Name) assert.Equal(t, "client1-renamed", c.Name)
assert.Equal(t, "1.1.1.2", c.IDs[0]) assert.Equal(t, "1.1.1.2", c.IDs[0])
@ -135,7 +135,7 @@ func TestClients(t *testing.T) {
t.Run("del_success", func(t *testing.T) { t.Run("del_success", func(t *testing.T) {
b := clients.Del("client1-renamed") b := clients.Del("client1-renamed")
assert.True(t, b) assert.True(t, b)
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile)) assert.False(t, clients.Exists(net.IPv4(1, 1, 1, 2), ClientSourceHostsFile))
}) })
t.Run("del_fail", func(t *testing.T) { t.Run("del_fail", func(t *testing.T) {
@ -156,7 +156,7 @@ func TestClients(t *testing.T) {
assert.True(t, b) assert.True(t, b)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile)) assert.True(t, clients.Exists(net.IPv4(1, 1, 1, 1), ClientSourceHostsFile))
}) })
t.Run("addhost_fail", func(t *testing.T) { t.Run("addhost_fail", func(t *testing.T) {
@ -174,12 +174,12 @@ func TestClientsWhois(t *testing.T) {
whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}} whois := [][]string{{"orgname", "orgname-val"}, {"country", "country-val"}}
// set whois info on new client // set whois info on new client
clients.SetWhoisInfo("1.1.1.255", whois) clients.SetWhoisInfo(net.IPv4(1, 1, 1, 255), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1]) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.255"].WhoisInfo[0][1])
// set whois info on existing auto-client // set whois info on existing auto-client
_, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS) _, _ = clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
clients.SetWhoisInfo("1.1.1.1", whois) clients.SetWhoisInfo(net.IPv4(1, 1, 1, 1), whois)
assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1]) assert.Equal(t, "orgname-val", clients.ipHost["1.1.1.1"].WhoisInfo[0][1])
// Check that we cannot set whois info on a manually-added client // Check that we cannot set whois info on a manually-added client
@ -188,7 +188,7 @@ func TestClientsWhois(t *testing.T) {
Name: "client1", Name: "client1",
} }
_, _ = clients.Add(c) _, _ = clients.Add(c)
clients.SetWhoisInfo("1.1.1.2", whois) clients.SetWhoisInfo(net.IPv4(1, 1, 1, 2), whois)
assert.Nil(t, clients.ipHost["1.1.1.2"]) assert.Nil(t, clients.ipHost["1.1.1.2"])
_ = clients.Del("client1") _ = clients.Del("client1")
} }

View File

@ -3,6 +3,7 @@ package home
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/http" "net/http"
) )
@ -229,8 +230,9 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
q := r.URL.Query() q := r.URL.Query()
data := []map[string]interface{}{} data := []map[string]interface{}{}
for i := 0; ; i++ { for i := 0; ; i++ {
ip := q.Get(fmt.Sprintf("ip%d", i)) ipStr := q.Get(fmt.Sprintf("ip%d", i))
if len(ip) == 0 { ip := net.ParseIP(ipStr)
if ip == nil {
break break
} }
@ -248,7 +250,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
} }
el[ip] = cj el[ipStr] = cj
data = append(data, el) data = append(data, el)
} }
@ -267,7 +269,8 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
// findTemporary looks up the IP in temporary storages, like autohosts or // findTemporary looks up the IP in temporary storages, like autohosts or
// blocklists. // blocklists.
func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found bool) { func (clients *clientsContainer) findTemporary(ip net.IP) (cj clientJSON, found bool) {
ipStr := ip.String()
ch, ok := clients.FindAutoClient(ip) ch, ok := clients.FindAutoClient(ip)
if !ok { if !ok {
// It is still possible that the IP used to be in the runtime // It is still possible that the IP used to be in the runtime
@ -281,7 +284,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found
} }
cj = clientJSON{ cj = clientJSON{
IDs: []string{ip}, IDs: []string{ipStr},
Disallowed: disallowed, Disallowed: disallowed,
DisallowedRule: rule, DisallowedRule: rule,
} }
@ -289,7 +292,7 @@ func (clients *clientsContainer) findTemporary(ip string) (cj clientJSON, found
return cj, true return cj, true
} }
cj = clientHostToJSON(ip, ch) cj = clientHostToJSON(ipStr, ch)
cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip) cj.Disallowed, cj.DisallowedRule = clients.dnsServer.IsBlockedIP(ip)
return cj, true return cj, true

View File

@ -2,6 +2,7 @@ package home
import ( import (
"io/ioutil" "io/ioutil"
"net"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -40,7 +41,7 @@ type configuration struct {
// It's reset after config is parsed // It's reset after config is parsed
fileData []byte fileData []byte
BindHost string `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to BindHost net.IP `yaml:"bind_host"` // BindHost is the IP address of the HTTP server to bind to
BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server BindPort int `yaml:"bind_port"` // BindPort is the port the HTTP server
BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client BetaBindPort int `yaml:"beta_bind_port"` // BetaBindPort is the port for new client
Users []User `yaml:"users"` // Users that can access HTTP server Users []User `yaml:"users"` // Users that can access HTTP server
@ -74,7 +75,7 @@ type configuration struct {
// field ordering is important -- yaml fields will mirror ordering from here // field ordering is important -- yaml fields will mirror ordering from here
type dnsConfig struct { type dnsConfig struct {
BindHost string `yaml:"bind_host"` BindHost net.IP `yaml:"bind_host"`
Port int `yaml:"port"` Port int `yaml:"port"`
// time interval for statistics (in days) // time interval for statistics (in days)
@ -121,9 +122,9 @@ type tlsConfigSettings struct {
var config = configuration{ var config = configuration{
BindPort: 3000, BindPort: 3000,
BetaBindPort: 0, BetaBindPort: 0,
BindHost: "0.0.0.0", BindHost: net.IP{0, 0, 0, 0},
DNS: dnsConfig{ DNS: dnsConfig{
BindHost: "0.0.0.0", BindHost: net.IP{0, 0, 0, 0},
Port: 53, Port: 53,
StatsInterval: 1, StatsInterval: 1,
FilteringConfig: dnsforward.FilteringConfig{ FilteringConfig: dnsforward.FilteringConfig{

View File

@ -36,11 +36,12 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
// --------------- // ---------------
// dns run control // dns run control
// --------------- // ---------------
func addDNSAddress(dnsAddresses *[]string, addr string) { func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
hostport := addr.String()
if config.DNS.Port != 53 { if config.DNS.Port != 53 {
addr = fmt.Sprintf("%s:%d", addr, config.DNS.Port) hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
} }
*dnsAddresses = append(*dnsAddresses, addr) *dnsAddresses = append(*dnsAddresses, hostport)
} }
func handleStatus(w http.ResponseWriter, _ *http.Request) { func handleStatus(w http.ResponseWriter, _ *http.Request) {

View File

@ -31,7 +31,7 @@ type netInterfaceJSON struct {
Name string `json:"name"` Name string `json:"name"`
MTU int `json:"mtu"` MTU int `json:"mtu"`
HardwareAddr string `json:"hardware_address"` HardwareAddr string `json:"hardware_address"`
Addresses []string `json:"ip_addresses"` Addresses []net.IP `json:"ip_addresses"`
Flags string `json:"flags"` Flags string `json:"flags"`
} }
@ -69,7 +69,7 @@ func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request
type checkConfigReqEnt struct { type checkConfigReqEnt struct {
Port int `json:"port"` Port int `json:"port"`
IP string `json:"ip"` IP net.IP `json:"ip"`
Autofix bool `json:"autofix"` Autofix bool `json:"autofix"`
} }
@ -138,7 +138,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
respData.DNS.Status = err.Error() respData.DNS.Status = err.Error()
} else if reqData.DNS.IP != "0.0.0.0" { } else if !reqData.DNS.IP.IsUnspecified() {
respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP) respData.StaticIP = handleStaticIP(reqData.DNS.IP, reqData.SetStaticIP)
} }
} }
@ -154,7 +154,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request)
// handleStaticIP - handles static IP request // handleStaticIP - handles static IP request
// It either checks if we have a static IP // It either checks if we have a static IP
// Or if set=true, it tries to set it // Or if set=true, it tries to set it
func handleStaticIP(ip string, set bool) staticIPJSON { func handleStaticIP(ip net.IP, set bool) staticIPJSON {
resp := staticIPJSON{} resp := staticIPJSON{}
interfaceName := util.GetInterfaceByIP(ip) interfaceName := util.GetInterfaceByIP(ip)
@ -186,7 +186,7 @@ func handleStaticIP(ip string, set bool) staticIPJSON {
if isStaticIP { if isStaticIP {
resp.Static = "yes" resp.Static = "yes"
} }
resp.IP = util.GetSubnet(interfaceName) resp.IP = util.GetSubnet(interfaceName).String()
} }
return resp return resp
} }
@ -262,7 +262,7 @@ func disableDNSStubListener() error {
} }
type applyConfigReqEnt struct { type applyConfigReqEnt struct {
IP string `json:"ip"` IP net.IP `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
} }
@ -297,7 +297,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
} }
restartHTTP := true restartHTTP := true
if config.BindHost == newSettings.Web.IP && config.BindPort == newSettings.Web.Port { if config.BindHost.Equal(newSettings.Web.IP) && config.BindPort == newSettings.Web.Port {
// no need to rebind // no need to rebind
restartHTTP = false restartHTTP = false
} }
@ -307,7 +307,7 @@ func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port) err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
if err != nil { if err != nil {
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s", httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err) net.JoinHostPort(newSettings.Web.IP.String(), strconv.Itoa(newSettings.Web.Port)), err)
return return
} }
@ -388,18 +388,18 @@ func (web *Web) registerInstallHandlers() {
// checkConfigReqEntBeta is a struct representing new client's config check // checkConfigReqEntBeta is a struct representing new client's config check
// request entry. It supports multiple IP values unlike the checkConfigReqEnt. // request entry. It supports multiple IP values unlike the checkConfigReqEnt.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReqEnt. // functionality will appear in default checkConfigReqEnt.
type checkConfigReqEntBeta struct { type checkConfigReqEntBeta struct {
Port int `json:"port"` Port int `json:"port"`
IP []string `json:"ip"` IP []net.IP `json:"ip"`
Autofix bool `json:"autofix"` Autofix bool `json:"autofix"`
} }
// checkConfigReqBeta is a struct representing new client's config check request // checkConfigReqBeta is a struct representing new client's config check request
// body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt. // body. It uses checkConfigReqEntBeta instead of checkConfigReqEnt.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default checkConfigReq. // functionality will appear in default checkConfigReq.
type checkConfigReqBeta struct { type checkConfigReqBeta struct {
Web checkConfigReqEntBeta `json:"web"` Web checkConfigReqEntBeta `json:"web"`
@ -410,7 +410,7 @@ type checkConfigReqBeta struct {
// handleInstallCheckConfigBeta is a substitution of /install/check_config // handleInstallCheckConfigBeta is a substitution of /install/check_config
// handler for new client. // handler for new client.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallCheckConfig. // functionality will appear in default handleInstallCheckConfig.
func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Request) {
reqData := checkConfigReqBeta{} reqData := checkConfigReqBeta{}
@ -456,17 +456,17 @@ func (web *Web) handleInstallCheckConfigBeta(w http.ResponseWriter, r *http.Requ
// applyConfigReqEntBeta is a struct representing new client's config setting // applyConfigReqEntBeta is a struct representing new client's config setting
// request entry. It supports multiple IP values unlike the applyConfigReqEnt. // request entry. It supports multiple IP values unlike the applyConfigReqEnt.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReqEnt. // functionality will appear in default applyConfigReqEnt.
type applyConfigReqEntBeta struct { type applyConfigReqEntBeta struct {
IP []string `json:"ip"` IP []net.IP `json:"ip"`
Port int `json:"port"` Port int `json:"port"`
} }
// applyConfigReqBeta is a struct representing new client's config setting // applyConfigReqBeta is a struct representing new client's config setting
// request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt. // request body. It uses applyConfigReqEntBeta instead of applyConfigReqEnt.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default applyConfigReq. // functionality will appear in default applyConfigReq.
type applyConfigReqBeta struct { type applyConfigReqBeta struct {
Web applyConfigReqEntBeta `json:"web"` Web applyConfigReqEntBeta `json:"web"`
@ -478,7 +478,7 @@ type applyConfigReqBeta struct {
// handleInstallConfigureBeta is a substitution of /install/configure handler // handleInstallConfigureBeta is a substitution of /install/configure handler
// for new client. // for new client.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallConfigure. // functionality will appear in default handleInstallConfigure.
func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Request) {
reqData := applyConfigReqBeta{} reqData := applyConfigReqBeta{}
@ -523,7 +523,7 @@ func (web *Web) handleInstallConfigureBeta(w http.ResponseWriter, r *http.Reques
// firstRunDataBeta is a struct representing new client's getting addresses // firstRunDataBeta is a struct representing new client's getting addresses
// request body. It uses array of structs instead of map. // request body. It uses array of structs instead of map.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default firstRunData. // functionality will appear in default firstRunData.
type firstRunDataBeta struct { type firstRunDataBeta struct {
WebPort int `json:"web_port"` WebPort int `json:"web_port"`
@ -534,7 +534,7 @@ type firstRunDataBeta struct {
// handleInstallConfigureBeta is a substitution of /install/get_addresses // handleInstallConfigureBeta is a substitution of /install/get_addresses
// handler for new client. // handler for new client.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handleInstallGetAddresses. // functionality will appear in default handleInstallGetAddresses.
func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) { func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Request) {
data := firstRunDataBeta{} data := firstRunDataBeta{}
@ -570,7 +570,7 @@ func (web *Web) handleInstallGetAddressesBeta(w http.ResponseWriter, r *http.Req
// registerBetaInstallHandlers registers the install handlers for new client // registerBetaInstallHandlers registers the install handlers for new client
// with the structures it supports. // with the structures it supports.
// //
// TODO(e.burkov): this should removed with the API v1 when the appropriate // TODO(e.burkov): This should removed with the API v1 when the appropriate
// functionality will appear in default handlers. // functionality will appear in default handlers.
func (web *Web) registerBetaInstallHandlers() { func (web *Web) registerBetaInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta))) Context.mux.HandleFunc("/control/install/get_addresses_beta", preInstall(ensureGET(web.handleInstallGetAddressesBeta)))

View File

@ -55,8 +55,8 @@ func initDNSServer() error {
filterConf := config.DNS.DnsfilterConf filterConf := config.DNS.DnsfilterConf
bindhost := config.DNS.BindHost bindhost := config.DNS.BindHost
if config.DNS.BindHost == "0.0.0.0" { if config.DNS.BindHost.IsUnspecified() {
bindhost = "127.0.0.1" bindhost = net.IPv4(127, 0, 0, 1)
} }
filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) filterConf.ResolverAddress = fmt.Sprintf("%s:%d", bindhost, config.DNS.Port)
filterConf.AutoHosts = &Context.autoHosts filterConf.AutoHosts = &Context.autoHosts
@ -98,26 +98,24 @@ func isRunning() bool {
} }
func onDNSRequest(d *proxy.DNSContext) { func onDNSRequest(d *proxy.DNSContext) {
ip := dnsforward.IPStringFromAddr(d.Addr) ip := dnsforward.IPFromAddr(d.Addr)
if ip == "" { if ip == nil {
// This would be quite weird if we get here // This would be quite weird if we get here
return return
} }
ipAddr := net.ParseIP(ip) if !ip.IsLoopback() {
if !ipAddr.IsLoopback() {
Context.rdns.Begin(ip) Context.rdns.Begin(ip)
} }
if !Context.ipDetector.detectSpecialNetwork(ipAddr) { if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip) Context.whois.Begin(ip)
} }
} }
func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) { func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
bindHost := net.ParseIP(config.DNS.BindHost)
newconfig = dnsforward.ServerConfig{ newconfig = dnsforward.ServerConfig{
UDPListenAddr: &net.UDPAddr{IP: bindHost, Port: config.DNS.Port}, UDPListenAddr: &net.UDPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
TCPListenAddr: &net.TCPAddr{IP: bindHost, Port: config.DNS.Port}, TCPListenAddr: &net.TCPAddr{IP: config.DNS.BindHost, Port: config.DNS.Port},
FilteringConfig: config.DNS.FilteringConfig, FilteringConfig: config.DNS.FilteringConfig,
ConfigModified: onConfigModified, ConfigModified: onConfigModified,
HTTPRegister: httpRegister, HTTPRegister: httpRegister,
@ -131,20 +129,20 @@ func generateServerConfig() (newconfig dnsforward.ServerConfig, err error) {
if tlsConf.PortDNSOverTLS != 0 { if tlsConf.PortDNSOverTLS != 0 {
newconfig.TLSListenAddr = &net.TCPAddr{ newconfig.TLSListenAddr = &net.TCPAddr{
IP: bindHost, IP: config.DNS.BindHost,
Port: tlsConf.PortDNSOverTLS, Port: tlsConf.PortDNSOverTLS,
} }
} }
if tlsConf.PortDNSOverQUIC != 0 { if tlsConf.PortDNSOverQUIC != 0 {
newconfig.QUICListenAddr = &net.UDPAddr{ newconfig.QUICListenAddr = &net.UDPAddr{
IP: bindHost, IP: config.DNS.BindHost,
Port: int(tlsConf.PortDNSOverQUIC), Port: int(tlsConf.PortDNSOverQUIC),
} }
} }
if tlsConf.PortDNSCrypt != 0 { if tlsConf.PortDNSCrypt != 0 {
newconfig.DNSCryptConfig, err = newDNSCrypt(bindHost, tlsConf) newconfig.DNSCryptConfig, err = newDNSCrypt(config.DNS.BindHost, tlsConf)
if err != nil { if err != nil {
// Don't wrap the error, because it's already // Don't wrap the error, because it's already
// wrapped by newDNSCrypt. // wrapped by newDNSCrypt.
@ -245,7 +243,7 @@ func getDNSEncryption() dnsEncryption {
func getDNSAddresses() []string { func getDNSAddresses() []string {
dnsAddresses := []string{} dnsAddresses := []string{}
if config.DNS.BindHost == "0.0.0.0" { if config.DNS.BindHost.IsUnspecified() {
ifaces, e := util.GetValidNetInterfacesForWeb() ifaces, e := util.GetValidNetInterfacesForWeb()
if e != nil { if e != nil {
log.Error("Couldn't get network interfaces: %v", e) log.Error("Couldn't get network interfaces: %v", e)
@ -276,10 +274,10 @@ func getDNSAddresses() []string {
} }
// If a client has his own settings, apply them // If a client has his own settings, apply them
func applyAdditionalFiltering(clientAddr string, setts *dnsfilter.RequestFilteringSettings) { func applyAdditionalFiltering(clientAddr net.IP, setts *dnsfilter.RequestFilteringSettings) {
Context.dnsFilter.ApplyBlockedServices(setts, nil, true) Context.dnsFilter.ApplyBlockedServices(setts, nil, true)
if len(clientAddr) == 0 { if clientAddr == nil {
return return
} }
setts.ClientIP = clientAddr setts.ClientIP = clientAddr
@ -328,13 +326,11 @@ func startDNSServer() error {
Context.queryLog.Start() Context.queryLog.Start()
const topClientsNumber = 100 // the number of clients to get const topClientsNumber = 100 // the number of clients to get
topClients := Context.stats.GetTopClientsIP(topClientsNumber) for _, ip := range Context.stats.GetTopClientsIP(topClientsNumber) {
for _, ip := range topClients { if !ip.IsLoopback() {
ipAddr := net.ParseIP(ip)
if !ipAddr.IsLoopback() {
Context.rdns.Begin(ip) Context.rdns.Begin(ip)
} }
if !Context.ipDetector.detectSpecialNetwork(ipAddr) { if !Context.ipDetector.detectSpecialNetwork(ip) {
Context.whois.Begin(ip) Context.whois.Begin(ip)
} }
} }

View File

@ -206,7 +206,7 @@ func setupConfig(args options) {
} }
// override bind host/port from the console // override bind host/port from the console
if args.bindHost != "" { if args.bindHost != nil {
config.BindHost = args.bindHost config.BindHost = args.bindHost
} }
if args.bindPort != 0 { if args.bindPort != 0 {
@ -575,36 +575,40 @@ func printHTTPAddresses(proto string) {
port = strconv.Itoa(tlsConf.PortHTTPS) port = strconv.Itoa(tlsConf.PortHTTPS)
} }
var hostStr string
if proto == "https" && tlsConf.ServerName != "" { if proto == "https" && tlsConf.ServerName != "" {
if tlsConf.PortHTTPS == 443 { if tlsConf.PortHTTPS == 443 {
log.Printf("Go to https://%s", tlsConf.ServerName) log.Printf("Go to https://%s", tlsConf.ServerName)
} else { } else {
log.Printf("Go to https://%s:%s", tlsConf.ServerName, port) log.Printf("Go to https://%s:%s", tlsConf.ServerName, port)
} }
} else if config.BindHost == "0.0.0.0" { } else if config.BindHost.IsUnspecified() {
log.Println("AdGuard Home is available on the following addresses:") log.Println("AdGuard Home is available on the following addresses:")
ifaces, err := util.GetValidNetInterfacesForWeb() ifaces, err := util.GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
// That's weird, but we'll ignore it // That's weird, but we'll ignore it
log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 { if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
} }
return return
} }
for _, iface := range ifaces { for _, iface := range ifaces {
for _, addr := range iface.Addresses { for _, addr := range iface.Addresses {
log.Printf("Go to %s://%s", proto, net.JoinHostPort(addr, strconv.Itoa(config.BindPort))) hostStr = addr.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BindPort)))
if config.BetaBindPort != 0 { if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(addr, strconv.Itoa(config.BetaBindPort))) log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
} }
} }
} }
} else { } else {
log.Printf("Go to %s://%s", proto, net.JoinHostPort(config.BindHost, port)) hostStr = config.BindHost.String()
log.Printf("Go to %s://%s", proto, net.JoinHostPort(hostStr, port))
if config.BetaBindPort != 0 { if config.BetaBindPort != 0 {
log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(config.BindHost, strconv.Itoa(config.BetaBindPort))) log.Printf("Go to %s://%s (BETA)", proto, net.JoinHostPort(hostStr, strconv.Itoa(config.BetaBindPort)))
} }
} }
} }

View File

@ -1,6 +1,6 @@
// +build !race // +build !race
// TODO(e.burkov): remove this weird buildtag. // TODO(e.burkov): Remove this weird buildtag.
package home package home

View File

@ -2,6 +2,7 @@ package home
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"strconv" "strconv"
@ -13,7 +14,7 @@ type options struct {
verbose bool // is verbose logging enabled verbose bool // is verbose logging enabled
configFilename string // path to the config file configFilename string // path to the config file
workDir string // path to the working directory where we will store the filters data and the querylog workDir string // path to the working directory where we will store the filters data and the querylog
bindHost string // host address to bind HTTP server on bindHost net.IP // host address to bind HTTP server on
bindPort int // port to serve HTTP pages on bindPort int // port to serve HTTP pages on
logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog logFile string // Path to the log file. If empty, write to stdout. If "syslog", writes to syslog
pidFile string // File name to save PID to pidFile string // File name to save PID to
@ -54,10 +55,19 @@ type arg struct {
// against its zero value and return nil if the parameter value is // against its zero value and return nil if the parameter value is
// zero otherwise they return a string slice of the parameter // zero otherwise they return a string slice of the parameter
func ipSliceOrNil(ip net.IP) []string {
if ip == nil {
return nil
}
return []string{ip.String()}
}
func stringSliceOrNil(s string) []string { func stringSliceOrNil(s string) []string {
if s == "" { if s == "" {
return nil return nil
} }
return []string{s} return []string{s}
} }
@ -65,6 +75,7 @@ func intSliceOrNil(i int) []string {
if i == 0 { if i == 0 {
return nil return nil
} }
return []string{strconv.Itoa(i)} return []string{strconv.Itoa(i)}
} }
@ -72,6 +83,7 @@ func boolSliceOrNil(b bool) []string {
if b { if b {
return []string{} return []string{}
} }
return nil return nil
} }
@ -96,8 +108,8 @@ var workDirArg = arg{
var hostArg = arg{ var hostArg = arg{
"Host address to bind HTTP server on", "Host address to bind HTTP server on",
"host", "h", "host", "h",
func(o options, v string) (options, error) { o.bindHost = v; return o, nil }, nil, nil, func(o options, v string) (options, error) { o.bindHost = net.ParseIP(v); return o, nil }, nil, nil,
func(o options) []string { return stringSliceOrNil(o.bindHost) }, func(o options) []string { return ipSliceOrNil(o.bindHost) },
} }
var portArg = arg{ var portArg = arg{

View File

@ -2,6 +2,7 @@ package home
import ( import (
"fmt" "fmt"
"net"
"testing" "testing"
) )
@ -65,14 +66,14 @@ func TestParseWorkDir(t *testing.T) {
} }
func TestParseBindHost(t *testing.T) { func TestParseBindHost(t *testing.T) {
if testParseOk(t).bindHost != "" { if testParseOk(t).bindHost != nil {
t.Fatal("empty is no host") t.Fatal("empty is no host")
} }
if testParseOk(t, "-h", "addr").bindHost != "addr" { if !testParseOk(t, "-h", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("-h is host") t.Fatal("-h is host")
} }
testParseParamMissing(t, "-h") testParseParamMissing(t, "-h")
if testParseOk(t, "--host", "addr").bindHost != "addr" { if !testParseOk(t, "--host", "1.2.3.4").bindHost.Equal(net.IP{1, 2, 3, 4}) {
t.Fatal("--host is host") t.Fatal("--host is host")
} }
testParseParamMissing(t, "--host") testParseParamMissing(t, "--host")
@ -204,7 +205,7 @@ func TestSerializeWorkDir(t *testing.T) {
} }
func TestSerializeBindHost(t *testing.T) { func TestSerializeBindHost(t *testing.T) {
testSerialize(t, options{bindHost: "addr"}, "-h", "addr") testSerialize(t, options{bindHost: net.IP{1, 2, 3, 4}}, "-h", "1.2.3.4")
} }
func TestSerializeBindPort(t *testing.T) { func TestSerializeBindPort(t *testing.T) {

View File

@ -2,6 +2,7 @@ package home
import ( import (
"encoding/binary" "encoding/binary"
"net"
"strings" "strings"
"time" "time"
@ -15,7 +16,7 @@ import (
type RDNS struct { type RDNS struct {
dnsServer *dnsforward.Server dnsServer *dnsforward.Server
clients *clientsContainer clients *clientsContainer
ipChannel chan string // pass data from DNS request handling thread to rDNS thread ipChannel chan net.IP // pass data from DNS request handling thread to rDNS thread
// Contains IP addresses of clients to be resolved by rDNS // Contains IP addresses of clients to be resolved by rDNS
// If IP address is resolved, it stays here while it's inside Clients. // If IP address is resolved, it stays here while it's inside Clients.
@ -35,15 +36,15 @@ func InitRDNS(dnsServer *dnsforward.Server, clients *clientsContainer) *RDNS {
cconf.MaxCount = 10000 cconf.MaxCount = 10000
r.ipAddrs = cache.New(cconf) r.ipAddrs = cache.New(cconf)
r.ipChannel = make(chan string, 256) r.ipChannel = make(chan net.IP, 256)
go r.workerLoop() go r.workerLoop()
return &r return &r
} }
// Begin - add IP address to rDNS queue // Begin - add IP address to rDNS queue
func (r *RDNS) Begin(ip string) { func (r *RDNS) Begin(ip net.IP) {
now := uint64(time.Now().Unix()) now := uint64(time.Now().Unix())
expire := r.ipAddrs.Get([]byte(ip)) expire := r.ipAddrs.Get(ip)
if len(expire) != 0 { if len(expire) != 0 {
exp := binary.BigEndian.Uint64(expire) exp := binary.BigEndian.Uint64(expire)
if exp > now { if exp > now {
@ -54,7 +55,7 @@ func (r *RDNS) Begin(ip string) {
expire = make([]byte, 8) expire = make([]byte, 8)
const ttl = 1 * 60 * 60 const ttl = 1 * 60 * 60
binary.BigEndian.PutUint64(expire, now+ttl) binary.BigEndian.PutUint64(expire, now+ttl)
_ = r.ipAddrs.Set([]byte(ip), expire) _ = r.ipAddrs.Set(ip, expire)
if r.clients.Exists(ip, ClientSourceRDNS) { if r.clients.Exists(ip, ClientSourceRDNS) {
return return
@ -70,7 +71,7 @@ func (r *RDNS) Begin(ip string) {
} }
// Use rDNS to get hostname by IP address // Use rDNS to get hostname by IP address
func (r *RDNS) resolve(ip string) string { func (r *RDNS) resolve(ip net.IP) string {
log.Tracef("Resolving host for %s", ip) log.Tracef("Resolving host for %s", ip)
req := dns.Msg{} req := dns.Msg{}
@ -83,7 +84,7 @@ func (r *RDNS) resolve(ip string) string {
}, },
} }
var err error var err error
req.Question[0].Name, err = dns.ReverseAddr(ip) req.Question[0].Name, err = dns.ReverseAddr(ip.String())
if err != nil { if err != nil {
log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err) log.Debug("Error while calling dns.ReverseAddr(%s): %s", ip, err)
return "" return ""
@ -123,6 +124,6 @@ func (r *RDNS) workerLoop() {
continue continue
} }
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) _, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
} }
} }

View File

@ -1,6 +1,7 @@
package home package home
import ( import (
"net"
"testing" "testing"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
@ -16,6 +17,6 @@ func TestResolveRDNS(t *testing.T) {
clients := &clientsContainer{} clients := &clientsContainer{}
rdns := InitRDNS(dns, clients) rdns := InitRDNS(dns, clients)
r := rdns.resolve("1.1.1.1") r := rdns.resolve(net.IP{1, 1, 1, 1})
assert.Equal(t, "one.one.one.one", r, r) assert.Equal(t, "one.one.one.one", r, r)
} }

View File

@ -31,7 +31,7 @@ const (
type webConfig struct { type webConfig struct {
firstRun bool firstRun bool
BindHost string BindHost net.IP
BindPort int BindPort int
BetaBindPort int BetaBindPort int
PortHTTPS int PortHTTPS int
@ -161,10 +161,11 @@ func (web *Web) Start() {
printHTTPAddresses("http") printHTTPAddresses("http")
errs := make(chan error, 2) errs := make(chan error, 2)
hostStr := web.conf.BindHost.String()
// we need to have new instance, because after Shutdown() the Server is not usable // we need to have new instance, because after Shutdown() the Server is not usable
web.httpServer = &http.Server{ web.httpServer = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG), ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BindPort)), Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody), Handler: withMiddlewares(Context.mux, limitRequestBody),
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
@ -177,7 +178,7 @@ func (web *Web) Start() {
if web.conf.BetaBindPort != 0 { if web.conf.BetaBindPort != 0 {
web.httpServerBeta = &http.Server{ web.httpServerBeta = &http.Server{
ErrorLog: log.StdLog("web: http", log.DEBUG), ErrorLog: log.StdLog("web: http", log.DEBUG),
Addr: net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.BetaBindPort)), Addr: net.JoinHostPort(hostStr, strconv.Itoa(web.conf.BetaBindPort)),
Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta), Handler: withMiddlewares(Context.mux, limitRequestBody, web.wrapIndexBeta),
ReadTimeout: web.conf.ReadTimeout, ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
@ -236,7 +237,7 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock() web.httpsServer.cond.L.Unlock()
// prepare HTTPS server // prepare HTTPS server
address := net.JoinHostPort(web.conf.BindHost, strconv.Itoa(web.conf.PortHTTPS)) address := net.JoinHostPort(web.conf.BindHost.String(), strconv.Itoa(web.conf.PortHTTPS))
web.httpsServer.server = &http.Server{ web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG), ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: address, Addr: address,

View File

@ -26,7 +26,7 @@ const (
// Whois - module context // Whois - module context
type Whois struct { type Whois struct {
clients *clientsContainer clients *clientsContainer
ipChan chan string ipChan chan net.IP
timeoutMsec uint timeoutMsec uint
// Contains IP addresses of clients // Contains IP addresses of clients
@ -46,7 +46,7 @@ func initWhois(clients *clientsContainer) *Whois {
cconf.MaxCount = 10000 cconf.MaxCount = 10000
w.ipAddrs = cache.New(cconf) w.ipAddrs = cache.New(cconf)
w.ipChan = make(chan string, 255) w.ipChan = make(chan net.IP, 255)
go w.workerLoop() go w.workerLoop()
return &w return &w
} }
@ -183,9 +183,9 @@ func (w *Whois) queryAll(target string) (string, error) {
} }
// Request WHOIS information // Request WHOIS information
func (w *Whois) process(ip string) [][]string { func (w *Whois) process(ip net.IP) [][]string {
data := [][]string{} data := [][]string{}
resp, err := w.queryAll(ip) resp, err := w.queryAll(ip.String())
if err != nil { if err != nil {
log.Debug("Whois: error: %s IP:%s", err, ip) log.Debug("Whois: error: %s IP:%s", err, ip)
return data return data
@ -209,7 +209,7 @@ func (w *Whois) process(ip string) [][]string {
} }
// Begin - begin requesting WHOIS info // Begin - begin requesting WHOIS info
func (w *Whois) Begin(ip string) { func (w *Whois) Begin(ip net.IP) {
now := uint64(time.Now().Unix()) now := uint64(time.Now().Unix())
expire := w.ipAddrs.Get([]byte(ip)) expire := w.ipAddrs.Get([]byte(ip))
if len(expire) != 0 { if len(expire) != 0 {

View File

@ -22,9 +22,11 @@ var logEntryHandlers = map[string]logEntryHandler{
if !ok { if !ok {
return nil return nil
} }
if len(ent.IP) == 0 {
ent.IP = v if ent.IP == nil {
ent.IP = net.ParseIP(v)
} }
return nil return nil
}, },
"T": func(t json.Token, ent *logEntry) error { "T": func(t json.Token, ent *logEntry) error {

View File

@ -47,7 +47,7 @@ func TestDecodeLogEntry(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
want := &logEntry{ want := &logEntry{
IP: "127.0.0.1", IP: net.IPv4(127, 0, 0, 1),
Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC), Time: time.Date(2020, 11, 25, 15, 55, 56, 519796000, time.UTC),
QHost: "an.yandex.ru", QHost: "an.yandex.ru",
QType: "A", QType: "A",

View File

@ -14,22 +14,19 @@ import (
// TODO(a.garipov): Use a proper structured approach here. // TODO(a.garipov): Use a proper structured approach here.
// Get Client IP address // Get Client IP address
func (l *queryLog) getClientIP(clientIP string) string { func (l *queryLog) getClientIP(ip net.IP) (clientIP net.IP) {
if l.conf.AnonymizeClientIP { if l.conf.AnonymizeClientIP && ip != nil {
ip := net.ParseIP(clientIP) const AnonymizeClientIPv4Mask = 16
if ip != nil { const AnonymizeClientIPv6Mask = 112
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 16 if ip.To4() != nil {
const AnonymizeClientIP6Mask = 112 return ip.Mask(net.CIDRMask(AnonymizeClientIPv4Mask, 32))
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
} }
return ip.Mask(net.CIDRMask(AnonymizeClientIPv6Mask, 128))
} }
return clientIP return ip
} }
// jobject is a JSON object alias. // jobject is a JSON object alias.
@ -153,9 +150,9 @@ func answerToMap(a *dns.Msg) (answers []jobject) {
// try most common record types // try most common record types
switch v := k.(type) { switch v := k.(type) {
case *dns.A: case *dns.A:
answer["value"] = v.A.String() answer["value"] = v.A
case *dns.AAAA: case *dns.AAAA:
answer["value"] = v.AAAA.String() answer["value"] = v.AAAA
case *dns.MX: case *dns.MX:
answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx) answer["value"] = fmt.Sprintf("%v %v", v.Preference, v.Mx)
case *dns.CNAME: case *dns.CNAME:

View File

@ -3,6 +3,7 @@ package querylog
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@ -60,7 +61,7 @@ func NewClientProto(s string) (cp ClientProto, err error) {
// logEntry - represents a single log entry // logEntry - represents a single log entry
type logEntry struct { type logEntry struct {
IP string `json:"IP"` // Client IP IP net.IP `json:"IP"` // Client IP
Time time.Time `json:"T"` Time time.Time `json:"T"`
QHost string `json:"QH"` QHost string `json:"QH"`
@ -147,7 +148,7 @@ func (l *queryLog) Add(params AddParams) {
now := time.Now() now := time.Now()
entry := logEntry{ entry := logEntry{
IP: l.getClientIP(params.ClientIP.String()), IP: l.getClientIP(params.ClientIP),
Time: now, Time: now,
Result: *params.Result, Result: *params.Result,

View File

@ -40,27 +40,27 @@ func TestQueryLog(t *testing.T) {
l := newQueryLog(conf) l := newQueryLog(conf)
// add disk entries // add disk entries
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// write to disk (first file) // write to disk (first file)
_ = l.flushLogBuffer(true) _ = l.flushLogBuffer(true)
// start writing to the second file // start writing to the second file
_ = l.rotate() _ = l.rotate()
// add disk entries // add disk entries
addEntry(l, "example.org", "1.1.1.2", "2.2.2.2") addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// write to disk // write to disk
_ = l.flushLogBuffer(true) _ = l.flushLogBuffer(true)
// add memory entries // add memory entries
addEntry(l, "test.example.org", "1.1.1.3", "2.2.2.3") addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
addEntry(l, "example.com", "1.1.1.4", "2.2.2.4") addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
// get all entries // get all entries
params := newSearchParams() params := newSearchParams()
entries, _ := l.search(params) entries, _ := l.search(params)
assert.Len(t, entries, 4) assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// search by domain (strict) // search by domain (strict)
params = newSearchParams() params = newSearchParams()
@ -71,7 +71,7 @@ func TestQueryLog(t *testing.T) {
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Len(t, entries, 1) assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
// search by domain (not strict) // search by domain (not strict)
params = newSearchParams() params = newSearchParams()
@ -82,9 +82,9 @@ func TestQueryLog(t *testing.T) {
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Len(t, entries, 3) assert.Len(t, entries, 3)
assertLogEntry(t, entries[0], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[0], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[1], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[1], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[2], "example.org", "1.1.1.1", "2.2.2.1") assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// search by client IP (strict) // search by client IP (strict)
params = newSearchParams() params = newSearchParams()
@ -95,7 +95,7 @@ func TestQueryLog(t *testing.T) {
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Len(t, entries, 1) assert.Len(t, entries, 1)
assertLogEntry(t, entries[0], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[0], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
// search by client IP (part of) // search by client IP (part of)
params = newSearchParams() params = newSearchParams()
@ -106,10 +106,10 @@ func TestQueryLog(t *testing.T) {
}) })
entries, _ = l.search(params) entries, _ = l.search(params)
assert.Len(t, entries, 4) assert.Len(t, entries, 4)
assertLogEntry(t, entries[0], "example.com", "1.1.1.4", "2.2.2.4") assertLogEntry(t, entries[0], "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4))
assertLogEntry(t, entries[1], "test.example.org", "1.1.1.3", "2.2.2.3") assertLogEntry(t, entries[1], "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3))
assertLogEntry(t, entries[2], "example.org", "1.1.1.2", "2.2.2.2") assertLogEntry(t, entries[2], "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2))
assertLogEntry(t, entries[3], "example.org", "1.1.1.1", "2.2.2.1") assertLogEntry(t, entries[3], "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
func TestQueryLogOffsetLimit(t *testing.T) { func TestQueryLogOffsetLimit(t *testing.T) {
@ -124,13 +124,13 @@ func TestQueryLogOffsetLimit(t *testing.T) {
// add 10 entries to the log // add 10 entries to the log
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
addEntry(l, "second.example.org", "1.1.1.1", "2.2.2.1") addEntry(l, "second.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// write them to disk (first file) // write them to disk (first file)
_ = l.flushLogBuffer(true) _ = l.flushLogBuffer(true)
// add 10 more entries to the log (memory) // add 10 more entries to the log (memory)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
addEntry(l, "first.example.org", "1.1.1.1", "2.2.2.1") addEntry(l, "first.example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// First page // First page
@ -178,7 +178,7 @@ func TestQueryLogMaxFileScanEntries(t *testing.T) {
// add 10 entries to the log // add 10 entries to the log
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
addEntry(l, "example.org", "1.1.1.1", "2.2.2.1") addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
} }
// write them to disk (first file) // write them to disk (first file)
_ = l.flushLogBuffer(true) _ = l.flushLogBuffer(true)
@ -204,9 +204,9 @@ func TestQueryLogFileDisabled(t *testing.T) {
defer func() { _ = os.RemoveAll(conf.BaseDir) }() defer func() { _ = os.RemoveAll(conf.BaseDir) }()
l := newQueryLog(conf) l := newQueryLog(conf)
addEntry(l, "example1.org", "1.1.1.1", "2.2.2.1") addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example2.org", "1.1.1.1", "2.2.2.1") addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
addEntry(l, "example3.org", "1.1.1.1", "2.2.2.1") addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1))
// the oldest entry is now removed from mem buffer // the oldest entry is now removed from mem buffer
params := newSearchParams() params := newSearchParams()
@ -216,7 +216,7 @@ func TestQueryLogFileDisabled(t *testing.T) {
assert.Equal(t, "example2.org", ll[1].QHost) assert.Equal(t, "example2.org", ll[1].QHost)
} }
func addEntry(l *queryLog, host, answerStr, client string) { func addEntry(l *queryLog, host string, answerStr, client net.IP) {
q := dns.Msg{} q := dns.Msg{}
q.Question = append(q.Question, dns.Question{ q.Question = append(q.Question, dns.Question{
Name: host + ".", Name: host + ".",
@ -232,7 +232,7 @@ func addEntry(l *queryLog, host, answerStr, client string) {
Rrtype: dns.TypeA, Rrtype: dns.TypeA,
Class: dns.ClassINET, Class: dns.ClassINET,
} }
answer.A = net.ParseIP(answerStr) answer.A = answerStr
a.Answer = append(a.Answer, answer) a.Answer = append(a.Answer, answer)
res := dnsfilter.Result{ res := dnsfilter.Result{
IsFiltered: true, IsFiltered: true,
@ -248,13 +248,13 @@ func addEntry(l *queryLog, host, answerStr, client string) {
Answer: &a, Answer: &a,
OrigAnswer: &a, OrigAnswer: &a,
Result: &res, Result: &res,
ClientIP: net.ParseIP(client), ClientIP: client,
Upstream: "upstream", Upstream: "upstream",
} }
l.Add(params) l.Add(params)
} }
func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string) bool { func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) bool {
assert.Equal(t, host, entry.QHost) assert.Equal(t, host, entry.QHost)
assert.Equal(t, client, entry.IP) assert.Equal(t, client, entry.IP)
assert.Equal(t, "A", entry.QType) assert.Equal(t, "A", entry.QType)
@ -263,9 +263,9 @@ func assertLogEntry(t *testing.T, entry *logEntry, host, answer, client string)
msg := new(dns.Msg) msg := new(dns.Msg)
assert.Nil(t, msg.Unpack(entry.Answer)) assert.Nil(t, msg.Unpack(entry.Answer))
assert.Len(t, msg.Answer, 1) assert.Len(t, msg.Answer, 1)
ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]) ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16()
assert.NotNil(t, ip) assert.NotNil(t, ip)
assert.Equal(t, answer, ip.String()) assert.Equal(t, answer, ip)
return true return true
} }

View File

@ -94,16 +94,20 @@ func (c *searchCriteria) ctDomainOrClientCase(entry *logEntry) bool {
if c.strict && qhost == searchVal { if c.strict && qhost == searchVal {
return true return true
} }
if !c.strict && strings.Contains(qhost, searchVal) { if !c.strict && strings.Contains(qhost, searchVal) {
return true return true
} }
if c.strict && entry.IP == c.value { ipStr := entry.IP.String()
if c.strict && ipStr == c.value {
return true return true
} }
if !c.strict && strings.Contains(entry.IP, c.value) {
if !c.strict && strings.Contains(ipStr, c.value) {
return true return true
} }
return false return false
} }

View File

@ -48,7 +48,7 @@ type Stats interface {
Update(e Entry) Update(e Entry)
// Get IP addresses of the clients with the most number of requests // Get IP addresses of the clients with the most number of requests
GetTopClientsIP(limit uint) []string GetTopClientsIP(limit uint) []net.IP
// WriteDiskConfig - write configuration // WriteDiskConfig - write configuration
WriteDiskConfig(dc *DiskConfig) WriteDiskConfig(dc *DiskConfig)

View File

@ -80,7 +80,7 @@ func TestStats(t *testing.T) {
assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64)) assert.EqualValues(t, 0.123456, d["avg_processing_time"].(float64))
topClients := s.GetTopClientsIP(2) topClients := s.GetTopClientsIP(2)
assert.Equal(t, "127.0.0.1", topClients[0]) assert.True(t, net.IP{127, 0, 0, 1}.Equal(topClients[0]))
s.clear() s.clear()
s.Close() s.Close()

View File

@ -443,22 +443,19 @@ func (s *statsCtx) clear() {
} }
// Get Client IP address // Get Client IP address
func (s *statsCtx) getClientIP(clientIP string) string { func (s *statsCtx) getClientIP(ip net.IP) (clientIP net.IP) {
if s.conf.AnonymizeClientIP { if s.conf.AnonymizeClientIP && ip != nil {
ip := net.ParseIP(clientIP) const AnonymizeClientIP4Mask = 16
if ip != nil { const AnonymizeClientIP6Mask = 112
ip4 := ip.To4()
const AnonymizeClientIP4Mask = 16 if ip.To4() != nil {
const AnonymizeClientIP6Mask = 112 return ip.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32))
if ip4 != nil {
clientIP = ip4.Mask(net.CIDRMask(AnonymizeClientIP4Mask, 32)).String()
} else {
clientIP = ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128)).String()
}
} }
return ip.Mask(net.CIDRMask(AnonymizeClientIP6Mask, 128))
} }
return clientIP return ip
} }
func (s *statsCtx) Update(e Entry) { func (s *statsCtx) Update(e Entry) {
@ -468,7 +465,7 @@ func (s *statsCtx) Update(e Entry) {
!(len(e.Client) == 4 || len(e.Client) == 16) { !(len(e.Client) == 4 || len(e.Client) == 16) {
return return
} }
client := s.getClientIP(e.Client.String()) client := s.getClientIP(e.Client)
s.unitLock.Lock() s.unitLock.Lock()
u := s.unit u := s.unit
@ -481,7 +478,7 @@ func (s *statsCtx) Update(e Entry) {
u.blockedDomains[e.Domain]++ u.blockedDomains[e.Domain]++
} }
u.clients[client]++ u.clients[client.String()]++
u.timeSum += uint64(e.Time) u.timeSum += uint64(e.Time)
u.nTotal++ u.nTotal++
s.unitLock.Unlock() s.unitLock.Unlock()
@ -658,7 +655,7 @@ func (s *statsCtx) getData() map[string]interface{} {
return d return d
} }
func (s *statsCtx) GetTopClientsIP(maxCount uint) []string { func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
units, _ := s.loadUnits(s.conf.limit) units, _ := s.loadUnits(s.conf.limit)
if units == nil { if units == nil {
return nil return nil
@ -672,9 +669,9 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []string {
} }
} }
a := convertMapToArray(m, int(maxCount)) a := convertMapToArray(m, int(maxCount))
d := []string{} d := []net.IP{}
for _, it := range a { for _, it := range a {
d = append(d, it.Name) d = append(d, net.ParseIP(it.Name))
} }
return d return d
} }

View File

@ -119,17 +119,13 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
} }
func ifaceSetStaticIP(ifaceName string) (err error) { func ifaceSetStaticIP(ifaceName string) (err error) {
ip := util.GetSubnet(ifaceName) ipNet := util.GetSubnet(ifaceName)
if len(ip) == 0 { if ipNet.IP == nil {
return errors.New("can't get IP address") return errors.New("can't get IP address")
} }
ip4, _, err := net.ParseCIDR(ip)
if err != nil {
return err
}
gatewayIP := GatewayIP(ifaceName) gatewayIP := GatewayIP(ifaceName)
add := updateStaticIPdhcpcdConf(ifaceName, ip, gatewayIP, ip4) add := updateStaticIPdhcpcdConf(ifaceName, ipNet.String(), gatewayIP, ipNet.IP)
body, err := ioutil.ReadFile("/etc/dhcpcd.conf") body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
if err != nil { if err != nil {

View File

@ -108,11 +108,11 @@ func TestAutoHostsFSNotify(t *testing.T) {
ips = ah.Process("newhost", dns.TypeA) ips = ah.Process("newhost", dns.TypeA)
assert.NotNil(t, ips) assert.NotNil(t, ips)
assert.Len(t, ips, 1) assert.Len(t, ips, 1)
assert.Equal(t, "127.0.0.2", ips[0].String()) assert.True(t, net.IP{127, 0, 0, 2}.Equal(ips[0]))
} }
func TestIP(t *testing.T) { func TestIP(t *testing.T) {
assert.Equal(t, "127.0.0.1", DNSUnreverseAddr("1.0.0.127.in-addr.arpa").String()) assert.True(t, net.IP{127, 0, 0, 1}.Equal(DNSUnreverseAddr("1.0.0.127.in-addr.arpa")))
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.b.a.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())
assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String()) assert.Equal(t, "::abcd:1234", DNSUnreverseAddr("4.3.2.1.d.c.B.A.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.ip6.arpa").String())

View File

@ -15,12 +15,12 @@ import (
// NetInterface represents a list of network interfaces // NetInterface represents a list of network interfaces
type NetInterface struct { type NetInterface struct {
Name string // Network interface name Name string // Network interface name
MTU int // MTU MTU int // MTU
HardwareAddr string // Hardware address HardwareAddr string // Hardware address
Addresses []string // Array with the network interface addresses Addresses []net.IP // Array with the network interface addresses
Subnets []string // Array with CIDR addresses of this network interface Subnets []*net.IPNet // Array with CIDR addresses of this network interface
Flags string // Network interface flags (up, broadcast, etc) Flags string // Network interface flags (up, broadcast, etc)
} }
// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP // GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
@ -78,8 +78,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
if ipNet.IP.IsLinkLocalUnicast() { if ipNet.IP.IsLinkLocalUnicast() {
continue continue
} }
netIface.Addresses = append(netIface.Addresses, ipNet.IP.String()) netIface.Addresses = append(netIface.Addresses, ipNet.IP)
netIface.Subnets = append(netIface.Subnets, ipNet.String()) netIface.Subnets = append(netIface.Subnets, ipNet)
} }
// Discard interfaces with no addresses // Discard interfaces with no addresses
@ -91,8 +91,8 @@ func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
return netInterfaces, nil return netInterfaces, nil
} }
// GetInterfaceByIP - Get interface name by its IP address. // GetInterfaceByIP returns the name of interface containing provided ip.
func GetInterfaceByIP(ip string) string { func GetInterfaceByIP(ip net.IP) string {
ifaces, err := GetValidNetInterfacesForWeb() ifaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
return "" return ""
@ -100,7 +100,7 @@ func GetInterfaceByIP(ip string) string {
for _, iface := range ifaces { for _, iface := range ifaces {
for _, addr := range iface.Addresses { for _, addr := range iface.Addresses {
if ip == addr { if ip.Equal(addr) {
return iface.Name return iface.Name
} }
} }
@ -109,13 +109,13 @@ func GetInterfaceByIP(ip string) string {
return "" return ""
} }
// GetSubnet - Get IP address with netmask for the specified interface // GetSubnet returns pointer to net.IPNet for the specified interface or nil if
// Returns an empty string if it fails to find it // the search fails.
func GetSubnet(ifaceName string) string { func GetSubnet(ifaceName string) *net.IPNet {
netIfaces, err := GetValidNetInterfacesForWeb() netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil { if err != nil {
log.Error("Could not get network interfaces info: %v", err) log.Error("Could not get network interfaces info: %v", err)
return "" return nil
} }
for _, netIface := range netIfaces { for _, netIface := range netIfaces {
@ -124,12 +124,12 @@ func GetSubnet(ifaceName string) string {
} }
} }
return "" return nil
} }
// CheckPortAvailable - check if TCP port is available // CheckPortAvailable - check if TCP port is available
func CheckPortAvailable(host string, port int) error { func CheckPortAvailable(host net.IP, port int) error {
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port))) ln, err := net.Listen("tcp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil { if err != nil {
return err return err
} }
@ -142,8 +142,8 @@ func CheckPortAvailable(host string, port int) error {
} }
// CheckPacketPortAvailable - check if UDP port is available // CheckPacketPortAvailable - check if UDP port is available
func CheckPacketPortAvailable(host string, port int) error { func CheckPacketPortAvailable(host net.IP, port int) error {
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port))) ln, err := net.ListenPacket("udp", net.JoinHostPort(host.String(), strconv.Itoa(port)))
if err != nil { if err != nil {
return err return err
} }