package home

import (
	"net"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestAuthRateLimiter_Cleanup(t *testing.T) {
	const key = "some-key"
	now := time.Now()

	testCases := []struct {
		name    string
		att     failedAuth
		wantExp bool
	}{{
		name: "expired",
		att: failedAuth{
			until: now.Add(-100 * time.Hour),
		},
		wantExp: true,
	}, {
		name: "nope_yet",
		att: failedAuth{
			until: now.Add(failedAuthTTL / 2),
		},
		wantExp: false,
	}, {
		name: "blocked",
		att: failedAuth{
			until: now.Add(100 * time.Hour),
		},
		wantExp: false,
	}}

	for _, tc := range testCases {
		ab := &authRateLimiter{
			failedAuths: map[string]failedAuth{
				key: tc.att,
			},
		}
		t.Run(tc.name, func(t *testing.T) {
			ab.cleanupLocked(now)
			if tc.wantExp {
				assert.Empty(t, ab.failedAuths)

				return
			}

			require.Len(t, ab.failedAuths, 1)

			_, ok := ab.failedAuths[key]
			require.True(t, ok)
		})
	}
}

func TestAuthRateLimiter_Check(t *testing.T) {
	key := string(net.IP{127, 0, 0, 1})
	const maxAtt = 1
	now := time.Now()

	testCases := []struct {
		until   time.Time
		name    string
		num     uint
		wantExp bool
	}{{
		until:   now.Add(-100 * time.Hour),
		name:    "expired",
		num:     0,
		wantExp: true,
	}, {
		until:   now.Add(failedAuthTTL),
		name:    "not_blocked_but_tracked",
		num:     0,
		wantExp: true,
	}, {
		until:   now,
		name:    "expired_but_stayed",
		num:     2,
		wantExp: true,
	}, {
		until:   now.Add(100 * time.Hour),
		name:    "blocked",
		num:     2,
		wantExp: false,
	}}

	for _, tc := range testCases {
		failedAuths := map[string]failedAuth{
			key: {
				num:   tc.num,
				until: tc.until,
			},
		}
		ab := &authRateLimiter{
			maxAttempts: maxAtt,
			failedAuths: failedAuths,
		}
		t.Run(tc.name, func(t *testing.T) {
			until := ab.check(key)

			if tc.wantExp {
				assert.LessOrEqual(t, until, time.Duration(0))
			} else {
				assert.Greater(t, until, time.Duration(0))
			}
		})
	}

	t.Run("non-existent", func(t *testing.T) {
		ab := &authRateLimiter{
			failedAuths: map[string]failedAuth{
				key + "smthng": {},
			},
		}

		until := ab.check(key)

		assert.Zero(t, until)
	})
}

func TestAuthRateLimiter_Inc(t *testing.T) {
	ip := net.IP{127, 0, 0, 1}
	key := string(ip)
	now := time.Now()
	const maxAtt = 2
	const blockDur = 15 * time.Minute

	testCases := []struct {
		until     time.Time
		wantUntil time.Time
		name      string
		num       uint
		wantNum   uint
	}{{
		name:      "only_inc",
		until:     now,
		wantUntil: now,
		num:       maxAtt - 1,
		wantNum:   maxAtt,
	}, {
		name:      "inc_and_block",
		until:     now,
		wantUntil: now.Add(failedAuthTTL),
		num:       maxAtt,
		wantNum:   maxAtt + 1,
	}}

	for _, tc := range testCases {
		failedAuths := map[string]failedAuth{
			key: {
				num:   tc.num,
				until: tc.until,
			},
		}
		ab := &authRateLimiter{
			blockDur:    blockDur,
			maxAttempts: maxAtt,
			failedAuths: failedAuths,
		}
		t.Run(tc.name, func(t *testing.T) {
			ab.inc(key)

			a, ok := ab.failedAuths[key]
			require.True(t, ok)

			assert.Equal(t, tc.wantNum, a.num)
			assert.LessOrEqual(t, tc.wantUntil.Unix(), a.until.Unix())
		})
	}

	t.Run("non-existent", func(t *testing.T) {
		ab := &authRateLimiter{
			blockDur:    blockDur,
			maxAttempts: maxAtt,
			failedAuths: map[string]failedAuth{},
		}

		ab.inc(key)

		a, ok := ab.failedAuths[key]
		require.True(t, ok)
		assert.EqualValues(t, 1, a.num)
	})
}

func TestAuthRateLimiter_Remove(t *testing.T) {
	const key = "some-key"

	failedAuths := map[string]failedAuth{
		key: {},
	}
	ab := &authRateLimiter{
		failedAuths: failedAuths,
	}

	ab.remove(key)

	assert.Empty(t, ab.failedAuths)
}