#include "pch.h"
#include "tunnel.h"

#include <shlobj_core.h>
#include <strsafe.h>
#include <spdlog/sinks/wincolor_sink.h>
#include <spdlog/sinks/rotating_file_sink.h>
#include <spdlog/sinks/dup_filter_sink.h>
#include <spdlog/sinks/daily_file_sink.h>
#include <dbghelp.h>

#include "usrerr.h"
#include "globalcfg.h"
#include "misc.h"
#include "user.h"

#include <shlwapi.h>
#include <winsock2.h>

#pragma comment(lib, "Dbghelp.lib")

#define CONFIG_FILE_NAME TEXT("tunnelsdk.ini")

static SDK_CONFIG g_globalConfig;

PSDK_CONFIG GetGlobalCfgInfo() {
    return &g_globalConfig;
}

static spdlog::level::level_enum logLevelToSpdlogLevel(LOG_LEVEL level) {
    switch (level) {
        case LOG_TRACE:
            return spdlog::level::level_enum::trace;
        case LOG_DEBUG:
            return spdlog::level::level_enum::debug;
        case LOG_INFO:
            return spdlog::level::level_enum::info;
        case LOG_WARN:
            return spdlog::level::level_enum::warn;
        case LOG_ERROR:
            return spdlog::level::level_enum::err;
        case LOG_CRITICAL:
            return spdlog::level::level_enum::critical;
        case LOG_OFF:
            return spdlog::level::level_enum::off;
    }

    return spdlog::level::level_enum::info;
}

static void InitTunnelSDKLog(const TCHAR *pLogFile, LOG_LEVEL level) {
    TCHAR buf[MAX_PATH] = {0};

    if (pLogFile && strlen(pLogFile) > 0 && !PathIsRelative(pLogFile)) {
        TCHAR tmpPath[MAX_PATH];
        StringCbCopy(tmpPath, MAX_PATH, pLogFile);
        PathRemoveFileSpec(tmpPath);
        MakeSureDirectoryPathExists(tmpPath);
        StringCbCopy(buf, MAX_PATH, pLogFile);
    } else {
        StringCbPrintf(buf, MAX_PATH, TEXT("%s\\tunnelsdklog.log"), g_globalConfig.workDirectory);
    }

    g_globalConfig.enableLog = TRUE;
    g_globalConfig.logLevel  = logLevelToSpdlogLevel(level);

    const auto dupFileFilter = std::make_shared<spdlog::sinks::dup_filter_sink_st>(std::chrono::seconds(5));
    const auto dupStdFilter  = std::make_shared<spdlog::sinks::dup_filter_sink_st>(std::chrono::seconds(5));

    //std::make_shared<spdlog::sinks::rotating_file_sink_mt>(buf, 1024 * 1024 * 5, 10)->
    dupFileFilter->add_sink(std::make_shared<spdlog::sinks::daily_file_sink_mt>(buf, 2, 30));
    //dupFileFilter->add_sink(std::make_shared<spdlog::sinks::rotating_file_sink_mt>(buf, 1024 * 1024 * 5, 10));
    dupStdFilter->add_sink(std::make_shared<spdlog::sinks::wincolor_stdout_sink_mt>());

    std::vector<spdlog::sink_ptr> sinks {dupStdFilter, dupFileFilter};
    auto logger = std::make_shared<spdlog::logger>(TEXT("tunnelSDK"), sinks.begin(), sinks.end());
    spdlog::set_default_logger(logger);

    spdlog::set_level(g_globalConfig.logLevel);
    spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e][%l][%s:%#] %v");
    spdlog::flush_every(std::chrono::seconds(1));

#if 0
    std::cout << "TRACE: " << logger->should_log(spdlog::level::trace) << std::endl;
    std::cout << "DEBUG: " << logger->should_log(spdlog::level::debug) << std::endl;
    std::cout << "INFO: " << logger->should_log(spdlog::level::info) << std::endl;
    std::cout << "WARN: " << logger->should_log(spdlog::level::warn) << std::endl;
    std::cout << "ERROR: " << logger->should_log(spdlog::level::err) << std::endl;
    std::cout << "CRITICAL: " << logger->should_log(spdlog::level::critical) << std::endl;
#endif

    SPDLOG_INFO(TEXT("Log({1}): {0}"), buf, static_cast<int>(level));
}

int TunnelSDKInitEnv(const TCHAR *pWorkDir,
                     const TCHAR *pSvrUrl,
                     const TCHAR *pLogFile,
                     LOG_LEVEL    level,
                     bool         isWorkServer) {
    int     ret;
    size_t  length;
    WSADATA WsaData;

    CoInitialize(nullptr);
    CoInitializeSecurity(nullptr,
                         -1,
                         nullptr,
                         nullptr,
                         RPC_C_AUTHN_LEVEL_PKT,
                         RPC_C_IMP_LEVEL_IMPERSONATE,
                         nullptr,
                         EOAC_NONE,
                         nullptr);
    WSAStartup(MAKEWORD(2, 2), &WsaData);

    memset(&g_globalConfig, 0, sizeof(SDK_CONFIG));

    g_globalConfig.isWorkServer       = isWorkServer;
    g_globalConfig.scgProxy.scgGwPort = 0;

    if (pWorkDir == nullptr) {
        // 获取当前文件默认路径
        GetModuleFileName(nullptr, g_globalConfig.workDirectory, MAX_PATH);
        PathRemoveFileSpec(g_globalConfig.workDirectory);
    } else {
        if (StringCbLength(pWorkDir, MAX_PATH, &length) == S_OK && length == 0 || PathIsRelative(pWorkDir)) {
            // 获取当前文件默认路径
            GetModuleFileName(nullptr, g_globalConfig.workDirectory, MAX_PATH);
            PathRemoveFileSpec(g_globalConfig.workDirectory);
        } else {
            MakeSureDirectoryPathExists(pWorkDir);
            StringCbCopy(g_globalConfig.workDirectory, MAX_PATH, pWorkDir);
        }
    }

    // 初始化日志
    InitTunnelSDKLog(pLogFile, level);

    // 创建配置文件存储目录
    if (FAILED(SHGetFolderPath(NULL, CSIDL_APPDATA | CSIDL_FLAG_NO_ALIAS, NULL, 0, g_globalConfig.configDirectory))) {
        SPDLOG_ERROR(TEXT("Get Windows system directory error."));
        return -ERR_SYS_CALL;
    }

    StringCbCat(g_globalConfig.configDirectory, MAX_PATH, "\\NetTunnel");
    SPDLOG_DEBUG(TEXT("Configure directory: {0}."), g_globalConfig.configDirectory);
    SPDLOG_DEBUG(TEXT("Platform Server: {}, Work Module: {}"), pSvrUrl, isWorkServer? TEXT("SERVER") : TEXT("Client"));

    // 如果配置目录不存在则自动创建
    if (!PathFileExists(g_globalConfig.configDirectory)) {
        if (!CreateDirectory(g_globalConfig.configDirectory, nullptr)) {
            SPDLOG_ERROR(TEXT("Create configure directory '{0}' error."), g_globalConfig.configDirectory);
            return -ERR_CREATE_FILE;
        }
    }

    StringCbCopy(g_globalConfig.platformServerUrl, MAX_IP_LEN, pSvrUrl);

    if (FAILED(SHGetFolderPath(NULL, CSIDL_WINDOWS | CSIDL_FLAG_NO_ALIAS, NULL, 0, g_globalConfig.systemDirectory))) {
        SPDLOG_ERROR(TEXT("Get Windows system directory error."));
        return -ERR_SYS_CALL;
    }

    StringCbPrintf(g_globalConfig.cfgPath, MAX_PATH, TEXT("%s\\%s"), g_globalConfig.workDirectory, CONFIG_FILE_NAME);

    if ((ret = InitializeWireGuardLibrary()) == ERR_SUCCESS) {
        return ret;
    }

#if 0
    if (FindWireguardExe(nullptr, 0) != ERR_SUCCESS) {
        SPDLOG_ERROR(TEXT("WireGuard not found, Please install WireGuard first or set the WireGuard Path."));
        return -ERR_ITEM_UNEXISTS;
    }
#endif
    
    return ERR_SUCCESS;
}

void TunnelSDKUnInit() {
    RemoteWireGuardControl(false);
    LocalWireGuardControl(false, false);
    UnInitializeWireGuardLibrary();
    CoFreeUnusedLibraries();
    WSACleanup();
}

void DisableVerifySignature() {
    memset(g_globalConfig.clientId, 0, MAX_PATH);
    memset(g_globalConfig.clientSecret, 0, MAX_PATH);
}

int EnableVerifySignature(const TCHAR *pClientId, const TCHAR *pClientSecret) {
    if (pClientId == nullptr || lstrlen(pClientId) == 0 || lstrlen(pClientId) >= MAX_PATH) {
        SPDLOG_ERROR(TEXT("Input pClientId params error: {0}"), pClientId);
        return -ERR_INPUT_PARAMS;
    }

    if (pClientSecret == nullptr || lstrlen(pClientSecret) == 0 || lstrlen(pClientSecret) >= MAX_PATH) {
        SPDLOG_ERROR(TEXT("Input pClientSecret params error: {0}"), pClientSecret);
        return -ERR_INPUT_PARAMS;
    }

    DisableVerifySignature();

    StringCbCopy(g_globalConfig.clientId, MAX_PATH, pClientId);
    StringCbCopy(g_globalConfig.clientSecret, MAX_PATH, pClientSecret);

    return ERR_SUCCESS;
}

int EnableSCGProxy(bool isEnable, const TCHAR *pSCGIpAddr, int scgPort) {

    if (pSCGIpAddr == nullptr || lstrlen(pSCGIpAddr) == 0 || lstrlen(pSCGIpAddr) >= MAX_IP_LEN) {
        SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pSCGIpAddr);
        return -ERR_INPUT_PARAMS;
    }

    memset(g_globalConfig.scgProxy.scgIpAddr, 0, MAX_IP_LEN);

    if (isEnable) {
        IP_INFO ipInfo;
        int     ret;
        if ((ret = GetIpV4InfoFromHostname(AF_INET, pSCGIpAddr, &ipInfo)) != ERR_SUCCESS) {
            return ret;
        }

        g_globalConfig.userCfg.cliConfig.scgCtrlAppId   = WG_CTRL_SCG_ID;
        g_globalConfig.userCfg.cliConfig.scgTunnelAppId = WG_TUNNEL_SCG_ID;

        StringCbCopy(g_globalConfig.scgProxy.scgIpAddr, MAX_IP_LEN, ipInfo.hostip);
        g_globalConfig.scgProxy.scgGwPort = static_cast<UINT16>(scgPort);
        CreateUDPProxyServer();
    } else {
        StopUDPProxyServer();
        g_globalConfig.scgProxy.scgGwPort = 0;
    }

    return ERR_SUCCESS;
}

bool UsedSCGProxy() {
    return (g_globalConfig.scgProxy.scgGwPort > 0);
}

void TunnelLogEnable(bool enLog) {
    if (enLog) {
        spdlog::set_level(g_globalConfig.logLevel);
    } else {
        spdlog::set_level(spdlog::level::level_enum::off);
    }
}

int SetProtocolEncryptType(const PROTO_CRYPTO_TYPE type, const TCHAR *pProKey) {
    if (type > CRYPTO_BASE64 && type < CRYPTO_MAX) {
        if (pProKey == nullptr || strlen(pProKey) < MIN_IP_LEN) {
            return -ERR_INPUT_PARAMS;
        }
    }

    g_globalConfig.proCryptoType = type;
    StringCbCopy(g_globalConfig.proKeyBuf, 256, pProKey);

    SPDLOG_DEBUG(TEXT("Protocol crypto type: {0} with key [{1}]"),
                 static_cast<int>(type),
                 pProKey ? pProKey : TEXT(""));

    return ERR_SUCCESS;
}

//int CheckSystemMinDepend(CHECK_FUNCTION chkItem, TCHAR* pErrMsg, errMsg[MAX_PATH], );

int CheckSystemMinRequired(CHK_RESULT chkResult[CHK_MAX]) {
    for (int i = 0; i < CHK_MAX; i++) {
        const PCHK_RESULT pChk = &chkResult[i];

        pChk->chk    = static_cast<CHECK_FUNCTION>(i);
        pChk->result = true;
        memset(pChk->errMsg, 0, MAX_PATH);
        switch (pChk->chk) {
            case CHK_SYSTEM_INIT:
                if (lstrlen(g_globalConfig.configDirectory) == 0) {
                    pChk->result = false;
                    StringCbCopy(pChk->errMsg,
                                 MAX_PATH,
                                 TEXT("错误: SDK 未初始化,请先调用 TunnelSDKInitEnv 接口执行初始化操作。"));
                    SPDLOG_ERROR(pChk->errMsg);
                }
                break;
            case CHK_WIREGUARD_CONFIG: {
                TCHAR cfgVal[MAX_PATH];
                GetPrivateProfileString(CFG_WIREGUARD_SECTION,
                                        CFG_WGCFG_PATH,
                                        TEXT(""),
                                        cfgVal,
                                        MAX_PATH,
                                        GetGlobalCfgInfo()->cfgPath);

                if (!PathFileExists(cfgVal)) {
                    pChk->result = false;
                    StringCbCopy(pChk->errMsg,
                                 MAX_PATH,
                                 TEXT("错误: 未找到 WireGuard 配置文件,请先调用 WireGuardInstallServerService 或者 "
                                      "WireGuardCreateClientConfig 接口创建合法的 WireGuard 配置文件。"));
                    SPDLOG_ERROR(pChk->errMsg);
                }
            } break;
            case CHK_WIREGUARD_SERVICE: {
                int  ret;
                bool bInstall;
                ret = IsWireGuardServerInstalled(&bInstall);
                if (ret != ERR_SUCCESS) {
                    pChk->result = false;
                    StringCbPrintf(pChk->errMsg,
                                   MAX_PATH,
                                   TEXT("错误: 获取系统 WireGuard 服务安装状态异常, 错误码:{0}。"),
                                   ret);
                    SPDLOG_ERROR(pChk->errMsg);
                }

                if (!bInstall) {
                    pChk->result = false;
                    StringCbCopy(pChk->errMsg,
                                 MAX_PATH,
                                 TEXT("错误:系统 WireGuard 服务未安装,请调用 WireGuardInstallServerService 接口安装 "
                                      "WireGuard 服务。"));
                    SPDLOG_ERROR(pChk->errMsg);
                }
            } break;
            case CHK_WG_INTERFACE_PRIVATE:
                if (!chkResult[CHK_WIREGUARD_SERVICE].chk) {
                    pChk->result = false;
                    StringCbCopy(pChk->errMsg,
                                 MAX_PATH,
                                 TEXT("错误:系统 WireGuard 服务未安装,请调用 WireGuardInstallServerService 接口安装 "
                                      "WireGuard 服务。"));
                    SPDLOG_ERROR(pChk->errMsg);
                }
                break;
            case CHK_MAX: 
                break;
        }
    }    

    return ERR_SUCCESS;
}