Skip to content

Commit 6ad38c9

Browse files
committed
refactor: push down the CA loading logic to ib transport object
1 parent aa168cc commit 6ad38c9

File tree

2 files changed

+163
-429
lines changed

2 files changed

+163
-429
lines changed

sdk/azidentity/workload_identity.go

Lines changed: 163 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package azidentity
88

99
import (
10+
"bytes"
1011
"context"
1112
"crypto/tls"
1213
"crypto/x509"
@@ -35,15 +36,6 @@ type WorkloadIdentityCredential struct {
3536
cred *ClientAssertionCredential
3637
expires time.Time
3738
mtx *sync.RWMutex
38-
// identity binding mode fields
39-
identityBinding bool
40-
kubernetesCAFile string
41-
kubernetesSNIName string
42-
kubernetesTokenEndpoint *url.URL
43-
// CA certificate caching fields
44-
caCert []byte
45-
caCertPool *x509.CertPool
46-
caExpires time.Time
4739
}
4840

4941
// WorkloadIdentityCredentialOptions contains optional parameters for WorkloadIdentityCredential.
@@ -104,29 +96,20 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
10496
}
10597

10698
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
107-
108-
// Configure identity binding if environment variables are present
109-
if err := w.configureIdentityBinding(); err != nil {
110-
return nil, err
111-
}
112-
113-
caco := ClientAssertionCredentialOptions{
99+
caco := &ClientAssertionCredentialOptions{
114100
AdditionallyAllowedTenants: options.AdditionallyAllowedTenants,
115101
Cache: options.Cache,
116102
ClientOptions: options.ClientOptions,
117103
DisableInstanceDiscovery: options.DisableInstanceDiscovery,
118104
}
119105

120-
// If identity binding mode is enabled, configure a custom HTTP client
121-
if w.identityBinding {
122-
// Override the client options to use our custom transport
123-
caco.ClientOptions.Transport = &identityBindingTransport{
124-
credential: &w,
125-
kubernetesSNIName: w.kubernetesSNIName,
126-
}
106+
// configure identity binding if environment variables are present.
107+
// In identity binding enabled mode, a dedicated transport will be used for proxying token requests to a dedicated endpoint.
108+
if err := w.configureIdentityBinding(caco); err != nil {
109+
return nil, err
127110
}
128111

129-
cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, &caco)
112+
cred, err := NewClientAssertionCredential(tenantID, clientID, w.getAssertion, caco)
130113
if err != nil {
131114
return nil, err
132115
}
@@ -173,127 +156,198 @@ func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, erro
173156
}
174157

175158
// configureIdentityBinding configures identity binding mode if the required environment variables are present
176-
func (w *WorkloadIdentityCredential) configureIdentityBinding() error {
177-
// Check for identity binding mode environment variables
159+
func (w *WorkloadIdentityCredential) configureIdentityBinding(caco *ClientAssertionCredentialOptions) error {
160+
// check for identity binding mode environment variables
178161
kubernetesTokenEndpointStr := os.Getenv(azureKubernetesTokenEndpoint)
179162
kubernetesSNIName := os.Getenv(azureKubernetesSNIName)
180163
kubernetesCAFile := os.Getenv(azureKubernetesCAFile)
181164

182-
// If any of the identity binding environment variables are present, enable identity binding mode
183-
if kubernetesTokenEndpointStr != "" || kubernetesSNIName != "" || kubernetesCAFile != "" {
184-
// All three variables must be present for identity binding mode
185-
if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" {
186-
return errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE")
187-
}
165+
if kubernetesTokenEndpointStr == "" && kubernetesSNIName == "" && kubernetesCAFile == "" {
166+
// identity binding is not set
167+
return nil
168+
}
188169

189-
// Parse the Kubernetes token endpoint URL
190-
kubernetesTokenEndpoint, err := url.Parse(kubernetesTokenEndpointStr)
191-
if err != nil {
192-
return fmt.Errorf("failed to parse Kubernetes token endpoint URL: %w", err)
193-
}
170+
// All three variables must be present for identity binding mode
171+
if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" {
172+
return errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE")
173+
}
194174

195-
w.identityBinding = true
196-
w.kubernetesTokenEndpoint = kubernetesTokenEndpoint
197-
w.kubernetesSNIName = kubernetesSNIName
198-
w.kubernetesCAFile = kubernetesCAFile
175+
transporter, err := newIdentityBindingTransport(
176+
kubernetesCAFile, kubernetesSNIName, kubernetesTokenEndpointStr,
177+
caco.Transport,
178+
)
179+
if err != nil {
180+
return err
181+
}
182+
caco.Transport = transporter
183+
return nil
184+
}
199185

