summaryrefslogtreecommitdiffstats
path: root/security/nss/lib/freebl/rsapkcs.c
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/lib/freebl/rsapkcs.c')
-rw-r--r--security/nss/lib/freebl/rsapkcs.c234
1 files changed, 152 insertions, 82 deletions
diff --git a/security/nss/lib/freebl/rsapkcs.c b/security/nss/lib/freebl/rsapkcs.c
index 577fe1f61..ad18c8b73 100644
--- a/security/nss/lib/freebl/rsapkcs.c
+++ b/security/nss/lib/freebl/rsapkcs.c
@@ -85,6 +85,25 @@ rsa_modulusLen(SECItem *modulus)
return modLen;
}
+static unsigned int
+rsa_modulusBits(SECItem *modulus)
+{
+ unsigned char byteZero = modulus->data[0];
+ unsigned int numBits = (modulus->len - 1) * 8;
+
+ if (byteZero == 0) {
+ numBits -= 8;
+ byteZero = modulus->data[1];
+ }
+
+ while (byteZero > 0) {
+ numBits++;
+ byteZero >>= 1;
+ }
+
+ return numBits;
+}
+
/*
* Format one block of data for public/private key encryption using
* the rules defined in PKCS #1.
@@ -271,10 +290,12 @@ MGF1(HASH_HashType hashAlg,
const SECHashObject *hash;
void *hashContext;
unsigned char C[4];
+ SECStatus rv = SECSuccess;
hash = HASH_GetRawHashObject(hashAlg);
- if (hash == NULL)
+ if (hash == NULL) {
return SECFailure;
+ }
hashContext = (*hash->create)();
rounds = (maskLen + hash->length - 1) / hash->length;
@@ -295,14 +316,19 @@ MGF1(HASH_HashType hashAlg,
(*hash->end)(hashContext, tempHash, &digestLen, hash->length);
} else { /* we're in the last round and need to cut the hash */
temp = (unsigned char *)PORT_Alloc(hash->length);
+ if (!temp) {
+ rv = SECFailure;
+ goto done;
+ }
(*hash->end)(hashContext, temp, &digestLen, hash->length);
PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length);
PORT_Free(temp);
}
}
- (*hash->destroy)(hashContext, PR_TRUE);
- return SECSuccess;
+done:
+ (*hash->destroy)(hashContext, PR_TRUE);
+ return rv;
}
/* XXX Doesn't set error code */
@@ -962,12 +988,11 @@ failure:
* We use mHash instead of M as input.
* emBits from the RFC is just modBits - 1, see section 8.1.1.
* We only support MGF1 as the MGF.
- *
- * NOTE: this code assumes modBits is a multiple of 8.
*/
static SECStatus
emsa_pss_encode(unsigned char *em,
unsigned int emLen,
+ unsigned int emBits,
const unsigned char *mHash,
HASH_HashType hashAlg,
HASH_HashType maskHashAlg,
@@ -1032,7 +1057,7 @@ emsa_pss_encode(unsigned char *em,
PORT_Free(dbMask);
/* Step 11 */
- em[0] &= 0x7f;
+ em[0] &= 0xff >> (8 * emLen - emBits);
/* Step 12 */
em[emLen - 1] = 0xbc;
@@ -1046,13 +1071,12 @@ emsa_pss_encode(unsigned char *em,
* We use mHash instead of M as input.
* emBits from the RFC is just modBits - 1, see section 8.1.2.
* We only support MGF1 as the MGF.
- *
- * NOTE: this code assumes modBits is a multiple of 8.
*/
static SECStatus
emsa_pss_verify(const unsigned char *mHash,
const unsigned char *em,
unsigned int emLen,
+ unsigned int emBits,
HASH_HashType hashAlg,
HASH_HashType maskHashAlg,
unsigned int saltLen)
@@ -1063,15 +1087,22 @@ emsa_pss_verify(const unsigned char *mHash,
unsigned char *H_; /* H' from the RFC */
unsigned int i;
unsigned int dbMaskLen;
+ unsigned int zeroBits;
SECStatus rv;
hash = HASH_GetRawHashObject(hashAlg);
dbMaskLen = emLen - hash->length - 1;
- /* Step 3 + 4 + 6 */
+ /* Step 3 + 4 */
if ((emLen < (hash->length + saltLen + 2)) ||
- (em[emLen - 1] != 0xbc) ||
- ((em[0] & 0x80) != 0)) {
+ (em[emLen - 1] != 0xbc)) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ return SECFailure;
+ }
+
+ /* Step 6 */
+ zeroBits = 8 * emLen - emBits;
+ if (em[0] >> (8 - zeroBits)) {
PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
return SECFailure;
}
@@ -1091,7 +1122,7 @@ emsa_pss_verify(const unsigned char *mHash,
}
/* Step 9 */
- db[0] &= 0x7f;
+ db[0] &= 0xff >> zeroBits;
/* Step 10 */
for (i = 0; i < (dbMaskLen - saltLen - 1); i++) {
@@ -1156,7 +1187,9 @@ RSA_SignPSS(RSAPrivateKey *key,
{
SECStatus rv = SECSuccess;
unsigned int modulusLen = rsa_modulusLen(&key->modulus);
- unsigned char *pssEncoded = NULL;
+ unsigned int modulusBits = rsa_modulusBits(&key->modulus);
+ unsigned int emLen = modulusLen;
+ unsigned char *pssEncoded, *em;
if (maxOutputLen < modulusLen) {
PORT_SetError(SEC_ERROR_OUTPUT_LEN);
@@ -1168,16 +1201,24 @@ RSA_SignPSS(RSAPrivateKey *key,
return SECFailure;
}
- pssEncoded = (unsigned char *)PORT_Alloc(modulusLen);
+ pssEncoded = em = (unsigned char *)PORT_Alloc(modulusLen);
if (pssEncoded == NULL) {
PORT_SetError(SEC_ERROR_NO_MEMORY);
return SECFailure;
}
- rv = emsa_pss_encode(pssEncoded, modulusLen, input, hashAlg,
+
+ /* len(em) == ceil((modulusBits - 1) / 8). */
+ if (modulusBits % 8 == 1) {
+ em[0] = 0;
+ emLen--;
+ em++;
+ }
+ rv = emsa_pss_encode(em, emLen, modulusBits - 1, input, hashAlg,
maskHashAlg, salt, saltLength);
if (rv != SECSuccess)
goto done;
+ // This sets error codes upon failure.
rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded);
*outputLen = modulusLen;
@@ -1198,7 +1239,9 @@ RSA_CheckSignPSS(RSAPublicKey *key,
{
SECStatus rv;
unsigned int modulusLen = rsa_modulusLen(&key->modulus);
- unsigned char *buffer;
+ unsigned int modulusBits = rsa_modulusBits(&key->modulus);
+ unsigned int emLen = modulusLen;
+ unsigned char *buffer, *em;
if (sigLen != modulusLen) {
PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
@@ -1210,7 +1253,7 @@ RSA_CheckSignPSS(RSAPublicKey *key,
return SECFailure;
}
- buffer = (unsigned char *)PORT_Alloc(modulusLen);
+ buffer = em = (unsigned char *)PORT_Alloc(modulusLen);
if (!buffer) {
PORT_SetError(SEC_ERROR_NO_MEMORY);
return SECFailure;
@@ -1223,14 +1266,18 @@ RSA_CheckSignPSS(RSAPublicKey *key,
return SECFailure;
}
- rv = emsa_pss_verify(hash, buffer, modulusLen, hashAlg,
+ /* len(em) == ceil((modulusBits - 1) / 8). */
+ if (modulusBits % 8 == 1) {
+ emLen--;
+ em++;
+ }
+ rv = emsa_pss_verify(hash, em, emLen, modulusBits - 1, hashAlg,
maskHashAlg, saltLength);
- PORT_Free(buffer);
+ PORT_Free(buffer);
return rv;
}
-/* XXX Doesn't set error code */
SECStatus
RSA_Sign(RSAPrivateKey *key,
unsigned char *output,
@@ -1239,34 +1286,34 @@ RSA_Sign(RSAPrivateKey *key,
const unsigned char *input,
unsigned int inputLen)
{
- SECStatus rv = SECSuccess;
+ SECStatus rv = SECFailure;
unsigned int modulusLen = rsa_modulusLen(&key->modulus);
- SECItem formatted;
- SECItem unformatted;
+ SECItem formatted = { siBuffer, NULL, 0 };
+ SECItem unformatted = { siBuffer, (unsigned char *)input, inputLen };
- if (maxOutputLen < modulusLen)
- return SECFailure;
+ if (maxOutputLen < modulusLen) {
+ PORT_SetError(SEC_ERROR_OUTPUT_LEN);
+ goto done;
+ }
- unformatted.len = inputLen;
- unformatted.data = (unsigned char *)input;
- formatted.data = NULL;
rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate,
&unformatted);
- if (rv != SECSuccess)
+ if (rv != SECSuccess) {
+ PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
goto done;
+ }
+ // This sets error codes upon failure.
rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
*outputLen = modulusLen;
- goto done;
-
done:
- if (formatted.data != NULL)
+ if (formatted.data != NULL) {
PORT_ZFree(formatted.data, modulusLen);
+ }
return rv;
}
-/* XXX Doesn't set error code */
SECStatus
RSA_CheckSign(RSAPublicKey *key,
const unsigned char *sig,
@@ -1274,60 +1321,71 @@ RSA_CheckSign(RSAPublicKey *key,
const unsigned char *data,
unsigned int dataLen)
{
- SECStatus rv;
+ SECStatus rv = SECFailure;
unsigned int modulusLen = rsa_modulusLen(&key->modulus);
unsigned int i;
- unsigned char *buffer;
+ unsigned char *buffer = NULL;
+
+ if (sigLen != modulusLen) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
- if (sigLen != modulusLen)
- goto failure;
/*
* 0x00 || BT || Pad || 0x00 || ActualData
*
* The "3" below is the first octet + the second octet + the 0x00
* octet that always comes just before the ActualData.
*/
- if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN))
- goto failure;
+ if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) {
+ PORT_SetError(SEC_ERROR_BAD_DATA);
+ goto done;
+ }
buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
- if (!buffer)
- goto failure;
+ if (!buffer) {
+ PORT_SetError(SEC_ERROR_NO_MEMORY);
+ goto done;
+ }
- rv = RSA_PublicKeyOp(key, buffer, sig);
- if (rv != SECSuccess)
- goto loser;
+ if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
/*
* check the padding that was used
*/
if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
buffer[1] != (unsigned char)RSA_BlockPrivate) {
- goto loser;
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
}
for (i = 2; i < modulusLen - dataLen - 1; i++) {
- if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET)
- goto loser;
+ if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
+ }
+ if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
}
- if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET)
- goto loser;
/*
* make sure we get the same results
*/
- if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) != 0)
- goto loser;
-
- PORT_Free(buffer);
- return SECSuccess;
+ if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) == 0) {
+ rv = SECSuccess;
+ }
-loser:
- PORT_Free(buffer);
-failure:
- return SECFailure;
+done:
+ if (buffer) {
+ PORT_Free(buffer);
+ }
+ return rv;
}
-/* XXX Doesn't set error code */
SECStatus
RSA_CheckSignRecover(RSAPublicKey *key,
unsigned char *output,
@@ -1336,21 +1394,27 @@ RSA_CheckSignRecover(RSAPublicKey *key,
const unsigned char *sig,
unsigned int sigLen)
{
- SECStatus rv;
+ SECStatus rv = SECFailure;
unsigned int modulusLen = rsa_modulusLen(&key->modulus);
unsigned int i;
- unsigned char *buffer;
+ unsigned char *buffer = NULL;
- if (sigLen != modulusLen)
- goto failure;
+ if (sigLen != modulusLen) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
- if (!buffer)
- goto failure;
+ if (!buffer) {
+ PORT_SetError(SEC_ERROR_NO_MEMORY);
+ goto done;
+ }
+
+ if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
- rv = RSA_PublicKeyOp(key, buffer, sig);
- if (rv != SECSuccess)
- goto loser;
*outputLen = 0;
/*
@@ -1358,28 +1422,34 @@ RSA_CheckSignRecover(RSAPublicKey *key,
*/
if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
buffer[1] != (unsigned char)RSA_BlockPrivate) {
- goto loser;
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
}
for (i = 2; i < modulusLen; i++) {
if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) {
*outputLen = modulusLen - i - 1;
break;
}
- if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET)
- goto loser;
+ if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
+ }
+ if (*outputLen == 0) {
+ PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
+ goto done;
+ }
+ if (*outputLen > maxOutputLen) {
+ PORT_SetError(SEC_ERROR_OUTPUT_LEN);
+ goto done;
}
- if (*outputLen == 0)
- goto loser;
- if (*outputLen > maxOutputLen)
- goto loser;
PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen);
+ rv = SECSuccess;
- PORT_Free(buffer);
- return SECSuccess;
-
-loser:
- PORT_Free(buffer);
-failure:
- return SECFailure;
+done:
+ if (buffer) {
+ PORT_Free(buffer);
+ }
+ return rv;
}