diff --git a/py/compile.c b/py/compile.c index 1fc5f07227..95fe1d759f 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1933,6 +1933,11 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { // optimisation for a, b = c, d; to match CPython's optimisation mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0]; mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0]; + if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr) + || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr)) { + // can't optimise when it's a star expression on the lhs + goto no_optimisation; + } compile_node(comp, pns10->nodes[0]); // rhs compile_node(comp, pns10->nodes[1]); // rhs EMIT(rot_two); @@ -1945,6 +1950,12 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { // optimisation for a, b, c = d, e, f; to match CPython's optimisation mp_parse_node_struct_t* pns10 = (mp_parse_node_struct_t*)pns1->nodes[0]; mp_parse_node_struct_t* pns0 = (mp_parse_node_struct_t*)pns->nodes[0]; + if (MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[0], PN_star_expr) + || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[1], PN_star_expr) + || MP_PARSE_NODE_IS_STRUCT_KIND(pns0->nodes[2], PN_star_expr)) { + // can't optimise when it's a star expression on the lhs + goto no_optimisation; + } compile_node(comp, pns10->nodes[0]); // rhs compile_node(comp, pns10->nodes[1]); // rhs compile_node(comp, pns10->nodes[2]); // rhs @@ -1954,6 +1965,7 @@ void compile_expr_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { c_assign(comp, pns0->nodes[1], ASSIGN_STORE); // lhs store c_assign(comp, pns0->nodes[2], ASSIGN_STORE); // lhs store } else { + no_optimisation: compile_node(comp, pns1->nodes[0]); // rhs c_assign(comp, pns->nodes[0], ASSIGN_STORE); // lhs store } diff --git a/py/obj.h b/py/obj.h index fc99055b6e..77cf7838ee 100644 --- a/py/obj.h +++ b/py/obj.h @@ -440,6 +440,7 @@ mp_obj_t mp_obj_tuple_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m // list mp_obj_t mp_obj_list_append(mp_obj_t self_in, mp_obj_t arg); void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items); +void mp_obj_list_set_len(mp_obj_t self_in, uint len); void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value); mp_obj_t mp_obj_list_sort(uint n_args, const mp_obj_t *args, mp_map_t *kwargs); diff --git a/py/objlist.c b/py/objlist.c index 620bf2944a..371d1cb26e 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -378,6 +378,13 @@ void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items) { *items = self->items; } +void mp_obj_list_set_len(mp_obj_t self_in, uint len) { + // trust that the caller knows what it's doing + // TODO realloc if len got much smaller than alloc + mp_obj_list_t *self = self_in; + self->len = len; +} + void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { mp_obj_list_t *self = self_in; uint i = mp_get_index(self->base.type, self->len, index, false); diff --git a/py/runtime.c b/py/runtime.c index 44e0ded507..3d1ae72c2f 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -672,6 +672,70 @@ too_long: nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "too many values to unpack (expected %d)", num)); } +// unpacked items are stored in reverse order into the array pointed to by items +void mp_unpack_ex(mp_obj_t seq_in, uint num_in, mp_obj_t *items) { + uint num_left = num_in & 0xff; + uint num_right = (num_in >> 8) & 0xff; + DEBUG_OP_printf("unpack ex %d %d\n", num_left, num_right); + uint seq_len; + if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple) || MP_OBJ_IS_TYPE(seq_in, &mp_type_list)) { + mp_obj_t *seq_items; + if (MP_OBJ_IS_TYPE(seq_in, &mp_type_tuple)) { + mp_obj_tuple_get(seq_in, &seq_len, &seq_items); + } else { + if (num_left == 0 && num_right == 0) { + // *a, = b # sets a to b if b is a list + items[0] = seq_in; + return; + } + mp_obj_list_get(seq_in, &seq_len, &seq_items); + } + if (seq_len < num_left + num_right) { + goto too_short; + } + for (uint i = 0; i < num_right; i++) { + items[i] = seq_items[seq_len - 1 - i]; + } + items[num_right] = mp_obj_new_list(seq_len - num_left - num_right, seq_items + num_left); + for (uint i = 0; i < num_left; i++) { + items[num_right + 1 + i] = seq_items[num_left - 1 - i]; + } + } else { + // Generic iterable; this gets a bit messy: we unpack known left length to the + // items destination array, then the rest to a dynamically created list. Once the + // iterable is exhausted, we take from this list for the right part of the items. + // TODO Improve to waste less memory in the dynamically created list. + mp_obj_t iterable = mp_getiter(seq_in); + mp_obj_t item; + for (seq_len = 0; seq_len < num_left; seq_len++) { + item = mp_iternext(iterable); + if (item == MP_OBJ_NULL) { + goto too_short; + } + items[num_left + num_right + 1 - 1 - seq_len] = item; + } + mp_obj_t rest = mp_obj_new_list(0, NULL); + while ((item = mp_iternext(iterable)) != MP_OBJ_NULL) { + mp_obj_list_append(rest, item); + } + uint rest_len; + mp_obj_t *rest_items; + mp_obj_list_get(rest, &rest_len, &rest_items); + if (rest_len < num_right) { + goto too_short; + } + items[num_right] = rest; + for (uint i = 0; i < num_right; i++) { + items[num_right - 1 - i] = rest_items[rest_len - num_right + i]; + } + mp_obj_list_set_len(rest, rest_len - num_right); + } + return; + +too_short: + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_ValueError, "need more than %d values to unpack", seq_len)); +} + mp_obj_t mp_load_attr(mp_obj_t base, qstr attr) { DEBUG_OP_printf("load attr %p.%s\n", base, qstr_str(attr)); // use load_method diff --git a/py/runtime.h b/py/runtime.h index cc76186f4e..ab34be2da9 100644 --- a/py/runtime.h +++ b/py/runtime.h @@ -44,6 +44,7 @@ mp_obj_t mp_call_method_n_kw(uint n_args, uint n_kw, const mp_obj_t *args); mp_obj_t mp_call_method_n_kw_var(bool have_self, uint n_args_n_kw, const mp_obj_t *args); void mp_unpack_sequence(mp_obj_t seq, uint num, mp_obj_t *items); +void mp_unpack_ex(mp_obj_t seq, uint num, mp_obj_t *items); mp_obj_t mp_store_map(mp_obj_t map, mp_obj_t key, mp_obj_t value); mp_obj_t mp_load_attr(mp_obj_t base, qstr attr); void mp_load_method(mp_obj_t base, qstr attr, mp_obj_t *dest); diff --git a/py/vm.c b/py/vm.c index 2e64cd9573..869a9381ad 100644 --- a/py/vm.c +++ b/py/vm.c @@ -653,6 +653,12 @@ unwind_jump: sp += unum - 1; break; + case MP_BC_UNPACK_EX: + DECODE_UINT; + mp_unpack_ex(sp[0], unum, sp); + sp += (unum & 0xff) + ((unum >> 8) & 0xff); + break; + case MP_BC_MAKE_FUNCTION: DECODE_UINT; PUSH(mp_make_function_from_id(unum, MP_OBJ_NULL, MP_OBJ_NULL)); diff --git a/tests/basics/unpack1.py b/tests/basics/unpack1.py new file mode 100644 index 0000000000..b2b2ddb7e4 --- /dev/null +++ b/tests/basics/unpack1.py @@ -0,0 +1,60 @@ +# unpack sequences + +a, = 1, ; print(a) +a, b = 2, 3 ; print(a, b) + +a, b = range(2); print(a, b) + +# with star + +*a, = () ; print(a) +*a, = 4, ; print(a) +*a, = 5, 6 ; print(a) + +*a, b = 7, ; print(a, b) +*a, b = 8, 9 ; print(a, b) +*a, b = 10, 11, 12 ; print(a, b) + +a, *b = 13, ; print(a, b) +a, *b = 14, 15 ; print(a, b) +a, *b = 16, 17, 18 ; print(a, b) + +a, *b, c = 19, 20 ; print(a, b) +a, *b, c = 21, 22, 23 ; print(a, b) +a, *b, c = 24, 25, 26, 27 ; print(a, b) + +a = [28, 29] +*b, = a +print(a, b, a == b) + +try: + a, *b, c = (30,) +except ValueError: + print("ValueError") + +# with star and generic iterator + +*a, = range(5) ; print(a) +*a, b = range(5) ; print(a, b) +*a, b, c = range(5) ; print(a, b, c) +a, *b = range(5) ; print(a, b) +a, *b, c = range(5) ; print(a, b, c) +a, *b, c, d = range(5) ; print(a, b, c, d) +a, b, *c = range(5) ; print(a, b, c) +a, b, *c, d = range(5) ; print(a, b, c, d) +a, b, *c, d, e = range(5) ; print(a, b, c, d, e) + +*a, = [x * 2 for x in [1, 2, 3, 4]] ; print(a) +*a, b = [x * 2 for x in [1, 2, 3, 4]] ; print(a, b) +a, *b = [x * 2 for x in [1, 2, 3, 4]] ; print(a, b) +a, *b, c = [x * 2 for x in [1, 2, 3, 4]]; print(a, b, c) + +try: + a, *b, c = range(0) +except ValueError: + print("ValueError") + +try: + a, *b, c = range(1) +except ValueError: + print("ValueError")