Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/api/registry/rest/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def list_data_sources(
data_sources = response.get("dataSources", [])

result = {
"data_sources": data_sources,
"dataSources": data_sources,
"pagination": response.get("pagination", {}),
}

Expand Down
24 changes: 21 additions & 3 deletions sdk/python/feast/api/registry/rest/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,28 @@ def get_entity(

result = entity

relationships = get_object_relationships(
grpc_handler, "entity", name, project, allow_cache
)
ds_list_req = RegistryServer_pb2.ListDataSourcesRequest(
project=project,
allow_cache=allow_cache,
)
ds_list_resp = grpc_call(grpc_handler.ListDataSources, ds_list_req)
ds_map = {ds["name"]: ds for ds in ds_list_resp.get("dataSources", [])}
data_source_objs = []
seen_ds_names = set()
for rel in relationships:
if rel.get("target", {}).get("type") == "dataSource":
ds_name = rel["target"]["name"]
if ds_name not in seen_ds_names:
ds_obj = ds_map.get(ds_name)
if ds_obj:
data_source_objs.append(ds_obj)
seen_ds_names.add(ds_name)
result["dataSources"] = data_source_objs

if include_relationships:
relationships = get_object_relationships(
grpc_handler, "entity", name, project, allow_cache
)
result["relationships"] = relationships

return result
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/feast/api/registry/rest/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ def list_features(
sorting=create_grpc_sorting_params(sorting_params),
)
response = grpc_call(grpc_handler.ListFeatures, req)
if "features" not in response:
response["features"] = []
if "pagination" not in response:
response["pagination"] = {}

if include_relationships:
features = response.get("features", [])
relationships = get_relationships_for_objects(
grpc_handler, response["features"], "feature", project, allow_cache
grpc_handler, features, "feature", project, allow_cache
)
response["relationships"] = relationships
return response
Expand Down
76 changes: 71 additions & 5 deletions sdk/python/tests/unit/api/test_api_rest_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ def test_feature_services_via_rest(fastapi_test_app):

def test_data_sources_via_rest(fastapi_test_app):
response = fastapi_test_app.get("/data_sources?project=demo_project")
assert response.status_code == 200
assert "data_sources" in response.json()
assert "dataSources" in response.json()
response = fastapi_test_app.get(
"/data_sources/user_profile_source?project=demo_project"
)
Expand Down Expand Up @@ -650,9 +649,9 @@ def test_data_sources_pagination_via_rest(fastapi_test_app_with_multiple_objects
response = client.get("/data_sources?project=demo_project&page=1&limit=2")
assert response.status_code == 200
data = response.json()
assert "data_sources" in data
assert "dataSources" in data
assert "pagination" in data
assert len(data["data_sources"]) == 2
assert len(data["dataSources"]) == 2
assert data["pagination"]["page"] == 1
assert data["pagination"]["limit"] == 2
assert data["pagination"]["totalCount"] == 3
Expand All @@ -669,7 +668,7 @@ def test_data_sources_sorting_via_rest(fastapi_test_app_with_multiple_objects):
)
assert response.status_code == 200
data = response.json()
ds_names = [ds["name"] for ds in data["data_sources"]]
ds_names = [ds["name"] for ds in data["dataSources"]]
assert ds_names == sorted(ds_names)


Expand Down Expand Up @@ -1064,3 +1063,70 @@ def test_lineage_complete_all_via_rest(fastapi_test_app):
assert "dataSources" in project_data["objects"]
assert "featureViews" in project_data["objects"]
assert "featureServices" in project_data["objects"]


def test_invalid_project_name_with_relationships_via_rest(fastapi_test_app):
"""Test REST API response with invalid project name using include_relationships=true.
The API should not throw 500 or any other error when an invalid project name is provided
with include_relationships=true parameter.
"""
response = fastapi_test_app.get(
"/entities?project=invalid_project_name&include_relationships=true"
)
assert response.status_code == 200
data = response.json()
assert "entities" in data
assert isinstance(data["entities"], list)
assert len(data["entities"]) == 0
assert "relationships" in data
assert isinstance(data["relationships"], dict)
assert len(data["relationships"]) == 0

response = fastapi_test_app.get(
"/feature_views?project=invalid_project_name&include_relationships=true"
)
assert response.status_code == 200
data = response.json()
assert "featureViews" in data
assert isinstance(data["featureViews"], list)
assert len(data["featureViews"]) == 0
assert "relationships" in data
assert isinstance(data["relationships"], dict)
assert len(data["relationships"]) == 0

response = fastapi_test_app.get(
"/data_sources?project=invalid_project_name&include_relationships=true"
)
# Should return 200 with empty results, not 500 or other errors
assert response.status_code == 200
data = response.json()
assert "dataSources" in data
assert isinstance(data["dataSources"], list)
assert len(data["dataSources"]) == 0
assert "relationships" in data
assert isinstance(data["relationships"], dict)
assert len(data["relationships"]) == 0

response = fastapi_test_app.get(
"/feature_services?project=invalid_project_name&include_relationships=true"
)
assert response.status_code == 200
data = response.json()
assert "featureServices" in data
assert isinstance(data["featureServices"], list)
assert len(data["featureServices"]) == 0
assert "relationships" in data
assert isinstance(data["relationships"], dict)
assert len(data["relationships"]) == 0

response = fastapi_test_app.get(
"/features?project=invalid_project_name&include_relationships=true"
)
assert response.status_code == 200
data = response.json()
assert "features" in data
assert isinstance(data["features"], list)
assert len(data["features"]) == 0
assert "relationships" in data
assert isinstance(data["relationships"], dict)
assert len(data["relationships"]) == 0
Loading