Skip to content

Switch Product Quantization VQ to mean() when n_centers = 1#2250

Open
lowener wants to merge 3 commits into
NVIDIA:mainfrom
lowener:26.08-pq-mean
Open

Switch Product Quantization VQ to mean() when n_centers = 1#2250
lowener wants to merge 3 commits into
NVIDIA:mainfrom
lowener:26.08-pq-mean

Conversation

@lowener

@lowener lowener commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

When running a PQ preprocessing operation, the VQ option can be used to act as a way to mean-center the dataset.
This proposed change will enable to do that mean-centering operation faster by using a direct call to raft::stats::mean instead of running the expectation-maximization steps.

Signed-off-by: Mickael Ide <mide@nvidia.com>
@lowener lowener requested a review from a team as a code owner June 22, 2026 15:04
@lowener lowener added improvement Improves an existing functionality non-breaking Introduces a non-breaking change C++ labels Jun 22, 2026
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 84910738-679d-4aae-9293-797f503afad0

📥 Commits

Reviewing files that changed from the base of the PR and between e091067 and 9b5bd2a.

📒 Files selected for processing (1)
  • cpp/src/neighbors/detail/vpq_dataset.cuh
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/src/neighbors/detail/vpq_dataset.cuh

📝 Walkthrough

Summary by CodeRabbit

  • Refactor
    • Improved vector quantization (VQ) center training by adding a faster single-center option (computed directly) and refining the multi-center path to use balanced k-means only when needed.
    • Keeps standard multi-center behavior intact while reducing unnecessary computation, improving training efficiency and execution speed.

Walkthrough

In train_vq, a new header dependency is added for raft::stats::mean. The kmeans_in_type alias is moved earlier. Explicit matrix views for vq_centers and the VQ trainset are constructed. A conditional branch handles vq_n_centers == 1 by computing the single center via raft::stats::mean; the general case continues to call cuvs::cluster::kmeans::fit with L2Expanded and params.kmeans_n_iters.

Changes

VQ center training refactor

Layer / File(s) Summary
Single-center fast path and k-means wiring
cpp/src/neighbors/detail/vpq_dataset.cuh
Header include for raft::stats::mean is added. kmeans_in_type is defined at the top of train_vq; dedicated views for vq_centers and the trainset are built; a vq_n_centers == 1 branch calls raft::stats::mean to compute the center, while the else branch calls cuvs::cluster::kmeans::fit with L2Expanded metric and params.kmeans_n_iters.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~5 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main optimization: switching PQ's VQ to use mean() when n_centers equals 1, which directly matches the primary change.
Description check ✅ Passed The description explains the purpose of the change (faster mean-centering using direct mean call instead of EM steps) and is directly related to the code modifications.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands.

Signed-off-by: Mickael Ide <mide@nvidia.com>
Signed-off-by: Mickael Ide <mide@nvidia.com>
@cjnolet

cjnolet commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

this proposed change will enable to do that mean-centering operation faster by using a direct call to raft::stats::mean

Thanks for the PQ @lowener. Can you please provide some benchmarks here to demonstrate the difference in the perf for this change?

@cjnolet cjnolet moved this to In Progress in Unstructured Data Processing Jun 23, 2026
rapids-bot Bot pushed a commit to NVIDIA/raft that referenced this pull request Jul 1, 2026
This PR enables the support of two different data types for stats::mean. It will be used in NVIDIA/cuvs#2250
Add support for half in strided dataset. Enabled thanks to this PR: #1585

Closes #2625

Authors:
  - Micka (https://github.com/lowener)
  - Anupam (https://github.com/aamijar)

Approvers:
  - Anupam (https://github.com/aamijar)
  - Artem M. Chirkin (https://github.com/achirkin)

URL: #3056
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

C++ improvement Improves an existing functionality non-breaking Introduces a non-breaking change

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants