Skip to content

[Cleanup] Combine Batched and Regular KMeans Impl#2015

Open
tarang-jain wants to merge 53 commits intorapidsai:mainfrom
tarang-jain:combine-batch
Open

[Cleanup] Combine Batched and Regular KMeans Impl#2015
tarang-jain wants to merge 53 commits intorapidsai:mainfrom
tarang-jain:combine-batch

Conversation

@tarang-jain
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain commented Apr 10, 2026

Combine batched and regular k-means implementations

  • Unified the batched (host-data) and regular (device-data) k-means fit into a single kmeans_fit template that works with both host and device mdspans via batch_load_iterator
  • Unified the device and host initialization paths in init_centroids
  • Removed the inertia_check parameter — inertia-based convergence checking now always runs. Zero clustering cost (perfect fit) logs a warning instead of asserting. This is needed because spectral clustering can cause all points to converge on the cluster centroids itself.
  • Added init_size parameter to control how many samples are drawn for KMeansPlusPlus initialization. Defaults to n_samples for device data, (3 * n_clusters) for host data
  • Replaced per-iteration centroid raft::copy with std::swap of buffer pointers
  • For streaming fit, precompute data norms once and cache them: host norms cached to a host buffer on the first iteration and copied back for subsequent iterations. process_batch no longer computes norms internally
  • Replaced raw cudaPointerGetAttributes call with raft::memory_type_from_pointer
  • Replaced cub::DeviceReduce::Sum calls with raft::linalg::mapThenSumReduce
  • Guarded weight normalization against overflow: apply (w / wt_sum) * n_samples via a composed op instead of precomputing a scale, so very small wt_sum values don't produce inf
  • Renamed checkWeight to weightSum and made it mdspan-based with an Accessor template: device reduce for device weights, host loop for host weights. Callers apply the scaling themselves
  • Eliminated batch_sums / batch_counts scratch buffers by accumulating directly into centroid_sums / weight_per_cluster via reset_sums=false in reduce_rows_by_key / reduce_cols_by_key, removing two per-batch raft::linalg::add kernels
  • Removed dead update_centroids helpers (both the detail and public template) — no remaining callers after the fit_main consolidation
  • Perf: remove multiple raft::sync_stream calls and add a CUDA Event to record if the convergence criteria is met. Convergence check is now done on device. Average per-iteration time with mandatory inertia check now matches previous benchmarks even when previously inertia check was disabled.

C Tests

This PR adds C tests for KMeans. These were missing. Here we test both -- the old version and the new (i.e. breaking change).

Benchmarks:

With mandatory early stopping. Batch size is such that we fill up 90% of available GPU memory (95830MiB)
HW:
GPU:
NVIDIA H100 NVL (CUDA 13.0)
CPU:

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          52 bits physical, 57 bits virtual
  Byte Order:             Little Endian
CPU(s):                   256
  On-line CPU(s) list:    0-255
Vendor ID:                AuthenticAMD
  Model name:             AMD EPYC 9554 64-Core Processor 
================================================================================
 SUMMARY
================================================================================
  n_clusters     batch_size  fit_time(s)        inertia   n_iter
----------------------------------------------------------------
      10,000     29,120,352      1584.72     2.8677e+08       30
      20,000     29,120,352      2907.34     2.7368e+08       31
      30,000     29,101,305      4254.43     2.6617e+08       31
      40,000     29,092,704      5836.12     2.6086e+08       32
      50,000     29,083,488      7107.04     2.5680e+08       31

Breaking Change

This PR is a breaking change of the C++ API because the inertia_check param is removed. The breaking changes to the C ABI will be applied in 26.08

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 10, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@tarang-jain tarang-jain self-assigned this Apr 10, 2026
@tarang-jain tarang-jain added improvement Improves an existing functionality non-breaking Introduces a non-breaking change cpp labels Apr 10, 2026
@tarang-jain tarang-jain marked this pull request as ready for review April 14, 2026 01:10
@tarang-jain tarang-jain requested review from a team as code owners April 14, 2026 01:10
Comment thread c/include/cuvs/cluster/kmeans.h
Copy link
Copy Markdown
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Here are some comments.


auto minClusterAndDistance = raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(
handle, streaming_batch_size);
auto L2NormBatch = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pams.streaming_batch_size = 0 by default in the data on device case, but nothing prevent a user from setting a value. This would allocate a smaller than n_samples L2NormBatch which would cause OOB writes (and later reads) during norm computation.

We should probably guard this with a check :
RAFT_EXPECTS(streaming_batch_size == n_samples || !data_on_device, ...)

Copy link
Copy Markdown
Contributor Author

@tarang-jain tarang-jain Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this so that for device arrays, we simply ignore the streaming_batch_size and use the entire dataset always.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
Comment on lines +661 to +663
auto init_sample =
raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto init_sample =
raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
if (init_sample_size == n_samples && data_on_device) {
auto init_sample_const = raft::make_device_matrix_view<const DataT, IndexT>(X.data_handle(), n_samples, n_features);
// pass directly to kmeansPlusPlus / initScalableKMeansPlusPlus
} else {
auto init_sample = raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
// pass init_sample to kmeansPlusPlus / initScalableKMeansPlusPlus
}

If init_size = 0 in the data on device path, we basically double memory use by copying the dataset over. Let's skip this by creating a view on the dataset.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely skipped the sampling for the device path. That is how it was being done earlier. The init size is only used if the data is on host.

