Skip to content

Commit b9ac90b

Browse files
authored
feat: Support compute engine to use multi feature views as source (#5482)
* Draft: multi source support Signed-off-by: HaoXuAI <[email protected]> * Draft: multi source support Signed-off-by: HaoXuAI <[email protected]> * Checkpoint Signed-off-by: HaoXuAI <[email protected]> * Checkpoint Signed-off-by: HaoXuAI <[email protected]> * Checkpoint Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> * fix testing Signed-off-by: HaoXuAI <[email protected]> --------- Signed-off-by: HaoXuAI <[email protected]>
1 parent 0af6e94 commit b9ac90b

File tree

30 files changed

+1254
-498
lines changed

30 files changed

+1254
-498
lines changed

protos/feast/core/FeatureView.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ message FeatureViewSpec {
7979

8080
// Whether these features should be written to the offline store
8181
bool offline = 13;
82+
83+
repeated FeatureViewSpec source_views = 14;
8284
}
8385

8486
message FeatureViewMeta {

sdk/python/feast/batch_feature_view.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BatchFeatureView(FeatureView):
5353
entities: List[str]
5454
ttl: Optional[timedelta]
5555
source: DataSource
56+
sink_source: Optional[DataSource] = None
5657
schema: List[Field]
5758
entity_columns: List[Field]
5859
features: List[Field]
@@ -65,7 +66,7 @@ class BatchFeatureView(FeatureView):
6566
materialization_intervals: List[Tuple[datetime, datetime]]
6667
udf: Optional[Callable[[Any], Any]]
6768
udf_string: Optional[str]
68-
feature_transformation: Transformation
69+
feature_transformation: Optional[Transformation]
6970
batch_engine: Optional[Field]
7071
aggregations: Optional[List[Aggregation]]
7172

@@ -74,7 +75,8 @@ def __init__(
7475
*,
7576
name: str,
7677
mode: Union[TransformationMode, str] = TransformationMode.PYTHON,
77-
source: DataSource,
78+
source: Union[DataSource, "BatchFeatureView", List["BatchFeatureView"]],
79+
sink_source: Optional[DataSource] = None,
7880
entities: Optional[List[Entity]] = None,
7981
ttl: Optional[timedelta] = None,
8082
tags: Optional[Dict[str, str]] = None,
@@ -83,7 +85,7 @@ def __init__(
8385
description: str = "",
8486
owner: str = "",
8587
schema: Optional[List[Field]] = None,
86-
udf: Optional[Callable[[Any], Any]],
88+
udf: Optional[Callable[[Any], Any]] = None,
8789
udf_string: Optional[str] = "",
8890
feature_transformation: Optional[Transformation] = None,
8991
batch_engine: Optional[Field] = None,
@@ -96,7 +98,7 @@ def __init__(
9698
RuntimeWarning,
9799
)
98100

99-
if (
101+
if isinstance(source, DataSource) and (
100102
type(source).__name__ not in SUPPORTED_BATCH_SOURCES
101103
and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE
102104
):
@@ -124,14 +126,13 @@ def __init__(
124126
description=description,
125127
owner=owner,
126128
schema=schema,
127-
source=source,
129+
source=source, # type: ignore[arg-type]
130+
sink_source=sink_source,
128131
)
129132

130-
def get_feature_transformation(self) -> Transformation:
133+
def get_feature_transformation(self) -> Optional[Transformation]:
131134
if not self.udf:
132-
raise ValueError(
133-
"Either a UDF or a feature transformation must be provided for BatchFeatureView"
134-
)
135+
return None
135136
if self.mode in (
136137
TransformationMode.PANDAS,
137138
TransformationMode.PYTHON,

sdk/python/feast/feature_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ def apply(
922922
for fv in itertools.chain(
923923
views_to_update, sfvs_to_update, odfvs_with_writes_to_update
924924
):
925-
if isinstance(fv, FeatureView):
925+
if isinstance(fv, FeatureView) and fv.batch_source:
926926
data_sources_set_to_update.add(fv.batch_source)
927927
if hasattr(fv, "stream_source"):
928928
if fv.stream_source:

sdk/python/feast/feature_view.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import copy
1515
import warnings
1616
from datetime import datetime, timedelta
17-
from typing import Dict, List, Optional, Tuple, Type
17+
from typing import Dict, List, Optional, Tuple, Type, Union
1818

1919
from google.protobuf.duration_pb2 import Duration
2020
from google.protobuf.message import Message
@@ -90,6 +90,7 @@ class FeatureView(BaseFeatureView):
9090
ttl: Optional[timedelta]
9191
batch_source: DataSource
9292
stream_source: Optional[DataSource]
93+
source_views: Optional[List["FeatureView"]]
9394
entity_columns: List[Field]
9495
features: List[Field]
9596
online: bool
@@ -103,7 +104,8 @@ def __init__(
103104
self,
104105
*,
105106
name: str,
106-
source: DataSource,
107+
source: Union[DataSource, "FeatureView", List["FeatureView"]],
108+
sink_source: Optional[DataSource] = None,
107109
schema: Optional[List[Field]] = None,
108110
entities: Optional[List[Entity]] = None,
109111
ttl: Optional[timedelta] = timedelta(days=0),
@@ -144,22 +146,45 @@ def __init__(
144146
self.ttl = ttl
145147
schema = schema or []
146148

147-
# Initialize data sources.
149+
# Normalize source
150+
self.stream_source = None
151+
self.data_source: Optional[DataSource] = None
152+
self.source_views: List[FeatureView] = []
153+
154+
if isinstance(source, DataSource):
155+
self.data_source = source
156+
elif isinstance(source, FeatureView):
157+
self.source_views = [source]
158+
elif isinstance(source, list) and all(
159+
isinstance(sv, FeatureView) for sv in source
160+
):
161+
self.source_views = source
162+
else:
163+
raise TypeError(
164+
"source must be a DataSource, a FeatureView, or a list of FeatureView."
165+
)
166+
167+
# Set up stream, batch and derived view sources
148168
if (
149-
isinstance(source, PushSource)
150-
or isinstance(source, KafkaSource)
151-
or isinstance(source, KinesisSource)
169+
isinstance(self.data_source, PushSource)
170+
or isinstance(self.data_source, KafkaSource)
171+
or isinstance(self.data_source, KinesisSource)
152172
):
153-
self.stream_source = source
154-
if not source.batch_source:
173+
# Stream source definition
174+
self.stream_source = self.data_source
175+
if not self.data_source.batch_source:
155176
raise ValueError(
156-
f"A batch_source needs to be specified for stream source `{source.name}`"
177+
f"A batch_source needs to be specified for stream source `{self.data_source.name}`"
157178
)
158-
else:
159-
self.batch_source = source.batch_source
179+
self.batch_source = self.data_source.batch_source
180+
elif self.data_source:
181+
# Batch source definition
182+
self.batch_source = self.data_source
160183
else:
161-
self.stream_source = None
162-
self.batch_source = source
184+
# Derived view source definition
185+
if not sink_source:
186+
raise ValueError("Derived FeatureView must specify `sink_source`.")
187+
self.batch_source = sink_source
163188

164189
# Initialize features and entity columns.
165190
features: List[Field] = []
@@ -201,25 +226,26 @@ def __init__(
201226
)
202227

203228
# TODO(felixwang9817): Add more robust validation of features.
204-
cols = [field.name for field in schema]
205-
for col in cols:
206-
if (
207-
self.batch_source.field_mapping is not None
208-
and col in self.batch_source.field_mapping.keys()
209-
):
210-
raise ValueError(
211-
f"The field {col} is mapped to {self.batch_source.field_mapping[col]} for this data source. "
212-
f"Please either remove this field mapping or use {self.batch_source.field_mapping[col]} as the "
213-
f"Entity or Feature name."
214-
)
229+
if self.batch_source is not None:
230+
cols = [field.name for field in schema]
231+
for col in cols:
232+
if (
233+
self.batch_source.field_mapping is not None
234+
and col in self.batch_source.field_mapping.keys()
235+
):
236+
raise ValueError(
237+
f"The field {col} is mapped to {self.batch_source.field_mapping[col]} for this data source. "
238+
f"Please either remove this field mapping or use {self.batch_source.field_mapping[col]} as the "
239+
f"Entity or Feature name."
240+
)
215241

216242
super().__init__(
217243
name=name,
218244
features=features,
219245
description=description,
220246
tags=tags,
221247
owner=owner,
222-
source=source,
248+
source=self.batch_source,
223249
)
224250
self.online = online
225251
self.offline = offline
@@ -348,13 +374,18 @@ def to_proto(self) -> FeatureViewProto:
348374
meta = self.to_proto_meta()
349375
ttl_duration = self.get_ttl_duration()
350376

351-
batch_source_proto = self.batch_source.to_proto()
352-
batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}"
377+
batch_source_proto = None
378+
if self.batch_source:
379+
batch_source_proto = self.batch_source.to_proto()
380+
batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}"
353381

354382
stream_source_proto = None
355383
if self.stream_source:
356384
stream_source_proto = self.stream_source.to_proto()
357385
stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}"
386+
source_view_protos = None
387+
if self.source_views:
388+
source_view_protos = [view.to_proto().spec for view in self.source_views]
358389
spec = FeatureViewSpecProto(
359390
name=self.name,
360391
entities=self.entities,
@@ -368,6 +399,7 @@ def to_proto(self) -> FeatureViewProto:
368399
offline=self.offline,
369400
batch_source=batch_source_proto,
370401
stream_source=stream_source_proto,
402+
source_views=source_view_protos,
371403
)
372404

373405
return FeatureViewProto(spec=spec, meta=meta)
@@ -403,12 +435,21 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):
403435
Returns:
404436
A FeatureViewProto object based on the feature view protobuf.
405437
"""
406-
batch_source = DataSource.from_proto(feature_view_proto.spec.batch_source)
438+
batch_source = (
439+
DataSource.from_proto(feature_view_proto.spec.batch_source)
440+
if feature_view_proto.spec.HasField("batch_source")
441+
else None
442+
)
407443
stream_source = (
408444
DataSource.from_proto(feature_view_proto.spec.stream_source)
409445
if feature_view_proto.spec.HasField("stream_source")
410446
else None
411447
)
448+
source_views = [
449+
FeatureView.from_proto(FeatureViewProto(spec=view_spec, meta=None))
450+
for view_spec in feature_view_proto.spec.source_views
451+
]
452+
412453
feature_view = cls(
413454
name=feature_view_proto.spec.name,
414455
description=feature_view_proto.spec.description,
@@ -421,7 +462,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):
421462
if feature_view_proto.spec.ttl.ToNanoseconds() == 0
422463
else feature_view_proto.spec.ttl.ToTimedelta()
423464
),
424-
source=batch_source,
465+
source=batch_source if batch_source else source_views,
425466
)
426467
if stream_source:
427468
feature_view.stream_source = stream_source

sdk/python/feast/inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def update_data_sources_with_inferred_event_timestamp_col(
2626
) -> None:
2727
ERROR_MSG_PREFIX = "Unable to infer DataSource timestamp_field"
2828
for data_source in data_sources:
29+
if data_source is None:
30+
continue
2931
if isinstance(data_source, RequestSource):
3032
continue
3133
if isinstance(data_source, PushSource):

sdk/python/feast/infra/compute_engines/algorithms/__init__.py

Whitespace-only changes.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import List, Set
2+
3+
from feast.infra.compute_engines.dag.node import DAGNode
4+
5+
6+
def topo_sort(root: DAGNode) -> List[DAGNode]:
7+
"""
8+
Topologically sort a DAG starting from a single root node.
9+
10+
Args:
11+
root: The root DAGNode.
12+
13+
Returns:
14+
A list of DAGNodes in topological order (dependencies first).
15+
"""
16+
return topo_sort_multiple([root])
17+
18+
19+
def topo_sort_multiple(roots: List[DAGNode]) -> List[DAGNode]:
20+
"""
21+
Topologically sort a DAG with multiple roots.
22+
23+
Args:
24+
roots: List of root DAGNodes.
25+
26+
Returns:
27+
A list of all reachable DAGNodes in execution-safe order.
28+
"""
29+
visited: Set[int] = set()
30+
ordered: List[DAGNode] = []
31+
32+
def dfs(node: DAGNode):
33+
if id(node) in visited:
34+
return
35+
visited.add(id(node))
36+
for input_node in node.inputs:
37+
dfs(input_node)
38+
ordered.append(node)
39+
40+
for root in roots:
41+
dfs(root)
42+
43+
return ordered

sdk/python/feast/infra/compute_engines/base.py

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Optional, Sequence, Union
2+
from typing import List, Sequence, Union
33

44
import pyarrow as pa
55

@@ -12,13 +12,12 @@
1212
MaterializationTask,
1313
)
1414
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
15-
from feast.infra.compute_engines.dag.context import ColumnInfo, ExecutionContext
15+
from feast.infra.compute_engines.dag.context import ExecutionContext
1616
from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob
1717
from feast.infra.online_stores.online_store import OnlineStore
1818
from feast.infra.registry.base_registry import BaseRegistry
1919
from feast.on_demand_feature_view import OnDemandFeatureView
2020
from feast.stream_feature_view import StreamFeatureView
21-
from feast.utils import _get_column_names
2221

2322

2423
class ComputeEngine(ABC):
@@ -124,52 +123,11 @@ def get_execution_context(
124123
if hasattr(task, "entity_df") and task.entity_df is not None:
125124
entity_df = task.entity_df
126125

127-
column_info = self.get_column_info(registry, task)
128126
return ExecutionContext(
129127
project=task.project,
130128
repo_config=self.repo_config,
131129
offline_store=self.offline_store,
132130
online_store=self.online_store,
133131
entity_defs=entity_defs,
134-
column_info=column_info,
135132
entity_df=entity_df,
136133
)
137-
138-
def get_column_info(
139-
self,
140-
registry: BaseRegistry,
141-
task: Union[MaterializationTask, HistoricalRetrievalTask],
142-
) -> ColumnInfo:
143-
entities = []
144-
for entity_name in task.feature_view.entities:
145-
entities.append(registry.get_entity(entity_name, task.project))
146-
147-
join_keys, feature_cols, ts_col, created_ts_col = _get_column_names(
148-
task.feature_view, entities
149-
)
150-
field_mapping = self.get_field_mapping(task.feature_view)
151-
152-
return ColumnInfo(
153-
join_keys=join_keys,
154-
feature_cols=feature_cols,
155-
ts_col=ts_col,
156-
created_ts_col=created_ts_col,
157-
field_mapping=field_mapping,
158-
)
159-
160-
def get_field_mapping(
161-
self, feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView]
162-
) -> Optional[dict]:
163-
"""
164-
Get the field mapping for a feature view.
165-
Args:
166-
feature_view: The feature view to get the field mapping for.
167-
168-
Returns:
169-
A dictionary mapping field names to column names.
170-
"""
171-
if feature_view.stream_source:
172-
return feature_view.stream_source.field_mapping
173-
if feature_view.batch_source:
174-
return feature_view.batch_source.field_mapping
175-
return None

0 commit comments

Comments
 (0)