diff --git a/bindings/cpp/CMakeLists.txt b/bindings/cpp/CMakeLists.txt index 7869c4822..fff255e09 100644 --- a/bindings/cpp/CMakeLists.txt +++ b/bindings/cpp/CMakeLists.txt @@ -123,7 +123,7 @@ if (SVS_RUNTIME_ENABLE_LVQ_LEANVEC) else() # Links to LTO-enabled static library, requires GCC/G++ 11.2 if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "11.2" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS "11.3") - set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/v0.3.0/svs-shared-library-0.3.0-lto-ivf.tar.gz" + set(SVS_URL "https://github.com/intel/ScalableVectorSearch/releases/download/nightly/svs-shared-library-lto-nightly-2026-04-28-1226.tar.gz" CACHE STRING "URL to download SVS shared library") else() message(WARNING diff --git a/bindings/cpp/include/svs/runtime/flat_index.h b/bindings/cpp/include/svs/runtime/flat_index.h index 9c635e35d..87410e614 100644 --- a/bindings/cpp/include/svs/runtime/flat_index.h +++ b/bindings/cpp/include/svs/runtime/flat_index.h @@ -44,6 +44,16 @@ struct SVS_RUNTIME_API FlatIndex { virtual Status save(std::ostream& out) const noexcept = 0; static Status load(FlatIndex** index, std::istream& in, MetricType metric) noexcept; + + // Load from a memory-mapped file. + // The file is expected to be in the format produced by save(). + static Status + map_to_file(FlatIndex** index, const char* path, MetricType metric) noexcept; + + // Load from a memory buffer. + // The buffer is expected to be in the format produced by save(). + static Status + map_to_memory(FlatIndex** index, void* data, size_t size, MetricType metric) noexcept; }; } // namespace v0 diff --git a/bindings/cpp/include/svs/runtime/vamana_index.h b/bindings/cpp/include/svs/runtime/vamana_index.h index 988319528..fa6dc8049 100644 --- a/bindings/cpp/include/svs/runtime/vamana_index.h +++ b/bindings/cpp/include/svs/runtime/vamana_index.h @@ -95,6 +95,22 @@ struct SVS_RUNTIME_API VamanaIndex { static Status load( VamanaIndex** index, std::istream& in, MetricType metric, StorageKind storage_kind ) noexcept; + + // Load from a memory-mapped file. + // The file is expected to be in the format produced by save(). + static Status map_to_file( + VamanaIndex** index, const char* path, MetricType metric, StorageKind storage_kind + ) noexcept; + + // Load from a memory buffer. + // The buffer is expected to be in the format produced by save(). + static Status map_to_memory( + VamanaIndex** index, + void* data, + size_t size, + MetricType metric, + StorageKind storage_kind + ) noexcept; }; struct SVS_RUNTIME_API VamanaIndexLeanVec : public VamanaIndex { diff --git a/bindings/cpp/src/flat_index.cpp b/bindings/cpp/src/flat_index.cpp index f2aa22453..bce25b5e4 100644 --- a/bindings/cpp/src/flat_index.cpp +++ b/bindings/cpp/src/flat_index.cpp @@ -115,5 +115,33 @@ Status FlatIndex::load(FlatIndex** index, std::istream& in, MetricType metric) n return Status_Ok; }); } + +Status +FlatIndex::map_to_file(FlatIndex** index, const char* path, MetricType metric) noexcept { + *index = nullptr; + return runtime_error_wrapper([&] { + std::filesystem::path fs_path(path); + auto is = std::make_unique(fs_path); + std::unique_ptr impl{ + FlatIndexImpl::map_to_stream(std::move(is), metric)}; + *index = new FlatIndexManager{std::move(impl)}; + return Status_Ok; + }); +} + +Status FlatIndex::map_to_memory( + FlatIndex** index, void* data, size_t size, MetricType metric +) noexcept { + *index = nullptr; + return runtime_error_wrapper([&] { + auto sp = std::span(reinterpret_cast(data), size); + auto is = std::make_unique(sp); + std::unique_ptr impl{ + FlatIndexImpl::map_to_stream(std::move(is), metric)}; + *index = new FlatIndexManager{std::move(impl)}; + return Status_Ok; + }); +} + } // namespace runtime } // namespace svs diff --git a/bindings/cpp/src/flat_index_impl.h b/bindings/cpp/src/flat_index_impl.h index 02e07f54d..cdfa2200f 100644 --- a/bindings/cpp/src/flat_index_impl.h +++ b/bindings/cpp/src/flat_index_impl.h @@ -22,6 +22,7 @@ #include #include +#include #include #include @@ -111,12 +112,39 @@ class FlatIndexImpl { }); } + static FlatIndexImpl* + map_to_stream(std::unique_ptr&& in, MetricType metric) { + if (!svs::io::is_memory_stream(*in)) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "Provided stream is not a memory stream"}; + } + auto threadpool = default_threadpool(); + using storage_type = svs::runtime::storage:: + StorageType_t>; + + svs::DistanceDispatcher distance_dispatcher(to_svs_distance(metric)); + return distance_dispatcher([&](auto&& distance) { + auto impl = new svs::Flat{svs::Flat::assemble( + *in, std::forward(distance), std::move(threadpool) + )}; + + return new FlatIndexImpl( + std::unique_ptr{impl}, metric, std::move(in) + ); + }); + } + protected: // Constructor used during loading - FlatIndexImpl(std::unique_ptr&& impl, MetricType metric) + FlatIndexImpl( + std::unique_ptr&& impl, + MetricType metric, + std::unique_ptr mapped_stream = nullptr + ) : dim_{impl->dimensions()} , metric_type_{metric} - , impl_{std::move(impl)} {} + , impl_{std::move(impl)} + , mapped_stream_{std::move(mapped_stream)} {} void init_impl(data::ConstSimpleDataView data) { auto threadpool = default_threadpool(); @@ -139,6 +167,8 @@ class FlatIndexImpl { size_t dim_; MetricType metric_type_; std::unique_ptr impl_; + // For memory-mapping, we need to keep the stream alive as long as the index is alive + std::unique_ptr mapped_stream_; }; } // namespace runtime } // namespace svs diff --git a/bindings/cpp/src/vamana_index.cpp b/bindings/cpp/src/vamana_index.cpp index c015dd21e..c0dd3d757 100644 --- a/bindings/cpp/src/vamana_index.cpp +++ b/bindings/cpp/src/vamana_index.cpp @@ -144,6 +144,38 @@ Status VamanaIndex::load( }); } +Status VamanaIndex::map_to_file( + VamanaIndex** index, const char* path, MetricType metric, StorageKind storage_kind +) noexcept { + using Impl = VamanaIndexImpl; + *index = nullptr; + return runtime_error_wrapper([&] { + std::filesystem::path fs_path(path); + auto is = std::make_unique(fs_path); + std::unique_ptr impl{ + Impl::map_to_stream(std::move(is), metric, storage_kind)}; + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + +Status VamanaIndex::map_to_memory( + VamanaIndex** index, + void* data, + size_t size, + MetricType metric, + StorageKind storage_kind +) noexcept { + using Impl = VamanaIndexImpl; + *index = nullptr; + return runtime_error_wrapper([&] { + auto sp = std::span(reinterpret_cast(data), size); + auto is = std::make_unique(sp); + std::unique_ptr impl{ + Impl::map_to_stream(std::move(is), metric, storage_kind)}; + *index = new VamanaIndexManagerBase{std::move(impl)}; + }); +} + #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC // Specialization to build LeanVec-based Vamana index with specified leanvec dims Status VamanaIndexLeanVec::build( diff --git a/bindings/cpp/src/vamana_index_impl.h b/bindings/cpp/src/vamana_index_impl.h index 45023b1d7..85eba0a46 100644 --- a/bindings/cpp/src/vamana_index_impl.h +++ b/bindings/cpp/src/vamana_index_impl.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -387,14 +388,18 @@ class VamanaIndexImpl { // Constructor used during loading VamanaIndexImpl( - std::unique_ptr&& impl, MetricType metric, StorageKind storage_kind + std::unique_ptr&& impl, + MetricType metric, + StorageKind storage_kind, + std::unique_ptr mapped_stream = nullptr ) : dim_{0} , metric_type_{metric} , storage_kind_{storage_kind} , build_params_{} , default_search_params_{} - , impl_{std::move(impl)} { + , impl_{std::move(impl)} + , mapped_stream_{std::move(mapped_stream)} { if (impl_) { dim_ = impl_->dimensions(); const auto& buffer_config = impl_->get_search_parameters().buffer_config_; @@ -410,11 +415,12 @@ class VamanaIndexImpl { } } - template + template static svs::Vamana* load_impl_t( storage::StorageType&& SVS_UNUSED(tag), std::istream& stream, - MetricType metric + MetricType metric, + Args&&... args ) { if constexpr (!storage::is_supported_storage_kind_v) { throw StatusException( @@ -425,7 +431,10 @@ class VamanaIndexImpl { auto threadpool = default_threadpool(); return new svs::Vamana(svs::Vamana::assemble( - stream, to_svs_distance(metric), std::move(threadpool) + stream, + to_svs_distance(metric), + std::move(threadpool), + std::forward(args)... )); } } @@ -447,6 +456,30 @@ class VamanaIndexImpl { ); } + static VamanaIndexImpl* map_to_stream( + std::unique_ptr&& in, MetricType metric, StorageKind storage_kind + ) { + using map_allocator_type = svs::io::MemoryStreamAllocator; + if (!svs::io::is_memory_stream(*in)) { + throw StatusException{ + ErrorCode::INVALID_ARGUMENT, "Provided stream is not a memory stream"}; + } + return storage::dispatch_storage_kind( + storage_kind, + [&](auto&& tag, std::unique_ptr&& in, MetricType metric) { + using Tag = std::decay_t; + auto impl = load_impl_t( + std::forward(tag), *in, metric, map_allocator_type{*in} + ); + return new VamanaIndexImpl( + std::unique_ptr{impl}, metric, storage_kind, std::move(in) + ); + }, + std::move(in), + metric + ); + } + // Data members protected: size_t dim_; @@ -455,6 +488,7 @@ class VamanaIndexImpl { VamanaIndex::BuildParams build_params_; VamanaIndex::SearchParams default_search_params_; std::unique_ptr impl_; + std::unique_ptr mapped_stream_; }; #ifdef SVS_RUNTIME_HAVE_LVQ_LEANVEC diff --git a/bindings/cpp/tests/runtime_test.cpp b/bindings/cpp/tests/runtime_test.cpp index 201375d3c..274b9403d 100644 --- a/bindings/cpp/tests/runtime_test.cpp +++ b/bindings/cpp/tests/runtime_test.cpp @@ -155,6 +155,81 @@ void write_and_read_index( Index::destroy(loaded); } +// Template function to write and map an index +template +void write_and_map_index( + BuildFunc build_func, + const std::vector& xb, + size_t n, + size_t d, + std::optional storage_kind = std::nullopt, + svs::runtime::v0::MetricType metric = svs::runtime::v0::MetricType::L2 +) { + // Build index + Index* index = nullptr; + svs::runtime::v0::Status status = build_func(&index); + + // Stop here if storage kind is not supported on this platform + if constexpr (std::is_base_of_v) { + if (storage_kind.has_value()) { + if (!Index::check_storage_kind(*storage_kind).ok()) { + CATCH_REQUIRE(!status.ok()); + return; + } + } + } + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(index != nullptr); + + // Add data to index + if constexpr (std::is_same_v || std::is_same_v) { + status = index->add(n, xb.data()); + } else { + std::vector labels(n); + std::iota(labels.begin(), labels.end(), 0); + status = index->add(n, labels.data(), xb.data()); + } + CATCH_REQUIRE(status.ok()); + + svs_test::prepare_temp_directory(); + auto temp_dir = svs_test::temp_directory(); + auto filename = temp_dir / "index_test.bin"; + + // Serialize + std::ofstream out(filename, std::ios::binary); + CATCH_REQUIRE(out.is_open()); + status = index->save(out); + CATCH_REQUIRE(status.ok()); + out.close(); + + // Deserialize + Index* loaded = nullptr; + + if constexpr (std::is_same_v) { + status = Index::map_to_file(&loaded, filename.c_str(), metric); + } else { + CATCH_REQUIRE(storage_kind.has_value()); + status = Index::map_to_file(&loaded, filename.c_str(), metric, *storage_kind); + } + CATCH_REQUIRE(status.ok()); + CATCH_REQUIRE(loaded != nullptr); + + // Test basic functionality of loaded index + const int nq = 5; + const float* xq = xb.data(); + const int k = 10; + + std::vector distances(nq * k); + std::vector result_labels(nq * k); + + status = loaded->search(nq, xq, k, distances.data(), result_labels.data()); + CATCH_REQUIRE(status.ok()); + + // Clean up + Index::destroy(index); + Index::destroy(loaded); +} + // Helper that writes and reads and index of requested size // Reports memory usage UsageInfo run_save_and_load_test( @@ -455,6 +530,16 @@ CATCH_TEST_CASE("FlatIndexWriteAndRead", "[runtime]") { ); } +CATCH_TEST_CASE("FlatIndexWriteAndMap", "[runtime]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::FlatIndex** index) { + return svs::runtime::v0::FlatIndex::build( + index, test_d, svs::runtime::v0::MetricType::L2 + ); + }; + write_and_map_index(build_func, test_data, test_n, test_d); +} + CATCH_TEST_CASE("SearchWithIDFilter", "[runtime]") { const auto& test_data = get_test_data(); // Build index @@ -692,6 +777,23 @@ CATCH_TEST_CASE("WriteAndReadStaticIndexSVS", "[runtime][static_vamana]") { ); } +CATCH_TEST_CASE("WriteAndMapStaticIndexSVS", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP32, + build_params + ); + }; + write_and_map_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP32 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSFP16", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -709,6 +811,23 @@ CATCH_TEST_CASE("WriteAndReadStaticIndexSVSFP16", "[runtime][static_vamana]") { ); } +CATCH_TEST_CASE("WriteAndMapStaticIndexSVSFP16", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::FP16, + build_params + ); + }; + write_and_map_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::FP16 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSSQI8", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -726,6 +845,23 @@ CATCH_TEST_CASE("WriteAndReadStaticIndexSVSSQI8", "[runtime][static_vamana]") { ); } +CATCH_TEST_CASE("WriteAndMapStaticIndexSVSSQI8", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::SQI8, + build_params + ); + }; + write_and_map_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::SQI8 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -743,6 +879,23 @@ CATCH_TEST_CASE("WriteAndReadStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") ); } +CATCH_TEST_CASE("WriteAndMapStaticIndexSVSLVQ4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndex::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LVQ4x4, + build_params + ); + }; + write_and_map_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LVQ4x4 + ); +} + CATCH_TEST_CASE("WriteAndReadStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); auto build_func = [](svs::runtime::v0::VamanaIndex** index) { @@ -761,6 +914,24 @@ CATCH_TEST_CASE("WriteAndReadStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_ ); } +CATCH_TEST_CASE("WriteAndMapStaticIndexSVSVamanaLeanVec4x4", "[runtime][static_vamana]") { + const auto& test_data = get_test_data(); + auto build_func = [](svs::runtime::v0::VamanaIndex** index) { + svs::runtime::v0::VamanaIndex::BuildParams build_params{64}; + return svs::runtime::v0::VamanaIndexLeanVec::build( + index, + test_d, + svs::runtime::v0::MetricType::L2, + svs::runtime::v0::StorageKind::LeanVec4x4, + 32, + build_params + ); + }; + write_and_map_index( + build_func, test_data, test_n, test_d, svs::runtime::v0::StorageKind::LeanVec4x4 + ); +} + CATCH_TEST_CASE("StaticIndexLeanVecWithTrainingData", "[runtime][static_vamana]") { const auto& test_data = get_test_data(); const size_t leanvec_dims = 32; diff --git a/bindings/cpp/tests/utils.h b/bindings/cpp/tests/utils.h index 8d1bc89f6..2e2ed5784 100644 --- a/bindings/cpp/tests/utils.h +++ b/bindings/cpp/tests/utils.h @@ -30,30 +30,52 @@ namespace svs_test { ///// File System ///// -inline std::filesystem::path temp_directory() { - // Use /tmp for runtime binding tests - return std::filesystem::path("/tmp/svs_runtime_test"); -} +namespace detail { +struct TempDirectory { + std::filesystem::path path; -inline bool cleanup_temp_directory() { - return std::filesystem::remove_all(temp_directory()); -} + explicit TempDirectory(const std::string& prefix) + : path{create_unique_temp_directory(prefix)} {} -inline bool make_temp_directory() { - return std::filesystem::create_directories(temp_directory()); -} + ~TempDirectory() noexcept { + std::error_code ec; + std::filesystem::remove_all(path, ec); + // Ignore errors in cleanup. + } -inline bool prepare_temp_directory() { - cleanup_temp_directory(); - return make_temp_directory(); -} + std::filesystem::path get() const { return path; } + operator const std::filesystem::path&() const { return path; } + + std::filesystem::path operator/(const std::string& subpath) const { + return path / subpath; + } -inline std::filesystem::path prepare_temp_directory_v2() { - cleanup_temp_directory(); - make_temp_directory(); - return temp_directory(); + static std::filesystem::path create_unique_temp_directory(const std::string& prefix) { + namespace fs = std::filesystem; + auto temp_dir = fs::temp_directory_path(); + constexpr int hex_mask = 0xFFFFFF; // 6 hex digits is enough. + // Try up to 10 times to create a unique directory. + for (int i = 0; i < 10; ++i) { + auto random_hex = std::to_string(std::rand() & hex_mask); + auto dir = temp_dir / (prefix + "-" + random_hex); + if (std::filesystem::create_directories(dir)) { + return dir; + } + } + throw std::runtime_error("Could not create a unique temporary directory!"); + } +}; +} // namespace detail + +inline detail::TempDirectory temp_directory() { + // Use /tmp for runtime binding tests + return detail::TempDirectory("svs_runtime_test"); } +inline bool prepare_temp_directory() { return true; } + +inline detail::TempDirectory prepare_temp_directory_v2() { return temp_directory(); } + } // namespace svs_test // Test utility functions diff --git a/include/svs/core/data/io.h b/include/svs/core/data/io.h index 572ffa63a..776a7d06c 100644 --- a/include/svs/core/data/io.h +++ b/include/svs/core/data/io.h @@ -19,6 +19,7 @@ // svs #include "svs/concepts/data.h" #include "svs/core/io.h" +#include "svs/core/io/memstream.h" #include "svs/lib/array.h" #include "svs/lib/exception.h" @@ -46,6 +47,11 @@ struct DefaultWriteAccessor { return file.reader(lib::Type()); } + template + lib::VectorReader vector_reader(const Data& data) const { + return lib::VectorReader(data.dimensions()); + } + template void set(Data& data, size_t i, Span span) const { data.set_datum(i, span); @@ -79,16 +85,13 @@ void populate_impl( } } -template void populate(std::istream& is, Data& data) { - auto accessor = DefaultWriteAccessor(); - +template +void populate(Data& data, WriteAccessor&& accessor, std::istream& is) { size_t num_vectors = data.size(); - size_t dims = data.dimensions(); - auto max_lines = Dynamic; auto nvectors = std::min(num_vectors, max_lines); - auto reader = lib::VectorReader(dims); + auto reader = accessor.vector_reader(data); for (size_t i = 0; i < nvectors; ++i) { reader.read(is); accessor.set(data, i, reader.data()); @@ -194,12 +197,33 @@ lib::lazy_result_t load_dataset(const File& file, const F& la return load_impl(detail::to_native(file), default_accessor, lazy); } +template F> +lib::lazy_result_t load_dataset( + std::istream& is, + WriteAccessor&& accessor, + const F& lazy, + size_t num_vectors, + size_t dims +) { + auto data = lazy(num_vectors, dims); + if constexpr (!is_view_type_v::allocator_type>) { + populate(data, std::forward(accessor), is); + } else { + if (!is_memory_stream(is)) { + throw ANNEXCEPTION("Trying to load a dataset with a view allocator from a " + "non-memory stream. This " + "is not supported since views are compatible only with " + "memory-mapped streams."); + } + } + return data; +} + template F> lib::lazy_result_t load_dataset(std::istream& is, const F& lazy, size_t num_vectors, size_t dims) { - auto data = lazy(num_vectors, dims); - populate(is, data); - return data; + auto accessor = DefaultWriteAccessor(); + return load_dataset(is, accessor, lazy, num_vectors, dims); } // Return whether or not a file is directly loadable via file-extension. diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index 6950983de..251f30559 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -21,11 +21,13 @@ #include "svs/core/allocator.h" #include "svs/core/compact.h" #include "svs/core/data/io.h" +#include "svs/core/io/memstream.h" #include "svs/lib/array.h" #include "svs/lib/boundscheck.h" #include "svs/lib/datatype.h" #include "svs/lib/memory.h" +#include "svs/lib/misc.h" #include "svs/lib/prefetch.h" #include "svs/lib/saveload.h" #include "svs/lib/threads.h" @@ -461,10 +463,7 @@ class SimpleData { /// svs::lib::load_from_disk>("directory"); /// @endcode /// - static SimpleData - load(const lib::LoadTable& table, const allocator_type& allocator = {}) - requires(!is_view) - { + static SimpleData load(const lib::LoadTable& table, const allocator_type& allocator) { return GenericSerializer::load( table, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); @@ -472,13 +471,26 @@ class SimpleData { ); } + static SimpleData load(const lib::LoadTable& table) + requires(!is_view) + { + return load(table, allocator_type{}); + } + + static SimpleData load(const lib::LoadTable& SVS_UNUSED(table)) + requires(is_view) + { + throw ANNEXCEPTION( + "Trying to load a SimpleData view without an istream. This is not supported " + "since views are compatible only with in-memory streams." + ); + } + static SimpleData load( const lib::ContextFreeLoadTable& table, std::istream& is, - const allocator_type& allocator = {} - ) - requires(!is_view) - { + const allocator_type& allocator + ) { return GenericSerializer::load( table, is, lib::Lazy([&](size_t n_elements, size_t n_dimensions) { return SimpleData(n_elements, n_dimensions, allocator); @@ -486,6 +498,28 @@ class SimpleData { ); } + static SimpleData load(const lib::ContextFreeLoadTable& table, std::istream& is) + requires(!is_view) + { + return load(table, is, allocator_type{}); + } + + static SimpleData load(const lib::ContextFreeLoadTable& table, std::istream& is) + requires(is_view) + { + static_assert( + std::is_same_v>, + "SimpleData views must use the MemoryStreamAllocator." + ); + if (!io::is_memory_stream(is)) { + throw ANNEXCEPTION( + "Trying to load a SimpleData view from a non-memory stream istream. This " + "is not supported since views are compatible only with in-memory streams." + ); + } + return load(table, is, allocator_type{is}); + } + /// /// @brief Try to automatically load the dataset. /// @@ -642,6 +676,16 @@ template class Blocked { template inline constexpr bool is_blocked_v = false; template inline constexpr bool is_blocked_v> = true; +} // namespace data + +namespace lib::detail { +// Allow rebinding of allocators through the Blocked wrapper. +template struct AllocatorRebinder> { + using type = data::Blocked>; +}; +} // namespace lib::detail + +namespace data { /// /// @brief A specialization of ``SimpleData`` for large-scale dynamic datasets. /// diff --git a/include/svs/core/graph/graph.h b/include/svs/core/graph/graph.h index 89e456e07..05db19b7c 100644 --- a/include/svs/core/graph/graph.h +++ b/include/svs/core/graph/graph.h @@ -416,13 +416,12 @@ class SimpleGraph : public SimpleGraphBase static constexpr SimpleGraph load( - const lib::ContextFreeLoadTable& table, - std::istream& is, - const Alloc& allocator = {} + const lib::ContextFreeLoadTable& table, std::istream& is, AllocArgs&&... alloc_args ) { auto lazy = lib::Lazy([](data_type data) { return SimpleGraph(std::move(data)); }); - return parent_type::load(table, lazy, is, allocator); + return parent_type::load(table, lazy, is, std::forward(alloc_args)...); } static constexpr SimpleGraph @@ -434,8 +433,11 @@ class SimpleGraph : public SimpleGraphBase(is, allocator); + template + static constexpr SimpleGraph load(std::istream& is, AllocArgs&&... alloc_args) { + return lib::load_from_stream( + is, std::forward(alloc_args)... + ); } }; diff --git a/include/svs/core/io/memstream.h b/include/svs/core/io/memstream.h new file mode 100644 index 000000000..40a5f6aeb --- /dev/null +++ b/include/svs/core/io/memstream.h @@ -0,0 +1,580 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "svs/core/allocator.h" +#include "svs/lib/array.h" // just for svs::is_view_type_v specialization + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__cpp_lib_spanstream) && __cpp_lib_spanstream >= 202106L +#include +#define SVS_HAS_STD_SPANSTREAM 1 +#else +#define SVS_HAS_STD_SPANSTREAM 0 +#endif + +namespace svs { +namespace io { + +template > +class basic_mmstreambuf : public std::basic_streambuf { + static_assert( + sizeof(CharT) == 1, "basic_mmstreambuf requires a 1-byte character type." + ); + + public: + using char_type = CharT; + using traits_type = Traits; + using int_type = typename traits_type::int_type; + using pos_type = typename traits_type::pos_type; + using off_type = typename traits_type::off_type; + + explicit basic_mmstreambuf(MMapPtr mapping) + : ptr_{std::move(mapping)} { + if (ptr_) { + auto base_ptr = static_cast(ptr_.base()); + this->setg(base_ptr, ptr_.data(), base_ptr + ptr_.size()); + this->setp(&empty_, &empty_); // disallow writing + } else { + this->setg(&empty_, &empty_, &empty_); // empty buffer + this->setp(&empty_, &empty_); // disallow writing + } + } + + basic_mmstreambuf() + : basic_mmstreambuf(MMapPtr{}) {} + basic_mmstreambuf(const basic_mmstreambuf&) = delete; + basic_mmstreambuf& operator=(const basic_mmstreambuf&) = delete; + basic_mmstreambuf(basic_mmstreambuf&&) = default; + basic_mmstreambuf& operator=(basic_mmstreambuf&&) = default; + + basic_mmstreambuf* open( + const std::filesystem::path& path, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out + ) { + std::error_code ec; + auto size = std::filesystem::file_size(path, ec); + if (ec) { + throw ANNEXCEPTION( + "Failed to get file size: {} with system error: {}", + path.string(), + ec.message() + ); + } + if (size == 0) { + throw ANNEXCEPTION("Cannot memory-map empty file: {}", path.string()); + } + auto perm = + (mode & std::ios_base::out) ? MemoryMapper::ReadWrite : MemoryMapper::ReadOnly; + ptr_ = + MemoryMapper{perm, MemoryMapper::MustUseExisting}.mmap(path, lib::Bytes{size}); + if (!ptr_) { + throw ANNEXCEPTION("Failed to memory-map file: {}", path.string()); + } + auto base_ptr = static_cast(ptr_.base()); + this->setg(base_ptr, ptr_.data(), base_ptr + ptr_.size()); + this->setp(&empty_, &empty_); // disallow writing + return this; + } + + basic_mmstreambuf* close() noexcept { + ptr_.unmap(); + this->setg(&empty_, &empty_, &empty_); // empty buffer + this->setp(&empty_, &empty_); // disallow writing + return this; + } + + [[nodiscard]] bool is_open() const noexcept { return static_cast(ptr_); } + + [[nodiscard]] std::size_t size() const noexcept { return ptr_.size(); } + + protected: + int_type underflow() override { + if (this->gptr() == this->egptr()) { + return traits_type::eof(); + } + return traits_type::to_int_type(*this->gptr()); + } + + pos_type seekoff( + off_type off, std::ios_base::seekdir dir, std::ios_base::openmode which + ) override { + if (!(which & std::ios_base::in)) { + return pos_type(off_type(-1)); + } + + const off_type current = static_cast(this->gptr() - this->eback()); + const off_type end = static_cast(this->egptr() - this->eback()); + + off_type target = 0; + switch (dir) { + case std::ios_base::beg: + target = off; + break; + case std::ios_base::cur: + target = current + off; + break; + case std::ios_base::end: + target = end + off; + break; + default: + return pos_type(off_type(-1)); + } + + if (target < 0 || target > end) { + return pos_type(off_type(-1)); + } + + this->setg(this->eback(), this->eback() + target, this->egptr()); + return pos_type(target); + } + + pos_type seekpos(pos_type sp, std::ios_base::openmode which) override { + return seekoff(static_cast(sp), std::ios_base::beg, which); + } + + int_type overflow(int_type) override { + return Traits::eof(); // disallow writing + } + + private: + MMapPtr ptr_; + // A dummy character to use as the put area for the streambuf when the mapping is empty + // or closed. This is necessary to ensure that the put area is always valid, even when + // the mapping is empty or closed. + char_type empty_ = char_type{}; +}; + +template > +class basic_mmstream : public std::basic_istream { + public: + using streambuf_type = basic_mmstreambuf; + + basic_mmstream() + : std::basic_istream(nullptr) { + this->init(&buf_); + } + + explicit basic_mmstream(MMapPtr mapping) + : std::basic_istream(nullptr) + , buf_(std::move(mapping)) { + this->init(&buf_); + } + + explicit basic_mmstream( + const std::filesystem::path& path, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out + ) + : basic_mmstream(MMapPtr{}) { + open(path, mode); + } + + void open( + const std::filesystem::path& path, + std::ios_base::openmode mode = std::ios_base::in | std::ios_base::out + ) { + buf_.open(path, mode); + this->clear(); + } + + void close() noexcept { + buf_.close(); + this->setstate(std::ios_base::eofbit); + } + + [[nodiscard]] bool is_open() const noexcept { return buf_.is_open(); } + + [[nodiscard]] std::size_t size() const noexcept { return buf_.size(); } + + [[nodiscard]] streambuf_type* rdbuf() noexcept { return &buf_; } + + private: + streambuf_type buf_; +}; + +using mmstreambuf = basic_mmstreambuf; +using mmstream = basic_mmstream; + +#if SVS_HAS_STD_SPANSTREAM + +template > +using basic_spanbuf = std::basic_spanbuf; + +template > +using basic_ispanstream = std::basic_ispanstream; + +#else + +template > +class basic_spanbuf : public std::basic_streambuf { + static_assert(sizeof(CharT) == 1, "basic_spanbuf requires a 1-byte character type."); + + public: + using char_type = CharT; + using traits_type = Traits; + using int_type = typename traits_type::int_type; + using pos_type = typename traits_type::pos_type; + using off_type = typename traits_type::off_type; + using span_type = std::span; + + basic_spanbuf() + : basic_spanbuf(span_type{}) {} + + explicit basic_spanbuf(span_type s) { span(s); } + + /// Returns the underlying span. + [[nodiscard]] span_type span() const noexcept { return data_; } + + /// Updates the underlying span and resets the read position to the beginning. + void span(span_type s) noexcept { + data_ = s; + if (data_.empty()) { + this->setg(&empty_, &empty_, &empty_); + } else { + auto* begin = data_.data(); + this->setg(begin, begin, begin + data_.size()); + } + this->setp(&empty_, &empty_); // disallow writing + } + + protected: + int_type overflow(int_type) override { + return traits_type::eof(); // disallow writing + } + + std::basic_streambuf* setbuf(char_type* s, std::streamsize n) override { + span(span_type{s, static_cast(n)}); + return this; + } + + pos_type seekoff( + off_type off, std::ios_base::seekdir dir, std::ios_base::openmode which + ) override { + if (!(which & std::ios_base::in)) { + return pos_type(off_type(-1)); + } + + const off_type current = static_cast(this->gptr() - this->eback()); + const off_type end = static_cast(this->egptr() - this->eback()); + + off_type target = 0; + switch (dir) { + case std::ios_base::beg: + target = off; + break; + case std::ios_base::cur: + target = current + off; + break; + case std::ios_base::end: + target = end + off; + break; + default: + return pos_type(off_type(-1)); + } + + if (target < 0 || target > end) { + return pos_type(off_type(-1)); + } + + this->setg(this->eback(), this->eback() + target, this->egptr()); + return pos_type(target); + } + + pos_type seekpos(pos_type sp, std::ios_base::openmode which) override { + return seekoff(static_cast(sp), std::ios_base::beg, which); + } + + private: + span_type data_; + char_type empty_ = char_type{}; +}; + +template > +class basic_ispanstream : public std::basic_istream { + public: + using char_type = CharT; + using traits_type = Traits; + using int_type = typename traits_type::int_type; + using pos_type = typename traits_type::pos_type; + using off_type = typename traits_type::off_type; + using streambuf_type = basic_spanbuf; + using span_type = typename streambuf_type::span_type; + + basic_ispanstream() + : std::basic_istream(nullptr) { + this->init(&buf_); + } + + explicit basic_ispanstream(span_type span) + : std::basic_istream(nullptr) + , buf_(span) { + this->init(&buf_); + } + + span_type span() const noexcept { return buf_.span(); } + void span(span_type s) noexcept { + buf_.span(s); + this->clear(); + } + + [[nodiscard]] streambuf_type* rdbuf() noexcept { return &buf_; } + + private: + streambuf_type buf_; +}; + +#endif + +using spanbuf = basic_spanbuf; +using ispanstream = basic_ispanstream; + +/// Returns true if @p stream is backed entirely by an in-memory buffer. +/// +/// Specifically, returns true when the stream's streambuf is either: +/// - a @c basic_mmstreambuf (memory-mapped file), or +/// - a @c basic_spanbuf (non-owning in-memory span), or +/// - a @c std::basic_stringbuf (std::istringstream / std::stringstream). +template > +[[nodiscard]] bool is_memory_stream(std::basic_istream& stream) noexcept { + auto* buf = stream.rdbuf(); + if (buf == nullptr) { + return false; + } + if (dynamic_cast*>(buf) != nullptr) { + return true; + } + if (dynamic_cast*>(buf) != nullptr) { + return true; + } + if (dynamic_cast*>(buf) != nullptr) { + return true; + } + return false; +} + +namespace detail { + +// A minimal accessor that promotes the protected gptr() method of +// std::basic_streambuf to public visibility. It adds no data members and no +// virtual functions, so the static_cast below is layout-safe (gptr() reads only +// base-class internal pointers). +template +struct StreambufAccessor : std::basic_streambuf { + static CharT* get(std::basic_streambuf* buf) noexcept { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + return static_cast(buf)->gptr(); + } + + static CharT* begin(std::basic_streambuf* buf) noexcept { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + return static_cast(buf)->eback(); + } + + static CharT* end(std::basic_streambuf* buf) noexcept { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + return static_cast(buf)->egptr(); + } +}; + +} // namespace detail + +/// Returns a typed pointer to the current read position of an in-memory stream. +/// +/// Works for: +/// - @c basic_mmstreambuf-backed streams (memory-mapped files via @c basic_mmstream) +/// - @c std::basic_stringbuf-backed streams (@c std::istringstream / @c +/// std::stringstream) +/// +/// @tparam T Element type to interpret the raw bytes at the current position as. +/// @tparam CharT Character type of the stream (must be 1-byte wide). +/// @tparam Traits Character traits of the stream. +/// @param stream The input stream to query. +/// @returns A pointer of type @c T* to the current read position, +/// or @c nullptr if the stream is not in-memory or has no streambuf. +template > +[[nodiscard]] T* current_ptr(std::basic_istream& stream) noexcept { + static_assert(sizeof(CharT) == 1, "current_ptr requires a 1-byte character type."); + if (!is_memory_stream(stream)) { + return nullptr; + } + + auto* buf = stream.rdbuf(); + auto begin = detail::StreambufAccessor::begin(buf); + auto end = detail::StreambufAccessor::end(buf); + if (begin == end) { + return nullptr; + } + auto raw = detail::StreambufAccessor::get(buf); + + // Return nullptr if the current position is misaligned for the requested type T, to + // avoid undefined behavior on dereference. + if (reinterpret_cast(raw) % alignof(T) != 0) { + assert( + false && "current_ptr: current position is misaligned for the requested type T" + ); + return nullptr; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(raw); +} + +/// Returns a typed pointer to the beginning of the read position of an in-memory stream. +/// +/// Works for: +/// - @c basic_mmstreambuf-backed streams (memory-mapped files via @c basic_mmstream) +/// - @c std::basic_stringbuf-backed streams (@c std::istringstream / @c +/// std::stringstream) +/// +/// @tparam T Element type to interpret the raw bytes at the current position as. +/// @tparam CharT Character type of the stream (must be 1-byte wide). +/// @tparam Traits Character traits of the stream. +/// @param stream The input stream to query. +/// @returns A pointer of type @c T* to the beginning of the read position, +/// or @c nullptr if the stream is not in-memory or has no streambuf. +template > +[[nodiscard]] T* begin_ptr(std::basic_istream& stream) noexcept { + static_assert(sizeof(CharT) == 1, "begin_ptr requires a 1-byte character type."); + if (!is_memory_stream(stream)) { + return nullptr; + } + + auto* buf = stream.rdbuf(); + auto begin = detail::StreambufAccessor::begin(buf); + auto end = detail::StreambufAccessor::end(buf); + if (begin == end) { + return nullptr; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(begin); +} + +/// Returns a typed pointer to the end of the read position of an in-memory stream. +/// +/// Works for: +/// - @c basic_mmstreambuf-backed streams (memory-mapped files via @c basic_mmstream) +/// - @c std::basic_stringbuf-backed streams (@c std::istringstream / @c +/// std::stringstream) +/// +/// @tparam T Element type to interpret the raw bytes at the current position as. +/// @tparam CharT Character type of the stream (must be 1-byte wide). +/// @tparam Traits Character traits of the stream. +/// @param stream The input stream to query. +/// @returns A pointer of type @c T* to the end of the read position, +/// or @c nullptr if the stream is not in-memory or has no streambuf. +template > +[[nodiscard]] T* end_ptr(std::basic_istream& stream) noexcept { + static_assert(sizeof(CharT) == 1, "end_ptr requires a 1-byte character type."); + if (!is_memory_stream(stream)) { + return nullptr; + } + + auto* buf = stream.rdbuf(); + auto begin = detail::StreambufAccessor::begin(buf); + auto end = detail::StreambufAccessor::end(buf); + if (begin == end) { + return nullptr; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(end); +} + +/// @brief Memory-stream allocator that allocates memory from in-memory streams. +/// +/// This is used to construct SVS data structures directly on memory-mapped files or +/// in-memory buffers, without needing to copy data out of the stream into separately +/// allocated memory. +/// +/// The allocator does not take ownership of the memory; the caller is responsible for +/// ensuring the memory remains valid for the lifetime of any pointers returned by this +/// allocator. +template > +struct MemoryStreamAllocator { + using value_type = T; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + + using stream_type = std::basic_istream; + + MemoryStreamAllocator() = default; + + MemoryStreamAllocator(stream_type& stream) + : stream_(&stream) { + if (!is_memory_stream(*stream_)) { + throw std::invalid_argument( + "MemoryStreamAllocator requires a memory-backed stream." + ); + } + } + + template + MemoryStreamAllocator(const MemoryStreamAllocator& other) + : stream_(&other.stream()) {} + + [[nodiscard]] pointer allocate(size_type n) { + if (stream_ == nullptr) { + throw std::runtime_error("MemoryStreamAllocator is not properly initialized."); + } + T* current = current_ptr(*stream_); + if (current == nullptr) { + throw std::runtime_error("Failed to obtain current pointer from memory stream." + ); + } + pointer result = current; + + // check for overflow: + auto off = lib::narrow(n * sizeof(T)); + + stream_->seekg(off, std::ios_base::cur); + if (!*stream_) { + throw std::runtime_error("Failed to advance memory stream after allocation."); + } + return result; + } + + void deallocate(pointer, size_type) noexcept { + // No-op since we don't own the memory. + } + + stream_type& stream() const noexcept { return *stream_; } + + private: + stream_type* stream_ = nullptr; +}; + +} // namespace io + +template +inline constexpr bool is_view_type_v> = true; + +} // namespace svs diff --git a/include/svs/lib/array.h b/include/svs/lib/array.h index fef4b3f81..90120e7e5 100644 --- a/include/svs/lib/array.h +++ b/include/svs/lib/array.h @@ -156,6 +156,9 @@ template struct View { template inline constexpr bool is_view_type_v = false; template inline constexpr bool is_view_type_v> = true; +template +concept ViewAllocator = is_view_type_v; + namespace array_impl { // Shared implementations across various DenseArray specializations. @@ -507,7 +510,7 @@ template > class D [[no_unique_address]] Alloc allocator_; }; -template class DenseArray> { +template class DenseArray { private: // N.B.: This is an important assumption for many algorithms of this type. // Don't remove this requirement without careful consideration. @@ -517,6 +520,7 @@ template class DenseArray> { static constexpr bool is_const = std::is_const_v; ///// Allocator Aware + using allocator_type = Alloc; using pointer = T*; using const_pointer = const T*; @@ -533,6 +537,9 @@ template class DenseArray> { using const_span = std::span>; using span = std::span>; + // Get the underlying allocator. + const allocator_type& get_allocator() const { return allocator_; } + /// @brief Return the extent of the span returned for `slice`. static constexpr size_t extent() { return array_impl::extent(); } @@ -647,11 +654,21 @@ template class DenseArray> { ///// explicit DenseArray(Dims dims, pointer ptr) + requires(std::is_same_v>) : pointer_{ptr} - , dims_{std::move(dims)} {} + , dims_{std::move(dims)} + , allocator_{ptr} {} - explicit DenseArray(Dims dims, View view) - : DenseArray{std::move(dims), view.ptr} {} + explicit DenseArray(Dims dims, Alloc allocator) + : pointer_{nullptr} + , dims_{std::move(dims)} + , allocator_{allocator} { + if constexpr (std::is_same_v>) { + pointer_ = allocator_.ptr; + } else { + pointer_ = std::allocator_traits::allocate(allocator_, size()); + } + } // Iterator pointer begin() @@ -682,6 +699,7 @@ template class DenseArray> { private: pointer pointer_{nullptr}; [[no_unique_address]] Dims dims_{}; + [[no_unique_address]] Alloc allocator_; }; template diff --git a/include/svs/lib/misc.h b/include/svs/lib/misc.h index 0eccd07e2..0edda83f8 100644 --- a/include/svs/lib/misc.h +++ b/include/svs/lib/misc.h @@ -40,8 +40,13 @@ namespace svs::lib { struct ZeroInitializer {}; /// @brief Get the full type of the allocator `Alloc` rebound to a value to `To`. +namespace detail { +template struct AllocatorRebinder { + using type = typename std::allocator_traits::template rebind_alloc; +}; +} // namespace detail template -using rebind_allocator_t = typename std::allocator_traits::template rebind_alloc; +using rebind_allocator_t = typename detail::AllocatorRebinder::type; /// @brief Rebind an allocator to a new value type. template diff --git a/include/svs/orchestrators/vamana.h b/include/svs/orchestrators/vamana.h index ad12fd8c3..c4c4422b6 100644 --- a/include/svs/orchestrators/vamana.h +++ b/include/svs/orchestrators/vamana.h @@ -469,7 +469,14 @@ class Vamana : public manager::IndexManager { auto deserializer = svs::lib::detail::Deserializer::build(stream); if (deserializer.is_native()) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); - using GraphType = svs::GraphLoader<>::return_type; + + using GraphType = std::conditional_t< + is_view_type_v, + graphs::SimpleGraph< + uint32_t, + lib::rebind_allocator_t>, + GraphLoader<>::return_type>; + if constexpr (std::is_same_v) { auto dispatcher = DistanceDispatcher(distance); return dispatcher([&](auto distance_function) { diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index 5998e76a7..1ac48ca34 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -198,7 +198,7 @@ namespace detail { struct MinMaxAccumulator { float min = std::numeric_limits::max(); - float max = std::numeric_limits::min(); + float max = std::numeric_limits::lowest(); void accumulate(float val) { min = std::min(min, val); @@ -369,8 +369,8 @@ class SQDataset { using allocator_type = Alloc; using element_type = T; using data_type = data::SimpleData; - using const_value_type = std::span; - using value_type = const_value_type; + using const_value_type = typename data_type::const_value_type; + using value_type = typename data_type::value_type; // Data wrapped in the library allocator. using lib_alloc_data_type = SQDataset>; @@ -399,6 +399,7 @@ class SQDataset { float get_bias() const { return bias_; } const_value_type get_datum(size_t i) const { return data_.get_datum(i); } + value_type get_datum(size_t i) { return data_.get_datum(i); } std::vector decompress_datum(size_t i) const { auto datum = get_datum(i); @@ -510,22 +511,20 @@ class SQDataset { void save(std::ostream& os) const { data_.save(os); } /// @brief Load dataset from a file. - static SQDataset - load(const lib::LoadTable& table, const allocator_type& allocator = {}) { + template + static SQDataset load(const lib::LoadTable& table, Args&&... args) { return SQDataset{ - SVS_LOAD_MEMBER_AT_(table, data, allocator), + SVS_LOAD_MEMBER_AT_(table, data, std::forward(args)...), lib::load_at(table, "scale"), lib::load_at(table, "bias")}; } /// @brief Load dataset from a stream. - static SQDataset load( - const lib::ContextFreeLoadTable& table, - std::istream& is, - const allocator_type& allocator = {} - ) { + template + static SQDataset + load(const lib::ContextFreeLoadTable& table, std::istream& is, Args&&... args) { return SQDataset{ - SVS_LOAD_MEMBER_AT_(table, data, is, allocator), + SVS_LOAD_MEMBER_AT_(table, data, is, std::forward(args)...), lib::load_at(table, "scale"), lib::load_at(table, "bias")}; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 265b37ef8..8c812d35a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -117,6 +117,7 @@ set(TEST_SOURCES ${TEST_DIR}/svs/core/io/vecs.cpp ${TEST_DIR}/svs/core/io/binary.cpp ${TEST_DIR}/svs/core/io/native.cpp + ${TEST_DIR}/svs/core/io/memstream.cpp ${TEST_DIR}/svs/core/io.cpp ${TEST_DIR}/svs/core/loading.cpp ${TEST_DIR}/svs/core/logging.cpp diff --git a/tests/svs/core/data/simple.cpp b/tests/svs/core/data/simple.cpp index 1371d7314..2e873ef37 100644 --- a/tests/svs/core/data/simple.cpp +++ b/tests/svs/core/data/simple.cpp @@ -23,6 +23,7 @@ // stdlib #include +#include #include // catch2 @@ -214,4 +215,51 @@ CATCH_TEST_CASE("Testing Simple Data", "[core][data]") { svs::data::ConstSimpleDataView(y), svs::ANNException ); } + + CATCH_SECTION("Load SimpleDataView from stringstream") { + auto src = svs::data::SimpleData(5, 4); + svs_test::data::set_sequential(src); + + // Save binary data to a stringstream and seek to the beginning for reading. + auto ss = std::stringstream{}; + src.save(ss); + ss.seekg(0); + + // Build a ContextFreeLoadTable from the separately obtained metadata. + auto meta = src.metadata(); + auto table = svs::lib::ContextFreeLoadTable(meta.get()); + + // Capture the pointer to the current read position inside the stream buffer. + auto* expected_ptr = svs::io::current_ptr(ss); + CATCH_REQUIRE(expected_ptr != nullptr); + + // Load the view — data_ must point into the stringstream's internal buffer. + auto view = svs::data:: + SimpleData>::load( + table, ss + ); + CATCH_REQUIRE(view.data() == expected_ptr); + CATCH_REQUIRE(view.size() == src.size()); + CATCH_REQUIRE(view.dimensions() == src.dimensions()); + CATCH_REQUIRE(svs_test::data::is_sequential(view)); + } + + CATCH_SECTION("Load SimpleDataView throws on non-memory stream") { + auto src = svs::data::SimpleData(3, 2); + auto meta = src.metadata(); + auto table = svs::lib::ContextFreeLoadTable(meta.get()); + + // A std::istringstream backed by a pre-built string is a memory stream; + // use a custom non-memory streambuf to trigger the exception path. + struct NonMemStreamBuf : std::streambuf {}; + NonMemStreamBuf buf; + std::istream non_mem_stream(&buf); + CATCH_REQUIRE_THROWS_AS( + (svs::data::SimpleData< + float, + svs::Dynamic, + svs::io::MemoryStreamAllocator>::load(table, non_mem_stream)), + svs::ANNException + ); + } } diff --git a/tests/svs/core/io/memstream.cpp b/tests/svs/core/io/memstream.cpp new file mode 100644 index 000000000..7f3aa1599 --- /dev/null +++ b/tests/svs/core/io/memstream.cpp @@ -0,0 +1,353 @@ +/* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "svs/core/io/memstream.h" + +#include "tests/utils/utils.h" + +#include "catch2/catch_test_macros.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +std::filesystem::path write_file(const std::string& name, const std::string& contents) { + auto path = svs_test::prepare_temp_directory_v2() / name; + auto out = std::ofstream(path, std::ios::binary); + out << contents; + out.close(); + return path; +} + +std::filesystem::path create_empty_file(const std::string& name) { + auto path = svs_test::prepare_temp_directory_v2() / name; + std::ofstream(path, std::ios::binary).close(); + return path; +} +} // namespace + +CATCH_TEST_CASE("mmstreambuf reads and seeks", "[core][io][mmap]") { + auto path = write_file("mmstream_data.bin", "0123456789"); + + auto buf = svs::io::mmstreambuf{}; + buf.open(path); + CATCH_REQUIRE(buf.is_open()); + CATCH_REQUIRE(buf.size() == 10); + + CATCH_REQUIRE(buf.sgetc() == '0'); + CATCH_REQUIRE(buf.sbumpc() == '0'); + CATCH_REQUIRE(buf.sgetc() == '1'); + + CATCH_REQUIRE(buf.pubseekoff(4, std::ios_base::beg, std::ios_base::in) == 4); + CATCH_REQUIRE(buf.sgetc() == '4'); + + CATCH_REQUIRE(buf.pubseekoff(-1, std::ios_base::cur, std::ios_base::in) == 3); + CATCH_REQUIRE(buf.sgetc() == '3'); + + CATCH_REQUIRE(buf.pubseekoff(-1, std::ios_base::beg, std::ios_base::in) == -1); + CATCH_REQUIRE(buf.pubseekoff(1, std::ios_base::end, std::ios_base::in) == -1); + + CATCH_REQUIRE(buf.pubseekpos(9, std::ios_base::in) == 9); + CATCH_REQUIRE(buf.sgetc() == '9'); + + CATCH_REQUIRE(buf.pubseekpos(10, std::ios_base::in) == 10); + CATCH_REQUIRE(buf.sgetc() == std::char_traits::eof()); + + buf.close(); + CATCH_REQUIRE(!buf.is_open()); + CATCH_REQUIRE(buf.size() == 0); +} + +CATCH_TEST_CASE("mmstreambuf handles empty files", "[core][io][mmap]") { + auto path = create_empty_file("mmstream_empty.bin"); + + auto buf = svs::io::mmstreambuf{}; + CATCH_REQUIRE_THROWS_AS(buf.open(path), svs::lib::ANNException); + CATCH_REQUIRE(!buf.is_open()); + CATCH_REQUIRE(buf.size() == 0); + CATCH_REQUIRE(buf.sgetc() == std::char_traits::eof()); +} + +CATCH_TEST_CASE("mmstreambuf supports move operations", "[core][io][mmap]") { + auto path = write_file("mmstream_move.bin", "abcdef"); + + auto source = svs::io::mmstreambuf{}; + source.open(path); + CATCH_REQUIRE(source.pubseekpos(2, std::ios_base::in) == 2); + + auto moved = svs::io::mmstreambuf(std::move(source)); + CATCH_REQUIRE(moved.is_open()); + CATCH_REQUIRE(!source.is_open()); + CATCH_REQUIRE(moved.sgetc() == 'c'); + + auto assigned = svs::io::mmstreambuf{}; + assigned = std::move(moved); + CATCH_REQUIRE(assigned.is_open()); + CATCH_REQUIRE(!moved.is_open()); + CATCH_REQUIRE(assigned.sgetc() == 'c'); +} + +CATCH_TEST_CASE("mmstream provides istream interface", "[core][io][mmap]") { + auto path = write_file("mmstream_stream.bin", "line1\nline2\n"); + + auto stream = svs::io::mmstream(path); + CATCH_REQUIRE(stream.is_open()); + CATCH_REQUIRE(stream.size() == 12); + + auto line = std::string{}; + std::getline(stream, line); + CATCH_REQUIRE(line == "line1"); + + stream.seekg(0, std::ios_base::beg); + std::getline(stream, line); + CATCH_REQUIRE(line == "line1"); + + stream.seekg(6, std::ios_base::beg); + std::getline(stream, line); + CATCH_REQUIRE(line == "line2"); + + stream.close(); + CATCH_REQUIRE(!stream.is_open()); + CATCH_REQUIRE(stream.eof()); +} + +CATCH_TEST_CASE("current_ptr pointer semantics", "[core][io][mmap]") { + auto path = write_file("mmstream_ptrs.bin", "ABCDE"); + auto stream = svs::io::mmstream(path); + + auto* b = svs::io::current_ptr(stream); + CATCH_REQUIRE(b != nullptr); + CATCH_REQUIRE(*b == 'A'); + + for (std::size_t i = 0; i < stream.size(); ++i) { + CATCH_REQUIRE( + svs::io::current_ptr(stream) == b + static_cast(i) + ); + stream.ignore(1); + } + + stream.seekg(0, std::ios_base::beg); + CATCH_REQUIRE(svs::io::current_ptr(stream) == b); + + stream.seekg(3, std::ios_base::beg); + CATCH_REQUIRE(svs::io::current_ptr(stream) == b + 3); + + stream.close(); + CATCH_REQUIRE(svs::io::current_ptr(stream) == nullptr); +} + +CATCH_TEST_CASE("mmstream open throws on missing file", "[core][io][mmap]") { + auto missing = svs_test::prepare_temp_directory_v2() / "mmstream_missing.bin"; + + auto stream = svs::io::mmstream{}; + CATCH_REQUIRE_THROWS_AS(stream.open(missing), svs::lib::ANNException); +} + +CATCH_TEST_CASE("is_memory_stream", "[core][io][mmap]") { + // mmstream is an in-memory stream. + auto path = write_file("mmstream_inmem.bin", "hello"); + auto mm = svs::io::mmstream(path); + CATCH_REQUIRE(svs::io::is_memory_stream(mm)); + + // std::istringstream is an in-memory stream. + auto iss = std::istringstream("world"); + CATCH_REQUIRE(svs::io::is_memory_stream(iss)); + + // std::stringstream is an in-memory stream. + auto ss = std::stringstream("test"); + CATCH_REQUIRE(svs::io::is_memory_stream(ss)); + + // svs::io::spanstream is an in-memory stream. + char buffer[] = "span"; + auto span = svs::io::ispanstream(std::span{buffer}); + CATCH_REQUIRE(svs::io::is_memory_stream(span)); + + // std::ifstream is NOT an in-memory stream. + auto ifs = std::ifstream(path); + CATCH_REQUIRE(!svs::io::is_memory_stream(ifs)); +} + +CATCH_TEST_CASE("current_ptr", "[core][io][mmap]") { + // ---- mmstream ---- + // Write 3 floats to a binary file and open it as an mmstream. + const std::array data = {1.0f, 2.0f, 3.0f}; + auto path = svs_test::prepare_temp_directory_v2() / "memstream_ptr.bin"; + { + auto out = std::ofstream(path, std::ios::binary); + out.write(reinterpret_cast(data.data()), sizeof(data)); + } + + auto mm = svs::io::mmstream(path); + auto* p0 = svs::io::current_ptr(mm); + CATCH_REQUIRE(p0 != nullptr); + CATCH_REQUIRE(*p0 == data[0]); + + // Reading one float worth of bytes advances current() by sizeof(float). + mm.ignore(static_cast(sizeof(float))); + auto* p1 = svs::io::current_ptr(mm); + CATCH_REQUIRE(p1 == p0 + 1); + CATCH_REQUIRE(*p1 == data[1]); + + // Seeking back to the start returns the same base pointer. + mm.seekg(0, std::ios_base::beg); + CATCH_REQUIRE(svs::io::current_ptr(mm) == p0); + + // After close the stream is not in-memory anymore: returns nullptr. + mm.close(); + CATCH_REQUIRE(svs::io::current_ptr(mm) == nullptr); + + // ---- std::istringstream ---- + // Build a string with known byte content, then read it back as chars. + const std::string text = "ABCDE"; + auto iss = std::istringstream(text); + auto* cp0 = svs::io::current_ptr(iss); + CATCH_REQUIRE(cp0 != nullptr); + CATCH_REQUIRE(*cp0 == 'A'); + + iss.ignore(2); + auto* cp1 = svs::io::current_ptr(iss); + CATCH_REQUIRE(cp1 == cp0 + 2); + CATCH_REQUIRE(*cp1 == 'C'); + + // ---- std::ifstream — not in-memory: must return nullptr ---- + { + auto ifs = std::ifstream(path, std::ios::binary); + CATCH_REQUIRE(svs::io::current_ptr(ifs) == nullptr); + } + + // ---- empty std::istringstream: in-memory but empty, must return nullptr ---- + { + auto empty_iss = std::istringstream(""); + CATCH_REQUIRE(svs::io::is_memory_stream(empty_iss)); + CATCH_REQUIRE(svs::io::current_ptr(empty_iss) == nullptr); + } +} + +CATCH_TEST_CASE("spanstream current_ptr", "[core][io][mmap]") { + char text[] = + "Hello, world!"; // Note: not a string literal, so we can take its address. + auto iss = svs::io::ispanstream(std::span{text}); + auto* base_ptr = svs::io::current_ptr(iss); + CATCH_REQUIRE(base_ptr != nullptr); + CATCH_REQUIRE(*base_ptr == 'H'); + + for (std::size_t i = 0; i < std::strlen(text); ++i) { + auto* current = svs::io::current_ptr(iss); + auto* expected = text + i; + auto match = (current == expected) && (*current == text[i]); + CATCH_REQUIRE(match); + iss.ignore(1); + } + + // After reading all characters, current_ptr should point to the null terminator. + CATCH_REQUIRE(svs::io::current_ptr(iss) == text + std::strlen(text)); + CATCH_REQUIRE(*svs::io::current_ptr(iss) == '\0'); +} + +CATCH_TEST_CASE("ispanstream span() getter and setter", "[core][io][mmap]") { + char text1[] = "First"; + char text2[] = "Second"; + + auto stream = svs::io::ispanstream(std::span{text1}); + + // Test getter + auto s1 = stream.rdbuf()->span(); + CATCH_REQUIRE(s1.data() == text1); + CATCH_REQUIRE(s1.size() == 6); + + // Test setter with new span + stream.rdbuf()->span(std::span{text2}); + auto s2 = stream.rdbuf()->span(); + CATCH_REQUIRE(s2.data() == text2); + CATCH_REQUIRE(s2.size() == 7); + + // Verify position resets to beginning + CATCH_REQUIRE(svs::io::current_ptr(stream) == text2); + CATCH_REQUIRE(*svs::io::current_ptr(stream) == 'S'); +} + +CATCH_TEST_CASE("ispanstream with empty span", "[core][io][mmap]") { + std::span empty; + auto stream = svs::io::ispanstream(empty); + + CATCH_REQUIRE(stream.rdbuf()->span().empty()); + CATCH_REQUIRE(svs::io::is_memory_stream(stream)); + CATCH_REQUIRE(svs::io::current_ptr(stream) == nullptr); + CATCH_REQUIRE(stream.rdbuf()->span().size() == 0); + + // Setting non-empty span should work + char text[] = "data"; + stream.rdbuf()->span(std::span{text}); + CATCH_REQUIRE(!stream.rdbuf()->span().empty()); + CATCH_REQUIRE(svs::io::current_ptr(stream) == text); +} + +CATCH_TEST_CASE("MemoryStreamAllocator", "[core][io][mmap]") { + // Create a buffer with float data + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + auto run_allocator_checks = [&](std::istream& stream) { + auto allocator = svs::io::MemoryStreamAllocator(stream); + + // Allocate 3 times, each time 2 floats + auto* p0 = allocator.allocate(2); + auto* p1 = allocator.allocate(2); + auto* p2 = allocator.allocate(2); + + // Verify pointers are contiguous + CATCH_REQUIRE(p1 == p0 + 2); + CATCH_REQUIRE(p2 == p1 + 2); + CATCH_REQUIRE(p2 == p0 + 4); + + // Verify data integrity + CATCH_REQUIRE(p0[0] == data[0]); + CATCH_REQUIRE(p0[1] == data[1]); + CATCH_REQUIRE(p1[0] == data[2]); + CATCH_REQUIRE(p1[1] == data[3]); + CATCH_REQUIRE(p2[0] == data[4]); + CATCH_REQUIRE(p2[1] == data[5]); + }; + + CATCH_SECTION("mmstream") { + auto path = svs_test::prepare_temp_directory_v2() / "allocator_contiguous.bin"; + { + auto out = std::ofstream(path, std::ios::binary); + out.write(reinterpret_cast(data), sizeof(data)); + } + auto stream = svs::io::mmstream(path); + run_allocator_checks(stream); + } + + CATCH_SECTION("ispanstream") { + auto bytes = std::span{reinterpret_cast(data), sizeof(data)}; + auto stream = svs::io::ispanstream(bytes); + run_allocator_checks(stream); + } + + CATCH_SECTION("std::stringstream") { + auto stream = std::stringstream(std::ios::in | std::ios::out | std::ios::binary); + stream.write(reinterpret_cast(data), sizeof(data)); + run_allocator_checks(stream); + } +} diff --git a/tests/svs/index/flat/flat.cpp b/tests/svs/index/flat/flat.cpp index a5e1c801f..c43f1cab3 100644 --- a/tests/svs/index/flat/flat.cpp +++ b/tests/svs/index/flat/flat.cpp @@ -21,6 +21,7 @@ #include "svs/orchestrators/exhaustive.h" // tests +#include "tests/utils/generators.h" #include "tests/utils/test_dataset.h" // catch2 @@ -81,6 +82,9 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { // Load test data auto data = Data_t::load(test_dataset::data_svs_file()); auto queries = test_dataset::queries(); + auto data_id_generator = svs_test::make_generator(size_t{0}, data.size() - 1); + auto query_id_generator = + svs_test::make_generator(size_t{0}, queries.size() - 1); // Build index Distance_t dist; @@ -144,4 +148,170 @@ CATCH_TEST_CASE("Flat Index Save and Load", "[flat][index][saveload]") { } } } + + CATCH_SECTION("Load with pointing to in-memory stream buffer") { + // We will load the FlatIndex's data as a SimpleDataView directly from the stream, + // without copying. + using ViewData_t = svs::data::SimpleData< + Data_t::element_type, + svs::Dynamic, + svs::io::MemoryStreamAllocator>; + + // Save the full index to a stringstream. + auto ss = std::stringstream{}; + index.save(ss); + ss.seekg(0); + + // Load the FlatIndex from the stream. + auto loaded_index = svs::Flat::assemble( + ss, dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // We cannot extract the pointer to the FlatIndex's internal data directly. + // To validate if the loaded Flat index is zero-copy, + // we will load a separate SimpleDataView, modify the view's data and check if it + // reflects in the loaded index's data. Load a SimpleDataView (zero-copy): its data_ + // must point into ss's buffer. We should follow the stream layout written by + // FlatIndex::assemble: + ss.seekg(0); + // First: load deserializer. + auto deserializer = svs::lib::detail::Deserializer::build(ss); + CATCH_REQUIRE(deserializer.is_native()); + // Second: load vectors data + auto view = svs::lib::load_from_stream(ss); + + CATCH_REQUIRE(view.size() == index.size()); + CATCH_REQUIRE(view.dimensions() == index.dimensions()); + // Check if view's data pointer points into the stringstream's internal buffer + // (i.e., zero-copy). + CATCH_REQUIRE(view.data() > svs::io::begin_ptr(ss)); + CATCH_REQUIRE(view.data() < svs::io::end_ptr(ss)); + // Now update the view's data and check if it reflects in the loaded index (since it + // should be zero-copy). For that we will copy a vector from queries into the view's + // data and check if the get_distance() result changes accordingly. + + // Randomly select a data point to modify. + auto data_index = svs_test::generate(data_id_generator); + // Randomly select a query to test against. + auto query_index = svs_test::generate(query_id_generator); + auto original_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + // Verify that original distance is correct before modification. + CATCH_REQUIRE( + original_distance == Catch::Approx(svs::distance::compute( + dist, + view.get_datum(data_index), + queries.get_datum(query_index) + )) + .epsilon(1e-5) + ); + // Modify the view's data by copying a query vector into it. + view.set_datum(data_index, queries.get_datum(query_index)); + // Now the distance from the modified data point to the query should be zero (or + // very close to zero due to floating point precision), since we copied the query + // vector into the data point. + auto modified_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + CATCH_REQUIRE(modified_distance == Catch::Approx(0.0).epsilon(1e-5)); + } + + CATCH_SECTION("Load with SimpleDataView pointing to memory mapped file") { + using ViewData_t = svs::data::SimpleData< + Data_t::element_type, + svs::Dynamic, + svs::io::MemoryStreamAllocator>; + + svs::lib::UniqueTempDirectory tempdir{"svs_flat_save"}; + auto index_path = tempdir.get() / "index.bin"; + auto os = std::ofstream{index_path, std::ios::binary}; + index.save(os); + os.close(); + + auto index_is = svs::io::mmstream(index_path); + // Load the FlatIndex from the stream. + auto loaded_index = svs::Flat::assemble( + index_is, dist, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), num_neighbors); + loaded_index.search(loaded_results.view(), queries.cview(), {}); + + // Compare results - should be identical + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < num_neighbors; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // We cannot extract the pointer to the FlatIndex's internal data directly. + // To validate if the loaded Flat index is zero-copy, + // we will load a separate SimpleDataView, modify the view's data and check if it + // reflects in the loaded index's data. Load a SimpleDataView (zero-copy): its data_ + // must point into ss's buffer. We should follow the stream layout written by + // FlatIndex::assemble: + auto view_is = svs::io::mmstream(index_path); + // First: load deserializer. + auto deserializer = svs::lib::detail::Deserializer::build(view_is); + CATCH_REQUIRE(deserializer.is_native()); + // Second: load vectors data + auto view = svs::lib::load_from_stream(view_is); + + CATCH_REQUIRE(view.size() == index.size()); + CATCH_REQUIRE(view.dimensions() == index.dimensions()); + // Check if view's data pointer points into the stringstream's internal buffer + // (i.e., zero-copy). + CATCH_REQUIRE(view.data() > svs::io::begin_ptr(view_is)); + CATCH_REQUIRE(view.data() < svs::io::end_ptr(view_is)); + // Now update the view's data and check if it reflects in the loaded index (since it + // should be zero-copy). For that we will copy a vector from queries into the view's + // data and check if the get_distance() result changes accordingly. + + // Randomly select a data point to modify. + auto data_index = svs_test::generate(data_id_generator); + // Randomly select a query to test against. + auto query_index = svs_test::generate(query_id_generator); + auto original_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + // Verify that original distance is correct before modification. + CATCH_REQUIRE( + original_distance == Catch::Approx(svs::distance::compute( + dist, + view.get_datum(data_index), + queries.get_datum(query_index) + )) + .epsilon(1e-5) + ); + // Modify the view's data by copying a query vector into it. + view.set_datum(data_index, queries.get_datum(query_index)); + // Now the distance from the modified data point to the query should be zero (or + // very close to zero due to floating point precision), since we copied the query + // vector into the data point. + auto modified_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + CATCH_REQUIRE(modified_distance == Catch::Approx(0.0).epsilon(1e-5)); + } } diff --git a/tests/svs/index/vamana/index.cpp b/tests/svs/index/vamana/index.cpp index bbc535c62..c63bcefe3 100644 --- a/tests/svs/index/vamana/index.cpp +++ b/tests/svs/index/vamana/index.cpp @@ -22,15 +22,18 @@ #include "svs/core/logging.h" // svs +#include "svs/extensions/vamana/scalar.h" #include "svs/index/vamana/build_params.h" #include "svs/lib/preprocessor.h" #include "svs/orchestrators/vamana.h" +#include "svs/quantization/scalar/scalar.h" // catch2 #include "catch2/catch_test_macros.hpp" #include // tests +#include "tests/utils/generators.h" #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" #include "tests/utils/vamana_reference.h" @@ -187,6 +190,8 @@ CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { const size_t N = 128; using Eltype = float; auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto data_id_generator = svs_test::make_generator(size_t{0}, data.size() - 1); + auto graph = svs::graphs::SimpleGraph(data.size(), 64); svs::distance::DistanceL2 distance_function; uint32_t entry_point = 0; @@ -205,6 +210,9 @@ CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { const size_t NUM_NEIGHBORS = 10; auto queries = test_dataset::queries(); + auto query_id_generator = + svs_test::make_generator(size_t{0}, queries.size() - 1); + auto search_params = svs::index::vamana::VamanaSearchParameters{}; search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{NUM_NEIGHBORS}; @@ -277,6 +285,194 @@ CATCH_TEST_CASE("Vamana Index Save and Load", "[vamana][index][saveload]") { } } } + + CATCH_SECTION("Load with pointing to in-memory stream buffer") { + // We will load the Vamana index's data as a SimpleDataView directly from the + // stream, without copying. + using ViewData_t = + svs::data::SimpleData>; + using Graph_t = + svs::graphs::SimpleGraph>; + + // Save the full index to a stringstream. + auto ss = std::stringstream{}; + index.save(ss); + + // Load the Vamana index from the stream. + ss.seekg(0); + auto loaded_index = svs::Vamana::assemble( + ss, distance_function, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // We cannot extract the pointer to the FlatIndex's internal data directly. + // To validate if the loaded Flat index is zero-copy, + // we will load a separate SimpleDataView, modify the view's data and check if it + // reflects in the loaded index's data. Load a SimpleDataView (zero-copy): its data_ + // must point into ss's buffer. We should follow the stream layout written by + // Vamana::assemble: + ss.seekg(0); + // First: load deserializer. + auto deserializer = svs::lib::detail::Deserializer::build(ss); + CATCH_REQUIRE(deserializer.is_native()); + // Following svs::index::vamana::auto_assemble(): + // Second: load config parameters (not strictly necessary to validate the view + // loading, but good to check that we can load the parameters as expected). + auto config_parameters = + svs::lib::load_from_stream(ss); + CATCH_REQUIRE(config_parameters.build_parameters == buildParams); + // Third: load vectors data + auto view = svs::lib::load_from_stream(ss); + CATCH_REQUIRE(view.size() == index.size()); + CATCH_REQUIRE(view.dimensions() == index.dimensions()); + // Fourth: load graph (also not strictly necessary, but good to check that we can + // load the graph as expected). + auto graph = svs::lib::load_from_stream(ss); + CATCH_REQUIRE(graph.n_nodes() == index.size()); + + // Check if view's data pointer points into the stringstream's internal buffer + // (i.e., zero-copy). + CATCH_REQUIRE(view.data() > svs::io::begin_ptr(ss)); + CATCH_REQUIRE(view.data() < svs::io::end_ptr(ss)); + // Now update the view's data and check if it reflects in the loaded index (since it + // should be zero-copy). For that we will copy a vector from queries into the view's + // data and check if the get_distance() result changes accordingly. + + // Randomly select a data point to modify. + auto data_index = svs_test::generate(data_id_generator); + // Randomly select a query to test against. + auto query_index = svs_test::generate(query_id_generator); + auto original_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + // Verify that original distance is correct before modification. + CATCH_REQUIRE( + original_distance == Catch::Approx(svs::distance::compute( + distance_function, + view.get_datum(data_index), + queries.get_datum(query_index) + )) + .epsilon(1e-5) + ); + // Modify the view's data by copying a query vector into it. + view.set_datum(data_index, queries.get_datum(query_index)); + // Now the distance from the modified data point to the query should be zero (or + // very close to zero due to floating point precision), since we copied the query + // vector into the data point. + auto modified_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + CATCH_REQUIRE(modified_distance == Catch::Approx(0.0).epsilon(1e-5)); + } + + CATCH_SECTION("Load with SimpleDataView pointing to memory mapped file") { + // We will load the Vamana index's data as a SimpleDataView directly from the + // stream, without copying. + using ViewData_t = + svs::data::SimpleData>; + using Graph_t = + svs::graphs::SimpleGraph>; + + // Save the full index to a file + svs::lib::UniqueTempDirectory tempdir{"svs_flat_save"}; + auto index_path = tempdir.get() / "index.bin"; + auto os = std::ofstream{index_path, std::ios::binary}; + index.save(os); + os.close(); + + auto index_is = svs::io::mmstream(index_path); + + // Load the Vamana index from the stream. + auto loaded_index = svs::Vamana::assemble( + index_is, distance_function, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // We cannot extract the pointer to the FlatIndex's internal data directly. + // To validate if the loaded Flat index is zero-copy, + // we will load a separate SimpleDataView, modify the view's data and check if it + // reflects in the loaded index's data. Load a SimpleDataView (zero-copy): its data_ + // must point into ss's buffer. We should follow the stream layout written by + // Vamana::assemble: + auto view_is = svs::io::mmstream(index_path); + // First: load deserializer. + auto deserializer = svs::lib::detail::Deserializer::build(view_is); + CATCH_REQUIRE(deserializer.is_native()); + // Following svs::index::vamana::auto_assemble(): + // Second: load config parameters (not strictly necessary to validate the view + // loading, but good to check that we can load the parameters as expected). + auto config_parameters = + svs::lib::load_from_stream(view_is); + CATCH_REQUIRE(config_parameters.build_parameters == buildParams); + // Third: load vectors data + auto view = svs::lib::load_from_stream(view_is); + CATCH_REQUIRE(view.size() == index.size()); + CATCH_REQUIRE(view.dimensions() == index.dimensions()); + // Fourth: load graph (also not strictly necessary, but good to check that we can + // load the graph as expected). + auto graph = svs::lib::load_from_stream(view_is); + CATCH_REQUIRE(graph.n_nodes() == index.size()); + + // Check if view's data pointer points into the stringstream's internal buffer + // (i.e., zero-copy). + CATCH_REQUIRE(view.data() > svs::io::begin_ptr(view_is)); + CATCH_REQUIRE(view.data() < svs::io::end_ptr(view_is)); + // Now update the view's data and check if it reflects in the loaded index (since it + // should be zero-copy). For that we will copy a vector from queries into the view's + // data and check if the get_distance() result changes accordingly. + + // Randomly select a data point to modify. + auto data_index = svs_test::generate(data_id_generator); + // Randomly select a query to test against. + auto query_index = svs_test::generate(query_id_generator); + auto original_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + // Verify that original distance is correct before modification. + CATCH_REQUIRE( + original_distance == Catch::Approx(svs::distance::compute( + distance_function, + view.get_datum(data_index), + queries.get_datum(query_index) + )) + .epsilon(1e-5) + ); + // Modify the view's data by copying a query vector into it. + view.set_datum(data_index, queries.get_datum(query_index)); + // Now the distance from the modified data point to the query should be zero (or + // very close to zero due to floating point precision), since we copied the query + // vector into the data point. + auto modified_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + CATCH_REQUIRE(modified_distance == Catch::Approx(0.0).epsilon(1e-5)); + } } CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") { @@ -358,3 +554,128 @@ CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") ); } } + +CATCH_TEST_CASE("Vamana Index Save and Load SQ", "[vamana][index][saveload][scalar]") { + using namespace svs::quantization::scalar; + + const size_t N = 128; + using Eltype = std::int8_t; + using Data_t = SQDataset; + using ViewData_t = + SQDataset>; + using ViewGraph_t = + svs::graphs::SimpleGraph>; + + auto data = Data_t::compress( + svs::data::SimpleData::load(test_dataset::data_svs_file()) + ); + auto data_id_generator = svs_test::make_generator(size_t{0}, data.size() - 1); + + svs::distance::DistanceL2 distance_function; + auto threadpool = svs::threads::DefaultThreadPool(1); + + // Build the VamanaIndex with the test logger + svs::index::vamana::VamanaBuildParameters buildParams(1.2, 64, 10, 20, 10, true); + auto index = svs::Vamana::build( + buildParams, std::move(data), distance_function, std::move(threadpool) + ); + + const size_t NUM_NEIGHBORS = 10; + auto queries = test_dataset::queries(); + auto query_id_generator = + svs_test::make_generator(size_t{0}, queries.size() - 1); + + auto search_params = svs::index::vamana::VamanaSearchParameters{}; + search_params.buffer_config_ = svs::index::vamana::SearchBufferConfig{NUM_NEIGHBORS}; + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search(results.view(), queries.cview(), search_params); + + CATCH_SECTION("Load with SQ pointing to in-memory stream buffer") { + // We will load the Vamana index's data as a SQDataset 'view' directly from the + // stream, without copying. + // Save the full index to a stringstream. + auto ss = std::stringstream{}; + index.save(ss); + + // Load the Vamana index from the stream. + ss.seekg(0); + auto loaded_index = svs::Vamana::assemble( + ss, distance_function, svs::threads::DefaultThreadPool(1) + ); + + CATCH_REQUIRE(loaded_index.size() == index.size()); + CATCH_REQUIRE(loaded_index.dimensions() == index.dimensions()); + + auto loaded_results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + loaded_index.search(loaded_results.view(), queries.cview(), search_params); + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + CATCH_REQUIRE(loaded_results.index(q, i) == results.index(q, i)); + CATCH_REQUIRE( + loaded_results.distance(q, i) == + Catch::Approx(results.distance(q, i)).epsilon(1e-5) + ); + } + } + + // We cannot extract the pointer to the FlatIndex's internal data directly. + // To validate if the loaded Flat index is zero-copy, + // we will load a separate SimpleDataView, modify the view's data and check if it + // reflects in the loaded index's data. Load a SimpleDataView (zero-copy): its data_ + // must point into ss's buffer. We should follow the stream layout written by + // Vamana::assemble: + ss.seekg(0); + // First: load deserializer. + auto deserializer = svs::lib::detail::Deserializer::build(ss); + CATCH_REQUIRE(deserializer.is_native()); + // Following svs::index::vamana::auto_assemble(): + // Second: load config parameters (not strictly necessary to validate the view + // loading, but good to check that we can load the parameters as expected). + auto config_parameters = + svs::lib::load_from_stream(ss); + CATCH_REQUIRE(config_parameters.build_parameters == buildParams); + // Third: load vectors data + auto view = svs::lib::load_from_stream(ss); + CATCH_REQUIRE(view.size() == index.size()); + CATCH_REQUIRE(view.dimensions() == index.dimensions()); + // Fourth: load graph (also not strictly necessary, but good to check that we can + // load the graph as expected). + auto graph = svs::lib::load_from_stream(ss); + CATCH_REQUIRE(graph.n_nodes() == index.size()); + + // Check if view's data pointer points into the stringstream's internal buffer + // (i.e., zero-copy). + CATCH_REQUIRE(view.get_datum(0).data() > svs::io::begin_ptr(ss)); + CATCH_REQUIRE(view.get_datum(0).data() < svs::io::end_ptr(ss)); + // Now update the view's data and check if it reflects in the loaded index (since it + // should be zero-copy). For that we will copy a vector from queries into the view's + // data and check if the get_distance() result changes accordingly. + + // Randomly select a data point to modify. + auto data_index = svs_test::generate(data_id_generator); + // Randomly select a query to test against. + auto query_index = svs_test::generate(query_id_generator); + auto original_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + // Verify that original distance is correct before modification. + CATCH_REQUIRE( + original_distance == + Catch::Approx( + svs::index::vamana::extensions::get_distance_ext( + view, distance_function, data_index, queries.get_datum(query_index) + ) + ) + .epsilon(1e-5) + ); + // Modify the view's data by copying a query vector into it. + view.set_datum(data_index, queries.get_datum(query_index)); + // Now the distance from the modified data point to the query should be zero (or + // very close to zero due to floating point precision), since we copied the query + // vector into the data point. + auto modified_distance = + loaded_index.get_distance(data_index, queries.get_datum(query_index)); + CATCH_REQUIRE(modified_distance == Catch::Approx(0.0).epsilon(1e-5)); + } +} diff --git a/tests/svs/quantization/scalar/scalar.cpp b/tests/svs/quantization/scalar/scalar.cpp index 0b8271d9d..ef58a7755 100644 --- a/tests/svs/quantization/scalar/scalar.cpp +++ b/tests/svs/quantization/scalar/scalar.cpp @@ -28,6 +28,9 @@ // catch2 #include "catch2/catch_test_macros.hpp" +// stdlib +#include + namespace scalar = svs::quantization::scalar; template void test_sq_top() { @@ -254,6 +257,44 @@ CATCH_TEST_CASE("Testing SQDataset", "[quantization][scalar]") { test_sq_top(); test_sq_top(); } + + CATCH_SECTION("Load SQDataset view pointed to stringstream") { + using sq_dtype = std::int8_t; + auto src = scalar::SQDataset(5, 4); + svs_test::data::set_sequential(src); + src.set_scale(0.5F); + src.set_bias(-3.0F); + + auto ss = std::stringstream{}; + src.save(ss); + ss.seekg(0); + + auto meta = src.metadata(); + auto table = svs::lib::ContextFreeLoadTable(meta.get()); + + // Capture the pointer to the current read position inside the stream buffer. + auto* expected_ptr = svs::io::current_ptr(ss); + CATCH_REQUIRE(expected_ptr != nullptr); + + auto view = scalar::SQDataset< + sq_dtype, + svs::Dynamic, + svs::io::MemoryStreamAllocator>::load(table, ss); + CATCH_REQUIRE(view.size() == src.size()); + CATCH_REQUIRE(view.dimensions() == src.dimensions()); + CATCH_REQUIRE(view.get_scale() == src.get_scale()); + CATCH_REQUIRE(view.get_bias() == src.get_bias()); + CATCH_REQUIRE(view.get_datum(0).data() == expected_ptr); + + for (size_t i = 0; i < src.size(); ++i) { + auto a = src.get_datum(i); + auto b = view.get_datum(i); + CATCH_REQUIRE(a.size() == b.size()); + for (size_t j = 0; j < a.size(); ++j) { + CATCH_REQUIRE(a[j] == b[j]); + } + } + } } CATCH_TEST_CASE(