commit 699beabd78df7b71217a883bd9816453bb0b4c32 Author: 黄昕 Date: Tue Aug 22 15:12:52 2023 +0800 1. 初始化工程 diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..a96422f --- /dev/null +++ b/.clang-format @@ -0,0 +1,220 @@ +# ClangFormatConfigureSource: 'clang-format-file://D:/development/c/daemon_agent/.clang-format' +Language: Cpp +AccessModifierOffset: -4 +InsertBraces: true +AlignArrayOfStructures: Left +AlignAfterOpenBracket: Align +AlignConsecutiveMacros: + Enabled: true + AcrossEmptyLines: true + AcrossComments: true +AlignConsecutiveAssignments: + Enabled: true + AcrossEmptyLines: false + AcrossComments: true + PadOperators: true + AlignCompound: true +AlignConsecutiveBitFields: None +AlignConsecutiveDeclarations: + Enabled: true + AcrossEmptyLines: false + AcrossComments: true + PadOperators: true + AlignCompound: true +AlignEscapedNewlines: Left +AlignOperands: DontAlign +AlignTrailingComments: true +AllowAllArgumentsOnNextLine: false +AllowAllConstructorInitializersOnNextLine: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortEnumsOnASingleLine: false +AllowShortBlocksOnASingleLine: Always +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: None +AllowShortLambdasOnASingleLine: All +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: MultiLine +AttributeMacros: + - __capability + - __unused +BinPackArguments: true +BinPackParameters: false +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: true + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: true +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: true +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DeriveLineEnding: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +StatementAttributeLikeMacros: + - Q_EMIT +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '^' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '^<.*\.h>' + Priority: 1 + SortPriority: 0 + CaseSensitive: false + - Regex: '^<.*' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 3 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentCaseLabels: true +IndentCaseBlocks: false +IndentGotoLabels: true +IndentPPDirectives: None +IndentExternBlock: AfterExternBlock +IndentRequiresClause: false +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertTrailingCommas: None +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Never +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 1000 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 140 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PenaltyIndentedWhitespace: 0 +PointerAlignment: Right +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + BasedOnStyle: google + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + - ParseTestProto + - ParsePartialTestProto + CanonicalDelimiter: '' + BasedOnStyle: google +ReflowComments: false +SortIncludes: Never +SortJavaStaticImport: Before +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceAroundPointerQualifiers: Default +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 4 +SpacesInAngles: false +SpacesInConditionalStatement: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +BitFieldColonSpacing: Both +Standard: Auto +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 4 +UseCRLF: false +UseTab: Never +SeparateDefinitionBlocks: Always +WhitespaceSensitiveMacros: + - STRINGIZE + - PP_STRINGIZE + - BOOST_PP_STRINGIZE + - NS_SWIFT_NAME + - CF_SWIFT_NAME +TypenameMacros: + - CONFIG_ITEM + - PCONFIG_ITEM diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..14536fe --- /dev/null +++ b/.gitignore @@ -0,0 +1,133 @@ +### C++ template +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +### CMake template +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + +### Example user template template +### Example user template + +# IntelliJ project files +.idea +*.iml +out +gen +### CLion template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..e702b79 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,9 @@ +cmake_minimum_required(VERSION 3.22) +project(scc) + +set(CMAKE_CXX_STANDARD 23) +add_definitions(-D_UNICODE) + +ADD_SUBDIRECTORY(NetTunnelSvr) +ADD_SUBDIRECTORY(NetTunnelServerApp) +ADD_SUBDIRECTORY(NetTunnelSDK) diff --git a/NetTunnelSDK/CMakeLists.txt b/NetTunnelSDK/CMakeLists.txt new file mode 100644 index 0000000..a12937f --- /dev/null +++ b/NetTunnelSDK/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.22) +project(NetTunnelSDK) + +set(CMAKE_CXX_STANDARD 23) + +find_path(CPPCODEC_INCLUDE_DIRS "cppcodec/base32_crockford.hpp") +INCLUDE_DIRECTORIES(include ./include/json ./include/httplib ../depends/WireGuardNT/include ${CPPCODEC_INCLUDE_DIRS}) +FILE(GLOB CPP_HEADS ./include/*.h ./include/json/AIGCJson.hpp ./include/httplib/httplib.h ../depends/WireGuardNT/include/*.h ${CPPCODEC_INCLUDE_DIRS}/*.hpp) + +ADD_DEFINITIONS(-DNETTUNNELSDK_EXPORTS) +AUX_SOURCE_DIRECTORY(tunnel CPP_SRC) +AUX_SOURCE_DIRECTORY(crypto CPP_SRC) +AUX_SOURCE_DIRECTORY(misc CPP_SRC) +AUX_SOURCE_DIRECTORY(network CPP_SRC) +AUX_SOURCE_DIRECTORY(protocol CPP_SRC) +AUX_SOURCE_DIRECTORY(user CPP_SRC) + +find_package(spdlog CONFIG REQUIRED) +find_package(magic_enum CONFIG REQUIRED) +find_package(OpenSSL REQUIRED) +find_package(RapidJSON CONFIG REQUIRED) + +ADD_LIBRARY(NetTunnelSDK SHARED dllmain.cpp ${CPP_SRC} ${CPP_HEADS}) +target_link_libraries(NetTunnelSDK PRIVATE spdlog::spdlog) +target_link_libraries(NetTunnelSDK PRIVATE magic_enum::magic_enum) +target_link_libraries(NetTunnelSDK PRIVATE OpenSSL::SSL OpenSSL::Crypto) +target_link_libraries(NetTunnelSDK PRIVATE rapidjson) + +ADD_CUSTOM_COMMAND(TARGET NetTunnelSDK + PRE_BUILD + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMENT "!!!!!! Notice: Clearup SDK includes." + COMMAND ${CMAKE_COMMAND} -E make_directory "${PROJECT_SOURCE_DIR}/sdk" + COMMAND ../scripts/cleansdk.bat) + +ADD_CUSTOM_COMMAND(TARGET NetTunnelSDK + POST_BUILD + WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} + COMMENT "!!!!!! Notice: Create SDK includes." + COMMAND ${CMAKE_COMMAND} -E make_directory "${PROJECT_SOURCE_DIR}/sdk" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PROJECT_SOURCE_DIR}/include/sccsdk.h" "${PROJECT_SOURCE_DIR}/sdk" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PROJECT_SOURCE_DIR}/include/common.h" "${PROJECT_SOURCE_DIR}/sdk" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${PROJECT_SOURCE_DIR}/include/usrerr.h" "${PROJECT_SOURCE_DIR}/sdk" + COMMAND ../scripts/gensdk.bat) \ No newline at end of file diff --git a/NetTunnelSDK/crypto/HashDigest.cpp b/NetTunnelSDK/crypto/HashDigest.cpp new file mode 100644 index 0000000..d276b41 --- /dev/null +++ b/NetTunnelSDK/crypto/HashDigest.cpp @@ -0,0 +1,288 @@ +#include "pch.h" +#include "tunnel.h" +#include "usrerr.h" +#include "misc.h" + +#include +#include +#include +#include +#include +#include + +#pragma comment(lib, "Bcrypt.lib") +#pragma comment(lib, "Crypt32.lib") + +#define NT_FAILED(s) (((NTSTATUS)(s)) < 0) + +static const LPCWSTR g_BcryptHash[] = { + BCRYPT_MD2_ALGORITHM, + BCRYPT_MD4_ALGORITHM, + BCRYPT_MD5_ALGORITHM, + BCRYPT_SHA1_ALGORITHM, + BCRYPT_SHA256_ALGORITHM, + BCRYPT_SHA384_ALGORITHM, + BCRYPT_SHA512_ALGORITHM, +}; + +int CalcFileHash(HASH_TYPE type, const TCHAR *pPath, TCHAR outHash[]) { + HANDLE hFile; + BYTE rgbFile[1024]; + DWORD cbRead = 0; + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_HASH_HANDLE hHash = nullptr; + NTSTATUS status; + DWORD cbData = 0, cbHash = 0, cbHashObject = 0; + PBYTE pbHashObject; + PBYTE pbHash; + + if (pPath == nullptr) { + SPDLOG_ERROR(TEXT("Input pPath params error: {0}"), pPath); + return -ERR_INPUT_PARAMS; + } + + if (!PathFileExists(pPath)) { + SPDLOG_ERROR(TEXT("File \'{0}\' not found."), pPath); + return -ERR_ITEM_UNEXISTS; + } + + hFile = CreateFile(pPath, + GENERIC_READ, + FILE_SHARE_READ, + nullptr, + OPEN_EXISTING, + FILE_FLAG_SEQUENTIAL_SCAN, + nullptr); + + if (INVALID_HANDLE_VALUE == hFile) { + SPDLOG_ERROR(TEXT("Error opening file %s\nError: {0}"), pPath, GetLastError()); + return -ERR_OPEN_FILE; + } + + //open an algorithm handle + if (NT_FAILED(status = BCryptOpenAlgorithmProvider(&hAlg, g_BcryptHash[type], nullptr, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptOpenAlgorithmProvider"), status); + CloseHandle(hFile); + return -ERR_BCRYPT_OPEN; + } + + //calculate the size of the buffer to hold the hash object + if (NT_FAILED(status = BCryptGetProperty(hAlg, + BCRYPT_OBJECT_LENGTH, + reinterpret_cast(&cbHashObject), + sizeof(DWORD), + &cbData, + 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + return -ERR_BCRYPT_GETPROPERTY; + } + + //allocate the hash object on the heap + pbHashObject = static_cast(HeapAlloc(GetProcessHeap(), 0, cbHashObject)); + if (nullptr == pbHashObject) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHashObject); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + return -ERR_MALLOC_MEMORY; + } + + //calculate the length of the hash + if (NT_FAILED(status = BCryptGetProperty(hAlg, + BCRYPT_HASH_LENGTH, + reinterpret_cast(&cbHash), + sizeof(DWORD), + &cbData, + 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + return -ERR_BCRYPT_GETPROPERTY; + } + + //allocate the hash buffer on the heap + pbHash = static_cast(HeapAlloc(GetProcessHeap(), 0, cbHash)); + if (nullptr == pbHash) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHash); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + return -ERR_MALLOC_MEMORY; + } + + //create a hash + if (NT_FAILED(status = BCryptCreateHash(hAlg, &hHash, pbHashObject, cbHashObject, nullptr, 0, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptCreateHash"), status); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_CREATEHASH; + } + + while (ReadFile(hFile, rgbFile, 1024, &cbRead, nullptr)) { + if (0 == cbRead) { + break; + } + + if (NT_FAILED(status = BCryptHashData(hHash, rgbFile, cbRead, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptHashData"), status); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_HASHDATA; + } + } + + //close the hash + if (NT_FAILED(status = BCryptFinishHash(hHash, pbHash, cbHash, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptFinishHash"), status); + CloseHandle(hFile); + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_FINISHHASH; + } + + binToHexString(outHash, pbHash, cbHash); + + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + CloseHandle(hFile); + + return ERR_SUCCESS; +} + +/** + * @brief 计算 HMAC HASH 值 + * @param[in] type Hash 类型 @see HASH_TYPE + * @param[in] pHashData 需要计算 Hash 值的数据 + * @param[in] inSize 需要计算 Hash 值的数据大小(字节数) + * @param[in] pKey HMAC Hash 秘钥 + * @param[in] keySize HMAC Hash 秘钥大小(字节数) + * @param[out] outHash 计算结果 + * @param[in] outBase64 是否以 BASE64 字符串输出 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 文件不存在 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_BCRYPT_OPEN 创建加解密算法失败 + * - -ERR_BCRYPT_GETPROPERTY 获取加解密算法属性失败 + * - -ERR_BCRYPT_CREATEHASH 创建 Hash 算法失败 + * - -ERR_BCRYPT_HASHDATA 计算 Hash 数据失败 + * - -ERR_BCRYPT_FINISHHASH 计算 Hash 结果失败 + * - ERR_SUCCESS 成功 + */ +int CalcHmacHash(HASH_TYPE type, + PUCHAR pHashData, + int inSize, + PUCHAR pKey, + int keySize, + TCHAR outHash[], + bool outBase64) { + BCRYPT_ALG_HANDLE hAlg = nullptr; + BCRYPT_HASH_HANDLE hHash = nullptr; + NTSTATUS status; + DWORD cbData = 0, cbHash = 0, cbHashObject = 0; + PBYTE pbHashObject; + PBYTE pbHash; + + //open an algorithm handle + if (NT_FAILED( + status = BCryptOpenAlgorithmProvider(&hAlg, g_BcryptHash[type], nullptr, BCRYPT_ALG_HANDLE_HMAC_FLAG))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptOpenAlgorithmProvider"), status); + return -ERR_BCRYPT_OPEN; + } + + //calculate the size of the buffer to hold the hash object + if (NT_FAILED(status = BCryptGetProperty(hAlg, + BCRYPT_OBJECT_LENGTH, + reinterpret_cast(&cbHashObject), + sizeof(DWORD), + &cbData, + 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status); + BCryptCloseAlgorithmProvider(hAlg, 0); + return -ERR_BCRYPT_GETPROPERTY; + } + + //allocate the hash object on the heap + pbHashObject = static_cast(HeapAlloc(GetProcessHeap(), 0, cbHashObject)); + if (nullptr == pbHashObject) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHashObject); + BCryptCloseAlgorithmProvider(hAlg, 0); + return -ERR_MALLOC_MEMORY; + } + + //calculate the length of the hash + if (NT_FAILED(status = BCryptGetProperty(hAlg, + BCRYPT_HASH_LENGTH, + reinterpret_cast(&cbHash), + sizeof(DWORD), + &cbData, + 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptGetProperty"), status); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + return -ERR_BCRYPT_GETPROPERTY; + } + + //allocate the hash buffer on the heap + pbHash = static_cast(HeapAlloc(GetProcessHeap(), 0, cbHash)); + if (nullptr == pbHash) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), cbHash); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + return -ERR_MALLOC_MEMORY; + } + + //create a hash + if (NT_FAILED(status = BCryptCreateHash(hAlg, &hHash, pbHashObject, cbHashObject, pKey, keySize, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptCreateHash"), status); + BCryptCloseAlgorithmProvider(hAlg, 0); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_CREATEHASH; + } + + //hash some data + if (NT_FAILED(status = BCryptHashData(hHash, pHashData, inSize, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptHashData"), status); + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_HASHDATA; + } + + //close the hash + if (NT_FAILED(status = BCryptFinishHash(hHash, pbHash, cbHash, 0))) { + SPDLOG_ERROR(TEXT("Error {0} returned by BCryptFinishHash"), status); + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + return -ERR_BCRYPT_FINISHHASH; + } + + if (outBase64) { + using base64 = cppcodec::base64_url_unpadded; + StringCbCopy(outHash, 256, base64::encode(pbHash, cbHash).c_str()); + } else { + binToHexString(outHash, pbHash, cbHash); + } + + BCryptCloseAlgorithmProvider(hAlg, 0); + BCryptDestroyHash(hHash); + HeapFree(GetProcessHeap(), 0, pbHashObject); + HeapFree(GetProcessHeap(), 0, pbHash); + + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelSDK/crypto/crypto.cpp b/NetTunnelSDK/crypto/crypto.cpp new file mode 100644 index 0000000..1989e1c --- /dev/null +++ b/NetTunnelSDK/crypto/crypto.cpp @@ -0,0 +1,9 @@ +#include "pch.h" + +#include "tunnel.h" +#include "usrerr.h" +#include "misc.h" + +#include +#include +#include diff --git a/NetTunnelSDK/dllmain.cpp b/NetTunnelSDK/dllmain.cpp new file mode 100644 index 0000000..9448f1c --- /dev/null +++ b/NetTunnelSDK/dllmain.cpp @@ -0,0 +1,17 @@ +// +// Created by HuangXin on 2023/8/22. +// +#include + +BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved) { + switch (ul_reason_for_call) { + case DLL_PROCESS_ATTACH: + case DLL_THREAD_ATTACH: + case DLL_THREAD_DETACH: + case DLL_PROCESS_DETACH: + break; + default: + break; + } + return TRUE; +} \ No newline at end of file diff --git a/NetTunnelSDK/include/ProtocolBase.h b/NetTunnelSDK/include/ProtocolBase.h new file mode 100644 index 0000000..c96bece --- /dev/null +++ b/NetTunnelSDK/include/ProtocolBase.h @@ -0,0 +1,56 @@ +#pragma once +#include "AIGCJson.hpp" + +#define USER_REAL_PLATFORM (0) + +class ProtocolBase { +public: + ProtocolBase() { + this->ver = 1; + this->timeStamp = static_cast(time(nullptr)); + this->cryptoType = 0; + } + + unsigned int ver; + unsigned int cryptoType; + unsigned int timeStamp; + + AIGC_JSON_HELPER(ver, cryptoType, timeStamp) + + void SetVersion(unsigned int versino) { + this->ver = versino; + } + + void SetTimeStamp(unsigned int ts) { + this->timeStamp = ts; + } + + void SetCryptoType(unsigned int crypto) { + this->cryptoType = crypto; + } +}; + +class ResponseStatus { +public: + int errCode; + std::string errMessage; + + AIGC_JSON_HELPER(errCode, errMessage) +}; + +template class ProtocolRequest : public ProtocolBase { +public: + T msgContent; + + AIGC_JSON_HELPER(msgContent) + AIGC_JSON_HELPER_BASE((ProtocolBase *)this) +}; + +template class ProtocolResponse : public ProtocolBase { +public: + int code; + T msgContent; + + AIGC_JSON_HELPER(code, msgContent) + AIGC_JSON_HELPER_BASE((ProtocolBase *)this) +}; \ No newline at end of file diff --git a/NetTunnelSDK/include/common.h b/NetTunnelSDK/include/common.h new file mode 100644 index 0000000..387c1dc --- /dev/null +++ b/NetTunnelSDK/include/common.h @@ -0,0 +1,142 @@ +#pragma once + +#define USED_PORTMAP_TUNNEL (1) +/** + * @brief WireGuard key 最大长度 + */ +#define WG_KEY_MAX (64) + +/** + * @brief 操作系统最大网卡数 + */ +#define NET_CARD_MAX (32) + +/** + * @brief IP 字符串最大长度(支持IPv6) + */ +#define MAX_IP_LEN (48) + +/** + * @brief IP 字符串最小长度 + */ +#define MIN_IP_LEN (7) + +/** + * @brief 网卡名称字符串最大长度(支持IPv6) + */ +#define MAX_NETCARD_NAME (64) + +/** + * @brief SCG 服务 ID + * + */ +typedef enum { + WG_TUNNEL_SCG_ID = 3, ///< 隧道服务 + WG_CTRL_SCG_ID = 4 ///< 隧道控制服务 +} SCG_SVR_ID; + +/** + * @brief 协议加密类型 + * + */ +typedef enum { + CRYPTO_NONE = 0, ///< 不加密 + CRYPTO_BASE64 = 1, ///< BASE64 字符串编码 + CRYPTO_AES128 = 2, ///< AES 128位秘钥 加密 + CRYPTO_3DES = 3, ///< 3DES 加密 + CRYPTO_AES256 = 4, ///< AES 256 位秘钥加密 + CRYPTO_MAX, +} PROTO_CRYPTO_TYPE; + +/** + * @brief 网络连接状态 + * + */ +typedef enum { + STATUS_DISCONNECTED = 0, ///< 连接已断开连接 + STATUS_CONNECTING, ///< 连接正在进行连接 + STATUS_CONNECTED, ///< 连接处于连接状态 + STATUS_DISCONNECTING, ///< 连接正在断开连接 + STATUS_HARDWARE_NOT_PRESENT, ///< 连接的硬件(例如网络接口卡 (NIC) )不存在 + STATUS_HARDWARE_DISABLED, ///< 连接的硬件存在,但未启用 + STATUS_HARDWARE_MALFUNCTION, ///< 连接的硬件中发生了故障 + STATUS_MEDIA_DISCONNECTED, ///< 媒体(例如网络电缆)断开连接 + STATUS_AUTHENTICATING, ///< 连接正在等待身份验证发生 + STATUS_AUTHENTICATION_SUCCEEDED, ///< 身份验证在此连接上成功 + STATUS_AUTHENTICATION_FAILED, ///< 此连接上身份验证失败 + STATUS_INVALID_ADDRESS, ///< 地址无效 + STATUS_CREDENTIALS_REQUIRED, ///< 需要安全凭据 + STATUS_ACTION_REQUIRED, ///< 连接需要其它动作 + STATUS_ACTION_REQUIRED_RETRY, ///< 重试连接其它动作 + STATUS_CONNECT_FAILED, ///< 连接失败 +} NET_CONNECT_STATUS; + +/** + * @brief 日志等级 + * + */ +enum LOG_LEVEL { + LOG_TRACE = 0, ///< TRACE 日志等级 + LOG_DEBUG, ///< DEBUG 日志等级 + LOG_INFO, ///< INFO 日志等级 + LOG_WARN, ///< WARN 日志等级 + LOG_ERROR, ///< ERROR 日志等级 + LOG_CRITICAL, ///< CRITICAL 日志等级 + LOG_OFF ///< 关闭日志 +}; + +/** + * @brief Hash 算法类型 + * + */ +typedef enum { + HASH_MD2 = 0, ///< MD2 HASH 算法 + HASH_MD4, ///< MD4 HASH 算法 + HASH_MD5, ///< MD5 HASH 算法 + HASH_SHA1, ///< SHA1 HASH 算法 + HASH_SHA256, ///< SHA256 HASH 算法 + HASH_SHA384, ///< SHA384 HASH 算法 + HASH_SHA512 ///< SHA512 HASH 算法 +} HASH_TYPE; + +/** + * @brief 网络共享模式 + * + */ +typedef enum { + ICS_SHARE_MODE = 0, ///< Internet Share Mode(ICS) 模式 + NAT_SHARE_MODE = 1 ///< Net Address Translation(NAT) 模式 +} NET_SHARE_MODE; + +/** + * + * @brief 虚拟主机配置信息 + */ +typedef struct { + int vmId; ///< 用户虚拟机 ID + TCHAR vmName[MAX_PATH]; ///< 用户虚拟机名称 + TCHAR svrPublicKey[64]; ///< 用户服务端公钥 + TCHAR vmNetwork[MAX_IP_LEN]; ///< 用户虚拟机网络地址 + TCHAR scgGateWay[MAX_PATH]; ///< 用户服务端接入网关 + TCHAR scgTunnelGw[MAX_PATH]; ///< 用户隧道接入网关 +} VM_CFG, *PVM_CFG; + +/** + * + * @brief 客户端用户相关配置信息 + */ +typedef struct { + int scgCtrlAppId; ///< 用户接入网关控制 ID + int scgTunnelAppId; ///< 用户接入网关隧道 ID + TCHAR cliPrivateKey[64]; ///< 用户客户端私钥 + TCHAR cliPublicKey[64]; ///< 用户客户端公钥 + TCHAR cliAddress[MAX_IP_LEN]; ///< 用户客户端隧道IP地址 + PVM_CFG pVMConfig; ///< 用户虚拟机配置列表 + int tolVM; ///< 用户虚拟机配置最大数 +} USER_CLIENT_CONFIG, *PUSER_CLIENT_CONFIG; + +typedef struct { + int svrListenPort; ///< 用户服务端监听端口 + TCHAR svrPrivateKey[64]; ///< 用户服务端公钥 + TCHAR svrAddress[MAX_IP_LEN]; ///< 用户服务端隧道 IP 地址 +} USER_SERVER_CONFIG, *PUSER_SERVER_CONFIG; \ No newline at end of file diff --git a/NetTunnelSDK/include/globalcfg.h b/NetTunnelSDK/include/globalcfg.h new file mode 100644 index 0000000..0f3598c --- /dev/null +++ b/NetTunnelSDK/include/globalcfg.h @@ -0,0 +1,84 @@ +#pragma once +#include +#include "common.h" + +#include + +#if 0 +/** + * @brief WireGuard 配置项 + */ +typedef struct { + TCHAR wireguardPath[MAX_PATH]; ///< wireguard.exe 路径 + BOOL wireguardExists; ///< wireguard.exe 是否存在 + TCHAR wgPath[MAX_PATH]; ///< wg.exe 路径 + BOOL wgExists; ///< wg.exe 是否存在 +} WIREGUARD_CFG, *PWIREGUARD_CFG; +#endif + +/** + * @brief WireGuard 网络接口配置项 + */ +typedef struct { + TCHAR wgName[260]; ///< 网卡名称, Windows标识为 UUID + TCHAR wgIpaddr[MAX_IP_LEN]; ///< 网卡 IP 地址 + TCHAR wgNetmask[MAX_IP_LEN]; ///< 网卡子网掩码 + TCHAR wgCfgPath[MAX_PATH]; ///< 配置文件路径 +} WGINTERFACE_CFG, *PWGINTERFACE_CFG; + +/** + * @brief 用户信息配置 + */ +typedef struct { + TCHAR userName[MAX_PATH]; ///< 用户名 + TCHAR userToken[MAX_PATH]; ///< 用户访问令牌 + USER_CLIENT_CONFIG cliConfig; ///< 用户客户端配置 + USER_SERVER_CONFIG svrConfig; ///< 用户服务端配置 +} USER_CONFIG, *PUSER_CONFIG; + +typedef struct { + TCHAR targetIp[MAX_IP_LEN]; + UINT16 targetPort; + UINT16 proxyPort; + UINT16 scgGwPort; + TCHAR scgIpAddr[MAX_IP_LEN]; + SOCKET udpProxySock; + SOCKET scgGwSock; + HANDLE hProxyTunnelThread; + HANDLE hProxySCGThread; + bool exitNow; +} SCG_PROXY_INFO, *PSCG_PROXY_INFO; + +/** + * @brief SDK 全局配置项 + */ +typedef struct { + TCHAR platformServerUrl[MAX_PATH]; ///< 管理平台IP地址 + TCHAR configDirectory[MAX_PATH]; ///< 配置存放目录 + TCHAR systemDirectory[MAX_PATH]; ///< 操作系统目录 + TCHAR workDirectory[MAX_PATH]; ///< SDK 当前工作目录 + bool isWorkServer; ///< SDK 当前模式 客户端/服务端 + PROTO_CRYPTO_TYPE proCryptoType; ///< 协议加密类型 + TCHAR proKeyBuf[256]; ///< 协议加密秘钥 + BOOL enableLog; ///< 是否启用日志 + spdlog::level::level_enum logLevel; ///< 日志等级 + TCHAR cfgPath[MAX_PATH]; ///< 配置文件路径 + WGINTERFACE_CFG wgServerCfg; ///< wireguard 服务端网络接口配置 + WGINTERFACE_CFG wgClientCfg; ///< wireguard 客户端网络接口配置 + USER_CONFIG userCfg; ///< 用户配置项 + SCG_PROXY_INFO scgProxy; ///< SCG UDP 代理信息 + int curConnVmId; ///< 当前连接的VM + TCHAR clientId[MAX_PATH]; ///< 客户端验证签名 ID + TCHAR clientSecret[MAX_PATH]; ///< 客户端验证签名秘钥 +} SDK_CONFIG, *PSDK_CONFIG; + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif + +PSDK_CONFIG GetGlobalCfgInfo(); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/httplib/httplib.h b/NetTunnelSDK/include/httplib/httplib.h new file mode 100644 index 0000000..60af391 --- /dev/null +++ b/NetTunnelSDK/include/httplib/httplib.h @@ -0,0 +1,9746 @@ +// +// httplib.h +// +// Copyright (c) 2023 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.13.0" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() - 1 : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifndef strcasecmp +#define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#pragma comment(lib, "cryptui.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if OPENSSL_VERSION_NUMBER < 0x1010100fL +#error Sorry, OpenSSL versions prior to 1.1.1 are not supported +#elif OPENSSL_VERSION_NUMBER < 0x30000000L +#define SSL_get1_peer_certificate SSL_get_peer_certificate +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +typedef void (*POSTSOCKETCONNECTCB)(socket_t sock); + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), + s1.end(), + s2.begin(), + s2.end(), + [](unsigned char c1, unsigned char c2) { return ::tolower(c1) < ::tolower(c2); }); + } +}; + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) : exit_function(std::move(f)), execute_on_destruction {true} { + } + + scope_exit(scope_exit &&rhs) + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction {rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { + this->exit_function(); + } + } + + void release() { + this->execute_on_destruction = false; + } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +using Headers = std::multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; + +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) { + } + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) { + } + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = std::function; + +using ContentProviderWithoutLength = std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; + +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = std::function< + bool(const char *data, size_t data_length, uint64_t offset, uint64_t total_length)>; + +using ContentReceiver = std::function; + +using MultipartContentHeader = std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), + multipart_reader_(std::move(multipart_reader)) { + } + + bool operator()(MultipartContentHeader header, ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, size_t id = 0) const; + template T get_header_value(const std::string &key, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, size_t id = 0) const; + template T get_header_value(const std::string &key, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = 302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + + void set_content_provider(size_t length, + const std::string &content_type, + ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + template ssize_t write_format(const char *fmt, const Args &...args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() { + } +}; + +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) { + } + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { + break; + } + + fn = std::move(pool_.jobs_.front()); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +void default_socket_options(socket_t sock); + +namespace detail { + +class MatcherBase { +public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + static constexpr char marker = ':'; + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) : regex_(pattern) { + } + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = std::function; + + using HandlerWithContentReader = std::function< + void(const Request &, Response &, const ContentReader &content_reader)>; + + using Expect100ContinueHandler = std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime); + Server &set_file_request_handler(Handler handler); + + Server &set_error_handler(HandlerWithResponse handler); + Server &set_error_handler(Handler handler); + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, + bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_ {INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = std::vector, Handler>>; + using HandlersForContentReader = std::vector< + std::pair, HandlerWithContentReader>>; + + static std::unique_ptr make_matcher(const std::string &pattern); + + socket_t create_server_socket(const std::string &host, + int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, const Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, + Response &res, + ContentReader content_reader, + const HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + void apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary); + bool write_response(Stream &strm, bool close_connection, const Request &req, Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, const Request &req, Response &res); + bool write_response_core(Stream &strm, + bool close_connection, + const Request &req, + Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, + const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool read_content_with_content_receiver(Stream &strm, + Request &req, + Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, + Request &req, + Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_ {false}; + std::atomic done_ {false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(const Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + + Result(std::unique_ptr &&res, Error err, Headers &&request_headers = Headers {}) + : res_(std::move(res)), + err_(err), + request_headers_(std::move(request_headers)) { + } + + // Response + operator bool() const { + return res_ != nullptr; + } + + bool operator==(std::nullptr_t) const { + return res_ == nullptr; + } + + bool operator!=(std::nullptr_t) const { + return res_ != nullptr; + } + + const Response &value() const { + return *res_; + } + + Response &value() { + return *res_; + } + + const Response &operator*() const { + return *res_; + } + + Response &operator*() { + return *res_; + } + + const Response *operator->() const { + return res_.get(); + } + + Response *operator->() { + return res_.get(); + } + + // Error + Error error() const { + return err_; + } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, size_t id = 0) const; + template T get_request_header_value(const std::string &key, size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); + void set_post_connect_callback(detail::POSTSOCKETCONNECTCB cb); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { + return sock != INVALID_SOCKET; + } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket); + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, Error &error); + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; +#endif + + Logger logger_; + + detail::POSTSOCKETCONNECTCB _postConnCb = nullptr; + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, Response &res); + bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); + std::unique_ptr send_with_content_provider(Request &req, + const char *body, + size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, + Error &error); + Result send_with_content_provider(const std::string &method, + const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type); + ContentProviderWithoutLength get_multipart_content_provider(const std::string &boundary, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + std::string adjust_host_string(const std::string &host) const; + + virtual bool process_socket(const Socket &socket, std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver); + Result Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const std::string &body, const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, const Params ¶ms); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary); + Result Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const std::string &body, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, const Params ¶ms); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items); + Result Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary); + Result Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, const std::string &content_type); + Result Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template void set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); + void set_post_connect_cb(detail::POSTSOCKETCONNECTCB cb); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, + const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store = nullptr); + + SSLServer(const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_socket); + + bool process_socket(const Socket &socket, std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy(Socket &sock, Response &res, bool &success, Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast(duration - std::chrono::seconds(sec)).count(); + callback(static_cast(sec), static_cast(usec)); +} + +template +inline T get_header_value(const Headers & /*headers*/, + const std::string & /*key*/, + size_t /*id*/ = 0, + uint64_t /*def*/ = 0) { +} + +template<> +inline uint64_t get_header_value(const Headers &headers, const std::string &key, size_t id, uint64_t def) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +} // namespace detail + +template inline T Request::get_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(headers, key, id, 0); +} + +template inline T Response::get_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(headers, key, id, 0); +} + +template inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { + const auto bufsiz = 2048; + std::array buf {}; + + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); + if (sn <= 0) { + return sn; + } + + auto n = static_cast(sn); + + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); + n = static_cast(snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +inline void default_socket_options(socket_t sock) { + int yes = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, reinterpret_cast(&yes), sizeof(yes)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), sizeof(yes)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), sizeof(yes)); +#endif +#endif +} + +template +inline Server &Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server &Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: + return "Success (no error)"; + case Error::Connection: + return "Could not establish connection"; + case Error::BindIPAddress: + return "Failed to bind IP address"; + case Error::Read: + return "Failed to read connection"; + case Error::Write: + return "Failed to write connection"; + case Error::ExceedRedirectCount: + return "Maximum redirect count exceeded"; + case Error::Canceled: + return "Connection handling canceled"; + case Error::SSLConnection: + return "SSL connection failed"; + case Error::SSLLoadingCerts: + return "SSL certificate loading failed"; + case Error::SSLServerVerification: + return "SSL server verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: + return "Compression failed"; + case Error::ConnectionTimeout: + return "Connection timed out"; + case Error::Unknown: + return "Unknown"; + default: + break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +template inline T Result::get_request_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(request_headers_, key, id, 0); +} + +template +inline void ClientImpl::set_connection_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_connection_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void Client::set_connection_timeout(const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(Ranges ranges); + +std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +void read_file(const std::string &path, std::string &out); + +std::string trim_copy(const std::string &s); + +void split(const char *b, const char *e, char d, std::function fn); + +bool process_client_socket(socket_t sock, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback); + +socket_t create_client_socket(const std::string &host, + const std::string &ip, + int port, + int address_family, + bool tcp_nodelay, + SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, + Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, size_t id = 0, const char *def = nullptr); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { + None = 0, + Gzip, + Brotli +}; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, Callback callback) = 0; +}; + +class nocompressor : public compressor { +public: + virtual ~nocompressor() = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor : public compressor { +public: + gzip_compressor(); + ~gzip_compressor(); + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, int &val) { + if (i >= s.size()) { + return false; + } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { + return false; + } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { + out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); + } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { +#ifdef _WIN32 + return _access_s(path.c_str(), 0) == 0; +#else + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +#endif +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { + return false; + } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || c == '.' || c == '!' || c == '~' || + c == '*' || c == '\'' || c == '(' || c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': + result += "%20"; + break; + case '+': + result += "%2B"; + break; + case '\r': + result += "%0D"; + break; + case '\n': + result += "%0A"; + break; + case '\'': + result += "%27"; + break; + case ',': + result += "%2C"; + break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': + result += "%3B"; + break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { + result.append(buff, len); + } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { + return m[1].str(); + } + return std::string(); +} + +inline bool is_space_or_tab(char c) { + return c == ' ' || c == '\t'; +} + +inline std::pair trim(const char *b, const char *e, size_t left, size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline void split(const char *b, const char *e, char d, std::function fn) { + size_t i = 0; + size_t beg = 0; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + beg = i + 1; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + fn(&b[r.first], &b[r.second]); + } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), + fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) { +} + +inline const char *stream_line_reader::ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { + break; + } + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } +} + +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), + static_cast(size), +#else + ptr, + size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), + static_cast(size), +#else + ptr, + size, +#endif + flags); + }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return 1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); }); +#endif +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return 1; + } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); }); +#endif +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { + return Error::ConnectionTimeout; + } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { + return Error::Connection; + } +#endif + + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); }); + + if (ret == 0) { + return Error::ConnectionTimeout; + } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024 * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(socket_t sock, + SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; +}; +#endif + +inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto duration = duration_cast(current - start); + auto timeout = keep_alive_timeout_sec * 1000; + if (duration.count() > timeout) { + return false; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; + } + } +} + +template +inline bool process_server_socket_core(const std::atomic &svr_sock, + socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, + T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (svr_sock != INVALID_SOCKET && count > 0 && keep_alive(sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { + break; + } + count--; + } + return ret; +} + +template +inline bool process_server_socket(const std::atomic &svr_sock, + socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + T callback) { + return process_server_socket_core( + svr_sock, + sock, + keep_alive_max_count, + keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket(socket_t sock, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +template +socket_t create_socket(const std::string &host, + const std::string &ip, + int port, + int address_family, + int socket_flags, + bool tcp_nodelay, + SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { + node = host.c_str(); + } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { + return INVALID_SOCKET; + } + + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); + if (sock != INVALID_SOCKET) { + sockaddr_un addr {}; + addr.sun_family = AF_UNIX; + std::copy(host.begin(), host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast(sizeof(addr) - sizeof(addr.sun_path) + addrlen); + + fcntl(sock, F_SETFD, FD_CLOEXEC); + if (socket_options) { + socket_options(sock); + } + + if (!bind_or_connect(sock, hints)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, + rp->ai_socktype, + rp->ai_protocol, + nullptr, + 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { + continue; + } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { + auto yes = 1; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), sizeof(yes)); +#else + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast(&yes), sizeof(yes)); +#endif + } + + if (socket_options) { + socket_options(sock); + } + + if (rp->ai_family == AF_INET6) { + auto no = 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), sizeof(no)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, reinterpret_cast(&no), sizeof(no)); +#endif + } + + // bind or connect + if (bind_or_connect(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { + return false; + } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + freeifaddrs(ifap); + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + freeifaddrs(ifap); + return addr_candidate; +} +#endif + +inline socket_t create_client_socket(const std::string &host, + const std::string &ip, + int port, + int address_family, + bool tcp_nodelay, + SocketOptions socket_options, + time_t connection_timeout_sec, + time_t connection_timeout_usec, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + const std::string &intf, + Error &error) { + auto sock = create_socket( + host, + ip, + port, + address_family, + 0, + tcp_nodelay, + std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { + ip_from_if = intf; + } + if (!bind_ip_address(sock2, ip_from_if.c_str())) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, connection_timeout_usec); + if (error != Error::Success) { + return false; + } + } + + set_nonblocking(sock2, false); + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { + error = Error::Connection; + } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr {}; + if (getnameinfo(reinterpret_cast(&addr), + addr_len, + ipstr.data(), + static_cast(ipstr.size()), + nullptr, + 0, + NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, unsigned int h) { + return (l == 0) + ? h + : str2tag_core(s + 1, + l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & h * 33) ^ static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator"" _t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline const char *find_content_type(const std::string &path, const std::map &user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { + return it->second.c_str(); + } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: + return nullptr; + case "css"_t: + return "text/css"; + case "csv"_t: + return "text/csv"; + case "htm"_t: + case "html"_t: + return "text/html"; + case "js"_t: + case "mjs"_t: + return "text/javascript"; + case "txt"_t: + return "text/plain"; + case "vtt"_t: + return "text/vtt"; + + case "apng"_t: + return "image/apng"; + case "avif"_t: + return "image/avif"; + case "bmp"_t: + return "image/bmp"; + case "gif"_t: + return "image/gif"; + case "png"_t: + return "image/png"; + case "svg"_t: + return "image/svg+xml"; + case "webp"_t: + return "image/webp"; + case "ico"_t: + return "image/x-icon"; + case "tif"_t: + return "image/tiff"; + case "tiff"_t: + return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: + return "image/jpeg"; + + case "mp4"_t: + return "video/mp4"; + case "mpeg"_t: + return "video/mpeg"; + case "webm"_t: + return "video/webm"; + + case "mp3"_t: + return "audio/mp3"; + case "mpga"_t: + return "audio/mpeg"; + case "weba"_t: + return "audio/webm"; + case "wav"_t: + return "audio/wave"; + + case "otf"_t: + return "font/otf"; + case "ttf"_t: + return "font/ttf"; + case "woff"_t: + return "font/woff"; + case "woff2"_t: + return "font/woff2"; + + case "7z"_t: + return "application/x-7z-compressed"; + case "atom"_t: + return "application/atom+xml"; + case "pdf"_t: + return "application/pdf"; + case "json"_t: + return "application/json"; + case "rss"_t: + return "application/rss+xml"; + case "tar"_t: + return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: + return "application/xhtml+xml"; + case "xslt"_t: + return "application/xslt+xml"; + case "xml"_t: + return "application/xml"; + case "gz"_t: + return "application/gzip"; + case "zip"_t: + return "application/zip"; + case "wasm"_t: + return "application/wasm"; + } +} + +inline const char *status_message(int status) { + switch (status) { + case 100: + return "Continue"; + case 101: + return "Switching Protocol"; + case 102: + return "Processing"; + case 103: + return "Early Hints"; + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 203: + return "Non-Authoritative Information"; + case 204: + return "No Content"; + case 205: + return "Reset Content"; + case 206: + return "Partial Content"; + case 207: + return "Multi-Status"; + case 208: + return "Already Reported"; + case 226: + return "IM Used"; + case 300: + return "Multiple Choice"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 303: + return "See Other"; + case 304: + return "Not Modified"; + case 305: + return "Use Proxy"; + case 306: + return "unused"; + case 307: + return "Temporary Redirect"; + case 308: + return "Permanent Redirect"; + case 400: + return "Bad Request"; + case 401: + return "Unauthorized"; + case 402: + return "Payment Required"; + case 403: + return "Forbidden"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 406: + return "Not Acceptable"; + case 407: + return "Proxy Authentication Required"; + case 408: + return "Request Timeout"; + case 409: + return "Conflict"; + case 410: + return "Gone"; + case 411: + return "Length Required"; + case 412: + return "Precondition Failed"; + case 413: + return "Payload Too Large"; + case 414: + return "URI Too Long"; + case 415: + return "Unsupported Media Type"; + case 416: + return "Range Not Satisfiable"; + case 417: + return "Expectation Failed"; + case 418: + return "I'm a teapot"; + case 421: + return "Misdirected Request"; + case 422: + return "Unprocessable Entity"; + case 423: + return "Locked"; + case 424: + return "Failed Dependency"; + case 425: + return "Too Early"; + case 426: + return "Upgrade Required"; + case 428: + return "Precondition Required"; + case 429: + return "Too Many Requests"; + case 431: + return "Request Header Fields Too Large"; + case 451: + return "Unavailable For Legal Reasons"; + case 501: + return "Not Implemented"; + case 502: + return "Bad Gateway"; + case 503: + return "Service Unavailable"; + case 504: + return "Gateway Timeout"; + case 505: + return "HTTP Version Not Supported"; + case 506: + return "Variant Also Negotiates"; + case 507: + return "Insufficient Storage"; + case 508: + return "Loop Detected"; + case 510: + return "Not Extended"; + case 511: + return "Network Authentication Required"; + + default: + case 500: + return "Internal Server Error"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: + return true; + + default: + return !content_type.rfind("text/", 0) && tag != "text/event-stream"_t; + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { + return EncodingType::None; + } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { + return EncodingType::Brotli; + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { + return EncodingType::Gzip; + } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, bool /*last*/, Callback callback) { + if (!data_length) { + return true; + } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { + deflateEnd(&strm_); +} + +inline bool gzip_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff {}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { + inflateEnd(&strm_); +} + +inline bool gzip_decompressor::is_valid() const { + return is_valid_; +} + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = (std::numeric_limits::max)(); + + strm_.avail_in = static_cast((std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff {}; + while (strm_.avail_in > 0) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + auto prev_avail_in = strm_.avail_in; + + ret = inflate(&strm_, Z_NO_FLUSH); + + if (prev_avail_in - strm_.avail_in == 0) { + return false; + } + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: + inflateEnd(&strm_); + return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { + return false; + } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, bool last, Callback callback) { + std::array buff {}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { + break; + } + } else { + if (!available_in) { + break; + } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, + operation, + &available_in, + &next_in, + &available_out, + &next_out, + nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { + BrotliDecoderDestroyInstance(decoder_s); + } +} + +inline bool brotli_decompressor::is_valid() const { + return decoder_s; +} + +inline bool brotli_decompressor::decompress(const char *data, size_t data_length, Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff {}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream(decoder_s, + &avail_in, + &next_in, + &avail_out, + reinterpret_cast(&next_out), + &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return false; + } + + if (!callback(buff.data(), buff.size() - avail_out)) { + return false; + } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, const std::string &key, size_t id, const char *def) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second.c_str(); + } + return def; +} + +inline bool compare_case_ignore(const std::string &a, const std::string &b) { + if (a.size() != b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (::tolower(a[i]) != ::tolower(b[i])) { + return false; + } + } + return true; +} + +template inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + if (p == end) { + return false; + } + + auto key_end = p; + + if (*p++ != ':') { + return false; + } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p < end) { + auto key = std::string(beg, key_end); + auto val = compare_case_ignore(key, "Location") ? std::string(p, end) : decode_url(std::string(p, end), false); + fn(std::move(key), std::move(val)); + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { + return false; + } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { + break; + } +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + } else { + // Blank line indicates end of headers. + if (line_reader.size() == 1) { + break; + } + line_terminator_len = 1; + } +#else + } else { + continue; // Skip invalid line. + } +#endif + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, [&](std::string &&key, std::string &&val) { + headers.emplace(std::move(key), std::move(val)); + }); + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, Progress progress, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return false; + } + + if (!out(buf, static_cast(n), r, len)) { + return false; + } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { + return false; + } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { + return; + } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + + if (!out(buf, static_cast(n), r, 0)) { + return false; + } + r += static_cast(n); + } + + return true; +} + +template inline bool read_content_chunked(Stream &strm, T &x, ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { + return false; + } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { + return false; + } + if (chunk_len == ULONG_MAX) { + return false; + } + + if (chunk_len == 0) { + break; + } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + + if (strcmp(line_reader.ptr(), "\r\n")) { + return false; + } + + if (!line_reader.getline()) { + return false; + } + } + + assert(chunk_len == 0); + + // Trailer + if (!line_reader.getline()) { + return false; + } + + while (strcmp(line_reader.ptr(), "\r\n")) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, [&](std::string &&key, std::string &&val) { + x.headers.emplace(std::move(key), std::move(val)); + }); + + if (!line_reader.getline()) { + return false; + } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, ContentReceiverWithProgress receiver, bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = 415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = 415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = 500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, + T &x, + size_t payload_max_length, + int &status, + Progress progress, + ContentReceiverWithProgress receiver, + bool decompress) { + return prepare_content_receiver(x, + status, + std::move(receiver), + decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value(x.headers, "Content-Length"); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? 413 : 400; + } + return ret; + }); +} // namespace detail + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + auto len = strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { + return len; + } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { + return len; + } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { + return false; + } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, + const ContentProvider &content_provider, + size_t offset, + size_t length, + T is_shutting_down, + Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (strm.is_writable() && write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, + const ContentProvider &content_provider, + size_t offset, + size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, error); +} + +template +inline bool write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!strm.is_writable() || !write_data(strm, d, l)) { + ok = false; + } + } + return ok; + }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, + U &compressor, + Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { + return; + } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + static const std::string done_marker("0\r\n"); + if (!write_data(strm, done_marker.data(), done_marker.size())) { + ok = false; + } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + static const std::string crlf("\r\n"); + if (!write_data(strm, crlf.data(), crlf.size())) { + ok = false; + } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { done_with_trailer(&trailer); }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, + U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, compressor, error); +} + +template +inline bool redirect(T &cli, + Request &req, + Response &res, + const std::string &path, + const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { + res.location = location; + } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { + query += "&"; + } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { + return; + } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { + return false; + } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = content_type.substr(beg, end - beg); + if (boundary.length() >= 2 && boundary.front() == '"' && boundary.back() == '"') { + boundary = boundary.substr(1, boundary.size() - 2); + } + return !boundary.empty(); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { + return; + } + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} +catch (...) { + return false; +} +#endif + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { + return is_valid_; + } + + bool parse(const char *buf, + size_t n, + const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + + // TODO: support 'filename*' + static const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*name="(.*?)"(?:;\s*filename="(.*?)")?(?:;\s*filename\*=\S+)?\s*$)~", + std::regex_constants::icase); + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { + return true; + } + if (!buf_start_with(dash_boundary_crlf_)) { + return false; + } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { + return false; + } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + static const std::string header_name = "content-type:"; + const auto header = buf_head(pos); + if (start_with_case_ignore(header, header_name)) { + file_.content_type = trim_copy(header.substr(header_name.size())); + } else { + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } else { + is_valid_ = false; + return false; + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { + return true; + } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { + return true; + } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { + return true; + } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_crlf_.size() > buf_size()) { + return true; + } + if (buf_start_with(dash_crlf_)) { + buf_erase(dash_crlf_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, const std::string &b) const { + if (a.size() < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (::tolower(a[i]) != ::tolower(b[i])) { + return false; + } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + const std::string dash_crlf_ = "--\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, const std::string &b) const { + if (epos - spos < b.size()) { + return false; + } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { + return false; + } + } + return true; + } + + size_t buf_size() const { + return buf_epos_ - buf_spos_; + } + + const char *buf_data() const { + return &buf_[buf_spos_]; + } + + std::string buf_head(size_t l) const { + return buf_.substr(buf_spos_, l); + } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { + return buf_size(); + } + if (buf_[pos] == c) { + break; + } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { + return buf_size(); + } + + if (start_with(buf_, pos, buf_epos_, s)) { + return pos - buf_spos_; + } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { + buf_.resize(remaining_size + n); + } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { + buf_spos_ += size; + } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; +} + +inline std::string make_multipart_data_boundary() { + static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + std::random_device seed_gen; + + // Request 128 bits of entropy for initialization + std::seed_seq seed_sequence {seed_gen(), seed_gen(), seed_gen(), seed_gen()}; + std::mt19937 engine(seed_sequence); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string serialize_multipart_formdata_item_begin(const T &item, const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { + return "\r\n"; +} + +inline std::string serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string serialize_multipart_formdata(const MultipartFormDataItems &items, + const std::string &boundary, + bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { + body += serialize_multipart_formdata_finish(boundary); + } + + return body; +} + +inline std::pair get_range_offset_and_length(const Request &req, size_t content_length, size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + auto slen = static_cast(content_length); + + if (r.first == -1) { + r.first = (std::max)(static_cast(0), slen - r.second); + r.second = slen - 1; + } + + if (r.second == -1) { + r.second = slen - 1; + } + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, + CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline bool make_multipart_ranges_data(const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type, + std::string &data) { + return process_multipart_ranges_data( + req, + res, + boundary, + content_type, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + if (offset < res.body.size()) { + data += res.body.substr(offset, length); + return true; + } + return false; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, + res, + boundary, + content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool write_multipart_ranges_data(Stream &strm, + const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type, + const T &is_shutting_down) { + return process_multipart_ranges_data( + req, + res, + boundary, + content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, is_shutting_down); + }); +} + +inline std::pair get_range_offset_and_length(const Request &req, const Response &res, size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { + r.second = static_cast(res.content_length_) - 1; + } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "PRI" || + req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { + return true; + } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr(EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { + return false; + } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != nullptr) { + auto encoded_cert = static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template using CFObjectPtr = std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { + CFRelease(obj); + } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, kCFBooleanTrue}; + + CFObjectPtr query(CFDictionaryCreate(nullptr, + reinterpret_cast(keys), + values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { + return false; + } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast(CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { + continue; + } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast(CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) { + is_valid_ = true; + } + } + + ~WSInit() { + if (is_valid_) { + WSACleanup(); + } + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, + const std::map &auth, + size_t cnonce_count, + const std::string &cnonce, + const std::string &username, + const std::string &password, + bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { + algo = auth.at("algorithm"); + } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { + A2 += ":" + H(req.body); + } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + auth.at("realm") + "\", nonce=\"" + + auth.at("nonce") + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + cnonce + "\", response=\"") + + response + "\"" + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const Response &res, std::map &auth, bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset + [] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(std::rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter(ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) { + } + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { + return std::string(); + } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { + addrs.push_back(ip); + } + } + + freeaddrinfo(result); +} + +inline std::string append_query_params(const std::string &path, const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { + field += ", "; + } + if (r.first != -1) { + field += std::to_string(r.first); + } + field += '-'; + if (r.second != -1) { + field += std::to_string(r.second); + } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return it->second; + } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { + return it->second; + } + return MultipartFormData(); +} + +inline std::vector Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = 302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content_provider(size_t in_length, + const std::string &content_type, + ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { + content_provider_ = std::move(provider); + } + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider(const std::string &content_type, + ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = resource_releaser; + is_chunked_content_provider_ = true; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, size_t id) const { + return detail::get_header_value(request_headers_, key, id, ""); +} + +inline size_t Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + read_buff_(read_buff_size_, 0) { +} + +inline SocketStream::~SocketStream() { +} + +inline bool SocketStream::is_readable() const { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!is_readable()) { + return -1; + } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { + return -1; + } + +#if defined(_WIN32) && !defined(_WIN64) + size = (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { + return sock_; +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { + return true; +} + +inline bool BufferStream::is_writable() const { + return true; +} + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, int & /*port*/) const { +} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, int & /*port*/) const { +} + +inline socket_t BufferStream::socket() const { + return 0; +} + +inline const std::string &BufferStream::get_buffer() const { + return buffer; +} + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find(marker, last_param_end); + if (marker_pos == std::string::npos) { + break; + } + + static_fragments_.push_back(pattern.substr(last_param_end, marker_pos - last_param_end)); + + const auto param_name_start = marker_pos + 1; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { + sep_pos = pattern.length(); + } + + auto param_name = pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + "' multiple times in route pattern '" + + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = {}; + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { + continue; + } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { + sep_pos = request.path.length(); + } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace(param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everythin up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() : new_task_queue([] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() { +} + +inline std::unique_ptr Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(make_matcher(pattern), std::move(handler))); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, const std::string &dir, Headers headers) { + if (detail::is_dir(dir)) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server &Server::set_file_extension_and_mimetype_mapping(const std::string &ext, const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler(HandlerWithResponse handler) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler(Handler handler) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server &Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) { + return false; + } + return true; +} + +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { + auto se = detail::scope_exit([&]() { done_ = true; }); + return listen_internal(); +} + +inline bool Server::listen(const std::string &host, int port, int socket_flags) { + auto se = detail::scope_exit([&]() { done_ = true; }); + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { + return is_running_; +} + +inline void Server::wait_until_ready() const { + while (!is_running() && !done_) { + std::this_thread::sleep_for(std::chrono::milliseconds {1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { + return false; + } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: + req.method = std::string(b, e); + break; + case 1: + req.target = std::string(b, e); + break; + case 2: + req.version = std::string(b, e); + break; + default: + break; + } + count++; + }); + + if (count != 3) { + return false; + } + } + + static const std::set + methods {"GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { + return false; + } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { + return false; + } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + size_t count = 0; + + detail::split(req.target.data(), req.target.data() + req.target.size(), '?', [&](const char *b, const char *e) { + switch (count) { + case 0: + req.path = detail::decode_url(std::string(b, e), false); + break; + case 1: { + if (e - b > 0) { + detail::parse_query_text(std::string(b, e), req.params); + } + break; + } + default: + break; + } + count++; + }); + + if (count > 2) { + return false; + } + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, const Request &req, Response &res) { + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, + bool close_connection, + const Request &req, + Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { + apply_ranges(req, res, content_type, boundary); + } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::stringstream ss; + ss << "timeout=" << keep_alive_timeout_sec_ << ", max=" << keep_alive_max_count_; + res.set_header("Keep-Alive", ss.str()); + } + + if (!res.has_header("Content-Type") && (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Content-Length") && res.body.empty() && !res.content_length_ && !res.content_provider_) { + res.set_header("Content-Length", "0"); + } + + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { + post_routing_handler_(req, res); + } + + // Response line and headers + { + detail::BufferStream bstrm; + + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, detail::status_message(res.status))) { + return false; + } + + if (!detail::write_headers(bstrm, res.headers)) { + return false; + } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + res.content_provider_success_ = false; + ret = false; + } + } + } + + // Log + if (logger_) { + logger_(req, res); + } + + return ret; +} + +inline bool Server::write_content_with_provider(Stream &strm, + const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { return this->svr_sock_ == INVALID_SOCKET; }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offsets = detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + auto length = offsets.second; + return detail::write_content(strm, res.content_provider_, offset, length, is_shutting_down); + } else { + return detail::write_multipart_ranges_data(strm, req, res, boundary, content_type, is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, + req, + res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { + return false; + } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { + return false; + } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = 413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver(Stream &strm, + Request &req, + Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, + req, + res, + std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool Server::read_content_core(Stream &strm, + Request &req, + Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { + path += "index.html"; + } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { + res.set_header("Content-Type", type); + } + for (const auto &kv : entry.headers) { + res.set_header(kv.first.c_str(), kv.second); + } + res.status = req.has_header("Range") ? 206 : 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; +} + +inline socket_t Server::create_server_socket(const std::string &host, + int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket(host, + std::string(), + port, + address_family_, + socket_flags, + tcp_nodelay_, + std::move(socket_options), + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { + return false; + } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, int socket_flags) { + if (!is_valid()) { + return -1; + } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { + return -1; + } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec_); + tv.tv_usec = static_cast(read_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec_); + tv.tv_usec = static_cast(write_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + task_queue->enqueue([this, sock]() { process_and_close_socket(sock); }); + } + + task_queue->shutdown(); + } + + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, + req, + res, + nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader(req, + res, + std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader(req, + res, + std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader(req, + res, + std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader(req, + res, + std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { + return false; + } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, const Handlers &handlers) { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, std::string &content_type, std::string &boundary) { + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.set_header("Content-Type", "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offsets = detail::get_range_offset_and_length(req, res.content_length_, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field(offset, length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field(offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + if (offset < res.body.size()) { + res.body = res.body.substr(offset, length); + } else { + res.body.clear(); + res.status = 416; + } + } else { + std::string data; + if (detail::make_multipart_ranges_data(req, res, boundary, content_type, data)) { + res.body.swap(data); + } else { + res.body.clear(); + res.status = 416; + } + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), + res.body.size(), + true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader(Request &req, + Response &res, + ContentReader content_reader, + const HandlersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool Server::process_request(Stream &strm, + bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf {}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { + return false; + } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + for (const auto &header : default_headers_) { + if (res.headers.find(header.first) == res.headers.end()) { + res.headers.insert(header); + } + } + +#ifdef _WIN32 + // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). +#else +#ifndef CPPHTTPLIB_USE_POLL + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 500; + return write_response(strm, close_connection, req, res); + } +#endif +#endif + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, close_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + strm.get_local_ip_and_port(req.local_addr, req.local_port); + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = 416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { + setup_request(req); + } + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, detail::status_message(status)); + break; + default: + return write_response(strm, close_connection, req, res); + } + } + + // Rounting + bool routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } + catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = 500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': + val += "\\r"; + break; + case '\n': + val += "\\n"; + break; + default: + val += s[i]; + break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } + catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? 200 : 206; + } + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { + res.status = 404; + } + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { + return true; +} + +inline bool Server::process_and_close_socket(socket_t sock) { + auto ret = detail::process_server_socket( + svr_sock_, + sock, + keep_alive_max_count_, + keep_alive_timeout_sec_, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) { +} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) { +} + +inline ClientImpl::ClientImpl(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(host), + port_(port), + host_and_port_(adjust_host_string(host) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), + client_key_path_(client_key_path) { +} + +inline ClientImpl::~ClientImpl() { + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { + return true; +} + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket(proxy_host_, + std::string(), + proxy_port_, + address_family_, + tcp_nodelay_, + socket_options_, + connection_timeout_sec_, + connection_timeout_usec_, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + interface_, + error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { + ip = it->second; + } + + return detail::create_client_socket(host_, + ip, + port_, + address_family_, + tcp_nodelay_, + socket_options_, + connection_timeout_sec_, + connection_timeout_usec_, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + interface_, + error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { + return false; + } + + if (_postConnCb) { + _postConnCb(sock); + } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) { + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { + return; + } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, Response &res) { + std::array buf {}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { + return false; + } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == 100) { + if (!line_reader.getline()) { + return false; + } // CRLF + if (!line_reader.getline()) { + return false; + } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { + return false; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, res, success, error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { + return false; + } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, + [&](Stream &strm) { return handle_request(strm, req, res, close_connection, error); }); + + if (!ret) { + if (error == Error::Success) { + error = Error::Unknown; + } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result {ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { + return false; + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == 401 || res.status == 407) && req.authorization_count_ < 5) { + auto is_proxy = res.status == 407; + const auto &username = is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header(req, + auth, + new_req.authorization_count_, + detail::random_string(10), + username, + password, + is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { + res = new_res; + } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { + return false; + } + + const static std::regex re( + R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { + return false; + } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { + next_host = m[3].str(); + } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { + next_scheme = scheme; + } + if (next_host.empty()) { + next_host = host_; + } + if (next_path.empty()) { + next_path = "/"; + } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { + cli.set_ca_cert_store(ca_cert_store_); + } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, const Request &req, Error &error) { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { + req.set_header("Accept", "*/*"); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header(basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert( + make_basic_authentication_header(proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header(proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + detail::write_headers(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, + const char *body, + size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, + Error &error) { + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + req.set_header("Content-Encoding", "gzip"); + } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress(data, + data_len, + last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter(std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider(const std::string &method, + const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + auto error = Error::Success; + + auto res = send_with_content_provider(req, + body, + content_length, + std::move(content_provider), + std::move(content_provider_without_length), + content_type, + error); + + return Result {std::move(res), error, std::move(req.headers)}; +} + +inline std::string ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { + return "[" + host + "]"; + } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, + Request &req, + Response &res, + bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { + return false; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + char buf[1]; + if (SSL_peek(socket_.ssl, buf, 1) == 0 && SSL_get_error(socket_.ssl, 0) == SSL_ERROR_ZERO_RETURN) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != 204) && req.method != "HEAD" && req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = req.content_receiver + ? static_cast([&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { + return true; + } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { + error = Error::Canceled; + } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { + return true; + } + auto ret = req.progress(current, total); + if (!ret) { + error = Error::Canceled; + } + return ret; + }; + + int dummy_status; + if (!detail::read_content(strm, + res, + (std::numeric_limits::max)(), + dummy_status, + std::move(progress), + std::move(out), + decompress_)) { + if (error != Error::Canceled) { + error = Error::Read; + } + return false; + } + } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. Maybe a code refactor (such as moving this out to + // the send function and getting rid of the recursiveness of the mutex) + // could make this more obvious. + + // This is safe to call because process_request is only called by + // handle_request which is only called by send, which locks the request + // mutex during the process. It would be a bug to call it from a different + // thread since it's a thread-safety issue to do these things to the socket + // if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + // Log + if (logger_) { + logger_(req, res); + } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + size_t cur_item = 0, cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, DataSink &sink) mutable -> bool { + if (!offset && items.size()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin(provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + bool has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool ClientImpl::process_socket(const Socket &socket, std::function callback) { + return detail::process_client_socket(socket.sock, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { + return false; +} + +inline Result ClientImpl::Get(const std::string &path) { + return Get(path, Headers(), Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, uint64_t /*offset*/, uint64_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + Progress progress) { + if (params.empty()) { + return Get(path, headers); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query.c_str(), headers, progress); +} + +inline Result ClientImpl::Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, params, headers, nullptr, content_receiver, progress); +} + +inline Result ClientImpl::Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + if (params.empty()) { + return Get(path, headers, response_handler, content_receiver, progress); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query.c_str(), headers, response_handler, content_receiver, progress); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), body.size(), nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", + path, + headers, + nullptr, + content_length, + std::move(content_provider), + nullptr, + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", + path, + headers, + nullptr, + 0, + nullptr, + std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type.c_str()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result {nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type.c_str()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("POST", + path, + headers, + nullptr, + 0, + nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), body.size(), nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", + path, + headers, + nullptr, + content_length, + std::move(content_provider), + nullptr, + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", + path, + headers, + nullptr, + 0, + nullptr, + std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result {nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider("PUT", + path, + headers, + nullptr, + 0, + nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, body, content_length, nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", + path, + headers, + nullptr, + content_length, + std::move(content_provider), + nullptr, + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", + path, + headers, + nullptr, + 0, + nullptr, + std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (!content_type.empty()) { + req.set_header("Content-Type", content_type); + } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { + return host_; +} + +inline int ClientImpl::port() const { + return port_; +} + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { + return socket_.sock; +} + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { + keep_alive_ = on; +} + +inline void ClientImpl::set_follow_location(bool on) { + follow_location_ = on; +} + +inline void ClientImpl::set_url_encode(bool on) { + url_encode_ = on; +} + +inline void ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; +} + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { + compress_ = on; +} + +inline void ClientImpl::set_decompress(bool on) { + decompress_ = on; +} + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_post_connect_callback(detail::POSTSOCKETCONNECTCB cb) { + _postConnCb = cb; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + if (!mem) { + return nullptr; + } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { + BIO_free_all(mem); + return nullptr; + } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { + continue; + } + + if (itmp->x509) { + X509_STORE_add_cert(cts, itmp->x509); + } + if (itmp->crl) { + X509_STORE_add_crl(cts, itmp->crl); + } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + BIO_free_all(mem); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { + SSL_shutdown(ssl); + } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, + SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: + break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl(const std::atomic &svr_sock, + SSL *ssl, + socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + T callback) { + return process_server_socket_core( + svr_sock, + sock, + keep_alive_max_count, + keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool process_client_socket_ssl(SSL *ssl, + socket_t sock, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +class SSLInit { +public: + SSLInit() { + OPENSSL_init_ssl(OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + } +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, + SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), + ssl_(ssl), + read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() { +} + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_READ || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { + auto handle_size = static_cast(std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && + (err == SSL_ERROR_WANT_WRITE || (err == SSL_ERROR_SYSCALL && WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { + return ret; + } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { + return sock_; +} + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, + const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + + // add default password callback before opening encrypted private key + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata(ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, client_ca_cert_dir_path); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, SSL_OP_NO_COMPRESSION | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { + SSL_CTX_free(ctx_); + } +} + +inline bool SSLServer::is_valid() const { + return ctx_; +} + +inline SSL_CTX *SSLServer::ssl_context() const { + return ctx_; +} + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, + ctx_, + ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking(sock, + ssl2, + SSL_accept, + read_timeout_sec_, + read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + ret = detail::process_server_socket_ssl( + svr_sock_, + ssl, + sock, + keep_alive_max_count_, + keep_alive_timeout_sec_, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + [this, ssl](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) : SSLClient(host, 443, std::string(), std::string()) { +} + +inline SSLClient::SSLClient(const std::string &host, int port) : SSLClient(host, port, std::string(), std::string()) { +} + +inline SSLClient::SSLClient(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, X509 *client_cert, EVP_PKEY *client_key) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { + SSL_CTX_free(ctx_); + } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { + return ctx_; +} + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { + return ctx_; +} + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, bool &success, Error &error) { + success = true; + Response res2; + if (!detail::process_client_socket(socket.sock, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_client_socket(socket.sock, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, + auth, + 1, + detail::random_string(10), + proxy_digest_auth_username_, + proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } else { + res = res2; + return false; + } + } + + return true; +} + +inline bool SSLClient::load_certs() { + bool ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { + SSL_CTX_set_default_verify_paths(ctx_); + } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, + ctx_, + ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking(socket.sock, + ssl2, + SSL_connect, + connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + error = Error::SSLServerVerification; + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl2) { + // NOTE: With -Wold-style-cast, this can produce a warning, since + // SSL_set_tlsext_host_name is a macro (in OpenSSL), which contains + // an old style cast. Short of doing compiler specific pragma's + // here, we can't get rid of this warning. :'( + SSL_set_tlsext_host_name(ssl2, host_.c_str()); + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool SSLClient::process_socket(const Socket &socket, std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl(socket.ssl, + socket.sock, + read_timeout_sec_, + read_timeout_usec_, + write_timeout_sec_, + write_timeout_usec_, + std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { + return true; +} + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || verify_host_with_common_name(server_cert); +} + +inline bool SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = check_host_name(name, name_len); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { + ret = true; + } + } + + GENERAL_NAMES_free( + const_cast(reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { + return true; + } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { + return false; + } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && !p.compare(0, p.size() - 1, h)); + if (!partial_match) { + return false; + } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) { +} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re(R"((?:([a-z]+):\/\/)?(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { + host = m[3].str(); + } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); + } + } else { + cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); + } +} + +inline Client::Client(const std::string &host, int port) : cli_(detail::make_unique(host, port)) { +} + +inline Client::Client(const std::string &host, + int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, client_key_path)) { +} + +inline Client::~Client() { +} + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path) { + return cli_->Get(path); +} + +inline Result Client::Get(const std::string &path, const Headers &headers) { + return cli_->Get(path, headers); +} + +inline Result Client::Get(const std::string &path, Progress progress) { + return cli_->Get(path, std::move(progress)); +} + +inline Result Client::Get(const std::string &path, const Headers &headers, Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} + +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} + +inline Result Client::Get(const std::string &path, const Headers &headers, ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} + +inline Result Client::Get(const std::string &path, ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Get(const std::string &path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Get(const std::string &path, ResponseHandler response_handler, ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver)); +} + +inline Result Client::Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver)); +} + +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Get(const std::string &path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Get(const std::string &path, const Params ¶ms, const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, progress); +} + +inline Result Client::Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, params, headers, content_receiver, progress); +} + +inline Result Client::Get(const std::string &path, + const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return cli_->Get(path, params, headers, response_handler, content_receiver, progress); +} + +inline Result Client::Head(const std::string &path) { + return cli_->Head(path); +} + +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { + return cli_->Post(path); +} + +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} + +inline Result Client::Post(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} + +inline Result Client::Post(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Post(path, body, content_type); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} + +inline Result Client::Post(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} + +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} + +inline Result Client::Post(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Post(path, headers, params); +} + +inline Result Client::Post(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} + +inline Result Client::Post(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} + +inline Result Client::Post(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} + +inline Result Client::Put(const std::string &path) { + return cli_->Put(path); +} + +inline Result Client::Put(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} + +inline Result Client::Put(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Put(path, body, content_type); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} + +inline Result Client::Put(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} + +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} + +inline Result Client::Put(const std::string &path, const Headers &headers, const Params ¶ms) { + return cli_->Put(path, headers, params); +} + +inline Result Client::Put(const std::string &path, const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} + +inline Result Client::Put(const std::string &path, const Headers &headers, const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} + +inline Result Client::Put(const std::string &path, + const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} + +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} + +inline Result Client::Patch(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} + +inline Result Client::Patch(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} + +inline Result Client::Patch(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} + +inline Result Client::Patch(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} + +inline Result Client::Patch(const std::string &path, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} + +inline Result Client::Patch(const std::string &path, + const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type); +} + +inline Result Client::Patch(const std::string &path, + const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} + +inline Result Client::Delete(const std::string &path) { + return cli_->Delete(path); +} + +inline Result Client::Delete(const std::string &path, const Headers &headers) { + return cli_->Delete(path, headers); +} + +inline Result Client::Delete(const std::string &path, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} + +inline Result Client::Delete(const std::string &path, + const Headers &headers, + const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} + +inline Result Client::Delete(const std::string &path, const std::string &body, const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} + +inline Result Client::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} + +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} + +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { + return cli_->send(req); +} + +inline void Client::stop() { + cli_->stop(); +} + +inline std::string Client::host() const { + return cli_->host(); +} + +inline int Client::port() const { + return cli_->port(); +} + +inline size_t Client::is_socket_open() const { + return cli_->is_socket_open(); +} + +inline socket_t Client::socket() const { + return cli_->socket(); +} + +inline void Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { + cli_->set_tcp_nodelay(on); +} + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, const std::string &password) { + cli_->set_basic_auth(username, password); +} + +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { + cli_->set_keep_alive(on); +} + +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_url_encode(bool on) { + cli_->set_url_encode(on); +} + +inline void Client::set_compress(bool on) { + cli_->set_compress(on); +} + +inline void Client::set_decompress(bool on) { + cli_->set_decompress(on); +} + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} + +inline void Client::set_post_connect_cb(detail::POSTSOCKETCONNECTCB cb) { + cli_->set_post_connect_callback(cb); +} + +inline void Client::set_proxy_basic_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} + +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { + return static_cast(*cli_).ssl_context(); + } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) +#undef poll +#endif + +#endif // CPPHTTPLIB_HTTPLIB_H \ No newline at end of file diff --git a/NetTunnelSDK/include/json/AIGCJson.hpp b/NetTunnelSDK/include/json/AIGCJson.hpp new file mode 100644 index 0000000..270f6ad --- /dev/null +++ b/NetTunnelSDK/include/json/AIGCJson.hpp @@ -0,0 +1,1000 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace aigc { + +/****************************************************** + * Register class or struct members + * eg: + * struct Test + * { + * string A; + * string B; + * AIGC_JSON_HELPER(A, B) + * }; + ******************************************************/ +#define AIGC_JSON_HELPER(...) \ + std::map __aigcDefaultValues; \ + bool AIGCJsonToObject(aigc::JsonHelperPrivate &handle, \ + rapidjson::Value &jsonValue, \ + std::vector &names) { \ + std::vector standardNames = handle.GetMembersNames(#__VA_ARGS__); \ + if (names.size() <= standardNames.size()) { \ + for (int i = static_cast(names.size()); i < static_cast(standardNames.size()); i++) \ + names.push_back(standardNames[i]); \ + } \ + return handle.SetMembers(names, 0, jsonValue, __aigcDefaultValues, __VA_ARGS__); \ + } \ + bool AIGCObjectToJson(aigc::JsonHelperPrivate &handle, \ + rapidjson::Value &jsonValue, \ + rapidjson::Document::AllocatorType &allocator, \ + std::vector &names) { \ + std::vector standardNames = handle.GetMembersNames(#__VA_ARGS__); \ + if (names.size() <= standardNames.size()) { \ + for (int i = static_cast(names.size()); i < static_cast(standardNames.size()); i++) \ + names.push_back(standardNames[i]); \ + } \ + return handle.GetMembers(names, 0, jsonValue, allocator, __VA_ARGS__); \ + } + +/****************************************************** + * Rename members + * eg: + * struct Test + * { + * string A; + * string B; + * AIGC_JSON_HELPER(A, B) + * AIGC_JSON_HELPER_RENAME("a", "b") + * }; + ******************************************************/ +#define AIGC_JSON_HELPER_RENAME(...) \ + std::vector AIGCRenameMembers(aigc::JsonHelperPrivate &handle) { \ + return handle.GetMembersNames(#__VA_ARGS__); \ + } + +/****************************************************** + * Register base-class + * eg: + * struct Base + * { + * string name; + * AIGC_JSON_HELPER(name) + * }; + * struct Test : Base + * { + * string A; + * string B; + * AIGC_JSON_HELPER(A, B) + * AIGC_JSON_HELPER_BASE((Base*)this) + * }; + ******************************************************/ +#define AIGC_JSON_HELPER_BASE(...) \ + bool AIGCBaseJsonToObject(aigc::JsonHelperPrivate &handle, rapidjson::Value &jsonValue) { \ + return handle.SetBase(jsonValue, __VA_ARGS__); \ + } \ + bool AIGCBaseObjectToJson(aigc::JsonHelperPrivate &handle, \ + rapidjson::Value &jsonValue, \ + rapidjson::Document::AllocatorType &allocator) { \ + return handle.GetBase(jsonValue, allocator, __VA_ARGS__); \ + } + +/****************************************************** + * Set default value + * eg: + * struct Base + * { + * string name; + * int age; + * AIGC_JSON_HELPER(name, age) + * AIGC_JSON_HELPER_DEFAULT(age=18) + * }; + ******************************************************/ +#define AIGC_JSON_HELPER_DEFAULT(...) \ + void AIGCDefaultValues(aigc::JsonHelperPrivate &handle) { \ + __aigcDefaultValues = handle.GetMembersValueMap(#__VA_ARGS__); \ + } + +class JsonHelperPrivate { +public: + /****************************************************** + * + * enable_if + * + ******************************************************/ + template struct enable_if {}; + + template struct enable_if { + typedef TYPE type; + }; + +public: + /****************************************************** + * + * Check Interface + * If class or struct add AIGC_JSON_HELPER\AIGC_JSON_HELPER_RENAME\AIGC_JSON_HELPER_BASE, + * it will go to the correct conver function. + * + ******************************************************/ + template struct HasConverFunction { + template static char func(decltype(&TT::AIGCJsonToObject)); + + template static int func(...); + + const static bool has = (sizeof(func(NULL)) == sizeof(char)); + }; + + template struct HasRenameFunction { + template static char func(decltype(&TT::AIGCRenameMembers)); + template static int func(...); + const static bool has = (sizeof(func(NULL)) == sizeof(char)); + }; + + template struct HasBaseConverFunction { + template static char func(decltype(&TT::AIGCBaseJsonToObject)); + template static int func(...); + const static bool has = (sizeof(func(NULL)) == sizeof(char)); + }; + + template struct HasDefaultValueFunction { + template static char func(decltype(&TT::AIGCDefaultValues)); + template static int func(...); + const static bool has = (sizeof(func(NULL)) == sizeof(char)); + }; + +public: + /****************************************************** + * + * Interface of JsonToObject\ObjectToJson + * + ******************************************************/ + template::has, int>::type = 0> + bool JsonToObject(T &obj, rapidjson::Value &jsonValue) { + if (!BaseConverJsonToObject(obj, jsonValue)) { + return false; + } + + LoadDefaultValuesMap(obj); + std::vector names = LoadRenameArray(obj); + return obj.AIGCJsonToObject(*this, jsonValue, names); + } + + template::has, int>::type = 0> + bool JsonToObject(T &obj, rapidjson::Value &jsonValue) { + if (std::is_enum::value) { + int ivalue; + if (!JsonToObject(ivalue, jsonValue)) { + return false; + } + + obj = static_cast(ivalue); + return true; + } + + m_message = "unsupported this type."; + return false; + } + + template::has, int>::type = 0> + bool ObjectToJson(T &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + if (jsonValue.IsNull()) { + jsonValue.SetObject(); + } + if (!BaseConverObjectToJson(obj, jsonValue, allocator)) { + return false; + } + std::vector names = LoadRenameArray(obj); + return obj.AIGCObjectToJson(*this, jsonValue, allocator, names); + } + + template::has, int>::type = 0> + bool ObjectToJson(T &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + if (std::is_enum::value) { + int ivalue = static_cast(obj); + return ObjectToJson(ivalue, jsonValue, allocator); + } + + m_message = "unsupported this type."; + return false; + } + + /****************************************************** + * + * Interface of LoadRenameArray + * + ******************************************************/ + template::has, int>::type = 0> + std::vector LoadRenameArray(T &obj) { + return obj.AIGCRenameMembers(*this); + } + + template::has, int>::type = 0> + std::vector LoadRenameArray(T &obj) { + return std::vector(); + } + + /****************************************************** + * + * Interface of BaseConverJsonToObject\BaseConverObjectToJson + * + ******************************************************/ + template::has, int>::type = 0> + bool BaseConverJsonToObject(T &obj, rapidjson::Value &jsonValue) { + return obj.AIGCBaseJsonToObject(*this, jsonValue); + } + + template::has, int>::type = 0> + bool BaseConverJsonToObject(T &obj, rapidjson::Value &jsonValue) { + return true; + } + + template::has, int>::type = 0> + bool BaseConverObjectToJson(T &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + return obj.AIGCBaseObjectToJson(*this, jsonValue, allocator); + } + + template::has, int>::type = 0> + bool BaseConverObjectToJson(T &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + return true; + } + + /****************************************************** + * + * Interface of Default value + * + ******************************************************/ + template::has, int>::type = 0> + void LoadDefaultValuesMap(T &obj) { + obj.AIGCDefaultValues(*this); + } + + template::has, int>::type = 0> + void LoadDefaultValuesMap(T &obj) { + (void)obj; + } + +public: + /****************************************************** + * + * Tool function + * + ******************************************************/ + static std::vector StringSplit(const std::string &str, char sep = ',') { + std::vector array; + std::string::size_type pos1, pos2; + pos1 = 0; + pos2 = str.find(sep); + while (std::string::npos != pos2) { + array.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = pos2 + 1; + pos2 = str.find(sep, pos1); + } + if (pos1 != str.length()) { + array.push_back(str.substr(pos1)); + } + + return array; + } + + static std::string StringTrim(std::string key) { + std::string newStr = key; + if (!newStr.empty()) { + newStr.erase(0, newStr.find_first_not_of(" ")); + newStr.erase(newStr.find_last_not_of(" ") + 1); + } + if (!newStr.empty()) { + newStr.erase(0, newStr.find_first_not_of("\"")); + newStr.erase(newStr.find_last_not_of("\"") + 1); + } + return newStr; + } + + static void StringTrim(std::vector &array) { + for (int i = 0; i < (int)array.size(); i++) { + array[i] = StringTrim(array[i]); + } + } + + /** + * Get json value type + */ + static std::string GetJsonValueTypeName(rapidjson::Value &jsonValue) { + switch (jsonValue.GetType()) { + case rapidjson::Type::kArrayType: + return "array"; + case rapidjson::Type::kFalseType: + case rapidjson::Type::kTrueType: + return "bool"; + case rapidjson::Type::kObjectType: + return "object"; + case rapidjson::Type::kStringType: + return "string"; + case rapidjson::Type::kNumberType: + return "number"; + default: + return "string"; + } + } + + static std::string GetStringFromJsonValue(rapidjson::Value &jsonValue, bool isPrettyWriter = false) { + rapidjson::StringBuffer buffer; + + if (isPrettyWriter) { + rapidjson::PrettyWriter writer(buffer); + jsonValue.Accept(writer); + } else { + rapidjson::Writer writer(buffer); + jsonValue.Accept(writer); + } + + std::string ret = std::string(buffer.GetString()); + return ret; + } + + static std::string FindStringFromMap(std::string name, std::map &stringMap) { + std::map::iterator iter = stringMap.find(name); + if (iter == stringMap.end()) { + return ""; + } + return iter->second; + } + +public: + /****************************************************** + * + * Set class/struct members value + * + ******************************************************/ + std::vector GetMembersNames(const std::string membersStr) { + std::vector array = StringSplit(membersStr); + StringTrim(array); + return array; + } + + std::map GetMembersValueMap(const std::string valueStr) { + std::vector array = StringSplit(valueStr); + std::map ret; + for (int i = 0; i < array.size(); i++) { + std::vector keyValue = StringSplit(array[i], '='); + if (keyValue.size() != 2) { + continue; + } + + std::string key = StringTrim(keyValue[0]); + std::string value = StringTrim(keyValue[1]); + if (ret.find(key) != ret.end()) { + continue; + } + ret.insert(std::pair(key, value)); + } + return ret; + } + + template + bool SetMembers(const std::vector &names, + int index, + rapidjson::Value &jsonValue, + std::map defaultValues, + TYPE &arg, + TYPES &...args) { + if (!SetMembers(names, index, jsonValue, defaultValues, arg)) { + return false; + } + + return SetMembers(names, ++index, jsonValue, defaultValues, args...); + } + + template + bool SetMembers(const std::vector &names, + int index, + rapidjson::Value &jsonValue, + std::map defaultValues, + TYPE &arg) { + if (jsonValue.IsNull()) { + return true; + } + + const char *key = names[index].c_str(); + if (!jsonValue.IsObject()) { + return false; + } + if (!jsonValue.HasMember(key)) { + std::string defaultV = FindStringFromMap(names[index], defaultValues); + if (!defaultV.empty()) { + StringToObject(arg, defaultV); + } + return true; + } + + if (!JsonToObject(arg, jsonValue[key])) { + m_message = "[" + names[index] + "] " + m_message; + return false; + } + return true; + } + + /****************************************************** + * + * Get class/struct members value + * + ******************************************************/ + template + bool GetMembers(const std::vector &names, + int index, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator, + TYPE &arg, + TYPES &...args) { + if (!GetMembers(names, index, jsonValue, allocator, arg)) { + return false; + } + return GetMembers(names, ++index, jsonValue, allocator, args...); + } + + template + bool GetMembers(const std::vector &names, + int index, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator, + TYPE &arg) { + rapidjson::Value item; + bool check = ObjectToJson(arg, item, allocator); + if (!check) { + m_message = "[" + names[index] + "] " + m_message; + return false; + } + + if (jsonValue.HasMember(names[index].c_str())) { + jsonValue.RemoveMember(names[index].c_str()); + } + + rapidjson::Value key; + key.SetString(names[index].c_str(), static_cast(names[index].length()), allocator); + jsonValue.AddMember(key, item, allocator); + return true; + } + +public: + /****************************************************** + * + * Set base class value + * + ******************************************************/ + template bool SetBase(rapidjson::Value &jsonValue, TYPE *arg, TYPES *...args) { + if (!SetBase(jsonValue, arg)) { + return false; + } + return SetBase(jsonValue, args...); + } + + template bool SetBase(rapidjson::Value &jsonValue, TYPE *arg) { + return JsonToObject(*arg, jsonValue); + } + + /****************************************************** + * + * Get base class value + * + ******************************************************/ + template + bool GetBase(rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator, + TYPE *arg, + TYPES *...args) { + if (!GetBase(jsonValue, allocator, arg)) { + return false; + } + return GetBase(jsonValue, allocator, args...); + } + + template + bool GetBase(rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator, TYPE *arg) { + return ObjectToJson(*arg, jsonValue, allocator); + } + +public: + /****************************************************** + * Conver base-type : string to base-type + * Contain: int\uint、int64_t\uint64_t、bool、float + * double、string + * + ******************************************************/ + template void StringToObject(TYPE &obj, std::string &value) { + return; + } + + void StringToObject(std::string &obj, std::string &value) { + obj = value; + } + + void StringToObject(int &obj, std::string &value) { + obj = atoi(value.c_str()); + } + + void StringToObject(unsigned int &obj, std::string &value) { + char *end; + obj = strtoul(value.c_str(), &end, 10); + } + + void StringToObject(int64_t &obj, std::string &value) { + char *end; + obj = strtoll(value.c_str(), &end, 10); + } + + void StringToObject(uint64_t &obj, std::string &value) { + char *end; + obj = strtoull(value.c_str(), &end, 10); + } + + void StringToObject(bool &obj, std::string &value) { + obj = (value == "true"); + } + + void StringToObject(float &obj, std::string &value) { + obj = strtof(value.c_str(), nullptr); + } + + void StringToObject(double &obj, std::string &value) { + obj = strtod(value.c_str(), nullptr); + } + +public: + /****************************************************** + * Conver base-type : Json string to base-type + * Contain: int\uint、int64_t\uint64_t、bool、float + * double、string、vector、list、map + * + ******************************************************/ + bool JsonToObject(int &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsInt()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is int."; + return false; + } + obj = jsonValue.GetInt(); + return true; + } + + bool JsonToObject(unsigned int &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsUint()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is unsigned int."; + return false; + } + obj = jsonValue.GetUint(); + return true; + } + + bool JsonToObject(short &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsInt()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is short."; + return false; + } + obj = jsonValue.GetInt(); + return true; + } + + bool JsonToObject(unsigned short &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsUint()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is unsigned short."; + return false; + } + obj = jsonValue.GetUint(); + return true; + } + + bool JsonToObject(int64_t &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsInt64()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is int64_t."; + return false; + } + obj = jsonValue.GetInt64(); + return true; + } + + bool JsonToObject(uint64_t &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsUint64()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is uint64_t."; + return false; + } + obj = jsonValue.GetUint64(); + return true; + } + + bool JsonToObject(bool &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsBool()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is bool."; + return false; + } + obj = jsonValue.GetBool(); + return true; + } + + bool JsonToObject(float &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsNumber()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is float."; + return false; + } + obj = jsonValue.GetFloat(); + return true; + } + + bool JsonToObject(double &obj, rapidjson::Value &jsonValue) { + if (!jsonValue.IsNumber()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is double."; + return false; + } + obj = jsonValue.GetDouble(); + return true; + } + + bool JsonToObject(std::string &obj, rapidjson::Value &jsonValue) { + obj = ""; + if (jsonValue.IsNull()) { + return true; + } + //object or number conver to string + else if (jsonValue.IsObject() || jsonValue.IsNumber()) { + obj = GetStringFromJsonValue(jsonValue); + } else if (!jsonValue.IsString()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is string."; + return false; + } else { + obj = jsonValue.GetString(); + } + + return true; + } + + template bool JsonToObject(std::vector &obj, rapidjson::Value &jsonValue) { + obj.clear(); + if (!jsonValue.IsArray()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is std::vector."; + return false; + } + + auto array = jsonValue.GetArray(); + for (int i = 0; i < array.Size(); i++) { + TYPE item; + if (!JsonToObject(item, array[i])) { + return false; + } + obj.push_back(item); + } + return true; + } + + template bool JsonToObject(std::list &obj, rapidjson::Value &jsonValue) { + obj.clear(); + if (!jsonValue.IsArray()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + " but object is std::list."; + return false; + } + + auto array = jsonValue.GetArray(); + for (int i = 0; i < static_cast(array.Size()); i++) { + TYPE item; + if (!JsonToObject(item, array[i])) { + return false; + } + obj.push_back(item); + } + return true; + } + + template bool JsonToObject(std::map &obj, rapidjson::Value &jsonValue) { + obj.clear(); + if (!jsonValue.IsObject()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + + " but object is std::map."; + return false; + } + + for (auto iter = jsonValue.MemberBegin(); iter != jsonValue.MemberEnd(); ++iter) { + auto key = (iter->name).GetString(); + auto &value = jsonValue[key]; + + TYPE item; + if (!JsonToObject(item, value)) { + return false; + } + + obj.insert(std::pair(key, item)); + } + return true; + } + + template bool JsonToObject(std::unordered_map &obj, rapidjson::Value &jsonValue) { + obj.clear(); + if (!jsonValue.IsObject()) { + m_message = "json-value is " + GetJsonValueTypeName(jsonValue) + + " but object is std::unordered_map."; + return false; + } + + for (auto iter = jsonValue.MemberBegin(); iter != jsonValue.MemberEnd(); ++iter) { + auto key = (iter->name).GetString(); + auto &value = jsonValue[key]; + + TYPE item; + if (!JsonToObject(item, value)) { + return false; + } + + obj.insert(std::pair(key, item)); + } + return true; + } + +public: + /****************************************************** + * Conver base-type : base-type to json string + * Contain: int\uint、int64_t\uint64_t、bool、float + * double、string、vector、list、map + * + ******************************************************/ + bool ObjectToJson(int &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetInt(obj); + return true; + } + + bool ObjectToJson(unsigned int &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetUint(obj); + return true; + } + + bool ObjectToJson(short &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetInt((int)obj); + return true; + } + + bool ObjectToJson(unsigned short &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetUint((unsigned int)obj); + return true; + } + + bool ObjectToJson(int64_t &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetInt64(obj); + return true; + } + + bool ObjectToJson(uint64_t &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetUint64(obj); + return true; + } + + bool ObjectToJson(bool &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetBool(obj); + return true; + } + + bool ObjectToJson(float &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetFloat(obj); + return true; + } + + bool ObjectToJson(double &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetDouble(obj); + return true; + } + + bool ObjectToJson(std::string &obj, rapidjson::Value &jsonValue, rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetString(obj.c_str(), static_cast(obj.length()), allocator); + return true; + } + + template + bool ObjectToJson(std::vector &obj, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator) { + rapidjson::Value array(rapidjson::Type::kArrayType); + for (int i = 0; i < obj.size(); i++) { + rapidjson::Value item; + if (!ObjectToJson(obj[i], item, allocator)) { + return false; + } + + array.PushBack(item, allocator); + } + + jsonValue = array; + return true; + } + + template + bool ObjectToJson(std::list &obj, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator) { + rapidjson::Value array(rapidjson::Type::kArrayType); + for (auto i = obj.begin(); i != obj.end(); i++) { + rapidjson::Value item; + if (!ObjectToJson(*i, item, allocator)) { + return false; + } + + array.PushBack(item, allocator); + } + + jsonValue = array; + return true; + } + + template + bool ObjectToJson(std::map &obj, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetObject(); + for (auto iter = obj.begin(); iter != obj.end(); ++iter) { + auto key = iter->first; + TYPE value = iter->second; + + rapidjson::Value jsonitem; + if (!ObjectToJson(value, jsonitem, allocator)) { + return false; + } + + rapidjson::Value jsonkey; + jsonkey.SetString(key.c_str(), key.length(), allocator); + + jsonValue.AddMember(jsonkey, jsonitem, allocator); + } + return true; + } + + template + bool ObjectToJson(std::unordered_map &obj, + rapidjson::Value &jsonValue, + rapidjson::Document::AllocatorType &allocator) { + jsonValue.SetObject(); + for (auto iter = obj.begin(); iter != obj.end(); ++iter) { + auto key = iter->first; + TYPE value = iter->second; + + rapidjson::Value jsonitem; + if (!ObjectToJson(value, jsonitem, allocator)) { + return false; + } + + rapidjson::Value jsonkey; + jsonkey.SetString(key.c_str(), key.length(), allocator); + + jsonValue.AddMember(jsonkey, jsonitem, allocator); + } + return true; + } + +public: + std::string m_message; +}; + +class JsonHelper { +public: + /** + * @brief conver json string to class | struct + * @param obj : class or struct + * @param jsonStr : json string + * @param keys : the path of the object + * @param message : printf err message when conver failed + */ + template + static bool JsonToObject(T &obj, + const std::string &jsonStr, + const std::vector keys = {}, + std::string *message = NULL) { + //Parse json string + rapidjson::Document root; + root.Parse(jsonStr.c_str()); + if (root.IsNull()) { + if (message) { + *message = "Json string can't parse."; + } + return false; + } + + //Go to the key-path + std::string path; + rapidjson::Value &value = root; + for (int i = 0; i < (int)keys.size(); i++) { + const char *find = keys[i].c_str(); + if (!path.empty()) { + path += "->"; + } + path += keys[i]; + + if (!value.IsObject() || !value.HasMember(find)) { + if (message) { + *message = "Can't parse the path [" + path + "]."; + } + return false; + } + value = value[find]; + } + + //Conver + JsonHelperPrivate handle; + if (!handle.JsonToObject(obj, value)) { + if (message) { + *message = handle.m_message; + } + return false; + } + return true; + } + + /** + * @brief conver json string to class | struct + * @param jsonStr : json string + * @param defaultT : default value + * @param keys : the path of the object + * @param message : printf err message when conver failed + */ + template + static T Get(const std::string &jsonStr, + T defaultT, + const std::vector keys = {}, + std::string *message = NULL) { + T obj; + if (JsonToObject(obj, jsonStr, keys, message)) { + return obj; + } + + return defaultT; + } + + /** + * @brief conver class | struct to json string + * @param errMessage : printf err message when conver failed + * @param isPretty Output pretty format json string + * @param obj : class or struct + * @param jsonStr : json string + */ + template static bool ObjectToJson(T &obj, std::string &jsonStr, bool isPretty, std::string *message) { + rapidjson::Document root; + root.SetObject(); + rapidjson::Document::AllocatorType &allocator = root.GetAllocator(); + + //Conver + JsonHelperPrivate handle; + if (!handle.ObjectToJson(obj, root, allocator)) { + if (message) { + *message = handle.m_message; + } + return false; + } + + jsonStr = handle.GetStringFromJsonValue(root, isPretty); + return true; + } + + /** + * @brief conver class | struct to json string + * @param errMessage : printf err message when conver failed + * @param obj : class or struct + * @param isPretty Output pretty format json string + * @param jsonStr : json string + */ + template static bool ObjectToJson(T &obj, std::string &jsonStr, bool isPretty = false) { + rapidjson::Document root; + root.SetObject(); + rapidjson::Document::AllocatorType &allocator = root.GetAllocator(); + + //Conver + JsonHelperPrivate handle; + if (!handle.ObjectToJson(obj, root, allocator)) { + return false; + } + + jsonStr = handle.GetStringFromJsonValue(root, isPretty); + return true; + } +}; + +} // namespace aigc \ No newline at end of file diff --git a/NetTunnelSDK/include/misc.h b/NetTunnelSDK/include/misc.h new file mode 100644 index 0000000..d341c5d --- /dev/null +++ b/NetTunnelSDK/include/misc.h @@ -0,0 +1,119 @@ +#pragma once + +#include "common.h" + +#include + +#define CFG_WIREGUARD_SECTION TEXT("WireGuard") +#define CFG_WIREGUARD_PATH TEXT("WireGuardExe") +#define CFG_WGCFG_PATH TEXT("WgCfgPath") +#define CFG_WG_PATH TEXT("WgExe") + +typedef struct { + TCHAR path[MAX_PATH]; +} FILE_PATH, *PFILE_PATH; + +typedef struct { + PFILE_PATH pFilePath; + unsigned int nItems; +} FILE_LIST, *PFILE_LIST; + +/** + * @brief IPv4 网络信息 + */ +typedef struct { + unsigned int prefix; ///< 网络前缀 + TCHAR hostip[MAX_IP_LEN]; ///< IP 地址 + TCHAR ip[MAX_IP_LEN]; ///< IP 地址 + TCHAR network[MAX_IP_LEN]; ///< 网络地址 + TCHAR broadcast[MAX_IP_LEN]; ///< 网络广播地址 + TCHAR netmask[MAX_IP_LEN]; ///< 网络子网掩码 + TCHAR hosts[64]; ///< number of hosts in text + TCHAR hostmin[MAX_IP_LEN]; ///< 最小网络主机 IP + TCHAR hostmax[MAX_IP_LEN]; ///< 最大网络主机 IP +} IP_INFO, *PIP_INFO; + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif + +void RemoveTailLineBreak(TCHAR *pInputStr, int strSize); +int RunCommand(TCHAR *pszCmd, TCHAR *pszResultBuffer, int dwResultBufferSize, unsigned long *pRetCode); + +/** + * @brief IPv4 子网掩码转 CIDR 掩码 + * @param[in] pNetMask IPv4 子网掩码字符串 + * @return IPv4 CIDR 掩码 + */ +int __cdecl NetmaskToCIDR(const TCHAR *pNetMask); + +/** + * @brief CIDR 掩码转 IPv4 子网掩码 + * @param[in] cidr CIDR 掩码 + * @return CIDR 对应的子网掩码 + */ +const TCHAR *CIDRToNetmask(const UINT8 cidr); + +void ShowWindowsErrorMessage(const TCHAR *pMsgHead); +void StringReplaceAll(TCHAR *pOrigin, const TCHAR *pOldStr, const TCHAR *pNewStr); +void StringRemoveAll(TCHAR *pOrigin, const TCHAR *pString); +TCHAR *binToHexString(TCHAR *p, const unsigned char *cp, unsigned int count); +int GetWindowsServiceStatus(const TCHAR *pSvrName, PDWORD pStatus); +/** + * @brief Unicode 宽字符串转 TCHAR 字符串 + * @param[in] pWStr 需要转换的宽字符串 + * @param[out] pOutStr 转换后的 TCHAR 字符串 + * @param[in] maxOutLen pOutStr 最大字节数 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +int WideCharToTChar(const WCHAR *pWStr, TCHAR *pOutStr, int maxOutLen); + +/** + * @brief TCHAR 字符串 字符串转 Unicode 字符串 + * @param[in] pTStr 需要转换的 TCHAR 字符串 + * @param[out] pOutStr 转换后的 WCHAR 字符串 + * @param[in] maxOutLen pOutStr 最大字节数 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +int TCharToWideChar(const TCHAR *pTStr, WCHAR *pOutStr, int maxOutLen); + +int FindFile(const TCHAR *pPath, PFILE_LIST pFileList, const bool exitWhenMatchOne); +/** + * @brief 计算 IPv4 网络信息 + * @param[in] pIpStr IPv4 地址 + * @param[in] pNetmask IPv4子网掩码 + * @param[out] pInfo 计算结果 + * @return 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetIpV4InfoFromNetmask(const TCHAR *pIpStr, const TCHAR *pNetmask, PIP_INFO pInfo); +/** + * @brief 计算 IPv4 网络信息 + * @param[in] pIpStr IPv4 网络信息 '/' 分割,支持CIDR以及子网掩码 example: 192.168.1.32/24, 192.168.1.32/255.255.255.0 + * @param[out] pInfo 计算结果 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetIpV4InfoFromCIDR(const TCHAR *pIpStr, PIP_INFO pInfo); + +int GetIpV4InfoFromHostname(int family, const char *host, PIP_INFO pInfo); + +int InitializeWireGuardLibrary(); +void UnInitializeWireGuardLibrary(); + +void StopUDPProxyServer(); +int CreateUDPProxyServer(); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/network.h b/NetTunnelSDK/include/network.h new file mode 100644 index 0000000..1d8e3a6 --- /dev/null +++ b/NetTunnelSDK/include/network.h @@ -0,0 +1,332 @@ +#pragma once + +#include "sccsdk.h" + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif + +/** + * @brief 根据网卡 IP地址 获取网卡索引 + * @param[in] pIpAddr 网卡IP地址 + * @param[out] pIfIndex 网卡索引编号 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_ITEM_UNEXISTS 找不到合适的网卡 + * - -ERR_UN_SUPPORT 系统不支持该操作 + * - ERR_SUCCESS 成功 + */ +int GetInterfaceIfIndexByIpAddr(const TCHAR *pIpAddr, ULONG *pIfIndex); + +/** + * @brief 根据网卡 GUDI 获取网卡名称 + * @param[in] pGUID 网卡 GUID + * @param[out] ifName 网卡名称 + * @param[out] pConnStatus 网卡连接状态 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_CREATE_COMMOBJECT 创建 COM 对象失败 + * - -ERR_SYS_CALL 调用 COM 接口失败 + * - -ERR_ITEM_UNEXISTS GUID 不存在 + * - ERR_SUCCESS 成功 + */ +int GetInterfaceNameByGUID(const TCHAR *pGUID, TCHAR ifName[MAX_NETCARD_NAME], int* pConnStatus); + +/** + * @brief 根据网卡名获取网卡索引 + * @param[in] pInterfaceName 网卡名称 + * @param[out] pIfIndex 网卡 Index + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 网卡不存在 + * - -ERR_SYS_CALL 获取操作系统网卡适配器失败 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetInterfaceIfIndexByName(const TCHAR *pInterfaceName, int *pIfIndex); + +/** + * @brief 根据网卡名获取网卡 GUID + * @param[in] ifIndex 网卡索引 + * @param[out] pGuid 网卡 GUID + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 网卡不存在 + * - -ERR_MEMORY_STR 字符串转 GUID 结构体失败 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetInterfaceGUIDByIfIndex(const int ifIndex, GUID *pGuid); + +/** + * @brief 根据网卡名获取网卡 GUID + * @param[in] pInterfaceName 网卡名称 + * @param[out] pGuid 网卡 GUID + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 网卡不存在 + * - -ERR_MEMORY_STR 字符串转 GUID 结构体失败 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetInterfaceGUIDByName(const TCHAR *pInterfaceName, GUID *pGuid); + +int WaitNetAdapterConnected(const TCHAR *pInterfaceName, int timeOutOfMs); + +/** + * @brief 获取网卡 NetworkCategory 是否设置为 Private 模式 + * @param[in] pInterfaceName 网卡名称 + * @param[out] pIsPrivate 网卡属性 + * - TRUE Private 模式 + * - FALSE Public 模式 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_CREATE_COMMOBJECT 创建 COM 对象失败 + * - -ERR_SYS_CALL 调用 COM 接口失败 + * - -ERR_ITEM_UNEXISTS 当前网络接口不存在不存在 + * - -ERR_MEMORY_STR 字符集转换失败 + * - ERR_SUCCESS 成功 + */ +int GetNetConnectionNetworkCategory(const TCHAR *pInterfaceName, bool *pIsPrivate); + +/** + * @brief 启动/停止 Windows 网络共享服务 + * @param[in] ifIndex 网卡索引 + * @param[in] isEnable 启动/停止 Windows 网络共享服务 + * - TRUE 启动服务 + * - FALSE 停止服务 + * @param[in] isSetPrivate 共享连接属性 + * -TRUE 私有连接 + * -FALSE 公共连接 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_CREATE_COMMOBJECT 创建 COM 对象失败 + * - -ERR_SYS_CALL 调用 COM 接口失败 + * - -ERR_ITEM_UNEXISTS GUID 不存在 + * - -ERR_NET_UNCONNECT 网络未连接 + * - ERR_SUCCESS 成功 + */ +int SetNetIntelnetConnectionSharing(int ifIndex, bool isEnable, bool isSetPrivate); + +/** + * @brief 获取当前网络共享服务状态 + * @param[in] ifIndex 网卡名称索引 + * @param[out] pIsEnable 当前网络共享服务状态 + * - TRUE 启动 + * - FALSE 停止 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_CREATE_COMMOBJECT 创建 COM 对象失败 + * - -ERR_SYS_CALL 调用 COM 接口失败 + * - -ERR_ITEM_UNEXISTS GUID 不存在 + * - -ERR_NET_UNCONNECT 网络未连接 + * - -ERR_CALL_COMMOBJECT 获取网络共享状态失败 + * - ERR_SUCCESS 成功 + */ +int GetNetIntelnetConnectionSharing(int ifIndex, bool *pIsEnable); +/** + * @brief 设置网卡为 Private/Public 模式 + * @param[in] pInterfaceName pInterfaceName 网卡名称 + * @param[in] isPrivate 网卡 Category 模式 + * - TRUE Private 模式 + * - FALSE Public 模式 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ +int SetNetConnectionNetworkCategory(const TCHAR *pInterfaceName, const bool isPrivate); + +/** + * @brief 添加系统路由表项 + * @param[in] pIP 目的 IP 地址 + * @param[in] pMask 目的子网掩码 + * @param[in] pGateway 路由网关 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_UN_SUPPORT IP地址转网络字节序网络地址失败 + * - -ERR_NET_ADD_ROUTE 添加路由表项失败 + * - -ERR_NET_REMOVE_ROUTE 删除路由表项失败 + * - ERR_SUCCESS 成功 + */ +int AddRouteTable(const char *pIP, const char *pMask, const char *pGateway); + +/** + * @brief 开启 Windows WireGuard NAT 转发功能 + * @param[in] pInterfaceName 网卡名称 + * @param[in] pCidrIpaddr CIDR 网络接口地址 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ +int SetNATRule(const TCHAR *pInterfaceName, const TCHAR *pCidrIpaddr); + +/** + * @brief 关闭 Windows WireGuard NAT 转发功能 + * @param pInterfaceName 网卡名称 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ +int RemoveNATRule(const TCHAR *pInterfaceName); + +#if 0 +/** + * @brief 设置网络接口IP地址 + * @param[in] pInterfaceName 网卡名称 + * @param[in] pIpaddr IP 地址 + * @param[in] pNetmask 子网掩码 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int SetInterfaceIpAddress(const TCHAR *pInterfaceName, const TCHAR *pIpaddr, const TCHAR *pNetmask); + +/** + * @brief 获取Windows Hyper-V 虚拟机状态, 必须开启后才能开启NAT转发功能 + * @param[out] pEnabled 当前 Hyper-V 虚拟机状态, TRUE 表示开启, FALSE 表示关闭 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int GetWindowsHyperVStatus(int *pEnabled); + +/** + * @brief 启用/禁用 Windows Hyper-V 功能 + * @param[in] enabled TRUE 启用, FALSE 关闭 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ + int EnableWindowsHyperV(bool enabled); + +/** + * @brief 设置网卡为 Private/Public 模式 + * @param[in] pInterfaceName 网卡名称 + * @param[in] isPrivate + * - TRUE Private 模式 + * - FALSE Public 模式 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ + int SetInterfacePrivate(const TCHAR *pInterfaceName, bool isPrivate); + +/** + * @brief 获取网卡 NetworkCategory 是否设置为 Private 模式 + * @param[in] pInterfaceName 网卡名称 + * @param[out] pIsPrivateMode 网卡属性 + * - TRUE Private 模式 + * - FALSE Public 模式 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_ITEM_UNEXISTS 设备不存在 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ + int IsInterfacePrivate(const TCHAR *pInterfaceName, bool *pIsPrivateMode); + +/** + * @brief 获取网卡接口索引编号 + * @param[in] pInterfaceName 网卡名称 + * @param[out] pIndex 网卡索引 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int GetInterfaceIndexByName(const TCHAR *pInterfaceName, int *pIndex); + +/** + * @brief 删除接口网络地址 + * @param[in] pInterfaceName 网卡名称 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int RemoveInterfaceIpAddress(const TCHAR *pInterfaceName); + +/** + * @brief 设置网络接口IP地址 + * @param[in] pInterfaceName 网卡名称 + * @param[in] pCidrIpaddr CIDR类型IP地址 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int SetInterfaceIpAddressFromCIDR(const TCHAR *pInterfaceName, const TCHAR *pCidrIpaddr); + +/** + * @brief 设置网络接口IP地址 + * @param[in] pInterfaceName 网卡名称 + * @param[in] pIpaddr IP 地址 + * @param[in] pNetmask 子网掩码 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ + int SetInterfaceIpAddress(const TCHAR *pInterfaceName, const TCHAR *pIpaddr, const TCHAR *pNetmask); + + +/** + * @brief 获取网络连接NAT功能是否开启 + * @param[in] pInterfaceName 网卡名称 + * @param[out] pIsEnabled 网卡NAT当前是否开启 + * - TRUE NAT 开启 + * - FALSE NAT 关闭 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_ITEM_UNEXISTS 设备不存在 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - -ERR_PROCESS_RETURN 系统调用执行结束返回失败 + * - ERR_SUCCESS 成功 + */ + int IsNetConnectionSharingEnabled(const TCHAR *pInterfaceName, bool *pIsEnabled); + +/** + * @brief WireGuard 服务 Windows PowerShell 自定义命令是否安装 + * @return + * - TRUE 已经安装 + * - FALSE 未安装 + */ + bool IsCustomNatPSCmdInstalled(); +#endif +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/pch.h b/NetTunnelSDK/include/pch.h new file mode 100644 index 0000000..dbb2e77 --- /dev/null +++ b/NetTunnelSDK/include/pch.h @@ -0,0 +1,10 @@ +// +// Created by HuangXin on 2023/8/22. +// +#pragma once + +#define WIN32_LEAN_AND_MEAN // 从 Windows 头文件中排除极少使用的内容 +#define CPPHTTPLIB_OPENSSL_SUPPORT + +// Windows 头文件 +#include diff --git a/NetTunnelSDK/include/protocol.h b/NetTunnelSDK/include/protocol.h new file mode 100644 index 0000000..d96eb53 --- /dev/null +++ b/NetTunnelSDK/include/protocol.h @@ -0,0 +1,290 @@ +#pragma once +#include "ProtocolBase.h" +#include "common.h" + +#if !USER_REAL_PLATFORM +class PlatformReqServerCfgParms { +public: + std::string vmIp; + AIGC_JSON_HELPER(vmIp) +}; + +class PlatformReqClientCfgParms { +public: + std::string userName; + std::string token; + AIGC_JSON_HELPER(userName) +}; + +class PlatformRspUserSvrCfgParams { +public: + PlatformRspUserSvrCfgParams() { + this->svrHost = TEXT(""); + this->svrPort = 0; + this->svrPriKey = TEXT(""); + } + + int svrPort; + std::string svrPriKey; + std::string svrHost; + + AIGC_JSON_HELPER(svrPort, svrPriKey, svrHost) +}; + +class VitrualMathineInfo { +public: + VitrualMathineInfo() { + this->vmId = 0; + this->scgPort = 0; + this->vmName = TEXT(""); + this->scgIp = TEXT(""); + this->vmNetwork = TEXT(""); + this->svrPubKey = TEXT(""); +#if USED_PORTMAP_TUNNEL + this->portMapIp = TEXT(""); + this->portMapPort = 0; +#endif + } + + int vmId; + std::string vmName; + std::string svrPubKey; + std::string vmNetwork; + std::string scgIp; + int scgPort; +#if USED_PORTMAP_TUNNEL + std::string portMapIp; + int portMapPort; + AIGC_JSON_HELPER(vmId, vmName, svrPubKey, vmNetwork, scgIp, scgPort, portMapIp, portMapPort) +#else + AIGC_JSON_HELPER(vmId, vmName, svrPubKey, vmNetwork, scgIp, scgPort) +#endif +}; + +class PlatformRspUserClientCfgParams { +public: + PlatformRspUserClientCfgParams() { + this->scgTunnelAppId = WG_TUNNEL_SCG_ID; + this->scgCtrlAppId = WG_CTRL_SCG_ID; + this->cliHost = TEXT(""); + } + + int scgCtrlAppId; + int scgTunnelAppId; + std::string cliPriKey; + std::string cliPubKey; + std::string cliHost; + std::list vmInfoList; + + AIGC_JSON_HELPER(scgCtrlAppId, scgTunnelAppId, cliPriKey, cliPubKey, vmInfoList, cliHost) +}; + +class PlatformRspServerCfgParams { +public: + std::string code; + std::string message; + PlatformRspUserSvrCfgParams data; + + AIGC_JSON_HELPER(code, data) +}; + +class PlatformRspClientCfgParams { +public: + std::string code; + std::string message; + PlatformRspUserClientCfgParams data; + + AIGC_JSON_HELPER(code, data) +}; + +#endif + +class ReqClientCfgParams { +public: + std::string identifier; + + AIGC_JSON_HELPER(identifier) +}; + +class ReqHeartParams { +public: + std::string message; + + AIGC_JSON_HELPER(message) + AIGC_JSON_HELPER_DEFAULT(message = TEXT("PING")) +}; + +class RspHeartParams : public ResponseStatus { +public: + std::string message; + + AIGC_JSON_HELPER(message) + AIGC_JSON_HELPER_BASE((ResponseStatus *)this) + AIGC_JSON_HELPER_DEFAULT(message = TEXT("PONG")) +}; + +class ReqGetUserCfgParams { +public: + std::string user; + std::string token; + + AIGC_JSON_HELPER(user, token) +}; + +class RspUserSevrCfgParams { +public: + RspUserSevrCfgParams() { + this->svrAddress = TEXT(""); + this->svrListenPort = 0; + this->svrPrivateKey = TEXT(""); + } + + int svrListenPort; + std::string svrPrivateKey; + std::string svrAddress; + + AIGC_JSON_HELPER(svrListenPort, svrPrivateKey, svrAddress) +}; + +class ReqStartTunnelParams { +public: + bool isStart; + AIGC_JSON_HELPER(isStart) +}; + +class ReqUserSetCliCfgParams { +public: + std::string cliPublicKey; + std::string cliNetwork; + std::string cliTunnelAddr; + + AIGC_JSON_HELPER(cliPublicKey, cliNetwork, cliTunnelAddr) +}; + +class RspUserSetCliCfgParams : public ResponseStatus { +public: + std::string svrNetwork; + + AIGC_JSON_HELPER(svrNetwork) + AIGC_JSON_HELPER_BASE((ResponseStatus *)this) +}; + +class VitrualMathineCfg { +public: + VitrualMathineCfg() { + this->vmId = 0; + this->vmName = TEXT(""); + this->scgGateway = TEXT(""); + this->vmNetwork = TEXT(""); + this->svrPublicKey = TEXT(""); +#if USED_PORTMAP_TUNNEL + this->portMapIp = TEXT(""); + this->portMapPort = 0; +#endif + } + + int vmId; + std::string vmName; + std::string svrPublicKey; + std::string vmNetwork; + std::string scgGateway; +#if USED_PORTMAP_TUNNEL + std::string portMapIp; + int portMapPort; + AIGC_JSON_HELPER(vmId, vmName, svrPublicKey, vmNetwork, scgGateway, portMapIp, portMapPort) +#else + AIGC_JSON_HELPER(vmId, vmName, svrPublicKey, vmNetwork, scgGateway) +#endif +}; + +class RspUsrCliConfigParams { +public: + int scgCtrlAppId; + int scgTunnelAppId; + std::string cliPrivateKey; + std::string cliPublicKey; + std::string cliAddress; + std::list vmConfig; + + AIGC_JSON_HELPER(scgCtrlAppId, scgTunnelAppId, cliPrivateKey, cliPublicKey, cliAddress, vmConfig) +}; + +#if USER_REAL_PLATFORM +#define GET_CLIENTCFG_PATH TEXT("/tunnel/getuserconfig") +#define GET_SERVERCFG_PATH TEXT("/tunnel/getserverconfig") +#else +#define GET_CLIENTCFG_PATH TEXT("/sc/open-portal/openapi/scc/cliTunnelCfg") +#define GET_SERVERCFG_PATH TEXT("/sc/open-portal/openapi/scc/svrTunnelCfg") +#endif + +#define SET_CLIENTCFG_PATH TEXT("/tunnel/setconfig") +#define SET_CLIENTSTART_TUNNEL TEXT("/tunnel/start") +#define SET_CLIENTHEART_PATH TEXT("/tunnel/heart") + +int InitControlServer(const TCHAR *pUserSvrUrl); + +/** + * @brief 调用 RESTful POST 接口并获取服务端返回数据 + * @param[in] pUrlPath 服务端 URL 路径 + * @param[in] pReq 请求消息 + * @param[in] pRsp 返回消息 + * @param[in] platformServer 访问平台还是访问控制服务 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_SYSTEM_UNINITIALIZE 服务端URL未初始化 + * - -ERR_JSON_CREATE 创建 JSON 字符串失败 + * - -ERR_HTTP_POST_DATA 调用 POST 方法失败 + * - -ERR_HTTP_SERVER_RSP 服务端返回失败(非200) + * - -ERR_READ_FILE 服务端返回空数据 + * - -ERR_JSON_DECODE 解析 JSON 数据失败 + * - ERR_SUCCESS 成功 + */ +template +int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +extern template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +extern template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +extern template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +extern template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +extern template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +#if !USER_REAL_PLATFORM +template int PlatformProtolPostMessage(const TCHAR *pUrlPath, T1 *pReq, T2 *pRsp); + +extern template int PlatformProtolPostMessage(const TCHAR *pUrlPath, + PlatformReqServerCfgParms *pReq, + PlatformRspServerCfgParams *pRsp); + +extern template int PlatformProtolPostMessage(const TCHAR *pUrlPath, + PlatformReqClientCfgParms *pReq, + PlatformRspClientCfgParams *pRsp); + +#if 0 +template int PlatformProtolGetMessage(const TCHAR *pUrlPath, T1 *pRsp); + +extern template int PlatformProtolGetMessage(const TCHAR *pUrlPath, + PlatformRspUserClientCfgParams *pRsp); +#endif +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/sccsdk.h b/NetTunnelSDK/include/sccsdk.h new file mode 100644 index 0000000..d4fcef3 --- /dev/null +++ b/NetTunnelSDK/include/sccsdk.h @@ -0,0 +1,303 @@ +#pragma once +#include +#include "common.h" +#include "usrerr.h" + +#ifdef NETTUNNELSDK_EXPORTS +#define SCCSDK_API __declspec(dllexport) +#else +#define SCCSDK_API __declspec(dllimport) +#endif + +typedef void (*PTUNNEL_HEART_ROUTINE)(const TCHAR *pMessage, unsigned int timeStampOfSeconds); +typedef PTUNNEL_HEART_ROUTINE LPTUNNEL_HEART_ROUTINE; + +/** + * + * @brief 本地计算机网卡信息 + */ +typedef struct { + int InterfaceIndex; ///< 网卡索引 + NET_CONNECT_STATUS netConnStatus; ///< 网卡状态 @see NET_CONNECT_STATUS + TCHAR NetCardUUID[260]; ///< 网卡名称, Windows标识为 UUID + TCHAR NetCardName[MAX_NETCARD_NAME]; ///< 网卡名称 + TCHAR NetCardDescription[132]; ///< 网卡描述 + TCHAR NetCardIpaddr[MAX_IP_LEN]; ///< 网卡 IP 地址 + TCHAR NetCardNetmask[MAX_IP_LEN]; ///< 网卡子网掩码 + TCHAR NetCardGateway[MAX_IP_LEN]; ///< 网卡网关 + TCHAR NetCardMacAddr[20]; ///< 网卡 MAC 地址 +} NIC_CONTENT, *PNIC_CONTENT; + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif + +/** + * @brief 初始化 SDK 运行环境 + * @param[in] pWorkDir 程序工作路径,如果不设置系统自动获取 + * @param[in] pSvrUrl 管理平台 URL 地址 example: http://localhost:2313, https://localhost:2313 + * @param[in] pLogFile 日志文件名称/完整路径 + * @param[in] level 日志最低有效等级 + * @param[in] isWorkServer SDK 工作模式 + * - TRUE 服务端 + * - FALSE 客户端 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_ITEM_EXISTS 未找到 WireGuard 程序 + * - -ERR_SYS_CALL 获取配置本地配置文件存储目录失败 + * - -ERR_CREATE_FILE 创建用户配置文件目录失败 + * - -ERR_ITEM_UNEXISTS WireGuard 未安装 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl TunnelSDKInitEnv(const TCHAR *pWorkDir, + const TCHAR *pSvrUrl, + const TCHAR *pLogFile, + LOG_LEVEL level, + bool isWorkServer); + +/** + * @brief 清理 SDK 运行资源 + */ +SCCSDK_API void __cdecl TunnelSDKUnInit(); + +/** + * @brief 打开/关闭 SDK 日志开关 + * @param enLog 日志开关 + * - TRUE 打开日志 + * - FALSE 关闭日志 + */ +SCCSDK_API void __cdecl TunnelLogEnable(bool enLog); + +/** + * @brief 获取当前 WireGuard 服务隧道是否正则运行 + * @param pTunnelName 隧道服务名 + * @param pIsRunning pIsRunning WireGuard 服务隧道运行状态 + * - TRUE 已经安装 + * - FALSE 未安装 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_OPEN_SCM 打开服务管理器设备失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetWireGuardServiceStatus(const TCHAR *pTunnelName, bool *pIsRunning); + +/** + * @brief 判断当前网络服务工作模式 客户端/服务端 + * @param[out] pIsWorkServer 工作模式 + * - TRUE 服务端 + * - FALSE 客户端 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetWireGuardWorkMode(bool *pIsWorkServer); + +/** + * @brief 获取当前 WireGuard 服务隧道是否正则运行 + * @param[in] pIfName WireGuard 隧道网络接口名称 + * @param[out] pIsRunning WireGuard 服务隧道运行状态 + * - TRUE 已经安装 + * - FALSE 未安装 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl IsWireGuardServerRunning(const TCHAR *pIfName, bool *pIsRunning); + +/** + * @brief SCG 代理服务开关 + * @param isEnable TRUE: 启动 SCG 代理, FALSE: 禁用 SCG 代理 + * @param pSCGIpAddr SCG 代理 IP + * @param scgPort SCG 代理端口 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl EnableSCGProxy(bool isEnable, const TCHAR *pSCGIpAddr, int scgPort); + +/** + * @brief 获取当前 SCG 代理服务状态 + * @return TRUE: SCG 代理启动, FALSE: SCG 代理禁用 + */ +SCCSDK_API bool __cdecl UsedSCGProxy(); + +/** + * @brief 获取当前网络共享模式 + * @return 当前网络共享模式 @see NET_SHARE_MODE + */ +SCCSDK_API NET_SHARE_MODE __cdecl GetCurrentNetShareMode(); + +/** + * @brief 设置获取当前网络共享模式 + * @param shareMode 网络共享模式 @see NET_SHARE_MODE + */ +SCCSDK_API void __cdecl SetCurrentNetShareMode(NET_SHARE_MODE shareMode); + +/** + * @brief 获取本机网卡信息 + * @param[out] pInfo 网卡信息 @see NIC_CONTENT + * @param[out] pItemCounts 计算机当前操作系统中网卡总数 最大值 32 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetAllNICInfo(PNIC_CONTENT *pInfo, int *pItemCounts); + +/** + * @brief 获取当前 Internet 网卡名 + * @param[out] pIfIndex 网卡索引 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 未找到具有 Internet 连接的网卡 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetInternetIfIndex(int *pIfIndex); + +/** + * @brief 判断当前网络适配器是否拥有 Internet 连接 + * @param[in] ifIndex 网卡适配器索引 + * @param[in] pRet 连接属性 + * - TRUE 网卡适配器具有 Internet 连接 + * - FALSE 网卡适配器无 Internet 连接 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_GET_IPFOWARDTBL 获取防火墙信息失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl IsInternetConnectAdapter(int ifIndex, bool *pRet); + +/** + * @brief 打开平台接口签名验证功能 + * @param pClientId 客户端 ID + * @param pClientSecret 客户端秘钥 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl EnableVerifySignature(const TCHAR *pClientId, const TCHAR *pClientSecret); + +/** + * @brief 关闭平台接口签名验证功能 + */ +SCCSDK_API void __cdecl DisableVerifySignature(); + +/** + * @brief 云电脑服务端创建控制服务 + * @param pSvr 服务端配置信息 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_CREATE_THREAD 创建线程失败 + * - -ERR_SOCKET_BIND_PORT 绑定端口失败 + * - -ERR_ITEM_EXISTS 服务线程状态异常 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl CreateControlService(PUSER_SERVER_CONFIG pSvr); + +/** + * @brief 获取用户服务端配置信息 + * @param[in] pUserName 用户名 + * @param[in] pToken 用户访问令牌 + * @param[out] pSvrCfg 服务端用户配置信息 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CREATE_FILE 创建用户配置目录失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetUserServerConfigure(const TCHAR *pUserName, + const TCHAR *pToken, + PUSER_SERVER_CONFIG *pSvrCfg); + +/** + * @brief 获取用户客户端配置信息 + * @param[in] pUserName 用户名 + * @param[in] pToken 用户访问令牌 + * @param[out] pCliCfg 客户端用户配置信息 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_CREATE_FILE 创建用户配置目录失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl GetUserClientConfigure(const TCHAR *pUserName, + const TCHAR *pToken, + PUSER_CLIENT_CONFIG *pCliCfg); + +/** + * @brief 停止云电脑服务端服务 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERROR_TIMEOUT 等待超时 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl StopControlService(); + +/** + * @brief 连接远程控制服务配置隧道参数 + * @param[in] vmId 需要连接的虚拟机ID编号 + * @param[in] pCliNetwork 需要共享的本地网络地址 + * @return 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_SYSTEM_UNINITIALIZE 服务端参数未初始化 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_MEMORY_STR 字符串处理失败 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_JSON_CREATE 创建 JSON 字符串失败 + * - -ERR_HTTP_POST_DATA 调用 POST 方法失败 + * - -ERR_HTTP_SERVER_RSP 服务端返回失败(非200) + * - -ERR_READ_FILE 服务端返回空数据 + * - -ERR_JSON_DECODE 解析 JSON 数据失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl RemoteCtrlSvrCfgUserTunnel(int vmId, const TCHAR *pCliNetwork); + +/** + * @brief 启动/停止远程云电脑中的 WireGuard 隧道服务 + * @param[in] isStart 启动/停止服务 TRUE 启动服务, FALSE 停止服务 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_SYSTEM_UNINITIALIZE 未初始化远程服务 URL 地址 + * - -ERR_CREATE_FILE 创建用户配置目录失败 + * - -ERR_HTTP_POST_DATA POST 数据到服务端失败 + * - -ERR_HTTP_SERVER_RSP HTTP 服务器返回错误 + * - -ERR_READ_FILE 服务端未返回正确的消息 + * - -ERR_JSON_DECODE JSON 字符串解码失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl RemoteWireGuardControl(bool isStart); + +/** + * @brief 启动/停止 本地 WireGuard 隧道服务 + * @param[in] isStart 启动/停止服务 TRUE 启动服务, FALSE 停止服务 + * @param[in] setPrivateMode 是否设置网卡工作模式为 专用网络模式(Private) + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_NET_CATEGORY_MODE 网卡工作模式错误 + * - -ERR_UN_SUPPORT 不支持的网络共享类型 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl LocalWireGuardControl(bool isStart, bool setPrivateMode); + +/** + * @brief 启动/停止 隧道控制服务心跳 + * @param isStart 启动/停止服务 TRUE 启动服务, FALSE 停止服务 + * @param lpHeartCbAddress 心跳服务回调函数 @see PTUNNEL_HEART_ROUTINE + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_CREATE_TIMER 创建定时器失败 + * - -ERR_DELETE_TIMER 删除定时器失败 + * - ERR_SUCCESS 成功 + */ +SCCSDK_API int __cdecl RemoteHeartControl(bool isStart, LPTUNNEL_HEART_ROUTINE lpHeartCbAddress); + +/** + * @brief 获取用户错误码字符串 + * @param err 用户错误码 + * @return 用户错误码对应的字符串, "UNKNOWN": 未知错误 + */ +SCCSDK_API const CHAR* __cdecl GetSDKErrorMessage(USER_ERRNO err); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/tunnel.h b/NetTunnelSDK/include/tunnel.h new file mode 100644 index 0000000..8004a47 --- /dev/null +++ b/NetTunnelSDK/include/tunnel.h @@ -0,0 +1,200 @@ +#pragma once +#include "sccsdk.h" + +typedef enum { + CHK_SYSTEM_INIT, + CHK_WIREGUARD_CONFIG, + CHK_WIREGUARD_SERVICE, + CHK_WG_INTERFACE_PRIVATE, + CHK_MAX +} CHECK_FUNCTION; + +typedef struct { + CHECK_FUNCTION chk; + bool result; + TCHAR errMsg[MAX_PATH]; +} CHK_RESULT, *PCHK_RESULT; + +/** + * @brief WireGuard 服务端配置信息 + */ +typedef struct { + TCHAR Name[64]; ///< WireGuard 网卡名称 + TCHAR Address[32]; ///< WireGuard 本地网络IP地址 + TCHAR PrivateKey[64]; ///< WireGuard 本机私钥 + int ListenPort; ///< WireGuard 服务端监听端口 + + // 根据系统设计,不支持多个客户端同时连接 + TCHAR CliPubKey[64]; ///< WireGuard 客户端公钥 + TCHAR AllowNet[256]; ///< WireGuard 允许对端访问本地网络的配置 +} WGSERVER_CONFIG, *PWGSERVER_CONFIG; + +/** + * @brief WireGuard 客户端配置信息 + */ +typedef struct { + TCHAR Name[64]; ///< WireGuard 网卡名称 + TCHAR PrivateKey[64]; ///< WireGuard 本机私钥 + TCHAR Address[32]; ///< WireGuard 本地网络IP地址 + + // Peer Server + TCHAR SvrPubKey[64]; ///< WireGuard 服务端公钥 + TCHAR AllowNet[256]; ///< WireGuard 允许对端访问本地网络的配置 + TCHAR ServerURL[256]; ///< WireGuard 服务端IP地址和端口 +} WGCLIENT_CONFIG, *PWGCLIENT_CONFIG; + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif +/** + * @brief 设置传输协议加密方式,默认 CRYPTO_NONE + * @param[in] type 协议加密类型 @see PROTO_CRYPTO_TYPE + * @param[in] pProKey 加密秘钥,CRYPTO_NONE, CRYPTO_BASE64 无效忽略 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - ERR_SUCCESS 成功 + */ +int SetProtocolEncryptType(const PROTO_CRYPTO_TYPE type, const TCHAR *pProKey); + +/** + * @brief 创建 WireGuard 服务端配置文件 + * @param[in] pWgConfig 配置文件相关配置项 @see WGSERVER_CONFIG + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_MEMORY_STR 字符串处理失败 + * - ERR_SUCCESS 成功 + */ +int WireGuardCreateServerConfig(const PWGSERVER_CONFIG pWgConfig); + +/** + * @brief 创建 WireGuard 客户端配置文件 + * @param[in] pWgConfig 配置文件相关配置项 @see WGCLIENT_CONFIG + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_MEMORY_STR 字符串处理失败 + * - ERR_SUCCESS 成功 + */ +int WireGuardCreateClientConfig(const PWGCLIENT_CONFIG pWgConfig); + +/** + * @brief 通过 WireGuard 配置文件安装隧道服务 + * @param pInterfaceName 隧道服务名 + * @param pWGConfigFilePath 配置文件完整路径 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 找不掉对应的服务 + * - -ERR_OPEN_SCM 打开服务管理器设备失败 + * - -ERR_OPEN_SERVICE 打开服务失败 + * - -ERR_CREATE_SERVICE 创建服务失败 + * - -ERR_CONFIG_SERVICE 配置服务失败 + * - -ERR_START_SERVICE 停止服务失败 + * - ERR_SUCCESS 成功 + */ +int CreateWireGuardService(const TCHAR *pInterfaceName, const TCHAR *pWGConfigFilePath); + +int GetWireGuradTunnelInfo(const TCHAR *pTunnelName); + +/** + * @brief 移除 WireGuard 隧道服务 + * @param pTunnelName 隧道服务名 + * @param bIsWaitStop 是否等待隧道服务结束 TRUE: 等待服务结束, FALSE: 不等待直接结束 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_OPEN_SCM 打开服务管理器设备失败 + * - -ERR_OPEN_SERVICE 打开服务失败 + * - -ERR_STOP_SERVICE 停止服务失败 + * - -ERR_DELETE_SERVICE 删除服务失败 + * - ERR_SUCCESS 成功 + */ +int RemoveGuardService(const TCHAR *pTunnelName, bool bIsWaitStop); + +/** + * @brief 安装/卸载 WireGuard 服务 + * @param[in] bInstall TRUE 安装服务, FALSE 卸载服务 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + */ +int WireGuardInstallDefaultServerService(bool bInstall); + +/** + * @brief 通过 WireGuard 配置文件安装隧道服务 + * @param[in] pTunnelCfgPath 配置文件完整路径 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 配置文件不存在 + * - -ERR_CALL_SHELL 调用 WireGuard 外部服务失败 + * - ERR_SUCCESS 成功 + */ +int WireGuardInstallServerService(const TCHAR *pTunnelCfgPath); + +/** + * @brief 卸载 WireGuard 隧道服务 + * @param[in] pTunnelName 隧道服务名 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 配置文件不存在 + * - -ERR_CALL_SHELL 调用 WireGuard 外部服务失败 + * - ERR_SUCCESS 成功 + */ +int WireGuardUnInstallServerService(const TCHAR *pTunnelName); + +/** + * @brief 获取当前 WireGuard 服务是否安装 + * @param[out] pIsInstalled WireGuard 服务安装状态 + * - TRUE 已经安装 + * - FALSE 未安装 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_OPEN_SCM, 打开服务管理器设备 + * - -ERR_OPEN_SERVICE, 打开服务失败 + * - -ERR_GET_SERVICESSTATUS, 获取服务状态失败 + * - ERR_SUCCESS 成功 + */ +int IsWireGuardServerInstalled(bool *pIsInstalled); + +/** + * @brief 计算文件 Hash + * @param[in] type Hash 类型 @see HASH_TYPE + * @param[in] pPath 需要计算 Hash 值的文件路径 + * @param[out] outHash 计算结果 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 文件不存在 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_BCRYPT_OPEN 创建加解密算法失败 + * - -ERR_BCRYPT_GETPROPERTY 获取加解密算法属性失败 + * - -ERR_BCRYPT_CREATEHASH 创建 Hash 算法失败 + * - -ERR_BCRYPT_HASHDATA 计算 Hash 数据失败 + * - -ERR_BCRYPT_FINISHHASH 计算 Hash 结果失败 + * - ERR_SUCCESS 成功 + */ +int CalcFileHash(HASH_TYPE type, const TCHAR *pPath, TCHAR outHash[]); + +/** + * @brief 计算 HMAC HASH 值 + * @param[in] type Hash 类型 @see HASH_TYPE + * @param[in] pHashData 需要计算 Hash 值的数据 + * @param[in] inSize 需要计算 Hash 值的数据大小(字节数) + * @param[in] pKey HMAC Hash 秘钥 + * @param[in] keySize HMAC Hash 秘钥大小(字节数) + * @param[out] outHash 计算结果 + * @param[in] outBase64 是否以 BASE64 字符串输出 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 文件不存在 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_BCRYPT_OPEN 创建加解密算法失败 + * - -ERR_BCRYPT_GETPROPERTY 获取加解密算法属性失败 + * - -ERR_BCRYPT_CREATEHASH 创建 Hash 算法失败 + * - -ERR_BCRYPT_HASHDATA 计算 Hash 数据失败 + * - -ERR_BCRYPT_FINISHHASH 计算 Hash 结果失败 + * - ERR_SUCCESS 成功 + */ +int CalcHmacHash(HASH_TYPE type, PUCHAR pHashData, int inSize, PUCHAR pKey, int keySize, TCHAR outHash[], bool outBase64); +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/user.h b/NetTunnelSDK/include/user.h new file mode 100644 index 0000000..1e1b80d --- /dev/null +++ b/NetTunnelSDK/include/user.h @@ -0,0 +1,54 @@ +#pragma once +#include "sccsdk.h" + +#define HEART_PERIOD_MS (3000) + +/** + * + * @brief 本地计算机网卡信息 + */ +typedef struct { + int isCurrent; ///< 网卡 MAC 地址 + TCHAR CfgPath[260]; ///< 配置文件路径 +} USER_CFGFILE, *PUSER_CFGFILE; + +#ifdef __cplusplus // If used by C++ code, +extern "C" { +// we need to export the C interface +#endif +/** + * @brief 连接到服务端控制服务 + * @param pUserSvrUrl 服务端控制服务 URL 地址 + */ +void ConnectServerControlService(const TCHAR *pUserSvrUrl); + +/** + * @brief 设置本地 WireGuard 隧道配置 + * @param[in] pCliPrivateKey 隧道私钥 + * @param[in] pSvrPublicKey 隧道服务端公钥 + * @param[in] pSvrNetwork 可访问隧道服务的云电脑网络 + * @param[in] pCliNetwork 客户端共享网络地址 + * @param[in] pSvrTunnelAddr 服务端隧道网络 + * @param[in] pSvrEndPoint 隧道服务端接入地址 + * @return 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_SYSTEM_UNINITIALIZE 服务端参数未初始化 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_MEMORY_STR 字符串处理失败 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int SetTunnelConfigure(const TCHAR *pCliPrivateKey, + const TCHAR *pSvrPublicKey, + const TCHAR *pSvrNetwork, + const TCHAR *pCliNetwork, + const TCHAR *pSvrTunnelAddr, + const TCHAR *pSvrEndPoint); + + int GetUserConfigFiles(const TCHAR *pUserName, PUSER_CFGFILE* pCfgFile, int *pItems); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/include/usrerr.h b/NetTunnelSDK/include/usrerr.h new file mode 100644 index 0000000..27ef62e --- /dev/null +++ b/NetTunnelSDK/include/usrerr.h @@ -0,0 +1,69 @@ +#pragma once + +/** + * @brief SDK 常用错误码 + */ +enum USER_ERRNO { + ERR_SUCCESS = 0, ///< 成功 + ERR_INPUT_PARAMS, ///< 输入参数错误 + ERR_UN_SUPPORT, ///< 不支持的操作 + ERR_CALL_SHELL, ///< 调用Shell命令失败 + ERR_ITEM_EXISTS, ///< 该内容已经存在 + ERR_ITEM_UNEXISTS, ///< 该内容不存在 + ERR_SYS_INIT, ///< 系统中断 + ERR_SYS_CALL, ///< 系统调用 + ERR_LOAD_LIBRARY, ///< 加载系统库失败 + ERR_MAP_LIBRARY, ///< 加载系统库接口失败 + ERR_SYS_TIMEOUT, ///< 系统超时 + ERR_SYSTEM_UNINITIALIZE, ///< 系统未初始化 + ERR_CREATE_FILE, ///< 创建文件/目录失败 + ERR_OPEN_FILE, ///< 打开文件失败 + ERR_READ_FILE, ///< 读取文件失败 + ERR_WRITE_FILE, ///< 写入文件失败 + ERR_FILE_NOT_EXISTS, ///< 文件不存在 + ERR_FILE_LOCKED, ///< 文件被锁定 + ERR_GET_FILE_SIZE, ///< 获取文件大小失败 + ERR_FIND_FILE, ///< 查找文件失败 + ERR_COPY_FILE, ///< 复制文件失败 + ERR_OPEN_SCM, ///< 打开服务管理器设备失败 + ERR_OPEN_SERVICE, ///< 打开服务失败 + ERR_CREATE_SERVICE, ///< 创建服务失败 + ERR_START_SERVICE, ///< 启动服务失败 + ERR_STOP_SERVICE, ///< 停止服务失败 + ERR_DELETE_SERVICE, ///< 删除服务失败 + ERR_CONFIG_SERVICE, ///< 修改服务配置失败 + ERR_GET_SERVICESSTATUS, ///< 获取服务状态失败 + ERR_MALLOC_MEMORY, ///< 分配内存失败 + ERR_MMAP_MEMORY, ///< 共享内存失败 + ERR_MEMORY_STR, ///< 字符串操作失败 + ERR_CREATE_PROCESS, ///< 创建进程失败 + ERR_PROCESS_RETURN, ///< 进程调用返回失败 + ERR_CREATE_THREAD, ///< 创建线程失败 + ERR_CREATE_TIMER, ///< 创建定时器失败 + ERR_DELETE_TIMER, ///< 销毁定时器失败 + ERR_SOCKET_CREATE, ///< 创建 SOCKET 失败 + ERR_SOCKET_BIND, ///< 绑定 SOCKET 端口失败 + ERR_SOCKET_CONNECT, ///< 连接 TCP SOCKET 服务器失败 + ERR_SOCKET_LISTEN, ///< TCP SOCKET 服务监听失败 + ERR_SOCKET_BIND_PORT, ///< 绑定端口失败 + ERR_SOCKET_SET_OPT, ///< 设置 SOCKET 参数失败 + ERR_SOCKET_GET_OPT, ///< 读取 SOCKET 参数失败 + ERR_BCRYPT_OPEN = 100, ///< 创建加密算法失败 + ERR_BCRYPT_GETPROPERTY, ///< 获取加密算法属性失败 + ERR_BCRYPT_CREATEHASH, ///< 创建 Hash 算法失败 + ERR_BCRYPT_HASHDATA, ///< 计算 Hash 数据失败 + ERR_BCRYPT_FINISHHASH, ///< 计算 Hash 结果失败 + ERR_NET_UNCONNECT = 200, ///< 网络未连接 + ERR_NET_CATEGORY_MODE, ///< 网络工作模式 + ERR_NET_INTELNEL_ICS, ///< 共享 Intelnet 网络 ICS 共享失败 + ERR_NET_WIREGUARD_ICS, ///< 共享 WireGuard 网络 ICS 共享失败 + ERR_GET_IPFOWARDTBL = 300, ///< 获取系统 IP 转发表失败 + ERR_CREATE_COMMOBJECT = 400, ///< 创建 COM 对象失败 + ERR_CALL_COMMOBJECT, ///< 调用 COM 对象失败 + ERR_JSON_CREATE = 500, ///< 创建 JSON 对象失败 + ERR_JSON_DECODE, ///< 从 JSON 反序列化对象失败 + ERR_HTTP_SERVER_RSP = 600, ///< HTTP 服务端返回错误 + ERR_HTTP_POST_DATA, ///< 发送 POST 请求失败 + ERR_NET_ADD_ROUTE, ///< 添加路由失败 + ERR_NET_REMOVE_ROUTE, ///< 删除路由失败 +}; \ No newline at end of file diff --git a/NetTunnelSDK/misc/ipcalc.cpp b/NetTunnelSDK/misc/ipcalc.cpp new file mode 100644 index 0000000..6bc6646 --- /dev/null +++ b/NetTunnelSDK/misc/ipcalc.cpp @@ -0,0 +1,538 @@ +#include "pch.h" +#include +#include +#include +#include +#include + +#include "usrerr.h" +#include "misc.h" +#include "tunnel.h" + +static const TCHAR *p2_table(unsigned pow) { + static const TCHAR *pow2[] = { + TEXT("1"), + TEXT("2"), + TEXT("4"), + TEXT("8"), + TEXT("16"), + TEXT("32"), + TEXT("64"), + TEXT("128"), + TEXT("256"), + TEXT("512"), + TEXT("1024"), + TEXT("2048"), + TEXT("4096"), + TEXT("8192"), + TEXT("16384"), + TEXT("32768"), + TEXT("65536"), + TEXT("131072"), + TEXT("262144"), + TEXT("524288"), + TEXT("1048576"), + TEXT("2097152"), + TEXT("4194304"), + TEXT("8388608"), + TEXT("16777216"), + TEXT("33554432"), + TEXT("67108864"), + TEXT("134217728"), + TEXT("268435456"), + TEXT("536870912"), + TEXT("1073741824"), + TEXT("2147483648"), + TEXT("4294967296"), + TEXT("8589934592"), + TEXT("17179869184"), + TEXT("34359738368"), + TEXT("68719476736"), + TEXT("137438953472"), + TEXT("274877906944"), + TEXT("549755813888"), + TEXT("1099511627776"), + TEXT("2199023255552"), + TEXT("4398046511104"), + TEXT("8796093022208"), + TEXT("17592186044416"), + TEXT("35184372088832"), + TEXT("70368744177664"), + TEXT("140737488355328"), + TEXT("281474976710656"), + TEXT("562949953421312"), + TEXT("1125899906842624"), + TEXT("2251799813685248"), + TEXT("4503599627370496"), + TEXT("9007199254740992"), + TEXT("18014398509481984"), + TEXT("36028797018963968"), + TEXT("72057594037927936"), + TEXT("144115188075855872"), + TEXT("288230376151711744"), + TEXT("576460752303423488"), + TEXT("1152921504606846976"), + TEXT("2305843009213693952"), + TEXT("4611686018427387904"), + TEXT("9223372036854775808"), + TEXT("18446744073709551616"), + TEXT("36893488147419103232"), + TEXT("73786976294838206464"), + TEXT("147573952589676412928"), + TEXT("295147905179352825856"), + TEXT("590295810358705651712"), + TEXT("1180591620717411303424"), + TEXT("2361183241434822606848"), + TEXT("4722366482869645213696"), + TEXT("9444732965739290427392"), + TEXT("18889465931478580854784"), + TEXT("37778931862957161709568"), + TEXT("75557863725914323419136"), + TEXT("151115727451828646838272"), + TEXT("302231454903657293676544"), + TEXT("604462909807314587353088"), + TEXT("1208925819614629174706176"), + TEXT("2417851639229258349412352"), + TEXT("4835703278458516698824704"), + TEXT("9671406556917033397649408"), + TEXT("19342813113834066795298816"), + TEXT("38685626227668133590597632"), + TEXT("77371252455336267181195264"), + TEXT("154742504910672534362390528"), + TEXT("309485009821345068724781056"), + TEXT("618970019642690137449562112"), + TEXT("1237940039285380274899124224"), + TEXT("2475880078570760549798248448"), + TEXT("4951760157141521099596496896"), + TEXT("9903520314283042199192993792"), + TEXT("19807040628566084398385987584"), + TEXT("39614081257132168796771975168"), + TEXT("79228162514264337593543950336"), + TEXT("158456325028528675187087900672"), + TEXT("316912650057057350374175801344"), + TEXT("633825300114114700748351602688"), + TEXT("1267650600228229401496703205376"), + TEXT("2535301200456458802993406410752"), + TEXT("5070602400912917605986812821504"), + TEXT("10141204801825835211973625643008"), + TEXT("20282409603651670423947251286016"), + TEXT("40564819207303340847894502572032"), + TEXT("81129638414606681695789005144064"), + TEXT("162259276829213363391578010288128"), + TEXT("324518553658426726783156020576256"), + TEXT("649037107316853453566312041152512"), + TEXT("1298074214633706907132624082305024"), + TEXT("2596148429267413814265248164610048"), + TEXT("5192296858534827628530496329220096"), + TEXT("10384593717069655257060992658440192"), + TEXT("20769187434139310514121985316880384"), + TEXT("41538374868278621028243970633760768"), + TEXT("83076749736557242056487941267521536"), + TEXT("166153499473114484112975882535043072"), + TEXT("332306998946228968225951765070086144"), + TEXT("664613997892457936451903530140172288"), + TEXT("1329227995784915872903807060280344576"), + TEXT("2658455991569831745807614120560689152"), + TEXT("5316911983139663491615228241121378304"), + TEXT("10633823966279326983230456482242756608"), + TEXT("21267647932558653966460912964485513216"), + TEXT("42535295865117307932921825928971026432"), + TEXT("85070591730234615865843651857942052864"), + TEXT("170141183460469231731687303715884105728"), + }; + if (pow <= 127) { + return pow2[pow]; + } + return TEXT(""); +} + +static int vasprintf(TCHAR **strp, const TCHAR *fmt, va_list ap) { + // _vscprintf tells you how big the buffer needs to be + const int len = _vscprintf(fmt, ap); + if (len == -1) { + return -1; + } + const size_t size = static_cast(len) + 1; + const auto str = static_cast(malloc(size)); + if (!str) { + return -1; + } + // _vsprintf_s is the "secure" version of vsprintf + const int r = vsprintf_s(str, len + 1, fmt, ap); + if (r == -1) { + free(str); + return -1; + } + *strp = str; + return r; +} + +static int asprintf(TCHAR **strp, const TCHAR *fmt, ...) { + va_list ap; + va_start(ap, fmt); + const int r = vasprintf(strp, fmt, ap); + va_end(ap); + return r; +} + +static int bit_count(unsigned int i) { + int c = 0; + unsigned int seen_one = 0; + + while (i > 0) { + if (i & 1) { + seen_one = 1; + c++; + } else { + if (seen_one) { + return -1; + } + } + i >>= 1; + } + + return c; +} + +/** + * @brief creates a netmask from a specified number of bits + * This function converts a prefix length to a netmask. As CIDR (classless + * internet domain internet domain routing) has taken off, more an more IP + * addresses are being specified in the format address/prefix + * (i.e. 192.168.2.3/24, with a corresponding netmask 255.255.255.0). If you + * need to see what netmask corresponds to the prefix part of the address, this + * is the function. See also @ref mask2prefix. + * @param prefix prefix is the number of bits to create a mask for. + * @return a network mask, in network byte order. + */ +unsigned int prefix2mask(int prefix) { + if (prefix) { + return htonl(~((1 << (32 - prefix)) - 1)); + } else { + return htonl(0); + } +} + +/** +* @brief calculates the number of bits masked off by a netmask. +* This function calculates the significant bits in an IP address as specified by +* a netmask. See also @ref prefix2mask. +* @param mask is the netmask, specified as an struct in_addr in network byte order. +* @return the number of significant bits. +*/ +int mask2prefix(IN_ADDR mask) { + return bit_count(ntohl(mask.s_addr)); +} + +static int ipv4_mask_to_int(const char *prefix) { + int ret; + IN_ADDR in; + + ret = inet_pton(AF_INET, prefix, &in); + if (ret == 0) { + return -1; + } + + return mask2prefix(in); +} + +/** +* @brief calculate broadcast address given an IP address and a prefix length. +* @param addr an IP address in network byte order. +* @param prefix a prefix length. +* @return the calculated broadcast address for the network, in network byte order. +*/ +static IN_ADDR calc_broadcast(IN_ADDR addr, int prefix) { + IN_ADDR mask; + IN_ADDR broadcast; + + mask.s_addr = prefix2mask(prefix); + + memset(&broadcast, 0, sizeof(broadcast)); + broadcast.s_addr = (addr.s_addr & mask.s_addr) | ~mask.s_addr; + return broadcast; +} + +/** +* @brief calculates the network address for a specified address and prefix. +* @param addr an IP address, in network byte order +* @param prefix the network prefix +* @return the base address of the network that addr is associated with, in +* network byte order. +*/ +static IN_ADDR calc_network(IN_ADDR addr, int prefix) { + IN_ADDR mask; + IN_ADDR network; + + mask.s_addr = prefix2mask(prefix); + + memset(&network, 0, sizeof(network)); + network.s_addr = addr.s_addr & mask.s_addr; + return network; +} + +static TCHAR *ipv4_prefix_to_hosts(TCHAR *hosts, unsigned hosts_size, unsigned prefix) { + if (prefix >= 31) { + StringCbPrintf(hosts, hosts_size, TEXT("%s"), p2_table(32 - prefix)); + } else { + unsigned int tmp; + tmp = (1 << (32 - prefix)) - 2; + StringCbPrintf(hosts, hosts_size, TEXT("%u"), tmp); + } + return hosts; +} + +static int str_to_prefix(int *ipv6, const char *prefixStr, unsigned fix) { + int prefix; + + if (!(*ipv6) && strchr(prefixStr, '.')) { /* prefix is 255.x.x.x */ + prefix = ipv4_mask_to_int(prefixStr); + } else { + prefix = strtol(prefixStr, nullptr, 10); + } + + if (fix && (prefix > 32 && !(*ipv6))) { + *ipv6 = 1; + } + + if (prefix < 0 || (((*ipv6) && prefix > 128) || (!(*ipv6) && prefix > 32))) { + return -1; + } + return prefix; +} + +static int GetIpV4Info(const TCHAR *pIpStr, int prefix, PIP_INFO pInfo, unsigned int flags) { + IN_ADDR ip, netmask, network, broadcast, minhost, maxhost; + TCHAR namebuf[INET_ADDRSTRLEN + 1]; + TCHAR *ipStr = _strdup(pIpStr); + + memset(pInfo, 0, sizeof(*pInfo)); + + if (inet_pton(AF_INET, ipStr, &ip) <= 0) { + SPDLOG_ERROR(TEXT("ipcalc: bad IPv4 address: {0}"), ipStr); + free(ipStr); + return -ERR_UN_SUPPORT; + } + + /* Handle CIDR entries such as 172/8 */ + if (prefix >= 0) { + auto tmp = const_cast(ipStr); + int i; + + for (i = 3; i > 0; i--) { + tmp = strchr(tmp, '.'); + if (!tmp) { + break; + } else { + tmp++; + } + } + + tmp = nullptr; + for (; i > 0; i--) { + if (asprintf(&tmp, "%s.0", ipStr) == -1) { + SPDLOG_ERROR(TEXT("Memory allocation failure")); + free(ipStr); + return -ERR_MALLOC_MEMORY; + } + ipStr = tmp; + } + } else { // assume good old days classful Internet + prefix = 32; + } + + if (prefix > 32) { + SPDLOG_ERROR(TEXT("ipcalc: bad IPv4 prefix: {0}"), prefix); + free(ipStr); + return -ERR_UN_SUPPORT; + } + + if (inet_ntop(AF_INET, &ip, namebuf, sizeof(namebuf)) == 0) { + SPDLOG_ERROR(TEXT("ipcalc: error calculating the IPv4 network")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + StringCbCopy(pInfo->ip, MAX_IP_LEN, namebuf); + + netmask.s_addr = prefix2mask(prefix); + memset(namebuf, '\0', sizeof(namebuf)); + + if (inet_ntop(AF_INET, &netmask, namebuf, INET_ADDRSTRLEN) == nullptr) { + SPDLOG_ERROR(TEXT("inet_ntop error")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + StringCbCopy(pInfo->netmask, MAX_IP_LEN, namebuf); + + pInfo->prefix = prefix; + + broadcast = calc_broadcast(ip, prefix); + + memset(namebuf, '\0', sizeof(namebuf)); + if (inet_ntop(AF_INET, &broadcast, namebuf, INET_ADDRSTRLEN) == nullptr) { + SPDLOG_ERROR(TEXT("inet_ntop error")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + + StringCbCopy(pInfo->broadcast, MAX_IP_LEN, namebuf); + + network = calc_network(ip, prefix); + + memset(namebuf, '\0', sizeof(namebuf)); + if (inet_ntop(AF_INET, &network, namebuf, INET_ADDRSTRLEN) == nullptr) { + SPDLOG_ERROR(TEXT("inet_ntop error")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + + StringCbCopy(pInfo->network, MAX_IP_LEN, namebuf); + + if (prefix < 32) { + memcpy(&minhost, &network, sizeof(minhost)); + + if (prefix <= 30) { + minhost.s_addr = htonl(ntohl(minhost.s_addr) | 1); + } + if (inet_ntop(AF_INET, &minhost, namebuf, INET_ADDRSTRLEN) == nullptr) { + SPDLOG_ERROR(TEXT("inet_ntop error")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + StringCbCopy(pInfo->hostmin, MAX_IP_LEN, namebuf); + + memcpy(&maxhost, &network, sizeof(minhost)); + maxhost.s_addr |= ~netmask.s_addr; + if (prefix <= 30) { + maxhost.s_addr = htonl(ntohl(maxhost.s_addr) - 1); + } + if (inet_ntop(AF_INET, &maxhost, namebuf, sizeof(namebuf)) == 0) { + SPDLOG_ERROR(TEXT("ipcalc: error calculating the IPv4 network")); + free(ipStr); + return -ERR_UN_SUPPORT; + } + + StringCbCopy(pInfo->hostmax, MAX_IP_LEN, namebuf); + } else { + StringCbCopy(pInfo->hostmin, MAX_IP_LEN, pInfo->network); + StringCbCopy(pInfo->hostmax, MAX_IP_LEN, pInfo->network); + } + + ipv4_prefix_to_hosts(pInfo->hosts, sizeof(pInfo->hosts), prefix); + + free(ipStr); + return ERR_SUCCESS; +} + +/** + * @brief 计算 IPv4 网络信息 + * @param[in] pIpStr IPv4 网络信息 '/' 分割,支持CIDR以及子网掩码 example: 192.168.1.32/24, 192.168.1.32/255.255.255.0 + * @param[out] pInfo 计算结果 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetIpV4InfoFromCIDR(const TCHAR *pIpStr, PIP_INFO pInfo) { + int ret, prefix, familyIPv6 = 0; + TCHAR *prefixStr; + TCHAR *ipStr = _strdup(pIpStr); + + if (pIpStr == nullptr || lstrlen(pIpStr) < MIN_IP_LEN || lstrlen(pIpStr) >= MAX_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pIpStr format error: {}."), pIpStr); + return -ERR_INPUT_PARAMS; + } + + if (pInfo == nullptr) { + SPDLOG_ERROR(TEXT("Input pInfo is NULL.")); + return -ERR_INPUT_PARAMS; + } + + if (strchr(ipStr, '/') != nullptr) { + prefixStr = static_cast(strchr(ipStr, '/')); + *prefixStr = '\0'; /* fix up ipStr */ + prefixStr++; + } else { + SPDLOG_ERROR(TEXT("Input pIpStr isn't CIDR format: {}."), pIpStr); + free(ipStr); + return -ERR_INPUT_PARAMS; + } + + if (strchr(prefixStr, '.') != nullptr) { + prefix = ipv4_mask_to_int(prefixStr); + } else { + prefix = str_to_prefix(&familyIPv6, prefixStr, 0); + } + + ret = GetIpV4Info(ipStr, prefix, pInfo, 0); + + free(ipStr); + return ret; +} + +/** + * @brief 计算 IPv4 网络信息 + * @param[in] pIpStr IPv4 地址 + * @param[in] pNetmask IPv4子网掩码 + * @param[out] pInfo 计算结果 + * @return 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_UN_SUPPORT 不支持的格式转换 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - ERR_SUCCESS 成功 + */ +int GetIpV4InfoFromNetmask(const TCHAR *pIpStr, const TCHAR *pNetmask, PIP_INFO pInfo) { + int prefix; + + if (pIpStr == nullptr || lstrlen(pIpStr) < MIN_IP_LEN || lstrlen(pIpStr) >= MAX_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pIpStr format error: {}."), pIpStr); + return -ERR_INPUT_PARAMS; + } + + if (pNetmask == nullptr || lstrlen(pNetmask) < MIN_IP_LEN || lstrlen(pNetmask) >= MAX_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pNetmask format error: {}."), pNetmask); + return -ERR_INPUT_PARAMS; + } + + if (pInfo == nullptr) { + SPDLOG_ERROR(TEXT("Input pInfo is NULL.")); + return -ERR_INPUT_PARAMS; + } + + prefix = ipv4_mask_to_int(pNetmask); + return GetIpV4Info(pIpStr, prefix, pInfo, 0); +} + +int GetIpV4InfoFromHostname(int family, const char *host, PIP_INFO pInfo) { + addrinfo *res, *rp; + addrinfo hints {}; + int err; + static char ipname[64]; + void *addr; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = family; + + err = getaddrinfo(host, nullptr, &hints, &res); + if (err != 0) { + return -ERR_INPUT_PARAMS; + } + + for (rp = res; rp != nullptr; rp = rp->ai_next) { + if (rp->ai_family == AF_INET) { + addr = (&reinterpret_cast(rp->ai_addr)->sin_addr); + } else { + addr = (&reinterpret_cast(rp->ai_addr)->sin6_addr); + } + + if (inet_ntop(rp->ai_family, addr, ipname, sizeof(ipname)) != nullptr) { + freeaddrinfo(res); + StringCbCopy(pInfo->hostip, MAX_IP_LEN, ipname); + return ERR_SUCCESS; + } + } + + freeaddrinfo(res); + return ERR_ITEM_EXISTS; +} \ No newline at end of file diff --git a/NetTunnelSDK/misc/misc.cpp b/NetTunnelSDK/misc/misc.cpp new file mode 100644 index 0000000..f40e8b0 --- /dev/null +++ b/NetTunnelSDK/misc/misc.cpp @@ -0,0 +1,335 @@ +#include "pch.h" +#include "misc.h" + +#include "sccsdk.h" + +#include +#include +#include +#include + +TCHAR *binToHexString(TCHAR *p, const unsigned char *cp, unsigned int count) { + static const TCHAR hex_asc[] = TEXT("0123456789abcdef"); + while (count) { + const unsigned char c = *cp++; + /* put lowercase hex digits */ + *p++ = static_cast(0x20 | hex_asc[c >> 4]); + *p++ = static_cast(0x20 | hex_asc[c & 0xf]); + count--; + } + + return p; +} + +void RemoveTailLineBreak(TCHAR *pInputStr, int strSize) { + size_t length; + if (pInputStr) { + if (StringCbLength(pInputStr, strSize, &length) == S_OK && length > 0) { + if (pInputStr[length - 2] == '\r' && pInputStr[length - 1] == '\n') { + pInputStr[length - 2] = pInputStr[length - 1] = 0; + } else if (pInputStr[length - 1] == '\n') { + pInputStr[length - 1] = 0; + } + } + } +} + +int RunCommand(TCHAR *pszCmd, TCHAR *pszResultBuffer, int dwResultBufferSize, unsigned long *pRetCode) { + BOOL bRet; + HANDLE hReadPipe = nullptr; + HANDLE hWritePipe = nullptr; + DWORD retCode; + STARTUPINFO si; + PROCESS_INFORMATION pi; + SECURITY_ATTRIBUTES securityAttributes; + + if (pszCmd == nullptr) { + SPDLOG_ERROR(TEXT("Input params Error: [{0}]"), pszCmd); + return -ERR_INPUT_PARAMS; + } + + if (pszResultBuffer && dwResultBufferSize > 0) { + memset(pszResultBuffer, 0, dwResultBufferSize); + } + + memset(&si, 0, sizeof(STARTUPINFO)); + memset(&pi, 0, sizeof(PROCESS_INFORMATION)); + + // 设定管道的安全属性 + securityAttributes.bInheritHandle = TRUE; + securityAttributes.nLength = sizeof(securityAttributes); + securityAttributes.lpSecurityDescriptor = nullptr; + + // 创建匿名管道 + bRet = ::CreatePipe(&hReadPipe, &hWritePipe, &securityAttributes, 0); + if (FALSE == bRet) { + SPDLOG_ERROR(TEXT("CreatePipe Error")); + return -ERR_SYS_CALL; + } + + // 设置新进程参数 + si.cb = sizeof(si); + si.hStdError = hWritePipe; + si.hStdOutput = hWritePipe; + si.wShowWindow = SW_HIDE; + si.dwFlags = STARTF_USESHOWWINDOW | STARTF_USESTDHANDLES; + + // 创建新进程执行命令, 将执行结果写入匿名管道中 + bRet = ::CreateProcess(nullptr, (pszCmd), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi); + if (FALSE == bRet) { + SPDLOG_ERROR(TEXT("CreateProcess Error")); + return -ERR_CREATE_PROCESS; + } + + // 等待命令执行结束 + //::WaitForSingleObject(pi.hThread, INFINITE); + ::WaitForSingleObject(pi.hThread, 3000); + ::WaitForSingleObject(pi.hProcess, 3000); + + if (pszResultBuffer) { + // 从匿名管道中读取结果到输出缓冲区 + ::RtlZeroMemory(pszResultBuffer, dwResultBufferSize); + ::ReadFile(hReadPipe, pszResultBuffer, dwResultBufferSize, nullptr, nullptr); + } + + // 获取调用程序返回值 + if (pRetCode) { + if (GetExitCodeProcess(pi.hProcess, &retCode)) { + *pRetCode = retCode; + } + } + + // 关闭句柄, 释放内存 + ::CloseHandle(pi.hThread); + ::CloseHandle(pi.hProcess); + ::CloseHandle(hWritePipe); + ::CloseHandle(hReadPipe); + + RemoveTailLineBreak(pszResultBuffer, dwResultBufferSize); + //pszResultBuffer[dwResultBufferSize - 1] = 0; + return ERR_SUCCESS; +} + +void ShowWindowsErrorMessage(const TCHAR *pMsgHead) { + LPVOID buf; + + if (FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_IGNORE_INSERTS | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_MAX_WIDTH_MASK, + nullptr, + GetLastError(), + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPSTR)&buf, + 0, + nullptr)) { + SPDLOG_ERROR(TEXT("{0} Error({1}): {2}"), pMsgHead, GetLastError(), buf); + LocalFree(buf); + } else { + SPDLOG_ERROR(TEXT("{0} Unknown Error{1}."), pMsgHead, GetLastError()); + } +} + +void StringReplaceAll(TCHAR *pOrigin, const TCHAR *pOldStr, const TCHAR *pNewStr) { + using namespace std; + const int maxSize = lstrlen(pOrigin); + auto src = string(pOrigin); + for (string::size_type pos(0); pos != string::npos; pos += lstrlen(pNewStr)) { + if ((pos = src.find(pOldStr, pos)) != string::npos) { + src.replace(pos, lstrlen(pOldStr), pNewStr); + } else { + break; + } + } + + memset(pOrigin, 0, maxSize); + StringCbCopyA(pOrigin, maxSize, src.c_str()); +} + +void StringRemoveAll(TCHAR *pOrigin, const TCHAR *pString) { + StringReplaceAll(pOrigin, pString, ""); +} + +int FindFile(const TCHAR *pPath, PFILE_LIST pFileList, const bool exitWhenMatchOne) { + std::vector pathList; + TCHAR rootPath[MAX_PATH]; + HANDLE hFind; + WIN32_FIND_DATA ffd; + + hFind = FindFirstFile(pPath, &ffd); + + if (INVALID_HANDLE_VALUE == hFind) { + return ERR_ITEM_UNEXISTS; + } + + StringCbCopy(rootPath, MAX_PATH, pPath); + PathRemoveFileSpec(rootPath); + + do { + if (ffd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { + if (strcmp(ffd.cFileName, ".") == 0 || strcmp(ffd.cFileName, "..") == 0) { + continue; + } + } else { + TCHAR tmp[MAX_PATH]; + PathCombine(tmp, rootPath, ffd.cFileName); + + pathList.push_back(_strdup(tmp)); + if (exitWhenMatchOne) { + break; + } + } + } while (FindNextFile(hFind, &ffd) != 0); + + FindClose(hFind); + + if (GetLastError() != ERROR_NO_MORE_FILES) { + return -ERR_FIND_FILE; + } + + pFileList->pFilePath = static_cast(HeapAlloc(GetProcessHeap(), 0, pathList.size() * sizeof(FILE_PATH))); + + if (pFileList->pFilePath == nullptr) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), pathList.size() * sizeof(FILE_PATH)); + return -ERR_MALLOC_MEMORY; + } + + pFileList->nItems = static_cast(pathList.size()); + + memset(pFileList->pFilePath, 0, pathList.size() * sizeof(FILE_PATH)); + + for (size_t i = 0; i < pathList.size(); i++) { + StringCbCopy(pFileList->pFilePath[i].path, MAX_PATH, pathList.at(i)); + } + + for (auto iter = pathList.begin(); iter != pathList.end(); ++iter) { + if (*iter != nullptr) { + free(*iter); + (*iter) = nullptr; + } + } + + pathList.clear(); + std::vector tmpSwapVector; + tmpSwapVector.swap(pathList); + + return ERR_SUCCESS; +} + +int GetWindowsServiceStatus(const TCHAR *pSvrName, PDWORD pStatus) { + SC_HANDLE schSCManager; + SC_HANDLE schService; + SERVICE_STATUS_PROCESS ssStatus; + DWORD dwBytesNeeded = 0; + + if (pSvrName == nullptr || lstrlen(pSvrName) == 0) { + SPDLOG_ERROR(TEXT("Input pSvrName params error")); + return -ERR_INPUT_PARAMS; + } + + if (pStatus == nullptr) { + SPDLOG_ERROR(TEXT("Input pStatus params error")); + return -ERR_INPUT_PARAMS; + } + + // Get a handle to the SCM database. + schSCManager = OpenSCManager(nullptr, // local computer + nullptr, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (nullptr == schSCManager) { + SPDLOG_ERROR(TEXT("OpenSCManager failed {0}"), GetLastError()); + return -ERR_OPEN_SCM; + } + + // Get a handle to the service. + schService = OpenService(schSCManager, // SCM database + pSvrName, // name of service + SERVICE_QUERY_STATUS | SERVICE_ENUMERATE_DEPENDENTS); // full access + + if (schService == nullptr) { + SPDLOG_ERROR(TEXT("OpenService failed {0}"), GetLastError()); + CloseServiceHandle(schSCManager); + return -ERR_OPEN_SERVICE; + } + + // Check the status in case the service is not stopped. + if (!QueryServiceStatusEx(schService, // handle to service + SC_STATUS_PROCESS_INFO, // information level + reinterpret_cast(&ssStatus), // address of structure + sizeof(SERVICE_STATUS_PROCESS), // size of structure + &dwBytesNeeded)) // size needed if buffer is too small + { + SPDLOG_ERROR(TEXT("QueryServiceStatusEx failed {0}"), GetLastError()); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_GET_SERVICESSTATUS; + } else { + *pStatus = ssStatus.dwCurrentState; + } + + return ERR_SUCCESS; +} + +int WideCharToTChar(const WCHAR *pWStr, TCHAR *pOutStr, int maxOutLen) { + if constexpr (sizeof(TCHAR) == sizeof(WCHAR)) { + if (wcslen(pWStr) * sizeof(WCHAR) >= maxOutLen) { + SPDLOG_ERROR(TEXT("Output buffer is to short: {0} need {1}"), maxOutLen, wcslen(pWStr) * sizeof(WCHAR)); + return -ERR_INPUT_PARAMS; + } + memcpy(pOutStr, pWStr, wcslen(pWStr)); + } else { + int len = WideCharToMultiByte(CP_ACP, 0, pWStr, static_cast(wcslen(pWStr)), nullptr, 0, nullptr, nullptr); + + if (len >= maxOutLen) { + SPDLOG_ERROR(TEXT("Output buffer is to short: {0} need {1}"), maxOutLen, len); + return -ERR_INPUT_PARAMS; + } + + WideCharToMultiByte(CP_ACP, 0, pWStr, static_cast(wcslen(pWStr)), pOutStr, len, nullptr, nullptr); + pOutStr[len] = 0; + + return ERR_SUCCESS; + } +} + +int TCharToWideChar(const TCHAR *pTStr, WCHAR *pOutStr, int maxOutLen) { + if constexpr (sizeof(TCHAR) == sizeof(WCHAR)) { + if (lstrlen(pTStr) * sizeof(WCHAR) >= maxOutLen) { + SPDLOG_ERROR(TEXT("Output buffer is to short: {0} need {1}"), maxOutLen, lstrlen(pTStr) * sizeof(WCHAR)); + return -ERR_INPUT_PARAMS; + } + memcpy(pOutStr, pTStr, lstrlen(pTStr)); + } else { + int len = MultiByteToWideChar(CP_ACP, 0, pTStr, lstrlen(pTStr), nullptr, 0); + //int len = WideCharToMultiByte(CP_ACP, 0, pWStr, static_cast(wcslen(pWStr)), nullptr, 0, nullptr, nullptr); + + if (len >= maxOutLen) { + SPDLOG_ERROR(TEXT("Output buffer is to short: {0} need {1}"), maxOutLen, len); + return -ERR_INPUT_PARAMS; + } + + MultiByteToWideChar(CP_ACP, 0, pTStr, lstrlen(pTStr), pOutStr, len); + pOutStr[len] = 0; + + return ERR_SUCCESS; + } +} + +static std::unordered_map g_UserErrorMap; + +const CHAR *GetSDKErrorMessage(USER_ERRNO err) { + std::unordered_map::iterator iter; + + if (g_UserErrorMap.empty()) { + constexpr auto color_entries = magic_enum::enum_entries(); + + for (auto colorEntry : color_entries) { + g_UserErrorMap.emplace(colorEntry.first, std::string(colorEntry.second)); + } + } + + if ((iter = g_UserErrorMap.find(err)) != g_UserErrorMap.end()) { + return iter->second.c_str(); + } else { + return "UNKNOWN"; + } +} \ No newline at end of file diff --git a/NetTunnelSDK/network/ControlService.cpp b/NetTunnelSDK/network/ControlService.cpp new file mode 100644 index 0000000..7289867 --- /dev/null +++ b/NetTunnelSDK/network/ControlService.cpp @@ -0,0 +1,338 @@ +#include "pch.h" +#include "tunnel.h" +#include +#include + +#include "globalcfg.h" +#include "httplib.h" +#include "misc.h" +#include "network.h" +#include "protocol.h" +#include "usrerr.h" +#include "user.h" + +#include + +static HANDLE g_ControlSvrThread = nullptr; +static httplib::Server g_httpServer; +static USER_SERVER_CONFIG g_UserSvrCfg; + +/** + * @brief 连接到服务端控制服务 + * @param pUserSvrUrl 服务端控制服务 URL 地址 + */ +void ConnectServerControlService(const TCHAR *pUserSvrUrl) { + InitControlServer(pUserSvrUrl); +} + +static void HttpResponseError(httplib::Response &pRes, int errCode, const TCHAR *pErrMessage) { + ProtocolResponse rsp; + std::string json; + + if (errCode != ERR_SUCCESS) { + rsp.msgContent.errCode = errCode; + } + + if (pErrMessage && lstrlen(pErrMessage) > 0) { + rsp.msgContent.errMessage = pErrMessage; + } else { + if (errCode == ERR_SUCCESS) { + rsp.msgContent.errMessage = TEXT("OK"); + } + } + + if (aigc::JsonHelper::ObjectToJson(rsp, json)) { + pRes.set_content(json, TEXT("application/json")); + } else { + SPDLOG_ERROR(TEXT("ProtocolResponse to json error")); + } +} + +int CreateControlService(PUSER_SERVER_CONFIG pSvr) { + static TCHAR g_CliNetwork[MAX_IP_LEN] = {}; + static WGSERVER_CONFIG g_curCliConfig = {}; + static std::mutex g_InterfaceMutex; + DWORD dwStatus = 0; + + // HTTP 服务已经运行 + if (g_httpServer.is_running()) { + return ERR_SUCCESS; + } + + // 线程已经运行 + if (g_ControlSvrThread && GetExitCodeThread(g_ControlSvrThread, &dwStatus) && dwStatus == STILL_ACTIVE) { + return -ERR_ITEM_EXISTS; + } + + if (pSvr == nullptr) { + SPDLOG_ERROR(TEXT("Input pSvr params error")); + return -ERR_INPUT_PARAMS; + } + + if (pSvr->svrListenPort <= 0 || pSvr->svrListenPort >= 65535) { + SPDLOG_ERROR(TEXT("Input svrListenPort params error {0}"), pSvr->svrListenPort); + return -ERR_INPUT_PARAMS; + } + + if (lstrlen(pSvr->svrPrivateKey) != lstrlen(TEXT("4PPcnW3wYewNpoXjNoY3hQuCnzTNq/E9hhfU9/U6QmY="))) { + SPDLOG_ERROR(TEXT("Input svrPrivateKey params length error {0}"), pSvr->svrPrivateKey); + return -ERR_INPUT_PARAMS; + } + + if (lstrlen(pSvr->svrAddress) == 0) { + SPDLOG_ERROR(TEXT("Input svrAddress params error {0}"), pSvr->svrAddress); + return -ERR_INPUT_PARAMS; + } + + // 保存参数 + memcpy(&g_UserSvrCfg, pSvr, sizeof(USER_SERVER_CONFIG)); + + g_httpServer.set_exception_handler([](const auto &req, auto &res, std::exception_ptr ep) { + const auto fmt = TEXT("

Error 500

%s

"); + char buf[BUFSIZ]; + try { + std::rethrow_exception(ep); + } + catch (std::exception &e) { + StringCbPrintf(buf, BUFSIZ, fmt, e.what()); + } + catch (...) { // See the following NOTE + StringCbPrintf(buf, BUFSIZ, fmt, TEXT("Unknown Exception")); + } + res.set_content(buf, TEXT("text/html")); + res.status = 500; + }); + + g_httpServer.set_error_handler([](const auto &req, auto &res) { + const auto fmt = TEXT("

Error Status: %d

"); + char buf[BUFSIZ]; + StringCbPrintf(buf, BUFSIZ, fmt, res.status); + res.set_content(buf, TEXT("text/html")); + }); + + g_httpServer.Post(SET_CLIENTHEART_PATH, [](const httplib::Request &req, httplib::Response &res) { + ProtocolResponse rsp; + std::string json; + + rsp.msgContent.errCode = ERR_SUCCESS; + rsp.msgContent.errMessage = TEXT("OK"); + + if (aigc::JsonHelper::ObjectToJson(rsp, json)) { + res.set_content(json, TEXT("application/json")); + } else { + SPDLOG_ERROR(TEXT("ProtocolResponse to json error")); + } + }); + + g_httpServer.Post(SET_CLIENTSTART_TUNNEL, [](const httplib::Request &req, httplib::Response &res) { + ProtocolRequest reqData; + + if (aigc::JsonHelper::JsonToObject(reqData, req.body)) { + int ret; + bool isSvrStart = false; + + g_InterfaceMutex.lock(); + // Because of COM return CO_E_FIRST + CoInitialize(nullptr); + + // 判断先前是否启动过服务 + if ((ret = IsWireGuardServerRunning(GetGlobalCfgInfo()->userCfg.userName, &isSvrStart)) != ERR_SUCCESS) { + // 返回获取系统服务错误,是否未安装 + HttpResponseError(res, ret, TEXT("Not found WireGuard application in system")); + SPDLOG_ERROR(TEXT("IsWireGuardServerInstalled error: {0}"), ret); + g_InterfaceMutex.unlock(); + return; + } + + // 当前服务状态和需要执行的操作不同 + if (isSvrStart != reqData.msgContent.isStart) { + if (reqData.msgContent.isStart) { + IP_INFO cliInfo; + IP_INFO tunnelInfo; + int retry = 3; + + // 启动服务 + ret = WireGuardInstallDefaultServerService(true); + if (ret != ERR_SUCCESS) { + // 返回启动服务失败 + SPDLOG_ERROR(TEXT("WireGuardInstallDefaultServerService error: {0}"), ret); + HttpResponseError(res, ret, TEXT("Start WireGuard Tunnel Service error.")); + g_InterfaceMutex.unlock(); + return; + } + // 添加路由 + if ((ret = GetIpV4InfoFromCIDR(g_CliNetwork, &cliInfo)) != ERR_SUCCESS) { + // 返回启动服务失败 + SPDLOG_ERROR(TEXT("GetIpV4InfoFromCIDR ({1}) error: {0}"), ret, g_CliNetwork); + HttpResponseError(res, ret, TEXT("Parse IpAddress error.")); + g_InterfaceMutex.unlock(); + return; + } + + if ((ret = GetIpV4InfoFromCIDR(g_UserSvrCfg.svrAddress, &tunnelInfo)) != ERR_SUCCESS) { + // 返回启动服务失败 + SPDLOG_ERROR(TEXT("GetIpV4InfoFromCIDR ({1}) error: {0}"), ret, g_UserSvrCfg.svrAddress); + HttpResponseError(res, ret, TEXT("Parse tunnel ip address error.")); + g_InterfaceMutex.unlock(); + return; + } + + do { + ret = AddRouteTable(cliInfo.ip, cliInfo.netmask, tunnelInfo.ip); + Sleep(1000); + } while (ret != ERR_SUCCESS && retry--); + + if (ret != ERR_SUCCESS) { + // 返回启动服务失败 + SPDLOG_ERROR(TEXT("Add Route {1}/{2} gateway {3} error: {0}"), + ret, + cliInfo.ip, + cliInfo.netmask, + tunnelInfo.ip); + HttpResponseError(res, ret, TEXT("Parse tunnel ip address error.")); + g_InterfaceMutex.unlock(); + return; + } + } else { + if ((ret = WireGuardUnInstallServerService(GetGlobalCfgInfo()->userCfg.userName)) != ERR_SUCCESS) { + // 返回停止服务失败 + HttpResponseError(res, ret, TEXT("Stop pre running WireGuard service error")); + SPDLOG_ERROR(TEXT("WireGuardUnInstallServerService error: {0}"), ret); + g_InterfaceMutex.unlock(); + return; + } + } + } + + if (reqData.msgContent.isStart) { + SPDLOG_INFO(TEXT("Tunnel Service Start Now ......: {0}"), GetGlobalCfgInfo()->userCfg.userName); + } else { + SPDLOG_INFO(TEXT("Tunnel Service Stoped: {0}"), GetGlobalCfgInfo()->userCfg.userName); + } + + HttpResponseError(res, ERR_SUCCESS, nullptr); + g_InterfaceMutex.unlock(); + } + }); + + g_httpServer.Post(SET_CLIENTCFG_PATH, [](const httplib::Request &req, httplib::Response &res) { + ProtocolRequest reqData; + + if (aigc::JsonHelper::JsonToObject(reqData, req.body)) { + int ret; + bool isSvrStart = false; + ProtocolResponse rsp; + std::string json; + + g_InterfaceMutex.lock(); + // Because of COM return CO_E_FIRST + CoInitialize(nullptr); + + // 判断先前是否启动过服务 + if ((ret = IsWireGuardServerRunning(GetGlobalCfgInfo()->userCfg.userName, &isSvrStart)) != ERR_SUCCESS) { + // 返回获取系统服务错误,是否未安装 + HttpResponseError(res, ret, TEXT("Not found WireGuard application in system")); + SPDLOG_ERROR(TEXT("IsWireGuardServerInstalled error: {0}"), ret); + g_InterfaceMutex.unlock(); + return; + } + + if (isSvrStart) { + SPDLOG_DEBUG(TEXT("WireGuardUnInstallServerService: {0}"), GetGlobalCfgInfo()->userCfg.userName); + if ((ret = WireGuardUnInstallServerService(GetGlobalCfgInfo()->userCfg.userName)) != ERR_SUCCESS) { + // 返回停止服务失败 + HttpResponseError(res, ret, TEXT("Stop pre running WireGuard service error")); + SPDLOG_ERROR(TEXT("WireGuardUnInstallServerService error: {0}"), ret); + g_InterfaceMutex.unlock(); + return; + } + } + + memset(&g_curCliConfig, 0, sizeof(WGSERVER_CONFIG)); + g_curCliConfig.ListenPort = g_UserSvrCfg.svrListenPort - 1; + StringCbCopy(g_curCliConfig.Name, 64, GetGlobalCfgInfo()->userCfg.userName); + StringCbCopy(g_curCliConfig.Address, 32, g_UserSvrCfg.svrAddress); + StringCbCopy(g_curCliConfig.PrivateKey, 64, g_UserSvrCfg.svrPrivateKey); + StringCbCopy(g_curCliConfig.CliPubKey, 64, reqData.msgContent.cliPublicKey.c_str()); + StringCbCopy(g_CliNetwork, MAX_IP_LEN, reqData.msgContent.cliNetwork.c_str()); + StringCbPrintf(g_curCliConfig.AllowNet, + 256, + TEXT("%s,%s"), + reqData.msgContent.cliNetwork.c_str(), + reqData.msgContent.cliTunnelAddr.c_str()); + + // 创建 WireGuard 配置文件 + ret = WireGuardCreateServerConfig(&g_curCliConfig); + if (ret != ERR_SUCCESS) { + // 返回写入 WireGuard 配置文件错误 + HttpResponseError(res, ret, TEXT("Create WireGuard service configure file error")); + SPDLOG_ERROR(TEXT("WireGuardCreateServerConfig error: {0}"), ret); + g_InterfaceMutex.unlock(); + return; + } + + // 返回当前隧道信息 + rsp.msgContent.errCode = ERR_SUCCESS; + rsp.msgContent.errMessage = TEXT("OK"); + rsp.msgContent.svrNetwork = g_UserSvrCfg.svrAddress; + + if (aigc::JsonHelper::ObjectToJson(rsp, json)) { + res.set_content(json, TEXT("application/json")); + } else { + SPDLOG_ERROR(TEXT("ProtocolResponse to json error")); + HttpResponseError(res, ERR_JSON_CREATE, TEXT("ProtocolResponse to json error")); + } + g_InterfaceMutex.unlock(); + } + }); + + SPDLOG_DEBUG(TEXT("Start HTTP Service at {0}"), pSvr->svrListenPort); + if (!g_httpServer.bind_to_port(TEXT("0.0.0.0"), pSvr->svrListenPort)) { + SPDLOG_ERROR(TEXT("Start HTTP Service at {0} error"), pSvr->svrListenPort); + return -ERR_SOCKET_BIND_PORT; + } + + g_ControlSvrThread = CreateThread( + nullptr, // Thread attributes + 0, // Stack size (0 = use default) + [](LPVOID lpParameter) { + if (!g_httpServer.listen_after_bind()) { + SPDLOG_ERROR(TEXT("Start HTTP Service at {0} error")); + } + + SPDLOG_DEBUG(TEXT("Http service exit.....")); + + return static_cast(0); + }, // Thread start address + nullptr, // Parameter to pass to the thread + 0, // Creation flags + nullptr); // Thread id + + if (g_ControlSvrThread == nullptr) { + // Thread creation failed. + // More details can be retrieved by calling GetLastError() + return -ERR_CREATE_THREAD; + } + + g_httpServer.wait_until_ready(); + + return ERR_SUCCESS; +} + +int StopControlService() { + if (g_httpServer.is_running()) { + g_httpServer.stop(); + } + + if (g_ControlSvrThread) { + // Wait for thread to finish execution + if (WaitForSingleObject(g_ControlSvrThread, 10 * 1000) == WAIT_TIMEOUT) { + SPDLOG_ERROR(TEXT("Waitting HTTP Service clost timeout")); + return -ERROR_TIMEOUT; + } + CloseHandle(g_ControlSvrThread); + g_ControlSvrThread = nullptr; + } + + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelSDK/network/ProxyService.cpp b/NetTunnelSDK/network/ProxyService.cpp new file mode 100644 index 0000000..c81990d --- /dev/null +++ b/NetTunnelSDK/network/ProxyService.cpp @@ -0,0 +1,285 @@ +#include "pch.h" + +#include +#include +#include "usrerr.h" +#include "misc.h" + +#if !USED_PORTMAP_TUNNEL +#include "globalcfg.h" +#include +#include +#include + +#define SCG_UDP_HEAD_SIZE (11) + +#endif + + +void StopUDPProxyServer() { +#if !USED_PORTMAP_TUNNEL + const PSCG_PROXY_INFO pProxy = &GetGlobalCfgInfo()->scgProxy; + pProxy->exitNow = true; + + if (pProxy->hProxyTunnelThread) { + if (WaitForSingleObject(pProxy->hProxyTunnelThread, 10 * 1000) == WAIT_TIMEOUT) { + SPDLOG_ERROR(TEXT("Waitting HTTP Service clost timeout")); + } + + closesocket(pProxy->udpProxySock); + } + + if (pProxy->hProxySCGThread) { + if (WaitForSingleObject(pProxy->hProxySCGThread, 10 * 1000) == WAIT_TIMEOUT) { + SPDLOG_ERROR(TEXT("Waitting HTTP Service clost timeout")); + } + + closesocket(pProxy->scgGwSock); + } +#endif +} + +#if !USED_PORTMAP_TUNNEL +static DWORD UDPProxvRemoteThread(LPVOID lpParameter) { + const auto pPeerSock = static_cast(lpParameter); + const PSCG_PROXY_INFO pProxy = &GetGlobalCfgInfo()->scgProxy; + + while (pPeerSock && !pProxy->exitNow) { + sockaddr_in remoteWgAddr {}; + TCHAR ipAddr[MAX_IP_LEN]; + + int addrSize = sizeof(SOCKADDR); + char recvBuf[1500]; + int iRecvBytes; + + // 代理服务 In + iRecvBytes = recvfrom(pProxy->scgGwSock, + recvBuf, + 1500, + 0, + reinterpret_cast(&remoteWgAddr), + &addrSize); + + memset(ipAddr, 0, MAX_IP_LEN); + InetNtop(AF_INET, &remoteWgAddr.sin_addr.s_addr, ipAddr, MAX_IP_LEN); + SPDLOG_TRACE(TEXT(">>> Scoket In {1} Recv {0} bytes from {2}:{3}"), + iRecvBytes, + pProxy->scgGwSock, + ipAddr, + ntohs(remoteWgAddr.sin_port)); + + if (iRecvBytes != SOCKET_ERROR) { + int sendBytes = sendto(pProxy->udpProxySock, + recvBuf, + iRecvBytes, + 0, + reinterpret_cast(pPeerSock), + sizeof(SOCKADDR)); + memset(ipAddr, 0, MAX_IP_LEN); + InetNtop(AF_INET, &pPeerSock->sin_addr.s_addr, ipAddr, MAX_IP_LEN); + SPDLOG_TRACE(TEXT("<<< Scoket In Send {0} bytes to {2}:{3}"), + sendBytes, + pProxy->udpProxySock, + ipAddr, + ntohs(pPeerSock->sin_port)); + } else { + SPDLOG_ERROR(TEXT(">>> Scoket In {1} Recv {0} bytes from {2}:{3} error: {4}"), + iRecvBytes, + pProxy->scgGwSock, + ipAddr, + ntohs(remoteWgAddr.sin_port), + WSAGetLastError()); + } + + Sleep(100); + } + return 0; +} + +static DWORD UDPProxyRecvThread(LPVOID lpParameter) { + bool isRemoteInit = false; + sockaddr_in localWgAddr {}; + sockaddr_in scgAddr {}; + unsigned char recvBuf[1500 + SCG_UDP_HEAD_SIZE]; + std::array arr; + const PSCG_PROXY_INFO pProxy = &GetGlobalCfgInfo()->scgProxy; + const auto svrId = static_cast(GetGlobalCfgInfo()->userCfg.cliConfig.scgTunnelAppId); + char *pRecBuf = reinterpret_cast(&recvBuf[SCG_UDP_HEAD_SIZE]); + + scgAddr.sin_family = AF_INET; + scgAddr.sin_port = htons(pProxy->scgGwPort); + InetPton(AF_INET, pProxy->scgIpAddr, &scgAddr.sin_addr.s_addr); + + // 构建 SCG UDP 包头 + recvBuf[0] = 0x01; // VERSION + recvBuf[1] = 0x09; // Length + recvBuf[2] = 0xF0; // ++++++ INFO[0] TYPE + recvBuf[3] = 0x04; // INFO[0] LENGTH + recvBuf[4] = 0; // INFO[0] VMID[0] + recvBuf[5] = 0; // INFO[0] VMID[1] + recvBuf[6] = 0; // INFO[0] VMID[2] + recvBuf[7] = 0; // INFO[0] VMID[3] + recvBuf[8] = 0xF1; // INFO[1] TYPE + recvBuf[9] = 0x01; // INFO[1] LENGTH + recvBuf[10] = svrId; // ------ INFO[1] SCG Service ID + + pProxy->exitNow = false; + while (!pProxy->exitNow) { + TCHAR ipAddr[MAX_IP_LEN]; + int addrSize = sizeof(SOCKADDR); + int iRecvBytes; + + // 代理服务 Out + iRecvBytes = recvfrom(pProxy->udpProxySock, + pRecBuf, + 1500, + 0, + reinterpret_cast(&localWgAddr), + &addrSize); + + InetNtop(AF_INET, &localWgAddr.sin_addr.s_addr, ipAddr, MAX_IP_LEN); + SPDLOG_TRACE(TEXT(">>> Scoket Out {1} Recv {0} bytes from {2}:{3}"), + iRecvBytes, + pProxy->udpProxySock, + ipAddr, + ntohs(localWgAddr.sin_port)); + + if (iRecvBytes >= (1450 - SCG_UDP_HEAD_SIZE)) { + SPDLOG_WARN(TEXT("!Maybe MTU overflow: Current package {0} bytes, UDP MTU 1450, SCG Head Used {1} bytes"), + iRecvBytes, + SCG_UDP_HEAD_SIZE); + } + + if (iRecvBytes != SOCKET_ERROR) { + int sendBytes; + const unsigned int id = htonl(GetGlobalCfgInfo()->curConnVmId); + unsigned char vmid[4]; + memcpy(vmid, &id, 4); + + if (!isRemoteInit) { + HANDLE handle; + isRemoteInit = true; + // 创建远端接收线程 + handle = CreateThread(nullptr, // Thread attributes + 0, // Stack size (0 = use default) + UDPProxvRemoteThread, // Thread start address + &localWgAddr, // Parameter to pass to the thread + 0, // Creation flags + nullptr); // Thread id + + if (handle == nullptr) { + SPDLOG_ERROR("Create Thread failed with error = {0}", GetLastError()); + closesocket(pProxy->udpProxySock); + closesocket(pProxy->scgGwSock); + pProxy->exitNow = true; + return -ERR_CREATE_THREAD; + } + + pProxy->hProxySCGThread = handle; + } + + recvBuf[4] = vmid[0]; // INFO[0] VMID[0] + recvBuf[5] = vmid[1]; // INFO[0] VMID[1] + recvBuf[6] = vmid[2]; // INFO[0] VMID[2] + recvBuf[7] = vmid[3]; // INFO[0] VMID[3] + + // 增加SCG包头数据长度 + iRecvBytes += 11; + + if (GetGlobalCfgInfo()->logLevel == spdlog::level::trace) { + const auto start = std::begin(recvBuf); + std::copy_n(start, iRecvBytes, std::begin(arr)); + SPDLOG_TRACE(TEXT("UDP Proxy SCG({1}/0x{2:X}) Payload: {0:Xa}"), + spdlog::to_hex(start, start + iRecvBytes, 16), + svrId, + id); + } + + sendBytes = sendto(pProxy->scgGwSock, + reinterpret_cast(recvBuf), + iRecvBytes, + 0, + reinterpret_cast(&scgAddr), + sizeof(SOCKADDR)); + memset(ipAddr, 0, MAX_IP_LEN); + InetNtop(AF_INET, &scgAddr.sin_addr.s_addr, ipAddr, MAX_IP_LEN); + SPDLOG_TRACE(TEXT("<<< Scoket Out Send {0} bytes to {2}:{3}"), + sendBytes, + pProxy->scgGwSock, + ipAddr, + ntohs(scgAddr.sin_port)); + } + + Sleep(100); + } + return 0; +} +#endif +int CreateUDPProxyServer() { +#if !USED_PORTMAP_TUNNEL + HANDLE handle; + int ret; + int addrSize = sizeof(sockaddr_in); + sockaddr_in server {}; + sockaddr_in bindAddr {}; + SOCKET sock; + const PSCG_PROXY_INFO pProxy = &GetGlobalCfgInfo()->scgProxy; + + // 创建本地 SOCKET 代理服务器 + sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) { + SPDLOG_ERROR("Create UDP Socket failed with error = {0}", WSAGetLastError()); + return -ERR_SOCKET_CREATE; + } + + server.sin_family = AF_INET; + server.sin_addr.s_addr = htonl(INADDR_ANY); + server.sin_port = htons(0); + + if (bind(sock, reinterpret_cast(&server), sizeof(SOCKADDR)) == SOCKET_ERROR) { + closesocket(sock); + SPDLOG_ERROR("Bind local UDP Socket failed with error = {0}", WSAGetLastError()); + return -ERR_SOCKET_BIND; + } + + if ((ret = getsockname(sock, reinterpret_cast(&bindAddr), &addrSize)) != 0) { + closesocket(sock); + SPDLOG_ERROR("Get UDP Socket bind port failed with error = {0}, {1}", WSAGetLastError(), ret); + return -ERR_SOCKET_BIND; + } + + // 保存 UDP 代理服务器信息 + pProxy->udpProxySock = sock; + pProxy->proxyPort = ntohs(bindAddr.sin_port); + + SPDLOG_DEBUG(TEXT("Proxy Server socket {0} bind {1} prot"), sock, pProxy->proxyPort); + + // 创建SCG SOCKET 连接客户端服务 + sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + if (sock == INVALID_SOCKET) { + SPDLOG_ERROR("Create UDP Socket failed with error = {0}", WSAGetLastError()); + closesocket(pProxy->udpProxySock); + return -ERR_SOCKET_CREATE; + } + + pProxy->scgGwSock = sock; + + // 创建代理服务发送线程 + handle = CreateThread(nullptr, // Thread attributes + 0, // Stack size (0 = use default) + UDPProxyRecvThread, // Thread start address + nullptr, // Parameter to pass to the thread + 0, // Creation flags + nullptr); // Thread id + + if (handle == nullptr) { + SPDLOG_ERROR("Create Thread failed with error = {0}", GetLastError()); + closesocket(pProxy->udpProxySock); + closesocket(pProxy->scgGwSock); + return -ERR_CREATE_THREAD; + } + + pProxy->hProxyTunnelThread = handle; +#endif + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelSDK/network/network.cpp b/NetTunnelSDK/network/network.cpp new file mode 100644 index 0000000..8dd9254 --- /dev/null +++ b/NetTunnelSDK/network/network.cpp @@ -0,0 +1,1723 @@ +#include "pch.h" + +#include "usrerr.h" +#include +#include +#include +#include +#include +#include +#include + +#include "globalcfg.h" +#include "misc.h" +#include "network.h" + +#include +#include + +#pragma comment(lib, "Iphlpapi.lib") +#pragma comment(lib, "Ws2_32.lib") + +static NIC_CONTENT g_NetAdapterInfo[NET_CARD_MAX]; + +int GetInterfaceIfIndexByIpAddr(const TCHAR *pIpAddr, ULONG *pIfIndex) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + ULONG ulOutBufLen; + + if (pIpAddr == nullptr || lstrlen(pIpAddr) == 0) { + SPDLOG_ERROR(TEXT("Input pIpAddr error: {0}"), pIpAddr); + return -ERR_INPUT_PARAMS; + } + + if (pIfIndex == nullptr) { + SPDLOG_ERROR(TEXT("Input pIfIndex params error")); + return -ERR_INPUT_PARAMS; + } + + ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + *pIfIndex = -1; + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + PIP_ADDR_STRING ipAddressListPointer = &(pAdapter->IpAddressList); + + while (ipAddressListPointer != nullptr) { + if (StrCmp((ipAddressListPointer->IpAddress).String, pIpAddr) == 0) { + *pIfIndex = pAdapter->Index; + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } else { + ipAddressListPointer = ipAddressListPointer->Next; + } + } + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +int GetInterfaceIfIndexByGUID(const TCHAR *pGUID, int *pIfIndex) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + + if (pGUID == nullptr || lstrlen(pGUID) == 0) { + SPDLOG_ERROR(TEXT("Input pGUID error: {0}"), pGUID); + return -ERR_INPUT_PARAMS; + } + + if (pIfIndex == nullptr) { + SPDLOG_ERROR(TEXT("Input pIfIndex params error")); + return -ERR_INPUT_PARAMS; + } + + ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + *pIfIndex = -1; + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + if (StrCmp(pAdapter->AdapterName, pGUID) == 0) { + *pIfIndex = static_cast(pAdapter->Index); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +int GetInterfaceNameByGUID(const TCHAR *pGUID, TCHAR ifName[MAX_NETCARD_NAME], int *pConnStatus) { + VARIANT v; + INetConnection *pNC = nullptr; + IEnumVARIANT *pEV = nullptr; + IUnknown *pUnk = nullptr; + INetSharingEveryConnectionCollection *pNSECC = nullptr; + INetSharingManager *pNSM; + HRESULT hr; + + if (pGUID == nullptr || lstrlen(pGUID) == 0) { + SPDLOG_ERROR(TEXT("Input pGUID params error: {0}"), pGUID); + return -ERR_INPUT_PARAMS; + } + + hr = ::CoCreateInstance(CLSID_NetSharingManager, + nullptr, + CLSCTX_ALL, + IID_INetSharingManager, + reinterpret_cast(&pNSM)); + + if (hr != S_OK || pNSM == nullptr) { + CoInitialize(nullptr); + CoInitializeSecurity(nullptr, + -1, + nullptr, + nullptr, + RPC_C_AUTHN_LEVEL_PKT, + RPC_C_IMP_LEVEL_IMPERSONATE, + nullptr, + EOAC_NONE, + nullptr); + + hr = ::CoCreateInstance(CLSID_NetSharingManager, + nullptr, + CLSCTX_ALL, + IID_INetSharingManager, + reinterpret_cast(&pNSM)); + } + + if (hr != S_OK || pNSM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetSharingManager failed: {0}"), hr); + return -ERR_CREATE_COMMOBJECT; + } + + VariantInit(&v); + hr = pNSM->get_EnumEveryConnection(&pNSECC); + + if (hr != S_OK || !pNSECC) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pNSECC->get__NewEnum(&pUnk); + pNSECC->Release(); + + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pUnk->QueryInterface(IID_IEnumVARIANT, reinterpret_cast(&pEV)); + pUnk->Release(); + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + while (S_OK == pEV->Next(1, &v, nullptr)) { + if (V_VT(&v) == VT_UNKNOWN) { + V_UNKNOWN(&v)->QueryInterface(IID_INetConnection, reinterpret_cast(&pNC)); + if (pNC) { + int ret; + TCHAR strGuid[MAX_PATH] = {}; + NETCON_PROPERTIES *pNP; + pNC->GetProperties(&pNP); + + StringCbPrintf(strGuid, + MAX_PATH, + TEXT("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}"), + pNP->guidId.Data1, + pNP->guidId.Data2, + pNP->guidId.Data3, + pNP->guidId.Data4[0], + pNP->guidId.Data4[1], + pNP->guidId.Data4[2], + pNP->guidId.Data4[3], + pNP->guidId.Data4[4], + pNP->guidId.Data4[5], + pNP->guidId.Data4[6], + pNP->guidId.Data4[7]); + + // ҵӦд + if (StrCmp(pGUID, strGuid) != 0) { + continue; + } + + memset(ifName, 0, MAX_NETCARD_NAME); + ret = WideCharToTChar(pNP->pszwName, ifName, MAX_NETCARD_NAME); + + // ִַת + if (ret != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Convert Unicode wide char to TCHAR failed: {0}."), ret); + return ret; + } + + if (pConnStatus) { + *pConnStatus = pNP->Status; + } + return ERR_SUCCESS; + } + } + } + + return ERR_ITEM_UNEXISTS; +} + +int GetInternetIfIndex(int *pIfIndex) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + ULONG ulOutBufLen; + + if (pIfIndex == nullptr) { + SPDLOG_ERROR(TEXT("Input pIfIndex params error")); + return -ERR_INPUT_PARAMS; + } + + ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + *pIfIndex = -1; + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + bool bIsInternel; + int index = static_cast(pAdapter->Index); + + int ret = IsInternetConnectAdapter(index, &bIsInternel); + + if (ret != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("IsInternetConnectAdapter {0} : {1}\n"), index, ret); + } + + if (ret == ERR_SUCCESS && bIsInternel) { + *pIfIndex = index; + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } + + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +int GetInterfaceGUIDByIfIndex(const int ifIndex, GUID *pGuid) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + ULONG ulOutBufLen; + if (pGuid == nullptr) { + SPDLOG_ERROR(TEXT("Input pGuid error.")); + return -ERR_INPUT_PARAMS; + } + + ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + } + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + if (ifIndex == static_cast(pAdapter->Index)) { + int ret; + WCHAR strGuid[MAX_PATH]; + + if ((ret = TCharToWideChar(pAdapter->AdapterName, strGuid, MAX_PATH)) != ERR_SUCCESS) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ret; + } + + if (CLSIDFromString(strGuid, pGuid) != NOERROR) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_MEMORY_STR; + } + + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +/** + * @brief ȡ + * @param[in] pInterfaceName + * @param[out] pIfIndex Index + * @return ִн 0: ɹ С0 ʧ @see USER_ERRNO + * - -ERR_INPUT_PARAMS + * - -ERR_ITEM_UNEXISTS + * - -ERR_SYS_CALL ȡϵͳʧ + * - -ERR_MALLOC_MEMORY ڴʧ + * - ERR_SUCCESS ɹ + */ +int GetInterfaceIfIndexByName(const TCHAR *pInterfaceName, int *pIfIndex) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + ULONG ulOutBufLen; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIfIndex == nullptr) { + SPDLOG_ERROR(TEXT("Input pIfIndex params error.")); + return -ERR_INPUT_PARAMS; + } + + ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + *pIfIndex = -1; + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + TCHAR NetCardName[MAX_NETCARD_NAME] = {}; + GetInterfaceNameByGUID(pAdapter->AdapterName, NetCardName, nullptr); + + if (StrCmp(pInterfaceName, NetCardName) == 0) { + *pIfIndex = static_cast(pAdapter->Index); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } + + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +int GetInterfaceGUIDByName(const TCHAR *pInterfaceName, GUID *pGuid) { + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + ULONG ulOutBufLen; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pGuid == nullptr) { + SPDLOG_ERROR(TEXT("Input pGuid params error")); + return -ERR_INPUT_PARAMS; + } + + ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter) { + int ret; + TCHAR NetCardName[MAX_NETCARD_NAME] = {}; + GetInterfaceNameByGUID(pAdapter->AdapterName, NetCardName, nullptr); + + if (StrCmp(pInterfaceName, NetCardName) == 0) { + WCHAR strGuid[MAX_PATH]; + + if ((ret = TCharToWideChar(pAdapter->AdapterName, strGuid, MAX_PATH)) != ERR_SUCCESS) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ret; + } + + if (CLSIDFromString(strGuid, pGuid) != NOERROR) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_MEMORY_STR; + } + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return ERR_SUCCESS; + } + + pAdapter = pAdapter->Next; + } + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + return -ERR_ITEM_UNEXISTS; +} + +int WaitNetAdapterConnected(const TCHAR *pInterfaceName, int timeOutOfMs) { + INetworkListManager *pNLM; + IEnumNetworkConnections *pEnumConns; + INetwork *pINet; + INetworkConnection *pIConn; + HRESULT hr; + GUID guid; + const DWORD startTime = timeGetTime(); + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + hr = ::CoCreateInstance(CLSID_NetworkListManager, + nullptr, + CLSCTX_ALL, + IID_INetworkListManager, + reinterpret_cast(&pNLM)); + + if (hr != S_OK || pNLM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetworkListManager failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + do { + + hr = pNLM->GetNetworkConnections(&pEnumConns); + + if (hr != S_OK || pEnumConns == nullptr) { + SPDLOG_ERROR(TEXT("NetworkListManager GetNetworks failed: {0}."), hr); + continue; + } + + while (S_OK == pEnumConns->Next(1, &pIConn, nullptr)) { + GUID adpterGuid; + + pIConn->GetAdapterId(&adpterGuid); + pIConn->GetNetwork(&pINet); + + if (pINet) { + BSTR sName = {}; + TCHAR ifName[MAX_PATH]; + + pINet->GetName(&sName); + + if (WideCharToTChar(sName, ifName, MAX_PATH) != ERR_SUCCESS) { + SysFreeString(sName); + return -ERR_MEMORY_STR; + } + + SysFreeString(sName); + + if (StrNCmp(pInterfaceName, ifName, lstrlen(pInterfaceName)) == 0) { + + int ret = GetInterfaceGUIDByName(pInterfaceName, &guid); + if (ret != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Get Interface {0} GUID error: {1}"), pInterfaceName, ret); + continue; + } + + /*SPDLOG_DEBUG(TEXT("Match Interface {0} --> {1}, Guid {2:x} --> {3:x}"), + ifName, + pInterfaceName, + adpterGuid.Data1, + guid.Data1);*/ + + if (memcmp(&adpterGuid, &guid, sizeof(GUID)) == 0) { + SPDLOG_DEBUG(TEXT("Interface {0}({1}) network connected now..."), ifName, pInterfaceName); + return ERR_SUCCESS; + } + } + } + } + + Sleep(1000); + } while (timeGetTime() - startTime <= static_cast(timeOutOfMs)); + + return -ERR_SYS_TIMEOUT; +} + +int GetAllNICInfo(PNIC_CONTENT *pInfo, int *pItemCounts) { + PNIC_CONTENT pNic; + PIP_ADAPTER_INFO pAdapterInfo; + DWORD dwRetVal; + + if (pItemCounts == nullptr) { + SPDLOG_ERROR(TEXT("Input pItemCounts params error")); + return -ERR_INPUT_PARAMS; + } + + if (pInfo == nullptr) { + SPDLOG_ERROR(TEXT("Input pInfo params error")); + return -ERR_INPUT_PARAMS; + } + + ULONG ulOutBufLen = sizeof(IP_ADAPTER_INFO); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(IP_ADAPTER_INFO))); + + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo")); + return -ERR_MALLOC_MEMORY; + } + + if (GetAdaptersInfo(pAdapterInfo, &ulOutBufLen) == ERROR_BUFFER_OVERFLOW) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + pAdapterInfo = static_cast(HeapAlloc(GetProcessHeap(), 0, ulOutBufLen)); + if (pAdapterInfo == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory needed to call GetAdaptersinfo\n")); + return -ERR_MALLOC_MEMORY; + } + } + + if ((dwRetVal = GetAdaptersInfo(pAdapterInfo, &ulOutBufLen)) == NO_ERROR) { + int id = 0; + int ncStatus; + PIP_ADAPTER_INFO pAdapter = pAdapterInfo; + while (pAdapter && id < NET_CARD_MAX) { + // + g_NetAdapterInfo[id].InterfaceIndex = static_cast(pAdapter->Index); + // + StringCbCopy(g_NetAdapterInfo[id].NetCardUUID, MAX_ADAPTER_NAME_LENGTH, pAdapter->AdapterName); + // ϸ + StringCbCopy(g_NetAdapterInfo[id].NetCardDescription, + MAX_ADAPTER_DESCRIPTION_LENGTH, + pAdapter->Description); + // IP ַ + StringCbCopy(g_NetAdapterInfo[id].NetCardIpaddr, MAX_IP_LEN - 1, pAdapter->IpAddressList.IpAddress.String); + // + StringCbCopy(g_NetAdapterInfo[id].NetCardNetmask, MAX_IP_LEN - 1, pAdapter->IpAddressList.IpMask.String); + // صַ + StringCbCopy(g_NetAdapterInfo[id].NetCardGateway, MAX_IP_LEN - 1, pAdapter->GatewayList.IpAddress.String); + // MAC ַ + StringCbPrintf(g_NetAdapterInfo[id].NetCardMacAddr, + 20 - 1, + TEXT("%02X:%02X:%02X:%02X:%02X:%02X"), + pAdapter->Address[0], + pAdapter->Address[1], + pAdapter->Address[2], + pAdapter->Address[3], + pAdapter->Address[4], + pAdapter->Address[5]); + + if (GetInterfaceNameByGUID(pAdapter->AdapterName, g_NetAdapterInfo[id].NetCardName, &ncStatus) == + ERR_SUCCESS) { + g_NetAdapterInfo[id].netConnStatus = static_cast(ncStatus); + } + + id++; + pAdapter = pAdapter->Next; + } + + *pItemCounts = id; + } else { + SPDLOG_ERROR(TEXT("GetAdaptersInfo failed with error: {0}\n"), dwRetVal); + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + return -ERR_SYS_CALL; + } + + if (pAdapterInfo) { + HeapFree(GetProcessHeap(), 0, pAdapterInfo); + } + + pNic = static_cast(CoTaskMemAlloc(sizeof(NIC_CONTENT) * (*pItemCounts))); + + if (pNic == nullptr) { + *pItemCounts = 0; + SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), sizeof(NIC_CONTENT) * (*pItemCounts)); + return -ERR_MALLOC_MEMORY; + } + + memset(pNic, 0, sizeof(NIC_CONTENT) * (*pItemCounts)); + memcpy(pNic, g_NetAdapterInfo, sizeof(NIC_CONTENT) * *pItemCounts); + *pInfo = pNic; + + return ERR_SUCCESS; +} + +int IsInternetConnectAdapter(int ifIndex, bool *pRet) { + DWORD dwSize = 0; + DWORD dwRetVal; + PMIB_IPFORWARDTABLE pIpForwardTable; + + if (ifIndex < 0 || ifIndex > 255) { + SPDLOG_ERROR(TEXT("Input ifIndex params error: {0}"), ifIndex); + return -ERR_INPUT_PARAMS; + } + + if (pRet == nullptr) { + SPDLOG_ERROR(TEXT("Input pRet params error")); + return -ERR_INPUT_PARAMS; + } + + pIpForwardTable = static_cast(HeapAlloc(GetProcessHeap(), 0, sizeof(MIB_IPFORWARDTABLE))); + + if (pIpForwardTable == nullptr) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), sizeof(MIB_IPFORWARDTABLE)); + return -ERR_MALLOC_MEMORY; + } + + if (GetIpForwardTable(pIpForwardTable, &dwSize, 0) == ERROR_INSUFFICIENT_BUFFER) { + HeapFree(GetProcessHeap(), 0, pIpForwardTable); + pIpForwardTable = static_cast(HeapAlloc(GetProcessHeap(), 0, dwSize)); + if (pIpForwardTable == nullptr) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), dwSize); + return -ERR_MALLOC_MEMORY; + } + } + + *pRet = false; + + if ((dwRetVal = GetIpForwardTable(pIpForwardTable, &dwSize, 0)) == NO_ERROR) { + for (DWORD i = 0; i < pIpForwardTable->dwNumEntries; i++) { + TCHAR ipStr[24] = {}; + TCHAR maskStr[24] = {}; + + if (static_cast(pIpForwardTable->table[i].dwForwardIfIndex) != ifIndex) { + continue; + } + + if (InetNtop(AF_INET, &pIpForwardTable->table[i].dwForwardDest, ipStr, 24) == nullptr) { + continue; + } + + if (InetNtop(AF_INET, &pIpForwardTable->table[i].dwForwardMask, maskStr, 24) == nullptr) { + continue; + } + + if (StrCmp(ipStr, TEXT("0.0.0.0")) == 0 && StrCmp(maskStr, TEXT("0.0.0.0")) == 0) { + *pRet = true; + break; + } + } + HeapFree(GetProcessHeap(), 0, pIpForwardTable); + return ERR_SUCCESS; + } else { + SPDLOG_ERROR(TEXT("GetIpForwardTable failed: {0}."), dwRetVal); + HeapFree(GetProcessHeap(), 0, pIpForwardTable); + return ERR_GET_IPFOWARDTBL; + } +} + +int SetNetConnectionNetworkCategory(const TCHAR *pInterfaceName, const bool isPrivate) { + INetworkListManager *pNLM; + IEnumNetworkConnections *pEnumConns; + INetwork *pINet; + INetworkConnection *pIConn; + HRESULT hr; + GUID guid; + int ret; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if ((ret = GetInterfaceGUIDByName(pInterfaceName, &guid)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Get NetCard [{0}] GUID error: {1}"), pInterfaceName, ret); + return ret; + } + + hr = ::CoCreateInstance(CLSID_NetworkListManager, + nullptr, + CLSCTX_ALL, + IID_INetworkListManager, + reinterpret_cast(&pNLM)); + + if (hr != S_OK || pNLM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetworkListManager failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + hr = pNLM->GetNetworkConnections(&pEnumConns); + + if (hr != S_OK || pEnumConns == nullptr) { + SPDLOG_ERROR(TEXT("NetworkListManager GetNetworks failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + while (S_OK == pEnumConns->Next(1, &pIConn, nullptr)) { + GUID adpterGuid; + pIConn->GetNetwork(&pINet); + pIConn->GetAdapterId(&adpterGuid); + + if (pINet) { + if (memcmp(&adpterGuid, &guid, sizeof(GUID)) == 0) { + pINet->SetCategory(isPrivate ? NLM_NETWORK_CATEGORY_PRIVATE : NLM_NETWORK_CATEGORY_PUBLIC); + return ERR_SUCCESS; + } + } + } + + return -ERR_ITEM_UNEXISTS; +} + +int GetNetConnectionNetworkCategory(const TCHAR *pInterfaceName, bool *pIsPrivate) { + INetworkListManager *pNLM; + IEnumNetworkConnections *pEnumConns; + INetwork *pINet; + INetworkConnection *pIConn; + HRESULT hr; + GUID guid; + int ret; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIsPrivate == nullptr) { + SPDLOG_ERROR(TEXT("Input pIsPrivate params error")); + return -ERR_INPUT_PARAMS; + } + + if ((ret = GetInterfaceGUIDByName(pInterfaceName, &guid)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Get NetCard [{0}] GUID error: {1}"), pInterfaceName, ret); + return ret; + } + + hr = ::CoCreateInstance(CLSID_NetworkListManager, + nullptr, + CLSCTX_ALL, + IID_INetworkListManager, + reinterpret_cast(&pNLM)); + + if (hr != S_OK || pNLM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetworkListManager failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + hr = pNLM->GetNetworkConnections(&pEnumConns); + + if (hr != S_OK || pEnumConns == nullptr) { + SPDLOG_ERROR(TEXT("NetworkListManager GetNetworks failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + while (S_OK == pEnumConns->Next(1, &pIConn, nullptr)) { + GUID adpterGuid; + pIConn->GetNetwork(&pINet); + pIConn->GetAdapterId(&adpterGuid); + + if (pINet) { + if (memcmp(&adpterGuid, &guid, sizeof(GUID)) == 0) { + NLM_NETWORK_CATEGORY cat; + pINet->GetCategory(&cat); + *pIsPrivate = (cat == NLM_NETWORK_CATEGORY_PRIVATE) ? true : false; + return ERR_SUCCESS; + } + } + } + + return -ERR_ITEM_UNEXISTS; +} + +int SetNetIntelnetConnectionSharing(int ifIndex, bool isEnable, bool isSetPrivate) { + VARIANT v; + INetConnection *pNC = nullptr; + INetSharingConfiguration *pNSC = nullptr; + IEnumVARIANT *pEV = nullptr; + IUnknown *pUnk = nullptr; + INetSharingEveryConnectionCollection *pNSECC = nullptr; + INetSharingManager *pNSM; + HRESULT hr; + GUID ifGuid; + int ret; + + if ((ret = GetInterfaceGUIDByIfIndex(ifIndex, &ifGuid)) != ERR_SUCCESS) { + return ret; + } + + hr = ::CoCreateInstance(CLSID_NetSharingManager, + nullptr, + CLSCTX_ALL, + IID_INetSharingManager, + reinterpret_cast(&pNSM)); + + if (hr != S_OK || pNSM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetSharingManager failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + VariantInit(&v); + hr = pNSM->get_EnumEveryConnection(&pNSECC); + + if (hr != S_OK || !pNSECC) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pNSECC->get__NewEnum(&pUnk); + pNSECC->Release(); + + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pUnk->QueryInterface(IID_IEnumVARIANT, reinterpret_cast(&pEV)); + pUnk->Release(); + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + while (S_OK == pEV->Next(1, &v, nullptr)) { + if (V_VT(&v) == VT_UNKNOWN) { + V_UNKNOWN(&v)->QueryInterface(IID_INetConnection, reinterpret_cast(&pNC)); + if (pNC) { + NETCON_PROPERTIES *pNP; + pNC->GetProperties(&pNP); + + // ҵӦд + if (memcmp(&ifGuid, &pNP->guidId, sizeof(GUID)) != 0) { + continue; + } + + // δӵ + if (pNP->Status != NCS_CONNECTED) { + return -ERR_NET_UNCONNECT; + } + + hr = pNSM->get_INetSharingConfigurationForINetConnection(pNC, &pNSC); + + if (hr != S_OK || !pNSC) { + continue; + } + + if (pNSC) { + if (isEnable) { + hr = pNSC->DisableSharing(); + + if (hr != S_OK) { + SPDLOG_ERROR(TEXT("INetSharingManager DisableSharing failed: {0}."), hr); + pNSC->Release(); + return -ERR_CALL_COMMOBJECT; + } + Sleep(500); + hr = pNSC->EnableSharing(isSetPrivate ? ICSSHARINGTYPE_PRIVATE : ICSSHARINGTYPE_PUBLIC); + + if (hr != S_OK) { + SPDLOG_ERROR(TEXT("INetSharingManager EnableSharing failed: {0}."), hr); + pNSC->Release(); + return -ERR_CALL_COMMOBJECT; + } + } else { + hr = pNSC->DisableSharing(); + + if (hr != S_OK) { + SPDLOG_ERROR(TEXT("INetSharingManager DisableSharing failed: {0}."), hr); + pNSC->Release(); + return -ERR_CALL_COMMOBJECT; + } + } + + pNSC->Release(); + return ERR_SUCCESS; + } + } + } + } + + return ERR_ITEM_UNEXISTS; +} + +int GetNetIntelnetConnectionSharing(int ifIndex, bool *pIsEnable) { + VARIANT v; + INetConnection *pNC = nullptr; + INetSharingConfiguration *pNSC = nullptr; + IEnumVARIANT *pEV = nullptr; + IUnknown *pUnk = nullptr; + INetSharingEveryConnectionCollection *pNSECC = nullptr; + INetSharingManager *pNSM; + HRESULT hr; + GUID ifGuid; + int ret; + + if (pIsEnable == nullptr) { + SPDLOG_ERROR(TEXT("Input pIsEnable params error")); + return -ERR_INPUT_PARAMS; + } + + // ȡGUID + if ((ret = GetInterfaceGUIDByIfIndex(ifIndex, &ifGuid)) != ERR_SUCCESS) { + return ret; + } + + hr = ::CoCreateInstance(CLSID_NetSharingManager, + nullptr, + CLSCTX_ALL, + IID_INetSharingManager, + reinterpret_cast(&pNSM)); + + if (hr != S_OK || pNSM == nullptr) { + SPDLOG_ERROR(TEXT("CoCreateInstance NetSharingManager failed: {0}."), hr); + return -ERR_CREATE_COMMOBJECT; + } + + VariantInit(&v); + hr = pNSM->get_EnumEveryConnection(&pNSECC); + + if (hr != S_OK || !pNSECC) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pNSECC->get__NewEnum(&pUnk); + pNSECC->Release(); + + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + hr = pUnk->QueryInterface(IID_IEnumVARIANT, reinterpret_cast(&pEV)); + pUnk->Release(); + if (hr != S_OK || !pUnk) { + SPDLOG_ERROR(TEXT("INetSharingManager get_EnumEveryConnection failed: {0}."), hr); + return -ERR_SYS_CALL; + } + + while (S_OK == pEV->Next(1, &v, nullptr)) { + if (V_VT(&v) == VT_UNKNOWN) { + V_UNKNOWN(&v)->QueryInterface(IID_INetConnection, reinterpret_cast(&pNC)); + if (pNC) { + NETCON_PROPERTIES *pNP; + pNC->GetProperties(&pNP); + + // ҵӦд + if (memcmp(&ifGuid, &pNP->guidId, sizeof(GUID)) != 0) { + continue; + } + + //if (StrCmp(pInterfaceName, ifName) != 0) { + // continue; + //} + + hr = pNSM->get_INetSharingConfigurationForINetConnection(pNC, &pNSC); + + if (hr != S_OK || !pNSC) { + continue; + } + + if (pNSC) { + VARIANT_BOOL bRet = false; + hr = pNSC->get_SharingEnabled(&bRet); + pNSC->Release(); + + if (hr != S_OK) { + SPDLOG_ERROR(TEXT("INetSharingManager DisableSharing failed: {0}."), hr); + return -ERR_CALL_COMMOBJECT; + } + + *pIsEnable = bRet; + + return ERR_SUCCESS; + } + } + } + } + + return ERR_ITEM_UNEXISTS; +} + +int AddRouteTable(const char *pIP, const char *pMask, const char *pGateway) { + PMIB_IPFORWARDTABLE pIpForwardTable = nullptr; + PMIB_IPFORWARDROW pRow = nullptr; + DWORD dwSize = 0; + DWORD dwStatus; + DWORD dwDestIp; + int ret; + IP_INFO ipInfo; + + if (pIP == nullptr || lstrlen(pIP) < MIN_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pIP params error: {0}"), pIP); + return -ERR_INPUT_PARAMS; + } + + if (pMask == nullptr || lstrlen(pMask) < MIN_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pMask params error: {0}"), pMask); + return -ERR_INPUT_PARAMS; + } + + if (pGateway == nullptr || lstrlen(pGateway) < MIN_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pGateway params error: {0}"), pGateway); + return -ERR_INPUT_PARAMS; + } + + if ((ret = GetIpV4InfoFromNetmask(pIP, pMask, &ipInfo)) != ERR_SUCCESS) { + return ret; + } + + if (inet_pton(AF_INET, ipInfo.network, &dwDestIp) <= 0) { + SPDLOG_ERROR(TEXT("Convert {0} to network ipaddress error."), pIP); + return -ERR_UN_SUPPORT; + } + + // Find out how big our buffer needs to be. + dwStatus = GetIpForwardTable(pIpForwardTable, &dwSize, FALSE); + if (dwStatus == ERROR_INSUFFICIENT_BUFFER) { + // Allocate the memory for the table + pIpForwardTable = static_cast(malloc(dwSize)); + if (!pIpForwardTable) { + SPDLOG_ERROR(TEXT("Malloc failed. Out of memory.")); + return -ERR_MALLOC_MEMORY; + } + // Now get the table. + dwStatus = GetIpForwardTable(pIpForwardTable, &dwSize, FALSE); + } + + if (dwStatus != ERROR_SUCCESS) { + SPDLOG_ERROR(TEXT("getIpForwardTable failed.")); + if (pIpForwardTable) { + free(pIpForwardTable); + } + return -ERR_SYS_CALL; + } + + for (DWORD i = 0; i < pIpForwardTable->dwNumEntries; i++) { + if (pIpForwardTable->table[i].dwForwardDest == 0) { + if (!pRow) { + pRow = static_cast(malloc(sizeof(MIB_IPFORWARDROW))); + if (!pRow) { + SPDLOG_ERROR(TEXT("Malloc failed. Out of memory.")); + free(pIpForwardTable); + free(pRow); + return -ERR_MALLOC_MEMORY; + } + // Copy the row + memcpy(pRow, &(pIpForwardTable->table[i]), sizeof(MIB_IPFORWARDROW)); + } + } else if (pIpForwardTable->table[i].dwForwardDest == dwDestIp) { + // ɾܴڵľɵ·Ϣ + dwStatus = DeleteIpForwardEntry(&(pIpForwardTable->table[i])); + + if (dwStatus != ERROR_SUCCESS) { + SPDLOG_ERROR(TEXT("Could not delete old gateway")); + return -ERR_NET_REMOVE_ROUTE; + } + } + } + + free(pIpForwardTable); + + pRow->dwForwardDest = dwDestIp; + + if (inet_pton(AF_INET, ipInfo.netmask, &pRow->dwForwardMask) <= 0) { + SPDLOG_ERROR(TEXT("Convert {0} to network ipaddress error."), pMask); + return -ERR_UN_SUPPORT; + } + + if (inet_pton(AF_INET, pGateway, &pRow->dwForwardNextHop) <= 0) { + SPDLOG_ERROR(TEXT("Convert {0} to network ipaddress error."), pGateway); + free(pRow); + return -ERR_UN_SUPPORT; + } + + if ((ret = GetInterfaceIfIndexByIpAddr(pGateway, &pRow->dwForwardIfIndex)) != ERR_SUCCESS) { + free(pRow); + return ret; + } + + if ((dwStatus = CreateIpForwardEntry(pRow)) != NO_ERROR) { + SPDLOG_ERROR(TEXT("Add Route {1} netmask {2} gateway {3} error: {0}."), dwStatus, pIP, pMask, pGateway); + free(pRow); + return -ERR_NET_ADD_ROUTE; + } + + free(pRow); + + return ERR_SUCCESS; +} + +int SetNATRule(const TCHAR *pInterfaceName, const TCHAR *pCidrIpaddr) { + int ret; + TCHAR cmdBuf[1024]; + DWORD retCode; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pCidrIpaddr == nullptr || lstrlen(pCidrIpaddr) == 0) { + SPDLOG_ERROR("Input pCidrIpaddr params error: {0}", pCidrIpaddr); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf( + cmdBuf, + 1024, + TEXT("PowerShell -NoProfile -NonInteractive -WindowStyle Hidden -ExecutionPolicy Bypass Invoke-Command " + "-ScriptBlock { Get-NetNat -Name %s_nat -ErrorAction Ignore | Remove-NetNat -Confirm:$false; " + "New-NetNat -Name %s_nat -InternalIPInterfaceAddressPrefix %s }"), + pInterfaceName, + pInterfaceName, + pCidrIpaddr))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Command({1}): {0}", cmdBuf, retCode); + + if (retCode != 0) { + SPDLOG_ERROR("PowerShell return error({1}): {0}", cmdBuf, retCode); + return -ERR_PROCESS_RETURN; + } + + return ERR_SUCCESS; +} + +int RemoveNATRule(const TCHAR *pInterfaceName) { + int ret; + TCHAR cmdBuf[1024]; + DWORD retCode; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf( + cmdBuf, + 1024, + TEXT("PowerShell -NoProfile -NonInteractive -WindowStyle Hidden -ExecutionPolicy Bypass Invoke-Command " + "-ScriptBlock { Get-NetNat -Name %s_nat -ErrorAction Ignore | Remove-NetNat -Confirm:$false}"), + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Command({1}): {0}", cmdBuf, retCode); + + return ERR_SUCCESS; +} + +#if 0 +int SetInterfaceIpAddress(const TCHAR *pInterfaceName, const TCHAR *pIpaddr, const TCHAR *pNetmask) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + //int cidr; + DWORD retCode; + IP_INFO ipInfo {}; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIpaddr == nullptr || lstrlen(pIpaddr) == 0) { + SPDLOG_ERROR("Input pIpaddr params error: {0}", pIpaddr); + return -ERR_INPUT_PARAMS; + } + + if (pNetmask == nullptr || lstrlen(pNetmask) == 0) { + SPDLOG_WARN("Input pNetmask params error: {0}", pNetmask); + return -ERR_INPUT_PARAMS; + } + + GetIpV4InfoFromNetmask(pIpaddr, pNetmask, &ipInfo); + //cidr = NetmaskToCIDR(pNetmask); + + if (FAILED(StringCbPrintf( + cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {New-NetIPAddress -InterfaceAlias %s -AddressFamily IPv4 -IPAddress %s -PrefixLength %d}\"", + pInterfaceName, + pIpaddr, + ipInfo.prefix))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Set IP Command({1}): {0}", cmdBuf, retCode); + + return ERR_SUCCESS; +} + +int GetWindowsHyperVStatus(int *pEnabled) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + TCHAR cmdResult[2048]; + DWORD retCode; + + if (pEnabled == nullptr) { + SPDLOG_ERROR(TEXT("Input pEnabled params error")); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + TEXT("PowerShell -Command \"& {Get-WindowsOptionalFeature -FeatureName Microsoft-Hyper-V-All " + "-Online | Format-List -Property State}\"")))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, cmdResult, 2048, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + if (StrStr(cmdResult, TEXT("Enabled")) != nullptr) { + *pEnabled = TRUE; + } else { + *pEnabled = FALSE; + } + + SPDLOG_DEBUG(TEXT("Run Get Windows Hyper-V status Command({1}): {0} result: {2}\n"), cmdBuf, retCode, cmdResult); + + return ERR_SUCCESS; +} + +int EnableWindowsHyperV(bool enabled) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + DWORD retCode; + + if (enabled) { + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + TEXT("PowerShell -Command \"& {Enable-WindowsOptionalFeature -Online -FeatureName " + "Microsoft-Hyper-V -All}\"")))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + } else { + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + TEXT("PowerShell -Command \"& {Disable-WindowsOptionalFeature -Online -FeatureName " + "Microsoft-Hyper-V-Hypervisor}\"")))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Run command [{0}] error: {1}"), cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + if (retCode != 0) { + SPDLOG_ERROR(TEXT("PowerShell return error({1}): {0}"), cmdBuf, retCode); + return -ERR_PROCESS_RETURN; + } + + return ERR_SUCCESS; +} + +int GetInterfaceIndexByName(const TCHAR *pInterfaceName, int *pIndex) { + int ret; + DWORD retCode; + TCHAR cmdBuf[MAX_PATH]; + TCHAR cmdResult[MAX_PATH] = {}; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIndex == nullptr) { + SPDLOG_ERROR(TEXT("Input pIndex params error")); + return -ERR_INPUT_PARAMS; + } + + if (FAILED( + StringCbPrintf(cmdBuf, + MAX_PATH, + TEXT("PowerShell -Command \"& {Get-NetAdapter -Name %s | Format-List -Property InterfaceIndex}\""), + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, cmdResult, MAX_PATH, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Run command [{0}] error: {1}"), cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG(TEXT("Run command [{0}] resutl \'{1}\' return {2}"), cmdBuf, cmdResult, retCode); + + // Ƴ + StringRemoveAll(cmdResult, TEXT("\r\n")); + StringRemoveAll(cmdResult, TEXT(" ")); + StringRemoveAll(cmdResult, TEXT("InterfaceIndex:")); + + *pIndex = strtol(cmdResult, nullptr, 10); + + return ERR_SUCCESS; +} + +int RemoveInterfaceIpAddress(const TCHAR *pInterfaceName) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + DWORD retCode; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {Remove-NetIPAddress -InterfaceAlias %s -Confirm:$false}\"", + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Set IP Command({1}): {0}", cmdBuf, retCode); + + return ERR_SUCCESS; +} + +int SetInterfaceIpAddress(const TCHAR *pInterfaceName, const TCHAR *pIpaddr, const TCHAR *pNetmask) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + int cidr; + DWORD retCode; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIpaddr == nullptr || lstrlen(pIpaddr) == 0) { + SPDLOG_ERROR("Input pIpaddr params error: {0}", pIpaddr); + return -ERR_INPUT_PARAMS; + } + + if (pNetmask == nullptr || lstrlen(pNetmask) == 0) { + SPDLOG_WARN("Input pNetmask params error: {0}", pNetmask); + return -ERR_INPUT_PARAMS; + } + + cidr = NetmaskToCIDR(pNetmask); + + if (FAILED(StringCbPrintf( + cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {New-NetIPAddress -InterfaceAlias %s -IPAddress %s -PrefixLength %d}\"", + pInterfaceName, + pIpaddr, + cidr))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Set IP Command({1}): {0}", cmdBuf, retCode); + + return ERR_SUCCESS; +} + +int SetInterfaceIpAddressFromCIDR(const TCHAR *pInterfaceName, const TCHAR *pCidrIpaddr) { + TCHAR ipAddr[MAX_IP_LEN]; + TCHAR ip[MAX_IP_LEN]; + int ret; + TCHAR cmdBuf[MAX_PATH]; + DWORD retCode; + + TCHAR *token, *p = nullptr; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pCidrIpaddr == nullptr || lstrlen(pCidrIpaddr) == 0) { + SPDLOG_ERROR("Input pCidrIpaddr params error: {0}", pCidrIpaddr); + return -ERR_INPUT_PARAMS; + } + + StringCbCopy(ipAddr, MAX_IP_LEN, pCidrIpaddr); + + // ȡǰIPַ + token = strtok_s(ipAddr, TEXT("/"), &p); + if (token == nullptr) { + SPDLOG_ERROR("CIDR IpAddress string format error: {0}", pCidrIpaddr); + return -ERR_INPUT_PARAMS; + } + StringCbCopy(ip, MAX_IP_LEN, token); + + // ȡ + token = strtok_s(nullptr, TEXT("/"), &p); + if (token == nullptr) { + SPDLOG_ERROR("CIDR IpAddress string format error: {0}", pCidrIpaddr); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf( + cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {New-NetIPAddress -InterfaceAlias %s -IPAddress %s -PrefixLength %s}\"", + pInterfaceName, + ip, + token))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run Set IP Command({1}): {0}", cmdBuf, retCode); + + return ERR_SUCCESS; +} + +int IsInterfacePrivate(const TCHAR *pInterfaceName, bool *pIsPrivateMode) { + int ret; + DWORD retCode; + TCHAR cmdBuf[MAX_PATH]; + TCHAR cmdResult[MAX_PATH] = {}; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIsPrivateMode == nullptr) { + SPDLOG_ERROR("Input pIsPrivateMode params error"); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf( + cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {Get-NetConnectionProfile -Name %s | Format-List -Property NetworkCategory}\"", + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, cmdResult, MAX_PATH, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run command [{0}] resutl \'{1}\' return {2}", cmdBuf, cmdResult, retCode); + + if (StrStr(cmdResult, TEXT("Private")) != nullptr) { + *pIsPrivateMode = true; + return ERR_SUCCESS; + } else { + if (StrStr(cmdResult, TEXT("Public")) != nullptr) { + *pIsPrivateMode = false; + return ERR_SUCCESS; + } + } + + return -ERR_ITEM_UNEXISTS; +} + +int SetInterfacePrivate(const TCHAR *pInterfaceName, bool isPrivate) { + int ret; + TCHAR cmdBuf[MAX_PATH]; + DWORD retCode; + + if (isPrivate) { + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {Get-NetConnectionProfile -InterfaceAlias %s | " + "Set-NetConnectionProfile -NetworkCategory Private}\"", + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + } else { + if (FAILED(StringCbPrintf(cmdBuf, + MAX_PATH, + "PowerShell -Command \"& {Get-NetConnectionProfile -InterfaceAlias %s | " + "Set-NetConnectionProfile -NetworkCategory Public}\"", + pInterfaceName))) { + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + if (retCode != 0) { + SPDLOG_ERROR("PowerShell return error({1}): {0}", cmdBuf, retCode); + return -ERR_PROCESS_RETURN; + } + + return ERR_SUCCESS; +} + +bool IsCustomNatPSCmdInstalled() { + TCHAR psCmdPath[MAX_PATH]; + + StringCbPrintf(psCmdPath, + MAX_PATH, + TEXT("%s\\system32\\WindowsPowerShell\\v1.0\\Modules\\wireguard"), + GetGlobalCfgInfo()->systemDirectory); + + // ж WireGuard NAT Ƿװ + if (!PathFileExists(psCmdPath)) { + if (!CreateDirectory(psCmdPath, nullptr)) { + return false; + } + } + + StringCbCat(psCmdPath, MAX_PATH, "\\wireguard.psm1"); + + if (PathFileExists(psCmdPath)) { + // ļ˵Ѿװ + return true; + } + + return false; +} + +int IsNetConnectionSharingEnabled(const TCHAR *pInterfaceName, bool *pIsEnabled) { + int ret; + DWORD retCode; + TCHAR cmdResult[MAX_PATH] = {}; + TCHAR cmdBuf[512]; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR("Input pInterfaceName params error: {0}", pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pIsEnabled == nullptr) { + SPDLOG_ERROR("Input pIsEnabled params error"); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbPrintf( + cmdBuf, + 512, + TEXT("PowerShell -NoProfile -NonInteractive -WindowStyle Hidden -ExecutionPolicy Bypass Invoke-Command " + "-ArgumentList '%s' -ScriptBlock {param($IFNAME);$netShare = New-Object -ComObject HNetCfg.HNetShare;" + "$privateConnection = $netShare.EnumEveryConnection |? { $netShare.NetConnectionProps.Invoke($_).Name " + "-eq " + "'wg_cli' };$privateConfig = " + "$netShare.INetSharingConfigurationForINetConnection.Invoke($privateConnection);" + "Write-Output $privateConfig}"), + pInterfaceName))) { + + SPDLOG_ERROR("Format String Error"); + return -ERR_MEMORY_STR; + } + + if ((ret = RunCommand(cmdBuf, nullptr, 0, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + if (retCode != 0) { + SPDLOG_ERROR("PowerShell return error({1}): {0}", cmdBuf, retCode); + return -ERR_PROCESS_RETURN; + } + + SPDLOG_DEBUG("Run command [{0}] resutl \'{1}\' return {2}", cmdBuf, cmdResult, retCode); + + if (StrStr(cmdResult, TEXT("False")) != nullptr) { + *pIsEnabled = false; + return ERR_SUCCESS; + } else { + if (StrStr(cmdResult, TEXT("True")) != nullptr) { + *pIsEnabled = true; + return ERR_SUCCESS; + } + } + + return -ERR_ITEM_UNEXISTS; +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/protocol/protocol.cpp b/NetTunnelSDK/protocol/protocol.cpp new file mode 100644 index 0000000..10f6252 --- /dev/null +++ b/NetTunnelSDK/protocol/protocol.cpp @@ -0,0 +1,444 @@ +#include "pch.h" + +#include "tunnel.h" +#include "protocol.h" + +#include "globalcfg.h" +#include "httplib.h" +#include "misc.h" +#include "usrerr.h" + +#include +#include + +#define HTTP_JSON_CONTENT TEXT("application/json") + +static httplib::Client *g_httpCtx = nullptr; +static httplib::Client *g_tunnelHttpCtx = nullptr; + +int InitControlServer(const TCHAR *pUserSvrUrl) { + + if (g_tunnelHttpCtx) { + delete g_tunnelHttpCtx; + g_tunnelHttpCtx = nullptr; + } + + if (UsedSCGProxy()) { + TCHAR scgProxyUrl[MAX_PATH]; + StringCbPrintf(scgProxyUrl, MAX_PATH, TEXT("http://127.0.0.1:%d"), GetGlobalCfgInfo()->scgProxy.scgGwPort); + + SPDLOG_DEBUG(TEXT("Control Server Used Proxy: {0} --> {1}"), pUserSvrUrl, scgProxyUrl); + g_tunnelHttpCtx = new httplib::Client(scgProxyUrl); + } else { + g_tunnelHttpCtx = new httplib::Client(pUserSvrUrl); + SPDLOG_DEBUG(TEXT("Control Server Unused Proxy: {0}"), pUserSvrUrl); + } + + if (g_tunnelHttpCtx) { + g_tunnelHttpCtx->set_connection_timeout(0, 1000000); // 1 second + g_tunnelHttpCtx->set_read_timeout(5, 0); // 5 seconds + g_tunnelHttpCtx->set_write_timeout(5, 0); // 5 seconds + g_tunnelHttpCtx->set_keep_alive(true); + g_tunnelHttpCtx->set_post_connect_cb([](socket_t sock) { + if (UsedSCGProxy()) { + int ret; + unsigned char vmid[4]; + unsigned char *p; + const unsigned int id = htonl(GetGlobalCfgInfo()->curConnVmId); + const auto svrId = static_cast(GetGlobalCfgInfo()->userCfg.cliConfig.scgCtrlAppId); + unsigned char scgProxy[] = {0x01, // VERSION + 0x09, // Length + 0xF0, // ++++++ INFO[0] TYPE + 0x04, // INFO[0] LENGTH + 0, // INFO[0] VMID[0] + 0, // INFO[0] VMID[1] + 0, // INFO[0] VMID[2] + 0, // INFO[0] VMID[3] + 0xF1, // INFO[1] TYPE + 0x01, // INFO[1] LENGTH + svrId}; // ------ INFO[1] SCG Service ID + + p = scgProxy; + memcpy(vmid, &id, 4); + scgProxy[4] = vmid[0]; + scgProxy[5] = vmid[1]; + scgProxy[6] = vmid[2]; + scgProxy[7] = vmid[3]; + + if (GetGlobalCfgInfo()->logLevel == spdlog::level::trace) { + std::array arr; + std::copy(std::begin(scgProxy), std::end(scgProxy), std::begin(arr)); + SPDLOG_DEBUG(TEXT("TCP Proxy SCG Payload: {0:Xa}"), spdlog::to_hex(arr)); + } + + ret = send(sock, reinterpret_cast(p), sizeof(scgProxy), 0); + + while (ret < static_cast(sizeof(scgProxy))) { + p += ret; + ret += send(sock, reinterpret_cast(p), sizeof(scgProxy), 0); + } + + SPDLOG_DEBUG(TEXT("Service Connected To SCG Server({1}/{2}): {0}"), + sock, + GetGlobalCfgInfo()->curConnVmId, + svrId); + } + }); + } + + SPDLOG_DEBUG(TEXT("Connect to Tunnel Control Service: {0}"), pUserSvrUrl); + + return ERR_SUCCESS; +} + +template int CreateProtocolRequest(T *pReqParams, TCHAR **pOutJson) { + std::string json; + + if (!g_httpCtx && lstrlen(GetGlobalCfgInfo()->platformServerUrl) > 0) { + g_httpCtx = new httplib::Client(GetGlobalCfgInfo()->platformServerUrl); + if (g_httpCtx) { + g_httpCtx->set_connection_timeout(0, 300000); // 300 milliseconds + g_httpCtx->set_read_timeout(5, 0); // 5 seconds + g_httpCtx->set_write_timeout(5, 0); // 5 seconds + g_httpCtx->set_keep_alive(true); + g_httpCtx->enable_server_certificate_verification(false); + } + } + + if (aigc::JsonHelper::ObjectToJson(*pReqParams, json)) { + *pOutJson = _strdup(json.c_str()); + return ERR_SUCCESS; + } + + return -ERR_JSON_CREATE; +} + +template int DecodeProtocolResponse(T *pResponse, const TCHAR *pJson) { + if (aigc::JsonHelper::JsonToObject(*pResponse, pJson)) { + return ERR_SUCCESS; + } + + return -ERR_JSON_DECODE; +} + +template +int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer) { + int ret; + httplib::Result res; + TCHAR *pJson = nullptr; + + if (lstrlen(GetGlobalCfgInfo()->platformServerUrl) == 0) { + SPDLOG_ERROR(TEXT("Platform Server URL uninitialize.")); + return -ERR_SYSTEM_UNINITIALIZE; + } + + if (pReq == nullptr) { + SPDLOG_ERROR(TEXT("Input pToken params error")); + SPDLOG_ERROR(TEXT("Input ProtocolRequest *pReq params error")); + return -ERR_INPUT_PARAMS; + } + + if (pRsp == nullptr) { + SPDLOG_ERROR(TEXT("Input ProtocolResponse *pRsp params error")); + return -ERR_INPUT_PARAMS; + } + + ret = CreateProtocolRequest(pReq, &pJson); + + if (ret != ERR_SUCCESS) { + if (pJson) { + free(pJson); + } + return ret; + } + + if (platformServer) { + std::string timestamp = std::to_string(time(nullptr)) + "000"; + TCHAR hashValeu[MAX_PATH] = {0}; + TCHAR hashBuf[1024] = {}; + + StringCbPrintf(hashBuf, + 1024, + TEXT("%s|%s|%s|%s"), + GetGlobalCfgInfo()->clientId, + GetGlobalCfgInfo()->clientSecret, + timestamp.c_str(), + pJson); + + if (lstrlen(GetGlobalCfgInfo()->clientSecret) > 0 && + CalcHmacHash(HASH_SHA256, + reinterpret_cast(hashBuf), + lstrlen(hashBuf), + reinterpret_cast(GetGlobalCfgInfo()->clientSecret), + lstrlen(GetGlobalCfgInfo()->clientSecret), + hashValeu, + true) == ERR_SUCCESS) { + const httplib::Headers headers = { + {"gzs-client-id", GetGlobalCfgInfo()->clientId}, + {"gzs-sign", hashValeu }, + {"gzs-timestamp", timestamp }, + }; + res = g_httpCtx->Post(pUrlPath, headers, pJson, HTTP_JSON_CONTENT); + } else { + res = g_httpCtx->Post(pUrlPath, pJson, HTTP_JSON_CONTENT); + } + } else { + if (g_tunnelHttpCtx == nullptr) { + free(pJson); + SPDLOG_ERROR(TEXT("Server Control Service don't connected(g_tunnelHttpCtx is not initialize).")); + return -ERR_SYSTEM_UNINITIALIZE; + } + res = g_tunnelHttpCtx->Post(pUrlPath, pJson, HTTP_JSON_CONTENT); + } + + if (res.error() != httplib::Error::Success) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} error: {2}"), pUrlPath, pJson, httplib::to_string(res.error())); + free(pJson); + return -ERR_HTTP_POST_DATA; + } + + if (res->status != 200) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} server return HTTP error: {2}"), pUrlPath, pJson, res->status); + free(pJson); + return -ERR_HTTP_SERVER_RSP; + } + + SPDLOG_DEBUG(TEXT("+++++ Http Request {0}\n---- Http Response {1}"), pJson, res->body.c_str()); + + free(pJson); + + if (lstrlen(res->body.c_str()) == 0) { + SPDLOG_ERROR(TEXT("Server response empty message")); + return -ERR_READ_FILE; + } + + if (DecodeProtocolResponse(pRsp, res->body.c_str()) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Decode JSON {0} to ProtocolResponse<{1}> error"), res->body, typeid(T2).name()); + return -ERR_JSON_DECODE; + } + + return ERR_SUCCESS; +} + +#if 0 +template int PlatformProtolGetMessage(const TCHAR *pUrlPath, T1 *pRsp) { + httplib::Result res; + TCHAR *pJson = nullptr; + std::string timestamp = std::to_string(time(nullptr)) + "000"; + TCHAR hashValeu[MAX_PATH] = {0}; + TCHAR hashBuf[1024] = {}; + + if (lstrlen(GetGlobalCfgInfo()->platformServerUrl) == 0) { + SPDLOG_ERROR(TEXT("Platform Server URL uninitialize.")); + return -ERR_SYSTEM_UNINITIALIZE; + } + + if (pRsp == nullptr) { + SPDLOG_ERROR(TEXT("Input ProtocolResponse *pRsp params error")); + return -ERR_INPUT_PARAMS; + } + + StringCbPrintf(hashBuf, + 1024, + TEXT("%s|%s|%s|%s"), + GetGlobalCfgInfo()->clientId, + GetGlobalCfgInfo()->clientSecret, + timestamp.c_str(), + pJson); + + if (lstrlen(GetGlobalCfgInfo()->clientSecret) > 0 && + CalcHmacHash(HASH_SHA256, + reinterpret_cast(hashBuf), + lstrlen(hashBuf), + reinterpret_cast(GetGlobalCfgInfo()->clientSecret), + lstrlen(GetGlobalCfgInfo()->clientSecret), + hashValeu, + true) == ERR_SUCCESS) { + const httplib::Headers headers = { + {"gzs-client-id", GetGlobalCfgInfo()->clientId}, + {"gzs-sign", hashValeu }, + {"gzs-timestamp", timestamp }, + }; + res = g_httpCtx->Get(pUrlPath, headers); + } else { + res = g_httpCtx->Get(pUrlPath); + } + + if (res.error() != httplib::Error::Success) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} error: {2}"), pUrlPath, pJson, httplib::to_string(res.error())); + free(pJson); + return -ERR_HTTP_POST_DATA; + } + + if (res->status != 200) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} server return HTTP error: {2}"), pUrlPath, pJson, res->status); + free(pJson); + return -ERR_HTTP_SERVER_RSP; + } + + SPDLOG_DEBUG(TEXT("+++++ Http Request {0}\n---- Http Response {1}"), pJson, res->body.c_str()); + + free(pJson); + + if (lstrlen(res->body.c_str()) == 0) { + SPDLOG_ERROR(TEXT("Server response empty message")); + return -ERR_READ_FILE; + } + + if (DecodeProtocolResponse(pRsp, res->body.c_str()) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Decode JSON {0} to ProtocolResponse<{1}> error"), res->body, typeid(T1).name()); + return -ERR_JSON_DECODE; + } + + return ERR_SUCCESS; +} +#endif + +template int PlatformProtolPostMessage(const TCHAR *pUrlPath, T1 *pReq, T2 *pRsp) { + int ret; + httplib::Result res; + TCHAR *pJson = nullptr; + std::string timestamp = std::to_string(time(nullptr)) + "000"; + TCHAR hashValeu[MAX_PATH] = {0}; + TCHAR hashBuf[1024] = {}; + + if (lstrlen(GetGlobalCfgInfo()->platformServerUrl) == 0) { + SPDLOG_ERROR(TEXT("Platform Server URL uninitialize.")); + return -ERR_SYSTEM_UNINITIALIZE; + } + + if (pReq == nullptr) { + SPDLOG_ERROR(TEXT("Input pToken params error")); + SPDLOG_ERROR(TEXT("Input ProtocolRequest *pReq params error")); + return -ERR_INPUT_PARAMS; + } + + if (pRsp == nullptr) { + SPDLOG_ERROR(TEXT("Input ProtocolResponse *pRsp params error")); + return -ERR_INPUT_PARAMS; + } + + ret = CreateProtocolRequest(pReq, &pJson); + + if (ret != ERR_SUCCESS) { + if (pJson) { + free(pJson); + } + return ret; + } + + StringCbPrintf(hashBuf, + 1024, + TEXT("%s|%s|%s|%s"), + GetGlobalCfgInfo()->clientId, + GetGlobalCfgInfo()->clientSecret, + timestamp.c_str(), + pJson); + + if (lstrlen(GetGlobalCfgInfo()->clientSecret) > 0 && + CalcHmacHash(HASH_SHA256, + reinterpret_cast(hashBuf), + lstrlen(hashBuf), + reinterpret_cast(GetGlobalCfgInfo()->clientSecret), + lstrlen(GetGlobalCfgInfo()->clientSecret), + hashValeu, + true) == ERR_SUCCESS) { + if (typeid(T1) == typeid(PlatformReqClientCfgParms)) { + const auto *p = reinterpret_cast(pReq); + const httplib::Headers headers = { + {"gzs-client-id", GetGlobalCfgInfo()->clientId }, + {"gzs-sign", hashValeu }, + {"gzs-timestamp", timestamp }, + {"Authorization", ("Bearer " + p->token).c_str()}, + }; + + res = g_httpCtx->Post(pUrlPath, headers, pJson, HTTP_JSON_CONTENT); + } else { + const httplib::Headers headers = { + {"gzs-client-id", GetGlobalCfgInfo()->clientId}, + {"gzs-sign", hashValeu }, + {"gzs-timestamp", timestamp }, + }; + + res = g_httpCtx->Post(pUrlPath, headers, pJson, HTTP_JSON_CONTENT); + } + } else { + if (typeid(T1) == typeid(PlatformReqClientCfgParms)) { + const auto *p = reinterpret_cast(pReq); + const httplib::Headers headers = { + {"Authorization", ("Bearer " + p->token).c_str()}, + }; + res = g_httpCtx->Post(pUrlPath, headers, pJson, HTTP_JSON_CONTENT); + } else { + res = g_httpCtx->Post(pUrlPath, pJson, HTTP_JSON_CONTENT); + } + } + + SPDLOG_DEBUG(TEXT("+++++ Http Request {0}\n---- Http Response {1}"), pJson, res->body.c_str()); + + if (res.error() != httplib::Error::Success) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} error: {2}"), pUrlPath, pJson, httplib::to_string(res.error())); + free(pJson); + return -ERR_HTTP_POST_DATA; + } + + if (res->status != 200) { + SPDLOG_ERROR(TEXT("[{0}]:Post Data {1} server return HTTP error: {2}"), pUrlPath, pJson, res->status); + free(pJson); + return -ERR_HTTP_SERVER_RSP; + } + + free(pJson); + + if (lstrlen(res->body.c_str()) == 0) { + SPDLOG_ERROR(TEXT("Server response empty message")); + return -ERR_READ_FILE; + } + + if (DecodeProtocolResponse(pRsp, res->body.c_str()) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Decode JSON {0} to ProtocolResponse<{1}> error"), res->body, typeid(T2).name()); + return -ERR_JSON_DECODE; + } + + return ERR_SUCCESS; +} + +template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +template int ProtolPostMessage(const TCHAR *pUrlPath, + ProtocolRequest *pReq, + ProtocolResponse *pRsp, + bool platformServer); + +#if !USER_REAL_PLATFORM +template int PlatformProtolPostMessage(const TCHAR *pUrlPath, + PlatformReqServerCfgParms *pReq, + PlatformRspServerCfgParams *pRsp); + +template int PlatformProtolPostMessage(const TCHAR *pUrlPath, + PlatformReqClientCfgParms *pReq, + PlatformRspClientCfgParams *pRsp); + +//template int PlatformProtolGetMessage(const TCHAR *pUrlPath, PlatformRspUserClientCfgParams *pRsp); +#endif \ No newline at end of file diff --git a/NetTunnelSDK/tunnel/WireGuardService.cpp b/NetTunnelSDK/tunnel/WireGuardService.cpp new file mode 100644 index 0000000..3f51e2e --- /dev/null +++ b/NetTunnelSDK/tunnel/WireGuardService.cpp @@ -0,0 +1,335 @@ +#include "pch.h" +#include "usrerr.h" +#include "globalcfg.h" +#include "tunnel.h" +#include "wireguard.h" +#include "misc.h" + +#include +#include +#include + +static WIREGUARD_CREATE_ADAPTER_FUNC *WireGuardCreateAdapter; +static WIREGUARD_OPEN_ADAPTER_FUNC *WireGuardOpenAdapter; +static WIREGUARD_CLOSE_ADAPTER_FUNC *WireGuardCloseAdapter; +static WIREGUARD_GET_ADAPTER_LUID_FUNC *WireGuardGetAdapterLUID; +static WIREGUARD_GET_RUNNING_DRIVER_VERSION_FUNC *WireGuardGetRunningDriverVersion; +static WIREGUARD_DELETE_DRIVER_FUNC *WireGuardDeleteDriver; +static WIREGUARD_SET_LOGGER_FUNC *WireGuardSetLogger; +static WIREGUARD_SET_ADAPTER_LOGGING_FUNC *WireGuardSetAdapterLogging; +static WIREGUARD_GET_ADAPTER_STATE_FUNC *WireGuardGetAdapterState; +static WIREGUARD_SET_ADAPTER_STATE_FUNC *WireGuardSetAdapterState; +static WIREGUARD_GET_CONFIGURATION_FUNC *WireGuardGetConfiguration; +static WIREGUARD_SET_CONFIGURATION_FUNC *WireGuardSetConfiguration; + +typedef struct { + WIREGUARD_INTERFACE Interface; + WIREGUARD_PEER RemoteServer; + WIREGUARD_ALLOWED_IP Allow1; + WIREGUARD_ALLOWED_IP Allow2; +} WG_CONFIG_INFO; + +static HMODULE g_WireGarudModule; + +int InitializeWireGuardLibrary() { + TCHAR dllPath[MAX_PATH]; + + StringCbPrintf(dllPath, MAX_PATH, TEXT("%s\\wireguard.dll"), GetGlobalCfgInfo()->workDirectory); + if (!PathFileExists(dllPath)) { + SPDLOG_ERROR(TEXT("WireGuard DLL Not Found: {0}"), dllPath); + return -ERR_ITEM_UNEXISTS; + } + + g_WireGarudModule = LoadLibraryEx(dllPath, + nullptr, + LOAD_LIBRARY_SEARCH_APPLICATION_DIR | LOAD_LIBRARY_SEARCH_SYSTEM32); + if (!g_WireGarudModule) { + DWORD errCode = GetLastError(); + SPDLOG_ERROR(TEXT("LoadLibraryEx WireGuard DLL error: {0}"), errCode); + return -ERR_LOAD_LIBRARY; + } +#define X(Name) ((*(FARPROC *)&(Name) = GetProcAddress(g_WireGarudModule, #Name)) == nullptr) + if (X(WireGuardCreateAdapter) || X(WireGuardOpenAdapter) || X(WireGuardCloseAdapter) || + X(WireGuardGetAdapterLUID) || X(WireGuardGetRunningDriverVersion) || X(WireGuardDeleteDriver) || + X(WireGuardSetLogger) || X(WireGuardSetAdapterLogging) || X(WireGuardGetAdapterState) || + X(WireGuardSetAdapterState) || X(WireGuardGetConfiguration) || X(WireGuardSetConfiguration)) +#undef X + { + SPDLOG_ERROR(TEXT("Map WireGuard DLL EntryPoint error: {0}"), GetLastError()); + FreeLibrary(g_WireGarudModule); + return -ERR_MAP_LIBRARY; + } + return ERR_SUCCESS; +} + +void UnInitializeWireGuardLibrary() { + if (g_WireGarudModule) { + FreeLibrary(g_WireGarudModule); + } +} + +int GetWireGuradTunnelInfo(const TCHAR *pTunnelName) { + WIREGUARD_ADAPTER_HANDLE Adapter; + int ret; + WCHAR wstrName[MAX_PATH]; + + if ((ret = TCharToWideChar(pTunnelName, wstrName, MAX_PATH)) != ERR_SUCCESS) { + return ret; + } + + Adapter = WireGuardOpenAdapter(wstrName); + + if (Adapter) { + WG_CONFIG_INFO config; + + DWORD Bytes = sizeof(WG_CONFIG_INFO); + if (!WireGuardGetConfiguration(Adapter, &config.Interface, &Bytes)) { + SPDLOG_ERROR("Failed to get configuration: {0}", GetLastError()); + } + } + + return ERR_SUCCESS; +} + +int GetWireGuardServiceStatus(const TCHAR *pTunnelName, bool *pIsRunning) { + SC_HANDLE schSCManager; + SC_HANDLE schService; + TCHAR svrName[MAX_PATH]; + + if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) { + SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName); + return -ERR_INPUT_PARAMS; + } + + if (pIsRunning == nullptr) { + SPDLOG_ERROR(TEXT("Input pIsRunning params error")); + return -ERR_INPUT_PARAMS; + } + + *pIsRunning = false; + + StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName); + // Get a handle to the SCM database. + schSCManager = OpenSCManager(nullptr, // local computer + nullptr, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (nullptr == schSCManager) { + SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); + return -ERR_OPEN_SCM; + } + + // Get a handle to the service. + schService = OpenService(schSCManager, // SCM database + svrName, // name of service + SERVICE_ALL_ACCESS); // full access + + CloseServiceHandle(schService); + + // 如果服务不存在则直接返回 + if (schService != nullptr) { + *pIsRunning = true; + } + + return ERR_SUCCESS; +} + +int RemoveGuardService(const TCHAR *pTunnelName, bool bIsWaitStop) { + SC_HANDLE schSCManager; + SC_HANDLE schService; + TCHAR svrName[MAX_PATH]; + SERVICE_STATUS svrStatus; + + if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) { + SPDLOG_ERROR(TEXT("Input pTunnelName error: {0}"), pTunnelName); + return -ERR_INPUT_PARAMS; + } + + StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pTunnelName); + // Get a handle to the SCM database. + schSCManager = OpenSCManager(nullptr, // local computer + nullptr, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (nullptr == schSCManager) { + SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); + return -ERR_OPEN_SCM; + } + + // Get a handle to the service. + schService = OpenService(schSCManager, // SCM database + svrName, // name of service + SERVICE_ALL_ACCESS); // full access + + // 如果服务不存在则直接返回 + if (schService == nullptr) { + CloseServiceHandle(schSCManager); + return ERR_SUCCESS; + } + + if (ControlService(schService, SERVICE_CONTROL_STOP, &svrStatus) == 0) { + DWORD errCode = GetLastError(); + + if (errCode != ERROR_SERVICE_CANNOT_ACCEPT_CTRL && errCode != ERROR_SERVICE_NOT_ACTIVE) { + SPDLOG_ERROR(TEXT("Stop Service {1} failed ({0})"), errCode, svrName); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_STOP_SERVICE; + } + } + + for (int i = 0; bIsWaitStop && i < 10; i++) { + SERVICE_STATUS_PROCESS ssStatus; + DWORD dwBytesNeeded = 0; + + if (QueryServiceStatusEx(schService, // handle to service + SC_STATUS_PROCESS_INFO, // information level + reinterpret_cast(&ssStatus), // address of structure + sizeof(SERVICE_STATUS_PROCESS), // size of structure + &dwBytesNeeded)) // size needed if buffer is too small + { + // 服务已经停止 + if (ssStatus.dwCurrentState == SERVICE_STOPPED) { + break; + } + } + + //SPDLOG_ERROR(TEXT("Stop Service {1} retry times ({0})"), i + 1, svrName); + Sleep(1000); + } + + if (!DeleteService(schService) && GetLastError() != ERROR_SERVICE_MARKED_FOR_DELETE) { + SPDLOG_ERROR(TEXT("Delete Service {1} failed ({0})"), GetLastError(), svrName); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_DELETE_SERVICE; + } + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return ERR_SUCCESS; +} + +int CreateWireGuardService(const TCHAR *pInterfaceName, const TCHAR *pWGConfigFilePath) { + //Service Name: "WireGuardTunnel$SomeTunnelName" + //Display Name: "Some Service Name" + //Service Type: SERVICE_WIN32_OWN_PROCESS + //Start Type: StartAutomatic + //Error Control: ErrorNormal, + //Dependencies: [ "Nsi", "TcpIp" ] + //Sid Type: SERVICE_SID_TYPE_UNRESTRICTED + //Executable: "C:\path\to\example\vpnclient.exe /service configfile.conf" + SC_HANDLE schSCManager; + SC_HANDLE schService; + TCHAR svrName[MAX_PATH]; + TCHAR displayName[MAX_PATH]; + TCHAR svrPath[MAX_PATH]; + TCHAR execParams[MAX_PATH * 2]; + + if (pInterfaceName == nullptr || lstrlen(pInterfaceName) == 0) { + SPDLOG_ERROR(TEXT("Input pInterfaceName error: {0}"), pInterfaceName); + return -ERR_INPUT_PARAMS; + } + + if (pWGConfigFilePath == nullptr || lstrlen(pWGConfigFilePath) == 0) { + SPDLOG_ERROR(TEXT("Input pGUID error: {0}"), pWGConfigFilePath); + return -ERR_INPUT_PARAMS; + } + + if (!PathFileExists(pWGConfigFilePath)) { + SPDLOG_ERROR(TEXT("WireGuard Configure File Not Exist: {0}"), pWGConfigFilePath); + return -ERR_ITEM_UNEXISTS; + } + + StringCbPrintf(svrName, MAX_PATH, TEXT("WireGuardTunnel$%s"), pInterfaceName); + StringCbPrintf(displayName, MAX_PATH, TEXT("WireGuard Tunnel Service %s"), pInterfaceName); + StringCbPrintf(svrPath, MAX_PATH, TEXT("%s\\NetTunnelSvr.exe"), GetGlobalCfgInfo()->workDirectory); + StringCbPrintf(execParams, MAX_PATH * 2, TEXT("\"%s\" /service \"%s\""), svrPath, pWGConfigFilePath); + //SPDLOG_DEBUG(TEXT("Params: {0}"), execParams); + + if (!PathFileExists(svrPath)) { + SPDLOG_ERROR(TEXT("WireGuard Service Not Exist: {0}"), svrPath); + return -ERR_ITEM_UNEXISTS; + } + + // Get a handle to the SCM database. + schSCManager = OpenSCManager(nullptr, // local computer + nullptr, // ServicesActive database + SC_MANAGER_ALL_ACCESS); // full access rights + + if (nullptr == schSCManager) { + SPDLOG_ERROR(TEXT("OpenSCManager failed ({0})"), GetLastError()); + return -ERR_OPEN_SCM; + } + + // Get a handle to the service. + schService = OpenService(schSCManager, // SCM database + svrName, // name of service + SERVICE_ALL_ACCESS); // full access + // 如果服务已经存在则关闭 + if (schService != nullptr) { + int ret; + CloseServiceHandle(schSCManager); + CloseServiceHandle(schService); + ret = RemoveGuardService(pInterfaceName, true); + + if (ret != ERR_SUCCESS) { + return ret; + } + } + + schService = CreateService(schSCManager, + svrName, + displayName, + SC_MANAGER_ALL_ACCESS, + SERVICE_WIN32_OWN_PROCESS, + SERVICE_DEMAND_START, + SERVICE_ERROR_NORMAL, + execParams, + nullptr, + nullptr, + TEXT("Nsi\0TcpIp"), + nullptr, + nullptr); + + if (schService == nullptr) { + SPDLOG_ERROR(TEXT("Create Service {1} failed ({0})"), GetLastError(), svrName); + CloseServiceHandle(schSCManager); + return -ERR_CREATE_SERVICE; + } else { + SERVICE_SID_INFO info; + SERVICE_DESCRIPTIONA desc; + info.dwServiceSidType = SERVICE_SID_TYPE_UNRESTRICTED; + + if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_SERVICE_SID_INFO, &info)) { + SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_SERVICE_SID_INFO Configure failed ({0})"), + GetLastError(), + svrName); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_CONFIG_SERVICE; + } + + desc.lpDescription = TEXT(const_cast("SCC Tunnel Service over WireGuard")); + + if (!ChangeServiceConfig2(schService, SERVICE_CONFIG_DESCRIPTION, &desc)) { + SPDLOG_ERROR(TEXT("Change Service {1} SERVICE_CONFIG_DESCRIPTION Configure failed ({0})"), + GetLastError(), + svrName); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_CONFIG_SERVICE; + } + + if (!StartService(schService, 0, nullptr)) { + SPDLOG_ERROR(TEXT("Start Service {1} failed ({0})"), GetLastError(), svrName); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return -ERR_START_SERVICE; + } + } + + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelSDK/tunnel/tunnel.cpp b/NetTunnelSDK/tunnel/tunnel.cpp new file mode 100644 index 0000000..4bfcec6 --- /dev/null +++ b/NetTunnelSDK/tunnel/tunnel.cpp @@ -0,0 +1,326 @@ +#include "pch.h" +#include "tunnel.h" + +#include +#include +#include +#include +#include +#include + +#include "globalcfg.h" +#include "misc.h" +#include "user.h" + +#include +#include + +#pragma comment(lib, "Dbghelp.lib") + +#define CONFIG_FILE_NAME TEXT("tunnelsdk.ini") + +static SDK_CONFIG g_globalConfig; + +PSDK_CONFIG GetGlobalCfgInfo() { + return &g_globalConfig; +} + +static spdlog::level::level_enum logLevelToSpdlogLevel(LOG_LEVEL level) { + switch (level) { + case LOG_TRACE: + return spdlog::level::level_enum::trace; + case LOG_DEBUG: + return spdlog::level::level_enum::debug; + case LOG_INFO: + return spdlog::level::level_enum::info; + case LOG_WARN: + return spdlog::level::level_enum::warn; + case LOG_ERROR: + return spdlog::level::level_enum::err; + case LOG_CRITICAL: + return spdlog::level::level_enum::critical; + case LOG_OFF: + return spdlog::level::level_enum::off; + } + + return spdlog::level::level_enum::info; +} + +static void InitTunnelSDKLog(const TCHAR *pLogFile, LOG_LEVEL level) { + TCHAR buf[MAX_PATH] = {0}; + + if (pLogFile && strlen(pLogFile) > 0 && !PathIsRelative(pLogFile)) { + TCHAR tmpPath[MAX_PATH]; + StringCbCopy(tmpPath, MAX_PATH, pLogFile); + PathRemoveFileSpec(tmpPath); + MakeSureDirectoryPathExists(tmpPath); + StringCbCopy(buf, MAX_PATH, pLogFile); + } else { + StringCbPrintf(buf, MAX_PATH, TEXT("%s\\tunnelsdklog.log"), g_globalConfig.workDirectory); + } + + g_globalConfig.enableLog = TRUE; + g_globalConfig.logLevel = logLevelToSpdlogLevel(level); + + const auto dupFileFilter = std::make_shared(std::chrono::seconds(5)); + const auto dupStdFilter = std::make_shared(std::chrono::seconds(5)); + + //std::make_shared(buf, 1024 * 1024 * 5, 10)-> + dupFileFilter->add_sink(std::make_shared(buf, 2, 30)); + //dupFileFilter->add_sink(std::make_shared(buf, 1024 * 1024 * 5, 10)); + dupStdFilter->add_sink(std::make_shared()); + + std::vector sinks {dupStdFilter, dupFileFilter}; + auto logger = std::make_shared(TEXT("tunnelSDK"), sinks.begin(), sinks.end()); + spdlog::set_default_logger(logger); + + spdlog::set_level(g_globalConfig.logLevel); + spdlog::set_pattern("[%Y-%m-%d %H:%M:%S.%e][%l][%s:%#] %v"); + spdlog::flush_every(std::chrono::seconds(1)); + +#if 0 + std::cout << "TRACE: " << logger->should_log(spdlog::level::trace) << std::endl; + std::cout << "DEBUG: " << logger->should_log(spdlog::level::debug) << std::endl; + std::cout << "INFO: " << logger->should_log(spdlog::level::info) << std::endl; + std::cout << "WARN: " << logger->should_log(spdlog::level::warn) << std::endl; + std::cout << "ERROR: " << logger->should_log(spdlog::level::err) << std::endl; + std::cout << "CRITICAL: " << logger->should_log(spdlog::level::critical) << std::endl; +#endif + + SPDLOG_INFO(TEXT("Log({1}): {0}"), buf, static_cast(level)); +} + +int TunnelSDKInitEnv(const TCHAR *pWorkDir, + const TCHAR *pSvrUrl, + const TCHAR *pLogFile, + LOG_LEVEL level, + bool isWorkServer) { + int ret; + size_t length; + WSADATA WsaData; + + CoInitialize(nullptr); + CoInitializeSecurity(nullptr, -1, nullptr, nullptr, RPC_C_AUTHN_LEVEL_PKT, RPC_C_IMP_LEVEL_IMPERSONATE, nullptr, + EOAC_NONE, nullptr); + WSAStartup(MAKEWORD(2, 2), &WsaData); + + memset(&g_globalConfig, 0, sizeof(SDK_CONFIG)); + + g_globalConfig.isWorkServer = isWorkServer; + g_globalConfig.scgProxy.scgGwPort = 0; + + if (pWorkDir == nullptr) { + // 获取当前文件默认路径 + GetModuleFileName(nullptr, g_globalConfig.workDirectory, MAX_PATH); + PathRemoveFileSpec(g_globalConfig.workDirectory); + } else { + if (StringCbLength(pWorkDir, MAX_PATH, &length) == S_OK && length == 0 || PathIsRelative(pWorkDir)) { + // 获取当前文件默认路径 + GetModuleFileName(nullptr, g_globalConfig.workDirectory, MAX_PATH); + PathRemoveFileSpec(g_globalConfig.workDirectory); + } else { + MakeSureDirectoryPathExists(pWorkDir); + StringCbCopy(g_globalConfig.workDirectory, MAX_PATH, pWorkDir); + } + } + + // 初始化日志 + InitTunnelSDKLog(pLogFile, level); + + // 创建配置文件存储目录 + if (FAILED(SHGetFolderPath(NULL, CSIDL_APPDATA | CSIDL_FLAG_NO_ALIAS, NULL, 0, g_globalConfig.configDirectory))) { + SPDLOG_ERROR(TEXT("Get Windows system directory error.")); + return -ERR_SYS_CALL; + } + + StringCbCat(g_globalConfig.configDirectory, MAX_PATH, "\\NetTunnel"); + SPDLOG_DEBUG(TEXT("Configure directory: {0}."), g_globalConfig.configDirectory); + SPDLOG_DEBUG(TEXT("Platform Server: {}, Work Module: {}"), pSvrUrl, isWorkServer ? TEXT("SERVER") : TEXT("Client")); + + // 如果配置目录不存在则自动创建 + if (!PathFileExists(g_globalConfig.configDirectory)) { + if (!CreateDirectory(g_globalConfig.configDirectory, nullptr)) { + SPDLOG_ERROR(TEXT("Create configure directory '{0}' error."), g_globalConfig.configDirectory); + return -ERR_CREATE_FILE; + } + } + + StringCbCopy(g_globalConfig.platformServerUrl, MAX_IP_LEN, pSvrUrl); + + if (FAILED(SHGetFolderPath(NULL, CSIDL_WINDOWS | CSIDL_FLAG_NO_ALIAS, NULL, 0, g_globalConfig.systemDirectory))) { + SPDLOG_ERROR(TEXT("Get Windows system directory error.")); + return -ERR_SYS_CALL; + } + + StringCbPrintf(g_globalConfig.cfgPath, MAX_PATH, TEXT("%s\\%s"), g_globalConfig.workDirectory, CONFIG_FILE_NAME); + + if ((ret = InitializeWireGuardLibrary()) == ERR_SUCCESS) { + return ret; + } + +#if 0 + if (FindWireguardExe(nullptr, 0) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("WireGuard not found, Please install WireGuard first or set the WireGuard Path.")); + return -ERR_ITEM_UNEXISTS; + } +#endif + + return ERR_SUCCESS; +} + +void TunnelSDKUnInit() { + RemoteWireGuardControl(false); + LocalWireGuardControl(false, false); + UnInitializeWireGuardLibrary(); + CoFreeUnusedLibraries(); + WSACleanup(); +} + +void DisableVerifySignature() { + memset(g_globalConfig.clientId, 0, MAX_PATH); + memset(g_globalConfig.clientSecret, 0, MAX_PATH); +} + +int EnableVerifySignature(const TCHAR *pClientId, const TCHAR *pClientSecret) { + if (pClientId == nullptr || lstrlen(pClientId) == 0 || lstrlen(pClientId) >= MAX_PATH) { + SPDLOG_ERROR(TEXT("Input pClientId params error: {0}"), pClientId); + return -ERR_INPUT_PARAMS; + } + + if (pClientSecret == nullptr || lstrlen(pClientSecret) == 0 || lstrlen(pClientSecret) >= MAX_PATH) { + SPDLOG_ERROR(TEXT("Input pClientSecret params error: {0}"), pClientSecret); + return -ERR_INPUT_PARAMS; + } + + DisableVerifySignature(); + + StringCbCopy(g_globalConfig.clientId, MAX_PATH, pClientId); + StringCbCopy(g_globalConfig.clientSecret, MAX_PATH, pClientSecret); + + return ERR_SUCCESS; +} + +int EnableSCGProxy(bool isEnable, const TCHAR *pSCGIpAddr, int scgPort) { + + if (pSCGIpAddr == nullptr || lstrlen(pSCGIpAddr) == 0 || lstrlen(pSCGIpAddr) >= MAX_IP_LEN) { + SPDLOG_ERROR(TEXT("Input pInterfaceName params error: {0}"), pSCGIpAddr); + return -ERR_INPUT_PARAMS; + } + + memset(g_globalConfig.scgProxy.scgIpAddr, 0, MAX_IP_LEN); + + if (isEnable) { + IP_INFO ipInfo; + int ret; + if ((ret = GetIpV4InfoFromHostname(AF_INET, pSCGIpAddr, &ipInfo)) != ERR_SUCCESS) { + return ret; + } + + g_globalConfig.userCfg.cliConfig.scgCtrlAppId = WG_CTRL_SCG_ID; + g_globalConfig.userCfg.cliConfig.scgTunnelAppId = WG_TUNNEL_SCG_ID; + + StringCbCopy(g_globalConfig.scgProxy.scgIpAddr, MAX_IP_LEN, ipInfo.hostip); + g_globalConfig.scgProxy.scgGwPort = static_cast(scgPort); + CreateUDPProxyServer(); + } else { + StopUDPProxyServer(); + g_globalConfig.scgProxy.scgGwPort = 0; + } + + return ERR_SUCCESS; +} + +bool UsedSCGProxy() { + return (g_globalConfig.scgProxy.scgGwPort > 0); +} + +void TunnelLogEnable(bool enLog) { + if (enLog) { + spdlog::set_level(g_globalConfig.logLevel); + } else { + spdlog::set_level(spdlog::level::level_enum::off); + } +} + +int SetProtocolEncryptType(const PROTO_CRYPTO_TYPE type, const TCHAR *pProKey) { + if (type > CRYPTO_BASE64 && type < CRYPTO_MAX) { + if (pProKey == nullptr || strlen(pProKey) < MIN_IP_LEN) { + return -ERR_INPUT_PARAMS; + } + } + + g_globalConfig.proCryptoType = type; + StringCbCopy(g_globalConfig.proKeyBuf, 256, pProKey); + + SPDLOG_DEBUG(TEXT("Protocol crypto type: {0} with key [{1}]"), static_cast(type), + pProKey ? pProKey : TEXT("")); + + return ERR_SUCCESS; +} + +//int CheckSystemMinDepend(CHECK_FUNCTION chkItem, TCHAR* pErrMsg, errMsg[MAX_PATH], ); + +int CheckSystemMinRequired(CHK_RESULT chkResult[CHK_MAX]) { + for (int i = 0; i < CHK_MAX; i++) { + const PCHK_RESULT pChk = &chkResult[i]; + + pChk->chk = static_cast(i); + pChk->result = true; + memset(pChk->errMsg, 0, MAX_PATH); + switch (pChk->chk) { + case CHK_SYSTEM_INIT: + if (lstrlen(g_globalConfig.configDirectory) == 0) { + pChk->result = false; + StringCbCopy(pChk->errMsg, MAX_PATH, + TEXT("错误: SDK 未初始化,请先调用 TunnelSDKInitEnv 接口执行初始化操作。")); + SPDLOG_ERROR(pChk->errMsg); + } + break; + case CHK_WIREGUARD_CONFIG: { + TCHAR cfgVal[MAX_PATH]; + GetPrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WGCFG_PATH, TEXT(""), cfgVal, MAX_PATH, + GetGlobalCfgInfo()->cfgPath); + + if (!PathFileExists(cfgVal)) { + pChk->result = false; + StringCbCopy(pChk->errMsg, MAX_PATH, + TEXT("错误: 未找到 WireGuard 配置文件,请先调用 WireGuardInstallServerService 或者 " + "WireGuardCreateClientConfig 接口创建合法的 WireGuard 配置文件。")); + SPDLOG_ERROR(pChk->errMsg); + } + } break; + case CHK_WIREGUARD_SERVICE: { + int ret; + bool bInstall; + ret = IsWireGuardServerInstalled(&bInstall); + if (ret != ERR_SUCCESS) { + pChk->result = false; + StringCbPrintf(pChk->errMsg, MAX_PATH, + TEXT("错误: 获取系统 WireGuard 服务安装状态异常, 错误码:{0}。"), ret); + SPDLOG_ERROR(pChk->errMsg); + } + + if (!bInstall) { + pChk->result = false; + StringCbCopy(pChk->errMsg, MAX_PATH, + TEXT("错误:系统 WireGuard 服务未安装,请调用 WireGuardInstallServerService 接口安装 " + "WireGuard 服务。")); + SPDLOG_ERROR(pChk->errMsg); + } + } break; + case CHK_WG_INTERFACE_PRIVATE: + if (!chkResult[CHK_WIREGUARD_SERVICE].chk) { + pChk->result = false; + StringCbCopy(pChk->errMsg, MAX_PATH, + TEXT("错误:系统 WireGuard 服务未安装,请调用 WireGuardInstallServerService 接口安装 " + "WireGuard 服务。")); + SPDLOG_ERROR(pChk->errMsg); + } + break; + case CHK_MAX: + break; + } + } + + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelSDK/tunnel/wireguard.cpp b/NetTunnelSDK/tunnel/wireguard.cpp new file mode 100644 index 0000000..dfcfe01 --- /dev/null +++ b/NetTunnelSDK/tunnel/wireguard.cpp @@ -0,0 +1,774 @@ +#include "pch.h" +#include "tunnel.h" +#include "usrerr.h" +#include +#include +#include + +#include "globalcfg.h" +#include "misc.h" +#include "network.h" + +#pragma comment(lib, "Shlwapi.lib") +#pragma comment(lib, "Winmm.lib") + +static NET_SHARE_MODE g_CurShareMode = ICS_SHARE_MODE; + +NET_SHARE_MODE GetCurrentNetShareMode() { + return g_CurShareMode; +} + +void SetCurrentNetShareMode(NET_SHARE_MODE shareMode) { + g_CurShareMode = shareMode; +} + +int GetWireGuardWorkMode(bool *pIsWorkServer) { + if (pIsWorkServer == nullptr) { + SPDLOG_ERROR(TEXT("Input pIsWorkServer params error")); + return -ERR_INPUT_PARAMS; + } + + *pIsWorkServer = GetGlobalCfgInfo()->isWorkServer; + + return ERR_SUCCESS; +} + +int WireGuardInstallDefaultServerService(bool bInstall) { + TCHAR cfgVal[MAX_PATH]; + + GetPrivateProfileString(CFG_WIREGUARD_SECTION, + CFG_WGCFG_PATH, + TEXT(""), + cfgVal, + MAX_PATH, + GetGlobalCfgInfo()->cfgPath); + + if (lstrlen(cfgVal) > 0) { + if (PathFileExists(cfgVal)) { + int ret; + TCHAR svrName[MAX_PATH]; + + StringCbCopy(svrName, MAX_PATH, cfgVal); + PathStripPath(svrName); + PathRemoveExtension(svrName); + + if (bInstall) { + ret = WireGuardInstallServerService(cfgVal); //CreateWireGuardService(svrName, cfgVal); + } else { + ret = RemoveGuardService(svrName, true); + } + + if (bInstall && ret == ERR_SUCCESS) { + int retry = 10; + do { + ret = WaitNetAdapterConnected(svrName, 1000); + } while (ret != ERR_SUCCESS && retry--); + } + + return ret; + } else { + SPDLOG_ERROR(TEXT("WireGuard configure file [{0}] not found"), cfgVal); + return -ERR_FILE_NOT_EXISTS; + } + } else { + SPDLOG_ERROR(TEXT("Configure [{0}] = {1} not found"), CFG_WGCFG_PATH, cfgVal); + return -ERR_ITEM_UNEXISTS; + } +} + +int WireGuardInstallServerService(const TCHAR *pTunnelCfgPath) { + // 卸载服务 + TCHAR svrName[MAX_PATH]; + int ret; + + StringCbCopy(svrName, MAX_PATH, pTunnelCfgPath); + PathStripPath(svrName); + PathRemoveExtension(svrName); + + if (pTunnelCfgPath == nullptr || lstrlen(pTunnelCfgPath) == 0) { + SPDLOG_ERROR(TEXT("Input pTunnelCfgPath params error")); + return -ERR_INPUT_PARAMS; + } + if (!PathFileExists(pTunnelCfgPath)) { + SPDLOG_ERROR(TEXT("WireGuard configure file {0} unexists."), pTunnelCfgPath); + return -ERR_ITEM_UNEXISTS; + } + + if ((ret = CreateWireGuardService(svrName, pTunnelCfgPath)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Create WireGuard Service Error({0}): {1}, {2} "), ret, svrName, pTunnelCfgPath); + return ret; + } + + return ERR_SUCCESS; +} + +int WireGuardUnInstallServerService(const TCHAR *pTunnelName) { + // 卸载服务 + int ret; + + if (pTunnelName == nullptr || lstrlen(pTunnelName) == 0) { + SPDLOG_ERROR(TEXT("Input pTunnelName params error")); + return -ERR_INPUT_PARAMS; + } + + if ((ret = RemoveGuardService(pTunnelName, true)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Stop WireGuard Service Error: {0}"), ret); + return ret; + } + + return ERR_SUCCESS; +} + +int IsWireGuardServerInstalled(bool *pIsInstalled) { + DWORD dwStatus; + int ret; + + if (pIsInstalled == nullptr) { + SPDLOG_ERROR(TEXT("Input pIsInstalled params error")); + return -ERR_INPUT_PARAMS; + } + + *pIsInstalled = false; + + ret = GetWindowsServiceStatus(TEXT("WireGuard"), &dwStatus); + + if (ret == ERR_SUCCESS) { + switch (dwStatus) { + case SERVICE_CONTINUE_PENDING: + case SERVICE_RUNNING: + case SERVICE_START_PENDING: + *pIsInstalled = true; + break; + default: + *pIsInstalled = false; + break; + } + } + + return ret; +} + +int IsWireGuardServerRunning(const TCHAR *pIfName, bool *pIsRunning) { + return GetWireGuardServiceStatus(pIfName, pIsRunning); +} + +int WireGuardCreateClientConfig(const PWGCLIENT_CONFIG pWgConfig) { + const size_t bufSize = 4096 * sizeof(TCHAR); + const TCHAR cfgFormat[] = TEXT( + "[Interface]\nPrivateKey = %s\nAddress = %s\n\n[Peer]\nPublicKey = %s\nAllowedIPs = %s\nEndpoint = " + "%s\nPersistentKeepalive = 30\n"); + TCHAR cfgPath[MAX_PATH]; + size_t length; + HANDLE hFile; + TCHAR *pBuf; + +#pragma region + if (pWgConfig == nullptr) { + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->Name, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Name error: {0}"), pWgConfig->Name); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->Address, 32, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Address error: {0}"), pWgConfig->Address); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->PrivateKey, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Private key error: {0}"), pWgConfig->PrivateKey); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->SvrPubKey, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Server Public key error: {0}"), pWgConfig->SvrPubKey); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->AllowNet, 256, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Allow Client Network error: {0}"), pWgConfig->AllowNet); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->ServerURL, 256, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Server Network error: {0}"), pWgConfig->ServerURL); + return -ERR_INPUT_PARAMS; + } +#pragma endregion 参数检查 + + pBuf = static_cast(malloc(bufSize)); + + if (pBuf == nullptr) { + SPDLOG_ERROR(TEXT("Malloc {1} bytes memory error: {0}"), GetLastError(), bufSize); + return -ERR_MALLOC_MEMORY; + } + + memset(pBuf, 0, bufSize); + + StringCbPrintf(cfgPath, + MAX_PATH, + "%s\\%s", + GetGlobalCfgInfo()->configDirectory, + GetGlobalCfgInfo()->userCfg.userName); + + // 如果当前用户配置目录不存在则自动创建 + if (!PathFileExists(cfgPath)) { + if (!CreateDirectory(cfgPath, nullptr)) { + SPDLOG_ERROR(TEXT("Create configure directory '{0}' error."), cfgPath); + return -ERR_CREATE_FILE; + } + } + + StringCbPrintf(cfgPath, + MAX_PATH, + "%s\\%s\\%s.conf", + GetGlobalCfgInfo()->configDirectory, + GetGlobalCfgInfo()->userCfg.userName, + pWgConfig->Name); + + hFile = CreateFile(cfgPath, // name of the write + GENERIC_WRITE | GENERIC_READ, // open for writing + FILE_SHARE_READ, // do not share + nullptr, // default security + CREATE_ALWAYS, // create new file only + FILE_ATTRIBUTE_NORMAL, // normal file + nullptr); // no attr. template + + if (hFile == INVALID_HANDLE_VALUE) { + SPDLOG_ERROR(TEXT("CreatFile [{0}] error: {1}"), cfgPath, GetLastError()); + free(pBuf); + return -ERR_OPEN_FILE; + } + + // 保存到配置文件中 + WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WGCFG_PATH, cfgPath, GetGlobalCfgInfo()->cfgPath); + + if (FAILED(StringCbPrintf(pBuf, + bufSize, + cfgFormat, + pWgConfig->PrivateKey, + pWgConfig->Address, + pWgConfig->SvrPubKey, + pWgConfig->AllowNet, + pWgConfig->ServerURL))) { + SPDLOG_ERROR(TEXT("Format string error: {0}"), GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_MEMORY_STR; + } + + if (FAILED(StringCbLength(pBuf, bufSize, &length))) { + SPDLOG_ERROR(TEXT("Get string \'{0}\' length error: {1}"), pBuf, GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_MEMORY_STR; + } + + SPDLOG_DEBUG(TEXT("WG Client Configure:\n{0}"), pBuf); + + if (!WriteFile(hFile, // open file handle + pBuf, // start of data to write + static_cast(length), // number of bytes to write + nullptr, // number of bytes that were written + nullptr)) { + SPDLOG_ERROR(TEXT("WriteFile [{0}] error: {1}"), cfgPath, GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_OPEN_FILE; + } + + StringCbCopy(GetGlobalCfgInfo()->wgClientCfg.wgName, 260, pWgConfig->Name); + StringCbCopy(GetGlobalCfgInfo()->wgClientCfg.wgIpaddr, MAX_IP_LEN, pWgConfig->Address); + StringCbCopy(GetGlobalCfgInfo()->wgClientCfg.wgCfgPath, MAX_PATH, cfgPath); + + ::CloseHandle(hFile); + return ERR_SUCCESS; +} + +int WireGuardCreateServerConfig(const PWGSERVER_CONFIG pWgConfig) { + const size_t bufSize = 4096 * sizeof(TCHAR); + const TCHAR cfgFormat[] = TEXT( + "[Interface]\nAddress = %s\nListenPort = %d\nPrivateKey = %s\n\n[Peer]\nPublicKey = %s\nAllowedIPs = %s\n"); + TCHAR cfgPath[MAX_PATH]; + size_t length; + HANDLE hFile; + TCHAR *pBuf; + +#pragma region + if (pWgConfig == nullptr) { + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->Name, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Name error: {0}"), pWgConfig->Name); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->Address, 32, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Address error: {0}"), pWgConfig->Address); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->PrivateKey, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Private key error: {0}"), pWgConfig->PrivateKey); + return -ERR_INPUT_PARAMS; + } + + if (pWgConfig->ListenPort <= 1024 || pWgConfig->ListenPort >= 65535) { + SPDLOG_ERROR(TEXT("WireGuard Listen port error: {0}, should be in arrange (1024, 65535)"), + pWgConfig->ListenPort); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->CliPubKey, 64, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Client Public key error: {0}"), pWgConfig->CliPubKey); + return -ERR_INPUT_PARAMS; + } + + if (FAILED(StringCbLength(pWgConfig->AllowNet, 256, &length)) || 0 == length) { + SPDLOG_ERROR(TEXT("WireGuard Allow Client Network error: {0}"), pWgConfig->AllowNet); + return -ERR_INPUT_PARAMS; + } +#pragma endregion 参数检查 + + pBuf = static_cast(malloc(bufSize)); + + if (pBuf == nullptr) { + SPDLOG_ERROR(TEXT("Malloc {1} bytes memory error: {0}"), GetLastError(), bufSize); + return -ERR_MALLOC_MEMORY; + } + + memset(pBuf, 0, bufSize); + + StringCbPrintf(cfgPath, + MAX_PATH, + "%s\\%s", + GetGlobalCfgInfo()->configDirectory, + GetGlobalCfgInfo()->userCfg.userName); + + // 如果当前用户配置目录不存在则自动创建 + if (!PathFileExists(cfgPath)) { + if (!CreateDirectory(cfgPath, nullptr)) { + SPDLOG_ERROR(TEXT("Create configure directory '{0}' error."), cfgPath); + return -ERR_CREATE_FILE; + } + } + + StringCbPrintf(cfgPath, + MAX_PATH, + "%s\\%s\\%s.conf", + GetGlobalCfgInfo()->configDirectory, + GetGlobalCfgInfo()->userCfg.userName, + pWgConfig->Name); + + hFile = CreateFile(cfgPath, // name of the write + GENERIC_WRITE | GENERIC_READ, // open for writing + FILE_SHARE_READ, // do not share + nullptr, // default security + CREATE_ALWAYS, // create new file only + FILE_ATTRIBUTE_NORMAL, // normal file + nullptr); // no attr. template + + if (hFile == INVALID_HANDLE_VALUE) { + SPDLOG_ERROR(TEXT("CreatFile [{0}] error: {1}"), cfgPath, GetLastError()); + free(pBuf); + return -ERR_OPEN_FILE; + } + + // 清空文件 + SetFilePointer(hFile, 0, nullptr, FILE_BEGIN); + SetEndOfFile(hFile); + + WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WGCFG_PATH, cfgPath, GetGlobalCfgInfo()->cfgPath); + + if (FAILED(StringCbPrintf(pBuf, + bufSize, + cfgFormat, + pWgConfig->Address, + pWgConfig->ListenPort, + pWgConfig->PrivateKey, + pWgConfig->CliPubKey, + pWgConfig->AllowNet))) { + SPDLOG_ERROR(TEXT("Format string error: {0}"), GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_MEMORY_STR; + } + + if (FAILED(StringCbLength(pBuf, bufSize, &length))) { + SPDLOG_ERROR(TEXT("Get string \'{0}\' length error: {1}"), pBuf, GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_MEMORY_STR; + } + + SPDLOG_DEBUG(TEXT("WG Server Configure:\n{0}"), pBuf); + + if (FALSE == + WriteFile(hFile, // open file handle + pBuf, // start of data to write + static_cast(length), // number of bytes to write + nullptr, // number of bytes that were written + nullptr)) // no overlapped structure) + { + SPDLOG_ERROR(TEXT("WriteFile [{0}] error: {1}"), cfgPath, GetLastError()); + free(pBuf); + ::CloseHandle(hFile); + return -ERR_OPEN_FILE; + } + + ::CloseHandle(hFile); + + StringCbCopy(GetGlobalCfgInfo()->wgServerCfg.wgName, 260, pWgConfig->Name); + StringCbCopy(GetGlobalCfgInfo()->wgServerCfg.wgIpaddr, MAX_IP_LEN, pWgConfig->Address); + StringCbCopy(GetGlobalCfgInfo()->wgServerCfg.wgCfgPath, MAX_PATH, cfgPath); + return ERR_SUCCESS; +} + +#if 0 + +/** + * @brief 创建 WireGuard 密钥对 + * @param[out] pPubKey 公钥缓冲区 + * @param[in] pubkeySize 公钥缓冲区大小(字节数) + * @param[out] pPrivKey 私钥缓冲区 + * @param[in] privKeySize 私钥缓冲区大小(字节数) + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_ITEM_UNEXISTS WireGuard 未配置或未安装 + * - -ERR_CALL_SHELL 调用操作系统命令行工具失败 + * - ERR_SUCCESS 成功 + */ +int GenerateWireguardKeyPairs(TCHAR *pPubKey, int pubkeySize, TCHAR *pPrivKey, int privKeySize) { + + int ret; + DWORD retCode; + TCHAR cmdBuffer[MAX_PATH]; + TCHAR cmdResult[MAX_PATH]; + PSDK_CONFIG pCfg = GetGlobalCfgInfo(); + + // WireGuard 不存在或者未配置目录 + if (!pCfg->wireguardCfg.wgExists || !pCfg->wireguardCfg.wireguardExists) { + return -ERR_ITEM_UNEXISTS; + } + + memset(cmdBuffer, 0, MAX_PATH); + memset(cmdResult, 0, MAX_PATH); + + StringCbPrintf(cmdBuffer, MAX_PATH, TEXT("cmd.exe /C \"%s\" genkey"), pCfg->wireguardCfg.wgPath); + + if ((ret = RunCommand(cmdBuffer, cmdResult, MAX_PATH, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Run command [{0}] error: {1}"), cmdBuffer, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG(TEXT("Run command [{0}] resutl \'{1}\'"), cmdBuffer, cmdResult); + + StringCbCopy(pPrivKey, privKeySize, cmdResult); + memset(cmdBuffer, 0, MAX_PATH); + StringCbPrintf(cmdBuffer, + MAX_PATH, + TEXT("cmd.exe /C echo %s | \"%s\" pubkey"), + cmdResult, + pCfg->wireguardCfg.wgPath); + + memset(cmdResult, 0, MAX_PATH); + if ((ret = RunCommand(cmdBuffer, cmdResult, MAX_PATH, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Run command [{0}] error: {1}"), cmdBuffer, ret); + return -ERR_CALL_SHELL; + } + + StringCbCopy(pPubKey, pubkeySize, cmdResult); + SPDLOG_DEBUG(TEXT("Run command [{0}] resutl \'{1}\'"), cmdBuffer, cmdResult); + + return ERR_SUCCESS; +} + +/** + * @brief 设置 wireguard.exe 程序路径 + * @param[in] pPath wireguard.exe 程序路径 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_ITEM_UNEXISTS 文件不存在 + * - ERR_SUCCESS 成功 + */ +int SetWireguardPath(const TCHAR *pPath) { + if (pPath == nullptr) { + return -ERR_INPUT_PARAMS; + } + + if (PathFileExists(pPath)) { + TCHAR wgPath[MAX_PATH]; + + SPDLOG_DEBUG(TEXT("Used configure file:{0}"), GetGlobalCfgInfo()->cfgPath); + + WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WIREGUARD_PATH, pPath, GetGlobalCfgInfo()->cfgPath); + SPDLOG_DEBUG(TEXT("Save configure: {1} --> {0}"), pPath, CFG_WIREGUARD_PATH); + + StringCbCopy(wgPath, MAX_PATH, pPath); + + if (TCHAR *pIndex = _tcsrchr(wgPath, '\\')) { + *pIndex = 0; + StringCbCat(wgPath, MAX_PATH, "\\wg.exe"); + WritePrivateProfileString(CFG_WIREGUARD_SECTION, CFG_WG_PATH, wgPath, GetGlobalCfgInfo()->cfgPath); + SPDLOG_DEBUG(TEXT("Save configure: {1} --> {0}"), wgPath, CFG_WG_PATH); + } + + return ERR_SUCCESS; + } else { + SPDLOG_ERROR(TEXT("WireGuard not found: {0}"), pPath); + return -ERR_ITEM_UNEXISTS; + } +} + +/** + * @brief 查找 WireGuard 运行环境 + * @param[out] pFullPath wireguard.exe 程序路径 + * @param[in] maxSize pFullPath 缓冲区最大字节数 + * @return 函数执行结果 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_MALLOC_MEMORY 分配内存失败 + * - -ERR_FILE_NOT_EXISTS 文件不存在 + * - ERR_SUCCESS 成功 + */ +int FindWireguardExe(TCHAR *pFullPath, int maxSize) { + TCHAR path[MAX_PATH]; + TCHAR wireguardPath[MAX_PATH]; + DWORD dwRet; + LPSTR pEnvBuf; + TCHAR *token, *p = nullptr; + + GetPrivateProfileString(CFG_WIREGUARD_SECTION, + CFG_WIREGUARD_PATH, + TEXT(""), + wireguardPath, + MAX_PATH, + GetGlobalCfgInfo()->cfgPath); + + if (PathFileExists(wireguardPath)) { + if (pFullPath && maxSize > 0) { + StringCbCopy(pFullPath, maxSize, wireguardPath); + } + + StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wireguardPath, MAX_PATH, wireguardPath); + GetGlobalCfgInfo()->wireguardCfg.wireguardExists = TRUE; + + SPDLOG_DEBUG(TEXT("Ini found WireGuard at: {0}"), wireguardPath); + + GetPrivateProfileString(CFG_WIREGUARD_SECTION, + CFG_WG_PATH, + TEXT(""), + wireguardPath, + MAX_PATH, + GetGlobalCfgInfo()->cfgPath); + + if (PathFileExists(wireguardPath)) { + StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wgPath, MAX_PATH, wireguardPath); + GetGlobalCfgInfo()->wireguardCfg.wgExists = TRUE; + SPDLOG_DEBUG(TEXT("Ini found WireGuard Tools at: {0}"), wireguardPath); + } + + return ERR_SUCCESS; + } + + StringCbCopy(wireguardPath, MAX_PATH, GetGlobalCfgInfo()->systemDirectory); + PathStripToRoot(wireguardPath); + StringCbCat(wireguardPath, MAX_PATH, TEXT("Program Files\\WireGuard\\wireguard.exe")); + + if (PathFileExists(wireguardPath)) { + // 保存路径到配置文件 + SetWireguardPath(wireguardPath); + GetGlobalCfgInfo()->wireguardCfg.wireguardExists = TRUE; + + PathRemoveFileSpec(wireguardPath); + StringCbCat(wireguardPath, MAX_PATH, TEXT("\\wg.exe")); + + if (PathFileExists(wireguardPath)) { + GetGlobalCfgInfo()->wireguardCfg.wgExists = TRUE; + } + + return ERR_SUCCESS; + } + + // 从环境变量中查找 + pEnvBuf = static_cast(malloc(WINENVBUF_SIZE)); + if (nullptr == pEnvBuf) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), WINENVBUF_SIZE); + return -ERR_MALLOC_MEMORY; + } + + dwRet = GetEnvironmentVariable(TEXT("path"), pEnvBuf, WINENVBUF_SIZE); + + if (0 == dwRet) { + DWORD dwErr; + dwErr = GetLastError(); + if (ERROR_ENVVAR_NOT_FOUND == dwErr) { + SPDLOG_DEBUG(TEXT("Environment variable path does not exist.")); + free(pEnvBuf); + return -ERR_FILE_NOT_EXISTS; + } + } else if (WINENVBUF_SIZE < dwRet) { + const auto pBuf = static_cast(realloc(pEnvBuf, dwRet * sizeof(CHAR))); + if (nullptr == pBuf) { + SPDLOG_ERROR(TEXT("Malloc {0} bytes memory error"), dwRet * sizeof(CHAR)); + free(pEnvBuf); + return -ERR_MALLOC_MEMORY; + } + + pEnvBuf = pBuf; + dwRet = GetEnvironmentVariable("path", pEnvBuf, dwRet); + if (!dwRet) { + SPDLOG_ERROR(TEXT("GetEnvironmentVariable failed (%d)"), GetLastError()); + free(pEnvBuf); + return -ERR_FILE_NOT_EXISTS; + } + } + + token = strtok_s(pEnvBuf, TEXT(";"), &p); + + while (token != nullptr) { + memset(path, 0, MAX_PATH); + StringCbPrintfA(path, MAX_PATH, TEXT("%s\\wireguard.exe"), token); + + if (PathFileExists(path)) { + if (pFullPath && maxSize > 0) { + StringCbCopy(pFullPath, maxSize, path); + } + + // 保存路径到配置文件 + SetWireguardPath(path); + SPDLOG_DEBUG(TEXT("Path Environment found WireGuard at: {0}"), path); + + StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wireguardPath, MAX_PATH, wireguardPath); + GetGlobalCfgInfo()->wireguardCfg.wireguardExists = TRUE; + + memset(path, 0, MAX_PATH); + StringCbPrintf(path, MAX_PATH, TEXT("%s\\wg.exe"), token); + + SPDLOG_DEBUG(TEXT("Find WireGuard tools at: {0}"), path); + + if (PathFileExists(path)) { + StringCbCopy(GetGlobalCfgInfo()->wireguardCfg.wgPath, MAX_PATH, path); + GetGlobalCfgInfo()->wireguardCfg.wgExists = TRUE; + + SPDLOG_DEBUG(TEXT("Path Environment found WireGuard tools at: {0}"), path); + } + + //TODO: throw exception by C# call, why?????? + //CloseHandle(hFind); + free(pEnvBuf); + return ERR_SUCCESS; + } + token = strtok_s(nullptr, TEXT(";"), &p); + } + + free(pEnvBuf); + return -ERR_FILE_NOT_EXISTS; +} + +/** + * @brief 安装 Windows NAT 自定义 PowerShell 命令 + * @return 0: 成功, 小于0 失败 @see USER_ERRNO + * - -ERR_INPUT_PARAMS 输入参数错误 + * - -ERR_OPEN_FILE 打开文件失败 + * - -ERR_MEMORY_STR 字符串处理 + * - -ERR_ITEM_UNEXISTS 资源不存在 + * - -ERR_CALL_SHELL 调用系统命令行失败 + * - ERR_SUCCESS 成功 + */ +int InstallWindowsNATCommand() { + TCHAR psCmdPath[MAX_PATH]; + + // 如果已经安装则退出 + if (IsCustomNatPSCmdInstalled()) { + return ERR_SUCCESS; + } + + StringCbPrintf(psCmdPath, + MAX_PATH, + TEXT("%s\\system32\\WindowsPowerShell\\v1.0\\Modules\\wireguard\\wireguard.psm1"), + GetGlobalCfgInfo()->systemDirectory); + + const HANDLE hFile = CreateFile(psCmdPath, + GENERIC_WRITE | GENERIC_READ, + FILE_SHARE_READ, + nullptr, + CREATE_ALWAYS, + FILE_ATTRIBUTE_NORMAL, + nullptr); + + if (hFile == INVALID_HANDLE_VALUE) { + SPDLOG_ERROR("CreatFile [{0}] error: {1}", psCmdPath, GetLastError()); + return -ERR_OPEN_FILE; + } + + const HMODULE hMod = GetModuleHandle(TEXT("NetTunnelSDK.dll")); + + if (nullptr == hMod) { + SPDLOG_ERROR(TEXT("Load NetTunnelSDK.dll module error")); + ::CloseHandle(hFile); + return -ERR_ITEM_UNEXISTS; + } + + const HRSRC hRsrc = FindResource(hMod, MAKEINTRESOURCE(PSCMD_RES_ID), TEXT("TXT")); + if (nullptr == hRsrc) { + SPDLOG_ERROR(TEXT("Donot found resource {0} of type {1}"), PSCMD_RES_ID, TEXT("TXT")); + ::CloseHandle(hFile); + return -ERR_ITEM_UNEXISTS; + } + + const DWORD resSize = SizeofResource(hMod, hRsrc); + + if (resSize == 0) { + SPDLOG_ERROR(TEXT("Resource {0} of type {1} is empty"), PSCMD_RES_ID, TEXT("TXT")); + ::CloseHandle(hFile); + return -ERR_ITEM_UNEXISTS; + } + + const HGLOBAL hGlobal = LoadResource(hMod, hRsrc); + + if (hGlobal == hRsrc) { + SPDLOG_ERROR(TEXT("Load resource {0} of type {1} error"), PSCMD_RES_ID, TEXT("TXT")); + ::CloseHandle(hFile); + return -ERR_ITEM_UNEXISTS; + } + + if (const LPVOID pBuffer = LockResource(hGlobal)) { + if (!WriteFile(hFile, // open file handle + pBuffer, // start of data to write + static_cast(resSize), // number of bytes to write + nullptr, // number of bytes that were written + nullptr)) { + SPDLOG_ERROR("WriteFile [{0}] error: {1}", psCmdPath, GetLastError()); + GlobalUnlock(hGlobal); + ::CloseHandle(hFile); + return -ERR_OPEN_FILE; + } + } + + GlobalUnlock(hGlobal); + ::CloseHandle(hFile); + + return ERR_SUCCESS; +} + +int WireGuardNetConnectionSharingEnable() { + int ret; + DWORD retCode; + TCHAR cmdResult[MAX_PATH] = {}; + TCHAR cmdBuf[] = TEXT("PowerShell -NoProfile -NonInteractive -WindowStyle Hidden -ExecutionPolicy Bypass Invoke-Command -ArgumentList 'wg_cli' -ScriptBlock {param($IFNAME);$netShare = New-Object -ComObject HNetCfg.HNetShare;" + "$privateConnection = $netShare.EnumEveryConnection |? { $netShare.NetConnectionProps.Invoke($_).Name -eq 'wg_cli' };" + "$privateConfig = $netShare.INetSharingConfigurationForINetConnection.Invoke($privateConnection);" + "Write-Output $privateConfig}"); + + if ((ret = RunCommand(cmdBuf, cmdResult, MAX_PATH, &retCode)) != ERR_SUCCESS) { + SPDLOG_ERROR("Run command [{0}] error: {1}", cmdBuf, ret); + return -ERR_CALL_SHELL; + } + + SPDLOG_DEBUG("Run command [{0}] resutl \'{1}\' return {2}", cmdBuf, cmdResult, retCode); + + return ERR_SUCCESS; +} +#endif \ No newline at end of file diff --git a/NetTunnelSDK/user/UserManager.cpp b/NetTunnelSDK/user/UserManager.cpp new file mode 100644 index 0000000..f66e933 --- /dev/null +++ b/NetTunnelSDK/user/UserManager.cpp @@ -0,0 +1,685 @@ +#include "pch.h" + +#include "tunnel.h" +#include "usrerr.h" +#include "globalcfg.h" +#include "httplib.h" +#include "misc.h" +#include "network.h" +#include "protocol.h" +#include "user.h" + +#include +#include + +static HANDLE g_HeartTimerQueue = nullptr; +static LPTUNNEL_HEART_ROUTINE g_lpHeartCb = nullptr; + +int RemoteHeartControl(bool isStart, LPTUNNEL_HEART_ROUTINE lpHeartCbAddress) { + if (isStart && lpHeartCbAddress == nullptr) { + SPDLOG_ERROR(TEXT("Input lpHeartCbAddress params nullptr")); + return -ERR_INPUT_PARAMS; + } + + g_lpHeartCb = lpHeartCbAddress; + + if (isStart) { + if (!g_HeartTimerQueue) { + HANDLE hTimer = nullptr; + // Create the timer queue. + g_HeartTimerQueue = CreateTimerQueue(); + if (nullptr == g_HeartTimerQueue) { + SPDLOG_ERROR(TEXT("CreateTimerQueue failed ({0})"), GetLastError()); + return -ERR_CREATE_TIMER; + } + + // Set a timer to call the timer routine in 10 seconds. + if (!CreateTimerQueueTimer( + &hTimer, + g_HeartTimerQueue, + [](PVOID lpParam, BOOLEAN TimerOrWaitFired) { + int ret; + ProtocolRequest req; + ProtocolResponse rsp; + + ret = ProtolPostMessage(SET_CLIENTHEART_PATH, &req, &rsp, false); + if (g_lpHeartCb && ret) { + g_lpHeartCb(rsp.msgContent.message.c_str(), rsp.timeStamp); + } + }, + nullptr, + 0, + HEART_PERIOD_MS, + WT_EXECUTEDEFAULT)) { + SPDLOG_ERROR(TEXT("CreateTimerQueueTimer failed ({0})"), GetLastError()); + return -ERR_CREATE_TIMER; + } + } + } else { + if (g_HeartTimerQueue) { + if (!DeleteTimerQueue(g_HeartTimerQueue)) { + SPDLOG_ERROR(TEXT("DeleteTimerQueue failed ({0})"), GetLastError()); + g_HeartTimerQueue = nullptr; + return -ERR_DELETE_TIMER; + } + + g_HeartTimerQueue = nullptr; + } + } + + return ERR_SUCCESS; +} + +int RemoteWireGuardControl(bool isStart) { + int ret; + ProtocolRequest req; + ProtocolResponse rsp; + + req.msgContent.isStart = isStart; + + ret = ProtolPostMessage(SET_CLIENTSTART_TUNNEL, &req, &rsp, false); + + if (ret != ERR_SUCCESS) { + return ret; + } + + if (rsp.msgContent.errCode != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Service Response error({0}): {1}"), rsp.msgContent.errCode, rsp.msgContent.errMessage); + return rsp.msgContent.errCode; + } + + return ERR_SUCCESS; +} + +int LocalWireGuardControl(bool isStart, bool setPrivateMode) { + int ret; + bool chkStatus = false; + int ifInetlnetIndex, ifWireGuardIndex; + + // 获取 Intelnet 网络网卡 Index + if ((ret = GetInternetIfIndex(&ifInetlnetIndex)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetInternetIfIndex error: {0}"), ret); + return ret; + } + + // 判断先前是否启动过服务 + if ((ret = IsWireGuardServerRunning(GetGlobalCfgInfo()->userCfg.userName, &chkStatus)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call IsWireGuardServerInstalled error: {0}"), ret); + return ret; + } + + // 先停止以前启动的隧道服务 + if (chkStatus) { + if ((ret = WireGuardUnInstallServerService(GetGlobalCfgInfo()->userCfg.userName)) != ERR_SUCCESS) { + // 返回停止服务失败 + SPDLOG_ERROR(TEXT("Call WireGuardUnInstallServerService error: {0}"), ret); + return ret; + } + } + + // 检查 Internet 网络共享状态 + if ((ret = GetNetIntelnetConnectionSharing(ifInetlnetIndex, &chkStatus)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + // 关闭 Intelnet 网络连接共享 + if (chkStatus) { + if ((ret = SetNetIntelnetConnectionSharing(ifInetlnetIndex, false, false)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call SetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + } + + if (isStart) { + // 启动服务 + ret = WireGuardInstallDefaultServerService(true); + if (ret != ERR_SUCCESS) { + // 返回启动服务失败 + SPDLOG_ERROR(TEXT("Call WireGuardInstallDefaultServerService error: {0}"), ret); + return ret; + } + + if (GetCurrentNetShareMode() == ICS_SHARE_MODE) { + // 获取 WireGuard 隧道网络网卡 Index + if ((ret = GetInterfaceIfIndexByName(GetGlobalCfgInfo()->userCfg.userName, &ifWireGuardIndex)) != + ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetInterfaceIfIndexByName error: {0}"), ret); + return ret; + } + + // 启动 WireGard 网络 ICS 服务为私有网络 + if ((ret = SetNetIntelnetConnectionSharing(ifWireGuardIndex, true, true)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call SetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + // 启动 Intelnet 网络 ICS 服务为公共网络 + if ((ret = SetNetIntelnetConnectionSharing(ifInetlnetIndex, true, false)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call SetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + // 校验 ICS 共享状态 + // 检查 WireGuard 网络共享状态 + if ((ret = GetNetIntelnetConnectionSharing(ifWireGuardIndex, &chkStatus)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + if (!chkStatus) { + SPDLOG_ERROR(TEXT("WireGuard network ICS error")); + return -ERR_NET_WIREGUARD_ICS; + } + + // 检查 WireGuard 网络共享状态 + if ((ret = GetNetIntelnetConnectionSharing(ifInetlnetIndex, &chkStatus)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + if (!chkStatus) { + SPDLOG_ERROR(TEXT("Internet network ICS error")); + return -ERR_NET_WIREGUARD_ICS; + } + + SPDLOG_INFO(TEXT("Net Share Service Work now on ICS mode: {0}"), GetGlobalCfgInfo()->userCfg.userName); + } else if (GetCurrentNetShareMode() == NAT_SHARE_MODE) { + IP_INFO ipInfo; + TCHAR ipNat[MAX_IP_LEN]; + GetIpV4InfoFromCIDR(GetGlobalCfgInfo()->userCfg.cliConfig.cliAddress, &ipInfo); + StringCbPrintf(ipNat, MAX_IP_LEN, TEXT("%s/%d"), ipInfo.hostmax, ipInfo.prefix); + + // 检查 WireGuard 网络共享状态 + if ((ret = SetNATRule(GetGlobalCfgInfo()->userCfg.userName, ipNat)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + SPDLOG_INFO(TEXT("Net Share Service Work now on NAT mode: {0}_nat --> {1}"), + GetGlobalCfgInfo()->userCfg.userName, + ipNat); + } else { + SPDLOG_ERROR(TEXT("Not support Net Share Type: {0}"), static_cast(GetCurrentNetShareMode())); + return -ERR_UN_SUPPORT; + } + } else { + if (GetCurrentNetShareMode() == NAT_SHARE_MODE) { + // 检查 WireGuard 网络共享状态 + if ((ret = RemoveNATRule(GetGlobalCfgInfo()->userCfg.userName)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call RemoveNATRule error: {0}"), ret); + return ret; + } + } + SPDLOG_INFO(TEXT("Net Share Service Stoped: {0}"), GetGlobalCfgInfo()->userCfg.userName); + } + + return ERR_SUCCESS; +} + +int RemoteCtrlSvrCfgUserTunnel(int vmId, const TCHAR *pCliNetwork) { + const PUSER_CONFIG pUser = &GetGlobalCfgInfo()->userCfg; + const PUSER_CLIENT_CONFIG pUserCfg = &pUser->cliConfig; + + for (int i = 0; i < pUserCfg->tolVM; i++) { + if (pUserCfg->pVMConfig[i].vmId == vmId) { + IP_INFO ipInfo = {}; + int ret; + ProtocolRequest req; + ProtocolResponse rsp; + + req.msgContent.cliPublicKey = pUserCfg->cliPublicKey; + req.msgContent.cliNetwork = pCliNetwork; + GetIpV4InfoFromCIDR(pUserCfg->cliAddress, &ipInfo); + req.msgContent.cliTunnelAddr = ipInfo.ip; + + GetGlobalCfgInfo()->curConnVmId = vmId; + SPDLOG_DEBUG(TEXT("Current VMID: {0}"), vmId); + + // 连接到服务端控制服务 + InitControlServer(pUserCfg->pVMConfig[i].scgGateWay); + + // 发送本地配置参数到控制服务 + ret = ProtolPostMessage(SET_CLIENTCFG_PATH, &req, &rsp, false); + + if (ret != ERR_SUCCESS) { + return ret; + } + + if (rsp.msgContent.errCode != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Service Response error({0}): {1}"), + rsp.msgContent.errCode, + rsp.msgContent.errMessage); + return rsp.msgContent.errCode; + } + + // 返回成功配置本地参数 + ret = SetTunnelConfigure(pUserCfg->cliPrivateKey, + pUserCfg->pVMConfig[i].svrPublicKey, + pUserCfg->pVMConfig[i].vmNetwork, + pUserCfg->cliAddress, + rsp.msgContent.svrNetwork.c_str(), + pUserCfg->pVMConfig[i].scgTunnelGw); + + if (ret != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("SetTunnelConfigure Error: {0}"), ret); + return ret; + } + + return ERR_SUCCESS; + } + } + + return -ERR_ITEM_UNEXISTS; +} + +int SetTunnelConfigure(const TCHAR *pCliPrivateKey, + const TCHAR *pSvrPublicKey, + const TCHAR *pSvrNetwork, + const TCHAR *pCliNetwork, + const TCHAR *pSvrTunnelAddr, + const TCHAR *pSvrEndPoint) { + int ret; + bool isSvrStart = false; + int ifInetlnetIndex; + IP_INFO tunnelInfo = {}; + IP_INFO svrInfo = {}; + WGCLIENT_CONFIG cliCfg = {}; + +#pragma region + + if (pCliPrivateKey == nullptr || lstrlen(pCliPrivateKey) == 0) { + SPDLOG_ERROR(TEXT("Input pCliPrivateKey error: {0}"), pCliPrivateKey); + return -ERR_INPUT_PARAMS; + } + + if (pSvrPublicKey == nullptr || lstrlen(pSvrPublicKey) == 0) { + SPDLOG_ERROR(TEXT("Input pSvrPublicKey error: {0}"), pSvrPublicKey); + return -ERR_INPUT_PARAMS; + } + + if (pSvrNetwork == nullptr || lstrlen(pSvrNetwork) == 0) { + SPDLOG_ERROR(TEXT("Input pSvrNetwork error: {0}"), pSvrNetwork); + return -ERR_INPUT_PARAMS; + } + + if (pCliNetwork == nullptr || lstrlen(pCliNetwork) == 0) { + SPDLOG_ERROR(TEXT("Input pCliNetwork error: {0}"), pCliNetwork); + return -ERR_INPUT_PARAMS; + } + + if (pSvrTunnelAddr == nullptr || lstrlen(pSvrTunnelAddr) == 0) { + SPDLOG_ERROR(TEXT("Input pSvrTunnelAddr error: {0}"), pSvrTunnelAddr); + return -ERR_INPUT_PARAMS; + } + + if (pSvrEndPoint == nullptr || lstrlen(pSvrEndPoint) == 0) { + SPDLOG_ERROR(TEXT("Input pSvrEndPoint error: {0}"), pSvrEndPoint); + return -ERR_INPUT_PARAMS; + } + +#pragma endregion 参数检查 + + // 判断先前是否启动过服务 + if ((ret = IsWireGuardServerRunning(GetGlobalCfgInfo()->userCfg.userName, &isSvrStart)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("IsWireGuardServerInstalled error: {0}"), ret); + return ret; + } + + // 停止先前隧道网络 + if (isSvrStart) { + SPDLOG_DEBUG(TEXT("WireGuardUnInstallServerService: {0}"), GetGlobalCfgInfo()->userCfg.userName); + if ((ret = WireGuardUnInstallServerService(GetGlobalCfgInfo()->userCfg.userName)) != ERR_SUCCESS) { + // 返回停止服务失败 + SPDLOG_ERROR(TEXT("WireGuardUnInstallServerService error: {0}"), ret); + return ret; + } + } + + // 获取 Intelnet 网络网卡 Index + if ((ret = GetInternetIfIndex(&ifInetlnetIndex)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetInternetIfIndex error: {0}"), ret); + return ret; + } + + if (GetCurrentNetShareMode() == ICS_SHARE_MODE) { + // 检查 Internet 网络共享状态 + if ((ret = GetNetIntelnetConnectionSharing(ifInetlnetIndex, &isSvrStart)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call GetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + + // 关闭 Intelnet 网络连接共享 + if (isSvrStart) { + if ((ret = SetNetIntelnetConnectionSharing(ifInetlnetIndex, false, false)) != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Call SetNetIntelnetConnectionSharing error: {0}"), ret); + return ret; + } + } + } + + ret = GetIpV4InfoFromCIDR(pSvrTunnelAddr, &tunnelInfo); + + if (ret != ERR_SUCCESS) { + return ret; + } + + ret = GetIpV4InfoFromCIDR(pSvrNetwork, &svrInfo); + + if (ret != ERR_SUCCESS) { + return ret; + } + + memset(&cliCfg, 0, sizeof(WGCLIENT_CONFIG)); + + StringCbCopy(cliCfg.Name, 64, GetGlobalCfgInfo()->userCfg.userName); + StringCbCopy(cliCfg.PrivateKey, 64, pCliPrivateKey); + StringCbCopy(cliCfg.Address, 32, pCliNetwork); + StringCbCopy(cliCfg.SvrPubKey, 64, pSvrPublicKey); + if (UsedSCGProxy()) { + StringCbPrintf(cliCfg.ServerURL, 256, TEXT("127.0.0.1:%d"), GetGlobalCfgInfo()->scgProxy.proxyPort); + } else { + StringCbCopy(cliCfg.ServerURL, 256, pSvrEndPoint); + } + + StringCbPrintf(cliCfg.AllowNet, + 256, + TEXT("%s/%d,%s/%d"), + tunnelInfo.network, + tunnelInfo.prefix, + svrInfo.network, + svrInfo.prefix); + + ret = WireGuardCreateClientConfig(&cliCfg); + + if (ret != ERR_SUCCESS) { + return ret; + } + + return ERR_SUCCESS; +} + +int GetUserServerConfigure(const TCHAR *pUserName, const TCHAR *pToken, PUSER_SERVER_CONFIG *pSvrCfg) { + int ret; + PUSER_CONFIG pUser = &GetGlobalCfgInfo()->userCfg; + PUSER_SERVER_CONFIG pUserCfg = &pUser->svrConfig; + +#if USER_REAL_PLATFORM + ProtocolRequest req; + ProtocolResponse rsp; +#else + PlatformReqServerCfgParms req; + PlatformRspServerCfgParams rsp; +#endif + + if (pSvrCfg == nullptr) { + SPDLOG_ERROR(TEXT("Input pSvrCfg params error")); + return -ERR_INPUT_PARAMS; + } + + if (pToken == nullptr || lstrlen(pToken) == 0) { + SPDLOG_ERROR(TEXT("Input pToken params error: {0}"), pToken); + return -ERR_INPUT_PARAMS; + } + + if (pUserName && lstrlen(pUserName) > 0) { + memset(pUser->userName, 0, MAX_PATH); + StringCbCopy(pUser->userName, MAX_PATH, pUserName); + } else { + StringCbCopy(pUser->userName, MAX_PATH, TEXT("tunnel_svr")); + } + + StringCbCopy(pUser->userToken, MAX_PATH, pToken); + +#if USER_REAL_PLATFORM + req.msgContent.token = pToken; + req.msgContent.user = pUser->userName; + + ret = ProtolPostMessage(GET_SERVERCFG_PATH, &req, &rsp, true); + + if (ret != ERR_SUCCESS) { + return ret; + } + + pUserCfg->svrListenPort = rsp.msgContent.svrListenPort; + StringCbCopy(pUserCfg->svrPrivateKey, 64, rsp.msgContent.svrPrivateKey.c_str()); + StringCbCopy(pUserCfg->svrAddress, MAX_IP_LEN, rsp.msgContent.svrAddress.c_str()); + +#else + req.vmIp = pToken; + ret = PlatformProtolPostMessage(GET_SERVERCFG_PATH, &req, &rsp); + + if (ret != ERR_SUCCESS) { + return ret; + } + + ret = strtol(rsp.code.c_str(), nullptr, 10); + + if (ret != 0) { + SPDLOG_ERROR(TEXT("Server response error code: {0}"), ret); + return -ERR_HTTP_SERVER_RSP; + } + + pUserCfg->svrListenPort = rsp.data.svrPort; + StringCbCopy(pUserCfg->svrPrivateKey, 64, rsp.data.svrPriKey.c_str()); + StringCbCopy(pUserCfg->svrAddress, MAX_IP_LEN, rsp.data.svrHost.c_str()); +#endif + + *pSvrCfg = pUserCfg; + return ERR_SUCCESS; +} + +int GetUserClientConfigure(const TCHAR *pUserName, const TCHAR *pToken, PUSER_CLIENT_CONFIG *pCliCfg) { + PUSER_CONFIG pUser = &GetGlobalCfgInfo()->userCfg; + PUSER_CLIENT_CONFIG pUserCfg = &pUser->cliConfig; + TCHAR userPath[MAX_PATH]; + int ret; + +#if USER_REAL_PLATFORM + ProtocolRequest req; + ProtocolResponse rsp; +#else + PlatformReqClientCfgParms req; + PlatformRspClientCfgParams rsp; +#endif + + if (pUserName == nullptr || lstrlen(pUserName) == 0) { + SPDLOG_ERROR(TEXT("Input pUserName params error: {0}"), pUserName); + return -ERR_INPUT_PARAMS; + } + + if (pToken == nullptr || lstrlen(pToken) == 0) { + SPDLOG_ERROR(TEXT("Input pToken params error: {0}"), pToken); + return -ERR_INPUT_PARAMS; + } + + if (pCliCfg == nullptr) { + SPDLOG_ERROR(TEXT("Input pCliCfg params error")); + return -ERR_INPUT_PARAMS; + } + + StringCbPrintf(userPath, MAX_PATH, "%s\\%s", GetGlobalCfgInfo()->configDirectory, pUserName); + // 如果配置目录不存在则自动创建 + if (!PathFileExists(userPath)) { + if (!CreateDirectory(userPath, nullptr)) { + SPDLOG_ERROR(TEXT("Create user {1} directory '{0}' error."), userPath, pUserName); + return -ERR_CREATE_FILE; + } + } + + memset(pUser->userName, 0, MAX_PATH); + if (pUserName && lstrlen(pUserName) > 0) { + StringCbCopy(pUser->userName, MAX_PATH, pUserName); + } + StringCbCopy(pUser->userToken, MAX_PATH, pToken); + +#if USER_REAL_PLATFORM + req.msgContent.token = pToken; + req.msgContent.user = pUserName; + + ret = ProtolPostMessage(GET_CLIENTCFG_PATH, &req, &rsp, true); + + if (ret != ERR_SUCCESS) { + return ret; + } + + StringCbCopy(pUserCfg->cliPrivateKey, 64, rsp.msgContent.cliPrivateKey.c_str()); + StringCbCopy(pUserCfg->cliPublicKey, 64, rsp.msgContent.cliPublicKey.c_str()); + StringCbCopy(pUserCfg->cliAddress, MAX_IP_LEN, rsp.msgContent.cliAddress.c_str()); + + if (!rsp.msgContent.vmConfig.empty()) { + PVM_CFG pVm; + unsigned int memSize = sizeof(VM_CFG) * static_cast(rsp.msgContent.vmConfig.size()); + pUserCfg->pVMConfig = static_cast(CoTaskMemAlloc(memSize)); + if (pUserCfg->pVMConfig == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), memSize); + return -ERR_MALLOC_MEMORY; + } + + memset(pUserCfg->pVMConfig, 0, memSize); + + pUserCfg->tolVM = static_cast(rsp.msgContent.vmConfig.size()); + pVm = pUserCfg->pVMConfig; + + for (auto vm : rsp.msgContent.vmConfig) { + TCHAR tmpAddr[MAX_PATH]; + pVm->vmId = vm.vmId; + StringCbCopy(pVm->vmName, MAX_PATH, vm.vmName.c_str()); + StringCbCopy(pVm->svrPublicKey, 64, vm.svrPublicKey.c_str()); + StringCbCopy(pVm->vmNetwork, MAX_IP_LEN, vm.vmNetwork.c_str()); + //StringCbCopy(pVm->scgGateWay, MAX_PATH, vm.scgGateway.c_str()); + StringCbPrintf(pVm->scgGateWay, MAX_PATH, TEXT("http://%s"), vm.scgGateway.c_str()); + StringCbCopy(tmpAddr, MAX_PATH, vm.scgGateway.c_str()); + httplib::Client cli(pVm->scgGateWay); + StringCbPrintf(pVm->scgTunnelGw, MAX_PATH, TEXT("%s:%d"), cli.host().c_str(), cli.port() - 1); + pVm++; + } + } +#else + req.userName = pUserName; + req.token = pToken; + + ret = PlatformProtolPostMessage(GET_CLIENTCFG_PATH, &req, &rsp); + + if (ret != ERR_SUCCESS) { + return ret; + } + + ret = strtol(rsp.code.c_str(), nullptr, 10); + + if (ret != 0) { + SPDLOG_ERROR(TEXT("Server response error code: {0}"), ret); + return -ERR_HTTP_SERVER_RSP; + } + + StringCbCopy(pUserCfg->cliPrivateKey, 64, rsp.data.cliPriKey.c_str()); + StringCbCopy(pUserCfg->cliPublicKey, 64, rsp.data.cliPubKey.c_str()); + StringCbCopy(pUserCfg->cliAddress, MAX_IP_LEN, rsp.data.cliHost.c_str()); + + if (!rsp.data.vmInfoList.empty()) { + PVM_CFG pVm; + unsigned int memSize = sizeof(VM_CFG) * static_cast(rsp.data.vmInfoList.size()); + pUserCfg->pVMConfig = static_cast(CoTaskMemAlloc(memSize)); + if (pUserCfg->pVMConfig == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), memSize); + return -ERR_MALLOC_MEMORY; + } + + memset(pUserCfg->pVMConfig, 0, memSize); + + pUserCfg->tolVM = static_cast(rsp.data.vmInfoList.size()); + pVm = pUserCfg->pVMConfig; + + for (auto vm : rsp.data.vmInfoList) { + pVm->vmId = vm.vmId; + StringCbCopy(pVm->vmName, MAX_PATH, vm.vmName.c_str()); + StringCbCopy(pVm->svrPublicKey, 64, vm.svrPubKey.c_str()); + StringCbCopy(pVm->vmNetwork, MAX_IP_LEN, vm.vmNetwork.c_str()); + //StringCbCopy(pVm->scgGateWay, MAX_PATH, vm.scgGateway.c_str()); + StringCbPrintf(pVm->scgGateWay, MAX_PATH, TEXT("http://%s:%d"), vm.scgIp.c_str(), vm.scgPort); +#if USED_PORTMAP_TUNNEL + StringCbPrintf(pVm->scgTunnelGw, MAX_PATH, TEXT("%s:%d"), vm.portMapIp.c_str(), vm.portMapPort); +#else + StringCbPrintf(pVm->scgTunnelGw, MAX_PATH, TEXT("%s:%d"), vm.scgIp.c_str(), vm.scgPort - 1); +#endif +#if USED_PORTMAP_TUNNEL +#endif + pVm++; + } + } + +#endif + *pCliCfg = pUserCfg; + return ERR_SUCCESS; +} + +int GetUserConfigFiles(const TCHAR *pUserName, PUSER_CFGFILE *pCfgFile, int *pItems) { + PUSER_CFGFILE pCfg; + FILE_LIST fileList = {nullptr, 0}; + TCHAR fnPath[MAX_PATH] = {}; + TCHAR cfgVal[MAX_PATH]; + bool isSelected = false; + + if (pUserName == nullptr || lstrlen(pUserName) == 0) { + SPDLOG_ERROR(TEXT("Input pUserName params error: {0}"), pUserName); + return -ERR_INPUT_PARAMS; + } + + if (pCfgFile == nullptr) { + SPDLOG_ERROR(TEXT("Input pCfgFile params error")); + return -ERR_INPUT_PARAMS; + } + + if (pItems == nullptr) { + SPDLOG_ERROR(TEXT("Input pItems params error")); + return -ERR_INPUT_PARAMS; + } + + GetPrivateProfileString(CFG_WIREGUARD_SECTION, + CFG_WGCFG_PATH, + TEXT(""), + cfgVal, + MAX_PATH, + GetGlobalCfgInfo()->cfgPath); + + if (PathFileExists(cfgVal)) { + isSelected = true; + } + + StringCbPrintf(fnPath, MAX_PATH, "%s\\%s\\*.conf", GetGlobalCfgInfo()->configDirectory, pUserName); + + int ret = FindFile(fnPath, &fileList, false); + + if (ret != ERR_SUCCESS) { + SPDLOG_ERROR(TEXT("Find WireGuard user {1} configure file error: {0}"), ret, pUserName); + return ret; + } + + pCfg = static_cast(CoTaskMemAlloc(sizeof(USER_CFGFILE) * fileList.nItems)); + + if (pCfg == nullptr) { + SPDLOG_ERROR(TEXT("Error allocating memory {0} bytes"), sizeof(USER_CFGFILE) * fileList.nItems); + return -ERR_SUCCESS; + } + + memset(pCfg, 0, sizeof(USER_CFGFILE) * fileList.nItems); + + *pCfgFile = pCfg; + *pItems = static_cast(fileList.nItems); + + for (unsigned int i = 0; fileList.pFilePath && i < fileList.nItems; i++) { + StringCbCopy(pCfg->CfgPath, MAX_PATH, fileList.pFilePath[i].path); + if (isSelected && StrCmp(pCfg->CfgPath, cfgVal) == 0) { + pCfg->isCurrent = true; + } else { + pCfg->isCurrent = false; + } + pCfg++; + } + + if (fileList.pFilePath) { + HeapFree(GetProcessHeap(), 0, fileList.pFilePath); + } + + return ERR_SUCCESS; +} \ No newline at end of file diff --git a/NetTunnelServerApp/CMakeLists.txt b/NetTunnelServerApp/CMakeLists.txt new file mode 100644 index 0000000..db311f0 --- /dev/null +++ b/NetTunnelServerApp/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.22) +project(NetTunnelServerApp) + +set(CMAKE_CXX_STANDARD 23) +add_definitions(-D_UNICODE) + +find_package(spdlog CONFIG REQUIRED) +add_executable(NetTunnelServerApp NetTunnelServerApp.cpp) +target_link_libraries(NetTunnelServerApp PRIVATE spdlog::spdlog) + +SET_TARGET_PROPERTIES(NetTunnelServerApp PROPERTIES LINK_FLAGS "/MANIFESTUAC:\"level='requireAdministrator' uiAccess='false'\"") \ No newline at end of file diff --git a/NetTunnelServerApp/NetTunnelServerApp.cpp b/NetTunnelServerApp/NetTunnelServerApp.cpp new file mode 100644 index 0000000..f3fc34a --- /dev/null +++ b/NetTunnelServerApp/NetTunnelServerApp.cpp @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include + +#pragma comment(linker, "/MANIFESTUAC:\"level='requireAdministrator' uiAccess='false'\"") + +int _tmain(int wargc, _TCHAR *wargv[]) { + int ret; +// PUSER_SERVER_CONFIG pSvrCfg; +// +// //https://xajhuang.com:9276 +// //http://172.21.40.39:32549 +// if ((ret = TunnelSDKInitEnv(nullptr, "https://112.17.28.201:1443", nullptr, LOG_DEBUG, true)) != ERR_SUCCESS) { +// wprintf(L"Init SCC SDK Error: %d\n", ret); +// return -1; +// } +// +// EnableVerifySignature("sc-winvdisdk-efa9v12xwtz5eppr", "lh5r8sw6m9m416nm"); +// +// if (ERR_SUCCESS != (ret = GetUserServerConfigure("tunnel_svr", "172.21.97.100", &pSvrCfg))) { +// wprintf(L"GetUserServerConfigure Error: %d\n", ret); +// return -2; +// } +// +// if (ERR_SUCCESS != (ret = CreateControlService(pSvrCfg))) { +// wprintf(L"CreateControlService Error: %d\n", ret); +// return -2; +// } + + wprintf(L"Press Key 'X' to exit......\n"); + + do { + ret = _getch(); + } while (ret != 'X' && ret != 'x'); + + return 0; +} diff --git a/NetTunnelSvr/CMakeLists.txt b/NetTunnelSvr/CMakeLists.txt new file mode 100644 index 0000000..0699b50 --- /dev/null +++ b/NetTunnelSvr/CMakeLists.txt @@ -0,0 +1,8 @@ +cmake_minimum_required(VERSION 3.22) +project(NetTunnelSvr) + +set(CMAKE_CXX_STANDARD 23) +add_definitions(-D_UNICODE) + +add_executable(NetTunnelSvr NetTunnelSvr.cpp) +SET_TARGET_PROPERTIES(NetTunnelSvr PROPERTIES LINK_FLAGS "/MANIFESTUAC:\"level='requireAdministrator' uiAccess='false'\"") diff --git a/NetTunnelSvr/NetTunnelSvr.cpp b/NetTunnelSvr/NetTunnelSvr.cpp new file mode 100644 index 0000000..718e823 --- /dev/null +++ b/NetTunnelSvr/NetTunnelSvr.cpp @@ -0,0 +1,106 @@ +#include +#include +#include + +#define WG_TUNNEL_SVR_NAME TEXT("WireGuard_DLL_SVR") + +typedef BOOL(WINAPI WIREGUARD_TUNNEL_SERVICE_FUNC)(_In_z_ LPCWSTR Name); +static WIREGUARD_TUNNEL_SERVICE_FUNC *WireGuardTunnelService; + +static void LogToSystemEventLog(int wErrorType, int wCustumerCode, const TCHAR *szMsg) { + HANDLE hEventSource; + DWORD dwEventIdentifer; + + switch (wErrorType) { + case EVENTLOG_SUCCESS: + case EVENTLOG_AUDIT_SUCCESS: + dwEventIdentifer = 0x00; + break; + case EVENTLOG_INFORMATION_TYPE: + dwEventIdentifer = 0x01; + break; + case EVENTLOG_WARNING_TYPE: + dwEventIdentifer = 0x02; + break; + case EVENTLOG_ERROR_TYPE: + case EVENTLOG_AUDIT_FAILURE: + dwEventIdentifer = 0x03; + break; + default: + dwEventIdentifer = 0; + break; + } + // 移位获得Sev,前面给出的 wErrorType 为 EVENTLOG_ERROR_TYPE,对应着下图 “级别” 一列显示“错误”图标 + dwEventIdentifer <<= 30; + dwEventIdentifer |= static_cast(wCustumerCode); // 前面自定义了Code,对应着下图中 事件ID 20 + + hEventSource = RegisterEventSource(nullptr, WG_TUNNEL_SVR_NAME); + + if (nullptr != hEventSource) { + + LPCTSTR lpszStrings[2] = { + WG_TUNNEL_SVR_NAME, + szMsg}; //要写入日志的信息有两行,分别是 服务名,和前面给出的szMsg,对应着下图“以下是包含在事件中的信息” + + ReportEvent(hEventSource, // event log handle + wErrorType, // event type + 0, // event category + dwEventIdentifer, // event identifier + nullptr, // no security identifier + 2, // size of lpszStrings array + 0, // no binary data + lpszStrings, // array of strings + nullptr); // no binary data + DeregisterEventSource(hEventSource); + } +} + +static HMODULE InitializeTunnelLibrary() { + const HMODULE tunnel = LoadLibraryExW(L"tunnel.dll", nullptr, + LOAD_LIBRARY_SEARCH_APPLICATION_DIR | LOAD_LIBRARY_SEARCH_SYSTEM32); + if (!tunnel) { + TCHAR tMsg[MAX_PATH * sizeof(TCHAR)]; + StringCbPrintf(tMsg, MAX_PATH * sizeof(TCHAR), TEXT("LoadLibraryExW Error: %d\n"), GetLastError()); + LogToSystemEventLog(EVENTLOG_ERROR_TYPE, 0x01, tMsg); + return nullptr; + } + +#define X(Name) ((*(FARPROC *)&(Name) = GetProcAddress(tunnel, #Name)) == nullptr) + if (X(WireGuardTunnelService)) +#undef X + { + const DWORD LastError = GetLastError(); + FreeLibrary(tunnel); + SetLastError(LastError); + return nullptr; + } + return tunnel; +} + +int _tmain(int wargc, _TCHAR *wargv[]) { + TCHAR tMsg[MAX_PATH] = {}; + + if (wargc == 3 && !wcscmp(wargv[1], L"/service")) { + BOOL ret; + const HMODULE hModule = InitializeTunnelLibrary(); + + if (!hModule || !WireGuardTunnelService) { + StringCbPrintf(tMsg, MAX_PATH, TEXT("Init WireGuardTunnelService Service Error: %d\n"), GetLastError()); + LogToSystemEventLog(EVENTLOG_ERROR_TYPE, 0x01, tMsg); + return -1; + } + + ret = WireGuardTunnelService(wargv[2]); + + if (ret) { + StringCbPrintf(tMsg, MAX_PATH, TEXT("Start WireGuardTunnelService Service Successed\n")); + LogToSystemEventLog(EVENTLOG_INFORMATION_TYPE, 0x00, tMsg); + } else { + StringCbPrintf(tMsg, MAX_PATH, TEXT("Start WireGuardTunnelService Service failed: %d\n"), GetLastError()); + LogToSystemEventLog(EVENTLOG_ERROR_TYPE, 0x02, tMsg); + } + + return ret; + } + return 0; +} diff --git a/depends/WinDivert/include/windivert.h b/depends/WinDivert/include/windivert.h new file mode 100644 index 0000000..fc63adf --- /dev/null +++ b/depends/WinDivert/include/windivert.h @@ -0,0 +1,630 @@ +/* + * windivert.h + * (C) 2019, all rights reserved, + * + * This file is part of WinDivert. + * + * WinDivert is free software: you can redistribute it and/or modify it under + * the terms of the GNU Lesser General Public License as published by the + * Free Software Foundation, either version 3 of the License, or (at your + * option) any later version. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public + * License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this program. If not, see . + * + * WinDivert is free software; you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free + * Software Foundation; either version 2 of the License, or (at your option) + * any later version. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., 51 + * Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + */ + +#ifndef __WINDIVERT_H +#define __WINDIVERT_H + +#ifndef WINDIVERT_KERNEL +#include +#endif /* WINDIVERT_KERNEL */ + +#ifndef WINDIVERTEXPORT +#define WINDIVERTEXPORT extern __declspec(dllimport) +#endif /* WINDIVERTEXPORT */ + +#ifdef __MINGW32__ +#define __in +#define __in_opt +#define __out +#define __out_opt +#define __inout +#define __inout_opt +#include +#define INT8 int8_t +#define UINT8 uint8_t +#define INT16 int16_t +#define UINT16 uint16_t +#define INT32 int32_t +#define UINT32 uint32_t +#define INT64 int64_t +#define UINT64 uint64_t +#endif /* __MINGW32__ */ + +#ifdef __cplusplus +extern "C" { +#endif + +/****************************************************************************/ +/* WINDIVERT API */ +/****************************************************************************/ + +/* + * WinDivert layers. + */ +typedef enum +{ + WINDIVERT_LAYER_NETWORK = 0, /* Network layer. */ + WINDIVERT_LAYER_NETWORK_FORWARD = 1,/* Network layer (forwarded packets) */ + WINDIVERT_LAYER_FLOW = 2, /* Flow layer. */ + WINDIVERT_LAYER_SOCKET = 3, /* Socket layer. */ + WINDIVERT_LAYER_REFLECT = 4, /* Reflect layer. */ +} WINDIVERT_LAYER, *PWINDIVERT_LAYER; + +/* + * WinDivert NETWORK and NETWORK_FORWARD layer data. + */ +typedef struct +{ + UINT32 IfIdx; /* Packet's interface index. */ + UINT32 SubIfIdx; /* Packet's sub-interface index. */ +} WINDIVERT_DATA_NETWORK, *PWINDIVERT_DATA_NETWORK; + +/* + * WinDivert FLOW layer data. + */ +typedef struct +{ + UINT64 EndpointId; /* Endpoint ID. */ + UINT64 ParentEndpointId; /* Parent endpoint ID. */ + UINT32 ProcessId; /* Process ID. */ + UINT32 LocalAddr[4]; /* Local address. */ + UINT32 RemoteAddr[4]; /* Remote address. */ + UINT16 LocalPort; /* Local port. */ + UINT16 RemotePort; /* Remote port. */ + UINT8 Protocol; /* Protocol. */ +} WINDIVERT_DATA_FLOW, *PWINDIVERT_DATA_FLOW; + +/* + * WinDivert SOCKET layer data. + */ +typedef struct +{ + UINT64 EndpointId; /* Endpoint ID. */ + UINT64 ParentEndpointId; /* Parent Endpoint ID. */ + UINT32 ProcessId; /* Process ID. */ + UINT32 LocalAddr[4]; /* Local address. */ + UINT32 RemoteAddr[4]; /* Remote address. */ + UINT16 LocalPort; /* Local port. */ + UINT16 RemotePort; /* Remote port. */ + UINT8 Protocol; /* Protocol. */ +} WINDIVERT_DATA_SOCKET, *PWINDIVERT_DATA_SOCKET; + +/* + * WinDivert REFLECTION layer data. + */ +typedef struct +{ + INT64 Timestamp; /* Handle open time. */ + UINT32 ProcessId; /* Handle process ID. */ + WINDIVERT_LAYER Layer; /* Handle layer. */ + UINT64 Flags; /* Handle flags. */ + INT16 Priority; /* Handle priority. */ +} WINDIVERT_DATA_REFLECT, *PWINDIVERT_DATA_REFLECT; + +/* + * WinDivert address. + */ +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4201) +#endif +typedef struct +{ + INT64 Timestamp; /* Packet's timestamp. */ + UINT32 Layer:8; /* Packet's layer. */ + UINT32 Event:8; /* Packet event. */ + UINT32 Sniffed:1; /* Packet was sniffed? */ + UINT32 Outbound:1; /* Packet is outound? */ + UINT32 Loopback:1; /* Packet is loopback? */ + UINT32 Impostor:1; /* Packet is impostor? */ + UINT32 IPv6:1; /* Packet is IPv6? */ + UINT32 IPChecksum:1; /* Packet has valid IPv4 checksum? */ + UINT32 TCPChecksum:1; /* Packet has valid TCP checksum? */ + UINT32 UDPChecksum:1; /* Packet has valid UDP checksum? */ + UINT32 Reserved1:8; + UINT32 Reserved2; + union + { + WINDIVERT_DATA_NETWORK Network; /* Network layer data. */ + WINDIVERT_DATA_FLOW Flow; /* Flow layer data. */ + WINDIVERT_DATA_SOCKET Socket; /* Socket layer data. */ + WINDIVERT_DATA_REFLECT Reflect; /* Reflect layer data. */ + UINT8 Reserved3[64]; + }; +} WINDIVERT_ADDRESS, *PWINDIVERT_ADDRESS; +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/* + * WinDivert events. + */ +typedef enum +{ + WINDIVERT_EVENT_NETWORK_PACKET = 0, /* Network packet. */ + WINDIVERT_EVENT_FLOW_ESTABLISHED = 1, + /* Flow established. */ + WINDIVERT_EVENT_FLOW_DELETED = 2, /* Flow deleted. */ + WINDIVERT_EVENT_SOCKET_BIND = 3, /* Socket bind. */ + WINDIVERT_EVENT_SOCKET_CONNECT = 4, /* Socket connect. */ + WINDIVERT_EVENT_SOCKET_LISTEN = 5, /* Socket listen. */ + WINDIVERT_EVENT_SOCKET_ACCEPT = 6, /* Socket accept. */ + WINDIVERT_EVENT_SOCKET_CLOSE = 7, /* Socket close. */ + WINDIVERT_EVENT_REFLECT_OPEN = 8, /* WinDivert handle opened. */ + WINDIVERT_EVENT_REFLECT_CLOSE = 9, /* WinDivert handle closed. */ +} WINDIVERT_EVENT, *PWINDIVERT_EVENT; + +/* + * WinDivert flags. + */ +#define WINDIVERT_FLAG_SNIFF 0x0001 +#define WINDIVERT_FLAG_DROP 0x0002 +#define WINDIVERT_FLAG_RECV_ONLY 0x0004 +#define WINDIVERT_FLAG_READ_ONLY WINDIVERT_FLAG_RECV_ONLY +#define WINDIVERT_FLAG_SEND_ONLY 0x0008 +#define WINDIVERT_FLAG_WRITE_ONLY WINDIVERT_FLAG_SEND_ONLY +#define WINDIVERT_FLAG_NO_INSTALL 0x0010 +#define WINDIVERT_FLAG_FRAGMENTS 0x0020 + +/* + * WinDivert parameters. + */ +typedef enum +{ + WINDIVERT_PARAM_QUEUE_LENGTH = 0, /* Packet queue length. */ + WINDIVERT_PARAM_QUEUE_TIME = 1, /* Packet queue time. */ + WINDIVERT_PARAM_QUEUE_SIZE = 2, /* Packet queue size. */ + WINDIVERT_PARAM_VERSION_MAJOR = 3, /* Driver version (major). */ + WINDIVERT_PARAM_VERSION_MINOR = 4, /* Driver version (minor). */ +} WINDIVERT_PARAM, *PWINDIVERT_PARAM; +#define WINDIVERT_PARAM_MAX WINDIVERT_PARAM_VERSION_MINOR + +/* + * WinDivert shutdown parameter. + */ +typedef enum +{ + WINDIVERT_SHUTDOWN_RECV = 0x1, /* Shutdown recv. */ + WINDIVERT_SHUTDOWN_SEND = 0x2, /* Shutdown send. */ + WINDIVERT_SHUTDOWN_BOTH = 0x3, /* Shutdown recv and send. */ +} WINDIVERT_SHUTDOWN, *PWINDIVERT_SHUTDOWN; +#define WINDIVERT_SHUTDOWN_MAX WINDIVERT_SHUTDOWN_BOTH + +#ifndef WINDIVERT_KERNEL + +/* + * Open a WinDivert handle. + */ +WINDIVERTEXPORT HANDLE WinDivertOpen( + __in const char *filter, + __in WINDIVERT_LAYER layer, + __in INT16 priority, + __in UINT64 flags); + +/* + * Receive (read) a packet from a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertRecv( + __in HANDLE handle, + __out_opt VOID *pPacket, + __in UINT packetLen, + __out_opt UINT *pRecvLen, + __out_opt WINDIVERT_ADDRESS *pAddr); + +/* + * Receive (read) a packet from a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertRecvEx( + __in HANDLE handle, + __out_opt VOID *pPacket, + __in UINT packetLen, + __out_opt UINT *pRecvLen, + __in UINT64 flags, + __out WINDIVERT_ADDRESS *pAddr, + __inout_opt UINT *pAddrLen, + __inout_opt LPOVERLAPPED lpOverlapped); + +/* + * Send (write/inject) a packet to a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertSend( + __in HANDLE handle, + __in const VOID *pPacket, + __in UINT packetLen, + __out_opt UINT *pSendLen, + __in const WINDIVERT_ADDRESS *pAddr); + +/* + * Send (write/inject) a packet to a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertSendEx( + __in HANDLE handle, + __in const VOID *pPacket, + __in UINT packetLen, + __out_opt UINT *pSendLen, + __in UINT64 flags, + __in const WINDIVERT_ADDRESS *pAddr, + __in UINT addrLen, + __inout_opt LPOVERLAPPED lpOverlapped); + +/* + * Shutdown a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertShutdown( + __in HANDLE handle, + __in WINDIVERT_SHUTDOWN how); + +/* + * Close a WinDivert handle. + */ +WINDIVERTEXPORT BOOL WinDivertClose( + __in HANDLE handle); + +/* + * Set a WinDivert handle parameter. + */ +WINDIVERTEXPORT BOOL WinDivertSetParam( + __in HANDLE handle, + __in WINDIVERT_PARAM param, + __in UINT64 value); + +/* + * Get a WinDivert handle parameter. + */ +WINDIVERTEXPORT BOOL WinDivertGetParam( + __in HANDLE handle, + __in WINDIVERT_PARAM param, + __out UINT64 *pValue); + +#endif /* WINDIVERT_KERNEL */ + +/* + * WinDivert constants. + */ +#define WINDIVERT_PRIORITY_HIGHEST 30000 +#define WINDIVERT_PRIORITY_LOWEST (-WINDIVERT_PRIORITY_HIGHEST) +#define WINDIVERT_PARAM_QUEUE_LENGTH_DEFAULT 4096 +#define WINDIVERT_PARAM_QUEUE_LENGTH_MIN 32 +#define WINDIVERT_PARAM_QUEUE_LENGTH_MAX 16384 +#define WINDIVERT_PARAM_QUEUE_TIME_DEFAULT 2000 /* 2s */ +#define WINDIVERT_PARAM_QUEUE_TIME_MIN 100 /* 100ms */ +#define WINDIVERT_PARAM_QUEUE_TIME_MAX 16000 /* 16s */ +#define WINDIVERT_PARAM_QUEUE_SIZE_DEFAULT 4194304 /* 4MB */ +#define WINDIVERT_PARAM_QUEUE_SIZE_MIN 65535 /* 64KB */ +#define WINDIVERT_PARAM_QUEUE_SIZE_MAX 33554432 /* 32MB */ +#define WINDIVERT_BATCH_MAX 0xFF /* 255 */ +#define WINDIVERT_MTU_MAX (40 + 0xFFFF) + +/****************************************************************************/ +/* WINDIVERT HELPER API */ +/****************************************************************************/ + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable: 4214) +#endif + +/* + * IPv4/IPv6/ICMP/ICMPv6/TCP/UDP header definitions. + */ +typedef struct +{ + UINT8 HdrLength:4; + UINT8 Version:4; + UINT8 TOS; + UINT16 Length; + UINT16 Id; + UINT16 FragOff0; + UINT8 TTL; + UINT8 Protocol; + UINT16 Checksum; + UINT32 SrcAddr; + UINT32 DstAddr; +} WINDIVERT_IPHDR, *PWINDIVERT_IPHDR; + +#define WINDIVERT_IPHDR_GET_FRAGOFF(hdr) \ + (((hdr)->FragOff0) & 0xFF1F) +#define WINDIVERT_IPHDR_GET_MF(hdr) \ + ((((hdr)->FragOff0) & 0x0020) != 0) +#define WINDIVERT_IPHDR_GET_DF(hdr) \ + ((((hdr)->FragOff0) & 0x0040) != 0) +#define WINDIVERT_IPHDR_GET_RESERVED(hdr) \ + ((((hdr)->FragOff0) & 0x0080) != 0) + +#define WINDIVERT_IPHDR_SET_FRAGOFF(hdr, val) \ + do \ + { \ + (hdr)->FragOff0 = (((hdr)->FragOff0) & 0x00E0) | \ + ((val) & 0xFF1F); \ + } \ + while (FALSE) +#define WINDIVERT_IPHDR_SET_MF(hdr, val) \ + do \ + { \ + (hdr)->FragOff0 = (((hdr)->FragOff0) & 0xFFDF) | \ + (((val) & 0x0001) << 5); \ + } \ + while (FALSE) +#define WINDIVERT_IPHDR_SET_DF(hdr, val) \ + do \ + { \ + (hdr)->FragOff0 = (((hdr)->FragOff0) & 0xFFBF) | \ + (((val) & 0x0001) << 6); \ + } \ + while (FALSE) +#define WINDIVERT_IPHDR_SET_RESERVED(hdr, val) \ + do \ + { \ + (hdr)->FragOff0 = (((hdr)->FragOff0) & 0xFF7F) | \ + (((val) & 0x0001) << 7); \ + } \ + while (FALSE) + +typedef struct +{ + UINT8 TrafficClass0:4; + UINT8 Version:4; + UINT8 FlowLabel0:4; + UINT8 TrafficClass1:4; + UINT16 FlowLabel1; + UINT16 Length; + UINT8 NextHdr; + UINT8 HopLimit; + UINT32 SrcAddr[4]; + UINT32 DstAddr[4]; +} WINDIVERT_IPV6HDR, *PWINDIVERT_IPV6HDR; + +#define WINDIVERT_IPV6HDR_GET_TRAFFICCLASS(hdr) \ + ((((hdr)->TrafficClass0) << 4) | ((hdr)->TrafficClass1)) +#define WINDIVERT_IPV6HDR_GET_FLOWLABEL(hdr) \ + ((((UINT32)(hdr)->FlowLabel0) << 16) | ((UINT32)(hdr)->FlowLabel1)) + +#define WINDIVERT_IPV6HDR_SET_TRAFFICCLASS(hdr, val) \ + do \ + { \ + (hdr)->TrafficClass0 = ((UINT8)(val) >> 4); \ + (hdr)->TrafficClass1 = (UINT8)(val); \ + } \ + while (FALSE) +#define WINDIVERT_IPV6HDR_SET_FLOWLABEL(hdr, val) \ + do \ + { \ + (hdr)->FlowLabel0 = (UINT8)((val) >> 16); \ + (hdr)->FlowLabel1 = (UINT16)(val); \ + } \ + while (FALSE) + +typedef struct +{ + UINT8 Type; + UINT8 Code; + UINT16 Checksum; + UINT32 Body; +} WINDIVERT_ICMPHDR, *PWINDIVERT_ICMPHDR; + +typedef struct +{ + UINT8 Type; + UINT8 Code; + UINT16 Checksum; + UINT32 Body; +} WINDIVERT_ICMPV6HDR, *PWINDIVERT_ICMPV6HDR; + +typedef struct +{ + UINT16 SrcPort; + UINT16 DstPort; + UINT32 SeqNum; + UINT32 AckNum; + UINT16 Reserved1:4; + UINT16 HdrLength:4; + UINT16 Fin:1; + UINT16 Syn:1; + UINT16 Rst:1; + UINT16 Psh:1; + UINT16 Ack:1; + UINT16 Urg:1; + UINT16 Reserved2:2; + UINT16 Window; + UINT16 Checksum; + UINT16 UrgPtr; +} WINDIVERT_TCPHDR, *PWINDIVERT_TCPHDR; + +typedef struct +{ + UINT16 SrcPort; + UINT16 DstPort; + UINT16 Length; + UINT16 Checksum; +} WINDIVERT_UDPHDR, *PWINDIVERT_UDPHDR; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/* + * Flags for WinDivertHelperCalcChecksums() + */ +#define WINDIVERT_HELPER_NO_IP_CHECKSUM 1 +#define WINDIVERT_HELPER_NO_ICMP_CHECKSUM 2 +#define WINDIVERT_HELPER_NO_ICMPV6_CHECKSUM 4 +#define WINDIVERT_HELPER_NO_TCP_CHECKSUM 8 +#define WINDIVERT_HELPER_NO_UDP_CHECKSUM 16 + +#ifndef WINDIVERT_KERNEL + +/* + * Hash a packet. + */ +WINDIVERTEXPORT UINT64 WinDivertHelperHashPacket( + __in const VOID *pPacket, + __in UINT packetLen, + __in UINT64 seed +#ifdef __cplusplus + = 0 +#endif +); + +/* + * Parse IPv4/IPv6/ICMP/ICMPv6/TCP/UDP headers from a raw packet. + */ +WINDIVERTEXPORT BOOL WinDivertHelperParsePacket( + __in const VOID *pPacket, + __in UINT packetLen, + __out_opt PWINDIVERT_IPHDR *ppIpHdr, + __out_opt PWINDIVERT_IPV6HDR *ppIpv6Hdr, + __out_opt UINT8 *pProtocol, + __out_opt PWINDIVERT_ICMPHDR *ppIcmpHdr, + __out_opt PWINDIVERT_ICMPV6HDR *ppIcmpv6Hdr, + __out_opt PWINDIVERT_TCPHDR *ppTcpHdr, + __out_opt PWINDIVERT_UDPHDR *ppUdpHdr, + __out_opt PVOID *ppData, + __out_opt UINT *pDataLen, + __out_opt PVOID *ppNext, + __out_opt UINT *pNextLen); + +/* + * Parse an IPv4 address. + */ +WINDIVERTEXPORT BOOL WinDivertHelperParseIPv4Address( + __in const char *addrStr, + __out_opt UINT32 *pAddr); + +/* + * Parse an IPv6 address. + */ +WINDIVERTEXPORT BOOL WinDivertHelperParseIPv6Address( + __in const char *addrStr, + __out_opt UINT32 *pAddr); + +/* + * Format an IPv4 address. + */ +WINDIVERTEXPORT BOOL WinDivertHelperFormatIPv4Address( + __in UINT32 addr, + __out char *buffer, + __in UINT bufLen); + +/* + * Format an IPv6 address. + */ +WINDIVERTEXPORT BOOL WinDivertHelperFormatIPv6Address( + __in const UINT32 *pAddr, + __out char *buffer, + __in UINT bufLen); + +/* + * Calculate IPv4/IPv6/ICMP/ICMPv6/TCP/UDP checksums. + */ +WINDIVERTEXPORT BOOL WinDivertHelperCalcChecksums( + __inout VOID *pPacket, + __in UINT packetLen, + __out_opt WINDIVERT_ADDRESS *pAddr, + __in UINT64 flags); + +/* + * Decrement the TTL/HopLimit. + */ +WINDIVERTEXPORT BOOL WinDivertHelperDecrementTTL( + __inout VOID *pPacket, + __in UINT packetLen); + +/* + * Compile the given filter string. + */ +WINDIVERTEXPORT BOOL WinDivertHelperCompileFilter( + __in const char *filter, + __in WINDIVERT_LAYER layer, + __out_opt char *object, + __in UINT objLen, + __out_opt const char **errorStr, + __out_opt UINT *errorPos); + +/* + * Evaluate the given filter string. + */ +WINDIVERTEXPORT BOOL WinDivertHelperEvalFilter( + __in const char *filter, + __in const VOID *pPacket, + __in UINT packetLen, + __in const WINDIVERT_ADDRESS *pAddr); + +/* + * Format the given filter string. + */ +WINDIVERTEXPORT BOOL WinDivertHelperFormatFilter( + __in const char *filter, + __in WINDIVERT_LAYER layer, + __out char *buffer, + __in UINT bufLen); + +/* + * Byte ordering. + */ +WINDIVERTEXPORT UINT16 WinDivertHelperNtohs( + __in UINT16 x); +WINDIVERTEXPORT UINT16 WinDivertHelperHtons( + __in UINT16 x); +WINDIVERTEXPORT UINT32 WinDivertHelperNtohl( + __in UINT32 x); +WINDIVERTEXPORT UINT32 WinDivertHelperHtonl( + __in UINT32 x); +WINDIVERTEXPORT UINT64 WinDivertHelperNtohll( + __in UINT64 x); +WINDIVERTEXPORT UINT64 WinDivertHelperHtonll( + __in UINT64 x); +WINDIVERTEXPORT void WinDivertHelperNtohIPv6Address( + __in const UINT *inAddr, + __out UINT *outAddr); +WINDIVERTEXPORT void WinDivertHelperHtonIPv6Address( + __in const UINT *inAddr, + __out UINT *outAddr); + +/* + * Old names to be removed in the next version. + */ +WINDIVERTEXPORT void WinDivertHelperNtohIpv6Address( + __in const UINT *inAddr, + __out UINT *outAddr); +WINDIVERTEXPORT void WinDivertHelperHtonIpv6Address( + __in const UINT *inAddr, + __out UINT *outAddr); + +#endif /* WINDIVERT_KERNEL */ + +#ifdef __cplusplus +} +#endif + +#endif /* __WINDIVERT_H */ diff --git a/depends/WinDivert/x64/WinDivert.dll b/depends/WinDivert/x64/WinDivert.dll new file mode 100644 index 0000000..50ca874 Binary files /dev/null and b/depends/WinDivert/x64/WinDivert.dll differ diff --git a/depends/WinDivert/x64/WinDivert64.sys b/depends/WinDivert/x64/WinDivert64.sys new file mode 100644 index 0000000..218ccaf Binary files /dev/null and b/depends/WinDivert/x64/WinDivert64.sys differ diff --git a/depends/WinDivert/x86/WinDivert.dll b/depends/WinDivert/x86/WinDivert.dll new file mode 100644 index 0000000..b9602c0 Binary files /dev/null and b/depends/WinDivert/x86/WinDivert.dll differ diff --git a/depends/WinDivert/x86/WinDivert32.sys b/depends/WinDivert/x86/WinDivert32.sys new file mode 100644 index 0000000..d06738c Binary files /dev/null and b/depends/WinDivert/x86/WinDivert32.sys differ diff --git a/depends/WinDivert/x86/WinDivert64.sys b/depends/WinDivert/x86/WinDivert64.sys new file mode 100644 index 0000000..218ccaf Binary files /dev/null and b/depends/WinDivert/x86/WinDivert64.sys differ diff --git a/depends/WireGuardNT/amd64/wireguard.dll b/depends/WireGuardNT/amd64/wireguard.dll new file mode 100644 index 0000000..efc0362 Binary files /dev/null and b/depends/WireGuardNT/amd64/wireguard.dll differ diff --git a/depends/WireGuardNT/arm/wireguard.dll b/depends/WireGuardNT/arm/wireguard.dll new file mode 100644 index 0000000..3977c5d Binary files /dev/null and b/depends/WireGuardNT/arm/wireguard.dll differ diff --git a/depends/WireGuardNT/arm64/wireguard.dll b/depends/WireGuardNT/arm64/wireguard.dll new file mode 100644 index 0000000..d49df5c Binary files /dev/null and b/depends/WireGuardNT/arm64/wireguard.dll differ diff --git a/depends/WireGuardNT/include/wireguard.h b/depends/WireGuardNT/include/wireguard.h new file mode 100644 index 0000000..ce562fa --- /dev/null +++ b/depends/WireGuardNT/include/wireguard.h @@ -0,0 +1,308 @@ +/* SPDX-License-Identifier: GPL-2.0 OR MIT + * + * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef ALIGNED +# if defined(_MSC_VER) +# define ALIGNED(n) __declspec(align(n)) +# elif defined(__GNUC__) +# define ALIGNED(n) __attribute__((aligned(n))) +# else +# error "Unable to define ALIGNED" +# endif +#endif + +/* MinGW is missing this one, unfortunately. */ +#ifndef _Post_maybenull_ +# define _Post_maybenull_ +#endif + +#pragma warning(push) +#pragma warning(disable : 4324) /* structure was padded due to alignment specifier */ + +/** + * A handle representing WireGuard adapter + */ +typedef struct _WIREGUARD_ADAPTER *WIREGUARD_ADAPTER_HANDLE; + +/** + * Creates a new WireGuard adapter. + * + * @param Name The requested name of the adapter. Zero-terminated string of up to MAX_ADAPTER_NAME-1 + * characters. + * + * @param TunnelType Name of the adapter tunnel type. Zero-terminated string of up to MAX_ADAPTER_NAME-1 + * characters. + * + * @param RequestedGUID The GUID of the created network adapter, which then influences NLA generation deterministically. + * If it is set to NULL, the GUID is chosen by the system at random, and hence a new NLA entry is + * created for each new adapter. It is called "requested" GUID because the API it uses is + * completely undocumented, and so there could be minor interesting complications with its usage. + * + * @return If the function succeeds, the return value is the adapter handle. Must be released with + * WireGuardCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call + * GetLastError. + */ +typedef _Must_inspect_result_ +_Return_type_success_(return != NULL) +_Post_maybenull_ +WIREGUARD_ADAPTER_HANDLE(WINAPI WIREGUARD_CREATE_ADAPTER_FUNC) +(_In_z_ LPCWSTR Name, _In_z_ LPCWSTR TunnelType, _In_opt_ const GUID *RequestedGUID); + +/** + * Opens an existing WireGuard adapter. + * + * @param Name The requested name of the adapter. Zero-terminated string of up to MAX_ADAPTER_NAME-1 + * characters. + * + * @return If the function succeeds, the return value is the adapter handle. Must be released with + * WireGuardCloseAdapter. If the function fails, the return value is NULL. To get extended error information, call + * GetLastError. + */ +typedef _Must_inspect_result_ +_Return_type_success_(return != NULL) +_Post_maybenull_ +WIREGUARD_ADAPTER_HANDLE(WINAPI WIREGUARD_OPEN_ADAPTER_FUNC)(_In_z_ LPCWSTR Name); + +/** + * Releases WireGuard adapter resources and, if adapter was created with WireGuardCreateAdapter, removes adapter. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter. + */ +typedef VOID(WINAPI WIREGUARD_CLOSE_ADAPTER_FUNC)(_In_opt_ WIREGUARD_ADAPTER_HANDLE Adapter); + +/** + * Deletes the WireGuard driver if there are no more adapters in use. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError. + */ +typedef _Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_DELETE_DRIVER_FUNC)(VOID); + +/** + * Returns the LUID of the adapter. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param Luid Pointer to LUID to receive adapter LUID. + */ +typedef VOID(WINAPI WIREGUARD_GET_ADAPTER_LUID_FUNC)(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, _Out_ NET_LUID *Luid); + +/** + * Determines the version of the WireGuard driver currently loaded. + * + * @return If the function succeeds, the return value is the version number. If the function fails, the return value is + * zero. To get extended error information, call GetLastError. Possible errors include the following: + * ERROR_FILE_NOT_FOUND WireGuard not loaded + */ +typedef _Return_type_success_(return != 0) +DWORD(WINAPI WIREGUARD_GET_RUNNING_DRIVER_VERSION_FUNC)(VOID); + +/** + * Determines the level of logging, passed to WIREGUARD_LOGGER_CALLBACK. + */ +typedef enum +{ + WIREGUARD_LOG_INFO, /**< Informational */ + WIREGUARD_LOG_WARN, /**< Warning */ + WIREGUARD_LOG_ERR /**< Error */ +} WIREGUARD_LOGGER_LEVEL; + +/** + * Called by internal logger to report diagnostic messages + * + * @param Level Message level. + * + * @param Timestamp Message timestamp in in 100ns intervals since 1601-01-01 UTC. + * + * @param Message Message text. + */ +typedef VOID(CALLBACK *WIREGUARD_LOGGER_CALLBACK)( + _In_ WIREGUARD_LOGGER_LEVEL Level, + _In_ DWORD64 Timestamp, + _In_z_ LPCWSTR Message); + +/** + * Sets logger callback function. + * + * @param NewLogger Pointer to callback function to use as a new global logger. NewLogger may be called from various + * threads concurrently. Should the logging require serialization, you must handle serialization in + * NewLogger. Set to NULL to disable. + */ +typedef VOID(WINAPI WIREGUARD_SET_LOGGER_FUNC)(_In_ WIREGUARD_LOGGER_CALLBACK NewLogger); + +/** + * Whether and how logs from the driver are collected for the callback function. + */ +typedef enum +{ + WIREGUARD_ADAPTER_LOG_OFF, /**< No logs are generated from the driver. */ + WIREGUARD_ADAPTER_LOG_ON, /**< Logs are generated from the driver. */ + WIREGUARD_ADAPTER_LOG_ON_WITH_PREFIX /**< Logs are generated from the driver, index-prefixed. */ +} WIREGUARD_ADAPTER_LOG_STATE; + +/** + * Sets whether and how the adapter logs to the logger previously set up with WireGuardSetLogger. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param LogState Adapter logging state. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError. + */ +typedef _Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_SET_ADAPTER_LOGGING_FUNC) +(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, _In_ WIREGUARD_ADAPTER_LOG_STATE LogState); + +/** + * Determines the state of the adapter. + */ +typedef enum +{ + WIREGUARD_ADAPTER_STATE_DOWN, /**< Down */ + WIREGUARD_ADAPTER_STATE_UP, /**< Up */ +} WIREGUARD_ADAPTER_STATE; + +/** + * Sets the adapter state of the WireGuard adapter. Note: sockets are owned by the process that sets the state to up. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param State Adapter state. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError. + */ +typedef _Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_SET_ADAPTER_STATE_FUNC) +(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, _In_ WIREGUARD_ADAPTER_STATE State); + +/** + * Gets the adapter state of the WireGuard adapter. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param State Pointer to adapter state. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError. + */ +typedef _Must_inspect_result_ +_Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_GET_ADAPTER_STATE_FUNC) +(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, _Out_ WIREGUARD_ADAPTER_STATE *State); + +#define WIREGUARD_KEY_LENGTH 32 + +typedef struct _WIREGUARD_ALLOWED_IP WIREGUARD_ALLOWED_IP; +struct ALIGNED(8) _WIREGUARD_ALLOWED_IP +{ + union + { + IN_ADDR V4; + IN6_ADDR V6; + } Address; /**< IP address */ + ADDRESS_FAMILY AddressFamily; /**< Address family, either AF_INET or AF_INET6 */ + BYTE Cidr; /**< CIDR of allowed IPs */ +}; + +typedef enum +{ + WIREGUARD_PEER_HAS_PUBLIC_KEY = 1 << 0, /**< The PublicKey field is set */ + WIREGUARD_PEER_HAS_PRESHARED_KEY = 1 << 1, /**< The PresharedKey field is set */ + WIREGUARD_PEER_HAS_PERSISTENT_KEEPALIVE = 1 << 2, /**< The PersistentKeepAlive field is set */ + WIREGUARD_PEER_HAS_ENDPOINT = 1 << 3, /**< The Endpoint field is set */ + WIREGUARD_PEER_REPLACE_ALLOWED_IPS = 1 << 5, /**< Remove all allowed IPs before adding new ones */ + WIREGUARD_PEER_REMOVE = 1 << 6, /**< Remove specified peer */ + WIREGUARD_PEER_UPDATE = 1 << 7 /**< Do not add a new peer */ +} WIREGUARD_PEER_FLAG; + +typedef struct _WIREGUARD_PEER WIREGUARD_PEER; +struct ALIGNED(8) _WIREGUARD_PEER +{ + WIREGUARD_PEER_FLAG Flags; /**< Bitwise combination of flags */ + DWORD Reserved; /**< Reserved; must be zero */ + BYTE PublicKey[WIREGUARD_KEY_LENGTH]; /**< Public key, the peer's primary identifier */ + BYTE PresharedKey[WIREGUARD_KEY_LENGTH]; /**< Preshared key for additional layer of post-quantum resistance */ + WORD PersistentKeepalive; /**< Seconds interval, or 0 to disable */ + SOCKADDR_INET Endpoint; /**< Endpoint, with IP address and UDP port number*/ + DWORD64 TxBytes; /**< Number of bytes transmitted */ + DWORD64 RxBytes; /**< Number of bytes received */ + DWORD64 LastHandshake; /**< Time of the last handshake, in 100ns intervals since 1601-01-01 UTC */ + DWORD AllowedIPsCount; /**< Number of allowed IP structs following this struct */ +}; + +typedef enum +{ + WIREGUARD_INTERFACE_HAS_PUBLIC_KEY = (1 << 0), /**< The PublicKey field is set */ + WIREGUARD_INTERFACE_HAS_PRIVATE_KEY = (1 << 1), /**< The PrivateKey field is set */ + WIREGUARD_INTERFACE_HAS_LISTEN_PORT = (1 << 2), /**< The ListenPort field is set */ + WIREGUARD_INTERFACE_REPLACE_PEERS = (1 << 3) /**< Remove all peers before adding new ones */ +} WIREGUARD_INTERFACE_FLAG; + +typedef struct _WIREGUARD_INTERFACE WIREGUARD_INTERFACE; +struct ALIGNED(8) _WIREGUARD_INTERFACE +{ + WIREGUARD_INTERFACE_FLAG Flags; /**< Bitwise combination of flags */ + WORD ListenPort; /**< Port for UDP listen socket, or 0 to choose randomly */ + BYTE PrivateKey[WIREGUARD_KEY_LENGTH]; /**< Private key of interface */ + BYTE PublicKey[WIREGUARD_KEY_LENGTH]; /**< Corresponding public key of private key */ + DWORD PeersCount; /**< Number of peer structs following this struct */ +}; + +/** + * Sets the configuration of the WireGuard adapter. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param Config Configuration for the adapter. + * + * @param Bytes Number of bytes in Config allocation. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError. + */ +typedef _Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_SET_CONFIGURATION_FUNC) +(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, _In_reads_bytes_(Bytes) const WIREGUARD_INTERFACE *Config, _In_ DWORD Bytes); + +/** + * Gets the configuration of the WireGuard adapter. + * + * @param Adapter Adapter handle obtained with WireGuardCreateAdapter or WireGuardOpenAdapter + * + * @param Config Configuration for the adapter. + * + * @param Bytes Pointer to number of bytes in Config allocation. + * + * @return If the function succeeds, the return value is nonzero. If the function fails, the return value is zero. To + * get extended error information, call GetLastError, which if ERROR_MORE_DATA, Bytes is updated with the + * required size. + */ +typedef _Must_inspect_result_ +_Return_type_success_(return != FALSE) +BOOL(WINAPI WIREGUARD_GET_CONFIGURATION_FUNC) +(_In_ WIREGUARD_ADAPTER_HANDLE Adapter, + _Out_writes_bytes_all_(*Bytes) WIREGUARD_INTERFACE *Config, + _Inout_ DWORD *Bytes); + +#pragma warning(pop) + +#ifdef __cplusplus +} +#endif diff --git a/depends/WireGuardNT/x86/wireguard.dll b/depends/WireGuardNT/x86/wireguard.dll new file mode 100644 index 0000000..ecca024 Binary files /dev/null and b/depends/WireGuardNT/x86/wireguard.dll differ diff --git a/depends/tunnel/amd64/tunnel.dll b/depends/tunnel/amd64/tunnel.dll new file mode 100644 index 0000000..4db0441 Binary files /dev/null and b/depends/tunnel/amd64/tunnel.dll differ diff --git a/depends/tunnel/arm64/tunnel.dll b/depends/tunnel/arm64/tunnel.dll new file mode 100644 index 0000000..96ddae2 Binary files /dev/null and b/depends/tunnel/arm64/tunnel.dll differ diff --git a/depends/tunnel/x86/tunnel.dll b/depends/tunnel/x86/tunnel.dll new file mode 100644 index 0000000..4adcac3 Binary files /dev/null and b/depends/tunnel/x86/tunnel.dll differ diff --git a/scripts/cleansdk.bat b/scripts/cleansdk.bat new file mode 100644 index 0000000..f31d823 --- /dev/null +++ b/scripts/cleansdk.bat @@ -0,0 +1,6 @@ +@echo off +setlocal EnableDelayedExpansion + +set "CurrCD=%~dp0" + +powershell -Command "& {Remove-Item sdk/*}" \ No newline at end of file diff --git a/scripts/gensdk.bat b/scripts/gensdk.bat new file mode 100644 index 0000000..0fc193a --- /dev/null +++ b/scripts/gensdk.bat @@ -0,0 +1,8 @@ +@echo off +setlocal EnableDelayedExpansion + +set "CurrCD=%~dp0" + +for /f %%i in ('dir sdk /b /s') do ( + powershell -Command "& {$fileContent = Get-Content -Path %%i;$newContent = $fileContent -replace 'TCHAR', 'CHAR ';$newContent | Set-Content -Path %%i}" +) \ No newline at end of file diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 0000000..6a22514 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,22 @@ +{ + "name" : "scc", + "version-string" : "1.0.0", + "builtin-baseline" : "662dbb50e63af15baa2909b7eac5b1b87e86a0aa", + "dependencies" : [ { + "name" : "spdlog", + "version>=" : "1.11.0#1", + "$comment" : " find_package(spdlog CONFIG REQUIRED)\n\n target_link_libraries(main PRIVATE spdlog::spdlog)\n" + }, { + "name" : "openssl", + "version>=" : "3.1.1" + }, { + "name" : "magic-enum", + "version>=" : "0.9.1" + }, { + "name" : "cppcodec", + "version>=" : "0.2#4" + }, { + "name" : "rapidjson", + "version>=" : "2023-04-27" + } ] +} \ No newline at end of file