diff --git a/include/ptrhash/ptrhash.hpp b/include/ptrhash/ptrhash.hpp new file mode 100644 index 0000000..5af1c06 --- /dev/null +++ b/include/ptrhash/ptrhash.hpp @@ -0,0 +1,1479 @@ +#ifndef PTRHASH_PTRHASH_HPP +#define PTRHASH_PTRHASH_HPP + +// Based on PtrHash, a minimal perfect hashing scheme. +// Paper: "PtrHash: Minimal Perfect Hashing at RAM Throughput" +// https://arxiv.org/abs/2502.15539 +// Reference implementation: +// https://github.com/RagnarGrootKoerkamp/PtrHash +// +// Usage: +// std::vector keys = {10, 20, 30}; +// auto hash = ptrhash::PtrHash::build(keys); +// size_t index = hash.index(20); // in [0, hash.n()) for keys used to build the hash. +// +// PtrHash does not store the original keys and cannot prove membership by itself. +// If queries may contain keys outside the build set, keep an index-addressed key +// or fingerprint table and verify the candidate returned by index(): +// +// std::vector keys_by_index(hash.n()); +// for (uint64_t key : keys) { +// keys_by_index[hash.index(key)] = key; +// } +// +// size_t candidate = hash.index(query); +// bool found = keys_by_index[candidate] == query; +// +// The serialized data contains the pilots and remap table required by index(). +// It can be persisted with hash.save(path) and restored with PtrHash::load(path). + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__unix__) || defined(__APPLE__) +#include +#include +#include +#include +#endif + +namespace ptrhash { + +enum class BucketFunction : uint64_t { + Linear = 0, + SquareEps = 1, + CubicEps = 2, +}; + +struct PtrHashParams { + double alpha = 0.99; + double lambda = 3.0; + uint64_t seed = 0x3141592653589793ull; + size_t max_seed_attempts = 256; + uint16_t max_pilot = 255; + // 0 means auto. Set to 1 to force single-core construction. + size_t build_threads = 0; + BucketFunction bucket_function = BucketFunction::Linear; +}; + +namespace detail { + +// Magic identifies the local PtrHash serialization family. The format revision +// is stored separately in kVersion. +constexpr uint8_t kMagic[8] = {'P', 'T', 'R', 'H', 'A', 'S', 'H', '\0'}; +constexpr uint32_t kVersion = 1; +constexpr size_t kHeaderSize = 8 + 4 + 4 + 8 * 9; +constexpr uint64_t kMix = 0x517cc1b727220a95ull; +constexpr uint32_t kRemapU32 = 4; +constexpr uint32_t kBucketFunctionShift = 8; +constexpr uint32_t kKeyHashKindShift = 16; + +enum class KeyHashKind : uint32_t { + Integer = 0, + String = 1, + Hash64 = 2, +}; + +inline uint64_t +mul_high(uint64_t a, uint64_t b) { +#if defined(__SIZEOF_INT128__) + return static_cast((static_cast(a) * b) >> 64); +#else + const uint64_t a_lo = static_cast(a); + const uint64_t a_hi = a >> 32; + const uint64_t b_lo = static_cast(b); + const uint64_t b_hi = b >> 32; + const uint64_t p0 = a_lo * b_lo; + const uint64_t p1 = a_lo * b_hi; + const uint64_t p2 = a_hi * b_lo; + const uint64_t p3 = a_hi * b_hi; + const uint64_t carry = ((p0 >> 32) + static_cast(p1) + static_cast(p2)) >> 32; + return p3 + (p1 >> 32) + (p2 >> 32) + carry; +#endif +} + +inline size_t +fast_reduce(uint64_t d, uint64_t h) { + return static_cast(mul_high(d, h)); +} + +inline uint64_t +fastmod32_multiplier(size_t d) { + return (std::numeric_limits::max() / static_cast(d)) + 1; +} + +inline size_t +fastmod32_reduce(uint64_t d, uint64_t m, uint64_t h) { + const uint64_t lowbits = m * h; +#if defined(__SIZEOF_INT128__) + return static_cast((static_cast(lowbits) * d) >> 64); +#else + return fast_reduce(d, lowbits); +#endif +} + +inline uint64_t +splitmix64(uint64_t x) { + x += 0x9e3779b97f4a7c15ull; + x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull; + x = (x ^ (x >> 27)) * 0x94d049bb133111ebull; + return x ^ (x >> 31); +} + +inline uint64_t +hash_key(uint64_t key, uint64_t seed) { +#if defined(__SIZEOF_INT128__) + const auto r = static_cast(key ^ seed) * kMix; + const auto low = static_cast(r); + const auto high = static_cast(r >> 64); + return (low ^ high) * kMix; +#else + return splitmix64(key ^ seed); +#endif +} + +inline uint64_t +read_u64_unaligned(const uint8_t* p) { + uint64_t value = 0; +#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + std::memcpy(&value, p, sizeof(value)); +#else + for (int i = 0; i < 8; ++i) { + value |= static_cast(p[i]) << (8 * i); + } +#endif + return value; +} + +inline uint64_t +rotl64(uint64_t x, int r) { + return (x << r) | (x >> (64 - r)); +} + +inline uint64_t +fmix64(uint64_t x) { + x ^= x >> 33; + x *= 0xff51afd7ed558ccdull; + x ^= x >> 33; + x *= 0xc4ceb9fe1a85ec53ull; + x ^= x >> 33; + return x; +} + +inline uint64_t +hash_bytes(std::string_view key, uint64_t seed) { + const auto* p = reinterpret_cast(key.data()); + const size_t len = key.size(); + constexpr uint64_t c1 = 0x87c37b91114253d5ull; + constexpr uint64_t c2 = 0x4cf5ad432745937full; + uint64_t h1 = seed; + uint64_t h2 = seed ^ (static_cast(len) * kMix); + + size_t remaining = len; + while (remaining >= 16) { + uint64_t k1 = read_u64_unaligned(p); + uint64_t k2 = read_u64_unaligned(p + 8); + + k1 *= c1; + k1 = rotl64(k1, 31); + k1 *= c2; + h1 ^= k1; + h1 = rotl64(h1, 27); + h1 += h2; + h1 = h1 * 5 + 0x52dce729; + + k2 *= c2; + k2 = rotl64(k2, 33); + k2 *= c1; + h2 ^= k2; + h2 = rotl64(h2, 31); + h2 += h1; + h2 = h2 * 5 + 0x38495ab5; + + p += 16; + remaining -= 16; + } + + uint64_t k1 = 0; + uint64_t k2 = 0; + const size_t first_tail = std::min(remaining, 8); + for (size_t i = 0; i < first_tail; ++i) { + k1 |= static_cast(p[i]) << (8 * i); + } + for (size_t i = 8; i < remaining; ++i) { + k2 |= static_cast(p[i]) << (8 * (i - 8)); + } + if (k2 != 0) { + k2 *= c2; + k2 = rotl64(k2, 33); + k2 *= c1; + h2 ^= k2; + } + if (k1 != 0) { + k1 *= c1; + k1 = rotl64(k1, 31); + k1 *= c2; + h1 ^= k1; + } + + h1 ^= static_cast(len); + h2 ^= static_cast(len); + h1 += h2; + h2 += h1; + h1 = fmix64(h1); + h2 = fmix64(h2); + h1 += h2; + return h1; +} + +template +inline typename std::enable_if::value && sizeof(Key) <= sizeof(uint64_t), uint64_t>::type +hash_key_for(Key key, uint64_t seed) { + return hash_key(static_cast(key), seed); +} + +inline uint64_t +hash_key_for(std::string_view key, uint64_t seed) { + return hash_bytes(key, seed); +} + +inline uint64_t +hash_pilot(uint64_t pilot, uint64_t seed) { + return kMix * (pilot ^ seed); +} + +inline bool +likely(bool value) { +#if defined(__GNUC__) || defined(__clang__) + return __builtin_expect(value, true); +#else + return value; +#endif +} + +inline uint64_t +bucket_transform(uint64_t x, BucketFunction bucket_function) { + switch (bucket_function) { + case BucketFunction::Linear: + return x; + case BucketFunction::SquareEps: + return mul_high(x, x) / 256 * 255 + x / 256; + case BucketFunction::CubicEps: + return mul_high(mul_high(x, x), (x >> 1) | (1ull << 63)) / 256 * 255 + x / 256; + } + throw std::invalid_argument("unknown bucket function"); +} + +inline void +append_u32(std::vector& out, uint32_t value) { + for (int i = 0; i < 4; ++i) { + out.push_back(static_cast(value >> (8 * i))); + } +} + +inline void +append_u64(std::vector& out, uint64_t value) { + for (int i = 0; i < 8; ++i) { + out.push_back(static_cast(value >> (8 * i))); + } +} + +inline uint32_t +read_u32(const uint8_t* p) { + uint32_t value = 0; + for (int i = 0; i < 4; ++i) { + value |= static_cast(p[i]) << (8 * i); + } + return value; +} + +inline uint64_t +read_u64(const uint8_t* p) { + uint64_t value = 0; + for (int i = 0; i < 8; ++i) { + value |= static_cast(p[i]) << (8 * i); + } + return value; +} + +inline size_t +ceil_division_as_size(double numerator, double denominator) { + if (!std::isfinite(denominator) || !(denominator > 0.0)) { + throw std::invalid_argument("PtrHash parameter must be positive"); + } + if (!std::isfinite(numerator) || numerator < 0.0) { + throw std::overflow_error("PtrHash size overflow"); + } + const double value = numerator / denominator; + if (!std::isfinite(value) || value < 0.0 || !(value < static_cast(std::numeric_limits::max()))) { + throw std::overflow_error("PtrHash size overflow"); + } + const size_t truncated = static_cast(value); + if (static_cast(truncated) == value) { + return truncated; + } + if (truncated == std::numeric_limits::max()) { + throw std::overflow_error("PtrHash size overflow"); + } + return truncated + 1; +} + +inline size_t +checked_multiply_as_size(size_t lhs, size_t rhs, const char* message) { + if (lhs != 0 && rhs > std::numeric_limits::max() / lhs) { + throw std::overflow_error(message); + } + return lhs * rhs; +} + +inline size_t +choose_parts(size_t n, double alpha) { + if (n == 0) { + return 0; + } + const double eps = (1.0 - alpha) / 2.0; + const double x = static_cast(n) * eps * eps / 2.0; + if (!(x > std::exp(1.0))) { + return 1; + } + double target_parts = x / std::log(x); + if (!(target_parts >= 1.0)) { + target_parts = 1.0; + } + const size_t compression_parts = std::max(1, static_cast(std::floor(target_parts))); + const size_t parallel_parts = 1; + return std::max(compression_parts, std::max(1, parallel_parts)); +} + +} // namespace detail + +class PtrHashView { + public: + PtrHashView() = default; + + static PtrHashView + from_bytes(const void* data, size_t size) { + if (data == nullptr && size != 0) { + throw std::invalid_argument("PtrHashView data is null"); + } + const auto* bytes = static_cast(data); + if (size < detail::kHeaderSize) { + throw std::invalid_argument("PtrHash data is truncated"); + } + if (!std::equal(detail::kMagic, detail::kMagic + 8, bytes)) { + throw std::invalid_argument("PtrHash magic mismatch"); + } + if (detail::read_u32(bytes + 8) != detail::kVersion) { + throw std::invalid_argument("unsupported PtrHash version"); + } + + const uint32_t flags = detail::read_u32(bytes + 12); + const uint32_t remap_width = flags & 0xffu; + const auto bucket_function = static_cast((flags >> detail::kBucketFunctionShift) & 0xffu); + const auto key_hash_kind = static_cast((flags >> detail::kKeyHashKindShift) & 0xffu); + if (remap_width != detail::kRemapU32) { + throw std::invalid_argument("unsupported PtrHash remap width"); + } + (void)detail::bucket_transform(0, bucket_function); + switch (key_hash_kind) { + case detail::KeyHashKind::Integer: + case detail::KeyHashKind::String: + case detail::KeyHashKind::Hash64: + break; + default: + throw std::invalid_argument("unsupported PtrHash key hash kind"); + } + + PtrHashView view; + view.data_ = bytes; + view.capacity_ = size; + view.bucket_function_ = bucket_function; + view.key_hash_kind_ = key_hash_kind; + const uint8_t* cursor = bytes + 16; + view.n_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.slots_total_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.buckets_total_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.seed_ = detail::read_u64(cursor); + cursor += 8; + view.pilot_count_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.remap_count_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.parts_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.slots_per_part_ = static_cast(detail::read_u64(cursor)); + cursor += 8; + view.buckets_per_part_ = static_cast(detail::read_u64(cursor)); + view.rem_slots_m_ = view.slots_per_part_ == 0 ? 0 : detail::fastmod32_multiplier(view.slots_per_part_); + for (size_t pilot = 0; pilot < view.pilot_hashes_.size(); ++pilot) { + view.pilot_hashes_[pilot] = detail::hash_pilot(pilot, view.seed_); + } + + if (view.n_ == 0 && + (view.slots_total_ != 0 || view.buckets_total_ != 0 || view.pilot_count_ != 0 || view.remap_count_ != 0 || + view.parts_ != 0 || view.slots_per_part_ != 0 || view.buckets_per_part_ != 0)) { + throw std::invalid_argument("PtrHash empty layout mismatch"); + } + if (view.parts_ == 0 && view.n_ != 0) { + throw std::invalid_argument("PtrHash part count is invalid"); + } + if (view.n_ != 0 && (view.parts_ == 0 || view.slots_per_part_ == 0 || view.buckets_per_part_ == 0 || + view.slots_total_ == 0 || view.buckets_total_ == 0 || view.pilot_count_ == 0)) { + throw std::invalid_argument("PtrHash non-empty layout has zero counts"); + } + if (view.parts_ != 0 && view.slots_per_part_ != 0 && + view.parts_ > std::numeric_limits::max() / view.slots_per_part_) { + throw std::overflow_error("PtrHash slot layout overflow"); + } + if (view.parts_ != 0 && view.buckets_per_part_ != 0 && + view.parts_ > std::numeric_limits::max() / view.buckets_per_part_) { + throw std::overflow_error("PtrHash bucket layout overflow"); + } + if (view.parts_ != 0 && view.slots_total_ != view.parts_ * view.slots_per_part_) { + throw std::invalid_argument("PtrHash slot layout mismatch"); + } + if (view.parts_ != 0 && view.buckets_total_ != view.parts_ * view.buckets_per_part_) { + throw std::invalid_argument("PtrHash bucket layout mismatch"); + } + if (view.pilot_count_ != view.buckets_total_) { + throw std::invalid_argument("PtrHash pilot count mismatch"); + } + if (view.slots_total_ < view.n_) { + throw std::invalid_argument("PtrHash slot count is invalid"); + } + if (view.remap_count_ != view.slots_total_ - view.n_) { + throw std::invalid_argument("PtrHash remap count mismatch"); + } + const size_t pilot_bytes = view.pilot_count_; + if (view.pilot_count_ != 0 && pilot_bytes / sizeof(uint8_t) != view.pilot_count_) { + throw std::overflow_error("PtrHash pilot size overflow"); + } + const size_t remap_bytes = view.remap_count_ * sizeof(uint32_t); + if (view.remap_count_ != 0 && remap_bytes / sizeof(uint32_t) != view.remap_count_) { + throw std::overflow_error("PtrHash remap size overflow"); + } + if (pilot_bytes > std::numeric_limits::max() - detail::kHeaderSize || + remap_bytes > std::numeric_limits::max() - detail::kHeaderSize - pilot_bytes) { + throw std::overflow_error("PtrHash serialized size overflow"); + } + view.serialized_size_ = detail::kHeaderSize + pilot_bytes + remap_bytes; + if (view.serialized_size_ > size) { + throw std::invalid_argument("PtrHash data is truncated"); + } + view.pilots_ = bytes + detail::kHeaderSize; + view.remap_ = view.pilots_ + pilot_bytes; + for (size_t i = 0; i < view.remap_count_; ++i) { + if (detail::read_u32(view.remap_ + i * sizeof(uint32_t)) >= view.n_) { + throw std::invalid_argument("PtrHash remap entry is invalid"); + } + } + return view; + } + + size_t + n() const { + return n_; + } + size_t + max_index() const { + return slots_total_; + } + size_t + bucket_count() const { + return buckets_total_; + } + size_t + serialized_size() const { + return serialized_size_; + } + + size_t + index_no_remap(uint64_t key) const { + require_key_hash_kind(detail::KeyHashKind::Integer); + return index_no_remap_hx(detail::hash_key(key, seed_)); + } + + size_t + index_no_remap(std::string_view key) const { + require_key_hash_kind(detail::KeyHashKind::String); + return index_no_remap_hx(detail::hash_key_for(key, seed_)); + } + + size_t + index_no_remap_hash(uint64_t hash) const { + require_key_hash_kind(detail::KeyHashKind::Hash64); + return index_no_remap_hx(detail::hash_key(hash, seed_)); + } + + size_t + index(uint64_t key) const { + return index_from_slot(index_no_remap(key)); + } + + size_t + index(std::string_view key) const { + return index_from_slot(index_no_remap(key)); + } + + size_t + index_hash(uint64_t hash) const { + return index_from_slot(index_no_remap_hash(hash)); + } + + private: + void + require_key_hash_kind(detail::KeyHashKind expected) const { + if (key_hash_kind_ != expected) { + throw std::invalid_argument("PtrHash key type does not match this data"); + } + } + + size_t + index_no_remap_hx(uint64_t hx) const { + if (n_ == 0) { + throw std::out_of_range("cannot query an empty PtrHash"); + } + const size_t part = detail::fast_reduce(static_cast(parts_), hx); + const size_t bucket = + bucket_function_ == BucketFunction::Linear + ? detail::fast_reduce(static_cast(buckets_total_), hx) + : part * buckets_per_part_ + + detail::fast_reduce( + static_cast(buckets_per_part_), + detail::bucket_transform(detail::splitmix64(hx ^ 0x243f6a8885a308d3ull), bucket_function_)); + const uint64_t pilot = pilots_[bucket]; + const size_t slot_in_part = + detail::fastmod32_reduce(static_cast(slots_per_part_), rem_slots_m_, hx ^ pilot_hashes_[pilot]); + return part * slots_per_part_ + slot_in_part; + } + + size_t + index_from_slot(size_t slot) const { + if (detail::likely(slot < n_)) { + return slot; + } + return static_cast(detail::read_u32(remap_ + (slot - n_) * 4)); + } + + const uint8_t* data_ = nullptr; + const uint8_t* pilots_ = nullptr; + const uint8_t* remap_ = nullptr; + size_t capacity_ = 0; + size_t serialized_size_ = 0; + size_t n_ = 0; + size_t slots_total_ = 0; + size_t buckets_total_ = 0; + size_t pilot_count_ = 0; + size_t remap_count_ = 0; + size_t parts_ = 0; + size_t slots_per_part_ = 0; + size_t buckets_per_part_ = 0; + uint64_t rem_slots_m_ = 0; + uint64_t seed_ = 0; + std::array pilot_hashes_{}; + BucketFunction bucket_function_ = BucketFunction::Linear; + detail::KeyHashKind key_hash_kind_ = detail::KeyHashKind::Integer; +}; + +class PtrHash { + public: + PtrHash() = default; + + PtrHash(const PtrHash& other) : storage_(other.storage_) { + reset_view(); + } + + PtrHash& + operator=(const PtrHash& other) { + if (this != &other) { + storage_ = other.storage_; + reset_view(); + } + return *this; + } + + PtrHash(PtrHash&& other) noexcept : storage_(std::move(other.storage_)) { + reset_view(); + } + + PtrHash& + operator=(PtrHash&& other) noexcept { + if (this != &other) { + storage_ = std::move(other.storage_); + reset_view(); + } + return *this; + } + + template ::value && sizeof(Key) <= sizeof(uint64_t), int>::type = 0> + static PtrHash + build(const std::vector& keys, const PtrHashParams& params = PtrHashParams()) { + return build_impl(keys, params, detail::KeyHashKind::Integer); + } + + static PtrHash + build(const std::vector& keys, const PtrHashParams& params = PtrHashParams()) { + return build_impl(keys, params, detail::KeyHashKind::String); + } + + static PtrHash + build(const std::vector& keys, const PtrHashParams& params = PtrHashParams()) { + return build_impl(keys, params, detail::KeyHashKind::String); + } + + static PtrHash + build_hashes(const std::vector& hashes, const PtrHashParams& params = PtrHashParams()) { + return build_impl(hashes, params, detail::KeyHashKind::Hash64); + } + + private: + using BucketId = uint32_t; + + template + static PtrHash + build_impl(const std::vector& keys, const PtrHashParams& params, detail::KeyHashKind key_hash_kind) { + validate_params(params); + validate_unique(keys); + if (keys.empty()) { + return from_parts(0, 0, 0, 0, 0, 0, params.seed, params.bucket_function, key_hash_kind, {}, {}); + } + + const size_t n = keys.size(); + if (n > static_cast(std::numeric_limits::max())) { + throw std::overflow_error("this compact serialized format supports up to 2^32-1 keys"); + } + const size_t parts = detail::choose_parts(n, params.alpha); + + for (size_t attempt = 0; attempt < params.max_seed_attempts; ++attempt) { + const uint64_t seed = detail::splitmix64(params.seed + attempt); + size_t slots_total = 0; + size_t buckets_total = 0; + size_t slots_per_part = 0; + size_t buckets_per_part = 0; + std::vector pilots; + std::vector remap; + if (try_build(keys, parts, params.alpha, params.lambda, seed, params.max_pilot, params.build_threads, + params.bucket_function, slots_total, buckets_total, slots_per_part, buckets_per_part, pilots, + remap)) { + return from_parts(n, slots_total, buckets_total, parts, slots_per_part, buckets_per_part, seed, + params.bucket_function, key_hash_kind, std::move(pilots), std::move(remap)); + } + } + throw std::runtime_error("unable to construct PtrHash with the requested parameters"); + } + + public: + static PtrHash + deserialize(const void* data, size_t size) { + PtrHashView view = PtrHashView::from_bytes(data, size); + const auto* bytes = static_cast(data); + PtrHash hash; + hash.storage_.assign(bytes, bytes + view.serialized_size()); + hash.reset_view(); + return hash; + } + + static PtrHash + load(const std::string& path) { + std::ifstream in(path, std::ios::binary); + if (!in) { + throw std::runtime_error("failed to open PtrHash file"); + } + std::vector bytes((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + return deserialize(bytes.data(), bytes.size()); + } + + void + save(const std::string& path) const { + std::ofstream out(path, std::ios::binary); + if (!out) { + throw std::runtime_error("failed to create PtrHash file"); + } + out.write(reinterpret_cast(storage_.data()), static_cast(storage_.size())); + if (!out.good()) { + throw std::runtime_error("failed to write PtrHash file"); + } + } + + size_t + n() const { + return view_.n(); + } + size_t + max_index() const { + return view_.max_index(); + } + size_t + index_no_remap(uint64_t key) const { + return view_.index_no_remap(key); + } + size_t + index_no_remap(std::string_view key) const { + return view_.index_no_remap(key); + } + size_t + index_no_remap_hash(uint64_t hash) const { + return view_.index_no_remap_hash(hash); + } + size_t + index(uint64_t key) const { + return view_.index(key); + } + size_t + index(std::string_view key) const { + return view_.index(key); + } + size_t + index_hash(uint64_t hash) const { + return view_.index_hash(hash); + } + const PtrHashView& + view() const { + return view_; + } + const std::vector& + serialize() const { + return storage_; + } + + private: + static void + validate_params(const PtrHashParams& params) { + if (!std::isfinite(params.alpha) || !(params.alpha > 0.0 && params.alpha <= 1.0)) { + throw std::invalid_argument("alpha must be in (0, 1]"); + } + if (!std::isfinite(params.lambda) || !(params.lambda > 0.0)) { + throw std::invalid_argument("lambda must be positive"); + } + if (params.max_pilot > std::numeric_limits::max()) { + throw std::invalid_argument("max_pilot must fit in the u8 serialized pilot format"); + } + } + + template + static void + validate_unique(const std::vector& keys) { + if (std::adjacent_find(keys.begin(), keys.end(), std::greater_equal()) == keys.end()) { + return; + } + std::vector sorted = keys; + std::sort(sorted.begin(), sorted.end()); + if (std::adjacent_find(sorted.begin(), sorted.end()) != sorted.end()) { + throw std::invalid_argument("PtrHash requires unique keys"); + } + } + + template + static bool + try_build(const std::vector& keys, size_t parts, double alpha, double lambda, uint64_t seed, + uint16_t max_pilot, size_t build_threads, BucketFunction bucket_function, size_t& slots_total, + size_t& buckets_total, size_t& slots_per_part, size_t& buckets_per_part, std::vector& pilots, + std::vector& remap) { + const size_t keys_per_part = std::max(1, (keys.size() + parts - 1) / parts); + slots_per_part = std::max(1, detail::ceil_division_as_size(static_cast(keys_per_part), alpha)); + if ((slots_per_part & (slots_per_part - 1)) == 0) { + ++slots_per_part; + } + const size_t bucket_base = detail::ceil_division_as_size(static_cast(keys_per_part), lambda); + if (bucket_base > std::numeric_limits::max() - 3) { + throw std::overflow_error("PtrHash size overflow"); + } + buckets_per_part = std::max(1, bucket_base + 3); + if (buckets_per_part > static_cast(std::numeric_limits::max())) { + throw std::overflow_error("too many buckets per part for compact build state"); + } + slots_total = detail::checked_multiply_as_size(parts, slots_per_part, "PtrHash slot layout overflow"); + buckets_total = detail::checked_multiply_as_size(parts, buckets_per_part, "PtrHash bucket layout overflow"); + const uint64_t rem_slots_m = detail::fastmod32_multiplier(slots_per_part); + std::array pilot_hashes{}; + for (size_t pilot = 0; pilot <= max_pilot; ++pilot) { + pilot_hashes[pilot] = detail::hash_pilot(pilot, seed); + } + + std::vector bucket_starts(buckets_total + 1, 0); + std::vector bucket_hashes(keys.size()); + fill_buckets(keys, parts, buckets_per_part, buckets_total, seed, bucket_function, build_threads, bucket_starts, + bucket_hashes); + + pilots.assign(buckets_total, 0); + std::vector taken(slots_total, 0); + std::atomic next_part{0}; + std::atomic ok{true}; + const size_t thread_count = effective_thread_count(build_threads, parts); + std::vector workers; + workers.reserve(thread_count); + try { + for (size_t thread = 0; thread < thread_count; ++thread) { + workers.emplace_back([&] { + while (ok.load(std::memory_order_relaxed)) { + const size_t part = next_part.fetch_add(1, std::memory_order_relaxed); + if (part >= parts) { + break; + } + if (!build_part(part, buckets_per_part, slots_per_part, rem_slots_m, max_pilot, pilot_hashes, + bucket_hashes, bucket_starts, pilots, taken)) { + ok.store(false, std::memory_order_relaxed); + break; + } + } + }); + } + } catch (...) { + join_workers(workers); + throw; + } + join_workers(workers); + if (!ok.load(std::memory_order_relaxed)) { + return false; + } + + const size_t remap_count = slots_total - keys.size(); + std::vector free_minimal; + free_minimal.reserve(remap_count); + for (size_t i = 0; i < keys.size(); ++i) { + if (!taken[i]) { + free_minimal.push_back(static_cast(i)); + } + } + + remap.assign(remap_count, 0); + size_t free_cursor = 0; + for (size_t slot = keys.size(); slot < slots_total; ++slot) { + if (taken[slot]) { + if (free_cursor >= free_minimal.size()) { + return false; + } + remap[slot - keys.size()] = free_minimal[free_cursor++]; + } + } + return free_cursor == free_minimal.size(); + } + + static void + join_workers(std::vector& workers) noexcept { + for (auto& worker : workers) { + if (worker.joinable()) { + worker.join(); + } + } + } + + static size_t + hardware_threads() { + return std::max(1, static_cast(std::thread::hardware_concurrency())); + } + + static size_t + effective_thread_count(size_t requested, size_t limit) { + const size_t wanted = requested == 0 ? hardware_threads() : requested; + return std::max(1, std::min(wanted, std::max(1, limit))); + } + + static size_t + build_thread_count(size_t n, size_t requested) { + return effective_thread_count(requested, (n + 99999) / 100000); + } + + static size_t + bucket_for_hash(uint64_t hx, size_t parts, size_t buckets_per_part, size_t buckets_total, + BucketFunction bucket_function) { + if (bucket_function == BucketFunction::Linear) { + return detail::fast_reduce(static_cast(buckets_total), hx); + } + return detail::fast_reduce(static_cast(parts), hx) * buckets_per_part + + detail::fast_reduce( + static_cast(buckets_per_part), + detail::bucket_transform(detail::splitmix64(hx ^ 0x243f6a8885a308d3ull), bucket_function)); + } + + template + static void + fill_buckets(const std::vector& keys, size_t parts, size_t buckets_per_part, size_t buckets_total, + uint64_t seed, BucketFunction bucket_function, size_t build_threads, + std::vector& bucket_starts, std::vector& bucket_hashes) { + const size_t thread_count = build_thread_count(keys.size(), build_threads); + if (thread_count == 1) { + for (const auto& key : keys) { + const uint64_t hx = detail::hash_key_for(key, seed); + ++bucket_starts[bucket_for_hash(hx, parts, buckets_per_part, buckets_total, bucket_function) + 1]; + } + for (size_t i = 1; i < bucket_starts.size(); ++i) { + bucket_starts[i] += bucket_starts[i - 1]; + } + std::vector cursor = bucket_starts; + for (const auto& key : keys) { + const uint64_t hx = detail::hash_key_for(key, seed); + const size_t bucket = bucket_for_hash(hx, parts, buckets_per_part, buckets_total, bucket_function); + bucket_hashes[cursor[bucket]++] = hx; + } + return; + } + + std::vector> counts(buckets_total); + std::vector workers; + workers.reserve(thread_count); + try { + for (size_t thread = 0; thread < thread_count; ++thread) { + const size_t begin = keys.size() * thread / thread_count; + const size_t end = keys.size() * (thread + 1) / thread_count; + workers.emplace_back([&, begin, end] { + for (size_t i = begin; i < end; ++i) { + const uint64_t hx = detail::hash_key_for(keys[i], seed); + const size_t bucket = + bucket_for_hash(hx, parts, buckets_per_part, buckets_total, bucket_function); + counts[bucket].fetch_add(1, std::memory_order_relaxed); + } + }); + } + } catch (...) { + join_workers(workers); + throw; + } + join_workers(workers); + + for (size_t i = 0; i < buckets_total; ++i) { + bucket_starts[i + 1] = bucket_starts[i] + counts[i].load(std::memory_order_relaxed); + } + + std::vector> cursor(buckets_total); + for (size_t i = 0; i < buckets_total; ++i) { + cursor[i].store(bucket_starts[i], std::memory_order_relaxed); + } + + workers.clear(); + try { + for (size_t thread = 0; thread < thread_count; ++thread) { + const size_t begin = keys.size() * thread / thread_count; + const size_t end = keys.size() * (thread + 1) / thread_count; + workers.emplace_back([&, begin, end] { + for (size_t i = begin; i < end; ++i) { + const uint64_t hx = detail::hash_key_for(keys[i], seed); + const size_t bucket = + bucket_for_hash(hx, parts, buckets_per_part, buckets_total, bucket_function); + const uint32_t pos = cursor[bucket].fetch_add(1, std::memory_order_relaxed); + bucket_hashes[pos] = hx; + } + }); + } + } catch (...) { + join_workers(workers); + throw; + } + join_workers(workers); + } + + static size_t + slot_in_part_hp(uint64_t hx, uint64_t rem_slots_m, uint64_t pilot_hash, size_t slots_per_part) { + return detail::fastmod32_reduce(static_cast(slots_per_part), rem_slots_m, hx ^ pilot_hash); + } + + static size_t + slot_in_part(uint64_t hx, uint64_t seed, uint64_t rem_slots_m, uint16_t pilot, size_t slots_per_part) { + return slot_in_part_hp(hx, rem_slots_m, detail::hash_pilot(pilot, seed), slots_per_part); + } + + static bool + bucket_slots(const std::vector& hashes, size_t begin, size_t end, uint64_t rem_slots_m, + uint64_t pilot_hash, size_t slots_per_part, std::vector& out) { + out.clear(); + out.reserve(end - begin); + for (size_t i = begin; i < end; ++i) { + const uint64_t hx = hashes[i]; + const size_t slot = slot_in_part_hp(hx, rem_slots_m, pilot_hash, slots_per_part); + if (std::find(out.begin(), out.end(), slot) != out.end()) { + return false; + } + out.push_back(slot); + } + return true; + } + + static bool + bucket_slots_available(const std::vector& hashes, size_t begin, size_t end, uint64_t rem_slots_m, + uint64_t pilot_hash, size_t slots_per_part, const uint8_t* taken_part) { + size_t i = begin; + const size_t unrolled_end = begin + ((end - begin) / 4) * 4; + for (; i < unrolled_end; i += 4) { + const size_t slot0 = slot_in_part_hp(hashes[i], rem_slots_m, pilot_hash, slots_per_part); + const size_t slot1 = slot_in_part_hp(hashes[i + 1], rem_slots_m, pilot_hash, slots_per_part); + const size_t slot2 = slot_in_part_hp(hashes[i + 2], rem_slots_m, pilot_hash, slots_per_part); + const size_t slot3 = slot_in_part_hp(hashes[i + 3], rem_slots_m, pilot_hash, slots_per_part); + if (taken_part[slot0] || taken_part[slot1] || taken_part[slot2] || taken_part[slot3]) { + return false; + } + } + for (; i < end; ++i) { + const size_t slot = slot_in_part_hp(hashes[i], rem_slots_m, pilot_hash, slots_per_part); + if (taken_part[slot]) { + return false; + } + } + return true; + } + + static bool + try_take_bucket_slots(const std::vector& hashes, size_t begin, size_t end, uint64_t rem_slots_m, + uint64_t pilot_hash, size_t slots_per_part, uint8_t* taken_part, std::vector& out) { + out.clear(); + out.reserve(end - begin); + for (size_t i = begin; i < end; ++i) { + const size_t slot = slot_in_part_hp(hashes[i], rem_slots_m, pilot_hash, slots_per_part); + if (taken_part[slot]) { + for (size_t taken_slot : out) { + taken_part[taken_slot] = 0; + } + return false; + } + taken_part[slot] = 1; + out.push_back(slot); + } + return true; + } + + static bool + contains_recent(const std::array& recent, BucketId bucket) { + return std::find(recent.begin(), recent.end(), bucket) != recent.end(); + } + + static bool + build_part(size_t part, size_t buckets_per_part, size_t slots_per_part, uint64_t rem_slots_m, uint16_t max_pilot, + const std::array& pilot_hashes, const std::vector& bucket_hashes, + const std::vector& bucket_starts, std::vector& pilots, std::vector& taken) { + const size_t bucket_offset = part * buckets_per_part; + const size_t slot_offset = part * slots_per_part; + uint8_t* const taken_part = taken.data() + slot_offset; + std::vector order(buckets_per_part); + for (size_t i = 0; i < buckets_per_part; ++i) { + order[i] = static_cast(i); + } + std::stable_sort(order.begin(), order.end(), [&](BucketId a, BucketId b) { + return bucket_starts[bucket_offset + a + 1] - bucket_starts[bucket_offset + a] > + bucket_starts[bucket_offset + b + 1] - bucket_starts[bucket_offset + b]; + }); + + std::vector slot_bucket(slots_per_part, bucket_npos()); + std::vector candidate_slots; + std::vector remove_slots; + std::array recent{}; + + auto bucket_len = [&](BucketId b) { + return static_cast(bucket_starts[bucket_offset + b + 1] - bucket_starts[bucket_offset + b]); + }; + + for (BucketId new_bucket : order) { + if (bucket_len(new_bucket) == 0) { + pilots[bucket_offset + new_bucket] = 0; + continue; + } + + std::priority_queue> stack; + stack.emplace(bucket_len(new_bucket), new_bucket); + recent.fill(bucket_npos()); + size_t recent_idx = 0; + recent[recent_idx] = new_bucket; + size_t evictions = 0; + + while (!stack.empty()) { + const BucketId bucket = stack.top().second; + stack.pop(); + const size_t begin = bucket_starts[bucket_offset + bucket]; + const size_t end = bucket_starts[bucket_offset + bucket + 1]; + + bool placed = false; + for (uint32_t pilot_u32 = 0; pilot_u32 <= max_pilot; ++pilot_u32) { + const auto pilot = static_cast(pilot_u32); + const uint64_t pilot_hash = pilot_hashes[pilot_u32]; + if (!bucket_slots_available(bucket_hashes, begin, end, rem_slots_m, pilot_hash, slots_per_part, + taken_part)) { + continue; + } + if (!try_take_bucket_slots(bucket_hashes, begin, end, rem_slots_m, pilot_hash, slots_per_part, + taken_part, candidate_slots)) { + continue; + } + pilots[bucket_offset + bucket] = static_cast(pilot); + for (size_t slot : candidate_slots) { + slot_bucket[slot] = bucket; + } + placed = true; + break; + } + if (placed) { + continue; + } + + size_t best_score = std::numeric_limits::max(); + uint16_t best_pilot = 0; + bool have_best = false; + for (uint32_t pilot_u32 = 0; pilot_u32 <= max_pilot; ++pilot_u32) { + const auto pilot = static_cast(pilot_u32); + const uint64_t pilot_hash = pilot_hashes[pilot_u32]; + if (!bucket_slots(bucket_hashes, begin, end, rem_slots_m, pilot_hash, slots_per_part, + candidate_slots)) { + continue; + } + size_t score = 0; + bool skip = false; + for (size_t slot : candidate_slots) { + const BucketId other = slot_bucket[slot]; + if (other == bucket_npos()) { + continue; + } + if (contains_recent(recent, other)) { + skip = true; + break; + } + const size_t len = bucket_len(other); + score += len * len; + if (score >= best_score) { + skip = true; + break; + } + } + if (!skip) { + best_score = score; + best_pilot = pilot; + have_best = true; + } + } + if (!have_best) { + return false; + } + + if (!bucket_slots(bucket_hashes, begin, end, rem_slots_m, pilot_hashes[best_pilot], slots_per_part, + candidate_slots)) { + return false; + } + pilots[bucket_offset + bucket] = static_cast(best_pilot); + for (size_t slot : candidate_slots) { + const BucketId other = slot_bucket[slot]; + if (other != bucket_npos() && other != bucket) { + stack.emplace(bucket_len(other), other); + ++evictions; + if (evictions > 10 * slots_per_part) { + return false; + } + const size_t other_begin = bucket_starts[bucket_offset + other]; + const size_t other_end = bucket_starts[bucket_offset + other + 1]; + const auto other_pilot = static_cast(pilots[bucket_offset + other]); + if (!bucket_slots(bucket_hashes, other_begin, other_end, rem_slots_m, pilot_hashes[other_pilot], + slots_per_part, remove_slots)) { + return false; + } + for (size_t remove_slot : remove_slots) { + if (slot_bucket[remove_slot] == other) { + slot_bucket[remove_slot] = bucket_npos(); + taken_part[remove_slot] = false; + } + } + } + slot_bucket[slot] = bucket; + taken_part[slot] = true; + } + + recent_idx = (recent_idx + 1) % recent.size(); + recent[recent_idx] = bucket; + } + } + return true; + } + + static constexpr BucketId + bucket_npos() { + return std::numeric_limits::max(); + } + + static PtrHash + from_parts(size_t n, size_t slots_total, size_t buckets_total, size_t parts, size_t slots_per_part, + size_t buckets_per_part, uint64_t seed, BucketFunction bucket_function, + detail::KeyHashKind key_hash_kind, std::vector pilots, std::vector remap) { + PtrHash hash; + const size_t remap_bytes = + detail::checked_multiply_as_size(remap.size(), sizeof(uint32_t), "PtrHash serialized size overflow"); + if (pilots.size() > std::numeric_limits::max() - detail::kHeaderSize || + remap_bytes > std::numeric_limits::max() - detail::kHeaderSize - pilots.size()) { + throw std::overflow_error("PtrHash serialized size overflow"); + } + hash.storage_.reserve(detail::kHeaderSize + pilots.size() + remap_bytes); + hash.storage_.insert(hash.storage_.end(), detail::kMagic, detail::kMagic + 8); + detail::append_u32(hash.storage_, detail::kVersion); + const uint32_t flags = detail::kRemapU32 | + (static_cast(bucket_function) << detail::kBucketFunctionShift) | + (static_cast(key_hash_kind) << detail::kKeyHashKindShift); + detail::append_u32(hash.storage_, flags); + detail::append_u64(hash.storage_, static_cast(n)); + detail::append_u64(hash.storage_, static_cast(slots_total)); + detail::append_u64(hash.storage_, static_cast(buckets_total)); + detail::append_u64(hash.storage_, seed); + detail::append_u64(hash.storage_, static_cast(pilots.size())); + detail::append_u64(hash.storage_, static_cast(remap.size())); + detail::append_u64(hash.storage_, static_cast(parts)); + detail::append_u64(hash.storage_, static_cast(slots_per_part)); + detail::append_u64(hash.storage_, static_cast(buckets_per_part)); + hash.storage_.insert(hash.storage_.end(), pilots.begin(), pilots.end()); + for (uint32_t value : remap) { + detail::append_u32(hash.storage_, value); + } + hash.reset_view(); + return hash; + } + + void + reset_view() { + if (storage_.empty()) { + view_ = PtrHashView(); + } else { + view_ = PtrHashView::from_bytes(storage_.data(), storage_.size()); + } + } + + std::vector storage_; + PtrHashView view_; +}; + +template +class PtrHashWithHasher { + public: + PtrHashWithHasher() = default; + + PtrHashWithHasher(PtrHash hash, Hasher hasher) : hash_(std::move(hash)), hasher_(std::move(hasher)) { + } + + static PtrHashWithHasher + build(const std::vector& keys, Hasher hasher = Hasher(), const PtrHashParams& params = PtrHashParams()) { + std::vector hashes; + hashes.reserve(keys.size()); + for (const auto& key : keys) { + hashes.push_back(to_u64_hash(hasher(key))); + } + return PtrHashWithHasher(PtrHash::build_hashes(hashes, params), std::move(hasher)); + } + + static PtrHashWithHasher + deserialize(const void* data, size_t size, Hasher hasher = Hasher()) { + return PtrHashWithHasher(PtrHash::deserialize(data, size), std::move(hasher)); + } + + static PtrHashWithHasher + load(const std::string& path, Hasher hasher = Hasher()) { + return PtrHashWithHasher(PtrHash::load(path), std::move(hasher)); + } + + void + save(const std::string& path) const { + hash_.save(path); + } + + size_t + n() const { + return hash_.n(); + } + size_t + max_index() const { + return hash_.max_index(); + } + size_t + index_no_remap(const Key& key) const { + return hash_.index_no_remap_hash(hash_key(key)); + } + size_t + index(const Key& key) const { + return hash_.index_hash(hash_key(key)); + } + const PtrHash& + raw() const { + return hash_; + } + const PtrHashView& + view() const { + return hash_.view(); + } + const std::vector& + serialize() const { + return hash_.serialize(); + } + + private: + template + static uint64_t + to_u64_hash(Value value) { + using Decayed = typename std::decay::type; + static_assert(std::is_integral::value && sizeof(Decayed) <= sizeof(uint64_t), + "PtrHashWithHasher hasher must return an integral value " + "that fits in uint64_t"); + return static_cast(value); + } + + uint64_t + hash_key(const Key& key) const { + return to_u64_hash(hasher_(key)); + } + + PtrHash hash_; + Hasher hasher_; +}; + +class MappedPtrHash { + public: + MappedPtrHash() = default; + MappedPtrHash(const MappedPtrHash&) = delete; + MappedPtrHash& + operator=(const MappedPtrHash&) = delete; + + MappedPtrHash(MappedPtrHash&& other) noexcept { + move_from(std::move(other)); + } + + MappedPtrHash& + operator=(MappedPtrHash&& other) noexcept { + if (this != &other) { + close(); + move_from(std::move(other)); + } + return *this; + } + + ~MappedPtrHash() { + close(); + } + + static MappedPtrHash + open(const std::string& path, size_t offset = 0) { +#if defined(__unix__) || defined(__APPLE__) + const int fd = ::open(path.c_str(), O_RDONLY); + if (fd < 0) { + throw std::runtime_error("failed to open PtrHash mmap file"); + } + struct stat st; + if (::fstat(fd, &st) != 0) { + ::close(fd); + throw std::runtime_error("failed to stat PtrHash mmap file"); + } + if (st.st_size < 0) { + ::close(fd); + throw std::invalid_argument("PtrHash mmap file size is invalid"); + } + const auto file_size = static_cast(st.st_size); + if (file_size > static_cast(std::numeric_limits::max())) { + ::close(fd); + throw std::overflow_error("PtrHash mmap file is too large"); + } + const size_t length = static_cast(file_size); + if (offset > length) { + ::close(fd); + throw std::invalid_argument("PtrHash mmap offset is past end of file"); + } + if (length == 0) { + ::close(fd); + throw std::invalid_argument("PtrHash mmap file is empty"); + } + void* mapping = ::mmap(nullptr, length, PROT_READ, MAP_PRIVATE, fd, 0); + ::close(fd); + if (mapping == MAP_FAILED) { + throw std::runtime_error("failed to mmap PtrHash file"); + } + MappedPtrHash result; + result.mapping_ = mapping; + result.mapping_size_ = length; + try { + result.view_ = PtrHashView::from_bytes(static_cast(mapping) + offset, length - offset); + } catch (...) { + result.close(); + throw; + } + return result; +#else + (void)path; + (void)offset; + throw std::runtime_error("mmap loading is only available on POSIX platforms"); +#endif + } + + const PtrHashView& + view() const { + return view_; + } + size_t + n() const { + return view_.n(); + } + size_t + max_index() const { + return view_.max_index(); + } + size_t + index_no_remap(uint64_t key) const { + return view_.index_no_remap(key); + } + size_t + index_no_remap(std::string_view key) const { + return view_.index_no_remap(key); + } + size_t + index_no_remap_hash(uint64_t hash) const { + return view_.index_no_remap_hash(hash); + } + size_t + index(uint64_t key) const { + return view_.index(key); + } + size_t + index(std::string_view key) const { + return view_.index(key); + } + size_t + index_hash(uint64_t hash) const { + return view_.index_hash(hash); + } + + private: + void + close() noexcept { +#if defined(__unix__) || defined(__APPLE__) + if (mapping_ != nullptr && mapping_size_ != 0) { + ::munmap(mapping_, mapping_size_); + } +#endif + mapping_ = nullptr; + mapping_size_ = 0; + view_ = PtrHashView(); + } + + void + move_from(MappedPtrHash&& other) noexcept { + mapping_ = other.mapping_; + mapping_size_ = other.mapping_size_; + view_ = other.view_; + other.mapping_ = nullptr; + other.mapping_size_ = 0; + other.view_ = PtrHashView(); + } + + void* mapping_ = nullptr; + size_t mapping_size_ = 0; + PtrHashView view_; +}; + +} // namespace ptrhash + +#endif // PTRHASH_PTRHASH_HPP diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 76dacde..f6c7b64 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,6 +19,7 @@ set(ALL_TEST_FILES init_gtest.cpp TracerTest.cpp StreamTest.cpp + PtrHashTest.cpp ) add_executable(all_tests diff --git a/test/PtrHashTest.cpp b/test/PtrHashTest.cpp new file mode 100644 index 0000000..37a96b1 --- /dev/null +++ b/test/PtrHashTest.cpp @@ -0,0 +1,1224 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ptrhash/ptrhash.hpp" + +namespace { + +using ptrhash::BucketFunction; +using ptrhash::MappedPtrHash; +using ptrhash::PtrHash; +using ptrhash::PtrHashParams; +using ptrhash::PtrHashView; +using ptrhash::PtrHashWithHasher; + +std::vector +MakeIntegerKeys(size_t n) { + std::vector keys; + keys.reserve(n); + for (uint64_t i = 0; i < n; ++i) { + keys.push_back(i * 11400714819323198485ull + 0x9e3779b97f4a7c15ull); + } + return keys; +} + +std::vector +MakeMultipleKeys(size_t n, uint64_t multiplier) { + std::vector keys; + keys.reserve(n); + for (uint64_t i = 0; i < n; ++i) { + keys.push_back(i * multiplier); + } + return keys; +} + +std::string +MakeTempPath(const std::string& label) { + static std::atomic counter{0}; + const auto now = std::chrono::steady_clock::now().time_since_epoch().count(); + return "/tmp/milvus_common_ptrhash_" + label + "_" + std::to_string(now) + "_" + + std::to_string(counter.fetch_add(1, std::memory_order_relaxed)); +} + +class TempFile { + public: + explicit TempFile(std::string label) : path_(MakeTempPath(label)) { + } + + ~TempFile() { + std::remove(path_.c_str()); + } + + const std::string& + path() const { + return path_; + } + + private: + std::string path_; +}; + +template +size_t +IndexKey(const Hash& hash, uint64_t key) { + return hash.index(key); +} + +template +size_t +IndexNoRemapKey(const Hash& hash, uint64_t key) { + return hash.index_no_remap(key); +} + +template +size_t +IndexKey(const Hash& hash, const std::string& key) { + return hash.index(std::string_view(key.data(), key.size())); +} + +template +size_t +IndexNoRemapKey(const Hash& hash, const std::string& key) { + return hash.index_no_remap(std::string_view(key.data(), key.size())); +} + +template +size_t +IndexKey(const Hash& hash, std::string_view key) { + return hash.index(key); +} + +template +size_t +IndexNoRemapKey(const Hash& hash, std::string_view key) { + return hash.index_no_remap(key); +} + +template +auto +IndexKey(const Hash& hash, const Key& key) -> decltype(hash.index(key)) { + return hash.index(key); +} + +template +void +ExpectMinimalPerfect(const Hash& hash, const std::vector& keys) { + ASSERT_EQ(hash.n(), keys.size()); + std::vector seen(keys.size(), 0); + for (const auto& key : keys) { + const size_t index = IndexKey(hash, key); + ASSERT_LT(index, keys.size()); + EXPECT_FALSE(seen[index]) << "duplicate minimal index " << index; + seen[index] = 1; + EXPECT_LT(IndexNoRemapKey(hash, key), hash.max_index()); + } + for (size_t i = 0; i < seen.size(); ++i) { + EXPECT_TRUE(seen[i]) << "missing minimal index " << i; + } +} + +template +void +ExpectMinimalPerfectHashes(const Hash& hash, const std::vector& hashes) { + ASSERT_EQ(hash.n(), hashes.size()); + std::vector seen(hashes.size(), 0); + for (uint64_t key_hash : hashes) { + const size_t index = hash.index_hash(key_hash); + ASSERT_LT(index, hashes.size()); + EXPECT_FALSE(seen[index]) << "duplicate minimal index " << index; + seen[index] = 1; + EXPECT_LT(hash.index_no_remap_hash(key_hash), hash.max_index()); + } +} + +template +void +ExpectSameQueryResults(const LhsHash& lhs, const RhsHash& rhs, const std::vector& keys) { + for (const auto& key : keys) { + EXPECT_EQ(IndexKey(lhs, key), IndexKey(rhs, key)); + EXPECT_EQ(IndexNoRemapKey(lhs, key), IndexNoRemapKey(rhs, key)); + } +} + +template +void +ExpectSameQueries(const PtrHash& lhs, const PtrHash& rhs, const std::vector& keys) { + ASSERT_EQ(lhs.serialize(), rhs.serialize()); + ExpectSameQueryResults(lhs, rhs, keys); +} + +template +void +ExpectSameHashQueries(const LhsHash& lhs, const RhsHash& rhs, const std::vector& hashes) { + for (uint64_t key_hash : hashes) { + EXPECT_EQ(lhs.index_hash(key_hash), rhs.index_hash(key_hash)); + EXPECT_EQ(lhs.index_no_remap_hash(key_hash), rhs.index_no_remap_hash(key_hash)); + } +} + +template +std::vector +CollectIndexes(const Hash& hash, const std::vector& keys) { + std::vector indexes; + indexes.reserve(keys.size()); + for (const auto& key : keys) { + indexes.push_back(IndexKey(hash, key)); + } + return indexes; +} + +template +std::vector +CollectHashIndexes(const Hash& hash, const std::vector& hashes) { + std::vector indexes; + indexes.reserve(hashes.size()); + for (uint64_t key_hash : hashes) { + indexes.push_back(hash.index_hash(key_hash)); + } + return indexes; +} + +uint64_t +StableFingerprint(const std::vector& bytes) { + uint64_t hash = 1469598103934665603ull; + for (uint8_t byte : bytes) { + hash ^= byte; + hash *= 1099511628211ull; + } + return hash; +} + +template +struct StaticMapTables { + std::vector keys_by_index; + std::vector values_by_index; +}; + +template +StaticMapTables +MakeStaticMapTables(const Hash& hash, const std::vector& keys, const std::vector& values) { + StaticMapTables tables; + tables.keys_by_index.resize(hash.n()); + tables.values_by_index.resize(hash.n()); + for (size_t i = 0; i < keys.size(); ++i) { + const size_t index = IndexKey(hash, keys[i]); + tables.keys_by_index[index] = keys[i]; + tables.values_by_index[index] = values[i]; + } + return tables; +} + +template +std::optional +StaticMapFind(const Hash& hash, const StaticMapTables& tables, const Key& query) { + if (hash.n() == 0) { + return std::nullopt; + } + const size_t index = IndexKey(hash, query); + if (!(tables.keys_by_index[index] == query)) { + return std::nullopt; + } + return tables.values_by_index[index]; +} + +template > +void +ExpectStaticMapMatchesUnorderedMap(const PtrHashType& hash, const std::vector& keys, + const std::vector& values, const std::vector& queries, + UnorderedHasher unordered_hasher = UnorderedHasher()) { + ASSERT_EQ(keys.size(), values.size()); + + std::unordered_map expected(0, unordered_hasher); + for (size_t i = 0; i < keys.size(); ++i) { + expected.emplace(keys[i], values[i]); + } + const auto tables = MakeStaticMapTables(hash, keys, values); + + for (const auto& query : queries) { + const auto actual = StaticMapFind(hash, tables, query); + const auto expected_it = expected.find(query); + if (expected_it == expected.end()) { + EXPECT_FALSE(actual.has_value()); + } else { + ASSERT_TRUE(actual.has_value()); + EXPECT_EQ(*actual, expected_it->second); + } + } +} + +uint32_t +ReadU32(const std::vector& bytes, size_t offset) { + return static_cast(bytes[offset]) | (static_cast(bytes[offset + 1]) << 8) | + (static_cast(bytes[offset + 2]) << 16) | (static_cast(bytes[offset + 3]) << 24); +} + +uint64_t +ReadU64(const std::vector& bytes, size_t offset) { + uint64_t value = 0; + for (size_t i = 0; i < 8; ++i) { + value |= static_cast(bytes[offset + i]) << (8 * i); + } + return value; +} + +void +WriteU32(std::vector& bytes, size_t offset, uint32_t value) { + for (size_t i = 0; i < 4; ++i) { + bytes[offset + i] = static_cast(value >> (8 * i)); + } +} + +void +WriteU64(std::vector& bytes, size_t offset, uint64_t value) { + for (size_t i = 0; i < 8; ++i) { + bytes[offset + i] = static_cast(value >> (8 * i)); + } +} + +PtrHash +Deserialize(const std::vector& bytes) { + return PtrHash::deserialize(bytes.data(), bytes.size()); +} + +size_t +CountMappingsForPath(const std::string& path) { +#if defined(__linux__) + std::ifstream maps("/proc/self/maps"); + size_t count = 0; + std::string line; + while (std::getline(maps, line)) { + if (line.find(path) != std::string::npos) { + ++count; + } + } + return count; +#else + (void)path; + return 0; +#endif +} + +struct IdentityHasher { + uint64_t + operator()(uint64_t key) const { + return key; + } +}; + +struct CustomKey { + uint64_t tenant = 0; + std::string name; + + bool + operator==(const CustomKey& other) const { + return tenant == other.tenant && name == other.name; + } +}; + +struct CustomKeyHasher { + uint64_t + operator()(const CustomKey& key) const { + uint64_t hash = key.tenant * 0x9e3779b97f4a7c15ull; + for (unsigned char byte : key.name) { + hash ^= static_cast(byte); + hash *= 1099511628211ull; + } + return hash; + } +}; + +struct ConstantCustomKeyHasher { + uint64_t + operator()(const CustomKey&) const { + return 7; + } +}; + +std::vector +MakeRandomIntegerKeys(size_t n, uint64_t seed) { + std::mt19937_64 rng(seed); + std::unordered_set seen; + std::vector keys; + keys.reserve(n); + while (keys.size() < n) { + const uint64_t key = rng(); + if (seen.insert(key).second) { + keys.push_back(key); + } + } + return keys; +} + +std::string +MakeRandomString(std::mt19937_64& rng) { + const size_t length = static_cast(rng() % 32); + std::string value; + value.reserve(length); + for (size_t i = 0; i < length; ++i) { + const char c = static_cast('a' + (rng() % 26)); + value.push_back(c); + } + return value; +} + +std::vector +MakeRandomStringKeys(size_t n, uint64_t seed) { + std::mt19937_64 rng(seed); + std::unordered_set seen; + std::vector keys; + keys.reserve(n); + while (keys.size() < n) { + std::string key = MakeRandomString(rng); + key += "#" + std::to_string(rng()); + if (seen.insert(key).second) { + keys.push_back(std::move(key)); + } + } + return keys; +} + +constexpr size_t kSerializedHeaderSize = 88; +constexpr size_t kSerializedNOffset = 16; +constexpr size_t kSerializedSlotsTotalOffset = 24; +constexpr size_t kSerializedBucketsTotalOffset = 32; +constexpr size_t kSerializedPilotCountOffset = 48; +constexpr size_t kSerializedRemapCountOffset = 56; +constexpr size_t kSerializedPartsOffset = 64; +constexpr size_t kSerializedSlotsPerPartOffset = 72; +constexpr size_t kSerializedBucketsPerPartOffset = 80; + +} // namespace + +TEST(PtrHashTest, BuildIntegerKeysReturnsMinimalPermutation) { + std::vector keys = {0, 1, 42, 999, std::numeric_limits::max(), 1ull << 63, (1ull << 63) + 17}; + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, BuildRandomSizedIntegerSetsReturnsMinimalPermutation) { + for (size_t n : + {size_t{0}, size_t{1}, size_t{2}, size_t{3}, size_t{4}, size_t{5}, size_t{6}, size_t{7}, size_t{8}, size_t{9}, + size_t{10}, size_t{30}, size_t{100}, size_t{300}, size_t{1000}, size_t{3000}, size_t{10000}, size_t{30000}}) { + auto keys = MakeIntegerKeys(n); + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); + } +} + +TEST(PtrHashTest, BuildMultipleIntegerSetsReturnsMinimalPermutation) { + for (uint64_t multiplier : {uint64_t{1}, uint64_t{1} << 40, uint64_t{1000000000000}, uint64_t{94143178827}}) { + for (size_t n : + {size_t{0}, size_t{1}, size_t{2}, size_t{3}, size_t{4}, size_t{5}, size_t{6}, size_t{7}, size_t{8}, + size_t{9}, size_t{10}, size_t{30}, size_t{100}, size_t{300}, size_t{1000}, size_t{3000}, size_t{10000}}) { + auto keys = MakeMultipleKeys(n, multiplier); + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); + } + } +} + +TEST(PtrHashTest, IntegerKeyTypesCompileAndQuery) { + auto check = [](const auto& keys) { + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); + }; + + check(std::vector{0, 1, std::numeric_limits::max()}); + check(std::vector{0, 1, 257, std::numeric_limits::max()}); + check(std::vector{0, 1, 65537, std::numeric_limits::max()}); + check(std::vector{0, 1, 1ull << 40, std::numeric_limits::max()}); + check(std::vector{0, 1, 4097, static_cast(std::numeric_limits::max())}); + check(std::vector{std::numeric_limits::min(), -1, 0, 1, std::numeric_limits::max()}); + check(std::vector{std::numeric_limits::min(), -1, 0, 1, std::numeric_limits::max()}); + check(std::vector{std::numeric_limits::min(), -1, 0, 1, std::numeric_limits::max()}); + check(std::vector{std::numeric_limits::min(), -1, 0, 1, std::numeric_limits::max()}); +} + +TEST(PtrHashTest, IndexSumMatchesMinimalPermutation) { + for (size_t n : {size_t{2}, size_t{10}, size_t{100}, size_t{1000}, size_t{10000}}) { + auto keys = MakeIntegerKeys(n); + auto hash = PtrHash::build(keys); + size_t sum = 0; + for (uint64_t key : keys) { + sum += hash.index(key); + } + EXPECT_EQ(sum, n * (n - 1) / 2); + } +} + +TEST(PtrHashTest, BuildUint32KeysReturnsMinimalPermutation) { + std::vector keys = {0, 1, 42, 65535, 65536, 1234567890u, std::numeric_limits::max()}; + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, BuildInt32KeysReturnsMinimalPermutation) { + std::vector keys = { + 0, 1, -1, 42, -42, std::numeric_limits::min(), std::numeric_limits::max()}; + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, BuildStringKeysReturnsMinimalPermutation) { + std::vector keys = {"alpha", "beta", "gamma", "delta", "epsilon", "zeta"}; + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, StringKeyCornerCases) { + std::vector keys; + keys.emplace_back(""); + keys.emplace_back("a"); + keys.emplace_back("abc"); + keys.emplace_back("prefix"); + keys.emplace_back("prefix_suffix"); + keys.emplace_back("a\0", 2); + keys.emplace_back("a\0b", 3); + keys.emplace_back("\0leading", 8); + keys.emplace_back(4096, 'x'); + keys.back()[2048] = '\0'; + + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, StringViewKeyCornerCases) { + std::vector backing; + backing.emplace_back(""); + backing.emplace_back("short"); + backing.emplace_back("a\0", 2); + backing.emplace_back("a\0b", 3); + backing.emplace_back(1024, 'q'); + backing.back()[511] = '\0'; + + std::vector keys; + keys.reserve(backing.size()); + for (const auto& key : backing) { + keys.emplace_back(key.data(), key.size()); + } + + auto hash = PtrHash::build(keys); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, BuildPrehashedKeysReturnsMinimalPermutation) { + auto hashes = MakeIntegerKeys(32); + auto hash = PtrHash::build_hashes(hashes); + ExpectMinimalPerfectHashes(hash, hashes); +} + +TEST(PtrHashTest, PtrHashWithHasherUsesExternalHasher) { + auto keys = MakeIntegerKeys(64); + auto hash = PtrHashWithHasher::build(keys, IdentityHasher{}); + ExpectMinimalPerfect(hash, keys); +} + +TEST(PtrHashTest, PtrHashWithHasherSupportsCustomKeyStaticMapPattern) { + std::vector keys = {{1, "alpha"}, {1, "beta"}, {2, "alpha"}, {3, "tenant-three"}, {99, "omega"}}; + std::vector values = {"a1", "b1", "a2", "t3", "o99"}; + std::vector queries = keys; + queries.push_back({1, "missing"}); + queries.push_back({4, "alpha"}); + queries.push_back({99, "omega-x"}); + + auto hash = PtrHashWithHasher::build(keys, CustomKeyHasher{}); + ExpectStaticMapMatchesUnorderedMap(hash, keys, values, queries, CustomKeyHasher{}); +} + +TEST(PtrHashTest, PtrHashWithHasherRejectsExternalHashCollisions) { + std::vector keys = {{1, "alpha"}, {2, "beta"}}; + + std::unordered_map unordered; + unordered.emplace(keys[0], 10); + unordered.emplace(keys[1], 20); + ASSERT_EQ(unordered.size(), 2); + + EXPECT_THROW((PtrHashWithHasher::build(keys, ConstantCustomKeyHasher{})), + std::invalid_argument); +} + +TEST(PtrHashTest, EmptyHashHasZeroSizeAndQueryThrows) { + auto hash = PtrHash::build(std::vector{}); + EXPECT_EQ(hash.n(), 0); + EXPECT_EQ(hash.max_index(), 0); + EXPECT_THROW(hash.index(1), std::out_of_range); + EXPECT_THROW(hash.index_no_remap(1), std::out_of_range); +} + +TEST(PtrHashTest, SingleKeyMapsToZero) { + auto hash = PtrHash::build(std::vector{42}); + EXPECT_EQ(hash.n(), 1); + EXPECT_EQ(hash.index(42), 0); + EXPECT_LT(hash.index_no_remap(42), hash.max_index()); + EXPECT_GT(hash.max_index(), hash.n()); +} + +TEST(PtrHashTest, SupportsBucketFunctions) { + auto keys = MakeIntegerKeys(100); + for (BucketFunction bucket_function : + {BucketFunction::Linear, BucketFunction::SquareEps, BucketFunction::CubicEps}) { + PtrHashParams params; + params.bucket_function = bucket_function; + auto hash = PtrHash::build(keys, params); + ExpectMinimalPerfect(hash, keys); + } +} + +TEST(PtrHashTest, SupportsBuildThreadSettings) { + auto keys = MakeIntegerKeys(120000); + for (size_t build_threads : {size_t{1}, size_t{2}}) { + PtrHashParams params; + params.build_threads = build_threads; + auto hash = PtrHash::build(keys, params); + ExpectMinimalPerfect(hash, keys); + } +} + +TEST(PtrHashTest, RejectsInvalidParams) { + const std::vector keys = {1, 2, 3}; + + PtrHashParams params; + params.alpha = 0.0; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.alpha = 1.01; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.alpha = 1e-300; + EXPECT_THROW(PtrHash::build(keys, params), std::overflow_error); + + params = PtrHashParams(); + params.lambda = 0.0; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.lambda = std::numeric_limits::infinity(); + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.max_pilot = 256; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); +} + +TEST(PtrHashTest, RejectsInvalidParamsForEmptyInput) { + const std::vector keys; + + PtrHashParams params; + params.alpha = 0.0; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.lambda = 0.0; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); + + params = PtrHashParams(); + params.max_pilot = 256; + EXPECT_THROW(PtrHash::build(keys, params), std::invalid_argument); +} + +TEST(PtrHashTest, RejectsDuplicateKeys) { + EXPECT_THROW(PtrHash::build(std::vector{1, 2, 1}), std::invalid_argument); + EXPECT_THROW(PtrHash::build(std::vector{"", ""}), std::invalid_argument); + + std::vector backing; + backing.emplace_back("a\0b", 3); + backing.emplace_back("a\0b", 3); + std::vector views; + for (const auto& key : backing) { + views.emplace_back(key.data(), key.size()); + } + EXPECT_THROW(PtrHash::build(views), std::invalid_argument); +} + +TEST(PtrHashTest, RejectsDuplicatePrehashedValues) { + EXPECT_THROW(PtrHash::build_hashes(std::vector{11, 17, 11}), std::invalid_argument); +} + +TEST(PtrHashTest, SerializedHeaderUsesPtrHashMagicAndVersionOne) { + auto hash = PtrHash::build(std::vector{1, 2, 3}); + const auto& bytes = hash.serialize(); + ASSERT_GE(bytes.size(), 12); + + const std::array expected_magic = {'P', 'T', 'R', 'H', 'A', 'S', 'H', '\0'}; + EXPECT_TRUE(std::equal(expected_magic.begin(), expected_magic.end(), bytes.begin())); + EXPECT_EQ(ReadU32(bytes, 8), 1u); +} + +TEST(PtrHashTest, GoldenMappingsStayStableForRepresentativeKeyKinds) { + PtrHashParams params; + params.seed = 0x123456789abcdef0ull; + params.build_threads = 1; + + params.bucket_function = BucketFunction::SquareEps; + std::vector integer_keys = { + 0, 1, 42, 999, std::numeric_limits::max(), 1ull << 63, (1ull << 63) + 17}; + auto integer_hash = PtrHash::build(integer_keys, params); + ExpectMinimalPerfect(integer_hash, integer_keys); + + params.bucket_function = BucketFunction::CubicEps; + std::vector string_keys = { + "", "alpha", "beta", "gamma", "delta", std::string("a\0b", 3), "longer-string-value"}; + auto string_hash = PtrHash::build(string_keys, params); + ExpectMinimalPerfect(string_hash, string_keys); + + params.bucket_function = BucketFunction::Linear; + std::vector prehashed_keys = {11, 0x100000001ull, 0xabcdef1234567890ull, + 99, 123456789, 0xfedcba9876543210ull}; + auto prehashed_hash = PtrHash::build_hashes(prehashed_keys, params); + ExpectMinimalPerfectHashes(prehashed_hash, prehashed_keys); + +#if defined(__SIZEOF_INT128__) + EXPECT_EQ(CollectIndexes(integer_hash, integer_keys), (std::vector{3, 5, 2, 0, 6, 1, 4})); + EXPECT_EQ(StableFingerprint(integer_hash.serialize()), 0xbff1bd6e24232c11ull); + + EXPECT_EQ(CollectIndexes(string_hash, string_keys), (std::vector{5, 2, 3, 4, 6, 1, 0})); + EXPECT_EQ(StableFingerprint(string_hash.serialize()), 0x3cbc60b3f568fd3bull); + + EXPECT_EQ(CollectHashIndexes(prehashed_hash, prehashed_keys), (std::vector{1, 3, 2, 4, 5, 0})); + EXPECT_EQ(StableFingerprint(prehashed_hash.serialize()), 0xd47690f139b34e71ull); +#endif +} + +TEST(PtrHashTest, SerializedBufferUsesExactCapacity) { + auto hash = PtrHash::build(MakeIntegerKeys(1000)); + const auto& bytes = hash.serialize(); + EXPECT_EQ(bytes.size(), bytes.capacity()); +} + +TEST(PtrHashTest, DeserializeRoundTripPreservesQueries) { + auto keys = MakeIntegerKeys(128); + auto hash = PtrHash::build(keys); + auto roundtrip = Deserialize(hash.serialize()); + ExpectSameQueries(hash, roundtrip, keys); +} + +TEST(PtrHashTest, SaveLoadRoundTripPreservesQueries) { + auto keys = MakeIntegerKeys(128); + auto hash = PtrHash::build(keys); + TempFile file("save_load"); + hash.save(file.path()); + + auto loaded = PtrHash::load(file.path()); + ExpectSameQueries(hash, loaded, keys); +} + +TEST(PtrHashTest, SaveLoadRoundTripPreservesStringAndPrehashedQueries) { + std::vector string_keys = {"", "alpha", "beta", std::string("a\0b", 3), "gamma"}; + auto string_hash = PtrHash::build(string_keys); + TempFile string_file("save_load_string"); + string_hash.save(string_file.path()); + + auto loaded_string = PtrHash::load(string_file.path()); + ExpectSameQueries(string_hash, loaded_string, string_keys); + EXPECT_THROW(loaded_string.index(uint64_t{1}), std::invalid_argument); + EXPECT_THROW(loaded_string.index_hash(1), std::invalid_argument); + + std::vector backing = {"view-alpha", "view-beta", std::string("view\0gamma", 10)}; + std::vector string_views; + string_views.reserve(backing.size()); + for (const auto& key : backing) { + string_views.emplace_back(key.data(), key.size()); + } + auto string_view_hash = PtrHash::build(string_views); + TempFile string_view_file("save_load_string_view"); + string_view_hash.save(string_view_file.path()); + + auto loaded_string_view = PtrHash::load(string_view_file.path()); + ExpectSameQueries(string_view_hash, loaded_string_view, string_views); + EXPECT_THROW(loaded_string_view.index(uint64_t{1}), std::invalid_argument); + EXPECT_THROW(loaded_string_view.index_hash(1), std::invalid_argument); + + auto prehashed_keys = MakeIntegerKeys(32); + auto prehashed_hash = PtrHash::build_hashes(prehashed_keys); + TempFile prehashed_file("save_load_prehashed"); + prehashed_hash.save(prehashed_file.path()); + + auto loaded_prehashed = PtrHash::load(prehashed_file.path()); + ASSERT_EQ(prehashed_hash.serialize(), loaded_prehashed.serialize()); + ExpectSameHashQueries(prehashed_hash, loaded_prehashed, prehashed_keys); + EXPECT_THROW(loaded_prehashed.index(uint64_t{prehashed_keys.front()}), std::invalid_argument); + EXPECT_THROW(loaded_prehashed.index(std::string_view("not-a-prehash")), std::invalid_argument); +} + +TEST(PtrHashTest, DeserializeRejectsCorruptInput) { + auto hash = PtrHash::build(std::vector{1, 2, 3, 4}); + auto bytes = hash.serialize(); + ASSERT_GE(bytes.size(), 88); + + auto bad = bytes; + bad[0] = 'X'; + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + WriteU32(bad, 8, 2); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + WriteU32(bad, 12, 0); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + WriteU32(bad, 12, 4 | (99u << 8)); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + WriteU32(bad, 12, 4 | (99u << 16)); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + WriteU64(bad, 48, 0); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad.assign(bytes.begin(), bytes.begin() + 10); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); + + bad = bytes; + bad.pop_back(); + EXPECT_THROW(Deserialize(bad), std::invalid_argument); +} + +TEST(PtrHashTest, DeserializeRejectsNullNonEmptyInput) { + EXPECT_THROW(PtrHash::deserialize(nullptr, 1), std::invalid_argument); + EXPECT_THROW(PtrHashView::from_bytes(nullptr, 1), std::invalid_argument); +} + +TEST(PtrHashTest, DeserializeRejectsImpossibleEmptyLayout) { + auto hash = PtrHash::build(std::vector{}); + auto bytes = hash.serialize(); + ASSERT_EQ(bytes.size(), kSerializedHeaderSize); + + WriteU64(bytes, kSerializedPartsOffset, 1); + EXPECT_THROW(Deserialize(bytes), std::invalid_argument); +} + +TEST(PtrHashTest, DeserializeRejectsNonEmptyLayoutWithZeroBuckets) { + auto hash = PtrHash::build(std::vector{}); + auto bytes = hash.serialize(); + ASSERT_EQ(bytes.size(), kSerializedHeaderSize); + + WriteU64(bytes, kSerializedNOffset, 1); + WriteU64(bytes, kSerializedSlotsTotalOffset, 1); + WriteU64(bytes, kSerializedBucketsTotalOffset, 0); + WriteU64(bytes, kSerializedPilotCountOffset, 0); + WriteU64(bytes, kSerializedRemapCountOffset, 0); + WriteU64(bytes, kSerializedPartsOffset, 1); + WriteU64(bytes, kSerializedSlotsPerPartOffset, 1); + WriteU64(bytes, kSerializedBucketsPerPartOffset, 0); + EXPECT_THROW(Deserialize(bytes), std::invalid_argument); +} + +TEST(PtrHashTest, DeserializeRejectsSerializedSizeOverflow) { + auto hash = PtrHash::build(std::vector{}); + auto bytes = hash.serialize(); + ASSERT_EQ(bytes.size(), kSerializedHeaderSize); + + const uint64_t huge = static_cast(std::numeric_limits::max()); + WriteU64(bytes, kSerializedNOffset, 1); + WriteU64(bytes, kSerializedSlotsTotalOffset, 1); + WriteU64(bytes, kSerializedBucketsTotalOffset, huge); + WriteU64(bytes, kSerializedPilotCountOffset, huge); + WriteU64(bytes, kSerializedRemapCountOffset, 0); + WriteU64(bytes, kSerializedPartsOffset, 1); + WriteU64(bytes, kSerializedSlotsPerPartOffset, 1); + WriteU64(bytes, kSerializedBucketsPerPartOffset, huge); + EXPECT_THROW(Deserialize(bytes), std::overflow_error); +} + +TEST(PtrHashTest, MutatedSerializedBytesDoNotCrashOrReturnOutOfRange) { + auto hash = PtrHash::build(MakeIntegerKeys(128)); + const auto& bytes = hash.serialize(); + + for (size_t offset = 0; offset < bytes.size(); ++offset) { + auto mutated = bytes; + mutated[offset] ^= 0x5a; + try { + auto loaded = Deserialize(mutated); + ASSERT_GT(loaded.n(), 0); + EXPECT_LT(loaded.index(uint64_t{42}), loaded.n()); + EXPECT_LT(loaded.index_no_remap(uint64_t{42}), loaded.max_index()); + } catch (const std::invalid_argument&) { + } catch (const std::overflow_error&) { + } + } +} + +TEST(PtrHashTest, DeserializeRejectsOutOfRangeRemapEntries) { + auto hash = PtrHash::build(MakeIntegerKeys(128)); + auto bytes = hash.serialize(); + const uint64_t n = ReadU64(bytes, kSerializedNOffset); + const uint64_t pilot_count = ReadU64(bytes, kSerializedPilotCountOffset); + const uint64_t remap_count = ReadU64(bytes, kSerializedRemapCountOffset); + ASSERT_GT(remap_count, 0); + ASSERT_LE(n, std::numeric_limits::max()); + ASSERT_GE(bytes.size(), kSerializedHeaderSize + pilot_count + sizeof(uint32_t)); + + WriteU32(bytes, kSerializedHeaderSize + static_cast(pilot_count), static_cast(n)); + EXPECT_THROW(Deserialize(bytes), std::invalid_argument); +} + +TEST(PtrHashTest, RepeatedBuildsStayMinimalForSkewedDistributions) { + std::vector> datasets; + datasets.emplace_back(MakeIntegerKeys(4096)); + + std::vector high_bits; + high_bits.reserve(4096); + for (uint64_t i = 0; i < 4096; ++i) { + high_bits.push_back((i << 32) | ((i * 17) & 0xffffu)); + } + datasets.emplace_back(std::move(high_bits)); + + std::vector descending; + descending.reserve(4096); + for (uint64_t i = 0; i < 4096; ++i) { + descending.push_back(std::numeric_limits::max() - i * 8191); + } + datasets.emplace_back(std::move(descending)); + + for (size_t dataset = 0; dataset < datasets.size(); ++dataset) { + for (uint64_t seed : {uint64_t{0}, uint64_t{1}, uint64_t{0x9e3779b97f4a7c15ull}}) { + for (BucketFunction bucket_function : + {BucketFunction::Linear, BucketFunction::SquareEps, BucketFunction::CubicEps}) { + SCOPED_TRACE("dataset=" + std::to_string(dataset) + " seed=" + std::to_string(seed)); + PtrHashParams params; + params.seed = seed; + params.bucket_function = bucket_function; + params.build_threads = 1; + auto hash = PtrHash::build(datasets[dataset], params); + ExpectMinimalPerfect(hash, datasets[dataset]); + } + } + } +} + +TEST(PtrHashTest, RandomizedStaticMapDifferentialMatchesUnorderedMap) { + const std::vector sizes = {0, 1, 7, 64, 257}; + const std::vector bucket_functions = {BucketFunction::Linear, BucketFunction::SquareEps, + BucketFunction::CubicEps}; + const std::vector thread_counts = {1, 2, 4}; + + for (size_t n : sizes) { + for (BucketFunction bucket_function : bucket_functions) { + for (size_t threads : thread_counts) { + SCOPED_TRACE("integer n=" + std::to_string(n) + " bucket=" + + std::to_string(static_cast(bucket_function)) + " threads=" + std::to_string(threads)); + PtrHashParams params; + params.seed = 0x72616e646f6d0000ull + n + threads; + params.bucket_function = bucket_function; + params.build_threads = threads; + + auto keys = MakeRandomIntegerKeys(n, params.seed); + std::vector values; + values.reserve(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) { + values.push_back(keys[i] ^ (0x9e3779b97f4a7c15ull + i)); + } + + std::vector queries = keys; + std::unordered_set key_set(keys.begin(), keys.end()); + for (uint64_t candidate : MakeRandomIntegerKeys(128, params.seed ^ 0x5eed5eedull)) { + if (key_set.find(candidate) == key_set.end()) { + queries.push_back(candidate); + } + } + + auto hash = PtrHash::build(keys, params); + ExpectStaticMapMatchesUnorderedMap(hash, keys, values, queries); + } + } + } + + const std::vector string_sizes = {0, 1, 9, 96}; + for (size_t n : string_sizes) { + for (BucketFunction bucket_function : bucket_functions) { + for (size_t threads : {size_t{1}, size_t{2}}) { + SCOPED_TRACE("string n=" + std::to_string(n) + " bucket=" + + std::to_string(static_cast(bucket_function)) + " threads=" + std::to_string(threads)); + PtrHashParams params; + params.seed = 0x737472696e670000ull + n + threads; + params.bucket_function = bucket_function; + params.build_threads = threads; + + auto keys = MakeRandomStringKeys(n, params.seed); + std::vector values; + values.reserve(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) { + values.push_back((i + 1) * 17); + } + + std::vector queries = keys; + std::unordered_set key_set(keys.begin(), keys.end()); + for (const auto& candidate : MakeRandomStringKeys(128, params.seed ^ 0xabcddcbaull)) { + if (key_set.find(candidate) == key_set.end()) { + queries.push_back(candidate); + } + } + + auto hash = PtrHash::build(keys, params); + ExpectStaticMapMatchesUnorderedMap(hash, keys, values, queries); + } + } + } +} + +TEST(PtrHashTest, ConcurrentQueriesPreserveResults) { + auto keys = MakeIntegerKeys(10000); + auto hash = PtrHash::build(keys); + + std::vector expected; + expected.reserve(keys.size()); + for (uint64_t key : keys) { + expected.push_back(hash.index(key)); + } + + std::atomic ok{true}; + std::vector threads; + for (size_t thread = 0; thread < 4; ++thread) { + threads.emplace_back([&, thread] { + for (size_t round = 0; round < 20; ++round) { + for (size_t i = thread; i < keys.size(); i += 4) { + if (hash.index(keys[i]) != expected[i]) { + ok.store(false, std::memory_order_relaxed); + return; + } + } + } + }); + } + for (auto& thread : threads) { + thread.join(); + } + EXPECT_TRUE(ok.load(std::memory_order_relaxed)); +} + +TEST(PtrHashTest, NonMemberQueriesStayInRange) { + auto integer_hash = PtrHash::build(std::vector{10, 20, 30, 40}); + for (uint64_t key : {uint64_t{0}, uint64_t{1}, uint64_t{999999}, std::numeric_limits::max()}) { + EXPECT_LT(integer_hash.index(key), integer_hash.n()); + EXPECT_LT(integer_hash.index_no_remap(key), integer_hash.max_index()); + } + + auto string_hash = PtrHash::build(std::vector{"alpha", "beta", "gamma"}); + for (std::string_view key : {std::string_view(""), std::string_view("delta"), std::string_view("alpha\0x", 7)}) { + EXPECT_LT(string_hash.index(key), string_hash.n()); + EXPECT_LT(string_hash.index_no_remap(key), string_hash.max_index()); + } + + auto prehashed = PtrHash::build_hashes(std::vector{100, 200, 300}); + for (uint64_t key_hash : {uint64_t{0}, uint64_t{100}, uint64_t{999}}) { + EXPECT_LT(prehashed.index_hash(key_hash), prehashed.n()); + EXPECT_LT(prehashed.index_no_remap_hash(key_hash), prehashed.max_index()); + } +} + +TEST(PtrHashTest, StaticMapPatternChecksKeysBeforeReturningValues) { + std::vector integer_keys = {10, 20, 30, 40, 50}; + std::vector integer_values = {"ten", "twenty", "thirty", "forty", "fifty"}; + std::vector integer_queries = integer_keys; + integer_queries.insert(integer_queries.end(), {0, 1, 21, 9999, std::numeric_limits::max()}); + + auto integer_hash = PtrHash::build(integer_keys); + ExpectStaticMapMatchesUnorderedMap(integer_hash, integer_keys, integer_values, integer_queries); + + std::vector string_keys = {"", "alpha", "beta", "gamma", std::string("a\0b", 3), "prefix"}; + std::vector string_values = {0, 1, 2, 3, 4, 5}; + std::vector string_queries = string_keys; + string_queries.push_back("delta"); + string_queries.push_back(std::string("a\0c", 3)); + string_queries.push_back("prefix_suffix"); + + auto string_hash = PtrHash::build(string_keys); + ExpectStaticMapMatchesUnorderedMap(string_hash, string_keys, string_values, string_queries); +} + +TEST(PtrHashTest, MappedPtrHashOpenPreservesQueries) { + auto keys = MakeIntegerKeys(128); + auto hash = PtrHash::build(keys); + TempFile file("mapped"); + hash.save(file.path()); + +#if defined(__unix__) || defined(__APPLE__) + auto mapped = MappedPtrHash::open(file.path()); + ExpectMinimalPerfect(mapped, keys); +#else + EXPECT_THROW(MappedPtrHash::open(file.path()), std::runtime_error); +#endif +} + +TEST(PtrHashTest, MappedPtrHashOpenPreservesStringAndPrehashedQueries) { +#if defined(__unix__) || defined(__APPLE__) + std::vector string_keys = {"mapped-alpha", "mapped-beta", std::string("mapped\0gamma", 12), ""}; + auto string_hash = PtrHash::build(string_keys); + TempFile string_file("mapped_string"); + string_hash.save(string_file.path()); + + auto mapped_string = MappedPtrHash::open(string_file.path()); + ExpectSameQueryResults(string_hash, mapped_string, string_keys); + EXPECT_THROW(mapped_string.index(uint64_t{1}), std::invalid_argument); + EXPECT_THROW(mapped_string.index_hash(1), std::invalid_argument); + + std::vector backing = {"mapped-view-alpha", "mapped-view-beta", std::string("mapped-view\0x", 13)}; + std::vector string_views; + string_views.reserve(backing.size()); + for (const auto& key : backing) { + string_views.emplace_back(key.data(), key.size()); + } + auto string_view_hash = PtrHash::build(string_views); + TempFile string_view_file("mapped_string_view"); + string_view_hash.save(string_view_file.path()); + + auto mapped_string_view = MappedPtrHash::open(string_view_file.path()); + ExpectSameQueryResults(string_view_hash, mapped_string_view, string_views); + EXPECT_THROW(mapped_string_view.index(uint64_t{1}), std::invalid_argument); + EXPECT_THROW(mapped_string_view.index_hash(1), std::invalid_argument); + + auto prehashed_keys = MakeIntegerKeys(32); + auto prehashed_hash = PtrHash::build_hashes(prehashed_keys); + TempFile prehashed_file("mapped_prehashed"); + prehashed_hash.save(prehashed_file.path()); + + auto mapped_prehashed = MappedPtrHash::open(prehashed_file.path()); + ExpectSameHashQueries(prehashed_hash, mapped_prehashed, prehashed_keys); + EXPECT_THROW(mapped_prehashed.index(uint64_t{prehashed_keys.front()}), std::invalid_argument); + EXPECT_THROW(mapped_prehashed.index(std::string_view("not-a-prehash")), std::invalid_argument); +#else + GTEST_SKIP() << "mmap loading is only available on POSIX platforms"; +#endif +} + +TEST(PtrHashTest, MappedPtrHashOpenWithOffsetAndRejectsBadOffset) { + auto keys = MakeIntegerKeys(128); + auto hash = PtrHash::build(keys); + const auto& bytes = hash.serialize(); + const std::array prefix = {'p', 'r', 'e', 'f', 'i', 'x', ':'}; + TempFile file("mapped_offset"); + + { + std::ofstream out(file.path(), std::ios::binary); + ASSERT_TRUE(out); + out.write(prefix.data(), static_cast(prefix.size())); + out.write(reinterpret_cast(bytes.data()), static_cast(bytes.size())); + ASSERT_TRUE(out.good()); + } + +#if defined(__unix__) || defined(__APPLE__) + auto mapped = MappedPtrHash::open(file.path(), prefix.size()); + ExpectMinimalPerfect(mapped, keys); + EXPECT_THROW(MappedPtrHash::open(file.path(), prefix.size() + bytes.size() + 1), std::invalid_argument); +#else + EXPECT_THROW(MappedPtrHash::open(file.path(), prefix.size()), std::runtime_error); +#endif +} + +TEST(PtrHashTest, MappedPtrHashRejectsEmptyFileBeforeMapping) { + TempFile file("mapped_empty"); + { + std::ofstream out(file.path(), std::ios::binary); + ASSERT_TRUE(out); + } + +#if defined(__unix__) || defined(__APPLE__) + try { + (void)MappedPtrHash::open(file.path()); + FAIL() << "expected empty mmap file to be rejected"; + } catch (const std::invalid_argument& e) { + EXPECT_STREQ("PtrHash mmap file is empty", e.what()); + } +#else + EXPECT_THROW(MappedPtrHash::open(file.path()), std::runtime_error); +#endif +} + +TEST(PtrHashTest, MappedPtrHashRejectsCorruptFileWithoutLeakingMapping) { + TempFile file("mapped_corrupt"); + { + std::ofstream out(file.path(), std::ios::binary); + ASSERT_TRUE(out); + std::string bytes(4096, 'x'); + out.write(bytes.data(), static_cast(bytes.size())); + ASSERT_TRUE(out.good()); + } + +#if defined(__unix__) || defined(__APPLE__) +#if defined(__linux__) + ASSERT_EQ(CountMappingsForPath(file.path()), 0); +#endif + for (size_t i = 0; i < 16; ++i) { + EXPECT_THROW(MappedPtrHash::open(file.path()), std::invalid_argument); + } +#if defined(__linux__) + EXPECT_EQ(CountMappingsForPath(file.path()), 0); +#endif +#else + EXPECT_THROW(MappedPtrHash::open(file.path()), std::runtime_error); +#endif +} + +TEST(PtrHashTest, MappedPtrHashMoveKeepsValidView) { + auto keys = MakeIntegerKeys(128); + auto hash = PtrHash::build(keys); + TempFile file("mapped_move"); + hash.save(file.path()); + +#if defined(__unix__) || defined(__APPLE__) + auto mapped = MappedPtrHash::open(file.path()); + auto moved = std::move(mapped); + ExpectMinimalPerfect(moved, keys); + + MappedPtrHash assigned; + assigned = std::move(moved); + ExpectMinimalPerfect(assigned, keys); +#else + EXPECT_THROW(MappedPtrHash::open(file.path()), std::runtime_error); +#endif +} + +TEST(PtrHashTest, QueryTypeMismatchThrows) { + auto integer_hash = PtrHash::build(std::vector{1, 2, 3}); + EXPECT_THROW(integer_hash.index(std::string_view("1")), std::invalid_argument); + + auto string_hash = PtrHash::build(std::vector{"1", "2", "3"}); + EXPECT_THROW(string_hash.index(uint64_t{1}), std::invalid_argument); + + auto prehashed = PtrHash::build_hashes(std::vector{1, 2, 3}); + EXPECT_THROW(prehashed.index(uint64_t{1}), std::invalid_argument); + EXPECT_THROW(prehashed.index(std::string_view("1")), std::invalid_argument); + EXPECT_NO_THROW((void)prehashed.index_hash(1)); +} + +TEST(PtrHashTest, CopyAndMoveKeepValidView) { + auto keys = MakeIntegerKeys(128); + PtrHash original = PtrHash::build(keys); + + PtrHash copied(original); + ExpectSameQueries(original, copied, keys); + + PtrHash copy_assigned; + copy_assigned = original; + ExpectSameQueries(original, copy_assigned, keys); + + PtrHash moved(std::move(copied)); + ExpectMinimalPerfect(moved, keys); + + PtrHash move_assigned; + move_assigned = std::move(copy_assigned); + ExpectMinimalPerfect(move_assigned, keys); +}