Skip to content

Commit 5ed1ada

Browse files
ldematteenp1s0
authored andcommitted
[Review][Java] Refactor: extract interface from CuVSMatrixBaseImpl (rapidsai#1361)
This PR extract an internal interface that is used as a base to implement all internal CuVSMatrix types; the interface introduces commonly used field accessors like `memorySegment()` and `toTensor()` that we do not want/could not appear on the public interface (e.g. because they expose or require Panama types or internal types). The new interface is implemented by all concrete matrix types, closing a gap that we had in rapidsai#1328 (which I realized while working on separate PRs like rapidsai#1283) Follow-up of rapidsai#1328 Authors: - Lorenzo Dematté (https://github.com/ldematte) Approvers: - Ben Frederickson (https://github.com/benfred) URL: rapidsai#1361
1 parent e402535 commit 5ed1ada

File tree

9 files changed

+267
-180
lines changed

9 files changed

+267
-180
lines changed

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSDeviceMatrix.java

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -30,67 +30,4 @@ default CuVSHostMatrix toHost() {
3030
toHost(hostMatrix);
3131
return hostMatrix;
3232
}
33-
34-
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
35-
return new CuVSDeviceMatrixDelegate(this);
36-
}
37-
38-
class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix {
39-
40-
private final CuVSDeviceMatrix deviceMatrix;
41-
42-
private CuVSDeviceMatrixDelegate(CuVSDeviceMatrix deviceMatrix) {
43-
this.deviceMatrix = deviceMatrix;
44-
}
45-
46-
@Override
47-
public long size() {
48-
return deviceMatrix.size();
49-
}
50-
51-
@Override
52-
public long columns() {
53-
return deviceMatrix.columns();
54-
}
55-
56-
@Override
57-
public DataType dataType() {
58-
return deviceMatrix.dataType();
59-
}
60-
61-
@Override
62-
public RowView getRow(long row) {
63-
return deviceMatrix.getRow(row);
64-
}
65-
66-
@Override
67-
public void toArray(int[][] array) {
68-
deviceMatrix.toArray(array);
69-
}
70-
71-
@Override
72-
public void toArray(float[][] array) {
73-
deviceMatrix.toArray(array);
74-
}
75-
76-
@Override
77-
public void toArray(byte[][] array) {
78-
deviceMatrix.toArray(array);
79-
}
80-
81-
@Override
82-
public void toHost(CuVSHostMatrix hostMatrix) {
83-
deviceMatrix.toHost(hostMatrix);
84-
}
85-
86-
@Override
87-
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88-
this.deviceMatrix.toDevice(deviceMatrix, cuVSResources);
89-
}
90-
91-
@Override
92-
public void close() {
93-
// Do nothing
94-
}
95-
}
9633
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSHostMatrix.java

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,76 +21,9 @@
2121
public interface CuVSHostMatrix extends CuVSMatrix {
2222
int get(int row, int col);
2323

24-
default CuVSHostMatrix toHost() {
25-
return new CuVSHostMatrixDelegate(this);
26-
}
27-
2824
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
2925
var deviceMatrix = CuVSMatrix.deviceBuilder(resources, size(), columns(), dataType()).build();
3026
toDevice(deviceMatrix, resources);
3127
return deviceMatrix;
3228
}
33-
34-
class CuVSHostMatrixDelegate implements CuVSHostMatrix {
35-
private final CuVSHostMatrix hostMatrix;
36-
37-
public CuVSHostMatrixDelegate(CuVSHostMatrix cuVSHostMatrix) {
38-
this.hostMatrix = cuVSHostMatrix;
39-
}
40-
41-
@Override
42-
public int get(int row, int col) {
43-
return hostMatrix.get(row, col);
44-
}
45-
46-
@Override
47-
public long size() {
48-
return hostMatrix.size();
49-
}
50-
51-
@Override
52-
public long columns() {
53-
return hostMatrix.columns();
54-
}
55-
56-
@Override
57-
public DataType dataType() {
58-
return hostMatrix.dataType();
59-
}
60-
61-
@Override
62-
public RowView getRow(long row) {
63-
return hostMatrix.getRow(row);
64-
}
65-
66-
@Override
67-
public void toArray(int[][] array) {
68-
hostMatrix.toArray(array);
69-
}
70-
71-
@Override
72-
public void toArray(float[][] array) {
73-
hostMatrix.toArray(array);
74-
}
75-
76-
@Override
77-
public void toArray(byte[][] array) {
78-
hostMatrix.toArray(array);
79-
}
80-
81-
@Override
82-
public void toHost(CuVSHostMatrix hostMatrix) {
83-
this.hostMatrix.toHost(hostMatrix);
84-
}
85-
86-
@Override
87-
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88-
hostMatrix.toDevice(deviceMatrix, cuVSResources);
89-
}
90-
91-
@Override
92-
public void close() {
93-
// Do nothing
94-
}
95-
}
9629
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ private BruteForceIndexImpl(
7878
Objects.requireNonNull(dataset);
7979
try (dataset) {
8080
this.resources = resources;
81-
assert dataset instanceof CuVSMatrixBaseImpl;
82-
this.bruteForceIndexReference = build((CuVSMatrixBaseImpl) dataset, bruteForceIndexParams);
81+
assert dataset instanceof CuVSMatrixInternal;
82+
this.bruteForceIndexReference = build((CuVSMatrixInternal) dataset, bruteForceIndexParams);
8383
}
8484
}
8585

