Skip to content

Commit bffd50e

Browse files
author
AlvaroMarquesAndrade
committed
add get_schema to spark client
1 parent 342e584 commit bffd50e

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

butterfree/clients/spark_client.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,16 @@ def add_table_partitions(
276276

277277
@staticmethod
278278
def _filter_schema(schema: DataFrame) -> List[str]:
279+
"""Returns filtered schema with the desired information.
280+
281+
Attributes:
282+
schema: desired table.
283+
284+
Returns:
285+
A list of strings in the format
286+
['{"column_name": "example1", type: "Spark_type"}', ...]
287+
288+
"""
279289
return (
280290
schema.filter(
281291
~schema.col_name.isin(
@@ -287,12 +297,22 @@ def _filter_schema(schema: DataFrame) -> List[str]:
287297
)
288298

289299
def _convert_schema(self, schema: DataFrame) -> List[Dict[str, str]]:
300+
"""Returns schema with the desired information.
301+
302+
Attributes:
303+
schema: desired table.
304+
305+
Returns:
306+
A list of dictionaries in the format
307+
[{"column_name": "example1", type: "Spark_type"}, ...]
308+
309+
"""
290310
schema_list = self._filter_schema(schema)
291-
schema = []
311+
converted_schema = []
292312
for row in schema_list:
293-
schema.append(json.loads(row))
313+
converted_schema.append(json.loads(row))
294314

295-
return schema
315+
return converted_schema
296316

297317
def get_schema(self, table: str, database: str) -> List[Dict[str, str]]:
298318
"""Returns desired table schema.
@@ -301,7 +321,7 @@ def get_schema(self, table: str, database: str) -> List[Dict[str, str]]:
301321
table: desired table.
302322
303323
Returns:
304-
A list dictionaries in the format
324+
A list of dictionaries in the format
305325
[{"column_name": "example1", type: "Spark_type"}, ...]
306326
307327
"""

tests/unit/butterfree/clients/test_spark_client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ def create_temp_view(dataframe: DataFrame, name: str) -> None:
1515
dataframe.createOrReplaceTempView(name)
1616

1717

18+
def create_db_and_table(spark, database, table, view):
19+
spark.sql(f"create database if not exists {database}")
20+
spark.sql(f"use {database}")
21+
spark.sql(
22+
f"create table if not exists {database}.{table} " # noqa
23+
f"as select * from {view}" # noqa
24+
)
25+
26+
1827
class TestSparkClient:
1928
def test_conn(self) -> None:
2029
# arrange
@@ -293,3 +302,27 @@ def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition):
293302
# act and assert
294303
with pytest.raises(ValueError):
295304
spark_client.add_table_partitions(partition, "table", "db")
305+
306+
def test_get_schema(
307+
self, target_df: DataFrame, spark_session: SparkSession
308+
) -> None:
309+
# arrange
310+
spark_client = SparkClient()
311+
create_temp_view(dataframe=target_df, name="temp_view")
312+
create_db_and_table(
313+
spark=spark_session,
314+
database="test_db",
315+
table="test_table",
316+
view="temp_view",
317+
)
318+
319+
expected_schema = [
320+
{"col_name": "col1", "data_type": "string"},
321+
{"col_name": "col2", "data_type": "bigint"},
322+
]
323+
324+
# act
325+
schema = spark_client.get_schema(table="test_table", database="test_db")
326+
327+
# assert
328+
assert schema, expected_schema

0 commit comments

Comments
 (0)