Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ ci = [
"pytest-mock==1.10.4",
"pytest-env",
"Sphinx>4.0.0,<7",
"sqlglot[rs]>=26.12.1",
"testcontainers==4.9.0",
"python-keycloak==4.2.2",
"pre-commit<3.3.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,10 @@ def _get_entity_schema(
{% else %}
{{ left_table_query_string }}
{% endif %}
),

)
{% if featureviews | length > 0 %}
,
{% endif %}
{% for featureview in featureviews %}

"{{ featureview.name }}__entity_dataframe" AS (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from unittest.mock import MagicMock, patch

import pandas as pd
import sqlglot

from feast.entity import Entity
from feast.feature_view import FeatureView, Field
from feast.feature_view import FeatureView, FeatureViewProjection, Field
from feast.infra.offline_stores.contrib.postgres_offline_store.postgres import (
PostgreSQLOfflineStore,
PostgreSQLOfflineStoreConfig,
Expand All @@ -14,8 +15,9 @@
PostgreSQLSource,
)
from feast.infra.offline_stores.offline_store import RetrievalJob
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import RepoConfig
from feast.types import Float32
from feast.types import Float32, ValueType

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
Expand All @@ -30,15 +32,7 @@ def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_conn):
project="test_project",
registry="test_registry",
provider="local",
offline_store=PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
),
offline_store=_mock_offline_store_config(),
)

test_data_source = PostgreSQLSource(
Expand Down Expand Up @@ -100,15 +94,7 @@ def test_pull_latest_from_table_without_nested_timestamp_or_query(mock_get_conn)
project="test_project",
registry="test_registry",
provider="local",
offline_store=PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
),
offline_store=_mock_offline_store_config(),
)

test_data_source = PostgreSQLSource(
Expand Down Expand Up @@ -167,15 +153,7 @@ def test_pull_all_from_table_or_query(mock_get_conn):
project="test_project",
registry="test_registry",
provider="local",
offline_store=PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
),
offline_store=_mock_offline_store_config(),
)

test_data_source = PostgreSQLSource(
Expand Down Expand Up @@ -239,15 +217,7 @@ def test_get_historical_features_entity_select_modes(
project="test_project",
registry="test_registry",
provider="local",
offline_store=PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
),
offline_store=_mock_offline_store_config(),
)

test_data_source = PostgreSQLSource(
Expand All @@ -259,13 +229,7 @@ def test_get_historical_features_entity_select_modes(

test_feature_view = FeatureView(
name="test_feature_view",
entities=[
Entity(
name="driver_id",
join_keys=["driver_id"],
description="Driver ID",
)
],
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
Expand Down Expand Up @@ -300,6 +264,9 @@ def test_get_historical_features_entity_select_modes(
all the logic as the field to GROUP BY the data
*/""")

sqlglot.parse(actual_query)
assert True


@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
@patch(
Expand Down Expand Up @@ -345,13 +312,7 @@ def test_get_historical_features_entity_select_modes_embed_query(

test_feature_view = FeatureView(
name="test_feature_view",
entities=[
Entity(
name="driver_id",
join_keys=["driver_id"],
description="Driver ID",
)
],
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
Expand Down Expand Up @@ -388,3 +349,103 @@ def test_get_historical_features_entity_select_modes_embed_query(
assert actual_query.startswith("""WITH

entity_query AS (""")

# Verify the SQL is valid by parsing it
sqlglot.parse(actual_query) # This will raise ParseError if SQL is invalid
assert True # If we get here, the SQL is valid


@patch("feast.infra.offline_stores.contrib.postgres_offline_store.postgres._get_conn")
@patch(
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.df_to_postgres_table"
)
@patch(
"feast.infra.offline_stores.contrib.postgres_offline_store.postgres.get_query_schema"
)
@patch("feast.on_demand_feature_view.OnDemandFeatureView.get_requested_odfvs")
def test_get_historical_features_no_feature_view(
mock_get_requested_odfvs,
mock_get_query_schema,
mock_df_to_postgres_table,
mock_get_conn,
):
mock_conn = MagicMock()
mock_get_conn.return_value.__enter__.return_value = mock_conn

# Create a mock OnDemandFeatureView
mock_odfv = MagicMock(spec=OnDemandFeatureView)
mock_odfv.name = "test_odfv"
mock_odfv.features = [Field(name="feature1", dtype=Float32)]
mock_odfv.projection = FeatureViewProjection(
name="test_odfv",
name_alias="test_odfv",
features=[Field(name="feature1", dtype=Float32)],
desired_features=[],
)
mock_get_requested_odfvs.return_value = [mock_odfv]

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=_mock_offline_store_config(),
)

test_data_source = PostgreSQLSource(
name="test_batch_source",
description="test_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
)

test_feature_view = FeatureView(
name="test_feature_view",
entities=_mock_entity(),
schema=[
Field(name="feature1", dtype=Float32),
],
source=test_data_source,
)

mock_registry = MagicMock()
mock_registry.get_on_demand_feature_view.return_value = test_feature_view
mock_registry.list_on_demand_feature_views.return_value = [mock_odfv]

entity_df = pd.DataFrame(
{"event_timestamp": [datetime(2021, 1, 1)], "driver_id": [1]}
)

retrieval_job = PostgreSQLOfflineStore.get_historical_features(
config=test_repo_config,
feature_views=[],
feature_refs=["test_odfv:feature1"],
entity_df=entity_df,
registry=mock_registry,
project="test_project",
)

sqlglot.parse(retrieval_job.to_sql().strip(), dialect="postgres")
assert True


def _mock_offline_store_config():
return PostgreSQLOfflineStoreConfig(
type="postgres",
host="localhost",
port=5432,
database="test_db",
db_schema="public",
user="test_user",
password="test_password",
)


def _mock_entity():
return [
Entity(
name="driver_id",
join_keys=["driver_id"],
description="Driver ID",
value_type=ValueType.INT64,
)
]
Loading