diff --git a/.github/workflows/ut.yaml b/.github/workflows/ut.yaml index 9ce0650..03e2089 100644 --- a/.github/workflows/ut.yaml +++ b/.github/workflows/ut.yaml @@ -50,8 +50,19 @@ jobs: conan profile detect --force \ && conan remote add default-conan-local2 https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local2 --force - name: Build & Run + env: + KNOWHERE_REQUIRE_IO_URING: "1" run: | - make test BUILD_TYPE=Release CONAN_EXTRA='-o \&:with_asan=True -s compiler.version=12' + make conan BUILD_TYPE=Release CONAN_EXTRA='-o \&:with_ut=True -o \&:with_asan=True -s compiler.version=12 -s compiler.cppstd=20' + sudo apt update && sudo apt install -y liburing-dev + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=${{ github.workspace }}/build/conan_toolchain.cmake \ + -DWITH_COMMON_UT=ON \ + -DENABLE_SYNCPOINT=ON \ + -Wno-dev + cmake --build build -j$(nproc) + . build/conanrun.sh && ctest --test-dir build --output-on-failure - name: Save Conan Packages uses: actions/cache/save@v4 with: @@ -95,3 +106,50 @@ jobs: with: path: ~/.conan2 key: milvus-common-macos-15-${{ hashFiles('conanfile.py')}} + + ut-no-liburing: + name: UT without liburing on ubuntu-22.04 + runs-on: ubuntu-22.04 + timeout-minutes: 240 + env: + CC: gcc-12 + CXX: g++-12 + strategy: + fail-fast: false + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + - name: Install Dependency + run: | + sudo apt update && sudo apt install -y cmake libaio-dev g++-12 gcc-12 python3 python3-pip \ + && pip3 install conan==2.25.1 + - name: Restore Conan Packages + uses: actions/cache@v4 + with: + path: ~/.conan2 + key: milvus-common-no-liburing-ubuntu-22.04-${{ hashFiles('conanfile.py')}} + restore-keys: milvus-common-no-liburing-ubuntu-22.04- + - name: Setup Conan + run: | + conan profile detect --force \ + && conan remote add default-conan-local2 https://milvus01.jfrog.io/artifactory/api/conan/default-conan-local2 --force + - name: Build & Run + run: | + make conan BUILD_TYPE=Release CONAN_EXTRA='-o \&:with_ut=True -o \&:with_asan=True -s compiler.version=12 -s compiler.cppstd=20' + cmake -S . -B build-no-liburing \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=${{ github.workspace }}/build/conan_toolchain.cmake \ + -DWITH_COMMON_UT=ON \ + -DENABLE_SYNCPOINT=ON \ + -DURING_INCLUDE_DIR=URING_INCLUDE_DIR-NOTFOUND \ + -DURING_LIBRARY=URING_LIBRARY-NOTFOUND \ + -Wno-dev + cmake --build build-no-liburing -j$(nproc) + . build/conanrun.sh && ctest --test-dir build-no-liburing --output-on-failure + - name: Save Conan Packages + uses: actions/cache/save@v4 + with: + path: ~/.conan2 + key: milvus-common-no-liburing-ubuntu-22.04-${{ hashFiles('conanfile.py')}} diff --git a/CMakeLists.txt b/CMakeLists.txt index 4d1588c..807c0f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,10 +115,33 @@ else() list(APPEND COMMON_LINKER_LIBS ${LIBAIO_LIBRARY}) endif() +set(MILVUS_COMMON_WITH_IO_URING OFF) +find_path(URING_INCLUDE_DIR liburing.h) +find_library(URING_LIBRARY NAMES uring) +if(URING_INCLUDE_DIR AND URING_LIBRARY) + message(STATUS "liburing found: ${URING_LIBRARY}, enabling io_uring support") + list(APPEND COMMON_LINKER_LIBS ${URING_LIBRARY}) + list(APPEND MILVUS_COMMON_EXTRA_INCLUDE_DIRS ${URING_INCLUDE_DIR}) + set(MILVUS_COMMON_WITH_IO_URING ON) +else() + list(REMOVE_ITEM SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/src/knowhere/uring_context_pool.cc) + message(STATUS "liburing not found, disabling io_uring support (keeping completion reader shim)") +endif() + add_library(milvus-common SHARED ${SRC_FILES}) +target_compile_features(milvus-common PRIVATE cxx_std_20) # Use std::shared_ptr etc. instead of opentelemetry::nostd:: equivalents target_compile_definitions(milvus-common PUBLIC OPENTELEMETRY_STL_VERSION=2017) +if(LIBAIO_LIBRARY) + target_compile_definitions(milvus-common PUBLIC MILVUS_COMMON_WITH_LIBAIO) +endif() +if(MILVUS_COMMON_WITH_IO_URING) + target_compile_definitions(milvus-common PUBLIC WITH_IO_URING) +endif() +if(MILVUS_COMMON_EXTRA_INCLUDE_DIRS) + target_include_directories(milvus-common PUBLIC ${MILVUS_COMMON_EXTRA_INCLUDE_DIRS}) +endif() if(APPLE) target_compile_definitions(milvus-common PUBLIC BOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED) @@ -141,7 +164,12 @@ target_link_libraries(milvus-common PUBLIC OpenMP::OpenMP_CXX) # thread_pool.cc uses openblas-specific API (openblas_set_num_threads) under OPENBLAS_OS_LINUX if(CMAKE_SYSTEM_NAME STREQUAL "Linux") find_package(OpenBLAS CONFIG REQUIRED) - target_link_libraries(milvus-common PUBLIC OpenBLAS::OpenBLAS) + if(TARGET OpenBLAS::OpenBLAS) + target_link_libraries(milvus-common PUBLIC OpenBLAS::OpenBLAS) + else() + target_include_directories(milvus-common PUBLIC ${OpenBLAS_INCLUDE_DIRS}) + target_link_libraries(milvus-common PUBLIC ${OpenBLAS_LIBRARIES}) + endif() endif() if(WITH_COMMON_UT) diff --git a/include/knowhere/aio_context_pool.h b/include/knowhere/aio_context_pool.h index 08d3863..3acc985 100644 --- a/include/knowhere/aio_context_pool.h +++ b/include/knowhere/aio_context_pool.h @@ -2,11 +2,19 @@ #include +#include +#include #include +#include +#include +#include #include #include +#include +#include #include "log/Log.h" +#include "syncpoint/sync_point.h" constexpr size_t default_max_nr = 65536; constexpr size_t default_max_events = 128; @@ -14,6 +22,12 @@ constexpr size_t default_pool_size = default_max_nr / default_max_events; class AioContextPool { public: + enum class State { + Healthy, + Unusable, + Stopped, + }; + AioContextPool(const AioContextPool&) = delete; AioContextPool& @@ -22,61 +36,334 @@ class AioContextPool { AioContextPool(AioContextPool&&) noexcept = delete; AioContextPool& - operator==(AioContextPool&&) noexcept = delete; + operator=(AioContextPool&&) noexcept = delete; size_t max_events_per_ctx() { return max_events_; } - void + size_t + created_context_count() const { + std::scoped_lock lk(ctx_mtx_); + return ctx_bak_.size(); + } + + bool + IsUsable() const { + std::scoped_lock lk(ctx_mtx_); + return state_ == State::Healthy && ctx_bak_.size() == num_ctx_ && !ctx_bak_.empty(); + } + + bool push(io_context_t ctx) { + if (ctx == nullptr) { + LOG_WARN("AioContextPool push gets null context"); + return false; + } + + bool should_destroy = false; + bool released = false; + bool notify_all = false; { std::scoped_lock lk(ctx_mtx_); - ctx_q_.push(ctx); + if (owned_ctxs_.find(ctx) == owned_ctxs_.end()) { + LOG_WARN("AioContextPool rejects unknown context: {}", static_cast(ctx)); + return false; + } + + if (checked_out_ctxs_.find(ctx) == checked_out_ctxs_.end()) { + LOG_WARN("AioContextPool rejects context not checked out: {}", static_cast(ctx)); + return false; + } + if (state_ != State::Healthy) { + RemoveTrackedContextLocked(ctx); + should_destroy = true; + notify_all = true; + } else { + try { + ctx_q_.push(ctx); + checked_out_ctxs_.erase(ctx); + released = true; + } catch (const std::exception& e) { + LOG_ERROR("AioContextPool failed to requeue context {}: {}", static_cast(ctx), e.what()); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + should_destroy = true; + notify_all = true; + } catch (...) { + LOG_ERROR("AioContextPool failed to requeue context {}: unknown exception", + static_cast(ctx)); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + should_destroy = true; + notify_all = true; + } + } + } + + if (should_destroy) { + DestroyContextNoThrow(ctx, "releasing"); + if (notify_all) { + ctx_cv_.notify_all(); + } + return false; } + ctx_cv_.notify_one(); + return released; } io_context_t pop() { std::unique_lock lk(ctx_mtx_); - if (stop_) { + ctx_cv_.wait(lk, [this] { return state_ != State::Healthy || !ctx_q_.empty(); }); + if (state_ != State::Healthy) { return nullptr; } - ctx_cv_.wait(lk, [this] { return ctx_q_.size(); }); - if (stop_) { + auto ret = ctx_q_.front(); + try { + const auto inserted = checked_out_ctxs_.insert(ret).second; + if (!inserted) { + LOG_ERROR("AioContextPool detected duplicate checked-out context: {}", static_cast(ret)); + MarkUnusableLocked(); + lk.unlock(); + ctx_cv_.notify_all(); + return nullptr; + } + } catch (const std::exception& e) { + LOG_ERROR("AioContextPool failed to mark context checked out: {}", e.what()); + MarkUnusableLocked(); + lk.unlock(); + ctx_cv_.notify_all(); + return nullptr; + } catch (...) { + LOG_ERROR("AioContextPool failed to mark context checked out: unknown exception"); + MarkUnusableLocked(); + lk.unlock(); + ctx_cv_.notify_all(); return nullptr; } - auto ret = ctx_q_.front(); ctx_q_.pop(); return ret; } + void + Shutdown() { + { + std::scoped_lock lk(ctx_mtx_); + if (state_ == State::Stopped) { + return; + } + state_ = State::Stopped; + } + ctx_cv_.notify_all(); + } + static bool InitGlobalAioPool(size_t num_ctx, size_t max_events); static std::shared_ptr GetGlobalAioPool(); + static bool + InitGlobalAioPoolWithValidation(size_t num_ctx, size_t max_events); + + static std::shared_ptr + GetGlobalAioPoolDirect(); + + static void + ResetGlobalForTest(); + + bool + ResetCheckedOut(io_context_t ctx) { + if (ctx == nullptr) { + LOG_WARN("AioContextPool reset gets null context"); + return false; + } + + { + std::scoped_lock lk(ctx_mtx_); + if (owned_ctxs_.find(ctx) == owned_ctxs_.end()) { + LOG_WARN("AioContextPool rejects reset for unknown context: {}", static_cast(ctx)); + return false; + } + if (checked_out_ctxs_.find(ctx) == checked_out_ctxs_.end()) { + LOG_WARN("AioContextPool rejects reset for context not checked out: {}", static_cast(ctx)); + return false; + } + } + + if (!DestroyContextNoThrow(ctx, "resetting")) { + { + std::scoped_lock lk(ctx_mtx_); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + } + ctx_cv_.notify_all(); + return false; + } + + io_context_t new_ctx = 0; + int ret = 0; +#ifdef ENABLE_SYNCPOINT + TEST_SYNC_POINT_CALLBACK("AioContextPool::ResetCheckedOut:BeforeSetup", &ret); +#endif + if (ret == 0) { + ret = io_setup(max_events_, &new_ctx); + } + if (ret == 0) { + bool reusable = false; + bool should_destroy_new_ctx = false; + { + std::scoped_lock lk(ctx_mtx_); + auto iter = std::find(ctx_bak_.begin(), ctx_bak_.end(), ctx); + if (state_ == State::Healthy && iter != ctx_bak_.end()) { + owned_ctxs_.erase(ctx); + checked_out_ctxs_.erase(ctx); + try { + const auto inserted = owned_ctxs_.insert(new_ctx).second; + if (!inserted) { + LOG_ERROR("AioContextPool replacement context already exists: {}", + static_cast(new_ctx)); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + should_destroy_new_ctx = true; + } else { + try { + ctx_q_.push(new_ctx); + } catch (...) { + owned_ctxs_.erase(new_ctx); + throw; + } + *iter = new_ctx; + reusable = true; + } + } catch (const std::exception& e) { + LOG_ERROR("AioContextPool failed to install replacement context: {}", e.what()); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + should_destroy_new_ctx = true; + } catch (...) { + LOG_ERROR("AioContextPool failed to install replacement context: unknown exception"); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + should_destroy_new_ctx = true; + } + } else { + RemoveTrackedContextLocked(ctx); + should_destroy_new_ctx = true; + } + } + if (reusable) { + ctx_cv_.notify_one(); + return true; + } + if (should_destroy_new_ctx) { + DestroyContextNoThrow(new_ctx, "discarding replacement"); + } + ctx_cv_.notify_all(); + return false; + } + + LOG_ERROR("io_setup failed while resetting AIO context with ret={}, errno={}: {}", ret, -ret, ::strerror(-ret)); + { + std::scoped_lock lk(ctx_mtx_); + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + } + ctx_cv_.notify_all(); + return false; + } + + bool + RetireCheckedOut(io_context_t ctx) { + if (ctx == nullptr) { + LOG_WARN("AioContextPool retire gets null context"); + return false; + } + + { + std::scoped_lock lk(ctx_mtx_); + if (owned_ctxs_.find(ctx) == owned_ctxs_.end()) { + LOG_WARN("AioContextPool rejects retire for unknown context: {}", static_cast(ctx)); + return false; + } + if (checked_out_ctxs_.find(ctx) == checked_out_ctxs_.end()) { + LOG_WARN("AioContextPool rejects retire for context not checked out: {}", static_cast(ctx)); + return false; + } + + RemoveTrackedContextLocked(ctx); + MarkUnusableLocked(); + } + + const bool destroyed = DestroyContextNoThrow(ctx, "retiring"); + ctx_cv_.notify_all(); + return destroyed; + } + ~AioContextPool() { - stop_ = true; + Shutdown(); + std::unordered_set checked_out; + { + std::scoped_lock lk(ctx_mtx_); + checked_out = checked_out_ctxs_; + } + if (!checked_out.empty()) { + LOG_WARN("AioContextPool shutdown with {} checked-out contexts still not returned", checked_out.size()); + } for (auto ctx : ctx_bak_) { - io_destroy(ctx); + if (checked_out.find(ctx) == checked_out.end()) { + DestroyContextNoThrow(ctx, "destroying during shutdown"); + } } - ctx_cv_.notify_all(); } private: + void + MarkUnusableLocked() noexcept { + if (state_ == State::Healthy) { + state_ = State::Unusable; + } + } + + void + RemoveTrackedContextLocked(io_context_t ctx) { + checked_out_ctxs_.erase(ctx); + owned_ctxs_.erase(ctx); + auto iter = std::find(ctx_bak_.begin(), ctx_bak_.end(), ctx); + if (iter != ctx_bak_.end()) { + ctx_bak_.erase(iter); + } + } + + bool + DestroyContextNoThrow(io_context_t ctx, const char* action) noexcept { + int ret = io_destroy(ctx); +#ifdef ENABLE_SYNCPOINT + TEST_SYNC_POINT_CALLBACK("AioContextPool::DestroyContext:AfterDestroy", &ret); +#endif + if (ret != 0) { + const int err = ret < 0 ? -ret : ret; + LOG_ERROR("io_destroy failed while {} AIO context {} with ret={}, errno={}: {}", action, + static_cast(ctx), ret, err, ::strerror(err)); + return false; + } + return true; + } + std::vector ctx_bak_; std::queue ctx_q_; - std::mutex ctx_mtx_; + std::unordered_set owned_ctxs_; + std::unordered_set checked_out_ctxs_; + mutable std::mutex ctx_mtx_; std::condition_variable ctx_cv_; - bool stop_ = false; + State state_ = State::Healthy; size_t num_ctx_; size_t max_events_; - static size_t global_aio_pool_size; - static size_t global_aio_max_events; + static std::atomic global_aio_pool_size; + static std::atomic global_aio_max_events; static std::mutex global_aio_pool_mut; AioContextPool(size_t num_ctx, size_t max_events) : num_ctx_(num_ctx), max_events_(max_events) { @@ -85,16 +372,22 @@ class AioContextPool { int ret = -1; for (int retry = 0; (ret = io_setup(max_events, &ctx)) != 0 && retry < 5; ++retry) { if (-ret != EAGAIN) { - LOG_ERROR("Unknown error occur in io_setup, errno: %d, %s", -ret, ::strerror(-ret)); + LOG_ERROR("Unknown error occur in io_setup, errno: {}, {}", -ret, ::strerror(-ret)); } } if (ret != 0) { - LOG_ERROR("io_setup() failed; returned %d, errno=%d: %s", ret, -ret, ::strerror(-ret)); + LOG_ERROR("io_setup() failed; returned {}, errno={}: {}", ret, -ret, ::strerror(-ret)); } else { - LOG_DEBUG("allocating ctx: %p", (void*)ctx); + LOG_DEBUG("allocating ctx: {}", static_cast(ctx)); ctx_q_.push(ctx); ctx_bak_.push_back(ctx); + owned_ctxs_.insert(ctx); } } + if (ctx_bak_.size() != num_ctx_) { + state_ = State::Unusable; + LOG_ERROR("AioContextPool initialization failed: created {} of {} requested contexts", ctx_bak_.size(), + num_ctx_); + } } }; diff --git a/include/knowhere/io_completion_reader.h b/include/knowhere/io_completion_reader.h new file mode 100644 index 0000000..324732e --- /dev/null +++ b/include/knowhere/io_completion_reader.h @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "knowhere/io_context_pool.h" +#include "knowhere/io_span.h" + +template +using IOCompletionReaderSpan = knowhere_compat::span; + +// Worker-local, single-threaded completion reader. Not thread-safe. +class IOCompletionReader { + public: + using RequestId = uint64_t; + + struct Completion { + RequestId request_id = 0; + bool ok = false; + }; + + IOCompletionReader(int fd, std::shared_ptr io_pool); + + IOCompletionReader(const IOCompletionReader&) = delete; + IOCompletionReader& + operator=(const IOCompletionReader&) = delete; + + IOCompletionReader(IOCompletionReader&& other) noexcept; + IOCompletionReader& + operator=(IOCompletionReader&& other) noexcept; + + ~IOCompletionReader(); + + RequestId + Submit(IOCompletionReaderSpan buffers, size_t size, IOCompletionReaderSpan offsets); + + Completion + WaitCompleted(); + + std::vector + PollCompleted(); + + bool + IsReady() const; + + private: + struct RequestState { + size_t remaining = 0; + size_t expected_size = 0; + bool ok = true; + }; + + std::optional + ProcessCqe(struct io_uring_cqe* cqe); + + void + WaitOneCompletion(); + + void + ProcessAvailableCompletions(); + + size_t + PendingOperationCount() const; + + void + DrainOutstandingNoThrow() noexcept; + + bool + DrainOutstandingBlockingNoThrow() noexcept; + + void + DrainOutstanding(); + + void + DrainOutstanding(RequestId request_id); + + bool + TryDrainOutstanding(RequestId request_id) noexcept; + + void + CleanupFailedSubmit(RequestId request_id, size_t prepared, size_t submitted); + + void + FailPendingRequests(RequestId excluded_request_id); + + void + RemoveReadyCompletion(RequestId request_id); + + bool + ResetHandleUring(); + + void + ReleaseHandle(); + + int fd_ = -1; + std::shared_ptr io_pool_; + IOContextHandle handle_; + RequestId next_request_id_ = 1; + std::unordered_map pending_requests_; + std::deque ready_completions_; +}; diff --git a/include/knowhere/io_context_pool.h b/include/knowhere/io_context_pool.h new file mode 100644 index 0000000..233674d --- /dev/null +++ b/include/knowhere/io_context_pool.h @@ -0,0 +1,189 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef MILVUS_COMMON_WITH_LIBAIO +#include +#endif + +#ifdef WITH_IO_URING +#include +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO +#include "knowhere/aio_context_pool.h" +#endif + +#ifdef WITH_IO_URING +#include "knowhere/uring_context_pool.h" +#endif + +class IOContextPool; + +enum class IOBackend { + UNKNOWN, + IO_URING, + AIO, +}; + +enum class IOContextReleaseDisposition { + Clean, + Dirty, + Retire, +}; + +constexpr size_t default_io_ctx_pool_size = 65536 / 128; + +struct IOContextPoolConfig { +#ifdef MILVUS_COMMON_WITH_LIBAIO + size_t num_ctx = default_pool_size; +#else + size_t num_ctx = default_io_ctx_pool_size; +#endif + size_t max_events = 128; +}; + +// Lifecycle invariants: +// - Healthy backend pools account every context as either available or leased. +// - Dirty/retired contexts never return to the available queue without a successful reset. +// - If reset/replacement cannot restore capacity, backend pools become fail-fast. +// - IOContextHandle is the single lease token; moving or destroying it releases the old lease. +struct IOContextHandle { + IOContextHandle() = default; + ~IOContextHandle(); + + IOContextHandle(const IOContextHandle&) = delete; + IOContextHandle& + operator=(const IOContextHandle&) = delete; + + IOContextHandle(IOContextHandle&& other) noexcept; + + IOContextHandle& + operator=(IOContextHandle&& other) noexcept; + + bool + HasContext() const noexcept; + + IOBackend backend = IOBackend::UNKNOWN; +#ifdef WITH_IO_URING + struct io_uring* uring = nullptr; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + io_context_t aio = nullptr; +#endif + + private: + friend class IOContextPool; + + void + ClearNoRelease() noexcept; + + void + ReleaseNoThrow() noexcept; + + std::shared_ptr owner_; +}; + +class IOContextPool : public std::enable_shared_from_this { + public: + IOContextPool(const IOContextPool&) = delete; + IOContextPool& + operator=(const IOContextPool&) = delete; + + static bool + InitGlobal(const IOContextPoolConfig& cfg); + + static std::shared_ptr + GetGlobal(); + + static std::shared_ptr + GetGlobalOrInit(const IOContextPoolConfig& cfg = IOContextPoolConfig{}); + + static void + ResetGlobalForTest(); + + IOBackend + Backend() const; + + std::string + BackendName() const; + + bool + IsInitialized() const; + + size_t + MaxEventsPerCtx() const; + + IOContextHandle + Pop(); + + bool + Push(IOContextHandle&& handle); + + bool + Reset(IOContextHandle&& handle); + + bool + Release(IOContextHandle&& handle, IOContextReleaseDisposition disposition); + +#ifdef WITH_IO_URING + struct io_uring* + PopUring(); + + bool + PushUring(struct io_uring* ring); + + bool + ResetUring(struct io_uring* ring); + + bool + RetireUring(struct io_uring* ring); + + std::shared_ptr + GetUringPoolForLegacy() const; +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO + io_context_t + PopAio(); + + bool + PushAio(io_context_t ctx); + + bool + ResetAio(io_context_t ctx); + + bool + RetireAio(io_context_t ctx); + + std::shared_ptr + GetAioPoolForLegacy() const; +#endif + + private: + IOContextPool() = default; + +#ifdef WITH_IO_URING + static bool + TryInitUring(const IOContextPoolConfig& cfg, const std::shared_ptr& io_pool); +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO + static bool + TryInitAio(const IOContextPoolConfig& cfg, const std::shared_ptr& io_pool); +#endif + + IOBackend backend_ = IOBackend::UNKNOWN; + size_t num_ctx_ = 0; + size_t max_events_per_ctx_ = 0; +#ifdef WITH_IO_URING + std::shared_ptr uring_pool_; +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO + std::shared_ptr aio_pool_; +#endif +}; diff --git a/include/knowhere/io_reader.h b/include/knowhere/io_reader.h new file mode 100644 index 0000000..b8f225c --- /dev/null +++ b/include/knowhere/io_reader.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include "knowhere/io_context_pool.h" +#include "knowhere/io_span.h" + +template +using IOReaderSpan = knowhere_compat::span; + +class IOReader { + public: + IOReader(); + + explicit IOReader(int fd); + + IOReader(int fd, std::shared_ptr io_pool); + + explicit IOReader(std::shared_ptr io_pool); + + bool + Read(IOReaderSpan buf, size_t size, IOReaderSpan offsets) const; + + std::future + ReadAsync(std::vector&& buffers, size_t size, std::vector&& offsets) const; + + IOBackend + Backend() const; + + std::string + BackendName() const; + + bool + IsReady() const; + + private: + int fd_ = -1; + std::shared_ptr io_pool_; +}; diff --git a/include/knowhere/io_span.h b/include/knowhere/io_span.h new file mode 100644 index 0000000..76e5106 --- /dev/null +++ b/include/knowhere/io_span.h @@ -0,0 +1,47 @@ +#pragma once + +#include + +namespace knowhere_compat { +template +class span { + public: + span(T* data, size_t size) : data_(data), size_(size) { + } + + T& + operator[](size_t idx) const { + return data_[idx]; + } + + size_t + size() const { + return size_; + } + + bool + empty() const { + return size_ == 0; + } + + T* + data() const { + return data_; + } + + T* + begin() const { + return data_; + } + + T* + end() const { + return data_ + size_; + } + + private: + T* data_; + size_t size_; +}; +} // namespace knowhere_compat + diff --git a/include/knowhere/uring_context_pool.h b/include/knowhere/uring_context_pool.h new file mode 100644 index 0000000..0203bb5 --- /dev/null +++ b/include/knowhere/uring_context_pool.h @@ -0,0 +1,217 @@ +#pragma once + +#ifdef WITH_IO_URING + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "log/Log.h" + +constexpr size_t default_uring_max_entries = 128; + +class UringContextPool { + public: + enum class State { + Healthy, + Unusable, + Stopped, + }; + + UringContextPool(const UringContextPool&) = delete; + + UringContextPool& + operator=(const UringContextPool&) = delete; + + UringContextPool(UringContextPool&&) noexcept = delete; + + UringContextPool& + operator=(UringContextPool&&) noexcept = delete; + + size_t + max_entries_per_ctx() { + return max_entries_; + } + + bool + IsUsable() const { + std::scoped_lock lk(ring_mtx_); + return state_ == State::Healthy && ring_bak_.size() == num_ctx_ && !ring_bak_.empty(); + } + + size_t + created_context_count() const { + std::scoped_lock lk(ring_mtx_); + return ring_bak_.size(); + } + + bool + push(struct io_uring* ring) { + if (ring == nullptr) { + LOG_WARN("UringContextPool push gets null ring"); + return false; + } + + bool should_destroy = false; + bool released = false; + bool notify_all = false; + { + std::scoped_lock lk(ring_mtx_); + if (owned_rings_.find(ring) == owned_rings_.end()) { + LOG_WARN("UringContextPool rejects unknown ring: {}", static_cast(ring)); + return false; + } + + if (checked_out_rings_.find(ring) == checked_out_rings_.end()) { + LOG_WARN("UringContextPool rejects ring not checked out: {}", static_cast(ring)); + return false; + } + + if (state_ != State::Healthy) { + RemoveTrackedRingLocked(ring); + should_destroy = true; + notify_all = true; + } else { + try { + ring_q_.push(ring); + checked_out_rings_.erase(ring); + released = true; + } catch (const std::exception& e) { + LOG_ERROR("UringContextPool failed to requeue ring {}: {}", static_cast(ring), e.what()); + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + should_destroy = true; + notify_all = true; + } catch (...) { + LOG_ERROR("UringContextPool failed to requeue ring {}: unknown exception", + static_cast(ring)); + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + should_destroy = true; + notify_all = true; + } + } + } + + if (should_destroy) { + DestroyRing(ring); + if (notify_all) { + ring_cv_.notify_all(); + } + return false; + } + + ring_cv_.notify_one(); + return released; + } + + struct io_uring* + pop() { + std::unique_lock lk(ring_mtx_); + ring_cv_.wait(lk, [this] { return state_ != State::Healthy || !ring_q_.empty(); }); + if (state_ != State::Healthy) { + return nullptr; + } + auto ret = ring_q_.front(); + try { + const auto inserted = checked_out_rings_.insert(ret).second; + if (!inserted) { + LOG_ERROR("UringContextPool detected duplicate checked-out ring: {}", static_cast(ret)); + MarkUnusableLocked(); + lk.unlock(); + ring_cv_.notify_all(); + return nullptr; + } + } catch (const std::exception& e) { + LOG_ERROR("UringContextPool failed to mark ring checked out: {}", e.what()); + MarkUnusableLocked(); + lk.unlock(); + ring_cv_.notify_all(); + return nullptr; + } catch (...) { + LOG_ERROR("UringContextPool failed to mark ring checked out: unknown exception"); + MarkUnusableLocked(); + lk.unlock(); + ring_cv_.notify_all(); + return nullptr; + } + ring_q_.pop(); + return ret; + } + + bool + ResetCheckedOut(struct io_uring* ring); + + bool + RetireCheckedOut(struct io_uring* ring); + + void + Shutdown(); + + static bool + InitGlobalUringPool(size_t num_ctx, size_t max_entries); + + static std::shared_ptr + GetGlobalUringPool(); + + static bool + InitGlobalUringPoolWithValidation(size_t num_ctx, size_t max_entries); + + static std::shared_ptr + GetGlobalUringPoolDirect(); + + static void + ResetGlobalForTest(); + + ~UringContextPool(); + + private: + void + MarkUnusableLocked() noexcept { + if (state_ == State::Healthy) { + state_ = State::Unusable; + } + } + + void + RemoveTrackedRingLocked(struct io_uring* ring) { + checked_out_rings_.erase(ring); + owned_rings_.erase(ring); + auto iter = std::find(ring_bak_.begin(), ring_bak_.end(), ring); + if (iter != ring_bak_.end()) { + ring_bak_.erase(iter); + } + } + + static void + DestroyRing(struct io_uring* ring) noexcept { + io_uring_queue_exit(ring); + delete ring; + } + + std::vector ring_bak_; + std::queue ring_q_; + std::unordered_set owned_rings_; + std::unordered_set checked_out_rings_; + mutable std::mutex ring_mtx_; + std::condition_variable ring_cv_; + State state_ = State::Healthy; + size_t num_ctx_; + size_t max_entries_; + + static size_t global_uring_pool_size; + static size_t global_uring_max_entries; + static std::mutex global_uring_pool_mut; + + UringContextPool(size_t num_ctx, size_t max_entries); +}; + +#endif // WITH_IO_URING diff --git a/src/knowhere/aio_context_pool.cc b/src/knowhere/aio_context_pool.cc index db5b418..75d4cfc 100644 --- a/src/knowhere/aio_context_pool.cc +++ b/src/knowhere/aio_context_pool.cc @@ -1,19 +1,28 @@ #include "knowhere/aio_context_pool.h" +#include "knowhere/io_context_pool.h" #include "log/Log.h" -size_t AioContextPool::global_aio_pool_size = 0; -size_t AioContextPool::global_aio_max_events = 0; +namespace { +std::shared_ptr g_aio_pool; +} + +std::atomic AioContextPool::global_aio_pool_size{0}; +std::atomic AioContextPool::global_aio_max_events{0}; std::mutex AioContextPool::global_aio_pool_mut; bool -AioContextPool::InitGlobalAioPool(size_t num_ctx, size_t max_events) { +AioContextPool::InitGlobalAioPoolWithValidation(size_t num_ctx, size_t max_events) { if (num_ctx <= 0) { LOG_ERROR("num_ctx should be bigger than 0"); return false; } + if (max_events == 0) { + LOG_ERROR("max_events should be bigger than 0"); + return false; + } if (max_events > default_max_events) { - LOG_ERROR("max_events %d should not be larger than %d", max_events, default_max_events); + LOG_ERROR("max_events {} should not be larger than {}", max_events, default_max_events); return false; } if (global_aio_pool_size == 0) { @@ -24,21 +33,54 @@ AioContextPool::InitGlobalAioPool(size_t num_ctx, size_t max_events) { return true; } } - LOG_WARN("Global AioContextPool has already been inialized with context num: %d", global_aio_pool_size); + if (global_aio_pool_size != num_ctx || global_aio_max_events != max_events) { + LOG_ERROR( + "Global AioContextPool already initialized with context num: {}, max_events: {} (requested {}, {})", + global_aio_pool_size.load(), global_aio_max_events.load(), num_ctx, max_events); + return false; + } + LOG_WARN("Global AioContextPool has already been initialized with context num: {}", global_aio_pool_size.load()); return true; } std::shared_ptr -AioContextPool::GetGlobalAioPool() { +AioContextPool::GetGlobalAioPoolDirect() { + std::scoped_lock lk(global_aio_pool_mut); if (global_aio_pool_size == 0) { - std::scoped_lock lk(global_aio_pool_mut); - if (global_aio_pool_size == 0) { - global_aio_pool_size = default_pool_size; - global_aio_max_events = default_max_events; - LOG_WARN("Global AioContextPool has not been inialized yet, init it now with context num: %d", - global_aio_pool_size); - } + global_aio_pool_size = default_pool_size; + global_aio_max_events = default_max_events; + LOG_WARN("Global AioContextPool has not been inialized yet, init it now with context num: {}", + global_aio_pool_size.load()); + } + if (g_aio_pool == nullptr) { + g_aio_pool = std::shared_ptr(new AioContextPool(global_aio_pool_size, global_aio_max_events)); } - static auto pool = std::shared_ptr(new AioContextPool(global_aio_pool_size, global_aio_max_events)); - return pool; + return g_aio_pool; +} + +bool +AioContextPool::InitGlobalAioPool(size_t num_ctx, size_t max_events) { + return InitGlobalAioPoolWithValidation(num_ctx, max_events) && GetGlobalAioPoolDirect() != nullptr; +} + +std::shared_ptr +AioContextPool::GetGlobalAioPool() { + auto io_pool = IOContextPool::GetGlobal(); + if (io_pool != nullptr && io_pool->IsInitialized() && io_pool->Backend() == IOBackend::AIO) { + return io_pool->GetAioPoolForLegacy(); + } + + if (io_pool != nullptr && io_pool->IsInitialized()) { + LOG_WARN("Returning independent legacy AIO pool while unified IOContextPool backend is {}", + io_pool->BackendName()); + } + return GetGlobalAioPoolDirect(); +} + +void +AioContextPool::ResetGlobalForTest() { + std::scoped_lock lk(global_aio_pool_mut); + g_aio_pool.reset(); + global_aio_pool_size = 0; + global_aio_max_events = 0; } diff --git a/src/knowhere/io_completion_reader.cc b/src/knowhere/io_completion_reader.cc new file mode 100644 index 0000000..91fae44 --- /dev/null +++ b/src/knowhere/io_completion_reader.cc @@ -0,0 +1,602 @@ +#include "knowhere/io_completion_reader.h" + +#include "io_reader_internal.h" + +#include "syncpoint/sync_point.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "log/Log.h" + +namespace { +constexpr size_t kNumRetries = 10; +constexpr size_t kCleanupPeekLimit = 1024; + +#ifdef WITH_IO_URING +constexpr size_t kDirectIoAlignment = 512; + +bool +IsAlignedForDirectIo(const void* ptr, size_t alignment) { + return reinterpret_cast(ptr) % alignment == 0; +} + +void +ValidateDirectIoAlignment(int fd, size_t size, IOCompletionReaderSpan buffers, + IOCompletionReaderSpan offsets) { +#ifdef O_DIRECT + const auto flags = fcntl(fd, F_GETFL); + if (flags < 0 || (flags & O_DIRECT) == 0) { + return; + } + + if (size % kDirectIoAlignment != 0) { + throw std::invalid_argument("O_DIRECT read size must be 512-byte aligned"); + } + for (size_t i = 0; i < buffers.size(); ++i) { + if (!IsAlignedForDirectIo(buffers[i], kDirectIoAlignment)) { + throw std::invalid_argument("O_DIRECT read buffer address must be 512-byte aligned"); + } + if (offsets[i] % kDirectIoAlignment != 0) { + throw std::invalid_argument("O_DIRECT read offset must be 512-byte aligned"); + } + } +#else + (void)fd; + (void)size; + (void)buffers; + (void)offsets; +#endif +} +#endif + +#ifdef WITH_IO_URING +int +SubmitRing(struct io_uring* ring) { +#ifdef ENABLE_SYNCPOINT + int forced_ret = 0; + TEST_SYNC_POINT_CALLBACK("IOCompletionReader::SubmitRing:Before", &forced_ret); + if (forced_ret != 0) { + return forced_ret; + } +#endif + return io_uring_submit(ring); +} +#endif +} // namespace + +IOCompletionReader::IOCompletionReader(int fd, std::shared_ptr io_pool) + : fd_(fd), io_pool_(std::move(io_pool)) { + if (fd_ < 0) { + throw std::invalid_argument("invalid file descriptor"); + } + + if (io_pool_ == nullptr || !io_pool_->IsInitialized()) { + throw std::runtime_error("IOContextPool is not initialized"); + } + +#ifndef WITH_IO_URING + throw std::runtime_error("IOCompletionReader requires io_uring support"); +#else + if (io_pool_->Backend() != IOBackend::IO_URING) { + throw std::runtime_error("IOCompletionReader requires io_uring backend"); + } + + handle_ = io_pool_->Pop(); + if (handle_.backend != IOBackend::IO_URING) { + throw std::runtime_error("failed to acquire io_uring context handle"); + } + if (handle_.uring == nullptr) { + throw std::runtime_error("failed to acquire io_uring context handle"); + } +#endif +} + +IOCompletionReader::IOCompletionReader(IOCompletionReader&& other) noexcept + : fd_(other.fd_), + io_pool_(std::move(other.io_pool_)), + handle_(std::move(other.handle_)), + next_request_id_(other.next_request_id_), + pending_requests_(std::move(other.pending_requests_)), + ready_completions_(std::move(other.ready_completions_)) { + other.fd_ = -1; + other.handle_ = IOContextHandle{}; + other.next_request_id_ = 1; +} + +IOCompletionReader& +IOCompletionReader::operator=(IOCompletionReader&& other) noexcept { + if (this == &other) { + return *this; + } + + DrainOutstandingNoThrow(); + ReleaseHandle(); + + fd_ = other.fd_; + io_pool_ = std::move(other.io_pool_); + handle_ = std::move(other.handle_); + next_request_id_ = other.next_request_id_; + pending_requests_ = std::move(other.pending_requests_); + ready_completions_ = std::move(other.ready_completions_); + + other.fd_ = -1; + other.handle_ = IOContextHandle{}; + other.next_request_id_ = 1; + return *this; +} + +IOCompletionReader::~IOCompletionReader() { + DrainOutstandingNoThrow(); + ReleaseHandle(); +} + +IOCompletionReader::RequestId +IOCompletionReader::Submit(IOCompletionReaderSpan buffers, size_t size, + IOCompletionReaderSpan offsets) { +#ifndef WITH_IO_URING + (void)buffers; + (void)size; + (void)offsets; + throw std::runtime_error("IOCompletionReader requires io_uring support"); +#else + if (!IsReady()) { + throw std::runtime_error("IOCompletionReader is not ready"); + } + if (size == 0) { + throw std::invalid_argument("size should be greater than 0"); + } + if (buffers.size() != offsets.size()) { + throw std::invalid_argument("buffers and offsets must have same size"); + } + if (buffers.empty()) { + throw std::invalid_argument("buffers should not be empty"); + } + for (const auto* buffer : buffers) { + if (buffer == nullptr) { + throw std::invalid_argument("buffer pointer should not be null"); + } + } + ValidateDirectIoAlignment(fd_, size, buffers, offsets); + + const auto max_events = io_pool_->MaxEventsPerCtx(); + if (max_events == 0) { + throw std::runtime_error("IOCompletionReader has no io_uring event capacity"); + } + if (buffers.size() > max_events) { + throw std::invalid_argument("buffers should not exceed io_uring event capacity"); + } + + ProcessAvailableCompletions(); + if (PendingOperationCount() + buffers.size() > max_events) { + throw std::runtime_error("too many outstanding io_uring requests"); + } + + auto request_id = next_request_id_++; + auto [iter, inserted] = pending_requests_.emplace(request_id, RequestState{}); + if (!inserted) { + throw std::runtime_error("duplicate io_uring request id"); + } + auto& state = iter->second; + state.expected_size = size; + state.ok = true; + + size_t prepared = 0; + while (prepared < buffers.size()) { + auto* sqe = io_uring_get_sqe(handle_.uring); + if (sqe == nullptr) { + const auto flushed = SubmitRing(handle_.uring); + if (flushed < 0) { + CleanupFailedSubmit(request_id, prepared, state.remaining); + throw std::runtime_error("io_uring_submit failed while preparing request"); + } + if (flushed == 0) { + CleanupFailedSubmit(request_id, prepared, state.remaining); + throw std::runtime_error("io_uring_submit made no progress while preparing request"); + } + state.remaining += static_cast(flushed); + continue; + } + + io_uring_prep_read(sqe, fd_, reinterpret_cast(buffers[prepared]), size, offsets[prepared]); + sqe->user_data = request_id; + ++prepared; + } + + while (state.remaining < prepared) { + const auto flushed = SubmitRing(handle_.uring); + if (flushed < 0) { + CleanupFailedSubmit(request_id, prepared, state.remaining); + throw std::runtime_error("io_uring_submit failed"); + } + if (flushed == 0) { + CleanupFailedSubmit(request_id, prepared, state.remaining); + throw std::runtime_error("io_uring_submit made no progress"); + } + state.remaining += static_cast(flushed); + } + + return request_id; +#endif +} + +IOCompletionReader::Completion +IOCompletionReader::WaitCompleted() { +#ifndef WITH_IO_URING + throw std::runtime_error("IOCompletionReader requires io_uring support"); +#else + if (!ready_completions_.empty()) { + auto completion = ready_completions_.front(); + ready_completions_.pop_front(); + return completion; + } + if (!IsReady()) { + throw std::runtime_error("IOCompletionReader is not ready"); + } + if (pending_requests_.empty()) { + throw std::runtime_error("IOCompletionReader has no pending requests"); + } + + while (ready_completions_.empty()) { + WaitOneCompletion(); + } + + auto completion = ready_completions_.front(); + ready_completions_.pop_front(); + return completion; +#endif +} + +std::vector +IOCompletionReader::PollCompleted() { +#ifndef WITH_IO_URING + throw std::runtime_error("IOCompletionReader requires io_uring support"); +#else + if (IsReady()) { + ProcessAvailableCompletions(); + } else if (ready_completions_.empty()) { + throw std::runtime_error("IOCompletionReader is not ready"); + } + + std::vector completed; + while (!ready_completions_.empty()) { + completed.push_back(ready_completions_.front()); + ready_completions_.pop_front(); + } + return completed; +#endif +} + +bool +IOCompletionReader::IsReady() const { +#ifdef WITH_IO_URING + return fd_ >= 0 && io_pool_ != nullptr && io_pool_->Backend() == IOBackend::IO_URING && + handle_.backend == IOBackend::IO_URING && handle_.uring != nullptr; +#else + return false; +#endif +} + +std::optional +IOCompletionReader::ProcessCqe(struct io_uring_cqe* cqe) { +#ifdef WITH_IO_URING + const auto request_id = static_cast(cqe->user_data); + auto iter = pending_requests_.find(request_id); + if (iter == pending_requests_.end()) { + return "unknown io_uring completion request id " + std::to_string(request_id) + ", result " + + std::to_string(cqe->res) + ", pending requests " + std::to_string(pending_requests_.size()); + } + + auto& state = iter->second; + if (cqe->res < 0 || static_cast(cqe->res) != state.expected_size) { + state.ok = false; + } + + if (state.remaining > 0) { + --state.remaining; + } + + if (state.remaining == 0) { + ready_completions_.push_back({request_id, state.ok}); + pending_requests_.erase(iter); + } + return std::nullopt; +#else + (void)cqe; + return std::nullopt; +#endif +} + +void +IOCompletionReader::WaitOneCompletion() { +#ifdef WITH_IO_URING + size_t retry = 0; + while (true) { + io_uring_cqe* cqe = nullptr; + const auto ret = io_uring_wait_cqe(handle_.uring, &cqe); + if (ret < 0) { + if (-ret == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + throw std::runtime_error("io_uring_wait_cqe interrupted too many times"); + } + continue; + } + throw std::runtime_error("io_uring_wait_cqe failed"); + } +#ifdef ENABLE_SYNCPOINT + bool force_null_cqe = false; + TEST_SYNC_POINT_CALLBACK("IOCompletionReader::WaitOneCompletion:ForceNullCqe", &force_null_cqe); + if (force_null_cqe) { + cqe = nullptr; + } +#endif + if (cqe == nullptr) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + LOG_ERROR("io_uring_wait_cqe returned success with null CQE, pending operations: {}", + PendingOperationCount()); + FailPendingRequests(0); + ResetHandleUring(); + throw std::runtime_error("io_uring_wait_cqe returned null CQE too many times"); + } + continue; + } + + auto error = ProcessCqe(cqe); + io_uring_cqe_seen(handle_.uring, cqe); + if (error) { + LOG_ERROR("IOCompletionReader detected invalid CQE while waiting: {}", *error); + FailPendingRequests(0); + ResetHandleUring(); + throw std::runtime_error(*error); + } + return; + } +#else + throw std::runtime_error("IOCompletionReader requires io_uring support"); +#endif +} + +void +IOCompletionReader::ProcessAvailableCompletions() { +#ifdef WITH_IO_URING +#ifdef ENABLE_SYNCPOINT + bool skip = false; + TEST_SYNC_POINT_CALLBACK("IOCompletionReader::ProcessAvailableCompletions:Skip", &skip); + if (skip) { + return; + } +#endif + size_t retry = 0; + while (true) { + io_uring_cqe* cqe = nullptr; + const auto ret = io_uring_peek_cqe(handle_.uring, &cqe); + if (ret == -EAGAIN) { + break; + } + if (ret < 0) { + if (-ret == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + throw std::runtime_error("io_uring_peek_cqe interrupted too many times"); + } + continue; + } + throw std::runtime_error("io_uring_peek_cqe failed"); + } + if (cqe == nullptr) { + break; + } + + auto error = ProcessCqe(cqe); + io_uring_cqe_seen(handle_.uring, cqe); + if (error) { + LOG_ERROR("IOCompletionReader detected invalid CQE while polling: {}", *error); + FailPendingRequests(0); + ResetHandleUring(); + throw std::runtime_error(*error); + } + retry = 0; + } +#endif +} + +size_t +IOCompletionReader::PendingOperationCount() const { + size_t count = 0; + for (const auto& pending : pending_requests_) { + count += pending.second.remaining; + } + return count; +} + +void +IOCompletionReader::DrainOutstandingNoThrow() noexcept { +#ifdef WITH_IO_URING + const bool drained = DrainOutstandingBlockingNoThrow(); + ready_completions_.clear(); + if (!drained) { + pending_requests_.clear(); + ResetHandleUring(); + } +#endif +} + +bool +IOCompletionReader::DrainOutstandingBlockingNoThrow() noexcept { +#ifdef WITH_IO_URING + if (!IsReady()) { + return pending_requests_.empty(); + } + +#ifdef ENABLE_SYNCPOINT + bool skip_drain = false; + TEST_SYNC_POINT_CALLBACK("IOCompletionReader::DrainOutstandingNoThrow:Skip", &skip_drain); + if (skip_drain) { + return false; + } +#endif + while (!pending_requests_.empty()) { + try { + WaitOneCompletion(); + } catch (const std::exception& e) { + LOG_WARN("IOCompletionReader cleanup failed with exception: {}, pending operations: {}", e.what(), + PendingOperationCount()); + return false; + } catch (...) { + LOG_WARN("IOCompletionReader cleanup failed with unknown exception, pending operations: {}", + PendingOperationCount()); + return false; + } + } + return true; +#else + return true; +#endif +} + +void +IOCompletionReader::DrainOutstanding() { +#ifdef WITH_IO_URING + while (!pending_requests_.empty()) { + WaitCompleted(); + } + ready_completions_.clear(); +#endif +} + +void +IOCompletionReader::DrainOutstanding(RequestId request_id) { +#ifdef WITH_IO_URING + while (pending_requests_.find(request_id) != pending_requests_.end()) { + WaitOneCompletion(); + } +#endif +} + +bool +IOCompletionReader::TryDrainOutstanding(RequestId request_id) noexcept { +#ifdef WITH_IO_URING + size_t retry = 0; + size_t attempts = 0; + while (pending_requests_.find(request_id) != pending_requests_.end() && attempts++ < kCleanupPeekLimit) { + try { + io_uring_cqe* cqe = nullptr; + const auto ret = io_uring_peek_cqe(handle_.uring, &cqe); + if (ret == -EAGAIN) { + return false; + } + if (ret < 0) { + if (-ret == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + return false; + } + continue; + } + return false; + } + if (cqe == nullptr) { + return false; + } + + auto error = ProcessCqe(cqe); + io_uring_cqe_seen(handle_.uring, cqe); + if (error) { + return false; + } + retry = 0; + } catch (const std::exception& e) { + LOG_WARN("IOCompletionReader request cleanup failed with exception: {}, pending operations: {}", e.what(), + PendingOperationCount()); + return false; + } catch (...) { + LOG_WARN("IOCompletionReader request cleanup failed with unknown exception, pending operations: {}", + PendingOperationCount()); + return false; + } + } + return pending_requests_.find(request_id) == pending_requests_.end(); +#else + (void)request_id; + return false; +#endif +} + +void +IOCompletionReader::CleanupFailedSubmit(RequestId request_id, size_t prepared, size_t submitted) { +#ifdef WITH_IO_URING + if (prepared > submitted) { + if (submitted == 0) { + pending_requests_.erase(request_id); + } + if (!DrainOutstandingBlockingNoThrow()) { + FailPendingRequests(request_id); + } + RemoveReadyCompletion(request_id); + ResetHandleUring(); + return; + } + + if (submitted > 0) { + if (TryDrainOutstanding(request_id)) { + RemoveReadyCompletion(request_id); + } else { + if (!DrainOutstandingBlockingNoThrow()) { + FailPendingRequests(request_id); + } + RemoveReadyCompletion(request_id); + ResetHandleUring(); + } + } else { + pending_requests_.erase(request_id); + } +#else + (void)request_id; + (void)prepared; + (void)submitted; +#endif +} + +void +IOCompletionReader::FailPendingRequests(RequestId excluded_request_id) { + RemoveReadyCompletion(excluded_request_id); + for (const auto& pending : pending_requests_) { + if (pending.first != excluded_request_id) { + ready_completions_.push_back({pending.first, false}); + } + } + pending_requests_.clear(); +} + +void +IOCompletionReader::RemoveReadyCompletion(RequestId request_id) { + auto new_end = std::remove_if(ready_completions_.begin(), ready_completions_.end(), + [request_id](const auto& completion) { return completion.request_id == request_id; }); + ready_completions_.erase(new_end, ready_completions_.end()); +} + +bool +IOCompletionReader::ResetHandleUring() { +#ifdef WITH_IO_URING + if (io_pool_ == nullptr || handle_.backend != IOBackend::IO_URING || handle_.uring == nullptr) { + handle_ = IOContextHandle{}; + return false; + } + + return io_pool_->Reset(std::move(handle_)); +#endif + return false; +} + +void +IOCompletionReader::ReleaseHandle() { + if (io_pool_ != nullptr && handle_.backend != IOBackend::UNKNOWN) { + io_pool_->Push(std::move(handle_)); + } + handle_ = IOContextHandle{}; +} diff --git a/src/knowhere/io_context_pool.cc b/src/knowhere/io_context_pool.cc new file mode 100644 index 0000000..8347515 --- /dev/null +++ b/src/knowhere/io_context_pool.cc @@ -0,0 +1,398 @@ +#include "knowhere/io_context_pool.h" + +#include +#include +#include + +#include "log/Log.h" + +namespace { +std::shared_ptr g_io_pool; +std::mutex g_io_pool_mutex; +} // namespace + +IOContextHandle::~IOContextHandle() { + ReleaseNoThrow(); +} + +IOContextHandle::IOContextHandle(IOContextHandle&& other) noexcept { + *this = std::move(other); +} + +IOContextHandle& +IOContextHandle::operator=(IOContextHandle&& other) noexcept { + if (this == &other) { + return *this; + } + ReleaseNoThrow(); + backend = other.backend; +#ifdef WITH_IO_URING + uring = other.uring; + other.uring = nullptr; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + aio = other.aio; + other.aio = nullptr; +#endif + owner_ = std::move(other.owner_); + other.backend = IOBackend::UNKNOWN; + return *this; +} + +bool +IOContextHandle::HasContext() const noexcept { + switch (backend) { +#ifdef WITH_IO_URING + case IOBackend::IO_URING: + return uring != nullptr; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + case IOBackend::AIO: + return aio != nullptr; +#endif + default: + return false; + } +} + +void +IOContextHandle::ClearNoRelease() noexcept { + backend = IOBackend::UNKNOWN; +#ifdef WITH_IO_URING + uring = nullptr; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + aio = nullptr; +#endif + owner_.reset(); +} + +void +IOContextHandle::ReleaseNoThrow() noexcept { + if (!HasContext()) { + ClearNoRelease(); + return; + } + + auto owner = owner_; + if (owner == nullptr) { + LOG_WARN("IOContextHandle drops context without owner for backend {}", static_cast(backend)); + ClearNoRelease(); + return; + } + + try { + owner->Release(std::move(*this), IOContextReleaseDisposition::Clean); + } catch (const std::exception& e) { + LOG_ERROR("IOContextHandle failed to release context: {}", e.what()); + ClearNoRelease(); + } catch (...) { + LOG_ERROR("IOContextHandle failed to release context: unknown exception"); + ClearNoRelease(); + } +} + +#ifdef WITH_IO_URING +bool +IOContextPool::TryInitUring(const IOContextPoolConfig& cfg, const std::shared_ptr& io_pool) { + if (!UringContextPool::InitGlobalUringPoolWithValidation(cfg.num_ctx, cfg.max_events)) { + return false; + } + + auto pool = UringContextPool::GetGlobalUringPoolDirect(); + if (pool == nullptr || !pool->IsUsable()) { + LOG_ERROR("Global UringContextPool is unavailable after initialization"); + UringContextPool::ResetGlobalForTest(); + return false; + } + + io_pool->uring_pool_ = pool; + io_pool->backend_ = IOBackend::IO_URING; + io_pool->num_ctx_ = cfg.num_ctx; + io_pool->max_events_per_ctx_ = pool->max_entries_per_ctx(); + return true; +} +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO +bool +IOContextPool::TryInitAio(const IOContextPoolConfig& cfg, const std::shared_ptr& io_pool) { + if (!AioContextPool::InitGlobalAioPoolWithValidation(cfg.num_ctx, cfg.max_events)) { + return false; + } + + auto pool = AioContextPool::GetGlobalAioPoolDirect(); + if (pool == nullptr || !pool->IsUsable()) { + LOG_ERROR("Global AioContextPool is unavailable after initialization"); + AioContextPool::ResetGlobalForTest(); + return false; + } + + io_pool->aio_pool_ = pool; + io_pool->backend_ = IOBackend::AIO; + io_pool->num_ctx_ = cfg.num_ctx; + io_pool->max_events_per_ctx_ = pool->max_events_per_ctx(); + return true; +} +#endif + +bool +IOContextPool::InitGlobal(const IOContextPoolConfig& cfg) { + if (cfg.num_ctx == 0) { + LOG_ERROR("num_ctx should be bigger than 0"); + return false; + } + + if (cfg.max_events == 0) { + LOG_ERROR("max_events should be bigger than 0"); + return false; + } + + std::scoped_lock lk(g_io_pool_mutex); + if (g_io_pool != nullptr && g_io_pool->IsInitialized()) { + if (cfg.max_events != g_io_pool->MaxEventsPerCtx()) { + LOG_ERROR("Global IOContextPool already initialized with max_events={}, requested={}", + g_io_pool->MaxEventsPerCtx(), cfg.max_events); + return false; + } + if (cfg.num_ctx != g_io_pool->num_ctx_) { + LOG_ERROR("Global IOContextPool already initialized with num_ctx={}, requested={}", g_io_pool->num_ctx_, + cfg.num_ctx); + return false; + } + LOG_WARN("Global IOContextPool has already been initialized with backend: {}", g_io_pool->BackendName()); + return true; + } + + auto io_pool = std::shared_ptr(new IOContextPool()); + +#ifdef WITH_IO_URING + if (TryInitUring(cfg, io_pool)) { + g_io_pool = io_pool; + LOG_INFO("Global IOContextPool initialized with backend io_uring"); + return true; + } +#ifdef MILVUS_COMMON_WITH_LIBAIO + LOG_WARN("io_uring backend initialization failed, fallback to aio backend"); + if (TryInitAio(cfg, io_pool)) { + g_io_pool = io_pool; + LOG_WARN("Global IOContextPool fallback initialized with backend aio"); + return true; + } +#endif +#elif defined(MILVUS_COMMON_WITH_LIBAIO) + if (TryInitAio(cfg, io_pool)) { + g_io_pool = io_pool; + LOG_INFO("Global IOContextPool initialized with backend aio"); + return true; + } +#endif + + LOG_ERROR("Failed to initialize IOContextPool with any backend"); + return false; +} + +std::shared_ptr +IOContextPool::GetGlobal() { + std::scoped_lock lk(g_io_pool_mutex); + return g_io_pool; +} + +std::shared_ptr +IOContextPool::GetGlobalOrInit(const IOContextPoolConfig& cfg) { + if (!InitGlobal(cfg)) { + return nullptr; + } + + std::scoped_lock lk(g_io_pool_mutex); + return g_io_pool; +} + +void +IOContextPool::ResetGlobalForTest() { + std::scoped_lock lk(g_io_pool_mutex); + g_io_pool.reset(); +#ifdef MILVUS_COMMON_WITH_LIBAIO + AioContextPool::ResetGlobalForTest(); +#endif +#ifdef WITH_IO_URING + UringContextPool::ResetGlobalForTest(); +#endif +} + +IOBackend +IOContextPool::Backend() const { + return backend_; +} + +std::string +IOContextPool::BackendName() const { + switch (backend_) { + case IOBackend::IO_URING: + return "io_uring"; + case IOBackend::AIO: + return "aio"; + default: + return "unknown"; + } +} + +bool +IOContextPool::IsInitialized() const { + return backend_ != IOBackend::UNKNOWN; +} + +size_t +IOContextPool::MaxEventsPerCtx() const { + return max_events_per_ctx_; +} + +IOContextHandle +IOContextPool::Pop() { + IOContextHandle handle; + handle.backend = backend_; + switch (backend_) { +#ifdef WITH_IO_URING + case IOBackend::IO_URING: + handle.uring = PopUring(); + break; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + case IOBackend::AIO: + handle.aio = PopAio(); + break; +#endif + default: + break; + } + if (handle.HasContext()) { + handle.owner_ = shared_from_this(); + } else { + handle.backend = IOBackend::UNKNOWN; + } + return handle; +} + +bool +IOContextPool::Push(IOContextHandle&& handle) { + return Release(std::move(handle), IOContextReleaseDisposition::Clean); +} + +bool +IOContextPool::Reset(IOContextHandle&& handle) { + return Release(std::move(handle), IOContextReleaseDisposition::Dirty); +} + +bool +IOContextPool::Release(IOContextHandle&& handle, IOContextReleaseDisposition disposition) { + if (!handle.HasContext()) { + handle.ClearNoRelease(); + return true; + } + if (handle.owner_.get() != this) { + LOG_WARN("IOContextPool rejects release for handle owned by a different pool"); + return false; + } + if (handle.backend != backend_) { + LOG_WARN("IOContextPool rejects release for backend {} while active backend is {}", + static_cast(handle.backend), static_cast(backend_)); + handle.ClearNoRelease(); + return false; + } + + bool released = false; + switch (handle.backend) { +#ifdef WITH_IO_URING + case IOBackend::IO_URING: + if (disposition == IOContextReleaseDisposition::Clean) { + released = PushUring(handle.uring); + } else if (disposition == IOContextReleaseDisposition::Dirty) { + released = ResetUring(handle.uring); + } else { + released = RetireUring(handle.uring); + } + break; +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + case IOBackend::AIO: + if (disposition == IOContextReleaseDisposition::Clean) { + released = PushAio(handle.aio); + } else if (disposition == IOContextReleaseDisposition::Dirty) { + released = ResetAio(handle.aio); + } else { + released = RetireAio(handle.aio); + } + break; +#endif + default: + break; + } + handle.ClearNoRelease(); + return released; +} + +#ifdef WITH_IO_URING +struct io_uring* +IOContextPool::PopUring() { + if (uring_pool_ == nullptr) { + return nullptr; + } + return uring_pool_->pop(); +} + +bool +IOContextPool::PushUring(struct io_uring* ring) { + if (uring_pool_ != nullptr) { + return uring_pool_->push(ring); + } + return false; +} + +bool +IOContextPool::ResetUring(struct io_uring* ring) { + return uring_pool_ != nullptr && uring_pool_->ResetCheckedOut(ring); +} + +bool +IOContextPool::RetireUring(struct io_uring* ring) { + return uring_pool_ != nullptr && uring_pool_->RetireCheckedOut(ring); +} + +std::shared_ptr +IOContextPool::GetUringPoolForLegacy() const { + return uring_pool_; +} +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO +io_context_t +IOContextPool::PopAio() { + if (aio_pool_ == nullptr) { + return nullptr; + } + return aio_pool_->pop(); +} + +bool +IOContextPool::PushAio(io_context_t ctx) { + if (aio_pool_ != nullptr) { + return aio_pool_->push(ctx); + } + return false; +} + +bool +IOContextPool::ResetAio(io_context_t ctx) { + return aio_pool_ != nullptr && aio_pool_->ResetCheckedOut(ctx); +} + +bool +IOContextPool::RetireAio(io_context_t ctx) { + return aio_pool_ != nullptr && aio_pool_->RetireCheckedOut(ctx); +} + +std::shared_ptr +IOContextPool::GetAioPoolForLegacy() const { + return aio_pool_; +} +#endif diff --git a/src/knowhere/io_reader.cc b/src/knowhere/io_reader.cc new file mode 100644 index 0000000..5d26ca4 --- /dev/null +++ b/src/knowhere/io_reader.cc @@ -0,0 +1,645 @@ +#include "knowhere/io_reader.h" + +#include "io_reader_internal.h" + +#include "syncpoint/sync_point.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "log/Log.h" + +namespace { +std::future +MakeReadyFuture(bool value) { + return std::async(std::launch::deferred, [value] { return value; }); +} + +class IOContextHandleGuard { + public: + IOContextHandleGuard(std::shared_ptr pool, IOContextHandle handle) + : pool_(std::move(pool)), handle_(std::move(handle)) { + } + + IOContextHandleGuard(const IOContextHandleGuard&) = delete; + IOContextHandleGuard& + operator=(const IOContextHandleGuard&) = delete; + + IOContextHandleGuard(IOContextHandleGuard&& other) noexcept + : pool_(std::move(other.pool_)), handle_(std::move(other.handle_)), active_(other.active_) { + other.active_ = false; + } + + ~IOContextHandleGuard() { + Reset(); + } + + IOContextHandle& + Handle() { + return handle_; + } + + void + Reset() { + if (active_ && pool_ != nullptr) { + pool_->Push(std::move(handle_)); + active_ = false; + } + } + +#ifdef MILVUS_COMMON_WITH_LIBAIO + void + ResetAio() { + if (!active_ || pool_ == nullptr || handle_.backend != IOBackend::AIO || handle_.aio == nullptr) { + return; + } + active_ = false; + pool_->Reset(std::move(handle_)); + } +#endif + +#ifdef WITH_IO_URING + void + ResetUring() { + if (!active_ || pool_ == nullptr || handle_.backend != IOBackend::IO_URING || handle_.uring == nullptr) { + return; + } + active_ = false; + pool_->Reset(std::move(handle_)); + } +#endif + + private: + std::shared_ptr pool_; + IOContextHandle handle_; + bool active_ = true; +}; + +constexpr size_t kNumRetries = 10; +constexpr size_t kDirectIoAlignment = 512; + +struct BatchWaitResult { + size_t completed = 0; + bool complete = false; + bool ok = false; +}; + +bool +IsAlignedForDirectIo(const void* ptr, size_t alignment) { + return reinterpret_cast(ptr) % alignment == 0; +} + +void +ValidateDirectIoAlignment(int fd, size_t size, const std::vector& buffers, + const std::vector& offsets) { +#ifdef O_DIRECT + const auto flags = fcntl(fd, F_GETFL); + if (flags < 0 || (flags & O_DIRECT) == 0) { + return; + } + + if (size % kDirectIoAlignment != 0) { + throw std::invalid_argument("O_DIRECT read size must be 512-byte aligned"); + } + for (size_t i = 0; i < buffers.size(); ++i) { + if (!IsAlignedForDirectIo(buffers[i], kDirectIoAlignment)) { + throw std::invalid_argument("O_DIRECT read buffer address must be 512-byte aligned"); + } + if (offsets[i] % kDirectIoAlignment != 0) { + throw std::invalid_argument("O_DIRECT read offset must be 512-byte aligned"); + } + } +#else + (void)fd; + (void)size; + (void)buffers; + (void)offsets; +#endif +} + +#ifdef MILVUS_COMMON_WITH_LIBAIO +size_t +SubmitAioBatch(io_context_t ctx, int fd, const std::vector& buffers, size_t size, + const std::vector& offsets, size_t start, size_t batch, std::vector& cbs) { + cbs.resize(batch); + std::vector cb_ptrs(batch); + for (size_t i = 0; i < batch; ++i) { + const auto idx = start + i; + io_prep_pread(&cbs[i], fd, reinterpret_cast(buffers[idx]), size, offsets[idx]); + cb_ptrs[i] = &cbs[i]; + } + + size_t submitted_total = 0; + size_t retry = 0; + while (submitted_total < batch) { + long submit_count = static_cast(batch - submitted_total); +#ifdef ENABLE_SYNCPOINT + size_t forced_submit_limit = 0; + TEST_SYNC_POINT_CALLBACK("IOReader::SubmitAioBatch:BeforeSubmit", &forced_submit_limit); + if (forced_submit_limit > 0) { + submit_count = static_cast(std::min(static_cast(submit_count), forced_submit_limit)); + } +#endif + const auto submitted = io_submit(ctx, submit_count, cb_ptrs.data() + submitted_total); + if (submitted < 0) { + if (-submitted == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + break; + } + continue; + } + break; + } + if (submitted == 0) { + if (++retry > kNumRetries) { + break; + } + continue; + } + submitted_total += static_cast(submitted); + retry = 0; + } + return submitted_total; +} + +BatchWaitResult +WaitAioBatch(io_context_t ctx, size_t size, const std::vector& cbs, size_t submitted_total) { + if (submitted_total == 0) { + return {0, true, true}; + } + + std::unordered_set expected_cbs; + for (size_t i = 0; i < submitted_total; ++i) { + expected_cbs.insert(&cbs[i]); + } + + std::vector events(submitted_total); + BatchWaitResult result; + result.ok = true; + size_t retry = 0; + while (result.completed < submitted_total) { + size_t max_events = submitted_total - result.completed; +#ifdef ENABLE_SYNCPOINT + size_t forced_max_events = 0; + TEST_SYNC_POINT_CALLBACK("IOReader::WaitAioBatch:BeforeGetEvents", &forced_max_events); + if (forced_max_events > 0) { + max_events = std::min(max_events, forced_max_events); + } +#endif + const auto completed = + io_getevents(ctx, 1, static_cast(max_events), events.data() + result.completed, nullptr); + if (completed < 0) { + if (-completed == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + break; + } + continue; + } + break; + } + if (completed == 0) { + if (++retry > kNumRetries) { + break; + } + continue; + } + for (size_t i = result.completed; i < result.completed + static_cast(completed); ++i) { + if (expected_cbs.erase(static_cast(events[i].obj)) == 0 || events[i].res < 0 || + static_cast(events[i].res) != size) { + result.ok = false; + } + } + result.completed += static_cast(completed); + retry = 0; + } + result.complete = result.completed == submitted_total; + result.ok = result.ok && result.complete; +#ifdef ENABLE_SYNCPOINT + bool force_bad_cleanup = false; + TEST_SYNC_POINT_CALLBACK("IOReader::WaitAioBatch:BeforeReturn", &force_bad_cleanup); + if (force_bad_cleanup) { + result.complete = false; + result.ok = false; + } +#endif + return result; +} + +class AioReadState { + public: + AioReadState(IOContextHandleGuard guard, size_t size, size_t first_submitted, + std::shared_ptr> first_cbs) + : guard_(std::move(guard)), + size_(size), + first_remaining_(first_submitted), + first_cbs_(std::move(first_cbs)) { + } + + AioReadState(const AioReadState&) = delete; + AioReadState& + operator=(const AioReadState&) = delete; + + AioReadState(AioReadState&& other) noexcept + : guard_(std::move(other.guard_)), + size_(other.size_), + first_remaining_(other.first_remaining_), + first_cbs_(std::move(other.first_cbs_)) { + other.first_remaining_ = 0; + } + + ~AioReadState() { + if (first_remaining_ != 0) { + const auto result = CollectFirst(); + if (!result.complete) { + ResetContext(); + } + } + } + + io_context_t + Context() { + return guard_.Handle().aio; + } + + BatchWaitResult + CollectFirst() { + if (first_remaining_ == 0) { + return {0, true, true}; + } + auto result = WaitAioBatch(Context(), size_, *first_cbs_, first_remaining_); + first_remaining_ -= result.completed; + result.complete = first_remaining_ == 0; + result.ok = result.ok && result.complete; + return result; + } + + void + ResetContext() { + first_remaining_ = 0; + guard_.ResetAio(); + } + + private: + IOContextHandleGuard guard_; + size_t size_ = 0; + size_t first_remaining_ = 0; + std::shared_ptr> first_cbs_; +}; + +std::future +ReadAioAsync(int fd, size_t size, std::vector&& buffers, std::vector&& offsets, + std::shared_ptr pool) { + auto handle = pool->Pop(); + if (handle.aio == nullptr) { + return MakeReadyFuture(false); + } + + auto ctx = handle.aio; + IOContextHandleGuard guard(pool, std::move(handle)); + const size_t max_batch = pool->MaxEventsPerCtx(); + if (max_batch == 0) { + return MakeReadyFuture(false); + } + + const size_t first_batch = std::min(max_batch, buffers.size()); + auto first_cbs = std::make_shared>(); + const size_t first_submitted = SubmitAioBatch(ctx, fd, buffers, size, offsets, 0, first_batch, *first_cbs); + if (first_submitted == 0) { + return MakeReadyFuture(false); + } + + return std::async( + std::launch::deferred, + [fd, size, buffers = std::move(buffers), offsets = std::move(offsets), max_batch, first_batch, first_submitted, + state = AioReadState(std::move(guard), size, first_submitted, std::move(first_cbs))]() mutable -> bool { + auto ctx = state.Context(); + const auto first_result = state.CollectFirst(); + if (!first_result.complete || !first_result.ok || first_submitted != first_batch) { + if (!first_result.complete || !first_result.ok) { + state.ResetContext(); + } + return false; + } + + size_t processed = first_batch; + while (processed < buffers.size()) { + const size_t batch = std::min(max_batch, buffers.size() - processed); + std::vector cbs; + const size_t submitted = SubmitAioBatch(ctx, fd, buffers, size, offsets, processed, batch, cbs); + if (submitted != batch) { + const auto cleanup = WaitAioBatch(ctx, size, cbs, submitted); + if (!cleanup.complete || !cleanup.ok) { + state.ResetContext(); + } + return false; + } + const auto result = WaitAioBatch(ctx, size, cbs, submitted); + if (!result.complete || !result.ok) { + state.ResetContext(); + return false; + } + processed += batch; + } + return true; + }); +} +#endif + +#ifdef WITH_IO_URING +int +SubmitUring(io_uring* ring) { +#ifdef ENABLE_SYNCPOINT + int forced_ret = 0; + TEST_SYNC_POINT_CALLBACK("IOReader::SubmitUring:Before", &forced_ret); + if (forced_ret != 0) { + return forced_ret; + } +#endif + return io_uring_submit(ring); +} + +size_t +PrepareUringBatch(io_uring* ring, int fd, const std::vector& buffers, size_t size, + const std::vector& offsets, size_t start, size_t max_batch) { + size_t batch = 0; + for (; batch < max_batch && start + batch < buffers.size(); ++batch) { + auto* sqe = io_uring_get_sqe(ring); + if (sqe == nullptr) { + break; + } + const auto idx = start + batch; + io_uring_prep_read(sqe, fd, reinterpret_cast(buffers[idx]), size, offsets[idx]); + sqe->user_data = idx; + } + return batch; +} + +BatchWaitResult +WaitUringBatch(io_uring* ring, size_t size, size_t submitted_total, size_t start) { + BatchWaitResult result; + result.ok = true; + std::unordered_set expected_ids; + for (size_t i = 0; i < submitted_total; ++i) { + expected_ids.insert(static_cast(start + i)); + } + size_t retry = 0; + while (result.completed < submitted_total) { + io_uring_cqe* cqe = nullptr; + const auto wait_result = io_uring_wait_cqe(ring, &cqe); + if (wait_result < 0) { + if (-wait_result == EINTR) { + if (!knowhere_internal::ShouldRetryInterruptedSyscall(retry, kNumRetries)) { + break; + } + continue; + } + break; + } + if (cqe == nullptr) { + break; + } + if (expected_ids.erase(cqe->user_data) == 0 || cqe->res < 0 || static_cast(cqe->res) != size) { + result.ok = false; + } + io_uring_cqe_seen(ring, cqe); + ++result.completed; + } + result.complete = result.completed == submitted_total; + result.ok = result.ok && result.complete; + return result; +} + +class UringReadState { + public: + UringReadState(IOContextHandleGuard guard, size_t size, size_t first_prepared, size_t first_submitted) + : guard_(std::move(guard)), + size_(size), + first_prepared_(first_prepared), + first_remaining_(first_submitted) { + } + + UringReadState(const UringReadState&) = delete; + UringReadState& + operator=(const UringReadState&) = delete; + + UringReadState(UringReadState&& other) noexcept + : guard_(std::move(other.guard_)), + size_(other.size_), + first_prepared_(other.first_prepared_), + first_remaining_(other.first_remaining_) { + other.first_prepared_ = 0; + other.first_remaining_ = 0; + } + + ~UringReadState() { + if (first_remaining_ == 0) { + return; + } + + const auto result = CollectFirst(); + if (!result.complete || !result.ok || first_prepared_ != result.completed) { + ResetRing(); + } + } + + io_uring* + Ring() { + return guard_.Handle().uring; + } + + void + ResetRing() { + first_remaining_ = 0; + guard_.ResetUring(); + } + + BatchWaitResult + CollectFirst() { + if (first_remaining_ == 0) { + return {0, true, true}; + } + const auto waiting = first_remaining_; + auto result = WaitUringBatch(Ring(), size_, first_remaining_, 0); + first_remaining_ -= result.completed; + result.complete = first_remaining_ == 0; + result.ok = result.ok && result.complete; + result.completed = waiting - first_remaining_; + return result; + } + + private: + IOContextHandleGuard guard_; + size_t size_ = 0; + size_t first_prepared_ = 0; + size_t first_remaining_ = 0; +}; + +std::future +ReadUringAsync(int fd, size_t size, std::vector&& buffers, std::vector&& offsets, + std::shared_ptr pool) { + auto handle = pool->Pop(); + if (handle.uring == nullptr) { + return MakeReadyFuture(false); + } + + auto* ring = handle.uring; + IOContextHandleGuard guard(pool, std::move(handle)); + const size_t max_batch = pool->MaxEventsPerCtx(); + if (max_batch == 0) { + return MakeReadyFuture(false); + } + const size_t first_batch = PrepareUringBatch(ring, fd, buffers, size, offsets, 0, max_batch); + if (first_batch == 0) { + guard.ResetUring(); + return MakeReadyFuture(false); + } + + const auto submitted = SubmitUring(ring); + if (submitted <= 0) { + guard.ResetUring(); + return MakeReadyFuture(false); + } + const auto first_submitted = static_cast(submitted); + + return std::async( + std::launch::deferred, + [fd, size, buffers = std::move(buffers), offsets = std::move(offsets), first_batch, first_submitted, + max_batch, state = UringReadState(std::move(guard), size, first_batch, first_submitted)]() mutable -> bool { + auto* ring = state.Ring(); + const auto first_result = state.CollectFirst(); + if (!first_result.complete || first_submitted != first_batch) { + state.ResetRing(); + return false; + } + if (!first_result.ok) { + state.ResetRing(); + return false; + } + + size_t processed = first_batch; + while (processed < buffers.size()) { + const size_t batch = PrepareUringBatch(ring, fd, buffers, size, offsets, processed, max_batch); + if (batch == 0) { + state.ResetRing(); + return false; + } + const auto submitted = SubmitUring(ring); + if (submitted <= 0) { + state.ResetRing(); + return false; + } + const auto submitted_count = static_cast(submitted); + const auto result = WaitUringBatch(ring, size, submitted_count, processed); + if (!result.complete || submitted_count != batch) { + state.ResetRing(); + return false; + } + if (!result.ok) { + state.ResetRing(); + return false; + } + processed += batch; + } + return true; + }); +} +#endif +} // namespace + +IOReader::IOReader() : io_pool_(IOContextPool::GetGlobal()) { +} + +IOReader::IOReader(int fd) : fd_(fd), io_pool_(IOContextPool::GetGlobal()) { +} + +IOReader::IOReader(int fd, std::shared_ptr io_pool) : fd_(fd), io_pool_(std::move(io_pool)) { +} + +IOReader::IOReader(std::shared_ptr io_pool) : io_pool_(std::move(io_pool)) { +} + +bool +IOReader::Read(IOReaderSpan buf, size_t size, IOReaderSpan offsets) const { + if (buf.size() != offsets.size()) { + throw std::invalid_argument("buffers and offsets must have same size"); + } + + std::vector buffers(buf.size()); + std::vector read_offsets(offsets.size()); + + for (size_t i = 0; i < buf.size(); ++i) { + if (buf[i] == nullptr) { + throw std::invalid_argument("buffer pointer should not be null"); + } + if (offsets[i] < 0) { + throw std::invalid_argument("offset should be non-negative"); + } + buffers[i] = buf[i]; + read_offsets[i] = static_cast(offsets[i]); + } + + return ReadAsync(std::move(buffers), size, std::move(read_offsets)).get(); +} + +std::future +IOReader::ReadAsync(std::vector&& buffers, size_t size, std::vector&& offsets) const { + if (size == 0) { + throw std::invalid_argument("size should be greater than 0"); + } + if (buffers.size() != offsets.size()) { + throw std::invalid_argument("buffers and offsets must have same size"); + } + if (buffers.empty()) { + return MakeReadyFuture(true); + } + if (fd_ < 0) { + throw std::invalid_argument("invalid file descriptor"); + } + + for (const auto* buffer : buffers) { + if (buffer == nullptr) { + throw std::invalid_argument("buffer pointer should not be null"); + } + } + ValidateDirectIoAlignment(fd_, size, buffers, offsets); + + auto pool = io_pool_ ? io_pool_ : IOContextPool::GetGlobal(); + if (pool == nullptr || !pool->IsInitialized()) { + throw std::runtime_error("IOContextPool is not initialized"); + } + + switch (pool->Backend()) { +#ifdef MILVUS_COMMON_WITH_LIBAIO + case IOBackend::AIO: + return ReadAioAsync(fd_, size, std::move(buffers), std::move(offsets), std::move(pool)); +#endif +#ifdef WITH_IO_URING + case IOBackend::IO_URING: + return ReadUringAsync(fd_, size, std::move(buffers), std::move(offsets), std::move(pool)); +#endif + default: + return MakeReadyFuture(false); + } +} + +IOBackend +IOReader::Backend() const { + return io_pool_ ? io_pool_->Backend() : IOBackend::UNKNOWN; +} + +std::string +IOReader::BackendName() const { + return io_pool_ ? io_pool_->BackendName() : "unknown"; +} + +bool +IOReader::IsReady() const { + return fd_ >= 0 && io_pool_ != nullptr && io_pool_->IsInitialized(); +} diff --git a/src/knowhere/io_reader_internal.h b/src/knowhere/io_reader_internal.h new file mode 100644 index 0000000..4f199d7 --- /dev/null +++ b/src/knowhere/io_reader_internal.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace knowhere_internal { +inline bool +ShouldRetryInterruptedSyscall(size_t& retry, size_t max_retries) { + return ++retry <= max_retries; +} +} // namespace knowhere_internal diff --git a/src/knowhere/uring_context_pool.cc b/src/knowhere/uring_context_pool.cc new file mode 100644 index 0000000..97414bb --- /dev/null +++ b/src/knowhere/uring_context_pool.cc @@ -0,0 +1,279 @@ +#ifdef WITH_IO_URING + +#include "knowhere/uring_context_pool.h" + +#include +#include +#include + +#include "knowhere/io_context_pool.h" +#include "syncpoint/sync_point.h" + +namespace { +std::shared_ptr g_uring_pool; +} + +size_t UringContextPool::global_uring_pool_size = 0; +size_t UringContextPool::global_uring_max_entries = 0; +std::mutex UringContextPool::global_uring_pool_mut; + +UringContextPool::UringContextPool(size_t num_ctx, size_t max_entries) : num_ctx_(num_ctx), max_entries_(max_entries) { + ring_bak_.reserve(num_ctx_); + + for (size_t i = 0; i < num_ctx_; ++i) { + auto* ring = new io_uring(); + std::memset(ring, 0, sizeof(io_uring)); + int ret = 0; +#ifdef ENABLE_SYNCPOINT + TEST_SYNC_POINT_CALLBACK("UringContextPool::Ctor:BeforeInit", &ret); +#endif + if (ret == 0) { + ret = io_uring_queue_init(static_cast(max_entries_), ring, 0); + } + if (ret < 0) { + LOG_ERROR("io_uring_queue_init failed with ret={}, errno={}: {}", ret, -ret, ::strerror(-ret)); + delete ring; + continue; + } + + ring_q_.push(ring); + ring_bak_.push_back(ring); + owned_rings_.insert(ring); + } + + if (ring_bak_.size() != num_ctx_) { + state_ = State::Unusable; + LOG_ERROR("UringContextPool initialization failed: created {} of {} requested contexts", ring_bak_.size(), + num_ctx_); + } +} + +bool +UringContextPool::InitGlobalUringPoolWithValidation(size_t num_ctx, size_t max_entries) { + if (num_ctx == 0) { + LOG_ERROR("num_ctx should be bigger than 0"); + return false; + } + + if (max_entries == 0 || max_entries > default_uring_max_entries) { + LOG_ERROR("max_entries {} should be in range (0, {}]", max_entries, default_uring_max_entries); + return false; + } + + std::scoped_lock lk(global_uring_pool_mut); + if (global_uring_pool_size == 0) { + global_uring_pool_size = num_ctx; + global_uring_max_entries = max_entries; + return true; + } + + if (global_uring_pool_size != num_ctx || global_uring_max_entries != max_entries) { + LOG_ERROR( + "Global UringContextPool already initialized with context num: {}, max_entries: {} (requested {}, {})", + global_uring_pool_size, global_uring_max_entries, num_ctx, max_entries); + return false; + } + + LOG_WARN("Global UringContextPool has already been initialized with context num: {}", global_uring_pool_size); + return true; +} + +bool +UringContextPool::ResetCheckedOut(struct io_uring* ring) { + if (ring == nullptr) { + LOG_WARN("UringContextPool reset gets null ring"); + return false; + } + + { + std::scoped_lock lk(ring_mtx_); + if (owned_rings_.find(ring) == owned_rings_.end()) { + LOG_WARN("UringContextPool rejects reset for unknown ring: {}", static_cast(ring)); + return false; + } + if (checked_out_rings_.find(ring) == checked_out_rings_.end()) { + LOG_WARN("UringContextPool rejects reset for ring not checked out: {}", static_cast(ring)); + return false; + } + } + + io_uring_queue_exit(ring); + std::memset(ring, 0, sizeof(io_uring)); + int ret = 0; +#ifdef ENABLE_SYNCPOINT + TEST_SYNC_POINT_CALLBACK("UringContextPool::ResetCheckedOut:BeforeInit", &ret); +#endif + if (ret == 0) { + ret = io_uring_queue_init(static_cast(max_entries_), ring, 0); + } + if (ret == 0) { + bool released = false; + bool should_destroy = false; + { + std::scoped_lock lk(ring_mtx_); + if (state_ == State::Healthy) { + try { + ring_q_.push(ring); + checked_out_rings_.erase(ring); + released = true; + } catch (const std::exception& e) { + LOG_ERROR("UringContextPool failed to requeue reset ring {}: {}", static_cast(ring), + e.what()); + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + should_destroy = true; + } catch (...) { + LOG_ERROR("UringContextPool failed to requeue reset ring {}: unknown exception", + static_cast(ring)); + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + should_destroy = true; + } + } else { + RemoveTrackedRingLocked(ring); + should_destroy = true; + } + } + if (released) { + ring_cv_.notify_one(); + return true; + } + if (should_destroy) { + DestroyRing(ring); + } + ring_cv_.notify_all(); + return false; + } + + LOG_ERROR("io_uring_queue_init failed while resetting ring with ret={}, errno={}: {}", ret, -ret, ::strerror(-ret)); + { + std::scoped_lock lk(ring_mtx_); + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + } + ring_cv_.notify_all(); + delete ring; + return false; +} + +bool +UringContextPool::RetireCheckedOut(struct io_uring* ring) { + if (ring == nullptr) { + LOG_WARN("UringContextPool retire gets null ring"); + return false; + } + + { + std::scoped_lock lk(ring_mtx_); + if (owned_rings_.find(ring) == owned_rings_.end()) { + LOG_WARN("UringContextPool rejects retire for unknown ring: {}", static_cast(ring)); + return false; + } + if (checked_out_rings_.find(ring) == checked_out_rings_.end()) { + LOG_WARN("UringContextPool rejects retire for ring not checked out: {}", static_cast(ring)); + return false; + } + + RemoveTrackedRingLocked(ring); + MarkUnusableLocked(); + } + + DestroyRing(ring); + ring_cv_.notify_all(); + return true; +} + +void +UringContextPool::Shutdown() { + { + std::scoped_lock lk(ring_mtx_); + if (state_ == State::Stopped) { + return; + } + state_ = State::Stopped; + } + ring_cv_.notify_all(); +} + +std::shared_ptr +UringContextPool::GetGlobalUringPoolDirect() { + std::scoped_lock lk(global_uring_pool_mut); + if (global_uring_pool_size == 0) { + IOContextPoolConfig cfg; + global_uring_pool_size = cfg.num_ctx; + global_uring_max_entries = cfg.max_events; + LOG_WARN("Global UringContextPool has not been initialized yet, init it now with context num: {}", + global_uring_pool_size); + } + + if (g_uring_pool == nullptr) { + g_uring_pool = + std::shared_ptr(new UringContextPool(global_uring_pool_size, global_uring_max_entries)); + } + return g_uring_pool; +} + +bool +UringContextPool::InitGlobalUringPool(size_t num_ctx, size_t max_entries) { + IOContextPoolConfig cfg; + cfg.num_ctx = num_ctx; + cfg.max_events = max_entries; + + if (!IOContextPool::InitGlobal(cfg)) { + return false; + } + + auto io_pool = IOContextPool::GetGlobal(); + if (io_pool == nullptr || !io_pool->IsInitialized()) { + return false; + } + if (io_pool->Backend() != IOBackend::IO_URING) { + LOG_ERROR("Global IOContextPool backend is {}, legacy io_uring API is unavailable", io_pool->BackendName()); + return false; + } + return true; +} + +std::shared_ptr +UringContextPool::GetGlobalUringPool() { + auto io_pool = IOContextPool::GetGlobal(); + if (io_pool == nullptr || !io_pool->IsInitialized()) { + return nullptr; + } + if (io_pool->Backend() != IOBackend::IO_URING) { + return nullptr; + } + return io_pool->GetUringPoolForLegacy(); +} + +void +UringContextPool::ResetGlobalForTest() { + std::scoped_lock lk(global_uring_pool_mut); + g_uring_pool.reset(); + global_uring_pool_size = 0; + global_uring_max_entries = 0; +} + +UringContextPool::~UringContextPool() { + Shutdown(); + + std::unordered_set checked_out; + { + std::scoped_lock lk(ring_mtx_); + checked_out = checked_out_rings_; + } + if (!checked_out.empty()) { + LOG_WARN("UringContextPool shutdown with {} checked-out rings still not returned", checked_out.size()); + } + + for (auto* ring : ring_bak_) { + if (checked_out.find(ring) == checked_out.end()) { + DestroyRing(ring); + } + } + ring_bak_.clear(); + owned_rings_.clear(); + checked_out_rings_.clear(); +} + +#endif // WITH_IO_URING diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d23b993..1f070ac 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,6 +18,7 @@ set(ALL_TEST_FILES init_gtest.cpp TracerTest.cpp StreamTest.cpp + IOContextPoolTest.cpp ) add_executable(all_tests diff --git a/test/IOContextPoolTest.cpp b/test/IOContextPoolTest.cpp new file mode 100644 index 0000000..4815a85 --- /dev/null +++ b/test/IOContextPoolTest.cpp @@ -0,0 +1,1601 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "knowhere/io_completion_reader.h" +#include "knowhere/io_context_pool.h" +#include "knowhere/io_reader.h" +#include "syncpoint/sync_point.h" +#include "../src/knowhere/io_reader_internal.h" +#ifdef MILVUS_COMMON_WITH_LIBAIO +#include "knowhere/aio_context_pool.h" +#endif +#ifdef WITH_IO_URING +#include "knowhere/uring_context_pool.h" +#endif + +class IOContextPoolTestFixture : public ::testing::Test { + protected: + void + SetUp() override { + IOContextPool::ResetGlobalForTest(); + } + + void + TearDown() override { + IOContextPool::ResetGlobalForTest(); + } +}; + +#ifdef WITH_IO_URING +#ifdef MILVUS_COMMON_WITH_LIBAIO +TEST_F(IOContextPoolTestFixture, InitShouldFallbackToAioWhenUringUnavailable) { + pid_t pid = fork(); + ASSERT_GE(pid, 0); + if (pid == 0) { + struct rlimit lim; + lim.rlim_cur = 3; + lim.rlim_max = 3; + if (setrlimit(RLIMIT_NOFILE, &lim) != 0) { + _exit(20); + } + + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + const bool ok = IOContextPool::InitGlobal(cfg); + if (!ok) { + _exit(21); + } + + auto pool = IOContextPool::GetGlobal(); + if (pool == nullptr || pool->Backend() != IOBackend::AIO) { + _exit(22); + } + _exit(0); + } + + int status = 0; + ASSERT_EQ(waitpid(pid, &status, 0), pid); + ASSERT_TRUE(WIFEXITED(status)); + ASSERT_EQ(WEXITSTATUS(status), 0); +} +#endif + +#if defined(MILVUS_COMMON_WITH_LIBAIO) && defined(ENABLE_SYNCPOINT) +TEST_F(IOContextPoolTestFixture, PartialUringInitShouldFallbackToAio) { + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + int init_calls = 0; + sync_point->SetCallBack("UringContextPool::Ctor:BeforeInit", [&](void* arg) { + if (init_calls++ == 1) { + *static_cast(arg) = -EMFILE; + } + }); + sync_point->EnableProcessing(); + + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + ASSERT_EQ(pool->Backend(), IOBackend::AIO); +} +#endif +#endif + +TEST_F(IOContextPoolTestFixture, BackendIsSelectedAtInit) { + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto backend = pool->Backend(); +#ifdef WITH_IO_URING + ASSERT_TRUE(backend == IOBackend::IO_URING || backend == IOBackend::AIO); +#else + ASSERT_EQ(backend, IOBackend::AIO); +#endif +} + +#ifdef WITH_IO_URING +TEST_F(IOContextPoolTestFixture, RequiredIoUringBackendShouldBeSelected) { + const char* require_uring = std::getenv("KNOWHERE_REQUIRE_IO_URING"); + if (require_uring == nullptr || std::strcmp(require_uring, "1") != 0) { + GTEST_SKIP() << "io_uring backend is not required by this environment"; + } + + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + ASSERT_EQ(pool->Backend(), IOBackend::IO_URING); +} +#endif + +TEST_F(IOContextPoolTestFixture, InvalidConfigRejected) { + IOContextPoolConfig cfg; + cfg.num_ctx = 0; + cfg.max_events = 128; + + ASSERT_FALSE(IOContextPool::InitGlobal(cfg)); +} + +TEST_F(IOContextPoolTestFixture, GetGlobalShouldNotSelectBackendBeforeExplicitInit) { + ASSERT_EQ(IOContextPool::GetGlobal(), nullptr); +} + +TEST_F(IOContextPoolTestFixture, ReinitWithDifferentConfigShouldFail) { + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + IOContextPoolConfig mismatch = cfg; + mismatch.num_ctx = 4; + + ASSERT_FALSE(IOContextPool::InitGlobal(mismatch)); +} + +TEST_F(IOContextPoolTestFixture, GetGlobalOrInitShouldRejectDifferentConfig) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + IOContextPoolConfig mismatch = cfg; + mismatch.num_ctx = 2; + + ASSERT_EQ(IOContextPool::GetGlobalOrInit(mismatch), nullptr); +} + +TEST_F(IOContextPoolTestFixture, ResetGlobalForTestShouldClearSingletonState) { + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + IOContextPoolConfig mismatch = cfg; + mismatch.num_ctx = 4; + + ASSERT_FALSE(IOContextPool::InitGlobal(mismatch)); + IOContextPool::ResetGlobalForTest(); + ASSERT_TRUE(IOContextPool::InitGlobal(mismatch)); +} + +#ifdef MILVUS_COMMON_WITH_LIBAIO +TEST_F(IOContextPoolTestFixture, DefaultConfigShouldMatchLegacyAioPoolSize) { + IOContextPoolConfig cfg; + ASSERT_EQ(cfg.num_ctx, default_pool_size); + ASSERT_EQ(cfg.max_events, default_max_events); +} +#else +TEST_F(IOContextPoolTestFixture, DefaultConfigShouldNotUseSingleContext) { + IOContextPoolConfig cfg; + ASSERT_GT(cfg.num_ctx, 1u); + ASSERT_EQ(cfg.max_events, 128u); +} +#endif + +TEST_F(IOContextPoolTestFixture, ReaderCanBeConstructed) { + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + IOReader reader; + ASSERT_EQ(reader.Backend(), IOContextPool::GetGlobal()->Backend()); +} + +TEST_F(IOContextPoolTestFixture, ReaderShouldNotImplicitlyInitializeGlobalPool) { + ASSERT_EQ(IOContextPool::GetGlobal(), nullptr); + + IOReader reader; + ASSERT_EQ(IOContextPool::GetGlobal(), nullptr); + ASSERT_FALSE(reader.IsReady()); +} + +TEST_F(IOContextPoolTestFixture, UnifiedPopPushShouldUseSelectedBackend) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto handle = pool->Pop(); + ASSERT_EQ(handle.backend, pool->Backend()); +#ifdef WITH_IO_URING + if (pool->Backend() == IOBackend::IO_URING) { + ASSERT_EQ(handle.backend, IOBackend::IO_URING); + ASSERT_NE(handle.uring, nullptr); + } +#endif +#if !defined(WITH_IO_URING) && defined(MILVUS_COMMON_WITH_LIBAIO) + ASSERT_EQ(handle.backend, IOBackend::AIO); + ASSERT_NE(handle.aio, nullptr); +#endif + pool->Push(std::move(handle)); + + auto second = pool->Pop(); + ASSERT_EQ(second.backend, pool->Backend()); + pool->Push(std::move(second)); +} + +TEST_F(IOContextPoolTestFixture, HandleDestructorShouldReturnCheckedOutContext) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + { + auto handle = pool->Pop(); + ASSERT_TRUE(handle.HasContext()); + } + + auto second = std::async(std::launch::async, [&] { return pool->Pop(); }); + ASSERT_EQ(second.wait_for(std::chrono::seconds(1)), std::future_status::ready); + auto handle = second.get(); + ASSERT_TRUE(handle.HasContext()); + pool->Push(std::move(handle)); +} + +TEST_F(IOContextPoolTestFixture, MoveAssignHandleShouldReturnPreviousContext) { + IOContextPoolConfig cfg; + cfg.num_ctx = 2; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto first = pool->Pop(); + auto second = pool->Pop(); + ASSERT_TRUE(first.HasContext()); + ASSERT_TRUE(second.HasContext()); + + second = std::move(first); + ASSERT_TRUE(second.HasContext()); + + auto returned = std::async(std::launch::async, [&] { return pool->Pop(); }); + ASSERT_EQ(returned.wait_for(std::chrono::seconds(1)), std::future_status::ready); + auto returned_handle = returned.get(); + ASSERT_TRUE(returned_handle.HasContext()); + + pool->Push(std::move(returned_handle)); + pool->Push(std::move(second)); +} + +TEST_F(IOContextPoolTestFixture, RetireHandleShouldMakePoolFailFast) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto handle = pool->Pop(); + ASSERT_TRUE(handle.HasContext()); + ASSERT_TRUE(pool->Release(std::move(handle), IOContextReleaseDisposition::Retire)); + + auto failed = std::async(std::launch::async, [&] { return pool->Pop(); }); + ASSERT_EQ(failed.wait_for(std::chrono::seconds(1)), std::future_status::ready); + ASSERT_FALSE(failed.get().HasContext()); +} + +TEST_F(IOContextPoolTestFixture, PushShouldRejectHandleFromDifferentBackend) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto handle = pool->Pop(); + ASSERT_EQ(handle.backend, pool->Backend()); + + IOContextHandle mismatched; + mismatched.backend = handle.backend == IOBackend::AIO ? IOBackend::IO_URING : IOBackend::AIO; + pool->Push(std::move(mismatched)); + pool->Push(std::move(handle)); + + auto second = pool->Pop(); + ASSERT_EQ(second.backend, pool->Backend()); + pool->Push(std::move(second)); +} + +TEST_F(IOContextPoolTestFixture, IoReaderSpanShouldUseCompatSpanType) { + EXPECT_TRUE((std::is_same_v, knowhere_compat::span>)); +#if defined(__cpp_lib_span) + EXPECT_FALSE((std::is_same_v, std::span>)); +#endif +} + +TEST_F(IOContextPoolTestFixture, InterruptedSyscallRetryShouldHaveBound) { + size_t retry = 0; + + EXPECT_TRUE(knowhere_internal::ShouldRetryInterruptedSyscall(retry, 2)); + EXPECT_EQ(retry, 1u); + EXPECT_TRUE(knowhere_internal::ShouldRetryInterruptedSyscall(retry, 2)); + EXPECT_EQ(retry, 2u); + EXPECT_FALSE(knowhere_internal::ShouldRetryInterruptedSyscall(retry, 2)); + EXPECT_EQ(retry, 3u); +} + +#ifdef ENABLE_SYNCPOINT +TEST_F(IOContextPoolTestFixture, ResetFailureShouldMakePoolFailFast) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); +#ifdef WITH_IO_URING + sync_point->SetCallBack("UringContextPool::ResetCheckedOut:BeforeInit", [](void* arg) { + *static_cast(arg) = -EMFILE; + }); +#endif +#ifdef MILVUS_COMMON_WITH_LIBAIO + sync_point->SetCallBack("AioContextPool::ResetCheckedOut:BeforeSetup", [](void* arg) { + *static_cast(arg) = -EAGAIN; + }); +#endif + sync_point->EnableProcessing(); + + auto handle = pool->Pop(); + ASSERT_TRUE(handle.HasContext()); + ASSERT_FALSE(pool->Reset(std::move(handle))); + + auto blocked = std::async(std::launch::async, [&] { return pool->Pop(); }); + ASSERT_EQ(blocked.wait_for(std::chrono::seconds(1)), std::future_status::ready); + auto failed = blocked.get(); + ASSERT_FALSE(failed.HasContext()); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); +} +#endif + +#if defined(MILVUS_COMMON_WITH_LIBAIO) && defined(ENABLE_SYNCPOINT) +TEST_F(IOContextPoolTestFixture, AioDestroyFailureShouldMakePoolFailFast) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::AIO) { + GTEST_SKIP() << "AIO backend unavailable"; + } + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + int destroy_calls = 0; + sync_point->SetCallBack("AioContextPool::DestroyContext:AfterDestroy", [&](void* arg) { + if (destroy_calls++ == 0) { + *static_cast(arg) = -EINVAL; + } + }); + sync_point->EnableProcessing(); + + auto handle = pool->Pop(); + ASSERT_TRUE(handle.HasContext()); + ASSERT_FALSE(pool->Reset(std::move(handle))); + + auto failed = std::async(std::launch::async, [&] { return pool->Pop(); }); + ASSERT_EQ(failed.wait_for(std::chrono::seconds(1)), std::future_status::ready); + ASSERT_FALSE(failed.get().HasContext()); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); +} +#endif + +#ifdef MILVUS_COMMON_WITH_LIBAIO +TEST_F(IOContextPoolTestFixture, LegacyAioInitStillWorksViaUnifiedPath) { + ASSERT_TRUE(AioContextPool::InitGlobalAioPool(2, 64)); + auto p = AioContextPool::GetGlobalAioPool(); + ASSERT_NE(p, nullptr); + ASSERT_EQ(p->max_events_per_ctx(), 64u); + auto io_pool = IOContextPool::GetGlobal(); + if (io_pool != nullptr && io_pool->IsInitialized()) { + ASSERT_NE(io_pool->Backend(), IOBackend::AIO); + } +} + +TEST_F(IOContextPoolTestFixture, LegacyAioValidationReinitMismatchShouldFail) { + ASSERT_TRUE(AioContextPool::InitGlobalAioPoolWithValidation(2, 128)); + ASSERT_FALSE(AioContextPool::InitGlobalAioPoolWithValidation(4, 128)); +} + +TEST_F(IOContextPoolTestFixture, LegacyAioValidationShouldRejectZeroMaxEvents) { + ASSERT_FALSE(AioContextPool::InitGlobalAioPoolWithValidation(1, 0)); + ASSERT_TRUE(AioContextPool::InitGlobalAioPoolWithValidation(1, default_max_events)); + + auto pool = AioContextPool::GetGlobalAioPoolDirect(); + ASSERT_NE(pool, nullptr); + ASSERT_EQ(pool->max_events_per_ctx(), default_max_events); +} + +TEST_F(IOContextPoolTestFixture, LegacyAioPopShouldReturnNullAfterShutdown) { + ASSERT_TRUE(AioContextPool::InitGlobalAioPoolWithValidation(1, 128)); + auto pool = AioContextPool::GetGlobalAioPoolDirect(); + ASSERT_NE(pool, nullptr); + + auto first = pool->pop(); + ASSERT_NE(first, nullptr); + + auto blocked = std::async(std::launch::async, [&]() { return pool->pop(); }); + ASSERT_EQ(blocked.wait_for(std::chrono::milliseconds(50)), std::future_status::timeout); + + pool->Shutdown(); + + ASSERT_EQ(blocked.wait_for(std::chrono::seconds(1)), std::future_status::ready); + ASSERT_EQ(blocked.get(), nullptr); +} + +TEST_F(IOContextPoolTestFixture, LegacyAioPoolShouldRejectDoublePush) { + ASSERT_TRUE(AioContextPool::InitGlobalAioPoolWithValidation(1, 128)); + auto pool = AioContextPool::GetGlobalAioPoolDirect(); + ASSERT_NE(pool, nullptr); + + auto first = pool->pop(); + ASSERT_NE(first, nullptr); + pool->push(first); + pool->push(first); + + auto checked_out = pool->pop(); + ASSERT_EQ(checked_out, first); + + auto blocked = std::async(std::launch::async, [&]() { return pool->pop(); }); + ASSERT_EQ(blocked.wait_for(std::chrono::milliseconds(50)), std::future_status::timeout); + + pool->push(checked_out); + ASSERT_EQ(blocked.wait_for(std::chrono::seconds(1)), std::future_status::ready); + auto second = blocked.get(); + ASSERT_EQ(second, first); + pool->push(second); +} +#endif + +#ifdef WITH_IO_URING +TEST_F(IOContextPoolTestFixture, LegacyUringInitStillWorksViaUnifiedPath) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto io_pool = IOContextPool::GetGlobal(); + ASSERT_NE(io_pool, nullptr); + if (io_pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + ASSERT_TRUE(UringContextPool::InitGlobalUringPool(1, 128)); + auto p = UringContextPool::GetGlobalUringPool(); + ASSERT_NE(p, nullptr); +} + +TEST_F(IOContextPoolTestFixture, LegacyUringValidationReinitMismatchShouldFail) { + ASSERT_TRUE(UringContextPool::InitGlobalUringPoolWithValidation(1, 64)); + ASSERT_FALSE(UringContextPool::InitGlobalUringPoolWithValidation(2, 64)); +} + +TEST_F(IOContextPoolTestFixture, LegacyUringInitShouldHonorRequestedConfig) { + if (!UringContextPool::InitGlobalUringPool(1, 64)) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + auto io_pool = IOContextPool::GetGlobal(); + ASSERT_NE(io_pool, nullptr); + ASSERT_EQ(io_pool->Backend(), IOBackend::IO_URING); + ASSERT_EQ(io_pool->MaxEventsPerCtx(), 64u); +} + +TEST_F(IOContextPoolTestFixture, LegacyUringDirectDefaultShouldProvideMultipleContexts) { + auto pool = UringContextPool::GetGlobalUringPoolDirect(); + ASSERT_NE(pool, nullptr); + if (!pool->IsUsable()) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + ASSERT_GT(pool->created_context_count(), 1u); + + auto* first = pool->pop(); + ASSERT_NE(first, nullptr); + auto* second = pool->pop(); + ASSERT_NE(second, nullptr); + pool->push(second); + pool->push(first); +} +#endif + +namespace { +constexpr size_t kIOReaderTestBlockSize = 4096; + +struct AlignedBufferDeleter { + void + operator()(std::byte* ptr) const { + std::free(ptr); + } +}; + +using AlignedBuffer = std::unique_ptr; + +AlignedBuffer +AllocateAlignedBuffer(size_t size) { + void* ptr = nullptr; + if (posix_memalign(&ptr, kIOReaderTestBlockSize, size) != 0) { + return nullptr; + } + std::memset(ptr, 0, size); + return AlignedBuffer(static_cast(ptr)); +} + +int +OpenIOReaderTestFile(char* path, bool needs_direct_io) { + int tmp_fd = ::mkstemp(path); + if (tmp_fd < 0) { + return -1; + } + ::close(tmp_fd); + + int flags = O_CREAT | O_TRUNC | O_RDWR; +#ifdef O_DIRECT + if (needs_direct_io) { + flags |= O_DIRECT; + } +#else + if (needs_direct_io) { + return -1; + } +#endif + return ::open(path, flags, 0644); +} +} // namespace + +TEST_F(IOContextPoolTestFixture, ReadAsyncShouldSubmitFirstBatchBeforeFutureWait) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 1; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + const bool needs_direct_io = pool->Backend() == IOBackend::AIO; + + char path[] = "/tmp/io_reader_eager_submit_XXXXXX"; + int fd = OpenIOReaderTestFile(path, needs_direct_io); + if (fd < 0 && needs_direct_io) { + ::unlink(path); + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(buffer, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x3}); + + const auto written = ::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0); + if (written < 0 && needs_direct_io && errno == EINVAL) { + ::close(fd); + ::unlink(path); + GTEST_SKIP() << "filesystem does not support direct I/O"; + } + ASSERT_EQ(written, static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto reader = IOReader(fd); + std::vector buffers{buffer.get()}; + std::vector offsets{0}; + + auto fut = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_EQ(::close(fd), 0); + fd = -1; + + ASSERT_TRUE(fut.get()); + ASSERT_EQ(std::memcmp(buffer.get(), content.get(), kIOReaderTestBlockSize), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, ReadAsyncShouldReadMultipleBatches) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 1; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + const bool needs_direct_io = pool->Backend() == IOBackend::AIO; + + char path[] = "/tmp/io_reader_multi_batch_XXXXXX"; + int fd = OpenIOReaderTestFile(path, needs_direct_io); + if (fd < 0 && needs_direct_io) { + ::unlink(path); + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x4}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0x5}); + + const auto written = ::pwrite(fd, content.get(), kTotalSize, 0); + if (written < 0 && needs_direct_io && errno == EINVAL) { + ::close(fd); + ::unlink(path); + GTEST_SKIP() << "filesystem does not support direct I/O"; + } + ASSERT_EQ(written, static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto reader = IOReader(fd); + std::vector buffers{first.get(), second.get()}; + std::vector offsets{0, kIOReaderTestBlockSize}; + + auto fut = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_TRUE(fut.get()); + ASSERT_EQ(std::memcmp(first.get(), content.get(), kIOReaderTestBlockSize), 0); + ASSERT_EQ(std::memcmp(second.get(), content.get() + kIOReaderTestBlockSize, kIOReaderTestBlockSize), 0); + + ::close(fd); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, ReadAsyncShouldReturnFalseOnShortRead) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + const bool needs_direct_io = pool->Backend() == IOBackend::AIO; + + char path[] = "/tmp/io_reader_short_read_XXXXXX"; + int fd = OpenIOReaderTestFile(path, needs_direct_io); + if (fd < 0 && needs_direct_io) { + ::unlink(path); + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(buffer, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0xa}); + + const auto written = ::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0); + if (written < 0 && needs_direct_io && errno == EINVAL) { + ::close(fd); + ::unlink(path); + GTEST_SKIP() << "filesystem does not support direct I/O"; + } + ASSERT_EQ(written, static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOReader reader(fd, pool); + std::vector buffers{buffer.get()}; + std::vector offsets{kIOReaderTestBlockSize}; + + auto fut = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_FALSE(fut.get()); + + ::close(fd); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, ReadAsyncShouldRejectMisalignedDirectIoBuffer) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + char path[] = "/tmp/io_reader_direct_alignment_XXXXXX"; + int fd = OpenIOReaderTestFile(path, true); + if (fd < 0 && errno == EINVAL) { + ::unlink(path); + GTEST_SKIP() << "direct I/O is not available for alignment test"; + } + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize + 1); + ASSERT_NE(content, nullptr); + ASSERT_NE(buffer, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x15}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOReader reader(fd, IOContextPool::GetGlobal()); + std::vector buffers{buffer.get() + 1}; + std::vector offsets{0}; + EXPECT_THROW(reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)), + std::invalid_argument); + + ::close(fd); + ::unlink(path); +} + +#if defined(WITH_IO_URING) && defined(ENABLE_SYNCPOINT) +TEST_F(IOContextPoolTestFixture, ReadAsyncSubmitFailureShouldResetUringHandle) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 1; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_reader_submit_failure_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + ASSERT_NE(retry, nullptr); + + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x8}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0x9}); + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + int submit_calls = 0; + sync_point->SetCallBack("IOReader::SubmitUring:Before", [&](void* arg) { + if (submit_calls++ == 1) { + *static_cast(arg) = -EIO; + } + }); + sync_point->EnableProcessing(); + + auto reader = IOReader(fd, pool); + std::vector buffers{first.get(), second.get()}; + std::vector offsets{0, kIOReaderTestBlockSize}; + auto failed = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_FALSE(failed.get()); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + std::vector retry_buffers{retry.get()}; + std::vector retry_offsets{0}; + auto retried = reader.ReadAsync(std::move(retry_buffers), kIOReaderTestBlockSize, std::move(retry_offsets)); + ASSERT_TRUE(retried.get()); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ::close(fd); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, DroppedReadAsyncFutureShouldDrainSubmittedUringIo) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 128; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_reader_dropped_future_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + ASSERT_NE(retry, nullptr); + + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0xb}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0xc}); + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto reader = IOReader(fd, pool); + { + std::vector buffers{first.get(), second.get()}; + std::vector offsets{0, kIOReaderTestBlockSize}; + auto dropped = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + } + + std::vector retry_buffers{retry.get()}; + std::vector retry_offsets{0}; + auto retried = reader.ReadAsync(std::move(retry_buffers), kIOReaderTestBlockSize, std::move(retry_offsets)); + ASSERT_TRUE(retried.get()); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ::close(fd); + ::unlink(path); +} +#endif + +#if defined(MILVUS_COMMON_WITH_LIBAIO) && defined(ENABLE_SYNCPOINT) +TEST_F(IOContextPoolTestFixture, DroppedAioReadAsyncFutureShouldDrainContext) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 2; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::AIO) { + GTEST_SKIP() << "AIO backend unavailable"; + } + + char path[] = "/tmp/io_reader_aio_dropped_future_XXXXXX"; + int fd = OpenIOReaderTestFile(path, true); + if (fd < 0 && errno == EINVAL) { + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(retry, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x11}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOReader reader(fd, pool); + { + std::vector buffers{first.get()}; + std::vector offsets{0}; + auto dropped = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + } + + std::vector retry_buffers{retry.get()}; + std::vector retry_offsets{0}; + auto retried = reader.ReadAsync(std::move(retry_buffers), kIOReaderTestBlockSize, std::move(retry_offsets)); + ASSERT_TRUE(retried.get()); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, AioLaterBatchCleanupFailureShouldResetContext) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 2; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::AIO) { + GTEST_SKIP() << "AIO backend unavailable"; + } + + char path[] = "/tmp/io_reader_aio_later_cleanup_XXXXXX"; + int fd = OpenIOReaderTestFile(path, true); + if (fd < 0 && errno == EINVAL) { + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + constexpr size_t kBlocks = 4; + constexpr size_t kTotalSize = kIOReaderTestBlockSize * kBlocks; + auto content = AllocateAlignedBuffer(kTotalSize); + ASSERT_NE(content, nullptr); + std::fill(content.get(), content.get() + kTotalSize, std::byte{0x4}); + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + std::vector owned_buffers; + std::vector buffers; + std::vector offsets; + for (size_t i = 0; i < kBlocks; ++i) { + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(buffer, nullptr); + buffers.push_back(buffer.get()); + offsets.push_back(i * kIOReaderTestBlockSize); + owned_buffers.push_back(std::move(buffer)); + } + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + int submit_calls = 0; + int wait_calls = 0; + int reset_calls = 0; + sync_point->SetCallBack("IOReader::SubmitAioBatch:BeforeSubmit", [&](void* arg) { + if (submit_calls++ == 1) { + *static_cast(arg) = 1; + } + }); + sync_point->SetCallBack("IOReader::WaitAioBatch:BeforeReturn", [&](void* arg) { + if (wait_calls++ == 1) { + *static_cast(arg) = true; + } + }); + sync_point->SetCallBack("AioContextPool::ResetCheckedOut:BeforeSetup", [&](void* arg) { + ++reset_calls; + *static_cast(arg) = 0; + }); + sync_point->EnableProcessing(); + + IOReader reader(fd, pool); + auto failed = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_FALSE(failed.get()); + ASSERT_GT(reset_calls, 0); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(retry, nullptr); + std::vector retry_buffers{retry.get()}; + std::vector retry_offsets{0}; + auto retried = reader.ReadAsync(std::move(retry_buffers), kIOReaderTestBlockSize, std::move(retry_offsets)); + ASSERT_TRUE(retried.get()); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, AioPartialProgressShouldNotExhaustRetryBudget) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 16; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::AIO) { + GTEST_SKIP() << "AIO backend unavailable"; + } + + char path[] = "/tmp/io_reader_aio_partial_progress_XXXXXX"; + int fd = OpenIOReaderTestFile(path, true); + if (fd < 0 && errno == EINVAL) { + GTEST_SKIP() << "direct I/O is not available for AIO ReadAsync test"; + } + ASSERT_GE(fd, 0); + + constexpr size_t kBlocks = 12; + constexpr size_t kTotalSize = kIOReaderTestBlockSize * kBlocks; + auto content = AllocateAlignedBuffer(kTotalSize); + ASSERT_NE(content, nullptr); + for (size_t i = 0; i < kBlocks; ++i) { + std::fill(content.get() + i * kIOReaderTestBlockSize, content.get() + (i + 1) * kIOReaderTestBlockSize, + static_cast(0x20 + i)); + } + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + std::vector owned_buffers; + std::vector buffers; + std::vector offsets; + for (size_t i = 0; i < kBlocks; ++i) { + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(buffer, nullptr); + buffers.push_back(buffer.get()); + offsets.push_back(i * kIOReaderTestBlockSize); + owned_buffers.push_back(std::move(buffer)); + } + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + sync_point->SetCallBack("IOReader::SubmitAioBatch:BeforeSubmit", [](void* arg) { + *static_cast(arg) = 1; + }); + sync_point->SetCallBack("IOReader::WaitAioBatch:BeforeGetEvents", [](void* arg) { + *static_cast(arg) = 1; + }); + sync_point->EnableProcessing(); + + IOReader reader(fd, pool); + auto result = reader.ReadAsync(std::move(buffers), kIOReaderTestBlockSize, std::move(offsets)); + ASSERT_TRUE(result.get()); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + for (size_t i = 0; i < kBlocks; ++i) { + ASSERT_EQ(std::memcmp(owned_buffers[i].get(), content.get() + i * kIOReaderTestBlockSize, + kIOReaderTestBlockSize), + 0); + } + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} +#endif + +#ifdef WITH_IO_URING +TEST_F(IOContextPoolTestFixture, CompletionReaderReturnsSubmittedRequestIds) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x1}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0x2}); + + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOCompletionReader reader(fd, pool); + std::array first_buffers{first.get()}; + std::array first_offsets{0}; + std::array second_buffers{second.get()}; + std::array second_offsets{kIOReaderTestBlockSize}; + + const auto request_1 = reader.Submit(IOCompletionReaderSpan(first_buffers.data(), first_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(first_offsets.data(), first_offsets.size())); + const auto request_2 = reader.Submit(IOCompletionReaderSpan(second_buffers.data(), second_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(second_offsets.data(), second_offsets.size())); + ASSERT_NE(request_1, request_2); + + std::vector completions; + completions.push_back(reader.WaitCompleted()); + + auto polled = reader.PollCompleted(); + completions.insert(completions.end(), polled.begin(), polled.end()); + while (completions.size() < 2) { + completions.push_back(reader.WaitCompleted()); + } + + ASSERT_EQ(completions.size(), 2u); + ASSERT_TRUE(std::all_of(completions.begin(), completions.end(), [](const auto& c) { return c.ok; })); + + std::vector ids{completions[0].request_id, completions[1].request_id}; + std::sort(ids.begin(), ids.end()); + ASSERT_EQ(ids[0], std::min(request_1, request_2)); + ASSERT_EQ(ids[1], std::max(request_1, request_2)); + + ASSERT_EQ(std::memcmp(first.get(), content.get(), kIOReaderTestBlockSize), 0); + ASSERT_EQ(std::memcmp(second.get(), content.get() + kIOReaderTestBlockSize, kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderWaitsForAllBuffersInRequest) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_multi_buffer_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0xd}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0xe}); + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOCompletionReader reader(fd, pool); + std::array buffers{first.get(), second.get()}; + std::array offsets{0, kIOReaderTestBlockSize}; + + const auto request = reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(offsets.data(), offsets.size())); + auto completion = reader.WaitCompleted(); + ASSERT_EQ(completion.request_id, request); + ASSERT_TRUE(completion.ok); + ASSERT_TRUE(reader.PollCompleted().empty()); + + ASSERT_EQ(std::memcmp(first.get(), content.get(), kIOReaderTestBlockSize), 0); + ASSERT_EQ(std::memcmp(second.get(), content.get() + kIOReaderTestBlockSize, kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderReportsShortReadFailure) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_short_read_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(buffer, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x12}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOCompletionReader reader(fd, pool); + std::array buffers{buffer.get()}; + std::array offsets{kIOReaderTestBlockSize}; + const auto request = reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(offsets.data(), offsets.size())); + auto completion = reader.WaitCompleted(); + ASSERT_EQ(completion.request_id, request); + ASSERT_FALSE(completion.ok); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderReportsNegativeCqeFailure) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_negative_cqe_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(buffer, nullptr); + + IOCompletionReader reader(fd, pool); + ASSERT_EQ(::close(fd), 0); + fd = -1; + + std::array buffers{buffer.get()}; + std::array offsets{0}; + const auto request = reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(offsets.data(), offsets.size())); + auto completion = reader.WaitCompleted(); + ASSERT_EQ(completion.request_id, request); + ASSERT_FALSE(completion.ok); + + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderShouldRejectMisalignedDirectIoBuffer) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_direct_alignment_XXXXXX"; + int fd = OpenIOReaderTestFile(path, true); + if (fd < 0 && errno == EINVAL) { + ::unlink(path); + GTEST_SKIP() << "direct I/O is not available for alignment test"; + } + ASSERT_GE(fd, 0); + + auto buffer = AllocateAlignedBuffer(kIOReaderTestBlockSize + 1); + ASSERT_NE(buffer, nullptr); + + IOCompletionReader reader(fd, pool); + std::array buffers{buffer.get() + 1}; + std::array offsets{0}; + EXPECT_THROW(reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + std::invalid_argument); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderRejectsBatchLargerThanCapacity) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 1; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_capacity_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + + IOCompletionReader reader(fd, pool); + std::array buffers{first.get(), second.get()}; + std::array offsets{0, kIOReaderTestBlockSize}; + + EXPECT_THROW(reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), kIOReaderTestBlockSize, + IOCompletionReaderSpan(offsets.data(), offsets.size())), + std::invalid_argument); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +#ifdef ENABLE_SYNCPOINT +TEST_F(IOContextPoolTestFixture, CompletionReaderRejectsOutstandingRequestsBeyondCapacity) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 2; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_outstanding_capacity_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto third = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + ASSERT_NE(third, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0xf}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOCompletionReader reader(fd, pool); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + sync_point->SetCallBack("IOCompletionReader::ProcessAvailableCompletions:Skip", [](void* arg) { + *static_cast(arg) = true; + }); + sync_point->EnableProcessing(); + + std::array offsets{0}; + std::array first_buffers{first.get()}; + std::array second_buffers{second.get()}; + std::array third_buffers{third.get()}; + ASSERT_NE(reader.Submit(IOCompletionReaderSpan(first_buffers.data(), first_buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + 0u); + ASSERT_NE(reader.Submit(IOCompletionReaderSpan(second_buffers.data(), second_buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + 0u); + EXPECT_THROW(reader.Submit(IOCompletionReaderSpan(third_buffers.data(), third_buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + std::runtime_error); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderSubmitFailureKeepsExistingRequestObservable) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_submit_failure_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + constexpr size_t kTotalSize = kIOReaderTestBlockSize * 2; + auto content = AllocateAlignedBuffer(kTotalSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto second = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(second, nullptr); + + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x6}); + std::fill(content.get() + kIOReaderTestBlockSize, content.get() + kTotalSize, std::byte{0x7}); + + ASSERT_EQ(::pwrite(fd, content.get(), kTotalSize, 0), static_cast(kTotalSize)); + ASSERT_EQ(::fsync(fd), 0); + + IOCompletionReader reader(fd, pool); + std::array first_buffers{first.get()}; + std::array first_offsets{0}; + const auto request_1 = reader.Submit(IOCompletionReaderSpan(first_buffers.data(), first_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(first_offsets.data(), first_offsets.size())); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + int forced_submits = 0; + int skipped_polls = 0; + sync_point->SetCallBack("IOCompletionReader::ProcessAvailableCompletions:Skip", [&](void* arg) { + if (skipped_polls++ == 0) { + *static_cast(arg) = true; + } + }); + sync_point->SetCallBack("IOCompletionReader::SubmitRing:Before", [&](void* arg) { + if (forced_submits++ == 0) { + *static_cast(arg) = -EIO; + } + }); + sync_point->EnableProcessing(); + + std::array second_buffers{second.get()}; + std::array second_offsets{kIOReaderTestBlockSize}; + EXPECT_THROW(reader.Submit(IOCompletionReaderSpan(second_buffers.data(), second_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(second_offsets.data(), second_offsets.size())), + std::runtime_error); + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + auto completion = reader.WaitCompleted(); + EXPECT_EQ(completion.request_id, request_1); + EXPECT_TRUE(completion.ok); + EXPECT_THROW(reader.WaitCompleted(), std::runtime_error); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderDestructorResetKeepsPoolReusable) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_destructor_reset_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(retry, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x13}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + sync_point->SetCallBack("IOCompletionReader::DrainOutstandingNoThrow:Skip", [](void* arg) { + *static_cast(arg) = true; + }); + sync_point->EnableProcessing(); + + { + IOCompletionReader reader(fd, pool); + std::array buffers{first.get()}; + std::array offsets{0}; + ASSERT_NE(reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + 0u); + } + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + IOCompletionReader retry_reader(fd, pool); + std::array retry_buffers{retry.get()}; + std::array retry_offsets{0}; + const auto request = retry_reader.Submit(IOCompletionReaderSpan(retry_buffers.data(), retry_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(retry_offsets.data(), retry_offsets.size())); + auto completion = retry_reader.WaitCompleted(); + ASSERT_EQ(completion.request_id, request); + ASSERT_TRUE(completion.ok); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} + +TEST_F(IOContextPoolTestFixture, CompletionReaderNullCqeShouldFailWithoutHanging) { + IOContextPoolConfig cfg; + cfg.num_ctx = 1; + cfg.max_events = 8; + ASSERT_TRUE(IOContextPool::InitGlobal(cfg)); + + auto pool = IOContextPool::GetGlobal(); + ASSERT_NE(pool, nullptr); + if (pool->Backend() != IOBackend::IO_URING) { + GTEST_SKIP() << "io_uring backend unavailable"; + } + + char path[] = "/tmp/io_completion_reader_null_cqe_XXXXXX"; + int fd = OpenIOReaderTestFile(path, false); + ASSERT_GE(fd, 0); + + auto content = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto first = AllocateAlignedBuffer(kIOReaderTestBlockSize); + auto retry = AllocateAlignedBuffer(kIOReaderTestBlockSize); + ASSERT_NE(content, nullptr); + ASSERT_NE(first, nullptr); + ASSERT_NE(retry, nullptr); + std::fill(content.get(), content.get() + kIOReaderTestBlockSize, std::byte{0x14}); + ASSERT_EQ(::pwrite(fd, content.get(), kIOReaderTestBlockSize, 0), static_cast(kIOReaderTestBlockSize)); + ASSERT_EQ(::fsync(fd), 0); + + auto* sync_point = milvus::SyncPoint::GetInstance(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + sync_point->SetCallBack("IOCompletionReader::WaitOneCompletion:ForceNullCqe", [](void* arg) { + *static_cast(arg) = true; + }); + sync_point->EnableProcessing(); + + { + IOCompletionReader reader(fd, pool); + std::array buffers{first.get()}; + std::array offsets{0}; + ASSERT_NE(reader.Submit(IOCompletionReaderSpan(buffers.data(), buffers.size()), + kIOReaderTestBlockSize, IOCompletionReaderSpan(offsets.data(), offsets.size())), + 0u); + EXPECT_THROW(reader.WaitCompleted(), std::runtime_error); + } + + sync_point->DisableProcessing(); + sync_point->ClearAllCallBacks(); + sync_point->ClearTrace(); + + IOCompletionReader retry_reader(fd, pool); + std::array retry_buffers{retry.get()}; + std::array retry_offsets{0}; + const auto request = retry_reader.Submit(IOCompletionReaderSpan(retry_buffers.data(), retry_buffers.size()), + kIOReaderTestBlockSize, + IOCompletionReaderSpan(retry_offsets.data(), retry_offsets.size())); + auto completion = retry_reader.WaitCompleted(); + ASSERT_EQ(completion.request_id, request); + ASSERT_TRUE(completion.ok); + ASSERT_EQ(std::memcmp(retry.get(), content.get(), kIOReaderTestBlockSize), 0); + + ASSERT_EQ(::close(fd), 0); + ::unlink(path); +} +#endif +#endif