diff --git a/wish/cpp/CMakeLists.txt b/wish/cpp/CMakeLists.txt index 8062985..12eabb5 100644 --- a/wish/cpp/CMakeLists.txt +++ b/wish/cpp/CMakeLists.txt @@ -12,227 +12,237 @@ # See the License for the specific language governing permissions and # limitations under the License. -cmake_minimum_required(VERSION 3.14) -project(wish_cpp) - - -option(WISH_BUILD_TESTS "Build wish unit tests" ON) -option(WISH_BUILD_BENCHMARKS "Build wish benchmarks" ON) -option(WISH_BUILD_EXAMPLES "Build wish examples" ON) - - -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -include(FetchContent) - -# Exclude tests, benchmarks, sample files, etc. from dependencies. -# Set before any MakeAvailable so no dep accidentally enables testing. -set(BUILD_TESTING OFF CACHE BOOL "" FORCE) - - -set(ABSL_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) -set(ABSL_PROPAGATE_CXX_STD ON) - -# Static-build abseil in order to use WHOLE_ARCHIVE to force linking of log_flags static initializer -set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) - -FetchContent_Declare( - abseil - GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git - GIT_TAG 255c84dadd029fd8ad25c5efb5933e47beaa00c7 # 20260107.1 -) -FetchContent_MakeAvailable(abseil) - -unset(BUILD_SHARED_LIBS CACHE) - - -set(WSLAY_TESTS OFF CACHE BOOL "" FORCE) -set(WSLAY_EXAMPLES OFF CACHE BOOL "" FORCE) - -FetchContent_Declare( - wslay - GIT_REPOSITORY https://github.com/tatsuhiro-t/wslay.git - GIT_TAG 0e7d106ff89ad6638090fd811a9b2e4c5dda8d40 - PATCH_COMMAND git apply --reverse --check ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch -) - - -# Disable OpenSSL and Mbed TLS. We use BoringSSL instead. -set(EVENT__DISABLE_OPENSSL ON CACHE BOOL "" FORCE) -set(EVENT__DISABLE_MBEDTLS ON CACHE BOOL "" FORCE) - -# Build libevent as a static library so executables are self-contained. -set(EVENT__LIBRARY_TYPE STATIC CACHE STRING "" FORCE) - -set(EVENT__DISABLE_BENCHMARKS ON CACHE BOOL "" FORCE) -set(EVENT__DISABLE_TESTS ON CACHE BOOL "" FORCE) -set(EVENT__DISABLE_REGRESS ON CACHE BOOL "" FORCE) -set(EVENT__DISABLE_SAMPLES ON CACHE BOOL "" FORCE) - -FetchContent_Declare( - libevent - GIT_REPOSITORY https://github.com/libevent/libevent.git - GIT_TAG 780acfe8b2495949f0dc3ebd6f18eea2dec605a6 -) - - -FetchContent_Declare( - boringssl - GIT_REPOSITORY https://boringssl.googlesource.com/boringssl - GIT_TAG 0.20260211.0 -) - - -FetchContent_Declare( - picohttpparser - GIT_REPOSITORY https://github.com/h2o/picohttpparser.git - GIT_TAG 539bb9f510f5aefc8efda53aaf822a21ecae2832 # v1.2 -) - - -# These variables are unprefixed, so MakeAvailable is called immediately and -# they are unset afterwards to avoid leaking into other dependencies. - -set(ENABLE_FAILMALLOC OFF CACHE BOOL "" FORCE) - -# Since BUILD_SHARED_LIBS=OFF, we need to explicitly enable the static target (nghttp2_static) -set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) -set(BUILD_STATIC_LIBS ON CACHE BOOL "" FORCE) - -# nghttp2: build only the core library (no TLS apps, no examples). -set(ENABLE_LIB_ONLY ON CACHE BOOL "" FORCE) - -set(ENABLE_DOC OFF CACHE BOOL "" FORCE) - -set(WITH_LIBXML2 OFF CACHE BOOL "" FORCE) -set(WITH_JEMALLOC OFF CACHE BOOL "" FORCE) -set(WITH_MRUBY OFF CACHE BOOL "" FORCE) -set(WITH_NEVERBLEED OFF CACHE BOOL "" FORCE) - -FetchContent_Declare( - nghttp2 - GIT_REPOSITORY https://github.com/nghttp2/nghttp2.git - GIT_TAG 68cb6900fde14c77f0cd7add0e094a862960eb99 # v1.69.0 -) -FetchContent_MakeAvailable(nghttp2) - -# Reset the cached variables to avoid affecting other dependencies. They are not prefixed. -foreach(_var IN ITEMS - ENABLE_LIB_ONLY - ENABLE_DOC - ENABLE_FAILMALLOC - BUILD_SHARED_LIBS - BUILD_STATIC_LIBS - WITH_LIBXML2 - WITH_JEMALLOC - WITH_MRUBY - WITH_NEVERBLEED) - unset(${_var} CACHE) -endforeach() - - -if(WISH_BUILD_BENCHMARKS) - set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "" FORCE) - set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "" FORCE) - set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) - set(BENCHMARK_INSTALL_DOCS OFF CACHE BOOL "" FORCE) - set(BENCHMARK_DOWNLOAD_DEPENDENCIES OFF CACHE BOOL "" FORCE) - FetchContent_Declare( - googlebenchmark - GIT_REPOSITORY https://github.com/google/benchmark.git - GIT_TAG 192ef10025eb2c4cdd392bc502f0c852196baa48 # v1.9.5 - ) - FetchContent_MakeAvailable(googlebenchmark) -endif() - -if(WISH_BUILD_TESTS) - FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG 52eb8108c5bdec04579160ae17225d66034bd723 # v1.17.0 - ) - FetchContent_MakeAvailable(googletest) -endif() - - -FetchContent_MakeAvailable(wslay libevent boringssl picohttpparser) - -include_directories(${wslay_SOURCE_DIR}/lib/includes) -include_directories(${libevent_SOURCE_DIR}/include) -include_directories(${libevent_BINARY_DIR}/include) -include_directories(${boringssl_SOURCE_DIR}/include) -include_directories(${nghttp2_SOURCE_DIR}/lib/includes) -include_directories(${nghttp2_BINARY_DIR}/lib/includes) -include_directories(${picohttpparser_SOURCE_DIR}) -include_directories(src) - -add_library(http1_handshake STATIC - src/handshake.h - src/handshake.cc - ${picohttpparser_SOURCE_DIR}/picohttpparser.c - ${picohttpparser_SOURCE_DIR}/picohttpparser.h -) -target_link_libraries(http1_handshake event absl::strings absl::log) - -# We will compile libevent's bufferevent_openssl.c inside our library -# since we disabled it in libevent's own build. -# Build as a static library: libevent is also static and was not compiled with -# -fPIC, so it cannot be linked into a shared object. -add_library(web_stream STATIC - src/web_stream.h - src/buffer_event_web_stream.cc - src/buffer_event_web_stream.h - ${libevent_SOURCE_DIR}/bufferevent_openssl.c - ${libevent_SOURCE_DIR}/bufferevent_ssl.c - src/plain_server.cc - src/plain_server.h - src/plain_client.cc - src/plain_client.h - src/tls_context.cc - src/tls_context.h - src/tls_server.cc - src/tls_server.h - src/tls_client.cc - src/tls_client.h - src/nghttp2_web_stream.cc - src/nghttp2_web_stream.h - src/h2_server.cc - src/h2_server.h - src/h2_client.cc - src/h2_client.h - src/h2_tls_server.cc - src/h2_tls_server.h - src/h2_tls_client.cc - src/h2_tls_client.h -) -# BoringSSL targets: ssl and crypto. nghttp2 for HTTP/2 framing. -target_link_libraries(web_stream http1_handshake wslay event ssl crypto nghttp2_static absl::strings absl::log) -target_compile_definitions(web_stream PUBLIC EVENT__HAVE_OPENSSL=1) - - -if(WISH_BUILD_EXAMPLES) - add_subdirectory(examples) -endif() - -if(WISH_BUILD_BENCHMARKS) - add_subdirectory(benchmark) -endif() - -if(WISH_BUILD_TESTS) - enable_testing() - - add_executable(buffer_event_web_stream_test src/buffer_event_web_stream_test.cc) - target_link_libraries(buffer_event_web_stream_test web_stream gtest_main gtest) - add_test(NAME buffer_event_web_stream_test COMMAND buffer_event_web_stream_test) - - add_executable(nghttp2_web_stream_test src/nghttp2_web_stream_test.cc) - target_link_libraries(nghttp2_web_stream_test web_stream gtest_main gtest) - add_test(NAME nghttp2_web_stream_test COMMAND nghttp2_web_stream_test) - - add_executable(handshake_test src/handshake_test.cc) - target_link_libraries(handshake_test http1_handshake gtest_main gtest event) - add_test(NAME handshake_test COMMAND handshake_test) -endif() +cmake_minimum_required(VERSION 3.14) +project(wish_cpp) + + +option(WISH_BUILD_TESTS "Build wish unit tests" ON) +option(WISH_BUILD_BENCHMARKS "Build wish benchmarks" ON) +option(WISH_BUILD_EXAMPLES "Build wish examples" ON) + + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) + +# Exclude tests, benchmarks, sample files, etc. from dependencies. +# Set before any MakeAvailable so no dep accidentally enables testing. +set(BUILD_TESTING OFF CACHE BOOL "" FORCE) + + +set(ABSL_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) +set(ABSL_PROPAGATE_CXX_STD ON) + +# Static-build abseil in order to use WHOLE_ARCHIVE to force linking of log_flags static initializer +set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) + +FetchContent_Declare( + abseil + GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git + GIT_TAG 255c84dadd029fd8ad25c5efb5933e47beaa00c7 # 20260107.1 +) +FetchContent_MakeAvailable(abseil) + +unset(BUILD_SHARED_LIBS CACHE) + + +set(WSLAY_TESTS OFF CACHE BOOL "" FORCE) +set(WSLAY_EXAMPLES OFF CACHE BOOL "" FORCE) + +FetchContent_Declare( + wslay + GIT_REPOSITORY https://github.com/tatsuhiro-t/wslay.git + GIT_TAG 0e7d106ff89ad6638090fd811a9b2e4c5dda8d40 + PATCH_COMMAND git apply --reverse --check ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch || git apply ${CMAKE_CURRENT_SOURCE_DIR}/patches/wslay.patch +) + + +# Disable OpenSSL and Mbed TLS. We use BoringSSL instead. +set(EVENT__DISABLE_OPENSSL ON CACHE BOOL "" FORCE) +set(EVENT__DISABLE_MBEDTLS ON CACHE BOOL "" FORCE) + +# Build libevent as a static library so executables are self-contained. +set(EVENT__LIBRARY_TYPE STATIC CACHE STRING "" FORCE) + +set(EVENT__DISABLE_BENCHMARKS ON CACHE BOOL "" FORCE) +set(EVENT__DISABLE_TESTS ON CACHE BOOL "" FORCE) +set(EVENT__DISABLE_REGRESS ON CACHE BOOL "" FORCE) +set(EVENT__DISABLE_SAMPLES ON CACHE BOOL "" FORCE) + +FetchContent_Declare( + libevent + GIT_REPOSITORY https://github.com/libevent/libevent.git + GIT_TAG 780acfe8b2495949f0dc3ebd6f18eea2dec605a6 +) + + +FetchContent_Declare( + boringssl + GIT_REPOSITORY https://boringssl.googlesource.com/boringssl + GIT_TAG 0.20260211.0 +) + + +FetchContent_Declare( + picohttpparser + GIT_REPOSITORY https://github.com/h2o/picohttpparser.git + GIT_TAG 539bb9f510f5aefc8efda53aaf822a21ecae2832 # v1.2 +) + + +# These variables are unprefixed, so MakeAvailable is called immediately and +# they are unset afterwards to avoid leaking into other dependencies. + +set(ENABLE_FAILMALLOC OFF CACHE BOOL "" FORCE) + +# Since BUILD_SHARED_LIBS=OFF, we need to explicitly enable the static target (nghttp2_static) +set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE) +set(BUILD_STATIC_LIBS ON CACHE BOOL "" FORCE) + +# nghttp2: build only the core library (no TLS apps, no examples). +set(ENABLE_LIB_ONLY ON CACHE BOOL "" FORCE) + +set(ENABLE_DOC OFF CACHE BOOL "" FORCE) + +set(WITH_LIBXML2 OFF CACHE BOOL "" FORCE) +set(WITH_JEMALLOC OFF CACHE BOOL "" FORCE) +set(WITH_MRUBY OFF CACHE BOOL "" FORCE) +set(WITH_NEVERBLEED OFF CACHE BOOL "" FORCE) + +FetchContent_Declare( + nghttp2 + GIT_REPOSITORY https://github.com/nghttp2/nghttp2.git + GIT_TAG 68cb6900fde14c77f0cd7add0e094a862960eb99 # v1.69.0 +) +FetchContent_MakeAvailable(nghttp2) + +# Reset the cached variables to avoid affecting other dependencies. They are not prefixed. +foreach(_var IN ITEMS + ENABLE_LIB_ONLY + ENABLE_DOC + ENABLE_FAILMALLOC + BUILD_SHARED_LIBS + BUILD_STATIC_LIBS + WITH_LIBXML2 + WITH_JEMALLOC + WITH_MRUBY + WITH_NEVERBLEED) + unset(${_var} CACHE) +endforeach() + + +if(WISH_BUILD_BENCHMARKS) + set(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "" FORCE) + set(BENCHMARK_ENABLE_GTEST_TESTS OFF CACHE BOOL "" FORCE) + set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) + set(BENCHMARK_INSTALL_DOCS OFF CACHE BOOL "" FORCE) + set(BENCHMARK_DOWNLOAD_DEPENDENCIES OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + googlebenchmark + GIT_REPOSITORY https://github.com/google/benchmark.git + GIT_TAG 192ef10025eb2c4cdd392bc502f0c852196baa48 # v1.9.5 + ) + FetchContent_MakeAvailable(googlebenchmark) +endif() + +if(WISH_BUILD_TESTS) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG 52eb8108c5bdec04579160ae17225d66034bd723 # v1.17.0 + ) + FetchContent_MakeAvailable(googletest) +endif() + + +FetchContent_MakeAvailable(wslay libevent boringssl picohttpparser) + +include_directories(${wslay_SOURCE_DIR}/lib/includes) +include_directories(${libevent_SOURCE_DIR}/include) +include_directories(${libevent_BINARY_DIR}/include) +include_directories(${boringssl_SOURCE_DIR}/include) +include_directories(${nghttp2_SOURCE_DIR}/lib/includes) +include_directories(${nghttp2_BINARY_DIR}/lib/includes) +include_directories(${picohttpparser_SOURCE_DIR}) +include_directories(src) + +add_library(http1_handshake STATIC + src/handshake.h + src/handshake.cc + ${picohttpparser_SOURCE_DIR}/picohttpparser.c + ${picohttpparser_SOURCE_DIR}/picohttpparser.h +) +target_link_libraries(http1_handshake event absl::strings absl::log) + +# We will compile libevent's bufferevent_openssl.c inside our library +# since we disabled it in libevent's own build. +# Build as a static library: libevent is also static and was not compiled with +# -fPIC, so it cannot be linked into a shared object. +add_library(web_stream_core STATIC + src/web_stream.h + src/buffer_event_web_stream.cc + src/buffer_event_web_stream.h + ${libevent_SOURCE_DIR}/bufferevent_openssl.c + ${libevent_SOURCE_DIR}/bufferevent_ssl.c + src/tls_context.cc + src/tls_context.h + src/nghttp2_web_stream.cc + src/nghttp2_web_stream.h +) +target_link_libraries(web_stream_core PUBLIC http1_handshake wslay event ssl crypto nghttp2_static absl::strings absl::log) +target_compile_definitions(web_stream_core PUBLIC EVENT__HAVE_OPENSSL=1) + +add_library(web_stream_client STATIC + src/plain_client.cc + src/plain_client.h + src/tls_client.cc + src/tls_client.h + src/h2_client.cc + src/h2_client.h + src/h2_tls_client.cc + src/h2_tls_client.h +) +target_link_libraries(web_stream_client PUBLIC web_stream_core) + +add_library(web_stream_server STATIC + src/plain_server.cc + src/plain_server.h + src/tls_server.cc + src/tls_server.h + src/h2_server.cc + src/h2_server.h + src/h2_tls_server.cc + src/h2_tls_server.h +) +target_link_libraries(web_stream_server PUBLIC web_stream_core) + +add_library(web_stream INTERFACE) +target_link_libraries(web_stream INTERFACE web_stream_client web_stream_server) + + +if(WISH_BUILD_EXAMPLES) + add_subdirectory(examples) +endif() + +if(WISH_BUILD_BENCHMARKS) + add_subdirectory(benchmark) +endif() + +if(WISH_BUILD_TESTS) + enable_testing() + + add_executable(buffer_event_web_stream_test src/buffer_event_web_stream_test.cc) + target_link_libraries(buffer_event_web_stream_test web_stream gtest_main gtest) + add_test(NAME buffer_event_web_stream_test COMMAND buffer_event_web_stream_test) + + add_executable(nghttp2_web_stream_test src/nghttp2_web_stream_test.cc) + target_link_libraries(nghttp2_web_stream_test web_stream gtest_main gtest) + add_test(NAME nghttp2_web_stream_test COMMAND nghttp2_web_stream_test) + + add_executable(handshake_test src/handshake_test.cc) + target_link_libraries(handshake_test http1_handshake gtest_main gtest event) + add_test(NAME handshake_test COMMAND handshake_test) +endif() diff --git a/wish/cpp/examples/h2_hello_client.cc b/wish/cpp/examples/h2_hello_client.cc index 33aa183..8c3868f 100644 --- a/wish/cpp/examples/h2_hello_client.cc +++ b/wish/cpp/examples/h2_hello_client.cc @@ -50,6 +50,12 @@ int main(int argc, char** argv) { return 1; } + client.SetOnError([&client]() { + LOG(ERROR) << "Client error or handshake failed"; + + client.Stop(); + }); + client.SetOnOpen([&client](WebStream* stream) { LOG(INFO) << "OnOpen"; @@ -79,6 +85,12 @@ int main(int argc, char** argv) { client.Stop(); }); + stream->SetOnError([&client]() { + LOG(ERROR) << "Stream error"; + + client.Stop(); + }); + stream->SendText("Hello web-stream text over HTTP/2!"); stream->SendBinary("Hello web-stream binary over HTTP/2!"); stream->SendMetadata("Hello web-stream metadata over HTTP/2!"); diff --git a/wish/cpp/examples/h2_tls_hello_client.cc b/wish/cpp/examples/h2_tls_hello_client.cc index 13602a6..f2692b5 100644 --- a/wish/cpp/examples/h2_tls_hello_client.cc +++ b/wish/cpp/examples/h2_tls_hello_client.cc @@ -72,6 +72,12 @@ int main(int argc, char** argv) { return 1; } + client.SetOnError([&client]() { + LOG(ERROR) << "Client error or handshake failed"; + + client.Stop(); + }); + client.SetOnOpen([&client](WebStream* stream) { LOG(INFO) << "OnOpen"; @@ -101,6 +107,12 @@ int main(int argc, char** argv) { client.Stop(); }); + stream->SetOnError([&client]() { + LOG(ERROR) << "Stream error"; + + client.Stop(); + }); + stream->SendText("Hello web-stream text over HTTP/2+TLS!"); stream->SendBinary("Hello web-stream binary over HTTP/2+TLS!"); stream->SendMetadata("Hello web-stream metadata over HTTP/2+TLS!"); diff --git a/wish/cpp/examples/plain_hello_client.cc b/wish/cpp/examples/plain_hello_client.cc index b005b9e..00d0ad7 100644 --- a/wish/cpp/examples/plain_hello_client.cc +++ b/wish/cpp/examples/plain_hello_client.cc @@ -50,6 +50,12 @@ int main(int argc, char** argv) { return 1; } + client.SetOnError([&client]() { + LOG(ERROR) << "Client error or handshake failed"; + + client.Stop(); + }); + client.SetOnOpen([&client](WebStream* stream) { LOG(INFO) << "OnOpen"; @@ -79,6 +85,12 @@ int main(int argc, char** argv) { client.Stop(); }); + stream->SetOnError([&client]() { + LOG(ERROR) << "Stream error"; + + client.Stop(); + }); + stream->SendText("Hello web-stream text!"); stream->SendBinary("Hello web-stream binary!"); stream->SendMetadata("Hello web-stream metadata!"); diff --git a/wish/cpp/examples/tls_hello_client.cc b/wish/cpp/examples/tls_hello_client.cc index 2f33dca..ac08753 100644 --- a/wish/cpp/examples/tls_hello_client.cc +++ b/wish/cpp/examples/tls_hello_client.cc @@ -12,105 +12,117 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include - -#include - -#include "../src/tls_client.h" -#include "../src/wish_opcodes.h" - -ABSL_FLAG(std::string, host, "127.0.0.1", "Server host address"); -ABSL_FLAG(int, port, 8080, "Server port"); - -ABSL_FLAG(std::string, - ca_cert, - "certs/ca.crt", - "Path to CA certificate file"); -ABSL_FLAG(std::string, - client_cert, - "certs/client.crt", - "Path to client certificate file"); -ABSL_FLAG(std::string, - client_key, - "certs/client.key", - "Path to client private key file"); - -int main(int argc, char** argv) { - absl::ParseCommandLine(argc, argv); - absl::InitializeLog(); - - const std::string host = absl::GetFlag(FLAGS_host); - const int port = absl::GetFlag(FLAGS_port); - - const std::string ca_cert = absl::GetFlag(FLAGS_ca_cert); - const std::string client_cert = absl::GetFlag(FLAGS_client_cert); - const std::string client_key = absl::GetFlag(FLAGS_client_key); - - event_base* base = event_base_new(); - if (!base) { - LOG(ERROR) << "Failed to create event_base"; - - return 1; - } - - { - TlsClient client(base, - host, - port, - ca_cert, - client_cert, - client_key); - - if (!client.Init()) { - LOG(INFO) << "Init() failed"; - - event_base_free(base); - - return 1; - } - - client.SetOnOpen([&client](WebStream* stream) { - LOG(INFO) << "OnOpen"; - - stream->SetOnMessage([](uint8_t opcode, const std::string& msg) { - std::string type; - switch (opcode) { - case WEB_STREAM_OPCODE_TEXT: - type = "TEXT"; - break; - case WEB_STREAM_OPCODE_BINARY: - type = "BINARY"; - break; - case WEB_STREAM_OPCODE_METADATA: - type = "METADATA"; - break; - default: - type = "UNKNOWN(" + std::to_string(opcode) + ")"; - break; - } - - LOG(INFO) << "OnMessage (opcode: " << type << ", message: " << msg << ")"; - }); - - stream->SetOnClose([&client]() { - LOG(INFO) << "OnClose"; - - client.Stop(); - }); - - stream->SendText("Hello web-stream text!"); - stream->SendBinary("Hello web-stream binary!"); - stream->SendMetadata("Hello web-stream metadata!"); - stream->Close(); - }); - - client.Run(); - } - - event_base_free(base); - - return 0; -} +#include +#include +#include +#include + +#include + +#include "../src/tls_client.h" +#include "../src/wish_opcodes.h" + +ABSL_FLAG(std::string, host, "127.0.0.1", "Server host address"); +ABSL_FLAG(int, port, 8080, "Server port"); + +ABSL_FLAG(std::string, + ca_cert, + "certs/ca.crt", + "Path to CA certificate file"); +ABSL_FLAG(std::string, + client_cert, + "certs/client.crt", + "Path to client certificate file"); +ABSL_FLAG(std::string, + client_key, + "certs/client.key", + "Path to client private key file"); + +int main(int argc, char** argv) { + absl::ParseCommandLine(argc, argv); + absl::InitializeLog(); + + const std::string host = absl::GetFlag(FLAGS_host); + const int port = absl::GetFlag(FLAGS_port); + + const std::string ca_cert = absl::GetFlag(FLAGS_ca_cert); + const std::string client_cert = absl::GetFlag(FLAGS_client_cert); + const std::string client_key = absl::GetFlag(FLAGS_client_key); + + event_base* base = event_base_new(); + if (!base) { + LOG(ERROR) << "Failed to create event_base"; + + return 1; + } + + { + TlsClient client(base, + host, + port, + ca_cert, + client_cert, + client_key); + + if (!client.Init()) { + LOG(INFO) << "Init() failed"; + + event_base_free(base); + + return 1; + } + + client.SetOnError([&client]() { + LOG(ERROR) << "Client error or handshake failed"; + + client.Stop(); + }); + + client.SetOnOpen([&client](WebStream* stream) { + LOG(INFO) << "OnOpen"; + + stream->SetOnMessage([](uint8_t opcode, const std::string& msg) { + std::string type; + switch (opcode) { + case WEB_STREAM_OPCODE_TEXT: + type = "TEXT"; + break; + case WEB_STREAM_OPCODE_BINARY: + type = "BINARY"; + break; + case WEB_STREAM_OPCODE_METADATA: + type = "METADATA"; + break; + default: + type = "UNKNOWN(" + std::to_string(opcode) + ")"; + break; + } + + LOG(INFO) << "OnMessage (opcode: " << type << ", message: " << msg << ")"; + }); + + stream->SetOnClose([&client]() { + LOG(INFO) << "OnClose"; + + client.Stop(); + }); + + stream->SetOnError([&client]() { + LOG(ERROR) << "Stream error"; + + client.Stop(); + }); + + stream->SendText("Hello web-stream text!"); + stream->SendBinary("Hello web-stream binary!"); + stream->SendMetadata("Hello web-stream metadata!"); + stream->Close(); + }); + + client.Run(); + } + + event_base_free(base); + + return 0; +} diff --git a/wish/cpp/src/buffer_event_web_stream.cc b/wish/cpp/src/buffer_event_web_stream.cc index 2e8fd89..802acdb 100644 --- a/wish/cpp/src/buffer_event_web_stream.cc +++ b/wish/cpp/src/buffer_event_web_stream.cc @@ -12,497 +12,499 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "buffer_event_web_stream.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -BufferEventWebStream::BufferEventWebStream(bufferevent* bev, - bool is_server) - : bev_(bev), - is_server_(is_server), - ctx_(nullptr), - state_(OPEN) {} - -bool BufferEventWebStream::Init() { - wslay_event_callbacks callbacks = { - WslayRecvCallback, - WslaySendCallback, - WslayGenmaskCallback, - WslayOnFrameRecvStartCallback, - nullptr, // on_frame_recv_chunk_callback - nullptr, // on_frame_recv_end_callback - WslayOnMsgRecvCallback}; - - int rv; - if (is_server_) { - rv = wslay_event_context_server_init(&ctx_, - &callbacks, - this); - } else { - rv = wslay_event_context_client_init(&ctx_, - &callbacks, - this); - } - return rv == 0; -} - -BufferEventWebStream::~BufferEventWebStream() { - wslay_event_context_free(ctx_); - - if (bev_) { - bufferevent_setcb(bev_, - nullptr, - nullptr, - nullptr, - nullptr); - bufferevent_free(bev_); - } -} - -void BufferEventWebStream::SetCleanupCallback(CleanupCallback cb) { - cleanup_cb_ = std::move(cb); -} - -void BufferEventWebStream::Start() { - bufferevent_setcb(bev_, - ReadCallback, - nullptr, - EventCallback, - this); - - int enable_rv = bufferevent_enable(bev_, EV_READ | EV_WRITE); - if (enable_rv != 0) { - VLOG(1) << "bufferevent_enable() failed"; - } - - // If there is already data in the input buffer, process it immediately. - size_t input_len = evbuffer_get_length(bufferevent_get_input(bev_)); - if (input_len > 0) { - ReadCallback(bev_, this); - } -} - -void BufferEventWebStream::SetOnMessage(MessageCallback cb) { on_message_ = cb; } - -void BufferEventWebStream::SetOnClose(CloseCallback cb) { on_close_ = cb; } - -void BufferEventWebStream::SetOnError(ErrorCallback cb) { on_error_ = cb; } - -// ---- Public send methods ---- - -int BufferEventWebStream::SendText(const std::string& msg) { - return SendMessage(WEB_STREAM_OPCODE_TEXT, msg); -} - -int BufferEventWebStream::SendBinary(const std::string& msg) { - return SendMessage(WEB_STREAM_OPCODE_BINARY, msg); -} - -int BufferEventWebStream::SendMetadata(const std::string& msg) { - return SendMessage(WEB_STREAM_OPCODE_METADATA, msg); -} - -int BufferEventWebStream::Close() { - if (close_pending_) { - return -1; - } - close_pending_ = true; - - // Write the terminal zero-length chunk that signals end-of-body to the peer. - static constexpr char kTerminalChunk[] = "0\r\n\r\n"; - int rv = bufferevent_write(bev_, kTerminalChunk, sizeof(kTerminalChunk) - 1); - if (rv != 0) { - VLOG(3) << "bufferevent_write() failed"; - - return -1; - } - - return 0; -} - -// ---- libevent callbacks ---- - -void BufferEventWebStream::ReadCallback(bufferevent* bev, void* ctx) { - BufferEventWebStream* stream = static_cast(ctx); - - for (;;) { - switch (stream->state_) { - case OPEN: { - int rv = wslay_event_recv(stream->ctx_); - if (rv != 0) { - VLOG(2) << "wslay_event_recv() failed: " << rv; - return; - } - - // The inbound terminal chunk was fully consumed by ReadChunkedBytes(). - if (stream->receive_closed_) { - // Check for any extra data received after the terminal chunk - evbuffer* input = bufferevent_get_input(stream->bev_); - size_t extra_len = evbuffer_get_length(input); - if (extra_len > 0) { - VLOG(2) << "Warning: received " << extra_len << " bytes of extra data after stream close."; - - if (evbuffer_drain(input, extra_len) != 0) { - VLOG(2) << "evbuffer_drain failed"; - } - } - - // Keep the read callback as ReadCallback; receive direction is done. - // Do NOT set state_ = CLOSED yet: on_close_() may still call Close() - // to queue the outbound terminal chunk, and we need the output buffer - // to drain before freeing the stream. - bufferevent_setcb(stream->bev_, - ReadCallback, - nullptr, - EventCallback, - stream); - - if (stream->in_message_) { - if (stream->on_error_) { - stream->on_error_(); - } - } else { - if (stream->on_close_) { - stream->on_close_(); - } - } - - if (stream->close_pending_) { - // Close() was called in the callback; the outbound terminal chunk - // (and any pending echo frames) are queued in the output buffer. - // Switch to DRAINING and delete only after the buffer empties. - stream->state_ = DRAINING; - bufferevent_setcb(stream->bev_, - ReadCallback, - DrainCallback, - EventCallback, - stream); - } else { - stream->state_ = CLOSED; - auto cleanup = std::move(stream->cleanup_cb_); - if (cleanup) { - cleanup(stream); - } - } - return; - } - - return; - } - case DRAINING: { - evbuffer* input = bufferevent_get_input(stream->bev_); - size_t len = evbuffer_get_length(input); - if (len > 0) { - VLOG(2) << "Warning: received " << len << " bytes of extra data after stream close."; - - if (evbuffer_drain(input, len) != 0) { - VLOG(2) << "evbuffer_drain failed"; - } - } - return; - } - case CLOSED: - return; - default: - ABSL_UNREACHABLE(); - } - } -} - -void BufferEventWebStream::DrainCallback(bufferevent* bev, - void* ctx) { - BufferEventWebStream* stream = static_cast(ctx); - - // Delete the stream only once all queued outbound data has been sent. - size_t output_len = evbuffer_get_length(bufferevent_get_output(bev)); - if (output_len == 0) { - stream->state_ = CLOSED; - auto cleanup = std::move(stream->cleanup_cb_); - if (cleanup) { - cleanup(stream); - } - } -} - -void BufferEventWebStream::EventCallback(bufferevent* bev, - short what, // NOLINT(runtime/int) - void* ctx) { - if (what & BEV_EVENT_ERROR) { - int err = EVUTIL_SOCKET_ERROR(); - if (err != 0) { - VLOG(2) << "Error on socket: " << evutil_socket_error_to_string(err); - } else { - VLOG(2) << "Error on bufferevent"; - } - } - - if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) { - BufferEventWebStream* stream = static_cast(ctx); - - // If the stream is still in OPEN state, the underlying connection closed - // before we received the clean chunked terminal chunk ("0\r\n\r\n"). - // This is always treated as a premature termination/error. - if (stream->state_ == OPEN) { - if (stream->on_error_) { - stream->on_error_(); - } - } - auto cleanup = std::move(stream->cleanup_cb_); - if (cleanup) { - cleanup(stream); - } - } -} - -// ---- wslay callbacks ---- - -ssize_t BufferEventWebStream::WslayRecvCallback(wslay_event_context* /*ctx*/, - uint8_t* buf, - size_t len, - int /*flags*/, - void* user_data) { - BufferEventWebStream* stream = static_cast(user_data); - - return stream->ReadChunkedBytes(buf, len); -} - -ssize_t BufferEventWebStream::WslaySendCallback(wslay_event_context* ctx, - const uint8_t* data, - size_t len, - int /*flags*/, - void* user_data) { - BufferEventWebStream* stream = static_cast(user_data); - - // Wrap the wslay frame bytes in a single HTTP/1.1 chunk: - // \r\n\r\n - char header[32]; - int header_len = snprintf(header, sizeof(header), "%zx\r\n", len); - if (header_len <= 0) { - wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - int write_header_rv = bufferevent_write(stream->bev_, - header, - static_cast(header_len)); - if (write_header_rv != 0) { - VLOG(3) << "bufferevent_write() failed"; - - wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - int write_data_rv = bufferevent_write(stream->bev_, - data, - len); - if (write_data_rv != 0) { - VLOG(3) << "bufferevent_write() failed"; - - wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - int write_trailer_rv = bufferevent_write(stream->bev_, - "\r\n", - 2); - if (write_trailer_rv != 0) { - VLOG(3) << "bufferevent_write() failed"; - - wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - return static_cast(len); -} - -int BufferEventWebStream::WslayGenmaskCallback(wslay_event_context* ctx, uint8_t* buf, - size_t len, void* user_data) { - ABSL_UNREACHABLE(); - return 0; -} - -void BufferEventWebStream::WslayOnMsgRecvCallback(wslay_event_context* ctx, - const wslay_event_on_msg_recv_arg* arg, - void* user_data) { - BufferEventWebStream* stream = static_cast(user_data); - - if (!wslay_is_ctrl_frame(arg->opcode)) { - stream->in_message_ = false; - } - - // Consider implementing backpressure. - - if (stream->on_message_) { - std::string msg(reinterpret_cast(arg->msg), arg->msg_length); - stream->on_message_(arg->opcode, msg); - } -} - -void BufferEventWebStream::WslayOnFrameRecvStartCallback(wslay_event_context* ctx, - const wslay_event_on_frame_recv_start_arg* arg, - void* user_data) { - BufferEventWebStream* stream = static_cast(user_data); - - if (!wslay_is_ctrl_frame(arg->opcode)) { - stream->in_message_ = true; - } -} - -// ---- Private helpers ---- - -// Decode Transfer-Encoding: chunked bytes from the inbound bufferevent into -// buf[0..len). Called by WslayRecvCallback on every wslay recv request. -// -// State machine (persists across calls via chunk_state_ / chunk_remaining_): -// -// HEADER — parse "\r\n"; if size == 0 set receive_closed_ then TRAILER -// DATA — copy up to chunk_remaining_ payload bytes into buf -// TRAILER — consume the "\r\n" that follows each chunk body -// -// Returns the number of bytes placed in buf, or -1 (with wslay error set) when -// no bytes are available yet or a protocol error is detected. -ssize_t BufferEventWebStream::ReadChunkedBytes(uint8_t* buf, size_t len) { - evbuffer* input = bufferevent_get_input(bev_); - - for (;;) { - switch (chunk_state_) { - case ChunkState::HEADER: { - // Find the "\r\n" that terminates the chunk-size line. - evbuffer_ptr pos = evbuffer_search(input, "\r\n", 2, nullptr); - if (pos.pos < 0) { - wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); - - return -1; // Wait for more data. - } - if (pos.pos == 0) { - VLOG(3) << "ReadChunkedBytes: empty chunk-size line"; - - wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - // Read the entire "\r\n" line into a temporary buffer. - size_t line_len = static_cast(pos.pos) + 2; // include \r\n - std::vector line(line_len + 1, '\0'); - - int remove_rv = evbuffer_remove(input, line.data(), line_len); - if (remove_rv < 0 || static_cast(remove_rv) != line_len) { - VLOG(3) << "ReadChunkedBytes: evbuffer_remove() failed"; - - wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - // Null-terminate at the \r so strtoul sees only hex digits. - line[pos.pos] = '\0'; - - char* end = nullptr; - unsigned long chunk_size = std::strtoul(line.data(), &end, 16); - if (end == line.data()) { - VLOG(3) << "ReadChunkedBytes: malformed chunk size"; - - wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - - chunk_remaining_ = static_cast(chunk_size); - - if (chunk_remaining_ == 0) { - // Terminal chunk: signal receive-close after the trailing \r\n. - terminal_chunk_seen_ = true; - - chunk_state_ = ChunkState::TRAILER; - continue; // Consume the TRAILER in the same call. - } - - chunk_state_ = ChunkState::DATA; - continue; - } - - case ChunkState::DATA: { - size_t avail = evbuffer_get_length(input); - if (avail == 0) { - wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); - - return -1; - } - - size_t n = std::min({len, chunk_remaining_, avail}); - int rv = evbuffer_remove(input, buf, n); - if (rv < 0) { - VLOG(3) << "ReadChunkedBytes: evbuffer_remove() failed"; - - wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); - - return -1; - } - chunk_remaining_ -= n; - if (chunk_remaining_ == 0) { - chunk_state_ = ChunkState::TRAILER; - } - return static_cast(n); - } - - case ChunkState::TRAILER: { - size_t input_len = evbuffer_get_length(input); - if (input_len < 2) { - wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); - - return -1; - } - - if (evbuffer_drain(input, 2) != 0) { - VLOG(2) << "evbuffer_drain failed"; - } - chunk_state_ = ChunkState::HEADER; - - if (terminal_chunk_seen_) { - receive_closed_ = true; - // Terminal chunk fully consumed; let ReadCallback close the stream. - wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); - - return -1; - } - - continue; // Parse the next chunk header. - } - } - } -} - -int BufferEventWebStream::SendMessage(uint8_t opcode, const std::string& msg) { - if (state_ != OPEN || close_pending_) { - return -1; - } - - wslay_event_msg msg_frame = { - opcode, - reinterpret_cast(msg.c_str()), - msg.length()}; - // Queue msg - int rv = wslay_event_queue_msg(ctx_, &msg_frame); - if (rv != 0) { - return rv; - } - - // Force send - return wslay_event_send(ctx_); -} +#include "buffer_event_web_stream.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +BufferEventWebStream::BufferEventWebStream(bufferevent* bev, + bool is_server) + : bev_(bev), + is_server_(is_server), + ctx_(nullptr), + state_(OPEN) {} + +bool BufferEventWebStream::Init() { + wslay_event_callbacks callbacks = { + WslayRecvCallback, + WslaySendCallback, + WslayGenmaskCallback, + WslayOnFrameRecvStartCallback, + nullptr, // on_frame_recv_chunk_callback + nullptr, // on_frame_recv_end_callback + WslayOnMsgRecvCallback}; + + int rv; + if (is_server_) { + rv = wslay_event_context_server_init(&ctx_, + &callbacks, + this); + } else { + rv = wslay_event_context_client_init(&ctx_, + &callbacks, + this); + } + return rv == 0; +} + +BufferEventWebStream::~BufferEventWebStream() { + wslay_event_context_free(ctx_); + + if (bev_) { + bufferevent_setcb(bev_, + nullptr, + nullptr, + nullptr, + nullptr); + bufferevent_free(bev_); + } +} + +void BufferEventWebStream::SetCleanupCallback(CleanupCallback cb) { + cleanup_cb_ = std::move(cb); +} + +void BufferEventWebStream::Start() { + bufferevent_setcb(bev_, + ReadCallback, + nullptr, + EventCallback, + this); + + int enable_rv = bufferevent_enable(bev_, EV_READ | EV_WRITE); + if (enable_rv != 0) { + VLOG(1) << "bufferevent_enable() failed"; + } + + // If there is already data in the input buffer, process it immediately. + size_t input_len = evbuffer_get_length(bufferevent_get_input(bev_)); + if (input_len > 0) { + ReadCallback(bev_, this); + } +} + +void BufferEventWebStream::SetOnMessage(MessageCallback cb) { on_message_ = cb; } + +void BufferEventWebStream::SetOnClose(CloseCallback cb) { on_close_ = cb; } + +void BufferEventWebStream::SetOnError(ErrorCallback cb) { on_error_ = cb; } + +// ---- Public send methods ---- + +int BufferEventWebStream::SendText(const std::string& msg) { + return SendMessage(WEB_STREAM_OPCODE_TEXT, msg); +} + +int BufferEventWebStream::SendBinary(const std::string& msg) { + return SendMessage(WEB_STREAM_OPCODE_BINARY, msg); +} + +int BufferEventWebStream::SendMetadata(const std::string& msg) { + return SendMessage(WEB_STREAM_OPCODE_METADATA, msg); +} + +int BufferEventWebStream::Close() { + if (close_pending_) { + return -1; + } + close_pending_ = true; + + // Write the terminal zero-length chunk that signals end-of-body to the peer. + static constexpr char kTerminalChunk[] = "0\r\n\r\n"; + int rv = bufferevent_write(bev_, kTerminalChunk, sizeof(kTerminalChunk) - 1); + if (rv != 0) { + VLOG(3) << "bufferevent_write() failed"; + + return -1; + } + + return 0; +} + +// ---- libevent callbacks ---- + +void BufferEventWebStream::ReadCallback(bufferevent* bev, void* ctx) { + BufferEventWebStream* stream = static_cast(ctx); + + for (;;) { + switch (stream->state_) { + case OPEN: { + int rv = wslay_event_recv(stream->ctx_); + if (rv != 0) { + VLOG(2) << "wslay_event_recv() failed: " << rv; + return; + } + + // The inbound terminal chunk was fully consumed by ReadChunkedBytes(). + if (stream->receive_closed_) { + // Check for any extra data received after the terminal chunk + evbuffer* input = bufferevent_get_input(stream->bev_); + size_t extra_len = evbuffer_get_length(input); + if (extra_len > 0) { + VLOG(2) << "Warning: received " << extra_len << " bytes of extra data after stream close."; + + if (evbuffer_drain(input, extra_len) != 0) { + VLOG(2) << "evbuffer_drain failed"; + } + } + + // Keep the read callback as ReadCallback; receive direction is done. + // Do NOT set state_ = CLOSED yet: on_close_() may still call Close() + // to queue the outbound terminal chunk, and we need the output buffer + // to drain before freeing the stream. + bufferevent_setcb(stream->bev_, + ReadCallback, + nullptr, + EventCallback, + stream); + + if (stream->in_message_) { + if (stream->on_error_) { + stream->on_error_(); + } + } else { + if (stream->on_close_) { + stream->on_close_(); + } + } + + if (stream->close_pending_) { + stream->state_ = DRAINING; + bufferevent_setcb(stream->bev_, + ReadCallback, + DrainCallback, + EventCallback, + stream); + + stream->TryDrain(); + } else { + stream->state_ = CLOSED; + auto cleanup = std::move(stream->cleanup_cb_); + if (cleanup) { + cleanup(stream); + } + } + return; + } + + return; + } + case DRAINING: { + evbuffer* input = bufferevent_get_input(stream->bev_); + size_t len = evbuffer_get_length(input); + if (len > 0) { + VLOG(2) << "Warning: received " << len << " bytes of extra data after stream close."; + + if (evbuffer_drain(input, len) != 0) { + VLOG(2) << "evbuffer_drain failed"; + } + } + return; + } + case CLOSED: + return; + default: + ABSL_UNREACHABLE(); + } + } +} + +void BufferEventWebStream::DrainCallback(bufferevent* /*bev*/, + void* ctx) { + BufferEventWebStream* stream = static_cast(ctx); + + stream->TryDrain(); +} + +void BufferEventWebStream::EventCallback(bufferevent* bev, + short what, // NOLINT(runtime/int) + void* ctx) { + if (what & BEV_EVENT_ERROR) { + int err = EVUTIL_SOCKET_ERROR(); + if (err != 0) { + VLOG(2) << "Error on socket: " << evutil_socket_error_to_string(err); + } else { + VLOG(2) << "Error on bufferevent"; + } + } + + if (what & (BEV_EVENT_EOF | BEV_EVENT_ERROR)) { + BufferEventWebStream* stream = static_cast(ctx); + + // If the stream is still in OPEN state, the underlying connection closed + // before we received the clean chunked terminal chunk ("0\r\n\r\n"). + // This is always treated as a premature termination/error. + if (stream->state_ == OPEN) { + if (stream->on_error_) { + stream->on_error_(); + } + } + auto cleanup = std::move(stream->cleanup_cb_); + if (cleanup) { + cleanup(stream); + } + } +} + +// ---- wslay callbacks ---- + +ssize_t BufferEventWebStream::WslayRecvCallback(wslay_event_context* /*ctx*/, + uint8_t* buf, + size_t len, + int /*flags*/, + void* user_data) { + BufferEventWebStream* stream = static_cast(user_data); + + return stream->ReadChunkedBytes(buf, len); +} + +ssize_t BufferEventWebStream::WslaySendCallback(wslay_event_context* ctx, + const uint8_t* data, + size_t len, + int /*flags*/, + void* user_data) { + BufferEventWebStream* stream = static_cast(user_data); + + // Wrap the wslay frame bytes in a single HTTP/1.1 chunk: + // \r\n\r\n + char header[32]; + int header_len = snprintf(header, sizeof(header), "%zx\r\n", len); + if (header_len <= 0) { + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + int write_header_rv = bufferevent_write(stream->bev_, + header, + static_cast(header_len)); + if (write_header_rv != 0) { + VLOG(3) << "bufferevent_write() failed"; + + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + int write_data_rv = bufferevent_write(stream->bev_, + data, + len); + if (write_data_rv != 0) { + VLOG(3) << "bufferevent_write() failed"; + + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + int write_trailer_rv = bufferevent_write(stream->bev_, + "\r\n", + 2); + if (write_trailer_rv != 0) { + VLOG(3) << "bufferevent_write() failed"; + + wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + return static_cast(len); +} + +int BufferEventWebStream::WslayGenmaskCallback(wslay_event_context* ctx, uint8_t* buf, + size_t len, void* user_data) { + ABSL_UNREACHABLE(); + return 0; +} + +void BufferEventWebStream::WslayOnMsgRecvCallback(wslay_event_context* ctx, + const wslay_event_on_msg_recv_arg* arg, + void* user_data) { + BufferEventWebStream* stream = static_cast(user_data); + + if (!wslay_is_ctrl_frame(arg->opcode)) { + stream->in_message_ = false; + } + + // Consider implementing backpressure. + + if (stream->on_message_) { + std::string msg(reinterpret_cast(arg->msg), arg->msg_length); + stream->on_message_(arg->opcode, msg); + } +} + +void BufferEventWebStream::WslayOnFrameRecvStartCallback(wslay_event_context* ctx, + const wslay_event_on_frame_recv_start_arg* arg, + void* user_data) { + BufferEventWebStream* stream = static_cast(user_data); + + if (!wslay_is_ctrl_frame(arg->opcode)) { + stream->in_message_ = true; + } +} + +// ---- Private helpers ---- + +// Decode Transfer-Encoding: chunked bytes from the inbound bufferevent into +// buf[0..len). Called by WslayRecvCallback on every wslay recv request. +// +// State machine (persists across calls via chunk_state_ / chunk_remaining_): +// +// HEADER — parse "\r\n"; if size == 0 set receive_closed_ then TRAILER +// DATA — copy up to chunk_remaining_ payload bytes into buf +// TRAILER — consume the "\r\n" that follows each chunk body +// +// Returns the number of bytes placed in buf, or -1 (with wslay error set) when +// no bytes are available yet or a protocol error is detected. +ssize_t BufferEventWebStream::ReadChunkedBytes(uint8_t* buf, size_t len) { + evbuffer* input = bufferevent_get_input(bev_); + + for (;;) { + switch (chunk_state_) { + case ChunkState::HEADER: { + // Find the "\r\n" that terminates the chunk-size line. + evbuffer_ptr pos = evbuffer_search(input, "\r\n", 2, nullptr); + if (pos.pos < 0) { + wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); + + return -1; // Wait for more data. + } + if (pos.pos == 0) { + VLOG(3) << "ReadChunkedBytes: empty chunk-size line"; + + wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + // Read the entire "\r\n" line into a temporary buffer. + size_t line_len = static_cast(pos.pos) + 2; // include \r\n + std::vector line(line_len + 1, '\0'); + + int remove_rv = evbuffer_remove(input, line.data(), line_len); + if (remove_rv < 0 || static_cast(remove_rv) != line_len) { + VLOG(3) << "ReadChunkedBytes: evbuffer_remove() failed"; + + wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + // Null-terminate at the \r so strtoul sees only hex digits. + line[pos.pos] = '\0'; + + char* end = nullptr; + unsigned long chunk_size = std::strtoul(line.data(), &end, 16); + if (end == line.data()) { + VLOG(3) << "ReadChunkedBytes: malformed chunk size"; + + wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + + chunk_remaining_ = static_cast(chunk_size); + + if (chunk_remaining_ == 0) { + // Terminal chunk: signal receive-close after the trailing \r\n. + terminal_chunk_seen_ = true; + + chunk_state_ = ChunkState::TRAILER; + continue; // Consume the TRAILER in the same call. + } + + chunk_state_ = ChunkState::DATA; + continue; + } + + case ChunkState::DATA: { + size_t avail = evbuffer_get_length(input); + if (avail == 0) { + wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); + + return -1; + } + + size_t n = std::min({len, chunk_remaining_, avail}); + int rv = evbuffer_remove(input, buf, n); + if (rv < 0) { + VLOG(3) << "ReadChunkedBytes: evbuffer_remove() failed"; + + wslay_event_set_error(ctx_, WSLAY_ERR_CALLBACK_FAILURE); + + return -1; + } + chunk_remaining_ -= n; + if (chunk_remaining_ == 0) { + chunk_state_ = ChunkState::TRAILER; + } + return static_cast(n); + } + + case ChunkState::TRAILER: { + size_t input_len = evbuffer_get_length(input); + if (input_len < 2) { + wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); + + return -1; + } + + if (evbuffer_drain(input, 2) != 0) { + VLOG(2) << "evbuffer_drain failed"; + } + chunk_state_ = ChunkState::HEADER; + + if (terminal_chunk_seen_) { + receive_closed_ = true; + // Terminal chunk fully consumed; let ReadCallback close the stream. + wslay_event_set_error(ctx_, WSLAY_ERR_WOULDBLOCK); + + return -1; + } + + continue; // Parse the next chunk header. + } + } + } +} + +int BufferEventWebStream::SendMessage(uint8_t opcode, const std::string& msg) { + if (state_ != OPEN || close_pending_) { + return -1; + } + + wslay_event_msg msg_frame = { + opcode, + reinterpret_cast(msg.c_str()), + msg.length()}; + // Queue msg + int rv = wslay_event_queue_msg(ctx_, &msg_frame); + if (rv != 0) { + return rv; + } + + // Force send + return wslay_event_send(ctx_); +} + +void BufferEventWebStream::TryDrain() { + size_t output_len = evbuffer_get_length(bufferevent_get_output(bev_)); + if (output_len == 0) { + state_ = CLOSED; + auto cleanup = std::move(cleanup_cb_); + if (cleanup) { + cleanup(this); + } + } +} diff --git a/wish/cpp/src/buffer_event_web_stream.h b/wish/cpp/src/buffer_event_web_stream.h index 14d17c9..325fbd0 100644 --- a/wish/cpp/src/buffer_event_web_stream.h +++ b/wish/cpp/src/buffer_event_web_stream.h @@ -141,7 +141,8 @@ class BufferEventWebStream : public WebStream { const wslay_event_on_frame_recv_start_arg* arg, void* user_data); - int SendMessage(uint8_t opcode, const std::string& msg); + int SendMessage(uint8_t opcode, const std::string& msg); + void TryDrain(); // Decode one batch of Transfer-Encoding: chunked bytes from the inbound // bufferevent into buf[0..len). Mirrors the wslay recv-callback signature: diff --git a/wish/cpp/src/h2_client.cc b/wish/cpp/src/h2_client.cc index 22e8549..1cd3959 100644 --- a/wish/cpp/src/h2_client.cc +++ b/wish/cpp/src/h2_client.cc @@ -117,6 +117,8 @@ bool H2Client::Init() { void H2Client::SetOnOpen(OpenCallback cb) { on_open_ = cb; } +void H2Client::SetOnError(ErrorCallback cb) { on_error_ = cb; } + int H2Client::Run() { return event_base_dispatch(base_); } @@ -225,8 +227,8 @@ void H2Client::EventCallback(bufferevent* bev, sess->h2session = nullptr; } - if (sess->client->Stop() != 0) { - VLOG(2) << "H2Client::Stop failed"; + if (sess->client->on_error_) { + sess->client->on_error_(); } bufferevent_free(bev); @@ -255,8 +257,8 @@ void H2Client::HandleSessionError(Session* sess) { sess->bev = nullptr; } - if (Stop() != 0) { - VLOG(2) << "H2Client::Stop failed"; + if (on_error_) { + on_error_(); } if (session_ == sess) { diff --git a/wish/cpp/src/h2_client.h b/wish/cpp/src/h2_client.h index 163ff67..d3b64e7 100644 --- a/wish/cpp/src/h2_client.h +++ b/wish/cpp/src/h2_client.h @@ -39,6 +39,7 @@ class H2Client { // Called with the live WebStream once the server responds with 200. using OpenCallback = std::function; using CloseCallback = std::function; + using ErrorCallback = std::function; H2Client(event_base* base, const std::string& host, @@ -48,6 +49,7 @@ class H2Client { bool Init(); void SetOnOpen(OpenCallback cb); + void SetOnError(ErrorCallback cb); int Run(); int Stop(); @@ -127,6 +129,7 @@ class H2Client { Session* session_; OpenCallback on_open_; + ErrorCallback on_error_; }; #endif // WISH_CPP_SRC_H2_CLIENT_H_ diff --git a/wish/cpp/src/h2_tls_client.cc b/wish/cpp/src/h2_tls_client.cc index ef8f30f..2faa975 100644 --- a/wish/cpp/src/h2_tls_client.cc +++ b/wish/cpp/src/h2_tls_client.cc @@ -153,6 +153,8 @@ bool H2TlsClient::Init() { void H2TlsClient::SetOnOpen(OpenCallback cb) { on_open_ = cb; } +void H2TlsClient::SetOnError(ErrorCallback cb) { on_error_ = cb; } + int H2TlsClient::Run() { return event_base_dispatch(base_); } @@ -261,8 +263,8 @@ void H2TlsClient::EventCallback(bufferevent* bev, sess->h2session = nullptr; } - if (sess->client->Stop() != 0) { - VLOG(2) << "H2TlsClient::Stop failed"; + if (sess->client->on_error_) { + sess->client->on_error_(); } bufferevent_free(bev); @@ -291,8 +293,8 @@ void H2TlsClient::HandleSessionError(Session* sess) { sess->bev = nullptr; } - if (Stop() != 0) { - VLOG(2) << "H2TlsClient::Stop failed"; + if (on_error_) { + on_error_(); } if (session_ == sess) { diff --git a/wish/cpp/src/h2_tls_client.h b/wish/cpp/src/h2_tls_client.h index 60e1517..832bc46 100644 --- a/wish/cpp/src/h2_tls_client.h +++ b/wish/cpp/src/h2_tls_client.h @@ -35,6 +35,7 @@ class H2TlsClient { public: using OpenCallback = std::function; using CloseCallback = std::function; + using ErrorCallback = std::function; H2TlsClient(event_base* base, const std::string& host, @@ -47,6 +48,7 @@ class H2TlsClient { bool Init(); void SetOnOpen(OpenCallback cb); + void SetOnError(ErrorCallback cb); int Run(); int Stop(); @@ -127,6 +129,7 @@ class H2TlsClient { Session* session_; OpenCallback on_open_; + ErrorCallback on_error_; }; #endif // WISH_CPP_SRC_H2_TLS_CLIENT_H_ diff --git a/wish/cpp/src/plain_client.cc b/wish/cpp/src/plain_client.cc index f914bd6..aa011c3 100644 --- a/wish/cpp/src/plain_client.cc +++ b/wish/cpp/src/plain_client.cc @@ -84,10 +84,17 @@ bool PlainClient::Init() { } stream_ = std::move(s); + stream_->SetCleanupCallback([this](BufferEventWebStream* s) { stream_.reset(); }); + stream_->SetOnError([this]() { + if (on_error_) { + on_error_(); + } + }); + if (on_open_) { on_open_(stream_.get()); } @@ -97,7 +104,12 @@ bool PlainClient::Init() { }, [this]() { VLOG(1) << "Client handshake failed"; + handshake_.reset(); + + if (on_error_) { + on_error_(); + } }); handshake_->Start(); @@ -109,6 +121,10 @@ void PlainClient::SetOnOpen(OpenCallback cb) { on_open_ = cb; } +void PlainClient::SetOnError(ErrorCallback cb) { + on_error_ = cb; +} + int PlainClient::Run() { return event_base_dispatch(base_); } diff --git a/wish/cpp/src/plain_client.h b/wish/cpp/src/plain_client.h index 9662325..da715da 100644 --- a/wish/cpp/src/plain_client.h +++ b/wish/cpp/src/plain_client.h @@ -33,6 +33,7 @@ class PlainClient { using OpenCallback = std::function; using MessageCallback = std::function; using CloseCallback = std::function; + using ErrorCallback = std::function; PlainClient(event_base* base, const std::string& host, @@ -42,6 +43,7 @@ class PlainClient { bool Init(); void SetOnOpen(OpenCallback cb); + void SetOnError(ErrorCallback cb); int Run(); int Stop(); @@ -58,6 +60,7 @@ class PlainClient { std::unique_ptr stream_; OpenCallback on_open_; + ErrorCallback on_error_; }; #endif // WISH_CPP_SRC_PLAIN_CLIENT_H_ diff --git a/wish/cpp/src/tls_client.cc b/wish/cpp/src/tls_client.cc index fd18ed2..530afd7 100644 --- a/wish/cpp/src/tls_client.cc +++ b/wish/cpp/src/tls_client.cc @@ -111,10 +111,17 @@ bool TlsClient::Init() { } stream_ = std::move(s); + stream_->SetCleanupCallback([this](BufferEventWebStream* s) { stream_.reset(); }); + stream_->SetOnError([this]() { + if (on_error_) { + on_error_(); + } + }); + if (on_open_) { on_open_(stream_.get()); } @@ -126,6 +133,10 @@ bool TlsClient::Init() { VLOG(1) << "Client handshake failed"; handshake_.reset(); + + if (on_error_) { + on_error_(); + } }); handshake_->Start(); @@ -137,6 +148,10 @@ void TlsClient::SetOnOpen(OpenCallback cb) { on_open_ = cb; } +void TlsClient::SetOnError(ErrorCallback cb) { + on_error_ = cb; +} + int TlsClient::Run() { return event_base_dispatch(base_); } diff --git a/wish/cpp/src/tls_client.h b/wish/cpp/src/tls_client.h index 82e25bf..10af921 100644 --- a/wish/cpp/src/tls_client.h +++ b/wish/cpp/src/tls_client.h @@ -34,6 +34,7 @@ class TlsClient { using OpenCallback = std::function; using MessageCallback = std::function; using CloseCallback = std::function; + using ErrorCallback = std::function; TlsClient(event_base* base, const std::string& host, @@ -46,6 +47,7 @@ class TlsClient { bool Init(); void SetOnOpen(OpenCallback cb); + void SetOnError(ErrorCallback cb); int Run(); int Stop(); @@ -68,6 +70,7 @@ class TlsClient { std::unique_ptr stream_; OpenCallback on_open_; + ErrorCallback on_error_; }; #endif // WISH_CPP_SRC_TLS_CLIENT_H_ diff --git a/wish/python/CMakeLists.txt b/wish/python/CMakeLists.txt index cffc426..047056e 100644 --- a/wish/python/CMakeLists.txt +++ b/wish/python/CMakeLists.txt @@ -1,27 +1,43 @@ cmake_minimum_required(VERSION 3.14) -project(wish_python) +project(web_stream_python) + set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + +if (CMAKE_VERSION VERSION_LESS 3.18) + set(DEV_MODULE Development) +else() + set(DEV_MODULE Development.Module) +endif() + +find_package(Python 3.9 COMPONENTS Interpreter ${DEV_MODULE} REQUIRED) + +if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +# Detect the installed nanobind package and import it into CMake +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT) find_package(nanobind CONFIG REQUIRED) -# Force static libevent for the Python extension -set(EVENT__LIBRARY_TYPE STATIC CACHE STRING "" FORCE) # Exclude examples and tests from cpp/CMakeLists.txt -set(BUILD_TESTING OFF CACHE BOOL "" FORCE) set(WISH_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(WISH_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE) set(WISH_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) -add_subdirectory(../cpp wish_cpp_build) +add_subdirectory(../cpp web_stream_cpp_build) + -nanobind_add_module(wish_ext src/wish_ext.cc) +nanobind_add_module(web_stream_ext src/web_stream_ext.cc) -# Provide the include directory so wish_ext.cc can find tls_client.h etc. -target_include_directories(wish_ext PRIVATE ../cpp/src) -target_link_libraries(wish_ext PRIVATE wish_handler event_pthreads) +# Provide the include directory so web_stream_ext.cc can find tls_client.h etc. +target_include_directories(web_stream_ext PRIVATE ../cpp/src) +target_link_libraries(web_stream_ext PRIVATE web_stream event_pthreads) -install(TARGETS wish_ext DESTINATION wish) +install(TARGETS web_stream_ext DESTINATION web_stream) diff --git a/wish/python/README.md b/wish/python/README.md index e09f318..2e2110e 100644 --- a/wish/python/README.md +++ b/wish/python/README.md @@ -1,6 +1,6 @@ -# WiSH Python Bindings +# WebStream Python Bindings -Python bindings for the WiSH protocol (C++). +Python bindings for the web-stream protocol (C++). These bindings provide an asyncio-compatible API and utilize `nanobind` for efficient C++-to-Python interoperability. The `libevent` event loop runs in a background thread, integrated with the Python `asyncio` event loop. @@ -42,10 +42,10 @@ This sub-project is intended to be used with a Python virtual environment (`venv ```python import asyncio -from wish.client import connect +from web_stream.client import connect -async def test_wish_client(): - uri = "wishs://example.com:443" +async def test_web_stream_client(): + uri = "webstreams://example.com:443" # Establish connection async with connect( @@ -56,22 +56,21 @@ async def test_wish_client(): ) as conn: # Send a message - await conn.send("Hello over WiSH!") + await conn.send("Hello over web-stream!") # Receive a message msg = await conn.recv() print("Received:", msg) if __name__ == "__main__": - asyncio.run(test_wish_client()) + asyncio.run(test_web_stream_client()) ``` ## Design Philosophy & Performance To provide maximum throughput without blocking the Python event loop, these bindings rely on several specific architectural choices: -- **Background C++ Event Loop:** The core networking operations and TLS data streaming are handled natively by `libevent`. The `wish/client.py` module uses `loop.run_in_executor()` to spawn this C++ event loop in a background thread. +- **Background C++ Event Loop:** The core networking operations and TLS data streaming are handled natively by `libevent`. The `web_stream/client.py` module uses `loop.run_in_executor()` to spawn this C++ event loop in a background thread. - **GIL Management:** The C++ extension releases the Python GIL while running its `libevent` dispatch loop and when performing heavy network I/O. The Python interpreter only re-acquires the GIL for split seconds to process incoming messages or to initiate send operations. - **Micro-Bindings:** `nanobind` is deliberately chosen for its modern, zero-overhead approach to creating C++ bindings. It has a small binary footprint and fast execution speed when crossing the C++/Python boundary. - **Thread-safe Asyncio Bridging:** When the C++ connection layer has an event (like a successful connection or a received message), it safely bridges back to the Python thread by triggering `loop.call_soon_threadsafe()`. Data is seamlessly funneled into an `asyncio.Queue` so that the Python user processes it using standard, non-blocking asynchronous loops. - diff --git a/wish/python/pyproject.toml b/wish/python/pyproject.toml index 8b76b29..5ef6bda 100644 --- a/wish/python/pyproject.toml +++ b/wish/python/pyproject.toml @@ -3,9 +3,9 @@ requires = ["scikit-build-core", "nanobind"] build-backend = "scikit_build_core.build" [project] -name = "wish" +name = "web_stream" version = "0.1.0" -description = "Python bindings for WiSH Protocol" +description = "Python bindings for web-stream Protocol" readme = "README.md" requires-python = ">=3.8" dependencies = [] @@ -14,5 +14,7 @@ dependencies = [] cmake.version = ">=3.14" ninja.version = ">=1.5" +build-dir = "build/{wheel_tag}" + [tool.scikit-build.cmake.define] CMAKE_CXX_STANDARD = "17" diff --git a/wish/python/src/web_stream_ext.cc b/wish/python/src/web_stream_ext.cc new file mode 100644 index 0000000..17b0d65 --- /dev/null +++ b/wish/python/src/web_stream_ext.cc @@ -0,0 +1,559 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "buffer_event_web_stream.h" +#include "plain_client.h" +#include "tls_client.h" +#include "wish_opcodes.h" + +namespace nb = nanobind; + +// --------------------------------------------------------------------------- +// WebStreamHandlerRef: a shared, nullable handle to WebStream. +// --------------------------------------------------------------------------- + +struct WebStreamHandlerRef { + std::mutex mu; + WebStream* ptr = nullptr; + + int send_text(const std::string& msg) { + std::lock_guard lock(mu); + + if (!ptr) { + throw std::runtime_error("Connection is closed"); + } + return ptr->SendText(msg); + } + + int send_binary(const std::string& msg) { + std::lock_guard lock(mu); + + if (!ptr) { + throw std::runtime_error("Connection is closed"); + } + return ptr->SendBinary(msg); + } + + int close() { + std::lock_guard lock(mu); + + if (!ptr) { + return 0; + } + return ptr->Close(); + } +}; + +// Custom deleter for event_base to manage its lifecycle inside the Py wrappers +struct EventBaseDeleter { + void operator()(event_base* base) const { + if (base) { + event_base_free(base); + } + } +}; + +// --------------------------------------------------------------------------- +// Wrapper structs +// --------------------------------------------------------------------------- + +struct TlsClientPy { + std::unique_ptr base; + + TlsClient client; + + nb::object on_open_cb; + nb::object on_message_cb; + nb::object on_error_cb; + nb::object on_close_cb; + + std::shared_ptr handler_ref; + + // Tracks whether Run() is currently executing. + std::atomic running{false}; + std::mutex stopped_mu; + std::condition_variable stopped_cv; + + std::atomic finalized{false}; + + TlsClientPy(const std::string& ca, + const std::string& cert, + const std::string& key, + const std::string& host, + int port) + : base(event_base_new()), + client(base.get(), + host, + port, + ca, + cert, + key) { + client.SetOnError([this]() { + client.Stop(); + + if (handler_ref) { + std::lock_guard lock(handler_ref->mu); + + handler_ref->ptr = nullptr; + } + + if (on_error_cb) { + nb::gil_scoped_acquire acquire; + + try { + on_error_cb(); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(on_error_cb.ptr()); + } + } + }); + } +}; + +struct PlainClientPy { + std::unique_ptr base; + + PlainClient client; + + nb::object on_open_cb; + nb::object on_message_cb; + nb::object on_error_cb; + nb::object on_close_cb; + + std::shared_ptr handler_ref; + + std::atomic running{false}; + std::mutex stopped_mu; + std::condition_variable stopped_cv; + + std::atomic finalized{false}; + + PlainClientPy(const std::string& host, + int port) + : base(event_base_new()), + client(base.get(), + host, + port) { + client.SetOnError([this]() { + client.Stop(); + + if (handler_ref) { + std::lock_guard lock(handler_ref->mu); + + handler_ref->ptr = nullptr; + } + + if (on_error_cb) { + nb::gil_scoped_acquire acquire; + + try { + on_error_cb(); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(on_error_cb.ptr()); + } + } + }); + } +}; + +// --------------------------------------------------------------------------- +// tp_traverse / tp_clear / tp_finalize for TlsClientPy +// --------------------------------------------------------------------------- + +static void tls_do_cleanup(TlsClientPy* w) { + if (w->finalized.exchange(true, std::memory_order_acq_rel)) { + return; + } + + w->client.Stop(); + + { + PyThreadState* ts = PyEval_SaveThread(); // release GIL + std::unique_lock lk(w->stopped_mu); + bool stopped = w->stopped_cv.wait_for(lk, + std::chrono::seconds(5), + [w] { return !w->running.load(std::memory_order_acquire); }); + PyEval_RestoreThread(ts); // reacquire GIL + if (!stopped) { + PySys_WriteStderr("web_stream_ext: WARNING: event loop did not stop within timeout\n"); + } + } + + w->client.SetOnOpen({}); + + if (w->handler_ref) { + std::lock_guard lock(w->handler_ref->mu); + w->handler_ref->ptr = nullptr; + } + w->handler_ref.reset(); +} + +static void tls_finalize(PyObject* self) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + tls_do_cleanup(w); +} + +static int tls_traverse(PyObject* self, visitproc visit, void* arg) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + Py_VISIT(w->on_error_cb.ptr()); + Py_VISIT(w->on_close_cb.ptr()); + + return 0; +} + +static int tls_clear(PyObject* self) { + TlsClientPy* w = nb::inst_ptr(nb::handle(self)); + + tls_do_cleanup(w); + + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + w->on_error_cb = nb::object(); + w->on_close_cb = nb::object(); + + return 0; +} + +// --------------------------------------------------------------------------- +// tp_traverse / tp_clear / tp_finalize for PlainClientPy +// --------------------------------------------------------------------------- + +static void plain_do_cleanup(PlainClientPy* w) { + if (w->finalized.exchange(true, std::memory_order_acq_rel)) { + return; + } + + w->client.Stop(); + + { + PyThreadState* ts = PyEval_SaveThread(); + std::unique_lock lk(w->stopped_mu); + bool stopped = w->stopped_cv.wait_for(lk, + std::chrono::seconds(5), + [w] { return !w->running.load(std::memory_order_acquire); }); + PyEval_RestoreThread(ts); + if (!stopped) { + PySys_WriteStderr("web_stream_ext: WARNING: event loop did not stop within timeout\n"); + } + } + + w->client.SetOnOpen({}); + + if (w->handler_ref) { + std::lock_guard lock(w->handler_ref->mu); + w->handler_ref->ptr = nullptr; + } + w->handler_ref.reset(); +} + +static void plain_finalize(PyObject* self) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + plain_do_cleanup(w); +} + +static int plain_traverse(PyObject* self, visitproc visit, void* arg) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + + Py_VISIT(w->on_open_cb.ptr()); + Py_VISIT(w->on_message_cb.ptr()); + Py_VISIT(w->on_error_cb.ptr()); + Py_VISIT(w->on_close_cb.ptr()); + + return 0; +} + +static int plain_clear(PyObject* self) { + PlainClientPy* w = nb::inst_ptr(nb::handle(self)); + + plain_do_cleanup(w); + + w->on_open_cb = nb::object(); + w->on_message_cb = nb::object(); + w->on_error_cb = nb::object(); + w->on_close_cb = nb::object(); + + return 0; +} + +// --------------------------------------------------------------------------- + +NB_MODULE(web_stream_ext, m) { +#ifdef _WIN32 + evthread_use_windows_threads(); +#else + evthread_use_pthreads(); +#endif + + nb::class_(m, "BufferEventWebStream") + .def("send_text", &WebStreamHandlerRef::send_text) + .def("send_binary", [](WebStreamHandlerRef& self, nb::object data) { + std::string s; + if (nb::isinstance(data)) { + nb::bytes b = nb::cast(data); + s = std::string(b.c_str(), b.size()); + } else if (nb::isinstance(data)) { + nb::str str_obj = nb::cast(data); + s = str_obj.c_str(); + } else { + throw nb::type_error("send_binary() expects bytes or str"); + } + + std::lock_guard lock(self.mu); + if (!self.ptr) { + throw std::runtime_error("Connection is closed"); + } + return self.ptr->SendBinary(s); + }) + .def("close", &WebStreamHandlerRef::close); + + // ---- TlsClient -------------------------------------------------------- + static PyType_Slot tls_slots[] = { + {Py_tp_traverse, (void*)tls_traverse}, + {Py_tp_clear, (void*)tls_clear}, + {Py_tp_finalize, (void*)tls_finalize}, + {0, nullptr}, + }; + + nb::class_(m, "TlsClient", nb::type_slots(tls_slots)) + .def(nb::init()) + .def("init", [](TlsClientPy& self) -> bool { + if (!self.client.Init()) { + throw std::runtime_error("TlsClient.init() failed"); + } + return true; + }) + .def("set_on_open", [](TlsClientPy& self, nb::object cb) { + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + + self.handler_ref->ptr = nullptr; + } + self.handler_ref.reset(); + + self.on_open_cb = cb; + + auto ref = std::make_shared(); + self.handler_ref = ref; + + self.client.SetOnOpen([&self, ref](WebStream* handler) { + { + std::lock_guard lock(ref->mu); + + ref->ptr = handler; + } + + handler->SetOnClose([ref, &self]() { + { + std::lock_guard lock(ref->mu); + + ref->ptr = nullptr; + } + + self.client.Stop(); + + if (self.on_close_cb) { + nb::gil_scoped_acquire acquire; + + try { + self.on_close_cb(); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_close_cb.ptr()); + } + } + }); + + handler->SetOnMessage([&self](uint8_t opcode, const std::string& msg) { + nb::gil_scoped_acquire acquire; + try { + if (opcode == WEB_STREAM_OPCODE_BINARY || opcode == WEB_STREAM_OPCODE_METADATA) { + self.on_message_cb(opcode, nb::bytes(msg.data(), msg.size())); + } else { + self.on_message_cb(opcode, msg); + } + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_message_cb.ptr()); + } + }); + + nb::gil_scoped_acquire acquire; + try { + self.on_open_cb(ref); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_open_cb.ptr()); + } + }); + }) + .def("set_on_message", [](TlsClientPy& self, nb::object cb) { + self.on_message_cb = cb; + }) + .def("set_on_error", [](TlsClientPy& self, nb::object cb) { + self.on_error_cb = cb; + }) + .def("set_on_close", [](TlsClientPy& self, nb::object cb) { + self.on_close_cb = cb; + }) + .def("run", [](TlsClientPy& self) { + self.running.store(true, std::memory_order_release); + struct RunGuard { + TlsClientPy& s; + ~RunGuard() noexcept { + { + std::lock_guard lk(s.stopped_mu); + s.running.store(false, std::memory_order_release); + } + s.stopped_cv.notify_all(); + } + } guard{self}; + self.client.Run(); }, nb::call_guard()) + .def("stop", [](TlsClientPy& self) { + self.client.Stop(); + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + self.handler_ref->ptr = nullptr; + } }) + .def("__enter__", [](TlsClientPy& self) -> TlsClientPy& { return self; }) + .def("__exit__", [](TlsClientPy& self, nb::object, nb::object, nb::object) { + self.client.Stop(); + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + self.handler_ref->ptr = nullptr; + } }); + + // ---- PlainClient ------------------------------------------------------ + static PyType_Slot plain_slots[] = { + {Py_tp_traverse, (void*)plain_traverse}, + {Py_tp_clear, (void*)plain_clear}, + {Py_tp_finalize, (void*)plain_finalize}, + {0, nullptr}, + }; + + nb::class_(m, "PlainClient", nb::type_slots(plain_slots)) + .def(nb::init()) + .def("init", [](PlainClientPy& self) -> bool { + if (!self.client.Init()) { + throw std::runtime_error("PlainClient.init() failed"); + } + return true; + }) + .def("set_on_open", [](PlainClientPy& self, nb::object cb) { + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + + self.handler_ref->ptr = nullptr; + } + self.handler_ref.reset(); + + self.on_open_cb = cb; + + auto ref = std::make_shared(); + self.handler_ref = ref; + + self.client.SetOnOpen([&self, ref](WebStream* handler) { + { + std::lock_guard lock(ref->mu); + + ref->ptr = handler; + } + + handler->SetOnClose([ref, &self]() { + { + std::lock_guard lock(ref->mu); + + ref->ptr = nullptr; + } + + self.client.Stop(); + + if (self.on_close_cb) { + nb::gil_scoped_acquire acquire; + + try { + self.on_close_cb(); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_close_cb.ptr()); + } + } + }); + + handler->SetOnMessage([&self](uint8_t opcode, const std::string& msg) { + nb::gil_scoped_acquire acquire; + try { + if (opcode == WEB_STREAM_OPCODE_BINARY || opcode == WEB_STREAM_OPCODE_METADATA) { + self.on_message_cb(opcode, nb::bytes(msg.data(), msg.size())); + } else { + self.on_message_cb(opcode, msg); + } + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_message_cb.ptr()); + } + }); + + nb::gil_scoped_acquire acquire; + try { + self.on_open_cb(ref); + } catch (nb::python_error& e) { + e.restore(); + PyErr_WriteUnraisable(self.on_open_cb.ptr()); + } + }); + }) + .def("set_on_message", [](PlainClientPy& self, nb::object cb) { + self.on_message_cb = cb; + }) + .def("set_on_error", [](PlainClientPy& self, nb::object cb) { + self.on_error_cb = cb; + }) + .def("set_on_close", [](PlainClientPy& self, nb::object cb) { + self.on_close_cb = cb; + }) + .def("run", [](PlainClientPy& self) { + self.running.store(true, std::memory_order_release); + struct RunGuard { + PlainClientPy& s; + ~RunGuard() noexcept { + { + std::lock_guard lk(s.stopped_mu); + s.running.store(false, std::memory_order_release); + } + s.stopped_cv.notify_all(); + } + } guard{self}; + self.client.Run(); }, nb::call_guard()) + .def("stop", [](PlainClientPy& self) { + self.client.Stop(); + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + self.handler_ref->ptr = nullptr; + } }) + .def("__enter__", [](PlainClientPy& self) -> PlainClientPy& { return self; }) + .def("__exit__", [](PlainClientPy& self, nb::object, nb::object, nb::object) { + self.client.Stop(); + if (self.handler_ref) { + std::lock_guard lock(self.handler_ref->mu); + self.handler_ref->ptr = nullptr; + } }); +} diff --git a/wish/python/src/wish_ext.cc b/wish/python/src/wish_ext.cc deleted file mode 100644 index 92cdf72..0000000 --- a/wish/python/src/wish_ext.cc +++ /dev/null @@ -1,344 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "plain_client.h" -#include "tls_client.h" -#include "wish_handler.h" - -namespace nb = nanobind; - -// --------------------------------------------------------------------------- -// WishHandlerRef: a shared, nullable handle to WishHandler. -// -// The raw WishHandler* lives only as long as the connection is open. -// WishHandler::EventCallback fires on_close_ BEFORE self-deleting; our -// on_close hook nullifies ptr under the mutex so any concurrent call from -// the Python thread via send_text/send_binary sees nullptr and raises -// RuntimeError rather than dereferencing freed memory. -// --------------------------------------------------------------------------- - -struct WishHandlerRef { - std::mutex mu; - WishHandler* ptr = nullptr; - - int send_text(const std::string& msg) { - std::lock_guard lock(mu); - if (!ptr) throw std::runtime_error("Connection is closed"); - return ptr->SendText(msg); - } - - int send_binary(const std::string& msg) { - std::lock_guard lock(mu); - if (!ptr) throw std::runtime_error("Connection is closed"); - return ptr->SendBinary(msg); - } -}; - -// --------------------------------------------------------------------------- -// Wrapper structs -// -// Storing the Python callbacks as direct nb::object members (rather than -// inside a std::function closure) makes them visible to Python's cyclic GC -// via tp_traverse / tp_clear, breaking cycles such as: -// -// WishConnection → TlsClient (Python) → on_open_ lambda -// → nb::object → Python closure → WishConnection -// --------------------------------------------------------------------------- - -struct TlsClientPy { - TlsClient client; - nb::object on_open_cb; - nb::object on_message_cb; - std::shared_ptr handler_ref; - - // Tracks whether Run() is currently executing. Used by tls_clear to - // wait for the event loop to exit before it clears the callbacks. - std::atomic running{false}; - std::mutex stopped_mu; - std::condition_variable stopped_cv; - - // Guards against running cleanup twice (once from tp_finalize, once from - // tp_clear in the GC path). - std::atomic finalized{false}; - - TlsClientPy(const std::string& ca, const std::string& cert, - const std::string& key, const std::string& host, int port) - : client(ca, cert, key, host, port) {}; -}; - -struct PlainClientPy { - PlainClient client; - nb::object on_open_cb; - nb::object on_message_cb; - std::shared_ptr handler_ref; - - std::atomic running{false}; - std::mutex stopped_mu; - std::condition_variable stopped_cv; - - std::atomic finalized{false}; - - PlainClientPy(const std::string& host, int port) - : client(host, port) {} -}; - -// --------------------------------------------------------------------------- -// tp_traverse / tp_clear / tp_finalize for TlsClientPy -// --------------------------------------------------------------------------- - -// tls_do_cleanup: stop the event loop, wait for it to exit, then release all -// C++ callbacks. Safe to call multiple times (idempotent via finalized flag). -static void tls_do_cleanup(TlsClientPy* w) { - if (w->finalized.exchange(true, std::memory_order_acq_rel)) return; - - // 1. Ask the event loop to stop. event_base_loopexit is thread-safe. - w->client.Stop(); - - // 2. Release the GIL and wait for Run() to return. - // - // Without this step there is a data race: - // - Event loop thread reads on_open_ / on_message_ and then blocks - // waiting to acquire the GIL (nb::gil_scoped_acquire). - // - GC thread (holding the GIL) writes those same std::function - // objects via SetOnOpen({}) etc. → UB. - // By releasing the GIL here we let any in-flight callback finish, after - // which event_base_dispatch returns and Run() signals stopped_cv. - { - PyThreadState* ts = PyEval_SaveThread(); // release GIL - std::unique_lock lk(w->stopped_mu); - w->stopped_cv.wait(lk, [w] { return !w->running.load(std::memory_order_acquire); }); - PyEval_RestoreThread(ts); // reacquire GIL - } - - // 3. Event loop has exited; mutations are now single-threaded and safe. - w->client.SetOnOpen({}); - w->client.SetOnClose({}); - w->client.SetOnMessage({}); - // Invalidate the safe handle so Python code can no longer call through it. - if (w->handler_ref) { - std::lock_guard lock(w->handler_ref->mu); - w->handler_ref->ptr = nullptr; - } - w->handler_ref.reset(); -} - -// tp_finalize is called for BOTH the normal refcount destruction path and the -// GC path (before tp_clear / tp_dealloc). Putting the cleanup here ensures -// it runs regardless of whether a reference cycle was involved. -static void tls_finalize(PyObject* self) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - tls_do_cleanup(w); - // Do NOT release on_open_cb / on_message_cb here: they are Python objects - // that tp_traverse must still be able to visit until tp_clear runs. -} - -static int tls_traverse(PyObject* self, visitproc visit, void* arg) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - Py_VISIT(w->on_open_cb.ptr()); - Py_VISIT(w->on_message_cb.ptr()); - return 0; -} - -static int tls_clear(PyObject* self) { - TlsClientPy* w = nb::inst_ptr(nb::handle(self)); - // tls_do_cleanup is idempotent; if tp_finalize already ran this is a no-op. - tls_do_cleanup(w); - // Drop Python object references to break the cycle. - w->on_open_cb = nb::object(); - w->on_message_cb = nb::object(); - return 0; -} - -// --------------------------------------------------------------------------- -// tp_traverse / tp_clear / tp_finalize for PlainClientPy -// --------------------------------------------------------------------------- - -static void plain_do_cleanup(PlainClientPy* w) { - if (w->finalized.exchange(true, std::memory_order_acq_rel)) return; - - w->client.Stop(); - - { - PyThreadState* ts = PyEval_SaveThread(); - std::unique_lock lk(w->stopped_mu); - w->stopped_cv.wait(lk, [w] { return !w->running.load(std::memory_order_acquire); }); - PyEval_RestoreThread(ts); - } - - w->client.SetOnOpen({}); - w->client.SetOnClose({}); - w->client.SetOnMessage({}); - if (w->handler_ref) { - std::lock_guard lock(w->handler_ref->mu); - w->handler_ref->ptr = nullptr; - } - w->handler_ref.reset(); -} - -static void plain_finalize(PyObject* self) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); - plain_do_cleanup(w); -} - -static int plain_traverse(PyObject* self, visitproc visit, void* arg) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); - Py_VISIT(w->on_open_cb.ptr()); - Py_VISIT(w->on_message_cb.ptr()); - return 0; -} - -static int plain_clear(PyObject* self) { - PlainClientPy* w = nb::inst_ptr(nb::handle(self)); - plain_do_cleanup(w); - w->on_open_cb = nb::object(); - w->on_message_cb = nb::object(); - return 0; -} - -// --------------------------------------------------------------------------- - -NB_MODULE(wish_ext, m) { - // Enable libevent thread-safety -#ifdef _WIN32 - evthread_use_windows_threads(); -#else - evthread_use_pthreads(); -#endif - - nb::class_>(m, "WishHandler") - .def("send_text", &WishHandlerRef::send_text) - .def("send_binary", &WishHandlerRef::send_binary); - - // ---- TlsClient -------------------------------------------------------- - static PyType_Slot tls_slots[] = { - {Py_tp_traverse, (void*)tls_traverse}, - {Py_tp_clear, (void*)tls_clear}, - {Py_tp_finalize, (void*)tls_finalize}, - {0, nullptr}, - }; - - nb::class_(m, "TlsClient", nb::type_slots(tls_slots)) - .def(nb::init()) - .def("init", [](TlsClientPy& self) { - if (!self.client.Init()) { - throw std::runtime_error("TlsClient.init() failed"); - } - }) - .def("set_on_open", [](TlsClientPy& self, nb::object cb) { - self.on_open_cb = cb; - // Create a fresh WishHandlerRef for this connection attempt. - auto ref = std::make_shared(); - self.handler_ref = ref; - // Wire the close notification: nullify ptr before WishHandler is - // deleted so Python cannot reach freed memory. - self.client.SetOnClose([ref]() { - std::lock_guard lock(ref->mu); - ref->ptr = nullptr; - }); - // Capture self by pointer; lifetime is safe because the lambda - // lives inside self.client and is always cleared before self - // is destroyed (either by tp_clear or ~TlsClient). - self.client.SetOnOpen([&self, ref](WishHandler* handler) { - { - std::lock_guard lock(ref->mu); - ref->ptr = handler; - } - nb::gil_scoped_acquire acquire; - try { - self.on_open_cb(ref); - } catch (nb::python_error& e) { - e.restore(); - PyErr_WriteUnraisable(nullptr); - } - }); - }) - .def("set_on_message", [](TlsClientPy& self, nb::object cb) { - self.on_message_cb = cb; - self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { - nb::gil_scoped_acquire acquire; - try { - self.on_message_cb(opcode, msg); - } catch (nb::python_error& e) { - e.restore(); - PyErr_WriteUnraisable(nullptr); - } - }); - }) - .def("run", [](TlsClientPy& self) { - self.running.store(true, std::memory_order_release); - self.client.Run(); // blocks in event_base_dispatch with GIL released - { - std::lock_guard lk(self.stopped_mu); - self.running.store(false, std::memory_order_release); - } - self.stopped_cv.notify_all(); }, nb::call_guard()) - .def("stop", [](TlsClientPy& self) { self.client.Stop(); }); - - // ---- PlainClient ------------------------------------------------------ - static PyType_Slot plain_slots[] = { - {Py_tp_traverse, (void*)plain_traverse}, - {Py_tp_clear, (void*)plain_clear}, - {Py_tp_finalize, (void*)plain_finalize}, - {0, nullptr}, - }; - - nb::class_(m, "PlainClient", nb::type_slots(plain_slots)) - .def(nb::init()) - .def("init", [](PlainClientPy& self) { - if (!self.client.Init()) { - throw std::runtime_error("PlainClient.init() failed"); - } - }) - .def("set_on_open", [](PlainClientPy& self, nb::object cb) { - self.on_open_cb = cb; - // Create a fresh WishHandlerRef for this connection attempt. - auto ref = std::make_shared(); - self.handler_ref = ref; - self.client.SetOnClose([ref]() { - std::lock_guard lock(ref->mu); - ref->ptr = nullptr; - }); - self.client.SetOnOpen([&self, ref](WishHandler* handler) { - { - std::lock_guard lock(ref->mu); - ref->ptr = handler; - } - nb::gil_scoped_acquire acquire; - try { - self.on_open_cb(ref); - } catch (nb::python_error& e) { - e.restore(); - PyErr_WriteUnraisable(nullptr); - } - }); - }) - .def("set_on_message", [](PlainClientPy& self, nb::object cb) { - self.on_message_cb = cb; - self.client.SetOnMessage([&self](uint8_t opcode, const std::string& msg) { - nb::gil_scoped_acquire acquire; - try { - self.on_message_cb(opcode, msg); - } catch (nb::python_error& e) { - e.restore(); - PyErr_WriteUnraisable(nullptr); - } - }); - }) - .def("run", [](PlainClientPy& self) { - self.running.store(true, std::memory_order_release); - self.client.Run(); - { - std::lock_guard lk(self.stopped_mu); - self.running.store(false, std::memory_order_release); - } - self.stopped_cv.notify_all(); }, nb::call_guard()) - .def("stop", [](PlainClientPy& self) { self.client.Stop(); }); -} diff --git a/wish/python/tests/test_callback_exceptions.py b/wish/python/tests/test_callback_exceptions.py index f070d09..9116463 100644 --- a/wish/python/tests/test_callback_exceptions.py +++ b/wish/python/tests/test_callback_exceptions.py @@ -1,7 +1,7 @@ -"""Tests that Python exceptions raised inside wish_ext callbacks do not +"""Tests that Python exceptions raised inside web_stream_ext callbacks do not propagate through the C libevent stack and crash the process. -Each callback lambda in wish_ext.cc wraps the Python call in +Each callback lambda in web_stream_ext.cc wraps the Python call in try/catch(nb::python_error) and routes the exception through PyErr_WriteUnraisable, which invokes sys.unraisablehook. These tests install a temporary hook to capture that notification and assert on it. @@ -28,18 +28,18 @@ def get_free_port() -> int: return s.getsockname()[1] -def _import_wish_ext(): +def _import_web_stream_ext(): try: - from wish import wish_ext - return wish_ext + from web_stream import web_stream_ext + return web_stream_ext except ImportError: return None -wish_ext = _import_wish_ext() +web_stream_ext = _import_web_stream_ext() -@unittest.skipIf(wish_ext is None, "wish_ext extension module not available - run 'pip install .'") +@unittest.skipIf(web_stream_ext is None, "web_stream_ext extension module not available - run 'pip install .'") @unittest.skipUnless( os.path.exists(SERVER_PLAIN_BIN), f"Plain echo server not found at {SERVER_PLAIN_BIN} - compile the C++ project first", @@ -96,7 +96,7 @@ def hook(info): return captured, lambda: setattr(sys, "unraisablehook", original) def _make_plain_client(self): - client = wish_ext.PlainClient("127.0.0.1", self.port) + client = web_stream_ext.PlainClient("127.0.0.1", self.port) self.assertTrue(client.init(), "PlainClient.init() returned False") return client diff --git a/wish/python/tests/test_client.py b/wish/python/tests/test_client.py index 64f4e2b..b7d47cd 100644 --- a/wish/python/tests/test_client.py +++ b/wish/python/tests/test_client.py @@ -1,13 +1,13 @@ import asyncio -import wish +import web_stream # Certs created for the tls_echo_server example (adjust paths if needed) -CERTS_DIR = "/home/ysnysnysn/net_http/wish/cpp/certs" +CERTS_DIR = "../../cpp/certs" async def main(): - print("Connecting to WiSH server...") - async with wish.connect( - "wish://127.0.0.1:8080", + print("Connecting to WebStream server...") + async with web_stream.connect( + "webstream://127.0.0.1:8080", ca_file=f"{CERTS_DIR}/ca.crt", cert_file=f"{CERTS_DIR}/client.crt", key_file=f"{CERTS_DIR}/client.key", diff --git a/wish/python/tests/test_dangling_pointer.py b/wish/python/tests/test_dangling_pointer.py index 4b3f239..480226a 100644 --- a/wish/python/tests/test_dangling_pointer.py +++ b/wish/python/tests/test_dangling_pointer.py @@ -1,9 +1,9 @@ -"""Tests that accessing a WishHandler after the connection closes does not +"""Tests that accessing a BufferEventWebStream after the connection closes does not cause a use-after-free crash. -After the server closes the connection (EOF), WishHandler::EventCallback calls -on_close_() which nullifies WishHandlerRef::ptr under a mutex. Any subsequent -call through the Python WishHandler object must raise RuntimeError rather than +After the server closes the connection (EOF), BufferEventWebStream::EventCallback calls +on_close_() which nullifies WebStreamHandlerRef::ptr under a mutex. Any subsequent +call through the Python BufferEventWebStream object must raise RuntimeError rather than dereferencing freed memory. """ @@ -27,24 +27,24 @@ def get_free_port() -> int: return s.getsockname()[1] -def _import_wish_ext(): +def _import_web_stream_ext(): try: - from wish import wish_ext # noqa: PLC0415 - return wish_ext + from web_stream import web_stream_ext # noqa: PLC0415 + return web_stream_ext except ImportError: return None -wish_ext = _import_wish_ext() +web_stream_ext = _import_web_stream_ext() -@unittest.skipIf(wish_ext is None, "wish_ext extension module not available – run 'pip install .'") +@unittest.skipIf(web_stream_ext is None, "web_stream_ext extension module not available – run 'pip install .'") @unittest.skipUnless( os.path.exists(SERVER_PLAIN_BIN), f"Plain echo server not found at {SERVER_PLAIN_BIN} – compile the C++ project first", ) class TestDanglingPointer(unittest.TestCase): - """Verify that WishHandler cannot be used after the connection is closed.""" + """Verify that BufferEventWebStream cannot be used after the connection is closed.""" port: int server_proc: subprocess.Popen @@ -74,7 +74,7 @@ def tearDownClass(cls) -> None: # ------------------------------------------------------------------ def _make_plain_client(self): - client = wish_ext.PlainClient("127.0.0.1", self.port) + client = web_stream_ext.PlainClient("127.0.0.1", self.port) self.assertTrue(client.init(), "PlainClient.init() returned False") return client @@ -109,7 +109,7 @@ def on_open(handler): handler = captured_handler[0] # The event loop has exited and stop() was called. The handler's - # WishHandlerRef was invalidated when the connection closed. + # WebStreamHandlerRef was invalidated when the connection closed. # send_text must raise RuntimeError, not segfault. with self.assertRaises(RuntimeError): handler.send_text("should fail") @@ -143,7 +143,7 @@ def on_open(handler): def test_handler_invalidated_before_on_close_returns(self): """By the time the Python on_close callback (if any) runs, the - WishHandlerRef must already be invalidated. + WebStreamHandlerRef must already be invalidated. We verify this indirectly: open a connection, stop the client, and confirm that the handler ref is invalid immediately after stop() diff --git a/wish/python/tests/test_e2e.py b/wish/python/tests/test_e2e.py index 11272cc..8d6341a 100644 --- a/wish/python/tests/test_e2e.py +++ b/wish/python/tests/test_e2e.py @@ -5,14 +5,14 @@ import tempfile import sys import unittest -import wish +import web_stream # Project root based on known directory structure (wish/python/tests/test_client.py) TEST_DIR = os.path.dirname(os.path.abspath(__file__)) PROJECT_ROOT = os.path.abspath(os.path.join(TEST_DIR, "..", "..")) SERVER_BIN = os.path.join(PROJECT_ROOT, "cpp", "build", "examples", "tls_echo_server") -SERVER_PLAIN_BIN = os.path.join(PROJECT_ROOT, "cpp", "build", "examples", "echo_server") +SERVER_PLAIN_BIN = os.path.join(PROJECT_ROOT, "cpp", "build", "examples", "plain_echo_server") CERTS_DIR = os.path.join(PROJECT_ROOT, "cpp", "certs") def get_free_port(): @@ -22,7 +22,7 @@ def get_free_port(): s.close() return port -class TestWishClientE2E(unittest.TestCase): +class TestWebStreamClientE2E(unittest.TestCase): @classmethod def setUpClass(cls): if not os.path.exists(SERVER_BIN): @@ -60,9 +60,9 @@ async def run(): # Allow the server more time to start up to avoid Connection reset await asyncio.sleep(1.0) - uri = f"wishs://127.0.0.1:{self.port}" + uri = f"webstreams://127.0.0.1:{self.port}" - async with wish.connect( + async with web_stream.connect( uri, ca_file=os.path.join(CERTS_DIR, "ca.crt"), cert_file=os.path.join(CERTS_DIR, "client.crt"), @@ -80,7 +80,7 @@ async def run(): loop.run_until_complete(run()) -class TestWishClientPlainE2E(unittest.TestCase): +class TestWebStreamClientPlainE2E(unittest.TestCase): @classmethod def setUpClass(cls): if not os.path.exists(SERVER_PLAIN_BIN): @@ -115,9 +115,9 @@ async def run(): # Allow the server more time to start up to avoid Connection reset await asyncio.sleep(1.0) - uri = f"wish://127.0.0.1:{self.port}" + uri = f"webstream://127.0.0.1:{self.port}" - async with wish.connect(uri) as ws: + async with web_stream.connect(uri) as ws: test_msg = "Hello E2E from Python over plain TCP!" await ws.send(test_msg) @@ -129,6 +129,42 @@ async def run(): asyncio.set_event_loop(loop) loop.run_until_complete(run()) + def test_server_disconnect_detect(self): + port = get_free_port() + cmd = [ + SERVER_PLAIN_BIN, + f"--port={port}", + ] + server_proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + async def run(): + await asyncio.sleep(1.0) + uri = f"webstream://127.0.0.1:{port}" + async with web_stream.connect(uri) as ws: + # Terminate server to force disconnect + server_proc.terminate() + server_proc.wait() + + # The client should detect this disconnect and raise ConnectionError or ConnectionAbortedError on recv() + try: + await asyncio.wait_for(ws.recv(), timeout=5.0) + self.fail("Expected ConnectionError or ConnectionAbortedError, but no exception was raised") + except (ConnectionError, ConnectionAbortedError): + pass + + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(run()) + finally: + server_proc.terminate() + server_proc.wait() + if __name__ == "__main__": result = unittest.main(exit=False).result diff --git a/wish/python/web_stream/__init__.py b/wish/python/web_stream/__init__.py new file mode 100644 index 0000000..4d5a628 --- /dev/null +++ b/wish/python/web_stream/__init__.py @@ -0,0 +1,3 @@ +from .client import connect, WebStreamConnection + +__all__ = ["connect", "WebStreamConnection"] diff --git a/wish/python/web_stream/client.py b/wish/python/web_stream/client.py new file mode 100644 index 0000000..ff9675c --- /dev/null +++ b/wish/python/web_stream/client.py @@ -0,0 +1,125 @@ +import asyncio +import threading +from urllib.parse import urlparse + +from . import web_stream_ext + +class WebStreamConnection: + def __init__(self, host, port, tls, ca_file="", cert_file="", key_file=""): + self._host = host + self._port = port + + if tls: + self._client = web_stream_ext.TlsClient(ca_file, cert_file, key_file, host, port) + else: + self._client = web_stream_ext.PlainClient(host, port) + + self._client.init() # raises RuntimeError on failure + + self._loop = asyncio.get_running_loop() + + self._recv_queue = asyncio.Queue() + self._open_future = self._loop.create_future() + self._thread = None + self._handler = None + + def on_open(handler): + def set_handler(): + self._handler = handler + self._open_future.set_result(True) + self._loop.call_soon_threadsafe(set_handler) + + def on_message(opcode, msg): + self._loop.call_soon_threadsafe(self._recv_queue.put_nowait, (opcode, msg)) + + def on_error(): + def set_error(): + if not self._open_future.done(): + self._open_future.set_exception(ConnectionError("Connection failed or lost")) + else: + self._recv_queue.put_nowait(ConnectionError("Connection lost")) + self._loop.call_soon_threadsafe(set_error) + + def on_close(): + def set_close(): + self._recv_queue.put_nowait(ConnectionAbortedError("Connection closed")) + self._loop.call_soon_threadsafe(set_close) + + self._client.set_on_open(on_open) + self._client.set_on_message(on_message) + self._client.set_on_error(on_error) + self._client.set_on_close(on_close) + + async def connect(self): + # Run the C++ event loop in a background daemon thread. + self._thread = threading.Thread(target=self._client.run, daemon=True) + self._thread.start() + # Wait until the on_open callback fires + await self._open_future + return self + + async def close(self): + """Sends EoF (Close) over the WebStream connection.""" + if self._handler: + self._handler.close() + self._client = None + + async def send(self, data): + """Sends data over the WebStream connection. If data is bytes, sends as binary, else text.""" + if not self._handler: + raise RuntimeError("Connection is not open") + + if isinstance(data, bytes): + self._handler.send_binary(data) + else: + self._handler.send_text(str(data)) + + async def recv(self): + """Receives a message from the WebStream connection.""" + if not self._client: + raise RuntimeError("Connection is closed") + + res = await self._recv_queue.get() + + if isinstance(res, Exception): + raise res + + opcode, msg = res + # You can process opcode here if you want to distinguish text/binary + # We'll just return the message + # In actual implementation: 1=Text, 2=Binary + if opcode == 2: + return msg.encode("utf-8") if isinstance(msg, str) else msg + else: + return msg + +class _ConnectContextManager: + def __init__(self, uri, ca_file="", cert_file="", key_file=""): + parsed = urlparse(uri) + + self.tls = parsed.scheme in ("webstreams", "https") + + self.host = parsed.hostname + self.port = parsed.port or (443 if self.tls else 80) + + self.ca_file = ca_file + self.cert_file = cert_file + self.key_file = key_file + + self.conn = None + + async def __aenter__(self): + self.conn = WebStreamConnection(self.host, + self.port, + self.tls, + self.ca_file, + self.cert_file, + self.key_file) + await self.conn.connect() + return self.conn + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.conn.close() + +def connect(uri, ca_file="", cert_file="", key_file=""): + return _ConnectContextManager(uri, ca_file, cert_file, key_file) diff --git a/wish/python/wish/__init__.py b/wish/python/wish/__init__.py deleted file mode 100644 index f3d60fe..0000000 --- a/wish/python/wish/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .client import connect, WishConnection - -__all__ = ["connect", "WishConnection"] diff --git a/wish/python/wish/client.py b/wish/python/wish/client.py deleted file mode 100644 index 495bd19..0000000 --- a/wish/python/wish/client.py +++ /dev/null @@ -1,90 +0,0 @@ -import asyncio -import threading -from urllib.parse import urlparse - -from . import wish_ext - -class WishConnection: - def __init__(self, host, port, tls, ca_file="", cert_file="", key_file=""): - self._host = host - self._port = port - if tls: - self._client = wish_ext.TlsClient(ca_file, cert_file, key_file, host, port) - else: - self._client = wish_ext.PlainClient(host, port) - self._client.init() # raises RuntimeError on failure - - self._loop = asyncio.get_running_loop() - self._recv_queue = asyncio.Queue() - self._open_future = self._loop.create_future() - self._run_future = None - self._handler = None - - def on_open(handler): - self._handler = handler - self._loop.call_soon_threadsafe(self._open_future.set_result, True) - - def on_message(opcode, msg): - self._loop.call_soon_threadsafe(self._recv_queue.put_nowait, (opcode, msg)) - - self._client.set_on_open(on_open) - self._client.set_on_message(on_message) - - async def connect(self): - # Run the C++ event loop in a background thread. - # Keep the Future so we can await thread completion on close. - self._run_future = self._loop.run_in_executor(None, self._client.run) - # Wait until the on_open callback fires - await self._open_future - return self - - async def close(self): - """Stop the C++ event loop and wait for the background thread to exit.""" - self._client.stop() - if self._run_future is not None: - await self._run_future - self._run_future = None - - async def send(self, data): - """Sends data over the WiSH connection. If data is bytes, sends as binary, else text.""" - if not self._handler: - raise RuntimeError("Connection is not open") - - if isinstance(data, bytes): - self._handler.send_binary(data) - else: - self._handler.send_text(str(data)) - - async def recv(self): - """Receives a message from the WiSH connection.""" - opcode, msg = await self._recv_queue.get() - # You can process opcode here if you want to distinguish text/binary - # We'll just return the message - # In actual implementation: 1=Text, 2=Binary - if opcode == 2: - return msg.encode('utf-8') if isinstance(msg, str) else msg - else: - return msg - -class _ConnectContextManager: - def __init__(self, uri, ca_file="", cert_file="", key_file=""): - parsed = urlparse(uri) - self.tls = parsed.scheme in ("wishs", "wss", "https") - self.host = parsed.hostname - self.port = parsed.port or (443 if self.tls else 80) - self.ca_file = ca_file - self.cert_file = cert_file - self.key_file = key_file - self.conn = None - - async def __aenter__(self): - self.conn = WishConnection(self.host, self.port, self.tls, - self.ca_file, self.cert_file, self.key_file) - await self.conn.connect() - return self.conn - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.conn.close() - -def connect(uri, ca_file="", cert_file="", key_file=""): - return _ConnectContextManager(uri, ca_file, cert_file, key_file)