diff --git a/.clang-tidy b/.clang-tidy index f069654..1e15b84 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,23 +1,24 @@ Checks: ' *, -llvmlibc-*, +-android-*, +-altera-*, +-fuchsia*, -google-readability-todo, --fuchsia-default-arguments-calls, --fuchsia-default-arguments-declarations, -cppcoreguidelines-init-variables, -cppcoreguidelines-pro-type-reinterpret-cast, --android-*, --altera-*, -llvm-namespace-comment, -readability-implicit-bool-conversion, -google-explicit-constructor, --fuchsia-overloaded-operator, +-abseil-string-find-str-contains, +-readability-avoid-return-with-void-value, +-readability-convert-member-functions-to-static, -google-readability-braces-around-statements, -google-readability-namespace-comments, +-hicpp-special-member-functions, -hicpp-braces-around-statements, --hicpp-explicit-conversions, --fuchsia-trailing-return +-hicpp-explicit-conversions ' WarningsAsErrors: 'bugprone-exception-escape' @@ -25,7 +26,7 @@ FormatStyle: 'none' # TODO: Replace with 'file' once we have a proper .clang-for InheritParentConfig: true CheckOptions: misc-include-cleaner.MissingIncludes: 'false' - misc-include-cleaner.IgnoreHeaders: 'CppSockets/OSDetection\.hpp' + misc-include-cleaner.IgnoreHeaders: 'CppSockets/OSDetection.*' bugprone-argument-comment.StrictMode: 1 @@ -34,5 +35,7 @@ CheckOptions: readability-identifier-naming.NamespaceCase: CamelCase - readability-identifier-length.IgnoredVariableNames: "^(fd|nb|ss)$" - readability-identifier-length.IgnoredParameterNames: "^([n]|fd)$" + readability-identifier-length.IgnoredVariableNames: "^(fd|nb|n|ss|ec|is|os|_.*)$" + readability-identifier-length.IgnoredParameterNames: "^(fd|n|is|os|_.*)$" + + cppcoreguidelines-special-member-functions.AllowSoleDefaultDtor: true diff --git a/.github/workflows/cmake-multi-platform.yml b/.github/workflows/cmake-multi-platform.yml index 95b35ff..0ba29c7 100644 --- a/.github/workflows/cmake-multi-platform.yml +++ b/.github/workflows/cmake-multi-platform.yml @@ -95,7 +95,7 @@ jobs: clang-tidy: needs: 'build' runs-on: ubuntu-latest - if: always() && github.event_name == 'pull-request' + if: always() && github.event_name == 'pull_request' steps: - name: Checkout Code diff --git a/.github/workflows/windows-vcpkg/action.yml b/.github/workflows/windows-vcpkg/action.yml index a613609..0feb036 100644 --- a/.github/workflows/windows-vcpkg/action.yml +++ b/.github/workflows/windows-vcpkg/action.yml @@ -27,7 +27,7 @@ runs: shell: powershell run: | echo "CMAKE_TOOLCHAIN_FILE=${env:VCPKG_ROOT}\scripts\buildsystems\vcpkg.cmake" | Out-File -FilePath $env:GITHUB_ENV -Append - vcpkg install --debug + vcpkg install - name: Always Save VCPKG Cache (Windows) if: always() && runner.os == 'Windows' && steps.fetch-vcpkg-cache.outputs.cache-hit != 'true' diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f11c91..5cd8fd3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ ## Author Francois Michaut ## ## Started on Sun Aug 28 19:26:51 2022 Francois Michaut -## Last update Tue Aug 5 19:07:23 2025 Francois Michaut +## Last update Wed Aug 20 17:05:18 2025 Francois Michaut ## ## CMakeLists.txt : CMake to build the CppSockets library ## @@ -21,15 +21,23 @@ project(LibCppSockets VERSION 0.1.0 LANGUAGES C CXX) configure_file(include/CppSockets/Version.hpp.in include/CppSockets/Version.hpp) add_library(cppsockets + source/Tls/Certificate.cpp + source/Tls/Context.cpp + source/Tls/Utils.cpp + source/Tls/Socket.cpp + source/Address.cpp - source/Certificate.cpp source/IPv4.cpp - source/SSL_Utils.cpp source/Socket.cpp source/SocketInit.cpp - source/TlsSocket.cpp ) -target_include_directories(cppsockets PUBLIC $ $) +target_include_directories(cppsockets + PUBLIC + $ + $ + PRIVATE + ${PROJECT_SOURCE_DIR}/private +) find_package(OpenSSL 3.0 COMPONENTS SSL) target_link_libraries(cppsockets OpenSSL::SSL) diff --git a/include/CppSockets/Address.hpp b/include/CppSockets/Address.hpp index 4561c7b..f17f69a 100644 --- a/include/CppSockets/Address.hpp +++ b/include/CppSockets/Address.hpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sun Feb 13 17:09:05 2022 Francois Michaut -** Last update Sat Dec 9 08:52:22 2023 Francois Michaut +** Last update Wed Aug 20 12:57:17 2025 Francois Michaut ** ** Address.hpp : Interface to represent network addresses */ @@ -20,43 +20,40 @@ namespace CppSockets { class IAddress { public: - [[nodiscard]] virtual auto getAddress() const -> std::uint32_t = 0; - [[nodiscard]] virtual auto getFamily() const -> int = 0; - [[nodiscard]] virtual auto toString() const -> const std::string & = 0; + virtual ~IAddress() = default; + + [[nodiscard]] virtual auto get_address() const -> std::uint32_t = 0; + [[nodiscard]] virtual auto get_family() const -> int = 0; + [[nodiscard]] virtual auto to_string() const -> const std::string & = 0; }; class IEndpoint { public: - [[nodiscard]] virtual auto getPort() const -> std::uint16_t = 0; - [[nodiscard]] virtual auto getAddr() const -> const IAddress & = 0; - [[nodiscard]] virtual auto toString() const -> const std::string & = 0; + virtual ~IEndpoint() = default; + + [[nodiscard]] virtual auto get_port() const -> std::uint16_t = 0; + [[nodiscard]] virtual auto get_addr() const -> const IAddress & = 0; + [[nodiscard]] virtual auto to_string() const -> const std::string & = 0; protected: - [[nodiscard]] auto makeString() const -> std::string; + [[nodiscard]] auto make_string() const -> std::string; }; template class Endpoint : public IEndpoint { + // TODO: Replace with new C++ requires static_assert(std::is_base_of::value, - "Endpoint address must derive from IAddress" + "Endpoint address must derive from IAddress" ); public: Endpoint(T addr, std::uint16_t port) : - addr(std::move(addr)), port(port), str(makeString()) + addr(std::move(addr)), port(port), str(make_string()) {}; - virtual ~Endpoint() = default; - - [[nodiscard]] auto getPort() const -> std::uint16_t override { - return port; - } - - [[nodiscard]] auto getAddr() const -> const T & override { - return addr; - } + ~Endpoint() override = default; - [[nodiscard]] auto toString() const -> const std::string & override { - return str; - } + [[nodiscard]] auto get_port() const -> std::uint16_t override { return port; } + [[nodiscard]] auto get_addr() const -> const T & override { return addr; } + [[nodiscard]] auto to_string() const -> const std::string & override { return str; } private: T addr; std::uint16_t port; diff --git a/include/CppSockets/IPv4.hpp b/include/CppSockets/IPv4.hpp index 99ccc38..e169948 100644 --- a/include/CppSockets/IPv4.hpp +++ b/include/CppSockets/IPv4.hpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sun Feb 13 17:05:02 2022 Francois Michaut -** Last update Sat Nov 11 16:58:19 2023 Francois Michaut +** Last update Wed Aug 20 12:57:35 2025 Francois Michaut ** ** IPv4.hpp : Class used to represent and manipulate IPv4 addresses */ @@ -19,10 +19,9 @@ namespace CppSockets { explicit IPv4(std::uint32_t addr); IPv4(const char *addr); // TODO add support for string. Maybe string_view ? - - [[nodiscard]] auto getAddress() const -> std::uint32_t override; - [[nodiscard]] auto getFamily() const -> int override; - [[nodiscard]] auto toString() const -> const std::string & override; + [[nodiscard]] auto get_address() const -> std::uint32_t override; + [[nodiscard]] auto get_family() const -> int override; + [[nodiscard]] auto to_string() const -> const std::string & override; private: std::uint32_t addr; diff --git a/include/CppSockets/SSL_Utils.hpp b/include/CppSockets/SSL_Utils.hpp deleted file mode 100644 index 3988d9b..0000000 --- a/include/CppSockets/SSL_Utils.hpp +++ /dev/null @@ -1,53 +0,0 @@ -/* -** Project LibCppSockets, 2025 -** -** Author Francois Michaut -** -** Started on Fri Aug 1 09:54:53 2025 Francois Michaut -** Last update Sun Aug 3 23:32:20 2025 Francois Michaut -** -** SSL_Utils.hpp : SSL Utility types -*/ - -#pragma once - -#include - -#include - -#define CPP_SOCKETS_SSL_UTILS_DEFINE_DTOR(TYPE) \ - struct TYPE##_dtor { \ - void operator()(TYPE *ptr) { TYPE##_free(ptr); } \ - }; - -#define CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(TYPE) \ - CPP_SOCKETS_SSL_UTILS_DEFINE_DTOR(TYPE) \ - using TYPE##_ptr = std::unique_ptr; - -namespace CppSockets { - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(BIO) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(SSL_CTX) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(SSL) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(X509) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(X509_NAME) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(X509_NAME_ENTRY) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(X509_EXTENSION) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(EVP_PKEY) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(EVP_MD) - CPP_SOCKETS_SSL_UTILS_DEFINE_PTR(EVP_MD_CTX) - - void throw_openssl_error(); - auto check_or_throw_openssl_error(int ret) -> int; - - template - auto check_or_throw_openssl_error(T *ret) -> T * { - if (ret == nullptr) { - throw_openssl_error(); - } - return ret; - } -} - -// Don't leak macros -#undef CPP_SOCKETS_SSL_UTILS_DEFINE_DTOR -#undef CPP_SOCKETS_SSL_UTILS_DEFINE_PTR diff --git a/include/CppSockets/Socket.hpp b/include/CppSockets/Socket.hpp index 7ff7ba8..b6c6c5d 100644 --- a/include/CppSockets/Socket.hpp +++ b/include/CppSockets/Socket.hpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sat Jan 15 01:17:42 2022 Francois Michaut -** Last update Tue Aug 5 00:00:48 2025 Francois Michaut +** Last update Wed Aug 20 14:01:21 2025 Francois Michaut ** ** Socket.hpp : Portable C++ socket class */ @@ -12,7 +12,7 @@ #pragma once #include "CppSockets/OSDetection.hpp" -#include "CppSockets/SocketInit.hpp" +#include "CppSockets/internal/SocketInit.hpp" // TODO: move the RawSocketType in CppSockets namespace #ifdef OS_WINDOWS @@ -65,18 +65,13 @@ namespace CppSockets { void set_blocking(bool val); - [[nodiscard]] - auto get_fd() const -> RawSocketType { return m_sockfd; } - [[nodiscard]] - auto get_type() const -> int { return m_type; } - [[nodiscard]] - auto get_domain() const -> int { return m_domain; } - [[nodiscard]] - auto get_protocol() const -> int { return m_protocol; } + [[nodiscard]] auto get_fd() const -> RawSocketType { return m_sockfd; } + [[nodiscard]] auto get_type() const -> int { return m_type; } + [[nodiscard]] auto get_domain() const -> int { return m_domain; } + [[nodiscard]] auto get_protocol() const -> int { return m_protocol; } // TODO: Allow to get Endpoint - [[nodiscard]] - auto connected() const -> bool { return m_is_connected; } + [[nodiscard]] auto connected() const -> bool { return m_is_connected; } static auto get_errno() -> int; static auto strerror(int err) -> char *; diff --git a/include/CppSockets/Certificate.hpp b/include/CppSockets/Tls/Certificate.hpp similarity index 59% rename from include/CppSockets/Certificate.hpp rename to include/CppSockets/Tls/Certificate.hpp index 1539ce7..7303a43 100644 --- a/include/CppSockets/Certificate.hpp +++ b/include/CppSockets/Tls/Certificate.hpp @@ -4,20 +4,21 @@ ** Author Francois Michaut ** ** Started on Fri Aug 1 09:50:33 2025 Francois Michaut -** Last update Tue Aug 5 19:08:13 2025 Francois Michaut +** Last update Wed Aug 20 17:17:44 2025 Francois Michaut ** ** Certificate.hpp : Classes to create and manage Certificates */ #pragma once -#include "CppSockets/SSL_Utils.hpp" +#include "CppSockets/Tls/Utils.hpp" #include #include #include +// TODO: Add weak ptr equivalent instead of `bool own = false` namespace CppSockets { class x509NameEntry; @@ -25,32 +26,30 @@ namespace CppSockets { public: x509Name(); x509Name(X509_NAME_ptr ptr); + x509Name(X509_NAME *ptr, bool own = true); x509Name(const x509Name &other) { *this = other; } x509Name(x509Name &&other) noexcept = default; auto operator=(const x509Name &other) -> x509Name &; auto operator=(x509Name &&other) noexcept -> x509Name & = default; - ~x509Name() = default; + ~x509Name(); - [[nodiscard]] - auto clone() const -> x509Name { return {*this}; } + [[nodiscard]] auto clone() const -> x509Name { return {*this}; } void add_entry(const x509NameEntry &entry, int loc = -1, int set = 0); void add_entry(const std::string &field_name, int type, const std::u8string &data, int loc = -1, int set = 0); void add_entry(const ASN1_OBJECT *obj, int type, const std::u8string &data, int loc = -1, int set = 0); void add_entry(int nid, int type, const std::u8string &data, int loc = -1, int set = 0); - [[nodiscard]] - auto entry_count() const -> int; - [[nodiscard]] - auto get_entry(int loc) const -> x509NameEntry; + [[nodiscard]] auto entry_count() const -> int; + [[nodiscard]] auto get_entry_by_index(int idx) const -> x509NameEntry; + [[nodiscard]] auto get_entry(int nid, int lastpos = -1) const -> x509NameEntry; + [[nodiscard]] auto get_entry(const ASN1_OBJECT *obj, int lastpos = -1) const -> x509NameEntry; auto delete_entry(int loc) -> x509NameEntry; - [[nodiscard]] - auto get_index(int nid, int lastpos = -1) const -> int; - [[nodiscard]] - auto get_index(const ASN1_OBJECT *obj, int lastpos = -1) const -> int; + [[nodiscard]] auto get_index(int nid, int lastpos = -1) const -> int; + [[nodiscard]] auto get_index(const ASN1_OBJECT *obj, int lastpos = -1) const -> int; // TODO ? // X509_NAME_get0_der @@ -63,16 +62,17 @@ namespace CppSockets { // X509_NAME_print_ex // X509_NAME_print_ex_fp - [[nodiscard]] - auto get() const -> X509_NAME * { return m_ptr.get(); } + [[nodiscard]] auto get() const -> X509_NAME * { return m_ptr.get(); } private: X509_NAME_ptr m_ptr; + bool m_own = true; }; class x509NameEntry { public: x509NameEntry(); x509NameEntry(X509_NAME_ENTRY_ptr ptr); + x509NameEntry(X509_NAME_ENTRY *ptr, bool own = true); x509NameEntry(const std::string &name, int type, const std::u8string &data); x509NameEntry(const ASN1_OBJECT *obj, int type, const std::u8string &data); x509NameEntry(int nid, int type, const std::u8string &data); @@ -82,29 +82,27 @@ namespace CppSockets { auto operator=(const x509NameEntry &other) -> x509NameEntry &; auto operator=(x509NameEntry &&other) noexcept -> x509NameEntry & = default; - ~x509NameEntry() = default; + ~x509NameEntry(); - [[nodiscard]] - auto clone() const -> x509NameEntry { return {*this}; } + [[nodiscard]] auto clone() const -> x509NameEntry { return {*this}; } void set_object(const ASN1_OBJECT *obj); void set_data(int type, const std::u8string &data); - [[nodiscard]] - auto get_object() const -> ASN1_OBJECT *; - [[nodiscard]] - auto get_data() const -> ASN1_STRING *; + [[nodiscard]] auto get_object() const -> const ASN1_OBJECT *; + [[nodiscard]] auto get_data() const -> const ASN1_STRING *; - [[nodiscard]] - auto get() const -> X509_NAME_ENTRY * { return m_ptr.get(); } + [[nodiscard]] auto get() const -> X509_NAME_ENTRY * { return m_ptr.get(); } private: X509_NAME_ENTRY_ptr m_ptr; + bool m_own = true; }; class x509Extension { public: x509Extension(); x509Extension(X509_EXTENSION_ptr ptr); + x509Extension(X509_EXTENSION *ptr, bool own = true); x509Extension(int nid, int crit, ASN1_OCTET_STRING *data); x509Extension(const ASN1_OBJECT *obj, int crit, ASN1_OCTET_STRING *data); @@ -113,32 +111,29 @@ namespace CppSockets { auto operator=(const x509Extension &other) -> x509Extension &; auto operator=(x509Extension &&other) noexcept -> x509Extension & = default; - ~x509Extension() = default; + ~x509Extension(); - [[nodiscard]] - auto clone() const -> x509Extension { return {*this}; } + [[nodiscard]] auto clone() const -> x509Extension { return {*this}; } void set_data(ASN1_OCTET_STRING *data); void set_object(const ASN1_OBJECT *obj); void set_critical(bool crit); - [[nodiscard]] - auto get_data() const -> ASN1_OCTET_STRING *; - [[nodiscard]] - auto get_object() const -> ASN1_OBJECT *; - [[nodiscard]] - auto get_critical() const -> bool; + [[nodiscard]] auto get_data() const -> ASN1_OCTET_STRING *; + [[nodiscard]] auto get_object() const -> ASN1_OBJECT *; + [[nodiscard]] auto get_critical() const -> bool; - [[nodiscard]] - auto get() const -> X509_EXTENSION * { return m_ptr.get(); } + [[nodiscard]] auto get() const -> X509_EXTENSION * { return m_ptr.get(); } private: X509_EXTENSION_ptr m_ptr; + bool m_own = true; }; class x509Certificate { public: x509Certificate(); - x509Certificate(X509_ptr x509); + x509Certificate(X509_ptr ptr); + x509Certificate(X509 *ptr, bool own = true); explicit x509Certificate(const std::filesystem::path &pem_file_path); x509Certificate(const x509Certificate &other) { *this = other; } @@ -146,60 +141,55 @@ namespace CppSockets { auto operator=(const x509Certificate &other) -> x509Certificate &; auto operator=(x509Certificate &&other) noexcept -> x509Certificate & = default; - ~x509Certificate() = default; + ~x509Certificate(); - [[nodiscard]] - auto clone() const -> x509Certificate { return {*this}; } + [[nodiscard]] auto clone() const -> x509Certificate { return {*this}; } // TODO: Provide overloads for password protected certs void load(const std::filesystem::path &pem_file_path); void save(const std::filesystem::path &pem_file_path) const; - // TODO: Get methods - // TODO: Provide overloads for hardcoded time void set_not_before(int offset_day, std::int64_t offset_sec, time_t *in_tm = nullptr); void set_not_after(int offset_day, std::int64_t offset_sec, time_t *in_tm = nullptr); + [[nodiscard]] auto get_not_before() const -> const ASN1_TIME *; + [[nodiscard]] auto get_not_after() const -> const ASN1_TIME *; void set_version(std::int64_t version); - [[nodiscard]] - auto get_version() const -> std::int64_t; + [[nodiscard]] auto get_version() const -> std::int64_t; - void set_serial_number(std::int64_t serial_number); void set_serial_number(std::uint64_t serial_number); void set_serial_number(BIGNUM *serial_number); + [[nodiscard]] auto get_serial_number() const -> ASN1_INTEGER *; void set_public_key(const EVP_PKEY_ptr &pkey); + [[nodiscard]] auto get_pubkey() const -> const EVP_PKEY *; void set_subject_name(const x509Name &name); void set_issuer_name(const x509Name &name); void set_self_signed_name(const x509Name &name); + [[nodiscard]] auto get_issuer_name() const -> x509Name; + [[nodiscard]] auto get_subject_name() const -> x509Name; void add_extension(const x509Extension &ext, int loc = -1); auto delete_extension(int loc) -> x509Extension; - [[nodiscard]] - auto extension_count() const -> int; - - [[nodiscard]] - auto get_extension(std::uint32_t loc) const -> x509Extension; - [[nodiscard]] - auto get_extension_by(int nid, int lastpos = - 1) const -> x509Extension; - [[nodiscard]] - auto get_extension_by(const ASN1_OBJECT *obj, int lastpos = -1) const -> x509Extension; - [[nodiscard]] - auto get_extension_by(bool critical, int lastpos = - 1) const -> x509Extension; - - [[nodiscard]] - auto self_signed(bool verify_signature) const -> bool; - [[nodiscard]] - auto verify(const EVP_PKEY_ptr &pkey) const -> bool; + [[nodiscard]] auto extension_count() const -> int; + + [[nodiscard]] auto get_extension(std::uint32_t loc) const -> x509Extension; + [[nodiscard]] auto get_extension_by(int nid, int lastpos = - 1) const -> x509Extension; + [[nodiscard]] auto get_extension_by(const ASN1_OBJECT *obj, int lastpos = -1) const -> x509Extension; + [[nodiscard]] auto get_extension_by(bool critical, int lastpos = - 1) const -> x509Extension; + + [[nodiscard]] auto self_signed(bool verify_signature) const -> bool; + [[nodiscard]] auto verify(const EVP_PKEY_ptr &pkey) const -> bool; + [[nodiscard]] auto verify() const -> bool; void sign(const EVP_PKEY_ptr &pkey, const EVP_MD *digest = EVP_sha256()); - [[nodiscard]] - auto get() const -> X509 * { return m_ptr.get(); } + [[nodiscard]] auto get() const -> X509 * { return m_ptr.get(); } private: X509_ptr m_ptr; + bool m_own = true; }; using Certificate = x509Certificate; diff --git a/include/CppSockets/Tls/Context.hpp b/include/CppSockets/Tls/Context.hpp new file mode 100644 index 0000000..00d0476 --- /dev/null +++ b/include/CppSockets/Tls/Context.hpp @@ -0,0 +1,48 @@ +/* +** Project LibCppSockets, 2025 +** +** Author Francois Michaut +** +** Started on Wed Aug 20 14:13:44 2025 Francois Michaut +** Last update Thu Aug 21 14:14:45 2025 Francois Michaut +** +** Context.hpp : Context for TLS sockets +*/ + +#pragma once + +#include "CppSockets/Tls/Utils.hpp" + +#include + +namespace CppSockets { + class TlsContext { + public: + TlsContext(); + TlsContext(SSL_CTX_ptr ptr); + TlsContext(SSL_CTX *ptr, bool own = true); + + TlsContext(const TlsContext &other) { *this = other; } + TlsContext(TlsContext &&other) noexcept = default; + auto operator=(const TlsContext &other) -> TlsContext &; + auto operator=(TlsContext &&other) noexcept -> TlsContext & = default; + + ~TlsContext(); + + void set_min_proto_version(int version); + + void set_verify(int mode, TlsVerifyCallback callback = {}); + void set_verify_depth(int depth); + + void set_certificate(std::string_view cert_path, std::string_view pkey_path); + [[nodiscard]] auto check_private_key() const -> bool; + + [[nodiscard]] auto get() const -> SSL_CTX * { return m_ptr.get(); } + private: + SSL_CTX_ptr m_ptr; + bool m_own = true; + TlsVerifyCallback m_verify_callback; + + friend struct TlsContext_Accessor; + }; +} diff --git a/include/CppSockets/TlsSocket.hpp b/include/CppSockets/Tls/Socket.hpp similarity index 64% rename from include/CppSockets/TlsSocket.hpp rename to include/CppSockets/Tls/Socket.hpp index c0d322e..229db95 100644 --- a/include/CppSockets/TlsSocket.hpp +++ b/include/CppSockets/Tls/Socket.hpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Wed Sep 14 20:51:23 2022 Francois Michaut -** Last update Tue Aug 5 00:01:16 2025 Francois Michaut +** Last update Wed Aug 20 23:11:28 2025 Francois Michaut ** ** SecureSocket.hpp : TLS socket wrapper using openssl */ @@ -12,9 +12,9 @@ #pragma once #include "CppSockets/OSDetection.hpp" - -#include "CppSockets/SSL_Utils.hpp" #include "CppSockets/Socket.hpp" +#include "CppSockets/Tls/Context.hpp" +#include "CppSockets/Tls/Utils.hpp" namespace CppSockets { // TODO add more TLS-related functions @@ -23,7 +23,8 @@ namespace CppSockets { TlsSocket() = default; // TODO: Constructor allowing application to reuse SSL_CTX objects // (Maybe even a different TLS_CTX class to manage them ?) - TlsSocket(int domain, int type, int protocol); + TlsSocket(int domain, int type, int protocol, TlsContext ctx = {}); + explicit TlsSocket(Socket &&other, TlsContext ctx = {}); explicit TlsSocket(Socket &&other, SSL_ptr ssl = nullptr); explicit TlsSocket(RawSocketType fd, SSL_ptr ssl = nullptr); ~TlsSocket() noexcept; @@ -35,7 +36,6 @@ namespace CppSockets { auto read(std::size_t len = -1) -> std::string; auto read(char *buff, std::size_t size) -> std::size_t; - auto write(const std::string &buff) -> std::size_t { return this->write(buff.c_str(), buff.size()); } auto write(std::string_view buff) -> std::size_t { return this->write(buff.data(), buff.size()); }; auto write(const char *buff, std::size_t len) -> std::size_t; @@ -43,24 +43,19 @@ namespace CppSockets { void set_certificate(const std::string &cert_path, const std::string &pkey_path); auto connect(const IEndpoint &endpoint) -> int; - auto accept(void *addr_out = nullptr, const SSL_CTX_ptr &ctx = nullptr) -> std::unique_ptr; - auto accept(const SSL_CTX_ptr &ctx) -> std::unique_ptr; + auto accept(void *addr_out, TlsContext ctx) -> std::unique_ptr; + auto accept(void *addr_out = nullptr) -> std::unique_ptr { return accept(addr_out, m_ctx); } + auto accept(TlsContext ctx) -> std::unique_ptr { return accept(nullptr, std::move(ctx)); } - [[nodiscard]] - auto get_ssl_ctx() const -> const SSL_CTX_ptr & { return m_ctx; } - [[nodiscard]] - auto get_ssl() const -> const SSL_ptr & { return m_ssl; } - [[nodiscard]] - auto get_client_cert() const -> const X509_ptr & { return m_peer_cert; } + [[nodiscard]] auto get_ssl_ctx() const -> TlsContext; + [[nodiscard]] auto get_ssl() const -> const SSL_ptr & { return m_ssl; } + [[nodiscard]] auto get_peer_cert() const -> const X509_ptr & { return m_peer_cert; } - [[nodiscard]] - auto tls_strerror(int ret) -> std::string; + [[nodiscard]] auto tls_strerror(int ret) -> std::string; private: - SSL_CTX_ptr m_ctx; + TlsContext m_ctx; SSL_ptr m_ssl; X509_ptr m_peer_cert; - X509_ptr m_cert; - EVP_PKEY_ptr m_pkey; void check_for_error(const std::string &error_msg, int ret); }; diff --git a/include/CppSockets/Tls/Utils.hpp b/include/CppSockets/Tls/Utils.hpp new file mode 100644 index 0000000..b73eaa4 --- /dev/null +++ b/include/CppSockets/Tls/Utils.hpp @@ -0,0 +1,61 @@ +/* +** Project LibCppSockets, 2025 +** +** Author Francois Michaut +** +** Started on Fri Aug 1 09:54:53 2025 Francois Michaut +** Last update Wed Aug 20 16:49:57 2025 Francois Michaut +** +** Utils.hpp : Tls Utility types +*/ + +#pragma once + +#include + +#include +#include + +#define CPP_SOCKETS_TLS_UTILS_DEFINE_DTOR(TYPE, PREFIX) \ + struct TYPE##_dtor { \ + void operator()(TYPE *ptr) { PREFIX##_free(ptr); } \ + }; + +#define CPP_SOCKETS_TLS_UTILS_DEFINE_PTR_CMD_PREFIX(TYPE, PREFIX) \ + CPP_SOCKETS_TLS_UTILS_DEFINE_DTOR(TYPE, PREFIX) \ + using TYPE##_ptr = std::unique_ptr; + +#define CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(TYPE) \ + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR_CMD_PREFIX(TYPE, TYPE) + +namespace CppSockets { + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(BIO) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR_CMD_PREFIX(BIGNUM, BN) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(SSL_CTX) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(SSL) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(X509) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(X509_NAME) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(X509_NAME_ENTRY) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(X509_EXTENSION) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(ASN1_STRING) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(EVP_PKEY) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(EVP_MD) + CPP_SOCKETS_TLS_UTILS_DEFINE_PTR(EVP_MD_CTX) + + void throw_openssl_error(); + auto check_or_throw_openssl_error(int ret) -> int; + + template + auto check_or_throw_openssl_error(T *ret) -> T * { + if (ret == nullptr) { + throw_openssl_error(); + } + return ret; + } + + using TlsVerifyCallback=std::function; +} + +// Don't leak macros +#undef CPP_SOCKETS_TLS_UTILS_DEFINE_DTOR +#undef CPP_SOCKETS_TLS_UTILS_DEFINE_PTR diff --git a/include/CppSockets/SocketInit.hpp b/include/CppSockets/internal/SocketInit.hpp similarity index 86% rename from include/CppSockets/SocketInit.hpp rename to include/CppSockets/internal/SocketInit.hpp index caee1b6..fa30590 100644 --- a/include/CppSockets/SocketInit.hpp +++ b/include/CppSockets/internal/SocketInit.hpp @@ -4,15 +4,14 @@ ** Author Francois Michaut ** ** Started on Sat Jan 15 01:17:42 2022 Francois Michaut -** Last update Tue Nov 14 19:37:59 2023 Francois Michaut +** Last update Wed Aug 20 13:04:47 2025 Francois Michaut ** ** SocketInit.hpp : Socket class automatic initialization and teardown */ // Inspired from https://stackoverflow.com/questions/64753466/how-do-i-automatically-implicitly-create-a-instance-of-a-class-at-program-launch/64754436#64754436 -namespace CppSockets -{ +namespace CppSockets { class SocketInit { struct Cleanup { ~Cleanup(); diff --git a/private/CppSockets/SslMacros.hpp b/private/CppSockets/SslMacros.hpp new file mode 100644 index 0000000..17c9080 --- /dev/null +++ b/private/CppSockets/SslMacros.hpp @@ -0,0 +1,51 @@ +/* +** Project LibCppSockets, 2025 +** +** Author Francois Michaut +** +** Started on Wed Aug 20 16:54:02 2025 Francois Michaut +** Last update Wed Aug 20 18:59:18 2025 Francois Michaut +** +** SslMacros.hpp : Private Macros to define SSL wrappers +*/ + +#define REQUIRED_PTR(ptr, name) \ + if (!ptr) { \ + throw std::runtime_error("Failed to create " name); \ + } + +#define ASSIGNMENT_OPERATOR(type) \ + if (this == &other) { \ + return *this; \ + } \ + \ + type *dup = type##_dup(other.m_ptr.get()); \ + \ + if (dup == nullptr) { \ + throw std::runtime_error("Failed to dup ##type##"); \ + } \ + if (!this->m_own) { \ + (void)this->m_ptr.release(); \ + } \ + this->m_ptr.reset(dup); \ + this->m_own = true; \ + return *this; \ + +#define UP_REF_ASSIGNMENT_OPERATOR(type) \ + if (this == &other) { \ + return *this; \ + } \ + \ + if (!this->m_own) { \ + (void)this->m_ptr.release(); \ + } \ + this->m_ptr.reset(other.m_ptr.get()); \ + this->m_own = false; \ + return *this; \ + +#define MAKE_DESTRUCTOR(klass) \ + klass::~klass() { \ + if (!m_own) { \ + (void)m_ptr.release(); \ + } \ + } diff --git a/source/Address.cpp b/source/Address.cpp index 76a26d1..8f5d680 100644 --- a/source/Address.cpp +++ b/source/Address.cpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sun Feb 13 22:03:32 2022 Francois Michaut -** Last update Sun Aug 3 21:59:36 2025 Francois Michaut +** Last update Wed Aug 20 12:58:08 2025 Francois Michaut ** ** Address.cpp : Implementation of generic Address classes & functions */ @@ -12,7 +12,7 @@ #include "CppSockets/Address.hpp" namespace CppSockets { - auto IEndpoint::makeString() const -> std::string { - return this->getAddr().toString() + ":" + std::to_string(this->getPort()); + auto IEndpoint::make_string() const -> std::string { + return this->get_addr().to_string() + ":" + std::to_string(this->get_port()); } } diff --git a/source/IPv4.cpp b/source/IPv4.cpp index 45a0dbf..8930244 100644 --- a/source/IPv4.cpp +++ b/source/IPv4.cpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sun Feb 13 18:52:28 2022 Francois Michaut -** Last update Tue Aug 5 01:43:47 2025 Francois Michaut +** Last update Wed Aug 20 12:58:26 2025 Francois Michaut ** ** IPv4.cpp : Implementation of IPv4 class */ @@ -41,15 +41,15 @@ namespace CppSockets { this->addr = address.s_addr; } - auto IPv4::getAddress() const -> std::uint32_t { + auto IPv4::get_address() const -> std::uint32_t { return addr; } - auto IPv4::toString() const -> const std::string & { + auto IPv4::to_string() const -> const std::string & { return str; } - auto IPv4::getFamily() const -> int { + auto IPv4::get_family() const -> int { return AF_INET; } } diff --git a/source/Socket.cpp b/source/Socket.cpp index 49aee5e..074d2e7 100644 --- a/source/Socket.cpp +++ b/source/Socket.cpp @@ -4,7 +4,7 @@ ** Author Francois Michaut ** ** Started on Sat Jan 15 01:27:40 2022 Francois Michaut -** Last update Tue Aug 5 14:46:12 2025 Francois Michaut +** Last update Wed Aug 20 12:59:14 2025 Francois Michaut ** ** Socket.cpp : Protable C++ socket class implementation */ @@ -55,7 +55,8 @@ namespace CppSockets { // TODO: more error handling arround is_connected == false and sockfd == INVALID in IO calls Socket::Socket() : - m_sockfd(INVALID_SOCKET) + m_sockfd(INVALID_SOCKET) // TODO: Currently using this constructor results in an UNSUABLE socket. + // That's bad -> fix it. {} Socket::Socket(int domain, int type, int protocol) : @@ -65,24 +66,20 @@ namespace CppSockets { throw std::runtime_error(std::string("Failed to create socket : ") + std::strerror(errno)); } + Socket::~Socket() { + close(); + } + Socket::Socket(Socket &&other) noexcept : m_sockfd(INVALID_SOCKET) { *this = std::move(other); } - Socket::~Socket() { - close(); - } - auto Socket::operator=(Socket &&other) noexcept -> Socket & { - if (&other == this) - return *this; - this->close(); + std::swap(m_sockfd, other.m_sockfd); - m_sockfd = other.m_sockfd; m_domain = other.m_domain; - other.m_sockfd = INVALID_SOCKET; m_is_connected = other.m_is_connected; return *this; } @@ -190,7 +187,7 @@ namespace CppSockets { auto Socket::bind(const IEndpoint &endpoint) -> int { // TODO: this only works for IPv4. Need to switch getFamily() to handle // IPv6 / AF_UNIX ... - return this->bind(endpoint.getAddr().getAddress(), endpoint.getPort()); + return this->bind(endpoint.get_addr().get_address(), endpoint.get_port()); } auto Socket::bind(std::uint32_t source_addr, uint16_t port) -> int { // NOLINT(readability-make-member-function-const) @@ -215,12 +212,13 @@ namespace CppSockets { struct sockaddr_in addr = {0}; int ret = 0; - addr.sin_addr.s_addr = endpoint.getAddr().getAddress(); - addr.sin_port = htons(endpoint.getPort()); - addr.sin_family = endpoint.getAddr().getFamily(); + addr.sin_addr.s_addr = endpoint.get_addr().get_address(); + addr.sin_port = htons(endpoint.get_port()); + addr.sin_family = endpoint.get_addr().get_family(); + // TODO: If connected close / reconnect ret = ::connect(m_sockfd, reinterpret_cast(&addr), sizeof(addr)); if (ret < 0) - throw std::runtime_error(std::string("Failed to connect socket to ") + endpoint.toString() + " : " + Socket::strerror()); + throw std::runtime_error(std::string("Failed to connect socket to ") + endpoint.to_string() + " : " + Socket::strerror()); m_is_connected = ret == 0; return ret; } diff --git a/source/SocketInit.cpp b/source/SocketInit.cpp index 22b1a2a..1e88909 100644 --- a/source/SocketInit.cpp +++ b/source/SocketInit.cpp @@ -4,13 +4,13 @@ ** Author Francois Michaut ** ** Started on Thu Sep 15 14:24:25 2022 Francois Michaut -** Last update Tue May 9 23:34:46 2023 Francois Michaut +** Last update Wed Aug 20 14:00:59 2025 Francois Michaut ** ** init.cpp : Startup/Cleanup functions implementation */ #include "CppSockets/OSDetection.hpp" -#include "CppSockets/SocketInit.hpp" +#include "CppSockets/internal/SocketInit.hpp" #include #include diff --git a/source/Certificate.cpp b/source/Tls/Certificate.cpp similarity index 81% rename from source/Certificate.cpp rename to source/Tls/Certificate.cpp index 107142b..9b32990 100644 --- a/source/Certificate.cpp +++ b/source/Tls/Certificate.cpp @@ -4,37 +4,21 @@ ** Author Francois Michaut ** ** Started on Sat Aug 2 22:41:35 2025 Francois Michaut -** Last update Tue Aug 5 19:12:12 2025 Francois Michaut +** Last update Wed Aug 20 17:05:49 2025 Francois Michaut ** ** Certificate.cpp : Implementation of classes to create and manage Certificates */ -#include "CppSockets/Certificate.hpp" -#include "CppSockets/SSL_Utils.hpp" +#include "CppSockets/Tls/Certificate.hpp" +#include "CppSockets/Tls/Utils.hpp" + +#include "CppSockets/SslMacros.hpp" #include #include #include #include -#define REQUIRED_PTR(ptr, name) \ - if (!ptr) { \ - throw std::runtime_error("Failed to create " name); \ - } - -#define ASSIGNMENT_OPERATOR(type) \ - if (this == &other) { \ - return *this; \ - } \ - \ - type *dup = type##_dup(other.m_ptr.get()); \ - \ - if (dup == nullptr) { \ - throw std::runtime_error("Failed to dup ##type##"); \ - } \ - this->m_ptr.reset(dup); \ - return *this; \ - namespace { template inline auto numeric_cast(const Src value) -> Dst { @@ -61,6 +45,14 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509_NAME") } + x509Name::x509Name(X509_NAME *ptr, bool own) : + m_ptr(ptr), m_own(own) + { + REQUIRED_PTR(m_ptr, "X509_NAME") + } + + MAKE_DESTRUCTOR(x509Name) + auto x509Name::operator=(const x509Name &other) -> x509Name & { ASSIGNMENT_OPERATOR(X509_NAME) } @@ -93,10 +85,22 @@ namespace CppSockets { return X509_NAME_entry_count(m_ptr.get()); } - auto x509Name::get_entry(int loc) const -> x509NameEntry { - X509_NAME_ENTRY_ptr ptr{check_or_throw_openssl_error(X509_NAME_get_entry(m_ptr.get(), loc))}; + auto x509Name::get_entry_by_index(int idx) const -> x509NameEntry { + X509_NAME_ENTRY *ptr{check_or_throw_openssl_error(X509_NAME_get_entry(m_ptr.get(), idx))}; - return {std::move(ptr)}; + return {ptr, false}; + } + + auto x509Name::get_entry(int nid, int lastpos) const -> x509NameEntry { + int index = get_index(nid, lastpos); + + return get_entry_by_index(index); + } + + auto x509Name::get_entry(const ASN1_OBJECT *obj, int lastpos) const -> x509NameEntry { + int index = get_index(obj, lastpos); + + return get_entry_by_index(index); } auto x509Name::delete_entry(int loc) -> x509NameEntry { @@ -128,6 +132,12 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509_NAME_ENTRY") } + x509NameEntry::x509NameEntry(X509_NAME_ENTRY *ptr, bool own) : + m_ptr(ptr), m_own(own) + { + REQUIRED_PTR(m_ptr, "X509_NAME_ENTRY") + } + x509NameEntry::x509NameEntry(const std::string &name, int type, const std::u8string &data) : m_ptr(X509_NAME_ENTRY_create_by_txt(nullptr, name.c_str(), type, reinterpret_cast(data.c_str()), numeric_cast(data.size()))) { @@ -146,6 +156,8 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509_NAME_ENTRY") } + MAKE_DESTRUCTOR(x509NameEntry) + auto x509NameEntry::operator=(const x509NameEntry &other) -> x509NameEntry & { ASSIGNMENT_OPERATOR(X509_NAME_ENTRY) } @@ -162,11 +174,11 @@ namespace CppSockets { check_or_throw_openssl_error(ret); } - auto x509NameEntry::get_object() const -> ASN1_OBJECT * { + auto x509NameEntry::get_object() const -> const ASN1_OBJECT * { return check_or_throw_openssl_error(X509_NAME_ENTRY_get_object(m_ptr.get())); } - auto x509NameEntry::get_data() const -> ASN1_STRING * { + auto x509NameEntry::get_data() const -> const ASN1_STRING * { return check_or_throw_openssl_error(X509_NAME_ENTRY_get_data(m_ptr.get())); } } @@ -185,6 +197,12 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509_EXTENSION") } + x509Extension::x509Extension(X509_EXTENSION *ptr, bool own) : + m_ptr(ptr), m_own(own) + { + REQUIRED_PTR(m_ptr, "X509_EXTENSION") + } + x509Extension::x509Extension(int nid, int crit, ASN1_OCTET_STRING *data) : m_ptr(X509_EXTENSION_create_by_NID(nullptr, nid, crit, data)) { @@ -197,6 +215,8 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509_EXTENSION") } + MAKE_DESTRUCTOR(x509Extension) + auto x509Extension::operator=(const x509Extension &other) -> x509Extension & { ASSIGNMENT_OPERATOR(X509_EXTENSION) } @@ -234,12 +254,20 @@ namespace CppSockets { REQUIRED_PTR(m_ptr, "X509") } - x509Certificate::x509Certificate(X509_ptr x509) : - m_ptr(std::move(x509)) + x509Certificate::x509Certificate(X509_ptr ptr) : + m_ptr(std::move(ptr)) + { + REQUIRED_PTR(m_ptr, "X509") + } + + x509Certificate::x509Certificate(X509 *ptr, bool own) : + m_ptr(ptr), m_own(own) { REQUIRED_PTR(m_ptr, "X509") } + MAKE_DESTRUCTOR(x509Certificate) + x509Certificate::x509Certificate(const std::filesystem::path &pem_file_path) { load(pem_file_path); } @@ -273,6 +301,30 @@ namespace CppSockets { } } + auto x509Certificate::get_not_before() const -> const ASN1_TIME * { + return X509_get0_notBefore(m_ptr.get()); + } + + auto x509Certificate::get_not_after() const -> const ASN1_TIME * { + return X509_get0_notAfter(m_ptr.get()); + } + + auto x509Certificate::get_serial_number() const -> ASN1_INTEGER * { + return X509_get_serialNumber(m_ptr.get()); + } + + auto x509Certificate::get_pubkey() const -> const EVP_PKEY * { + return check_or_throw_openssl_error(X509_get0_pubkey(m_ptr.get())); + } + + auto x509Certificate::get_issuer_name() const -> x509Name { + return {X509_get_issuer_name(m_ptr.get()), false}; + } + + auto x509Certificate::get_subject_name() const -> x509Name { + return {X509_get_subject_name(m_ptr.get()), false}; + } + void x509Certificate::set_not_before(int offset_day, std::int64_t offset_sec, time_t *in_tm) { ASN1_TIME *not_before = X509_getm_notBefore(m_ptr.get()); @@ -289,7 +341,6 @@ namespace CppSockets { } } - void x509Certificate::set_version(std::int64_t version) { if (!X509_set_version(m_ptr.get(), version)) { throw std::runtime_error("Failed to set version"); @@ -300,14 +351,6 @@ namespace CppSockets { return X509_get_version(m_ptr.get()); } - void x509Certificate::set_serial_number(int64_t serial_number) { - ASN1_INTEGER *ptr = X509_get_serialNumber(m_ptr.get()); - - if (!ASN1_INTEGER_set_int64(ptr, serial_number)) { - throw std::runtime_error("Failed to set serial number"); - } - } - void x509Certificate::set_serial_number(uint64_t serial_number) { ASN1_INTEGER *ptr = X509_get_serialNumber(m_ptr.get()); @@ -367,12 +410,12 @@ namespace CppSockets { } auto x509Certificate::get_extension(std::uint32_t loc) const -> x509Extension { - X509_EXTENSION_ptr ptr {X509_get_ext(m_ptr.get(), numeric_cast(loc))}; + X509_EXTENSION *ptr {X509_get_ext(m_ptr.get(), numeric_cast(loc))}; if (!ptr) { throw std::runtime_error("Failed to get extension"); } - return {std::move(ptr)}; + return {ptr, false}; } auto x509Certificate::get_extension_by(int nid, int lastpos) const -> x509Extension { @@ -408,7 +451,7 @@ namespace CppSockets { if (ret < 0) { throw std::runtime_error("Failed to check certificate self-signed"); } - return ret; + return ret == 1; } auto x509Certificate::verify(const EVP_PKEY_ptr &pkey) const -> bool { @@ -417,7 +460,16 @@ namespace CppSockets { if (ret < 0) { throw std::runtime_error("Failed to check certificate signature"); } - return ret; + return ret == 1; + } + + auto x509Certificate::verify() const -> bool { + auto ret = X509_verify(m_ptr.get(), X509_get0_pubkey(m_ptr.get())); + + if (ret < 0) { + throw std::runtime_error("Failed to check certificate signature"); + } + return ret == 1; } void x509Certificate::sign(const EVP_PKEY_ptr &pkey, const EVP_MD *digest) { diff --git a/source/Tls/Context.cpp b/source/Tls/Context.cpp new file mode 100644 index 0000000..ebd5f76 --- /dev/null +++ b/source/Tls/Context.cpp @@ -0,0 +1,113 @@ +/* +** Project LibCppSockets, 2025 +** +** Author Francois Michaut +** +** Started on Wed Aug 20 14:40:41 2025 Francois Michaut +** Last update Wed Aug 20 18:58:53 2025 Francois Michaut +** +** Context.cpp : Implementation of the Context for TLS sockets +*/ + +#include "CppSockets/Tls/Context.hpp" + +#include "CppSockets/SslMacros.hpp" + +#include +#include +#include + +#include + +namespace CppSockets { + // TODO: Free this ? + const int TLS_CONTEXT_IDX = SSL_CTX_get_ex_new_index(0, (void *)"TlsContent index", nullptr, nullptr, nullptr); + + struct TlsContext_Accessor { + static auto get_function(TlsContext &ctx) -> TlsVerifyCallback & { + return ctx.m_verify_callback; + } + }; + + static auto tls_context_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) -> int { + SSL *ssl = static_cast(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + SSL_CTX *ssl_ctx = SSL_get_SSL_CTX(ssl); + + auto *tls_ctx = static_cast(SSL_CTX_get_ex_data(ssl_ctx, TLS_CONTEXT_IDX)); + + return CppSockets::TlsContext_Accessor::get_function(*tls_ctx)(preverify_ok, ctx); + } +} + +#define TLS_CONTEXT_CONSTRUCTOR_BODY \ + REQUIRED_PTR(m_ptr, "SSL_CTX"); \ + SSL_CTX_set_ex_data(m_ptr.get(), TLS_CONTEXT_IDX, this); \ + set_min_proto_version(TLS1_1_VERSION); + +namespace CppSockets { + TlsContext::TlsContext() : + m_ptr(SSL_CTX_new(TLS_method())) + { + TLS_CONTEXT_CONSTRUCTOR_BODY; + } + + TlsContext::TlsContext(SSL_CTX_ptr ptr) : + m_ptr(std::move(ptr)) + { + TLS_CONTEXT_CONSTRUCTOR_BODY; + } + + TlsContext::TlsContext(SSL_CTX *ptr, bool own) : + m_ptr(ptr), m_own(own) + { + TLS_CONTEXT_CONSTRUCTOR_BODY; + } + + auto TlsContext::operator=(const TlsContext &other) -> TlsContext & { + UP_REF_ASSIGNMENT_OPERATOR(SSL_CTX) + } + + MAKE_DESTRUCTOR(TlsContext) + + void TlsContext::set_min_proto_version(int version) { + if (!SSL_CTX_set_min_proto_version(m_ptr.get(), version)) { + throw std::runtime_error("Failed to set TlsProtocol Version"); + } + } + + void TlsContext::set_verify(int mode, TlsVerifyCallback callback) { + if (!m_own) { + // We would loose the TlsVerifyCallback if the TlsContext goes out of scope, but the SSL_CTX + // hasn't been freed yet, which would result in a crash if the callback was then requested + throw std::runtime_error("Can't set_verify on a non-owned TlsContext. Use the Raw SSL_CTX methods"); + } + m_verify_callback = std::move(callback); + SSL_verify_cb raw_cb = m_verify_callback ? tls_context_verify_callback : nullptr; + + SSL_CTX_set_verify(m_ptr.get(), mode, raw_cb); + } + + void TlsContext::set_verify_depth(int depth) { + SSL_CTX_set_verify_depth(m_ptr.get(), depth); + } + + void TlsContext::set_certificate(std::string_view cert_path, std::string_view pkey_path) { + BIO_ptr cert(BIO_new_file(cert_path.data(), "r")); + BIO_ptr pkey(BIO_new_file(pkey_path.data(), "r")); + // TODO: handle pkey password: SSL_CTX_set_default_passwd_cb or PEM_read_bio_X509 last 2 args + X509_ptr x509(PEM_read_bio_X509(cert.get(), nullptr, nullptr, nullptr)); + EVP_PKEY_ptr evp_pkey(PEM_read_bio_PrivateKey(pkey.get(), nullptr, nullptr, nullptr)); + + if (SSL_CTX_use_certificate(m_ptr.get(), x509.get()) <= 0) { + throw std::runtime_error("Failed to set certificate"); + } + + if (SSL_CTX_use_PrivateKey(m_ptr.get(), evp_pkey.get()) <= 0 ) { + throw std::runtime_error("Failed to set private key"); + } + } + + auto TlsContext::check_private_key() const -> bool { + return SSL_CTX_check_private_key(m_ptr.get()) == 1; + } +} diff --git a/source/TlsSocket.cpp b/source/Tls/Socket.cpp similarity index 69% rename from source/TlsSocket.cpp rename to source/Tls/Socket.cpp index c49fb83..51db942 100644 --- a/source/TlsSocket.cpp +++ b/source/Tls/Socket.cpp @@ -4,13 +4,15 @@ ** Author Francois Michaut ** ** Started on Wed Sep 14 21:04:42 2022 Francois Michaut -** Last update Tue Aug 5 13:49:19 2025 Francois Michaut +** Last update Wed Aug 20 23:12:24 2025 Francois Michaut ** ** SecureSocket.cpp : TLS socket wrapper implementation */ #include "CppSockets/OSDetection.hpp" -#include "CppSockets/TlsSocket.hpp" +#include "CppSockets/Tls/Context.hpp" +#include "CppSockets/Tls/Socket.hpp" +#include "CppSockets/Tls/Utils.hpp" #include #include @@ -36,15 +38,11 @@ namespace { return ss.str(); } - void init_ssl_socket(SSL *ssl, SSL_CTX *ctx, CppSockets::TlsSocket *socket) { - int success = 1; - - if (!ctx || !ssl || !SSL_set_fd(ssl, socket->get_fd())) { + void init_ssl_socket(SSL *ssl, CppSockets::TlsSocket *socket) { + if (!ssl || !SSL_set_fd(ssl, socket->get_fd())) { throw std::runtime_error(std::string("Failed to initialize TLS socket: ") + socket->tls_strerror(0)); } - success = success && SSL_CTX_set_min_proto_version(ctx, TLS1_VERSION); - success = success && SSL_set_min_proto_version(ssl, TLS1_VERSION); - if (!success) { + if (!SSL_set_min_proto_version(ssl, TLS1_VERSION)) { throw std::runtime_error(std::string("Failed to select TLS version: ") + socket->tls_strerror(0)); } } @@ -54,37 +52,27 @@ namespace { namespace CppSockets { // TODO check if base destroctor is called (need to close the socket if error is raised) // TODO check if needs to call SSL_shutdown in such cases - TlsSocket::TlsSocket(int domain, int type, int protocol) : + TlsSocket::TlsSocket(int domain, int type, int protocol, TlsContext ctx) : CppSockets::Socket(domain, type, protocol), - m_ctx(SSL_CTX_new(TLS_method())), - m_ssl((m_ctx ? SSL_new(m_ctx.get()) : nullptr)), - m_peer_cert(nullptr), - m_cert(nullptr), - m_pkey(nullptr) + m_ctx(std::move(ctx)), m_ssl((SSL_new(m_ctx.get()))), m_peer_cert(nullptr) { - init_ssl_socket(m_ssl.get(), m_ctx.get(), this); + init_ssl_socket(m_ssl.get(), this); } - // TlsSocket::TlsSocket(RawSocketType fd, SSL_ptr ssl) : - // CppSockets::Socket(fd), - // m_ctx((ssl ? SSL_get_SSL_CTX(ssl.get()) : SSL_CTX_new(TLS_method())), SSL_CTX_free), - // m_ssl(ssl ? std::move(ssl) : (m_ctx ? SSL_ptr(SSL_new(m_ctx.get()), SSL_free) : nullptr)), - // m_peer_cert(nullptr, X509_free), - // m_cert(nullptr, X509_free), - // m_pkey(nullptr, EVP_PKEY_free) - // { - // init_ssl_socket(m_ssl.get(), m_ctx.get(), this); - // } + TlsSocket::TlsSocket(Socket &&other, TlsContext ctx) : + CppSockets::Socket(std::move(other)), // TODO: if socket is not connected, at that moment, does it break ? + m_ctx(std::move(ctx)), m_ssl(SSL_ptr(SSL_new(m_ctx.get()))), m_peer_cert(nullptr) + { + init_ssl_socket(m_ssl.get(), this); + } TlsSocket::TlsSocket(Socket &&other, SSL_ptr ssl) : CppSockets::Socket(std::move(other)), // TODO: if socket is not connected, at that moment, does it break ? - m_ctx((ssl ? SSL_get_SSL_CTX(ssl.get()) : SSL_CTX_new(TLS_method()))), - m_ssl(ssl ? std::move(ssl) : (m_ctx ? SSL_ptr(SSL_new(m_ctx.get())) : nullptr)), - m_peer_cert(nullptr), - m_cert(nullptr), - m_pkey(nullptr) + m_ctx({ssl ? SSL_get_SSL_CTX(ssl.get()) : SSL_CTX_new(TLS_method()), !ssl}), + m_ssl(ssl ? std::move(ssl) : SSL_ptr(SSL_new(m_ctx.get()))), + m_peer_cert(nullptr) { - init_ssl_socket(m_ssl.get(), m_ctx.get(), this); + init_ssl_socket(m_ssl.get(), this); } TlsSocket::~TlsSocket() noexcept { @@ -106,16 +94,14 @@ namespace CppSockets { TlsSocket::TlsSocket(TlsSocket &&other) noexcept : Socket(std::move(other)), m_ctx(std::move(other.m_ctx)), - m_ssl(std::move(other.m_ssl)), m_peer_cert(std::move(other.m_peer_cert)), - m_cert(std::move(other.m_cert)), m_pkey(std::move(other.m_pkey)) + m_ssl(std::move(other.m_ssl)), m_peer_cert(std::move(other.m_peer_cert)) {} auto TlsSocket::operator=(TlsSocket &&other) noexcept -> TlsSocket & { + std::swap(m_ssl, other.m_ssl); + m_ctx = std::move(other.m_ctx); - m_ssl = std::move(other.m_ssl); m_peer_cert = std::move(other.m_peer_cert); - m_cert = std::move(other.m_cert); - m_pkey = std::move(other.m_pkey), Socket::operator=(std::move(other)); return *this; @@ -130,21 +116,17 @@ namespace CppSockets { void TlsSocket::set_certificate(const std::string &cert_path, const std::string &pkey_path) { BIO_ptr cert(BIO_new_file(cert_path.c_str(), "r")); BIO_ptr pkey(BIO_new_file(pkey_path.c_str(), "r")); - // TODO: handle pkey password: SSL_CTX_set_default_passwd_cb or PEM_read_bio_X509 last 2 args + // TODO: handle pkey password: SSL_set_default_passwd_cb or PEM_read_bio_X509 last 2 args X509_ptr x509(PEM_read_bio_X509(cert.get(), nullptr, nullptr, nullptr)); EVP_PKEY_ptr evp_pkey(PEM_read_bio_PrivateKey(pkey.get(), nullptr, nullptr, nullptr)); - // TODO: While setting it on the CTX makes sense imo (since accepted sockets will inherit this), an application - // might not want that behavior. Need to provide alertnate ways to set certificate on CTX vs SSL - if (SSL_CTX_use_certificate(m_ctx.get(), x509.get()) <= 0) { + if (SSL_use_certificate(m_ssl.get(), x509.get()) <= 0) { throw std::runtime_error(std::string("Failed to set certificate: ") + TlsSocket::tls_strerror(0)); } - m_cert = std::move(x509); - if (SSL_CTX_use_PrivateKey(m_ctx.get(), evp_pkey.get()) <= 0 ) { + if (SSL_use_PrivateKey(m_ssl.get(), evp_pkey.get()) <= 0 ) { throw std::runtime_error(std::string("Failed to set private key: ") + TlsSocket::tls_strerror(0)); } - m_pkey = std::move(evp_pkey); } // TODO add SSL_get_shutdown checks in read operations @@ -209,40 +191,28 @@ namespace CppSockets { return ret; } - auto TlsSocket::accept(const SSL_CTX_ptr &ctx) -> std::unique_ptr { - return accept(nullptr, ctx); - } - - auto TlsSocket::accept(void *addr_out, const SSL_CTX_ptr &ctx) -> std::unique_ptr { + auto TlsSocket::accept(void *addr_out, TlsContext ctx) -> std::unique_ptr { std::unique_ptr res = Socket::accept(addr_out); std::unique_ptr tls; int ssl_ret = 0; - SSL_CTX *raw_ctx = ctx.get(); if (!res) { return nullptr; } - if (!raw_ctx) { - raw_ctx = m_ctx.get(); - } - - // TODO: Not really sure we should do this. The TlsSockets shouldn't own the CTX. - // But since its already ref-counted, a shared_ptr doesnt make sense... - // Currently required cause the Constructor will get the CTX from the SSL object, and place it in a unique_ptr. - // Which means if will call CTX_free on TlsSocket destroyed. So we need to up_ref so it doesnt free for real. - if (SSL_CTX_up_ref(raw_ctx) == 0) { - throw std::runtime_error("Failed to DUP SSL_CTX: " + TlsSocket::tls_strerror(0)); - } - tls = std::make_unique(std::move(*res.release()), SSL_ptr(SSL_new(raw_ctx))); + tls = std::make_unique(std::move(*res), std::move(ctx)); ssl_ret = SSL_accept(tls->get_ssl().get()); if (ssl_ret <= 0) { throw std::runtime_error("Failed to accept TLS connection: " + TlsSocket::tls_strerror(ssl_ret)); } - tls->m_peer_cert.reset(SSL_get_peer_certificate(m_ssl.get())); + tls->m_peer_cert.reset(SSL_get_peer_certificate(tls->get_ssl().get())); return tls; } + auto TlsSocket::get_ssl_ctx() const -> TlsContext { + return {SSL_get_SSL_CTX(m_ssl.get()), false}; + } + auto TlsSocket::tls_strerror(int ret) -> std::string { int err = SSL_get_error(m_ssl.get(), ret); diff --git a/source/SSL_Utils.cpp b/source/Tls/Utils.cpp similarity index 92% rename from source/SSL_Utils.cpp rename to source/Tls/Utils.cpp index ac42fa9..de69f92 100644 --- a/source/SSL_Utils.cpp +++ b/source/Tls/Utils.cpp @@ -4,12 +4,12 @@ ** Author Francois Michaut ** ** Started on Sun Aug 3 20:36:03 2025 Francois Michaut -** Last update Sun Aug 3 22:12:04 2025 Francois Michaut +** Last update Wed Aug 20 14:12:29 2025 Francois Michaut ** ** SSL_Utils.cpp : SSL Utility implementations */ -#include "CppSockets/SSL_Utils.hpp" +#include "CppSockets/Tls/Utils.hpp" #include