diff --git a/py/objstr.c b/py/objstr.c index 35a948700c..00586a3b3b 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -186,19 +186,26 @@ wrong_args: // like strstr but with specified length and allows \0 bytes // TODO replace with something more efficient/standard -STATIC const byte *find_subbytes(const byte *haystack, uint hlen, const byte *needle, uint nlen) { +STATIC const byte *find_subbytes(const byte *haystack, machine_uint_t hlen, const byte *needle, machine_uint_t nlen, machine_int_t direction) { if (hlen >= nlen) { - for (uint i = 0; i <= hlen - nlen; i++) { - bool found = true; - for (uint j = 0; j < nlen; j++) { - if (haystack[i + j] != needle[j]) { - found = false; - break; - } + machine_uint_t str_index, str_index_end; + if (direction > 0) { + str_index = 0; + str_index_end = hlen - nlen; + } else { + str_index = hlen - nlen; + str_index_end = 0; + } + for (;;) { + if (memcmp(&haystack[str_index], needle, nlen) == 0) { + //found + return haystack + str_index; } - if (found) { - return haystack + i; + if (str_index == str_index_end) { + //not found + break; } + str_index += direction; } } return NULL; @@ -260,7 +267,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { /* NOTE `a in b` is `b.__contains__(a)` */ if (MP_OBJ_IS_STR(rhs_in)) { GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len); - return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len) != NULL); + return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL); } break; @@ -382,7 +389,7 @@ STATIC mp_obj_t str_split(uint n_args, const mp_obj_t *args) { return res; } -STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) { +STATIC mp_obj_t str_finder(uint n_args, const mp_obj_t *args, machine_int_t direction) { assert(2 <= n_args && n_args <= 4); assert(MP_OBJ_IS_STR(args[0])); assert(MP_OBJ_IS_STR(args[1])); @@ -392,7 +399,6 @@ STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) { machine_uint_t start = 0; machine_uint_t end = haystack_len; - /* TODO use a non-exception-throwing mp_get_index */ if (n_args >= 3 && args[2] != mp_const_none) { start = mp_get_index(&str_type, haystack_len, args[2], true); } @@ -400,20 +406,24 @@ STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) { end = mp_get_index(&str_type, haystack_len, args[3], true); } - const byte *p = find_subbytes(haystack + start, haystack_len - start, needle, needle_len); + const byte *p = find_subbytes(haystack + start, end - start, needle, needle_len, direction); if (p == NULL) { // not found return MP_OBJ_NEW_SMALL_INT(-1); } else { // found - machine_int_t pos = p - haystack; - if (pos + needle_len > end) { - pos = -1; - } - return MP_OBJ_NEW_SMALL_INT(pos); + return MP_OBJ_NEW_SMALL_INT(p - haystack); } } +STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) { + return str_finder(n_args, args, 1); +} + +STATIC mp_obj_t str_rfind(uint n_args, const mp_obj_t *args) { + return str_finder(n_args, args, -1); +} + // TODO: (Much) more variety in args STATIC mp_obj_t str_startswith(mp_obj_t self_in, mp_obj_t arg) { GET_STR_DATA_LEN(self_in, str, str_len); @@ -424,15 +434,6 @@ STATIC mp_obj_t str_startswith(mp_obj_t self_in, mp_obj_t arg) { return MP_BOOL(memcmp(str, prefix, prefix_len) == 0); } -STATIC bool chr_in_str(const byte* const str, const machine_uint_t str_len, int c) { - for (machine_uint_t i = 0; i < str_len; i++) { - if (str[i] == c) { - return true; - } - } - return false; -} - STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) { assert(1 <= n_args && n_args <= 2); assert(MP_OBJ_IS_STR(args[0])); @@ -457,7 +458,7 @@ STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) { bool first_good_char_pos_set = false; machine_uint_t last_good_char_pos = 0; for (machine_uint_t i = 0; i < orig_str_len; i++) { - if (!chr_in_str(chars_to_del, chars_to_del_len, orig_str[i])) { + if (find_subbytes(chars_to_del, chars_to_del_len, &orig_str[i], 1, 1) == NULL) { last_good_char_pos = i; if (!first_good_char_pos_set) { first_good_char_pos = i; @@ -547,7 +548,7 @@ STATIC mp_obj_t str_replace(uint n_args, const mp_obj_t *args) { const byte *old_occurrence; const byte *offset_ptr = str; machine_uint_t offset_num = 0; - while ((old_occurrence = find_subbytes(offset_ptr, str_len - offset_num, old, old_len)) != NULL) { + while ((old_occurrence = find_subbytes(offset_ptr, str_len - offset_num, old, old_len, 1)) != NULL) { // copy from just after end of last occurrence of to-be-replaced string to right before start of next occurrence if (data != NULL) { memcpy(data + replaced_str_index, offset_ptr, old_occurrence - offset_ptr); @@ -601,7 +602,6 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) { machine_uint_t start = 0; machine_uint_t end = haystack_len; - /* TODO use a non-exception-throwing mp_get_index */ if (n_args >= 3 && args[2] != mp_const_none) { start = mp_get_index(&str_type, haystack_len, args[2], true); } @@ -648,27 +648,12 @@ STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, machine_int_t di result[2] = self_in; } - if (str_len >= sep_len) { - machine_uint_t str_index, str_index_end; - if (direction > 0) { - str_index = 0; - str_index_end = str_len - sep_len; - } else { - str_index = str_len - sep_len; - str_index_end = 0; - } - for (;;) { - if (memcmp(&str[str_index], sep, sep_len) == 0) { - result[0] = mp_obj_new_str(str, str_index, false); - result[1] = arg; - result[2] = mp_obj_new_str(str + str_index + sep_len, str_len - str_index - sep_len, false); - break; - } - if (str_index == str_index_end) { - break; - } - str_index += direction; - } + const byte *position_ptr = find_subbytes(str, str_len, sep, sep_len, direction); + if (position_ptr != NULL) { + machine_uint_t position = position_ptr - str; + result[0] = mp_obj_new_str(str, position, false); + result[1] = arg; + result[2] = mp_obj_new_str(str + position + sep_len, str_len - position - sep_len, false); } return mp_obj_new_tuple(3, result); @@ -697,6 +682,7 @@ STATIC machine_int_t str_get_buffer(mp_obj_t self_in, buffer_info_t *bufinfo, in } STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_find_obj, 2, 4, str_find); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_rfind_obj, 2, 4, str_rfind); STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_join_obj, str_join); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_split_obj, 1, 3, str_split); STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_startswith_obj, str_startswith); @@ -709,6 +695,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_rpartition_obj, str_rpartition); STATIC const mp_method_t str_type_methods[] = { { "find", &str_find_obj }, + { "rfind", &str_rfind_obj }, { "join", &str_join_obj }, { "split", &str_split_obj }, { "startswith", &str_startswith_obj }, diff --git a/tests/basics/string_rfind.py b/tests/basics/string_rfind.py new file mode 100644 index 0000000000..4d0e84018f --- /dev/null +++ b/tests/basics/string_rfind.py @@ -0,0 +1,23 @@ +print("hello world".rfind("ll")) +print("hello world".rfind("ll", None)) +print("hello world".rfind("ll", 1)) +print("hello world".rfind("ll", 1, None)) +print("hello world".rfind("ll", None, None)) +print("hello world".rfind("ll", 1, -1)) +print("hello world".rfind("ll", 1, 1)) +print("hello world".rfind("ll", 1, 2)) +print("hello world".rfind("ll", 1, 3)) +print("hello world".rfind("ll", 1, 4)) +print("hello world".rfind("ll", 1, 5)) +print("hello world".rfind("ll", -100)) +print("0000".rfind('0')) +print("0000".rfind('0', 0)) +print("0000".rfind('0', 1)) +print("0000".rfind('0', 2)) +print("0000".rfind('0', 3)) +print("0000".rfind('0', 4)) +print("0000".rfind('0', 5)) +print("0000".rfind('-1', 3)) +print("0000".rfind('1', 3)) +print("0000".rfind('1', 4)) +print("0000".rfind('1', 5))