diff --git a/py/modstruct.c b/py/modstruct.c index b53dd51836..46d0ad004d 100644 --- a/py/modstruct.c +++ b/py/modstruct.c @@ -34,6 +34,7 @@ #include "objtuple.h" #include "objstr.h" #include "binary.h" +#include "parsenum.h" #if MICROPY_ENABLE_MOD_STRUCT @@ -56,9 +57,26 @@ STATIC char get_fmt_type(const char **fmt) { return t; } +STATIC machine_uint_t get_fmt_num(const char **p) { + const char *num = *p; + uint len = 1; + while (unichar_isdigit(*++num)) { + len++; + } + machine_uint_t val = (machine_uint_t)MP_OBJ_SMALL_INT_VALUE(mp_parse_num_integer(*p, len, 10)); + *p = num; + return val; +} + STATIC uint calcsize_items(const char *fmt) { - // TODO - return strlen(fmt); + uint cnt = 0; + while (*fmt) { + // TODO supports size spec only for "s" + if (!unichar_isdigit(*fmt++)) { + cnt++; + } + } + return cnt; } STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { @@ -67,9 +85,23 @@ STATIC mp_obj_t struct_calcsize(mp_obj_t fmt_in) { machine_uint_t size; for (size = 0; *fmt; fmt++) { uint align; - int sz = mp_binary_get_size(fmt_type, *fmt, &align); + machine_uint_t cnt = 1; + if (unichar_isdigit(*fmt)) { + cnt = get_fmt_num(&fmt); + } + if (cnt > 1) { + // TODO: count spec support only for string len + assert(*fmt == 's'); + } + + machine_uint_t sz; + if (*fmt == 's') { + sz = cnt; + } else { + sz = (machine_uint_t)mp_binary_get_size(fmt_type, *fmt, &align); + } // TODO - assert(sz != -1); + assert(sz != (machine_uint_t)-1); // Apply alignment size = (size + align - 1) & ~(align - 1); size += sz; @@ -89,7 +121,22 @@ STATIC mp_obj_t struct_unpack(mp_obj_t fmt_in, mp_obj_t data_in) { byte *p = bufinfo.buf; for (uint i = 0; i < size; i++) { - mp_obj_t item = mp_binary_get_val(fmt_type, *fmt++, &p); + machine_uint_t sz = 1; + if (unichar_isdigit(*fmt)) { + sz = get_fmt_num(&fmt); + } + if (sz > 1) { + // TODO: size spec support only for string len + assert(*fmt == 's'); + } + mp_obj_t item; + if (*fmt == 's') { + item = mp_obj_new_bytes(p, sz); + p += sz; + fmt++; + } else { + item = mp_binary_get_val(fmt_type, *fmt++, &p); + } res->items[i] = item; } return res; @@ -106,7 +153,29 @@ STATIC mp_obj_t struct_pack(uint n_args, mp_obj_t *args) { memset(p, 0, size); for (uint i = 1; i < n_args; i++) { - mp_binary_set_val(fmt_type, *fmt++, args[i], &p); + machine_uint_t sz = 1; + if (unichar_isdigit(*fmt)) { + sz = get_fmt_num(&fmt); + } + if (sz > 1) { + // TODO: size spec support only for string len + assert(*fmt == 's'); + } + + if (*fmt == 's') { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[i], &bufinfo, MP_BUFFER_READ); + machine_uint_t to_copy = sz; + if (bufinfo.len < to_copy) { + to_copy = bufinfo.len; + } + memcpy(p, bufinfo.buf, to_copy); + memset(p + to_copy, 0, sz - to_copy); + p += sz; + fmt++; + } else { + mp_binary_set_val(fmt_type, *fmt++, args[i], &p); + } } return res; } diff --git a/tests/basics/struct1.py b/tests/basics/struct1.py index 3a05c85f0b..b114a789b5 100644 --- a/tests/basics/struct1.py +++ b/tests/basics/struct1.py @@ -16,3 +16,8 @@ print(struct.pack(">b", 1)) print(struct.pack("bI", -128, 256)) + +print(struct.calcsize("100sI")) +print(struct.calcsize("97sI")) +print(struct.unpack("<6sH", b"foo\0\0\0\x12\x34")) +print(struct.pack("<6sH", b"foo", 10000))