/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "ckmk.h"

/* Sigh, For all the talk about 'ease of use', apple has hidden the interfaces
 * needed to be able to truly use CSSM. These came from their modification
 * to NSS's S/MIME code. The following two functions currently are not
 * part of the SecKey.h interface.
 */
OSStatus
SecKeyGetCredentials(
    SecKeyRef keyRef,
    CSSM_ACL_AUTHORIZATION_TAG authTag,
    int type,
    const CSSM_ACCESS_CREDENTIALS **creds);

/* this function could be implemented using 'SecKeychainItemCopyKeychain' and
 * 'SecKeychainGetCSPHandle' */
OSStatus
SecKeyGetCSPHandle(
    SecKeyRef keyRef,
    CSSM_CSP_HANDLE *cspHandle);

typedef struct ckmkInternalCryptoOperationRSAPrivStr
    ckmkInternalCryptoOperationRSAPriv;
struct ckmkInternalCryptoOperationRSAPrivStr {
    NSSCKMDCryptoOperation mdOperation;
    NSSCKMDMechanism *mdMechanism;
    ckmkInternalObject *iKey;
    NSSItem *buffer;
    CSSM_CC_HANDLE cssmContext;
};

typedef enum {
    CKMK_DECRYPT,
    CKMK_SIGN
} ckmkRSAOpType;

/*
 * ckmk_mdCryptoOperationRSAPriv_Create
 */
static NSSCKMDCryptoOperation *
ckmk_mdCryptoOperationRSAPriv_Create(
    const NSSCKMDCryptoOperation *proto,
    NSSCKMDMechanism *mdMechanism,
    NSSCKMDObject *mdKey,
    ckmkRSAOpType type,
    CK_RV *pError)
{
    ckmkInternalObject *iKey = (ckmkInternalObject *)mdKey->etc;
    const NSSItem *classItem = nss_ckmk_FetchAttribute(iKey, CKA_CLASS, pError);
    const NSSItem *keyType = nss_ckmk_FetchAttribute(iKey, CKA_KEY_TYPE, pError);
    ckmkInternalCryptoOperationRSAPriv *iOperation;
    SecKeyRef privateKey;
    OSStatus macErr;
    CSSM_RETURN cssmErr;
    const CSSM_KEY *cssmKey;
    CSSM_CSP_HANDLE cspHandle;
    const CSSM_ACCESS_CREDENTIALS *creds = NULL;
    CSSM_CC_HANDLE cssmContext;
    CSSM_ACL_AUTHORIZATION_TAG authType;

    /* make sure we have the right objects */
    if (((const NSSItem *)NULL == classItem) ||
        (sizeof(CK_OBJECT_CLASS) != classItem->size) ||
        (CKO_PRIVATE_KEY != *(CK_OBJECT_CLASS *)classItem->data) ||
        ((const NSSItem *)NULL == keyType) ||
        (sizeof(CK_KEY_TYPE) != keyType->size) ||
        (CKK_RSA != *(CK_KEY_TYPE *)keyType->data)) {
        *pError = CKR_KEY_TYPE_INCONSISTENT;
        return (NSSCKMDCryptoOperation *)NULL;
    }

    privateKey = (SecKeyRef)iKey->u.item.itemRef;
    macErr = SecKeyGetCSSMKey(privateKey, &cssmKey);
    if (noErr != macErr) {
        CKMK_MACERR("Getting CSSM Key", macErr);
        *pError = CKR_KEY_HANDLE_INVALID;
        return (NSSCKMDCryptoOperation *)NULL;
    }
    macErr = SecKeyGetCSPHandle(privateKey, &cspHandle);
    if (noErr != macErr) {
        CKMK_MACERR("Getting CSP for Key", macErr);
        *pError = CKR_KEY_HANDLE_INVALID;
        return (NSSCKMDCryptoOperation *)NULL;
    }
    switch (type) {
        case CKMK_DECRYPT:
            authType = CSSM_ACL_AUTHORIZATION_DECRYPT;
            break;
        case CKMK_SIGN:
            authType = CSSM_ACL_AUTHORIZATION_SIGN;
            break;
        default:
            *pError = CKR_GENERAL_ERROR;
#ifdef DEBUG
            fprintf(stderr, "RSAPriv_Create: bad type = %d\n", type);
#endif
            return (NSSCKMDCryptoOperation *)NULL;
    }

    macErr = SecKeyGetCredentials(privateKey, authType, 0, &creds);
    if (noErr != macErr) {
        CKMK_MACERR("Getting Credentials for Key", macErr);
        *pError = CKR_KEY_HANDLE_INVALID;
        return (NSSCKMDCryptoOperation *)NULL;
    }

    switch (type) {
        case CKMK_DECRYPT:
            cssmErr = CSSM_CSP_CreateAsymmetricContext(cspHandle, CSSM_ALGID_RSA,
                                                       creds, cssmKey, CSSM_PADDING_PKCS1, &cssmContext);
            break;
        case CKMK_SIGN:
            cssmErr = CSSM_CSP_CreateSignatureContext(cspHandle, CSSM_ALGID_RSA,
                                                      creds, cssmKey, &cssmContext);
            break;
        default:
            *pError = CKR_GENERAL_ERROR;
#ifdef DEBUG
            fprintf(stderr, "RSAPriv_Create: bad type = %d\n", type);
#endif
            return (NSSCKMDCryptoOperation *)NULL;
    }
    if (noErr != cssmErr) {
        CKMK_MACERR("Getting Context for Key", cssmErr);
        *pError = CKR_GENERAL_ERROR;
        return (NSSCKMDCryptoOperation *)NULL;
    }

    iOperation = nss_ZNEW(NULL, ckmkInternalCryptoOperationRSAPriv);
    if ((ckmkInternalCryptoOperationRSAPriv *)NULL == iOperation) {
        *pError = CKR_HOST_MEMORY;
        return (NSSCKMDCryptoOperation *)NULL;
    }
    iOperation->mdMechanism = mdMechanism;
    iOperation->iKey = iKey;
    iOperation->cssmContext = cssmContext;

    nsslibc_memcpy(&iOperation->mdOperation,
                   proto, sizeof(NSSCKMDCryptoOperation));
    iOperation->mdOperation.etc = iOperation;

    return &iOperation->mdOperation;
}

