Skip to content

Commit 72eb3f1

Browse files
divangDivang SharmaDavid-Engel
authored
Fixed session recovery for Azure SQL DB in redirect mode connected with AAD auth (#2668)
* In case of federated authentication request, do not terminate the connection to resolve Fed auth connection resiliency issue * In place of serverSupportsDNSCaching, routingInfo value is a better choice to take decision on connection termination. If endpoint is routed then no need to terminate the current connection. * Added mock test for connectionReconveryCheck() method --------- Co-authored-by: Divang Sharma <[email protected]> Co-authored-by: David Engel <[email protected]>
1 parent c2a6ba4 commit 72eb3f1

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7252,27 +7252,33 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ
72527252

72537253
LogonProcessor logonProcessor = new LogonProcessor(authentication);
72547254
TDSReader tdsReader;
7255+
sessionRecovery.setConnectionRecoveryPossible(false);
72557256
do {
72567257
tdsReader = logonCommand.startResponse();
7257-
sessionRecovery.setConnectionRecoveryPossible(false);
72587258
TDSParser.parse(tdsReader, logonProcessor);
72597259
} while (!logonProcessor.complete(logonCommand, tdsReader));
72607260

7261-
if (sessionRecovery.isReconnectRunning() && !sessionRecovery.isConnectionRecoveryPossible()) {
7261+
connectionReconveryCheck(sessionRecovery.isReconnectRunning(), sessionRecovery.isConnectionRecoveryPossible(),
7262+
routingInfo);
7263+
7264+
if (connectRetryCount > 0 && !sessionRecovery.isReconnectRunning()) {
7265+
sessionRecovery.getSessionStateTable().setOriginalCatalog(sCatalog);
7266+
sessionRecovery.getSessionStateTable().setOriginalCollation(databaseCollation);
7267+
sessionRecovery.getSessionStateTable().setOriginalLanguage(sLanguage);
7268+
}
7269+
}
7270+
7271+
private void connectionReconveryCheck(boolean isReconnectRunning, boolean isConnectionRecoveryPossible,
7272+
ServerPortPlaceHolder routingDetails) throws SQLServerException {
7273+
if (isReconnectRunning && !isConnectionRecoveryPossible && routingDetails == null) {
72627274
if (connectionlogger.isLoggable(Level.WARNING)) {
72637275
connectionlogger.warning(this.toString()
72647276
+ "SessionRecovery feature extension ack was not sent by the server during reconnection.");
72657277
}
72667278
terminate(SQLServerException.DRIVER_ERROR_INVALID_TDS,
72677279
SQLServerException.getErrString("R_crClientNoRecoveryAckFromLogin"));
72687280
}
7269-
if (connectRetryCount > 0 && !sessionRecovery.isReconnectRunning()) {
7270-
sessionRecovery.getSessionStateTable().setOriginalCatalog(sCatalog);
7271-
sessionRecovery.getSessionStateTable().setOriginalCollation(databaseCollation);
7272-
sessionRecovery.getSessionStateTable().setOriginalLanguage(sLanguage);
7273-
}
72747281
}
7275-
72767282
/* --------------- JDBC 3.0 ------------- */
72777283

72787284
/**

src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,21 @@
99
import static org.junit.jupiter.api.Assertions.assertThrows;
1010
import static org.junit.jupiter.api.Assertions.assertTrue;
1111
import static org.junit.jupiter.api.Assertions.fail;
12+
import static org.mockito.ArgumentMatchers.anyInt;
13+
import static org.mockito.ArgumentMatchers.anyString;
14+
import static org.mockito.ArgumentMatchers.eq;
15+
import static org.mockito.Mockito.doNothing;
16+
import static org.mockito.Mockito.doReturn;
17+
import static org.mockito.Mockito.mock;
18+
import static org.mockito.Mockito.never;
19+
import static org.mockito.Mockito.spy;
20+
import static org.mockito.Mockito.times;
21+
import static org.mockito.Mockito.verify;
1222

1323
import java.io.IOException;
1424
import java.io.Reader;
1525
import java.lang.management.ManagementFactory;
26+
import java.lang.reflect.Method;
1627
import java.sql.Clob;
1728
import java.sql.Connection;
1829
import java.sql.DriverManager;
@@ -29,6 +40,7 @@
2940
import java.util.concurrent.ExecutorService;
3041
import java.util.concurrent.Executors;
3142
import java.util.concurrent.Future;
43+
import java.util.logging.Level;
3244
import java.util.logging.Logger;
3345

3446
import javax.sql.ConnectionEvent;
@@ -60,6 +72,9 @@ public class SQLServerConnectionTest extends AbstractTest {
6072

6173
String randomServer = RandomUtil.getIdentifier("Server");
6274

75+
SQLServerConnection mockConnection;
76+
Logger mockLogger;
77+
6378
@BeforeAll
6479
public static void setupTests() throws Exception {
6580
setConnection();
@@ -1528,6 +1543,49 @@ public void testManagedIdentityWithEncryptStrict() {
15281543
} catch (SQLException e) {
15291544
fail("Connection failed: " + e.getMessage());
15301545
}
1531-
}
1546+
}
1547+
1548+
public Method mockedConnectionRecoveryCheck() throws Exception {
1549+
mockConnection = spy(new SQLServerConnection("test"));
1550+
mockLogger = mock(Logger.class);
1551+
doReturn(true).when(mockLogger).isLoggable(Level.WARNING);
1552+
doNothing().when(mockConnection).terminate(anyInt(), anyString());
1553+
1554+
Method method = SQLServerConnection.class.getDeclaredMethod("connectionReconveryCheck", boolean.class,
1555+
boolean.class, ServerPortPlaceHolder.class);
1556+
method.setAccessible(true);
1557+
return method;
1558+
}
1559+
1560+
@Test
1561+
void testConnectionRecoveryCheckThrowsWhenAllConditionsMet() throws Exception {
1562+
Method method = mockedConnectionRecoveryCheck();
1563+
method.invoke(mockConnection, true, false, null);
1564+
verify(mockConnection, times(1)).terminate(eq(SQLServerException.DRIVER_ERROR_INVALID_TDS),
1565+
eq(SQLServerException.getErrString("R_crClientNoRecoveryAckFromLogin")));
1566+
}
1567+
1568+
@Test
1569+
void testConnectionRecoveryCheckDoesNotThrowWhenNotReconnectRunning() throws Exception {
1570+
Method method = mockedConnectionRecoveryCheck();
1571+
method.invoke(mockConnection, false, false, null);
1572+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1573+
}
1574+
1575+
@Test
1576+
void testConnectionRecoveryCheckDoesNotThrowWhenRecoveryPossible() throws Exception {
1577+
Method method = mockedConnectionRecoveryCheck();
1578+
method.invoke(mockConnection, true, true, null);
1579+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1580+
}
1581+
1582+
@Test
1583+
void testConnectionRecoveryCheckDoesNotThrowWhenRoutingDetailsNotNull() throws Exception {
1584+
Method method = mockedConnectionRecoveryCheck();
1585+
ServerPortPlaceHolder routingDetails = mock(ServerPortPlaceHolder.class);
1586+
method.setAccessible(true);
1587+
method.invoke(mockConnection, true, false, routingDetails);
1588+
verify(mockConnection, never()).terminate(anyInt(), anyString());
1589+
}
15321590

15331591
}

0 commit comments

Comments
 (0)