#include "pch.h"
#include "tunnel.h"
#include "usrerr.h"
#include <strsafe.h>
#include <tchar.h>
#include <shlwapi.h>

#include "globalcfg.h"
#include "misc.h"

#pragma comment(lib, "Shlwapi.lib")

constexpr auto WINENVBUF_SIZE = (4096);

#define CFG_WIREGUARD_SECTION TEXT("WireGuard")
#define CFG_WIREGUARD_PATH    TEXT("WireGuardExe")
#define CFG_WGCFG_PATH        TEXT("WgCfgPath")
#define CFG_WG_PATH           TEXT("WgExe")

int WireGuardInstallServerService(bool bInstall) {
    TCHAR cfgVal[MAX_PATH];
    TCHAR cmdBuf[MAX_PATH];

    GetPrivateProfileString(CFG_WIREGUARD_SECTION,
                            CFG_WGCFG_PATH,
                            TEXT(""),
                            cfgVal,
                            MAX_PATH,
                            GetGlobalCfgInfo()->cfgPath);

    if (lstrlen(cfgVal) > 0) {
        WIN32_FIND_DATA FindFileData;
        const HANDLE    hFind = FindFirstFile(cfgVal, &FindFileData);

        if (hFind != INVALID_HANDLE_VALUE) {
            int ret;

            if (bInstall) {
                // 安装服务
                StringCbPrintf(cmdBuf,
                               MAX_PATH,
                               TEXT("\"%s\" /installtunnelservice \"%s\""),
                               GetGlobalCfgInfo()->wireguardCfg.wireguardPath,
                               cfgVal);
            } else {
                // 卸载服务
                TCHAR svrName[MAX_PATH];

                StringCbCopy(svrName, MAX_PATH, cfgVal);
                PathStripPath(svrName);
                PathRemoveExtension(svrName);

                StringCbPrintf(cmdBuf,
                               MAX_PATH,
                               TEXT("\"%s\" /uninstalltunnelservice %s"),
                               GetGlobalCfgInfo()->wireguardCfg.wireguardPath,
                               svrName);
            }

            if ((ret = RunCommand(cmdBuf, nullptr, 0)) != ERR_SUCCESS) {
                SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret);
                return -ERR_CALL_SHELL;
            }

            SPDLOG_DEBUG("Run command [{0}]", cmdBuf);

            return ERR_SUCCESS;
        } else {
            SPDLOG_ERROR("WireGuard configure file [{0}] not found", cfgVal);
            return -ERR_FILE_NOT_EXISTS;
        }
    } else {
        SPDLOG_ERROR("Configure [{0}] not found", CFG_WGCFG_PATH);
        return -ERR_ITEM_UNEXISTS;
    }
}

int WireGuardCreateClientConfig(const PWGCLIENT_CONFIG pWgConfig) {
    const size_t bufSize     = 4096 * sizeof(TCHAR);
    const TCHAR  cfgFormat[] = TEXT(
        "[Interface]\nPrivateKey = %s\nAddress = %s\n\n[Peer]\nPublicKey = %s\nAllowedIPs = %s\nEndpoint = "
         "%s\nPersistentKeepalive = 30\n");
    TCHAR  cfgPath[MAX_PATH];
    size_t length;
    HANDLE hFile;
    TCHAR *pBuf;

#pragma region
    if (pWgConfig == nullptr) {
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->Name, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Name error: {0}", pWgConfig->Name);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->Address, 32, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Address error: {0}", pWgConfig->Address);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->PrivateKey, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Private key error: {0}", pWgConfig->PrivateKey);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->SvrPubKey, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Server Public key error: {0}", pWgConfig->SvrPubKey);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->AllowNet, 256, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Allow Client Network error: {0}", pWgConfig->AllowNet);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->ServerURL, 256, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Server Network error: {0}", pWgConfig->ServerURL);
        return -ERR_INPUT_PARAMS;
    }
