Skip to content

Commit 32375a5

Browse files
authored
fix: Embed Query configuration breaks when switching between DataFrame and SQL (feast-dev#5257)
* fix: Embed Query configuration breaks when switching between DataFrame and SQL Signed-off-by: Blake <[email protected]> * liting fix Signed-off-by: Blake <[email protected]> --------- Signed-off-by: Blake <[email protected]>
1 parent 6770ee6 commit 32375a5

File tree

2 files changed

+97
-12
lines changed
  • sdk/python
    • feast/infra/offline_stores/contrib/postgres_offline_store
    • tests/unit/infra/offline_stores/contrib/postgres_offline_store

2 files changed

+97
-12
lines changed

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ def query_generator() -> Iterator[str]:
144144
table_name = offline_utils.get_temp_entity_table_name()
145145

146146
# If using CTE and entity_df is a SQL query, we don't need a table
147-
if config.offline_store.entity_select_mode == EntitySelectMode.embed_query:
148-
if isinstance(entity_df, str):
149-
left_table_query_string = entity_df
150-
else:
151-
raise ValueError(
152-
f"Invalid entity select mode: {config.offline_store.entity_select_mode} cannot be used with entity_df as a DataFrame"
153-
)
147+
use_cte = (
148+
isinstance(entity_df, str)
149+
and config.offline_store.entity_select_mode
150+
== EntitySelectMode.embed_query
151+
)
152+
if use_cte:
153+
left_table_query_string = entity_df
154154
else:
155155
left_table_query_string = table_name
156156
_upload_entity_df(config, entity_df, table_name)
@@ -187,7 +187,7 @@ def query_generator() -> Iterator[str]:
187187
entity_df_columns=entity_schema.keys(),
188188
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
189189
full_feature_names=full_feature_names,
190-
entity_select_mode=config.offline_store.entity_select_mode,
190+
use_cte=use_cte,
191191
)
192192
finally:
193193
# Only cleanup if we created a table
@@ -386,7 +386,7 @@ def build_point_in_time_query(
386386
entity_df_columns: KeysView[str],
387387
query_template: str,
388388
full_feature_names: bool = False,
389-
entity_select_mode: EntitySelectMode = EntitySelectMode.temp_table,
389+
use_cte: bool = False,
390390
) -> str:
391391
"""Build point-in-time query between each feature view table and the entity dataframe for PostgreSQL"""
392392
template = Environment(loader=BaseLoader()).from_string(source=query_template)
@@ -414,7 +414,7 @@ def build_point_in_time_query(
414414
"featureviews": feature_view_query_contexts,
415415
"full_feature_names": full_feature_names,
416416
"final_output_feature_names": final_output_feature_names,
417-
"entity_select_mode": entity_select_mode.value,
417+
"use_cte": use_cte,
418418
}
419419

420420
query = template.render(template_context)
@@ -456,7 +456,7 @@ def _get_entity_schema(
456456

457457
MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
458458
WITH
459-
{% if entity_select_mode == "embed_query" %}
459+
{% if use_cte %}
460460
entity_query AS ({{ left_table_query_string }}),
461461
{% endif %}
462462
/*
@@ -479,15 +479,17 @@ def _get_entity_schema(
479479
{% endif %}
480480
{% endfor %}
481481
FROM
482-
{% if entity_select_mode == "embed_query" %}
482+
{% if use_cte %}
483483
entity_query
484484
{% else %}
485485
{{ left_table_query_string }}
486486
{% endif %}
487487
)
488+
488489
{% if featureviews | length > 0 %}
489490
,
490491
{% endif %}
492+
491493
{% for featureview in featureviews %}
492494
493495
"{{ featureview.name }}__entity_dataframe" AS (

sdk/python/tests/unit/infra/offline_stores/contrib/postgres_offline_store/test_postgres.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,89 @@ def test_get_historical_features_entity_select_modes_embed_query(
355355
assert True # If we get here, the SQL is valid
356356

357357

358+
@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
359+
@patch(
360+
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table"
361+
)
362+
@patch(
363+
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.get_query_schema"
364+
)
365+
def test_get_historical_features_entity_select_modes_embed_query_with_dataframe(
366+
mock_get_query_schema, mock_df_to_postgres_table, mock_get_conn
367+
):
368+
mock_conn = MagicMock()
369+
mock_get_conn.return_value.__enter__.return_value = mock_conn
370+
371+
# Mock the query schema to return a simple schema
372+
mock_get_query_schema.return_value = {
373+
"event_timestamp": pd.Timestamp,
374+
"driver_id": pd.Int64Dtype(),
375+
}
376+
377+
test_repo_config = RepoConfig(
378+
project="test_project",
379+
registry="test_registry",
380+
provider="local",
381+
offline_store=PostgreSQLOfflineStoreConfig(
382+
type="postgres",
383+
host="localhost",
384+
port=5432,
385+
database="test_db",
386+
db_schema="public",
387+
user="test_user",
388+
password="test_password",
389+
entity_select_mode="embed_query",
390+
),
391+
)
392+
393+
test_data_source = PostgreSQLSource(
394+
name="test_batch_source",
395+
description="test_batch_source",
396+
table="offline_store_database_name.offline_store_table_name",
397+
timestamp_field="event_published_datetime_utc",
398+
)
399+
400+
test_feature_view = FeatureView(
401+
name="test_feature_view",
402+
entities=_mock_entity(),
403+
schema=[
404+
Field(name="feature1", dtype=Float32),
405+
],
406+
source=test_data_source,
407+
)
408+
409+
mock_registry = MagicMock()
410+
mock_registry.get_feature_view.return_value = test_feature_view
411+
412+
# Use a DataFrame even though embed_query mode is used
413+
entity_df = pd.DataFrame(
414+
{"event_timestamp": [datetime(2021, 1, 1)], "driver_id": [1]}
415+
)
416+
417+
retrieval_job = PostgreSQLOfflineStore.get_historical_features(
418+
config=test_repo_config,
419+
feature_views=[test_feature_view],
420+
feature_refs=["test_feature_view:feature1"],
421+
entity_df=entity_df,
422+
registry=mock_registry,
423+
project="test_project",
424+
)
425+
426+
actual_query = retrieval_job.to_sql().strip()
427+
logger.debug("Actual query:\n%s", actual_query)
428+
429+
# Check that the query starts with WITH and contains the expected comment block
430+
assert actual_query.startswith("""WITH
431+
432+
/*
433+
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
434+
all the logic as the field to GROUP BY the data
435+
*/""")
436+
437+
sqlglot.parse(actual_query)
438+
assert True
439+
440+
358441
@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
359442
@patch(
360443
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table"

0 commit comments

Comments
 (0)