py/mpz: Fix overflow of borrow in mpn_div.

For certain operands to mpn_div, the existing code path for
`DIG_SIZE == MPZ_DBL_DIG_SIZE / 2` had a bug in it where borrow could still
overflow in the `(x >= *n || *n - x <= borrow)` branch, ie
`borrow + x - (mpz_dbl_dig_t)*n` overflows the borrow variable.  In such
cases the subsequent right-shift of borrow would not bring in the overflow
bit, leading to an error in the result.  An example division that had
overflow when MPZ_DIG_SIZE = 16 is `(2 ** 48 - 1) ** 2 // (2 ** 48 - 1)`.

This is fixed in this commit by simplifying the code and handling the low
digits of borrow first, and then the upper bits (to shift down) separately.
There is no longer a distinction between `DIG_SIZE < MPZ_DBL_DIG_SIZE / 2`
and `DIG_SIZE == MPZ_DBL_DIG_SIZE / 2`.

This commit also simplifies the second part of the calculation so that
borrow does not need to be negated (instead the code just works knowing
that borrow is negative and using + instead of - in calculations involving
borrow).

Fixes #6777.

Signed-off-by: Damien George <damien@micropython.org>
This commit is contained in:
Damien George 2021-02-05 00:26:08 +11:00
parent 9dedcf122d
commit 0a59938574
2 changed files with 20 additions and 39 deletions

View File

@ -531,60 +531,37 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
quo /= lead_den_digit; quo /= lead_den_digit;
// Multiply quo by den and subtract from num to get remainder. // Multiply quo by den and subtract from num to get remainder.
// We have different code here to handle different compile-time // Must be careful with overflow of the borrow variable. Both
// configurations of mpz: // borrow and low_digs are signed values and need signed right-shift,
// // but x is unsigned and may take a full-range value.
// 1. DIG_SIZE is stricly less than half the number of bits
// available in mpz_dbl_dig_t. In this case we can use a
// slightly more optimal (in time and space) routine that
// uses the extra bits in mpz_dbl_dig_signed_t to store a
// sign bit.
//
// 2. DIG_SIZE is exactly half the number of bits available in
// mpz_dbl_dig_t. In this (common) case we need to be careful
// not to overflow the borrow variable. And the shifting of
// borrow needs some special logic (it's a shift right with
// round up).
//
const mpz_dig_t *d = den_dig; const mpz_dig_t *d = den_dig;
mpz_dbl_dig_t d_norm = 0; mpz_dbl_dig_t d_norm = 0;
mpz_dbl_dig_t borrow = 0; mpz_dbl_dig_signed_t borrow = 0;
for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) { for (mpz_dig_t *n = num_dig - den_len; n < num_dig; ++n, ++d) {
// Get the next digit in (den).
d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE); d_norm = ((mpz_dbl_dig_t)*d << norm_shift) | (d_norm >> DIG_SIZE);
// Multiply the next digit in (quo * den).
mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK); mpz_dbl_dig_t x = (mpz_dbl_dig_t)quo * (d_norm & DIG_MASK);
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2 // Compute the low DIG_MASK bits of the next digit in (num - quo * den)
borrow += (mpz_dbl_dig_t)*n - x; // will overflow if DIG_SIZE >= MPZ_DBL_DIG_SIZE/2 mpz_dbl_dig_signed_t low_digs = (borrow & DIG_MASK) + *n - (x & DIG_MASK);
*n = borrow & DIG_MASK; // Store the digit result for (num).
borrow = (mpz_dbl_dig_signed_t)borrow >> DIG_SIZE; *n = low_digs & DIG_MASK;
#else // DIG_SIZE == MPZ_DBL_DIG_SIZE / 2 // Compute the borrow, shifted right before summing to avoid overflow.
if (x >= *n || *n - x <= borrow) { borrow = (borrow >> DIG_SIZE) - (x >> DIG_SIZE) + (low_digs >> DIG_SIZE);
borrow += x - (mpz_dbl_dig_t)*n;
*n = (-borrow) & DIG_MASK;
borrow = (borrow >> DIG_SIZE) + ((borrow & DIG_MASK) == 0 ? 0 : 1); // shift-right with round-up
} else {
*n = ((mpz_dbl_dig_t)*n - x - borrow) & DIG_MASK;
borrow = 0;
} }
#endif
}
#if DIG_SIZE < MPZ_DBL_DIG_SIZE / 2
// Borrow was negative in the above for-loop, make it positive for next if-block.
borrow = -borrow;
#endif
// At this point we have either: // At this point we have either:
// //
// 1. quo was the correct value and the most-sig-digit of num is exactly // 1. quo was the correct value and the most-sig-digit of num is exactly
// cancelled by borrow (borrow == *num_dig). In this case there is // cancelled by borrow (borrow + *num_dig == 0). In this case there is
// nothing more to do. // nothing more to do.
// //
// 2. quo was too large, we subtracted too many den from num, and the // 2. quo was too large, we subtracted too many den from num, and the
// most-sig-digit of num is 1 less than borrow (borrow == *num_dig + 1). // most-sig-digit of num is less than needed (borrow + *num_dig < 0).
// In this case we must reduce quo and add back den to num until the // In this case we must reduce quo and add back den to num until the
// carry from this operation cancels out the borrow. // carry from this operation cancels out the borrow.
// //
borrow -= *num_dig; borrow += *num_dig;
for (; borrow != 0; --quo) { for (; borrow != 0; --quo) {
d = den_dig; d = den_dig;
d_norm = 0; d_norm = 0;
@ -595,7 +572,7 @@ STATIC void mpn_div(mpz_dig_t *num_dig, size_t *num_len, const mpz_dig_t *den_di
*n = carry & DIG_MASK; *n = carry & DIG_MASK;
carry >>= DIG_SIZE; carry >>= DIG_SIZE;
} }
borrow -= carry; borrow += carry;
} }
// store this digit of the quotient // store this digit of the quotient

View File

@ -8,3 +8,7 @@ x = 0x8000000000000000
print((x + 1) // x) print((x + 1) // x)
x = 0x86c60128feff5330 x = 0x86c60128feff5330
print((x + 1) // x) print((x + 1) // x)
# these check edge cases where borrow overflows
print((2 ** 48 - 1) ** 2 // (2 ** 48 - 1))
print((2 ** 256 - 2 ** 32) ** 2 // (2 ** 256 - 2 ** 32))