static void
ckmk_mdCryptoOperationRSAPriv_Destroy(
    NSSCKMDCryptoOperation *mdOperation,
    NSSCKFWCryptoOperation *fwOperation,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance)
{
    ckmkInternalCryptoOperationRSAPriv *iOperation =
        (ckmkInternalCryptoOperationRSAPriv *)mdOperation->etc;

    if (iOperation->buffer) {
        nssItem_Destroy(iOperation->buffer);
    }
    if (iOperation->cssmContext) {
        CSSM_DeleteContext(iOperation->cssmContext);
    }
    nss_ZFreeIf(iOperation);
    return;
}

static CK_ULONG
ckmk_mdCryptoOperationRSA_GetFinalLength(
    NSSCKMDCryptoOperation *mdOperation,
    NSSCKFWCryptoOperation *fwOperation,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    CK_RV *pError)
{
    ckmkInternalCryptoOperationRSAPriv *iOperation =
        (ckmkInternalCryptoOperationRSAPriv *)mdOperation->etc;
    const NSSItem *modulus =
        nss_ckmk_FetchAttribute(iOperation->iKey, CKA_MODULUS, pError);

    return modulus->size;
}

/*
 * ckmk_mdCryptoOperationRSADecrypt_GetOperationLength
 * we won't know the length until we actually decrypt the
 * input block. Since we go to all the work to decrypt the
 * the block, we'll save if for when the block is asked for
 */
static CK_ULONG
ckmk_mdCryptoOperationRSADecrypt_GetOperationLength(
    NSSCKMDCryptoOperation *mdOperation,
    NSSCKFWCryptoOperation *fwOperation,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    const NSSItem *input,
    CK_RV *pError)
{
    ckmkInternalCryptoOperationRSAPriv *iOperation =
        (ckmkInternalCryptoOperationRSAPriv *)mdOperation->etc;
    CSSM_DATA cssmInput;
    CSSM_DATA cssmOutput = { 0, NULL };
    PRUint32 bytesDecrypted;
    CSSM_DATA remainder = { 0, NULL };
    NSSItem output;
    CSSM_RETURN cssmErr;

    if (iOperation->buffer) {
        return iOperation->buffer->size;
    }

    cssmInput.Data = input->data;
    cssmInput.Length = input->size;

    cssmErr = CSSM_DecryptData(iOperation->cssmContext,
                               &cssmInput, 1, &cssmOutput, 1,
                               &bytesDecrypted, &remainder);
    if (CSSM_OK != cssmErr) {
        CKMK_MACERR("Decrypt Failed", cssmErr);
        *pError = CKR_DATA_INVALID;
        return 0;
    }
    /* we didn't suppy any buffers, so it should all be in remainder */
    output.data = nss_ZNEWARRAY(NULL, char, bytesDecrypted + remainder.Length);
    if (NULL == output.data) {
        free(cssmOutput.Data);
        free(remainder.Data);
        *pError = CKR_HOST_MEMORY;
        return 0;
    }
    output.size = bytesDecrypted + remainder.Length;

    if (0 != bytesDecrypted) {
        nsslibc_memcpy(output.data, cssmOutput.Data, bytesDecrypted);
        free(cssmOutput.Data);
    }
    if (0 != remainder.Length) {
        nsslibc_memcpy(((char *)output.data) + bytesDecrypted,
                       remainder.Data, remainder.Length);
        free(remainder.Data);
    }

    iOperation->buffer = nssItem_Duplicate(&output, NULL, NULL);
    nss_ZFreeIf(output.data);
    if ((NSSItem *)NULL == iOperation->buffer) {
        *pError = CKR_HOST_MEMORY;
        return 0;
    }

    return iOperation->buffer->size;
}

