1616
1717package com .nvidia .cuvs .internal ;
1818
19- import static java .lang .foreign .ValueLayout .ADDRESS ;
2019import static com .nvidia .cuvs .internal .common .LinkerHelper .C_FLOAT ;
2120import static com .nvidia .cuvs .internal .common .LinkerHelper .C_FLOAT_BYTE_SIZE ;
2221import static com .nvidia .cuvs .internal .common .LinkerHelper .C_INT ;
2322import static com .nvidia .cuvs .internal .common .LinkerHelper .C_INT_BYTE_SIZE ;
2423import static com .nvidia .cuvs .internal .common .LinkerHelper .C_POINTER ;
25- import static com .nvidia .cuvs .internal .common .LinkerHelper .C_LONG ;
2624import static com .nvidia .cuvs .internal .common .Util .buildMemorySegment ;
2725import static com .nvidia .cuvs .internal .common .Util .checkCuVSError ;
28- import static com .nvidia .cuvs .internal .common .Util .checkCudaError ;
2926import static com .nvidia .cuvs .internal .common .Util .concatenate ;
27+ import static com .nvidia .cuvs .internal .common .Util .cudaMemcpy ;
28+ import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .*;
3029import static com .nvidia .cuvs .internal .common .Util .prepareTensor ;
3130import static com .nvidia .cuvs .internal .panama .headers_h .cuvsCagraBuild ;
3231import static com .nvidia .cuvs .internal .panama .headers_h .cuvsCagraDeserialize ;
4544import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamGet ;
4645import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamSync ;
4746import static com .nvidia .cuvs .internal .panama .headers_h .omp_set_num_threads ;
48- import static com .nvidia .cuvs .internal .panama .headers_h .cudaMemcpy ;
4947import static com .nvidia .cuvs .internal .panama .headers_h .cudaStream_t ;
5048
5149import java .io .FileInputStream ;
5250import java .io .FileOutputStream ;
5351import java .io .InputStream ;
5452import java .io .OutputStream ;
5553import java .lang .foreign .Arena ;
56- import java .lang .foreign .FunctionDescriptor ;
5754import java .lang .foreign .MemoryLayout ;
5855import java .lang .foreign .MemorySegment ;
5956import java .lang .foreign .SequenceLayout ;
6057import java .lang .foreign .ValueLayout ;
61- import java .lang .invoke .MethodHandle ;
6258import java .nio .file .Files ;
6359import java .nio .file .Path ;
6460import java .util .Objects ;
@@ -287,8 +283,7 @@ public SearchResults search(CagraQuery query) throws Throwable {
287283 MemorySegment prefilterDP = MemorySegment .NULL ;
288284 long prefilterLen = 0 ;
289285
290- returnValue = cudaMemcpy (queriesDP , floatsSeg , queriesBytes , 4 );
291- checkCudaError (returnValue , "cudaMemcpy" );
286+ cudaMemcpy (queriesDP , floatsSeg , queriesBytes , INFER_DIRECTION );
292287
293288 long queriesShape [] = { numQueries , vectorDimension };
294289 MemorySegment queriesTensor = prepareTensor (arena , queriesDP , queriesShape , 2 , 32 , 2 , 2 , 1 );
@@ -329,8 +324,7 @@ public SearchResults search(CagraQuery query) throws Throwable {
329324
330325 prefilterDP = prefilterD .get (C_POINTER , 0 );
331326
332- returnValue = cudaMemcpy (prefilterDP , prefilterDataMemorySegment , prefilterBytes , 1 );
333- checkCudaError (returnValue , "cudaMemcpy" );
327+ cudaMemcpy (prefilterDP , prefilterDataMemorySegment , prefilterBytes , HOST_TO_DEVICE );
334328
335329 prefilterTensor = prepareTensor (arena , prefilterDP , prefilterShape , 1 , 32 , 1 , 2 , 1 );
336330
@@ -348,10 +342,8 @@ public SearchResults search(CagraQuery query) throws Throwable {
348342 returnValue = cuvsStreamSync (cuvsRes );
349343 checkCuVSError (returnValue , "cuvsStreamSync" );
350344
351- returnValue = cudaMemcpy (neighborsMemorySegment , neighborsDP , neighborsBytes , 4 );
352- checkCudaError (returnValue , "cudaMemcpy" );
353- returnValue = cudaMemcpy (distancesMemorySegment , distancesDP , distancesBytes , 4 );
354- checkCudaError (returnValue , "cudaMemcpy" );
345+ cudaMemcpy (neighborsMemorySegment , neighborsDP , neighborsBytes , INFER_DIRECTION );
346+ cudaMemcpy (distancesMemorySegment , distancesDP , distancesBytes , INFER_DIRECTION );
355347
356348 returnValue = cuvsRMMFree (cuvsRes , distancesDP , distancesBytes );
357349 checkCuVSError (returnValue , "cuvsRMMFree" );
0 commit comments