Skip to content

Delta to JSONL conversion script cleanup and bug fix #868

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 70 additions & 40 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
from pyspark.sql.types import Row

MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0'
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1'
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2'

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -377,64 +377,61 @@ def fetch(
cursor.close()


def fetch_DT(args: Namespace) -> None:
"""Fetch UC Delta Table to local as jsonl."""
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
f'Check the json_output_folder and verify it is a local path!')

if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
)

os.makedirs(args.json_output_folder, exist_ok=True)

if not args.json_output_filename.endswith('.jsonl'):
raise ValueError('json_output_filename needs to be a jsonl file')

log.info(f'Directory {args.json_output_folder} created.')
def validate_and_get_cluster_info(cluster_id: str,
databricks_host: str,
databricks_token: str,
http_path: Optional[str],
use_serverless: bool = False) -> tuple:
"""Validate and get cluster info for running the Delta to JSONL conversion.

Args:
cluster_id (str): cluster id to validate and fetch additional info for
databricks_host (str): databricks host name
databricks_token (str): databricks auth token
http_path (Optional[str]): http path to use for sql connect
use_serverless (bool): whether to use serverless or not
"""
method = 'dbsql'
dbsql = None
sparkSession = None

if args.use_serverless:
if use_serverless:
method = 'dbconnect'
else:
w = WorkspaceClient()
res = w.clusters.get(cluster_id=args.cluster_id)
runtime_version = res.spark_version.split('-scala')[0].replace(
'x-snapshot', '0').replace('x', '0')
res = w.clusters.get(cluster_id=cluster_id)
if res is None:
raise ValueError(
f'Cluster id {cluster_id} does not exist. Check cluster id and try again!'
)
stripped_runtime = re.sub(
r'[a-zA-Z]', '',
res.spark_version.split('-scala')[0].replace('x-snapshot', ''))
runtime_version = re.sub(r'.-+$', '', stripped_runtime)
if version.parse(runtime_version) < version.parse(
MINIMUM_SQ_CONNECT_DBR_VERSION):
raise ValueError(
f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}'
)

if args.http_path is None and version.parse(
if http_path is None and version.parse(
runtime_version) >= version.parse(
MINIMUM_DB_CONNECT_DBR_VERSION):
method = 'dbconnect'

if method == 'dbconnect':
try:
if args.use_serverless:
if use_serverless:
session_id = str(uuid4())
sparkSession = DatabricksSession.builder.host(
args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header(
databricks_host).token(databricks_token).header(
'x-databricks-session-id', session_id).getOrCreate()

else:
sparkSession = DatabricksSession.builder.remote(
host=args.DATABRICKS_HOST,
token=args.DATABRICKS_TOKEN,
cluster_id=args.cluster_id).getOrCreate()
host=databricks_host,
token=databricks_token,
cluster_id=cluster_id).getOrCreate()

except Exception as e:
raise RuntimeError(
Expand All @@ -444,15 +441,47 @@ def fetch_DT(args: Namespace) -> None:
try:
dbsql = sql.connect(
server_hostname=re.compile(r'^https?://').sub(
'', args.DATABRICKS_HOST).strip(
'', databricks_host).strip(
), # sqlconnect hangs if hostname starts with https
http_path=args.http_path,
access_token=args.DATABRICKS_TOKEN,
http_path=http_path,
access_token=databricks_token,
)
except Exception as e:
raise RuntimeError(
'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
) from e
return method, dbsql, sparkSession


def fetch_DT(args: Namespace) -> None:
"""Fetch UC Delta Table to local as jsonl."""
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
f'Check the json_output_folder and verify it is a local path!')

if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
)

os.makedirs(args.json_output_folder, exist_ok=True)

if not args.json_output_filename.endswith('.jsonl'):
raise ValueError('json_output_filename needs to be a jsonl file')

log.info(f'Directory {args.json_output_folder} created.')

method, dbsql, sparkSession = validate_and_get_cluster_info(
cluster_id=args.cluster_id,
databricks_host=args.DATABRICKS_HOST,
databricks_token=args.DATABRICKS_TOKEN,
http_path=args.http_path,
use_serverless=args.use_serverless)

fetch(method, args.delta_table_name, args.json_output_folder,
args.batch_size, args.processes, sparkSession, dbsql)
Expand Down Expand Up @@ -494,9 +523,8 @@ def fetch_DT(args: Namespace) -> None:
help='number of processes allowed to use')
parser.add_argument(
'--cluster_id',
required=True,
required=False,
type=str,
default=None,
help=
'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.'
)
Expand All @@ -513,7 +541,9 @@ def fetch_DT(args: Namespace) -> None:
required=False,
type=str,
default='train-00000-of-00001.jsonl',
help='The combined final jsonl that combines all partitioned jsonl')
help=
'The name of the combined final jsonl that combines all partitioned jsonl'
)
args = parser.parse_args()

from databricks.sdk import WorkspaceClient
Expand Down