/*
 * ckmk_mdCryptoOperationRSADecrypt_UpdateFinal
 *
 * NOTE: ckmk_mdCryptoOperationRSADecrypt_GetOperationLength is presumed to
 * have been called previously.
 */
static CK_RV
ckmk_mdCryptoOperationRSADecrypt_UpdateFinal(
    NSSCKMDCryptoOperation *mdOperation,
    NSSCKFWCryptoOperation *fwOperation,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    const NSSItem *input,
    NSSItem *output)
{
    ckmkInternalCryptoOperationRSAPriv *iOperation =
        (ckmkInternalCryptoOperationRSAPriv *)mdOperation->etc;
    NSSItem *buffer = iOperation->buffer;

    if ((NSSItem *)NULL == buffer) {
        return CKR_GENERAL_ERROR;
    }
    nsslibc_memcpy(output->data, buffer->data, buffer->size);
    output->size = buffer->size;
    return CKR_OK;
}

/*
 * ckmk_mdCryptoOperationRSASign_UpdateFinal
 *
 */
static CK_RV
ckmk_mdCryptoOperationRSASign_UpdateFinal(
    NSSCKMDCryptoOperation *mdOperation,
    NSSCKFWCryptoOperation *fwOperation,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    const NSSItem *input,
    NSSItem *output)
{
    ckmkInternalCryptoOperationRSAPriv *iOperation =
        (ckmkInternalCryptoOperationRSAPriv *)mdOperation->etc;
    CSSM_DATA cssmInput;
    CSSM_DATA cssmOutput = { 0, NULL };
    CSSM_RETURN cssmErr;

    cssmInput.Data = input->data;
    cssmInput.Length = input->size;

    cssmErr = CSSM_SignData(iOperation->cssmContext, &cssmInput, 1,
                            CSSM_ALGID_NONE, &cssmOutput);
    if (CSSM_OK != cssmErr) {
        CKMK_MACERR("Signed Failed", cssmErr);
        return CKR_FUNCTION_FAILED;
    }
    if (cssmOutput.Length > output->size) {
        free(cssmOutput.Data);
        return CKR_BUFFER_TOO_SMALL;
    }
    nsslibc_memcpy(output->data, cssmOutput.Data, cssmOutput.Length);
    free(cssmOutput.Data);
    output->size = cssmOutput.Length;

    return CKR_OK;
}

NSS_IMPLEMENT_DATA const NSSCKMDCryptoOperation
    ckmk_mdCryptoOperationRSADecrypt_proto = {
        NULL, /* etc */
        ckmk_mdCryptoOperationRSAPriv_Destroy,
        NULL, /* GetFinalLengh - not needed for one shot Decrypt/Encrypt */
        ckmk_mdCryptoOperationRSADecrypt_GetOperationLength,
        NULL, /* Final - not needed for one shot operation */
        NULL, /* Update - not needed for one shot operation */
        NULL, /* DigetUpdate - not needed for one shot operation */
        ckmk_mdCryptoOperationRSADecrypt_UpdateFinal,
        NULL,        /* UpdateCombo - not needed for one shot operation */
        NULL,        /* DigetKey - not needed for one shot operation */
        (void *)NULL /* null terminator */
    };