Comment on lines +731 to +732
auto batch_workspace = rmm::device_uvector<char>(
current_batch_sz, stream, raft::resource::get_workspace_resource(handle));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call to process_batch allocates both this workspace and the device scalar below. Both buffers could be instantiated out of the process_batch function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved the workspace buffer allocation outside the process_batch function.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
raft::matrix::sample_rows(handle, random_state, X, centroidsRawData);
} else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) {
IndexT default_init_size =
data_on_device ? n_samples : std::min(static_cast<IndexT>(3 * n_clusters), n_samples);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlikely to be an actual issue, but n_clusters could be casted before the multiplication to avoid any risk of integer overflow.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gotten rid of the batching for the device path. So when a user sets a batch size for device mdspan, we just set it to n_samples and warn the user. We should definitely not be creating a new buffer just for the init sample if we can accommodate the entire input matrix on device already.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
Comment on lines +876 to +881
DataT curClusteringCost = DataT{0};
raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream);
raft::resource::sync_stream(handle, stream);

if (curClusteringCost == DataT{0}) {
RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going from ASSERT to RAFT_LOG_WARN may indeed be useful for the spectral clustering case. However, removing the inertia_check option forces the sync at every iteration. Do we truly need to drop this option?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need the log and an assert might be better?

Early stopping (aka skipping iterations) is ultimately going to be the best way to extract perf here. Whether it's by explicitly computing inertia or just looking at the residuals of the centroids from the prior iteration.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Copy link
Copy Markdown
Contributor Author

@tarang-jain tarang-jain Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Until the iteration has completed, the CPU should not start the next iteration. So all the operations on the GPU stream must complete to finish the iteration.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Yes, but this is throwing an error in the spectral clustering case wherein all the points converge on the centroids themselves. This is happening in one of the spectral tests and an assertion here is leading to an error, where instead it should simply return those centroids directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need the log and an assert might be better?

Therefore, I had to change it to a warning instead of an assertion. Earlier those spectral tests were skipping the inertia check which was avoiding the assertion.

Comment on lines +638 to +646
} else {
std::vector<DataT> h_weights(n_samples);
auto d_view = raft::make_device_vector_view<const DataT, IndexT>(weight_ptr, n_samples);
auto h_view = raft::make_host_vector_view<DataT, IndexT>(h_weights.data(), n_samples);
raft::copy(handle, h_view, d_view);
raft::resource::sync_stream(handle);
for (IndexT i = 0; i < n_samples; ++i) {
wt_sum += h_weights[i];
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the data on device case since the data is already on device it would be much faster to sumreduce thanks to cub::DeviceReduce::Sum or raft::linalg::reduce. The summation would also have better precision since it is done in a tree fashion O(log N).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When its device accessible, I have changed that to a raft::linalg::mapThenSumReduce. I have also removed this function and directly updated checkWeight (changed its name to weightSum). We do the scaling after the weight sum is computed.

@@ -33,9 +33,10 @@ cdef extern from "cuvs/cluster/kmeans.h" nogil:
int batch_samples,
int batch_centroids,
bool inertia_check,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment saying the field is present but deprecated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think its necessary to add a comment here in the .pxd. The C header already has that information. And this file will be updated along with the C headers / src files.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
@tarang-jain tarang-jain requested a review from dantegd April 24, 2026 21:51

for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) {
if (n_current_iter > 1) {
RAFT_CUDA_TRY(cudaEventSynchronize(convergence_event));
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convergence event is recorded right after the inertia is computed, but we check it here before proceeding to the next iteration so that the operations after the convergence check are not blocked. The use of an event also means that only the operations up until the event need to be completed in order for it to be synchronized.

Comment thread c/include/cuvs/cluster/kmeans.h Outdated
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@tarang-jain
Copy link
Copy Markdown
Contributor Author

/ok to test b1c034e

@tarang-jain tarang-jain requested a review from a team as a code owner April 29, 2026 17:17
@tarang-jain
Copy link
Copy Markdown
Contributor Author

/ok to test 73293cf

Copy link
Copy Markdown
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a minor suggestion, not blocking

Comment on lines 115 to +123
/**
* If true, check inertia during iterations for early convergence.
* Number of samples to randomly draw for the KMeansPlusPlus initialization
* step. A random subset of this size is used for centroid seeding.
* When set to 0 the default depends on the data location:
* - Device data: n_samples (use the full dataset).
* - Host data: min(3 * n_clusters, n_samples).
* Default: 0.
*/
bool inertia_check = false;
int64_t init_size = 0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I soflty agree with this rabbit comment, would be worth to add a not in the PR description and code, C++ callers using the C++ API directly will fail to compile, not warn. That's a real source-API break worth a release-notes line at least

@tarang-jain
Copy link
Copy Markdown
Contributor Author

@dantegd I have updated the PR desc. Since this PR is marked as breaking, it will automatically be mentioned in CHANGELOG.md by the bot, right?

cluster_centers, impl->n_lists(), impl->dim());
if (impl->metric() == distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize<raft::linalg::L2Norm>(handle, centers_const_view, centers_view);
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This normalization was lost in #2001. So this PR adds it again.

rapids-bot Bot pushed a commit to rapidsai/cuml that referenced this pull request Apr 30, 2026
Depends on rapidsai/cuvs#2015. Inertia checking is being made mandatory and rapidsai/cuvs#2015 is a breaking change. This PR is needed to prevent compilation failures.

Authors:
  - Tarang Jain (https://github.com/tarang-jain)

Approvers:
  - Jim Crist-Harif (https://github.com/jcrist)
  - Anupam (https://github.com/aamijar)
  - Victor Lafargue (https://github.com/viclafargue)

URL: #8033
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Introduces a breaking change cpp improvement Improves an existing functionality

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants