Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1005,22 +1005,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_len
ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
<< ", query_text: " << static_cast<const void*>(query_text)
<< ", text_length: " << text_length;
// GH-47711 TODO: Implement SQLExecDirect
return SQL_INVALID_HANDLE;

using ODBC::ODBCStatement;
// The driver is built to handle SELECT statements only.
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
std::string query = ODBC::SqlWcharToString(query_text, text_length);

statement->Prepare(query);
statement->ExecutePrepared();
Comment on lines +1015 to +1016
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we implemented SQLExecDirect to use prepare/execute to use prepared statements to align with the behavior with JDBC


return SQL_SUCCESS;
});
}

SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER text_length) {
ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
<< ", query_text: " << static_cast<const void*>(query_text)
<< ", text_length: " << text_length;
// GH-47712 TODO: Implement SQLPrepare
return SQL_INVALID_HANDLE;

using ODBC::ODBCStatement;
// The driver is built to handle SELECT statements only.
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
std::string query = ODBC::SqlWcharToString(query_text, text_length);

statement->Prepare(query);

return SQL_SUCCESS;
});
}

SQLRETURN SQLExecute(SQLHSTMT stmt) {
ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
// GH-47712 TODO: Implement SQLExecute
return SQL_INVALID_HANDLE;

using ODBC::ODBCStatement;
// The driver is built to handle SELECT statements only.
return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);

statement->ExecutePrepared();

return SQL_SUCCESS;
});
}

SQLRETURN SQLFetch(SQLHSTMT stmt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& diagnostics,
call_options_.timeout = TimeoutDuration{-1};
}

FlightSqlStatement::~FlightSqlStatement() {
ClosePreparedStatementIfAny(prepared_statement_, call_options_);
}

bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
const Attribute& value) {
switch (attribute) {
Expand Down Expand Up @@ -119,7 +123,6 @@ bool FlightSqlStatement::ExecutePrepared() {

Result<std::shared_ptr<FlightInfo>> result =
prepared_statement_->Execute(call_options_);

ThrowIfNotOK(result.status());

current_result_set_ = std::make_shared<FlightSqlResultSet>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class FlightSqlStatement : public Statement {
FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& sql_client,
FlightClientOptions client_options, FlightCallOptions call_options,
const MetadataSettings& metadata_settings);
~FlightSqlStatement();

bool SetAttribute(StatementAttributeId attribute, const Attribute& value) override;

Expand Down
85 changes: 81 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,86 @@ class StatementRemoteTest : public FlightSQLODBCRemoteTestBase {};
using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
TYPED_TEST_SUITE(StatementTest, TestTypes);

TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
std::wstring wsql = L"SELECT 1;";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_SUCCESS,
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));

// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we weren't expecting these tests to pass yet? Are there build issues with this not commented out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this PR is just created as draft, we aren't expecting them to pass. But once the pre-requisite ODBC API PRs get merged in main, I am expecting this PR to pass the CI tests after a rebase. And some calls to unimplemented API (SQLFetch and SQLGetData) are commented out because they will be enabled in a different PR

/*
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));

SQLINTEGER val;

ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
// Verify 1 is returned
EXPECT_EQ(1, val);

ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));

ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
// Invalid cursor state
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
*/
}

TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
std::wstring wsql = L"SELECT;";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_ERROR,
SQLExecDirect(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
// ODBC provides generic error code HY000 to all statement errors
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
}

TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
std::wstring wsql = L"SELECT 1;";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_SUCCESS,
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));

ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));

// GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
/*
// Fetch data
ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));

SQLINTEGER val;
ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));

// Verify 1 is returned
EXPECT_EQ(1, val);

ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));

ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
// Invalid cursor state
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
*/
}

TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
std::wstring wsql = L"SELECT;";
std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());

ASSERT_EQ(SQL_ERROR,
SQLPrepare(this->stmt, &sql0[0], static_cast<SQLINTEGER>(sql0.size())));
// ODBC provides generic error code HY000 to all statement errors
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);

ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
// Verify function sequence error state is returned
VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
}

TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
SQLWCHAR buf[1024];
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
SQLINTEGER output_char_len = 0;
Expand All @@ -58,7 +135,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {

TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) {
SQLWCHAR buf[1024];
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
SQLINTEGER output_char_len = 0;
Expand Down Expand Up @@ -95,7 +172,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputStringLength) {
TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {
const SQLINTEGER small_buf_size_in_char = 11;
SQLWCHAR small_buf[small_buf_size_in_char];
SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize();
SQLINTEGER small_buf_char_len = sizeof(small_buf) / GetSqlWCharSize();
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
SQLINTEGER output_char_len = 0;
Expand All @@ -122,7 +199,7 @@ TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {

TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
SQLWCHAR buf[1024];
SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
SQLINTEGER output_char_len = 0;
Expand Down
Loading