//
// Created by xajhu on 2021/7/13 0013.
//
#include <string.h>
#include <openssl/aes.h>
#include <openssl/evp.h>
#include <openssl/sha.h>

#include "crypto.h"
#include "user_errno.h"

static int sha1prng_for_aes_key(const char *pKey, unsigned char *pShaPrng16) {
    unsigned int   outSize                 = EVP_MAX_MD_SIZE;
    unsigned char  data[SHA_DIGEST_LENGTH] = {0};
    unsigned char *pHashValue              = NULL;

    int ret = hash_digest_mem(HASH_TYPE_SHA1, (const unsigned char *)pKey, strlen(pKey), &pHashValue, &outSize);

    if (ret != ERR_SUCCESS) {

        if (pHashValue) {
            free(pHashValue);
        }

        return ret;
    }

    memcpy(data, pHashValue, SHA_DIGEST_LENGTH);
    free(pHashValue);

    ret = hash_digest_mem(HASH_TYPE_SHA1, data, SHA_DIGEST_LENGTH, &pHashValue, &outSize);

    if (ret != ERR_SUCCESS) {

        if (pHashValue) {
            free(pHashValue);
        }

        return ret;
    }

    memcpy(pShaPrng16, pHashValue, 16);

    free(pHashValue);

    return ERR_SUCCESS;
}

#define EVP_ENV_INIT(algorithmType, pCipher, keyBuf, pKey)                     \
    do {                                                                       \
        switch (algorithmType) {                                               \
            case DES3_ECB_PKCS7PADDING:                                        \
                pCipher = EVP_des_ede3();                                      \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case DES3_CBC_PKCS7PADDING:                                        \
                pCipher = EVP_des_ede3_cbc();                                  \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case DES3_OFB_PKCS7PADDING:                                        \
                pCipher = EVP_des_ede3_ofb();                                  \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES128_ECB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_128_ecb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES128_ECB_PKCS7PADDING_SHA1PRNG:                             \
                if (sha1prng_for_aes_key(pKey, keyBuf) != ERR_SUCCESS) {       \
                    EVP_CIPHER_CTX_cleanup(pCtx);                              \
                    return -ERR_AES128_KEYGEN;                                 \
                }                                                              \
                pCipher = EVP_aes_128_ecb();                                   \
                break;                                                         \
            case AES128_CBC_PKCS7PADDING:                                      \
                pCipher = EVP_aes_128_cbc();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES128_OFB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_128_ofb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES192_ECB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_192_ecb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES192_CBC_PKCS7PADDING:                                      \
                pCipher = EVP_aes_192_cbc();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES192_OFB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_192_ofb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES256_ECB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_256_ecb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES256_CBC_PKCS7PADDING:                                      \
                pCipher = EVP_aes_256_cbc();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            case AES256_OFB_PKCS7PADDING:                                      \
                pCipher = EVP_aes_256_ofb();                                   \
                strncpy((char *)keyBuf, pKey, EVP_CIPHER_key_length(pCipher)); \
                break;                                                         \
            default:                                                           \
                EVP_CIPHER_CTX_cleanup(pCtx);                                  \
                return -ERR_UNSUP_EVP_TYPE;                                    \
        }                                                                      \
    } while (0)

