Pull request: dhcpd: imp static lease validation

Closes #2838.
Updates #2834.

Squashed commit of the following:

commit 608dce28cf6bcbaf5a7f0bf499889ec25777e121
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Thu Mar 18 16:49:20 2021 +0300

    dhcpd: fix windows; imp code

commit 5e56eebf6ab85ca5fd0a0278c312674d921a3077
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Wed Mar 17 18:47:54 2021 +0300

    dhcpd: imp static lease validation
This commit is contained in:
Ainar Garipov 2021-03-18 17:07:13 +03:00
parent ffa0afae27
commit eb9526cc92
15 changed files with 392 additions and 121 deletions

View File

@ -20,6 +20,8 @@ and this project adheres to
### Changed ### Changed
- Stricter validation of the IP addresses of static leases in the DHCP server
with regards to the netmask ([#2838]).
- Stricter validation of `$dnsrewrite` filter modifier parameters ([#2498]). - Stricter validation of `$dnsrewrite` filter modifier parameters ([#2498]).
- New, more correct versioning scheme ([#2412]). - New, more correct versioning scheme ([#2412]).
@ -42,6 +44,7 @@ and this project adheres to
[#2533]: https://github.com/AdguardTeam/AdGuardHome/issues/2533 [#2533]: https://github.com/AdguardTeam/AdGuardHome/issues/2533
[#2541]: https://github.com/AdguardTeam/AdGuardHome/issues/2541 [#2541]: https://github.com/AdguardTeam/AdGuardHome/issues/2541
[#2835]: https://github.com/AdguardTeam/AdGuardHome/issues/2835 [#2835]: https://github.com/AdguardTeam/AdGuardHome/issues/2835
[#2838]: https://github.com/AdguardTeam/AdGuardHome/issues/2838

1
go.mod
View File

@ -32,7 +32,6 @@ require (
github.com/stretchr/testify v1.6.1 github.com/stretchr/testify v1.6.1
github.com/ti-mo/netfilter v0.4.0 github.com/ti-mo/netfilter v0.4.0
github.com/u-root/u-root v7.0.0+incompatible github.com/u-root/u-root v7.0.0+incompatible
github.com/willf/bitset v1.1.11
go.etcd.io/bbolt v1.3.5 go.etcd.io/bbolt v1.3.5
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83 golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110

2
go.sum
View File

@ -425,8 +425,6 @@ github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGr
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/willf/bitset v1.1.11 h1:N7Z7E9UvjW+sGsEl7k/SJrvY2reP1A07MrGuCjIOjRE=
github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=

52
internal/dhcpd/bitset.go Normal file
View File

@ -0,0 +1,52 @@
package dhcpd
const bitsPerWord = 64
// bitSet is a sparse bitSet. A nil *bitSet is an empty bitSet.
type bitSet struct {
words map[uint64]uint64
}
// newBitSet returns a new bitset.
func newBitSet() (s *bitSet) {
return &bitSet{
words: map[uint64]uint64{},
}
}
// isSet returns true if the bit n is set.
func (s *bitSet) isSet(n uint64) (ok bool) {
if s == nil {
return false
}
wordIdx := n / bitsPerWord
bitIdx := n % bitsPerWord
var word uint64
word, ok = s.words[wordIdx]
if !ok {
return false
}
return word&(1<<bitIdx) != 0
}
// set sets or unsets a bit.
func (s *bitSet) set(n uint64, ok bool) {
if s == nil {
return
}
wordIdx := n / bitsPerWord
bitIdx := n % bitsPerWord
word := s.words[wordIdx]
if ok {
word |= 1 << bitIdx
} else {
word &^= 1 << bitIdx
}
s.words[wordIdx] = word
}

View File

@ -0,0 +1,90 @@
package dhcpd
import (
"math"
"testing"
"testing/quick"
"github.com/stretchr/testify/assert"
)
func TestBitSet(t *testing.T) {
t.Run("nil", func(t *testing.T) {
var s *bitSet
ok := s.isSet(0)
assert.False(t, ok)
assert.NotPanics(t, func() {
s.set(0, true)
})
ok = s.isSet(0)
assert.False(t, ok)
assert.NotPanics(t, func() {
s.set(0, false)
})
ok = s.isSet(0)
assert.False(t, ok)
})
t.Run("non_nil", func(t *testing.T) {
s := newBitSet()
ok := s.isSet(0)
assert.False(t, ok)
s.set(0, true)
ok = s.isSet(0)
assert.True(t, ok)
s.set(0, false)
ok = s.isSet(0)
assert.False(t, ok)
})
t.Run("non_nil_long", func(t *testing.T) {
s := newBitSet()
s.set(0, true)
s.set(math.MaxUint64, true)
assert.Len(t, s.words, 2)
ok := s.isSet(0)
assert.True(t, ok)
ok = s.isSet(math.MaxUint64)
assert.True(t, ok)
})
t.Run("compare_to_map", func(t *testing.T) {
m := map[uint64]struct{}{}
s := newBitSet()
mapFunc := func(setNew, checkOld, delOld uint64) (ok bool) {
m[setNew] = struct{}{}
delete(m, delOld)
_, ok = m[checkOld]
return ok
}
setFunc := func(setNew, checkOld, delOld uint64) (ok bool) {
s.set(setNew, true)
s.set(delOld, false)
ok = s.isSet(checkOld)
return ok
}
err := quick.CheckEqual(mapFunc, setFunc, &quick.Config{
MaxCount: 10_000,
MaxCountScale: 10,
})
assert.NoError(t, err)
})
}

View File

@ -128,7 +128,7 @@ func normalizeLeases(staticLeases, dynLeases []*Lease) []*Lease {
func (s *Server) dbStore() { func (s *Server) dbStore() {
var leases []leaseJSON var leases []leaseJSON
leases4 := s.srv4.GetLeasesRef() leases4 := s.srv4.getLeasesRef()
for _, l := range leases4 { for _, l := range leases4 {
if l.Expiry.Unix() == 0 { if l.Expiry.Unix() == 0 {
continue continue
@ -143,7 +143,7 @@ func (s *Server) dbStore() {
} }
if s.srv6 != nil { if s.srv6 != nil {
leases6 := s.srv6.GetLeasesRef() leases6 := s.srv6.getLeasesRef()
for _, l := range leases6 { for _, l := range leases6 {
if l.Expiry.Unix() == 0 { if l.Expiry.Unix() == 0 {
continue continue

View File

@ -36,16 +36,23 @@ type Lease struct {
Expiry time.Time `json:"expires"` Expiry time.Time `json:"expires"`
} }
// IsStatic returns true if the lease is static.
//
// TODO(a.garipov): Just make it a boolean field.
func (l *Lease) IsStatic() (ok bool) {
return l != nil && l.Expiry.Unix() == leaseExpireStatic
}
// MarshalJSON implements the json.Marshaler interface for *Lease. // MarshalJSON implements the json.Marshaler interface for *Lease.
func (l *Lease) MarshalJSON() ([]byte, error) { func (l *Lease) MarshalJSON() ([]byte, error) {
var expiryStr string var expiryStr string
if expiry := l.Expiry; expiry.Unix() != leaseExpireStatic { if !l.IsStatic() {
// The front-end is waiting for RFC 3999 format of the time // The front-end is waiting for RFC 3999 format of the time
// value. It also shouldn't got an Expiry field for static // value. It also shouldn't got an Expiry field for static
// leases. // leases.
// //
// See https://github.com/AdguardTeam/AdGuardHome/issues/2692. // See https://github.com/AdguardTeam/AdGuardHome/issues/2692.
expiryStr = expiry.Format(time.RFC3339) expiryStr = l.Expiry.Format(time.RFC3339)
} }
type lease Lease type lease Lease
@ -203,6 +210,7 @@ func Create(conf ServerConfig) *Server {
func (s *Server) onNotify(flags uint32) { func (s *Server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore { if flags == LeaseChangedDBStore {
s.dbStore() s.dbStore()
return return
} }
@ -218,6 +226,7 @@ func (s *Server) notify(flags int) {
if len(s.onLeaseChanged) == 0 { if len(s.onLeaseChanged) == 0 {
return return
} }
for _, f := range s.onLeaseChanged { for _, f := range s.onLeaseChanged {
f(flags) f(flags)
} }

View File

@ -54,7 +54,8 @@ func TestDB(t *testing.T) {
srv4, ok := s.srv4.(*v4Server) srv4, ok := s.srv4.(*v4Server)
require.True(t, ok) require.True(t, ok)
srv4.addLease(&leases[0]) err = srv4.addLease(&leases[0])
require.Nil(t, err)
require.Nil(t, s.srv4.AddStaticLease(leases[1])) require.Nil(t, s.srv4.AddStaticLease(leases[1]))
s.dbStore() s.dbStore()
@ -69,7 +70,7 @@ func TestDB(t *testing.T) {
assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr) assert.Equal(t, leases[1].HWAddr, ll[0].HWAddr)
assert.Equal(t, leases[1].IP, ll[0].IP) assert.Equal(t, leases[1].IP, ll[0].IP)
assert.EqualValues(t, leaseExpireStatic, ll[0].Expiry.Unix()) assert.True(t, ll[0].IsStatic())
assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr) assert.Equal(t, leases[0].HWAddr, ll[1].HWAddr)
assert.Equal(t, leases[0].IP, ll[1].IP) assert.Equal(t, leases[0].IP, ll[1].IP)

View File

@ -93,7 +93,7 @@ func (r *ipRange) find(p ipPredicate) (ip net.IP) {
// offset returns the offset of ip from the beginning of r. It returns 0 and // offset returns the offset of ip from the beginning of r. It returns 0 and
// false if ip is not in r. // false if ip is not in r.
func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) { func (r *ipRange) offset(ip net.IP) (offset uint64, ok bool) {
if r == nil { if r == nil {
return 0, false return 0, false
} }
@ -108,5 +108,5 @@ func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) {
// Assume that the range was checked against maxRangeLen during // Assume that the range was checked against maxRangeLen during
// construction. // construction.
return uint(offsetInt.Uint64()), true return offsetInt.Uint64(), true
} }

View File

@ -66,10 +66,10 @@ func TestNewIPRange(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
r, err := newIPRange(tc.start, tc.end) r, err := newIPRange(tc.start, tc.end)
if tc.wantErrMsg == "" { if tc.wantErrMsg == "" {
assert.Nil(t, err) assert.NoError(t, err)
assert.NotNil(t, r) assert.NotNil(t, r)
} else { } else {
require.NotNil(t, err) require.Error(t, err)
assert.Equal(t, tc.wantErrMsg, err.Error()) assert.Equal(t, tc.wantErrMsg, err.Error())
} }
}) })
@ -79,7 +79,7 @@ func TestNewIPRange(t *testing.T) {
func TestIPRange_Contains(t *testing.T) { func TestIPRange_Contains(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 3} start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 3}
r, err := newIPRange(start, end) r, err := newIPRange(start, end)
require.Nil(t, err) require.NoError(t, err)
assert.True(t, r.contains(start)) assert.True(t, r.contains(start))
assert.True(t, r.contains(net.IP{0, 0, 0, 2})) assert.True(t, r.contains(net.IP{0, 0, 0, 2}))
@ -92,7 +92,7 @@ func TestIPRange_Contains(t *testing.T) {
func TestIPRange_Find(t *testing.T) { func TestIPRange_Find(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5} start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
r, err := newIPRange(start, end) r, err := newIPRange(start, end)
require.Nil(t, err) require.NoError(t, err)
want := net.IPv4(0, 0, 0, 2) want := net.IPv4(0, 0, 0, 2)
got := r.find(func(ip net.IP) (ok bool) { got := r.find(func(ip net.IP) (ok bool) {
@ -110,12 +110,12 @@ func TestIPRange_Find(t *testing.T) {
func TestIPRange_Offset(t *testing.T) { func TestIPRange_Offset(t *testing.T) {
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5} start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
r, err := newIPRange(start, end) r, err := newIPRange(start, end)
require.Nil(t, err) require.NoError(t, err)
testCases := []struct { testCases := []struct {
name string name string
in net.IP in net.IP
wantOffset uint wantOffset uint64
wantOK bool wantOK bool
}{{ }{{
name: "in", name: "in",

View File

@ -11,8 +11,6 @@ type DHCPServer interface {
ResetLeases(leases []*Lease) ResetLeases(leases []*Lease)
// GetLeases - get leases // GetLeases - get leases
GetLeases(flags int) []Lease GetLeases(flags int) []Lease
// GetLeasesRef - get reference to leases array
GetLeasesRef() []*Lease
// AddStaticLease - add a static lease // AddStaticLease - add a static lease
AddStaticLease(lease Lease) error AddStaticLease(lease Lease) error
// RemoveStaticLease - remove a static lease // RemoveStaticLease - remove a static lease
@ -29,6 +27,8 @@ type DHCPServer interface {
Start() error Start() error
// Stop - stop server // Stop - stop server
Stop() Stop()
getLeasesRef() []*Lease
} }
// V4ServerConf - server configuration // V4ServerConf - server configuration
@ -68,7 +68,12 @@ type V4ServerConf struct {
subnetMask net.IPMask // value for Option SubnetMask subnetMask net.IPMask // value for Option SubnetMask
options []dhcpOption options []dhcpOption
// Server calls this function when leases data changes // notify is a way to signal to other components that leases have
// change. notify must be called outside of locked sections, since the
// clients might want to get the new data.
//
// TODO(a.garipov): This is utter madness and must be refactored. It
// just begs for deadlock bugs and other nastiness.
notify func(uint32) notify func(uint32)
} }

View File

@ -9,11 +9,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/go-ping/ping" "github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4" "github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/willf/bitset"
) )
// v4Server is a DHCPv4 server. // v4Server is a DHCPv4 server.
@ -25,7 +25,7 @@ type v4Server struct {
// leasedOffsets contains offsets from conf.ipRange.start that have been // leasedOffsets contains offsets from conf.ipRange.start that have been
// leased. // leased.
leasedOffsets *bitset.BitSet leasedOffsets *bitSet
// leases contains all dynamic and static leases. // leases contains all dynamic and static leases.
leases []*Lease leases []*Lease
@ -47,50 +47,88 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
func (s *v4Server) ResetLeases(leases []*Lease) { func (s *v4Server) ResetLeases(leases []*Lease) {
s.leases = nil s.leases = nil
r := s.conf.ipRange
for _, l := range leases { for _, l := range leases {
if l.Expiry.Unix() != leaseExpireStatic && !s.conf.ipRange.contains(l.IP) { if !l.IsStatic() && !r.contains(l.IP) {
log.Debug("dhcpv4: skipping a lease with ip %v: not within current ip range", l.IP) log.Debug(
"dhcpv4: skipping lease %s (%s): not within current ip range",
l.IP,
l.HWAddr,
)
continue continue
} }
s.addLease(l) err := s.addLease(l)
if err != nil {
// TODO(a.garipov): Better error handling.
log.Error("dhcpv4: adding a lease for %s (%s): %s", l.IP, l.HWAddr, err)
continue
}
} }
} }
// GetLeasesRef - get leases // getLeasesRef returns the actual leases slice. For internal use only.
func (s *v4Server) GetLeasesRef() []*Lease { func (s *v4Server) getLeasesRef() []*Lease {
return s.leases return s.leases
} }
// Return TRUE if this lease holds a blacklisted IP // isBlocklisted returns true if this lease holds a blocklisted IP.
func (s *v4Server) blacklisted(l *Lease) bool { //
return l.HWAddr.String() == "00:00:00:00:00:00" // TODO(a.garipov): Make a method of *Lease?
} func (s *v4Server) isBlocklisted(l *Lease) (ok bool) {
if len(l.HWAddr) == 0 {
return false
}
// GetLeases returns the list of current DHCP leases (thread-safe) ok = true
func (s *v4Server) GetLeases(flags int) []Lease { for _, b := range l.HWAddr {
// The function shouldn't return nil value because zero-length slice if b != 0 {
// behaves differently in cases like marshalling. Our front-end also ok = false
// requires non-nil value in the response.
result := []Lease{}
now := time.Now().Unix()
s.leasesLock.Lock() break
for _, lease := range s.leases {
if ((flags&LeasesDynamic) != 0 && lease.Expiry.Unix() > now && !s.blacklisted(lease)) ||
((flags&LeasesStatic) != 0 && lease.Expiry.Unix() == leaseExpireStatic) {
result = append(result, *lease)
} }
} }
s.leasesLock.Unlock()
return result return ok
}
// GetLeases returns the list of current DHCP leases. It is safe for concurrent
// use.
func (s *v4Server) GetLeases(flags int) (res []Lease) {
// The function shouldn't return nil, because zero-length slice behaves
// differently in cases like marshalling. Our front-end also requires
// a non-nil value in the response.
res = []Lease{}
// TODO(a.garipov): Remove the silly bit twiddling and make GetLeases
// accept booleans. Seriously, this doesn't even save stack space.
getDynamic := flags&LeasesDynamic != 0
getStatic := flags&LeasesStatic != 0
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
now := time.Now()
for _, l := range s.leases {
if getDynamic && l.Expiry.After(now) && !s.isBlocklisted(l) {
res = append(res, *l)
continue
}
if getStatic && l.IsStatic() {
res = append(res, *l)
}
}
return res
} }
// FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases // FindMACbyIP - find a MAC address by IP address in the currently active DHCP leases
func (s *v4Server) FindMACbyIP(ip net.IP) net.HardwareAddr { func (s *v4Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
now := time.Now().Unix() now := time.Now()
s.leasesLock.Lock() s.leasesLock.Lock()
defer s.leasesLock.Unlock() defer s.leasesLock.Unlock()
@ -102,12 +140,12 @@ func (s *v4Server) FindMACbyIP(ip net.IP) net.HardwareAddr {
for _, l := range s.leases { for _, l := range s.leases {
if l.IP.Equal(ip4) { if l.IP.Equal(ip4) {
unix := l.Expiry.Unix() if l.Expiry.After(now) || l.IsStatic() {
if unix > now || unix == leaseExpireStatic {
return l.HWAddr return l.HWAddr
} }
} }
} }
return nil return nil
} }
@ -142,7 +180,7 @@ func (s *v4Server) rmLeaseByIndex(i int) {
r := s.conf.ipRange r := s.conf.ipRange
offset, ok := r.offset(l.IP) offset, ok := r.offset(l.IP)
if ok { if ok {
s.leasedOffsets.Clear(offset) s.leasedOffsets.set(offset, false)
} }
log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr) log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr)
@ -150,12 +188,12 @@ func (s *v4Server) rmLeaseByIndex(i int) {
// Remove a dynamic lease with the same properties // Remove a dynamic lease with the same properties
// Return error if a static lease is found // Return error if a static lease is found
func (s *v4Server) rmDynamicLease(lease Lease) error { func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
for i := 0; i < len(s.leases); i++ { for i := 0; i < len(s.leases); i++ {
l := s.leases[i] l := s.leases[i]
if bytes.Equal(l.HWAddr, lease.HWAddr) { if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.Expiry.Unix() == leaseExpireStatic { if l.IsStatic() {
return fmt.Errorf("static lease already exists") return fmt.Errorf("static lease already exists")
} }
@ -168,40 +206,70 @@ func (s *v4Server) rmDynamicLease(lease Lease) error {
} }
if net.IP.Equal(l.IP, lease.IP) { if net.IP.Equal(l.IP, lease.IP) {
if l.Expiry.Unix() == leaseExpireStatic { if l.IsStatic() {
return fmt.Errorf("static lease already exists") return fmt.Errorf("static lease already exists")
} }
s.rmLeaseByIndex(i) s.rmLeaseByIndex(i)
} }
} }
return nil return nil
} }
// addLease adds a lease. func (s *v4Server) addStaticLease(l *Lease) (err error) {
func (s *v4Server) addLease(l *Lease) { subnet := &net.IPNet{
r := s.conf.ipRange IP: s.conf.routerIP,
offset, ok := r.offset(l.IP) Mask: s.conf.subnetMask,
if !ok { }
// TODO(a.garipov): Better error handling.
log.Debug("dhcpv4: lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
return if !subnet.Contains(l.IP) {
return fmt.Errorf("subnet %s does not contain the ip %q", subnet, l.IP)
} }
s.leases = append(s.leases, l) s.leases = append(s.leases, l)
s.leasedOffsets.Set(offset)
log.Debug("dhcpv4: added lease %s (%s)", l.IP, l.HWAddr) r := s.conf.ipRange
offset, ok := r.offset(l.IP)
if ok {
s.leasedOffsets.set(offset, true)
}
return nil
}
func (s *v4Server) addDynamicLease(l *Lease) (err error) {
r := s.conf.ipRange
offset, ok := r.offset(l.IP)
if !ok {
return fmt.Errorf("lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
}
s.leases = append(s.leases, l)
s.leasedOffsets.set(offset, true)
return nil
}
// addLease adds a dynamic or static lease.
func (s *v4Server) addLease(l *Lease) (err error) {
if l.IsStatic() {
return s.addStaticLease(l)
}
return s.addDynamicLease(l)
} }
// Remove a lease with the same properties // Remove a lease with the same properties
func (s *v4Server) rmLease(lease Lease) error { func (s *v4Server) rmLease(lease Lease) error {
if len(s.leases) == 0 {
return nil
}
for i, l := range s.leases { for i, l := range s.leases {
if l.IP.Equal(lease.IP) { if l.IP.Equal(lease.IP) {
if !bytes.Equal(l.HWAddr, lease.HWAddr) || if !bytes.Equal(l.HWAddr, lease.HWAddr) || l.Hostname != lease.Hostname {
l.Hostname != lease.Hostname { return fmt.Errorf("lease for ip %s is different: %+v", lease.IP, l)
return fmt.Errorf("lease not found")
} }
s.rmLeaseByIndex(i) s.rmLeaseByIndex(i)
@ -210,30 +278,55 @@ func (s *v4Server) rmLease(lease Lease) error {
} }
} }
return fmt.Errorf("lease not found") return agherr.Error("lease not found")
} }
// AddStaticLease adds a static lease (thread-safe) // AddStaticLease adds a static lease. It is safe for concurrent use.
func (s *v4Server) AddStaticLease(lease Lease) error { func (s *v4Server) AddStaticLease(l Lease) (err error) {
if len(lease.IP) != 4 { defer agherr.Annotate("dhcpv4: %w", &err)
return fmt.Errorf("invalid IP")
}
if len(lease.HWAddr) != 6 {
return fmt.Errorf("invalid MAC")
}
lease.Expiry = time.Unix(leaseExpireStatic, 0)
s.leasesLock.Lock() if ip4 := l.IP.To4(); ip4 == nil {
err := s.rmDynamicLease(lease) return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
}
if len(l.HWAddr) != 6 {
return fmt.Errorf("invalid mac %q, only EUI-48 is supported", l.HWAddr)
}
l.Expiry = time.Unix(leaseExpireStatic, 0)
// Perform the following actions in an anonymous function to make sure
// that the lock gets unlocked before the notification step.
func() {
s.leasesLock.Lock()
defer s.leasesLock.Unlock()
err = s.rmDynamicLease(&l)
if err != nil {
err = fmt.Errorf(
"removing dynamic leases for %s (%s): %w",
l.IP,
l.HWAddr,
err,
)
return
}
err = s.addLease(&l)
if err != nil {
err = fmt.Errorf("adding static lease for %s (%s): %w", l.IP, l.HWAddr, err)
return
}
}()
if err != nil { if err != nil {
s.leasesLock.Unlock()
return err return err
} }
s.addLease(&lease)
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock()
s.conf.notify(LeaseChangedDBStore)
s.conf.notify(LeaseChangedAddedStatic) s.conf.notify(LeaseChangedAddedStatic)
return nil return nil
} }
@ -250,12 +343,14 @@ func (s *v4Server) RemoveStaticLease(l Lease) error {
err := s.rmLease(l) err := s.rmLease(l)
if err != nil { if err != nil {
s.leasesLock.Unlock() s.leasesLock.Unlock()
return err return err
} }
s.conf.notify(LeaseChangedDBStore)
s.leasesLock.Unlock() s.leasesLock.Unlock()
s.conf.notify(LeaseChangedDBStore)
s.conf.notify(LeaseChangedRemovedStatic) s.conf.notify(LeaseChangedRemovedStatic)
return nil return nil
} }
@ -318,7 +413,7 @@ func (s *v4Server) nextIP() (ip net.IP) {
return false return false
} }
return !s.leasedOffsets.Test(offset) return !s.leasedOffsets.isSet(offset)
}) })
return ip.To4() return ip.To4()
@ -326,19 +421,19 @@ func (s *v4Server) nextIP() (ip net.IP) {
// Find an expired lease and return its index or -1 // Find an expired lease and return its index or -1
func (s *v4Server) findExpiredLease() int { func (s *v4Server) findExpiredLease() int {
now := time.Now().Unix() now := time.Now()
for i, lease := range s.leases { for i, lease := range s.leases {
if lease.Expiry.Unix() != leaseExpireStatic && if !lease.IsStatic() && lease.Expiry.Before(now) {
lease.Expiry.Unix() <= now {
return i return i
} }
} }
return -1 return -1
} }
// reserveLease reserves a lease for a client by its MAC-address. It returns // reserveLease reserves a lease for a client by its MAC-address. It returns
// nil if it couldn't allocate a new lease. // nil if it couldn't allocate a new lease.
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease) { func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease, err error) {
l = &Lease{ l = &Lease{
HWAddr: make([]byte, len(mac)), HWAddr: make([]byte, len(mac)),
} }
@ -349,17 +444,20 @@ func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease) {
if l.IP == nil { if l.IP == nil {
i := s.findExpiredLease() i := s.findExpiredLease()
if i < 0 { if i < 0 {
return nil return nil, nil
} }
copy(s.leases[i].HWAddr, mac) copy(s.leases[i].HWAddr, mac)
return s.leases[i] return s.leases[i], nil
} }
s.addLease(l) err = s.addLease(l)
if err != nil {
return nil, err
}
return l return l, nil
} }
func (s *v4Server) commitLease(l *Lease) { func (s *v4Server) commitLease(l *Lease) {
@ -373,46 +471,55 @@ func (s *v4Server) commitLease(l *Lease) {
} }
// Process Discover request and return lease // Process Discover request and return lease
func (s *v4Server) processDiscover(req, resp *dhcpv4.DHCPv4) *Lease { func (s *v4Server) processDiscover(req, resp *dhcpv4.DHCPv4) (l *Lease, err error) {
mac := req.ClientHWAddr mac := req.ClientHWAddr
s.leasesLock.Lock() s.leasesLock.Lock()
defer s.leasesLock.Unlock() defer s.leasesLock.Unlock()
lease := s.findLease(mac) // TODO(a.garipov): Refactor this mess.
if lease == nil { l = s.findLease(mac)
if l == nil {
toStore := false toStore := false
for lease == nil { for l == nil {
lease = s.reserveLease(mac) l, err = s.reserveLease(mac)
if lease == nil { if err != nil {
log.Debug("dhcpv4: No more IP addresses") return nil, fmt.Errorf("reserving a lease: %w", err)
}
if l == nil {
log.Debug("dhcpv4: no more ip addresses")
if toStore { if toStore {
s.conf.notify(LeaseChangedDBStore) s.conf.notify(LeaseChangedDBStore)
} }
return nil
// TODO(a.garipov): Return a special error?
return nil, nil
} }
toStore = true toStore = true
if !s.addrAvailable(lease.IP) { if !s.addrAvailable(l.IP) {
s.blocklistLease(lease) s.blocklistLease(l)
lease = nil l = nil
continue continue
} }
break break
} }
s.conf.notify(LeaseChangedDBStore) s.conf.notify(LeaseChangedDBStore)
} else { } else {
reqIP := req.RequestedIPAddress() reqIP := req.RequestedIPAddress()
if len(reqIP) != 0 && !reqIP.Equal(lease.IP) { if len(reqIP) != 0 && !reqIP.Equal(l.IP) {
log.Debug("dhcpv4: different RequestedIP: %v != %v", reqIP, lease.IP) log.Debug("dhcpv4: different RequestedIP: %s != %s", reqIP, l.IP)
} }
} }
resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer)) resp.UpdateOption(dhcpv4.OptMessageType(dhcpv4.MessageTypeOffer))
return lease
return l, nil
} }
type optFQDN struct { type optFQDN struct {
@ -490,7 +597,7 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (*Lease, bool) {
return nil, true return nil, true
} }
if lease.Expiry.Unix() != leaseExpireStatic { if !lease.IsStatic() {
lease.Hostname = req.HostName() lease.Hostname = req.HostName()
s.commitLease(lease) s.commitLease(lease)
} else if len(lease.Hostname) != 0 { } else if len(lease.Hostname) != 0 {
@ -515,22 +622,27 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (*Lease, bool) {
// Return 0: error; reply with Nak // Return 0: error; reply with Nak
// Return -1: error; don't reply // Return -1: error; don't reply
func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int { func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
var lease *Lease var err error
resp.UpdateOption(dhcpv4.OptServerIdentifier(s.conf.dnsIPAddrs[0])) resp.UpdateOption(dhcpv4.OptServerIdentifier(s.conf.dnsIPAddrs[0]))
var l *Lease
switch req.MessageType() { switch req.MessageType() {
case dhcpv4.MessageTypeDiscover: case dhcpv4.MessageTypeDiscover:
lease = s.processDiscover(req, resp) l, err = s.processDiscover(req, resp)
if lease == nil { if err != nil {
log.Error("dhcpv4: processing discover: %s", err)
return 0 return 0
} }
if l == nil {
return 0
}
case dhcpv4.MessageTypeRequest: case dhcpv4.MessageTypeRequest:
var toReply bool var toReply bool
lease, toReply = s.processRequest(req, resp) l, toReply = s.processRequest(req, resp)
if lease == nil { if l == nil {
if toReply { if toReply {
return 0 return 0
} }
@ -539,7 +651,7 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
} }
resp.YourIPAddr = make([]byte, 4) resp.YourIPAddr = make([]byte, 4)
copy(resp.YourIPAddr, lease.IP) copy(resp.YourIPAddr, l.IP)
resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime)) resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime))
resp.UpdateOption(dhcpv4.OptRouter(s.conf.routerIP)) resp.UpdateOption(dhcpv4.OptRouter(s.conf.routerIP))
@ -549,6 +661,7 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int {
for _, opt := range s.conf.options { for _, opt := range s.conf.options {
resp.Options[opt.code] = opt.data resp.Options[opt.code] = opt.data
} }
return 1 return 1
} }
@ -683,7 +796,7 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
return s, fmt.Errorf("dhcpv4: %w", err) return s, fmt.Errorf("dhcpv4: %w", err)
} }
s.leasedOffsets = &bitset.BitSet{} s.leasedOffsets = newBitSet()
if conf.LeaseDuration == 0 { if conf.LeaseDuration == 0 {
s.conf.leaseTime = time.Hour * 24 s.conf.leaseTime = time.Hour * 24

View File

@ -10,7 +10,7 @@ type winServer struct{}
func (s *winServer) ResetLeases(leases []*Lease) {} func (s *winServer) ResetLeases(leases []*Lease) {}
func (s *winServer) GetLeases(flags int) []Lease { return nil } func (s *winServer) GetLeases(flags int) []Lease { return nil }
func (s *winServer) GetLeasesRef() []*Lease { return nil } func (s *winServer) getLeasesRef() []*Lease { return nil }
func (s *winServer) AddStaticLease(lease Lease) error { return nil } func (s *winServer) AddStaticLease(lease Lease) error { return nil }
func (s *winServer) RemoveStaticLease(l Lease) error { return nil } func (s *winServer) RemoveStaticLease(l Lease) error { return nil }
func (s *winServer) FindMACbyIP(ip net.IP) net.HardwareAddr { return nil } func (s *winServer) FindMACbyIP(ip net.IP) net.HardwareAddr { return nil }

View File

@ -40,7 +40,7 @@ func TestV4_AddRemove_static(t *testing.T) {
require.Len(t, ls, 1) require.Len(t, ls, 1)
assert.True(t, l.IP.Equal(ls[0].IP)) assert.True(t, l.IP.Equal(ls[0].IP))
assert.Equal(t, l.HWAddr, ls[0].HWAddr) assert.Equal(t, l.HWAddr, ls[0].HWAddr)
assert.EqualValues(t, leaseExpireStatic, ls[0].Expiry.Unix()) assert.True(t, ls[0].IsStatic())
// Try to remove static lease. // Try to remove static lease.
assert.NotNil(t, s.RemoveStaticLease(Lease{ assert.NotNil(t, s.RemoveStaticLease(Lease{
@ -77,7 +77,8 @@ func TestV4_AddReplace(t *testing.T) {
}} }}
for i := range dynLeases { for i := range dynLeases {
s.addLease(&dynLeases[i]) err = s.addLease(&dynLeases[i])
require.Nil(t, err)
} }
stLeases := []Lease{{ stLeases := []Lease{{
@ -98,7 +99,7 @@ func TestV4_AddReplace(t *testing.T) {
for i, l := range ls { for i, l := range ls {
assert.True(t, stLeases[i].IP.Equal(l.IP)) assert.True(t, stLeases[i].IP.Equal(l.IP))
assert.Equal(t, stLeases[i].HWAddr, l.HWAddr) assert.Equal(t, stLeases[i].HWAddr, l.HWAddr)
assert.EqualValues(t, leaseExpireStatic, l.Expiry.Unix()) assert.True(t, l.IsStatic())
} }
} }

View File

@ -92,8 +92,8 @@ func (s *v6Server) GetLeases(flags int) []Lease {
return result return result
} }
// GetLeasesRef - get leases // getLeasesRef returns the actual leases slice. For internal use only.
func (s *v6Server) GetLeasesRef() []*Lease { func (s *v6Server) getLeasesRef() []*Lease {
return s.leases return s.leases
} }