diff --git a/py/obj.h b/py/obj.h index bed5103dbd..ca3ab1af6a 100644 --- a/py/obj.h +++ b/py/obj.h @@ -397,3 +397,4 @@ typedef struct _mp_obj_classmethod_t { void mp_seq_multiply(const void *items, uint item_sz, uint len, uint times, void *dest); bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_uint_t *begin, machine_uint_t *end); #define m_seq_copy(dest, src, len, item_sz) memcpy(dest, src, len * sizeof(item_sz)) +bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2); diff --git a/py/objstr.c b/py/objstr.c index 92bd71f3de..50cd31d542 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -169,6 +169,18 @@ mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data); return mp_obj_str_builder_end(s); } + + // These 2 are never passed here, dealt with as a special case in rt_binary_op(). + //case RT_BINARY_OP_EQUAL: + //case RT_BINARY_OP_NOT_EQUAL: + case RT_BINARY_OP_LESS: + case RT_BINARY_OP_LESS_EQUAL: + case RT_BINARY_OP_MORE: + case RT_BINARY_OP_MORE_EQUAL: + if (MP_OBJ_IS_STR(rhs_in)) { + GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); + return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len)); + } } return MP_OBJ_NULL; // op not supported diff --git a/py/sequence.c b/py/sequence.c index 1e851a9f80..74b4fcfdf8 100644 --- a/py/sequence.c +++ b/py/sequence.c @@ -14,6 +14,8 @@ // Helpers for sequence types +#define SWAP(type, var1, var2) { type t = var2; var2 = var1; var1 = t; } + // Implements backend of sequence * integer operation. Assumes elements are // memory-adjacent in sequence. void mp_seq_multiply(const void *items, uint item_sz, uint len, uint times, void *dest) { @@ -53,3 +55,39 @@ bool m_seq_get_fast_slice_indexes(machine_uint_t len, mp_obj_t slice, machine_ui *end = stop; return true; } + +// Special-case comparison function for sequences of bytes +// Don't pass RT_BINARY_OP_NOT_EQUAL here +bool mp_seq_cmp_bytes(int op, const byte *data1, uint len1, const byte *data2, uint len2) { + // Let's deal only with > & >= + if (op == RT_BINARY_OP_LESS || op == RT_BINARY_OP_LESS_EQUAL) { + SWAP(const byte*, data1, data2); + SWAP(uint, len1, len2); + if (op == RT_BINARY_OP_LESS) { + op = RT_BINARY_OP_MORE; + } else { + op = RT_BINARY_OP_MORE_EQUAL; + } + } + uint min_len = len1 < len2 ? len1 : len2; + int res = memcmp(data1, data2, min_len); + if (res < 0) { + return false; + } + if (res > 0) { + return true; + } + + // If we had tie in the last element... + // ... and we have lists of different lengths... + if (len1 != len2) { + if (len1 < len2) { + // ... then longer list length wins (we deal only with >) + return false; + } + } else if (op == RT_BINARY_OP_MORE) { + // Otherwise, if we have strict relation, equality means failure + return false; + } + return true; +} diff --git a/tests/basics/string-compare.py b/tests/basics/string-compare.py new file mode 100644 index 0000000000..740e1959c8 --- /dev/null +++ b/tests/basics/string-compare.py @@ -0,0 +1,50 @@ +print("" == "") +print("" > "") +print("" < "") +print("" == "1") +print("1" == "") +print("" > "1") +print("1" > "") +print("" < "1") +print("1" < "") +print("" >= "1") +print("1" >= "") +print("" <= "1") +print("1" <= "") + +print("1" == "1") +print("1" != "1") +print("1" == "2") +print("1" == "10") + +print("1" > "1") +print("1" > "2") +print("2" > "1") +print("10" > "1") +print("1/" > "1") +print("1" > "10") +print("1" > "1/") + +print("1" < "1") +print("2" < "1") +print("1" < "2") +print("1" < "10") +print("1" < "1/") +print("10" < "1") +print("1/" < "1") + +print("1" >= "1") +print("1" >= "2") +print("2" >= "1") +print("10" >= "1") +print("1/" >= "1") +print("1" >= "10") +print("1" >= "1/") + +print("1" <= "1") +print("2" <= "1") +print("1" <= "2") +print("1" <= "10") +print("1" <= "1/") +print("10" <= "1") +print("1/" <= "1")