#include "pch.h" #include "usrerr.h" #include "globalcfg.h" #include "tunnel.h" #include "wireguard.h" #include "misc.h" #include #include #include static WIREGUARD_CREATE_ADAPTER_FUNC *WireGuardCreateAdapter; static WIREGUARD_OPEN_ADAPTER_FUNC *WireGuardOpenAdapter; static WIREGUARD_CLOSE_ADAPTER_FUNC *WireGuardCloseAdapter; static WIREGUARD_GET_ADAPTER_LUID_FUNC *WireGuardGetAdapterLUID; static WIREGUARD_GET_RUNNING_DRIVER_VERSION_FUNC *WireGuardGetRunningDriverVersion; static WIREGUARD_DELETE_DRIVER_FUNC *WireGuardDeleteDriver; static WIREGUARD_SET_LOGGER_FUNC *WireGuardSetLogger; static WIREGUARD_SET_ADAPTER_LOGGING_FUNC *WireGuardSetAdapterLogging; static WIREGUARD_GET_ADAPTER_STATE_FUNC *WireGuardGetAdapterState; static WIREGUARD_SET_ADAPTER_STATE_FUNC *WireGuardSetAdapterState; static WIREGUARD_GET_CONFIGURATION_FUNC *WireGuardGetConfiguration; static WIREGUARD_SET_CONFIGURATION_FUNC *WireGuardSetConfiguration; typedef struct { WIREGUARD_INTERFACE Interface; WIREGUARD_PEER RemoteServer; WIREGUARD_ALLOWED_IP Allow1; WIREGUARD_ALLOWED_IP Allow2; } WG_CONFIG_INFO; static HMODULE g_WireGarudModule; int InitializeWireGuardLibrary() { TCHAR dllPath[MAX_PATH]; StringCbPrintf(dllPath, MAX_PATH, TEXT("%s\\wireguard.dll"), GetGlobalCfgInfo()->workDirectory); if (!PathFileExists(dllPath)) { SPDLOG_ERROR(TEXT("WireGuard DLL Not Found: {0}"), dllPath); return -ERR_ITEM_UNEXISTS; } g_WireGarudModule = LoadLibraryEx(dllPath, nullptr, LOAD_LIBRARY_SEARCH_APPLICATION_DIR | LOAD_LIBRARY_SEARCH_SYSTEM32); if (!g_WireGarudModule) { DWORD errCode = GetLastError(); SPDLOG_ERROR(TEXT("LoadLibraryEx WireGuard DLL error: {0}"), errCode); return -ERR_LOAD_LIBRARY; } #define X(Name) ((*(FARPROC *)&(Name) = GetProcAddress(g_WireGarudModule, #Name)) == nullptr) if (X(WireGuardCreateAdapter) || X(WireGuardOpenAdapter) || X(WireGuardCloseAdapter) || X(WireGuardGetAdapterLUID) || X(WireGuardGetRunningDriverVersion) || X(WireGuardDeleteDriver) || X(WireGuardSetLogger) || X(WireGuardSetAdapterLogging) || X(WireGuardGetAdapterState) || X(WireGuardSetAdapterState) || X(WireGuardGetConfiguration) || X(WireGuardSetConfiguration)) #undef X { SPDLOG_ERROR(TEXT("Map WireGuard DLL EntryPoint error: {0}"), GetLastError()); FreeLibrary(g_WireGarudModule); return -ERR_MAP_LIBRARY; } return ERR_SUCCESS; } void UnInitializeWireGuardLibrary() { if (g_WireGarudModule) { FreeLibrary(g_WireGarudModule); } } int GetWireGuradTunnelInfo(const TCHAR *pTunnelName) { WIREGUARD_ADAPTER_HANDLE Adapter; int ret; WCHAR wstrName[MAX_PATH]; if ((ret = TCharToWideChar(pTunnelName, wstrName, MAX_PATH)) != ERR_SUCCESS) { return ret; } Adapter = WireGuardOpenAdapter(wstrName); if (Adapter) { WG_CONFIG_INFO config; DWORD Bytes = sizeof(WG_CONFIG_INFO); if (!WireGuardGetConfiguration(Adapter, &config.Interface, &Bytes)) { SPDLOG_ERROR("Failed to get configuration: {0}", GetLastError()); } } return ERR_SUCCESS; } int GetWireGuardServiceStatus(const TCHAR *pTunnelName, bool *pIsRunning) { SC_HANDLE schSCManager; SC_HANDLE schService; TCHAR svrName[MAX_PATH]; if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) { SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName); return -ERR_INPUT_PARAMS; } if (pIsRunning == nullptr) { SPDLOG_ERROR(TEXT("Input pIsRunning params error")); return -ERR_INPUT_PARAMS; } *pIsRunning = false; StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName); // Get a handle to the SCM database. schSCManager = OpenSCManager(nullptr, // local computer nullptr, // ServicesActive database SC_MANAGER_ALL_ACCESS); // full access rights if (nullptr == schSCManager) { SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); return -ERR_OPEN_SCM; } // Get a handle to the service. schService = OpenService(schSCManager, // SCM database svrName, // name of service SERVICE_ALL_ACCESS); // full access CloseServiceHandle(schService); // 如果服务不存在则直接返回 if (schService != nullptr) { *pIsRunning = true; } return ERR_SUCCESS; } int RemoveGuardService(const TCHAR *pTunnelName, bool bIsWaitStop) { SC_HANDLE schSCManager; SC_HANDLE schService; TCHAR svrName[MAX_PATH]; SERVICE_STATUS svrStatus; if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) { SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName); return -ERR_INPUT_PARAMS; } StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName); // Get a handle to the SCM database. schSCManager = OpenSCManager(nullptr, // local computer nullptr, // ServicesActive database SC_MANAGER_ALL_ACCESS); // full access rights if (nullptr == schSCManager) { SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); return -ERR_OPEN_SCM; } // Get a handle to the service. schService = OpenService(schSCManager, // SCM database svrName, // name of service SERVICE_ALL_ACCESS); // full access // 如果服务不存在则直接返回 if (schService == nullptr) { CloseServiceHandle(schSCManager); return ERR_SUCCESS; } if (ControlService(schService, SERVICE_CONTROL_STOP, &svrStatus) == 0) { DWORD errCode = GetLastError(); if (errCode != ERROR_SERVICE_CANNOT_ACCEPT_CTRL && errCode != ERROR_SERVICE_NOT_ACTIVE) { SPDLOG_ERROR(TEXT("Stop Service {1} failed ({0})"), errCode, svrName); CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return -ERR_STOP_SERVICE; } } for (int i = 0; bIsWaitStop && i < 10; i++) { SERVICE_STATUS_PROCESS ssStatus; DWORD dwBytesNeeded = 0; if (QueryServiceStatusEx(schService, // handle to service SC_STATUS_PROCESS_INFO, // information level reinterpret_cast(&ssStatus), // address of structure sizeof(SERVICE_STATUS_PROCESS), // size of structure &dwBytesNeeded)) // size needed if buffer is too small { // 服务已经停止 if (ssStatus.dwCurrentState == SERVICE_STOPPED) { break; } } //SPDLOG_ERROR(TEXT("Stop Service {1} retry times ({0})"), i + 1, svrName); Sleep(1000); } if (!DeleteService(schService) && GetLastError() != ERROR_SERVICE_MARKED_FOR_DELETE) { SPDLOG_ERROR(TEXT("Delete Service {1} failed ({0})"), GetLastError(), svrName); CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return -ERR_DELETE_SERVICE; } CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return ERR_SUCCESS; } int CreateWireGuardService(const TCHAR *pInterfaceName, const TCHAR *pWGConfigFilePath) { //Service Name: "WireGuardTunnel$SomeTunnelName" //Display Name: "Some Service Name" //Service Type: SERVICE_WIN32_OWN_PROCESS //Start Type: StartAutomatic //Error Control: ErrorNormal, //Dependencies: [ "Nsi", "TcpIp" ] //Sid Type: SERVICE_SID_TYPE_UNRESTRICTED //Executable: "C:\path\to\example\vpnclient.exe /service configfile.conf" SC_HANDLE schSCManager; SC_HANDLE schService; TCHAR svrName[MAX_PATH]; TCHAR displayName[MAX_PATH]; TCHAR svrPath[MAX_PATH]; TCHAR execParams[MAX_PATH * 2]; if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { SPDLOG_ERROR(TEXT("Input pInterfaceName error: {0}"), pInterfaceName); return -ERR_INPUT_PARAMS; } if (pWGConfigFilePath == nullptr || lstrlen(pWGConfigFilePath) == 0) { SPDLOG_ERROR(TEXT("Input pGUID error: {0}"), pWGConfigFilePath); return -ERR_INPUT_PARAMS; } if (!PathFileExists(pWGConfigFilePath)) { SPDLOG_ERROR(TEXT("WireGuard Configure File Not Exist: {0}"), pWGConfigFilePath); return -ERR_ITEM_UNEXISTS; } StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pInterfaceName); StringCbPrintf(displayName, MAX_PATH, TEXT("WireGuard Tunnel Service %s"), pInterfaceName); StringCbPrintf(svrPath, MAX_PATH, TEXT("%s\\NetTunnelSvr.exe"), GetGlobalCfgInfo()->workDirectory); StringCbPrintf(execParams, MAX_PATH * 2, TEXT("\"%s\" /service \"%s\""), svrPath, pWGConfigFilePath); //SPDLOG_DEBUG(TEXT("Params: {0}"), execParams); if (!PathFileExists(svrPath)) { SPDLOG_ERROR(TEXT("WireGuard Service Not Exist: {0}"), svrPath); return -ERR_ITEM_UNEXISTS; } // Get a handle to the SCM database. schSCManager = OpenSCManager(nullptr, // local computer nullptr, // ServicesActive database SC_MANAGER_ALL_ACCESS); // full access rights if (nullptr == schSCManager) { SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); return -ERR_OPEN_SCM; } // Get a handle to the service. schService = OpenService(schSCManager, // SCM database svrName, // name of service SERVICE_ALL_ACCESS); // full access // 如果服务已经存在则关闭 if (schService != nullptr) { int ret; CloseServiceHandle(schSCManager); CloseServiceHandle(schService); ret = RemoveGuardService(pInterfaceName, true); if (ret != ERR_SUCCESS) { return ret; } } schService = CreateService(schSCManager, svrName, displayName, SC_MANAGER_ALL_ACCESS, SERVICE_WIN32_OWN_PROCESS, SERVICE_DEMAND_START, SERVICE_ERROR_NORMAL, execParams, nullptr, nullptr, TEXT("Nsi\0TcpIp"), nullptr, nullptr); if (schService == nullptr) { SPDLOG_ERROR(TEXT("Create Service {1} failed ({0})"), GetLastError(), svrName); CloseServiceHandle(schSCManager); return -ERR_CREATE_SERVICE; } else { SERVICE_SID_INFO info; SERVICE_DESCRIPTIONA desc; info.dwServiceSidType = SERVICE_SID_TYPE_UNRESTRICTED; if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_SERVICE_SID_INFO, &info)) { SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_SERVICE_SID_INFO Configure failed ({0})"), GetLastError(), svrName); CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return -ERR_CONFIG_SERVICE; } desc.lpDescription = TEXT(const_cast("SCC Tunnel Service over WireGuard")); if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, &desc)) { SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_DESCRIPTION Configure failed ({0})"), GetLastError(), svrName); CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return -ERR_CONFIG_SERVICE; } if (!StartService(schService, 0, nullptr)) { SPDLOG_ERROR(TEXT("Start Service {1} failed ({0})"), GetLastError(), svrName); CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return -ERR_START_SERVICE; } } CloseServiceHandle(schService); CloseServiceHandle(schSCManager); return ERR_SUCCESS; }