Skip to content

Commit cd29d48

Browse files
committed
Fix tests.
1 parent 50d0273 commit cd29d48

File tree

7 files changed

+41
-12
lines changed

7 files changed

+41
-12
lines changed

butterfree/clients/spark_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def conn(self) -> SparkSession:
3434
def read(
3535
self,
3636
format: str,
37-
path: Optional[str] = None,
37+
path: Optional[Union[str, List[str]]] = None,
3838
schema: Optional[StructType] = None,
3939
stream: bool = False,
4040
**options: Any,
@@ -57,16 +57,16 @@ def read(
5757
"""
5858
if not isinstance(format, str):
5959
raise ValueError("format needs to be a string with the desired read format")
60-
if not isinstance(path, str):
61-
raise ValueError("path needs to be a string")
60+
if not isinstance(path, (str, list)):
61+
raise ValueError("path needs to be a string or a list of string")
6262

6363
df_reader: Union[
6464
DataStreamReader, DataFrameReader
6565
] = self.conn.readStream if stream else self.conn.read
6666

6767
df_reader = df_reader.schema(schema) if schema else df_reader
6868

69-
return df_reader.format(format).load(path, **options)
69+
return df_reader.format(format).load(path, **options) # type: ignore
7070

7171
def read_table(self, table: str, database: str = None) -> DataFrame:
7272
"""Use the SparkSession.read interface to read a metastore table.

butterfree/configs/db/metastore_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_path_with_partitions(self, key: str, dataframe: DataFrame) -> List:
109109
)
110110
for row in dataframe_values:
111111
path_list.append(
112-
f"s3a://{self.bucket}/{key}/year={row['year']}/"
112+
f"{self.file_system}://{self.path}/{key}/year={row['year']}/"
113113
f"month={row['month']}/day={row['day']}"
114114
)
115115

butterfree/extract/readers/reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def with_(
5252
self.transformations.append(new_transformation)
5353
return self
5454

55-
def with_incremantal_strategy(
55+
def with_incremental_strategy(
5656
self, incremental_strategy: IncrementalStrategy
5757
) -> "Reader":
5858
"""Define the incremental strategy for the Reader.

butterfree/load/writers/online_feature_store_writer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,17 @@ def write(
181181
table_name = feature_set.entity if self.write_to_entity else feature_set.name
182182

183183
if not self.debug_mode:
184+
config = (
185+
self.db_config
186+
if self.db_config == CassandraConfig
187+
else CassandraConfig()
188+
)
189+
184190
cassandra_client = CassandraClient(
185-
host=[self.db_config.host],
186-
keyspace=self.db_config.keyspace,
187-
user=self.db_config.username,
188-
password=self.db_config.password,
191+
host=[config.host],
192+
keyspace=config.keyspace,
193+
user=config.username,
194+
password=config.password,
189195
)
190196

191197
dataframe = self.check_schema(cassandra_client, dataframe, table_name)

tests/integration/butterfree/pipelines/test_feature_set_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def test_feature_set_pipeline(
163163
],
164164
timestamp=TimestampFeature(),
165165
),
166-
sink=Sink(writers=[historical_writer],),
166+
sink=Sink(writers=[historical_writer]),
167167
)
168168
test_pipeline.run()
169169

tests/unit/butterfree/load/writers/test_historical_feature_store_writer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,15 @@ def test_write_with_transform(
318318
# given
319319
spark_client = mocker.stub("spark_client")
320320
spark_client.write_table = mocker.stub("write_table")
321+
321322
writer = HistoricalFeatureStoreWriter().with_(json_transform)
322323

324+
schema_dataframe = writer._create_partitions(feature_set_dataframe)
325+
json_dataframe = writer._apply_transformations(schema_dataframe)
326+
writer.check_schema_hook = mocker.stub("check_schema_hook")
327+
writer.check_schema_hook.run = mocker.stub("run")
328+
writer.check_schema_hook.run.return_value = json_dataframe
329+
323330
# when
324331
writer.write(
325332
feature_set=feature_set,

tests/unit/butterfree/load/writers/test_online_feature_store_writer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def test_write_in_debug_and_stream_mode(self, feature_set, spark_session, mocker
151151
assert isinstance(handler, StreamingQuery)
152152

153153
@pytest.mark.parametrize("has_checkpoint", [True, False])
154-
def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
154+
def test_write_stream(self, feature_set, has_checkpoint, monkeypatch, mocker):
155155
# arrange
156156
spark_client = SparkClient()
157157
spark_client.write_stream = Mock()
@@ -174,6 +174,10 @@ def test_write_stream(self, feature_set, has_checkpoint, monkeypatch):
174174
writer = OnlineFeatureStoreWriter(cassandra_config)
175175
writer.filter_latest = Mock()
176176

177+
writer.check_schema_hook = mocker.stub("check_schema_hook")
178+
writer.check_schema_hook.run = mocker.stub("run")
179+
writer.check_schema_hook.run.return_value = dataframe
180+
177181
# act
178182
stream_handler = writer.write(feature_set, dataframe, spark_client)
179183

@@ -252,6 +256,10 @@ def test_write_with_transform(
252256
spark_client.write_dataframe = mocker.stub("write_dataframe")
253257
writer = OnlineFeatureStoreWriter(cassandra_config).with_(json_transform)
254258

259+
writer.check_schema_hook = mocker.stub("check_schema_hook")
260+
writer.check_schema_hook.run = mocker.stub("run")
261+
writer.check_schema_hook.run.return_value = feature_set_dataframe
262+
255263
# when
256264
writer.write(feature_set, feature_set_dataframe, spark_client)
257265

@@ -285,6 +293,10 @@ def test_write_with_kafka_config(
285293
kafka_config = KafkaConfig()
286294
writer = OnlineFeatureStoreWriter(kafka_config).with_(json_transform)
287295

296+
writer.check_schema_hook = mocker.stub("check_schema_hook")
297+
writer.check_schema_hook.run = mocker.stub("run")
298+
writer.check_schema_hook.run.return_value = feature_set_dataframe
299+
288300
# when
289301
writer.write(feature_set, feature_set_dataframe, spark_client)
290302

@@ -308,6 +320,10 @@ def test_write_with_custom_kafka_config(
308320
json_transform
309321
)
310322

323+
custom_writer.check_schema_hook = mocker.stub("check_schema_hook")
324+
custom_writer.check_schema_hook.run = mocker.stub("run")
325+
custom_writer.check_schema_hook.run.return_value = feature_set_dataframe
326+
311327
# when
312328
custom_writer.write(feature_set, feature_set_dataframe, spark_client)
313329

0 commit comments

Comments
 (0)