Skip to content

Commit 34032a8

Browse files
authored
[Java] Add CAGRA index graph accessor/build from graph (host memory) (#1216)
This PR leverages the functions introduced by #1086 and the data structures introduced by #1111 to access, copy, and re-create an index to/from a CAGRA graph. Supersedes #1105 Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) Approvers: - MithunR (https://github.com/mythrocks) URL: #1216
1 parent aff1e3b commit 34032a8

File tree

7 files changed

+217
-12
lines changed

7 files changed

+217
-12
lines changed

java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ public interface CagraIndex {
5050
*/
5151
SearchResults search(CagraQuery query) throws Throwable;
5252

53+
/** Returns the CAGRA graph
54+
*
55+
* @return a {@link CuVSMatrix} encapsulating the native int (uint32_t) array used to represent
56+
* the cagra graph
57+
*/
58+
CuVSMatrix getGraph();
59+
5360
/**
5461
* A method to persist a CAGRA index using an instance of {@link OutputStream}
5562
* for writing index bytes.
@@ -208,6 +215,12 @@ interface Builder {
208215
*/
209216
Builder from(InputStream inputStream);
210217

218+
/**
219+
* Sets a CAGRA graph instance to re-create an index from a
220+
* previously built graph.
221+
*/
222+
Builder from(CuVSMatrix graph);
223+
211224
/**
212225
* Sets the dataset vectors for building the {@link CagraIndex}.
213226
*

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ public interface CuVSMatrix extends AutoCloseable {
3030
enum DataType {
3131
FLOAT,
3232
INT,
33+
UINT,
3334
BYTE
3435
}
3536

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import com.nvidia.cuvs.internal.common.CloseableHandle;
4646
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
4747
import com.nvidia.cuvs.internal.common.CompositeCloseableHandle;
48+
import com.nvidia.cuvs.internal.common.Util;
4849
import com.nvidia.cuvs.internal.panama.cuvsCagraCompressionParams;
4950
import com.nvidia.cuvs.internal.panama.cuvsCagraIndexParams;
5051
import com.nvidia.cuvs.internal.panama.cuvsCagraMergeParams;
@@ -122,6 +123,31 @@ private CagraIndexImpl(IndexReference indexReference, CuVSResources resources) {
122123
this.destroyed = false;
123124
}
124125

126+
/**
127+
* Constructor for creating an index from a pre-build CAGRA graph
128+
*
129+
* @param metric the distance type used
130+
* @param graph a previously built CAGRA graph
131+
* @param dataset the dataset used for indexing
132+
* @param resources an instance of {@link CuVSResources}
133+
*/
134+
private CagraIndexImpl(
135+
CagraIndexParams.CuvsDistanceType metric,
136+
CuVSMatrix graph,
137+
CuVSMatrix dataset,
138+
CuVSResources resources) {
139+
Objects.requireNonNull(graph);
140+
Objects.requireNonNull(dataset);
141+
142+
this.resources = resources;
143+
144+
assert graph instanceof CuVSMatrixBaseImpl;
145+
assert dataset instanceof CuVSMatrixBaseImpl;
146+
147+
this.cagraIndexReference =
148+
fromGraph(metric, (CuVSMatrixBaseImpl) graph, (CuVSMatrixBaseImpl) dataset);
149+
}
150+
125151
private void checkNotDestroyed() {
126152
if (destroyed) {
127153
throw new IllegalStateException("destroyed");
@@ -164,9 +190,12 @@ private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixBaseImp
164190
omp_set_num_threads(numWriterThreads);
165191

166192
MemorySegment dataSeg = dataset.memorySegment();
193+
// TODO: type kDLCPU()/kDLCUDA() should be aligned with the CuVSMatrixBaseImpl type (host or
194+
// device?)
167195

168196
long[] datasetShape = {rows, cols};
169-
MemorySegment datasetTensor = prepareTensor(localArena, dataSeg, datasetShape, 2, 32, 2, 1);
197+
MemorySegment datasetTensor =
198+
prepareTensor(localArena, dataSeg, datasetShape, kDLFloat(), 32, kDLCPU(), 1);
170199

171200
var index = createCagraIndex();
172201

@@ -381,6 +410,81 @@ public void serialize(OutputStream outputStream, Path tempFile, int bufferLength
381410
}
382411
}
383412

413+
@Override
414+
public CuVSMatrix getGraph() {
415+
try (var localArena = Arena.ofConfined()) {
416+
var outPtr = localArena.allocate(uint32_t);
417+
checkCuVSError(
418+
cuvsCagraIndexGetGraphDegree(cagraIndexReference.getMemorySegment(), outPtr),
419+
"cuvsCagraIndexGetGraphDegree");
420+
long graphDegree = Util.dereferenceUnsignedInt(outPtr);
421+
422+
checkCuVSError(
423+
cuvsCagraIndexGetSize(cagraIndexReference.getMemorySegment(), outPtr),
424+
"cuvsCagraIndexGetSize");
425+
long size = Util.dereferenceUnsignedInt(outPtr);
426+
427+
// TODO: use a "device" graph + tensor, avoid (defer) copy
428+
var graph = new CuVSHostMatrixArenaImpl(size, graphDegree, CuVSMatrix.DataType.UINT);
429+
var graphHostTensor = graph.toTensor(localArena);
430+
var graphDeviceTensor =
431+
prepareTensor(
432+
localArena,
433+
MemorySegment.NULL,
434+
new long[] {size, graphDegree},
435+
kDLUInt(),
436+
32,
437+
kDLCUDA(),
438+
1);
439+
checkCuVSError(
440+
cuvsCagraIndexGetGraph(cagraIndexReference.getMemorySegment(), graphDeviceTensor),
441+
"cuvsCagraIndexGetGraph");
442+
443+
try (var resourceAccess = resources.access()) {
444+
var cuvsRes = resourceAccess.handle();
445+
checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync");
446+
447+
checkCuVSError(
448+
cuvsMatrixCopy(cuvsRes, graphDeviceTensor, graphHostTensor), "cuvsMatrixCopy");
449+
450+
checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync");
451+
}
452+
453+
return graph;
454+
}
455+
}
456+
457+
private IndexReference fromGraph(
458+
CagraIndexParams.CuvsDistanceType metric,
459+
CuVSMatrixBaseImpl graph,
460+
CuVSMatrixBaseImpl dataset) {
461+
try (var localArena = Arena.ofConfined()) {
462+
long rows = dataset.size();
463+
long cols = dataset.columns();
464+
465+
var index = createCagraIndex();
466+
try (var resourcesAccess = resources.access()) {
467+
long cuvsRes = resourcesAccess.handle();
468+
469+
long[] datasetShape = {rows, cols};
470+
MemorySegment datasetTensor =
471+
prepareTensor(
472+
localArena, dataset.memorySegment(), datasetShape, kDLFloat(), 32, kDLCPU(), 1);
473+
474+
long[] graphShape = {graph.size(), graph.columns()};
475+
MemorySegment graphTensor =
476+
prepareTensor(
477+
localArena, graph.memorySegment(), graphShape, kDLUInt(), 32, kDLCPU(), 1);
478+
479+
checkCuVSError(
480+
cuvsCagraIndexFromArgs(cuvsRes, metric.value, graphTensor, datasetTensor, index),
481+
"cuvsCagraIndexFromArgs");
482+
}
483+
484+
return new IndexReference(index, dataset);
485+
}
486+
}
487+
384488
@Override
385489
public void serializeToHNSW(OutputStream outputStream) throws Throwable {
386490
Path path =
@@ -678,6 +782,7 @@ public static class Builder implements CagraIndex.Builder {
678782
private CagraIndexParams cagraIndexParams;
679783
private final CuVSResources cuvsResources;
680784
private InputStream inputStream;
785+
private CuVSMatrix graph;
681786

682787
public Builder(CuVSResources cuvsResources) {
683788
this.cuvsResources = cuvsResources;
@@ -689,6 +794,12 @@ public Builder from(InputStream inputStream) {
689794
return this;
690795
}
691796

797+
@Override
798+
public Builder from(CuVSMatrix graph) {
799+
this.graph = graph;
800+
return this;
801+
}
802+
692803
@Override
693804
public Builder withDataset(float[][] vectors) {
694805
this.dataset = CuVSMatrix.ofArray(vectors);
@@ -712,7 +823,17 @@ public CagraIndexImpl build() throws Throwable {
712823
if (inputStream != null) {
713824
return new CagraIndexImpl(inputStream, cuvsResources);
714825
} else {
715-
return new CagraIndexImpl(cagraIndexParams, dataset, cuvsResources);
826+
if (graph != null) {
827+
if (cagraIndexParams == null || dataset == null) {
828+
throw new IllegalArgumentException(
829+
"In order to reconstruct a CAGRA index from a graph, "
830+
+ "you must specify the original dataset and the metric used.");
831+
}
832+
return new CagraIndexImpl(
833+
cagraIndexParams.getCuvsDistanceType(), graph, dataset, cuvsResources);
834+
} else {
835+
return new CagraIndexImpl(cagraIndexParams, dataset, cuvsResources);
836+
}
716837
}
717838
}
718839
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSHostMatrixImpl.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@
1818
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_CHAR;
1919
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
2020
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT;
21+
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
22+
import static com.nvidia.cuvs.internal.panama.headers_h.*;
2123

2224
import com.nvidia.cuvs.CuVSHostMatrix;
2325
import com.nvidia.cuvs.RowView;
24-
import java.lang.foreign.MemoryLayout;
25-
import java.lang.foreign.MemorySegment;
26-
import java.lang.foreign.SequenceLayout;
27-
import java.lang.foreign.ValueLayout;
26+
import java.lang.foreign.*;
2827
import java.lang.invoke.VarHandle;
2928

3029
/**
@@ -61,7 +60,7 @@ protected CuVSHostMatrixImpl(
6160
protected static ValueLayout valueLayoutFromType(DataType dataType) {
6261
return switch (dataType) {
6362
case FLOAT -> C_FLOAT;
64-
case INT -> C_INT;
63+
case INT, UINT -> C_INT;
6564
case BYTE -> C_CHAR;
6665
};
6766
}
@@ -131,6 +130,12 @@ public ValueLayout valueLayout() {
131130
return valueLayout;
132131
}
133132

133+
@Override
134+
public MemorySegment toTensor(Arena arena) {
135+
return prepareTensor(
136+
arena, memorySegment, new long[] {size, columns}, code(), bits(), kDLCPU(), 1);
137+
}
138+
134139
private static class SliceRowView implements RowView {
135140
private final MemorySegment memorySegment;
136141
private final long size;

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
*/
1616
package com.nvidia.cuvs.internal;
1717

18+
import static com.nvidia.cuvs.internal.panama.headers_h.*;
19+
1820
import com.nvidia.cuvs.CuVSMatrix;
21+
import java.lang.foreign.Arena;
1922
import java.lang.foreign.MemorySegment;
2023

2124
public abstract class CuVSMatrixBaseImpl implements CuVSMatrix {
@@ -45,4 +48,21 @@ public long columns() {
4548
public MemorySegment memorySegment() {
4649
return memorySegment;
4750
}
51+
52+
protected int bits() {
53+
return switch (dataType) {
54+
case FLOAT, INT, UINT -> 32;
55+
case BYTE -> 8;
56+
};
57+
}
58+
59+
protected int code() {
60+
return switch (dataType) {
61+
case FLOAT -> kDLFloat();
62+
case INT -> kDLInt();
63+
case UINT, BYTE -> kDLUInt();
64+
};
65+
}
66+
67+
public abstract MemorySegment toTensor(Arena arena);
4868
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@
1919
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
2020
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT;
2121
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
22-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaGetDeviceCount;
23-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaGetDeviceProperties_v2;
24-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemGetInfo;
25-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaSetDevice;
26-
import static com.nvidia.cuvs.internal.panama.headers_h.size_t;
22+
import static com.nvidia.cuvs.internal.panama.headers_h.*;
2723

2824
import com.nvidia.cuvs.GPUInfo;
2925
import com.nvidia.cuvs.internal.panama.DLDataType;
@@ -73,6 +69,13 @@ public static void checkCudaError(int value, String caller) {
7369
}
7470
}
7571

72+
private static final long UNSIGNED_INT_MASK = 0xFFFFFFFFL;
73+
74+
public static long dereferenceUnsignedInt(MemorySegment ptr) {
75+
assert ptr.byteSize() == 4;
76+
return ptr.get(uint32_t, 0) & UNSIGNED_INT_MASK;
77+
}
78+
7679
/**
7780
* Java analog to CUDA's cudaMemcpyKind, used for cudaMemcpy() calls.
7881
* @see <a href="https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html">CUDA Runtime API</a>

java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,48 @@ private Path createSerializedIndex(CuVSMatrix dataset) throws Throwable {
290290
}
291291
}
292292

293+
@Test
294+
public void testReconstructIndexFromGraph() throws Throwable {
295+
try (var dataset = CuVSMatrix.ofArray(createSampleData())) {
296+
var queries = createSampleQueries();
297+
List<Map<Integer, Float>> expectedResults = getExpectedResults();
298+
299+
try (CuVSResources resources = CuVSResources.create()) {
300+
var index = indexOnce(dataset, resources);
301+
var graph = index.getGraph();
302+
303+
var reconstructedIndex =
304+
CagraIndex.newBuilder(resources)
305+
.from(graph)
306+
.withDataset(dataset)
307+
.withIndexParams(
308+
new CagraIndexParams.Builder().withMetric(CuvsDistanceType.L2Expanded).build())
309+
.build();
310+
queryAndCompare(
311+
index,
312+
reconstructedIndex,
313+
SearchResults.IDENTITY_MAPPING,
314+
queries,
315+
expectedResults,
316+
resources);
317+
318+
var originalIndexPath = serializeOnce(index);
319+
var reconstructedIndexPath = serializeOnce(reconstructedIndex);
320+
321+
var originalBytes = Files.readAllBytes(originalIndexPath);
322+
var reconstructedBytes = Files.readAllBytes(reconstructedIndexPath);
323+
324+
assertArrayEquals(originalBytes, reconstructedBytes);
325+
326+
index.destroyIndex();
327+
reconstructedIndex.destroyIndex();
328+
329+
Files.deleteIfExists(originalIndexPath);
330+
Files.deleteIfExists(reconstructedIndexPath);
331+
}
332+
}
333+
}
334+
293335
@Test
294336
public void testIndexingAndSearchingFlowWithCustomMappingFunction() throws Throwable {
295337
var dataset = CuVSMatrix.ofArray(createSampleData());

0 commit comments

Comments
 (0)