#pragma endregion 参数检查

    pBuf = static_cast<TCHAR *>(malloc(bufSize));

    if (pBuf == nullptr) {
        SPDLOG_ERROR("Malloc {1} bytes memory error: {0}", GetLastError(), bufSize);
        return -ERR_MALLOC_MEMORY;
    }

    memset(pBuf, 0, bufSize);

    StringCbPrintf(cfgPath, MAX_PATH, "%s\\%s.conf", GetGlobalCfgInfo()->workDirectory, pWgConfig->Name);

    hFile = CreateFile(cfgPath,                         // name of the write
                       GENERIC_WRITE | GENERIC_READ,    // open for writing
                       FILE_SHARE_READ,                 // do not share
                       nullptr,                         // default security
                       CREATE_ALWAYS,                   // create new file only
                       FILE_ATTRIBUTE_NORMAL,           // normal file
                       nullptr);                        // no attr. template

    if (hFile == INVALID_HANDLE_VALUE) {
        SPDLOG_ERROR("CreatFile [{0}] error: {1}", cfgPath, GetLastError());
        free(pBuf);
        return -ERR_OPEN_FILE;
    }

    // 保存到配置文件中
    WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WGCFG_PATH, cfgPath, GetGlobalCfgInfo()->cfgPath);

    if (FAILED(StringCbPrintf(pBuf,
                              bufSize,
                              cfgFormat,
                              pWgConfig->PrivateKey,
                              pWgConfig->Address,
                              pWgConfig->SvrPubKey,
                              pWgConfig->AllowNet,
                              pWgConfig->ServerURL))) {
        SPDLOG_ERROR("Format string error: {0}", GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_MEMORY_STR;
    }

    if (FAILED(StringCbLength(pBuf, bufSize, &length))) {
        SPDLOG_ERROR("Get string \'{0}\' length error: {1}", pBuf, GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_MEMORY_STR;
    }

    SPDLOG_DEBUG("WG Client Configure:\n{0}", pBuf);

    if (!WriteFile(hFile,                         // open file handle
                   pBuf,                          // start of data to write
                   static_cast<DWORD>(length),    // number of bytes to write
                   nullptr,                       // number of bytes that were written
                   nullptr)) {
        SPDLOG_ERROR("WriteFile [{0}] error: {1}", cfgPath, GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_OPEN_FILE;
    }

    ::CloseHandle(hFile);
    return ERR_SUCCESS;
}

int WireGuardCreateServerConfig(const PWGSERVER_CONFIG pWgConfig) {
    const size_t bufSize     = 4096 * sizeof(TCHAR);
    const TCHAR  cfgFormat[] = TEXT(
        "[Interface]\nAddress = %s\nListenPort = %d\nPrivateKey = %s\n\n[Peer]\nPublicKey = %s\nAllowedIPs = %s\n");
    TCHAR  cfgPath[MAX_PATH];
    size_t length;
    HANDLE hFile;
    TCHAR *pBuf;

#pragma region
    if (pWgConfig == nullptr) {
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->Name, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Name error: {0}", pWgConfig->Name);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->Address, 32, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Address error: {0}", pWgConfig->Address);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->PrivateKey, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Private key error: {0}", pWgConfig->PrivateKey);
        return -ERR_INPUT_PARAMS;
    }

    if (pWgConfig->ListenPort <= 1024 || pWgConfig->ListenPort >= 65535) {
        SPDLOG_ERROR("WireGuard Listen port error: {0}, should be in arrange (1024, 65535)", pWgConfig->ListenPort);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->CliPubKey, 64, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Client Public key error: {0}", pWgConfig->CliPubKey);
        return -ERR_INPUT_PARAMS;
    }

    if (FAILED(StringCbLength(pWgConfig->AllowNet, 256, &length)) || 0 == length) {
        SPDLOG_ERROR("WireGuard Allow Client Network error: {0}", pWgConfig->AllowNet);
        return -ERR_INPUT_PARAMS;
    }
#pragma endregion 参数检查

    pBuf = static_cast<TCHAR *>(malloc(bufSize));

    if (pBuf == nullptr) {
        SPDLOG_ERROR("Malloc {1} bytes memory error: {0}", GetLastError(), bufSize);
        return -ERR_MALLOC_MEMORY;
    }

    memset(pBuf, 0, bufSize);

    StringCbPrintf(cfgPath, MAX_PATH, "%s\\%s.conf", GetGlobalCfgInfo()->workDirectory, pWgConfig->Name);

    hFile = CreateFile(cfgPath,                         // name of the write
                       GENERIC_WRITE | GENERIC_READ,    // open for writing
                       FILE_SHARE_READ,                 // do not share
                       nullptr,                         // default security
                       CREATE_ALWAYS,                   // create new file only
                       FILE_ATTRIBUTE_NORMAL,           // normal file
                       nullptr);                        // no attr. template

    if (hFile == INVALID_HANDLE_VALUE) {
        SPDLOG_ERROR("CreatFile [{0}] error: {1}", cfgPath, GetLastError());
        free(pBuf);
        return -ERR_OPEN_FILE;
    }

    WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WGCFG_PATH, cfgPath, GetGlobalCfgInfo()->cfgPath);

    if (FAILED(StringCbPrintf(pBuf,
                              bufSize,
                              cfgFormat,
                              pWgConfig->Address,
                              pWgConfig->ListenPort,
                              pWgConfig->PrivateKey,
                              pWgConfig->CliPubKey,
                              pWgConfig->AllowNet))) {
        SPDLOG_ERROR("Format string error: {0}", GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_MEMORY_STR;
    }

    if (FAILED(StringCbLength(pBuf, bufSize, &length))) {
        SPDLOG_ERROR("Get string \'{0}\' length error: {1}", pBuf, GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_MEMORY_STR;
    }

    SPDLOG_DEBUG("WG Server Configure:\n{0}", pBuf);

    if (FALSE ==
        WriteFile(hFile,                         // open file handle
                  pBuf,                          // start of data to write
                  static_cast<DWORD>(length),    // number of bytes to write
                  nullptr,                       // number of bytes that were written
                  nullptr))                      // no overlapped structure)
    {
        SPDLOG_ERROR("WriteFile [{0}] error: {1}", cfgPath, GetLastError());
        free(pBuf);
        ::CloseHandle(hFile);
        return -ERR_OPEN_FILE;
    }

    ::CloseHandle(hFile);

    StringCbCopy(GetGlobalCfgInfo()->wgServerCfg.wgName, 260, pWgConfig->Name);
    StringCbCopy(GetGlobalCfgInfo()->wgServerCfg.wgIpaddr, MAX_IP_LEN, pWgConfig->Address);
    return ERR_SUCCESS;
}

int GenerateWireguardKeyPairs(TCHAR *pPubKey, int pubkeySize, TCHAR *pPrivKey, int privKeySize) {
    int         ret;
    TCHAR       cmdBuffer[MAX_PATH];
    TCHAR       cmdResult[MAX_PATH];
    PSDK_CONFIG pCfg = GetGlobalCfgInfo();

    // WireGuard 不存在或者未配置目录
    if (!pCfg->wireguardCfg.wgExists || !pCfg->wireguardCfg.wireguardExists) {
        return -ERR_ITEM_UNEXISTS;
    }

    memset(cmdBuffer, 0, MAX_PATH);
    memset(cmdResult, 0, MAX_PATH);

    StringCbPrintf(cmdBuffer, MAX_PATH, TEXT("cmd.exe /C \"%s\" genkey"), pCfg->wireguardCfg.wgPath);

    if ((ret = RunCommand(cmdBuffer, cmdResult, MAX_PATH)) != ERR_SUCCESS) {
        SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuffer, ret);
        return -ERR_CALL_SHELL;
    }

    SPDLOG_DEBUG("Run command [{0}] resutl \'{1}\'", cmdBuffer, cmdResult);

    StringCbCopy(pPrivKey, privKeySize, cmdResult);
    memset(cmdBuffer, 0, MAX_PATH);
    StringCbPrintf(cmdBuffer,
                   MAX_PATH,
                   TEXT("cmd.exe /C echo %s | \"%s\" pubkey"),
                   cmdResult,
                   pCfg->wireguardCfg.wgPath);

    memset(cmdResult, 0, MAX_PATH);
    if ((ret = RunCommand(cmdBuffer, cmdResult, MAX_PATH)) != ERR_SUCCESS) {
        SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuffer, ret);
        return -ERR_CALL_SHELL;
    }

    StringCbCopy(pPubKey, pubkeySize, cmdResult);
    SPDLOG_DEBUG("Run command [{0}] resutl \'{1}\'", cmdBuffer, cmdResult);

    return ERR_SUCCESS;
}

