Skip to content

Commit 11cc5d5

Browse files
authored
fix: performance improvements (#374)
1 parent f6c5db6 commit 11cc5d5

File tree

6 files changed

+64
-33
lines changed

6 files changed

+64
-33
lines changed

butterfree/_cli/migrate.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import pkgutil
66
import sys
7-
from typing import Set
7+
from typing import Set, Type
88

99
import boto3
1010
import setuptools
@@ -90,8 +90,18 @@ def __fs_objects(path: str) -> Set[FeatureSetPipeline]:
9090

9191
instances.add(value)
9292

93+
def create_instance(cls: Type[FeatureSetPipeline]) -> FeatureSetPipeline:
94+
sig = inspect.signature(cls.__init__)
95+
parameters = sig.parameters
96+
97+
if "run_date" in parameters:
98+
run_date = datetime.datetime.today().strftime("%y-%m-%d")
99+
return cls(run_date)
100+
101+
return cls()
102+
93103
logger.info("Creating instances...")
94-
return set(value() for value in instances) # type: ignore
104+
return set(create_instance(value) for value in instances) # type: ignore
95105

96106

97107
PATH = typer.Argument(

butterfree/extract/source.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Optional
44

55
from pyspark.sql import DataFrame
6+
from pyspark.storagelevel import StorageLevel
67

78
from butterfree.clients import SparkClient
89
from butterfree.extract.readers.reader import Reader
@@ -95,16 +96,21 @@ def construct(
9596
DataFrame with the query result against all readers.
9697
9798
"""
99+
# Step 1: Build temporary views for each reader
98100
for reader in self.readers:
99-
reader.build(
100-
client=client, start_date=start_date, end_date=end_date
101-
) # create temporary views for each reader
101+
reader.build(client=client, start_date=start_date, end_date=end_date)
102102

103+
# Step 2: Execute SQL query on the combined readers
103104
dataframe = client.sql(self.query)
104105

106+
# Step 3: Cache the dataframe if necessary, using memory and disk storage
105107
if not dataframe.isStreaming and self.eager_evaluation:
106-
dataframe.cache().count()
108+
# Persist to ensure the DataFrame is stored in mem and disk (if necessary)
109+
dataframe.persist(StorageLevel.MEMORY_AND_DISK)
110+
# Trigger the cache/persist operation by performing an action
111+
dataframe.count()
107112

113+
# Step 4: Run post-processing hooks on the dataframe
108114
post_hook_df = self.run_post_hooks(dataframe)
109115

110116
return post_hook_df

butterfree/pipelines/feature_set_pipeline.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import List, Optional
44

5+
from pyspark.storagelevel import StorageLevel
6+
57
from butterfree.clients import SparkClient
68
from butterfree.dataframe_service import repartition_sort_df
79
from butterfree.extract import Source
@@ -209,35 +211,46 @@ def run(
209211
soon. Use only if strictly necessary.
210212
211213
"""
214+
# Step 1: Construct input dataframe from the source.
212215
dataframe = self.source.construct(
213216
client=self.spark_client,
214217
start_date=self.feature_set.define_start_date(start_date),
215218
end_date=end_date,
216219
)
217220

221+
# Step 2: Repartition and sort if required, avoid if not necessary.
218222
if partition_by:
219223
order_by = order_by or partition_by
220-
dataframe = repartition_sort_df(
221-
dataframe, partition_by, order_by, num_processors
222-
)
223-
224-
dataframe = self.feature_set.construct(
224+
current_partitions = dataframe.rdd.getNumPartitions()
225+
optimal_partitions = num_processors or current_partitions
226+
if current_partitions != optimal_partitions:
227+
dataframe = repartition_sort_df(
228+
dataframe, partition_by, order_by, num_processors
229+
)
230+
231+
# Step 3: Construct the feature set dataframe using defined transformations.
232+
transformed_dataframe = self.feature_set.construct(
225233
dataframe=dataframe,
226234
client=self.spark_client,
227235
start_date=start_date,
228236
end_date=end_date,
229237
num_processors=num_processors,
230238
)
231239

240+
if dataframe.storageLevel != StorageLevel.NONE:
241+
dataframe.unpersist() # Clear the data from the cache (disk and memory)
242+
243+
# Step 4: Load the data into the configured sink.
232244
self.sink.flush(
233-
dataframe=dataframe,
245+
dataframe=transformed_dataframe,
234246
feature_set=self.feature_set,
235247
spark_client=self.spark_client,
236248
)
237249

238-
if not dataframe.isStreaming:
250+
# Step 5: Validate the output if not streaming and data volume is reasonable.
251+
if not transformed_dataframe.isStreaming:
239252
self.sink.validate(
240-
dataframe=dataframe,
253+
dataframe=transformed_dataframe,
241254
feature_set=self.feature_set,
242255
spark_client=self.spark_client,
243256
)

butterfree/transform/aggregated_feature_set.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def _aggregate(
387387
]
388388

389389
groupby = self.keys_columns.copy()
390+
390391
if window is not None:
391392
dataframe = dataframe.withColumn("window", window.get())
392393
groupby.append("window")
@@ -410,19 +411,23 @@ def _aggregate(
410411
"keep_rn", functions.row_number().over(partition_window)
411412
).filter("keep_rn = 1")
412413

413-
# repartition to have all rows for each group at the same partition
414-
# by doing that, we won't have to shuffle data on grouping by id
415-
dataframe = repartition_df(
416-
dataframe,
417-
partition_by=groupby,
418-
num_processors=num_processors,
419-
)
414+
current_partitions = dataframe.rdd.getNumPartitions()
415+
optimal_partitions = num_processors or current_partitions
416+
417+
if current_partitions != optimal_partitions:
418+
dataframe = repartition_df(
419+
dataframe,
420+
partition_by=groupby,
421+
num_processors=optimal_partitions,
422+
)
423+
420424
grouped_data = dataframe.groupby(*groupby)
421425

422-
if self._pivot_column:
426+
if self._pivot_column and self._pivot_values:
423427
grouped_data = grouped_data.pivot(self._pivot_column, self._pivot_values)
424428

425429
aggregated = grouped_data.agg(*aggregations)
430+
426431
return self._with_renamed_columns(aggregated, features, window)
427432

428433
def _with_renamed_columns(
@@ -637,12 +642,12 @@ def construct(
637642
output_df = output_df.select(*self.columns).replace( # type: ignore
638643
float("nan"), None
639644
)
640-
if not output_df.isStreaming:
641-
if self.deduplicate_rows:
642-
output_df = self._filter_duplicated_rows(output_df)
643-
if self.eager_evaluation:
644-
output_df.cache().count()
645+
if not output_df.isStreaming and self.deduplicate_rows:
646+
output_df = self._filter_duplicated_rows(output_df)
645647

646648
post_hook_df = self.run_post_hooks(output_df)
647649

650+
if not output_df.isStreaming and self.eager_evaluation:
651+
post_hook_df.cache().count()
652+
648653
return post_hook_df

butterfree/transform/feature_set.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,8 @@ def construct(
436436
pre_hook_df,
437437
).select(*self.columns)
438438

439-
if not output_df.isStreaming:
440-
if self.deduplicate_rows:
441-
output_df = self._filter_duplicated_rows(output_df)
442-
if self.eager_evaluation:
443-
output_df.cache().count()
439+
if not output_df.isStreaming and self.deduplicate_rows:
440+
output_df = self._filter_duplicated_rows(output_df)
444441

445442
output_df = self.incremental_strategy.filter_with_incremental_strategy(
446443
dataframe=output_df, start_date=start_date, end_date=end_date

tests/unit/butterfree/transform/test_feature_set.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_construct(
220220
+ feature_divide.get_output_columns()
221221
)
222222
assert_dataframe_equality(result_df, feature_set_dataframe)
223-
assert result_df.is_cached
223+
assert not result_df.is_cached
224224

225225
def test_construct_invalid_df(
226226
self, key_id, timestamp_c, feature_add, feature_divide

0 commit comments

Comments
 (0)