//go:build linux

package ipset

import (
	"net"
	"strings"
	"testing"
	"time"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/logutil/slogutil"
	"github.com/AdguardTeam/golibs/testutil"
	"github.com/digineo/go-ipset/v2"
	"github.com/mdlayher/netlink"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/ti-mo/netfilter"
)

// testTimeout is a common timeout for tests and contexts.
const testTimeout = 1 * time.Second

// fakeConn is a fake ipsetConn for tests.
type fakeConn struct {
	ipv4Header  *ipset.HeaderPolicy
	ipv4Entries *[]*ipset.Entry
	ipv6Header  *ipset.HeaderPolicy
	ipv6Entries *[]*ipset.Entry
	sets        []props
}

// type check
var _ ipsetConn = (*fakeConn)(nil)

// Add implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Add(name string, entries ...*ipset.Entry) (err error) {
	if strings.Contains(name, "ipv4") {
		*c.ipv4Entries = append(*c.ipv4Entries, entries...)

		return nil
	} else if strings.Contains(name, "ipv6") {
		*c.ipv6Entries = append(*c.ipv6Entries, entries...)

		return nil
	}

	return errors.Error("test: ipset not found")
}

// Close implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Close() (err error) {
	return nil
}

// Header implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) Header(_ string) (_ *ipset.HeaderPolicy, _ error) {
	return nil, nil
}

// listAll implements the [ipsetConn] interface for *fakeConn.
func (c *fakeConn) listAll() (sets []props, err error) {
	return c.sets, nil
}

func TestManager_Add(t *testing.T) {
	ipsetList := []string{
		"example.com,example.net/ipv4set",
		"example.org,example.biz/ipv6set",
	}

	var ipv4Entries []*ipset.Entry
	var ipv6Entries []*ipset.Entry

	fakeDial := func(
		pf netfilter.ProtoFamily,
		conf *netlink.Config,
	) (conn ipsetConn, err error) {
		return &fakeConn{
			ipv4Header: &ipset.HeaderPolicy{
				Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv4)),
			},
			ipv4Entries: &ipv4Entries,
			ipv6Header: &ipset.HeaderPolicy{
				Family: ipset.NewUInt8Box(uint8(netfilter.ProtoIPv6)),
			},
			ipv6Entries: &ipv6Entries,
			sets: []props{{
				name:   "ipv4set",
				family: netfilter.ProtoIPv4,
			}, {
				name:   "ipv6set",
				family: netfilter.ProtoIPv6,
			}},
		}, nil
	}

	conf := &Config{
		Logger: slogutil.NewDiscardLogger(),
		Lines:  ipsetList,
	}
	m, err := newManagerWithDialer(testutil.ContextWithTimeout(t, testTimeout), conf, fakeDial)
	require.NoError(t, err)

	ip4 := net.IP{1, 2, 3, 4}
	ip6 := net.IP{
		0x12, 0x34, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x00, 0x00,
		0x00, 0x00, 0x56, 0x78,
	}

	n, err := m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.net", []net.IP{ip4}, nil)
	require.NoError(t, err)

	assert.Equal(t, 1, n)

	require.Len(t, ipv4Entries, 1)

	gotIP4 := ipv4Entries[0].IP.Value
	assert.Equal(t, ip4, gotIP4)

	n, err = m.Add(testutil.ContextWithTimeout(t, testTimeout), "example.biz", nil, []net.IP{ip6})
	require.NoError(t, err)

	assert.Equal(t, 1, n)

	require.Len(t, ipv6Entries, 1)

	gotIP6 := ipv6Entries[0].IP.Value
	assert.Equal(t, ip6, gotIP6)

	err = m.Close()
	assert.NoError(t, err)
}

// ipsetPropsSink is the typed sink for benchmark results.
var ipsetPropsSink []props

func BenchmarkManager_LookupHost(b *testing.B) {
	propsLong := []props{{
		name:   "example.com",
		family: netfilter.ProtoIPv4,
	}}

	propsShort := []props{{
		name:   "example.net",
		family: netfilter.ProtoIPv4,
	}}

	m := &manager{
		domainToIpsets: map[string][]props{
			"":            propsLong,
			"example.net": propsShort,
		},
	}

	b.Run("long", func(b *testing.B) {
		const name = "a.very.long.domain.name.inside.the.domain.example.com"
		for range b.N {
			ipsetPropsSink = m.lookupHost(name)
		}

		require.Equal(b, propsLong, ipsetPropsSink)
	})

	b.Run("short", func(b *testing.B) {
		const name = "example.net"
		for range b.N {
			ipsetPropsSink = m.lookupHost(name)
		}

		require.Equal(b, propsShort, ipsetPropsSink)
	})
}