-
Notifications
You must be signed in to change notification settings - Fork 198
KMeans: Reuse Precomputed Norms for Inertia Computation #2258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
04d00e3
22f32a6
fa79838
5371627
4c98eca
5cad5a4
5d48437
7d75605
10274e0
57eae14
7d49d63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1532,14 +1532,16 @@ void transform(raft::resources const& handle, | |
| * @param[out] cost Resulting cluster cost | ||
| * @param[in] sample_weight Optional per-sample weights. | ||
| * [len = n_samples] | ||
| * | ||
| * @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples]. | ||
| * When provided, the internal norm computation is skipped. | ||
| */ | ||
| void cluster_cost( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const float, int> X, | ||
| raft::device_matrix_view<const float, int> centroids, | ||
| raft::host_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt); | ||
| raft::device_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int>> sample_weight = std::nullopt, | ||
| std::optional<raft::device_vector_view<const float, int>> X_norm = std::nullopt); | ||
|
Comment on lines
1538
to
+1544
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🗄️ Data Integrity & Integration | 🔴 Critical | 🏗️ Heavy lift CRITICAL: Switching Consider keeping the old host-scalar overloads as deprecated shims for at least one release and documenting the migration. As per coding guidelines, “API changes require deprecation warnings.” As per path instructions for Also applies to: 1562-1568, 1586-1592, 1610-1616 🤖 Prompt for AI AgentsSources: Coding guidelines, Path instructions |
||
|
|
||
| /** | ||
| * @brief Compute cluster cost | ||
|
|
@@ -1554,13 +1556,16 @@ void cluster_cost( | |
| * @param[out] cost Resulting cluster cost | ||
| * @param[in] sample_weight Optional per-sample weights. | ||
| * [len = n_samples] | ||
| * @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples]. | ||
| * When provided, the internal norm computation is skipped. | ||
| */ | ||
| void cluster_cost( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const double, int> X, | ||
| raft::device_matrix_view<const double, int> centroids, | ||
| raft::host_scalar_view<double> cost, | ||
| std::optional<raft::device_vector_view<const double, int>> sample_weight = std::nullopt); | ||
| raft::device_scalar_view<double> cost, | ||
| std::optional<raft::device_vector_view<const double, int>> sample_weight = std::nullopt, | ||
| std::optional<raft::device_vector_view<const double, int>> X_norm = std::nullopt); | ||
|
|
||
| /** | ||
| * @brief Compute (optionally weighted) cluster cost | ||
|
|
@@ -1575,13 +1580,16 @@ void cluster_cost( | |
| * @param[out] cost Resulting cluster cost | ||
| * @param[in] sample_weight Optional per-sample weights. | ||
| * [len = n_samples] | ||
| * @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples]. | ||
| * When provided, the internal norm computation is skipped. | ||
| */ | ||
| void cluster_cost( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const float, int64_t> X, | ||
| raft::device_matrix_view<const float, int64_t> centroids, | ||
| raft::host_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int64_t>> sample_weight = std::nullopt); | ||
| raft::device_scalar_view<float> cost, | ||
| std::optional<raft::device_vector_view<const float, int64_t>> sample_weight = std::nullopt, | ||
| std::optional<raft::device_vector_view<const float, int64_t>> X_norm = std::nullopt); | ||
|
|
||
| /** | ||
| * @brief Compute (optionally weighted) cluster cost | ||
|
|
@@ -1596,13 +1604,16 @@ void cluster_cost( | |
| * @param[out] cost Resulting cluster cost | ||
| * @param[in] sample_weight Optional per-sample weights. | ||
| * [len = n_samples] | ||
| * @param[in] X_norm Optional precomputed L2 norms of X rows [n_samples]. | ||
| * When provided, the internal norm computation is skipped. | ||
| */ | ||
| void cluster_cost( | ||
| const raft::resources& handle, | ||
| raft::device_matrix_view<const double, int64_t> X, | ||
| raft::device_matrix_view<const double, int64_t> centroids, | ||
| raft::host_scalar_view<double> cost, | ||
| std::optional<raft::device_vector_view<const double, int64_t>> sample_weight = std::nullopt); | ||
| raft::device_scalar_view<double> cost, | ||
| std::optional<raft::device_vector_view<const double, int64_t>> sample_weight = std::nullopt, | ||
| std::optional<raft::device_vector_view<const double, int64_t>> X_norm = std::nullopt); | ||
| /** | ||
| * @} | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -931,7 +931,11 @@ void kmeans_fit( | |
| auto centroids_const = raft::make_device_matrix_view<const DataT, IndexT>( | ||
| cur_centroids_ptr, n_clusters, n_features); | ||
|
|
||
| iter_inertia = DataT{0}; | ||
| auto d_iter_inertia = raft::make_device_scalar<DataT>(handle, DataT{0}); | ||
| auto d_batch_cost = raft::make_device_scalar<DataT>(handle, DataT{0}); | ||
| DataT* p_acc = d_iter_inertia.data_handle(); | ||
| DataT* p_batch = d_batch_cost.data_handle(); | ||
|
|
||
| data_batches.reset(); | ||
| using wt_iter_t = cuvs::spatial::knn::detail::utils::batch_load_iterator_dyn<DataT>; | ||
| std::optional<wt_iter_t> wt_it; | ||
|
|
@@ -956,15 +960,33 @@ void kmeans_fit( | |
| cur_batch_weights(static_cast<IndexT>(data_batch.offset()), wt_data, cur_batch_size); | ||
| } | ||
|
|
||
| DataT batch_cost = DataT{0}; | ||
| cuvs::cluster::kmeans::cluster_cost(handle, | ||
| batch_data_view, | ||
| centroids_const, | ||
| raft::make_host_scalar_view(&batch_cost), | ||
| batch_sw); | ||
| std::optional<raft::device_vector_view<const DataT, IndexT>> batch_xnorm = std::nullopt; | ||
| if (need_compute_norms) { | ||
| if constexpr (data_on_device) { | ||
| batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>( | ||
| L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size); | ||
| } else if (norms_cached) { | ||
| raft::copy(L2NormBatch.data_handle(), | ||
| h_norm_cache.data_handle() + data_batch.offset(), | ||
| cur_batch_size, | ||
| stream); | ||
| batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>( | ||
| L2NormBatch.data_handle(), cur_batch_size); | ||
| } | ||
|
Comment on lines
+963
to
+975
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎯 Functional Correctness | 🟠 Major | ⚡ Quick win HIGH: Guard the uncached host-norm path for zero-iteration fits. Issue: For host data, this final inertia path always copies from Suggested fix if (need_compute_norms) {
if constexpr (data_on_device) {
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle() + data_batch.offset(), cur_batch_size);
} else {
- raft::copy(L2NormBatch.data_handle(),
- h_norm_cache.data_handle() + data_batch.offset(),
- cur_batch_size,
- stream);
+ if (norms_cached) {
+ raft::copy(L2NormBatch.data_handle(),
+ h_norm_cache.data_handle() + data_batch.offset(),
+ cur_batch_size,
+ stream);
+ } else {
+ compute_batch_norms(data_batch.data(), cur_batch_size);
+ }
batch_xnorm = raft::make_device_vector_view<const DataT, IndexT>(
L2NormBatch.data_handle(), cur_batch_size);
}
}🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
| cuvs::cluster::kmeans::cluster_cost( | ||
| handle, batch_data_view, centroids_const, d_batch_cost.view(), batch_sw, batch_xnorm); | ||
|
|
||
| iter_inertia += batch_cost; | ||
| raft::linalg::map_offset(handle, | ||
| raft::make_device_vector_view<DataT, int>(p_acc, 1), | ||
| [p_acc, p_batch] __device__(int) { return *p_acc + *p_batch; }); | ||
| } | ||
|
|
||
| raft::copy(handle, | ||
| raft::make_host_scalar_view<DataT>(&iter_inertia), | ||
| raft::make_const_mdspan(d_iter_inertia.view())); | ||
| raft::resource::sync_stream(handle); | ||
| } | ||
|
|
||
| if (iter_inertia < inertia[0]) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1139,8 +1139,12 @@ void build_hierarchical(const raft::resources& handle, | |
| reinterpret_cast<const MathT*>(dataset), n_rows, dim); | ||
| auto centroids_view = | ||
| raft::make_device_matrix_view<const MathT, IdxT>(cluster_centers, n_clusters, dim); | ||
| cuvs::cluster::kmeans::cluster_cost( | ||
| handle, X_view, centroids_view, raft::make_host_scalar_view<MathT>(inertia)); | ||
| auto d_inertia = raft::make_device_scalar<MathT>(handle, MathT{0}); | ||
| cuvs::cluster::kmeans::cluster_cost(handle, X_view, centroids_view, d_inertia.view()); | ||
| raft::copy(handle, | ||
| raft::make_host_scalar_view<MathT>(inertia), | ||
| raft::make_const_mdspan(d_inertia.view())); | ||
| raft::resource::sync_stream(handle, stream); | ||
|
Comment on lines
+1142
to
+1147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚀 Performance & Scalability | 🟠 Major | ⚡ Quick win HIGH: final hierarchical inertia still drops the cached dataset norms.
Have you considered forwarding 🤖 Prompt for AI Agents |
||
| } else { | ||
| RAFT_LOG_WARN("Inertia is not computed for non float/double types"); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎯 Functional Correctness | 🟠 Major | ⚡ Quick win
HIGH:
X_normis documented as an L2 norm, but the implementation consumes squared norms.This parameter is forwarded into the
L2NormXpath incpp/src/cluster/kmeans.cuh, whose contract is||x||^2. Documenting it as a plain “L2 norm” invites callers to passsqrt(sum(x^2)), which will silently skew the reported cluster cost.Please make the public Doxygen explicit that
X_normmust contain squared row norms. As per coding guidelines, “All public API functions must include complete Doxygen documentation describing parameters, return values, and any side effects.” As per path instructions forcpp/include/cuvs/**/*, “For public C++ API headers, additionally check: Doxygen documentation for all public functions/classes.”Also applies to: 1559-1560, 1583-1584, 1607-1608
🤖 Prompt for AI Agents
Sources: Coding guidelines, Path instructions