NetTunnelWindows/NetTunnelSDK/UserManager.cpp

311 lines
10 KiB
C++

#include "pch.h"
#include "tunnel.h"
#include "usrerr.h"
#include "globalcfg.h"
#include "misc.h"
#include "protocol.h"
#include "user.h"
#include <shlwapi.h>
#include <strsafe.h>
static HANDLE g_HeartTimerQueue = nullptr;
static LPTUNNEL_HEART_ROUTINE g_lpHeartCb = nullptr;
/**
* @brief 启动/停止 隧道控制服务心跳
* @param isStart 启动/停止服务 TRUE 启动服务, FALSE 停止服务
* @param lpHeartCbAddress 心跳服务回调函数 @see PTUNNEL_HEART_ROUTINE
* @return 0: 成功, 小于0 失败 @see USER_ERRNO
* - -ERR_INPUT_PARAMS 输入参数错误
* - -ERR_CREATE_TIMER 创建定时器失败
* - -ERR_DELETE_TIMER 删除定时器失败
* - ERR_SUCCESS 成功
*/
int RemoteHeartControl(bool isStart, LPTUNNEL_HEART_ROUTINE lpHeartCbAddress) {
if (isStart && lpHeartCbAddress == nullptr) {
SPDLOG_ERROR(TEXT("Input lpHeartCbAddress params nullptr"));
return -ERR_INPUT_PARAMS;
}
g_lpHeartCb = lpHeartCbAddress;
if (isStart) {
if (!g_HeartTimerQueue) {
HANDLE hTimer = nullptr;
// Create the timer queue.
g_HeartTimerQueue = CreateTimerQueue();
if (nullptr == g_HeartTimerQueue) {
SPDLOG_ERROR(TEXT("CreateTimerQueue failed ({0})"), GetLastError());
return -ERR_CREATE_TIMER;
}
// Set a timer to call the timer routine in 10 seconds.
if (!CreateTimerQueueTimer(
&hTimer,
g_HeartTimerQueue,
[](PVOID lpParam, BOOLEAN TimerOrWaitFired) {
int ret;
ProtocolRequest<ReqHeartParams> req;
ProtocolResponse<RspHeartParams> rsp;
ret = ProtolPostMessage(SET_CLIENTHEART_PATH, &req, &rsp, false);
if (g_lpHeartCb && ret) {
g_lpHeartCb(rsp.msgContent.message.c_str(), rsp.timeStamp);
}
},
nullptr,
0,
HEART_PERIOD_MS,
WT_EXECUTEDEFAULT)) {
SPDLOG_ERROR(TEXT("CreateTimerQueueTimer failed ({0})"), GetLastError());
return -ERR_CREATE_TIMER;
}
}
} else {
if (g_HeartTimerQueue) {
if (!DeleteTimerQueue(g_HeartTimerQueue)) {
SPDLOG_ERROR(TEXT("DeleteTimerQueue failed ({0})"), GetLastError());
g_HeartTimerQueue = nullptr;
return -ERR_DELETE_TIMER;
}
g_HeartTimerQueue = nullptr;
}
}
return ERR_SUCCESS;
}
int RemoteWireGuardControl(bool isStart) {
int ret;
ProtocolRequest<ReqStartTunnelParams> req;
ProtocolResponse<ResponseStatus> rsp;
req.msgContent.isStart = isStart;
ret = ProtolPostMessage(SET_CLIENTSTART_TUNNEL, &req, &rsp, false);
if (ret != ERR_SUCCESS) {
return ret;
}
if (rsp.msgContent.errCode != ERR_SUCCESS) {
SPDLOG_ERROR(TEXT("Service Response error({0}): {1}"), rsp.msgContent.errCode, rsp.msgContent.errMessage);
return rsp.msgContent.errCode;
}
return ERR_SUCCESS;
}
int SetClientConfige(const TCHAR *pCliPublicKey, const TCHAR *pCliNetwork, const TCHAR *pCliTunnelAddr) {
int ret;
ProtocolRequest<ReqUserSetCliCfgParams> req;
ProtocolResponse<ResponseStatus> rsp;
req.msgContent.cliPublicKey = pCliPublicKey;
req.msgContent.cliNetwork = pCliNetwork;
req.msgContent.cliTunnelAddr = pCliTunnelAddr;
ret = ProtolPostMessage(SET_CLIENTCFG_PATH, &req, &rsp, false);
if (ret != ERR_SUCCESS) {
return ret;
}
if (rsp.msgContent.errCode != ERR_SUCCESS) {
SPDLOG_ERROR(TEXT("Service Response error({0}): {1}"), rsp.msgContent.errCode, rsp.msgContent.errMessage);
return rsp.msgContent.errCode;
}
return ERR_SUCCESS;
}
int GetUserServerConfigure(const TCHAR *pUserName, const TCHAR *pToken, PUSER_SERVER_CONFIG *pSvrCfg) {
int ret;
PUSER_CONFIG pUser = &GetGlobalCfgInfo()->userCfg;
PUSER_SERVER_CONFIG pUserCfg = &pUser->svrConfig;
ProtocolRequest<ReqGetUserCfgParams> req;
ProtocolResponse<RspUserSevrCfgParams> rsp;
if (pSvrCfg == nullptr) {
SPDLOG_ERROR(TEXT("Input pSvrCfg params error"));
return -ERR_INPUT_PARAMS;
}
if (pToken == nullptr || lstrlen(pToken) == 0) {
SPDLOG_ERROR(TEXT("Input pToken params error: {0}"), pToken);
return -ERR_INPUT_PARAMS;
}
if (pUserName && lstrlen(pUserName) > 0) {
memset(pUser->userName, 0, MAX_PATH);
StringCbCopy(pUser->userName, MAX_PATH, pUserName);
}
StringCbCopy(pUser->userToken, MAX_PATH, pToken);
req.msgContent.token = pToken;
req.msgContent.user = pUserName;
ret = ProtolPostMessage(GET_SERVERCFG_PATH, &req, &rsp);
if (ret != ERR_SUCCESS) {
return ret;
}
pUserCfg->svrListenPort = rsp.msgContent.svrListenPort;
StringCbCopy(pUserCfg->svrPrivateKey, 64, rsp.msgContent.svrPrivateKey.c_str());
StringCbCopy(pUserCfg->svrAddress, MAX_IP_LEN, rsp.msgContent.svrAddress.c_str());
*pSvrCfg = pUserCfg;
return ERR_SUCCESS;
}
int GetUserClientConfigure(const TCHAR *pUserName, const TCHAR *pToken, PUSER_CLIENT_CONFIG *pCliCfg) {
PVM_CFG pVm;
PUSER_CONFIG pUser = &GetGlobalCfgInfo()->userCfg;
PUSER_CLIENT_CONFIG pUserCfg = &pUser->cliConfig;
TCHAR userPath[MAX_PATH];
int ret;
unsigned int memSize;
ProtocolRequest<ReqGetUserCfgParams> req;
ProtocolResponse<RspUsrCliConfigParams> rsp;
if (pToken == nullptr || lstrlen(pToken) == 0) {
SPDLOG_ERROR(TEXT("Input pToken params error: {0}"), pToken);
return -ERR_INPUT_PARAMS;
}
if (pCliCfg == nullptr) {
SPDLOG_ERROR(TEXT("Input pCliCfg params error"));
return -ERR_INPUT_PARAMS;
}
StringCbPrintf(userPath, MAX_PATH, "%s\\%s", GetGlobalCfgInfo()->configDirectory, pUserName);
// 如果配置目录不存在则自动创建
if (!PathFileExists(userPath)) {
if (!CreateDirectory(userPath, nullptr)) {
SPDLOG_ERROR(TEXT("Create user {1} directory '{0}' error."), userPath, pUserName);
return -ERR_CREATE_FILE;
}
}
memset(pUser->userName, 0, MAX_PATH);
if (pUserName && lstrlen(pUserName) > 0) {
StringCbCopy(pUser->userName, MAX_PATH, pUserName);
}
StringCbCopy(pUser->userToken, MAX_PATH, pToken);
req.msgContent.token = pToken;
req.msgContent.user = pUserName;
ret = ProtolPostMessage(GET_CLIENTCFG_PATH, &req, &rsp);
if (ret != ERR_SUCCESS) {
return ret;
}
memSize = sizeof(VM_CFG) * static_cast<UINT>(rsp.msgContent.vmConfig.size());
pUserCfg->pVMConfig = static_cast<PVM_CFG>(CoTaskMemAlloc(memSize));
if (pUserCfg->pVMConfig == nullptr) {
SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), memSize);
return -ERR_MALLOC_MEMORY;
}
memset(pUserCfg->pVMConfig, 0, memSize);
pUserCfg->scgCtrlAppId = rsp.msgContent.scgCtrlAppId;
pUserCfg->scgTunnelAppId = rsp.msgContent.scgTunnelAppId;
StringCbCopy(pUserCfg->cliPrivateKey, 64, rsp.msgContent.cliPrivateKey.c_str());
StringCbCopy(pUserCfg->cliPublicKey, 64, rsp.msgContent.cliPublicKey.c_str());
StringCbCopy(pUserCfg->cliAddress, MAX_IP_LEN, rsp.msgContent.cliAddress.c_str());
pUserCfg->tolVM = static_cast<int>(rsp.msgContent.vmConfig.size());
pVm = pUserCfg->pVMConfig;
for (auto vm : rsp.msgContent.vmConfig) {
pVm->vmId = vm.vmId;
StringCbCopy(pVm->vmName, MAX_PATH, vm.vmName.c_str());
StringCbCopy(pVm->svrPublicKey, 64, vm.svrPublicKey.c_str());
StringCbCopy(pVm->vmNetwork, MAX_IP_LEN, vm.vmNetwork.c_str());
StringCbCopy(pVm->scgGateWay, MAX_PATH, vm.scgGateway.c_str());
pVm++;
}
*pCliCfg = pUserCfg;
return ERR_SUCCESS;
}
int GetUserConfigFiles(const TCHAR *pUserName, PUSER_CFGFILE *pCfgFile, int *pItems) {
PUSER_CFGFILE pCfg;
FILE_LIST fileList = {nullptr, 0};
TCHAR fnPath[MAX_PATH] = {};
TCHAR cfgVal[MAX_PATH];
bool isSelected = false;
if (pUserName == nullptr || lstrlen(pUserName) == 0) {
SPDLOG_ERROR(TEXT("Input pUserName params error: {0}"), pUserName);
return -ERR_INPUT_PARAMS;
}
if (pCfgFile == nullptr) {
SPDLOG_ERROR(TEXT("Input pCfgFile params error"));
return -ERR_INPUT_PARAMS;
}
if (pItems == nullptr) {
SPDLOG_ERROR(TEXT("Input pItems params error"));
return -ERR_INPUT_PARAMS;
}
GetPrivateProfileString(CFG_WIREGUARD_SECTION,
CFG_WGCFG_PATH,
TEXT(""),
cfgVal,
MAX_PATH,
GetGlobalCfgInfo()->cfgPath);
if (PathFileExists(cfgVal)) {
isSelected = true;
}
StringCbPrintf(fnPath, MAX_PATH, "%s\\%s\\*.conf", GetGlobalCfgInfo()->configDirectory, pUserName);
int ret = FindFile(fnPath, &fileList, false);
if (ret != ERR_SUCCESS) {
SPDLOG_ERROR(TEXT("Find WireGuard user {1} configure file error: {0}"), ret, pUserName);
return ret;
}
pCfg = static_cast<PUSER_CFGFILE>(CoTaskMemAlloc(sizeof(USER_CFGFILE) * fileList.nItems));
if (pCfg == nullptr) {
SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), sizeof(USER_CFGFILE) * fileList.nItems);
return -ERR_SUCCESS;
}
memset(pCfg, 0, sizeof(USER_CFGFILE) * fileList.nItems);
*pCfgFile = pCfg;
*pItems = static_cast<int>(fileList.nItems);
for (unsigned int i = 0; fileList.pFilePath && i < fileList.nItems; i++) {
StringCbCopy(pCfg->CfgPath, MAX_PATH, fileList.pFilePath[i].path);
if (isSelected && StrCmp(pCfg->CfgPath, cfgVal) == 0) {
pCfg->isCurrent = true;
} else {
pCfg->isCurrent = false;
}
pCfg++;
}
if (fileList.pFilePath) {
HeapFree(GetProcessHeap(), 0, fileList.pFilePath);
}
return ERR_SUCCESS;
}