@@ -387,7 +387,6 @@ def _aggregate(
387
387
]
388
388
389
389
groupby = self .keys_columns .copy ()
390
-
391
390
if window is not None :
392
391
dataframe = dataframe .withColumn ("window" , window .get ())
393
392
groupby .append ("window" )
@@ -411,23 +410,19 @@ def _aggregate(
411
410
"keep_rn" , functions .row_number ().over (partition_window )
412
411
).filter ("keep_rn = 1" )
413
412
414
- current_partitions = dataframe .rdd .getNumPartitions ()
415
- optimal_partitions = num_processors or current_partitions
416
-
417
- if current_partitions != optimal_partitions :
418
- dataframe = repartition_df (
419
- dataframe ,
420
- partition_by = groupby ,
421
- num_processors = optimal_partitions ,
422
- )
423
-
413
+ # repartition to have all rows for each group at the same partition
414
+ # by doing that, we won't have to shuffle data on grouping by id
415
+ dataframe = repartition_df (
416
+ dataframe ,
417
+ partition_by = groupby ,
418
+ num_processors = num_processors ,
419
+ )
424
420
grouped_data = dataframe .groupby (* groupby )
425
421
426
- if self ._pivot_column and self . _pivot_values :
422
+ if self ._pivot_column :
427
423
grouped_data = grouped_data .pivot (self ._pivot_column , self ._pivot_values )
428
424
429
425
aggregated = grouped_data .agg (* aggregations )
430
-
431
426
return self ._with_renamed_columns (aggregated , features , window )
432
427
433
428
def _with_renamed_columns (
@@ -576,12 +571,14 @@ def construct(
576
571
577
572
pre_hook_df = self .run_pre_hooks (dataframe )
578
573
579
- output_df = pre_hook_df
580
- for feature in self .keys + [self .timestamp ]:
581
- output_df = feature .transform (output_df )
574
+ output_df = reduce (
575
+ lambda df , feature : feature .transform (df ),
576
+ self .keys + [self .timestamp ],
577
+ pre_hook_df ,
578
+ )
582
579
583
580
if self ._windows and end_date is not None :
584
- # Run aggregations for each window
581
+ # run aggregations for each window
585
582
agg_list = [
586
583
self ._aggregate (
587
584
dataframe = output_df ,
@@ -601,12 +598,13 @@ def construct(
601
598
602
599
# keeping this logic to maintain the same behavior for already implemented
603
600
# feature sets
601
+
604
602
if self ._windows [0 ].slide == "1 day" :
605
603
base_df = self ._get_base_dataframe (
606
604
client = client , dataframe = output_df , end_date = end_date
607
605
)
608
606
609
- # Left join each aggregation result to our base dataframe
607
+ # left join each aggregation result to our base dataframe
610
608
output_df = reduce (
611
609
lambda left , right : self ._dataframe_join (
612
610
left ,
@@ -639,18 +637,12 @@ def construct(
639
637
output_df = output_df .select (* self .columns ).replace ( # type: ignore
640
638
float ("nan" ), None
641
639
)
642
-
643
- if not output_df .isStreaming and self .deduplicate_rows :
644
- output_df = self ._filter_duplicated_rows (output_df )
640
+ if not output_df .isStreaming :
641
+ if self .deduplicate_rows :
642
+ output_df = self ._filter_duplicated_rows (output_df )
643
+ if self .eager_evaluation :
644
+ output_df .cache ().count ()
645
645
646
646
post_hook_df = self .run_post_hooks (output_df )
647
647
648
- # Eager evaluation, only if needed and managable
649
- if not output_df .isStreaming and self .eager_evaluation :
650
- # Small dataframes only
651
- if output_df .count () < 1_000_000 :
652
- post_hook_df .cache ().count ()
653
- else :
654
- post_hook_df .cache () # Cache without materialization for large volumes
655
-
656
648
return post_hook_df
0 commit comments