diff --git a/py/objarray.c b/py/objarray.c index a17ae276e2..58285c8660 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -33,6 +33,7 @@ #include "py/runtime0.h" #include "py/runtime.h" #include "py/binary.h" +#include "py/objstr.h" #if MICROPY_PY_ARRAY || MICROPY_PY_BUILTINS_BYTEARRAY || MICROPY_PY_BUILTINS_MEMORYVIEW @@ -283,6 +284,29 @@ STATIC mp_obj_t array_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) return lhs_in; } + case MP_BINARY_OP_IN: { + /* NOTE `a in b` is `b.__contains__(a)` */ + mp_buffer_info_t lhs_bufinfo; + mp_buffer_info_t rhs_bufinfo; + + // Can search string only in bytearray + if (mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) { + if (!MP_OBJ_IS_TYPE(lhs_in, &mp_type_bytearray)) { + return mp_const_false; + } + array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ); + return mp_obj_new_bool( + find_subbytes(lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len, 1) != NULL); + } + + // Otherwise, can only look for a scalar numeric value in an array + if (MP_OBJ_IS_INT(rhs_in) || mp_obj_is_float(rhs_in)) { + mp_not_implemented(""); + } + + return mp_const_false; + } + case MP_BINARY_OP_EQUAL: { mp_buffer_info_t lhs_bufinfo; mp_buffer_info_t rhs_bufinfo; diff --git a/py/objstr.c b/py/objstr.c index 60f65d8439..0c2d904035 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -245,7 +245,7 @@ wrong_args: // like strstr but with specified length and allows \0 bytes // TODO replace with something more efficient/standard -STATIC const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byte *needle, mp_uint_t nlen, mp_int_t direction) { +const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byte *needle, mp_uint_t nlen, mp_int_t direction) { if (hlen >= nlen) { mp_uint_t str_index, str_index_end; if (direction > 0) { diff --git a/py/objstr.h b/py/objstr.h index 6179a74afc..6b8ad97ec2 100644 --- a/py/objstr.h +++ b/py/objstr.h @@ -71,6 +71,7 @@ mp_int_t mp_obj_str_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_u const byte *str_index_to_ptr(const mp_obj_type_t *type, const byte *self_data, size_t self_len, mp_obj_t index, bool is_slice); +const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byte *needle, mp_uint_t nlen, mp_int_t direction); MP_DECLARE_CONST_FUN_OBJ(str_encode_obj); MP_DECLARE_CONST_FUN_OBJ(str_find_obj);