Skip to content

Commit af9f404

Browse files
catsillermvlipka
andauthored
feat: added RequireAllEnvironmentAccess middleware and applied to necessary endpoints BED-6228 (#2093)
* added RequireAllEnvironmentAccess middleware and added it to supported endpoints * chore: remove some RequireAllEnvironmentAccess middleware from a few routes * added RequireAllEnvironmentAccess middleware and added it to supported endpoints * chore: update middleware to another endpoint * chore(endpoints): add all env middleware to a couple asset group posts * chore: remove middleware from asset groups endpoint * BED-6828 fix for all_environments not being sent on updating/creating an admin * chore: removed middleware from api endpoints per pr review feedback * chore(prfeedback): added middleware on endpoint per feedback --------- Co-authored-by: Michael Lipka <[email protected]>
1 parent dc67a27 commit af9f404

File tree

5 files changed

+151
-16
lines changed

5 files changed

+151
-16
lines changed

cmd/api/src/api/middleware/etac.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,27 @@ func SupportsETACMiddleware(db database.Database) mux.MiddlewareFunc {
5757
}
5858
}
5959

60+
// RequireAllEnvironmentAccessMiddleware will check if a user's all environments flag is true and return a forbidden response code if set to false
61+
func RequireAllEnvironmentAccessMiddleware(db database.Database) mux.MiddlewareFunc {
62+
return func(next http.Handler) http.Handler {
63+
return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
64+
if etacFlag, err := db.GetFlagByKey(request.Context(), appcfg.FeatureETAC); err != nil {
65+
api.HandleDatabaseError(request, response, err)
66+
} else if !etacFlag.Enabled {
67+
next.ServeHTTP(response, request)
68+
} else if bhCtx := ctx.FromRequest(request); !bhCtx.AuthCtx.Authenticated() {
69+
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnauthorized, "not authenticated", request), response)
70+
} else if currentUser, found := auth.GetUserFromAuthCtx(bhCtx.AuthCtx); !found {
71+
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "no associated user found with request", request), response)
72+
} else if currentUser.AllEnvironments {
73+
next.ServeHTTP(response, request)
74+
} else {
75+
api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "user does not have access to this resource", request), response)
76+
}
77+
})
78+
}
79+
}
80+
6081
// getEnvironmentIdFromRequest will pull the environment id from the request's path variables where the environment id can be equal to an objectid, tenantid, or domainsid
6182
func getEnvironmentIdFromRequest(request *http.Request) (string, error) {
6283
if domainSID, hasDomainSID := mux.Vars(request)[api.URIPathVariableDomainID]; hasDomainSID {

cmd/api/src/api/middleware/etac_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,99 @@ func TestSupportsETACMiddleware(t *testing.T) {
186186
})
187187
}
188188
}
189+
190+
func TestRequireAllEnvironmentAccessMiddleware(t *testing.T) {
191+
192+
var (
193+
mockCtrl = gomock.NewController(t)
194+
mockDB = mocks.NewMockDatabase(mockCtrl)
195+
)
196+
197+
defer mockCtrl.Finish()
198+
199+
tests := []struct {
200+
name string
201+
setupMocks func()
202+
bhCtx ctx.Context
203+
expectedCode int
204+
expectNextHit bool
205+
}{
206+
{
207+
name: "Success feature flag disabled",
208+
setupMocks: func() {
209+
mockDB.EXPECT().
210+
GetFlagByKey(gomock.Any(), appcfg.FeatureETAC).
211+
Return(appcfg.FeatureFlag{Enabled: false}, nil)
212+
},
213+
expectedCode: http.StatusOK,
214+
expectNextHit: true,
215+
},
216+
{
217+
name: "Success All Environments enabled",
218+
setupMocks: func() {
219+
mockDB.EXPECT().
220+
GetFlagByKey(gomock.Any(), appcfg.FeatureETAC).
221+
Return(appcfg.FeatureFlag{Enabled: true}, nil)
222+
},
223+
bhCtx: ctx.Context{
224+
AuthCtx: auth.Context{
225+
PermissionOverrides: auth.PermissionOverrides{},
226+
Owner: model.User{
227+
AllEnvironments: true,
228+
EnvironmentTargetedAccessControl: nil,
229+
},
230+
Session: model.UserSession{},
231+
},
232+
},
233+
expectedCode: http.StatusOK,
234+
expectNextHit: true,
235+
},
236+
{
237+
name: "Fail If All Environments is false",
238+
setupMocks: func() {
239+
mockDB.EXPECT().
240+
GetFlagByKey(gomock.Any(), appcfg.FeatureETAC).
241+
Return(appcfg.FeatureFlag{Enabled: true}, nil)
242+
},
243+
bhCtx: ctx.Context{
244+
AuthCtx: auth.Context{
245+
PermissionOverrides: auth.PermissionOverrides{},
246+
Owner: model.User{
247+
AllEnvironments: false,
248+
EnvironmentTargetedAccessControl: nil,
249+
},
250+
Session: model.UserSession{},
251+
},
252+
},
253+
expectedCode: http.StatusForbidden,
254+
expectNextHit: false,
255+
},
256+
}
257+
258+
for _, tt := range tests {
259+
t.Run(tt.name, func(t *testing.T) {
260+
tt.setupMocks()
261+
262+
nextHit := false
263+
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264+
nextHit = true
265+
w.WriteHeader(http.StatusOK)
266+
})
267+
268+
handler := RequireAllEnvironmentAccessMiddleware(mockDB)(next)
269+
270+
req := httptest.NewRequestWithContext(context.Background(), http.MethodGet, "/test/12345", nil)
271+
req = ctx.SetRequestContext(req, &tt.bhCtx)
272+
req = mux.SetURLVars(req, map[string]string{
273+
api.URIPathVariableObjectID: "12345",
274+
})
275+
276+
rec := httptest.NewRecorder()
277+
278+
handler.ServeHTTP(rec, req)
279+
280+
assert.Equal(t, tt.expectedCode, rec.Code)
281+
assert.Equal(t, tt.expectNextHit, nextHit)
282+
})
283+
}
284+
}

