From 83ec7c54ef1e689a1f887e78e3055522539222d5 Mon Sep 17 00:00:00 2001 From: Eugene Burkov Date: Mon, 8 Jul 2024 18:56:21 +0300 Subject: [PATCH] dhcpsvc: add db --- internal/dhcpsvc/config.go | 8 +- internal/dhcpsvc/config_test.go | 9 + internal/dhcpsvc/db.go | 197 ++++++++++++++++++ internal/dhcpsvc/db_test.go | 55 +++++ internal/dhcpsvc/leaseindex.go | 15 ++ internal/dhcpsvc/server.go | 54 ++++- internal/dhcpsvc/server_test.go | 25 ++- .../TestServer_loadDatabase/leases.json | 10 + internal/dhcpsvc/v4.go | 1 + 9 files changed, 364 insertions(+), 10 deletions(-) create mode 100644 internal/dhcpsvc/db.go create mode 100644 internal/dhcpsvc/db_test.go create mode 100644 internal/dhcpsvc/testdata/TestServer_loadDatabase/leases.json diff --git a/internal/dhcpsvc/config.go b/internal/dhcpsvc/config.go index c1d7910d..2a35a977 100644 --- a/internal/dhcpsvc/config.go +++ b/internal/dhcpsvc/config.go @@ -3,6 +3,7 @@ package dhcpsvc import ( "fmt" "log/slog" + "os" "time" "github.com/AdguardTeam/golibs/errors" @@ -23,7 +24,8 @@ type Config struct { // clients' hostnames. LocalDomainName string - // TODO(e.burkov): Add DB path. + // DBFilePath is the path to the database file containing the DHCP leases. + DBFilePath string // ICMPTimeout is the timeout for checking another DHCP server's presence. ICMPTimeout time.Duration @@ -64,6 +66,10 @@ func (conf *Config) Validate() (err error) { errs = append(errs, err) } + if _, err = os.Stat(conf.DBFilePath); err != nil && !errors.Is(err, os.ErrNotExist) { + errs = append(errs, fmt.Errorf("db file path %q: %w", conf.DBFilePath, err)) + } + if len(conf.Interfaces) == 0 { errs = append(errs, errNoInterfaces) diff --git a/internal/dhcpsvc/config_test.go b/internal/dhcpsvc/config_test.go index aa87b0d6..85dab4a9 100644 --- a/internal/dhcpsvc/config_test.go +++ b/internal/dhcpsvc/config_test.go @@ -1,6 +1,7 @@ package dhcpsvc_test import ( + "path/filepath" "testing" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" @@ -8,6 +9,8 @@ import ( ) func TestConfig_Validate(t *testing.T) { + leasesPath := filepath.Join(t.TempDir(), "leases.json") + testCases := []struct { name string conf *dhcpsvc.Config @@ -25,6 +28,7 @@ func TestConfig_Validate(t *testing.T) { conf: &dhcpsvc.Config{ Enabled: true, Interfaces: testInterfaceConf, + DBFilePath: leasesPath, }, wantErrMsg: `bad domain name "": domain name is empty`, }, { @@ -32,6 +36,7 @@ func TestConfig_Validate(t *testing.T) { Enabled: true, LocalDomainName: testLocalTLD, Interfaces: nil, + DBFilePath: leasesPath, }, name: "no_interfaces", wantErrMsg: "no interfaces specified", @@ -40,6 +45,7 @@ func TestConfig_Validate(t *testing.T) { Enabled: true, LocalDomainName: testLocalTLD, Interfaces: nil, + DBFilePath: leasesPath, }, name: "no_interfaces", wantErrMsg: "no interfaces specified", @@ -50,6 +56,7 @@ func TestConfig_Validate(t *testing.T) { Interfaces: map[string]*dhcpsvc.InterfaceConfig{ "eth0": nil, }, + DBFilePath: leasesPath, }, name: "nil_interface", wantErrMsg: `interface "eth0": config is nil`, @@ -63,6 +70,7 @@ func TestConfig_Validate(t *testing.T) { IPv6: &dhcpsvc.IPv6Config{Enabled: false}, }, }, + DBFilePath: leasesPath, }, name: "nil_ipv4", wantErrMsg: `interface "eth0": ipv4: config is nil`, @@ -76,6 +84,7 @@ func TestConfig_Validate(t *testing.T) { IPv6: nil, }, }, + DBFilePath: leasesPath, }, name: "nil_ipv6", wantErrMsg: `interface "eth0": ipv6: config is nil`, diff --git a/internal/dhcpsvc/db.go b/internal/dhcpsvc/db.go new file mode 100644 index 00000000..5fb4b23f --- /dev/null +++ b/internal/dhcpsvc/db.go @@ -0,0 +1,197 @@ +package dhcpsvc + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "slices" + "strings" + "time" + + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/google/renameio/v2/maybe" +) + +// dataVersion is the current version of the stored DHCP leases structure. +const dataVersion = 1 + +// dataLeases is the structure of the stored DHCP leases. +type dataLeases struct { + // Leases is the list containing stored DHCP leases. + Leases []*dbLease `json:"leases"` + + // Version is the current version of the structure. + Version int `json:"version"` +} + +// dbLease is the structure of stored lease. +type dbLease struct { + Expiry string `json:"expires"` + IP netip.Addr `json:"ip"` + Hostname string `json:"hostname"` + HWAddr string `json:"mac"` + IsStatic bool `json:"static"` +} + +// compareNames returns the result of comparing the hostnames of dl and other +// lexicographically. +func (dl *dbLease) compareNames(other *dbLease) (res int) { + return strings.Compare(dl.Hostname, other.Hostname) +} + +// fromLease converts *Lease to *dbLease. +func fromLease(l *Lease) (dl *dbLease) { + var expiryStr string + if !l.IsStatic { + // The front-end is waiting for RFC 3999 format of the time value. It + // also shouldn't got an Expiry field for static leases. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/2692. + expiryStr = l.Expiry.Format(time.RFC3339) + } + + return &dbLease{ + Expiry: expiryStr, + Hostname: l.Hostname, + HWAddr: l.HWAddr.String(), + IP: l.IP, + IsStatic: l.IsStatic, + } +} + +// toLease converts dl to *Lease. +func (dl *dbLease) toLease() (l *Lease, err error) { + mac, err := net.ParseMAC(dl.HWAddr) + if err != nil { + return nil, fmt.Errorf("parsing hardware address: %w", err) + } + + expiry := time.Time{} + if !dl.IsStatic { + expiry, err = time.Parse(time.RFC3339, dl.Expiry) + if err != nil { + return nil, fmt.Errorf("parsing expiry time: %w", err) + } + } + + return &Lease{ + Expiry: expiry, + IP: dl.IP, + Hostname: dl.Hostname, + HWAddr: mac, + IsStatic: dl.IsStatic, + }, nil +} + +// dbLoad loads stored leases. It must only be called before the service has +// been started. +func (srv *DHCPServer) dbLoad(ctx context.Context) (err error) { + defer func() { err = errors.Annotate(err, "loading db: %w") }() + + file, err := os.Open(srv.dbFilePath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("reading db: %w", err) + } + + srv.logger.DebugContext(ctx, "no db file found") + + return nil + } + + dl := &dataLeases{} + err = json.NewDecoder(file).Decode(dl) + if err != nil { + return fmt.Errorf("decoding db: %w", err) + } + + srv.resetLeases() + srv.addDBLeases(ctx, dl.Leases) + + return nil +} + +// addDBLeases adds leases to the server. +func (srv *DHCPServer) addDBLeases(ctx context.Context, leases []*dbLease) { + const logMsg = "loading lease" + + var v4, v6 uint + for i, l := range leases { + var lease *Lease + lease, err := l.toLease() + if err != nil { + srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err) + + continue + } + + addr := l.IP + iface, err := srv.ifaceForAddr(addr) + if err != nil { + srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err) + + continue + } + + err = srv.leases.add(lease, iface) + if err != nil { + srv.logger.DebugContext(ctx, logMsg, "idx", i, slogutil.KeyError, err) + + continue + } + + if lease.IP.Is4() { + v4++ + } else { + v6++ + } + } + + srv.logger.InfoContext( + ctx, + "loaded leases", + "v4", v4, + "v6", v6, + "total", len(leases), + ) +} + +// writeDB writes leases to the database file. It expects the +// [DHCPServer.leasesMu] to be locked. +func (srv *DHCPServer) dbStore(ctx context.Context) (err error) { + defer func() { err = errors.Annotate(err, "writing db: %w") }() + + dl := &dataLeases{ + // Avoid writing "null" into the database file if there are no leases. + Leases: make([]*dbLease, 0, srv.leases.len()), + Version: dataVersion, + } + + srv.leases.rangeLeases(func(l *Lease) (cont bool) { + lease := fromLease(l) + i, _ := slices.BinarySearchFunc(dl.Leases, lease, (*dbLease).compareNames) + dl.Leases = slices.Insert(dl.Leases, i, lease) + + return true + }) + + buf, err := json.Marshal(dl) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + err = maybe.WriteFile(srv.dbFilePath, buf, 0o644) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return err + } + + srv.logger.InfoContext(ctx, "stored leases", "num", len(dl.Leases), "file", srv.dbFilePath) + + return nil +} diff --git a/internal/dhcpsvc/db_test.go b/internal/dhcpsvc/db_test.go new file mode 100644 index 00000000..f3a94766 --- /dev/null +++ b/internal/dhcpsvc/db_test.go @@ -0,0 +1,55 @@ +package dhcpsvc_test + +import ( + "net/netip" + "path/filepath" + "testing" + "time" + + "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" + "github.com/AdguardTeam/golibs/testutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestServer_loadDatabase(t *testing.T) { + leasesPath := filepath.Join("testdata", t.Name(), "leases.json") + + ipv4Conf := &dhcpsvc.IPv4Config{ + Enabled: true, + GatewayIP: netip.MustParseAddr("192.168.0.1"), + SubnetMask: netip.MustParseAddr("255.255.255.0"), + RangeStart: netip.MustParseAddr("192.168.0.2"), + RangeEnd: netip.MustParseAddr("192.168.0.254"), + LeaseDuration: 1 * time.Hour, + } + conf := &dhcpsvc.Config{ + Enabled: true, + LocalDomainName: "local", + Interfaces: map[string]*dhcpsvc.InterfaceConfig{ + "eth0": { + IPv4: ipv4Conf, + IPv6: &dhcpsvc.IPv6Config{Enabled: false}, + }, + }, + DBFilePath: leasesPath, + Logger: discardLog, + } + + ctx := testutil.ContextWithTimeout(t, testTimeout) + + srv, err := dhcpsvc.New(ctx, conf) + require.NoError(t, err) + + expiry, err := time.Parse(time.RFC3339, "2042-01-02T03:04:05Z") + require.NoError(t, err) + + wantLeases := []*dhcpsvc.Lease{{ + Expiry: expiry, + IP: netip.MustParseAddr("192.168.0.3"), + Hostname: "example.host", + HWAddr: mustParseMAC(t, "AA:AA:AA:AA:AA:AA"), + IsStatic: false, + }} + assert.Equal(t, wantLeases, srv.Leases()) +} diff --git a/internal/dhcpsvc/leaseindex.go b/internal/dhcpsvc/leaseindex.go index c9487b75..855d6b84 100644 --- a/internal/dhcpsvc/leaseindex.go +++ b/internal/dhcpsvc/leaseindex.go @@ -124,3 +124,18 @@ func (idx *leaseIndex) update(l *Lease, iface *netInterface) (err error) { return nil } + +// rangeLeases calls f for each lease in idx in an unspecified order until f +// returns false. +func (idx *leaseIndex) rangeLeases(f func(l *Lease) (cont bool)) { + for _, l := range idx.byName { + if !f(l) { + break + } + } +} + +// len returns the number of leases in idx. +func (idx *leaseIndex) len() (l uint) { + return uint(len(idx.byAddr)) +} diff --git a/internal/dhcpsvc/server.go b/internal/dhcpsvc/server.go index cd1e93b2..19e5c645 100644 --- a/internal/dhcpsvc/server.go +++ b/internal/dhcpsvc/server.go @@ -27,6 +27,12 @@ type DHCPServer struct { // hostnames. localTLD string + // dbFilePath is the path to the database file containing the DHCP leases. + // + // TODO(e.burkov): Perhaps, extract leases and database into a separate + // type. + dbFilePath string + // leasesMu protects the leases index as well as leases in the interfaces. leasesMu *sync.RWMutex @@ -93,9 +99,16 @@ func New(ctx context.Context, conf *Config) (srv *DHCPServer, err error) { interfaces4: ifaces4, interfaces6: ifaces6, icmpTimeout: conf.ICMPTimeout, + dbFilePath: conf.DBFilePath, } - // TODO(e.burkov): Load leases. + // TODO(e.burkov): !! migrate? + + err = srv.dbLoad(ctx) + if err != nil { + // Don't wrap the error since it's informative enough as is. + return nil, err + } return srv, nil } @@ -167,9 +180,26 @@ func (srv *DHCPServer) IPByHost(host string) (ip netip.Addr) { // Reset implements the [Interface] interface for *DHCPServer. func (srv *DHCPServer) Reset(ctx context.Context) (err error) { + defer func() { err = errors.Annotate(err, "resetting leases: %w") }() + srv.leasesMu.Lock() defer srv.leasesMu.Unlock() + srv.resetLeases() + err = srv.dbStore(ctx) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + + srv.logger.DebugContext(ctx, "reset leases") + + return nil +} + +// resetLeases resets the leases for all network interfaces of the server. It +// expects the DHCPServer.leasesMu to be locked. +func (srv *DHCPServer) resetLeases() { for _, iface := range srv.interfaces4 { iface.reset() } @@ -177,10 +207,6 @@ func (srv *DHCPServer) Reset(ctx context.Context) (err error) { iface.reset() } srv.leases.clear() - - srv.logger.DebugContext(ctx, "reset leases") - - return nil } // AddLease implements the [Interface] interface for *DHCPServer. @@ -203,6 +229,12 @@ func (srv *DHCPServer) AddLease(ctx context.Context, l *Lease) (err error) { return err } + err = srv.dbStore(ctx) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + iface.logger.DebugContext( ctx, "added lease", "hostname", l.Hostname, @@ -236,6 +268,12 @@ func (srv *DHCPServer) UpdateStaticLease(ctx context.Context, l *Lease) (err err return err } + err = srv.dbStore(ctx) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + iface.logger.DebugContext( ctx, "updated lease", "hostname", l.Hostname, @@ -267,6 +305,12 @@ func (srv *DHCPServer) RemoveLease(ctx context.Context, l *Lease) (err error) { return err } + err = srv.dbStore(ctx) + if err != nil { + // Don't wrap the error since there is already an annotation deferred. + return err + } + iface.logger.DebugContext( ctx, "removed lease", "hostname", l.Hostname, diff --git a/internal/dhcpsvc/server_test.go b/internal/dhcpsvc/server_test.go index 5f6f002f..3ce9c0a9 100644 --- a/internal/dhcpsvc/server_test.go +++ b/internal/dhcpsvc/server_test.go @@ -3,6 +3,7 @@ package dhcpsvc_test import ( "net" "net/netip" + "path/filepath" "strings" "testing" "time" @@ -14,6 +15,8 @@ import ( "github.com/stretchr/testify/require" ) +// TODO(e.burkov): !! Adjust tests and check the leases db. + // testLocalTLD is a common local TLD for tests. const testLocalTLD = "local" @@ -103,6 +106,8 @@ func TestNew(t *testing.T) { RASLAACOnly: true, } + leasesPath := filepath.Join(t.TempDir(), "leases.json") + testCases := []struct { conf *dhcpsvc.Config name string @@ -118,6 +123,7 @@ func TestNew(t *testing.T) { IPv6: validIPv6Conf, }, }, + DBFilePath: leasesPath, }, name: "valid", wantErrMsg: "", @@ -132,6 +138,7 @@ func TestNew(t *testing.T) { IPv6: &dhcpsvc.IPv6Config{Enabled: false}, }, }, + DBFilePath: leasesPath, }, name: "disabled_interfaces", wantErrMsg: "", @@ -146,6 +153,7 @@ func TestNew(t *testing.T) { IPv6: validIPv6Conf, }, }, + DBFilePath: leasesPath, }, name: "gateway_within_range", wantErrMsg: `interface "eth0": ipv4: ` + @@ -161,6 +169,7 @@ func TestNew(t *testing.T) { IPv6: validIPv6Conf, }, }, + DBFilePath: leasesPath, }, name: "bad_start", wantErrMsg: `interface "eth0": ipv4: ` + @@ -185,6 +194,7 @@ func TestDHCPServer_AddLease(t *testing.T) { Logger: discardLog, LocalDomainName: testLocalTLD, Interfaces: testInterfaceConf, + DBFilePath: filepath.Join(t.TempDir(), "leases.json"), }) require.NoError(t, err) @@ -290,6 +300,7 @@ func TestDHCPServer_index(t *testing.T) { Logger: discardLog, LocalDomainName: testLocalTLD, Interfaces: testInterfaceConf, + DBFilePath: filepath.Join(t.TempDir(), "leases.json"), }) require.NoError(t, err) @@ -368,6 +379,7 @@ func TestDHCPServer_UpdateStaticLease(t *testing.T) { Logger: discardLog, LocalDomainName: testLocalTLD, Interfaces: testInterfaceConf, + DBFilePath: filepath.Join(t.TempDir(), "leases.json"), }) require.NoError(t, err) @@ -491,6 +503,7 @@ func TestDHCPServer_RemoveLease(t *testing.T) { Logger: discardLog, LocalDomainName: testLocalTLD, Interfaces: testInterfaceConf, + DBFilePath: filepath.Join(t.TempDir(), "leases.json"), }) require.NoError(t, err) @@ -579,14 +592,17 @@ func TestDHCPServer_RemoveLease(t *testing.T) { } func TestDHCPServer_Reset(t *testing.T) { - ctx := testutil.ContextWithTimeout(t, testTimeout) - - srv, err := dhcpsvc.New(ctx, &dhcpsvc.Config{ + leasesPath := filepath.Join(t.TempDir(), "leases.json") + conf := &dhcpsvc.Config{ Enabled: true, Logger: discardLog, LocalDomainName: testLocalTLD, Interfaces: testInterfaceConf, - }) + DBFilePath: leasesPath, + } + + ctx := testutil.ContextWithTimeout(t, testTimeout) + srv, err := dhcpsvc.New(ctx, conf) require.NoError(t, err) leases := []*dhcpsvc.Lease{{ @@ -619,5 +635,6 @@ func TestDHCPServer_Reset(t *testing.T) { require.NoError(t, srv.Reset(ctx)) + assert.FileExists(t, leasesPath) assert.Empty(t, srv.Leases()) } diff --git a/internal/dhcpsvc/testdata/TestServer_loadDatabase/leases.json b/internal/dhcpsvc/testdata/TestServer_loadDatabase/leases.json new file mode 100644 index 00000000..069f3595 --- /dev/null +++ b/internal/dhcpsvc/testdata/TestServer_loadDatabase/leases.json @@ -0,0 +1,10 @@ +{ + "leases": [{ + "expires": "2042-01-02T03:04:05Z", + "ip": "192.168.0.3", + "hostname": "example.host", + "mac": "AA:AA:AA:AA:AA:AA", + "static": false + }], + "version": 1 +} diff --git a/internal/dhcpsvc/v4.go b/internal/dhcpsvc/v4.go index 09df8013..10624105 100644 --- a/internal/dhcpsvc/v4.go +++ b/internal/dhcpsvc/v4.go @@ -120,6 +120,7 @@ func newNetInterfaceV4( keyInterface, name, keyFamily, netutil.AddrFamilyIPv4, ) + if !conf.Enabled { l.DebugContext(ctx, "disabled")