diff options
Diffstat (limited to 'security/nss/lib/freebl/rsapkcs.c')
-rw-r--r-- | security/nss/lib/freebl/rsapkcs.c | 234 |
1 files changed, 152 insertions, 82 deletions
diff --git a/security/nss/lib/freebl/rsapkcs.c b/security/nss/lib/freebl/rsapkcs.c index 577fe1f61..ad18c8b73 100644 --- a/security/nss/lib/freebl/rsapkcs.c +++ b/security/nss/lib/freebl/rsapkcs.c @@ -85,6 +85,25 @@ rsa_modulusLen(SECItem *modulus) return modLen; } +static unsigned int +rsa_modulusBits(SECItem *modulus) +{ + unsigned char byteZero = modulus->data[0]; + unsigned int numBits = (modulus->len - 1) * 8; + + if (byteZero == 0) { + numBits -= 8; + byteZero = modulus->data[1]; + } + + while (byteZero > 0) { + numBits++; + byteZero >>= 1; + } + + return numBits; +} + /* * Format one block of data for public/private key encryption using * the rules defined in PKCS #1. @@ -271,10 +290,12 @@ MGF1(HASH_HashType hashAlg, const SECHashObject *hash; void *hashContext; unsigned char C[4]; + SECStatus rv = SECSuccess; hash = HASH_GetRawHashObject(hashAlg); - if (hash == NULL) + if (hash == NULL) { return SECFailure; + } hashContext = (*hash->create)(); rounds = (maskLen + hash->length - 1) / hash->length; @@ -295,14 +316,19 @@ MGF1(HASH_HashType hashAlg, (*hash->end)(hashContext, tempHash, &digestLen, hash->length); } else { /* we're in the last round and need to cut the hash */ temp = (unsigned char *)PORT_Alloc(hash->length); + if (!temp) { + rv = SECFailure; + goto done; + } (*hash->end)(hashContext, temp, &digestLen, hash->length); PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length); PORT_Free(temp); } } - (*hash->destroy)(hashContext, PR_TRUE); - return SECSuccess; +done: + (*hash->destroy)(hashContext, PR_TRUE); + return rv; } /* XXX Doesn't set error code */ @@ -962,12 +988,11 @@ failure: * We use mHash instead of M as input. * emBits from the RFC is just modBits - 1, see section 8.1.1. * We only support MGF1 as the MGF. - * - * NOTE: this code assumes modBits is a multiple of 8. */ static SECStatus emsa_pss_encode(unsigned char *em, unsigned int emLen, + unsigned int emBits, const unsigned char *mHash, HASH_HashType hashAlg, HASH_HashType maskHashAlg, @@ -1032,7 +1057,7 @@ emsa_pss_encode(unsigned char *em, PORT_Free(dbMask); /* Step 11 */ - em[0] &= 0x7f; + em[0] &= 0xff >> (8 * emLen - emBits); /* Step 12 */ em[emLen - 1] = 0xbc; @@ -1046,13 +1071,12 @@ emsa_pss_encode(unsigned char *em, * We use mHash instead of M as input. * emBits from the RFC is just modBits - 1, see section 8.1.2. * We only support MGF1 as the MGF. - * - * NOTE: this code assumes modBits is a multiple of 8. */ static SECStatus emsa_pss_verify(const unsigned char *mHash, const unsigned char *em, unsigned int emLen, + unsigned int emBits, HASH_HashType hashAlg, HASH_HashType maskHashAlg, unsigned int saltLen) @@ -1063,15 +1087,22 @@ emsa_pss_verify(const unsigned char *mHash, unsigned char *H_; /* H' from the RFC */ unsigned int i; unsigned int dbMaskLen; + unsigned int zeroBits; SECStatus rv; hash = HASH_GetRawHashObject(hashAlg); dbMaskLen = emLen - hash->length - 1; - /* Step 3 + 4 + 6 */ + /* Step 3 + 4 */ if ((emLen < (hash->length + saltLen + 2)) || - (em[emLen - 1] != 0xbc) || - ((em[0] & 0x80) != 0)) { + (em[emLen - 1] != 0xbc)) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + return SECFailure; + } + + /* Step 6 */ + zeroBits = 8 * emLen - emBits; + if (em[0] >> (8 - zeroBits)) { PORT_SetError(SEC_ERROR_BAD_SIGNATURE); return SECFailure; } @@ -1091,7 +1122,7 @@ emsa_pss_verify(const unsigned char *mHash, } /* Step 9 */ - db[0] &= 0x7f; + db[0] &= 0xff >> zeroBits; /* Step 10 */ for (i = 0; i < (dbMaskLen - saltLen - 1); i++) { @@ -1156,7 +1187,9 @@ RSA_SignPSS(RSAPrivateKey *key, { SECStatus rv = SECSuccess; unsigned int modulusLen = rsa_modulusLen(&key->modulus); - unsigned char *pssEncoded = NULL; + unsigned int modulusBits = rsa_modulusBits(&key->modulus); + unsigned int emLen = modulusLen; + unsigned char *pssEncoded, *em; if (maxOutputLen < modulusLen) { PORT_SetError(SEC_ERROR_OUTPUT_LEN); @@ -1168,16 +1201,24 @@ RSA_SignPSS(RSAPrivateKey *key, return SECFailure; } - pssEncoded = (unsigned char *)PORT_Alloc(modulusLen); + pssEncoded = em = (unsigned char *)PORT_Alloc(modulusLen); if (pssEncoded == NULL) { PORT_SetError(SEC_ERROR_NO_MEMORY); return SECFailure; } - rv = emsa_pss_encode(pssEncoded, modulusLen, input, hashAlg, + + /* len(em) == ceil((modulusBits - 1) / 8). */ + if (modulusBits % 8 == 1) { + em[0] = 0; + emLen--; + em++; + } + rv = emsa_pss_encode(em, emLen, modulusBits - 1, input, hashAlg, maskHashAlg, salt, saltLength); if (rv != SECSuccess) goto done; + // This sets error codes upon failure. rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded); *outputLen = modulusLen; @@ -1198,7 +1239,9 @@ RSA_CheckSignPSS(RSAPublicKey *key, { SECStatus rv; unsigned int modulusLen = rsa_modulusLen(&key->modulus); - unsigned char *buffer; + unsigned int modulusBits = rsa_modulusBits(&key->modulus); + unsigned int emLen = modulusLen; + unsigned char *buffer, *em; if (sigLen != modulusLen) { PORT_SetError(SEC_ERROR_BAD_SIGNATURE); @@ -1210,7 +1253,7 @@ RSA_CheckSignPSS(RSAPublicKey *key, return SECFailure; } - buffer = (unsigned char *)PORT_Alloc(modulusLen); + buffer = em = (unsigned char *)PORT_Alloc(modulusLen); if (!buffer) { PORT_SetError(SEC_ERROR_NO_MEMORY); return SECFailure; @@ -1223,14 +1266,18 @@ RSA_CheckSignPSS(RSAPublicKey *key, return SECFailure; } - rv = emsa_pss_verify(hash, buffer, modulusLen, hashAlg, + /* len(em) == ceil((modulusBits - 1) / 8). */ + if (modulusBits % 8 == 1) { + emLen--; + em++; + } + rv = emsa_pss_verify(hash, em, emLen, modulusBits - 1, hashAlg, maskHashAlg, saltLength); - PORT_Free(buffer); + PORT_Free(buffer); return rv; } -/* XXX Doesn't set error code */ SECStatus RSA_Sign(RSAPrivateKey *key, unsigned char *output, @@ -1239,34 +1286,34 @@ RSA_Sign(RSAPrivateKey *key, const unsigned char *input, unsigned int inputLen) { - SECStatus rv = SECSuccess; + SECStatus rv = SECFailure; unsigned int modulusLen = rsa_modulusLen(&key->modulus); - SECItem formatted; - SECItem unformatted; + SECItem formatted = { siBuffer, NULL, 0 }; + SECItem unformatted = { siBuffer, (unsigned char *)input, inputLen }; - if (maxOutputLen < modulusLen) - return SECFailure; + if (maxOutputLen < modulusLen) { + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + goto done; + } - unformatted.len = inputLen; - unformatted.data = (unsigned char *)input; - formatted.data = NULL; rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate, &unformatted); - if (rv != SECSuccess) + if (rv != SECSuccess) { + PORT_SetError(SEC_ERROR_LIBRARY_FAILURE); goto done; + } + // This sets error codes upon failure. rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data); *outputLen = modulusLen; - goto done; - done: - if (formatted.data != NULL) + if (formatted.data != NULL) { PORT_ZFree(formatted.data, modulusLen); + } return rv; } -/* XXX Doesn't set error code */ SECStatus RSA_CheckSign(RSAPublicKey *key, const unsigned char *sig, @@ -1274,60 +1321,71 @@ RSA_CheckSign(RSAPublicKey *key, const unsigned char *data, unsigned int dataLen) { - SECStatus rv; + SECStatus rv = SECFailure; unsigned int modulusLen = rsa_modulusLen(&key->modulus); unsigned int i; - unsigned char *buffer; + unsigned char *buffer = NULL; + + if (sigLen != modulusLen) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } - if (sigLen != modulusLen) - goto failure; /* * 0x00 || BT || Pad || 0x00 || ActualData * * The "3" below is the first octet + the second octet + the 0x00 * octet that always comes just before the ActualData. */ - if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) - goto failure; + if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) { + PORT_SetError(SEC_ERROR_BAD_DATA); + goto done; + } buffer = (unsigned char *)PORT_Alloc(modulusLen + 1); - if (!buffer) - goto failure; + if (!buffer) { + PORT_SetError(SEC_ERROR_NO_MEMORY); + goto done; + } - rv = RSA_PublicKeyOp(key, buffer, sig); - if (rv != SECSuccess) - goto loser; + if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } /* * check the padding that was used */ if (buffer[0] != RSA_BLOCK_FIRST_OCTET || buffer[1] != (unsigned char)RSA_BlockPrivate) { - goto loser; + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; } for (i = 2; i < modulusLen - dataLen - 1; i++) { - if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) - goto loser; + if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } + } + if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; } - if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) - goto loser; /* * make sure we get the same results */ - if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) != 0) - goto loser; - - PORT_Free(buffer); - return SECSuccess; + if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) == 0) { + rv = SECSuccess; + } -loser: - PORT_Free(buffer); -failure: - return SECFailure; +done: + if (buffer) { + PORT_Free(buffer); + } + return rv; } -/* XXX Doesn't set error code */ SECStatus RSA_CheckSignRecover(RSAPublicKey *key, unsigned char *output, @@ -1336,21 +1394,27 @@ RSA_CheckSignRecover(RSAPublicKey *key, const unsigned char *sig, unsigned int sigLen) { - SECStatus rv; + SECStatus rv = SECFailure; unsigned int modulusLen = rsa_modulusLen(&key->modulus); unsigned int i; - unsigned char *buffer; + unsigned char *buffer = NULL; - if (sigLen != modulusLen) - goto failure; + if (sigLen != modulusLen) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } buffer = (unsigned char *)PORT_Alloc(modulusLen + 1); - if (!buffer) - goto failure; + if (!buffer) { + PORT_SetError(SEC_ERROR_NO_MEMORY); + goto done; + } + + if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } - rv = RSA_PublicKeyOp(key, buffer, sig); - if (rv != SECSuccess) - goto loser; *outputLen = 0; /* @@ -1358,28 +1422,34 @@ RSA_CheckSignRecover(RSAPublicKey *key, */ if (buffer[0] != RSA_BLOCK_FIRST_OCTET || buffer[1] != (unsigned char)RSA_BlockPrivate) { - goto loser; + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; } for (i = 2; i < modulusLen; i++) { if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) { *outputLen = modulusLen - i - 1; break; } - if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) - goto loser; + if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } + } + if (*outputLen == 0) { + PORT_SetError(SEC_ERROR_BAD_SIGNATURE); + goto done; + } + if (*outputLen > maxOutputLen) { + PORT_SetError(SEC_ERROR_OUTPUT_LEN); + goto done; } - if (*outputLen == 0) - goto loser; - if (*outputLen > maxOutputLen) - goto loser; PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen); + rv = SECSuccess; - PORT_Free(buffer); - return SECSuccess; - -loser: - PORT_Free(buffer); -failure: - return SECFailure; +done: + if (buffer) { + PORT_Free(buffer); + } + return rv; } |