@@ -124,7 +124,7 @@ public void close() {
124124
* index
125125
*/
126126
private IndexReference build(
127-
CuVSMatrixBaseImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
127+
CuVSMatrixInternal dataset, BruteForceIndexParams bruteForceIndexParams) {
128128
long rows = dataset.size();
129129
long cols = dataset.columns();
130130

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
3737
import com.nvidia.cuvs.internal.common.CompositeCloseableHandle;
3838
import com.nvidia.cuvs.internal.panama.*;
39-
4039
import java.io.FileInputStream;
4140
import java.io.InputStream;
4241
import java.io.OutputStream;
@@ -78,8 +77,8 @@ private CagraIndexImpl(
7877
CagraIndexParams indexParameters, CuVSMatrix dataset, CuVSResources resources) {
7978
Objects.requireNonNull(dataset);
8079
this.resources = resources;
81-
assert dataset instanceof CuVSMatrixBaseImpl;
82-
this.cagraIndexReference = build(indexParameters, (CuVSMatrixBaseImpl) dataset);
80+
assert dataset instanceof CuVSMatrixInternal;
81+
this.cagraIndexReference = build(indexParameters, (CuVSMatrixInternal) dataset);
8382
}
8483

8584
/**
@@ -124,11 +123,11 @@ private CagraIndexImpl(
124123

125124
this.resources = resources;
126125

127-
assert graph instanceof CuVSMatrixBaseImpl;
128-
assert dataset instanceof CuVSMatrixBaseImpl;
126+
assert graph instanceof CuVSMatrixInternal;
127+
assert dataset instanceof CuVSMatrixInternal;
129128

130129
this.cagraIndexReference =
131-
fromGraph(metric, (CuVSMatrixBaseImpl) graph, (CuVSMatrixBaseImpl) dataset);
130+
fromGraph(metric, (CuVSMatrixInternal) graph, (CuVSMatrixInternal) dataset);
132131
}
133132

134133
private void checkNotDestroyed() {
@@ -161,7 +160,7 @@ public void close() {
161160
* @return an instance of {@link IndexReference} that holds the pointer to the
162161
* index
163162
*/
164-
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixBaseImpl dataset) {
163+
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixInternal dataset) {
165164
long rows = dataset.size();
166165

167166
try (var indexParams = segmentFromIndexParams(indexParameters);
@@ -410,8 +409,8 @@ public CuVSDeviceMatrix getGraph() {
410409

411410
private IndexReference fromGraph(
412411
CagraIndexParams.CuvsDistanceType metric,
413-
CuVSMatrixBaseImpl graph,
414-
CuVSMatrixBaseImpl dataset) {
412+
CuVSMatrixInternal graph,
413+
CuVSMatrixInternal dataset) {
415414
try (var localArena = Arena.ofConfined()) {
416415
var index = createCagraIndex();
417416
try (var resourcesAccess = resources.access()) {

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSDeviceMatrixImpl.java

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ public void toHost(CuVSHostMatrix hostMatrix) {
224224
throw new IllegalArgumentException("[hostMatrix] must have the same dataType");
225225
}
226226
try (var localArena = Arena.ofConfined()) {
227-
var hostMatrixTensor = ((CuVSHostMatrixImpl) hostMatrix).toTensor(localArena);
227+
var hostMatrixTensor = ((CuVSMatrixInternal) hostMatrix).toTensor(localArena);
228228

229229
try (var resourceAccess = resources.access()) {
230230
var cuvsRes = resourceAccess.handle();
@@ -236,9 +236,14 @@ public void toHost(CuVSHostMatrix hostMatrix) {
236236
}
237237
}
238238

239+
@Override
240+
public CuVSDeviceMatrix toDevice(CuVSResources resources) {
241+
return new CuVSDeviceMatrixDelegate(this);
242+
}
243+
239244
@Override
240245
public void toDevice(CuVSDeviceMatrix targetMatrix, CuVSResources cuVSResources) {
241-
copyMatrix(this, (CuVSMatrixBaseImpl) targetMatrix, cuVSResources);
246+
copyMatrix(this, (CuVSMatrixInternal) targetMatrix, cuVSResources);
242247
}
243248

244249
@Override
@@ -248,4 +253,92 @@ public void close() {
248253
hostBuffer = MemorySegment.NULL;
249254
}
250255
}
256+
257+
private static class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix, CuVSMatrixInternal {
258+
private final CuVSDeviceMatrixImpl deviceMatrix;
259+
260+
private CuVSDeviceMatrixDelegate(CuVSDeviceMatrixImpl deviceMatrix) {
261+
this.deviceMatrix = deviceMatrix;
262+
}
263+
264+
@Override
265+
public long size() {
266+
return deviceMatrix.size();
267+
}
268+
269+
@Override
270+
public long columns() {
271+
return deviceMatrix.columns();
272+
}
273+
274+
@Override
275+
public DataType dataType() {
276+
return deviceMatrix.dataType();
277+
}
278+
279+
@Override
280+
public RowView getRow(long row) {
281+
return deviceMatrix.getRow(row);
282+
}
283+
284+
@Override
285+
public void toArray(int[][] array) {
286+
deviceMatrix.toArray(array);
287+
}
288+
289+
@Override
290+
public void toArray(float[][] array) {
291+
deviceMatrix.toArray(array);
292+
}
293+
294+
@Override
295+
public void toArray(byte[][] array) {
296+
deviceMatrix.toArray(array);
297+
}
298+
299+
@Override
300+
public void toHost(CuVSHostMatrix hostMatrix) {
301+
deviceMatrix.toHost(hostMatrix);
302+
}
303+
304+
@Override
305+
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
306+
deviceMatrix.toDevice(deviceMatrix, cuVSResources);
307+
}
308+
309+
@Override
310+
public CuVSDeviceMatrix toDevice(CuVSResources cuVSResources) {
311+
return this;
312+
}
313+
314+
@Override
315+
public MemorySegment memorySegment() {
316+
return deviceMatrix.memorySegment();
317+
}
318+
319+
@Override
320+
public ValueLayout valueLayout() {
321+
return deviceMatrix.valueLayout();
322+
}
323+
324+
@Override
325+
public int bits() {
326+
return deviceMatrix.bits();
327+
}
328+
329+
@Override
330+
public int code() {
331+
return 0;
332+
}
333+
334+
@Override
335+
public MemorySegment toTensor(Arena arena) {
336+
return deviceMatrix.toTensor(arena);
337+
}
338+
339+
@Override
340+
public void close() {
341+
// Do nothing
342+
}
343+
}
251344
}

0 commit comments

Comments
 (0)