Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ final InetSocketAddress open(String host, int port, int timeoutMillis, boolean u
/**
* Set TCP keep-alive options for idle connection resiliency
*/
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) throws IOException {
private void setSocketOptions(Socket tcpSocket, TDSChannel channel) {
try {
if (SQLServerDriver.socketSetOptionMethod != null && SQLServerDriver.socketKeepIdleOption != null
&& SQLServerDriver.socketKeepIntervalOption != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1605,7 +1605,7 @@ SQLServerPooledConnection getPooledConnectionParent() {
return pooledConnectionParent;
}

SQLServerConnection(String parentInfo) throws SQLServerException {
SQLServerConnection(String parentInfo) {
int connectionID = nextConnectionID(); // sequential connection id
traceID = "ConnectionID:" + connectionID;
loggingClassName += ":" + connectionID;
Expand Down Expand Up @@ -6963,7 +6963,7 @@ public java.sql.Struct createStruct(String typeName, Object[] attributes) throws
return null;
}

String getTrustedServerNameAE() throws SQLServerException {
String getTrustedServerNameAE() {
return trustedServerNameAE.toUpperCase();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ private String generateAzureDWSelect(ResultSet rs, Map<Integer, String> columns)
return sb.toString();
}

private String generateAzureDWEmptyRS(Map<Integer, String> columns) throws SQLException {
private String generateAzureDWEmptyRS(Map<Integer, String> columns) {
StringBuilder sb = new StringBuilder("SELECT TOP 0 ");
for (Entry<Integer, String> p : columns.entrySet()) {
sb.append("NULL AS ").append(p.getValue()).append(",");
Expand Down Expand Up @@ -950,7 +950,7 @@ public java.sql.ResultSet getBestRowIdentifier(String catalog, String schema, St

@Override
public java.sql.ResultSet getCrossReference(String cat1, String schem1, String tab1, String cat2, String schem2,
String tab2) throws SQLException, SQLTimeoutException {
String tab2) throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
Expand Down Expand Up @@ -1014,8 +1014,7 @@ public String getDriverVersion() throws SQLServerException {
}

@Override
public java.sql.ResultSet getExportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getExportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(cat, schema, table, null, null, null);
}

Expand All @@ -1032,12 +1031,11 @@ public String getIdentifierQuoteString() throws SQLServerException {
}

@Override
public java.sql.ResultSet getImportedKeys(String cat, String schema,
String table) throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getImportedKeys(String cat, String schema, String table) throws SQLException {
return getCrossReference(null, null, null, cat, schema, table);
}

private ResultSet executeSPFkeys(String[] procParams) throws SQLException, SQLTimeoutException {
private ResultSet executeSPFkeys(String[] procParams) throws SQLException {
if (!this.connection.isAzureDW()) {
String tempTableName = "@jdbc_temp_fkeys_result";
String sql = "DECLARE " + tempTableName + " table (PKTABLE_QUALIFIER sysname, " + "PKTABLE_OWNER sysname, "
Expand Down Expand Up @@ -1219,7 +1217,7 @@ public int getMaxColumnsInTable() throws SQLServerException {
}

@Override
public int getMaxConnections() throws SQLException, SQLTimeoutException {
public int getMaxConnections() throws SQLException {
checkClosed();
try (SQLServerResultSet rs = getResultSetFromInternalQueries(null,
"select maximum from sys.configurations where name = 'user connections'")) {
Expand Down Expand Up @@ -1439,7 +1437,7 @@ public ResultSet getPseudoColumns(String catalog, String schemaPattern, String t
}

@Override
public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getSchemas() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
Expand All @@ -1448,8 +1446,7 @@ public java.sql.ResultSet getSchemas() throws SQLException, SQLTimeoutException

}

private java.sql.ResultSet getSchemasInternal(String catalog,
String schemaPattern) throws SQLException, SQLTimeoutException {
private java.sql.ResultSet getSchemasInternal(String catalog, String schemaPattern) throws SQLException {

String s;
// The schemas that return null for catalog name, these are prebuilt
Expand Down Expand Up @@ -1606,14 +1603,13 @@ public java.sql.ResultSet getTablePrivileges(String catalog, String schema,
}

@Override
public java.sql.ResultSet getTableTypes() throws SQLException, SQLTimeoutException {
public java.sql.ResultSet getTableTypes() throws SQLException {
if (loggerExternal.isLoggable(Level.FINER) && Util.isActivityTraceOn()) {
loggerExternal.finer(toString() + ACTIVITY_ID + ActivityCorrelator.getNext().toString());
}
checkClosed();
String s = "SELECT 'VIEW' 'TABLE_TYPE' UNION SELECT 'TABLE' UNION SELECT 'SYSTEM TABLE'";
SQLServerResultSet rs = getResultSetFromInternalQueries(null, s);
return rs;
return getResultSetFromInternalQueries(null, s);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
import com.microsoft.aad.msal4j.InteractiveRequestParameters;
import com.microsoft.aad.msal4j.MsalInteractionRequiredException;
import com.microsoft.aad.msal4j.MsalThrottlingException;
import com.microsoft.aad.msal4j.PublicClientApplication;
import com.microsoft.aad.msal4j.SilentParameters;
import com.microsoft.aad.msal4j.SystemBrowserOptions;
Expand Down Expand Up @@ -74,7 +75,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -108,7 +109,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, aadPrincipalID, authenticationString);
} finally {
executorService.shutdown();
Expand Down Expand Up @@ -150,7 +151,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, "", authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -177,11 +178,14 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
StringBuilder acc = new StringBuilder();
if (accountsInCache != null) {
for (IAccount account : accountsInCache) {
if (acc.length() != 0) acc.append(", ");
if (acc.length() != 0) {
acc.append(", ");
}
acc.append(account.username());
}
}
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = " + (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
logger.fine(logger.toString() + "Accounts in cache = " + acc + ", size = "
+ (accountsInCache == null ? null : accountsInCache.size()) + ", user = " + user);
}
if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) {
IAccount account = getAccountByUsername(accountsInCache, user);
Expand Down Expand Up @@ -228,7 +232,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu
Thread.currentThread().interrupt();

throw new SQLServerException(e.getMessage(), e);
} catch (ExecutionException e) {
} catch (MsalThrottlingException | ExecutionException e) {
throw getCorrectedException(e, user, authenticationString);
} finally {
executorService.shutdown();
Expand All @@ -247,8 +251,7 @@ private static IAccount getAccountByUsername(Set<IAccount> accounts, String user
return null;
}

private static SQLServerException getCorrectedException(ExecutionException e, String user,
String authenticationString) {
private static SQLServerException getCorrectedException(Exception e, String user, String authenticationString) {
Object[] msgArgs = {user, authenticationString};

if (null == e.getCause() || null == e.getCause().getMessage()) {
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/microsoft/sqlserver/jdbc/TestResource.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,5 +201,6 @@ protected Object[][] getContents() {
{"R_objectNullOrEmpty", "The {0} is null or empty."},
{"R_cekDecryptionFailed", "Failed to decrypt a column encryption key using key store provider: {0}."},
{"R_connectTimedOut", "connect timed out"},
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"}};
{"R_sessionKilled", "Cannot continue the execution because the session is in the kill state"},
{"R_failedFedauth", "Failed to acquire fedauth token: "}};
}
40 changes: 32 additions & 8 deletions src/test/java/com/microsoft/sqlserver/jdbc/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import org.junit.Assert;

import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ClientCredentialParameters;
Expand Down Expand Up @@ -91,6 +94,8 @@ public final class TestUtils {
static final int ENGINE_EDITION_FOR_SQL_AZURE_DW = 6;
static final int ENGINE_EDITION_FOR_SQL_AZURE_MI = 8;

public static final int TEST_TOKEN_EXPIRY_SECONDS = 120; // token expiry time in secs

static String applicationKey;
static String applicationClientID;

Expand All @@ -108,18 +113,20 @@ public final class TestUtils {
public static boolean expireTokenToggle = false;

public static final SQLServerAccessTokenCallback accessTokenCallback = new SQLServerAccessTokenCallback() {
@Override public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
@Override
public SqlAuthenticationToken getAccessToken(String spn, String stsurl) {
String scope = spn + "/.default";
Set<String> scopes = new HashSet<>();
scopes.add(scope);

try {
ExecutorService executorService = Executors.newSingleThreadExecutor();
IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey);
ConfidentialClientApplication clientApplication = ConfidentialClientApplication.builder(
applicationClientID, credential).executorService(executorService).authority(stsurl).build();
CompletableFuture<IAuthenticationResult> future = clientApplication.acquireToken(
ClientCredentialParameters.builder(scopes).build());
ConfidentialClientApplication clientApplication = ConfidentialClientApplication
.builder(applicationClientID, credential).executorService(executorService).authority(stsurl)
.build();
CompletableFuture<IAuthenticationResult> future = clientApplication
.acquireToken(ClientCredentialParameters.builder(scopes).build());

IAuthenticationResult authenticationResult = future.get();
String accessToken = authenticationResult.accessToken();
Expand All @@ -139,6 +146,21 @@ public final class TestUtils {
}
};

public static void setAccessTokenExpiry(Object con, String accessToken) {
Field fedAuthTokenField;
try {
fedAuthTokenField = SQLServerConnection.class.getDeclaredField("fedAuthToken");
fedAuthTokenField.setAccessible(true);

Date newExpiry = new Date(
System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(TEST_TOKEN_EXPIRY_SECONDS));
SqlAuthenticationToken newFedAuthToken = new SqlAuthenticationToken(accessToken, newExpiry);
fedAuthTokenField.set(con, newFedAuthToken);
} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
Assert.fail("Failed to set token expiry: " + e.getMessage());
}
}

private TestUtils() {}

/**
Expand Down Expand Up @@ -440,7 +462,8 @@ public static void dropTypeIfExists(String typeName, java.sql.Statement stmt) th
* @throws SQLException
*/
public static void dropUserDefinedTypeIfExists(String typeName, Statement stmt) throws SQLException {
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName) + "') DROP TYPE " + typeName);
stmt.executeUpdate("IF EXISTS (select * from sys.types where name = '" + escapeSingleQuotes(typeName)
+ "') DROP TYPE " + typeName);
}

/**
Expand Down Expand Up @@ -1029,12 +1052,13 @@ private static java.security.cert.Certificate getCertificate(String certname) th
}
}

public static String getConnectionID(SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
public static String getConnectionID(
SQLServerPooledConnection pc) throws ClassNotFoundException, NoSuchFieldException, IllegalAccessException {
Class<?> pooledConnection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerPooledConnection");
Class<?> connection = Class.forName("com.microsoft.sqlserver.jdbc.SQLServerConnection");

Field physicalConnection = pooledConnection.getDeclaredField("physicalConnection");
Field traceID = connection.getDeclaredField("traceID");
Field traceID = connection.getDeclaredField("traceID");

physicalConnection.setAccessible(true);
traceID.setAccessible(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ private void testAccessTokenExpiredThenCreateNewStatement(SqlAuthentication auth
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 2 minutes
Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s

secondsPassed = (System.currentTimeMillis() - start) / 1000;

try (Statement stmt1 = connection.createStatement()) {
testUserName(connection, azureUserName, authentication);

Expand Down Expand Up @@ -150,10 +153,12 @@ private void testAccessTokenExpiredThenExecuteUsingSameStatement(
TestUtils.dropTableIfExists(charTable, stmt);
}

TestUtils.setAccessTokenExpiry(connection, accessToken);
secondsBeforeExpiration = TestUtils.TEST_TOKEN_EXPIRY_SECONDS;
while (secondsPassed < secondsBeforeExpiration) {
Thread.sleep(TimeUnit.MINUTES.toMillis(5)); // Sleep for 5 minutes

Thread.sleep(TimeUnit.SECONDS.toMillis(90)); // Sleep for 90s
secondsPassed = (System.currentTimeMillis() - start) / 1000;

testUserName(connection, azureUserName, authentication);
}

Expand Down
Loading