Skip to content

Commit aa168cc

Browse files
Copilotbcho
andcommitted
Address code review feedback: remove host checks, early return, extract identity binding configuration
Co-authored-by: bcho <[email protected]>
1 parent b693895 commit aa168cc

File tree

2 files changed

+91
-85
lines changed

2 files changed

+91
-85
lines changed

sdk/azidentity/workload_identity.go

Lines changed: 86 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -105,44 +105,9 @@ func NewWorkloadIdentityCredential(options *WorkloadIdentityCredentialOptions) (
105105

106106
w := WorkloadIdentityCredential{file: file, mtx: &sync.RWMutex{}}
107107

108-
// Check for identity binding mode environment variables
109-
kubernetesTokenEndpointStr := os.Getenv(azureKubernetesTokenEndpoint)
110-
kubernetesSNIName := os.Getenv(azureKubernetesSNIName)
111-
kubernetesCAFile := os.Getenv(azureKubernetesCAFile)
112-
113-
// If any of the identity binding environment variables are present, enable identity binding mode
114-
if kubernetesTokenEndpointStr != "" || kubernetesSNIName != "" || kubernetesCAFile != "" {
115-
// All three variables must be present for identity binding mode
116-
if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" {
117-
return nil, errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE")
118-
}
119-
120-
// Parse the Kubernetes token endpoint URL
121-
kubernetesTokenEndpoint, err := url.Parse(kubernetesTokenEndpointStr)
122-
if err != nil {
123-
return nil, fmt.Errorf("failed to parse Kubernetes token endpoint URL: %w", err)
124-
}
125-
126-
w.identityBinding = true
127-
w.kubernetesTokenEndpoint = kubernetesTokenEndpoint
128-
w.kubernetesSNIName = kubernetesSNIName
129-
w.kubernetesCAFile = kubernetesCAFile
130-
131-
// Validate CA file exists and is readable during construction
132-
if _, err := os.Stat(kubernetesCAFile); err != nil {
133-
return nil, fmt.Errorf("failed to read Kubernetes CA file: %w", err)
134-
}
135-
136-
// Validate CA certificate is parseable during construction
137-
caCert, err := os.ReadFile(kubernetesCAFile)
138-
if err != nil {
139-
return nil, fmt.Errorf("failed to read Kubernetes CA file: %w", err)
140-
}
141-
142-
caCertPool := x509.NewCertPool()
143-
if !caCertPool.AppendCertsFromPEM(caCert) {
144-
return nil, errors.New("failed to parse Kubernetes CA certificate")
145-
}
108+
// Configure identity binding if environment variables are present
109+
if err := w.configureIdentityBinding(); err != nil {
110+
return nil, err
146111
}
147112

148113
caco := ClientAssertionCredentialOptions{
@@ -207,24 +172,69 @@ func (w *WorkloadIdentityCredential) getAssertion(context.Context) (string, erro
207172
return w.assertion, nil
208173
}
209174

175+
// configureIdentityBinding configures identity binding mode if the required environment variables are present
176+
func (w *WorkloadIdentityCredential) configureIdentityBinding() error {
177+
// Check for identity binding mode environment variables
178+
kubernetesTokenEndpointStr := os.Getenv(azureKubernetesTokenEndpoint)
179+
kubernetesSNIName := os.Getenv(azureKubernetesSNIName)
180+
kubernetesCAFile := os.Getenv(azureKubernetesCAFile)
181+
182+
// If any of the identity binding environment variables are present, enable identity binding mode
183+
if kubernetesTokenEndpointStr != "" || kubernetesSNIName != "" || kubernetesCAFile != "" {
184+
// All three variables must be present for identity binding mode
185+
if kubernetesTokenEndpointStr == "" || kubernetesSNIName == "" || kubernetesCAFile == "" {
186+
return errors.New("identity binding mode requires all three environment variables: AZURE_KUBERNETES_TOKEN_ENDPOINT, AZURE_KUBERNETES_SNI_NAME, and AZURE_KUBERNETES_CA_FILE")
187+
}
188+
189+
// Parse the Kubernetes token endpoint URL
190+
kubernetesTokenEndpoint, err := url.Parse(kubernetesTokenEndpointStr)
191+
if err != nil {
192+
return fmt.Errorf("failed to parse Kubernetes token endpoint URL: %w", err)
193+
}
194+
195+
w.identityBinding = true
196+
w.kubernetesTokenEndpoint = kubernetesTokenEndpoint
197+
w.kubernetesSNIName = kubernetesSNIName
198+
w.kubernetesCAFile = kubernetesCAFile
199+
200+
// Validate CA file exists and is readable during construction
201+
if _, err := os.Stat(kubernetesCAFile); err != nil {
202+
return fmt.Errorf("failed to read Kubernetes CA file: %w", err)
203+
}
204+
205+
// Validate CA certificate is parseable during construction
206+
caCert, err := os.ReadFile(kubernetesCAFile)
207+
if err != nil {
208+
return fmt.Errorf("failed to read Kubernetes CA file: %w", err)
209+
}
210+
211+
caCertPool := x509.NewCertPool()
212+
if !caCertPool.AppendCertsFromPEM(caCert) {
213+
return errors.New("failed to parse Kubernetes CA certificate")
214+
}
215+
}
216+
217+
return nil
218+
}
219+
210220
// loadKubernetesCA loads and caches the Kubernetes CA certificate
211221
func (w *WorkloadIdentityCredential) loadKubernetesCA() (*x509.CertPool, error) {
212222
w.mtx.RLock()
213223
if w.caExpires.After(time.Now()) && w.caCertPool != nil {
214224
defer w.mtx.RUnlock()
215225
return w.caCertPool, nil
216226
}
217-
227+
218228
// ensure only one goroutine at a time updates the CA cert
219229
w.mtx.RUnlock()
220230
w.mtx.Lock()
221231
defer w.mtx.Unlock()
222-
232+
223233
// double check because another goroutine may have acquired the write lock first and done the update
224234
if now := time.Now(); w.caExpires.After(now) && w.caCertPool != nil {
225235
return w.caCertPool, nil
226236
}
227-
237+
228238
// Load the CA certificate for the Kubernetes endpoint
229239
caCert, err := os.ReadFile(w.kubernetesCAFile)
230240
if err != nil {
@@ -235,12 +245,12 @@ func (w *WorkloadIdentityCredential) loadKubernetesCA() (*x509.CertPool, error)
235245
if !caCertPool.AppendCertsFromPEM(caCert) {
236246
return nil, errors.New("failed to parse Kubernetes CA certificate")
237247
}
238-
248+
239249
// Cache the CA certificate for 10 minutes (same as token assertion)
240250
w.caCert = caCert
241251
w.caCertPool = caCertPool
242252
w.caExpires = time.Now().Add(10 * time.Minute)
243-
253+
244254
return caCertPool, nil
245255
}
246256

@@ -252,42 +262,38 @@ type identityBindingTransport struct {
252262
}
253263

254264
func (t *identityBindingTransport) Do(req *http.Request) (*http.Response, error) {
255-
// Check if this is a token request to the Azure authority host
256-
if strings.HasSuffix(req.URL.Path, "/oauth2/v2.0/token") && (req.URL.Host == "login.microsoftonline.com" ||
257-
req.URL.Host == "login.microsoftonline.us" ||
258-
req.URL.Host == "login.partner.microsoftonline.cn" ||
259-
req.URL.Host == "login.microsoftonline.de") {
260-
// This is a token request, redirect to Kubernetes endpoint
261-
262-
// Load the CA certificate (this will use cached version if still valid)
263-
caCertPool, err := t.credential.loadKubernetesCA()
264-
if err != nil {
265-
return nil, err
266-
}
267-
268-
// Create custom transport with the CA and SNI configuration
269-
transport := &http.Transport{
270-
TLSClientConfig: &tls.Config{
271-
RootCAs: caCertPool,
272-
ServerName: t.kubernetesSNIName,
273-
},
274-
}
275-
276-
// Clone the request to avoid modifying the original
277-
newReq := req.Clone(req.Context())
278-
279-
// Update the URL to point to the Kubernetes endpoint
280-
newReq.URL.Scheme = t.credential.kubernetesTokenEndpoint.Scheme
281-
newReq.URL.Host = t.credential.kubernetesTokenEndpoint.Host
282-
newReq.Host = t.credential.kubernetesTokenEndpoint.Host
283-
284-
// Preserve the original path (contains tenant ID and token endpoint path)
285-
// The path should be something like "/tenant-id/oauth2/v2.0/token"
286-
// Keep the original path to maintain the token request structure
287-
288-
return transport.RoundTrip(newReq)
265+
// Return early if this is not a token request
266+
if !strings.HasSuffix(req.URL.Path, "/oauth2/v2.0/token") {
267+
return http.DefaultTransport.RoundTrip(req)
289268
}
290-
291-
// For non-token requests, use the default transport
292-
return http.DefaultTransport.RoundTrip(req)
269+
270+
// This is a token request, redirect to Kubernetes endpoint
271+
272+
// Load the CA certificate (this will use cached version if still valid)
273+
caCertPool, err := t.credential.loadKubernetesCA()
274+
if err != nil {
275+
return nil, err
276+
}
277+
278+
// Create custom transport with the CA and SNI configuration
279+
transport := &http.Transport{
280+
TLSClientConfig: &tls.Config{
281+
RootCAs: caCertPool,
282+
ServerName: t.kubernetesSNIName,
283+
},
284+
}
285+
286+
// Clone the request to avoid modifying the original
287+
newReq := req.Clone(req.Context())
288+
289+
// Update the URL to point to the Kubernetes endpoint
290+
newReq.URL.Scheme = t.credential.kubernetesTokenEndpoint.Scheme
291+
newReq.URL.Host = t.credential.kubernetesTokenEndpoint.Host
292+
newReq.Host = t.credential.kubernetesTokenEndpoint.Host
293+
294+
// Preserve the original path (contains tenant ID and token endpoint path)
295+
// The path should be something like "/tenant-id/oauth2/v2.0/token"
296+
// Keep the original path to maintain the token request structure
297+
298+
return transport.RoundTrip(newReq)
293299
}

sdk/azidentity/workload_identity_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ kAF/AQQQAAQBQAAkCNADCBADI==
328328
if err != nil {
329329
t.Fatal(err)
330330
}
331-
331+
332332
mockCredential := &WorkloadIdentityCredential{
333333
identityBinding: true,
334334
kubernetesTokenEndpoint: kubernetesURL,
@@ -343,7 +343,7 @@ kAF/AQQQAAQBQAAkCNADCBADI==
343343
kubernetesSNIName: "kubernetes.default.svc",
344344
}
345345

346-
// Test with Azure authority host request (should be redirected)
346+
// Test with token endpoint request (should be redirected)
347347
req := &http.Request{
348348
Method: "POST",
349349
Header: make(http.Header),
@@ -362,7 +362,7 @@ kAF/AQQQAAQBQAAkCNADCBADI==
362362
t.Error("expected an error due to network connectivity")
363363
}
364364

365-
// Test with non-Azure host request (should not be redirected)
365+
// Test with non-token request (should not be redirected)
366366
req2 := &http.Request{
367367
Method: "GET",
368368
Header: make(http.Header),
@@ -405,7 +405,7 @@ func TestWorkloadIdentityCredential_IdentityBinding_Success(t *testing.T) {
405405
if err != nil {
406406
t.Skipf("test certificate not found, run: cd /tmp/test-certs && openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -days 1 -nodes -subj \"/C=US/ST=Test/L=Test/O=Test/OU=Test/CN=kubernetes.default.svc\"")
407407
}
408-
408+
409409
if err := os.WriteFile(caFile, caCert, os.ModePerm); err != nil {
410410
t.Fatalf("failed to write CA file: %v", err)
411411
}
@@ -479,7 +479,7 @@ klmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890abcdefghij
479479
klmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890abcdefghij
480480
klmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890abcdefghij
481481
-----END CERTIFICATE-----`
482-
482+
483483
if err := os.WriteFile(caFile, []byte(caCert), os.ModePerm); err != nil {
484484
t.Fatalf("failed to write CA file: %v", err)
485485
}

0 commit comments

Comments
 (0)