NSS_IMPLEMENT_DATA const NSSCKMDCryptoOperation
    ckmk_mdCryptoOperationRSASign_proto = {
        NULL, /* etc */
        ckmk_mdCryptoOperationRSAPriv_Destroy,
        ckmk_mdCryptoOperationRSA_GetFinalLength,
        NULL, /* GetOperationLengh - not needed for one shot Sign/Verify */
        NULL, /* Final - not needed for one shot operation */
        NULL, /* Update - not needed for one shot operation */
        NULL, /* DigetUpdate - not needed for one shot operation */
        ckmk_mdCryptoOperationRSASign_UpdateFinal,
        NULL,        /* UpdateCombo - not needed for one shot operation */
        NULL,        /* DigetKey - not needed for one shot operation */
        (void *)NULL /* null terminator */
    };

/********** NSSCKMDMechansim functions ***********************/
/*
 * ckmk_mdMechanismRSA_Destroy
 */
static void
ckmk_mdMechanismRSA_Destroy(
    NSSCKMDMechanism *mdMechanism,
    NSSCKFWMechanism *fwMechanism,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance)
{
    nss_ZFreeIf(fwMechanism);
}

/*
 * ckmk_mdMechanismRSA_GetMinKeySize
 */
static CK_ULONG
ckmk_mdMechanismRSA_GetMinKeySize(
    NSSCKMDMechanism *mdMechanism,
    NSSCKFWMechanism *fwMechanism,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    CK_RV *pError)
{
    return 384;
}

/*
 * ckmk_mdMechanismRSA_GetMaxKeySize
 */
static CK_ULONG
ckmk_mdMechanismRSA_GetMaxKeySize(
    NSSCKMDMechanism *mdMechanism,
    NSSCKFWMechanism *fwMechanism,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    CK_RV *pError)
{
    return 16384;
}

/*
 * ckmk_mdMechanismRSA_DecryptInit
 */
static NSSCKMDCryptoOperation *
ckmk_mdMechanismRSA_DecryptInit(
    NSSCKMDMechanism *mdMechanism,
    NSSCKFWMechanism *fwMechanism,
    CK_MECHANISM *pMechanism,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    NSSCKMDObject *mdKey,
    NSSCKFWObject *fwKey,
    CK_RV *pError)
{
    return ckmk_mdCryptoOperationRSAPriv_Create(
        &ckmk_mdCryptoOperationRSADecrypt_proto,
        mdMechanism, mdKey, CKMK_DECRYPT, pError);
}

/*
 * ckmk_mdMechanismRSA_SignInit
 */
static NSSCKMDCryptoOperation *
ckmk_mdMechanismRSA_SignInit(
    NSSCKMDMechanism *mdMechanism,
    NSSCKFWMechanism *fwMechanism,
    CK_MECHANISM *pMechanism,
    NSSCKMDSession *mdSession,
    NSSCKFWSession *fwSession,
    NSSCKMDToken *mdToken,
    NSSCKFWToken *fwToken,
    NSSCKMDInstance *mdInstance,
    NSSCKFWInstance *fwInstance,
    NSSCKMDObject *mdKey,
    NSSCKFWObject *fwKey,
    CK_RV *pError)
{
    return ckmk_mdCryptoOperationRSAPriv_Create(
        &ckmk_mdCryptoOperationRSASign_proto,
        mdMechanism, mdKey, CKMK_SIGN, pError);
}

NSS_IMPLEMENT_DATA const NSSCKMDMechanism
    nss_ckmk_mdMechanismRSA = {
        (void *)NULL, /* etc */
        ckmk_mdMechanismRSA_Destroy,
        ckmk_mdMechanismRSA_GetMinKeySize,
        ckmk_mdMechanismRSA_GetMaxKeySize,
        NULL, /* GetInHardware - default false */
        NULL, /* EncryptInit - default errs */
        ckmk_mdMechanismRSA_DecryptInit,
        NULL, /* DigestInit - default errs*/
        ckmk_mdMechanismRSA_SignInit,
        NULL,                         /* VerifyInit - default errs */
        ckmk_mdMechanismRSA_SignInit, /* SignRecoverInit */
        NULL,                         /* VerifyRecoverInit - default errs */
        NULL,                         /* GenerateKey - default errs */
        NULL,                         /* GenerateKeyPair - default errs */
        NULL,                         /* GetWrapKeyLength - default errs */
        NULL,                         /* WrapKey - default errs */
        NULL,                         /* UnwrapKey - default errs */
        NULL,                         /* DeriveKey - default errs */
        (void *)NULL                  /* null terminator */
    };