Skip to content

Commit 8b9afec

Browse files
Backport #2364 (#2422)
1 parent 64e6ee5 commit 8b9afec

File tree

5 files changed

+153
-4
lines changed

5 files changed

+153
-4
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
3+
* available under the terms of the MIT License. See the LICENSE file in the project root for more information.
4+
*/
5+
package com.microsoft.sqlserver.jdbc;
6+
7+
/**
8+
* This functional interface represents a listener which is called before a reconnect of {@link SQLServerConnection}.
9+
*/
10+
@FunctionalInterface
11+
public interface ReconnectListener {
12+
13+
void beforeReconnect();
14+
15+
}

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,19 @@ SQLServerPooledConnection getPooledConnectionParent() {
17561756
return pooledConnectionParent;
17571757
}
17581758

1759+
/**
1760+
* List of listeners which are called before reconnecting.
1761+
*/
1762+
private List<ReconnectListener> reconnectListeners = new ArrayList<>();
1763+
1764+
public void registerBeforeReconnectListener(ReconnectListener reconnectListener) {
1765+
reconnectListeners.add(reconnectListener);
1766+
}
1767+
1768+
public void removeBeforeReconnectListener(ReconnectListener reconnectListener) {
1769+
reconnectListeners.remove(reconnectListener);
1770+
}
1771+
17591772
SQLServerConnection(String parentInfo) {
17601773
int connectionID = nextConnectionID(); // sequential connection id
17611774
traceID = "ConnectionID:" + connectionID;
@@ -4347,6 +4360,8 @@ boolean executeCommand(TDSCommand newCommand) throws SQLServerException {
43474360
preparedStatementHandleCache.clear();
43484361
}
43494362

4363+
this.reconnectListeners.forEach(ReconnectListener::beforeReconnect);
4364+
43504365
if (loggerResiliency.isLoggable(Level.FINE)) {
43514366
loggerResiliency.fine(toString()
43524367
+ " Idle connection resiliency - starting idle connection resiliency reconnect.");

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
224224
*/
225225
private Vector<CryptoMetadata> cryptoMetaBatch = new Vector<>();
226226

227+
/**
228+
* Listener to clear the {@link SQLServerPreparedStatement#prepStmtHandle} and
229+
* {@link SQLServerPreparedStatement#cachedPreparedStatementHandle} before reconnecting.
230+
*/
231+
private ReconnectListener clearPrepStmtHandleOnReconnectListener;
232+
227233
/**
228234
* Constructs a SQLServerPreparedStatement.
229235
*
@@ -244,6 +250,9 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
244250
SQLServerStatementColumnEncryptionSetting stmtColEncSetting) throws SQLServerException {
245251
super(conn, nRSType, nRSConcur, stmtColEncSetting);
246252

253+
clearPrepStmtHandleOnReconnectListener = this::clearPrepStmtHandle;
254+
connection.registerBeforeReconnectListener(clearPrepStmtHandleOnReconnectListener);
255+
247256
if (null == sql) {
248257
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_NullValue"));
249258
Object[] msgArgs1 = {"Statement SQL"};
@@ -280,6 +289,8 @@ private boolean resetPrepStmtHandle(boolean discardCurrentCacheItem) {
280289
* Closes the prepared statement's prepared handle.
281290
*/
282291
private void closePreparedHandle() {
292+
connection.removeBeforeReconnectListener(clearPrepStmtHandleOnReconnectListener);
293+
283294
if (!hasPreparedStatementHandle())
284295
return;
285296

@@ -3569,4 +3580,12 @@ public void addBatch(String sql) throws SQLServerException {
35693580
Object[] msgArgs = {"addBatch()"};
35703581
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
35713582
}
3583+
3584+
private void clearPrepStmtHandle() {
3585+
prepStmtHandle = 0;
3586+
cachedPreparedStatementHandle = null;
3587+
if (getStatementLogger().isLoggable(Level.FINER)) {
3588+
getStatementLogger().finer(toString() + " cleared cachedPrepStmtHandle!");
3589+
}
3590+
}
35723591
}

src/test/java/com/microsoft/sqlserver/jdbc/connection/RequestBoundaryMethodsTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ private List<String> getVerifiedMethodNames() {
525525
verifiedMethodNames.add("setUseFlexibleCallableStatements");
526526
verifiedMethodNames.add("getCalcBigDecimalPrecision");
527527
verifiedMethodNames.add("setCalcBigDecimalPrecision");
528+
verifiedMethodNames.add("registerBeforeReconnectListener");
529+
verifiedMethodNames.add("removeBeforeReconnectListener");
528530
return verifiedMethodNames;
529531
}
530532
}

src/test/java/com/microsoft/sqlserver/jdbc/resiliency/BasicConnectionTest.java

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,35 @@
66
package com.microsoft.sqlserver.jdbc.resiliency;
77

88
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotEquals;
910
import static org.junit.Assert.assertTrue;
1011
import static org.junit.Assert.fail;
1112

1213
import java.lang.reflect.Field;
1314
import java.sql.Connection;
1415
import java.sql.DriverManager;
16+
import java.sql.PreparedStatement;
1517
import java.sql.ResultSet;
1618
import java.sql.SQLException;
1719
import java.sql.Statement;
20+
import java.util.Collections;
21+
import java.util.LinkedList;
22+
import java.util.List;
1823
import java.util.UUID;
1924

2025
import javax.sql.PooledConnection;
2126

27+
import org.junit.jupiter.api.BeforeAll;
28+
import org.junit.jupiter.api.Tag;
29+
import org.junit.jupiter.api.Test;
30+
2231
import com.microsoft.sqlserver.jdbc.RandomUtil;
2332
import com.microsoft.sqlserver.jdbc.SQLServerConnection;
2433
import com.microsoft.sqlserver.jdbc.SQLServerConnectionPoolDataSource;
2534
import com.microsoft.sqlserver.jdbc.SQLServerPooledConnection;
2635
import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement;
2736
import com.microsoft.sqlserver.jdbc.TestResource;
2837
import com.microsoft.sqlserver.jdbc.TestUtils;
29-
import org.junit.jupiter.api.BeforeAll;
30-
import org.junit.jupiter.api.Tag;
31-
import org.junit.jupiter.api.Test;
32-
3338
import com.microsoft.sqlserver.testframework.AbstractTest;
3439
import com.microsoft.sqlserver.testframework.Constants;
3540

@@ -367,6 +372,99 @@ public void testPreparedStatementCacheShouldBeCleared() throws SQLException {
367372
}
368373
}
369374

