Skip to content

Commit 7d42173

Browse files
committed
Restructure the copyFrom function to take structured input instead of a query
1 parent 0236666 commit 7d42173

File tree

3 files changed

+72
-20
lines changed

3 files changed

+72
-20
lines changed

Sources/PostgresNIO/Connection/PostgresConnection.swift

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -786,25 +786,54 @@ public struct PostgresCopyFromWriter: Sendable {
786786
}
787787
}
788788

789+
public struct CopyFromOptions {
790+
let delimiter: StaticString?
791+
792+
public init(delimiter: StaticString? = nil) {
793+
self.delimiter = delimiter
794+
}
795+
}
796+
797+
private func buildCopyFromQuery(
798+
table: String,
799+
columns: [StaticString]?,
800+
options: CopyFromOptions
801+
) -> PostgresQuery {
802+
var query = "COPY \(table)"
803+
if let columns {
804+
// TODO: Is using `StaticString` sufficient here to prevent against SQL injection attacks or should we try to
805+
// escape the identifiers, essentially re-implementing `PQescapeIdentifier`?
806+
query += "(" + columns.map(\.description).joined(separator: ",") + ")"
807+
}
808+
query += " FROM STDIN"
809+
var queryOptions: [String] = []
810+
if let delimiter = options.delimiter {
811+
queryOptions.append("DELIMITER '\(delimiter)'")
812+
}
813+
if !queryOptions.isEmpty {
814+
query += " WITH "
815+
query += queryOptions.map { "(\($0))" }.joined(separator: ",")
816+
}
817+
return "\(unescaped: query)"
818+
}
819+
789820
extension PostgresConnection {
790821
// TODO: Instead of an arbitrary query, make this a structured data structure.
791822
// TODO: Write doc comment
792823
public func copyFrom(
793-
_ query: PostgresQuery,
794-
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void,
824+
table: String,
825+
columns: [StaticString]? = nil,
826+
options: CopyFromOptions = CopyFromOptions(),
795827
logger: Logger,
828+
writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void,
796829
file: String = #fileID,
797830
line: Int = #line
798831
) async throws {
799832
var logger = logger
800833
logger[postgresMetadataKey: .connectionID] = "\(self.id)"
801-
guard query.binds.count <= Int(UInt16.max) else {
802-
throw PSQLError(code: .tooManyParameters, query: query)
803-
}
804-
805834
let writer = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<PostgresCopyFromWriter, any Error>) in
806835
let context = ExtendedQueryContext(
807-
copyFromQuery: query,
836+
copyFromQuery: buildCopyFromQuery(table: table, columns: columns, options: options),
808837
triggerCopy: continuation,
809838
logger: logger
810839
)

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,17 @@ final class IntegrationTests: XCTestCase {
389389

390390
_ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get()
391391
_ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get()
392-
try await conn.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
392+
try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], options: CopyFromOptions(delimiter: ","), logger: .psqlTest) { writer in
393393
let records: [(id: Int, name: String)] = [
394394
(1, "Alice"),
395395
(42, "Bob")
396396
]
397397
for record in records {
398398
var buffer = ByteBuffer()
399-
buffer.writeString("\(record.id)\t\(record.name)\n")
399+
buffer.writeString("\(record.id),\(record.name)\n")
400400
try await writer.write(buffer)
401401
}
402-
}, logger: .psqlTest)
402+
}
403403
let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) }
404404
guard rows.count == 2 else {
405405
XCTFail("Expected 2 columns, received \(rows.count)")

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -628,9 +628,9 @@ class PostgresConnectionTests: XCTestCase {
628628

629629
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
630630
taskGroup.addTask {
631-
try await connection.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
631+
try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in
632632
try await writer.write(ByteBuffer(staticString: "1\tAlice\n"))
633-
}, logger: .psqlTest)
633+
}
634634
}
635635

636636
let copyMessage = try await channel.waitForUnpreparedRequest()
@@ -656,9 +656,9 @@ class PostgresConnectionTests: XCTestCase {
656656
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
657657
taskGroup.addTask {
658658
await assertThrowsError(
659-
try await connection.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
659+
try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in
660660
throw MyError()
661-
}, logger: .psqlTest)
661+
}
662662
) { error in
663663
XCTAssert(error is MyError, "Expected error of type MyError, got \(error)")
664664
}
@@ -691,9 +691,9 @@ class PostgresConnectionTests: XCTestCase {
691691
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
692692
taskGroup.addTask {
693693
await assertThrowsError(
694-
try await connection.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
694+
try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in
695695
try await writer.write(ByteBuffer(staticString: "1Alice\n"))
696-
}, logger: .psqlTest)
696+
}
697697
) { error in
698698
XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02")
699699
}
@@ -724,11 +724,11 @@ class PostgresConnectionTests: XCTestCase {
724724
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
725725
taskGroup.addTask {
726726
await assertThrowsError(
727-
try await connection.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
727+
try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in
728728
try await writer.write(ByteBuffer(staticString: "1Alice\n"))
729729
channel.flush()
730730
_ = await XCTWaiter.fulfillment(of: [backendDidSendErrorExpectation])
731-
}, logger: .psqlTest)
731+
}
732732
) { error in
733733
XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02")
734734
}
@@ -764,7 +764,7 @@ class PostgresConnectionTests: XCTestCase {
764764
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
765765
taskGroup.addTask {
766766
await assertThrowsError(
767-
try await connection.copyFrom("COPY copy_table FROM STDIN", writeData: { writer in
767+
try await connection.copyFrom(table: "copy_table", logger: .psqlTest) { writer in
768768
try await writer.write(ByteBuffer(staticString: "1Alice\n"))
769769
channel.flush()
770770
_ = await XCTWaiter.fulfillment(of: [expectation])
@@ -775,7 +775,7 @@ class PostgresConnectionTests: XCTestCase {
775775
XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)")
776776
throw error
777777
}
778-
}, logger: .psqlTest)
778+
}
779779
) { error in
780780
XCTAssertEqual((error as? PSQLError)?.serverInfo?.underlying.fields[.sqlState], "22P02")
781781
}
@@ -803,6 +803,29 @@ class PostgresConnectionTests: XCTestCase {
803803
}
804804
}
805805

806+
func testCopyDataWithOptions() async throws {
807+
let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel()
808+
809+
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
810+
taskGroup.addTask {
811+
try await connection.copyFrom(table: "copy_table", columns: ["id", "name"], options: CopyFromOptions(delimiter: ","), logger: .psqlTest) { writer in
812+
try await writer.write(ByteBuffer(staticString: "1,Alice\n"))
813+
}
814+
}
815+
816+
let copyMessage = try await channel.waitForUnpreparedRequest()
817+
XCTAssertEqual(copyMessage.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (DELIMITER ',')")
818+
XCTAssertEqual(copyMessage.bind.parameters, [])
819+
try await channel.sendUnpreparedRequestWithNoParametersBindResponse()
820+
try await channel.sendCopyInResponseForTwoTextualColumns()
821+
let data = try await channel.waitForCopyData()
822+
XCTAssertEqual(String(buffer: data.data), "1,Alice\n")
823+
XCTAssertEqual(data.result, .done)
824+
try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1"))
825+
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
826+
}
827+
}
828+
806829
func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) {
807830
let eventLoop = NIOAsyncTestingEventLoop()
808831
let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in

0 commit comments

Comments
 (0)