Skip to content

Commit a9853bb

Browse files
moromimayralphrass
authored andcommitted
[MLOP-635] Rebase Incremental Job/Interval Run branch for test on selected feature sets (#278)
* Add interval branch modifications. * Add interval_runs notebook. * Add tests. * Apply style (black, flack8 and mypy). * Fix tests. * Change version to create package dev.
1 parent c23522b commit a9853bb

File tree

8 files changed

+382
-7
lines changed

8 files changed

+382
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each
4646
* [BUG] Fix Cassandra Connect Session ([#316](https://github.com/quintoandar/butterfree/pull/316))
4747
* Fix method to generate agg feature name. ([#326](https://github.com/quintoandar/butterfree/pull/326))
4848

49+
## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3)
50+
### Added
51+
* [MLOP-636] Create migration classes ([#282](https://github.com/quintoandar/butterfree/pull/282))
52+
4953
## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3)
5054
### Added
5155
* [MLOP-599] Apply mypy to ButterFree ([#273](https://github.com/quintoandar/butterfree/pull/273))

butterfree/load/writers/historical_feature_store_writer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,25 @@ def write(
146146
"""
147147
dataframe = self._create_partitions(dataframe)
148148

149-
dataframe = self._apply_transformations(dataframe)
149+
partition_df = self._apply_transformations(dataframe)
150+
151+
if self.debug_mode:
152+
dataframe = partition_df
153+
else:
154+
dataframe = self.check_schema(
155+
spark_client, partition_df, feature_set.name, self.database
156+
)
157+
158+
if self.interval_mode:
159+
if self.debug_mode:
160+
spark_client.create_temporary_view(
161+
dataframe=dataframe,
162+
name=f"historical_feature_store__{feature_set.name}",
163+
)
164+
return
165+
166+
self._incremental_mode(feature_set, dataframe, spark_client)
167+
return
150168

151169
if self.interval_mode:
152170
partition_overwrite_mode = spark_client.conn.conf.get(
@@ -191,6 +209,34 @@ def write(
191209
**self.db_config.get_options(s3_key),
192210
)
193211

212+
def _incremental_mode(
213+
self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient
214+
) -> None:
215+
216+
partition_overwrite_mode = spark_client.conn.conf.get(
217+
"spark.sql.sources.partitionOverwriteMode"
218+
).lower()
219+
220+
if partition_overwrite_mode != "dynamic":
221+
raise RuntimeError(
222+
"m=load_incremental_table, "
223+
"spark.sql.sources.partitionOverwriteMode={}, "
224+
"msg=partitionOverwriteMode have to be configured to 'dynamic'".format(
225+
partition_overwrite_mode
226+
)
227+
)
228+
229+
s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
230+
options = {"path": self.db_config.get_options(s3_key).get("path")}
231+
232+
spark_client.write_dataframe(
233+
dataframe=dataframe,
234+
format_=self.db_config.format_,
235+
mode=self.db_config.mode,
236+
**options,
237+
partitionBy=self.PARTITION_BY,
238+
)
239+
194240
def _assert_validation_count(
195241
self, table_name: str, written_count: int, dataframe_count: int
196242
) -> None:

butterfree/load/writers/online_feature_store_writer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pyspark.sql.functions import col, row_number
88
from pyspark.sql.streaming import StreamingQuery
99

10-
from butterfree.clients import SparkClient
10+
from butterfree.clients import CassandraClient, SparkClient
1111
from butterfree.configs.db import AbstractWriteConfig, CassandraConfig
1212
from butterfree.constants.columns import TIMESTAMP_COLUMN
1313
from butterfree.hooks import Hook
@@ -182,6 +182,22 @@ def write(
182182
"""
183183
table_name = feature_set.entity if self.write_to_entity else feature_set.name
184184

185+
if not self.debug_mode:
186+
config = (
187+
self.db_config
188+
if self.db_config == CassandraConfig
189+
else CassandraConfig()
190+
)
191+
192+
cassandra_client = CassandraClient(
193+
host=[config.host],
194+
keyspace=config.keyspace,
195+
user=config.username,
196+
password=config.password,
197+
)
198+
199+
dataframe = self.check_schema(cassandra_client, dataframe, table_name)
200+
185201
if dataframe.isStreaming:
186202
dataframe = self._apply_transformations(dataframe)
187203
if self.debug_mode:

tests/integration/butterfree/load/test_sink.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111

12-
def test_sink(input_dataframe, feature_set):
12+
def test_sink(input_dataframe, feature_set, mocker):
1313
# arrange
1414
client = SparkClient()
1515
client.conn.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
@@ -44,6 +44,10 @@ def test_sink(input_dataframe, feature_set):
4444
)
4545
online_writer = OnlineFeatureStoreWriter(db_config=online_config)
4646

47+
online_writer.check_schema_hook = mocker.stub("check_schema_hook")
48+
online_writer.check_schema_hook.run = mocker.stub("run")
49+
online_writer.check_schema_hook.run.return_value = feature_set_df
50+
4751
writers = [historical_writer, online_writer]
4852
sink = Sink(writers)
4953

tests/integration/butterfree/pipelines/test_feature_set_pipeline.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pyspark.sql import DataFrame
55
from pyspark.sql import functions as F
66

7+
from butterfree.clients import SparkClient
78
from butterfree.configs import environment
89
from butterfree.configs.db import MetastoreConfig
910
from butterfree.constants import DataType
@@ -411,3 +412,247 @@ def test_pipeline_interval_run(
411412

412413
# tear down
413414
shutil.rmtree("test_folder")
415+
416+
def test_feature_set_pipeline_with_dates(
417+
self,
418+
mocked_date_df,
419+
spark_session,
420+
fixed_windows_output_feature_set_date_dataframe,
421+
feature_set_pipeline,
422+
mocker,
423+
):
424+
# arrange
425+
table_reader_table = "b_table"
426+
create_temp_view(dataframe=mocked_date_df, name=table_reader_table)
427+
428+
historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
429+
430+
feature_set_pipeline.sink.writers = [historical_writer]
431+
432+
# act
433+
feature_set_pipeline.run(start_date="2016-04-12", end_date="2016-04-13")
434+
435+
df = spark_session.sql("select * from historical_feature_store__feature_set")
436+
437+
# assert
438+
assert_dataframe_equality(df, fixed_windows_output_feature_set_date_dataframe)
439+
440+
def test_feature_set_pipeline_with_execution_date(
441+
self,
442+
mocked_date_df,
443+
spark_session,
444+
fixed_windows_output_feature_set_date_dataframe,
445+
feature_set_pipeline,
446+
mocker,
447+
):
448+
# arrange
449+
table_reader_table = "b_table"
450+
create_temp_view(dataframe=mocked_date_df, name=table_reader_table)
451+
452+
target_df = fixed_windows_output_feature_set_date_dataframe.filter(
453+
"timestamp < '2016-04-13'"
454+
)
455+
456+
historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
457+
458+
feature_set_pipeline.sink.writers = [historical_writer]
459+
460+
# act
461+
feature_set_pipeline.run_for_date(execution_date="2016-04-12")
462+
463+
df = spark_session.sql("select * from historical_feature_store__feature_set")
464+
465+
# assert
466+
assert_dataframe_equality(df, target_df)
467+
468+
def test_pipeline_with_hooks(self, spark_session, mocker):
469+
# arrange
470+
hook1 = AddHook(value=1)
471+
472+
spark_session.sql(
473+
"select 1 as id, timestamp('2020-01-01') as timestamp, 0 as feature"
474+
).createOrReplaceTempView("test")
475+
476+
target_df = spark_session.sql(
477+
"select 1 as id, timestamp('2020-01-01') as timestamp, 6 as feature, 2020 "
478+
"as year, 1 as month, 1 as day"
479+
)
480+
481+
historical_writer = HistoricalFeatureStoreWriter(debug_mode=True)
482+
483+
test_pipeline = FeatureSetPipeline(
484+
source=Source(
485+
readers=[TableReader(id="reader", table="test",).add_post_hook(hook1)],
486+
query="select * from reader",
487+
).add_post_hook(hook1),
488+
feature_set=FeatureSet(
489+
name="feature_set",
490+
entity="entity",
491+
description="description",
492+
features=[
493+
Feature(
494+
name="feature",
495+
description="test",
496+
transformation=SQLExpressionTransform(expression="feature + 1"),
497+
dtype=DataType.INTEGER,
498+
),
499+
],
500+
keys=[
501+
KeyFeature(
502+
name="id",
503+
description="The user's Main ID or device ID",
504+
dtype=DataType.INTEGER,
505+
)
506+
],
507+
timestamp=TimestampFeature(),
508+
)
509+
.add_pre_hook(hook1)
510+
.add_post_hook(hook1),
511+
sink=Sink(writers=[historical_writer],).add_pre_hook(hook1),
512+
)
513+
514+
# act
515+
test_pipeline.run()
516+
output_df = spark_session.table("historical_feature_store__feature_set")
517+
518+
# assert
519+
output_df.show()
520+
assert_dataframe_equality(output_df, target_df)
521+
522+
def test_pipeline_interval_run(
523+
self, mocked_date_df, pipeline_interval_run_target_dfs, spark_session
524+
):
525+
"""Testing pipeline's idempotent interval run feature.
526+
Source data:
527+
+-------+---+-------------------+-------------------+
528+
|feature| id| ts| timestamp|
529+
+-------+---+-------------------+-------------------+
530+
| 200| 1|2016-04-11 11:31:11|2016-04-11 11:31:11|
531+
| 300| 1|2016-04-12 11:44:12|2016-04-12 11:44:12|
532+
| 400| 1|2016-04-13 11:46:24|2016-04-13 11:46:24|
533+
| 500| 1|2016-04-14 12:03:21|2016-04-14 12:03:21|
534+
+-------+---+-------------------+-------------------+
535+
The test executes 3 runs for different time intervals. The input data has 4 data
536+
points: 2016-04-11, 2016-04-12, 2016-04-13 and 2016-04-14. The following run
537+
specifications are:
538+
1) Interval: from 2016-04-11 to 2016-04-13
539+
Target table result:
540+
+---+-------+---+-----+------+-------------------+----+
541+
|day|feature| id|month|run_id| timestamp|year|
542+
+---+-------+---+-----+------+-------------------+----+
543+
| 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016|
544+
| 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
545+
| 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
546+
+---+-------+---+-----+------+-------------------+----+
547+
2) Interval: only 2016-04-14.
548+
Target table result:
549+
+---+-------+---+-----+------+-------------------+----+
550+
|day|feature| id|month|run_id| timestamp|year|
551+
+---+-------+---+-----+------+-------------------+----+
552+
| 11| 200| 1| 4| 1|2016-04-11 11:31:11|2016|
553+
| 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
554+
| 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
555+
| 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016|
556+
+---+-------+---+-----+------+-------------------+----+
557+
3) Interval: only 2016-04-11.
558+
Target table result:
559+
+---+-------+---+-----+------+-------------------+----+
560+
|day|feature| id|month|run_id| timestamp|year|
561+
+---+-------+---+-----+------+-------------------+----+
562+
| 11| 200| 1| 4| 3|2016-04-11 11:31:11|2016|
563+
| 12| 300| 1| 4| 1|2016-04-12 11:44:12|2016|
564+
| 13| 400| 1| 4| 1|2016-04-13 11:46:24|2016|
565+
| 14| 500| 1| 4| 2|2016-04-14 12:03:21|2016|
566+
+---+-------+---+-----+------+-------------------+----+
567+
"""
568+
# arrange
569+
create_temp_view(dataframe=mocked_date_df, name="input_data")
570+
571+
db = environment.get_variable("FEATURE_STORE_HISTORICAL_DATABASE")
572+
path = "test_folder/historical/entity/feature_set"
573+
574+
spark_session.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic")
575+
spark_session.sql(f"create database if not exists {db}")
576+
spark_session.sql(
577+
f"create table if not exists {db}.feature_set_interval "
578+
f"(id int, timestamp timestamp, feature int, "
579+
f"run_id int, year int, month int, day int);"
580+
)
581+
582+
dbconfig = MetastoreConfig()
583+
dbconfig.get_options = Mock(
584+
return_value={"mode": "overwrite", "format_": "parquet", "path": path}
585+
)
586+
587+
historical_writer = HistoricalFeatureStoreWriter(
588+
db_config=dbconfig, interval_mode=True
589+
)
590+
591+
first_run_hook = RunHook(id=1)
592+
second_run_hook = RunHook(id=2)
593+
third_run_hook = RunHook(id=3)
594+
595+
(
596+
first_run_target_df,
597+
second_run_target_df,
598+
third_run_target_df,
599+
) = pipeline_interval_run_target_dfs
600+
601+
test_pipeline = FeatureSetPipeline(
602+
source=Source(
603+
readers=[
604+
TableReader(id="id", table="input_data",).with_incremental_strategy(
605+
IncrementalStrategy("ts")
606+
),
607+
],
608+
query="select * from id ",
609+
),
610+
feature_set=FeatureSet(
611+
name="feature_set_interval",
612+
entity="entity",
613+
description="",
614+
keys=[KeyFeature(name="id", description="", dtype=DataType.INTEGER,)],
615+
timestamp=TimestampFeature(from_column="ts"),
616+
features=[
617+
Feature(name="feature", description="", dtype=DataType.INTEGER),
618+
Feature(name="run_id", description="", dtype=DataType.INTEGER),
619+
],
620+
),
621+
sink=Sink([historical_writer],),
622+
)
623+
624+
# act and assert
625+
dbconfig.get_path_with_partitions = Mock(
626+
return_value=[
627+
"test_folder/historical/entity/feature_set/year=2016/month=4/day=11",
628+
"test_folder/historical/entity/feature_set/year=2016/month=4/day=12",
629+
"test_folder/historical/entity/feature_set/year=2016/month=4/day=13",
630+
]
631+
)
632+
test_pipeline.feature_set.add_pre_hook(first_run_hook)
633+
test_pipeline.run(end_date="2016-04-13", start_date="2016-04-11")
634+
first_run_output_df = spark_session.read.parquet(path)
635+
assert_dataframe_equality(first_run_output_df, first_run_target_df)
636+
637+
dbconfig.get_path_with_partitions = Mock(
638+
return_value=[
639+
"test_folder/historical/entity/feature_set/year=2016/month=4/day=14",
640+
]
641+
)
642+
test_pipeline.feature_set.add_pre_hook(second_run_hook)
643+
test_pipeline.run_for_date("2016-04-14")
644+
second_run_output_df = spark_session.read.parquet(path)
645+
assert_dataframe_equality(second_run_output_df, second_run_target_df)
646+
647+
dbconfig.get_path_with_partitions = Mock(
648+
return_value=[
649+
"test_folder/historical/entity/feature_set/year=2016/month=4/day=11",
650+
]
651+
)
652+
test_pipeline.feature_set.add_pre_hook(third_run_hook)
653+
test_pipeline.run_for_date("2016-04-11")
654+
third_run_output_df = spark_session.read.parquet(path)
655+
assert_dataframe_equality(third_run_output_df, third_run_target_df)
656+
657+
# tear down
658+
shutil.rmtree("test_folder")

0 commit comments

Comments
 (0)