|
6 | 6 | package com.microsoft.sqlserver.jdbc; |
7 | 7 |
|
8 | 8 | import java.security.InvalidKeyException; |
| 9 | +import java.security.MessageDigest; |
9 | 10 | import java.security.NoSuchAlgorithmException; |
10 | 11 | import java.text.MessageFormat; |
11 | 12 | import java.util.Arrays; |
| 13 | +import java.util.HashMap; |
12 | 14 | import java.util.Optional; |
13 | 15 | import java.util.Iterator; |
14 | 16 | import java.util.List; |
| 17 | +import java.util.concurrent.locks.Lock; |
| 18 | +import java.util.concurrent.locks.ReentrantLock; |
15 | 19 |
|
16 | 20 | import javax.crypto.Mac; |
17 | 21 | import javax.crypto.spec.SecretKeySpec; |
18 | 22 |
|
19 | 23 | import com.azure.core.credential.AccessToken; |
| 24 | +import com.azure.core.credential.TokenCredential; |
20 | 25 | import com.azure.core.credential.TokenRequestContext; |
21 | 26 | import com.azure.identity.ManagedIdentityCredential; |
22 | 27 | import com.azure.identity.ManagedIdentityCredentialBuilder; |
@@ -46,6 +51,11 @@ class SQLServerSecurityUtility { |
46 | 51 | // Environment variable for additionally allowed tenants. The tenantIds are comma delimited |
47 | 52 | private static final String ADDITIONALLY_ALLOWED_TENANTS = "ADDITIONALLY_ALLOWED_TENANTS"; |
48 | 53 |
|
| 54 | + // Credential Cache for ManagedIdentityCredential and DefaultAzureCredential |
| 55 | + private static final HashMap<String, Credential> CREDENTIAL_CACHE = new HashMap<>(); |
| 56 | + |
| 57 | + private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); |
| 58 | + |
49 | 59 | private SQLServerSecurityUtility() { |
50 | 60 | throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); |
51 | 61 | } |
@@ -331,16 +341,35 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer |
331 | 341 | */ |
332 | 342 | static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, |
333 | 343 | String managedIdentityClientId) throws SQLServerException { |
334 | | - ManagedIdentityCredential mic = null; |
335 | 344 |
|
336 | 345 | if (logger.isLoggable(java.util.logging.Level.FINEST)) { |
337 | 346 | logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); |
338 | 347 | } |
339 | 348 |
|
340 | | - if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { |
341 | | - mic = new ManagedIdentityCredentialBuilder().clientId(managedIdentityClientId).build(); |
342 | | - } else { |
343 | | - mic = new ManagedIdentityCredentialBuilder().build(); |
| 349 | + String key = getHashedSecret( |
| 350 | + new String[] {managedIdentityClientId, ManagedIdentityCredential.class.getSimpleName()}); |
| 351 | + ManagedIdentityCredential mic = (ManagedIdentityCredential) getCredentialFromCache(key); |
| 352 | + |
| 353 | + if (null == mic) { |
| 354 | + CREDENTIAL_LOCK.lock(); |
| 355 | + |
| 356 | + try { |
| 357 | + mic = (ManagedIdentityCredential) getCredentialFromCache(key); |
| 358 | + if (null == mic) { |
| 359 | + ManagedIdentityCredentialBuilder micBuilder = new ManagedIdentityCredentialBuilder(); |
| 360 | + |
| 361 | + if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { |
| 362 | + mic = micBuilder.clientId(managedIdentityClientId).build(); |
| 363 | + } else { |
| 364 | + mic = micBuilder.build(); |
| 365 | + } |
| 366 | + |
| 367 | + Credential credential = new Credential(mic); |
| 368 | + CREDENTIAL_CACHE.put(key, credential); |
| 369 | + } |
| 370 | + } finally { |
| 371 | + CREDENTIAL_LOCK.unlock(); |
| 372 | + } |
344 | 373 | } |
345 | 374 |
|
346 | 375 | TokenRequestContext tokenRequestContext = new TokenRequestContext(); |
@@ -383,22 +412,49 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, |
383 | 412 | String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); |
384 | 413 | String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); |
385 | 414 |
|
386 | | - DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder(); |
387 | | - DefaultAzureCredential dac = null; |
| 415 | + int secretsLength = null == additionallyAllowedTenants ? 3 : additionallyAllowedTenants.length + 3; |
| 416 | + String[] secrets = new String[secretsLength]; |
388 | 417 |
|
389 | | - if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { |
390 | | - dacBuilder.managedIdentityClientId(managedIdentityClientId); |
| 418 | + if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { |
| 419 | + System.arraycopy(additionallyAllowedTenants, 0, secrets, 3, additionallyAllowedTenants.length); |
391 | 420 | } |
392 | 421 |
|
393 | | - if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) { |
394 | | - dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath); |
395 | | - } |
| 422 | + secrets[0] = DefaultAzureCredential.class.getSimpleName(); |
| 423 | + secrets[1] = managedIdentityClientId; |
| 424 | + secrets[2] = intellijKeepassPath; |
396 | 425 |
|
397 | | - if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { |
398 | | - dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants); |
399 | | - } |
| 426 | + String key = getHashedSecret(secrets); |
| 427 | + DefaultAzureCredential dac = (DefaultAzureCredential) getCredentialFromCache(key); |
| 428 | + |
| 429 | + if (null == dac) { |
| 430 | + CREDENTIAL_LOCK.lock(); |
| 431 | + |
| 432 | + try { |
| 433 | + dac = (DefaultAzureCredential) getCredentialFromCache(key); |
| 434 | + if (null == dac) { |
| 435 | + DefaultAzureCredentialBuilder dacBuilder = new DefaultAzureCredentialBuilder(); |
| 436 | + |
| 437 | + if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { |
| 438 | + dacBuilder.managedIdentityClientId(managedIdentityClientId); |
| 439 | + } |
| 440 | + |
| 441 | + if (null != intellijKeepassPath && !intellijKeepassPath.isEmpty()) { |
| 442 | + dacBuilder.intelliJKeePassDatabasePath(intellijKeepassPath); |
| 443 | + } |
| 444 | + |
| 445 | + if (null != additionallyAllowedTenants && additionallyAllowedTenants.length != 0) { |
| 446 | + dacBuilder.additionallyAllowedTenants(additionallyAllowedTenants); |
| 447 | + } |
| 448 | + |
| 449 | + dac = dacBuilder.build(); |
400 | 450 |
|
401 | | - dac = dacBuilder.build(); |
| 451 | + Credential credential = new Credential(dac); |
| 452 | + CREDENTIAL_CACHE.put(key, credential); |
| 453 | + } |
| 454 | + } finally { |
| 455 | + CREDENTIAL_LOCK.unlock(); |
| 456 | + } |
| 457 | + } |
402 | 458 |
|
403 | 459 | TokenRequestContext tokenRequestContext = new TokenRequestContext(); |
404 | 460 | String scope = resource.endsWith(SQLServerMSAL4JUtils.SLASH_DEFAULT) ? resource : resource |
@@ -430,4 +486,36 @@ private static String[] getAdditonallyAllowedTenants() { |
430 | 486 |
|
431 | 487 | return null; |
432 | 488 | } |
| 489 | + |
| 490 | + private static TokenCredential getCredentialFromCache(String key) { |
| 491 | + Credential credential = CREDENTIAL_CACHE.get(key); |
| 492 | + |
| 493 | + if (null != credential) { |
| 494 | + return credential.tokenCredential; |
| 495 | + } |
| 496 | + |
| 497 | + return null; |
| 498 | + } |
| 499 | + |
| 500 | + private static class Credential { |
| 501 | + TokenCredential tokenCredential; |
| 502 | + |
| 503 | + public Credential(TokenCredential tokenCredential) { |
| 504 | + this.tokenCredential = tokenCredential; |
| 505 | + } |
| 506 | + } |
| 507 | + |
| 508 | + private static String getHashedSecret(String[] secrets) throws SQLServerException { |
| 509 | + try { |
| 510 | + MessageDigest md = MessageDigest.getInstance("SHA-256"); |
| 511 | + for (String secret : secrets) { |
| 512 | + if (null != secret) { |
| 513 | + md.update(secret.getBytes(java.nio.charset.StandardCharsets.UTF_16LE)); |
| 514 | + } |
| 515 | + } |
| 516 | + return new String(md.digest()); |
| 517 | + } catch (NoSuchAlgorithmException e) { |
| 518 | + throw new SQLServerException(SQLServerException.getErrString("R_NoSHA256Algorithm"), e); |
| 519 | + } |
| 520 | + } |
433 | 521 | } |
0 commit comments