diff --git a/py/obj.h b/py/obj.h index 4f32a808f3..7bf00d4dee 100644 --- a/py/obj.h +++ b/py/obj.h @@ -74,7 +74,7 @@ typedef struct _mp_obj_base_t mp_obj_base_t; #define MP_OBJ_IS_TYPE(o, t) (MP_OBJ_IS_OBJ(o) && (((mp_obj_base_t*)(o))->type == (t))) // this does not work for checking int, str or fun; use below macros for that #define MP_OBJ_IS_INT(o) (MP_OBJ_IS_SMALL_INT(o) || MP_OBJ_IS_TYPE(o, &mp_type_int)) #define MP_OBJ_IS_STR(o) (MP_OBJ_IS_QSTR(o) || MP_OBJ_IS_TYPE(o, &mp_type_str)) -#define MP_OBJ_IS_STR_OR_BYTES(o) (MP_OBJ_IS_STR(o) || MP_OBJ_IS_TYPE(o, &mp_type_bytes)) +#define MP_OBJ_IS_STR_OR_BYTES(o) (MP_OBJ_IS_QSTR(o) || (MP_OBJ_IS_OBJ(o) && ((mp_obj_base_t*)(o))->type->binary_op == mp_obj_str_binary_op)) #define MP_OBJ_IS_FUN(o) (MP_OBJ_IS_OBJ(o) && (((mp_obj_base_t*)(o))->type->binary_op == mp_obj_fun_binary_op)) #define MP_OBJ_SMALL_INT_VALUE(o) (((mp_int_t)(o)) >> 1) diff --git a/py/objint.c b/py/objint.c index a771383502..c190c1800b 100644 --- a/py/objint.c +++ b/py/objint.c @@ -38,6 +38,7 @@ #include "smallint.h" #include "mpz.h" #include "objint.h" +#include "objstr.h" #include "runtime0.h" #include "runtime.h" diff --git a/py/objstr.c b/py/objstr.c index f38532963c..366ba88163 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -247,6 +247,9 @@ STATIC const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byt return NULL; } +// Note: this function is used to check if an object is a str or bytes, which +// works because both those types use it as their binary_op method. Revisit +// MP_OBJ_IS_STR_OR_BYTES if this fact changes. mp_obj_t mp_obj_str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len); mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);