Skip to content

Commit 2213264

Browse files
moromimayMayara Moromisato
authored andcommitted
[MLOP-635] Rebase Incremental Job/Interval Run branch for test on selected feature sets (#278)
* Add interval branch modifications. * Add interval_runs notebook. * Add tests. * Apply style (black, flack8 and mypy). * Fix tests. * Change version to create package dev.
1 parent 4e17b81 commit 2213264

15 files changed

+304
-345
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ Preferably use **Added**, **Changed**, **Removed** and **Fixed** topics in each
3333
* [BUG] Fix Cassandra Connect Session ([#316](https://github.com/quintoandar/butterfree/pull/316))
3434
* Fix method to generate agg feature name. ([#326](https://github.com/quintoandar/butterfree/pull/326))
3535

36+
## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3)
37+
### Added
38+
* [MLOP-636] Create migration classes ([#282](https://github.com/quintoandar/butterfree/pull/282))
39+
3640
## [1.1.3](https://github.com/quintoandar/butterfree/releases/tag/1.1.3)
3741
### Added
3842
* [MLOP-599] Apply mypy to ButterFree ([#273](https://github.com/quintoandar/butterfree/pull/273))

butterfree/clients/cassandra_client.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,9 @@
33
from typing import Dict, List, Optional
44

55
from cassandra.auth import PlainTextAuthProvider
6-
from cassandra.cluster import (
7-
EXEC_PROFILE_DEFAULT,
8-
Cluster,
9-
ExecutionProfile,
10-
ResponseFuture,
11-
Session,
12-
)
13-
from cassandra.policies import DCAwareRoundRobinPolicy
14-
from cassandra.query import ConsistencyLevel, dict_factory
6+
from cassandra.cluster import Cluster, ResponseFuture, Session
7+
from cassandra.policies import RoundRobinPolicy
8+
from cassandra.query import dict_factory
159
from typing_extensions import TypedDict
1610

1711
from butterfree.clients import AbstractClient
@@ -61,36 +55,29 @@ def __init__(
6155
@property
6256
def conn(self, *, ssl_path: str = None) -> Session: # type: ignore
6357
"""Establishes a Cassandra connection."""
64-
if not self._session:
65-
auth_provider = (
66-
PlainTextAuthProvider(username=self.user, password=self.password)
67-
if self.user is not None
68-
else None
69-
)
70-
ssl_opts = (
71-
{
72-
"ca_certs": ssl_path,
73-
"ssl_version": PROTOCOL_TLSv1,
74-
"cert_reqs": CERT_REQUIRED,
75-
}
76-
if ssl_path is not None
77-
else None
78-
)
79-
80-
execution_profiles = {
81-
EXEC_PROFILE_DEFAULT: ExecutionProfile(
82-
load_balancing_policy=DCAwareRoundRobinPolicy(),
83-
consistency_level=ConsistencyLevel.LOCAL_QUORUM,
84-
row_factory=dict_factory,
85-
)
58+
auth_provider = (
59+
PlainTextAuthProvider(username=self.user, password=self.password)
60+
if self.user is not None
61+
else None
62+
)
63+
ssl_opts = (
64+
{
65+
"ca_certs": ssl_path,
66+
"ssl_version": PROTOCOL_TLSv1,
67+
"cert_reqs": CERT_REQUIRED,
8668
}
87-
cluster = Cluster(
88-
contact_points=self.host,
89-
auth_provider=auth_provider,
90-
ssl_options=ssl_opts,
91-
execution_profiles=execution_profiles,
92-
)
93-
self._session = cluster.connect(self.keyspace)
69+
if ssl_path is not None
70+
else None
71+
)
72+
73+
cluster = Cluster(
74+
contact_points=self.host,
75+
auth_provider=auth_provider,
76+
ssl_options=ssl_opts,
77+
load_balancing_policy=RoundRobinPolicy(),
78+
)
79+
self._session = cluster.connect(self.keyspace)
80+
self._session.row_factory = dict_factory
9481
return self._session
9582

9683
def sql(self, query: str) -> ResponseFuture:
@@ -100,9 +87,11 @@ def sql(self, query: str) -> ResponseFuture:
10087
query: desired query.
10188
10289
"""
103-
return self.conn.execute(query)
90+
if not self._session:
91+
raise RuntimeError("There's no session available for this query.")
92+
return self._session.execute(query)
10493

105-
def get_schema(self, table: str, database: str = None) -> List[Dict[str, str]]:
94+
def get_schema(self, table: str) -> List[Dict[str, str]]:
10695
"""Returns desired table schema.
10796
10897
Attributes:

butterfree/clients/spark_client.py

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""SparkClient entity."""
22

3-
import json
43
from typing import Any, Dict, List, Optional, Union
54

65
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
@@ -58,7 +57,7 @@ def read(
5857
"""
5958
if not isinstance(format, str):
6059
raise ValueError("format needs to be a string with the desired read format")
61-
if path and not isinstance(path, (str, list)):
60+
if not isinstance(path, (str, list)):
6261
raise ValueError("path needs to be a string or a list of string")
6362

6463
df_reader: Union[
@@ -67,7 +66,7 @@ def read(
6766

6867
df_reader = df_reader.schema(schema) if schema else df_reader
6968

70-
return df_reader.format(format).load(path=path, **options) # type: ignore
69+
return df_reader.format(format).load(path, **options) # type: ignore
7170

7271
def read_table(self, table: str, database: str = None) -> DataFrame:
7372
"""Use the SparkSession.read interface to read a metastore table.
@@ -217,8 +216,7 @@ def write_table(
217216
**options,
218217
)
219218

220-
@staticmethod
221-
def create_temporary_view(dataframe: DataFrame, name: str) -> Any:
219+
def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any:
222220
"""Create a temporary view from a given dataframe.
223221
224222
Args:
@@ -273,65 +271,3 @@ def add_table_partitions(
273271
)
274272

275273
self.conn.sql(command)
276-
277-
@staticmethod
278-
def _filter_schema(schema: DataFrame) -> List[str]:
279-
"""Returns filtered schema with the desired information.
280-
281-
Attributes:
282-
schema: desired table.
283-
284-
Returns:
285-
A list of strings in the format
286-
['{"column_name": "example1", type: "Spark_type"}', ...]
287-
288-
"""
289-
return (
290-
schema.filter(
291-
~schema.col_name.isin(
292-
["# Partition Information", "# col_name", "year", "month", "day"]
293-
)
294-
)
295-
.toJSON()
296-
.collect()
297-
)
298-
299-
def _convert_schema(self, schema: DataFrame) -> List[Dict[str, str]]:
300-
"""Returns schema with the desired information.
301-
302-
Attributes:
303-
schema: desired table.
304-
305-
Returns:
306-
A list of dictionaries in the format
307-
[{"column_name": "example1", type: "Spark_type"}, ...]
308-
309-
"""
310-
schema_list = self._filter_schema(schema)
311-
converted_schema = []
312-
for row in schema_list:
313-
converted_schema.append(json.loads(row))
314-
315-
return converted_schema
316-
317-
def get_schema(self, table: str, database: str = None) -> List[Dict[str, str]]:
318-
"""Returns desired table schema.
319-
320-
Attributes:
321-
table: desired table.
322-
323-
Returns:
324-
A list of dictionaries in the format
325-
[{"column_name": "example1", type: "Spark_type"}, ...]
326-
327-
"""
328-
query = f"DESCRIBE {database}.{table} " # noqa
329-
330-
response = self.sql(query)
331-
332-
if not response:
333-
raise RuntimeError(
334-
f"No columns found for table: {table}" f"in database: {database}"
335-
)
336-
337-
return self._convert_schema(response)

butterfree/load/writers/historical_feature_store_writer.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Holds the Historical Feature Store writer class."""
22

33
import os
4-
from typing import Any
4+
from typing import Any, Union
55

66
from pyspark.sql.dataframe import DataFrame
77
from pyspark.sql.functions import dayofmonth, month, year
@@ -106,17 +106,16 @@ class HistoricalFeatureStoreWriter(Writer):
106106

107107
def __init__(
108108
self,
109-
db_config: AbstractWriteConfig = None,
109+
db_config: Union[AbstractWriteConfig, MetastoreConfig] = None,
110110
database: str = None,
111111
num_partitions: int = None,
112112
validation_threshold: float = DEFAULT_VALIDATION_THRESHOLD,
113113
debug_mode: bool = False,
114114
interval_mode: bool = False,
115115
check_schema_hook: Hook = None,
116116
):
117-
super(HistoricalFeatureStoreWriter, self).__init__(
118-
db_config or MetastoreConfig(), debug_mode, interval_mode
119-
)
117+
super(HistoricalFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
118+
self.db_config = db_config or MetastoreConfig()
120119
self.database = database or environment.get_variable(
121120
"FEATURE_STORE_HISTORICAL_DATABASE"
122121
)
@@ -141,20 +140,25 @@ def write(
141140
"""
142141
dataframe = self._create_partitions(dataframe)
143142

144-
dataframe = self._apply_transformations(dataframe)
143+
partition_df = self._apply_transformations(dataframe)
144+
145+
if self.debug_mode:
146+
dataframe = partition_df
147+
else:
148+
dataframe = self.check_schema(
149+
spark_client, partition_df, feature_set.name, self.database
150+
)
145151

146152
if self.interval_mode:
147-
partition_overwrite_mode = spark_client.conn.conf.get(
148-
"spark.sql.sources.partitionOverwriteMode"
149-
).lower()
150-
151-
if partition_overwrite_mode != "dynamic":
152-
raise RuntimeError(
153-
"m=load_incremental_table, "
154-
"spark.sql.sources.partitionOverwriteMode={}, "
155-
"msg=partitionOverwriteMode have to "
156-
"be configured to 'dynamic'".format(partition_overwrite_mode)
153+
if self.debug_mode:
154+
spark_client.create_temporary_view(
155+
dataframe=dataframe,
156+
name=f"historical_feature_store__{feature_set.name}",
157157
)
158+
return
159+
160+
self._incremental_mode(feature_set, dataframe, spark_client)
161+
return
158162

159163
if self.debug_mode:
160164
spark_client.create_temporary_view(
@@ -173,6 +177,34 @@ def write(
173177
**self.db_config.get_options(s3_key),
174178
)
175179

180+
def _incremental_mode(
181+
self, feature_set: FeatureSet, dataframe: DataFrame, spark_client: SparkClient
182+
) -> None:
183+
184+
partition_overwrite_mode = spark_client.conn.conf.get(
185+
"spark.sql.sources.partitionOverwriteMode"
186+
).lower()
187+
188+
if partition_overwrite_mode != "dynamic":
189+
raise RuntimeError(
190+
"m=load_incremental_table, "
191+
"spark.sql.sources.partitionOverwriteMode={}, "
192+
"msg=partitionOverwriteMode have to be configured to 'dynamic'".format(
193+
partition_overwrite_mode
194+
)
195+
)
196+
197+
s3_key = os.path.join("historical", feature_set.entity, feature_set.name)
198+
options = {"path": self.db_config.get_options(s3_key).get("path")}
199+
200+
spark_client.write_dataframe(
201+
dataframe=dataframe,
202+
format_=self.db_config.format_,
203+
mode=self.db_config.mode,
204+
**options,
205+
partitionBy=self.PARTITION_BY,
206+
)
207+
176208
def _assert_validation_count(
177209
self, table_name: str, written_count: int, dataframe_count: int
178210
) -> None:
@@ -199,9 +231,10 @@ def validate(
199231
Raises:
200232
AssertionError: if count of written data doesn't match count in current
201233
feature set dataframe.
234+
202235
"""
203236
table_name = (
204-
os.path.join("historical", feature_set.entity, feature_set.name)
237+
f"{feature_set.name}"
205238
if self.interval_mode and not self.debug_mode
206239
else (
207240
f"{self.database}.{feature_set.name}"
@@ -213,9 +246,7 @@ def validate(
213246
written_count = (
214247
spark_client.read(
215248
self.db_config.format_,
216-
path=self.db_config.get_path_with_partitions(
217-
table_name, self._create_partitions(dataframe)
218-
),
249+
path=self.db_config.get_path_with_partitions(table_name, dataframe),
219250
).count()
220251
if self.interval_mode and not self.debug_mode
221252
else spark_client.read_table(table_name).count()

butterfree/load/writers/online_feature_store_writer.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pyspark.sql.functions import col, row_number
88
from pyspark.sql.streaming import StreamingQuery
99

10-
from butterfree.clients import SparkClient
10+
from butterfree.clients import CassandraClient, SparkClient
1111
from butterfree.configs.db import AbstractWriteConfig, CassandraConfig
1212
from butterfree.constants.columns import TIMESTAMP_COLUMN
1313
from butterfree.hooks import Hook
@@ -80,18 +80,16 @@ class OnlineFeatureStoreWriter(Writer):
8080

8181
def __init__(
8282
self,
83-
db_config: AbstractWriteConfig = None,
84-
database: str = None,
83+
db_config: Union[AbstractWriteConfig, CassandraConfig] = None,
8584
debug_mode: bool = False,
8685
write_to_entity: bool = False,
8786
interval_mode: bool = False,
8887
check_schema_hook: Hook = None,
8988
):
90-
super(OnlineFeatureStoreWriter, self).__init__(
91-
db_config or CassandraConfig(), debug_mode, interval_mode, write_to_entity
92-
)
89+
super(OnlineFeatureStoreWriter, self).__init__(debug_mode, interval_mode)
90+
self.db_config = db_config or CassandraConfig()
91+
self.write_to_entity = write_to_entity
9392
self.check_schema_hook = check_schema_hook
94-
self.database = database
9593

9694
@staticmethod
9795
def filter_latest(dataframe: DataFrame, id_columns: List[Any]) -> DataFrame:
@@ -182,6 +180,22 @@ def write(
182180
"""
183181
table_name = feature_set.entity if self.write_to_entity else feature_set.name
184182

183+
if not self.debug_mode:
184+
config = (
185+
self.db_config
186+
if self.db_config == CassandraConfig
187+
else CassandraConfig()
188+
)
189+
190+
cassandra_client = CassandraClient(
191+
host=[config.host],
192+
keyspace=config.keyspace,
193+
user=config.username,
194+
password=config.password,
195+
)
196+
197+
dataframe = self.check_schema(cassandra_client, dataframe, table_name)
198+
185199
if dataframe.isStreaming:
186200
dataframe = self._apply_transformations(dataframe)
187201
if self.debug_mode:

0 commit comments

Comments
 (0)