py: This time, real proper overflow checking of small int power.

Previous overflow test was inadequate.
This commit is contained in:
Damien George 2014-04-04 11:13:51 +00:00
parent 6902eeda25
commit ecf5b77123
8 changed files with 83 additions and 77 deletions

View File

@ -17,7 +17,7 @@
#include "obj.h" #include "obj.h"
#include "compile.h" #include "compile.h"
#include "runtime.h" #include "runtime.h"
#include "intdivmod.h" #include "smallint.h"
// TODO need to mangle __attr names // TODO need to mangle __attr names
@ -143,10 +143,10 @@ mp_parse_node_t fold_constants(mp_parse_node_t pn) {
} else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_SLASH)) { } else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_SLASH)) {
; // pass ; // pass
} else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_PERCENT)) { } else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_PERCENT)) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, python_modulo(arg0, arg1)); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, mp_small_int_modulo(arg0, arg1));
} else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_DBL_SLASH)) { } else if (MP_PARSE_NODE_IS_TOKEN_KIND(pns->nodes[1], MP_TOKEN_OP_DBL_SLASH)) {
if (arg1 != 0) { if (arg1 != 0) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, python_floor_divide(arg0, arg1)); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, mp_small_int_floor_divide(arg0, arg1));
} }
} else { } else {
// shouldn't happen // shouldn't happen

View File

@ -1,24 +0,0 @@
#include "mpconfig.h"
machine_int_t python_modulo(machine_int_t dividend, machine_int_t divisor) {
machine_int_t lsign = (dividend >= 0) ? 1 :-1;
machine_int_t rsign = (divisor >= 0) ? 1 :-1;
dividend %= divisor;
if (lsign != rsign) {
dividend += divisor;
}
return dividend;
}
machine_int_t python_floor_divide(machine_int_t num, machine_int_t denom) {
machine_int_t lsign = num > 0 ? 1 : -1;
machine_int_t rsign = denom > 0 ? 1 : -1;
if (lsign == -1) {num *= -1;}
if (rsign == -1) {denom *= -1;}
if (lsign != rsign){
return - ( num + denom - 1) / denom;
} else {
return num / denom;
}
}

View File

@ -1,4 +0,0 @@
// Functions for integer modulo and floor division
machine_int_t python_modulo(machine_int_t dividend, machine_int_t divisor);
machine_int_t python_floor_divide(machine_int_t num, machine_int_t denom);

View File

@ -113,7 +113,7 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
case MP_BINARY_OP_INPLACE_FLOOR_DIVIDE: { case MP_BINARY_OP_INPLACE_FLOOR_DIVIDE: {
mpz_t rem; mpz_init_zero(&rem); mpz_t rem; mpz_init_zero(&rem);
mpz_divmod_inpl(&res->mpz, &rem, zlhs, zrhs); mpz_divmod_inpl(&res->mpz, &rem, zlhs, zrhs);
if (zlhs->neg != zrhs->neg) { if (zlhs->neg != zrhs->neg) {
if (!mpz_is_zero(&rem)) { if (!mpz_is_zero(&rem)) {
mpz_t mpzone; mpz_init_from_int(&mpzone, -1); mpz_t mpzone; mpz_init_from_int(&mpzone, -1);
mpz_add_inpl(&res->mpz, &res->mpz, &mpzone); mpz_add_inpl(&res->mpz, &res->mpz, &mpzone);
@ -127,8 +127,8 @@ mp_obj_t int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mpz_t quo; mpz_init_zero(&quo); mpz_t quo; mpz_init_zero(&quo);
mpz_divmod_inpl(&quo, &res->mpz, zlhs, zrhs); mpz_divmod_inpl(&quo, &res->mpz, zlhs, zrhs);
mpz_deinit(&quo); mpz_deinit(&quo);
// Check signs and do Python style modulo // Check signs and do Python style modulo
if (zlhs->neg != zrhs->neg) { if (zlhs->neg != zrhs->neg) {
mpz_add_inpl(&res->mpz, &res->mpz, zrhs); mpz_add_inpl(&res->mpz, &res->mpz, zrhs);
} }
break; break;

View File

@ -84,7 +84,7 @@ PY_O_BASENAME = \
vm.o \ vm.o \
showbc.o \ showbc.o \
repl.o \ repl.o \
intdivmod.o \ smallint.o \
pfenv.o \ pfenv.o \
# prepend the build destination prefix to the py object files # prepend the build destination prefix to the py object files

View File

@ -16,7 +16,7 @@
#include "builtin.h" #include "builtin.h"
#include "builtintables.h" #include "builtintables.h"
#include "bc.h" #include "bc.h"
#include "intdivmod.h" #include "smallint.h"
#include "objgenerator.h" #include "objgenerator.h"
#if 0 // print debugging info #if 0 // print debugging info
@ -289,7 +289,7 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
// If long long type exists and is larger than machine_int_t, then // If long long type exists and is larger than machine_int_t, then
// we can use the following code to perform overflow-checked multiplication. // we can use the following code to perform overflow-checked multiplication.
// Otherwise (eg in x64 case) we must use the branching code below. // Otherwise (eg in x64 case) we must use mp_small_int_mul_overflow.
#if 0 #if 0
// compute result using long long precision // compute result using long long precision
long long res = (long long)lhs_val * (long long)rhs_val; long long res = (long long)lhs_val * (long long)rhs_val;
@ -302,36 +302,14 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
} }
#endif #endif
if (lhs_val > 0) { // lhs_val is positive if (mp_small_int_mul_overflow(lhs_val, rhs_val)) {
if (rhs_val > 0) { // lhs_val and rhs_val are positive // use higher precision
if (lhs_val > (MP_SMALL_INT_MAX / rhs_val)) { lhs = mp_obj_new_int_from_ll(lhs_val);
goto mul_overflow; goto generic_binary_op;
} } else {
} else { // lhs_val positive, rhs_val nonpositive // use standard precision
if (rhs_val < (MP_SMALL_INT_MIN / lhs_val)) { return MP_OBJ_NEW_SMALL_INT(lhs_val * rhs_val);
goto mul_overflow; }
}
} // lhs_val positive, rhs_val nonpositive
} else { // lhs_val is nonpositive
if (rhs_val > 0) { // lhs_val is nonpositive, rhs_val is positive
if (lhs_val < (MP_SMALL_INT_MIN / rhs_val)) {
goto mul_overflow;
}
} else { // lhs_val and rhs_val are nonpositive
if (lhs_val != 0 && rhs_val < (MP_SMALL_INT_MAX / lhs_val)) {
goto mul_overflow;
}
} // End if lhs_val and rhs_val are nonpositive
} // End if lhs_val is nonpositive
// use standard precision
return MP_OBJ_NEW_SMALL_INT(lhs_val * rhs_val);
mul_overflow:
// use higher precision
lhs = mp_obj_new_int_from_ll(lhs_val);
goto generic_binary_op;
break; break;
} }
case MP_BINARY_OP_FLOOR_DIVIDE: case MP_BINARY_OP_FLOOR_DIVIDE:
@ -339,7 +317,7 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
if (rhs_val == 0) { if (rhs_val == 0) {
goto zero_division; goto zero_division;
} }
lhs_val = python_floor_divide(lhs_val, rhs_val); lhs_val = mp_small_int_floor_divide(lhs_val, rhs_val);
break; break;
#if MICROPY_ENABLE_FLOAT #if MICROPY_ENABLE_FLOAT
@ -352,11 +330,11 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
#endif #endif
case MP_BINARY_OP_MODULO: case MP_BINARY_OP_MODULO:
case MP_BINARY_OP_INPLACE_MODULO: case MP_BINARY_OP_INPLACE_MODULO: {
{ lhs_val = mp_small_int_modulo(lhs_val, rhs_val);
lhs_val = python_modulo(lhs_val, rhs_val);
break; break;
} }
case MP_BINARY_OP_POWER: case MP_BINARY_OP_POWER:
case MP_BINARY_OP_INPLACE_POWER: case MP_BINARY_OP_INPLACE_POWER:
if (rhs_val < 0) { if (rhs_val < 0) {
@ -370,21 +348,19 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
machine_int_t ans = 1; machine_int_t ans = 1;
while (rhs_val > 0) { while (rhs_val > 0) {
if (rhs_val & 1) { if (rhs_val & 1) {
machine_int_t old = ans; if (mp_small_int_mul_overflow(ans, lhs_val)) {
ans *= lhs_val;
if (ans < old) {
goto power_overflow; goto power_overflow;
} }
ans *= lhs_val;
} }
if (rhs_val == 1) { if (rhs_val == 1) {
break; break;
} }
rhs_val /= 2; rhs_val /= 2;
machine_int_t old = lhs_val; if (mp_small_int_mul_overflow(lhs_val, lhs_val)) {
lhs_val *= lhs_val;
if (lhs_val < old) {
goto power_overflow; goto power_overflow;
} }
lhs_val *= lhs_val;
} }
lhs_val = ans; lhs_val = ans;
} }

53
py/smallint.c Normal file
View File

@ -0,0 +1,53 @@
#include "misc.h"
#include "mpconfig.h"
#include "qstr.h"
#include "obj.h"
bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y) {
// Check for multiply overflow; see CERT INT32-C
if (x > 0) { // x is positive
if (y > 0) { // x and y are positive
if (x > (MP_SMALL_INT_MAX / y)) {
return true;
}
} else { // x positive, y nonpositive
if (y < (MP_SMALL_INT_MIN / x)) {
return true;
}
} // x positive, y nonpositive
} else { // x is nonpositive
if (y > 0) { // x is nonpositive, y is positive
if (x < (MP_SMALL_INT_MIN / y)) {
return true;
}
} else { // x and y are nonpositive
if (x != 0 && y < (MP_SMALL_INT_MAX / x)) {
return true;
}
} // End if x and y are nonpositive
} // End if x is nonpositive
return false;
}
machine_int_t mp_small_int_modulo(machine_int_t dividend, machine_int_t divisor) {
machine_int_t lsign = (dividend >= 0) ? 1 :-1;
machine_int_t rsign = (divisor >= 0) ? 1 :-1;
dividend %= divisor;
if (lsign != rsign) {
dividend += divisor;
}
return dividend;
}
machine_int_t mp_small_int_floor_divide(machine_int_t num, machine_int_t denom) {
machine_int_t lsign = num > 0 ? 1 : -1;
machine_int_t rsign = denom > 0 ? 1 : -1;
if (lsign == -1) {num *= -1;}
if (rsign == -1) {denom *= -1;}
if (lsign != rsign){
return - ( num + denom - 1) / denom;
} else {
return num / denom;
}
}

5
py/smallint.h Normal file
View File

@ -0,0 +1,5 @@
// Functions for small integer arithmetic
bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y);
machine_int_t mp_small_int_modulo(machine_int_t dividend, machine_int_t divisor);
machine_int_t mp_small_int_floor_divide(machine_int_t num, machine_int_t denom);