macssh/gmp/mpn/generic/mul_n.c

1344 lines
28 KiB
C
Executable File

/* mpn_mul_n and helper function -- Multiply/square natural numbers.
THE HELPER FUNCTIONS IN THIS FILE (meaning everything except mpn_mul_n)
ARE INTERNAL FUNCTIONS WITH MUTABLE INTERFACES. IT IS ONLY SAFE TO REACH
THEM THROUGH DOCUMENTED INTERFACES. IN FACT, IT IS ALMOST GUARANTEED
THAT THEY'LL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.
Copyright (C) 1991, 1993, 1994, 1996, 1997, 1998, 1999, 2000 Free Software
Foundation, Inc.
This file is part of the GNU MP Library.
The GNU MP Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.
The GNU MP Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
License for more details.
You should have received a copy of the GNU Lesser General Public License
along with the GNU MP Library; see the file COPYING.LIB. If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA. */
#include "gmp.h"
#include "gmp-impl.h"
#include "longlong.h"
/* Multiplicative inverse of 3, modulo 2^BITS_PER_MP_LIMB.
0xAAAAAAAB for 32 bits, 0xAAAAAAAAAAAAAAAB for 64 bits. */
#define INVERSE_3 ((MP_LIMB_T_MAX / 3) * 2 + 1)
#if !defined (__alpha) && !defined (__mips)
/* For all other machines, we want to call mpn functions for the compund
operations instead of open-coding them. */
#define USE_MORE_MPN
#endif
/*== Function declarations =================================================*/
static void evaluate3 _PROTO ((mp_ptr, mp_ptr, mp_ptr,
mp_ptr, mp_ptr, mp_ptr,
mp_srcptr, mp_srcptr, mp_srcptr,
mp_size_t, mp_size_t));
static void interpolate3 _PROTO ((mp_srcptr,
mp_ptr, mp_ptr, mp_ptr,
mp_srcptr,
mp_ptr, mp_ptr, mp_ptr,
mp_size_t, mp_size_t));
static mp_limb_t add2Times _PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t));
/*-- mpn_kara_mul_n ---------------------------------------------------------------*/
/* Multiplies using 3 half-sized mults and so on recursively.
* p[0..2*n-1] := product of a[0..n-1] and b[0..n-1].
* No overlap of p[...] with a[...] or b[...].
* ws is workspace.
*/
void
#if __STDC__
mpn_kara_mul_n (mp_ptr p, mp_srcptr a, mp_srcptr b, mp_size_t n, mp_ptr ws)
#else
mpn_kara_mul_n(p, a, b, n, ws)
mp_ptr p;
mp_srcptr a;
mp_srcptr b;
mp_size_t n;
mp_ptr ws;
#endif
{
mp_limb_t i, sign, w, w0, w1;
mp_size_t n2;
mp_srcptr x, y;
n2 = n >> 1;
ASSERT (n2 > 0);
if (n & 1)
{
/* Odd length. */
mp_size_t n1, n3, nm1;
n3 = n - n2;
sign = 0;
w = a[n2];
if (w != 0)
w -= mpn_sub_n (p, a, a + n3, n2);
else
{
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n3+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = a + n3;
y = a;
sign = 1;
}
else
{
x = a;
y = a + n3;
}
mpn_sub_n (p, x, y, n2);
}
p[n2] = w;
w = b[n2];
if (w != 0)
w -= mpn_sub_n (p + n3, b, b + n3, n2);
else
{
i = n2;
do
{
--i;
w0 = b[i];
w1 = b[n3+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = b + n3;
y = b;
sign ^= 1;
}
else
{
x = b;
y = b + n3;
}
mpn_sub_n (p + n3, x, y, n2);
}
p[n] = w;
n1 = n + 1;
if (n2 < KARATSUBA_MUL_THRESHOLD)
{
if (n3 < KARATSUBA_MUL_THRESHOLD)
{
mpn_mul_basecase (ws, p, n3, p + n3, n3);
mpn_mul_basecase (p, a, n3, b, n3);
}
else
{
mpn_kara_mul_n (ws, p, p + n3, n3, ws + n1);
mpn_kara_mul_n (p, a, b, n3, ws + n1);
}
mpn_mul_basecase (p + n1, a + n3, n2, b + n3, n2);
}
else
{
mpn_kara_mul_n (ws, p, p + n3, n3, ws + n1);
mpn_kara_mul_n (p, a, b, n3, ws + n1);
mpn_kara_mul_n (p + n1, a + n3, b + n3, n2, ws + n1);
}
if (sign)
mpn_add_n (ws, p, ws, n1);
else
mpn_sub_n (ws, p, ws, n1);
nm1 = n - 1;
if (mpn_add_n (ws, p + n1, ws, nm1))
{
mp_limb_t x = ws[nm1] + 1;
ws[nm1] = x;
if (x == 0)
++ws[n];
}
if (mpn_add_n (p + n3, p + n3, ws, n1))
{
mp_limb_t x;
i = n1 + n3;
do
{
x = p[i] + 1;
p[i] = x;
++i;
} while (x == 0);
}
}
else
{
/* Even length. */
mp_limb_t t;
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n2+i];
}
while (w0 == w1 && i != 0);
sign = 0;
if (w0 < w1)
{
x = a + n2;
y = a;
sign = 1;
}
else
{
x = a;
y = a + n2;
}
mpn_sub_n (p, x, y, n2);
i = n2;
do
{
--i;
w0 = b[i];
w1 = b[n2+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = b + n2;
y = b;
sign ^= 1;
}
else
{
x = b;
y = b + n2;
}
mpn_sub_n (p + n2, x, y, n2);
/* Pointwise products. */
if (n2 < KARATSUBA_MUL_THRESHOLD)
{
mpn_mul_basecase (ws, p, n2, p + n2, n2);
mpn_mul_basecase (p, a, n2, b, n2);
mpn_mul_basecase (p + n, a + n2, n2, b + n2, n2);
}
else
{
mpn_kara_mul_n (ws, p, p + n2, n2, ws + n);
mpn_kara_mul_n (p, a, b, n2, ws + n);
mpn_kara_mul_n (p + n, a + n2, b + n2, n2, ws + n);
}
/* Interpolate. */
if (sign)
w = mpn_add_n (ws, p, ws, n);
else
w = -mpn_sub_n (ws, p, ws, n);
w += mpn_add_n (ws, p + n, ws, n);
w += mpn_add_n (p + n2, p + n2, ws, n);
/* TO DO: could put "if (w) { ... }" here.
* Less work but badly predicted branch.
* No measurable difference in speed on Alpha.
*/
i = n + n2;
t = p[i] + w;
p[i] = t;
if (t < w)
{
do
{
++i;
w = p[i] + 1;
p[i] = w;
}
while (w == 0);
}
}
}
void
#if __STDC__
mpn_kara_sqr_n (mp_ptr p, mp_srcptr a, mp_size_t n, mp_ptr ws)
#else
mpn_kara_sqr_n (p, a, n, ws)
mp_ptr p;
mp_srcptr a;
mp_size_t n;
mp_ptr ws;
#endif
{
mp_limb_t i, sign, w, w0, w1;
mp_size_t n2;
mp_srcptr x, y;
n2 = n >> 1;
ASSERT (n2 > 0);
if (n & 1)
{
/* Odd length. */
mp_size_t n1, n3, nm1;
n3 = n - n2;
sign = 0;
w = a[n2];
if (w != 0)
w -= mpn_sub_n (p, a, a + n3, n2);
else
{
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n3+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = a + n3;
y = a;
sign = 1;
}
else
{
x = a;
y = a + n3;
}
mpn_sub_n (p, x, y, n2);
}
p[n2] = w;
w = a[n2];
if (w != 0)
w -= mpn_sub_n (p + n3, a, a + n3, n2);
else
{
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n3+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = a + n3;
y = a;
sign ^= 1;
}
else
{
x = a;
y = a + n3;
}
mpn_sub_n (p + n3, x, y, n2);
}
p[n] = w;
n1 = n + 1;
if (n2 < KARATSUBA_SQR_THRESHOLD)
{
if (n3 < KARATSUBA_SQR_THRESHOLD)
{
mpn_sqr_basecase (ws, p, n3);
mpn_sqr_basecase (p, a, n3);
}
else
{
mpn_kara_sqr_n (ws, p, n3, ws + n1);
mpn_kara_sqr_n (p, a, n3, ws + n1);
}
mpn_sqr_basecase (p + n1, a + n3, n2);
}
else
{
mpn_kara_sqr_n (ws, p, n3, ws + n1);
mpn_kara_sqr_n (p, a, n3, ws + n1);
mpn_kara_sqr_n (p + n1, a + n3, n2, ws + n1);
}
if (sign)
mpn_add_n (ws, p, ws, n1);
else
mpn_sub_n (ws, p, ws, n1);
nm1 = n - 1;
if (mpn_add_n (ws, p + n1, ws, nm1))
{
mp_limb_t x = ws[nm1] + 1;
ws[nm1] = x;
if (x == 0)
++ws[n];
}
if (mpn_add_n (p + n3, p + n3, ws, n1))
{
mp_limb_t x;
i = n1 + n3;
do
{
x = p[i] + 1;
p[i] = x;
++i;
} while (x == 0);
}
}
else
{
/* Even length. */
mp_limb_t t;
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n2+i];
}
while (w0 == w1 && i != 0);
sign = 0;
if (w0 < w1)
{
x = a + n2;
y = a;
sign = 1;
}
else
{
x = a;
y = a + n2;
}
mpn_sub_n (p, x, y, n2);
i = n2;
do
{
--i;
w0 = a[i];
w1 = a[n2+i];
}
while (w0 == w1 && i != 0);
if (w0 < w1)
{
x = a + n2;
y = a;
sign ^= 1;
}
else
{
x = a;
y = a + n2;
}
mpn_sub_n (p + n2, x, y, n2);
/* Pointwise products. */
if (n2 < KARATSUBA_SQR_THRESHOLD)
{
mpn_sqr_basecase (ws, p, n2);
mpn_sqr_basecase (p, a, n2);
mpn_sqr_basecase (p + n, a + n2, n2);
}
else
{
mpn_kara_sqr_n (ws, p, n2, ws + n);
mpn_kara_sqr_n (p, a, n2, ws + n);
mpn_kara_sqr_n (p + n, a + n2, n2, ws + n);
}
/* Interpolate. */
if (sign)
w = mpn_add_n (ws, p, ws, n);
else
w = -mpn_sub_n (ws, p, ws, n);
w += mpn_add_n (ws, p + n, ws, n);
w += mpn_add_n (p + n2, p + n2, ws, n);
/* TO DO: could put "if (w) { ... }" here.
* Less work but badly predicted branch.
* No measurable difference in speed on Alpha.
*/
i = n + n2;
t = p[i] + w;
p[i] = t;
if (t < w)
{
do
{
++i;
w = p[i] + 1;
p[i] = w;
}
while (w == 0);
}
}
}
/*-- add2Times -------------------------------------------------------------*/
/* z[] = x[] + 2 * y[]
Note that z and x might point to the same vectors. */
#ifdef USE_MORE_MPN
static inline mp_limb_t
#if __STDC__
add2Times (mp_ptr z, mp_srcptr x, mp_srcptr y, mp_size_t n)
#else
add2Times (z, x, y, n)
mp_ptr z;
mp_srcptr x;
mp_srcptr y;
mp_size_t n;
#endif
{
mp_ptr t;
mp_limb_t c;
TMP_DECL (marker);
TMP_MARK (marker);
t = (mp_ptr) TMP_ALLOC (n * BYTES_PER_MP_LIMB);
c = mpn_lshift (t, y, n, 1);
c += mpn_add_n (z, x, t, n);
TMP_FREE (marker);
return c;
}
#else
static mp_limb_t
#if __STDC__
add2Times (mp_ptr z, mp_srcptr x, mp_srcptr y, mp_size_t n)
#else
add2Times (z, x, y, n)
mp_ptr z;
mp_srcptr x;
mp_srcptr y;
mp_size_t n;
#endif
{
mp_limb_t c, v, w;
ASSERT (n > 0);
v = *x; w = *y;
c = w >> (BITS_PER_MP_LIMB - 1);
w <<= 1;
v += w;
c += v < w;
*z = v;
++x; ++y; ++z;
while (--n)
{
v = *x;
w = *y;
v += c;
c = v < c;
c += w >> (BITS_PER_MP_LIMB - 1);
w <<= 1;
v += w;
c += v < w;
*z = v;
++x; ++y; ++z;
}
return c;
}
#endif
/*-- evaluate3 -------------------------------------------------------------*/
/* Evaluates:
* ph := 4*A+2*B+C
* p1 := A+B+C
* p2 := A+2*B+4*C
* where:
* ph[], p1[], p2[], A[] and B[] all have length len,
* C[] has length len2 with len-len2 = 0, 1 or 2.
* Returns top words (overflow) at pth, pt1 and pt2 respectively.
*/
#ifdef USE_MORE_MPN
static void
#if __STDC__
evaluate3 (mp_ptr ph, mp_ptr p1, mp_ptr p2, mp_ptr pth, mp_ptr pt1, mp_ptr pt2,
mp_srcptr A, mp_srcptr B, mp_srcptr C, mp_size_t len, mp_size_t len2)
#else
evaluate3 (ph, p1, p2, pth, pt1, pt2,
A, B, C, len, len2)
mp_ptr ph;
mp_ptr p1;
mp_ptr p2;
mp_ptr pth;
mp_ptr pt1;
mp_ptr pt2;
mp_srcptr A;
mp_srcptr B;
mp_srcptr C;
mp_size_t len;
mp_size_t len2;
#endif
{
mp_limb_t c, d, e;
ASSERT (len - len2 <= 2);
e = mpn_lshift (p1, B, len, 1);
c = mpn_lshift (ph, A, len, 2);
c += e + mpn_add_n (ph, ph, p1, len);
d = mpn_add_n (ph, ph, C, len2);
if (len2 == len) c += d; else c += mpn_add_1 (ph + len2, ph + len2, len-len2, d);
ASSERT (c < 7);
*pth = c;
c = mpn_lshift (p2, C, len2, 2);
#if 1
if (len2 != len) { p2[len-1] = 0; p2[len2] = c; c = 0; }
c += e + mpn_add_n (p2, p2, p1, len);
#else
d = mpn_add_n (p2, p2, p1, len2);
c += d;
if (len2 != len) c = mpn_add_1 (p2+len2, p1+len2, len-len2, c);
c += e;
#endif
c += mpn_add_n (p2, p2, A, len);
ASSERT (c < 7);
*pt2 = c;
c = mpn_add_n (p1, A, B, len);
d = mpn_add_n (p1, p1, C, len2);
if (len2 == len) c += d;
else c += mpn_add_1 (p1+len2, p1+len2, len-len2, d);
ASSERT (c < 3);
*pt1 = c;
}
#else
static void
#if __STDC__
evaluate3 (mp_ptr ph, mp_ptr p1, mp_ptr p2, mp_ptr pth, mp_ptr pt1, mp_ptr pt2,
mp_srcptr A, mp_srcptr B, mp_srcptr C, mp_size_t l, mp_size_t ls)
#else
evaluate3 (ph, p1, p2, pth, pt1, pt2,
A, B, C, l, ls)
mp_ptr ph;
mp_ptr p1;
mp_ptr p2;
mp_ptr pth;
mp_ptr pt1;
mp_ptr pt2;
mp_srcptr A;
mp_srcptr B;
mp_srcptr C;
mp_size_t l;
mp_size_t ls;
#endif
{
mp_limb_t a,b,c, i, t, th,t1,t2, vh,v1,v2;
ASSERT (l - ls <= 2);
th = t1 = t2 = 0;
for (i = 0; i < l; ++i)
{
a = *A;
b = *B;
c = i < ls ? *C : 0;
/* TO DO: choose one of the following alternatives. */
#if 0
t = a << 2;
vh = th + t;
th = vh < t;
th += a >> (BITS_PER_MP_LIMB - 2);
t = b << 1;
vh += t;
th += vh < t;
th += b >> (BITS_PER_MP_LIMB - 1);
vh += c;
th += vh < c;
#else
vh = th + c;
th = vh < c;
t = b << 1;
vh += t;
th += vh < t;
th += b >> (BITS_PER_MP_LIMB - 1);
t = a << 2;
vh += t;
th += vh < t;
th += a >> (BITS_PER_MP_LIMB - 2);
#endif
v1 = t1 + a;
t1 = v1 < a;
v1 += b;
t1 += v1 < b;
v1 += c;
t1 += v1 < c;
v2 = t2 + a;
t2 = v2 < a;
t = b << 1;
v2 += t;
t2 += v2 < t;
t2 += b >> (BITS_PER_MP_LIMB - 1);
t = c << 2;
v2 += t;
t2 += v2 < t;
t2 += c >> (BITS_PER_MP_LIMB - 2);
*ph = vh;
*p1 = v1;
*p2 = v2;
++A; ++B; ++C;
++ph; ++p1; ++p2;
}
ASSERT (th < 7);
ASSERT (t1 < 3);
ASSERT (t2 < 7);
*pth = th;
*pt1 = t1;
*pt2 = t2;
}
#endif
/*-- interpolate3 ----------------------------------------------------------*/
/* Interpolates B, C, D (in-place) from:
* 16*A+8*B+4*C+2*D+E
* A+B+C+D+E
* A+2*B+4*C+8*D+16*E
* where:
* A[], B[], C[] and D[] all have length l,
* E[] has length ls with l-ls = 0, 2 or 4.
*
* Reads top words (from earlier overflow) from ptb, ptc and ptd,
* and returns new top words there.
*/
#ifdef USE_MORE_MPN
static void
#if __STDC__
interpolate3 (mp_srcptr A, mp_ptr B, mp_ptr C, mp_ptr D, mp_srcptr E,
mp_ptr ptb, mp_ptr ptc, mp_ptr ptd, mp_size_t len, mp_size_t len2)
#else
interpolate3 (A, B, C, D, E,
ptb, ptc, ptd, len, len2)
mp_srcptr A;
mp_ptr B;
mp_ptr C;
mp_ptr D;
mp_srcptr E;
mp_ptr ptb;
mp_ptr ptc;
mp_ptr ptd;
mp_size_t len;
mp_size_t len2;
#endif
{
mp_ptr ws;
mp_limb_t t, tb,tc,td;
TMP_DECL (marker);
TMP_MARK (marker);
ASSERT (len - len2 == 0 || len - len2 == 2 || len - len2 == 4);
/* Let x1, x2, x3 be the values to interpolate. We have:
* b = 16*a + 8*x1 + 4*x2 + 2*x3 + e
* c = a + x1 + x2 + x3 + e
* d = a + 2*x1 + 4*x2 + 8*x3 + 16*e
*/
ws = (mp_ptr) TMP_ALLOC (len * BYTES_PER_MP_LIMB);
tb = *ptb; tc = *ptc; td = *ptd;
/* b := b - 16*a - e
* c := c - a - e
* d := d - a - 16*e
*/
t = mpn_lshift (ws, A, len, 4);
tb -= t + mpn_sub_n (B, B, ws, len);
t = mpn_sub_n (B, B, E, len2);
if (len2 == len) tb -= t;
else tb -= mpn_sub_1 (B+len2, B+len2, len-len2, t);
tc -= mpn_sub_n (C, C, A, len);
t = mpn_sub_n (C, C, E, len2);
if (len2 == len) tc -= t;
else tc -= mpn_sub_1 (C+len2, C+len2, len-len2, t);
t = mpn_lshift (ws, E, len2, 4);
t += mpn_add_n (ws, ws, A, len2);
#if 1
if (len2 != len) t = mpn_add_1 (ws+len2, A+len2, len-len2, t);
td -= t + mpn_sub_n (D, D, ws, len);
#else
t += mpn_sub_n (D, D, ws, len2);
if (len2 != len) {
t = mpn_sub_1 (D+len2, D+len2, len-len2, t);
t += mpn_sub_n (D+len2, D+len2, A+len2, len-len2);
} /* end if/else */
td -= t;
#endif
/* b, d := b + d, b - d */
#ifdef HAVE_MPN_ADD_SUB_N
/* #error TO DO ... */
#else
t = tb + td + mpn_add_n (ws, B, D, len);
td = tb - td - mpn_sub_n (D, B, D, len);
tb = t;
MPN_COPY (B, ws, len);
#endif
/* b := b-8*c */
t = 8 * tc + mpn_lshift (ws, C, len, 3);
tb -= t + mpn_sub_n (B, B, ws, len);
/* c := 2*c - b */
tc = 2 * tc + mpn_lshift (C, C, len, 1);
tc -= tb + mpn_sub_n (C, C, B, len);
/* d := d/3 */
td = (td - mpn_divexact_by3 (D, D, len)) * INVERSE_3;
/* b, d := b + d, b - d */
#ifdef HAVE_MPN_ADD_SUB_N
/* #error TO DO ... */
#else
t = tb + td + mpn_add_n (ws, B, D, len);
td = tb - td - mpn_sub_n (D, B, D, len);
tb = t;
MPN_COPY (B, ws, len);
#endif
/* Now:
* b = 4*x1
* c = 2*x2
* d = 4*x3
*/
ASSERT(!(*B & 3));
mpn_rshift (B, B, len, 2);
B[len-1] |= tb<<(BITS_PER_MP_LIMB-2);
ASSERT((long)tb >= 0);
tb >>= 2;
ASSERT(!(*C & 1));
mpn_rshift (C, C, len, 1);
C[len-1] |= tc<<(BITS_PER_MP_LIMB-1);
ASSERT((long)tc >= 0);
tc >>= 1;
ASSERT(!(*D & 3));
mpn_rshift (D, D, len, 2);
D[len-1] |= td<<(BITS_PER_MP_LIMB-2);
ASSERT((long)td >= 0);
td >>= 2;
#if WANT_ASSERT
ASSERT (tb < 2);
if (len == len2)
{
ASSERT (tc < 3);
ASSERT (td < 2);
}
else
{
ASSERT (tc < 2);
ASSERT (!td);
}
#endif
*ptb = tb;
*ptc = tc;
*ptd = td;
TMP_FREE (marker);
}
#else
static void
#if __STDC__
interpolate3 (mp_srcptr A, mp_ptr B, mp_ptr C, mp_ptr D, mp_srcptr E,
mp_ptr ptb, mp_ptr ptc, mp_ptr ptd, mp_size_t l, mp_size_t ls)
#else
interpolate3 (A, B, C, D, E,
ptb, ptc, ptd, l, ls)
mp_srcptr A;
mp_ptr B;
mp_ptr C;
mp_ptr D;
mp_srcptr E;
mp_ptr ptb;
mp_ptr ptc;
mp_ptr ptd;
mp_size_t l;
mp_size_t ls;
#endif
{
mp_limb_t a,b,c,d,e,t, i, sb,sc,sd, ob,oc,od;
const mp_limb_t maskOffHalf = (~(mp_limb_t) 0) << (BITS_PER_MP_LIMB >> 1);
#if WANT_ASSERT
t = l - ls;
ASSERT (t == 0 || t == 2 || t == 4);
#endif
sb = sc = sd = 0;
for (i = 0; i < l; ++i)
{
mp_limb_t tb, tc, td, tt;
a = *A;
b = *B;
c = *C;
d = *D;
e = i < ls ? *E : 0;
/* Let x1, x2, x3 be the values to interpolate. We have:
* b = 16*a + 8*x1 + 4*x2 + 2*x3 + e
* c = a + x1 + x2 + x3 + e
* d = a + 2*x1 + 4*x2 + 8*x3 + 16*e
*/
/* b := b - 16*a - e
* c := c - a - e
* d := d - a - 16*e
*/
t = a << 4;
tb = -(a >> (BITS_PER_MP_LIMB - 4)) - (b < t);
b -= t;
tb -= b < e;
b -= e;
tc = -(c < a);
c -= a;
tc -= c < e;
c -= e;
td = -(d < a);
d -= a;
t = e << 4;
td = td - (e >> (BITS_PER_MP_LIMB - 4)) - (d < t);
d -= t;
/* b, d := b + d, b - d */
t = b + d;
tt = tb + td + (t < b);
td = tb - td - (b < d);
d = b - d;
b = t;
tb = tt;
/* b := b-8*c */
t = c << 3;
tb = tb - (tc << 3) - (c >> (BITS_PER_MP_LIMB - 3)) - (b < t);
b -= t;
/* c := 2*c - b */
t = c << 1;
tc = (tc << 1) + (c >> (BITS_PER_MP_LIMB - 1)) - tb - (t < b);
c = t - b;
/* d := d/3 */
d *= INVERSE_3;
td = td - (d >> (BITS_PER_MP_LIMB - 1)) - (d*3 < d);
td *= INVERSE_3;
/* b, d := b + d, b - d */
t = b + d;
tt = tb + td + (t < b);
td = tb - td - (b < d);
d = b - d;
b = t;
tb = tt;
/* Now:
* b = 4*x1
* c = 2*x2
* d = 4*x3
*/
/* sb has period 2. */
b += sb;
tb += b < sb;
sb &= maskOffHalf;
sb |= sb >> (BITS_PER_MP_LIMB >> 1);
sb += tb;
/* sc has period 1. */
c += sc;
tc += c < sc;
/* TO DO: choose one of the following alternatives. */
#if 1
sc = (mp_limb_t)((long)sc >> (BITS_PER_MP_LIMB - 1));
sc += tc;
#else
sc = tc - ((long)sc < 0L);
#endif
/* sd has period 2. */
d += sd;
td += d < sd;
sd &= maskOffHalf;
sd |= sd >> (BITS_PER_MP_LIMB >> 1);
sd += td;
if (i != 0)
{
B[-1] = ob | b << (BITS_PER_MP_LIMB - 2);
C[-1] = oc | c << (BITS_PER_MP_LIMB - 1);
D[-1] = od | d << (BITS_PER_MP_LIMB - 2);
}
ob = b >> 2;
oc = c >> 1;
od = d >> 2;
++A; ++B; ++C; ++D; ++E;
}
/* Handle top words. */
b = *ptb;
c = *ptc;
d = *ptd;
t = b + d;
d = b - d;
b = t;
b -= c << 3;
c = (c << 1) - b;
d *= INVERSE_3;
t = b + d;
d = b - d;
b = t;
b += sb;
c += sc;
d += sd;
B[-1] = ob | b << (BITS_PER_MP_LIMB - 2);
C[-1] = oc | c << (BITS_PER_MP_LIMB - 1);
D[-1] = od | d << (BITS_PER_MP_LIMB - 2);
b >>= 2;
c >>= 1;
d >>= 2;
#if WANT_ASSERT
ASSERT (b < 2);
if (l == ls)
{
ASSERT (c < 3);
ASSERT (d < 2);
}
else
{
ASSERT (c < 2);
ASSERT (!d);
}
#endif
*ptb = b;
*ptc = c;
*ptd = d;
}
#endif
/*-- mpn_toom3_mul_n --------------------------------------------------------------*/
/* Multiplies using 5 mults of one third size and so on recursively.
* p[0..2*n-1] := product of a[0..n-1] and b[0..n-1].
* No overlap of p[...] with a[...] or b[...].
* ws is workspace.
*/
/* TO DO: If TOOM3_MUL_THRESHOLD is much bigger than KARATSUBA_MUL_THRESHOLD then the
* recursion in mpn_toom3_mul_n() will always bottom out with mpn_kara_mul_n()
* because the "n < KARATSUBA_MUL_THRESHOLD" test here will always be false.
*/
#define TOOM3_MUL_REC(p, a, b, n, ws) \
do { \
if (n < KARATSUBA_MUL_THRESHOLD) \
mpn_mul_basecase (p, a, n, b, n); \
else if (n < TOOM3_MUL_THRESHOLD) \
mpn_kara_mul_n (p, a, b, n, ws); \
else \
mpn_toom3_mul_n (p, a, b, n, ws); \
} while (0)
void
#if __STDC__
mpn_toom3_mul_n (mp_ptr p, mp_srcptr a, mp_srcptr b, mp_size_t n, mp_ptr ws)
#else
mpn_toom3_mul_n (p, a, b, n, ws)
mp_ptr p;
mp_srcptr a;
mp_srcptr b;
mp_size_t n;
mp_ptr ws;
#endif
{
mp_limb_t cB,cC,cD, dB,dC,dD, tB,tC,tD;
mp_limb_t *A,*B,*C,*D,*E, *W;
mp_size_t l,l2,l3,l4,l5,ls;
/* Break n words into chunks of size l, l and ls.
* n = 3*k => l = k, ls = k
* n = 3*k+1 => l = k+1, ls = k-1
* n = 3*k+2 => l = k+1, ls = k
*/
{
mp_limb_t m;
ASSERT (n >= TOOM3_MUL_THRESHOLD);
l = ls = n / 3;
m = n - l * 3;
if (m != 0)
++l;
if (m == 1)
--ls;
l2 = l * 2;
l3 = l * 3;
l4 = l * 4;
l5 = l * 5;
A = p;
B = ws;
C = p + l2;
D = ws + l2;
E = p + l4;
W = ws + l4;
}
/** First stage: evaluation at points 0, 1/2, 1, 2, oo. **/
evaluate3 (A, B, C, &cB, &cC, &cD, a, a + l, a + l2, l, ls);
evaluate3 (A + l, B + l, C + l, &dB, &dC, &dD, b, b + l, b + l2, l, ls);
/** Second stage: pointwise multiplies. **/
TOOM3_MUL_REC(D, C, C + l, l, W);
tD = cD*dD;
if (cD) tD += mpn_addmul_1 (D + l, C + l, l, cD);
if (dD) tD += mpn_addmul_1 (D + l, C, l, dD);
ASSERT (tD < 49);
TOOM3_MUL_REC(C, B, B + l, l, W);
tC = cC*dC;
/* TO DO: choose one of the following alternatives. */
#if 0
if (cC) tC += mpn_addmul_1 (C + l, B + l, l, cC);
if (dC) tC += mpn_addmul_1 (C + l, B, l, dC);
#else
if (cC)
{
if (cC == 1) tC += mpn_add_n (C + l, C + l, B + l, l);
else tC += add2Times (C + l, C + l, B + l, l);
}
if (dC)
{
if (dC == 1) tC += mpn_add_n (C + l, C + l, B, l);
else tC += add2Times (C + l, C + l, B, l);
}
#endif
ASSERT (tC < 9);
TOOM3_MUL_REC(B, A, A + l, l, W);
tB = cB*dB;
if (cB) tB += mpn_addmul_1 (B + l, A + l, l, cB);
if (dB) tB += mpn_addmul_1 (B + l, A, l, dB);
ASSERT (tB < 49);
TOOM3_MUL_REC(A, a, b, l, W);
TOOM3_MUL_REC(E, a + l2, b + l2, ls, W);
/** Third stage: interpolation. **/
interpolate3 (A, B, C, D, E, &tB, &tC, &tD, l2, ls << 1);
/** Final stage: add up the coefficients. **/
{
mp_limb_t i, x, y;
tB += mpn_add_n (p + l, p + l, B, l2);
tD += mpn_add_n (p + l3, p + l3, D, l2);
mpn_incr_u (p + l3, tB);
mpn_incr_u (p + l4, tC);
mpn_incr_u (p + l5, tD);
}
}
/*-- mpn_toom3_sqr_n --------------------------------------------------------------*/
/* Like previous function but for squaring */
#define TOOM3_SQR_REC(p, a, n, ws) \
do { \
if (n < KARATSUBA_SQR_THRESHOLD) \
mpn_sqr_basecase (p, a, n); \
else if (n < TOOM3_SQR_THRESHOLD) \
mpn_kara_sqr_n (p, a, n, ws); \
else \
mpn_toom3_sqr_n (p, a, n, ws); \
} while (0)
void
#if __STDC__
mpn_toom3_sqr_n (mp_ptr p, mp_srcptr a, mp_size_t n, mp_ptr ws)
#else
mpn_toom3_sqr_n (p, a, n, ws)
mp_ptr p;
mp_srcptr a;
mp_size_t n;
mp_ptr ws;
#endif
{
mp_limb_t cB,cC,cD, tB,tC,tD;
mp_limb_t *A,*B,*C,*D,*E, *W;
mp_size_t l,l2,l3,l4,l5,ls;
/* Break n words into chunks of size l, l and ls.
* n = 3*k => l = k, ls = k
* n = 3*k+1 => l = k+1, ls = k-1
* n = 3*k+2 => l = k+1, ls = k
*/
{
mp_limb_t m;
ASSERT (n >= TOOM3_MUL_THRESHOLD);
l = ls = n / 3;
m = n - l * 3;
if (m != 0)
++l;
if (m == 1)
--ls;
l2 = l * 2;
l3 = l * 3;
l4 = l * 4;
l5 = l * 5;
A = p;
B = ws;
C = p + l2;
D = ws + l2;
E = p + l4;
W = ws + l4;
}
/** First stage: evaluation at points 0, 1/2, 1, 2, oo. **/
evaluate3 (A, B, C, &cB, &cC, &cD, a, a + l, a + l2, l, ls);
/** Second stage: pointwise multiplies. **/
TOOM3_SQR_REC(D, C, l, W);
tD = cD*cD;
if (cD) tD += mpn_addmul_1 (D + l, C, l, 2*cD);
ASSERT (tD < 49);
TOOM3_SQR_REC(C, B, l, W);
tC = cC*cC;
/* TO DO: choose one of the following alternatives. */
#if 0
if (cC) tC += mpn_addmul_1 (C + l, B, l, 2*cC);
#else
if (cC >= 1)
{
tC += add2Times (C + l, C + l, B, l);
if (cC == 2)
tC += add2Times (C + l, C + l, B, l);
}
#endif
ASSERT (tC < 9);
TOOM3_SQR_REC(B, A, l, W);
tB = cB*cB;
if (cB) tB += mpn_addmul_1 (B + l, A, l, 2*cB);
ASSERT (tB < 49);
TOOM3_SQR_REC(A, a, l, W);
TOOM3_SQR_REC(E, a + l2, ls, W);
/** Third stage: interpolation. **/
interpolate3 (A, B, C, D, E, &tB, &tC, &tD, l2, ls << 1);
/** Final stage: add up the coefficients. **/
{
mp_limb_t i, x, y;
tB += mpn_add_n (p + l, p + l, B, l2);
tD += mpn_add_n (p + l3, p + l3, D, l2);
mpn_incr_u (p + l3, tB);
mpn_incr_u (p + l4, tC);
mpn_incr_u (p + l5, tD);
}
}
void
#if __STDC__
mpn_mul_n (mp_ptr p, mp_srcptr a, mp_srcptr b, mp_size_t n)
#else
mpn_mul_n (p, a, b, n)
mp_ptr p;
mp_srcptr a;
mp_srcptr b;
mp_size_t n;
#endif
{
if (n < KARATSUBA_MUL_THRESHOLD)
mpn_mul_basecase (p, a, n, b, n);
else if (n < TOOM3_MUL_THRESHOLD)
{
/* Allocate workspace of fixed size on stack: fast! */
#if TUNE_PROGRAM_BUILD
mp_limb_t ws[2 * (TOOM3_MUL_THRESHOLD_LIMIT-1) + 2 * BITS_PER_MP_LIMB];
#else
mp_limb_t ws[2 * (TOOM3_MUL_THRESHOLD-1) + 2 * BITS_PER_MP_LIMB];
#endif
mpn_kara_mul_n (p, a, b, n, ws);
}
#if WANT_FFT || TUNE_PROGRAM_BUILD
else if (n < FFT_MUL_THRESHOLD)
#else
else
#endif
{
/* Use workspace of unknown size in heap, as stack space may
* be limited. Since n is at least TOOM3_MUL_THRESHOLD, the
* multiplication will take much longer than malloc()/free(). */
mp_limb_t wsLen, *ws;
wsLen = 2 * n + 3 * BITS_PER_MP_LIMB;
ws = (mp_ptr) (*_mp_allocate_func) ((size_t) wsLen * sizeof (mp_limb_t));
mpn_toom3_mul_n (p, a, b, n, ws);
(*_mp_free_func) (ws, (size_t) wsLen * sizeof (mp_limb_t));
}
#if WANT_FFT || TUNE_PROGRAM_BUILD
else
{
mpn_mul_fft_full (p, a, n, b, n);
}
#endif
}