diff --git a/py/builtin.c b/py/builtin.c index 2b94163f16..d29a2bf8c3 100644 --- a/py/builtin.c +++ b/py/builtin.c @@ -16,30 +16,20 @@ mp_obj_t mp_builtin___build_class__(mp_obj_t o_class_fun, mp_obj_t o_class_name) { // we differ from CPython: we set the new __locals__ object here - mp_map_t *old_locals = rt_get_map_locals(); + mp_map_t *old_locals = rt_locals_get(); mp_map_t *class_locals = mp_map_new(MP_MAP_QSTR, 0); - rt_set_map_locals(class_locals); + rt_locals_set(class_locals); // call the class code rt_call_function_1(o_class_fun, (mp_obj_t)0xdeadbeef); // restore old __locals__ object - rt_set_map_locals(old_locals); + rt_locals_set(old_locals); // create and return the new class return mp_obj_new_class(class_locals); } -mp_obj_t mp_builtin___import__(int n, mp_obj_t *args) { - printf("import:\n"); - for (int i = 0; i < n; i++) { - printf(" "); - mp_obj_print(args[i]); - printf("\n"); - } - return mp_const_none; -} - mp_obj_t mp_builtin___repl_print__(mp_obj_t o) { if (o != mp_const_none) { mp_obj_print(o); diff --git a/py/builtinimport.c b/py/builtinimport.c new file mode 100644 index 0000000000..f1479ab123 --- /dev/null +++ b/py/builtinimport.c @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include + +#include "nlr.h" +#include "misc.h" +#include "mpconfig.h" +#include "lexer.h" +#include "lexerunix.h" +#include "parse.h" +#include "compile.h" +#include "obj.h" +#include "runtime0.h" +#include "runtime.h" +#include "map.h" +#include "builtin.h" + +mp_obj_t mp_builtin___import__(int n, mp_obj_t *args) { + /* + printf("import:\n"); + for (int i = 0; i < n; i++) { + printf(" "); + mp_obj_print(args[i]); + printf("\n"); + } + */ + + // find the file to import + qstr mod_name = mp_obj_get_qstr(args[0]); + mp_lexer_t *lex = mp_import_open_file(mod_name); + if (lex == NULL) { + // TODO handle lexer error correctly + return mp_const_none; + } + + // create a new module object + mp_obj_t module_obj = mp_obj_new_module(mp_obj_get_qstr(args[0])); + + // save the old context + mp_map_t *old_locals = rt_locals_get(); + mp_map_t *old_globals = rt_globals_get(); + + // set the new context + rt_locals_set(mp_obj_module_get_globals(module_obj)); + rt_globals_set(mp_obj_module_get_globals(module_obj)); + + // parse the imported script + mp_parse_node_t pn = mp_parse(lex, MP_PARSE_FILE_INPUT); + mp_lexer_free(lex); + + if (pn == MP_PARSE_NODE_NULL) { + // TODO handle parse error correctly + rt_locals_set(old_locals); + rt_globals_set(old_globals); + return mp_const_none; + } + + if (!mp_compile(pn, false)) { + // TODO handle compile error correctly + rt_locals_set(old_locals); + rt_globals_set(old_globals); + return mp_const_none; + } + + // complied successfully, execute it + mp_obj_t module_fun = rt_make_function_from_id(1); // TODO we should return from mp_compile the unique_code_id for the module + nlr_buf_t nlr; + if (nlr_push(&nlr) == 0) { + rt_call_function_0(module_fun); + nlr_pop(); + } else { + // exception; restore context and re-raise same exception + rt_locals_set(old_locals); + rt_globals_set(old_globals); + nlr_jump(nlr.ret_val); + } + rt_locals_set(old_locals); + rt_globals_set(old_globals); + + return module_obj; +} diff --git a/py/lexer.h b/py/lexer.h index f58a38e92b..27244fde96 100644 --- a/py/lexer.h +++ b/py/lexer.h @@ -138,3 +138,6 @@ bool mp_lexer_opt_str(mp_lexer_t *lex, const char *str); */ bool mp_lexer_show_error(mp_lexer_t *lex, const char *msg); bool mp_lexer_show_error_pythonic(mp_lexer_t *lex, const char *msg); + +// used to import a module; must be implemented for a specific port +mp_lexer_t *mp_import_open_file(qstr mod_name); diff --git a/py/lexerunix.c b/py/lexerunix.c index 398cb792a7..14c28c16d9 100644 --- a/py/lexerunix.c +++ b/py/lexerunix.c @@ -58,3 +58,23 @@ mp_lexer_t *mp_lexer_new_from_file(const char *filename) { return mp_lexer_new_from_str_len(filename, data, size, true); } + +/******************************************************************************/ +/* unix implementation of import */ + +// TODO properly! + +static const char *import_base_dir = NULL; + +void mp_import_set_directory(const char *dir) { + import_base_dir = dir; +} + +mp_lexer_t *mp_import_open_file(qstr mod_name) { + vstr_t *vstr = vstr_new(); + if (import_base_dir != NULL) { + vstr_printf(vstr, "%s/", import_base_dir); + } + vstr_printf(vstr, "%s.py", qstr_str(mod_name)); + return mp_lexer_new_from_file(vstr_str(vstr)); // TODO does lexer need to copy the string? can we free it here? +} diff --git a/py/lexerunix.h b/py/lexerunix.h index d86f202d53..b422a43062 100644 --- a/py/lexerunix.h +++ b/py/lexerunix.h @@ -1,2 +1,4 @@ mp_lexer_t *mp_lexer_new_from_str_len(const char *src_name, const char *str, uint len, bool free_str); mp_lexer_t *mp_lexer_new_from_file(const char *filename); + +void mp_import_set_directory(const char *dir); diff --git a/py/map.h b/py/map.h index 8ee8429b52..f8ca886aa4 100644 --- a/py/map.h +++ b/py/map.h @@ -23,10 +23,6 @@ typedef struct _mp_set_t { mp_obj_t *table; } mp_set_t; -// these are defined in runtime.c -mp_map_t *rt_get_map_locals(void); -void rt_set_map_locals(mp_map_t *m); - int get_doubling_prime_greater_or_equal_to(int x); void mp_map_init(mp_map_t *map, mp_map_kind_t kind, int n); mp_map_t *mp_map_new(mp_map_kind_t kind, int n); diff --git a/py/obj.h b/py/obj.h index 6a0cefd915..7b4b0656f2 100644 --- a/py/obj.h +++ b/py/obj.h @@ -215,11 +215,14 @@ mp_obj_t mp_obj_dict_store(mp_obj_t self_in, mp_obj_t key, mp_obj_t value); void mp_obj_set_store(mp_obj_t self_in, mp_obj_t item); // functions -typedef struct _mp_obj_fun_native_t { // need this so we can define static objects +typedef struct _mp_obj_fun_native_t { // need this so we can define const objects (to go in ROM) mp_obj_base_t base; machine_uint_t n_args_min; // inclusive machine_uint_t n_args_max; // inclusive void *fun; + // TODO add mp_map_t *globals + // for const function objects, make an empty, const map + // such functions won't be able to access the global scope, but that's probably okay } mp_obj_fun_native_t; extern const mp_obj_type_t fun_native_type; extern const mp_obj_type_t fun_bc_type; diff --git a/py/objfun.c b/py/objfun.c index cefc9a95fe..e998bd28d2 100644 --- a/py/objfun.c +++ b/py/objfun.c @@ -7,6 +7,7 @@ #include "misc.h" #include "mpconfig.h" #include "obj.h" +#include "map.h" #include "runtime.h" #include "bc.h" @@ -129,9 +130,10 @@ mp_obj_t rt_make_function_var_between(int n_args_min, int n_args_max, mp_fun_var typedef struct _mp_obj_fun_bc_t { mp_obj_base_t base; - int n_args; - uint n_state; - const byte *code; + mp_map_t *globals; // the context within which this function was defined + int n_args; // number of arguments this function takes + uint n_state; // total state size for the executing function (incl args, locals, stack) + const byte *bytecode; // bytecode for the function } mp_obj_fun_bc_t; // args are in reverse order in the array @@ -142,15 +144,17 @@ mp_obj_t fun_bc_call_n(mp_obj_t self_in, int n_args, const mp_obj_t *args) { nlr_jump(mp_obj_new_exception_msg_2_args(rt_q_TypeError, "function takes %d positional arguments but %d were given", (const char*)(machine_int_t)self->n_args, (const char*)(machine_int_t)n_args)); } - return mp_execute_byte_code(self->code, args, n_args, self->n_state); -} - -void mp_obj_fun_bc_get(mp_obj_t self_in, int *n_args, uint *n_state, const byte **code) { - assert(MP_OBJ_IS_TYPE(self_in, &fun_bc_type)); - mp_obj_fun_bc_t *self = self_in; - *n_args = self->n_args; - *n_state = self->n_state; - *code = self->code; + // optimisation: allow the compiler to optimise this tail call for + // the common case when the globals don't need to be changed + mp_map_t *old_globals = rt_globals_get(); + if (self->globals == old_globals) { + return mp_execute_byte_code(self->bytecode, args, n_args, self->n_state); + } else { + rt_globals_set(self->globals); + mp_obj_t result = mp_execute_byte_code(self->bytecode, args, n_args, self->n_state); + rt_globals_set(old_globals); + return result; + } } const mp_obj_type_t fun_bc_type = { @@ -170,12 +174,21 @@ const mp_obj_type_t fun_bc_type = { mp_obj_t mp_obj_new_fun_bc(int n_args, uint n_state, const byte *code) { mp_obj_fun_bc_t *o = m_new_obj(mp_obj_fun_bc_t); o->base.type = &fun_bc_type; + o->globals = rt_globals_get(); o->n_args = n_args; o->n_state = n_state; - o->code = code; + o->bytecode = code; return o; } +void mp_obj_fun_bc_get(mp_obj_t self_in, int *n_args, uint *n_state, const byte **code) { + assert(MP_OBJ_IS_TYPE(self_in, &fun_bc_type)); + mp_obj_fun_bc_t *self = self_in; + *n_args = self->n_args; + *n_state = self->n_state; + *code = self->bytecode; +} + /******************************************************************************/ /* inline assembler functions */ diff --git a/py/runtime.c b/py/runtime.c index 3fae61f6fe..a8e55467bd 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -281,14 +281,6 @@ void rt_assign_inline_asm_code(int unique_code_id, void *fun, uint len, int n_ar #endif } -mp_map_t *rt_get_map_locals(void) { - return map_locals; -} - -void rt_set_map_locals(mp_map_t *m) { - map_locals = m; -} - static bool fit_small_int(mp_small_int_t o) { return true; } @@ -786,6 +778,7 @@ mp_obj_t rt_load_attr(mp_obj_t base, qstr attr) { } else if (MP_OBJ_IS_TYPE(base, &instance_type)) { return mp_obj_instance_load_attr(base, attr); } else if (MP_OBJ_IS_TYPE(base, &module_type)) { + DEBUG_OP_printf("lookup module map %p\n", mp_obj_module_get_globals(base)); mp_map_elem_t *elem = mp_qstr_map_lookup(mp_obj_module_get_globals(base), attr, false); if (elem == NULL) { // TODO what about generic method lookup? @@ -913,6 +906,24 @@ mp_obj_t rt_import_from(mp_obj_t module, qstr name) { return x; } +mp_map_t *rt_locals_get(void) { + return map_locals; +} + +void rt_locals_set(mp_map_t *m) { + DEBUG_OP_printf("rt_locals_set(%p)\n", m); + map_locals = m; +} + +mp_map_t *rt_globals_get(void) { + return map_globals; +} + +void rt_globals_set(mp_map_t *m) { + DEBUG_OP_printf("rt_globals_set(%p)\n", m); + map_globals = m; +} + // these must correspond to the respective enum void *const rt_fun_table[RT_F_NUMBER_OF] = { rt_load_const_dec, diff --git a/py/runtime.h b/py/runtime.h index 37b036852f..cf9180275e 100644 --- a/py/runtime.h +++ b/py/runtime.h @@ -57,3 +57,9 @@ mp_obj_t rt_getiter(mp_obj_t o); mp_obj_t rt_iternext(mp_obj_t o); mp_obj_t rt_import_name(qstr name, mp_obj_t fromlist, mp_obj_t level); mp_obj_t rt_import_from(mp_obj_t module, qstr name); + +struct _mp_map_t; +struct _mp_map_t *rt_locals_get(void); +void rt_locals_set(struct _mp_map_t *m); +struct _mp_map_t *rt_globals_get(void); +void rt_globals_set(struct _mp_map_t *m); diff --git a/py/showbc.c b/py/showbc.c index 15cd056427..a3bfa2833b 100644 --- a/py/showbc.c +++ b/py/showbc.c @@ -142,12 +142,10 @@ void mp_show_byte_code(const byte *ip, int len) { printf("STORE_NAME %s", qstr_str(qstr)); break; - /* case MP_BC_STORE_GLOBAL: DECODE_QSTR; - rt_store_global(qstr, POP()); + printf("STORE_GLOBAL %s", qstr_str(qstr)); break; - */ case MP_BC_STORE_ATTR: DECODE_QSTR; @@ -343,6 +341,16 @@ void mp_show_byte_code(const byte *ip, int len) { printf("YIELD_VALUE"); break; + case MP_BC_IMPORT_NAME: + DECODE_QSTR; + printf("IMPORT NAME %s", qstr_str(qstr)); + break; + + case MP_BC_IMPORT_FROM: + DECODE_QSTR; + printf("IMPORT NAME %s", qstr_str(qstr)); + break; + default: printf("code %p, byte code 0x%02x not implemented\n", ip, op); assert(0); diff --git a/py/vstr.c b/py/vstr.c index 98cf027250..80841b24ca 100644 --- a/py/vstr.c +++ b/py/vstr.c @@ -167,8 +167,12 @@ void vstr_vprintf(vstr_t *vstr, const char *fmt, va_list ap) { while (1) { // try to print in the allocated space + // need to make a copy of the va_list because we may call vsnprintf multiple times int size = vstr->alloc - vstr->len; - int n = vsnprintf(vstr->buf + vstr->len, size, fmt, ap); + va_list ap2; + va_copy(ap2, ap); + int n = vsnprintf(vstr->buf + vstr->len, size, fmt, ap2); + va_end(ap2); // if that worked, return if (n > -1 && n < size) { diff --git a/stm/Makefile b/stm/Makefile index 6868f85ba7..be4ca8b3a0 100644 --- a/stm/Makefile +++ b/stm/Makefile @@ -82,6 +82,7 @@ PY_O = \ objtuple.o \ objtype.o \ builtin.o \ + builtinimport.o \ vm.o \ repl.o \ diff --git a/stm/lexerstm.c b/stm/lexerstm.c index dfb84cca13..661dfb0160 100644 --- a/stm/lexerstm.c +++ b/stm/lexerstm.c @@ -61,3 +61,8 @@ mp_lexer_t *mp_lexer_new_from_file(const char *filename, mp_lexer_file_buf_t *fb fb->pos = 0; return mp_lexer_new(filename, fb, (mp_lexer_stream_next_char_t)file_buf_next_char, (mp_lexer_stream_close_t)file_buf_close); } + +mp_lexer_t *mp_import_open_file(qstr mod_name) { + printf("import not implemented\n"); + return NULL; +} diff --git a/tests/basics/run-tests b/tests/basics/run-tests index 1b027c3e9f..72e69c2d8e 100755 --- a/tests/basics/run-tests +++ b/tests/basics/run-tests @@ -11,7 +11,7 @@ namefailed= for infile in tests/*.py do - basename=`basename $infile .c` + basename=`basename $infile .py` outfile=${basename}.out expfile=${basename}.exp diff --git a/tests/basics/tests/import1a.py b/tests/basics/tests/import1a.py new file mode 100644 index 0000000000..16b2d4d30f --- /dev/null +++ b/tests/basics/tests/import1a.py @@ -0,0 +1,2 @@ +import import1b +print(import1b.var) diff --git a/tests/basics/tests/import1b.py b/tests/basics/tests/import1b.py new file mode 100644 index 0000000000..80479088f0 --- /dev/null +++ b/tests/basics/tests/import1b.py @@ -0,0 +1 @@ +var = 123 diff --git a/unix-cpy/Makefile b/unix-cpy/Makefile index 9399a765c3..0f20fe31ce 100644 --- a/unix-cpy/Makefile +++ b/unix-cpy/Makefile @@ -47,6 +47,7 @@ PY_O = \ objtuple.o \ objtype.o \ builtin.o \ + builtinimport.o \ vm.o \ showbc.o \ repl.o \ diff --git a/unix/Makefile b/unix/Makefile index b8955d11a8..271cf22654 100644 --- a/unix/Makefile +++ b/unix/Makefile @@ -54,6 +54,7 @@ PY_O = \ objtuple.o \ objtype.o \ builtin.o \ + builtinimport.o \ vm.o \ showbc.o \ repl.o \ diff --git a/unix/main.c b/unix/main.c index 376dbc0c04..c23a8e54c4 100644 --- a/unix/main.c +++ b/unix/main.c @@ -105,6 +105,18 @@ static void do_repl(void) { } void do_file(const char *file) { + // hack: set dir for import based on where this file is + { + const char * s = strrchr(file, '/'); + if (s != NULL) { + int len = s - file; + char *dir = m_new(char, len + 1); + memcpy(dir, file, len); + dir[len] = '\0'; + mp_import_set_directory(dir); + } + } + mp_lexer_t *lex = mp_lexer_new_from_file(file); //const char *pysrc = "def f():\n x=x+1\n print(42)\n"; //mp_lexer_t *lex = mp_lexer_from_str_len("<>", pysrc, strlen(pysrc), false);