Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions UPSTREAM_PR_BODY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Add serialize/deserialize to the Rust brute force index

## What

Adds `serialize` and `deserialize` to the Rust `brute_force::Index`, wrapping the
existing C entry points `cuvsBruteForceSerialize` / `cuvsBruteForceDeserialize`.
This brings the brute force binding to parity with the CAGRA binding, which already
exposes serialize/deserialize.

- `Index::serialize(&self, res, filename)` writes the index to disk.
- `Index::deserialize(res, filename) -> Result<Index>` loads an index from disk.

Both methods mirror the CAGRA implementation:

- A private `path_to_cstring` helper converts the filesystem path to a `CString`,
returning `Error::InvalidArgument` (instead of panicking) for paths that are not
valid UTF-8 or that contain an interior NUL byte. The path is validated before any
FFI call is made.
- Every FFI call is wrapped in `check_cuvs`.
- `deserialize` constructs the `Index` handle first (via `Index::new()`), so that if
the underlying `cuvsBruteForceDeserialize` call fails, the handle's `Drop` still
runs and releases the C-side index allocation (RAII-safe error path).

The doc comments note that the serialization format may change between cuVS versions,
matching the wording in the C header.

## Notes for reviewers

- **No new bindings were generated.** `cuvsBruteForceSerialize` and
`cuvsBruteForceDeserialize` are already present in `rust/cuvs-sys/src/bindings.rs`
(brute_force is pulled in through `core/all.h`), so this change is purely additive
on the safe Rust wrapper side and touches no generated code.
- **Test helper lifetime detail.** The brute force `Index` keeps a non-owning device
view of its dataset (`_dataset`). The serialize round-trip test deliberately keeps
the host `ndarray` array in the same scope as the index for the duration of the
test, because the device tensor's `shape` pointer borrows that array's dimension
storage. Moving the host array while the index is alive would dangle that pointer
(this is a property of the existing `ManagedTensor` view, not of these new methods).
- **Conflicts with sibling in-flight Rust PRs** (e.g. the IVF-SQ bindings PR #2229
and other Rust binding PRs): if conflicts arise, resolve by merging `main` into this
branch rather than rebasing, per the project's no-rebase contribution guideline.

## Testing

Two new unit tests were added alongside the existing `test_l2`, mirroring the CAGRA
serialize tests:

- `test_brute_force_serialize_deserialize` — builds an index, serializes it, asserts
the output file exists and is non-empty, deserializes it back, and re-verifies that
a self-neighbor search on the **loaded** index returns each query as its own nearest
neighbor.
- `test_brute_force_serialize_rejects_interior_nul` — confirms that a path containing
an interior NUL byte surfaces as `Error::InvalidArgument` rather than panicking.

All brute force tests pass (run single-threaded on a single GPU):

```
cargo test -p cuvs brute_force -- --test-threads=1
test brute_force::tests::test_brute_force_serialize_deserialize ... ok
test brute_force::tests::test_brute_force_serialize_rejects_interior_nul ... ok
test brute_force::tests::test_l2 ... ok
test result: ok. 3 passed; 0 failed; 0 ignored
```
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

`cargo fmt` and `cargo clippy` are clean for the changed file.
140 changes: 139 additions & 1 deletion rust/cuvs/src/brute_force.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
*/
//! Brute Force KNN

use std::ffi::CString;
use std::io::{Write, stderr};
use std::path::Path;

use crate::distance_type::DistanceType;
use crate::dlpack::ManagedTensor;
use crate::error::{Result, check_cuvs};
use crate::error::{Error, Result, check_cuvs};
use crate::resources::Resources;

/// Brute Force KNN Index
Expand All @@ -20,6 +22,17 @@ pub struct Index {
_dataset: Option<ManagedTensor>,
}

/// Convert a filesystem path into a `CString` suitable for the cuVS C API,
/// returning `Error::InvalidArgument` instead of panicking for paths that are
/// not valid UTF-8 or that contain an interior NUL byte.
fn path_to_cstring(path: &Path) -> Result<CString> {
let path_str = path
.to_str()
.ok_or_else(|| Error::InvalidArgument(format!("path is not valid UTF-8: {path:?}")))?;
CString::new(path_str)
.map_err(|e| Error::InvalidArgument(format!("path contains an interior NUL byte: {e}")))
}

impl Index {
/// Builds a new Brute Force KNN Index from the dataset for efficient search.
///
Expand Down Expand Up @@ -87,6 +100,40 @@ impl Index {
))
}
}

