From 3e56a2c97d8da44054d165d1365c7bb78824c529 Mon Sep 17 00:00:00 2001 From: Sohaib ul Hassan Date: Thu, 9 Jul 2020 18:51:51 +0000 Subject: [NSS] Implement constant-time GCD and modular inversion The implementation is based on the work by Bernstein and Yang (https://eprint.iacr.org/2019/266) "Fast constant-time gcd computation and modular inversion". It fixes the old mp_gcd and s_mp_invmod_odd_m functions. The patch also fixes mpl_significant_bits s_mp_div_2d and s_mp_mul_2d by having less control flow to reduce side-channel leaks. Co-authored by : Billy Bob Brumley --- security/nss/lib/freebl/mpi/mpi.c | 378 +++++++++++++++++++++++----------- security/nss/lib/freebl/mpi/mpi.h | 1 + security/nss/lib/freebl/mpi/mplogic.c | 45 ++-- 3 files changed, 292 insertions(+), 132 deletions(-) (limited to 'security') diff --git a/security/nss/lib/freebl/mpi/mpi.c b/security/nss/lib/freebl/mpi/mpi.c index 7e96e51ff..1b7b171e7 100644 --- a/security/nss/lib/freebl/mpi/mpi.c +++ b/security/nss/lib/freebl/mpi/mpi.c @@ -8,6 +8,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "mpi-priv.h" +#include "mplogic.h" #if defined(OSF1) #include #endif @@ -1688,98 +1689,112 @@ mp_iseven(const mp_int *a) /* {{{ mp_gcd(a, b, c) */ /* - Like the old mp_gcd() function, except computes the GCD using the - binary algorithm due to Josef Stein in 1961 (via Knuth). + Computes the GCD using the constant-time algorithm + by Bernstein and Yang (https://eprint.iacr.org/2019/266) + "Fast constant-time gcd computation and modular inversion" */ mp_err mp_gcd(mp_int *a, mp_int *b, mp_int *c) { mp_err res; - mp_int u, v, t; - mp_size k = 0; + mp_digit cond = 0, mask = 0; + mp_int g, temp, f; + int i, j, m, bit = 1, delta = 1, shifts = 0, last = -1; + mp_size top, flen, glen; + mp_int *clear[3]; ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG); - - if (mp_cmp_z(a) == MP_EQ && mp_cmp_z(b) == MP_EQ) - return MP_RANGE; + /* + Early exit if either of the inputs is zero. + Caller is responsible for the proper handling of inputs. + */ if (mp_cmp_z(a) == MP_EQ) { - return mp_copy(b, c); + res = mp_copy(b, c); + SIGN(c) = ZPOS; + return res; } else if (mp_cmp_z(b) == MP_EQ) { - return mp_copy(a, c); - } - - if ((res = mp_init(&t)) != MP_OKAY) + res = mp_copy(a, c); + SIGN(c) = ZPOS; return res; - if ((res = mp_init_copy(&u, a)) != MP_OKAY) - goto U; - if ((res = mp_init_copy(&v, b)) != MP_OKAY) - goto V; - - SIGN(&u) = ZPOS; - SIGN(&v) = ZPOS; - - /* Divide out common factors of 2 until at least 1 of a, b is even */ - while (mp_iseven(&u) && mp_iseven(&v)) { - s_mp_div_2(&u); - s_mp_div_2(&v); - ++k; } - /* Initialize t */ - if (mp_isodd(&u)) { - if ((res = mp_copy(&v, &t)) != MP_OKAY) - goto CLEANUP; - - /* t = -v */ - if (SIGN(&v) == ZPOS) - SIGN(&t) = NEG; - else - SIGN(&t) = ZPOS; + MP_CHECKOK(mp_init(&temp)); + clear[++last] = &temp; + MP_CHECKOK(mp_init_copy(&g, a)); + clear[++last] = &g; + MP_CHECKOK(mp_init_copy(&f, b)); + clear[++last] = &f; - } else { - if ((res = mp_copy(&u, &t)) != MP_OKAY) - goto CLEANUP; + /* + For even case compute the number of + shared powers of 2 in f and g. + */ + for (i = 0; i < USED(&f) && i < USED(&g); i++) { + mask = ~(DIGIT(&f, i) | DIGIT(&g, i)); + for (j = 0; j < MP_DIGIT_BIT; j++) { + bit &= mask; + shifts += bit; + mask >>= 1; + } } + /* Reduce to the odd case by removing the powers of 2. */ + s_mp_div_2d(&f, shifts); + s_mp_div_2d(&g, shifts); - for (;;) { - while (mp_iseven(&t)) { - s_mp_div_2(&t); - } + /* Allocate to the size of largest mp_int. */ + top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g)); + MP_CHECKOK(s_mp_grow(&f, top)); + MP_CHECKOK(s_mp_grow(&g, top)); + MP_CHECKOK(s_mp_grow(&temp, top)); - if (mp_cmp_z(&t) == MP_GT) { - if ((res = mp_copy(&t, &u)) != MP_OKAY) - goto CLEANUP; + /* Make sure f contains the odd value. */ + MP_CHECKOK(mp_cswap((~DIGIT(&f, 0) & 1), &f, &g, top)); - } else { - if ((res = mp_copy(&t, &v)) != MP_OKAY) - goto CLEANUP; + /* Upper bound for the total iterations. */ + flen = mpl_significant_bits(&f); + glen = mpl_significant_bits(&g); + m = 4 + 3 * ((flen >= glen) ? flen : glen); - /* v = -t */ - if (SIGN(&t) == ZPOS) - SIGN(&v) = NEG; - else - SIGN(&v) = ZPOS; - } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit +#endif - if ((res = mp_sub(&u, &v, &t)) != MP_OKAY) - goto CLEANUP; + for (i = 0; i < m; i++) { + /* Step 1: conditional swap. */ + /* Set cond if delta > 0 and g is odd. */ + cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1; + /* If cond is set replace (delta,f) with (-delta,-f). */ + delta = (-cond & -delta) | ((cond - 1) & delta); + SIGN(&f) ^= cond; + /* If cond is set swap f with g. */ + MP_CHECKOK(mp_cswap(cond, &f, &g, top)); + + /* Step 2: elemination. */ + /* Update delta. */ + delta++; + /* If g is odd, right shift (g+f) else right shift g. */ + MP_CHECKOK(mp_add(&g, &f, &temp)); + MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top)); + s_mp_div_2(&g); + } + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif - if (s_mp_cmp_d(&t, 0) == MP_EQ) - break; - } + /* GCD is in f, take the absolute value. */ + SIGN(&f) = ZPOS; - s_mp_2expt(&v, k); /* v = 2^k */ - res = mp_mul(&u, &v, c); /* c = u * v */ + /* Add back the removed powers of 2. */ + MP_CHECKOK(s_mp_mul_2d(&f, shifts)); -CLEANUP: - mp_clear(&v); -V: - mp_clear(&u); -U: - mp_clear(&t); + MP_CHECKOK(mp_copy(&f, c)); +CLEANUP: + while (last >= 0) + mp_clear(clear[last--]); return res; - } /* end mp_gcd() */ /* }}} */ @@ -2131,42 +2146,114 @@ CLEANUP: return res; } -/* compute mod inverse using Schroeppel's method, only if m is odd */ +/* + Computes the modular inverse using the constant-time algorithm + by Bernstein and Yang (https://eprint.iacr.org/2019/266) + "Fast constant-time gcd computation and modular inversion" + */ mp_err s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c) { - int k; mp_err res; - mp_int x; + mp_digit cond = 0; + mp_int g, f, v, r, temp; + int i, its, delta = 1, last = -1; + mp_size top, flen, glen; + mp_int *clear[6]; ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG); - - if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0) + /* Check for invalid inputs. */ + if (mp_cmp_z(a) == MP_EQ || mp_cmp_d(m, 2) == MP_LT) return MP_RANGE; - if (mp_iseven(m)) + + if (a == m || mp_iseven(m)) return MP_UNDEF; - MP_DIGITS(&x) = 0; + MP_CHECKOK(mp_init(&temp)); + clear[++last] = &temp; + MP_CHECKOK(mp_init(&v)); + clear[++last] = &v; + MP_CHECKOK(mp_init(&r)); + clear[++last] = &r; + MP_CHECKOK(mp_init_copy(&g, a)); + clear[++last] = &g; + MP_CHECKOK(mp_init_copy(&f, m)); + clear[++last] = &f; + + mp_set(&v, 0); + mp_set(&r, 1); + + /* Allocate to the size of largest mp_int. */ + top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g)); + MP_CHECKOK(s_mp_grow(&f, top)); + MP_CHECKOK(s_mp_grow(&g, top)); + MP_CHECKOK(s_mp_grow(&temp, top)); + MP_CHECKOK(s_mp_grow(&v, top)); + MP_CHECKOK(s_mp_grow(&r, top)); + + /* Upper bound for the total iterations. */ + flen = mpl_significant_bits(&f); + glen = mpl_significant_bits(&g); + its = 4 + 3 * ((flen >= glen) ? flen : glen); + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit +#endif - if (a == c) { - if ((res = mp_init_copy(&x, a)) != MP_OKAY) - return res; - if (a == m) - m = &x; - a = &x; - } else if (m == c) { - if ((res = mp_init_copy(&x, m)) != MP_OKAY) - return res; - m = &x; - } else { - MP_DIGITS(&x) = 0; + for (i = 0; i < its; i++) { + /* Step 1: conditional swap. */ + /* Set cond if delta > 0 and g is odd. */ + cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1; + /* If cond is set replace (delta,f,v) with (-delta,-f,-v). */ + delta = (-cond & -delta) | ((cond - 1) & delta); + SIGN(&f) ^= cond; + SIGN(&v) ^= cond; + /* If cond is set swap (f,v) with (g,r). */ + MP_CHECKOK(mp_cswap(cond, &f, &g, top)); + MP_CHECKOK(mp_cswap(cond, &v, &r, top)); + + /* Step 2: elemination. */ + /* Update delta */ + delta++; + /* If g is odd replace r with (r+v). */ + MP_CHECKOK(mp_add(&r, &v, &temp)); + MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &r, &temp, top)); + /* If g is odd, right shift (g+f) else right shift g. */ + MP_CHECKOK(mp_add(&g, &f, &temp)); + MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top)); + s_mp_div_2(&g); + /* + If r is even, right shift it. + If r is odd, right shift (r+m) which is even because m is odd. + We want the result modulo m so adding in multiples of m here vanish. + */ + MP_CHECKOK(mp_add(&r, m, &temp)); + MP_CHECKOK(mp_cswap((DIGIT(&r, 0) & 1), &r, &temp, top)); + s_mp_div_2(&r); } - MP_CHECKOK(s_mp_almost_inverse(a, m, c)); - k = res; - MP_CHECKOK(s_mp_fixup_reciprocal(c, m, k, c)); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + /* We have the inverse in v, propagate sign from f. */ + SIGN(&v) ^= SIGN(&f); + /* GCD is in f, take the absolute value. */ + SIGN(&f) = ZPOS; + + /* If gcd != 1, not invertible. */ + if (mp_cmp_d(&f, 1) != MP_EQ) { + res = MP_UNDEF; + goto CLEANUP; + } + + /* Return inverse modulo m. */ + MP_CHECKOK(mp_mod(&v, m, c)); + CLEANUP: - mp_clear(&x); + while (last >= 0) + mp_clear(clear[last--]); return res; } @@ -2218,13 +2305,24 @@ s_mp_invmod_2d(const mp_int *a, mp_size k, mp_int *c) if (mp_iseven(a)) return MP_UNDEF; + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit +#endif if (k <= MP_DIGIT_BIT) { mp_digit i = s_mp_invmod_radix(MP_DIGIT(a, 0)); + /* propagate the sign from mp_int */ + i = (i ^ -(mp_digit)SIGN(a)) + (mp_digit)SIGN(a); if (k < MP_DIGIT_BIT) i &= ((mp_digit)1 << k) - (mp_digit)1; mp_set(c, i); return MP_OKAY; } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + MP_DIGITS(&t0) = 0; MP_DIGITS(&t1) = 0; MP_DIGITS(&val) = 0; @@ -2831,6 +2929,8 @@ s_mp_clamp(mp_int *mp) while (used > 1 && DIGIT(mp, used - 1) == 0) --used; MP_USED(mp) = used; + if (used == 1 && DIGIT(mp, 0) == 0) + MP_SIGN(mp) = ZPOS; } /* end s_mp_clamp() */ /* }}} */ @@ -2908,37 +3008,36 @@ mp_err s_mp_mul_2d(mp_int *mp, mp_digit d) { mp_err res; - mp_digit dshift, bshift; - mp_digit mask; + mp_digit dshift, rshift, mask, x, prev = 0; + mp_digit *pa = NULL; + int i; ARGCHK(mp != NULL, MP_BADARG); dshift = d / MP_DIGIT_BIT; - bshift = d % MP_DIGIT_BIT; + d %= MP_DIGIT_BIT; + /* mp_digit >> rshift is undefined behavior for rshift >= MP_DIGIT_BIT */ + /* mod and corresponding mask logic avoid that when d = 0 */ + rshift = MP_DIGIT_BIT - d; + rshift %= MP_DIGIT_BIT; + /* mask = (2**d - 1) * 2**(w-d) mod 2**w */ + mask = (DIGIT_MAX << rshift) + 1; + mask &= DIGIT_MAX - 1; /* bits to be shifted out of the top word */ - if (bshift) { - mask = (mp_digit)~0 << (MP_DIGIT_BIT - bshift); - mask &= MP_DIGIT(mp, MP_USED(mp) - 1); - } else { - mask = 0; - } + x = MP_DIGIT(mp, MP_USED(mp) - 1) & mask; - if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (mask != 0)))) + if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (x != 0)))) return res; if (dshift && MP_OKAY != (res = s_mp_lshd(mp, dshift))) return res; - if (bshift) { - mp_digit *pa = MP_DIGITS(mp); - mp_digit *alim = pa + MP_USED(mp); - mp_digit prev = 0; + pa = MP_DIGITS(mp) + dshift; - for (pa += dshift; pa < alim;) { - mp_digit x = *pa; - *pa++ = (x << bshift) | prev; - prev = x >> (DIGIT_BIT - bshift); - } + for (i = MP_USED(mp) - dshift; i > 0; i--) { + x = *pa; + *pa++ = (x << d) | prev; + prev = (x & mask) >> rshift; } s_mp_clamp(mp); @@ -3077,18 +3176,20 @@ void s_mp_div_2d(mp_int *mp, mp_digit d) { int ix; - mp_digit save, next, mask; + mp_digit save, next, mask, lshift; s_mp_rshd(mp, d / DIGIT_BIT); d %= DIGIT_BIT; - if (d) { - mask = ((mp_digit)1 << d) - 1; - save = 0; - for (ix = USED(mp) - 1; ix >= 0; ix--) { - next = DIGIT(mp, ix) & mask; - DIGIT(mp, ix) = (DIGIT(mp, ix) >> d) | (save << (DIGIT_BIT - d)); - save = next; - } + /* mp_digit << lshift is undefined behavior for lshift >= MP_DIGIT_BIT */ + /* mod and corresponding mask logic avoid that when d = 0 */ + lshift = DIGIT_BIT - d; + lshift %= DIGIT_BIT; + mask = ((mp_digit)1 << d) - 1; + save = 0; + for (ix = USED(mp) - 1; ix >= 0; ix--) { + next = DIGIT(mp, ix) & mask; + DIGIT(mp, ix) = (save << lshift) | (DIGIT(mp, ix) >> d); + save = next; } s_mp_clamp(mp); @@ -4841,5 +4942,44 @@ mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length) } /* end mp_to_fixlen_octets() */ /* }}} */ +/* {{{ mp_cswap(condition, a, b, numdigits) */ +/* performs a conditional swap between mp_int. */ +mp_err +mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits) +{ + mp_digit x; + unsigned int i; + mp_err res = 0; + + /* if pointers are equal return */ + if (a == b) + return res; + + if (MP_ALLOC(a) < numdigits || MP_ALLOC(b) < numdigits) { + MP_CHECKOK(s_mp_grow(a, numdigits)); + MP_CHECKOK(s_mp_grow(b, numdigits)); + } + + condition = ((~condition & ((condition - 1))) >> (MP_DIGIT_BIT - 1)) - 1; + + x = (USED(a) ^ USED(b)) & condition; + USED(a) ^= x; + USED(b) ^= x; + + x = (SIGN(a) ^ SIGN(b)) & condition; + SIGN(a) ^= x; + SIGN(b) ^= x; + + for (i = 0; i < numdigits; i++) { + x = (DIGIT(a, i) ^ DIGIT(b, i)) & condition; + DIGIT(a, i) ^= x; + DIGIT(b, i) ^= x; + } + +CLEANUP: + return res; +} /* end mp_cswap() */ +/* }}} */ + /*------------------------------------------------------------------------*/ /* HERE THERE BE DRAGONS */ diff --git a/security/nss/lib/freebl/mpi/mpi.h b/security/nss/lib/freebl/mpi/mpi.h index af608b43d..b1a07a61d 100644 --- a/security/nss/lib/freebl/mpi/mpi.h +++ b/security/nss/lib/freebl/mpi/mpi.h @@ -267,6 +267,7 @@ mp_size mp_trailing_zeros(const mp_int *mp); void freebl_cpuid(unsigned long op, unsigned long *eax, unsigned long *ebx, unsigned long *ecx, unsigned long *edx); +mp_err mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits); #define MP_CHECKOK(x) \ if (MP_OKAY > (res = (x))) \ diff --git a/security/nss/lib/freebl/mpi/mplogic.c b/security/nss/lib/freebl/mpi/mplogic.c index 89fd03ae8..23ddfec1a 100644 --- a/security/nss/lib/freebl/mpi/mplogic.c +++ b/security/nss/lib/freebl/mpi/mplogic.c @@ -409,35 +409,54 @@ mpl_get_bits(const mp_int *a, mp_size lsbNum, mp_size numBits) return (mp_err)mask; } +#define LZCNTLOOP(i) \ + do { \ + x = d >> (i); \ + mask = (0 - x); \ + mask = (0 - (mask >> (MP_DIGIT_BIT - 1))); \ + bits += (i)&mask; \ + d ^= (x ^ d) & mask; \ + } while (0) + /* mpl_significant_bits - returns number of significnant bits in abs(a). + returns number of significant bits in abs(a). + In other words: floor(lg(abs(a))) + 1. returns 1 if value is zero. */ mp_size mpl_significant_bits(const mp_int *a) { - mp_size bits = 0; + /* + start bits at 1. + lg(0) = 0 => bits = 1 by function semantics. + below does a binary search for the _position_ of the top bit set, + which is floor(lg(abs(a))) for a != 0. + */ + mp_size bits = 1; int ix; ARGCHK(a != NULL, MP_BADARG); for (ix = MP_USED(a); ix > 0;) { - mp_digit d; - d = MP_DIGIT(a, --ix); - if (d) { - while (d) { - ++bits; - d >>= 1; - } - break; - } + mp_digit d, x, mask; + if ((d = MP_DIGIT(a, --ix)) == 0) + continue; +#if !defined(MP_USE_UINT_DIGIT) + LZCNTLOOP(32); +#endif + LZCNTLOOP(16); + LZCNTLOOP(8); + LZCNTLOOP(4); + LZCNTLOOP(2); + LZCNTLOOP(1); + break; } bits += ix * MP_DIGIT_BIT; - if (!bits) - bits = 1; return bits; } +#undef LZCNTLOOP + /*------------------------------------------------------------------------*/ /* HERE THERE BE DRAGONS */ -- cgit v1.2.3