Skip to content

Commit 3be5afa

Browse files
committed
added context chaining, cleanup
1 parent 0fef952 commit 3be5afa

File tree

4 files changed

+53
-51
lines changed

4 files changed

+53
-51
lines changed

pkg/ec2provider/ec2provider.go

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ type EC2API interface {
5151
// Get a node name from instance ID
5252
type EC2Provider interface {
5353
GetPrivateDNSName(ctx context.Context, id string) (string, error)
54-
StartEc2DescribeBatchProcessing()
54+
StartEc2DescribeBatchProcessing(ctx context.Context)
5555
}
5656

5757
type ec2PrivateDNSCache struct {
@@ -114,10 +114,6 @@ func newEC2Client(ctx context.Context, roleARN, sourceARN, region string, qps in
114114
stsCfg := cfg
115115
stsCfg.HTTPClient = rateLimitedClient
116116

117-
if err != nil {
118-
logrus.Fatalf("Failed to load AWS config: %v", err)
119-
}
120-
121117
stsClient := sts.NewFromConfig(*applySTSRequestHeaders(&stsCfg, sourceARN))
122118
ap := stscreds.NewAssumeRoleProvider(stsClient, roleARN, func(o *stscreds.AssumeRoleOptions) {
123119
o.Duration = time.Duration(60) * time.Minute
@@ -228,7 +224,7 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(ctx context.Context, id string) (str
228224
return privateDNSName, nil
229225
}
230226

231-
func (p *ec2ProviderImpl) StartEc2DescribeBatchProcessing() {
227+
func (p *ec2ProviderImpl) StartEc2DescribeBatchProcessing(ctx context.Context) {
232228
startTime := time.Now()
233229
var instanceIdList []string
234230
for {
@@ -256,17 +252,17 @@ func (p *ec2ProviderImpl) StartEc2DescribeBatchProcessing() {
256252
startTime = time.Now()
257253
dupInstanceList := make([]string, len(instanceIdList))
258254
copy(dupInstanceList, instanceIdList)
259-
go p.getPrivateDnsAndPublishToCache(dupInstanceList)
255+
go p.getPrivateDnsAndPublishToCache(ctx, dupInstanceList)
260256
instanceIdList = nil
261257
}
262258
}
263259
}
264260

265-
func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string) {
261+
func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(ctx context.Context, instanceIdList []string) {
266262
// Look up instance from EC2 API
267263
logrus.Infof("Making Batch Query to DescribeInstances for %v instances ", len(instanceIdList))
268264
metrics.Get().EC2DescribeInstanceCallCount.Inc()
269-
output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{
265+
output, err := p.ec2.DescribeInstances(ctx, &ec2.DescribeInstancesInput{
270266
InstanceIds: instanceIdList,
271267
})
272268
if err != nil {

pkg/ec2provider/ec2provider_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestGetPrivateDNSName(t *testing.T) {
7070
metrics.InitMetrics(prometheus.NewRegistry())
7171
ec2Provider := newMockedEC2ProviderImpl()
7272
ec2Provider.ec2 = &mockEc2Client{Reservations: prepareSingleInstanceOutput()}
73-
go ec2Provider.StartEc2DescribeBatchProcessing()
73+
go ec2Provider.StartEc2DescribeBatchProcessing(context.TODO())
7474
dns_name, err := ec2Provider.GetPrivateDNSName(context.TODO(), "ec2-1")
7575
if err != nil {
7676
t.Error("There is an error which is not expected when calling ec2 API with setting up mocks")
@@ -103,7 +103,7 @@ func TestGetPrivateDNSNameWithBatching(t *testing.T) {
103103
ec2Provider := newMockedEC2ProviderImpl()
104104
reservations := prepare100InstanceOutput()
105105
ec2Provider.ec2 = &mockEc2Client{Reservations: reservations}
106-
go ec2Provider.StartEc2DescribeBatchProcessing()
106+
go ec2Provider.StartEc2DescribeBatchProcessing(context.TODO())
107107
var wg sync.WaitGroup
108108
for i := 1; i < 101; i++ {
109109
instanceString := "ec2-" + strconv.Itoa(i)
@@ -206,20 +206,24 @@ func TestApplySTSRequestHeaders(t *testing.T) {
206206
{
207207
name: "header with source arn",
208208
args: map[string]string{
209-
headerSourceArn: "arn:aws:eks:us-east-1:123456789012:MyCluster/res1",
209+
headerSourceAccount: "123456789012",
210+
headerSourceArn: "arn:aws:eks:us-east-1:123456789012:MyCluster/res1",
210211
},
211212
want: map[string]string{
212-
headerSourceArn: "arn:aws:eks:us-east-1:123456789012:MyCluster/res1",
213+
headerSourceAccount: "123456789012",
214+
headerSourceArn: "arn:aws:eks:us-east-1:123456789012:MyCluster/res1",
213215
},
214216
wantErr: false,
215217
},
216218
{
217219
name: "header without source arn",
218220
args: map[string]string{
219-
headerSourceArn: "",
221+
headerSourceAccount: "123456789012",
222+
headerSourceArn: "",
220223
},
221224
want: map[string]string{
222-
headerSourceArn: "",
225+
headerSourceAccount: "",
226+
headerSourceArn: "",
223227
},
224228
wantErr: false,
225229
},

pkg/server/server.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,15 @@ func (c *Server) getHandler(ctx context.Context, backendMapper BackendMapper, ec
223223
backendModeConfigInitDone: false,
224224
}
225225

226-
h.HandleFunc("/authenticate", h.authenticateEndpoint)
226+
h.HandleFunc("/authenticate", func(w http.ResponseWriter, r *http.Request) {
227+
h.authenticateEndpoint(ctx, w, r)
228+
})
227229
h.Handle("/metrics", promhttp.Handler())
228230
h.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
229231
fmt.Fprintf(w, "ok")
230232
})
231233
logrus.Infof("Starting the h.ec2Provider.startEc2DescribeBatchProcessing ")
232-
go h.ec2Provider.StartEc2DescribeBatchProcessing()
234+
go h.ec2Provider.StartEc2DescribeBatchProcessing(ctx)
233235
if strings.TrimSpace(c.DynamicBackendModePath) != "" {
234236
fileutil.StartLoadDynamicFile(c.DynamicBackendModePath, h, stopCh)
235237
}
@@ -303,7 +305,7 @@ func (h *handler) isLoggableIdentity(identity *token.Identity) bool {
303305
return true
304306
}
305307

306-
func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request) {
308+
func (h *handler) authenticateEndpoint(ctx context.Context, w http.ResponseWriter, req *http.Request) {
307309
start := time.Now()
308310
log := logrus.WithFields(logrus.Fields{
309311
"path": req.URL.Path,
@@ -372,7 +374,7 @@ func (h *handler) authenticateEndpoint(w http.ResponseWriter, req *http.Request)
372374
log = log.WithField("arn", identity.CanonicalARN)
373375
}
374376

375-
username, groups, err := h.doMapping(identity)
377+
username, groups, err := h.doMapping(ctx, identity)
376378
if err != nil {
377379
metrics.Get().Latency.WithLabelValues(metrics.Unknown).Observe(duration(start))
378380
log.WithError(err).Warn("access denied")
@@ -429,14 +431,14 @@ func ReservedPrefixExists(username string, reservedList []string) bool {
429431
return false
430432
}
431433

432-
func (h *handler) doMapping(identity *token.Identity) (string, []string, error) {
434+
func (h *handler) doMapping(ctx context.Context, identity *token.Identity) (string, []string, error) {
433435
var errs []error
434436

435437
for _, m := range h.backendMapper.mappers {
436438
mapping, err := m.Map(identity)
437439
if err == nil {
438440
// Mapping found, try to render any templates like {{EC2PrivateDNSName}}
439-
username, groups, err := h.renderTemplates(*mapping, identity)
441+
username, groups, err := h.renderTemplates(ctx, *mapping, identity)
440442
if err != nil {
441443
return "", nil, fmt.Errorf("mapper %s renderTemplates error: %v", m.Name(), err)
442444
}
@@ -461,19 +463,19 @@ func (h *handler) doMapping(identity *token.Identity) (string, []string, error)
461463
return "", nil, errutil.ErrNotMapped
462464
}
463465

464-
func (h *handler) renderTemplates(mapping config.IdentityMapping, identity *token.Identity) (string, []string, error) {
466+
func (h *handler) renderTemplates(ctx context.Context, mapping config.IdentityMapping, identity *token.Identity) (string, []string, error) {
465467
var username string
466468
groups := []string{}
467469
var err error
468470

469471
userPattern := mapping.Username
470-
username, err = h.renderTemplate(userPattern, identity)
472+
username, err = h.renderTemplate(ctx, userPattern, identity)
471473
if err != nil {
472474
return "", nil, fmt.Errorf("error rendering username template %q: %s", userPattern, err.Error())
473475
}
474476

475477
for _, groupPattern := range mapping.Groups {
476-
group, err := h.renderTemplate(groupPattern, identity)
478+
group, err := h.renderTemplate(ctx, groupPattern, identity)
477479
if err != nil {
478480
return "", nil, fmt.Errorf("error rendering group template %q: %s", groupPattern, err.Error())
479481
}
@@ -483,13 +485,13 @@ func (h *handler) renderTemplates(mapping config.IdentityMapping, identity *toke
483485
return username, groups, nil
484486
}
485487

486-
func (h *handler) renderTemplate(template string, identity *token.Identity) (string, error) {
488+
func (h *handler) renderTemplate(ctx context.Context, template string, identity *token.Identity) (string, error) {
487489
// Private DNS requires EC2 API call
488490
if strings.Contains(template, "{{EC2PrivateDNSName}}") {
489491
if !instanceIDPattern.MatchString(identity.SessionName) {
490492
return "", fmt.Errorf("SessionName did not contain an instance id")
491493
}
492-
privateDNSName, err := h.ec2Provider.GetPrivateDNSName(context.Background(), identity.SessionName)
494+
privateDNSName, err := h.ec2Provider.GetPrivateDNSName(ctx, identity.SessionName)
493495
if err != nil {
494496
return "", err
495497
}

0 commit comments

Comments
 (0)