/// Save the Brute Force index to file.
///
/// The serialization format can be subject to change, therefore loading an
/// index saved with a previous version of cuVS is not guaranteed to work.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `filename` - The file path for saving the index
pub fn serialize<P: AsRef<Path>>(&self, res: &Resources, filename: P) -> Result<()> {
let c_filename = path_to_cstring(filename.as_ref())?;
unsafe { check_cuvs(ffi::cuvsBruteForceSerialize(res.0, c_filename.as_ptr(), self.inner)) }
}

/// Load a Brute Force index from file.
///
/// The serialization format can be subject to change, therefore loading an
/// index saved with a previous version of cuVS is not guaranteed to work.
///
/// # Arguments
///
/// * `res` - Resources to use
/// * `filename` - The path of the file that stores the index
pub fn deserialize<P: AsRef<Path>>(res: &Resources, filename: P) -> Result<Index> {
let c_filename = path_to_cstring(filename.as_ref())?;
// Create the Index handle first so that any error path below still runs
// its `Drop` and releases the C-side index allocation.
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsBruteForceDeserialize(res.0, c_filename.as_ptr(), index.inner))?;
}
Ok(index)
}
}

impl Drop for Index {
Expand Down Expand Up @@ -168,4 +215,95 @@ mod tests {
fn test_l2() {
test_bfknn(DistanceType::L2Expanded);
}

const N_DATAPOINTS: usize = 16;
const N_FEATURES: usize = 8;

/// Search the first `n_queries` rows of `dataset` against `index` and assert
/// each query finds itself as the top-1 neighbor.
fn search_and_verify_self_neighbors(
res: &Resources,
index: &Index,
dataset: &ndarray::Array2<f32>,
n_queries: usize,
k: usize,
) {
let queries = dataset.slice(s![0..n_queries, ..]);
let queries = ManagedTensor::from(&queries).to_device(res).unwrap();

let mut neighbors_host = ndarray::Array::<i64, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from(&neighbors_host).to_device(res).unwrap();

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from(&distances_host).to_device(res).unwrap();

index.search(res, &queries, &neighbors, &distances).expect("search failed");

distances.to_host(res, &mut distances_host).unwrap();
neighbors.to_host(res, &mut neighbors_host).unwrap();
res.sync_stream().unwrap();

for i in 0..n_queries {
assert_eq!(
neighbors_host[[i, 0]],
i as i64,
"query {i} should be its own nearest neighbor"
);
}
}

#[test]
fn test_brute_force_serialize_deserialize() {
let res = Resources::new().unwrap();

// Keep `dataset` (the host array) in this scope for the whole test: the
// device dataset view stored inside the index borrows its shape, so the
// host array must not be moved while the index is alive.
let dataset =
ndarray::Array::<f32, _>::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0));
let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset)
.expect("failed to build brute force index");
res.sync_stream().unwrap();

let filepath = std::env::temp_dir().join("test_brute_force_index.bin");
index.serialize(&res, &filepath).expect("failed to serialize brute force index");
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

assert!(filepath.exists(), "serialized index file should exist");
assert!(
std::fs::metadata(&filepath).unwrap().len() > 0,
"serialized index file should not be empty"
);

let loaded_index =
Index::deserialize(&res, &filepath).expect("failed to deserialize brute force index");

// The deserialized index should still find each query as its own
// nearest neighbor.
search_and_verify_self_neighbors(&res, &loaded_index, &dataset, 4, 4);

let _ = std::fs::remove_file(&filepath);
}

/// Passing a filename containing an interior NUL byte must surface as an
/// `InvalidArgument` error rather than panicking inside the serializer.
#[test]
fn test_brute_force_serialize_rejects_interior_nul() {
let res = Resources::new().unwrap();

let dataset =
ndarray::Array::<f32, _>::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0));
let device_dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap();
let index = Index::build(&res, DistanceType::L2Expanded, None, device_dataset)
.expect("failed to build brute force index");
res.sync_stream().unwrap();

// `PathBuf::from` on Unix preserves arbitrary bytes, so we can embed a
// NUL byte in the path and confirm the helper rejects it.
let bad_path = std::path::PathBuf::from("/tmp/has\0nul.bin");
let err = index
.serialize(&res, &bad_path)
.expect_err("serialize should reject paths with interior NUL");
assert!(matches!(err, Error::InvalidArgument(_)), "expected InvalidArgument, got {err:?}");
}
}
Loading