diff --git a/py/builtinimport.c b/py/builtinimport.c index e197dc7832..342bad079b 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -73,15 +73,8 @@ STATIC mp_import_stat_t mp_import_stat_any(const char *path) { return mp_import_stat(path); } -STATIC mp_import_stat_t stat_dir_or_file(vstr_t *path) { +STATIC mp_import_stat_t stat_file_py_or_mpy(vstr_t *path) { mp_import_stat_t stat = mp_import_stat_any(vstr_null_terminated_str(path)); - DEBUG_printf("stat %s: %d\n", vstr_str(path), stat); - if (stat == MP_IMPORT_STAT_DIR) { - return stat; - } - - vstr_add_str(path, ".py"); - stat = mp_import_stat_any(vstr_null_terminated_str(path)); if (stat == MP_IMPORT_STAT_FILE) { return stat; } @@ -97,6 +90,18 @@ STATIC mp_import_stat_t stat_dir_or_file(vstr_t *path) { return MP_IMPORT_STAT_NO_EXIST; } +STATIC mp_import_stat_t stat_dir_or_file(vstr_t *path) { + mp_import_stat_t stat = mp_import_stat_any(vstr_null_terminated_str(path)); + DEBUG_printf("stat %s: %d\n", vstr_str(path), stat); + if (stat == MP_IMPORT_STAT_DIR) { + return stat; + } + + // not a directory, add .py and try as a file + vstr_add_str(path, ".py"); + return stat_file_py_or_mpy(path); +} + STATIC mp_import_stat_t find_file(const char *file_str, uint file_len, vstr_t *dest) { #if MICROPY_PY_SYS // extract the list of paths @@ -463,7 +468,7 @@ mp_obj_t mp_builtin___import__(size_t n_args, const mp_obj_t *args) { mp_store_attr(module_obj, MP_QSTR___path__, mp_obj_new_str(vstr_str(&path), vstr_len(&path), false)); vstr_add_char(&path, PATH_SEP_CHAR); vstr_add_str(&path, "__init__.py"); - if (mp_import_stat_any(vstr_null_terminated_str(&path)) != MP_IMPORT_STAT_FILE) { + if (stat_file_py_or_mpy(&path) != MP_IMPORT_STAT_FILE) { vstr_cut_tail_bytes(&path, sizeof("/__init__.py") - 1); // cut off /__init__.py //mp_warning("%s is imported as namespace package", vstr_str(&path)); } else {