1
- from datetime import timedelta
1
+ from datetime import datetime , timedelta
2
2
from typing import List , Optional , Union , cast
3
3
4
4
from pyspark .sql import DataFrame , SparkSession , Window
5
5
from pyspark .sql import functions as F
6
6
7
7
from feast import BatchFeatureView , StreamFeatureView
8
8
from feast .aggregation import Aggregation
9
- from feast .infra .common .materialization_job import MaterializationTask
10
- from feast .infra .common .retrieval_task import HistoricalRetrievalTask
9
+ from feast .data_source import DataSource
11
10
from feast .infra .compute_engines .dag .context import ExecutionContext
12
11
from feast .infra .compute_engines .dag .model import DAGFormat
13
12
from feast .infra .compute_engines .dag .node import DAGNode
23
22
from feast .infra .offline_stores .offline_utils import (
24
23
infer_event_timestamp_from_entity_df ,
25
24
)
26
- from feast .utils import _get_fields_with_aliases
27
25
28
26
ENTITY_TS_ALIAS = "__entity_event_timestamp"
29
27
@@ -49,18 +47,21 @@ def rename_entity_ts_column(
49
47
return entity_df
50
48
51
49
52
- class SparkMaterializationReadNode (DAGNode ):
50
+ class SparkReadNode (DAGNode ):
53
51
def __init__ (
54
- self , name : str , task : Union [MaterializationTask , HistoricalRetrievalTask ]
52
+ self ,
53
+ name : str ,
54
+ source : DataSource ,
55
+ start_time : Optional [datetime ] = None ,
56
+ end_time : Optional [datetime ] = None ,
55
57
):
56
58
super ().__init__ (name )
57
- self .task = task
59
+ self .source = source
60
+ self .start_time = start_time
61
+ self .end_time = end_time
58
62
59
63
def execute (self , context : ExecutionContext ) -> DAGValue :
60
64
offline_store = context .offline_store
61
- start_time = self .task .start_time
62
- end_time = self .task .end_time
63
-
64
65
(
65
66
join_key_columns ,
66
67
feature_name_columns ,
@@ -69,15 +70,15 @@ def execute(self, context: ExecutionContext) -> DAGValue:
69
70
) = context .column_info
70
71
71
72
# 📥 Reuse Feast's robust query resolver
72
- retrieval_job = offline_store .pull_latest_from_table_or_query (
73
+ retrieval_job = offline_store .pull_all_from_table_or_query (
73
74
config = context .repo_config ,
74
- data_source = self .task . feature_view . batch_source ,
75
+ data_source = self .source ,
75
76
join_key_columns = join_key_columns ,
76
77
feature_name_columns = feature_name_columns ,
77
78
timestamp_field = timestamp_field ,
78
79
created_timestamp_column = created_timestamp_column ,
79
- start_date = start_time ,
80
- end_date = end_time ,
80
+ start_date = self . start_time ,
81
+ end_date = self . end_time ,
81
82
)
82
83
spark_df = cast (SparkRetrievalJob , retrieval_job ).to_spark_df ()
83
84
@@ -88,74 +89,8 @@ def execute(self, context: ExecutionContext) -> DAGValue:
88
89
"source" : "feature_view_batch_source" ,
89
90
"timestamp_field" : timestamp_field ,
90
91
"created_timestamp_column" : created_timestamp_column ,
91
- "start_date" : start_time ,
92
- "end_date" : end_time ,
93
- },
94
- )
95
-
96
-
97
- class SparkHistoricalRetrievalReadNode (DAGNode ):
98
- def __init__ (
99
- self , name : str , task : HistoricalRetrievalTask , spark_session : SparkSession
100
- ):
101
- super ().__init__ (name )
102
- self .task = task
103
- self .spark_session = spark_session
104
-
105
- def execute (self , context : ExecutionContext ) -> DAGValue :
106
- """
107
- Read data from the offline store on the Spark engine.
108
- TODO: Some functionality is duplicated with SparkMaterializationReadNode and spark get_historical_features.
109
- Args:
110
- context: SparkExecutionContext
111
- Returns: DAGValue
112
- """
113
- fv = self .task .feature_view
114
- source = fv .batch_source
115
-
116
- (
117
- join_key_columns ,
118
- feature_name_columns ,
119
- timestamp_field ,
120
- created_timestamp_column ,
121
- ) = context .column_info
122
-
123
- # TODO: Use pull_all_from_table_or_query when it supports not filtering by timestamp
124
- # retrieval_job = offline_store.pull_all_from_table_or_query(
125
- # config=context.repo_config,
126
- # data_source=source,
127
- # join_key_columns=join_key_columns,
128
- # feature_name_columns=feature_name_columns,
129
- # timestamp_field=timestamp_field,
130
- # start_date=min_ts,
131
- # end_date=max_ts,
132
- # )
133
- # spark_df = cast(SparkRetrievalJob, retrieval_job).to_spark_df()
134
-
135
- columns = join_key_columns + feature_name_columns + [timestamp_field ]
136
- if created_timestamp_column :
137
- columns .append (created_timestamp_column )
138
-
139
- (fields_with_aliases , aliases ) = _get_fields_with_aliases (
140
- fields = columns ,
141
- field_mappings = source .field_mapping ,
142
- )
143
- fields_with_alias_string = ", " .join (fields_with_aliases )
144
-
145
- from_expression = source .get_table_query_string ()
146
-
147
- query = f"""
148
- SELECT { fields_with_alias_string }
149
- FROM { from_expression }
150
- """
151
- spark_df = self .spark_session .sql (query )
152
-
153
- return DAGValue (
154
- data = spark_df ,
155
- format = DAGFormat .SPARK ,
156
- metadata = {
157
- "source" : "feature_view_batch_source" ,
158
- "timestamp_field" : timestamp_field ,
92
+ "start_date" : self .start_time ,
93
+ "end_date" : self .end_time ,
159
94
},
160
95
)
161
96
@@ -227,7 +162,12 @@ def execute(self, context: ExecutionContext) -> DAGValue:
227
162
feature_df : DataFrame = feature_value .data
228
163
229
164
entity_df = context .entity_df
230
- assert entity_df is not None , "entity_df must be set in ExecutionContext"
165
+ if entity_df is None :
166
+ return DAGValue (
167
+ data = feature_df ,
168
+ format = DAGFormat .SPARK ,
169
+ metadata = {"joined_on" : None },
170
+ )
231
171
232
172
# Get timestamp fields from feature view
233
173
join_keys , feature_cols , ts_col , created_ts_col = context .column_info
@@ -272,13 +212,13 @@ def execute(self, context: ExecutionContext) -> DAGValue:
272
212
if ENTITY_TS_ALIAS in input_df .columns :
273
213
filtered_df = filtered_df .filter (F .col (ts_col ) <= F .col (ENTITY_TS_ALIAS ))
274
214
275
- # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
276
- if self .ttl :
277
- ttl_seconds = int (self .ttl .total_seconds ())
278
- lower_bound = F .col (ENTITY_TS_ALIAS ) - F .expr (
279
- f"INTERVAL { ttl_seconds } seconds"
280
- )
281
- filtered_df = filtered_df .filter (F .col (ts_col ) >= lower_bound )
215
+ # Optional TTL filter: feature.ts >= entity.event_timestamp - ttl
216
+ if self .ttl :
217
+ ttl_seconds = int (self .ttl .total_seconds ())
218
+ lower_bound = F .col (ENTITY_TS_ALIAS ) - F .expr (
219
+ f"INTERVAL { ttl_seconds } seconds"
220
+ )
221
+ filtered_df = filtered_df .filter (F .col (ts_col ) >= lower_bound )
282
222
283
223
# Optional custom filter condition
284
224
if self .filter_condition :
0 commit comments