diff --git a/py/objtype.c b/py/objtype.c index 07aa56434a..d10d6cbd5a 100644 --- a/py/objtype.c +++ b/py/objtype.c @@ -666,6 +666,25 @@ STATIC mp_obj_t instance_getiter(mp_obj_t self_in) { return mp_call_function_n_kw(meth, 0, 0, NULL); } +STATIC mp_int_t instance_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) { + mp_obj_instance_t *self = self_in; + mp_obj_t member[2] = {MP_OBJ_NULL}; + struct class_lookup_data lookup = { + .obj = self, + .attr = MP_QSTR_, // don't actually look for a method + .meth_offset = offsetof(mp_obj_type_t, buffer_p.get_buffer), + .dest = member, + .is_type = false, + }; + mp_obj_class_lookup(&lookup, self->base.type); + if (member[0] == MP_OBJ_SENTINEL) { + mp_obj_type_t *type = mp_obj_get_type(self->subobj[0]); + return type->buffer_p.get_buffer(self->subobj[0], bufinfo, flags); + } else { + return 1; // object does not support buffer protocol + } +} + /******************************************************************************/ // type object // - the struct is mp_obj_type_t and is defined in obj.h so const types can be made @@ -807,13 +826,16 @@ mp_obj_t mp_obj_new_type(qstr name, mp_obj_t bases_tuple, mp_obj_t locals_dict) o->name = name; o->print = instance_print; o->make_new = instance_make_new; + o->call = mp_obj_instance_call; o->unary_op = instance_unary_op; o->binary_op = instance_binary_op; o->load_attr = mp_obj_instance_load_attr; o->store_attr = mp_obj_instance_store_attr; o->subscr = instance_subscr; - o->call = mp_obj_instance_call; o->getiter = instance_getiter; + //o->iternext = ; not implemented + o->buffer_p.get_buffer = instance_get_buffer; + //o->stream_p = ; not implemented o->bases_tuple = bases_tuple; o->locals_dict = locals_dict; diff --git a/tests/basics/subclass_native_buffer.py b/tests/basics/subclass_native_buffer.py new file mode 100644 index 0000000000..43c3819657 --- /dev/null +++ b/tests/basics/subclass_native_buffer.py @@ -0,0 +1,16 @@ +# test when we subclass a type with the buffer protocol + +class my_bytes(bytes): + pass + +b1 = my_bytes([0, 1]) +b2 = my_bytes([2, 3]) +b3 = bytes([4, 5]) + +# addition will use the buffer protocol on the RHS +print(b1 + b2) +print(b1 + b3) +print(b3 + b1) + +# bytearray construction will use the buffer protocol +print(bytearray(b1))