200-
// Validate CA file exists and is readable during construction
201-
if _, err := os.Stat(kubernetesCAFile); err != nil {
202-
return fmt.Errorf("failed to read Kubernetes CA file: %w", err)
186+
const (
187+
tokenEndpointSuffix = "/oauth2/v2.0/token"
188+
caReloadInterval = 10 * time.Minute
189+
)
190+
191+
// identityBindingTransport is a custom HTTP transport that redirects token requests
192+
// to the Kubernetes token endpoint when in identity binding mode
193+
type identityBindingTransport struct {
194+
caFile string
195+
sniName string
196+
tokenEndpoint *url.URL
197+
fallbackTransporter policy.Transporter
198+
199+
mtx *sync.RWMutex
200+
201+
nextRead time.Time
202+
currentCA []byte
203+
transport *http.Transport
204+
}
205+
206+
func newIdentityBindingTransport(
207+
caFile, sniName, tokenEndpointStr string,
208+
fallbackTransporter policy.Transporter,
209+
) (*identityBindingTransport, error) {
210+
tokenEndpoint, err := url.Parse(tokenEndpointStr)
211+
if err != nil {
212+
return nil, fmt.Errorf("failed to parse token endpoint URL %q: %w", tokenEndpointStr, err)
213+
}
214+
215+
initialTransport := func() *http.Transport {
216+
// try reusing the user provided transport if available
217+
if fallbackTransporter != nil {
218+
if httpClient, ok := fallbackTransporter.(*http.Client); ok {
219+
if transport, ok := httpClient.Transport.(*http.Transport); ok {
220+
return transport.Clone()
221+
}
222+
}
203223
}
204224

205-
// Validate CA certificate is parseable during construction
206-
caCert, err := os.ReadFile(kubernetesCAFile)
207-
if err != nil {
208-
return fmt.Errorf("failed to read Kubernetes CA file: %w", err)
225+
// if the user did not provide a transport, we use the default one
226+
// FIXME: can we callback to the defaultHTTPClient from azcore/runtime?
227+
if transport, ok := http.DefaultTransport.(*http.Transport); ok {
228+
return transport.Clone()
209229
}
210230

211-
caCertPool := x509.NewCertPool()
212-
if !caCertPool.AppendCertsFromPEM(caCert) {
213-
return errors.New("failed to parse Kubernetes CA certificate")
231+
// this should not happen, but if the user mutates the net/http.DefaultTransport
232+
// to something else, we fall back to a sane default
233+
return &http.Transport{
234+
ForceAttemptHTTP2: true,
235+
MaxIdleConns: 100,
236+
IdleConnTimeout: 90 * time.Second,
237+
TLSHandshakeTimeout: 10 * time.Second,
214238
}
239+
}()
240+
241+
tr := &identityBindingTransport{
242+
caFile: caFile,
243+
sniName: sniName,
244+
tokenEndpoint: tokenEndpoint,
245+
fallbackTransporter: fallbackTransporter,
246+
mtx: &sync.RWMutex{},
247+
transport: initialTransport,
215248
}
216249

217-
return nil
218-
}
219-
220-
// loadKubernetesCA loads and caches the Kubernetes CA certificate
221-
func (w *WorkloadIdentityCredential) loadKubernetesCA() (*x509.CertPool, error) {
222-
w.mtx.RLock()
223-
if w.caExpires.After(time.Now()) && w.caCertPool != nil {
224-
defer w.mtx.RUnlock()
225-
return w.caCertPool, nil
250+
// perform an initial load to surface any issues with the CA file and transport settings.
251+
// Lock is not held here as this is called in the constructor
252+
if err := tr.reloadCA(); err != nil {
253+
return nil, err
226254
}
227255

228-
// ensure only one goroutine at a time updates the CA cert
229-
w.mtx.RUnlock()
230-
w.mtx.Lock()
231-
defer w.mtx.Unlock()
256+
return tr, nil
257+
}
232258

233-
// double check because another goroutine may have acquired the write lock first and done the update
234-
if now := time.Now(); w.caExpires.After(now) && w.caCertPool != nil {
235-
return w.caCertPool, nil
259+
func (i *identityBindingTransport) Do(req *http.Request) (*http.Response, error) {
260+
if !strings.HasSuffix(req.URL.Path, tokenEndpointSuffix) {
261+
// not a token request, fallback to the original transporter
262+
return i.fallbackTransporter.Do(req)
236263
}
237264

238-
// Load the CA certificate for the Kubernetes endpoint
239-
caCert, err := os.ReadFile(w.kubernetesCAFile)
265+
tr, err := i.getTokenTransporter()
240266
if err != nil {
241-
return nil, fmt.Errorf("failed to read Kubernetes CA file: %w", err)
267+
return nil, err
242268
}
243269

244-
caCertPool := x509.NewCertPool()
245-
if !caCertPool.AppendCertsFromPEM(caCert) {
246-
return nil, errors.New("failed to parse Kubernetes CA certificate")
247-
}
270+
newReq := req.Clone(req.Context())
271+
newReq.URL.Scheme = i.tokenEndpoint.Scheme // this will always be https
272+
newReq.URL.Host = i.tokenEndpoint.Host
273+
newReq.Host = i.tokenEndpoint.Host
248274

249-
// Cache the CA certificate for 10 minutes (same as token assertion)
250-
w.caCert = caCert
251-
w.caCertPool = caCertPool
252-
w.caExpires = time.Now().Add(10 * time.Minute)
275+
return tr.RoundTrip(newReq)
276+
}
253277

254-
return caCertPool, nil
278+
func (i *identityBindingTransport) getTokenTransporter() (*http.Transport, error) {
279+
i.mtx.RLock()
280+
if i.nextRead.Before(time.Now()) {
281+
i.mtx.RUnlock()
282+
i.mtx.Lock()
283+
defer i.mtx.Unlock()
284+
// double check on the read time
285+
if now := time.Now(); i.nextRead.Before(now) {
286+
if err := i.reloadCA(); err != nil {
287+
// we return error if any attempt of reloading CA fails
288+
// This should surface in the token calls and we expect the caller to
289+
// have proper error handling / rate limit so we don't fall into deadloop here
290+
// due to scenario like broken CA file.
291+
return nil, err
292+
}
293+
}
294+
} else {
295+
defer i.mtx.RUnlock()
296+
}
297+
return i.transport, nil
255298
}
256299

257-
// identityBindingTransport is a custom HTTP transport that redirects token requests
258-
// to the Kubernetes token endpoint when in identity binding mode
259-
type identityBindingTransport struct {
260-
credential *WorkloadIdentityCredential
261-
kubernetesSNIName string
300+
func (i *identityBindingTransport) createTransportWithCAPool(
301+
fromTransport *http.Transport,
302+
caPool *x509.CertPool,
303+
) *http.Transport {
304+
transport := fromTransport.Clone()
305+
if transport.TLSClientConfig == nil {
306+
transport.TLSClientConfig = &tls.Config{}
307+
}
308+
transport.TLSClientConfig.ServerName = i.sniName
309+
transport.TLSClientConfig.RootCAs = caPool
310+
return transport
262311
}
263312

264-
func (t *identityBindingTransport) Do(req *http.Request) (*http.Response, error) {
265-
// Return early if this is not a token request
266-
if !strings.HasSuffix(req.URL.Path, "/oauth2/v2.0/token") {
267-
return http.DefaultTransport.RoundTrip(req)
313+
// reloadCA attempts to read the latest CA from the CA file and updates the transport if the content has changed.
314+
// If a new CA is discovered, the existing transport will be replaced with a new one that uses the new CA.
315+
// It expects the caller to hold the write lock on i.mtx to ensure thread safety.
316+
func (i *identityBindingTransport) reloadCA() error {
317+
newCA, err := os.ReadFile(i.caFile)
318+
if err != nil {
319+
return fmt.Errorf("read CA file %q: %w", i.caFile, err)
268320
}
269321

270-
// This is a token request, redirect to Kubernetes endpoint
322+
if len(newCA) == 0 {
323+
// the CA file might be in the middle of rotation without the content written.
324+
// We return nil and rely on next check.
325+
return nil
326+
}
271327

272-
// Load the CA certificate (this will use cached version if still valid)
273-
caCertPool, err := t.credential.loadKubernetesCA()
274-
if err != nil {
275-
return nil, err
328+
if bytes.Equal(i.currentCA, newCA) {
329+
// no change in CA content, no need to replace
330+
i.nextRead = time.Now().Add(caReloadInterval)
331+
return nil
276332
}
277333

278-
// Create custom transport with the CA and SNI configuration
279-
transport := &http.Transport{
280-
TLSClientConfig: &tls.Config{
281-
RootCAs: caCertPool,
282-
ServerName: t.kubernetesSNIName,
283-
},
334+
newCAPool := x509.NewCertPool()
335+
if !newCAPool.AppendCertsFromPEM(newCA) {
336+
return fmt.Errorf("parse CA file %q: no valid certificates found", i.caFile)
284337
}
285338

286-
// Clone the request to avoid modifying the original
287-
newReq := req.Clone(req.Context())
339+
newTransport := i.createTransportWithCAPool(i.transport, newCAPool)
340+
oldTransport := i.transport
288341

289-
// Update the URL to point to the Kubernetes endpoint
290-
newReq.URL.Scheme = t.credential.kubernetesTokenEndpoint.Scheme
291-
newReq.URL.Host = t.credential.kubernetesTokenEndpoint.Host
292-
newReq.Host = t.credential.kubernetesTokenEndpoint.Host
342+
i.transport = newTransport
343+
i.currentCA = newCA
344+
i.nextRead = time.Now().Add(caReloadInterval)
293345

294-
// Preserve the original path (contains tenant ID and token endpoint path)
295-
// The path should be something like "/tenant-id/oauth2/v2.0/token"
296-
// Keep the original path to maintain the token request structure
346+
if oldTransport != nil {
347+
// drop any idle connections from previous transport so new requests can be
348+
// moved to the new transport
349+
oldTransport.CloseIdleConnections()
350+
}
297351

298-
return transport.RoundTrip(newReq)
352+
return nil
299353
}

0 commit comments

Comments
 (0)