1- from datetime import timedelta
1+ from datetime import datetime , timedelta
22from typing import List , Optional , Union , cast
33
44from pyspark .sql import DataFrame , SparkSession , Window
55from pyspark .sql import functions as F
66
77from feast import BatchFeatureView , StreamFeatureView
88from feast .aggregation import Aggregation
9- from feast .infra .common .materialization_job import MaterializationTask
10- from feast .infra .common .retrieval_task import HistoricalRetrievalTask
9+ from feast .data_source import DataSource
1110from feast .infra .compute_engines .dag .context import ExecutionContext
1211from feast .infra .compute_engines .dag .model import DAGFormat
1312from feast .infra .compute_engines .dag .node import DAGNode
2322from feast .infra .offline_stores .offline_utils import (
2423 infer_event_timestamp_from_entity_df ,
2524)
26- from feast .utils import _get_fields_with_aliases
2725
2826ENTITY_TS_ALIAS = "__entity_event_timestamp"
2927
@@ -49,18 +47,21 @@ def rename_entity_ts_column(
4947 return entity_df
5048
5149
52- class SparkMaterializationReadNode (DAGNode ):
50+ class SparkReadNode (DAGNode ):
5351 def __init__ (
54- self , name : str , task : Union [MaterializationTask , HistoricalRetrievalTask ]
52+ self ,
53+ name : str ,
54+ source : DataSource ,
55+ start_time : Optional [datetime ] = None ,
56+ end_time : Optional [datetime ] = None ,
5557 ):
5658 super ().__init__ (name )
57- self .task = task
59+ self .source = source
60+ self .start_time = start_time
61+ self .end_time = end_time
5862
5963 def execute (self , context : ExecutionContext ) -> DAGValue :
6064 offline_store = context .offline_store
61- start_time = self .task .start_time
62- end_time = self .task .end_time
63-
6465 (
6566 join_key_columns ,
6667 feature_name_columns ,
@@ -69,15 +70,15 @@ def execute(self, context: ExecutionContext) -> DAGValue:
6970 ) = context .column_info
7071
7172 # 📥 Reuse Feast's robust query resolver
72- retrieval_job = offline_store .pull_latest_from_table_or_query (
73+ retrieval_job = offline_store .pull_all_from_table_or_query (
7374 config = context .repo_config ,
74- data_source = self .task . feature_view . batch_source ,
75+ data_source = self .source ,
7576 join_key_columns = join_key_columns ,
7677 feature_name_columns = feature_name_columns ,
7778 timestamp_field = timestamp_field ,
7879 created_timestamp_column = created_timestamp_column ,
79- start_date = start_time ,
80- end_date = end_time ,
80+ start_date = self . start_time ,
81+ end_date = self . end_time ,
8182 )
8283 spark_df = cast (SparkRetrievalJob , retrieval_job ).to_spark_df ()
8384
@@ -88,74 +89,8 @@ def execute(self, context: ExecutionContext) -> DAGValue:
8889 "source" : "feature_view_batch_source" ,
8990 "timestamp_field" : timestamp_field ,
9091 "created_timestamp_column" : created_timestamp_column ,
91- "start_date" : start_time ,
92- "end_date" : end_time ,
93- },
94- )
95-
96-
97- class SparkHistoricalRetrievalReadNode (DAGNode ):
98- def __init__ (
99- self , name : str , task : HistoricalRetrievalTask , spark_session : SparkSession
100- ):
101- super ().__init__ (name )
102- self .task = task
103- self .spark_session = spark_session
104-
105- def execute (self , context : ExecutionContext ) -> DAGValue :
106- """
107- Read data from the offline store on the Spark engine.
108- TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features.
109- Args:
110- context: SparkExecutionContext
111- Returns: DAGValue
112- """
113- fv = self .task .feature_view
114- source = fv .batch_source
115-
116- (
117- join_key_columns ,
118- feature_name_columns ,
119- timestamp_field ,
120- created_timestamp_column ,
121- ) = context .column_info
122-
123- # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp
124- # retrieval_job = offline_store.pull_all_from_table_or_query(
125- # config=context.repo_config,
126- # data_source=source,
127- # join_key_columns=join_key_columns,
128- # feature_name_columns=feature_name_columns,
129- # timestamp_field=timestamp_field,
130- # start_date=min_ts,
131- # end_date=max_ts,
132- # )
133- # spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()
134-
135- columns = join_key_columns + feature_name_columns + [timestamp_field ]
136- if created_timestamp_column :
137- columns .append (created_timestamp_column )
138-
139- (fields_with_aliases , aliases ) = _get_fields_with_aliases (
140- fields = columns ,
141- field_mappings = source .field_mapping ,
142- )
143- fields_with_alias_string = ", " .join (fields_with_aliases )
144-
145- from_expression = source .get_table_query_string ()
146-
147- query = f"""
148- SELECT { fields_with_alias_string }
149- FROM { from_expression }
150- """
151- spark_df = self .spark_session .sql (query )
152-
153- return DAGValue (
154- data = spark_df ,
155- format = DAGFormat .SPARK ,
156- metadata = {
157- "source" : "feature_view_batch_source" ,
158- "timestamp_field" : timestamp_field ,
92+ "start_date" : self .start_time ,
93+ "end_date" : self .end_time ,
15994 },
16095 )
16196
@@ -227,7 +162,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
227162 feature_df : DataFrame = feature_value .data
228163
229164 entity_df = context .entity_df
230- assert entity_df is not None , "entity_df must be set in ExecutionContext"
165+ if entity_df is None :
166+ return DAGValue (
167+ data = feature_df ,
168+ format = DAGFormat .SPARK ,
169+ metadata = {"joined_on" : None },
170+ )
231171
232172 # Get timestamp fields from feature view
233173 join_keys , feature_cols , ts_col , created_ts_col = context .column_info
@@ -272,13 +212,13 @@ def execute(self, context: ExecutionContext) -> DAGValue:
272212 if ENTITY_TS_ALIAS in input_df .columns :
273213 filtered_df = filtered_df .filter (F .col (ts_col ) <= F .col (ENTITY_TS_ALIAS ))
274214
275- # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
276- if self .ttl :
277- ttl_seconds = int (self .ttl .total_seconds ())
278- lower_bound = F .col (ENTITY_TS_ALIAS ) - F .expr (
279- f"INTERVAL { ttl_seconds } seconds"
280- )
281- filtered_df = filtered_df .filter (F .col (ts_col ) >= lower_bound )
215+ # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
216+ if self .ttl :
217+ ttl_seconds = int (self .ttl .total_seconds ())
218+ lower_bound = F .col (ENTITY_TS_ALIAS ) - F .expr (
219+ f"INTERVAL { ttl_seconds } seconds"
220+ )
221+ filtered_df = filtered_df .filter (F .col (ts_col ) >= lower_bound )
282222
283223 # Optional custom filter condition
284224 if self .filter_condition :
0 commit comments