Skip to content

Commit 4b94608

Browse files
authored
feat: Offline store update pull_all_from_table_or_query to make timestampfield optional (#5281)
* Update offline store pull all API date field optional Signed-off-by: HaoXuAI <[email protected]> * Update offline store pull all API date field optional Signed-off-by: HaoXuAI <[email protected]> * update Signed-off-by: HaoXuAI <[email protected]> * update Signed-off-by: HaoXuAI <[email protected]> * update Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * update source read node Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> * fix test Signed-off-by: HaoXuAI <[email protected]> --------- Signed-off-by: HaoXuAI <[email protected]>
1 parent 1b291b2 commit 4b94608

File tree

23 files changed

+368
-218
lines changed

23 files changed

+368
-218
lines changed

sdk/python/feast/infra/common/retrieval_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from datetime import datetime
3-
from typing import Union
3+
from typing import Optional, Union
44

55
import pandas as pd
66

@@ -15,5 +15,5 @@ class HistoricalRetrievalTask:
1515
feature_view: Union[BatchFeatureView, StreamFeatureView]
1616
full_feature_name: bool
1717
registry: Registry
18-
start_time: datetime
19-
end_time: datetime
18+
start_time: Optional[datetime] = None
19+
end_time: Optional[datetime] = None

sdk/python/feast/infra/compute_engines/feature_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def _should_validate(self):
7272
def build(self) -> ExecutionPlan:
7373
last_node = self.build_source_node()
7474

75-
# PIT join entities to the feature data, and perform filtering
76-
if isinstance(self.task, HistoricalRetrievalTask):
77-
last_node = self.build_join_node(last_node)
75+
# Join entity_df with source if needed
76+
last_node = self.build_join_node(last_node)
7877

78+
# PIT filter, TTL, and user-defined filter
7979
last_node = self.build_filter_node(last_node)
8080

8181
if self._should_aggregate():

sdk/python/feast/infra/compute_engines/local/feature_builder.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from feast.infra.common.materialization_job import MaterializationTask
44
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
5-
from feast.infra.compute_engines.dag.plan import ExecutionPlan
65
from feast.infra.compute_engines.feature_builder import FeatureBuilder
76
from feast.infra.compute_engines.local.backends.base import DataFrameBackend
87
from feast.infra.compute_engines.local.nodes import (
@@ -95,25 +94,3 @@ def build_output_nodes(self, input_node):
9594
node = LocalOutputNode("output")
9695
node.add_input(input_node)
9796
self.nodes.append(node)
98-
99-
def build(self) -> ExecutionPlan:
100-
last_node = self.build_source_node()
101-
102-
if isinstance(self.task, HistoricalRetrievalTask):
103-
last_node = self.build_join_node(last_node)
104-
105-
last_node = self.build_filter_node(last_node)
106-
107-
if self._should_aggregate():
108-
last_node = self.build_aggregation_node(last_node)
109-
elif isinstance(self.task, HistoricalRetrievalTask):
110-
last_node = self.build_dedup_node(last_node)
111-
112-
if self._should_transform():
113-
last_node = self.build_transformation_node(last_node)
114-
115-
if self._should_validate():
116-
last_node = self.build_validation_node(last_node)
117-
118-
self.build_output_nodes(last_node)
119-
return ExecutionPlan(self.nodes)

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from datetime import timedelta
1+
from datetime import datetime, timedelta
22
from typing import Optional
33

44
import pyarrow as pa
55

6+
from feast.data_source import DataSource
67
from feast.infra.compute_engines.dag.context import ExecutionContext
78
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
89
from feast.infra.compute_engines.local.backends.base import DataFrameBackend
@@ -15,14 +16,40 @@
1516

1617

1718
class LocalSourceReadNode(LocalNode):
18-
def __init__(self, name: str, feature_view, task):
19+
def __init__(
20+
self,
21+
name: str,
22+
source: DataSource,
23+
start_time: Optional[datetime] = None,
24+
end_time: Optional[datetime] = None,
25+
):
1926
super().__init__(name)
20-
self.feature_view = feature_view
21-
self.task = task
27+
self.source = source
28+
self.start_time = start_time
29+
self.end_time = end_time
2230

2331
def execute(self, context: ExecutionContext) -> ArrowTableValue:
24-
# TODO : Implement the logic to read from offline store
25-
return ArrowTableValue(data=pa.Table.from_pandas(context.entity_df))
32+
offline_store = context.offline_store
33+
(
34+
join_key_columns,
35+
feature_name_columns,
36+
timestamp_field,
37+
created_timestamp_column,
38+
) = context.column_info
39+
40+
# 📥 Reuse Feast's robust query resolver
41+
retrieval_job = offline_store.pull_all_from_table_or_query(
42+
config=context.repo_config,
43+
data_source=self.source,
44+
join_key_columns=join_key_columns,
45+
feature_name_columns=feature_name_columns,
46+
timestamp_field=timestamp_field,
47+
created_timestamp_column=created_timestamp_column,
48+
start_date=self.start_time,
49+
end_date=self.end_time,
50+
)
51+
arrow_table = retrieval_job.to_arrow()
52+
return ArrowTableValue(data=arrow_table)
2653

2754

2855
class LocalJoinNode(LocalNode):

sdk/python/feast/infra/compute_engines/spark/feature_builder.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
SparkAggregationNode,
1010
SparkDedupNode,
1111
SparkFilterNode,
12-
SparkHistoricalRetrievalReadNode,
1312
SparkJoinNode,
14-
SparkMaterializationReadNode,
13+
SparkReadNode,
1514
SparkTransformationNode,
1615
SparkWriteNode,
1716
)
@@ -27,12 +26,10 @@ def __init__(
2726
self.spark_session = spark_session
2827

2928
def build_source_node(self):
30-
if isinstance(self.task, MaterializationTask):
31-
node = SparkMaterializationReadNode("source", self.task)
32-
else:
33-
node = SparkHistoricalRetrievalReadNode(
34-
"source", self.task, self.spark_session
35-
)
29+
source = self.feature_view.batch_source
30+
start_time = self.task.start_time
31+
end_time = self.task.end_time
32+
node = SparkReadNode("source", source, start_time, end_time)
3633
self.nodes.append(node)
3734
return node
3835

sdk/python/feast/infra/compute_engines/spark/node.py

Lines changed: 30 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
from datetime import timedelta
1+
from datetime import datetime, timedelta
22
from typing import List, Optional, Union, cast
33

44
from pyspark.sql import DataFrame, SparkSession, Window
55
from pyspark.sql import functions as F
66

77
from feast import BatchFeatureView, StreamFeatureView
88
from feast.aggregation import Aggregation
9-
from feast.infra.common.materialization_job import MaterializationTask
10-
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
9+
from feast.data_source import DataSource
1110
from feast.infra.compute_engines.dag.context import ExecutionContext
1211
from feast.infra.compute_engines.dag.model import DAGFormat
1312
from feast.infra.compute_engines.dag.node import DAGNode
@@ -23,7 +22,6 @@
2322
from feast.infra.offline_stores.offline_utils import (
2423
infer_event_timestamp_from_entity_df,
2524
)
26-
from feast.utils import _get_fields_with_aliases
2725

2826
ENTITY_TS_ALIAS = "__entity_event_timestamp"
2927

@@ -49,18 +47,21 @@ def rename_entity_ts_column(
4947
return entity_df
5048

5149

52-
class SparkMaterializationReadNode(DAGNode):
50+
class SparkReadNode(DAGNode):
5351
def __init__(
54-
self, name: str, task: Union[MaterializationTask, HistoricalRetrievalTask]
52+
self,
53+
name: str,
54+
source: DataSource,
55+
start_time: Optional[datetime] = None,
56+
end_time: Optional[datetime] = None,
5557
):
5658
super().__init__(name)
57-
self.task = task
59+
self.source = source
60+
self.start_time = start_time
61+
self.end_time = end_time
5862

5963
def execute(self, context: ExecutionContext) -> DAGValue:
6064
offline_store = context.offline_store
61-
start_time = self.task.start_time
62-
end_time = self.task.end_time
63-
6465
(
6566
join_key_columns,
6667
feature_name_columns,
@@ -69,15 +70,15 @@ def execute(self, context: ExecutionContext) -> DAGValue:
6970
) = context.column_info
7071

7172
# 📥 Reuse Feast's robust query resolver
72-
retrieval_job = offline_store.pull_latest_from_table_or_query(
73+
retrieval_job = offline_store.pull_all_from_table_or_query(
7374
config=context.repo_config,
74-
data_source=self.task.feature_view.batch_source,
75+
data_source=self.source,
7576
join_key_columns=join_key_columns,
7677
feature_name_columns=feature_name_columns,
7778
timestamp_field=timestamp_field,
7879
created_timestamp_column=created_timestamp_column,
79-
start_date=start_time,
80-
end_date=end_time,
80+
start_date=self.start_time,
81+
end_date=self.end_time,
8182
)
8283
spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()
8384

@@ -88,74 +89,8 @@ def execute(self, context: ExecutionContext) -> DAGValue:
8889
"source": "feature_view_batch_source",
8990
"timestamp_field": timestamp_field,
9091
"created_timestamp_column": created_timestamp_column,
91-
"start_date": start_time,
92-
"end_date": end_time,
93-
},
94-
)
95-
96-
97-
class SparkHistoricalRetrievalReadNode(DAGNode):
98-
def __init__(
99-
self, name: str, task: HistoricalRetrievalTask, spark_session: SparkSession
100-
):
101-
super().__init__(name)
102-
self.task = task
103-
self.spark_session = spark_session
104-
105-
def execute(self, context: ExecutionContext) -> DAGValue:
106-
"""
107-
Read data from the offline store on the Spark engine.
108-
TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features.
109-
Args:
110-
context: SparkExecutionContext
111-
Returns: DAGValue
112-
"""
113-
fv = self.task.feature_view
114-
source = fv.batch_source
115-
116-
(
117-
join_key_columns,
118-
feature_name_columns,
119-
timestamp_field,
120-
created_timestamp_column,
121-
) = context.column_info
122-
123-
# TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp
124-
# retrieval_job = offline_store.pull_all_from_table_or_query(
125-
# config=context.repo_config,
126-
# data_source=source,
127-
# join_key_columns=join_key_columns,
128-
# feature_name_columns=feature_name_columns,
129-
# timestamp_field=timestamp_field,
130-
# start_date=min_ts,
131-
# end_date=max_ts,
132-
# )
133-
# spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()
134-
135-
columns = join_key_columns + feature_name_columns + [timestamp_field]
136-
if created_timestamp_column:
137-
columns.append(created_timestamp_column)
138-
139-
(fields_with_aliases, aliases) = _get_fields_with_aliases(
140-
fields=columns,
141-
field_mappings=source.field_mapping,
142-
)
143-
fields_with_alias_string = ", ".join(fields_with_aliases)
144-
145-
from_expression = source.get_table_query_string()
146-
147-
query = f"""
148-
SELECT {fields_with_alias_string}
149-
FROM {from_expression}
150-
"""
151-
spark_df = self.spark_session.sql(query)
152-
153-
return DAGValue(
154-
data=spark_df,
155-
format=DAGFormat.SPARK,
156-
metadata={
157-
"source": "feature_view_batch_source",
158-
"timestamp_field": timestamp_field,
92+
"start_date": self.start_time,
93+
"end_date": self.end_time,
15994
},
16095
)
16196

@@ -227,7 +162,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
227162
feature_df: DataFrame = feature_value.data
228163

229164
entity_df = context.entity_df
230-
assert entity_df is not None, "entity_df must be set in ExecutionContext"
165+
if entity_df is None:
166+
return DAGValue(
167+
data=feature_df,
168+
format=DAGFormat.SPARK,
169+
metadata={"joined_on": None},
170+
)
231171

232172
# Get timestamp fields from feature view
233173
join_keys, feature_cols, ts_col, created_ts_col = context.column_info
@@ -272,13 +212,13 @@ def execute(self, context: ExecutionContext) -> DAGValue:
272212
if ENTITY_TS_ALIAS in input_df.columns:
273213
filtered_df = filtered_df.filter(F.col(ts_col) <= F.col(ENTITY_TS_ALIAS))
274214

275-
# Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
276-
if self.ttl:
277-
ttl_seconds = int(self.ttl.total_seconds())
278-
lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr(
279-
f"INTERVAL {ttl_seconds} seconds"
280-
)
281-
filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound)
215+
# Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
216+
if self.ttl:
217+
ttl_seconds = int(self.ttl.total_seconds())
218+
lower_bound = F.col(ENTITY_TS_ALIAS) - F.expr(
219+
f"INTERVAL {ttl_seconds} seconds"
220+
)
221+
filtered_df = filtered_df.filter(F.col(ts_col) >= lower_bound)
282222

283223
# Optional custom filter condition
284224
if self.filter_condition:

sdk/python/feast/infra/offline_stores/bigquery.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
BigQuerySource,
5353
SavedDatasetBigQueryStorage,
5454
)
55+
from .offline_utils import get_timestamp_filter_sql
5556