int SetWireguardPath(const TCHAR *pPath) {
    WIN32_FIND_DATA FindFileData;
    HANDLE          hFind;

    if (pPath == nullptr) {
        return -ERR_INPUT_PARAMS;
    }

    hFind = FindFirstFile(pPath, &FindFileData);

    if (hFind != INVALID_HANDLE_VALUE) {
        TCHAR wgPath[MAX_PATH];

        SPDLOG_DEBUG(TEXT("Used configure file:{0}"), GetGlobalCfgInfo()->cfgPath);

        WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WIREGUARD_PATH, pPath, GetGlobalCfgInfo()->cfgPath);
        SPDLOG_DEBUG(TEXT("Save configure: {1} --> {0}"), pPath, CFG_WIREGUARD_PATH);

        StringCbCopy(wgPath, MAX_PATH, pPath);

        if (TCHAR *pIndex = _tcsrchr(wgPath, '\\')) {
            *pIndex = 0;
            StringCbCat(wgPath, MAX_PATH, "\\wg.exe");
            WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WG_PATH, wgPath, GetGlobalCfgInfo()->cfgPath);
            SPDLOG_DEBUG(TEXT("Save configure: {1} --> {0}"), wgPath, CFG_WG_PATH);
        }

        return ERR_SUCCESS;
    } else {
        SPDLOG_ERROR(TEXT("WireGuard not found: {0}"), pPath);
        return -ERR_ITEM_UNEXISTS;
    }
}

