#include "pch.h"
#include "tunnel.h"
#include "usrerr.h"
#include "misc.h"

#include <bcrypt.h>
#include <shlwapi.h>
#include <spdlog/spdlog.h>

#pragma comment(lib, "Bcrypt.lib")

//#define NT_SUCCESS(Status)  (((NTSTATUS)(Status)) >= 0)
#define NT_FAILED(s) (((NTSTATUS)(s)) < 0)
//#define STATUS_UNSUCCESSFUL ((NTSTATUS)0xC0000001L)

static const LPCWSTR g_BcryptHash[] = {
    BCRYPT_MD2_ALGORITHM,
    BCRYPT_MD4_ALGORITHM,
    BCRYPT_MD5_ALGORITHM,
    BCRYPT_SHA1_ALGORITHM,
    BCRYPT_SHA256_ALGORITHM,
    BCRYPT_SHA384_ALGORITHM,
    BCRYPT_SHA512_ALGORITHM,
};

int CalcFileHash(const HASH_TYPE type, const TCHAR *pPath, TCHAR outHash[]) {
    HANDLE             hFile;
    BYTE               rgbFile[1024];
    DWORD              cbRead = 0;
    BCRYPT_ALG_HANDLE  hAlg   = nullptr;
    BCRYPT_HASH_HANDLE hHash  = nullptr;
    NTSTATUS           status;
    DWORD              cbData = 0, cbHash = 0, cbHashObject = 0;
    PBYTE              pbHashObject;
    PBYTE              pbHash;

    if (pPath == nullptr) {
        SPDLOG_ERROR(TEXT("Input pPath params error: {0}"), pPath);
        return -ERR_INPUT_PARAMS;
    }

    if (!PathFileExists(pPath)) {
        SPDLOG_ERROR(TEXT("File \'{0}\' not found."), pPath);
        return -ERR_ITEM_UNEXISTS;
    }

    hFile = CreateFile(pPath,
                       GENERIC_READ,
                       FILE_SHARE_READ,
                       nullptr,
                       OPEN_EXISTING,
                       FILE_FLAG_SEQUENTIAL_SCAN,
                       nullptr);

    if (INVALID_HANDLE_VALUE == hFile) {
        SPDLOG_ERROR(TEXT("Error opening file %s\nError: {0}"), pPath, GetLastError());
        return -ERR_OPEN_FILE;
    }

    //open an algorithm handle
    if (NT_FAILED(status = BCryptOpenAlgorithmProvider(&hAlg, g_BcryptHash[type], nullptr, 0))) {
        SPDLOG_ERROR(TEXT("Error {0} returned by BCryptOpenAlgorithmProvider"), status);
        CloseHandle(hFile);
        return -ERR_BCRYPT_OPEN;
    }

    //calculate the size of the buffer to hold the hash object
    if (NT_FAILED(status = BCryptGetProperty(hAlg,
                                             BCRYPT_OBJECT_LENGTH,
                                             reinterpret_cast<PBYTE>(&cbHashObject),
                                             sizeof(DWORD),
                                             &cbData,
                                             0))) {
        SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        return -ERR_BCRYPT_GETPROPERTY;
    }

    //allocate the hash object on the heap
    pbHashObject = static_cast<PBYTE>(HeapAlloc(GetProcessHeap(), 0, cbHashObject));
    if (nullptr == pbHashObject) {
        SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHashObject);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        return -ERR_MALLOC_MEMORY;
    }

    //calculate the length of the hash
    if (NT_FAILED(status = BCryptGetProperty(hAlg,
                                             BCRYPT_HASH_LENGTH,
                                             reinterpret_cast<PBYTE>(&cbHash),
                                             sizeof(DWORD),
                                             &cbData,
                                             0))) {
        SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        HeapFree(GetProcessHeap(), 0, pbHashObject);
        return -ERR_BCRYPT_GETPROPERTY;
    }

    //allocate the hash buffer on the heap
    pbHash = static_cast<PBYTE>(HeapAlloc(GetProcessHeap(), 0, cbHash));
    if (nullptr == pbHash) {
        SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHash);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        HeapFree(GetProcessHeap(), 0, pbHashObject);
        return -ERR_MALLOC_MEMORY;
    }

    //create a hash
    if (NT_FAILED(status = BCryptCreateHash(hAlg, &hHash, pbHashObject, cbHashObject, nullptr, 0, 0))) {
        SPDLOG_ERROR(TEXT("Error {0} returned by BCryptCreateHash"), status);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        HeapFree(GetProcessHeap(), 0, pbHashObject);
        HeapFree(GetProcessHeap(), 0, pbHash);
        return -ERR_BCRYPT_CREATEHASH;
    }

    while (ReadFile(hFile, rgbFile, 1024, &cbRead, nullptr)) {
        if (0 == cbRead) {
            break;
        }

        if (NT_FAILED(status = BCryptHashData(hHash, rgbFile, cbRead, 0))) {
            SPDLOG_ERROR(TEXT("Error {0} returned by BCryptHashData"), status);
            CloseHandle(hFile);
            BCryptCloseAlgorithmProvider(hAlg, 0);
            BCryptDestroyHash(hHash);
            HeapFree(GetProcessHeap(), 0, pbHashObject);
            HeapFree(GetProcessHeap(), 0, pbHash);
            return -ERR_BCRYPT_HASHDATA;
        }
    }

    //close the hash
    if (NT_FAILED(status = BCryptFinishHash(hHash, pbHash, cbHash, 0))) {
        SPDLOG_ERROR(TEXT("Error {0} returned by BCryptFinishHash"), status);
        CloseHandle(hFile);
        BCryptCloseAlgorithmProvider(hAlg, 0);
        BCryptDestroyHash(hHash);
        HeapFree(GetProcessHeap(), 0, pbHashObject);
        HeapFree(GetProcessHeap(), 0, pbHash);
        return -ERR_BCRYPT_FINISHHASH;
    }

    binToHexString(outHash, pbHash, cbHash);

    BCryptCloseAlgorithmProvider(hAlg, 0);
    BCryptDestroyHash(hHash);
    HeapFree(GetProcessHeap(), 0, pbHashObject);
    HeapFree(GetProcessHeap(), 0, pbHash);
    CloseHandle(hFile);

    return ERR_SUCCESS;
}