5657
try:
5758
from google.api_core import client_info as http_client_info
@@ -188,8 +189,9 @@ def pull_all_from_table_or_query(
188189
join_key_columns: List[str],
189190
feature_name_columns: List[str],
190191
timestamp_field: str,
191-
start_date: datetime,
192-
end_date: datetime,
192+
created_timestamp_column: Optional[str] = None,
193+
start_date: Optional[datetime] = None,
194+
end_date: Optional[datetime] = None,
193195
) -> RetrievalJob:
194196
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
195197
assert isinstance(data_source, BigQuerySource)
@@ -201,15 +203,26 @@ def pull_all_from_table_or_query(
201203
project=project_id,
202204
location=config.offline_store.location,
203205
)
206+
207+
timestamp_fields = [timestamp_field]
208+
if created_timestamp_column:
209+
timestamp_fields.append(created_timestamp_column)
204210
field_string = ", ".join(
205211
BigQueryOfflineStore._escape_query_columns(join_key_columns)
206212
+ BigQueryOfflineStore._escape_query_columns(feature_name_columns)
207-
+ [timestamp_field]
213+
+ timestamp_fields
214+
)
215+
timestamp_filter = get_timestamp_filter_sql(
216+
start_date,
217+
end_date,
218+
timestamp_field,
219+
quote_fields=False,
220+
cast_style="timestamp_func",
208221
)
209222
query = f"""
210223
SELECT {field_string}
211224
FROM {from_expression}
212-
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date}') AND TIMESTAMP('{end_date}')
225+
WHERE {timestamp_filter}
213226
"""
214227
return BigQueryRetrievalJob(
215228
query=query,

0 commit comments

Comments
 (0)