diff --git a/internal/aghnet/ipset.go b/internal/aghnet/ipset.go new file mode 100644 index 00000000..422d4b4f --- /dev/null +++ b/internal/aghnet/ipset.go @@ -0,0 +1,26 @@ +package aghnet + +import ( + "net" +) + +// IpsetManager is the ipset manager interface. +// +// TODO(a.garipov): Perhaps generalize this into some kind of a NetFilter type, +// since ipset is exclusive to Linux? +type IpsetManager interface { + Add(host string, ip4s, ip6s []net.IP) (n int, err error) + Close() (err error) +} + +// NewIpsetManager returns a new ipset. IPv4 addresses are added to an ipset +// with an ipv4 family; IPv6 addresses, to an ipv6 ipset. ipset must exist. +// +// The syntax of the ipsetConf is: +// +// DOMAIN[,DOMAIN].../IPSET_NAME[,IPSET_NAME]... +// +// The error is of type *aghos.UnsupportedError if the OS is not supported. +func NewIpsetManager(ipsetConf []string) (mgr IpsetManager, err error) { + return newIpsetMgr(ipsetConf) +} diff --git a/internal/aghnet/ipset_linux.go b/internal/aghnet/ipset_linux.go new file mode 100644 index 00000000..095399a3 --- /dev/null +++ b/internal/aghnet/ipset_linux.go @@ -0,0 +1,376 @@ +//go:build linux +// +build linux + +package aghnet + +import ( + "fmt" + "net" + "strings" + "sync" + + "github.com/AdguardTeam/golibs/errors" + "github.com/digineo/go-ipset/v2" + "github.com/mdlayher/netlink" + "github.com/ti-mo/netfilter" +) + +// How to test on a real Linux machine: +// +// 1. Run: +// +// sudo ipset create example_set hash:ip family ipv4 +// +// 2. Run: +// +// sudo ipset list example_set +// +// The Members field should be empty. +// +// 3. Add the line "example.com/example_set" to your AdGuardHome.yaml. +// +// 4. Start AdGuardHome. +// +// 5. Make requests to example.com and its subdomains. +// +// 6. Run: +// +// sudo ipset list example_set +// +// The Members field should contain the resolved IP addresses. + +// newIpsetMgr returns a new Linux ipset manager. +func newIpsetMgr(ipsetConf []string) (set IpsetManager, err error) { + dial := func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) { + return ipset.Dial(pf, conf) + } + + return newIpsetMgrWithDialer(ipsetConf, dial) +} + +// ipsetConn is the ipset conn interface. +type ipsetConn interface { + Add(name string, entries ...*ipset.Entry) (err error) + Close() (err error) + Header(name string) (p *ipset.HeaderPolicy, err error) +} + +// ipsetDialer creates an ipsetConn. +type ipsetDialer func(pf netfilter.ProtoFamily, conf *netlink.Config) (conn ipsetConn, err error) + +// ipsetProps contains one Linux Netfilter ipset properties. +type ipsetProps struct { + name string + family netfilter.ProtoFamily +} + +// unit is a convenient alias for struct{}. +type unit = struct{} + +// ipsetMgr is the Linux Netfilter ipset manager. +type ipsetMgr struct { + nameToIpset map[string]ipsetProps + domainToIpsets map[string][]ipsetProps + + dial ipsetDialer + + // mu protects all properties below. + mu *sync.Mutex + + // TODO(a.garipov): Currently, the ipset list is static, and we don't + // read the IPs already in sets, so we can assume that all incoming IPs + // are either added to all corresponding ipsets or not. When that stops + // being the case, for example if we add dynamic reconfiguration of + // ipsets, this map will need to become a per-ipset-name one. + addedIPs map[[16]byte]unit + + ipv4Conn ipsetConn + ipv6Conn ipsetConn +} + +// dialNetfilter establishes connections to Linux's netfilter module. +func (m *ipsetMgr) dialNetfilter(conf *netlink.Config) (err error) { + // The kernel API does not actually require two sockets but package + // github.com/digineo/go-ipset does. + // + // TODO(a.garipov): Perhaps we can ditch package ipset altogether and + // just use packages netfilter and netlink. + m.ipv4Conn, err = m.dial(netfilter.ProtoIPv4, conf) + if err != nil { + return fmt.Errorf("dialing v4: %w", err) + } + + m.ipv6Conn, err = m.dial(netfilter.ProtoIPv6, conf) + if err != nil { + return fmt.Errorf("dialing v6: %w", err) + } + + return nil +} + +// parseIpsetConfig parses one ipset configuration string. +func parseIpsetConfig(confStr string) (hosts, ipsetNames []string, err error) { + confStr = strings.TrimSpace(confStr) + hostsAndNames := strings.Split(confStr, "/") + if len(hostsAndNames) != 2 { + return nil, nil, fmt.Errorf("invalid value %q: expected one slash", confStr) + } + + hosts = strings.Split(hostsAndNames[0], ",") + ipsetNames = strings.Split(hostsAndNames[1], ",") + + if len(ipsetNames) == 0 { + return nil, nil, nil + } + + for i := range ipsetNames { + ipsetNames[i] = strings.TrimSpace(ipsetNames[i]) + if len(ipsetNames[i]) == 0 { + return nil, nil, fmt.Errorf("invalid value %q: empty ipset name", confStr) + } + } + + for i := range hosts { + hosts[i] = strings.ToLower(strings.TrimSpace(hosts[i])) + } + + return hosts, ipsetNames, nil +} + +// ipsetProps returns the properties of an ipset with the given name. +func (m *ipsetMgr) ipsetProps(name string) (set ipsetProps, err error) { + // The family doesn't seem to matter when we use a header query, so + // query only the IPv4 one. + // + // TODO(a.garipov): Find out if this is a bug or a feature. + var res *ipset.HeaderPolicy + res, err = m.ipv4Conn.Header(name) + if err != nil { + return set, err + } + + if res == nil || res.Family == nil { + return set, errors.Error("empty response or no family data") + } + + family := netfilter.ProtoFamily(res.Family.Value) + if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 { + return set, fmt.Errorf("unexpected ipset family %d", family) + } + + return ipsetProps{ + name: name, + family: family, + }, nil +} + +// ipsets returns currently known ipsets. +func (m *ipsetMgr) ipsets(names []string) (sets []ipsetProps, err error) { + for _, name := range names { + set, ok := m.nameToIpset[name] + if ok { + sets = append(sets, set) + + continue + } + + set, err = m.ipsetProps(name) + if err != nil { + return nil, fmt.Errorf("querying ipset %q: %w", name, err) + } + + m.nameToIpset[name] = set + sets = append(sets, set) + } + + return sets, nil +} + +// newIpsetMgrWithDialer returns a new Linux ipset manager using the provided +// dialer. +func newIpsetMgrWithDialer(ipsetConf []string, dial ipsetDialer) (mgr IpsetManager, err error) { + defer func() { err = errors.Annotate(err, "ipset: %w") }() + + m := &ipsetMgr{ + mu: &sync.Mutex{}, + + nameToIpset: make(map[string]ipsetProps), + domainToIpsets: make(map[string][]ipsetProps), + + dial: dial, + + addedIPs: make(map[[16]byte]unit), + } + + err = m.dialNetfilter(&netlink.Config{}) + if err != nil { + return nil, fmt.Errorf("dialing netfilter: %w", err) + } + + for i, confStr := range ipsetConf { + var hosts, ipsetNames []string + hosts, ipsetNames, err = parseIpsetConfig(confStr) + if err != nil { + return nil, fmt.Errorf("config line at idx %d: %w", i, err) + } + + var ipsets []ipsetProps + ipsets, err = m.ipsets(ipsetNames) + if err != nil { + return nil, fmt.Errorf( + "getting ipsets from config line at idx %d: %w", + i, + err, + ) + } + + for _, host := range hosts { + m.domainToIpsets[host] = append(m.domainToIpsets[host], ipsets...) + } + } + + return m, nil +} + +// lookupHost find the ipsets for the host, taking subdomain wildcards into +// account. +func (m *ipsetMgr) lookupHost(host string) (sets []ipsetProps) { + // Search for matching ipset hosts starting with most specific domain. + // We could use a trie here but the simple, inefficient solution isn't + // that expensive: ~10 ns for TLD + SLD vs. ~140 ns for 10 subdomains on + // an AMD Ryzen 7 PRO 4750U CPU; ~120 ns vs. ~ 1500 ns on a Raspberry + // Pi's ARMv7 rev 4 CPU. + for i := 0; ; i++ { + host = host[i:] + sets = m.domainToIpsets[host] + if sets != nil { + return sets + } + + i = strings.Index(host, ".") + if i == -1 { + break + } + } + + // Check the root catch-all one. + return m.domainToIpsets[""] +} + +// addIPs adds the IP addresses for the host to the ipset. set must be same +// family as set's family. +func (m *ipsetMgr) addIPs(host string, set ipsetProps, ips []net.IP) (n int, err error) { + if len(ips) == 0 { + return 0, nil + } + + var entries []*ipset.Entry + var newAddedIPs [][16]byte + for _, ip := range ips { + var iparr [16]byte + copy(iparr[:], ip.To16()) + if _, added := m.addedIPs[iparr]; added { + continue + } + + entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip))) + newAddedIPs = append(newAddedIPs, iparr) + } + + n = len(entries) + if n == 0 { + return 0, nil + } + + var conn ipsetConn + switch set.family { + case netfilter.ProtoIPv4: + conn = m.ipv4Conn + case netfilter.ProtoIPv6: + conn = m.ipv6Conn + default: + return 0, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name) + } + + err = conn.Add(set.name, entries...) + if err != nil { + return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err) + } + + // Only add these to the cache once we're sure that all of them were + // actually sent to the ipset. + for _, iparr := range newAddedIPs { + m.addedIPs[iparr] = unit{} + } + + return n, nil +} + +// addToSets adds the IP addresses to the corresponding ipset. +func (m *ipsetMgr) addToSets( + host string, + ip4s []net.IP, + ip6s []net.IP, + sets []ipsetProps, +) (n int, err error) { + for _, set := range sets { + var nn int + switch set.family { + case netfilter.ProtoIPv4: + nn, err = m.addIPs(host, set, ip4s) + if err != nil { + return n, err + } + case netfilter.ProtoIPv6: + nn, err = m.addIPs(host, set, ip6s) + if err != nil { + return n, err + } + default: + return n, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name) + } + + n += nn + } + + return n, nil +} + +// Add implements the IpsetManager interface for *ipsetMgr +func (m *ipsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + + sets := m.lookupHost(host) + if len(sets) == 0 { + return 0, nil + } + + return m.addToSets(host, ip4s, ip6s, sets) +} + +// Close implements the IpsetManager interface for *ipsetMgr. +func (m *ipsetMgr) Close() (err error) { + m.mu.Lock() + defer m.mu.Unlock() + + var errs []error + + // Close both and collect errors so that the errors from closing one + // don't interfere with closing the other. + err = m.ipv4Conn.Close() + if err != nil { + errs = append(errs, err) + } + + err = m.ipv6Conn.Close() + if err != nil { + errs = append(errs, err) + } + + if len(errs) != 0 { + return errors.List("closing ipsets", errs...) + } + + return nil +} diff --git a/internal/aghnet/ipset_linux_test.go b/internal/aghnet/ipset_linux_test.go new file mode 100644 index 00000000..12c842a0 --- /dev/null +++ b/internal/aghnet/ipset_linux_test.go @@ -0,0 +1,155 @@ +//go:build linux +// +build linux + +package aghnet + +import ( + "net" + "strings" + "testing" + + "github.com/AdguardTeam/golibs/errors" + "github.com/digineo/go-ipset/v2" + "github.com/mdlayher/netlink" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/ti-mo/netfilter" +) + +// fakeIpsetConn is a fake ipsetConn for tests. +type fakeIpsetConn struct { + ipv4Header *ipset.HeaderPolicy + ipv4Entries *[]*ipset.Entry + ipv6Header *ipset.HeaderPolicy + ipv6Entries *[]*ipset.Entry +} + +// Add implements the ipsetConn interface for *fakeIpsetConn. +func (c *fakeIpsetConn) Add(name string, entries ...*ipset.Entry) (err error) { + if strings.Contains(name, "ipv4") { + *c.ipv4Entries = append(*c.ipv4Entries, entries...) + + return nil + } else if strings.Contains(name, "ipv6") { + *c.ipv6Entries = append(*c.ipv6Entries, entries...) + + return nil + } + + return errors.Error("test: ipset not found") +} + +// Close implements the ipsetConn interface for *fakeIpsetConn. +func (c *fakeIpsetConn) Close() (err error) { + return nil +} + +// Header implements the ipsetConn interface for *fakeIpsetConn. +func (c *fakeIpsetConn) Header(name string) (p *ipset.HeaderPolicy, err error) { + if strings.Contains(name, "ipv4") { + return c.ipv4Header, nil + } else if strings.Contains(name, "ipv6") { + return c.ipv6Header, nil + } + + return nil, errors.Error("test: ipset not found") +} + +func TestIpsetMgr_Add(t *testing.T) { + ipsetConf := []string{ + "example.com,example.net/ipv4set", + "example.org,example.biz/ipv6set", + } + + var ipv4Entries []*ipset.Entry + var ipv6Entries []*ipset.Entry + + fakeDial := func( + pf netfilter.ProtoFamily, + conf *netlink.Config, + ) (conn ipsetConn, err error) { + return &fakeIpsetConn{ + ipv4Header: &ipset.HeaderPolicy{ + Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv4)), + }, + ipv4Entries: &ipv4Entries, + ipv6Header: &ipset.HeaderPolicy{ + Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)), + }, + ipv6Entries: &ipv6Entries, + }, nil + } + + m, err := newIpsetMgrWithDialer(ipsetConf, fakeDial) + require.NoError(t, err) + + ip4 := net.IP{1, 2, 3, 4} + ip6 := net.IP{ + 0x12, 0x34, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x56, 0x78, + } + + n, err := m.Add("example.net", []net.IP{ip4}, nil) + require.NoError(t, err) + + assert.Equal(t, 1, n) + + require.Len(t, ipv4Entries, 1) + + gotIP4 := ipv4Entries[0].IP.Value + assert.Equal(t, ip4, gotIP4) + + n, err = m.Add("example.biz", nil, []net.IP{ip6}) + require.NoError(t, err) + + assert.Equal(t, 1, n) + + require.Len(t, ipv6Entries, 1) + + gotIP6 := ipv6Entries[0].IP.Value + assert.Equal(t, ip6, gotIP6) + + err = m.Close() + assert.NoError(t, err) +} + +var ipsetPropsSink []ipsetProps + +func BenchmarkIpsetMgr_lookupHost(b *testing.B) { + propsLong := []ipsetProps{{ + name: "example.com", + family: netfilter.ProtoIPv4, + }} + + propsShort := []ipsetProps{{ + name: "example.net", + family: netfilter.ProtoIPv4, + }} + + m := &ipsetMgr{ + domainToIpsets: map[string][]ipsetProps{ + "": propsLong, + "example.net": propsShort, + }, + } + + b.Run("long", func(b *testing.B) { + const name = "a.very.long.domain.name.inside.the.domain.example.com" + for i := 0; i < b.N; i++ { + ipsetPropsSink = m.lookupHost(name) + } + + require.Equal(b, propsLong, ipsetPropsSink) + }) + + b.Run("short", func(b *testing.B) { + const name = "example.net" + for i := 0; i < b.N; i++ { + ipsetPropsSink = m.lookupHost(name) + } + + require.Equal(b, propsShort, ipsetPropsSink) + }) +} diff --git a/internal/aghnet/ipset_others.go b/internal/aghnet/ipset_others.go new file mode 100644 index 00000000..814c35be --- /dev/null +++ b/internal/aghnet/ipset_others.go @@ -0,0 +1,12 @@ +//go:build !linux +// +build !linux + +package aghnet + +import ( + "github.com/AdguardTeam/AdGuardHome/internal/aghos" +) + +func newIpsetMgr(_ []string) (mgr IpsetManager, err error) { + return nil, aghos.Unsupported("ipset") +} diff --git a/internal/aghos/os.go b/internal/aghos/os.go index 6807ad69..e3e2fb5e 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -1,4 +1,6 @@ -// Package aghos contains utilities for functions requiring system calls. +// Package aghos contains utilities for functions requiring system calls and +// other OS-specific APIs. OS-specific network handling should go to aghnet +// instead. package aghos import ( diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 3b28c236..49494422 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -111,10 +111,12 @@ type FilteringConfig struct { EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests - // IPSET configuration - add IP addresses of the specified domain names to an ipset list - // Syntax: - // "DOMAIN[,DOMAIN].../IPSET_NAME" - IPSETList []string `yaml:"ipset"` + // IpsetList is the ipset configuration that allows AdGuard Home to add + // IP addresses of the specified domain names to an ipset list. Syntax: + // + // DOMAIN[,DOMAIN].../IPSET_NAME + // + IpsetList []string `yaml:"ipset"` } // TLSConfig is the TLS configuration for HTTPS, DNS-over-HTTPS, and DNS-over-TLS diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index d1d30d42..616ab4f8 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/http" - "os" "runtime" "strings" "sync" @@ -203,7 +202,7 @@ func (s *Server) Close() { s.queryLog = nil s.dnsProxy = nil - if err := s.ipset.Close(); err != nil { + if err := s.ipset.close(); err != nil { log.Error("closing ipset: %s", err) } } @@ -451,26 +450,15 @@ func (s *Server) Prepare(config *ServerConfig) error { // -- s.initDefaultSettings() - // Initialize IPSET configuration + // Initialize ipset configuration // -- - err := s.ipset.init(s.conf.IPSETList) + err := s.ipset.init(s.conf.IpsetList) if err != nil { - if !errors.Is(err, os.ErrInvalid) && !errors.Is(err, os.ErrPermission) { - return fmt.Errorf("cannot initialize ipset: %w", err) - } - - // ipset cannot currently be initialized if the server was - // installed from Snap or when the user or the binary doesn't - // have the required permissions, or when the kernel doesn't - // support netfilter. - // - // Log and go on. - // - // TODO(a.garipov): The Snap problem can probably be solved if - // we add the netlink-connector interface plug. - log.Info("warning: cannot initialize ipset: %s", err) + return err } + log.Debug("inited ipset") + // Prepare DNS servers settings // -- err = s.prepareUpstreamSettings() diff --git a/internal/dnsforward/ipset.go b/internal/dnsforward/ipset.go new file mode 100644 index 00000000..2b52be85 --- /dev/null +++ b/internal/dnsforward/ipset.go @@ -0,0 +1,137 @@ +package dnsforward + +import ( + "fmt" + "net" + "os" + "strings" + + "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/log" + "github.com/miekg/dns" +) + +// ipsetCtx is the ipset context. ipsetMgr can be nil. +type ipsetCtx struct { + ipsetMgr aghnet.IpsetManager +} + +// init initializes the ipset context. It is not safe for concurrent use. +// +// TODO(a.garipov): Rewrite into a simple constructor? +func (c *ipsetCtx) init(ipsetConf []string) (err error) { + c.ipsetMgr, err = aghnet.NewIpsetManager(ipsetConf) + if errors.Is(err, os.ErrInvalid) || errors.Is(err, os.ErrPermission) { + // ipset cannot currently be initialized if the server was + // installed from Snap or when the user or the binary doesn't + // have the required permissions, or when the kernel doesn't + // support netfilter. + // + // Log and go on. + // + // TODO(a.garipov): The Snap problem can probably be solved if + // we add the netlink-connector interface plug. + log.Info("warning: cannot initialize ipset: %s", err) + + return nil + } else if unsupErr := (&aghos.UnsupportedError{}); errors.As(err, &unsupErr) { + log.Info("warning: %s", err) + + return nil + } else if err != nil { + return fmt.Errorf("initializing ipset: %w", err) + } + + return nil +} + +// close closes the Linux Netfilter connections. +func (c *ipsetCtx) close() (err error) { + if c.ipsetMgr != nil { + return c.ipsetMgr.Close() + } + + return nil +} + +func (c *ipsetCtx) dctxIsfilled(dctx *dnsContext) (ok bool) { + return dctx != nil && + dctx.responseFromUpstream && + dctx.proxyCtx != nil && + dctx.proxyCtx.Res != nil && + dctx.proxyCtx.Req != nil && + len(dctx.proxyCtx.Req.Question) > 0 +} + +// skipIpsetProcessing returns true when the ipset processing can be skipped for +// this request. +func (c *ipsetCtx) skipIpsetProcessing(dctx *dnsContext) (ok bool) { + if c == nil || c.ipsetMgr == nil || !c.dctxIsfilled(dctx) { + return true + } + + qtype := dctx.proxyCtx.Req.Question[0].Qtype + + return qtype != dns.TypeA && qtype != dns.TypeAAAA && qtype != dns.TypeANY +} + +// ipFromRR returns an IP address from a DNS resource record. +func ipFromRR(rr dns.RR) (ip net.IP) { + switch a := rr.(type) { + case *dns.A: + return a.A + case *dns.AAAA: + return a.AAAA + default: + return nil + } +} + +// ipsFromAnswer returns IPv4 and IPv6 addresses from a DNS answer. +func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) { + for _, rr := range ans { + ip := ipFromRR(rr) + if ip == nil { + continue + } + + if ip.To4() == nil { + ip6s = append(ip6s, ip) + + continue + } + + ip4s = append(ip4s, ip) + } + + return ip4s, ip6s +} + +// process adds the resolved IP addresses to the domain's ipsets, if any. +func (c *ipsetCtx) process(dctx *dnsContext) (rc resultCode) { + if c.skipIpsetProcessing(dctx) { + return resultCodeSuccess + } + + log.Debug("ipset: starting processing") + + req := dctx.proxyCtx.Req + host := req.Question[0].Name + host = strings.TrimSuffix(host, ".") + host = strings.ToLower(host) + + ip4s, ip6s := ipsFromAnswer(dctx.proxyCtx.Res.Answer) + n, err := c.ipsetMgr.Add(host, ip4s, ip6s) + if err != nil { + // Consider ipset errors non-critical to the request. + log.Error("ipset: adding host ips: %s", err) + + return resultCodeSuccess + } + + log.Debug("ipset: added %d new ips", n) + + return resultCodeSuccess +} diff --git a/internal/dnsforward/ipset_linux.go b/internal/dnsforward/ipset_linux.go deleted file mode 100644 index e94c87c2..00000000 --- a/internal/dnsforward/ipset_linux.go +++ /dev/null @@ -1,401 +0,0 @@ -//go:build linux -// +build linux - -package dnsforward - -import ( - "fmt" - "net" - "strings" - "sync" - - "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" - "github.com/digineo/go-ipset/v2" - "github.com/mdlayher/netlink" - "github.com/miekg/dns" - "github.com/ti-mo/netfilter" -) - -// TODO(a.garipov): Cover with unit tests as well as document how to test it -// manually. The original PR by @dsheets on Github contained an integration -// test, but unfortunately I didn't have the time to properly refactor it and -// check it in. -// -// See https://github.com/AdguardTeam/AdGuardHome/issues/2611. - -// ipsetProps contains one Linux Netfilter ipset properties. -type ipsetProps struct { - name string - family netfilter.ProtoFamily -} - -// ipsetCtx is the Linux Netfilter ipset context. -type ipsetCtx struct { - // mu protects all properties below. - mu *sync.Mutex - - nameToIpset map[string]ipsetProps - domainToIpsets map[string][]ipsetProps - - // TODO(a.garipov): Currently, the ipset list is static, and we don't - // read the IPs already in sets, so we can assume that all incoming IPs - // are either added to all corresponding ipsets or not. When that stops - // being the case, for example if we add dynamic reconfiguration of - // ipsets, this map will need to become a per-ipset-name one. - addedIPs map[[16]byte]struct{} - - ipv4Conn *ipset.Conn - ipv6Conn *ipset.Conn -} - -// dialNetfilter establishes connections to Linux's netfilter module. -func (c *ipsetCtx) dialNetfilter(config *netlink.Config) (err error) { - // The kernel API does not actually require two sockets but package - // github.com/digineo/go-ipset does. - // - // TODO(a.garipov): Perhaps we can ditch package ipset altogether and - // just use packages netfilter and netlink. - c.ipv4Conn, err = ipset.Dial(netfilter.ProtoIPv4, config) - if err != nil { - return fmt.Errorf("dialing v4: %w", err) - } - - c.ipv6Conn, err = ipset.Dial(netfilter.ProtoIPv6, config) - if err != nil { - return fmt.Errorf("dialing v6: %w", err) - } - - return nil -} - -// ipsetProps returns the properties of an ipset with the given name. -func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) { - // The family doesn't seem to matter when we use a header query, so - // query only the IPv4 one. - // - // TODO(a.garipov): Find out if this is a bug or a feature. - var res *ipset.HeaderPolicy - res, err = c.ipv4Conn.Header(name) - if err != nil { - return set, err - } - - if res == nil || res.Family == nil { - return set, errors.Error("empty response or no family data") - } - - family := netfilter.ProtoFamily(res.Family.Value) - if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 { - return set, fmt.Errorf("unexpected ipset family %s", family) - } - - return ipsetProps{ - name: name, - family: family, - }, nil -} - -// ipsets returns currently known ipsets. -func (c *ipsetCtx) ipsets(names []string) (sets []ipsetProps, err error) { - for _, name := range names { - set, ok := c.nameToIpset[name] - if ok { - sets = append(sets, set) - - continue - } - - set, err = c.ipsetProps(name) - if err != nil { - return nil, fmt.Errorf("querying ipset %q: %w", name, err) - } - - c.nameToIpset[name] = set - sets = append(sets, set) - } - - return sets, nil -} - -// parseIpsetConfig parses one ipset configuration string. -func parseIpsetConfig(cfgStr string) (hosts, ipsetNames []string, err error) { - cfgStr = strings.TrimSpace(cfgStr) - hostsAndNames := strings.Split(cfgStr, "/") - if len(hostsAndNames) != 2 { - return nil, nil, fmt.Errorf("invalid value %q: expected one slash", cfgStr) - } - - hosts = strings.Split(hostsAndNames[0], ",") - ipsetNames = strings.Split(hostsAndNames[1], ",") - - if len(ipsetNames) == 0 { - log.Info("ipset: resolutions for %q will not be stored", hosts) - - return nil, nil, nil - } - - for i := range ipsetNames { - ipsetNames[i] = strings.TrimSpace(ipsetNames[i]) - if len(ipsetNames[i]) == 0 { - return nil, nil, fmt.Errorf("invalid value %q: empty ipset name", cfgStr) - } - } - - for i := range hosts { - hosts[i] = strings.TrimSpace(hosts[i]) - hosts[i] = strings.ToLower(hosts[i]) - if len(hosts[i]) == 0 { - log.Info("ipset: root catchall in %q", ipsetNames) - } - } - - return hosts, ipsetNames, nil -} - -// init initializes the ipset context. It is not safe for concurrent use. -// -// TODO(a.garipov): Rewrite into a simple constructor? -func (c *ipsetCtx) init(ipsetConfig []string) (err error) { - c.mu = &sync.Mutex{} - c.nameToIpset = make(map[string]ipsetProps) - c.domainToIpsets = make(map[string][]ipsetProps) - c.addedIPs = make(map[[16]byte]struct{}) - - err = c.dialNetfilter(&netlink.Config{}) - if err != nil { - return fmt.Errorf("ipset: dialing netfilter: %w", err) - } - - for i, cfgStr := range ipsetConfig { - var hosts, ipsetNames []string - hosts, ipsetNames, err = parseIpsetConfig(cfgStr) - if err != nil { - return fmt.Errorf("ipset: config line at index %d: %w", i, err) - } - - var ipsets []ipsetProps - ipsets, err = c.ipsets(ipsetNames) - if err != nil { - return fmt.Errorf("ipset: getting ipsets config line at index %d: %w", i, err) - } - - for _, host := range hosts { - c.domainToIpsets[host] = append(c.domainToIpsets[host], ipsets...) - } - } - - log.Debug("ipset: added %d domains for %d ipsets", len(c.domainToIpsets), len(c.nameToIpset)) - - return nil -} - -// Close closes the Linux Netfilter connections. -func (c *ipsetCtx) Close() (err error) { - var errs []error - if c.ipv4Conn != nil { - err = c.ipv4Conn.Close() - if err != nil { - errs = append(errs, err) - } - } - - if c.ipv6Conn != nil { - err = c.ipv6Conn.Close() - if err != nil { - errs = append(errs, err) - } - } - - if len(errs) != 0 { - return errors.List("closing ipsets", errs...) - } - - return nil -} - -// ipFromRR returns an IP address from a DNS resource record. -func ipFromRR(rr dns.RR) (ip net.IP) { - switch a := rr.(type) { - case *dns.A: - return a.A - case *dns.AAAA: - return a.AAAA - default: - return nil - } -} - -// lookupHost find the ipsets for the host, taking subdomain wildcards into -// account. -func (c *ipsetCtx) lookupHost(host string) (sets []ipsetProps) { - // Search for matching ipset hosts starting with most specific - // subdomain. We could use a trie here but the simple, inefficient - // solution isn't that expensive. ~75 % for 10 subdomains vs 0, but - // still sub-microsecond on a Core i7. - // - // TODO(a.garipov): Re-add benchmarks from the original PR. - for i := 0; i != -1; i++ { - host = host[i:] - sets = c.domainToIpsets[host] - if sets != nil { - return sets - } - - i = strings.Index(host, ".") - if i == -1 { - break - } - } - - // Check the root catch-all one. - return c.domainToIpsets[""] -} - -// addIPs adds the IP addresses for the host to the ipset. set must be same -// family as set's family. -func (c *ipsetCtx) addIPs(host string, set ipsetProps, ips []net.IP) (err error) { - if len(ips) == 0 { - return - } - - entries := make([]*ipset.Entry, 0, len(ips)) - for _, ip := range ips { - entries = append(entries, ipset.NewEntry(ipset.EntryIP(ip))) - } - - var conn *ipset.Conn - switch set.family { - case netfilter.ProtoIPv4: - conn = c.ipv4Conn - case netfilter.ProtoIPv6: - conn = c.ipv6Conn - default: - return fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name) - } - - err = conn.Add(set.name, entries...) - if err != nil { - return fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err) - } - - log.Debug("ipset: added %s%s to ipset %s", host, ips, set.name) - - return nil -} - -// skipIpsetProcessing returns true when the ipset processing can be skipped for -// this request. -func (c *ipsetCtx) skipIpsetProcessing(ctx *dnsContext) (ok bool) { - if len(c.domainToIpsets) == 0 || ctx == nil || !ctx.responseFromUpstream { - return true - } - - req := ctx.proxyCtx.Req - if req == nil || len(req.Question) == 0 { - return true - } - - qt := req.Question[0].Qtype - return qt != dns.TypeA && qt != dns.TypeAAAA && qt != dns.TypeANY -} - -// process adds the resolved IP addresses to the domain's ipsets, if any. -func (c *ipsetCtx) process(ctx *dnsContext) (rc resultCode) { - var err error - - if c == nil { - return resultCodeSuccess - } - - log.Debug("ipset: starting processing") - - c.mu.Lock() - defer c.mu.Unlock() - - if c.skipIpsetProcessing(ctx) { - log.Debug("ipset: skipped processing for request") - - return resultCodeSuccess - } - - req := ctx.proxyCtx.Req - host := req.Question[0].Name - host = strings.TrimSuffix(host, ".") - host = strings.ToLower(host) - sets := c.lookupHost(host) - if len(sets) == 0 { - log.Debug("ipset: no ipsets for host %s", host) - - return resultCodeSuccess - } - - log.Debug("ipset: found ipsets %+v for host %s", sets, host) - - if ctx.proxyCtx.Res == nil { - return resultCodeSuccess - } - - ans := ctx.proxyCtx.Res.Answer - l := len(ans) - v4s := make([]net.IP, 0, l) - v6s := make([]net.IP, 0, l) - for _, rr := range ans { - ip := ipFromRR(rr) - if ip == nil { - continue - } - - var iparr [16]byte - copy(iparr[:], ip.To16()) - if _, added := c.addedIPs[iparr]; added { - continue - } - - if ip.To4() == nil { - v6s = append(v6s, ip) - - continue - } - - v4s = append(v4s, ip) - } - -setLoop: - for _, set := range sets { - switch set.family { - case netfilter.ProtoIPv4: - err = c.addIPs(host, set, v4s) - if err != nil { - break setLoop - } - case netfilter.ProtoIPv6: - err = c.addIPs(host, set, v6s) - if err != nil { - break setLoop - } - default: - err = fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name) - break setLoop - } - } - if err != nil { - log.Error("ipset: adding host ips: %s", err) - } else { - log.Debug("ipset: processed %d new ips", len(v4s)+len(v6s)) - } - - for _, ip := range v4s { - var iparr [16]byte - copy(iparr[:], ip.To16()) - c.addedIPs[iparr] = struct{}{} - } - - for _, ip := range v6s { - var iparr [16]byte - copy(iparr[:], ip.To16()) - c.addedIPs[iparr] = struct{}{} - } - - return resultCodeSuccess -} diff --git a/internal/dnsforward/ipset_others.go b/internal/dnsforward/ipset_others.go deleted file mode 100644 index fad0b341..00000000 --- a/internal/dnsforward/ipset_others.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build !linux -// +build !linux - -package dnsforward - -import ( - "github.com/AdguardTeam/golibs/log" -) - -type ipsetCtx struct{} - -// init initializes the ipset context. -func (c *ipsetCtx) init(ipsetConfig []string) (err error) { - if len(ipsetConfig) != 0 { - log.Info("ipset: only available on linux") - } - - return nil -} - -// process adds the resolved IP addresses to the domain's ipsets, if any. -func (c *ipsetCtx) process(_ *dnsContext) (rc resultCode) { - return resultCodeSuccess -} - -// Close closes the Linux Netfilter connections. -func (c *ipsetCtx) Close() (_ error) { return nil } diff --git a/internal/dnsforward/ipset_test.go b/internal/dnsforward/ipset_test.go new file mode 100644 index 00000000..a46deec1 --- /dev/null +++ b/internal/dnsforward/ipset_test.go @@ -0,0 +1,116 @@ +package dnsforward + +import ( + "net" + "testing" + + "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" +) + +// fakeIpsetMgr is a fake aghnet.IpsetManager for tests. +type fakeIpsetMgr struct { + ip4s []net.IP + ip6s []net.IP +} + +// Add implements the aghnet.IpsetManager inteface for *fakeIpsetMgr. +func (m *fakeIpsetMgr) Add(host string, ip4s, ip6s []net.IP) (n int, err error) { + m.ip4s = append(m.ip4s, ip4s...) + m.ip6s = append(m.ip6s, ip6s...) + + return len(ip4s) + len(ip6s), nil +} + +// Close implements the aghnet.IpsetManager interface for *fakeIpsetMgr. +func (*fakeIpsetMgr) Close() (err error) { + return nil +} + +func TestIpsetCtx_process(t *testing.T) { + ip4 := net.IP{1, 2, 3, 4} + ip6 := net.IP{ + 0x12, 0x34, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x56, 0x78, + } + + req4 := createTestMessageWithType("example.com", dns.TypeA) + req6 := createTestMessageWithType("example.com", dns.TypeAAAA) + + resp4 := &dns.Msg{ + Answer: []dns.RR{&dns.A{ + A: ip4, + }}, + } + resp6 := &dns.Msg{ + Answer: []dns.RR{&dns.AAAA{ + AAAA: ip6, + }}, + } + + t.Run("nil", func(t *testing.T) { + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{}, + + responseFromUpstream: true, + } + + ictx := &ipsetCtx{} + rc := ictx.process(dctx) + assert.Equal(t, resultCodeSuccess, rc) + + err := ictx.close() + assert.NoError(t, err) + }) + + t.Run("ipv4", func(t *testing.T) { + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Req: req4, + Res: resp4, + }, + + responseFromUpstream: true, + } + + m := &fakeIpsetMgr{} + ictx := &ipsetCtx{ + ipsetMgr: m, + } + + rc := ictx.process(dctx) + assert.Equal(t, resultCodeSuccess, rc) + assert.Equal(t, []net.IP{ip4}, m.ip4s) + assert.Empty(t, m.ip6s) + + err := ictx.close() + assert.NoError(t, err) + }) + + t.Run("ipv6", func(t *testing.T) { + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Req: req6, + Res: resp6, + }, + + responseFromUpstream: true, + } + + m := &fakeIpsetMgr{} + ictx := &ipsetCtx{ + ipsetMgr: m, + } + + rc := ictx.process(dctx) + assert.Equal(t, resultCodeSuccess, rc) + assert.Empty(t, m.ip4s) + assert.Equal(t, []net.IP{ip6}, m.ip6s) + + err := ictx.close() + assert.NoError(t, err) + }) +}