int symmetric_decrypto(AES_TYPE        algorithmType,
                       unsigned char  *pInBuf,
                       unsigned int    inSize,
                       unsigned char **pOutBuf,
                       int            *pOutSize,
                       const char     *pKey) {
    int               enBytes = 0;
    EVP_CIPHER_CTX   *pCtx;
    unsigned char    *pAesBuf;
    unsigned char     keyBuf[EVP_MAX_KEY_LENGTH];
    unsigned char     iv[EVP_MAX_IV_LENGTH] = {0};
    const EVP_CIPHER *pCipher;

    if (!pInBuf || !pOutSize || !pKey || inSize == 0) {
        return -ERR_INPUT_PARAMS;
    }

    *pOutSize = 0;

    pCtx = EVP_CIPHER_CTX_new();

    if (!pCtx) {
        return -ERR_EVP_CREATE_CTX;
    }

    memset(keyBuf, 0, EVP_MAX_KEY_LENGTH);
    memset(iv, '0', EVP_MAX_IV_LENGTH);

    EVP_ENV_INIT(algorithmType, pCipher, keyBuf, pKey);

    EVP_CIPHER_CTX_init(pCtx);

    if (EVP_DecryptInit_ex(pCtx, pCipher, NULL, (const unsigned char *)keyBuf, iv) == 0) {
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_INIT_KEY;
    }

    pAesBuf = malloc(inSize + EVP_MAX_BLOCK_LENGTH);

    if (pAesBuf == NULL) {
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_MALLOC_MEMORY;
    }

    memset(pAesBuf, 0, inSize + EVP_MAX_BLOCK_LENGTH);

    if (EVP_DecryptUpdate(pCtx, pAesBuf, &enBytes, pInBuf, (int)inSize) == 0) {
        free(pAesBuf);
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_UPDATE;
    }

    *pOutBuf = pAesBuf;
    pAesBuf += enBytes;
    *pOutSize += enBytes;

    if (EVP_DecryptFinal_ex(pCtx, pAesBuf, &enBytes) == 0) {
        //free(pOutBuf); // Maybe exception
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_FINALE;
    }

    *pOutSize += enBytes;

    EVP_CIPHER_CTX_cleanup(pCtx);
    EVP_CIPHER_CTX_free(pCtx);

    return ERR_SUCCESS;
}

int symmetric_encrypto(AES_TYPE        algorithmType,
                       unsigned char  *pInBuf,
                       unsigned int    inSize,
                       unsigned char **pOutBuf,
                       int            *pOutSize,
                       const char     *pKey) {
    int               enBytes = 0;
    unsigned char    *pAesBuf;
    EVP_CIPHER_CTX   *pCtx;
    unsigned char     keyBuf[EVP_MAX_KEY_LENGTH];
    unsigned char     iv[EVP_MAX_IV_LENGTH] = {0};
    const EVP_CIPHER *pCipher;

    if (!pInBuf || !pOutSize || !pKey || inSize == 0) {
        return -ERR_INPUT_PARAMS;
    }

    *pOutSize = 0;

    pCtx = EVP_CIPHER_CTX_new();

    if (!pCtx) {
        return -ERR_EVP_CREATE_CTX;
    }

    memset(keyBuf, 0, EVP_MAX_KEY_LENGTH);
    memset(iv, '0', EVP_MAX_IV_LENGTH);

    EVP_ENV_INIT(algorithmType, pCipher, keyBuf, pKey);

    EVP_CIPHER_CTX_init(pCtx);

    if (EVP_EncryptInit_ex(pCtx, pCipher, NULL, (const unsigned char *)keyBuf, iv) == 0) {
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_INIT_KEY;
    }

    pAesBuf = malloc(inSize + EVP_MAX_BLOCK_LENGTH);

    if (pAesBuf == NULL) {
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_MALLOC_MEMORY;
    }

    memset(pAesBuf, 0, inSize + EVP_MAX_BLOCK_LENGTH);

    if (EVP_EncryptUpdate(pCtx, pAesBuf, &enBytes, pInBuf, (int)inSize) == 0) {
        free(pAesBuf);
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_UPDATE;
    }

    *pOutBuf = pAesBuf;
    pAesBuf += enBytes;
    *pOutSize += enBytes;

    if (EVP_EncryptFinal_ex(pCtx, pAesBuf, &enBytes) == 0) {
        //free(pOutBuf); // Maybe exception
        EVP_CIPHER_CTX_cleanup(pCtx);
        EVP_CIPHER_CTX_free(pCtx);
        return -ERR_EVP_FINALE;
    }

    *pOutSize += enBytes;

    EVP_CIPHER_CTX_cleanup(pCtx);
    EVP_CIPHER_CTX_free(pCtx);

    return ERR_SUCCESS;
}