cmd/api/src/api/registration/v2.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) {
161161
routerInst.GET("/api/v2/asset-groups", resources.ListAssetGroups).RequirePermissions(permissions.GraphDBRead),
162162
routerInst.POST("/api/v2/asset-groups", resources.CreateAssetGroup).RequirePermissions(permissions.GraphDBWrite),
163163
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}", api.URIPathVariableAssetGroupID), resources.GetAssetGroup).RequirePermissions(permissions.GraphDBRead),
164-
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/custom-selectors", api.URIPathVariableAssetGroupID), resources.GetAssetGroupCustomMemberCount).RequirePermissions(permissions.GraphDBRead),
164+
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/custom-selectors", api.URIPathVariableAssetGroupID), resources.GetAssetGroupCustomMemberCount).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
165165
routerInst.DELETE(fmt.Sprintf("/api/v2/asset-groups/{%s}", api.URIPathVariableAssetGroupID), resources.DeleteAssetGroup).RequirePermissions(permissions.GraphDBWrite),
166166
routerInst.PUT(fmt.Sprintf("/api/v2/asset-groups/{%s}", api.URIPathVariableAssetGroupID), resources.UpdateAssetGroup).RequirePermissions(permissions.GraphDBWrite),
167167
routerInst.DELETE(fmt.Sprintf("/api/v2/asset-groups/{%s}/selectors/{%s}", api.URIPathVariableAssetGroupID, api.URIPathVariableAssetGroupSelectorID), resources.DeleteAssetGroupSelector).RequirePermissions(permissions.GraphDBWrite),
168-
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/collections", api.URIPathVariableAssetGroupID), resources.ListAssetGroupCollections).RequirePermissions(permissions.GraphDBRead),
169-
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/members", api.URIPathVariableAssetGroupID), resources.ListAssetGroupMembers).RequirePermissions(permissions.GraphDBRead),
170-
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/members/counts", api.URIPathVariableAssetGroupID), resources.ListAssetGroupMemberCountsByKind).RequirePermissions(permissions.GraphDBRead),
168+
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/collections", api.URIPathVariableAssetGroupID), resources.ListAssetGroupCollections).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
169+
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/members", api.URIPathVariableAssetGroupID), resources.ListAssetGroupMembers).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
170+
routerInst.GET(fmt.Sprintf("/api/v2/asset-groups/{%s}/members/counts", api.URIPathVariableAssetGroupID), resources.ListAssetGroupMemberCountsByKind).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
171171
routerInst.PUT(fmt.Sprintf("/api/v2/asset-groups/{%s}/selectors", api.URIPathVariableAssetGroupID), resources.UpdateAssetGroupSelectors).RequirePermissions(permissions.GraphDBWrite),
172172
// DEPRECATED: this has been changed to a PUT endpoint above, and must be removed for API V3
173173
routerInst.POST(fmt.Sprintf("/api/v2/asset-groups/{%s}/selectors", api.URIPathVariableAssetGroupID), resources.UpdateAssetGroupSelectors).RequirePermissions(permissions.GraphDBWrite),
@@ -176,29 +176,29 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) {
176176
// tags
177177
routerInst.GET("/api/v2/asset-group-tags", resources.GetAssetGroupTags).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
178178
routerInst.PATCH(fmt.Sprintf("/api/v2/asset-group-tags/{%s}", api.URIPathVariableAssetGroupTagID), resources.UpdateAssetGroupTag).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBWrite),
179-
routerInst.POST("/api/v2/asset-group-tags/search", resources.SearchAssetGroupTags).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
179+
routerInst.POST("/api/v2/asset-group-tags/search", resources.SearchAssetGroupTags).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
180180
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTag).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
181-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupMembersByTag).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
182-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members/counts", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTagMemberCountsByKind).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
183-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagMemberID), resources.GetAssetGroupTagMemberInfo).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
181+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupMembersByTag).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
182+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members/counts", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTagMemberCountsByKind).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
183+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/members/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagMemberID), resources.GetAssetGroupTagMemberInfo).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
184184

185185
// selectors
186-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTagSelectors).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
186+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors", api.URIPathVariableAssetGroupTagID), resources.GetAssetGroupTagSelectors).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
187187
routerInst.POST(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors", api.URIPathVariableAssetGroupTagID), resources.CreateAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBWrite),
188-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.GetAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
188+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.GetAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
189189
routerInst.PATCH(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.UpdateAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBWrite),
190190
routerInst.DELETE(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.DeleteAssetGroupTagSelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBWrite),
191-
routerInst.POST("/api/v2/asset-group-tags/preview-selectors", resources.PreviewSelectors).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
192-
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}/members", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.GetAssetGroupMembersBySelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
191+
routerInst.POST("/api/v2/asset-group-tags/preview-selectors", resources.PreviewSelectors).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
192+
routerInst.GET(fmt.Sprintf("/api/v2/asset-group-tags/{%s}/selectors/{%s}/members", api.URIPathVariableAssetGroupTagID, api.URIPathVariableAssetGroupTagSelectorID), resources.GetAssetGroupMembersBySelector).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
193193

194194
// history
195-
routerInst.GET("/api/v2/asset-group-tags-history", resources.GetAssetGroupTagHistory).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
196-
routerInst.POST("/api/v2/asset-group-tags-history", resources.SearchAssetGroupTagHistory).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead),
195+
routerInst.GET("/api/v2/asset-group-tags-history", resources.GetAssetGroupTagHistory).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
196+
routerInst.POST("/api/v2/asset-group-tags-history", resources.SearchAssetGroupTagHistory).CheckFeatureFlag(resources.DB, appcfg.FeatureTierManagement).RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
197197

198198
// QA API
199199
routerInst.GET("/api/v2/completeness", resources.GetDatabaseCompleteness).RequirePermissions(permissions.GraphDBRead),
200200

201-
routerInst.GET("/api/v2/pathfinding", resources.GetPathfindingResult).Queries("start_node", "{start_node}", "end_node", "{end_node}").RequirePermissions(permissions.GraphDBRead),
201+
routerInst.GET("/api/v2/pathfinding", resources.GetPathfindingResult).Queries("start_node", "{start_node}", "end_node", "{end_node}").RequirePermissions(permissions.GraphDBRead).RequireAllEnvironmentAccess(resources.DB),
202202
routerInst.GET("/api/v2/graphs/kinds", resources.ListKinds).RequirePermissions(permissions.GraphDBRead),
203203
routerInst.GET("/api/v2/graphs/source-kinds", resources.ListSourceKinds).RequirePermissions(permissions.GraphDBRead),
204204
routerInst.GET("/api/v2/graphs/shortest-path", resources.GetShortestPath).Queries(params.StartNode.String(), params.StartNode.RouteMatcher(), params.EndNode.String(), params.EndNode.RouteMatcher()).RequirePermissions(permissions.GraphDBRead),

cmd/api/src/api/router/router.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ func (s *Route) SupportsETAC(db database.Database) *Route {
9696
return s
9797
}
9898

99+
func (s *Route) RequireAllEnvironmentAccess(db database.Database) *Route {
100+
s.handler.Use(middleware.RequireAllEnvironmentAccessMiddleware(db))
101+
return s
102+
}
103+
99104
func (s *Route) CheckFeatureFlag(db database.Database, flagKey string) *Route {
100105
s.handler.Use(middleware.FeatureFlagMiddleware(db, flagKey))
101106
return s

cmd/api/src/api/v2/auth/etac.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,20 @@ import (
3939
func handleETACRequest(ctx context.Context, updateUserRequest v2.UpdateUserRequest, roles model.Roles, user *model.User, graphDB queries.Graph) error {
4040
if updateUserRequest.AllEnvironments.Valid || updateUserRequest.EnvironmentTargetedAccessControl != nil {
4141
// Admin / Power Users can only have all_environments set to true
42-
if (roles.Has(model.Role{Name: auth.RoleAdministrator}) || roles.Has(model.Role{Name: auth.RolePowerUser})) &&
42+
if (!hasValidRolesForETAC(roles)) &&
4343
(!updateUserRequest.AllEnvironments.Bool || (updateUserRequest.EnvironmentTargetedAccessControl != nil && len(updateUserRequest.EnvironmentTargetedAccessControl.Environments) > 0)) {
4444
return errors.New(api.ErrorResponseETACInvalidRoles)
4545
}
4646
user.AllEnvironments = updateUserRequest.AllEnvironments.Bool
4747
}
4848

49+
// This will force admins to have valid defaults in the event that a user does not send all_environments or explore_enabled
50+
if !updateUserRequest.AllEnvironments.Valid {
51+
if !hasValidRolesForETAC(roles) {
52+
user.AllEnvironments = true
53+
}
54+
}
55+
4956
if updateUserRequest.EnvironmentTargetedAccessControl == nil || len(updateUserRequest.EnvironmentTargetedAccessControl.Environments) == 0 {
5057
user.EnvironmentTargetedAccessControl = make([]model.EnvironmentTargetedAccessControl, 0)
5158
return nil
@@ -101,3 +108,9 @@ func nodeSetToObjectIDMap(set graph.NodeSet) (map[string]bool, error) {
101108

102109
return objectIDs, nil
103110
}
111+
112+
// hasValidRolesForETAC will check the passed in roles to determine if a user can have ETAC controls applied to theem
113+
// returning true if they may be ETAC controlled and false if they may not
114+
func hasValidRolesForETAC(roles model.Roles) bool {
115+
return !(roles.Has(model.Role{Name: auth.RoleAdministrator}) || roles.Has(model.Role{Name: auth.RolePowerUser}))
116+
}

0 commit comments

Comments
 (0)