Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion butterfree/clients/spark_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SparkClient entity."""

import json
from typing import Any, Dict, List, Optional, Union

from pyspark.sql import DataFrame, DataFrameReader, SparkSession
Expand Down Expand Up @@ -216,7 +217,8 @@ def write_table(
**options,
)

def create_temporary_view(self, dataframe: DataFrame, name: str) -> Any:
@staticmethod
def create_temporary_view(dataframe: DataFrame, name: str) -> Any:
"""Create a temporary view from a given dataframe.

Args:
Expand Down Expand Up @@ -271,3 +273,65 @@ def add_table_partitions(
)

self.conn.sql(command)

@staticmethod
def _filter_schema(schema: DataFrame) -> List[str]:
"""Returns filtered schema with the desired information.

Attributes:
schema: desired table.

Returns:
A list of strings in the format
['{"column_name": "example1", type: "Spark_type"}', ...]

"""
return (
schema.filter(
~schema.col_name.isin(
["# Partition Information", "# col_name", "year", "month", "day"]
)
)
.toJSON()
.collect()
)

def _convert_schema(self, schema: DataFrame) -> List[Dict[str, str]]:
"""Returns schema with the desired information.

Attributes:
schema: desired table.

Returns:
A list of dictionaries in the format
[{"column_name": "example1", type: "Spark_type"}, ...]

"""
schema_list = self._filter_schema(schema)
converted_schema = []
for row in schema_list:
converted_schema.append(json.loads(row))

return converted_schema

def get_schema(self, table: str, database: str) -> List[Dict[str, str]]:
"""Returns desired table schema.

Attributes:
table: desired table.

Returns:
A list of dictionaries in the format
[{"column_name": "example1", type: "Spark_type"}, ...]

"""
query = f"DESCRIBE {database}.{table} " # noqa

response = self.sql(query)

if not response:
raise RuntimeError(
f"No columns found for table: {table}" f"in database: {database}"
)

return self._convert_schema(response)
33 changes: 33 additions & 0 deletions tests/unit/butterfree/clients/test_spark_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ def create_temp_view(dataframe: DataFrame, name: str) -> None:
dataframe.createOrReplaceTempView(name)


def create_db_and_table(spark, database, table, view):
spark.sql(f"create database if not exists {database}")
spark.sql(f"use {database}")
spark.sql(
f"create table if not exists {database}.{table} " # noqa
f"as select * from {view}" # noqa
)


class TestSparkClient:
def test_conn(self) -> None:
# arrange
Expand Down Expand Up @@ -293,3 +302,27 @@ def test_add_invalid_partitions(self, mock_spark_sql: Mock, partition):
# act and assert
with pytest.raises(ValueError):
spark_client.add_table_partitions(partition, "table", "db")

def test_get_schema(
self, target_df: DataFrame, spark_session: SparkSession
) -> None:
# arrange
spark_client = SparkClient()
create_temp_view(dataframe=target_df, name="temp_view")
create_db_and_table(
spark=spark_session,
database="test_db",
table="test_table",
view="temp_view",
)

expected_schema = [
{"col_name": "col1", "data_type": "string"},
{"col_name": "col2", "data_type": "bigint"},
]

# act
schema = spark_client.get_schema(table="test_table", database="test_db")

# assert
assert schema, expected_schema