From dc3faea0405dea803828f5a2be314734b8c166b6 Mon Sep 17 00:00:00 2001 From: Damien George Date: Sun, 8 May 2016 21:38:43 +0100 Subject: [PATCH] py/mpz: Fix bug with overflowing C-shift in division routine. When DIG_SIZE=32, a uint32_t is used to store limbs, and no normalisation is needed because the MSB is already set, then there will be left and right shifts (in C) by 32 of a 32-bit variable, leading to undefined behaviour. This patch fixes this bug. --- py/mpz.c | 8 ++++---- tests/basics/int_big_div.py | 7 +++++++ tests/basics/int_big_mod.py | 7 +++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/py/mpz.c b/py/mpz.c index 2c02699811..100d2832cc 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -491,7 +491,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, for (mpz_dig_t *den = den_dig, carry = 0; den < den_dig + den_len; ++den) { mpz_dig_t d = *den; *den = ((d << norm_shift) | carry) & DIG_MASK; - carry = d >> (DIG_SIZE - norm_shift); + carry = (mpz_dbl_dig_t)d >> (DIG_SIZE - norm_shift); } // now need to shift numerator by same amount as denominator @@ -501,7 +501,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, for (mpz_dig_t *num = num_dig, carry = 0; num < num_dig + *num_len; ++num) { mpz_dig_t n = *num; *num = ((n << norm_shift) | carry) & DIG_MASK; - carry = n >> (DIG_SIZE - norm_shift); + carry = (mpz_dbl_dig_t)n >> (DIG_SIZE - norm_shift); } // cache the leading digit of the denominator @@ -618,14 +618,14 @@ STATIC void mpn_div(mpz_dig_t *num_dig, mp_uint_t *num_len, mpz_dig_t *den_dig, for (mpz_dig_t *den = den_dig + den_len - 1, carry = 0; den >= den_dig; --den) { mpz_dig_t d = *den; *den = ((d >> norm_shift) | carry) & DIG_MASK; - carry = d << (DIG_SIZE - norm_shift); + carry = (mpz_dbl_dig_t)d << (DIG_SIZE - norm_shift); } // unnormalise numerator (remainder now) for (mpz_dig_t *num = orig_num_dig + *num_len - 1, carry = 0; num >= orig_num_dig; --num) { mpz_dig_t n = *num; *num = ((n >> norm_shift) | carry) & DIG_MASK; - carry = n << (DIG_SIZE - norm_shift); + carry = (mpz_dbl_dig_t)n << (DIG_SIZE - norm_shift); } // strip trailing zeros diff --git a/tests/basics/int_big_div.py b/tests/basics/int_big_div.py index 8dacf495db..642f051d41 100644 --- a/tests/basics/int_big_div.py +++ b/tests/basics/int_big_div.py @@ -1,3 +1,10 @@ for lhs in (1000000000000000000000000, 10000000000100000000000000, 10012003400000000000000007, 12349083434598210349871029923874109871234789): for rhs in range(1, 555): print(lhs // rhs) + +# these check an edge case on 64-bit machines where two mpz limbs +# are used and the most significant one has the MSB set +x = 0x8000000000000000 +print((x + 1) // x) +x = 0x86c60128feff5330 +print((x + 1) // x) diff --git a/tests/basics/int_big_mod.py b/tests/basics/int_big_mod.py index 77c0ffc468..f383553c18 100644 --- a/tests/basics/int_big_mod.py +++ b/tests/basics/int_big_mod.py @@ -8,3 +8,10 @@ for i in range(11): y = delta * (j)# - 5) # TODO reinstate negative number test when % is working with sign correctly if y != 0: print(x % y) + +# these check an edge case on 64-bit machines where two mpz limbs +# are used and the most significant one has the MSB set +x = 0x8000000000000000 +print((x + 1) % x) +x = 0x86c60128feff5330 +print((x + 1) % x)