diff --git a/py/mpz.c b/py/mpz.c index e503927d09..6477c3f8d8 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -909,6 +909,37 @@ mp_uint_t mpz_set_from_str(mpz_t *z, const char *str, mp_uint_t len, bool neg, m return cur - str; } +void mpz_set_from_bytes(mpz_t *z, bool big_endian, mp_uint_t len, const byte *buf) { + int delta = 1; + if (big_endian) { + buf += len - 1; + delta = -1; + } + + mpz_need_dig(z, (len * 8 + DIG_SIZE - 1) / DIG_SIZE); + + mpz_dig_t d = 0; + int num_bits = 0; + z->neg = 0; + z->len = 0; + while (len) { + while (len && num_bits < DIG_SIZE) { + d |= *buf << num_bits; + num_bits += 8; + buf += delta; + len--; + } + z->dig[z->len++] = d & DIG_MASK; + // Need this #if because it's C undefined behavior to do: uint32_t >> 32 + #if DIG_SIZE != 8 && DIG_SIZE != 16 && DIG_SIZE != 32 + d >>= DIG_SIZE; + #else + d = 0; + #endif + num_bits -= DIG_SIZE; + } +} + bool mpz_is_zero(const mpz_t *z) { return z->len == 0; } diff --git a/py/mpz.h b/py/mpz.h index 55ef3e15ff..a26cbea5cb 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -109,6 +109,7 @@ void mpz_set_from_ll(mpz_t *z, long long i, bool is_signed); void mpz_set_from_float(mpz_t *z, mp_float_t src); #endif mp_uint_t mpz_set_from_str(mpz_t *z, const char *str, mp_uint_t len, bool neg, mp_uint_t base); +void mpz_set_from_bytes(mpz_t *z, bool big_endian, mp_uint_t len, const byte *buf); bool mpz_is_zero(const mpz_t *z); int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs);