Skip to content

Commit 3246c6e

Browse files
committed
refactor: Declarative implementation
1 parent a7ebd41 commit 3246c6e

12 files changed

+149
-239
lines changed

pom.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,6 @@
190190
<configuration>
191191
<includeStale>false</includeStale>
192192
<style>GOOGLE</style>
193-
<formatMain>true</formatMain>
194-
<formatTest>true</formatTest>
195193
<filterModified>false</filterModified>
196194
<skip>false</skip>
197195
<fixImports>true</fixImports>

src/main/java/io/qdrant/spark/Qdrant.java

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,40 @@
88
import org.apache.spark.sql.types.StructType;
99
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
1010

11-
/** A class that implements the TableProvider and DataSourceRegister interfaces. */
11+
/** Qdrant datasource for Apache Spark. */
1212
public class Qdrant implements TableProvider, DataSourceRegister {
1313

14-
private final String[] requiredFields = new String[] {"schema", "collection_name", "qdrant_url"};
14+
private static final String[] REQUIRED_FIELDS = {"schema", "collection_name", "qdrant_url"};
1515

16-
/**
17-
* Returns the short name of the data source.
18-
*
19-
* @return The short name of the data source.
20-
*/
16+
/** Returns the short name of the data source. */
2117
@Override
2218
public String shortName() {
2319
return "qdrant";
2420
}
2521

2622
/**
27-
* Infers the schema of the data source based on the provided options.
23+
* Validates and infers the schema from the provided options.
2824
*
29-
* @param options The options used to infer the schema.
30-
* @return The inferred schema.
25+
* @throws IllegalArgumentException if required options are missing.
3126
*/
3227
@Override
3328
public StructType inferSchema(CaseInsensitiveStringMap options) {
34-
for (String fieldName : requiredFields) {
35-
if (!options.containsKey(fieldName)) {
36-
throw new IllegalArgumentException(fieldName.concat(" option is required"));
29+
validateOptions(options);
30+
return (StructType) StructType.fromJson(options.get("schema"));
31+
}
32+
33+
private void validateOptions(CaseInsensitiveStringMap options) {
34+
for (String field : REQUIRED_FIELDS) {
35+
if (!options.containsKey(field)) {
36+
throw new IllegalArgumentException(String.format("%s option is required", field));
3737
}
3838
}
39-
StructType schema = (StructType) StructType.fromJson(options.get("schema"));
40-
41-
return schema;
4239
}
4340

4441
/**
45-
* Returns a table for the data source based on the provided schema, partitioning, and properties.
42+
* Creates a Qdrant table instance with validated options.
4643
*
47-
* @param schema The schema of the table.
48-
* @param partitioning The partitioning of the table.
49-
* @param properties The properties of the table.
50-
* @return The table for the data source.
44+
* @throws IllegalArgumentException if options are invalid.
5145
*/
5246
@Override
5347
public Table getTable(

src/main/java/io/qdrant/spark/QdrantBatchWriter.java

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import org.apache.spark.sql.connector.write.WriterCommitMessage;
77
import org.apache.spark.sql.types.StructType;
88

9-
/** QdrantBatchWriter class implements the BatchWrite interface. */
9+
/** Qdrant batch writer for Apache Spark. */
1010
public class QdrantBatchWriter implements BatchWrite {
1111

1212
private final QdrantOptions options;
@@ -23,13 +23,8 @@ public DataWriterFactory createBatchWriterFactory(PhysicalWriteInfo info) {
2323
}
2424

2525
@Override
26-
public void commit(WriterCommitMessage[] messages) {
27-
// TODO Auto-generated method stub
28-
29-
}
26+
public void commit(WriterCommitMessage[] messages) {}
3027

3128
@Override
32-
public void abort(WriterCommitMessage[] messages) {
33-
// TODO Auto-generated method stub
34-
}
29+
public void abort(WriterCommitMessage[] messages) {}
3530
}
Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
11
package io.qdrant.spark;
22

3-
import java.util.Arrays;
43
import java.util.Collections;
5-
import java.util.HashSet;
4+
import java.util.EnumSet;
65
import java.util.Set;
76
import org.apache.spark.sql.connector.catalog.SupportsWrite;
87
import org.apache.spark.sql.connector.catalog.TableCapability;
98
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
109
import org.apache.spark.sql.connector.write.WriteBuilder;
1110
import org.apache.spark.sql.types.StructType;
1211

13-
/** QdrantCluster class implements the SupportsWrite interface. */
12+
/** Qdrant cluster implementation supporting batch writes. */
1413
public class QdrantCluster implements SupportsWrite {
1514

1615
private final StructType schema;
1716
private final QdrantOptions options;
1817

19-
private static final Set<TableCapability> TABLE_CAPABILITY_SET =
20-
Collections.unmodifiableSet(
21-
new HashSet<>(
22-
Arrays.asList(TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE)));
18+
private static final Set<TableCapability> CAPABILITIES = EnumSet.of(TableCapability.BATCH_WRITE);
2319

2420
public QdrantCluster(QdrantOptions options, StructType schema) {
2521
this.options = options;
@@ -28,7 +24,7 @@ public QdrantCluster(QdrantOptions options, StructType schema) {
2824

2925
@Override
3026
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
31-
return new QdrantWriteBuilder(this.options, this.schema);
27+
return new QdrantWriteBuilder(options, schema);
3228
}
3329

3430
@Override
@@ -38,11 +34,11 @@ public String name() {
3834

3935
@Override
4036
public StructType schema() {
41-
return this.schema;
37+
return schema;
4238
}
4339

4440
@Override
4541
public Set<TableCapability> capabilities() {
46-
return TABLE_CAPABILITY_SET;
42+
return Collections.unmodifiableSet(CAPABILITIES);
4743
}
4844
}

src/main/java/io/qdrant/spark/QdrantDataWriter.java

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,88 @@
11
package io.qdrant.spark;
22

3-
import io.qdrant.client.grpc.JsonWithInt.Value;
4-
import io.qdrant.client.grpc.Points.PointId;
53
import io.qdrant.client.grpc.Points.PointStruct;
6-
import io.qdrant.client.grpc.Points.Vectors;
74
import java.io.Serializable;
85
import java.net.URL;
96
import java.util.ArrayList;
10-
import java.util.Map;
7+
import java.util.List;
118
import org.apache.spark.sql.catalyst.InternalRow;
129
import org.apache.spark.sql.connector.write.DataWriter;
1310
import org.apache.spark.sql.connector.write.WriterCommitMessage;
1411
import org.apache.spark.sql.types.StructType;
1512
import org.slf4j.Logger;
1613
import org.slf4j.LoggerFactory;
1714

18-
/** A DataWriter implementation that writes data to Qdrant. */
15+
/** DataWriter implementation for writing data to Qdrant. */
1916
public class QdrantDataWriter implements DataWriter<InternalRow>, Serializable {
17+
18+
private static final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);
19+
2020
private final QdrantOptions options;
2121
private final StructType schema;
22-
private final String qdrantUrl;
23-
private final String apiKey;
24-
private final Logger LOG = LoggerFactory.getLogger(QdrantDataWriter.class);
25-
26-
private final ArrayList<PointStruct> points = new ArrayList<>();
22+
private final List<PointStruct> pointsBuffer = new ArrayList<>();
2723

2824
public QdrantDataWriter(QdrantOptions options, StructType schema) {
2925
this.options = options;
3026
this.schema = schema;
31-
this.qdrantUrl = options.qdrantUrl;
32-
this.apiKey = options.apiKey;
3327
}
3428

3529
@Override
3630
public void write(InternalRow record) {
37-
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
38-
39-
PointId pointId = QdrantPointIdHandler.preparePointId(record, this.schema, this.options);
40-
pointBuilder.setId(pointId);
41-
42-
Vectors vectors = QdrantVectorHandler.prepareVectors(record, this.schema, this.options);
43-
pointBuilder.setVectors(vectors);
44-
45-
Map<String, Value> payload =
46-
QdrantPayloadHandler.preparePayload(record, this.schema, this.options);
47-
pointBuilder.putAllPayload(payload);
48-
49-
this.points.add(pointBuilder.build());
31+
PointStruct point = createPointStruct(record);
32+
pointsBuffer.add(point);
5033

51-
if (this.points.size() >= this.options.batchSize) {
52-
this.write(this.options.retries);
34+
if (pointsBuffer.size() >= options.batchSize) {
35+
writeBatch(options.retries);
5336
}
5437
}
5538

56-
@Override
57-
public WriterCommitMessage commit() {
58-
this.write(this.options.retries);
59-
return new WriterCommitMessage() {
60-
@Override
61-
public String toString() {
62-
return "point committed to Qdrant";
63-
}
64-
};
39+
private PointStruct createPointStruct(InternalRow record) {
40+
PointStruct.Builder pointBuilder = PointStruct.newBuilder();
41+
pointBuilder.setId(QdrantPointIdHandler.preparePointId(record, schema, options));
42+
pointBuilder.setVectors(QdrantVectorHandler.prepareVectors(record, schema, options));
43+
pointBuilder.putAllPayload(QdrantPayloadHandler.preparePayload(record, schema, options));
44+
return pointBuilder.build();
6545
}
6646

67-
public void write(int retries) {
68-
LOG.info(
69-
String.join(
70-
"", "Uploading batch of ", Integer.toString(this.points.size()), " points to Qdrant"));
71-
72-
if (this.points.isEmpty()) {
47+
private void writeBatch(int retries) {
48+
if (pointsBuffer.isEmpty()) {
7349
return;
7450
}
51+
7552
try {
76-
// Instantiate a new QdrantGrpc object to maintain serializability
77-
QdrantGrpc qdrant = new QdrantGrpc(new URL(this.qdrantUrl), this.apiKey);
78-
qdrant.upsert(this.options.collectionName, this.points, this.options.shardKeySelector);
79-
qdrant.close();
80-
this.points.clear();
53+
doWriteBatch();
54+
pointsBuffer.clear();
8155
} catch (Exception e) {
82-
LOG.error(String.join("", "Exception while uploading batch to Qdrant: ", e.getMessage()));
56+
LOG.error("Exception while uploading batch to Qdrant: {}", e.getMessage());
8357
if (retries > 0) {
8458
LOG.info("Retrying upload batch to Qdrant");
85-
write(retries - 1);
59+
writeBatch(retries - 1);
8660
} else {
8761
throw new RuntimeException(e);
8862
}
8963
}
9064
}
9165

66+
private void doWriteBatch() throws Exception {
67+
LOG.info("Uploading batch of {} points to Qdrant", pointsBuffer.size());
68+
69+
// Instantiate QdrantGrpc client for each batch to maintain serializability
70+
QdrantGrpc qdrant = new QdrantGrpc(new URL(options.qdrantUrl), options.apiKey);
71+
qdrant.upsert(options.collectionName, pointsBuffer, options.shardKeySelector);
72+
qdrant.close();
73+
}
74+
75+
@Override
76+
public WriterCommitMessage commit() {
77+
writeBatch(options.retries);
78+
return new WriterCommitMessage() {
79+
@Override
80+
public String toString() {
81+
return "point committed to Qdrant";
82+
}
83+
};
84+
}
85+
9286
@Override
9387
public void abort() {}
9488

src/main/java/io/qdrant/spark/QdrantDataWriterFactory.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,35 +4,26 @@
44
import org.apache.spark.sql.connector.write.streaming.StreamingDataWriterFactory;
55
import org.apache.spark.sql.types.StructType;
66

7-
/** Factory class for creating QdrantDataWriter instances for Spark Structured Streaming. */
7+
/** Factory class for creating QdrantDataWriter instances for Spark data sources. */
88
public class QdrantDataWriterFactory implements StreamingDataWriterFactory, DataWriterFactory {
9+
910
private final QdrantOptions options;
1011
private final StructType schema;
1112

12-
/**
13-
* Constructor for QdrantDataWriterFactory.
14-
*
15-
* @param options QdrantOptions instance containing configuration options for Qdrant.
16-
* @param schema StructType instance containing schema information for the data being written.
17-
*/
1813
public QdrantDataWriterFactory(QdrantOptions options, StructType schema) {
1914
this.options = options;
2015
this.schema = schema;
2116
}
2217

2318
@Override
2419
public QdrantDataWriter createWriter(int partitionId, long taskId, long epochId) {
25-
try {
26-
return new QdrantDataWriter(this.options, this.schema);
27-
} catch (Exception e) {
28-
throw new RuntimeException(e);
29-
}
20+
return createWriter(partitionId, taskId);
3021
}
3122

3223
@Override
3324
public QdrantDataWriter createWriter(int partitionId, long taskId) {
3425
try {
35-
return new QdrantDataWriter(this.options, this.schema);
26+
return new QdrantDataWriter(options, schema);
3627
} catch (Exception e) {
3728
throw new RuntimeException(e);
3829
}

src/main/java/io/qdrant/spark/QdrantGrpc.java

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,54 +10,33 @@
1010
import java.net.URL;
1111
import java.util.List;
1212
import java.util.concurrent.ExecutionException;
13-
import javax.annotation.Nullable;
1413

15-
/** A class that provides methods to interact with Qdrant GRPC API. */
14+
/** Client for interacting with the Qdrant GRPC API. */
1615
public class QdrantGrpc implements Serializable {
16+
1717
private final QdrantClient client;
1818

19-
/**
20-
* Constructor for QdrantRest class.
21-
*
22-
* @param url The URL of the Qdrant instance.
23-
* @param apiKey The API key to authenticate with Qdrant.
24-
* @throws MalformedURLException If the URL is invalid.
25-
*/
2619
public QdrantGrpc(URL url, String apiKey) throws MalformedURLException {
27-
2820
String host = url.getHost();
2921
int port = url.getPort() == -1 ? 6334 : url.getPort();
3022
boolean useTls = url.getProtocol().equalsIgnoreCase("https");
31-
32-
this.client =
23+
client =
3324
new QdrantClient(
3425
QdrantGrpcClient.newBuilder(host, port, useTls).withApiKey(apiKey).build());
3526
}
3627

37-
/**
38-
* Uploads a batch of points to a Qdrant collection.
39-
*
40-
* @param collectionName The name of the collection to upload the points to.
41-
* @param points The list of points to upload.
42-
* @param shardKeySelector The shard key selector to use for the upsert.
43-
* @throws InterruptedException If there was an error uploading the batch to Qdrant.
44-
* @throws ExecutionException If there was an error uploading the batch to Qdrant.
45-
*/
4628
public void upsert(
47-
String collectionName, List<PointStruct> points, @Nullable ShardKeySelector shardKeySelector)
29+
String collectionName, List<PointStruct> points, ShardKeySelector shardKeySelector)
4830
throws InterruptedException, ExecutionException {
49-
5031
UpsertPoints.Builder upsertPoints =
5132
UpsertPoints.newBuilder().setCollectionName(collectionName).addAllPoints(points);
52-
5333
if (shardKeySelector != null) {
5434
upsertPoints.setShardKeySelector(shardKeySelector);
5535
}
56-
57-
this.client.upsertAsync(upsertPoints.build()).get();
36+
client.upsertAsync(upsertPoints.build()).get();
5837
}
5938

6039
public void close() {
61-
this.client.close();
40+
client.close();
6241
}
6342
}

0 commit comments

Comments
 (0)