int FindWireguardExe(TCHAR *pFullPath, int maxSize) {
    TCHAR           path[MAX_PATH];
    TCHAR           wrieguardPath[MAX_PATH];
    WIN32_FIND_DATA FindFileData;
    HANDLE          hFind;
    DWORD           dwRet;
    LPSTR           pEnvBuf;
    TCHAR          *token, *p = nullptr;

    GetPrivateProfileString(CFG_WIREGUARD_SECTION,
                            CFG_WIREGUARD_PATH,
                            TEXT(""),
                            wrieguardPath,
                            MAX_PATH,
                            GetGlobalCfgInfo()->cfgPath);

    hFind = FindFirstFile(wrieguardPath, &FindFileData);
    if (hFind != INVALID_HANDLE_VALUE) {
        if (pFullPath && maxSize > 0) {
            StringCbCopy(pFullPath, maxSize, wrieguardPath);
        }

        StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wireguardPath, MAX_PATH, wrieguardPath);
        GetGlobalCfgInfo()->wireguardCfg.wireguardExists = TRUE;

        SPDLOG_DEBUG(TEXT("Ini found WireGuard at: {0}"), wrieguardPath);

        GetPrivateProfileString(CFG_WIREGUARD_SECTION,
                                CFG_WG_PATH,
                                TEXT(""),
                                wrieguardPath,
                                MAX_PATH,
                                GetGlobalCfgInfo()->cfgPath);

        hFind = FindFirstFile(wrieguardPath, &FindFileData);
        if (hFind != INVALID_HANDLE_VALUE) {
            StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wgPath, MAX_PATH, wrieguardPath);
            GetGlobalCfgInfo()->wireguardCfg.wgExists = TRUE;
            SPDLOG_DEBUG(TEXT("Ini found WireGuard Tools at: {0}"), wrieguardPath);
        }

        return ERR_SUCCESS;
    }

    pEnvBuf = static_cast<LPSTR>(malloc(WINENVBUF_SIZE));
    if (nullptr == pEnvBuf) {
        SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), WINENVBUF_SIZE);
        return -ERR_MALLOC_MEMORY;
    }

    dwRet = GetEnvironmentVariable(TEXT("path"), pEnvBuf, WINENVBUF_SIZE);

    if (0 == dwRet) {
        DWORD dwErr;
        dwErr = GetLastError();
        if (ERROR_ENVVAR_NOT_FOUND == dwErr) {
            SPDLOG_DEBUG(TEXT("Environment variable path does not exist."));
            free(pEnvBuf);
            return -ERR_FILE_NOT_EXISTS;
        }
    } else if (WINENVBUF_SIZE < dwRet) {
        LPSTR pBuf = static_cast<LPSTR>(realloc(pEnvBuf, dwRet * sizeof(CHAR)));
        if (nullptr == pBuf) {
            SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), dwRet * sizeof(CHAR));
            free(pEnvBuf);
            return -ERR_MALLOC_MEMORY;
        }

        pEnvBuf = pBuf;
        dwRet   = GetEnvironmentVariable("path", pEnvBuf, dwRet);
        if (!dwRet) {
            SPDLOG_ERROR(TEXT("GetEnvironmentVariable failed (%d)"), GetLastError());
            free(pEnvBuf);
            return -ERR_FILE_NOT_EXISTS;
        }
    }

    token = strtok_s(pEnvBuf, TEXT(";"), &p);

    while (token != nullptr) {
        memset(path, 0, MAX_PATH);
        StringCbPrintfA(path, MAX_PATH, TEXT("%s\\wireguard.exe"), token);

        hFind = FindFirstFile(path, &FindFileData);

        if (hFind != INVALID_HANDLE_VALUE) {
            if (pFullPath && maxSize > 0) {
                StringCbCopy(pFullPath, maxSize, path);
            }

            // 保存路径到配置文件
            SetWireguardPath(path);
            SPDLOG_DEBUG(TEXT("Path Environment found WireGuard at: {0}"), path);

            StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wireguardPath, MAX_PATH, wrieguardPath);
            GetGlobalCfgInfo()->wireguardCfg.wireguardExists = TRUE;

            memset(path, 0, MAX_PATH);
            StringCbPrintf(path, MAX_PATH, TEXT("%s\\wg.exe"), token);

            SPDLOG_DEBUG(TEXT("Find WireGuard tools at: {0}"), path);

            hFind = FindFirstFile(path, &FindFileData);
            if (hFind != INVALID_HANDLE_VALUE) {
                StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wgPath, MAX_PATH, path);
                GetGlobalCfgInfo()->wireguardCfg.wgExists = TRUE;

                SPDLOG_DEBUG(TEXT("Path Environment found WireGuard tools at: {0}"), path);
            }

            //TODO: throw exception by C# call, why??????
            //CloseHandle(hFind);
            free(pEnvBuf);
            return ERR_SUCCESS;
        }
        token = strtok_s(nullptr, TEXT(";"), &p);
    }

    free(pEnvBuf);
    return -ERR_FILE_NOT_EXISTS;
}