py/runtime: Properly handle passing user mappings to ** keyword args.

This commit is contained in:
Damien George 2016-05-07 22:02:46 +01:00
parent 12dd8df375
commit 470c429ee1
2 changed files with 36 additions and 10 deletions

View File

@ -714,13 +714,18 @@ void mp_call_prepare_args_n_kw_var(bool have_self, mp_uint_t n_args_n_kw, const
} }
} }
} else { } else {
// generic mapping // generic mapping:
// TODO is calling 'items' on the mapping the correct thing to do here? // - call keys() to get an iterable of all keys in the mapping
mp_obj_t dest[2]; // - call __getitem__ for each key to get the corresponding value
mp_load_method(kw_dict, MP_QSTR_items, dest);
// get the keys iterable
mp_obj_t dest[3];
mp_load_method(kw_dict, MP_QSTR_keys, dest);
mp_obj_t iterable = mp_getiter(mp_call_method_n_kw(0, 0, dest)); mp_obj_t iterable = mp_getiter(mp_call_method_n_kw(0, 0, dest));
mp_obj_t item;
while ((item = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) { mp_obj_t key;
while ((key = mp_iternext(iterable)) != MP_OBJ_STOP_ITERATION) {
// expand size of args array if needed
if (args2_len + 1 >= args2_alloc) { if (args2_len + 1 >= args2_alloc) {
uint new_alloc = args2_alloc * 2; uint new_alloc = args2_alloc * 2;
if (new_alloc < 4) { if (new_alloc < 4) {
@ -729,15 +734,20 @@ void mp_call_prepare_args_n_kw_var(bool have_self, mp_uint_t n_args_n_kw, const
args2 = m_renew(mp_obj_t, args2, args2_alloc, new_alloc); args2 = m_renew(mp_obj_t, args2, args2_alloc, new_alloc);
args2_alloc = new_alloc; args2_alloc = new_alloc;
} }
mp_obj_t *items;
mp_obj_get_array_fixed_n(item, 2, &items);
// the key must be a qstr, so intern it if it's a string // the key must be a qstr, so intern it if it's a string
mp_obj_t key = items[0];
if (MP_OBJ_IS_TYPE(key, &mp_type_str)) { if (MP_OBJ_IS_TYPE(key, &mp_type_str)) {
key = mp_obj_str_intern(key); key = mp_obj_str_intern(key);
} }
// get the value corresponding to the key
mp_load_method(kw_dict, MP_QSTR___getitem__, dest);
dest[2] = key;
mp_obj_t value = mp_call_method_n_kw(1, 0, dest);
// store the key/value pair in the argument array
args2[args2_len++] = key; args2[args2_len++] = key;
args2[args2_len++] = items[1]; args2[args2_len++] = value;
} }
} }

View File

@ -0,0 +1,16 @@
# test passing a user-defined mapping as the argument to **
def foo(**kw):
print(sorted(kw.items()))
class Mapping:
def keys(self):
return ['a', 'b', 'c']
def __getitem__(self, key):
if key == 'a':
return 1
else:
return 2
foo(**Mapping())