Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
package com.nvidia.cuvs;

import java.util.Arrays;
import java.util.BitSet;
import java.util.Objects;
import java.util.function.LongToIntFunction;
Expand All @@ -34,7 +33,7 @@ public class CagraQuery {

private final CagraSearchParams cagraSearchParameters;
private final LongToIntFunction mapping;
private final float[][] queryVectors;
private final CuVSMatrix queryVectors;
private final int topK;
private final BitSet prefilter;
private final int numDocs;
Expand All @@ -53,9 +52,9 @@ public class CagraQuery {
* @param numDocs Total number of dataset vectors; used to align the prefilter correctly
* @param resources CuVSResources instance to use for this query
*/
public CagraQuery(
private CagraQuery(
CagraSearchParams cagraSearchParameters,
float[][] queryVectors,
CuVSMatrix queryVectors,
LongToIntFunction mapping,
int topK,
BitSet prefilter,
Expand All @@ -81,11 +80,9 @@ public CagraSearchParams getCagraSearchParameters() {
}

/**
* Gets the query vector 2D float array.
*
* @return 2D float array
* Gets the query vector matrix.
*/
public float[][] getQueryVectors() {
public CuVSMatrix getQueryVectors() {
return queryVectors;
}

Expand Down Expand Up @@ -137,7 +134,7 @@ public String toString() {
return "CuVSQuery [cagraSearchParameters="
+ cagraSearchParameters
+ ", queryVectors="
+ Arrays.toString(queryVectors)
+ queryVectors.toString()
+ ", mapping="
+ mapping
+ ", topK="
Expand All @@ -151,7 +148,7 @@ public String toString() {
public static class Builder {

private CagraSearchParams cagraSearchParams;
private float[][] queryVectors;
private CuVSMatrix queryVectors;
private LongToIntFunction mapping = SearchResults.IDENTITY_MAPPING;
private int topK = 2;
private BitSet prefilter;
Expand Down Expand Up @@ -186,10 +183,10 @@ public Builder withSearchParams(CagraSearchParams cagraSearchParams) {
/**
* Registers the query vectors to be passed in the search call.
*
* @param queryVectors 2D float query vector array
* @param queryVectors 2D query vector array
* @return an instance of this Builder
*/
public Builder withQueryVectors(float[][] queryVectors) {
public Builder withQueryVectors(CuVSMatrix queryVectors) {
this.queryVectors = queryVectors;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

import static com.nvidia.cuvs.internal.CuVSParamsHelper.*;
import static com.nvidia.cuvs.internal.common.CloseableRMMAllocation.allocateRMMSegment;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.DEVICE_TO_HOST;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.INFER_DIRECTION;
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
import static com.nvidia.cuvs.internal.common.Util.concatenate;
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
import static com.nvidia.cuvs.internal.panama.headers_h.*;

Expand All @@ -35,6 +32,7 @@
import com.nvidia.cuvs.internal.common.CloseableHandle;
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
import com.nvidia.cuvs.internal.common.CompositeCloseableHandle;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.*;
import java.io.FileInputStream;
import java.io.InputStream;
Expand Down Expand Up @@ -232,19 +230,18 @@ public SearchResults search(CagraQuery query) throws Throwable {
try (var localArena = Arena.ofConfined()) {
checkNotDestroyed();
int topK = query.getTopK();
long numQueries = query.getQueryVectors().length;
var queryVectors = (CuVSMatrixInternal) query.getQueryVectors();
long numQueries = queryVectors.size();
long numBlocks = topK * numQueries;
int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0;

SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_INT);
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
SequenceLayout distancesSequenceLayout =
MemoryLayout.sequenceLayout(numBlocks, queryVectors.valueLayout());
MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
MemorySegment floatsSeg = buildMemorySegment(localArena, query.getQueryVectors());

final long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
final long neighborsBytes = C_INT_BYTE_SIZE * numQueries * topK;
final long distancesBytes = C_FLOAT_BYTE_SIZE * numQueries * topK;
final long distancesBytes = queryVectors.valueLayout().byteSize() * numQueries * topK;
final boolean hasPreFilter = query.getPrefilter() != null;
final BitSet[] prefilters =
hasPreFilter ? new BitSet[] {query.getPrefilter()} : EMPTY_PREFILTER_BITSET;
Expand All @@ -254,80 +251,90 @@ public SearchResults search(CagraQuery query) throws Throwable {

try (var resourcesAccessor = query.getResources().access()) {
var cuvsRes = resourcesAccessor.handle();
var cuvsStream = Util.getStream(cuvsRes);

try (var queriesDP = allocateRMMSegment(cuvsRes, queriesBytes);
try (var deviceQueryVectors =
(CuVSMatrixInternal) queryVectors.toDevice(query.getResources());
var neighborsDP = allocateRMMSegment(cuvsRes, neighborsBytes);
var distancesDP = allocateRMMSegment(cuvsRes, distancesBytes);
var prefilterDP =
hasPreFilter
? allocateRMMSegment(cuvsRes, prefilterBytes)
: CloseableRMMAllocation.EMPTY) {

cudaMemcpy(queriesDP.handle(), floatsSeg, queriesBytes, INFER_DIRECTION);

long[] queriesShape = {numQueries, vectorDimension};
MemorySegment queriesTensor =
prepareTensor(
localArena, queriesDP.handle(), queriesShape, kDLFloat(), 32, kDLCUDA());
var queryTensor = deviceQueryVectors.toTensor(localArena);
long[] neighborsShape = {numQueries, topK};
MemorySegment neighborsTensor =
prepareTensor(
localArena, neighborsDP.handle(), neighborsShape, kDLUInt(), 32, kDLCUDA());
long[] distancesShape = {numQueries, topK};
MemorySegment distancesTensor =
prepareTensor(
localArena, distancesDP.handle(), distancesShape, kDLFloat(), 32, kDLCUDA());

var returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");
localArena,
distancesDP.handle(),
distancesShape,
deviceQueryVectors.code(),
deviceQueryVectors.bits(),
kDLCUDA());

// prepare the prefiltering data
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
if (hasPreFilter) {
BitSet concatenatedFilters = concatenate(prefilters, query.getNumDocs());
long[] filters = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = buildMemorySegment(localArena, filters);
}

MemorySegment prefilter = cuvsFilter.allocate(localArena);
MemorySegment prefilterTensor;

if (!hasPreFilter) {
cuvsFilter.type(prefilter, 0); // NO_FILTER
cuvsFilter.addr(prefilter, 0);
} else {
BitSet concatenatedFilters = concatenate(prefilters, query.getNumDocs());
long[] filters = concatenatedFilters.toLongArray();
var prefilterDataMemorySegment = buildMemorySegment(localArena, filters);

long[] prefilterShape = {prefilterLen};

cudaMemcpy(
prefilterDP.handle(), prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
Util.cudaMemcpyAsync(
prefilterDP.handle(),
prefilterDataMemorySegment,
prefilterBytes,
HOST_TO_DEVICE,
cuvsStream);

prefilterTensor =
MemorySegment prefilterTensor =
prepareTensor(
localArena, prefilterDP.handle(), prefilterShape, kDLUInt(), 32, kDLCUDA());

cuvsFilter.type(prefilter, 1);
cuvsFilter.addr(prefilter, prefilterTensor.address());
}

returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");
// TODO: do we need this stream sync here?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to do this sync here - since cuvsCagraSearch is operating on the same stream

checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync");

returnValue =
var returnValue =
cuvsCagraSearch(
cuvsRes,
segmentFromSearchParams(localArena, query.getCagraSearchParameters()),
cagraIndexReference.getMemorySegment(),
queriesTensor,
queryTensor,
neighborsTensor,
distancesTensor,
prefilter);
checkCuVSError(returnValue, "cuvsCagraSearch");

returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");

cudaMemcpy(neighborsMemorySegment, neighborsDP.handle(), neighborsBytes, INFER_DIRECTION);
cudaMemcpy(distancesMemorySegment, distancesDP.handle(), distancesBytes, INFER_DIRECTION);
// TODO: we can avoid/defer this using CuVSDeviceMatrix for neighborsDP and distancesDP
// TODO: also, should we use cuvsMatrixCopy instead?
Util.cudaMemcpyAsync(
neighborsMemorySegment,
neighborsDP.handle(),
neighborsBytes,
DEVICE_TO_HOST,
cuvsStream);
Util.cudaMemcpyAsync(
distancesMemorySegment,
distancesDP.handle(),
distancesBytes,
DEVICE_TO_HOST,
cuvsStream);

checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,6 @@ public ValueLayout valueLayout() {
return deviceMatrix.valueLayout();
}

@Override
public int bits() {
return deviceMatrix.bits();
}

@Override
public int code() {
return 0;
}

@Override
public MemorySegment toTensor(Arena arena) {
return deviceMatrix.toTensor(arena);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ default int bits() {
* DLTensor data type {@code code} for the element type of this matrix
*/
default int code() {
return switch (dataType()) {
return code(dataType());
}

static int code(DataType dataType) {
return switch (dataType) {
case FLOAT -> kDLFloat();
case INT -> kDLInt();
case UINT, BYTE -> kDLUInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,18 @@ public static void cudaMemcpyAsync(
* Helper to get the CUDA stream associated with a {@link CuVSResources}
*/
public static MemorySegment getStream(CuVSResources resources) {
try (var resourcesAccess = resources.access();
var localArena = Arena.ofConfined()) {
try (var resourcesAccess = resources.access()) {
return getStream(resourcesAccess.handle());
}
}

/**
* Helper to get the CUDA stream associated with a {@link CuVSResources} handle
*/
public static MemorySegment getStream(long resourcesHandle) {
try (var localArena = Arena.ofConfined()) {
var streamPointer = localArena.allocate(cudaStream_t);
checkCuVSError(cuvsStreamGet(resourcesAccess.handle(), streamPointer), "cuvsStreamGet");
checkCuVSError(cuvsStreamGet(resourcesHandle, streamPointer), "cuvsStreamGet");
return streamPointer.get(cudaStream_t, 0);
}
}
Expand Down
Loading