package aghio import ( "fmt" "io" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestLimitReader(t *testing.T) { testCases := []struct { want error name string n int64 }{{ want: nil, name: "positive", n: 1, }, { want: nil, name: "zero", n: 0, }, { want: fmt.Errorf("aghio: invalid n in LimitReader: -1"), name: "negative", n: -1, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := LimitReader(nil, tc.n) assert.Equal(t, tc.want, err) }) } } func TestLimitedReader_Read(t *testing.T) { testCases := []struct { err error name string rStr string limit int64 want int }{{ err: nil, name: "perfectly_match", rStr: "abc", limit: 3, want: 3, }, { err: io.EOF, name: "eof", rStr: "", limit: 3, want: 0, }, { err: &LimitReachedError{ Limit: 0, }, name: "limit_reached", rStr: "abc", limit: 0, want: 0, }, { err: nil, name: "truncated", rStr: "abc", limit: 2, want: 2, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { readCloser := io.NopCloser(strings.NewReader(tc.rStr)) buf := make([]byte, tc.limit+1) lreader, err := LimitReader(readCloser, tc.limit) require.NoError(t, err) n, err := lreader.Read(buf) require.Equal(t, tc.err, err) assert.Equal(t, tc.want, n) }) } } func TestLimitedReader_LimitReachedError(t *testing.T) { testCases := []struct { err error name string want string }{{ err: &LimitReachedError{ Limit: 0, }, name: "simplest", want: "attempted to read more than 0 bytes", }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.want, tc.err.Error()) }) } }