diff --git a/internal/ipset/ipset_linux.go b/internal/ipset/ipset_linux.go index 68f53dcb..6bc9add1 100644 --- a/internal/ipset/ipset_linux.go +++ b/internal/ipset/ipset_linux.go @@ -101,6 +101,7 @@ func (qc *queryConn) listAll() (sets []props, err error) { type ipsetConn interface { Add(name string, entries ...*ipset.Entry) (err error) Close() (err error) + Header(name string) (p *ipset.HeaderPolicy, err error) listAll() (sets []props, err error) } @@ -112,6 +113,9 @@ type props struct { // name of the ipset. name string + // typeName of the ipset. + typeName string + // family of the IP addresses in the ipset. family netfilter.ProtoFamily @@ -148,6 +152,8 @@ func (p *props) parseAttribute(a netfilter.Attribute) { case ipset.AttrSetName: // Trim the null character. p.name = string(bytes.Trim(a.Data, "\x00")) + case ipset.AttrTypeName: + p.typeName = string(bytes.Trim(a.Data, "\x00")) case ipset.AttrFamily: p.family = netfilter.ProtoFamily(a.Data[0]) default: @@ -288,6 +294,34 @@ func (m *manager) parseIpsetConfig(ipsetConf []string) (err error) { return nil } +// ipsetProps returns the properties of an ipset with the given name. +// +// Additional header data query. See https://github.com/AdguardTeam/AdGuardHome/issues/6420. +func (m *manager) ipsetProps(p props) (err error) { + // The family doesn't seem to matter when we use a header query, so + // query only the IPv4 one. + // + // TODO(a.garipov): Find out if this is a bug or a feature. + var res *ipset.HeaderPolicy + res, err = m.ipv4Conn.Header(p.name) + if err != nil { + return err + } + + if res == nil || res.Family == nil { + return errors.Error("empty response or no family data") + } + + family := netfilter.ProtoFamily(res.Family.Value) + if family != netfilter.ProtoIPv4 && family != netfilter.ProtoIPv6 { + return fmt.Errorf("unexpected ipset family %q", family) + } + + p.family = family + + return nil +} + // ipsets returns currently known ipsets. func (m *manager) ipsets(names []string) (sets []props, err error) { for _, n := range names { @@ -297,7 +331,16 @@ func (m *manager) ipsets(names []string) (sets []props, err error) { } if p.family != netfilter.ProtoIPv4 && p.family != netfilter.ProtoIPv6 { - return nil, fmt.Errorf("%q unexpected ipset family %q", p.name, p.family) + log.Debug("ipset: getting properties: %q %q unexpected ipset family %q", + p.name, + p.typeName, + p.family, + ) + + err = m.ipsetProps(p) + if err != nil { + return nil, fmt.Errorf("%q %q making header query: %w", p.name, p.typeName, err) + } } sets = append(sets, p) @@ -340,6 +383,8 @@ func newManagerWithDialer(ipsetConf []string, dial dialer) (mgr Manager, err err return nil, fmt.Errorf("getting ipsets: %w", err) } + log.Debug("ipset: initialized") + return m, nil } @@ -408,7 +453,7 @@ func (m *manager) addIPs(host string, set props, ips []net.IP) (n int, err error err = conn.Add(set.name, entries...) if err != nil { - return 0, fmt.Errorf("adding %q%s to ipset %q: %w", host, ips, set.name, err) + return 0, fmt.Errorf("adding %q%s to %q %q: %w", host, ips, set.name, set.typeName, err) } // Only add these to the cache once we're sure that all of them were @@ -444,10 +489,10 @@ func (m *manager) addToSets( return n, err } default: - return n, fmt.Errorf("unexpected family %s for ipset %q", set.family, set.name) + return n, fmt.Errorf("%q %q unexpected family %q", set.name, set.typeName, set.family) } - log.Debug("ipset: added %d ips to set %s", nn, set.name) + log.Debug("ipset: added %d ips to set %q %q", nn, set.name, set.typeName) n += nn } diff --git a/internal/ipset/ipset_linux_internal_test.go b/internal/ipset/ipset_linux_internal_test.go index 84e25650..f22d93c1 100644 --- a/internal/ipset/ipset_linux_internal_test.go +++ b/internal/ipset/ipset_linux_internal_test.go @@ -47,6 +47,11 @@ func (c *fakeConn) Close() (err error) { return nil } +// Header implements the [ipsetConn] interface for *fakeConn. +func (c *fakeConn) Header(_ string) (_ *ipset.HeaderPolicy, _ error) { + return nil, nil +} + // listAll implements the [ipsetConn] interface for *fakeConn. func (c *fakeConn) listAll() (sets []props, err error) { return c.sets, nil