NetTunnelWindows/NetTunnelSDK/tunnel/WireGuardService.cpp

335 lines
13 KiB
C++

#include "pch.h"
#include "usrerr.h"
#include "globalcfg.h"
#include "tunnel.h"
#include "wireguard.h"
#include "misc.h"
#include <shlwapi.h>
#include <strsafe.h>
#include <spdlog/spdlog.h>
static WIREGUARD_CREATE_ADAPTER_FUNC *WireGuardCreateAdapter;
static WIREGUARD_OPEN_ADAPTER_FUNC *WireGuardOpenAdapter;
static WIREGUARD_CLOSE_ADAPTER_FUNC *WireGuardCloseAdapter;
static WIREGUARD_GET_ADAPTER_LUID_FUNC *WireGuardGetAdapterLUID;
static WIREGUARD_GET_RUNNING_DRIVER_VERSION_FUNC *WireGuardGetRunningDriverVersion;
static WIREGUARD_DELETE_DRIVER_FUNC *WireGuardDeleteDriver;
static WIREGUARD_SET_LOGGER_FUNC *WireGuardSetLogger;
static WIREGUARD_SET_ADAPTER_LOGGING_FUNC *WireGuardSetAdapterLogging;
static WIREGUARD_GET_ADAPTER_STATE_FUNC *WireGuardGetAdapterState;
static WIREGUARD_SET_ADAPTER_STATE_FUNC *WireGuardSetAdapterState;
static WIREGUARD_GET_CONFIGURATION_FUNC *WireGuardGetConfiguration;
static WIREGUARD_SET_CONFIGURATION_FUNC *WireGuardSetConfiguration;
typedef struct {
WIREGUARD_INTERFACE Interface;
WIREGUARD_PEER RemoteServer;
WIREGUARD_ALLOWED_IP Allow1;
WIREGUARD_ALLOWED_IP Allow2;
} WG_CONFIG_INFO;
static HMODULE g_WireGarudModule;
int InitializeWireGuardLibrary() {
TCHAR dllPath[MAX_PATH];
StringCbPrintf(dllPath, MAX_PATH, TEXT("%s\\wireguard.dll"), GetGlobalCfgInfo()->workDirectory);
if (!PathFileExists(dllPath)) {
SPDLOG_ERROR(TEXT("WireGuard DLL Not Found: {0}"), dllPath);
return -ERR_ITEM_UNEXISTS;
}
g_WireGarudModule = LoadLibraryEx(dllPath,
nullptr,
LOAD_LIBRARY_SEARCH_APPLICATION_DIR | LOAD_LIBRARY_SEARCH_SYSTEM32);
if (!g_WireGarudModule) {
DWORD errCode = GetLastError();
SPDLOG_ERROR(TEXT("LoadLibraryEx WireGuard DLL error: {0}"), errCode);
return -ERR_LOAD_LIBRARY;
}
#define X(Name) ((*(FARPROC *)&(Name) = GetProcAddress(g_WireGarudModule, #Name)) == nullptr)
if (X(WireGuardCreateAdapter) || X(WireGuardOpenAdapter) || X(WireGuardCloseAdapter) ||
X(WireGuardGetAdapterLUID) || X(WireGuardGetRunningDriverVersion) || X(WireGuardDeleteDriver) ||
X(WireGuardSetLogger) || X(WireGuardSetAdapterLogging) || X(WireGuardGetAdapterState) ||
X(WireGuardSetAdapterState) || X(WireGuardGetConfiguration) || X(WireGuardSetConfiguration))
#undef X
{
SPDLOG_ERROR(TEXT("Map WireGuard DLL EntryPoint error: {0}"), GetLastError());
FreeLibrary(g_WireGarudModule);
return -ERR_MAP_LIBRARY;
}
return ERR_SUCCESS;
}
void UnInitializeWireGuardLibrary() {
if (g_WireGarudModule) {
FreeLibrary(g_WireGarudModule);
}
}
int GetWireGuradTunnelInfo(const TCHAR *pTunnelName) {
WIREGUARD_ADAPTER_HANDLE Adapter;
int ret;
WCHAR wstrName[MAX_PATH];
if ((ret = TCharToWideChar(pTunnelName, wstrName, MAX_PATH)) != ERR_SUCCESS) {
return ret;
}
Adapter = WireGuardOpenAdapter(wstrName);
if (Adapter) {
WG_CONFIG_INFO config;
DWORD Bytes = sizeof(WG_CONFIG_INFO);
if (!WireGuardGetConfiguration(Adapter, &config.Interface, &Bytes)) {
SPDLOG_ERROR("Failed to get configuration: {0}", GetLastError());
}
}
return ERR_SUCCESS;
}
int GetWireGuardServiceStatus(const TCHAR *pTunnelName, bool *pIsRunning) {
SC_HANDLE schSCManager;
SC_HANDLE schService;
TCHAR svrName[MAX_PATH];
if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) {
SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName);
return -ERR_INPUT_PARAMS;
}
if (pIsRunning == nullptr) {
SPDLOG_ERROR(TEXT("Input pIsRunning params error"));
return -ERR_INPUT_PARAMS;
}
*pIsRunning = false;
StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName);
// Get a handle to the SCM database.
schSCManager = OpenSCManager(nullptr, // local computer
nullptr, // ServicesActive database
SC_MANAGER_ALL_ACCESS); // full access rights
if (nullptr == schSCManager) {
SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError());
return -ERR_OPEN_SCM;
}
// Get a handle to the service.
schService = OpenService(schSCManager, // SCM database
svrName, // name of service
SERVICE_ALL_ACCESS); // full access
CloseServiceHandle(schService);
// 如果服务不存在则直接返回
if (schService != nullptr) {
*pIsRunning = true;
}
return ERR_SUCCESS;
}
int RemoveGuardService(const TCHAR *pTunnelName, bool bIsWaitStop) {
SC_HANDLE schSCManager;
SC_HANDLE schService;
TCHAR svrName[MAX_PATH];
SERVICE_STATUS svrStatus;
if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) {
SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName);
return -ERR_INPUT_PARAMS;
}
StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName);
// Get a handle to the SCM database.
schSCManager = OpenSCManager(nullptr, // local computer
nullptr, // ServicesActive database
SC_MANAGER_ALL_ACCESS); // full access rights
if (nullptr == schSCManager) {
SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError());
return -ERR_OPEN_SCM;
}
// Get a handle to the service.
schService = OpenService(schSCManager, // SCM database
svrName, // name of service
SERVICE_ALL_ACCESS); // full access
// 如果服务不存在则直接返回
if (schService == nullptr) {
CloseServiceHandle(schSCManager);
return ERR_SUCCESS;
}
if (ControlService(schService, SERVICE_CONTROL_STOP, &svrStatus) == 0) {
DWORD errCode = GetLastError();
if (errCode != ERROR_SERVICE_CANNOT_ACCEPT_CTRL && errCode != ERROR_SERVICE_NOT_ACTIVE) {
SPDLOG_ERROR(TEXT("Stop Service {1} failed ({0})"), errCode, svrName);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return -ERR_STOP_SERVICE;
}
}
for (int i = 0; bIsWaitStop && i < 10; i++) {
SERVICE_STATUS_PROCESS ssStatus;
DWORD dwBytesNeeded = 0;
if (QueryServiceStatusEx(schService, // handle to service
SC_STATUS_PROCESS_INFO, // information level
reinterpret_cast<LPBYTE>(&ssStatus), // address of structure
sizeof(SERVICE_STATUS_PROCESS), // size of structure
&dwBytesNeeded)) // size needed if buffer is too small
{
// 服务已经停止
if (ssStatus.dwCurrentState == SERVICE_STOPPED) {
break;
}
}
//SPDLOG_ERROR(TEXT("Stop Service {1} retry times ({0})"), i + 1, svrName);
Sleep(1000);
}
if (!DeleteService(schService) && GetLastError() != ERROR_SERVICE_MARKED_FOR_DELETE) {
SPDLOG_ERROR(TEXT("Delete Service {1} failed ({0})"), GetLastError(), svrName);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return -ERR_DELETE_SERVICE;
}
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return ERR_SUCCESS;
}
int CreateWireGuardService(const TCHAR *pInterfaceName, const TCHAR *pWGConfigFilePath) {
//Service Name: "WireGuardTunnel$SomeTunnelName"
//Display Name: "Some Service Name"
//Service Type: SERVICE_WIN32_OWN_PROCESS
//Start Type: StartAutomatic
//Error Control: ErrorNormal,
//Dependencies: [ "Nsi", "TcpIp" ]
//Sid Type: SERVICE_SID_TYPE_UNRESTRICTED
//Executable: "C:\path\to\example\vpnclient.exe /service configfile.conf"
SC_HANDLE schSCManager;
SC_HANDLE schService;
TCHAR svrName[MAX_PATH];
TCHAR displayName[MAX_PATH];
TCHAR svrPath[MAX_PATH];
TCHAR execParams[MAX_PATH * 2];
if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) {
SPDLOG_ERROR(TEXT("Input pInterfaceName error: {0}"), pInterfaceName);
return -ERR_INPUT_PARAMS;
}
if (pWGConfigFilePath == nullptr || lstrlen(pWGConfigFilePath) == 0) {
SPDLOG_ERROR(TEXT("Input pGUID error: {0}"), pWGConfigFilePath);
return -ERR_INPUT_PARAMS;
}
if (!PathFileExists(pWGConfigFilePath)) {
SPDLOG_ERROR(TEXT("WireGuard Configure File Not Exist: {0}"), pWGConfigFilePath);
return -ERR_ITEM_UNEXISTS;
}
StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pInterfaceName);
StringCbPrintf(displayName, MAX_PATH, TEXT("WireGuard Tunnel Service %s"), pInterfaceName);
StringCbPrintf(svrPath, MAX_PATH, TEXT("%s\\NetTunnelSvr.exe"), GetGlobalCfgInfo()->workDirectory);
StringCbPrintf(execParams, MAX_PATH * 2, TEXT("\"%s\" /service \"%s\""), svrPath, pWGConfigFilePath);
//SPDLOG_DEBUG(TEXT("Params: {0}"), execParams);
if (!PathFileExists(svrPath)) {
SPDLOG_ERROR(TEXT("WireGuard Service Not Exist: {0}"), svrPath);
return -ERR_ITEM_UNEXISTS;
}
// Get a handle to the SCM database.
schSCManager = OpenSCManager(nullptr, // local computer
nullptr, // ServicesActive database
SC_MANAGER_ALL_ACCESS); // full access rights
if (nullptr == schSCManager) {
SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError());
return -ERR_OPEN_SCM;
}
// Get a handle to the service.
schService = OpenService(schSCManager, // SCM database
svrName, // name of service
SERVICE_ALL_ACCESS); // full access
// 如果服务已经存在则关闭
if (schService != nullptr) {
int ret;
CloseServiceHandle(schSCManager);
CloseServiceHandle(schService);
ret = RemoveGuardService(pInterfaceName, true);
if (ret != ERR_SUCCESS) {
return ret;
}
}
schService = CreateService(schSCManager,
svrName,
displayName,
SC_MANAGER_ALL_ACCESS,
SERVICE_WIN32_OWN_PROCESS,
SERVICE_DEMAND_START,
SERVICE_ERROR_NORMAL,
execParams,
nullptr,
nullptr,
TEXT("Nsi\0TcpIp"),
nullptr,
nullptr);
if (schService == nullptr) {
SPDLOG_ERROR(TEXT("Create Service {1} failed ({0})"), GetLastError(), svrName);
CloseServiceHandle(schSCManager);
return -ERR_CREATE_SERVICE;
} else {
SERVICE_SID_INFO info;
SERVICE_DESCRIPTIONA desc;
info.dwServiceSidType = SERVICE_SID_TYPE_UNRESTRICTED;
if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_SERVICE_SID_INFO, &info)) {
SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_SERVICE_SID_INFO Configure failed ({0})"),
GetLastError(),
svrName);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return -ERR_CONFIG_SERVICE;
}
desc.lpDescription = TEXT(const_cast<LPSTR>("SCC Tunnel Service over WireGuard"));
if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, &desc)) {
SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_DESCRIPTION Configure failed ({0})"),
GetLastError(),
svrName);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return -ERR_CONFIG_SERVICE;
}
if (!StartService(schService, 0, nullptr)) {
SPDLOG_ERROR(TEXT("Start Service {1} failed ({0})"), GetLastError(), svrName);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return -ERR_START_SERVICE;
}
}
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return ERR_SUCCESS;
}