Skip to content
Merged
1 change: 1 addition & 0 deletions src/main/java/com/microsoft/sqlserver/jdbc/IOBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ final class TDS {
static final byte ADALWORKFLOW_ACTIVEDIRECTORYMANAGEDIDENTITY = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYINTERACTIVE = 0x03;
static final byte ADALWORKFLOW_DEFAULTAZURECREDENTIAL = 0x03;
static final byte ADALWORKFLOW_ACCESSTOKENCALLBACK = 0x03;
static final byte ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL = 0x01; // Using the Password byte as that is the
// closest we have.
static final byte FEDAUTH_INFO_ID_STSURL = 0x01; // FedAuthInfoData is token endpoint URL from which to acquire fed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,13 +411,18 @@ CallableStatement prepareCall(String sql, int nType, int nConcur, int nHold,
/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method will always return 0 and is for backwards compatibility only.
*
* @return Method will always return 0.
*/
@Deprecated
int getMsiTokenCacheTtl();

/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method is a no-op for backwards compatibility only.
*
* @param timeToLive
* Time-to-live is no longer supported.
*/
@Deprecated
void setMsiTokenCacheTtl(int timeToLive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1216,14 +1216,34 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource {
/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method is a no-op for backwards compatibility only.
*
* @param timeToLive
* - Time-to-live is no longer supported.
*/
@Deprecated
void setMsiTokenCacheTtl(int timeToLive);

/**
* Deprecated. Time-to-live is no longer supported for the cached Managed Identity tokens.
* This method will always return 0 and is for backwards compatibility only.
*
* @return Method will always return 0.
*/
@Deprecated
int getMsiTokenCacheTtl();

/**
* Sets the {@link SQLServerAccessTokenCallback} delegate.
*
* @param accessTokenCallback
* Access token callback delegate.
*/
void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback);

/**
* Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate.
*
* @return Access token callback delegate.
*/
SQLServerAccessTokenCallback getAccessTokenCallback();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.microsoft.sqlserver.jdbc;

/**
* Provides SqlAuthenticationToken callback to be implemented by client code.
*/
public interface SQLServerAccessTokenCallback {

/**
* For an example of callback usage, look under the project's code samples.
*
* Returns the access token for the authentication request
*
* @param stsurl
* - Security token service URL.
* @param spn
* - Service principal name.
*
* @return Returns a {@link SqlAuthenticationToken}.
*/
SqlAuthenticationToken getAccessToken(String stsurl, String spn);
}
84 changes: 57 additions & 27 deletions src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,7 @@
import java.text.MessageFormat;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -162,7 +153,7 @@ public class SQLServerConnection implements ISQLServerConnection, java.io.Serial
private byte[] accessTokenInByte = null;

/** fedAuth token */
private SqlFedAuthToken fedAuthToken = null;
private SqlAuthenticationToken fedAuthToken = null;

/** original hostNameInCertificate */
private String originalHostNameInCertificate = null;
Expand Down Expand Up @@ -537,6 +528,13 @@ class FederatedAuthenticationFeatureExtensionData implements Serializable {
this.authentication = SqlAuthentication.ActiveDirectoryInteractive;
break;
default:
// If authenticationString not specified, check if access token callback was set.
// If access token callback is set, break.
if (null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
this.authentication = SqlAuthentication.NotSpecified;
break;
}
assert (false);
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_InvalidConnectionSetting"));
Expand Down Expand Up @@ -1669,7 +1667,7 @@ void checkClosed() throws SQLServerException {
* @return true/false
*/
protected boolean needsReconnect() {
return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.expiresOn));
return (null != fedAuthToken && Util.checkIfNeedNewAccessToken(this, fedAuthToken.getExpiresOn()));
}

/**
Expand Down Expand Up @@ -2423,6 +2421,17 @@ Connection connectInternal(Properties propsIn,
ntlmAuthentication = true;
}

SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString());

if (null != callback && (!activeConnectionProperties
.getProperty(SQLServerDriverStringProperty.USER.toString()).isEmpty()
|| !activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
throw new SQLServerException(
SQLServerException.getErrString("R_AccessTokenCallbackWithUserPassword"), null);
}

sPropKey = SQLServerDriverStringProperty.AUTHENTICATION.toString();
sPropValue = activeConnectionProperties.getProperty(sPropKey);
if (null == sPropValue) {
Expand All @@ -2434,7 +2443,7 @@ Connection connectInternal(Properties propsIn,
&& (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_MSIAuthenticationWithPassword"));
SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null);
}

Expand Down Expand Up @@ -2466,7 +2475,7 @@ Connection connectInternal(Properties propsIn,
&& (!activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString())
.isEmpty())) {
MessageFormat form = new MessageFormat(
SQLServerException.getErrString("R_MSIAuthenticationWithPassword"));
SQLServerException.getErrString("R_ManagedIdentityAuthenticationWithPassword"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null);
}

Expand Down Expand Up @@ -3464,7 +3473,8 @@ private void executeReconnect(LogonCommand logonCommand) throws SQLServerExcepti
void prelogin(String serverName, int portNumber) throws SQLServerException {
// Build a TDS Pre-Login packet to send to the server.
if ((!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString()))
|| (null != accessTokenInByte)) {
|| (null != accessTokenInByte) || null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
fedAuthRequiredByUser = true;
}

Expand Down Expand Up @@ -3846,7 +3856,8 @@ void prelogin(String serverName, int portNumber) throws SQLServerException {
// Or AccessToken is not null, mean token based authentication is used.
if (((null != authenticationString)
&& (!authenticationString.equalsIgnoreCase(SqlAuthentication.NotSpecified.toString())))
|| (null != accessTokenInByte)) {
|| (null != accessTokenInByte) || null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
fedAuthRequiredPreLoginResponse = (preloginResponse[optionOffset] == 1);
}
break;
Expand Down Expand Up @@ -4790,6 +4801,12 @@ int writeFedAuthFeatureRequest(boolean write, /* if false just calculates the le
workflow = TDS.ADALWORKFLOW_ACTIVEDIRECTORYSERVICEPRINCIPAL;
break;
default:
// If not specified, check if access token callback was set. If it is set, break.
if (null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
workflow = TDS.ADALWORKFLOW_ACCESSTOKENCALLBACK;
break;
}
assert (false); // Unrecognized Authentication type for fedauth ADAL request
break;
}
Expand Down Expand Up @@ -5005,7 +5022,9 @@ private void logon(LogonCommand command) throws SQLServerException {
.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryServicePrincipal.toString())
|| authenticationString
.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString()))
&& fedAuthRequiredPreLoginResponse)) {
&& fedAuthRequiredPreLoginResponse)
|| null != activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString())) {
federatedAuthenticationInfoRequested = true;
fedAuthFeatureExtensionData = new FederatedAuthenticationFeatureExtensionData(TDS.TDS_FEDAUTH_LIBRARY_ADAL,
authenticationString, fedAuthRequiredPreLoginResponse);
Expand Down Expand Up @@ -5499,9 +5518,9 @@ final class FedAuthTokenCommand extends UninterruptableTDSCommand {
// Always update serialVersionUID when prompted.
private static final long serialVersionUID = 1L;
TDSTokenHandler tdsTokenHandler = null;
SqlFedAuthToken sqlFedAuthToken = null;
SqlAuthenticationToken sqlFedAuthToken = null;

FedAuthTokenCommand(SqlFedAuthToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) {
FedAuthTokenCommand(SqlAuthenticationToken sqlFedAuthToken, TDSTokenHandler tdsTokenHandler) {
super("FedAuth");
this.tdsTokenHandler = tdsTokenHandler;
this.sqlFedAuthToken = sqlFedAuthToken;
Expand Down Expand Up @@ -5530,7 +5549,16 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler)
assert null != fedAuthInfo;

attemptRefreshTokenLocked = true;
fedAuthToken = getFedAuthToken(fedAuthInfo);

SQLServerAccessTokenCallback callback = (SQLServerAccessTokenCallback) activeConnectionProperties
.get(SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString());

if (authenticationString.equals(SqlAuthentication.NotSpecified.toString()) && null != callback) {
fedAuthToken = callback.getAccessToken(fedAuthInfo.spn, fedAuthInfo.stsurl);
} else {
fedAuthToken = getFedAuthToken(fedAuthInfo);
}

attemptRefreshTokenLocked = false;

// fedAuthToken cannot be null.
Expand All @@ -5540,8 +5568,8 @@ void onFedAuthInfo(SqlFedAuthInfo fedAuthInfo, TDSTokenHandler tdsTokenHandler)
fedAuthCommand.execute(tdsChannel.getWriter(), tdsChannel.getReader(fedAuthCommand));
}

private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException {
SqlFedAuthToken fedAuthToken = null;
private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLServerException {
SqlAuthenticationToken fedAuthToken = null;

// fedAuthInfo should not be null.
assert null != fedAuthInfo;
Expand All @@ -5552,7 +5580,7 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
int sleepInterval = 100;

if (!msalContextExists()
&& !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryInteractive.toString())) {
&& !authenticationString.equalsIgnoreCase(SqlAuthentication.ActiveDirectoryIntegrated.toString())) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALMissing"));
throw new SQLServerException(form.format(new Object[] {authenticationString}), null, 0, null);
}
Expand Down Expand Up @@ -5613,7 +5641,9 @@ private SqlFedAuthToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throws SQLSe
byte[] accessTokenFromDLL = dllInfo.accessTokenBytes;

String accessToken = new String(accessTokenFromDLL, UTF_16LE);
fedAuthToken = new SqlFedAuthToken(accessToken, dllInfo.expiresIn);
Date now = new Date();
now.setTime(now.getTime() + (dllInfo.expiresIn * 1000));
fedAuthToken = new SqlAuthenticationToken(accessToken, now);

// Break out of the retry loop in successful case.
break;
Expand Down Expand Up @@ -5721,18 +5751,18 @@ private boolean msalContextExists() {
/**
* Send the access token to the server.
*/
private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlFedAuthToken fedAuthToken,
private void sendFedAuthToken(FedAuthTokenCommand fedAuthCommand, SqlAuthenticationToken fedAuthToken,
TDSTokenHandler tdsTokenHandler) throws SQLServerException {
assert null != fedAuthToken;
assert null != fedAuthToken.accessToken;
assert null != fedAuthToken.getAccessToken();

if (connectionlogger.isLoggable(Level.FINER)) {
connectionlogger.fine(toString() + " Sending federated authentication token.");
}

TDSWriter tdsWriter = fedAuthCommand.startRequest(TDS.PKT_FEDAUTH_TOKEN_MESSAGE);

byte[] accessToken = fedAuthToken.accessToken.getBytes(UTF_16LE);
byte[] accessToken = fedAuthToken.getAccessToken().getBytes(UTF_16LE);

// Send total length (length of token plus 4 bytes for the token length field)
// If we were sending a nonce, this would include that length as well
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,30 @@ public int getMsiTokenCacheTtl() {
return 0;
}

/**
* Sets the {@link SQLServerAccessTokenCallback} delegate.
*
* @param accessTokenCallback
* Access token callback delegate.
*/
@Override
public void setAccessTokenCallback(SQLServerAccessTokenCallback accessTokenCallback) {
setObjectProperty(connectionProps, SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
accessTokenCallback);
}

/**
* Returns a {@link SQLServerAccessTokenCallback}, the access token callback delegate.
*
* @return Access token callback delegate.
*/
@Override
public SQLServerAccessTokenCallback getAccessTokenCallback() {
return (SQLServerAccessTokenCallback) getObjectProperty(connectionProps,
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.toString(),
SQLServerDriverObjectProperty.ACCESS_TOKEN_CALLBACK.getDefaultValue());
}

/**
* Sets a property string value.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ enum SqlAuthentication {
SqlPassword,
ActiveDirectoryPassword,
ActiveDirectoryIntegrated,
ActiveDirectoryMSI,
ActiveDirectoryManagedIdentity,
ActiveDirectoryServicePrincipal,
ActiveDirectoryInteractive,
Expand Down Expand Up @@ -395,7 +394,8 @@ static ApplicationIntent valueOfString(String value) throws SQLServerException {


enum SQLServerDriverObjectProperty {
GSS_CREDENTIAL("gsscredential", null);
GSS_CREDENTIAL("gsscredential", null),
ACCESS_TOKEN_CALLBACK("accessTokenCallback", null);

private final String name;
private final String defaultValue;
Expand Down
Loading