|
7 | 7 | package azidentity
|
8 | 8 |
|
9 | 9 | import (
|
| 10 | + "bytes" |
10 | 11 | "context"
|
11 | 12 | "crypto/tls"
|
12 | 13 | "crypto/x509"
|
@@ -35,15 +36,6 @@ type WorkloadIdentityCredential struct {
|
35 | 36 | cred *ClientAssertionCredential
|
36 | 37 | expires time.Time
|
37 | 38 | mtx *sync.RWMutex
|
38 |
| - // identity binding mode fields |
39 |
| - identityBinding bool |
40 |
| - kubernetesCAFile string |
41 |
| - kubernetesSNIName string |
42 |
| - kubernetesTokenEndpoint *url.URL |
43 |
| - // CA certificate caching fields |
44 |
| - caCert []byte |
45 |
| - caCertPool *x509.CertPool |
46 |
| - caExpires time.Time |
47 | 39 | }
|
48 | 40 |
|
49 | 41 | // WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential.
|
@@ -104,29 +96,20 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
|
104 | 96 | }
|
105 | 97 |
|
106 | 98 | w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
|
107 |
| - |
108 |
| - // Configure identity binding if environment variables are present |
109 |
| - if err := w.configureIdentityBinding(); err != nil { |
110 |
| - return nil, err |
111 |
| - } |
112 |
| - |
113 |
| - caco := ClientAssertionCredentialOptions{ |
| 99 | + caco := &ClientAssertionCredentialOptions{ |
114 | 100 | AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
|
115 | 101 | Cache: options.Cache,
|
116 | 102 | ClientOptions: options.ClientOptions,
|
117 | 103 | DisableInstanceDiscovery: options.DisableInstanceDiscovery,
|
118 | 104 | }
|
119 | 105 |
|
120 |
| - // If identity binding mode is enabled, configure a custom HTTP client |
121 |
| - if w.identityBinding { |
122 |
| - // Override the client options to use our custom transport |
123 |
| - caco.ClientOptions.Transport = &identityBindingTransport{ |
124 |
| - credential: &w, |
125 |
| - kubernetesSNIName: w.kubernetesSNIName, |
126 |
| - } |
| 106 | + // configure identity binding if environment variables are present. |
| 107 | + // In identity binding enabled mode, a dedicated transport will be used for proxying token requests to a dedicated endpoint. |
| 108 | + if err := w.configureIdentityBinding(caco); err != nil { |
| 109 | + return nil, err |
127 | 110 | }
|
128 | 111 |
|
129 |
| - cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, &caco) |
| 112 | + cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, caco) |
130 | 113 | if err != nil {
|
131 | 114 | return nil, err
|
132 | 115 | }
|
@@ -173,127 +156,198 @@ func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, erro
|
173 | 156 | }
|
174 | 157 |
|
175 | 158 | // configureIdentityBinding configures identity binding mode if the required environment variables are present
|
176 |
| -func (w *WorkloadIdentityCredential) configureIdentityBinding() error { |
177 |
| - // Check for identity binding mode environment variables |
| 159 | +func (w *WorkloadIdentityCredential) configureIdentityBinding(caco *ClientAssertionCredentialOptions) error { |
| 160 | + // check for identity binding mode environment variables |
178 | 161 | kubernetesTokenEndpointStr := os.Getenv(azureKubernetesTokenEndpoint)
|
179 | 162 | kubernetesSNIName := os.Getenv(azureKubernetesSNIName)
|
180 | 163 | kubernetesCAFile := os.Getenv(azureKubernetesCAFile)
|
181 | 164 |
|
182 |
| - // If any of the identity binding environment variables are present, enable identity binding mode |
183 |
| - if kubernetesTokenEndpointStr != "" || kubernetesSNIName != "" || kubernetesCAFile != "" { |
184 |
| - // All three variables must be present for identity binding mode |
185 |
| - if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" { |
186 |
| - return errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE") |
187 |
| - } |
| 165 | + if kubernetesTokenEndpointStr == "" && kubernetesSNIName == "" && kubernetesCAFile == "" { |
| 166 | + // identity binding is not set |
| 167 | + return nil |
| 168 | + } |
188 | 169 |
|
189 |
| - // Parse the Kubernetes token endpoint URL |
190 |
| - kubernetesTokenEndpoint, err := url.Parse(kubernetesTokenEndpointStr) |
191 |
| - if err != nil { |
192 |
| - return fmt.Errorf("failed to parse Kubernetes token endpoint URL: %w", err) |
193 |
| - } |
| 170 | + // All three variables must be present for identity binding mode |
| 171 | + if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" { |
| 172 | + return errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE") |
| 173 | + } |
194 | 174 |
|
195 |
| - w.identityBinding = true |
196 |
| - w.kubernetesTokenEndpoint = kubernetesTokenEndpoint |
197 |
| - w.kubernetesSNIName = kubernetesSNIName |
198 |
| - w.kubernetesCAFile = kubernetesCAFile |
| 175 | + transporter, err := newIdentityBindingTransport( |
| 176 | + kubernetesCAFile, kubernetesSNIName, kubernetesTokenEndpointStr, |
| 177 | + caco.Transport, |
| 178 | + ) |
| 179 | + if err != nil { |
| 180 | + return err |
| 181 | + } |
| 182 | + caco.Transport = transporter |
| 183 | + return nil |
| 184 | +} |
199 | 185 |
|
200 |
| - // Validate CA file exists and is readable during construction |
201 |
| - if _, err := os.Stat(kubernetesCAFile); err != nil { |
202 |
| - return fmt.Errorf("failed to read Kubernetes CA file: %w", err) |
| 186 | +const ( |
| 187 | + tokenEndpointSuffix = "/oauth2/v2.0/token" |
| 188 | + caReloadInterval = 10 * time.Minute |
| 189 | +) |
| 190 | + |
| 191 | +// identityBindingTransport is a custom HTTP transport that redirects token requests |
| 192 | +// to the Kubernetes token endpoint when in identity binding mode |
| 193 | +type identityBindingTransport struct { |
| 194 | + caFile string |
| 195 | + sniName string |
| 196 | + tokenEndpoint *url.URL |
| 197 | + fallbackTransporter policy.Transporter |
| 198 | + |
| 199 | + mtx *sync.RWMutex |
| 200 | + |
| 201 | + nextRead time.Time |
| 202 | + currentCA []byte |
| 203 | + transport *http.Transport |
| 204 | +} |
| 205 | + |
| 206 | +func newIdentityBindingTransport( |
| 207 | + caFile, sniName, tokenEndpointStr string, |
| 208 | + fallbackTransporter policy.Transporter, |
| 209 | +) (*identityBindingTransport, error) { |
| 210 | + tokenEndpoint, err := url.Parse(tokenEndpointStr) |
| 211 | + if err != nil { |
| 212 | + return nil, fmt.Errorf("failed to parse token endpoint URL %q: %w", tokenEndpointStr, err) |
| 213 | + } |
| 214 | + |
| 215 | + initialTransport := func() *http.Transport { |
| 216 | + // try reusing the user provided transport if available |
| 217 | + if fallbackTransporter != nil { |
| 218 | + if httpClient, ok := fallbackTransporter.(*http.Client); ok { |
| 219 | + if transport, ok := httpClient.Transport.(*http.Transport); ok { |
| 220 | + return transport.Clone() |
| 221 | + } |
| 222 | + } |
203 | 223 | }
|
204 | 224 |
|
205 |
| - // Validate CA certificate is parseable during construction |
206 |
| - caCert, err := os.ReadFile(kubernetesCAFile) |
207 |
| - if err != nil { |
208 |
| - return fmt.Errorf("failed to read Kubernetes CA file: %w", err) |
| 225 | + // if the user did not provide a transport, we use the default one |
| 226 | + // FIXME: can we callback to the defaultHTTPClient from azcore/runtime? |
| 227 | + if transport, ok := http.DefaultTransport.(*http.Transport); ok { |
| 228 | + return transport.Clone() |
209 | 229 | }
|
210 | 230 |
|
211 |
| - caCertPool := x509.NewCertPool() |
212 |
| - if !caCertPool.AppendCertsFromPEM(caCert) { |
213 |
| - return errors.New("failed to parse Kubernetes CA certificate") |
| 231 | + // this should not happen, but if the user mutates the net/http.DefaultTransport |
| 232 | + // to something else, we fall back to a sane default |
| 233 | + return &http.Transport{ |
| 234 | + ForceAttemptHTTP2: true, |
| 235 | + MaxIdleConns: 100, |
| 236 | + IdleConnTimeout: 90 * time.Second, |
| 237 | + TLSHandshakeTimeout: 10 * time.Second, |
214 | 238 | }
|
| 239 | + }() |
| 240 | + |
| 241 | + tr := &identityBindingTransport{ |
| 242 | + caFile: caFile, |
| 243 | + sniName: sniName, |
| 244 | + tokenEndpoint: tokenEndpoint, |
| 245 | + fallbackTransporter: fallbackTransporter, |
| 246 | + mtx: &sync.RWMutex{}, |
| 247 | + transport: initialTransport, |
215 | 248 | }
|
216 | 249 |
|
217 |
| - return nil |
218 |
| -} |
219 |
| - |
220 |
| -// loadKubernetesCA loads and caches the Kubernetes CA certificate |
221 |
| -func (w *WorkloadIdentityCredential) loadKubernetesCA() (*x509.CertPool, error) { |
222 |
| - w.mtx.RLock() |
223 |
| - if w.caExpires.After(time.Now()) && w.caCertPool != nil { |
224 |
| - defer w.mtx.RUnlock() |
225 |
| - return w.caCertPool, nil |
| 250 | + // perform an initial load to surface any issues with the CA file and transport settings. |
| 251 | + // Lock is not held here as this is called in the constructor |
| 252 | + if err := tr.reloadCA(); err != nil { |
| 253 | + return nil, err |
226 | 254 | }
|
227 | 255 |
|
228 |
| - // ensure only one goroutine at a time updates the CA cert |
229 |
| - w.mtx.RUnlock() |
230 |
| - w.mtx.Lock() |
231 |
| - defer w.mtx.Unlock() |
| 256 | + return tr, nil |
| 257 | +} |
232 | 258 |
|
233 |
| - // double check because another goroutine may have acquired the write lock first and done the update |
234 |
| - if now := time.Now(); w.caExpires.After(now) && w.caCertPool != nil { |
235 |
| - return w.caCertPool, nil |
| 259 | +func (i *identityBindingTransport) Do(req *http.Request) (*http.Response, error) { |
| 260 | + if !strings.HasSuffix(req.URL.Path, tokenEndpointSuffix) { |
| 261 | + // not a token request, fallback to the original transporter |
| 262 | + return i.fallbackTransporter.Do(req) |
236 | 263 | }
|
237 | 264 |
|
238 |
| - // Load the CA certificate for the Kubernetes endpoint |
239 |
| - caCert, err := os.ReadFile(w.kubernetesCAFile) |
| 265 | + tr, err := i.getTokenTransporter() |
240 | 266 | if err != nil {
|
241 |
| - return nil, fmt.Errorf("failed to read Kubernetes CA file: %w", err) |
| 267 | + return nil, err |
242 | 268 | }
|
243 | 269 |
|
244 |
| - caCertPool := x509.NewCertPool() |
245 |
| - if !caCertPool.AppendCertsFromPEM(caCert) { |
246 |
| - return nil, errors.New("failed to parse Kubernetes CA certificate") |
247 |
| - } |
| 270 | + newReq := req.Clone(req.Context()) |
| 271 | + newReq.URL.Scheme = i.tokenEndpoint.Scheme // this will always be https |
| 272 | + newReq.URL.Host = i.tokenEndpoint.Host |
| 273 | + newReq.Host = i.tokenEndpoint.Host |
248 | 274 |
|
249 |
| - // Cache the CA certificate for 10 minutes (same as token assertion) |
250 |
| - w.caCert = caCert |
251 |
| - w.caCertPool = caCertPool |
252 |
| - w.caExpires = time.Now().Add(10 * time.Minute) |
| 275 | + return tr.RoundTrip(newReq) |
| 276 | +} |
253 | 277 |
|
254 |
| - return caCertPool, nil |
| 278 | +func (i *identityBindingTransport) getTokenTransporter() (*http.Transport, error) { |
| 279 | + i.mtx.RLock() |
| 280 | + if i.nextRead.Before(time.Now()) { |
| 281 | + i.mtx.RUnlock() |
| 282 | + i.mtx.Lock() |
| 283 | + defer i.mtx.Unlock() |
| 284 | + // double check on the read time |
| 285 | + if now := time.Now(); i.nextRead.Before(now) { |
| 286 | + if err := i.reloadCA(); err != nil { |
| 287 | + // we return error if any attempt of reloading CA fails |
| 288 | + // This should surface in the token calls and we expect the caller to |
| 289 | + // have proper error handling / rate limit so we don't fall into deadloop here |
| 290 | + // due to scenario like broken CA file. |
| 291 | + return nil, err |
| 292 | + } |
| 293 | + } |
| 294 | + } else { |
| 295 | + defer i.mtx.RUnlock() |
| 296 | + } |
| 297 | + return i.transport, nil |
255 | 298 | }
|
256 | 299 |
|
257 |
| -// identityBindingTransport is a custom HTTP transport that redirects token requests |
258 |
| -// to the Kubernetes token endpoint when in identity binding mode |
259 |
| -type identityBindingTransport struct { |
260 |
| - credential *WorkloadIdentityCredential |
261 |
| - kubernetesSNIName string |
| 300 | +func (i *identityBindingTransport) createTransportWithCAPool( |
| 301 | + fromTransport *http.Transport, |
| 302 | + caPool *x509.CertPool, |
| 303 | +) *http.Transport { |
| 304 | + transport := fromTransport.Clone() |
| 305 | + if transport.TLSClientConfig == nil { |
| 306 | + transport.TLSClientConfig = &tls.Config{} |
| 307 | + } |
| 308 | + transport.TLSClientConfig.ServerName = i.sniName |
| 309 | + transport.TLSClientConfig.RootCAs = caPool |
| 310 | + return transport |
262 | 311 | }
|
263 | 312 |
|
264 |
| -func (t *identityBindingTransport) Do(req *http.Request) (*http.Response, error) { |
265 |
| - // Return early if this is not a token request |
266 |
| - if !strings.HasSuffix(req.URL.Path, "/oauth2/v2.0/token") { |
267 |
| - return http.DefaultTransport.RoundTrip(req) |
| 313 | +// reloadCA attempts to read the latest CA from the CA file and updates the transport if the content has changed. |
| 314 | +// If a new CA is discovered, the existing transport will be replaced with a new one that uses the new CA. |
| 315 | +// It expects the caller to hold the write lock on i.mtx to ensure thread safety. |
| 316 | +func (i *identityBindingTransport) reloadCA() error { |
| 317 | + newCA, err := os.ReadFile(i.caFile) |
| 318 | + if err != nil { |
| 319 | + return fmt.Errorf("read CA file %q: %w", i.caFile, err) |
268 | 320 | }
|
269 | 321 |
|
270 |
| - // This is a token request, redirect to Kubernetes endpoint |
| 322 | + if len(newCA) == 0 { |
| 323 | + // the CA file might be in the middle of rotation without the content written. |
| 324 | + // We return nil and rely on next check. |
| 325 | + return nil |
| 326 | + } |
271 | 327 |
|
272 |
| - // Load the CA certificate (this will use cached version if still valid) |
273 |
| - caCertPool, err := t.credential.loadKubernetesCA() |
274 |
| - if err != nil { |
275 |
| - return nil, err |
| 328 | + if bytes.Equal(i.currentCA, newCA) { |
| 329 | + // no change in CA content, no need to replace |
| 330 | + i.nextRead = time.Now().Add(caReloadInterval) |
| 331 | + return nil |
276 | 332 | }
|
277 | 333 |
|
278 |
| - // Create custom transport with the CA and SNI configuration |
279 |
| - transport := &http.Transport{ |
280 |
| - TLSClientConfig: &tls.Config{ |
281 |
| - RootCAs: caCertPool, |
282 |
| - ServerName: t.kubernetesSNIName, |
283 |
| - }, |
| 334 | + newCAPool := x509.NewCertPool() |
| 335 | + if !newCAPool.AppendCertsFromPEM(newCA) { |
| 336 | + return fmt.Errorf("parse CA file %q: no valid certificates found", i.caFile) |
284 | 337 | }
|
285 | 338 |
|
286 |
| - // Clone the request to avoid modifying the original |
287 |
| - newReq := req.Clone(req.Context()) |
| 339 | + newTransport := i.createTransportWithCAPool(i.transport, newCAPool) |
| 340 | + oldTransport := i.transport |
288 | 341 |
|
289 |
| - // Update the URL to point to the Kubernetes endpoint |
290 |
| - newReq.URL.Scheme = t.credential.kubernetesTokenEndpoint.Scheme |
291 |
| - newReq.URL.Host = t.credential.kubernetesTokenEndpoint.Host |
292 |
| - newReq.Host = t.credential.kubernetesTokenEndpoint.Host |
| 342 | + i.transport = newTransport |
| 343 | + i.currentCA = newCA |
| 344 | + i.nextRead = time.Now().Add(caReloadInterval) |
293 | 345 |
|
294 |
| - // Preserve the original path (contains tenant ID and token endpoint path) |
295 |
| - // The path should be something like "/tenant-id/oauth2/v2.0/token" |
296 |
| - // Keep the original path to maintain the token request structure |
| 346 | + if oldTransport != nil { |
| 347 | + // drop any idle connections from previous transport so new requests can be |
| 348 | + // moved to the new transport |
| 349 | + oldTransport.CloseIdleConnections() |
| 350 | + } |
297 | 351 |
|
298 |
| - return transport.RoundTrip(newReq) |
| 352 | + return nil |
299 | 353 | }
|
0 commit comments