Skip to content

Commit 125bf5e

Browse files
committed
fix fedauth tests
1 parent 4a27a0d commit 125bf5e

File tree

14 files changed

+205
-116
lines changed

14 files changed

+205
-116
lines changed

src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ final InetSocketAddress open(String host, int port, int timeoutMillis, boolean u
735735
/**
736736
* Set TCP keep-alive options for idle connection resiliency
737737
*/
738-
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) throws IOException {
738+
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) {
739739
try {
740740
if (SQLServerDriver.socketSetOptionMethod != null && SQLServerDriver.socketKeepIdleOption != null
741741
&& SQLServerDriver.socketKeepIntervalOption != null) {

src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerEnclaveProvider.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import java.util.ArrayList;
3636
import java.util.Arrays;
3737
import java.util.HashMap;
38-
import java.util.Hashtable;
3938
import java.util.Map;
4039
import java.util.Map.Entry;
4140
import java.util.concurrent.ConcurrentHashMap;

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ SQLServerPooledConnection getPooledConnectionParent() {
16051605
return pooledConnectionParent;
16061606
}
16071607

1608-
SQLServerConnection(String parentInfo) throws SQLServerException {
1608+
SQLServerConnection(String parentInfo) {
16091609
int connectionID = nextConnectionID(); // sequential connection id
16101610
traceID = "ConnectionID:" + connectionID;
16111611
loggingClassName += ":" + connectionID;
@@ -6963,7 +6963,7 @@ public java.sql.Struct createStruct(String typeName, Object[] attributes) throws
69636963
return null;
69646964
}
69656965

6966-
String getTrustedServerNameAE() throws SQLServerException {
6966+
String getTrustedServerNameAE() {
69676967
return trustedServerNameAE.toUpperCase();
69686968
}
69696969

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDatabaseMetaData.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ private String generateAzureDWSelect(ResultSet rs, Map<Integer, String> columns)
809809
return sb.toString();
810810
}
811811

812-
private String generateAzureDWEmptyRS(Map<Integer, String> columns) throws SQLException {
812+
private String generateAzureDWEmptyRS(Map<Integer, String> columns) {
813813
StringBuilder sb = new StringBuilder("SELECT TOP 0 ");
814814
for (Entry<Integer, String> p : columns.entrySet()) {
815815
sb.append("NULL AS ").append(p.getValue()).append(",");
@@ -950,7 +950,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St
950950

951951
@Override
952952
public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2,
953-
String tab2) throws SQLException, SQLTimeoutException {
953+
String tab2) throws SQLException {
954954
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
955955
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
956956
}
@@ -1014,8 +1014,7 @@ public String getDriverVersion() throws SQLServerException {
10141014
}
10151015

10161016
@Override
1017-
public java.sql.ResultSet getExportedKeys(String cat, String schema,
1018-
String table) throws SQLException, SQLTimeoutException {
1017+
public java.sql.ResultSet getExportedKeys(String cat, String schema, String table) throws SQLException {
10191018
return getCrossReference(cat, schema, table, null, null, null);
10201019
}
10211020

@@ -1032,12 +1031,11 @@ public String getIdentifierQuoteString() throws SQLServerException {
10321031
}
10331032

10341033
@Override
1035-
public java.sql.ResultSet getImportedKeys(String cat, String schema,
1036-
String table) throws SQLException, SQLTimeoutException {
1034+
public java.sql.ResultSet getImportedKeys(String cat, String schema, String table) throws SQLException {
10371035
return getCrossReference(null, null, null, cat, schema, table);
10381036
}
10391037

1040-
private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException {
1038+
private ResultSet executeSPFkeys(String[] procParams) throws SQLException {
10411039
if (!this.connection.isAzureDW()) {
10421040
String tempTableName = "@jdbc_temp_fkeys_result";
10431041
String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, "
@@ -1219,7 +1217,7 @@ public int getMaxColumnsInTable() throws SQLServerException {
12191217
}
12201218

12211219
@Override
1222-
public int getMaxConnections() throws SQLException, SQLTimeoutException {
1220+
public int getMaxConnections() throws SQLException {
12231221
checkClosed();
12241222
try (SQLServerResultSet rs = getResultSetFromInternalQueries(null,
12251223
"select maximum from sys.configurations where name = 'user connections'")) {
@@ -1439,7 +1437,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t
14391437
}
14401438

14411439
@Override
1442-
public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException {
1440+
public java.sql.ResultSet getSchemas() throws SQLException {
14431441
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
14441442
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
14451443
}
@@ -1448,8 +1446,7 @@ public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException
14481446

14491447
}
14501448

1451-
private java.sql.ResultSet getSchemasInternal(String catalog,
1452-
String schemaPattern) throws SQLException, SQLTimeoutException {
1449+
private java.sql.ResultSet getSchemasInternal(String catalog, String schemaPattern) throws SQLException {
14531450

14541451
String s;
14551452
// The schemas that return null for catalog name, these are prebuilt
@@ -1606,14 +1603,13 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema,
16061603
}
16071604

16081605
@Override
1609-
public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException {
1606+
public java.sql.ResultSet getTableTypes() throws SQLException {
16101607
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
16111608
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
16121609
}
16131610
checkClosed();
16141611
String s = "SELECT 'VIEW' 'TABLE_TYPE' UNION SELECT 'TABLE' UNION SELECT 'SYSTEM TABLE'";
1615-
SQLServerResultSet rs = getResultSetFromInternalQueries(null, s);
1616-
return rs;
1612+
return getResultSetFromInternalQueries(null, s);
16171613
}
16181614

16191615
@Override

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
2929
import com.microsoft.aad.msal4j.InteractiveRequestParameters;
3030
import com.microsoft.aad.msal4j.MsalInteractionRequiredException;
31+
import com.microsoft.aad.msal4j.MsalThrottlingException;
3132
import com.microsoft.aad.msal4j.PublicClientApplication;
3233
import com.microsoft.aad.msal4j.SilentParameters;
3334
import com.microsoft.aad.msal4j.SystemBrowserOptions;
@@ -74,7 +75,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str
7475
Thread.currentThread().interrupt();
7576

7677
throw new SQLServerException(e.getMessage(), e);
77-
} catch (ExecutionException e) {
78+
} catch (MsalThrottlingException | ExecutionException e) {
7879
throw getCorrectedException(e, user, authenticationString);
7980
} finally {
8081
executorService.shutdown();
@@ -108,7 +109,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth
108109
Thread.currentThread().interrupt();
109110

110111
throw new SQLServerException(e.getMessage(), e);
111-
} catch (ExecutionException e) {
112+
} catch (MsalThrottlingException | ExecutionException e) {
112113
throw getCorrectedException(e, aadPrincipalID, authenticationString);
113114
} finally {
114115
executorService.shutdown();
@@ -150,7 +151,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut
150151
Thread.currentThread().interrupt();
151152

152153
throw new SQLServerException(e.getMessage(), e);
153-
} catch (ExecutionException e) {
154+
} catch (MsalThrottlingException | ExecutionException e) {
154155
throw getCorrectedException(e, "", authenticationString);
155156
} finally {
156157
executorService.shutdown();
@@ -177,11 +178,13 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
177178
StringBuilder acc = new StringBuilder();
178179
if (accountsInCache != null) {
179180
for (IAccount account : accountsInCache) {
180-
if (acc.length() != 0) acc.append(", ");
181+
if (acc.length() != 0)
182+
acc.append(", ");
181183
acc.append(account.username());
182184
}
183185
}
184-
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = " + (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
186+
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = "
187+
+ (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
185188
}
186189
if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) {
187190
IAccount account = getAccountByUsername(accountsInCache, user);
@@ -228,7 +231,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
228231
Thread.currentThread().interrupt();
229232

230233
throw new SQLServerException(e.getMessage(), e);
231-
} catch (ExecutionException e) {
234+
} catch (MsalThrottlingException | ExecutionException e) {
232235
throw getCorrectedException(e, user, authenticationString);
233236
} finally {
234237
executorService.shutdown();
@@ -247,8 +250,7 @@ private static IAccount getAccountByUsername(Set<IAccount> accounts, String user
247250
return null;
248251
}
249252

250-
private static SQLServerException getCorrectedException(ExecutionException e, String user,
251-
String authenticationString) {
253+
private static SQLServerException getCorrectedException(Exception e, String user, String authenticationString) {
252254
Object[] msgArgs = {user, authenticationString};
253255

254256
if (null == e.getCause() || null == e.getCause().getMessage()) {

src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,5 +201,6 @@ protected Object[][] getContents() {
201201
{"R_objectNullOrEmpty", "The {0} is null or empty."},
202202
{"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."},
203203
{"R_connectTimedOut", "connect timed out"},
204-
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}};
204+
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"},
205+
{"R_failedFedauth", "Failed to acquire fedauth token: "}};
205206
}

src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
import java.util.concurrent.CompletableFuture;
4040
import java.util.concurrent.ExecutorService;
4141
import java.util.concurrent.Executors;
42+
import java.util.concurrent.TimeUnit;
43+
44+
import org.junit.Assert;
4245

4346
import com.microsoft.aad.msal4j.ClientCredentialFactory;
4447
import com.microsoft.aad.msal4j.ClientCredentialParameters;
@@ -91,6 +94,8 @@ public final class TestUtils {
9194
static final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6;
9295
static final int ENGINE_EDITION_FOR_SQL_AZURE_MI = 8;
9396

97+
public static final int TEST_TOKEN_EXPIRY_SECONDS = 120; // token expiry time in secs
98+
9499
static String applicationKey;
95100
static String applicationClientID;
96101

@@ -108,18 +113,20 @@ public final class TestUtils {
108113
public static boolean expireTokenToggle = false;
109114

110115
public static final SQLServerAccessTokenCallback accessTokenCallback = new SQLServerAccessTokenCallback() {
111-
@Override public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
116+
@Override
117+
public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
112118
String scope = spn + "/.default";
113119
Set<String> scopes = new HashSet<>();
114120
scopes.add(scope);
115121

116122
try {
117123
ExecutorService executorService = Executors.newSingleThreadExecutor();
118124
IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey);
119-
ConfidentialClientApplication clientApplication = ConfidentialClientApplication.builder(
120-
applicationClientID, credential).executorService(executorService).authority(stsurl).build();
121-
CompletableFuture<IAuthenticationResult> future = clientApplication.acquireToken(
122-
ClientCredentialParameters.builder(scopes).build());
125+
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
126+
.builder(applicationClientID, credential).executorService(executorService).authority(stsurl)
127+
.build();
128+
CompletableFuture<IAuthenticationResult> future = clientApplication
129+
.acquireToken(ClientCredentialParameters.builder(scopes).build());
123130

124131
IAuthenticationResult authenticationResult = future.get();
125132
String accessToken = authenticationResult.accessToken();
@@ -139,6 +146,21 @@ public final class TestUtils {
139146
}
140147
};
141148

149+
public static void setAccessTokenExpiry(Object con, String accessToken) {
150+
Field fedAuthTokenField;
151+
try {
152+
fedAuthTokenField = SQLServerConnection.class.getDeclaredField("fedAuthToken");
153+
fedAuthTokenField.setAccessible(true);
154+
155+
Date newExpiry = new Date(
156+
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(TEST_TOKEN_EXPIRY_SECONDS));
157+
SqlAuthenticationToken newFedAuthToken = new SqlAuthenticationToken(accessToken, newExpiry);
158+
fedAuthTokenField.set(con, newFedAuthToken);
159+
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
160+
Assert.fail("Failed to set token expiry: " + e.getMessage());
161+
}
162+
}
163+
142164
private TestUtils() {}
143165

144166
/**
@@ -440,7 +462,8 @@ public static void dropTypeIfExists(String typeName, java.sql.Statement stmt) th
440462
* @throws SQLException
441463
*/
442464
public static void dropUserDefinedTypeIfExists(String typeName, Statement stmt) throws SQLException {
443-
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + "') DROP TYPE " + typeName);
465+
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName)
466+
+ "') DROP TYPE " + typeName);
444467
}
445468

446469
/**
@@ -1029,12 +1052,13 @@ private static java.security.cert.Certificate getCertificate(String certname) th
10291052
}
10301053
}
10311054

1032-
public static String getConnectionID(SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
1055+
public static String getConnectionID(
1056+
SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
10331057
Class<?> pooledConnection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerPooledConnection");
10341058
Class<?> connection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerConnection");
10351059

10361060
Field physicalConnection = pooledConnection.getDeclaredField("physicalConnection");
1037-
Field traceID = connection.getDeclaredField("traceID");
1061+
Field traceID = connection.getDeclaredField("traceID");
10381062

10391063
physicalConnection.setAccessible(true);
10401064
traceID.setAccessible(true);

src/test/java/com/microsoft/sqlserver/jdbc/fedauth/ConnectionSuspensionTest.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,13 @@ private void testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication auth
8383
TestUtils.dropTableIfExists(charTable, stmt);
8484
}
8585

86+
TestUtils.setAccessTokenExpiry(connection, accessToken);
87+
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
8688
while (secondsPassed < secondsBeforeExpiration) {
87-
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 2 minutes
89+
Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s
8890

8991
secondsPassed = (System.currentTimeMillis() - start) / 1000;
92+
9093
try (Statement stmt1 = connection.createStatement()) {
9194
testUserName(connection, azureUserName, authentication);
9295

@@ -150,10 +153,12 @@ private void testAccessTokenExpiredThenExecuteUsingSameStatement(
150153
TestUtils.dropTableIfExists(charTable, stmt);
151154
}
152155

156+
TestUtils.setAccessTokenExpiry(connection, accessToken);
157+
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
153158
while (secondsPassed < secondsBeforeExpiration) {
154-
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 5 minutes
155-
159+
Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s
156160
secondsPassed = (System.currentTimeMillis() - start) / 1000;
161+
157162
testUserName(connection, azureUserName, authentication);
158163
}
159164

0 commit comments

Comments
 (0)