375+
@Test
376+
public void testPreparedStatementHandleOfStatementShouldBeCleared() throws SQLException {
377+
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {
378+
int cacheSize = 2;
379+
String query = String.format("/*testPreparedStatementHandleOfStatementShouldBeCleared%s*/SELECT 1; -- ",
380+
UUID.randomUUID().toString());
381+
382+
// enable caching
383+
con.setDisableStatementPooling(false);
384+
con.setStatementPoolingCacheSize(cacheSize);
385+
con.setServerPreparedStatementDiscardThreshold(cacheSize);
386+
387+
List<SQLServerPreparedStatement> statements = new LinkedList<>();
388+
389+
// add statements to fill cache
390+
for (int i = 0; i < cacheSize + 1; ++i) {
391+
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) con.prepareStatement(query + i);
392+
pstmt.execute();
393+
pstmt.execute();
394+
pstmt.execute();
395+
pstmt.getMoreResults();
396+
statements.add(pstmt);
397+
}
398+
399+
// handle of the prepared statement should be set
400+
assertNotEquals(0, statements.get(1).getPreparedStatementHandle());
401+
402+
ResiliencyUtils.killConnection(con, connectionString, 1);
403+
404+
// call first statement to trigger reconnect
405+
statements.get(0).execute();
406+
407+
// handle of the other statements should be cleared after reconnect
408+
assertEquals(0, statements.get(1).getPreparedStatementHandle());
409+
}
410+
}
411+
412+
@Test
413+
public void testPreparedStatementShouldNotUseWrongHandleAfterReconnect() throws SQLException {
414+
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {
415+
int cacheSize = 3;
416+
String queryOne = "select * from sys.sysusers where name=?;";
417+
String queryTwo = "select * from sys.sysusers where name=? and uid=?;";
418+
String queryThree = "select * from sys.sysusers where name=? and uid=? and islogin=?";
419+
420+
String parameterOne = "name";
421+
int parameterUid = 0;
422+
int parameterIsLogin = 0;
423+
424+
// enable caching
425+
con.setDisableStatementPooling(false);
426+
con.setStatementPoolingCacheSize(cacheSize);
427+
con.setServerPreparedStatementDiscardThreshold(cacheSize);
428+
429+
List<PreparedStatement> statements = new LinkedList<>();
430+
431+
PreparedStatement ps = con.prepareStatement(queryOne);
432+
ps.setString(1, parameterOne);
433+
statements.add(ps);
434+
435+
ps = con.prepareStatement(queryTwo);
436+
ps.setString(1, parameterOne);
437+
ps.setInt(2, parameterUid);
438+
statements.add(ps);
439+
440+
ps = con.prepareStatement(queryThree);
441+
ps.setString(1, parameterOne);
442+
ps.setInt(2, parameterUid);
443+
ps.setInt(3, parameterIsLogin);
444+
statements.add(ps);
445+
446+
// add new statements to fill cache
447+
for (PreparedStatement preparedStatement : statements) {
448+
preparedStatement.execute();
449+
preparedStatement.execute();
450+
preparedStatement.execute();
451+
preparedStatement.getMoreResults();
452+
}
453+
454+
ResiliencyUtils.killConnection(con, connectionString, 1);
455+
456+
// call statements in reversed order, in order to force the statement to use the wrong handle
457+
// first execute triggers a reconnect
458+
Collections.reverse(statements);
459+
for (PreparedStatement preparedStatement : statements) {
460+
preparedStatement.execute();
461+
preparedStatement.execute();
462+
preparedStatement.execute();
463+
preparedStatement.getMoreResults();
464+
}
465+
}
466+
}
467+
370468
@Test
371469
public void testUnprocessedResponseCountSuccessfulIdleConnectionRecovery() throws SQLException {
372470
try (SQLServerConnection con = (SQLServerConnection) ResiliencyUtils.getConnection(connectionString)) {

0 commit comments

Comments
 (0)