diff --git a/wish/cpp/CMakeLists.txt b/wish/cpp/CMakeLists.txt index 7294377..066c9ca 100644 --- a/wish/cpp/CMakeLists.txt +++ b/wish/cpp/CMakeLists.txt @@ -28,7 +28,7 @@ set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) FetchContent_Declare( abseil GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git - GIT_TAG 20260107.1 + GIT_TAG 255c84dadd029fd8ad25c5efb5933e47beaa00c7 # 20260107.1 ) FetchContent_MakeAvailable(abseil) @@ -42,7 +42,7 @@ FetchContent_Declare( wslay GIT_REPOSITORY https://github.com/tatsuhiro-t/wslay.git GIT_TAG 0e7d106ff89ad6638090fd811a9b2e4c5dda8d40 - PATCH_COMMAND git apply --reverse --check ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay-disable-mask.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay-disable-mask.patch + PATCH_COMMAND git apply --reverse --check ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch ) @@ -61,7 +61,7 @@ set(EVENT__DISABLE_SAMPLES ON CACHE BOOL "" FORCE) FetchContent_Declare( libevent GIT_REPOSITORY https://github.com/libevent/libevent.git - GIT_TAG master + GIT_TAG 780acfe8b2495949f0dc3ebd6f18eea2dec605a6 ) @@ -72,42 +72,50 @@ FetchContent_Declare( ) +FetchContent_Declare( + picohttpparser + GIT_REPOSITORY https://github.com/h2o/picohttpparser.git + GIT_TAG 539bb9f510f5aefc8efda53aaf822a21ecae2832 # v1.2 +) + + # These variables are unprefixed, so MakeAvailable is called immediately and # they are unset afterwards to avoid leaking into other dependencies. +set(ENABLE_FAILMALLOC OFF CACHE BOOL "" FORCE) + +# Since BUILD_SHARED_LIBS=OFF, we need to explicitly enable the static target (nghttp2_static) +set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) +set(BUILD_STATIC_LIBS ON CACHE BOOL "" FORCE) + # nghttp2: build only the core library (no TLS apps, no examples). set(ENABLE_LIB_ONLY ON CACHE BOOL "" FORCE) -set(ENABLE_APP OFF CACHE BOOL "" FORCE) -set(ENABLE_HPACK_TOOLS OFF CACHE BOOL "" FORCE) + set(ENABLE_DOC OFF CACHE BOOL "" FORCE) -set(ENABLE_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE) -set(ENABLE_EXAMPLES OFF CACHE BOOL "" FORCE) -set(ENABLE_FAILMALLOC OFF CACHE BOOL "" FORCE) -set(WITH_JEMALLOC OFF CACHE BOOL "" FORCE) -set(WITH_LIBEVENT_OPENSSL OFF CACHE BOOL "" FORCE) -set(WITH_OPENSSL OFF CACHE BOOL "" FORCE) + set(WITH_LIBXML2 OFF CACHE BOOL "" FORCE) -set(WITH_SYSTEMD OFF CACHE BOOL "" FORCE) +set(WITH_JEMALLOC OFF CACHE BOOL "" FORCE) set(WITH_MRUBY OFF CACHE BOOL "" FORCE) set(WITH_NEVERBLEED OFF CACHE BOOL "" FORCE) -# Since BUILD_SHARED_LIBS=OFF, we need to explicitly enable the static target (nghttp2_static) -set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) -set(BUILD_STATIC_LIBS ON CACHE BOOL "" FORCE) - FetchContent_Declare( nghttp2 GIT_REPOSITORY https://github.com/nghttp2/nghttp2.git - GIT_TAG v1.65.0 + GIT_TAG 68cb6900fde14c77f0cd7add0e094a862960eb99 # v1.69.0 ) FetchContent_MakeAvailable(nghttp2) # Reset the cached variables to avoid affecting other dependencies. They are not prefixed. foreach(_var IN ITEMS - ENABLE_LIB_ONLY ENABLE_APP ENABLE_HPACK_TOOLS ENABLE_DOC - ENABLE_PYTHON_BINDINGS ENABLE_EXAMPLES ENABLE_FAILMALLOC - WITH_JEMALLOC WITH_LIBEVENT_OPENSSL WITH_OPENSSL WITH_LIBXML2 - WITH_SYSTEMD WITH_MRUBY WITH_NEVERBLEED BUILD_STATIC_LIBS BUILD_SHARED_LIBS) + ENABLE_LIB_ONLY + ENABLE_DOC + ENABLE_FAILMALLOC + BUILD_SHARED_LIBS + BUILD_STATIC_LIBS + WITH_LIBXML2 + WITH_JEMALLOC + WITH_MRUBY + WITH_NEVERBLEED) unset(${_var} CACHE) endforeach() @@ -121,7 +129,7 @@ if(WISH_BUILD_BENCHMARKS) FetchContent_Declare( googlebenchmark GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG main + GIT_TAG 192ef10025eb2c4cdd392bc502f0c852196baa48 # v1.9.5 ) FetchContent_MakeAvailable(googlebenchmark) endif() @@ -130,14 +138,13 @@ if(WISH_BUILD_TESTS) FetchContent_Declare( googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.17.0 + GIT_TAG 52eb8108c5bdec04579160ae17225d66034bd723 # v1.17.0 ) FetchContent_MakeAvailable(googletest) endif() -FetchContent_MakeAvailable(wslay libevent boringssl) - +FetchContent_MakeAvailable(wslay libevent boringssl picohttpparser) include_directories(${wslay_SOURCE_DIR}/lib/includes) include_directories(${libevent_SOURCE_DIR}/include) @@ -145,18 +152,25 @@ include_directories(${libevent_BINARY_DIR}/include) include_directories(${boringssl_SOURCE_DIR}/include) include_directories(${nghttp2_SOURCE_DIR}/lib/includes) include_directories(${nghttp2_BINARY_DIR}/lib/includes) +include_directories(${picohttpparser_SOURCE_DIR}) include_directories(src) +add_library(http1_handshake STATIC + src/handshake.h + src/handshake.cc + ${picohttpparser_SOURCE_DIR}/picohttpparser.c + ${picohttpparser_SOURCE_DIR}/picohttpparser.h +) +target_link_libraries(http1_handshake event absl::strings absl::log) + # We will compile libevent's bufferevent_openssl.c inside our library # since we disabled it in libevent's own build. # Build as a static library: libevent is also static and was not compiled with # -fPIC, so it cannot be linked into a shared object. -add_library(buffer_event_web_stream STATIC +add_library(web_stream STATIC src/web_stream.h src/buffer_event_web_stream.cc src/buffer_event_web_stream.h - src/handshake.h - src/handshake.cc ${libevent_SOURCE_DIR}/bufferevent_openssl.c ${libevent_SOURCE_DIR}/bufferevent_ssl.c src/plain_server.cc @@ -181,8 +195,8 @@ add_library(buffer_event_web_stream STATIC src/h2_tls_client.h ) # BoringSSL targets: ssl and crypto. nghttp2 for HTTP/2 framing. -target_link_libraries(buffer_event_web_stream wslay event ssl crypto nghttp2::nghttp2 absl::strings absl::log) -target_compile_definitions(buffer_event_web_stream PUBLIC EVENT__HAVE_OPENSSL=1) +target_link_libraries(web_stream http1_handshake wslay event ssl crypto nghttp2_static absl::strings absl::log) +target_compile_definitions(web_stream PUBLIC EVENT__HAVE_OPENSSL=1) if(WISH_BUILD_EXAMPLES) @@ -197,14 +211,14 @@ if(WISH_BUILD_TESTS) enable_testing() add_executable(buffer_event_web_stream_test src/buffer_event_web_stream_test.cc) - target_link_libraries(buffer_event_web_stream_test buffer_event_web_stream gtest_main gtest event wslay) + target_link_libraries(buffer_event_web_stream_test web_stream gtest_main gtest) add_test(NAME buffer_event_web_stream_test COMMAND buffer_event_web_stream_test) add_executable(nghttp2_web_stream_test src/nghttp2_web_stream_test.cc) - target_link_libraries(nghttp2_web_stream_test buffer_event_web_stream gtest_main gtest nghttp2::nghttp2 wslay) + target_link_libraries(nghttp2_web_stream_test web_stream gtest_main gtest) add_test(NAME nghttp2_web_stream_test COMMAND nghttp2_web_stream_test) add_executable(handshake_test src/handshake_test.cc) - target_link_libraries(handshake_test buffer_event_web_stream gtest_main gtest event wslay) + target_link_libraries(handshake_test http1_handshake gtest_main gtest event) add_test(NAME handshake_test COMMAND handshake_test) endif() diff --git a/wish/cpp/benchmark/CMakeLists.txt b/wish/cpp/benchmark/CMakeLists.txt index b9addda..a5dfea4 100644 --- a/wish/cpp/benchmark/CMakeLists.txt +++ b/wish/cpp/benchmark/CMakeLists.txt @@ -2,7 +2,7 @@ add_executable(plain_benchmark_client plain_client.cc ) target_link_libraries(plain_benchmark_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -16,7 +16,7 @@ add_executable(high_qps_benchmark_client high_qps_client.cc ) target_link_libraries(high_qps_benchmark_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -29,7 +29,7 @@ add_executable(tls_benchmark_client tls_client.cc ) target_link_libraries(tls_benchmark_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log diff --git a/wish/cpp/examples/CMakeLists.txt b/wish/cpp/examples/CMakeLists.txt index 2eebd21..9e9aeec 100644 --- a/wish/cpp/examples/CMakeLists.txt +++ b/wish/cpp/examples/CMakeLists.txt @@ -2,7 +2,7 @@ add_executable(plain_echo_server plain_echo_server.cc ) target_link_libraries(plain_echo_server - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -15,7 +15,7 @@ add_executable(plain_hello_client plain_hello_client.cc ) target_link_libraries(plain_hello_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -27,7 +27,7 @@ add_executable(tls_echo_server tls_echo_server.cc ) target_link_libraries(tls_echo_server - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -39,7 +39,7 @@ add_executable(tls_hello_client tls_hello_client.cc ) target_link_libraries(tls_hello_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -53,7 +53,7 @@ add_executable(h2_echo_server h2_echo_server.cc ) target_link_libraries(h2_echo_server - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -65,7 +65,7 @@ add_executable(h2_hello_client h2_hello_client.cc ) target_link_libraries(h2_hello_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -77,7 +77,7 @@ add_executable(h2_tls_echo_server h2_tls_echo_server.cc ) target_link_libraries(h2_tls_echo_server - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log @@ -89,7 +89,7 @@ add_executable(h2_tls_hello_client h2_tls_hello_client.cc ) target_link_libraries(h2_tls_hello_client - buffer_event_web_stream + web_stream absl::flags absl::flags_parse absl::log diff --git a/wish/cpp/patches/wslay-disable-mask.patch b/wish/cpp/patches/wslay.patch similarity index 100% rename from wish/cpp/patches/wslay-disable-mask.patch rename to wish/cpp/patches/wslay.patch diff --git a/wish/cpp/src/buffer_event_web_stream.cc b/wish/cpp/src/buffer_event_web_stream.cc index 79aa641..6549614 100644 --- a/wish/cpp/src/buffer_event_web_stream.cc +++ b/wish/cpp/src/buffer_event_web_stream.cc @@ -59,7 +59,8 @@ void BufferEventWebStream::Start() { } // If there is already data in the input buffer, process it immediately. - if (evbuffer_get_length(bufferevent_get_input(bev_)) > 0) { + size_t input_len = evbuffer_get_length(bufferevent_get_input(bev_)); + if (input_len > 0) { ReadCallback(bev_, this); } } @@ -186,7 +187,8 @@ void BufferEventWebStream::DrainCallback(bufferevent* bev, BufferEventWebStream* stream = static_cast(ctx); // Delete the stream only once all queued outbound data has been sent. - if (evbuffer_get_length(bufferevent_get_output(bev)) == 0) { + size_t output_len = evbuffer_get_length(bufferevent_get_output(bev)); + if (output_len == 0) { stream->state_ = CLOSED; delete stream; } @@ -205,9 +207,6 @@ void BufferEventWebStream::EventCallback(bufferevent* bev, } if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) { - // Connection closed - LOG(INFO) << "Connection closed."; - BufferEventWebStream* stream = static_cast(ctx); // If the stream is still in OPEN state, the underlying connection closed @@ -247,14 +246,40 @@ ssize_t BufferEventWebStream::WslaySendCallback(wslay_event_context* ctx, int header_len = snprintf(header, sizeof(header), "%zx\r\n", len); if (header_len <= 0) { wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + int write_header_rv = bufferevent_write(stream->bev_, + header, + static_cast(header_len)); + if (write_header_rv != 0) { + LOG(ERROR) << "bufferevent_write() failed"; + + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + int write_data_rv = bufferevent_write(stream->bev_, + data, + len); + if (write_data_rv != 0) { + LOG(ERROR) << "bufferevent_write() failed"; + + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + return -1; } - if (bufferevent_write(stream->bev_, header, static_cast(header_len)) != 0 || - bufferevent_write(stream->bev_, data, len) != 0 || - bufferevent_write(stream->bev_, "\r\n", 2) != 0) { + int write_trailer_rv = bufferevent_write(stream->bev_, + "\r\n", + 2); + if (write_trailer_rv != 0) { LOG(ERROR) << "bufferevent_write() failed"; + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + return -1; } @@ -376,7 +401,8 @@ ssize_t BufferEventWebStream::ReadChunkedBytes(uint8_t* buf, size_t len) { } case ChunkState::TRAILER: { - if (evbuffer_get_length(input) < 2) { + size_t input_len = evbuffer_get_length(input); + if (input_len < 2) { wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); return -1; } diff --git a/wish/cpp/src/h2_client.cc b/wish/cpp/src/h2_client.cc index ca69a3c..b21646f 100644 --- a/wish/cpp/src/h2_client.cc +++ b/wish/cpp/src/h2_client.cc @@ -93,11 +93,12 @@ bool H2Client::Init() { return false; } - if (bufferevent_socket_connect_hostname(bev, - dns_base_, - AF_INET, - host_.c_str(), - port_) < 0) { + int connect_rv = bufferevent_socket_connect_hostname(bev, + dns_base_, + AF_INET, + host_.c_str(), + port_); + if (connect_rv < 0) { LOG(ERROR) << "bufferevent_socket_connect_hostname() failed"; return false; @@ -109,8 +110,6 @@ bool H2Client::Init() { void H2Client::SetOnOpen(OpenCallback cb) { on_open_ = cb; } void H2Client::Run() { - LOG(INFO) << "Running..."; - event_base_dispatch(base_); } diff --git a/wish/cpp/src/h2_server.cc b/wish/cpp/src/h2_server.cc index 1748e32..15dff6a 100644 --- a/wish/cpp/src/h2_server.cc +++ b/wish/cpp/src/h2_server.cc @@ -60,21 +60,19 @@ bool H2Server::Init() { void H2Server::SetOnStream(StreamCallback cb) { on_stream_ = cb; } void H2Server::Run() { - LOG(INFO) << "H2Server listening on port " << port_ << "..."; - event_base_dispatch(base_); } // ---- libevent listener callbacks ---- -void H2Server::AcceptConnCb(struct evconnlistener* listener, +void H2Server::AcceptConnCb(evconnlistener* listener, evutil_socket_t fd, - struct sockaddr* /*address*/, + sockaddr* /*address*/, int /*socklen*/, void* ctx) { H2Server* server = static_cast(ctx); - struct event_base* base = evconnlistener_get_base(listener); + event_base* base = evconnlistener_get_base(listener); int one = 1; int set_rv = setsockopt(fd, @@ -88,9 +86,9 @@ void H2Server::AcceptConnCb(struct evconnlistener* listener, return; } - struct bufferevent* bev = bufferevent_socket_new(base, - fd, - BEV_OPT_CLOSE_ON_FREE); + bufferevent* bev = bufferevent_socket_new(base, + fd, + BEV_OPT_CLOSE_ON_FREE); if (!bev) { LOG(ERROR) << "H2Server: bufferevent_socket_new() failed"; evutil_closesocket(fd); diff --git a/wish/cpp/src/h2_tls_client.cc b/wish/cpp/src/h2_tls_client.cc index fc13639..f85dca0 100644 --- a/wish/cpp/src/h2_tls_client.cc +++ b/wish/cpp/src/h2_tls_client.cc @@ -2,15 +2,14 @@ #include #include +#include #include +#include #include #include #include -#include -#include - #define H2TC_MAKE_NV(name, value) \ { \ (uint8_t*)(name), (uint8_t*)(value), strlen(name), strlen(value), NGHTTP2_NV_FLAG_NONE} @@ -98,9 +97,13 @@ bool H2TlsClient::Init() { if (!bev) { LOG(ERROR) << "bufferevent_openssl_socket_new() failed"; + SSL_free(ssl); + return false; } + bufferevent_openssl_set_allow_dirty_shutdown(bev, 1); + session_ = new Session; session_->client = this; session_->bev = bev; @@ -123,11 +126,12 @@ bool H2TlsClient::Init() { return false; } - if (bufferevent_socket_connect_hostname(bev, - dns_base_, - AF_INET, - host_.c_str(), - port_) < 0) { + int connect_rv = bufferevent_socket_connect_hostname(bev, + dns_base_, + AF_INET, + host_.c_str(), + port_); + if (connect_rv < 0) { LOG(ERROR) << "bufferevent_socket_connect_hostname() failed"; return false; @@ -139,8 +143,6 @@ bool H2TlsClient::Init() { void H2TlsClient::SetOnOpen(OpenCallback cb) { on_open_ = cb; } void H2TlsClient::Run() { - LOG(INFO) << "Running..."; - event_base_dispatch(base_); } diff --git a/wish/cpp/src/h2_tls_server.cc b/wish/cpp/src/h2_tls_server.cc index a78fea5..091d618 100644 --- a/wish/cpp/src/h2_tls_server.cc +++ b/wish/cpp/src/h2_tls_server.cc @@ -21,7 +21,8 @@ static int AlpnSelectCb(SSL* /*ssl*/, unsigned int inlen, void* /*arg*/) { // nghttp2_select_next_protocol returns 1 if "h2" is found, <=0 otherwise. - if (nghttp2_select_next_protocol(const_cast(out), outlen, in, inlen) <= 0) { + int select_rv = nghttp2_select_next_protocol(const_cast(out), outlen, in, inlen); + if (select_rv <= 0) { return SSL_TLSEXT_ERR_NOACK; } return SSL_TLSEXT_ERR_OK; @@ -96,8 +97,6 @@ bool H2TlsServer::Init() { void H2TlsServer::SetOnStream(StreamCallback cb) { on_stream_ = cb; } void H2TlsServer::Run() { - LOG(INFO) << "H2TlsServer listening on port " << port_ << "..."; - event_base_dispatch(base_); } @@ -122,19 +121,27 @@ void H2TlsServer::AcceptConnCb(evconnlistener* listener, LOG(ERROR) << "H2TlsServer: setsockopt(TCP_NODELAY) failed"; evutil_closesocket(fd); + return; } SSL* ssl = SSL_new(server->tls_ctx_.ssl_ctx()); - bufferevent* bev = bufferevent_openssl_socket_new( - base, fd, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); + bufferevent* bev = bufferevent_openssl_socket_new(base, + fd, + ssl, + BUFFEREVENT_SSL_ACCEPTING, + BEV_OPT_CLOSE_ON_FREE); if (!bev) { LOG(ERROR) << "H2TlsServer: bufferevent_openssl_socket_new() failed"; + SSL_free(ssl); evutil_closesocket(fd); + return; } + bufferevent_openssl_set_allow_dirty_shutdown(bev, 1); + Session* sess = new Session; sess->server = server; sess->bev = bev; diff --git a/wish/cpp/src/handshake.cc b/wish/cpp/src/handshake.cc index e1d5a30..a1ba57a 100644 --- a/wish/cpp/src/handshake.cc +++ b/wish/cpp/src/handshake.cc @@ -4,56 +4,172 @@ #include #include #include +#include +#include #include #include +#include namespace { -bool ReadHttpHeaders(evbuffer* input, std::string* headers_out) { - size_t len = evbuffer_get_length(input); - if (len == 0) { +bool EqualsIgnoreCase(std::string_view a, std::string_view b) { + if (a.size() != b.size()) { return false; } + return std::equal(a.begin(), a.end(), b.begin(), [](char char_a, char char_b) { + return std::tolower(static_cast(char_a)) == std::tolower(static_cast(char_b)); + }); +} - // Search for \r\n\r\n - evbuffer_ptr ptr = evbuffer_search(input, "\r\n\r\n", 4, nullptr); - if (ptr.pos == -1) { - return false; // Not full headers yet +bool CheckHeader(const phr_header* headers, + size_t num_headers, + std::string_view target_name, + std::string_view target_value) { + for (size_t i = 0; i < num_headers; ++i) { + std::string_view name(headers[i].name, headers[i].name_len); + std::string_view value(headers[i].value, headers[i].value_len); + if (EqualsIgnoreCase(name, target_name) && EqualsIgnoreCase(value, target_value)) { + return true; + } } + return false; +} - // Read up to the end of headers - size_t header_len = ptr.pos + 4; - char* headers = new char[header_len + 1]; - evbuffer_remove(input, headers, header_len); - headers[header_len] = '\0'; - *headers_out = std::string(headers); - delete[] headers; +bool HasHeader(const phr_header* headers, + size_t num_headers, + std::string_view target_name) { + for (size_t i = 0; i < num_headers; ++i) { + std::string_view name(headers[i].name, headers[i].name_len); + if (EqualsIgnoreCase(name, target_name)) { + return true; + } + } + return false; +} - return true; +std::string_view GetHeaderValue(const phr_header* headers, + size_t num_headers, + std::string_view target_name) { + for (size_t i = 0; i < num_headers; ++i) { + std::string_view name(headers[i].name, headers[i].name_len); + if (EqualsIgnoreCase(name, target_name)) { + return std::string_view(headers[i].value, headers[i].value_len); + } + } + return {}; } -void SendHttpRequest(bufferevent* bev) { - std::stringstream ss; - ss << "POST / HTTP/1.1\r\n"; - ss << "Host: localhost\r\n"; - ss << "Content-Type: application/web-stream\r\n"; - ss << "Transfer-Encoding: chunked\r\n"; - ss << "\r\n"; - std::string data = ss.str(); - bufferevent_write(bev, data.c_str(), data.length()); +bool CheckHeaderContains(const phr_header* headers, + size_t num_headers, + std::string_view target_name, + std::string_view target_substring) { + std::string sub_lower(target_substring); + std::transform(sub_lower.begin(), sub_lower.end(), sub_lower.begin(), ::tolower); + + for (size_t i = 0; i < num_headers; ++i) { + std::string_view name(headers[i].name, headers[i].name_len); + if (EqualsIgnoreCase(name, target_name)) { + std::string_view value(headers[i].value, headers[i].value_len); + std::string val_lower(value); + std::transform(val_lower.begin(), val_lower.end(), val_lower.begin(), ::tolower); + if (val_lower.find(sub_lower) != std::string::npos) { + return true; + } + } + } + return false; } -void SendHttpResponse(bufferevent* bev, - const std::string& status, - const std::string& content_type) { - std::stringstream ss; - ss << "HTTP/1.1 " << status << "\r\n"; - ss << "Content-Type: " << content_type << "\r\n"; - ss << "Transfer-Encoding: chunked\r\n"; - ss << "\r\n"; // End of headers - std::string data = ss.str(); - bufferevent_write(bev, data.c_str(), data.length()); +bool IsHttpWhitespace(char c) { + return c == ' ' || c == '\t'; +} + +std::vector SplitAndTrim(std::string_view s) { + std::vector result; + size_t start = 0; + while (start < s.size()) { + size_t comma = s.find(',', start); + std::string_view token = (comma == std::string_view::npos) ? s.substr(start) : s.substr(start, comma - start); + // Trim whitespace + while (!token.empty() && IsHttpWhitespace(token.front())) { + token.remove_prefix(1); + } + while (!token.empty() && IsHttpWhitespace(token.back())) { + token.remove_suffix(1); + } + if (!token.empty()) { + result.push_back(token); + } + if (comma == std::string_view::npos) { + break; + } + start = comma + 1; + } + return result; +} + +bool ValidateTransferEncoding(const phr_header* headers, size_t num_headers) { + for (size_t i = 0; i < num_headers; ++i) { + std::string_view name(headers[i].name, headers[i].name_len); + if (EqualsIgnoreCase(name, "transfer-encoding")) { + std::string_view value(headers[i].value, headers[i].value_len); + std::vector tokens = SplitAndTrim(value); + if (tokens.empty()) { + LOG(ERROR) << "Empty Transfer-Encoding header value"; + return false; + } + for (std::string_view token : tokens) { + if (EqualsIgnoreCase(token, "chunked")) { + // Valid chunked token + } else { + LOG(ERROR) << "Unsupported Transfer-Encoding token: " << token; + + return false; + } + } + } + } + return true; +} + +bool ValidateHeaders(const phr_header* headers, size_t num_headers) { + bool has_content_encoding = HasHeader(headers, num_headers, "content-encoding"); + if (has_content_encoding) { + LOG(ERROR) << "Content-Encoding support is not implemented yet"; + return false; + } + + bool has_content_length = HasHeader(headers, num_headers, "content-length"); + if (has_content_length) { + LOG(ERROR) << "Content-Length support is not implemented yet"; + return false; + } + + bool has_connection_close = CheckHeaderContains(headers, num_headers, "connection", "close"); + if (has_connection_close) { + LOG(ERROR) << "Connection: close support is not implemented yet"; + return false; + } + + bool has_connection_upgrade = CheckHeaderContains(headers, num_headers, "connection", "upgrade"); + if (has_connection_upgrade) { + LOG(ERROR) << "Connection: upgrade is not supported by web-stream"; + return false; + } + + bool has_upgrade = HasHeader(headers, num_headers, "upgrade"); + if (has_upgrade) { + LOG(ERROR) << "Upgrade header is not supported by web-stream"; + return false; + } + + bool te_valid = ValidateTransferEncoding(headers, num_headers); + if (!te_valid) { + return false; + } + + return true; } } // namespace @@ -75,11 +191,29 @@ ClientHandshake::~ClientHandshake() { void ClientHandshake::Start() { bufferevent_setcb(bev_, ReadCb, nullptr, EventCb, this); + int enable_rv = bufferevent_enable(bev_, EV_READ | EV_WRITE); if (enable_rv != 0) { - LOG(ERROR) << "bufferevent_enable() failed in ClientHandshake"; + LOG(ERROR) << "bufferevent_enable() failed"; + + InvokeError(); + + return; + } + + std::stringstream ss; + ss << "POST / HTTP/1.1\r\n"; + ss << "Host: localhost\r\n"; + ss << "Content-Type: application/web-stream\r\n"; + ss << "Transfer-Encoding: chunked\r\n"; + ss << "\r\n"; + std::string data = ss.str(); + int write_rv = bufferevent_write(bev_, data.c_str(), data.length()); + if (write_rv != 0) { + LOG(ERROR) << "bufferevent_write() failed"; + + InvokeError(); } - SendHttpRequest(bev_); } void ClientHandshake::ReadCb(bufferevent* bev, void* ctx) { @@ -92,14 +226,74 @@ void ClientHandshake::EventCb(bufferevent* bev, short what, void* ctx) { void ClientHandshake::HandleRead() { evbuffer* input = bufferevent_get_input(bev_); - std::string headers; - if (!ReadHttpHeaders(input, &headers)) { - return; // Wait for more data + + size_t len = evbuffer_get_length(input); + if (len == 0) { + return; + } + + const char* data = reinterpret_cast(evbuffer_pullup(input, -1)); + if (!data) { + ABSL_UNREACHABLE(); + } + + int minor_version; + int status; + const char* msg; + size_t msg_len; + struct phr_header headers[100]; + size_t num_headers = 100; + + int parse_rv = phr_parse_response(data, len, &minor_version, &status, &msg, &msg_len, headers, &num_headers, 0); + if (parse_rv == -1) { + LOG(ERROR) << "Failed to parse client handshake HTTP response"; + + InvokeError(); + + return; + } + if (parse_rv == -2) { + return; // Incomplete headers, wait for more data + } + + if (status != 200) { + LOG(ERROR) << "Bad client handshake response status: " << status; + + InvokeError(); + + return; + } + + if (minor_version < 1) { + LOG(ERROR) << "HTTP version must be at least 1.1, got 1." << minor_version; + + InvokeError(); + + return; + } + + bool has_valid_ct = CheckHeader(headers, num_headers, "content-type", "application/web-stream"); + if (!has_valid_ct) { + LOG(ERROR) << "Client handshake response missing web-stream Content-Type!"; + + InvokeError(); + + return; + } + + bool headers_valid = ValidateHeaders(headers, num_headers); + if (!headers_valid) { + InvokeError(); + + return; } - if (headers.find("200 OK") == std::string::npos) { - LOG(ERROR) << "Bad Handy handshake response: " << headers; - HandleEvent(BEV_EVENT_ERROR); + int drain_rv = evbuffer_drain(input, parse_rv); + if (drain_rv != 0) { + LOG(ERROR) << "evbuffer_drain() failed"; + + InvokeError(); + return; } @@ -116,25 +310,28 @@ void ClientHandshake::HandleRead() { void ClientHandshake::HandleEvent(short what) { if (what & BEV_EVENT_CONNECTED) { - LOG(INFO) << "Client handshake: connected"; return; } - if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR | BEV_EVENT_TIMEOUT)) { - if (what & BEV_EVENT_ERROR) { - int err = EVUTIL_SOCKET_ERROR(); - if (err != 0) { - LOG(ERROR) << "Error during client handshake: " - << evutil_socket_error_to_string(err); - } else { - LOG(ERROR) << "Error during client handshake"; - } + if (what & BEV_EVENT_ERROR) { + int err = EVUTIL_SOCKET_ERROR(); + if (err != 0) { + LOG(ERROR) << "Error during client handshake: " + << evutil_socket_error_to_string(err); + } else { + LOG(ERROR) << "Error during client handshake"; } + } - auto on_error = std::move(on_error_); - if (on_error) { - on_error(); - } + if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR | BEV_EVENT_TIMEOUT)) { + InvokeError(); + } +} + +void ClientHandshake::InvokeError() { + auto on_error = std::move(on_error_); + if (on_error) { + on_error(); } } @@ -155,9 +352,14 @@ ServerHandshake::~ServerHandshake() { void ServerHandshake::Start() { bufferevent_setcb(bev_, ReadCb, nullptr, EventCb, this); + int enable_rv = bufferevent_enable(bev_, EV_READ | EV_WRITE); if (enable_rv != 0) { - LOG(ERROR) << "bufferevent_enable() failed in ServerHandshake"; + LOG(ERROR) << "bufferevent_enable() failed"; + + InvokeError(); + + return; } } @@ -171,21 +373,85 @@ void ServerHandshake::EventCb(bufferevent* bev, short what, void* ctx) { void ServerHandshake::HandleRead() { evbuffer* input = bufferevent_get_input(bev_); - std::string headers; - if (!ReadHttpHeaders(input, &headers)) { - return; // Wait for more data + + size_t len = evbuffer_get_length(input); + if (len == 0) { + return; } - // Check for web-stream specific header - if (headers.find("Content-Type: application/web-stream") == std::string::npos && - headers.find("content-type: application/web-stream") == std::string::npos) { + const char* data = reinterpret_cast(evbuffer_pullup(input, -1)); + if (!data) { + ABSL_UNREACHABLE(); + } + + int minor_version; + const char* method; + size_t method_len; + const char* path; + size_t path_len; + struct phr_header headers[100]; + size_t num_headers = 100; + + int parse_rv = phr_parse_request(data, len, &method, &method_len, &path, &path_len, &minor_version, headers, &num_headers, 0); + if (parse_rv == -1) { + LOG(ERROR) << "Failed to parse server handshake HTTP request"; + + InvokeError(); + + return; + } + if (parse_rv == -2) { + return; // Incomplete headers, wait for more data + } + + if (minor_version < 1) { + LOG(ERROR) << "HTTP version must be at least 1.1, got 1." << minor_version; + + InvokeError(); + + return; + } + + bool has_valid_ct = CheckHeader(headers, num_headers, "content-type", "application/web-stream"); + if (!has_valid_ct) { LOG(ERROR) << "Missing web-stream Content-Type!"; - HandleEvent(BEV_EVENT_ERROR); + + InvokeError(); + + return; + } + + bool headers_valid = ValidateHeaders(headers, num_headers); + if (!headers_valid) { + InvokeError(); + + return; + } + + int drain_rv = evbuffer_drain(input, parse_rv); + if (drain_rv != 0) { + LOG(ERROR) << "evbuffer_drain() failed"; + + InvokeError(); + return; } // Send the HTTP 200 response - SendHttpResponse(bev_, "200 OK", "application/web-stream"); + std::stringstream ss; + ss << "HTTP/1.1 200 OK\r\n"; + ss << "Content-Type: application/web-stream\r\n"; + ss << "Transfer-Encoding: chunked\r\n"; + ss << "\r\n"; // End of headers + std::string response_data = ss.str(); + int write_rv = bufferevent_write(bev_, response_data.c_str(), response_data.length()); + if (write_rv != 0) { + LOG(ERROR) << "bufferevent_write() failed"; + + InvokeError(); + + return; + } // Handshake successful. Hand over bufferevent and trigger success callback. bufferevent* bev = bev_; @@ -201,21 +467,25 @@ void ServerHandshake::HandleRead() { } void ServerHandshake::HandleEvent(short what) { - if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR | BEV_EVENT_TIMEOUT)) { - if (what & BEV_EVENT_ERROR) { - int err = EVUTIL_SOCKET_ERROR(); - if (err != 0) { - LOG(ERROR) << "Error during server handshake: " - << evutil_socket_error_to_string(err); - } else { - LOG(ERROR) << "Error during server handshake"; - } + if (what & BEV_EVENT_ERROR) { + int err = EVUTIL_SOCKET_ERROR(); + if (err != 0) { + LOG(ERROR) << "Error during server handshake: " + << evutil_socket_error_to_string(err); + } else { + LOG(ERROR) << "Error during server handshake"; } + } - auto on_error = std::move(on_error_); - delete this; - if (on_error) { - on_error(); - } + if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR | BEV_EVENT_TIMEOUT)) { + InvokeError(); + } +} + +void ServerHandshake::InvokeError() { + auto on_error = std::move(on_error_); + delete this; + if (on_error) { + on_error(); } } diff --git a/wish/cpp/src/handshake.h b/wish/cpp/src/handshake.h index 50f30ab..ff73d4b 100644 --- a/wish/cpp/src/handshake.h +++ b/wish/cpp/src/handshake.h @@ -1,8 +1,9 @@ #ifndef WISH_CPP_SRC_HANDSHAKE_H_ #define WISH_CPP_SRC_HANDSHAKE_H_ -#include #include +#include + #include #include #include @@ -23,8 +24,10 @@ class ClientHandshake { void HandleRead(); void HandleEvent(short what); + void InvokeError(); bufferevent* bev_; + OnOpenCallback on_open_; OnErrorCallback on_error_; }; @@ -45,8 +48,10 @@ class ServerHandshake { void HandleRead(); void HandleEvent(short what); + void InvokeError(); bufferevent* bev_; + OnOpenCallback on_open_; OnErrorCallback on_error_; }; diff --git a/wish/cpp/src/handshake_test.cc b/wish/cpp/src/handshake_test.cc index 0130a42..8165ba7 100644 --- a/wish/cpp/src/handshake_test.cc +++ b/wish/cpp/src/handshake_test.cc @@ -319,3 +319,321 @@ TEST_F(HandshakeTest, ServerHandshakeEventError) { EXPECT_FALSE(open_called); EXPECT_TRUE(error_called); } + +TEST_F(HandshakeTest, ClientHandshakeRejectsContentEncoding) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nContent-Encoding: gzip\r\nTransfer-Encoding: chunked\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsContentLength) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nContent-Length: 10\r\nTransfer-Encoding: chunked\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsConnectionClose) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nConnection: close\r\nTransfer-Encoding: chunked\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsNonChunkedTransferEncoding) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nTransfer-Encoding: gzip\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsUpgradeHeader) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nUpgrade: websocket\r\nTransfer-Encoding: chunked\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsHTTP10) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.0 200 OK\r\nContent-Type: application/web-stream\r\nTransfer-Encoding: chunked\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ClientHandshakeRejectsMultipleTransferEncodings) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto client = std::make_unique( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + client->Start(); + + int limit = 100; + while (evbuffer_get_length(bufferevent_get_input(pair[1])) == 0 && --limit > 0) { + event_base_loop(base_, EVLOOP_NONBLOCK); + } + std::string req = ReadAllData(pair[1]); (void)req; + + const char* response = "HTTP/1.1 200 OK\r\nContent-Type: application/web-stream\r\nTransfer-Encoding: chunked, gzip\r\n\r\n"; + bufferevent_write(pair[1], response, strlen(response)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ServerHandshakeRejectsContentEncoding) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto* server = new ServerHandshake( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + server->Start(); + + const char* request = "POST / HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/web-stream\r\nContent-Encoding: gzip\r\n\r\n"; + bufferevent_write(pair[1], request, strlen(request)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ServerHandshakeRejectsContentLength) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto* server = new ServerHandshake( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + server->Start(); + + const char* request = "POST / HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/web-stream\r\nContent-Length: 100\r\n\r\n"; + bufferevent_write(pair[1], request, strlen(request)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ServerHandshakeRejectsConnectionClose) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto* server = new ServerHandshake( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + server->Start(); + + const char* request = "POST / HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/web-stream\r\nConnection: close\r\n\r\n"; + bufferevent_write(pair[1], request, strlen(request)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ServerHandshakeRejectsNonChunkedTransferEncoding) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto* server = new ServerHandshake( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + server->Start(); + + const char* request = "POST / HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/web-stream\r\nTransfer-Encoding: gzip\r\n\r\n"; + bufferevent_write(pair[1], request, strlen(request)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} + +TEST_F(HandshakeTest, ServerHandshakeRejectsHTTP10) { + bufferevent* pair[2]; + int rv = bufferevent_pair_new(base_, BEV_OPT_CLOSE_ON_FREE | BEV_OPT_DEFER_CALLBACKS, pair); + ASSERT_EQ(rv, 0); + bufferevent_enable(pair[1], EV_READ | EV_WRITE); + + bool open_called = false; + bool error_called = false; + auto* server = new ServerHandshake( + pair[0], + [&](bufferevent* bev) { open_called = true; bufferevent_free(bev); event_base_loopbreak(base_); }, + [&]() { error_called = true; event_base_loopbreak(base_); }); + server->Start(); + + const char* request = "POST / HTTP/1.0\r\nHost: localhost\r\nContent-Type: application/web-stream\r\n\r\n"; + bufferevent_write(pair[1], request, strlen(request)); + event_base_dispatch(base_); + + EXPECT_FALSE(open_called); + EXPECT_TRUE(error_called); + bufferevent_free(pair[1]); +} diff --git a/wish/cpp/src/plain_client.cc b/wish/cpp/src/plain_client.cc index db5744c..ccbdf1f 100644 --- a/wish/cpp/src/plain_client.cc +++ b/wish/cpp/src/plain_client.cc @@ -46,12 +46,14 @@ bool PlainClient::Init() { return false; } - if (bufferevent_socket_connect_hostname(bev, - dns_base_, - AF_INET, - host_.c_str(), - port_) < 0) { + int connect_rv = bufferevent_socket_connect_hostname(bev, + dns_base_, + AF_INET, + host_.c_str(), + port_); + if (connect_rv < 0) { LOG(ERROR) << "bufferevent_socket_connect_hostname() failed"; + bufferevent_free(bev); return false; @@ -84,8 +86,6 @@ void PlainClient::SetOnOpen(OpenCallback cb) { } void PlainClient::Run() { - LOG(INFO) << "Client running..."; - event_base_dispatch(base_); } diff --git a/wish/cpp/src/plain_server.cc b/wish/cpp/src/plain_server.cc index 7f4e815..3f3ffae 100644 --- a/wish/cpp/src/plain_server.cc +++ b/wish/cpp/src/plain_server.cc @@ -61,8 +61,6 @@ void PlainServer::SetOnStream(StreamCallback cb) { } void PlainServer::Run() { - LOG(INFO) << "Server listening on port " << port_ << "..."; - event_base_dispatch(base_); } @@ -78,11 +76,12 @@ void PlainServer::AcceptConnCb(evconnlistener* listener, PlainServer* server = static_cast(ctx); int one = 1; - if (setsockopt(fd, - IPPROTO_TCP, - TCP_NODELAY, - &one, - sizeof(one)) < 0) { + int set_opt_rv = setsockopt(fd, + IPPROTO_TCP, + TCP_NODELAY, + &one, + sizeof(one)); + if (set_opt_rv < 0) { LOG(ERROR) << "setsockopt(TCP_NODELAY) failed: " << strerror(errno); } diff --git a/wish/cpp/src/tls_client.cc b/wish/cpp/src/tls_client.cc index 69d548b..9db345a 100644 --- a/wish/cpp/src/tls_client.cc +++ b/wish/cpp/src/tls_client.cc @@ -1,7 +1,6 @@ #include "tls_client.h" #include - #include #include @@ -72,14 +71,19 @@ bool TlsClient::Init() { if (!bev) { LOG(ERROR) << "bufferevent_openssl_socket_new() failed"; + SSL_free(ssl); + return false; } - if (bufferevent_socket_connect_hostname(bev, - dns_base_, - AF_INET, - host_.c_str(), - port_) < 0) { + bufferevent_openssl_set_allow_dirty_shutdown(bev, 1); + + int connect_rv = bufferevent_socket_connect_hostname(bev, + dns_base_, + AF_INET, + host_.c_str(), + port_); + if (connect_rv < 0) { LOG(ERROR) << "bufferevent_socket_connect_hostname() failed"; bufferevent_free(bev); @@ -113,8 +117,6 @@ void TlsClient::SetOnOpen(OpenCallback cb) { } void TlsClient::Run() { - LOG(INFO) << "Client running..."; - event_base_dispatch(base_); } diff --git a/wish/cpp/src/tls_context.cc b/wish/cpp/src/tls_context.cc index 47e033a..91cebb7 100644 --- a/wish/cpp/src/tls_context.cc +++ b/wish/cpp/src/tls_context.cc @@ -35,9 +35,10 @@ bool TlsContext::Init(bool is_server) { } if (!ca_file_.empty()) { - if (SSL_CTX_load_verify_locations(ssl_ctx_, - ca_file_.c_str(), - nullptr) != 1) { + int load_rv = SSL_CTX_load_verify_locations(ssl_ctx_, + ca_file_.c_str(), + nullptr); + if (load_rv != 1) { LOG(ERROR) << "Error loading CA file: " << ca_file_; return false; @@ -55,23 +56,26 @@ bool TlsContext::Init(bool is_server) { // Load own certificate and key if (!certificate_file_.empty() && !private_key_file_.empty()) { - if (SSL_CTX_use_certificate_file(ssl_ctx_, - certificate_file_.c_str(), - SSL_FILETYPE_PEM) <= 0) { + int cert_rv = SSL_CTX_use_certificate_file(ssl_ctx_, + certificate_file_.c_str(), + SSL_FILETYPE_PEM); + if (cert_rv <= 0) { LOG(ERROR) << "Error loading certificate file: " << certificate_file_; return false; } - if (SSL_CTX_use_PrivateKey_file(ssl_ctx_, - private_key_file_.c_str(), - SSL_FILETYPE_PEM) <= 0) { + int key_rv = SSL_CTX_use_PrivateKey_file(ssl_ctx_, + private_key_file_.c_str(), + SSL_FILETYPE_PEM); + if (key_rv <= 0) { LOG(ERROR) << "Error loading key file: " << private_key_file_; return false; } - if (!SSL_CTX_check_private_key(ssl_ctx_)) { + int check_rv = SSL_CTX_check_private_key(ssl_ctx_); + if (check_rv != 1) { LOG(ERROR) << "Private key does not match the certificate public key"; return false; diff --git a/wish/cpp/src/tls_server.cc b/wish/cpp/src/tls_server.cc index c7412b9..72235e6 100644 --- a/wish/cpp/src/tls_server.cc +++ b/wish/cpp/src/tls_server.cc @@ -2,16 +2,15 @@ #include #include +#include #include +#include #include #include "buffer_event_web_stream.h" #include "handshake.h" -#include -#include - TlsServer::TlsServer(int port, const std::string& ca_file, const std::string& cert_file, @@ -81,8 +80,6 @@ bool TlsServer::Init() { } void TlsServer::Run() { - LOG(INFO) << "Server listening on port " << port_ << "..."; - event_base_dispatch(base_); } @@ -99,11 +96,12 @@ void TlsServer::AcceptConnCb(evconnlistener* listener, TlsServer* server = static_cast(ctx); int one = 1; - if (setsockopt(fd, - IPPROTO_TCP, - TCP_NODELAY, - &one, - sizeof(one)) < 0) { + int set_opt_rv = setsockopt(fd, + IPPROTO_TCP, + TCP_NODELAY, + &one, + sizeof(one)); + if (set_opt_rv < 0) { LOG(ERROR) << "setsockopt(TCP_NODELAY) failed: " << strerror(errno); } @@ -113,6 +111,16 @@ void TlsServer::AcceptConnCb(evconnlistener* listener, ssl, BUFFEREVENT_SSL_ACCEPTING, BEV_OPT_CLOSE_ON_FREE); + if (!bev) { + LOG(ERROR) << "bufferevent_openssl_socket_new() failed"; + + SSL_free(ssl); + evutil_closesocket(fd); + + return; + } + + bufferevent_openssl_set_allow_dirty_shutdown(bev, 1); auto* handshake = new ServerHandshake( bev,