Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,31 @@
public interface CuVSResources extends AutoCloseable {

/**
* Closes this resources and releases any resources associated with it.
* Provide scoped access to the native resources object.
*/
interface ScopedAccess extends AutoCloseable {
/**
* Gets the opaque CuVSResources handle, to be used whenever we need to pass a cuvsResources_t parameter
*
* @return the CuVSResources handle
*/
long handle();

@Override
void close();
}

/**
* Gets scoped access to the native resources object.
* The native resource object is not thread safe: only a single thread at every time should access
* concurrently the same native resources. Calling this method from multiple thread is OK, but the
* returned {@link ScopedAccess} object must be closed before calling {@code access()} again from a
* different thread.
*/
ScopedAccess access();

/**
* Closes this CuVSResources object and releases any resources associated with it.
*/
@Override
void close();
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class CuVSResourcesImpl implements CuVSResources {

private final Path tempDirectory;
private final long resourceHandle;
private boolean destroyed;
private final ScopedAccess access;

/**
* Constructor that allocates the resources needed for cuVS
Expand All @@ -45,38 +45,35 @@ public CuVSResourcesImpl(Path tempDirectory) {
var resourcesMemorySegment = localArena.allocate(cuvsResources_t);
int returnValue = cuvsResourcesCreate(resourcesMemorySegment);
checkCuVSError(returnValue, "cuvsResourcesCreate");
resourceHandle = resourcesMemorySegment.get(cuvsResources_t, 0);
this.resourceHandle = resourcesMemorySegment.get(cuvsResources_t, 0);
this.access =
new ScopedAccess() {
@Override
public long handle() {
return resourceHandle;
}

@Override
public void close() {}
};
}
}

@Override
public ScopedAccess access() {
return this.access;
}

@Override
public void close() {
synchronized (this) {
checkNotDestroyed();
int returnValue = cuvsResourcesDestroy(resourceHandle);
checkCuVSError(returnValue, "cuvsResourcesDestroy");
destroyed = true;
}
}

@Override
public Path tempDirectory() {
return tempDirectory;
}

private void checkNotDestroyed() {
if (destroyed) {
throw new IllegalStateException("destroyed");
}
}

/**
* Gets the opaque CuVSResources handle, to be used whenever we need to pass a cuvsResources_t parameter
*
* @return the CuVSResources handle
*/
long getHandle() {
checkNotDestroyed();
return resourceHandle;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@
*/
public class HnswIndexImpl implements HnswIndex {

private final CuVSResourcesImpl resources;
private final CuVSResources resources;
private final HnswIndexParams hnswIndexParams;
private final IndexReference hnswIndexReference;

/**
* Constructor for loading the index from an {@link InputStream}
*
* @param inputStream an instance of stream to read the index bytes from
* @param resources an instance of {@link CuVSResourcesImpl}
* @param resources an instance of {@link CuVSResources}
*/
private HnswIndexImpl(
InputStream inputStream, CuVSResourcesImpl resources, HnswIndexParams hnswIndexParams)
InputStream inputStream, CuVSResources resources, HnswIndexParams hnswIndexParams)
throws Throwable {
this.hnswIndexParams = hnswIndexParams;
this.resources = resources;
Expand Down Expand Up @@ -105,8 +105,6 @@ public SearchResults search(HnswQuery query) throws Throwable {
MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);
MemorySegment querySeg = buildMemorySegment(localArena, queryVectors);

long cuvsRes = resources.getHandle();

long[] queriesShape = {numQueries, vectorDimension};
MemorySegment queriesTensor =
prepareTensor(localArena, querySeg, queriesShape, 2, 32, 2, 1, 1);
Expand All @@ -117,21 +115,24 @@ public SearchResults search(HnswQuery query) throws Throwable {
MemorySegment distancesTensor =
prepareTensor(localArena, distancesMemorySegment, distancesShape, 2, 32, 2, 1, 1);

int returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");

returnValue =
cuvsHnswSearch(
cuvsRes,
segmentFromSearchParams(localArena, query.getHnswSearchParams()),
hnswIndexReference.getMemorySegment(),
queriesTensor,
neighborsTensor,
distancesTensor);
checkCuVSError(returnValue, "cuvsHnswSearch");

returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");
try (var resourcesAccessor = resources.access()) {
var cuvsRes = resourcesAccessor.handle();
int returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");

returnValue =
cuvsHnswSearch(
cuvsRes,
segmentFromSearchParams(localArena, query.getHnswSearchParams()),
hnswIndexReference.getMemorySegment(),
queriesTensor,
neighborsTensor,
distancesTensor);
checkCuVSError(returnValue, "cuvsHnswSearch");

returnValue = cuvsStreamSync(cuvsRes);
checkCuVSError(returnValue, "cuvsStreamSync");
}

return HnswSearchResults.create(
neighborsSequenceLayout,
Expand Down Expand Up @@ -178,8 +179,6 @@ private IndexReference deserialize(InputStream inputStream) throws Throwable {
inputStream.transferTo(outputStream);
MemorySegment pathSeg = buildMemorySegment(localArena, tmpIndexFile.toString());

long cuvsRes = resources.getHandle();

var indexReference = createHnswIndex();

MemorySegment dtype = DLDataType.allocate(localArena);
Expand All @@ -189,15 +188,18 @@ private IndexReference deserialize(InputStream inputStream) throws Throwable {

cuvsHnswIndex.dtype(indexReference.memorySegment, dtype);

var returnValue =
cuvsHnswDeserialize(
cuvsRes,
segmentFromIndexParams(localArena, hnswIndexParams),
pathSeg,
hnswIndexParams.getVectorDimension(),
0,
indexReference.memorySegment);
checkCuVSError(returnValue, "cuvsHnswDeserialize");
try (var resourcesAccessor = resources.access()) {
var cuvsRes = resourcesAccessor.handle();
var returnValue =
cuvsHnswDeserialize(
cuvsRes,
segmentFromIndexParams(localArena, hnswIndexParams),
pathSeg,
hnswIndexParams.getVectorDimension(),
0,
indexReference.memorySegment);
checkCuVSError(returnValue, "cuvsHnswDeserialize");
}

return indexReference;

Expand Down Expand Up @@ -227,19 +229,15 @@ private static MemorySegment segmentFromSearchParams(Arena arena, HnswSearchPara
}

public static HnswIndex.Builder newBuilder(CuVSResources cuvsResources) {
Objects.requireNonNull(cuvsResources);
if (!(cuvsResources instanceof CuVSResourcesImpl)) {
throw new IllegalArgumentException("Unsupported " + cuvsResources);
}
return new HnswIndexImpl.Builder((CuVSResourcesImpl) cuvsResources);
return new HnswIndexImpl.Builder(Objects.requireNonNull(cuvsResources));
}

/**
* Builder helps configure and create an instance of {@link HnswIndex}.
*/
public static class Builder implements HnswIndex.Builder {

private final CuVSResourcesImpl cuvsResources;
private final CuVSResources cuvsResources;
private InputStream inputStream;
private HnswIndexParams hnswIndexParams;

Expand All @@ -248,7 +246,7 @@ public static class Builder implements HnswIndex.Builder {
*
* @param cuvsResources an instance of {@link CuVSResources}
*/
public Builder(CuVSResourcesImpl cuvsResources) {
public Builder(CuVSResources cuvsResources) {
this.cuvsResources = cuvsResources;
}

Expand Down
Loading