|
18 | 18 | from feast.infra.compute_engines.dag.node import DAGNode
|
19 | 19 | from feast.infra.compute_engines.dag.value import DAGValue
|
20 | 20 | from feast.infra.compute_engines.ray.config import RayComputeEngineConfig
|
| 21 | +from feast.infra.compute_engines.ray.utils import ( |
| 22 | + safe_batch_processor, |
| 23 | + write_to_online_store, |
| 24 | +) |
21 | 25 | from feast.infra.compute_engines.utils import create_offline_store_retrieval_job
|
22 | 26 | from feast.infra.ray_shared_utils import (
|
23 | 27 | apply_field_mapping,
|
@@ -149,9 +153,8 @@ def execute(self, context: ExecutionContext) -> DAGValue:
|
149 | 153 | feature_df = feature_dataset.to_pandas()
|
150 | 154 | feature_ref = ray.put(feature_df)
|
151 | 155 |
|
| 156 | + @safe_batch_processor |
152 | 157 | def join_with_aggregated_features(batch: pd.DataFrame) -> pd.DataFrame:
|
153 |
| - if batch.empty: |
154 |
| - return batch |
155 | 158 | features = ray.get(feature_ref)
|
156 | 159 | if join_keys:
|
157 | 160 | result = pd.merge(
|
@@ -226,10 +229,9 @@ def execute(self, context: ExecutionContext) -> DAGValue:
|
226 | 229 | input_value.assert_format(DAGFormat.RAY)
|
227 | 230 | dataset: Dataset = input_value.data
|
228 | 231 |
|
| 232 | + @safe_batch_processor |
229 | 233 | def apply_filters(batch: pd.DataFrame) -> pd.DataFrame:
|
230 | 234 | """Apply TTL and custom filters to the batch."""
|
231 |
| - if batch.empty: |
232 |
| - return batch |
233 | 235 |
|
234 | 236 | filtered_batch = batch.copy()
|
235 | 237 |
|
@@ -447,11 +449,9 @@ def execute(self, context: ExecutionContext) -> DAGValue:
|
447 | 449 | input_value.assert_format(DAGFormat.RAY)
|
448 | 450 | dataset: Dataset = input_value.data
|
449 | 451 |
|
| 452 | + @safe_batch_processor |
450 | 453 | def deduplicate_batch(batch: pd.DataFrame) -> pd.DataFrame:
|
451 | 454 | """Remove duplicates from the batch."""
|
452 |
| - if batch.empty: |
453 |
| - return batch |
454 |
| - |
455 | 455 | # Get deduplication keys
|
456 | 456 | join_keys = self.column_info.join_keys
|
457 | 457 | timestamp_col = self.column_info.timestamp_column
|
@@ -518,27 +518,21 @@ def execute(self, context: ExecutionContext) -> DAGValue:
|
518 | 518 | elif callable(self.transformation):
|
519 | 519 | transformation_serialized = dill.dumps(self.transformation)
|
520 | 520 |
|
| 521 | + @safe_batch_processor |
521 | 522 | def apply_transformation_with_serialized_udf(
|
522 | 523 | batch: pd.DataFrame,
|
523 | 524 | ) -> pd.DataFrame:
|
524 | 525 | """Apply the transformation using pre-serialized UDF."""
|
525 |
| - if batch.empty: |
526 |
| - return batch |
527 |
| - |
528 |
| - try: |
529 |
| - if transformation_serialized: |
530 |
| - transformation_func = dill.loads(transformation_serialized) |
531 |
| - transformed_batch = transformation_func(batch) |
532 |
| - else: |
533 |
| - logger.warning( |
534 |
| - "No serialized transformation available, returning original batch" |
535 |
| - ) |
536 |
| - transformed_batch = batch |
| 526 | + if transformation_serialized: |
| 527 | + transformation_func = dill.loads(transformation_serialized) |
| 528 | + transformed_batch = transformation_func(batch) |
| 529 | + else: |
| 530 | + logger.warning( |
| 531 | + "No serialized transformation available, returning original batch" |
| 532 | + ) |
| 533 | + transformed_batch = batch |
537 | 534 |
|
538 |
| - return transformed_batch |
539 |
| - except Exception as e: |
540 |
| - logger.error(f"Transformation failed: {e}") |
541 |
| - return batch |
| 535 | + return transformed_batch |
542 | 536 |
|
543 | 537 | transformed_dataset = dataset.map_batches(
|
544 | 538 | apply_transformation_with_serialized_udf, batch_format="pandas"
|
@@ -645,46 +639,36 @@ def execute(self, context: ExecutionContext) -> DAGValue:
|
645 | 639 | feature_view=self.feature_view, repo_config=context.repo_config
|
646 | 640 | )
|
647 | 641 |
|
| 642 | + @safe_batch_processor |
648 | 643 | def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
|
649 | 644 | """Write each batch using pre-serialized artifacts."""
|
650 |
| - if batch.empty: |
651 |
| - return batch |
652 |
| - |
653 |
| - try: |
654 |
| - ( |
655 |
| - feature_view, |
656 |
| - online_store, |
657 |
| - offline_store, |
658 |
| - repo_config, |
659 |
| - ) = serialized_artifacts.unserialize() |
660 |
| - |
661 |
| - arrow_table = pa.Table.from_pandas(batch) |
662 |
| - |
663 |
| - # Write to online store if enabled |
664 |
| - if getattr(feature_view, "online", False): |
665 |
| - # TODO: Implement proper online store writing with correct data format conversion |
666 |
| - logger.debug( |
667 |
| - "Online store writing not implemented yet for Ray compute engine" |
668 |
| - ) |
669 |
| - |
670 |
| - # Write to offline store if enabled |
671 |
| - if getattr(feature_view, "offline", False): |
672 |
| - try: |
673 |
| - offline_store.offline_write_batch( |
674 |
| - config=repo_config, |
675 |
| - feature_view=feature_view, |
676 |
| - table=arrow_table, |
677 |
| - progress=lambda x: None, |
678 |
| - ) |
679 |
| - except Exception as e: |
680 |
| - logger.error(f"Failed to write to offline store: {e}") |
681 |
| - raise |
| 645 | + ( |
| 646 | + feature_view, |
| 647 | + online_store, |
| 648 | + offline_store, |
| 649 | + repo_config, |
| 650 | + ) = serialized_artifacts.unserialize() |
| 651 | + |
| 652 | + arrow_table = pa.Table.from_pandas(batch) |
| 653 | + |
| 654 | + # Write to online store if enabled |
| 655 | + write_to_online_store( |
| 656 | + arrow_table=arrow_table, |
| 657 | + feature_view=feature_view, |
| 658 | + online_store=online_store, |
| 659 | + repo_config=repo_config, |
| 660 | + ) |
682 | 661 |
|
683 |
| - return batch |
| 662 | + # Write to offline store if enabled |
| 663 | + if getattr(feature_view, "offline", False): |
| 664 | + offline_store.offline_write_batch( |
| 665 | + config=repo_config, |
| 666 | + feature_view=feature_view, |
| 667 | + table=arrow_table, |
| 668 | + progress=lambda x: None, |
| 669 | + ) |
684 | 670 |
|
685 |
| - except Exception as e: |
686 |
| - logger.error(f"Write operation failed: {e}") |
687 |
| - raise |
| 671 | + return batch |
688 | 672 |
|
689 | 673 | written_dataset = dataset.map_batches(
|
690 | 674 | write_batch_with_serialized_artifacts, batch_format="pandas"
|
|
0 commit comments