py: Allow +, in, and compare ops between bytes and bytearray/array.
Eg b"123" + bytearray(2) now works. This patch actually decreases code size while adding functionality: 32-bit unix down by 128 bytes, stmhal down by 84 bytes.
This commit is contained in:
parent
346aacf27f
commit
a65c03c6c0
146
py/objstr.c
146
py/objstr.c
|
@ -285,61 +285,8 @@ STATIC const byte *find_subbytes(const byte *haystack, mp_uint_t hlen, const byt
|
||||||
// works because both those types use it as their binary_op method. Revisit
|
// works because both those types use it as their binary_op method. Revisit
|
||||||
// MP_OBJ_IS_STR_OR_BYTES if this fact changes.
|
// MP_OBJ_IS_STR_OR_BYTES if this fact changes.
|
||||||
mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
|
mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
|
||||||
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
|
// check for modulo
|
||||||
mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
|
if (op == MP_BINARY_OP_MODULO) {
|
||||||
mp_obj_type_t *rhs_type = mp_obj_get_type(rhs_in);
|
|
||||||
switch (op) {
|
|
||||||
case MP_BINARY_OP_ADD:
|
|
||||||
case MP_BINARY_OP_INPLACE_ADD:
|
|
||||||
if (lhs_type == rhs_type) {
|
|
||||||
// add 2 strings or bytes
|
|
||||||
|
|
||||||
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
|
|
||||||
mp_uint_t alloc_len = lhs_len + rhs_len;
|
|
||||||
|
|
||||||
/* code for making qstr
|
|
||||||
byte *q_ptr;
|
|
||||||
byte *val = qstr_build_start(alloc_len, &q_ptr);
|
|
||||||
memcpy(val, lhs_data, lhs_len);
|
|
||||||
memcpy(val + lhs_len, rhs_data, rhs_len);
|
|
||||||
return MP_OBJ_NEW_QSTR(qstr_build_end(q_ptr));
|
|
||||||
*/
|
|
||||||
|
|
||||||
// code for non-qstr
|
|
||||||
byte *data;
|
|
||||||
mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
|
|
||||||
memcpy(data, lhs_data, lhs_len);
|
|
||||||
memcpy(data + lhs_len, rhs_data, rhs_len);
|
|
||||||
return mp_obj_str_builder_end(s);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case MP_BINARY_OP_IN:
|
|
||||||
/* NOTE `a in b` is `b.__contains__(a)` */
|
|
||||||
if (lhs_type == rhs_type) {
|
|
||||||
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
|
|
||||||
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case MP_BINARY_OP_MULTIPLY: {
|
|
||||||
mp_int_t n;
|
|
||||||
if (!mp_obj_get_int_maybe(rhs_in, &n)) {
|
|
||||||
return MP_OBJ_NULL; // op not supported
|
|
||||||
}
|
|
||||||
if (n <= 0) {
|
|
||||||
if (lhs_type == &mp_type_str) {
|
|
||||||
return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
|
|
||||||
}
|
|
||||||
n = 0;
|
|
||||||
}
|
|
||||||
byte *data;
|
|
||||||
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
|
|
||||||
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
|
|
||||||
return mp_obj_str_builder_end(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
case MP_BINARY_OP_MODULO: {
|
|
||||||
mp_obj_t *args;
|
mp_obj_t *args;
|
||||||
mp_uint_t n_args;
|
mp_uint_t n_args;
|
||||||
mp_obj_t dict = MP_OBJ_NULL;
|
mp_obj_t dict = MP_OBJ_NULL;
|
||||||
|
@ -357,28 +304,89 @@ mp_obj_t mp_obj_str_binary_op(mp_uint_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
|
||||||
return str_modulo_format(lhs_in, n_args, args, dict);
|
return str_modulo_format(lhs_in, n_args, args, dict);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// from now on we need lhs type and data, so extract them
|
||||||
|
mp_obj_type_t *lhs_type = mp_obj_get_type(lhs_in);
|
||||||
|
GET_STR_DATA_LEN(lhs_in, lhs_data, lhs_len);
|
||||||
|
|
||||||
|
// check for multiply
|
||||||
|
if (op == MP_BINARY_OP_MULTIPLY) {
|
||||||
|
mp_int_t n;
|
||||||
|
if (!mp_obj_get_int_maybe(rhs_in, &n)) {
|
||||||
|
return MP_OBJ_NULL; // op not supported
|
||||||
|
}
|
||||||
|
if (n <= 0) {
|
||||||
|
if (lhs_type == &mp_type_str) {
|
||||||
|
return MP_OBJ_NEW_QSTR(MP_QSTR_); // empty str
|
||||||
|
} else {
|
||||||
|
return mp_const_empty_bytes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
byte *data;
|
||||||
|
mp_obj_t s = mp_obj_str_builder_start(lhs_type, lhs_len * n, &data);
|
||||||
|
mp_seq_multiply(lhs_data, sizeof(*lhs_data), lhs_len, n, data);
|
||||||
|
return mp_obj_str_builder_end(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// From now on all operations allow:
|
||||||
|
// - str with str
|
||||||
|
// - bytes with bytes
|
||||||
|
// - bytes with bytearray
|
||||||
|
// - bytes with array.array
|
||||||
|
// To do this efficiently we use the buffer protocol to extract the raw
|
||||||
|
// data for the rhs, but only if the lhs is a bytes object.
|
||||||
|
//
|
||||||
|
// NOTE: CPython does not allow comparison between bytes ard array.array
|
||||||
|
// (even if the array is of type 'b'), even though it allows addition of
|
||||||
|
// such types. We are not compatible with this (we do allow comparison
|
||||||
|
// of bytes with anything that has the buffer protocol). It would be
|
||||||
|
// easy to "fix" this with a bit of extra logic below, but it costs code
|
||||||
|
// size and execution time so we don't.
|
||||||
|
|
||||||
|
const byte *rhs_data;
|
||||||
|
mp_uint_t rhs_len;
|
||||||
|
if (lhs_type == mp_obj_get_type(rhs_in)) {
|
||||||
|
GET_STR_DATA_LEN(rhs_in, rhs_data_, rhs_len_);
|
||||||
|
rhs_data = rhs_data_;
|
||||||
|
rhs_len = rhs_len_;
|
||||||
|
} else if (lhs_type == &mp_type_bytes) {
|
||||||
|
mp_buffer_info_t bufinfo;
|
||||||
|
if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
|
||||||
|
goto incompatible;
|
||||||
|
}
|
||||||
|
rhs_data = bufinfo.buf;
|
||||||
|
rhs_len = bufinfo.len;
|
||||||
|
} else {
|
||||||
|
// incompatible types
|
||||||
|
incompatible:
|
||||||
|
if (op == MP_BINARY_OP_EQUAL) {
|
||||||
|
return mp_const_false; // can check for equality against every type
|
||||||
|
}
|
||||||
|
return MP_OBJ_NULL; // op not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case MP_BINARY_OP_ADD:
|
||||||
|
case MP_BINARY_OP_INPLACE_ADD: {
|
||||||
|
mp_uint_t alloc_len = lhs_len + rhs_len;
|
||||||
|
byte *data;
|
||||||
|
mp_obj_t s = mp_obj_str_builder_start(lhs_type, alloc_len, &data);
|
||||||
|
memcpy(data, lhs_data, lhs_len);
|
||||||
|
memcpy(data + lhs_len, rhs_data, rhs_len);
|
||||||
|
return mp_obj_str_builder_end(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
case MP_BINARY_OP_IN:
|
||||||
|
/* NOTE `a in b` is `b.__contains__(a)` */
|
||||||
|
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
|
||||||
|
|
||||||
//case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
|
//case MP_BINARY_OP_NOT_EQUAL: // This is never passed here
|
||||||
case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
|
case MP_BINARY_OP_EQUAL: // This will be passed only for bytes, str is dealt with in mp_obj_equal()
|
||||||
case MP_BINARY_OP_LESS:
|
case MP_BINARY_OP_LESS:
|
||||||
case MP_BINARY_OP_LESS_EQUAL:
|
case MP_BINARY_OP_LESS_EQUAL:
|
||||||
case MP_BINARY_OP_MORE:
|
case MP_BINARY_OP_MORE:
|
||||||
case MP_BINARY_OP_MORE_EQUAL:
|
case MP_BINARY_OP_MORE_EQUAL:
|
||||||
if (lhs_type == rhs_type) {
|
|
||||||
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
|
|
||||||
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
|
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, rhs_data, rhs_len));
|
||||||
}
|
}
|
||||||
if (lhs_type == &mp_type_bytes) {
|
|
||||||
mp_buffer_info_t bufinfo;
|
|
||||||
if (!mp_get_buffer(rhs_in, &bufinfo, MP_BUFFER_READ)) {
|
|
||||||
goto uncomparable;
|
|
||||||
}
|
|
||||||
return MP_BOOL(mp_seq_cmp_bytes(op, lhs_data, lhs_len, bufinfo.buf, bufinfo.len));
|
|
||||||
}
|
|
||||||
uncomparable:
|
|
||||||
if (op == MP_BINARY_OP_EQUAL) {
|
|
||||||
return mp_const_false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return MP_OBJ_NULL; // op not supported
|
return MP_OBJ_NULL; // op not supported
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
# test bytes + other
|
||||||
|
|
||||||
|
print(b"123" + b"456")
|
||||||
|
print(b"123" + bytearray(2))
|
||||||
|
|
||||||
|
import array
|
||||||
|
|
||||||
|
print(b"123" + array.array('i', [1]))
|
||||||
|
print(b"\x01\x02" + array.array('b', [1, 2]))
|
|
@ -1,7 +1,12 @@
|
||||||
import array
|
|
||||||
|
|
||||||
print(b"1" == 1)
|
print(b"1" == 1)
|
||||||
print(b"123" == bytearray(b"123"))
|
print(b"123" == bytearray(b"123"))
|
||||||
print(b"123" == "123")
|
print(b"123" == "123")
|
||||||
# CPyhon gives False here
|
print(b'123' < bytearray(b"124"))
|
||||||
|
print(b'123' > bytearray(b"122"))
|
||||||
|
print(bytearray(b"23") in b"1234")
|
||||||
|
|
||||||
|
import array
|
||||||
|
|
||||||
|
print(array.array('b', [1, 2]) in b'\x01\x02\x03')
|
||||||
|
# CPython gives False here
|
||||||
#print(b"\x01\x02\x03" == array.array("B", [1, 2, 3]))
|
#print(b"\x01\x02\x03" == array.array("B", [1, 2, 3]))
|
||||||
|
|
Loading…
Reference in New Issue