From 9d68e9ccdd4d7f4ecb7a8765ca694e355753d686 Mon Sep 17 00:00:00 2001 From: Damien George <damien.p.george@gmail.com> Date: Wed, 12 Mar 2014 15:38:15 +0000 Subject: [PATCH] py: Implement integer overflow checking for * and << ops. If operation will overflow, a multi-precision integer is created. --- py/obj.h | 6 +- py/objfloat.c | 2 - py/objint.c | 6 ++ py/objint_mpz.c | 2 +- py/runtime.c | 147 +++++++++++++++++++++++++++++++++++++++--------- 5 files changed, 131 insertions(+), 32 deletions(-) diff --git a/py/obj.h b/py/obj.h index 4d93c7afad..d41db37c08 100644 --- a/py/obj.h +++ b/py/obj.h @@ -29,6 +29,8 @@ typedef struct _mp_obj_base_t mp_obj_base_t; // - xxxx...xx00: a pointer to an mp_obj_base_t // In SMALL_INT, next-to-highest bits is used as sign, so both must match for value in range +#define MP_SMALL_INT_MIN ((mp_small_int_t)(((machine_int_t)WORD_MSBIT_HIGH) >> 1)) +#define MP_SMALL_INT_MAX ((mp_small_int_t)(~(MP_SMALL_INT_MIN))) #define MP_OBJ_FITS_SMALL_INT(n) ((((n) ^ ((n) << 1)) & WORD_MSBIT_HIGH) == 0) #define MP_OBJ_IS_SMALL_INT(o) ((((mp_small_int_t)(o)) & 1) != 0) #define MP_OBJ_IS_QSTR(o) ((((mp_small_int_t)(o)) & 3) == 2) @@ -218,9 +220,7 @@ mp_obj_t mp_obj_new_cell(mp_obj_t obj); mp_obj_t mp_obj_new_int(machine_int_t value); mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value); mp_obj_t mp_obj_new_int_from_long_str(const char *s); -#if MICROPY_LONGINT_IMPL != MICROPY_LONGINT_IMPL_NONE -mp_obj_t mp_obj_new_int_from_ll(long long val); -#endif +mp_obj_t mp_obj_new_int_from_ll(long long val); // this must return a multi-precision integer object (or raise an overflow exception) mp_obj_t mp_obj_new_str(const byte* data, uint len, bool make_qstr_if_not_already); mp_obj_t mp_obj_new_bytes(const byte* data, uint len); #if MICROPY_ENABLE_FLOAT diff --git a/py/objfloat.c b/py/objfloat.c index 91d669ad58..04d1278014 100644 --- a/py/objfloat.c +++ b/py/objfloat.c @@ -17,8 +17,6 @@ #include "formatfloat.h" #endif -mp_obj_t mp_obj_new_float(mp_float_t value); - STATIC void float_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj_t o_in, mp_print_kind_t kind) { mp_obj_float_t *o = o_in; #if MICROPY_FLOAT_IMPL == MICROPY_FLOAT_IMPL_FLOAT diff --git a/py/objint.c b/py/objint.c index 490b4340bb..7a9b0366db 100644 --- a/py/objint.c +++ b/py/objint.c @@ -71,6 +71,12 @@ mp_obj_t mp_obj_new_int_from_long_str(const char *s) { return mp_const_none; } +// This is called when an integer larger than a SMALL_INT is needed (although val might still fit in a SMALL_INT) +mp_obj_t mp_obj_new_int_from_ll(long long val) { + nlr_jump(mp_obj_new_exception_msg(&mp_type_OverflowError, "small int overflow")); + return mp_const_none; +} + mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) { // SMALL_INT accepts only signed numbers, of one bit less size // then word size, which totals 2 bits less for unsigned numbers. diff --git a/py/objint_mpz.c b/py/objint_mpz.c index 5cd4fb7bac..e8e8b85472 100644 --- a/py/objint_mpz.c +++ b/py/objint_mpz.c @@ -161,7 +161,7 @@ mp_obj_t mp_obj_new_int(machine_int_t value) { mp_obj_t mp_obj_new_int_from_ll(long long val) { mp_obj_int_t *o = mp_obj_int_new_mpz(); - mpz_set_from_int(&o->mpz, val); + mpz_set_from_ll(&o->mpz, val); return o; } diff --git a/py/runtime.c b/py/runtime.c index 31cbb660ad..bd6f2289de 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -455,16 +455,23 @@ mp_obj_t rt_unary_op(int op, mp_obj_t arg) { if (MP_OBJ_IS_SMALL_INT(arg)) { mp_small_int_t val = MP_OBJ_SMALL_INT_VALUE(arg); switch (op) { - case RT_UNARY_OP_BOOL: return MP_BOOL(val != 0); - case RT_UNARY_OP_POSITIVE: break; - case RT_UNARY_OP_NEGATIVE: val = -val; break; - case RT_UNARY_OP_INVERT: val = ~val; break; - default: assert(0); val = 0; + case RT_UNARY_OP_BOOL: + return MP_BOOL(val != 0); + case RT_UNARY_OP_POSITIVE: + return arg; + case RT_UNARY_OP_NEGATIVE: + // check for overflow + if (val == MP_SMALL_INT_MIN) { + return mp_obj_new_int(-val); + } else { + return MP_OBJ_NEW_SMALL_INT(-val); + } + case RT_UNARY_OP_INVERT: + return MP_OBJ_NEW_SMALL_INT(~val); + default: + assert(0); + return arg; } - if (MP_OBJ_FITS_SMALL_INT(val)) { - return MP_OBJ_NEW_SMALL_INT(val); - } - return mp_obj_new_int(val); } else { mp_obj_type_t *type = mp_obj_get_type(arg); if (type->unary_op != NULL) { @@ -532,6 +539,15 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { mp_small_int_t lhs_val = MP_OBJ_SMALL_INT_VALUE(lhs); if (MP_OBJ_IS_SMALL_INT(rhs)) { mp_small_int_t rhs_val = MP_OBJ_SMALL_INT_VALUE(rhs); + // This is a binary operation: lhs_val op rhs_val + // We need to be careful to handle overflow; see CERT INT32-C + // Operations that can overflow: + // + result always fits in machine_int_t, then handled by SMALL_INT check + // - result always fits in machine_int_t, then handled by SMALL_INT check + // * checked explicitly + // / if lhs=MIN and rhs=-1; result always fits in machine_int_t, then handled by SMALL_INT check + // % if lhs=MIN and rhs=-1; result always fits in machine_int_t, then handled by SMALL_INT check + // << checked explicitly switch (op) { case RT_BINARY_OP_OR: case RT_BINARY_OP_INPLACE_OR: lhs_val |= rhs_val; break; @@ -540,41 +556,117 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { case RT_BINARY_OP_AND: case RT_BINARY_OP_INPLACE_AND: lhs_val &= rhs_val; break; case RT_BINARY_OP_LSHIFT: - case RT_BINARY_OP_INPLACE_LSHIFT: lhs_val <<= rhs_val; break; + case RT_BINARY_OP_INPLACE_LSHIFT: { + if (rhs_val < 0) { + // negative shift not allowed + nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative shift count")); + } else if (rhs_val >= BITS_PER_WORD || lhs_val > (MP_SMALL_INT_MAX >> rhs_val) || lhs_val < (MP_SMALL_INT_MIN >> rhs_val)) { + // left-shift will overflow, so use higher precision integer + lhs = mp_obj_new_int_from_ll(lhs_val); + goto generic_binary_op; + } else { + // use standard precision + lhs_val <<= rhs_val; + } + break; + } case RT_BINARY_OP_RSHIFT: - case RT_BINARY_OP_INPLACE_RSHIFT: lhs_val >>= rhs_val; break; + case RT_BINARY_OP_INPLACE_RSHIFT: + if (rhs_val < 0) { + // negative shift not allowed + nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative shift count")); + } else { + // standard precision is enough for right-shift + lhs_val >>= rhs_val; + } + break; case RT_BINARY_OP_ADD: case RT_BINARY_OP_INPLACE_ADD: lhs_val += rhs_val; break; case RT_BINARY_OP_SUBTRACT: case RT_BINARY_OP_INPLACE_SUBTRACT: lhs_val -= rhs_val; break; case RT_BINARY_OP_MULTIPLY: - case RT_BINARY_OP_INPLACE_MULTIPLY: lhs_val *= rhs_val; break; + case RT_BINARY_OP_INPLACE_MULTIPLY: { + + // If long long type exists and is larger than machine_int_t, then + // we can use the following code to perform overflow-checked multiplication. + // Otherwise (eg in x64 case) we must use the branching code below. + #if 0 + // compute result using long long precision + long long res = (long long)lhs_val * (long long)rhs_val; + if (res > MP_SMALL_INT_MAX || res < MP_SMALL_INT_MIN) { + // result overflowed SMALL_INT, so return higher precision integer + return mp_obj_new_int_from_ll(res); + } else { + // use standard precision + lhs_val = (mp_small_int_t)res; + } + #endif + + if (lhs_val > 0) { // lhs_val is positive + if (rhs_val > 0) { // lhs_val and rhs_val are positive + if (lhs_val > (MP_SMALL_INT_MAX / rhs_val)) { + goto mul_overflow; + } + } else { // lhs_val positive, rhs_val nonpositive + if (rhs_val < (MP_SMALL_INT_MIN / lhs_val)) { + goto mul_overflow; + } + } // lhs_val positive, rhs_val nonpositive + } else { // lhs_val is nonpositive + if (rhs_val > 0) { // lhs_val is nonpositive, rhs_val is positive + if (lhs_val < (MP_SMALL_INT_MIN / rhs_val)) { + goto mul_overflow; + } + } else { // lhs_val and rhs_val are nonpositive + if (lhs_val != 0 && rhs_val < (MP_SMALL_INT_MAX / lhs_val)) { + goto mul_overflow; + } + } // End if lhs_val and rhs_val are nonpositive + } // End if lhs_val is nonpositive + + // use standard precision + return MP_OBJ_NEW_SMALL_INT(lhs_val * rhs_val); + + mul_overflow: + // use higher precision + lhs = mp_obj_new_int_from_ll(lhs_val); + goto generic_binary_op; + + break; + } case RT_BINARY_OP_FLOOR_DIVIDE: case RT_BINARY_OP_INPLACE_FLOOR_DIVIDE: lhs_val /= rhs_val; break; - #if MICROPY_ENABLE_FLOAT + #if MICROPY_ENABLE_FLOAT case RT_BINARY_OP_TRUE_DIVIDE: case RT_BINARY_OP_INPLACE_TRUE_DIVIDE: return mp_obj_new_float((mp_float_t)lhs_val / (mp_float_t)rhs_val); - #endif + #endif // TODO implement modulo as specified by Python case RT_BINARY_OP_MODULO: case RT_BINARY_OP_INPLACE_MODULO: lhs_val %= rhs_val; break; - // TODO check for negative power, and overflow case RT_BINARY_OP_POWER: case RT_BINARY_OP_INPLACE_POWER: - { - int ans = 1; - while (rhs_val > 0) { - if (rhs_val & 1) { - ans *= lhs_val; + if (rhs_val < 0) { + #if MICROPY_ENABLE_FLOAT + lhs = mp_obj_new_float(lhs_val); + goto generic_binary_op; + #else + nlr_jump(mp_obj_new_exception_msg(&mp_type_ValueError, "negative power with no float support")); + #endif + } else { + // TODO check for overflow + machine_int_t ans = 1; + while (rhs_val > 0) { + if (rhs_val & 1) { + ans *= lhs_val; + } + lhs_val *= lhs_val; + rhs_val /= 2; } - lhs_val *= lhs_val; - rhs_val /= 2; + lhs_val = ans; } - lhs_val = ans; break; - } case RT_BINARY_OP_LESS: return MP_BOOL(lhs_val < rhs_val); break; case RT_BINARY_OP_MORE: return MP_BOOL(lhs_val > rhs_val); break; case RT_BINARY_OP_LESS_EQUAL: return MP_BOOL(lhs_val <= rhs_val); break; @@ -585,8 +677,9 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { // TODO: We just should make mp_obj_new_int() inline and use that if (MP_OBJ_FITS_SMALL_INT(lhs_val)) { return MP_OBJ_NEW_SMALL_INT(lhs_val); + } else { + return mp_obj_new_int(lhs_val); } - return mp_obj_new_int(lhs_val); #if MICROPY_ENABLE_FLOAT } else if (MP_OBJ_IS_TYPE(rhs, &mp_type_float)) { return mp_obj_float_binary_op(op, lhs_val, rhs); @@ -628,7 +721,9 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { } // generic binary_op supplied by type - mp_obj_type_t *type = mp_obj_get_type(lhs); + mp_obj_type_t *type; +generic_binary_op: + type = mp_obj_get_type(lhs); if (type->binary_op != NULL) { mp_obj_t result = type->binary_op(op, lhs, rhs); if (result != MP_OBJ_NULL) {