diff --git a/py/runtime.c b/py/runtime.c index 67534c4b5e..7f28abbf4f 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -714,13 +714,18 @@ void mp_call_prepare_args_n_kw_var(bool have_self, mp_uint_t n_args_n_kw, const } } } else { - // generic mapping - // TODO is calling 'items' on the mapping the correct thing to do here? - mp_obj_t dest[2]; - mp_load_method(kw_dict, MP_QSTR_items, dest); + // generic mapping: + // - call keys() to get an iterable of all keys in the mapping + // - call __getitem__ for each key to get the corresponding value + + // get the keys iterable + mp_obj_t dest[3]; + mp_load_method(kw_dict, MP_QSTR_keys, dest); mp_obj_t iterable = mp_getiter(mp_call_method_n_kw(0, 0, dest)); - mp_obj_t item; - while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { + + mp_obj_t key; + while ((key = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { + // expand size of args array if needed if (args2_len + 1 >= args2_alloc) { uint new_alloc = args2_alloc * 2; if (new_alloc < 4) { @@ -729,15 +734,20 @@ void mp_call_prepare_args_n_kw_var(bool have_self, mp_uint_t n_args_n_kw, const args2 = m_renew(mp_obj_t, args2, args2_alloc, new_alloc); args2_alloc = new_alloc; } - mp_obj_t *items; - mp_obj_get_array_fixed_n(item, 2, &items); + // the key must be a qstr, so intern it if it's a string - mp_obj_t key = items[0]; if (MP_OBJ_IS_TYPE(key, &mp_type_str)) { key = mp_obj_str_intern(key); } + + // get the value corresponding to the key + mp_load_method(kw_dict, MP_QSTR___getitem__, dest); + dest[2] = key; + mp_obj_t value = mp_call_method_n_kw(1, 0, dest); + + // store the key/value pair in the argument array args2[args2_len++] = key; - args2[args2_len++] = items[1]; + args2[args2_len++] = value; } } diff --git a/tests/basics/fun_calldblstar3.py b/tests/basics/fun_calldblstar3.py new file mode 100644 index 0000000000..4367e68df7 --- /dev/null +++ b/tests/basics/fun_calldblstar3.py @@ -0,0 +1,16 @@ +# test passing a user-defined mapping as the argument to ** + +def foo(**kw): + print(sorted(kw.items())) + +class Mapping: + def keys(self): + return ['a', 'b', 'c'] + + def __getitem__(self, key): + if key == 'a': + return 1 + else: + return 2 + +foo(**Mapping())