diff --git a/README.md b/README.md index 98bb187f..1aa34f91 100644 --- a/README.md +++ b/README.md @@ -40,12 +40,16 @@ cd .. #### Run server and client examples + ### WS + ```sh ./build/ws_examples ``` +See also: [Simple-WebSocket-Sample](sample/README.md) + ### WSS Before running the WSS-examples, an RSA private key (server.key) and an SSL certificate (server.crt) must be created. Follow, for instance, the instructions given here (for a self-signed certificate): http://www.akadia.com/services/ssh_test_certificate.html @@ -54,3 +58,5 @@ Then: ``` ./build/wss_examples ``` + + diff --git a/client_ws.hpp b/client_ws.hpp index 3580ea7e..33514028 100644 --- a/client_ws.hpp +++ b/client_ws.hpp @@ -199,11 +199,11 @@ namespace SimpleWeb { size_t num_bytes; if(length > 0xffff) { num_bytes = 8; - send_stream->put(static_cast(127 + 128)); + send_stream->put(static_cast(/*127 + 128*/ 0xFEu)); } else { num_bytes = 2; - send_stream->put(static_cast(126 + 128)); + send_stream->put(static_cast(/*126 + 128*/ 0xFEu)); } for(size_t c = num_bytes - 1; c != static_cast(-1); c--) @@ -216,7 +216,7 @@ namespace SimpleWeb { send_stream->put(static_cast(mask[c])); for(size_t c = 0; c < length; c++) - send_stream->put(message_stream->get() ^ mask[c % 4]); + send_stream->put(static_cast(message_stream->get() ^ mask[c % 4])); auto self = this->shared_from_this(); strand.post([self, send_stream, callback]() { @@ -234,7 +234,7 @@ namespace SimpleWeb { auto send_stream = std::make_shared(); - send_stream->put(status >> 8); + send_stream->put(static_cast(status >> 8)); send_stream->put(status % 256); *send_stream << reason; @@ -333,6 +333,8 @@ namespace SimpleWeb { unsigned short port; std::string path; + std::string protocol = ""; + std::shared_ptr connection; std::mutex connection_mutex; @@ -363,8 +365,8 @@ namespace SimpleWeb { virtual void connect() = 0; - void handshake(const std::shared_ptr &connection) { - connection->read_remote_endpoint_data(); + void handshake(const std::shared_ptr &new_connection) { + new_connection->read_remote_endpoint_data(); auto write_buffer = std::make_shared(); @@ -387,69 +389,71 @@ namespace SimpleWeb { auto nonce_base64 = std::make_shared(Crypto::Base64::encode(nonce)); request << "Sec-WebSocket-Key: " << *nonce_base64 << "\r\n"; request << "Sec-WebSocket-Version: 13\r\n"; + if (protocol != "") + request << "Sec-WebSocket-Protocol: " << protocol << "\r\n"; request << "\r\n"; - connection->message = std::shared_ptr(new Message()); + new_connection->message = std::shared_ptr(new Message()); - connection->set_timeout(config.timeout_request); - asio::async_write(*connection->socket, *write_buffer, [this, connection, write_buffer, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + new_connection->set_timeout(config.timeout_request); + asio::async_write(*new_connection->socket, *write_buffer, [this, new_connection, write_buffer, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { + new_connection->cancel_timeout(); + auto lock = new_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - connection->set_timeout(this->config.timeout_request); - asio::async_read_until(*connection->socket, connection->message->streambuf, "\r\n\r\n", [this, connection, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + new_connection->set_timeout(this->config.timeout_request); + asio::async_read_until(*new_connection->socket, new_connection->message->streambuf, "\r\n\r\n", [this, new_connection, nonce_base64](const error_code &ec, size_t /*bytes_transferred*/) { + new_connection->cancel_timeout(); + auto lock = new_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - if(!ResponseMessage::parse(*connection->message, connection->http_version, connection->status_code, connection->header) || - connection->status_code != "101 Web Socket Protocol Handshake") { - this->connection_error(connection, make_error_code::make_error_code(errc::protocol_error)); + if(!ResponseMessage::parse(*new_connection->message, new_connection->http_version, new_connection->status_code, new_connection->header) || + new_connection->status_code.substr(0, 3) != "101") { + this->connection_error(new_connection, make_error_code::make_error_code(errc::protocol_error)); return; } - auto header_it = connection->header.find("Sec-WebSocket-Accept"); + auto header_it = new_connection->header.find("Sec-WebSocket-Accept"); static auto ws_magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; - if(header_it != connection->header.end() && + if(header_it != new_connection->header.end() && Crypto::Base64::decode(header_it->second) == Crypto::sha1(*nonce_base64 + ws_magic_string)) { - this->connection_open(connection); - read_message(connection); + this->connection_open(new_connection); + read_message(new_connection); } else - this->connection_error(connection, make_error_code::make_error_code(errc::protocol_error)); + this->connection_error(new_connection, make_error_code::make_error_code(errc::protocol_error)); } else - this->connection_error(connection, ec); + this->connection_error(new_connection, ec); }); } else - this->connection_error(connection, ec); + this->connection_error(new_connection, ec); }); } - void read_message(const std::shared_ptr &connection) { - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t bytes_transferred) { - auto lock = connection->handler_runner->continue_lock(); + void read_message(const std::shared_ptr &msg_connection) { + asio::async_read(*msg_connection->socket, msg_connection->message->streambuf, asio::transfer_exactly(2), [this, msg_connection](const error_code &ec, size_t bytes_transferred) { + auto lock = msg_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { if(bytes_transferred == 0) { // TODO: This might happen on server at least, might also happen here - this->read_message(connection); + this->read_message(msg_connection); return; } std::vector first_bytes; first_bytes.resize(2); - connection->message->read(reinterpret_cast(&first_bytes[0]), 2); + msg_connection->message->read(reinterpret_cast(&first_bytes[0]), 2); - connection->message->fin_rsv_opcode = first_bytes[0]; + msg_connection->message->fin_rsv_opcode = first_bytes[0]; // Close connection if masked message from server (protocol error) if(first_bytes[1] >= 128) { const std::string reason("message from server masked"); - connection->send_close(1002, reason); - this->connection_close(connection, 1002, reason); + msg_connection->send_close(1002, reason); + this->connection_close(msg_connection, 1002, reason); return; } @@ -457,123 +461,123 @@ namespace SimpleWeb { if(length == 126) { // 2 next bytes is the size of content - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(2), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); + asio::async_read(*msg_connection->socket, msg_connection->message->streambuf, asio::transfer_exactly(2), [this, msg_connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = msg_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { std::vector length_bytes; length_bytes.resize(2); - connection->message->read(reinterpret_cast(&length_bytes[0]), 2); + msg_connection->message->read(reinterpret_cast(&length_bytes[0]), 2); size_t length = 0; size_t num_bytes = 2; for(size_t c = 0; c < num_bytes; c++) length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - connection->message->length = length; - this->read_message_content(connection); + msg_connection->message->length = length; + this->read_message_content(msg_connection); } else - this->connection_error(connection, ec); + this->connection_error(msg_connection, ec); }); } else if(length == 127) { // 8 next bytes is the size of content - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(8), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); + asio::async_read(*msg_connection->socket, msg_connection->message->streambuf, asio::transfer_exactly(8), [this, msg_connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = msg_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { std::vector length_bytes; length_bytes.resize(8); - connection->message->read(reinterpret_cast(&length_bytes[0]), 8); + msg_connection->message->read(reinterpret_cast(&length_bytes[0]), 8); size_t length = 0; size_t num_bytes = 8; for(size_t c = 0; c < num_bytes; c++) length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - connection->message->length = length; - this->read_message_content(connection); + msg_connection->message->length = length; + this->read_message_content(msg_connection); } else - this->connection_error(connection, ec); + this->connection_error(msg_connection, ec); }); } else { - connection->message->length = length; - this->read_message_content(connection); + msg_connection->message->length = length; + this->read_message_content(msg_connection); } } else - this->connection_error(connection, ec); + this->connection_error(msg_connection, ec); }); } - void read_message_content(const std::shared_ptr &connection) { - asio::async_read(*connection->socket, connection->message->streambuf, asio::transfer_exactly(connection->message->length), [this, connection](const error_code &ec, size_t /*bytes_transferred*/) { - auto lock = connection->handler_runner->continue_lock(); + void read_message_content(const std::shared_ptr &msg_connection) { + asio::async_read(*msg_connection->socket, msg_connection->message->streambuf, asio::transfer_exactly(msg_connection->message->length), [this, msg_connection](const error_code &ec, size_t /*bytes_transferred*/) { + auto lock = msg_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { // If connection close - if((connection->message->fin_rsv_opcode & 0x0f) == 8) { + if((msg_connection->message->fin_rsv_opcode & 0x0f) == 8) { int status = 0; - if(connection->message->length >= 2) { - unsigned char byte1 = connection->message->get(); - unsigned char byte2 = connection->message->get(); + if(msg_connection->message->length >= 2) { + unsigned char byte1 = static_cast(msg_connection->message->get()); + unsigned char byte2 = static_cast(msg_connection->message->get()); status = (byte1 << 8) + byte2; } - auto reason = connection->message->string(); - connection->send_close(status, reason); - this->connection_close(connection, status, reason); + auto reason = msg_connection->message->string(); + msg_connection->send_close(status, reason); + this->connection_close(msg_connection, status, reason); return; } // If ping - else if((connection->message->fin_rsv_opcode & 0x0f) == 9) { + else if((msg_connection->message->fin_rsv_opcode & 0x0f) == 9) { // Send pong auto empty_send_stream = std::make_shared(); - connection->send(empty_send_stream, nullptr, connection->message->fin_rsv_opcode + 1); + msg_connection->send(empty_send_stream, nullptr, msg_connection->message->fin_rsv_opcode + 1); } else if(this->on_message) { - connection->cancel_timeout(); - connection->set_timeout(); - this->on_message(connection, connection->message); + msg_connection->cancel_timeout(); + msg_connection->set_timeout(); + this->on_message(msg_connection, msg_connection->message); } // Next message - connection->message = std::shared_ptr(new Message()); - this->read_message(connection); + msg_connection->message = std::shared_ptr(new Message()); + this->read_message(msg_connection); } else - this->connection_error(connection, ec); + this->connection_error(msg_connection, ec); }); } - void connection_open(const std::shared_ptr &connection) const { - connection->cancel_timeout(); - connection->set_timeout(); + void connection_open(const std::shared_ptr &opening_connection) const { + opening_connection->cancel_timeout(); + opening_connection->set_timeout(); if(on_open) - on_open(connection); + on_open(opening_connection); } - void connection_close(const std::shared_ptr &connection, int status, const std::string &reason) const { - connection->cancel_timeout(); - connection->set_timeout(); + void connection_close(const std::shared_ptr &closing_connection, int status, const std::string &reason) const { + closing_connection->cancel_timeout(); + closing_connection->set_timeout(); if(on_close) - on_close(connection, status, reason); + on_close(closing_connection, status, reason); } - void connection_error(const std::shared_ptr &connection, const error_code &ec) const { - connection->cancel_timeout(); - connection->set_timeout(); + void connection_error(const std::shared_ptr &err_connection, const error_code &ec) const { + err_connection->cancel_timeout(); + err_connection->set_timeout(); if(on_error) - on_error(connection, ec); + on_error(err_connection, ec); } }; @@ -587,38 +591,43 @@ namespace SimpleWeb { public: SocketClient(const std::string &server_port_path) noexcept : SocketClientBase::SocketClientBase(server_port_path, 80){}; + void SetProtocol(std::string theProtocol) + { + SocketClientBase::protocol = theProtocol; + } + protected: void connect() override { std::unique_lock lock(connection_mutex); - auto connection = this->connection = std::shared_ptr(new Connection(handler_runner, config.timeout_idle, *io_service)); + auto newConnection = this->connection = std::shared_ptr(new Connection(handler_runner, config.timeout_idle, *io_service)); lock.unlock(); asio::ip::tcp::resolver::query query(host, std::to_string(port)); auto resolver = std::make_shared(*io_service); - connection->set_timeout(config.timeout_request); - resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + newConnection->set_timeout(config.timeout_request); + resolver->async_resolve(query, [this, newConnection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { + newConnection->cancel_timeout(); + auto lock = newConnection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - connection->set_timeout(this->config.timeout_request); - asio::async_connect(*connection->socket, it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + newConnection->set_timeout(this->config.timeout_request); + asio::async_connect(*newConnection->socket, it, [this, newConnection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { + newConnection->cancel_timeout(); + auto lock = newConnection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { asio::ip::tcp::no_delay option(true); - connection->socket->set_option(option); + newConnection->socket->set_option(option); - this->handshake(connection); + this->handshake(newConnection); } else - this->connection_error(connection, ec); + this->connection_error(newConnection, ec); }); } else - this->connection_error(connection, ec); + this->connection_error(newConnection, ec); }); } }; diff --git a/client_wss.hpp b/client_wss.hpp index bf27841e..9dbca043 100644 --- a/client_wss.hpp +++ b/client_wss.hpp @@ -43,47 +43,47 @@ namespace SimpleWeb { void connect() override { std::unique_lock connection_lock(connection_mutex); - auto connection = this->connection = std::shared_ptr(new Connection(handler_runner, config.timeout_idle, *io_service, context)); + auto new_connection = this->connection = std::shared_ptr(new Connection(handler_runner, config.timeout_idle, *io_service, context)); connection_lock.unlock(); asio::ip::tcp::resolver::query query(host, std::to_string(port)); auto resolver = std::make_shared(*io_service); - connection->set_timeout(config.timeout_request); - resolver->async_resolve(query, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + new_connection->set_timeout(config.timeout_request); + resolver->async_resolve(query, [this, new_connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator it) { + new_connection->cancel_timeout(); + auto lock = new_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { - connection->set_timeout(this->config.timeout_request); - asio::async_connect(connection->socket->lowest_layer(), it, [this, connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + new_connection->set_timeout(this->config.timeout_request); + asio::async_connect(new_connection->socket->lowest_layer(), it, [this, new_connection, resolver](const error_code &ec, asio::ip::tcp::resolver::iterator /*it*/) { + new_connection->cancel_timeout(); + auto lock = new_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { asio::ip::tcp::no_delay option(true); - connection->socket->lowest_layer().set_option(option); + new_connection->socket->lowest_layer().set_option(option); - SSL_set_tlsext_host_name(connection->socket->native_handle(), this->host.c_str()); + SSL_set_tlsext_host_name(new_connection->socket->native_handle(), this->host.c_str()); - connection->set_timeout(this->config.timeout_request); - connection->socket->async_handshake(asio::ssl::stream_base::client, [this, connection](const error_code &ec) { - connection->cancel_timeout(); - auto lock = connection->handler_runner->continue_lock(); + new_connection->set_timeout(this->config.timeout_request); + new_connection->socket->async_handshake(asio::ssl::stream_base::client, [this, new_connection](const error_code &ec) { + new_connection->cancel_timeout(); + auto lock = new_connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) - handshake(connection); + handshake(new_connection); else - this->connection_error(connection, ec); + this->connection_error(new_connection, ec); }); } else - this->connection_error(connection, ec); + this->connection_error(new_connection, ec); }); } else - this->connection_error(connection, ec); + this->connection_error(new_connection, ec); }); } }; diff --git a/crypto.hpp b/crypto.hpp index 697d8f3e..70e18870 100644 --- a/crypto.hpp +++ b/crypto.hpp @@ -212,8 +212,8 @@ namespace SimpleWeb { static std::string pbkdf2(const std::string &password, const std::string &salt, int iterations, int key_size) noexcept { std::string key; key.resize(static_cast(key_size)); - PKCS5_PBKDF2_HMAC_SHA1(password.c_str(), password.size(), - reinterpret_cast(salt.c_str()), salt.size(), iterations, + PKCS5_PBKDF2_HMAC_SHA1(password.c_str(), static_cast(password.size()), + reinterpret_cast(salt.c_str()), static_cast(salt.size()), iterations, key_size, reinterpret_cast(&key[0])); return key; } diff --git a/sample/.gitignore b/sample/.gitignore new file mode 100644 index 00000000..567609b1 --- /dev/null +++ b/sample/.gitignore @@ -0,0 +1 @@ +build/ diff --git a/sample/CMakeLists.txt b/sample/CMakeLists.txt new file mode 100644 index 00000000..d8498a22 --- /dev/null +++ b/sample/CMakeLists.txt @@ -0,0 +1,92 @@ +cmake_minimum_required (VERSION 2.8) +project (Simple-WebSocket-Sample) +get_filename_component(SAMPLE_ROOT ${CMAKE_CURRENT_LIST_FILE} DIRECTORY) + +if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 /WX") +elseif(CMAKE_COMPILER_IS_GNUCC OR CMAKE_COMPILER_IS_GNUCXX) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Wextra -Wsign-conversion") +endif() + + + +set(BOOST_COMPONENTS system coroutine context thread) +# Late 2017 TODO: remove the following checks and always use std::regex +if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9) + set(BOOST_COMPONENTS ${BOOST_COMPONENTS} regex) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_BOOST_REGEX") + endif() +endif() + +if (WIN32) + + # prereq: populate $env:BoostRoot and $env:BoostVer to tell cmake where boost is + set(BOOST_ROOT "$ENV{BoostRoot}") + set(BOOST_LIBRARYDIR "$ENV{BoostRoot}/lib") + set(BOOST_INCLUDEDIR "$ENV{BoostRoot}/include/boost-$ENV{BoostVer}") + add_definitions(-DBOOST_ALL_DYN_LINK) + set(BOOST_COMPONENTS ${BOOST_COMPONENTS} regex) + + file(GLOB boostDlls "$ENV{BoostRoot}/lib/*.dll") + foreach(boostDll ${boostDlls}) + file(TO_CMAKE_PATH "${boostDll}" correctedBoostDll) + get_filename_component(boostFileName ${correctedBoostDll} NAME) + if(boostFileName MATCHES ".*boost_(chrono|system|regex|system|thread|coroutine|context).*\\.dll$") + configure_file( "${correctedBoostDll}" "${CMAKE_BINARY_DIR}" COPYONLY) + endif() + endforeach() + + # prereq: populate $env:OpenSSLRoot to tell cmake where OpenSSL is + set(OPENSSL_ROOT_DIR "$ENV{OpenSSLRoot}") + set(OPENSSL_INCLUDE_DIR "$ENV{OpenSSLRoot}/include") + +endif() + +find_package(Boost 1.54.0 COMPONENTS ${BOOST_COMPONENTS} REQUIRED) +include_directories(SYSTEM ${Boost_INCLUDE_DIR}) +include_directories(SYSTEM ${Boost_INCLUDE_DIR}) + +if(APPLE) + set(OPENSSL_ROOT_DIR "/usr/local/opt/openssl") +endif() + +#TODO: add requirement for version 1.0.1g (can it be done in one line?) +find_package(OpenSSL REQUIRED) +include_directories(SYSTEM ${OPENSSL_INCLUDE_DIR}) + +find_package(Threads REQUIRED) + +include_directories(.) + +include_directories("${SAMPLE_ROOT}/..") + + +add_executable(sample_client sample_client.cpp) +set_target_properties(sample_client PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(sample_client ${Boost_LIBRARIES}) +target_link_libraries(sample_client ${OPENSSL_CRYPTO_LIBRARY}) +target_link_libraries(sample_client ${CMAKE_THREAD_LIBS_INIT}) + +add_executable(sample_server sample_server.cpp) +set_target_properties(sample_server PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(sample_server ${Boost_LIBRARIES}) +target_link_libraries(sample_server ${OPENSSL_LIBRARIES}) +target_link_libraries(sample_server ${CMAKE_THREAD_LIBS_INIT}) + +add_executable(ws_examples ../ws_examples.cpp) +target_link_libraries(ws_examples ${Boost_LIBRARIES}) +target_link_libraries(ws_examples ${OPENSSL_CRYPTO_LIBRARY}) +target_link_libraries(ws_examples ${CMAKE_THREAD_LIBS_INIT}) + +add_executable(wss_examples ../wss_examples.cpp) +target_link_libraries(wss_examples ${Boost_LIBRARIES}) +target_link_libraries(wss_examples ${OPENSSL_LIBRARIES}) +target_link_libraries(wss_examples ${CMAKE_THREAD_LIBS_INIT}) + +if(MSYS) + target_link_libraries(sample_client ws2_32 wsock32) + target_link_libraries(sample_server ws2_32 wsock32) + target_link_libraries(ws_examples ws2_32 wsock32) + target_link_libraries(wss_examples ws2_32 wsock32) +endif() diff --git a/sample/README.md b/sample/README.md new file mode 100644 index 00000000..9c409d2b --- /dev/null +++ b/sample/README.md @@ -0,0 +1,63 @@ +Simple-WebSocket-Sample +======================= + +This project contains two executables which allow the user to control the connections and message flow for Simple-WebSocket-Server. It might be useful for testing interoperability with other websockets implementations. + +### sample_server controls + + s : Start server + t : sTop server + m : send Message to all clients + q : Quit + +### sample_client controls + + s : Set up connection + l : cLose connection + c : stop Client + m : send Message + q : Quit + +## Usage + +Run one server and as many clients as you like. Type the letter for the desired action and hit enter. A typical session might look like this: + +| sample_client | sample_server | Effect | +| :-----------: | :------------:|:-------| +| | **S**tart | The server starts listening for connections | +| **S**tart | | The client connects to the server | +| **M**essage | | The client sends a message to the server (the server will respond with an echo) | +| | **M**essage | The server sends a message to all connected clients (they will not respond) | +| c**L**ose | | The client disconnects with a message | +| | s**T**op | The server stops listening | +| s**T**op | | The client cleans itself up | +| **Q**uit | | The client quits | +| | **Q**uit | The server quits | +## Building + +The sample uses [Simple-WebSocket-Server](../README.md) (duh). You'll need its dependencies installed. + + +#### Windows + +Populate the following environmentla variables: + +| variable | value | +|:--|:--| +| BoostRoot | C:\path\to\Boost | +| BoostVer | 1_62 | +| OpenSSLRoot | C:\path\to\OpenSSL | + +Specify the correct generator in your call to cmake, this example uses 2017 with a 64 bit build: + + mkdir build + cd build + cmake .. -G "Visual Studio 15 2017 Win64" + +Open in your IDE of choice `build/Simple_WebSocket_Sample.sln` and build it. + +#### Linux + + mkdir build + cd build + cmake .. diff --git a/sample/sample_client.cpp b/sample/sample_client.cpp new file mode 100644 index 00000000..2a9b78e2 --- /dev/null +++ b/sample/sample_client.cpp @@ -0,0 +1,106 @@ +#include +#include +#include +#include "boost/thread.hpp" +#include "client_ws.hpp" +typedef SimpleWeb::SocketClient WsClient; + +using namespace std; + +int main(int, char**) +{ + shared_ptr client; + shared_ptr _connection; + boost::thread client_thread; + + cout << "s : Set up connection" << endl + << "l : cLose connection" << endl + << "c : stop Client" << endl + << "m : send Message" << endl + << "q : Quit" << endl; + + string line; + while (line != "q") + { + getline(cin, line); + + if (line == "s") + { + client = std::make_shared("localhost:8081/some/http/resource"); + + + client->SetProtocol("some-protocol"); // optional + + client->on_open = [&](shared_ptr connection) + { + _connection = connection; + cout << "Client Started & Connection " << (size_t)connection.get() << " Opened" << endl << endl; + }; + + client->on_close = [&](shared_ptr connection, int code, const string& reason) + { + _connection = nullptr; + cout << "Closed Connection " << (size_t)connection.get() << "(" << code << ")" << endl << " Reason: " << reason << endl << endl; + }; + + client->on_error = [](shared_ptr connection, const boost::system::error_code& code) + { + cout << "Error in Connection " << (size_t)connection.get() << "(" << code << ")" << endl << " Code: " << code.message() << endl << endl; + }; + + client->on_message = [](shared_ptr connection, shared_ptr message) + { + cout << "Server Message on Connection " << (size_t)connection.get() << endl << " Message: " << message->string() << endl << endl; + }; + + client_thread = boost::thread([&client]() + { + client->start(); + }); + cout << "Connection started" << endl << endl; + } + else if (line == "c") + { + if (client != nullptr) + { + client->stop(); + client = nullptr; + cout << "Stopped Client" << endl << endl; + } + else + { + cout << "Client Already Stopped" << endl << endl; + } + + } + else if (line == "l") + { + if (_connection != nullptr) + { + _connection->send_close(10, "Word to your moms, I came to drop bombs, I got more rhymes than the bible's got psalms.", [](const boost::system::error_code code) + { + cout << "Error on send_close Code: " << code + << " Message: " << code.message() << endl; + }); + cout << "Closed connection " << (size_t)_connection.get() << " with message" << endl << endl; + } + else + { + cout << "Connection already closed" << endl << endl; + } + } + + else if (line == "m") + { + auto msg = std::make_shared(); + *msg << "It's tricky to rock a rhyme to rock a rhyme that's right on time it's tricky!"; + _connection->send(msg); + cout << "Message sent" << endl << endl; + } + } + if (client != nullptr) + { + client->stop(); + client_thread.join(); + } +} diff --git a/sample/sample_server.cpp b/sample/sample_server.cpp new file mode 100644 index 00000000..0203bd40 --- /dev/null +++ b/sample/sample_server.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include +#include +#include "server_ws.hpp" +typedef SimpleWeb::SocketServer WsServer; + +using namespace std; + +int main(int, char**) +{ + shared_ptr server; + thread server_thread; + + cout << "s : Start server" << endl + << "t : sTop server" << endl + << "m : send Message to all clients" << endl + << "q : Quit" << endl; + + string line; + while (line != "q") + { + getline(cin, line); + + if (line == "s") + { + server = make_shared(); + server->config.port = 8081; + + auto& tunnel = server->endpoint["/some/http/resource"]; + tunnel.on_message = [&server](shared_ptr connection, shared_ptr message) + { + auto message_str = message->string(); + cout << "Client Message from connection " << (size_t)connection.get() << endl + << " Message: " << message_str << endl << endl; + + auto sendThisBack = make_shared(); + *sendThisBack << "[echo] " << message_str; + connection->send(sendThisBack, [](const boost::system::error_code code) + { + if (code) + { + cout << " Error while responding: " << code << ", error message: " << code.message() << endl << endl; + } + }); + }; + + tunnel.on_open = [](shared_ptr connection) + { + cout << "Opened Connection: " << (size_t)connection.get() << " from: " << connection->remote_endpoint_address << ":" << connection->remote_endpoint_port << endl << endl; + }; + + tunnel.on_close = [](shared_ptr connection, int status, const string& reason) + { + cout << "Closed Connection: " << (size_t)connection.get() << " from: " << connection->remote_endpoint_address << ":" << connection->remote_endpoint_port << endl; + + //See RFC 6455 7.4.1. for status codes + cout << " Code: " << status << " Reason: " << reason << endl << endl; + }; + + tunnel.on_error = [](shared_ptr connection, const boost::system::error_code& code) + { + //See http://www.boost.org/doc/libs/1_55_0/doc/html/boost_asio/reference.html, Error Codes for error code meanings + cout << "Error in connection " << (size_t)connection.get() << endl; + cout << " Code: " << code << ", Message: " << code.message() << endl << endl; + }; + + server_thread = thread([&server]() + { + server->start(); + }); + cout << "Server started" << endl; + } + else if (line == "t") + { + server->stop(); + server_thread.join(); + server = nullptr; + cout << "Server stopped" << endl; + } + else if (line == "m") + { + int i = 0; + for (auto connection : server->get_connections()) + { + i++; + auto msg = make_shared(); + *msg << "This is for the kids whippin' up some home-cook, spittin' 86 bars f'n no hook.."; + cout << "Sending Message..." << endl; + connection->send(msg, [&](const boost::system::error_code code) + { + if (code) + { + cout << "Error while sending to connnection: " << i << " Code: " << code + << " Message: " << code.message() << endl; + } + }); + cout << " ...sent" << endl; + } + } + } + + if (server != nullptr) + { + server->stop(); + server_thread.join(); + } +} diff --git a/server_ws.hpp b/server_ws.hpp index a3ed60d4..2d52fcfd 100644 --- a/server_ws.hpp +++ b/server_ws.hpp @@ -149,10 +149,25 @@ namespace SimpleWeb { static auto ws_magic_string = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; auto sha1 = Crypto::sha1(header_it->second + ws_magic_string); + auto proto_it = header.find("Sec-WebSocket-Key"); + + bool protocolSpecified = false; + std::string protocol; + proto_it = header.find("Sec-WebSocket-Protocol"); + if (proto_it != header.end()) + { + protocolSpecified = true; + protocol = proto_it->second; + } + handshake << "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; handshake << "Upgrade: websocket\r\n"; handshake << "Connection: Upgrade\r\n"; handshake << "Sec-WebSocket-Accept: " << Crypto::Base64::encode(sha1) << "\r\n"; + if (protocolSpecified) + { + handshake << "Sec-WebSocket-Protocol: " << protocol << "\r\n"; + } handshake << "\r\n"; return true; @@ -264,8 +279,8 @@ namespace SimpleWeb { auto send_stream = std::make_shared(); - send_stream->put(status >> 8); - send_stream->put(status % 256); + send_stream->put(static_cast(status >> 8)); + send_stream->put(static_cast(status % 256)); *send_stream << reason; @@ -371,17 +386,17 @@ namespace SimpleWeb { if(io_service->stopped()) io_service->reset(); - asio::ip::tcp::endpoint endpoint; + asio::ip::tcp::endpoint local_endpoint; if(config.address.size() > 0) - endpoint = asio::ip::tcp::endpoint(asio::ip::address::from_string(config.address), config.port); + local_endpoint = asio::ip::tcp::endpoint(asio::ip::address::from_string(config.address), config.port); else - endpoint = asio::ip::tcp::endpoint(asio::ip::tcp::v4(), config.port); + local_endpoint = asio::ip::tcp::endpoint(asio::ip::tcp::v4(), config.port); if(!acceptor) acceptor = std::unique_ptr(new asio::ip::tcp::acceptor(*io_service)); - acceptor->open(endpoint.protocol()); + acceptor->open(local_endpoint.protocol()); acceptor->set_option(asio::socket_base::reuse_address(config.reuse_address)); - acceptor->bind(endpoint); + acceptor->bind(local_endpoint); acceptor->listen(); accept(); @@ -516,14 +531,14 @@ namespace SimpleWeb { } } - void read_message(const std::shared_ptr &connection, Endpoint &endpoint) const { - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint](const error_code &ec, size_t bytes_transferred) { + void read_message(const std::shared_ptr &connection, Endpoint &open_endpoint) const { + asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &open_endpoint](const error_code &ec, size_t bytes_transferred) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; if(!ec) { if(bytes_transferred == 0) { // TODO: why does this happen sometimes? - read_message(connection, endpoint); + read_message(connection, open_endpoint); return; } std::istream stream(&connection->read_buffer); @@ -538,7 +553,7 @@ namespace SimpleWeb { if(first_bytes[1] < 128) { const std::string reason("message from client not masked"); connection->send_close(1002, reason); - connection_close(connection, endpoint, 1002, reason); + connection_close(connection, open_endpoint, 1002, reason); return; } @@ -546,7 +561,7 @@ namespace SimpleWeb { if(length == 126) { // 2 next bytes is the size of content - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(2), [this, connection, &open_endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; @@ -562,15 +577,15 @@ namespace SimpleWeb { for(size_t c = 0; c < num_bytes; c++) length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - read_message_content(connection, length, endpoint, fin_rsv_opcode); + read_message_content(connection, length, open_endpoint, fin_rsv_opcode); } else - connection_error(connection, endpoint, ec); + connection_error(connection, open_endpoint, ec); }); } else if(length == 127) { // 8 next bytes is the size of content - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(8), [this, connection, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(8), [this, connection, &open_endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; @@ -586,22 +601,22 @@ namespace SimpleWeb { for(size_t c = 0; c < num_bytes; c++) length += static_cast(length_bytes[c]) << (8 * (num_bytes - 1 - c)); - read_message_content(connection, length, endpoint, fin_rsv_opcode); + read_message_content(connection, length, open_endpoint, fin_rsv_opcode); } else - connection_error(connection, endpoint, ec); + connection_error(connection, open_endpoint, ec); }); } else - read_message_content(connection, length, endpoint, fin_rsv_opcode); + read_message_content(connection, length, open_endpoint, fin_rsv_opcode); } else - connection_error(connection, endpoint, ec); + connection_error(connection, open_endpoint, ec); }); } - void read_message_content(const std::shared_ptr &connection, size_t length, Endpoint &endpoint, unsigned char fin_rsv_opcode) const { - asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(4 + length), [this, connection, length, &endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { + void read_message_content(const std::shared_ptr &connection, size_t length, Endpoint &msg_endpoint, unsigned char fin_rsv_opcode) const { + asio::async_read(*connection->socket, connection->read_buffer, asio::transfer_exactly(4 + length), [this, connection, length, &msg_endpoint, fin_rsv_opcode](const error_code &ec, size_t /*bytes_transferred*/) { auto lock = connection->handler_runner->continue_lock(); if(!lock) return; @@ -619,21 +634,21 @@ namespace SimpleWeb { std::ostream message_data_out_stream(&message->streambuf); for(size_t c = 0; c < length; c++) { - message_data_out_stream.put(raw_message_data.get() ^ mask[c % 4]); + message_data_out_stream.put(static_cast(raw_message_data.get() ^ mask[c % 4])); } // If connection close if((fin_rsv_opcode & 0x0f) == 8) { int status = 0; if(length >= 2) { - unsigned char byte1 = message->get(); - unsigned char byte2 = message->get(); + unsigned char byte1 = static_cast(message->get()); + unsigned char byte2 = static_cast(message->get()); status = (byte1 << 8) + byte2; } auto reason = message->string(); connection->send_close(status, reason); - connection_close(connection, endpoint, status, reason); + connection_close(connection, msg_endpoint, status, reason); return; } else { @@ -643,58 +658,58 @@ namespace SimpleWeb { auto empty_send_stream = std::make_shared(); connection->send(empty_send_stream, nullptr, fin_rsv_opcode + 1); } - else if(endpoint.on_message) { + else if(msg_endpoint.on_message) { connection->cancel_timeout(); connection->set_timeout(); - endpoint.on_message(connection, message); + msg_endpoint.on_message(connection, message); } // Next message - read_message(connection, endpoint); + read_message(connection, msg_endpoint); } } else - connection_error(connection, endpoint, ec); + connection_error(connection, msg_endpoint, ec); }); } - void connection_open(const std::shared_ptr &connection, Endpoint &endpoint) const { + void connection_open(const std::shared_ptr &connection, Endpoint &open_endpoint) const { connection->cancel_timeout(); connection->set_timeout(); { - std::unique_lock lock(endpoint.connections_mutex); - endpoint.connections.insert(connection); + std::unique_lock lock(open_endpoint.connections_mutex); + open_endpoint.connections.insert(connection); } - if(endpoint.on_open) - endpoint.on_open(connection); + if(open_endpoint.on_open) + open_endpoint.on_open(connection); } - void connection_close(const std::shared_ptr &connection, Endpoint &endpoint, int status, const std::string &reason) const { + void connection_close(const std::shared_ptr &connection, Endpoint &close_endpoint, int status, const std::string &reason) const { connection->cancel_timeout(); connection->set_timeout(); { - std::unique_lock lock(endpoint.connections_mutex); - endpoint.connections.erase(connection); + std::unique_lock lock(close_endpoint.connections_mutex); + close_endpoint.connections.erase(connection); } - if(endpoint.on_close) - endpoint.on_close(connection, status, reason); + if(close_endpoint.on_close) + close_endpoint.on_close(connection, status, reason); } - void connection_error(const std::shared_ptr &connection, Endpoint &endpoint, const error_code &ec) const { + void connection_error(const std::shared_ptr &connection, Endpoint &err_endpoint, const error_code &ec) const { connection->cancel_timeout(); connection->set_timeout(); { - std::unique_lock lock(endpoint.connections_mutex); - endpoint.connections.erase(connection); + std::unique_lock lock(err_endpoint.connections_mutex); + err_endpoint.connections.erase(connection); } - if(endpoint.on_error) - endpoint.on_error(connection, ec); + if(err_endpoint.on_error) + err_endpoint.on_error(connection, ec); } }; diff --git a/server_wss.hpp b/server_wss.hpp index 6b7a9e79..e15e65be 100644 --- a/server_wss.hpp +++ b/server_wss.hpp @@ -40,7 +40,7 @@ namespace SimpleWeb { session_id_context = std::to_string(config.port) + ':'; session_id_context.append(config.address.rbegin(), config.address.rend()); SSL_CTX_set_session_id_context(context.native_handle(), reinterpret_cast(session_id_context.data()), - std::min(session_id_context.size(), SSL_MAX_SSL_SESSION_ID_LENGTH)); + static_cast(std::min(session_id_context.size(), SSL_MAX_SSL_SESSION_ID_LENGTH))); } SocketServerBase::start(); }