Skip to content

Commit c4edecd

Browse files
authored
Credential caching (#2415) (#2426)
1 parent e411e0b commit c4edecd

File tree

1 file changed

+104
-16
lines changed

1 file changed

+104
-16
lines changed

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66
package com.microsoft.sqlserver.jdbc;
77

88
import java.security.InvalidKeyException;
9+
import java.security.MessageDigest;
910
import java.security.NoSuchAlgorithmException;
1011
import java.text.MessageFormat;
1112
import java.util.Arrays;
13+
import java.util.HashMap;
1214
import java.util.Optional;
1315
import java.util.Iterator;
1416
import java.util.List;
17+
import java.util.concurrent.locks.Lock;
18+
import java.util.concurrent.locks.ReentrantLock;
1519

1620
import javax.crypto.Mac;
1721
import javax.crypto.spec.SecretKeySpec;
1822

1923
import com.azure.core.credential.AccessToken;
24+
import com.azure.core.credential.TokenCredential;
2025
import com.azure.core.credential.TokenRequestContext;
2126
import com.azure.identity.ManagedIdentityCredential;
2227
import com.azure.identity.ManagedIdentityCredentialBuilder;
@@ -46,6 +51,11 @@ class SQLServerSecurityUtility {
4651
// Environment variable for additionally allowed tenants. The tenantIds are comma delimited
4752
private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS";
4853

54+
// Credential Cache for ManagedIdentityCredential and DefaultAzureCredential
55+
private static final HashMap<String, Credential> CREDENTIAL_CACHE = new HashMap<>();
56+
57+
private static final Lock CREDENTIAL_LOCK = new ReentrantLock();
58+
4959
private SQLServerSecurityUtility() {
5060
throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported"));
5161
}
@@ -331,16 +341,35 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer
331341
*/
332342
static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource,
333343
String managedIdentityClientId) throws SQLServerException {
334-
ManagedIdentityCredential mic = null;
335344

336345
if (logger.isLoggable(java.util.logging.Level.FINEST)) {
337346
logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId);
338347
}
339348

340-
if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
341-
mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build();
342-
} else {
343-
mic = new ManagedIdentityCredentialBuilder().build();
349+
String key = getHashedSecret(
350+
new String[] {managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()});
351+
ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key);
352+
353+
if (null == mic) {
354+
CREDENTIAL_LOCK.lock();
355+
356+
try {
357+
mic = (ManagedIdentityCredential) getCredentialFromCache(key);
358+
if (null == mic) {
359+
ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder();
360+
361+
if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
362+
mic = micBuilder.clientId(managedIdentityClientId).build();
363+
} else {
364+
mic = micBuilder.build();
365+
}
366+
367+
Credential credential = new Credential(mic);
368+
CREDENTIAL_CACHE.put(key, credential);
369+
}
370+
} finally {
371+
CREDENTIAL_LOCK.unlock();
372+
}
344373
}
345374

346375
TokenRequestContext tokenRequestContext = new TokenRequestContext();
@@ -383,22 +412,49 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource,
383412
String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS);
384413
String[] additionallyAllowedTenants = getAdditonallyAllowedTenants();
385414

386-
DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
387-
DefaultAzureCredential dac = null;
415+
int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3;
416+
String[] secrets = new String[secretsLength];
388417

389-
if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
390-
dacBuilder.managedIdentityClientId(managedIdentityClientId);
418+
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
419+
System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length);
391420
}
392421

393-
if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
394-
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
395-
}
422+
secrets[0] = DefaultAzureCredential.class.getSimpleName();
423+
secrets[1] = managedIdentityClientId;
424+
secrets[2] = intellijKeepassPath;
396425

397-
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
398-
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
399-
}
426+
String key = getHashedSecret(secrets);
427+
DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key);
428+
429+
if (null == dac) {
430+
CREDENTIAL_LOCK.lock();
431+
432+
try {
433+
dac = (DefaultAzureCredential) getCredentialFromCache(key);
434+
if (null == dac) {
435+
DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder();
436+
437+
if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) {
438+
dacBuilder.managedIdentityClientId(managedIdentityClientId);
439+
}
440+
441+
if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) {
442+
dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath);
443+
}
444+
445+
if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) {
446+
dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants);
447+
}
448+
449+
dac = dacBuilder.build();
400450

401-
dac = dacBuilder.build();
451+
Credential credential = new Credential(dac);
452+
CREDENTIAL_CACHE.put(key, credential);
453+
}
454+
} finally {
455+
CREDENTIAL_LOCK.unlock();
456+
}
457+
}
402458

403459
TokenRequestContext tokenRequestContext = new TokenRequestContext();
404460
String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource
@@ -430,4 +486,36 @@ private static String[] getAdditonallyAllowedTenants() {
430486

431487
return null;
432488
}
489+
490+
private static TokenCredential getCredentialFromCache(String key) {
491+
Credential credential = CREDENTIAL_CACHE.get(key);
492+
493+
if (null != credential) {
494+
return credential.tokenCredential;
495+
}
496+
497+
return null;
498+
}
499+
500+
private static class Credential {
501+
TokenCredential tokenCredential;
502+
503+
public Credential(TokenCredential tokenCredential) {
504+
this.tokenCredential = tokenCredential;
505+
}
506+
}
507+
508+
private static String getHashedSecret(String[] secrets) throws SQLServerException {
509+
try {
510+
MessageDigest md = MessageDigest.getInstance("SHA-256");
511+
for (String secret : secrets) {
512+
if (null != secret) {
513+
md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE));
514+
}
515+
}
516+
return new String(md.digest());
517+
} catch (NoSuchAlgorithmException e) {
518+
throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e);
519+
}
520+
}
433521
}

0 commit comments

Comments
 (0)