Skip to content

Commit b69318e

Browse files
authored
Delta to JSONL conversion script cleanup and bug fix (#868)
* Small test change * small cleanups * lint and precommit * lint and precommit * comments * another one * pr suggestion and use input param not args
1 parent d05c099 commit b69318e

File tree

1 file changed

+70
-40
lines changed

1 file changed

+70
-40
lines changed

scripts/data_prep/convert_delta_to_json.py

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from pyspark.sql.dataframe import DataFrame as SparkDataFrame
3434
from pyspark.sql.types import Row
3535

36-
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1.0'
37-
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2.0'
36+
MINIMUM_DB_CONNECT_DBR_VERSION = '14.1'
37+
MINIMUM_SQ_CONNECT_DBR_VERSION = '12.2'
3838

3939
log = logging.getLogger(__name__)
4040

@@ -377,64 +377,61 @@ def fetch(
377377
cursor.close()
378378

379379

380-
def fetch_DT(args: Namespace) -> None:
381-
"""Fetch UC Delta Table to local as jsonl."""
382-
log.info(f'Start .... Convert delta to json')
383-
384-
obj = urllib.parse.urlparse(args.json_output_folder)
385-
if obj.scheme != '':
386-
raise ValueError(
387-
f'Check the json_output_folder and verify it is a local path!')
388-
389-
if os.path.exists(args.json_output_folder):
390-
if not os.path.isdir(args.json_output_folder) or os.listdir(
391-
args.json_output_folder):
392-
raise RuntimeError(
393-
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
394-
)
395-
396-
os.makedirs(args.json_output_folder, exist_ok=True)
397-
398-
if not args.json_output_filename.endswith('.jsonl'):
399-
raise ValueError('json_output_filename needs to be a jsonl file')
400-
401-
log.info(f'Directory {args.json_output_folder} created.')
380+
def validate_and_get_cluster_info(cluster_id: str,
381+
databricks_host: str,
382+
databricks_token: str,
383+
http_path: Optional[str],
384+
use_serverless: bool = False) -> tuple:
385+
"""Validate and get cluster info for running the Delta to JSONL conversion.
402386
387+
Args:
388+
cluster_id (str): cluster id to validate and fetch additional info for
389+
databricks_host (str): databricks host name
390+
databricks_token (str): databricks auth token
391+
http_path (Optional[str]): http path to use for sql connect
392+
use_serverless (bool): whether to use serverless or not
393+
"""
403394
method = 'dbsql'
404395
dbsql = None
405396
sparkSession = None
406397

407-
if args.use_serverless:
398+
if use_serverless:
408399
method = 'dbconnect'
409400
else:
410401
w = WorkspaceClient()
411-
res = w.clusters.get(cluster_id=args.cluster_id)
412-
runtime_version = res.spark_version.split('-scala')[0].replace(
413-
'x-snapshot', '0').replace('x', '0')
402+
res = w.clusters.get(cluster_id=cluster_id)
403+
if res is None:
404+
raise ValueError(
405+
f'Cluster id {cluster_id} does not exist. Check cluster id and try again!'
406+
)
407+
stripped_runtime = re.sub(
408+
r'[a-zA-Z]', '',
409+
res.spark_version.split('-scala')[0].replace('x-snapshot', ''))
410+
runtime_version = re.sub(r'.-+$', '', stripped_runtime)
414411
if version.parse(runtime_version) < version.parse(
415412
MINIMUM_SQ_CONNECT_DBR_VERSION):
416413
raise ValueError(
417414
f'The minium DBR version required is {MINIMUM_SQ_CONNECT_DBR_VERSION} but got {version.parse(runtime_version)}'
418415
)
419416

420-
if args.http_path is None and version.parse(
417+
if http_path is None and version.parse(
421418
runtime_version) >= version.parse(
422419
MINIMUM_DB_CONNECT_DBR_VERSION):
423420
method = 'dbconnect'
424421

425422
if method == 'dbconnect':
426423
try:
427-
if args.use_serverless:
424+
if use_serverless:
428425
session_id = str(uuid4())
429426
sparkSession = DatabricksSession.builder.host(
430-
args.DATABRICKS_HOST).token(args.DATABRICKS_TOKEN).header(
427+
databricks_host).token(databricks_token).header(
431428
'x-databricks-session-id', session_id).getOrCreate()
432429

433430
else:
434431
sparkSession = DatabricksSession.builder.remote(
435-
host=args.DATABRICKS_HOST,
436-
token=args.DATABRICKS_TOKEN,
437-
cluster_id=args.cluster_id).getOrCreate()
432+
host=databricks_host,
433+
token=databricks_token,
434+
cluster_id=cluster_id).getOrCreate()
438435

439436
except Exception as e:
440437
raise RuntimeError(
@@ -444,15 +441,47 @@ def fetch_DT(args: Namespace) -> None:
444441
try:
445442
dbsql = sql.connect(
446443
server_hostname=re.compile(r'^https?://').sub(
447-
'', args.DATABRICKS_HOST).strip(
444+
'', databricks_host).strip(
448445
), # sqlconnect hangs if hostname starts with https
449-
http_path=args.http_path,
450-
access_token=args.DATABRICKS_TOKEN,
446+
http_path=http_path,
447+
access_token=databricks_token,
451448
)
452449
except Exception as e:
453450
raise RuntimeError(
454451
'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
455452
) from e
453+
return method, dbsql, sparkSession
454+
455+
456+
def fetch_DT(args: Namespace) -> None:
457+
"""Fetch UC Delta Table to local as jsonl."""
458+
log.info(f'Start .... Convert delta to json')
459+
460+
obj = urllib.parse.urlparse(args.json_output_folder)
461+
if obj.scheme != '':
462+
raise ValueError(
463+
f'Check the json_output_folder and verify it is a local path!')
464+
465+
if os.path.exists(args.json_output_folder):
466+
if not os.path.isdir(args.json_output_folder) or os.listdir(
467+
args.json_output_folder):
468+
raise RuntimeError(
469+
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
470+
)
471+
472+
os.makedirs(args.json_output_folder, exist_ok=True)
473+
474+
if not args.json_output_filename.endswith('.jsonl'):
475+
raise ValueError('json_output_filename needs to be a jsonl file')
476+
477+
log.info(f'Directory {args.json_output_folder} created.')
478+
479+
method, dbsql, sparkSession = validate_and_get_cluster_info(
480+
cluster_id=args.cluster_id,
481+
databricks_host=args.DATABRICKS_HOST,
482+
databricks_token=args.DATABRICKS_TOKEN,
483+
http_path=args.http_path,
484+
use_serverless=args.use_serverless)
456485

457486
fetch(method, args.delta_table_name, args.json_output_folder,
458487
args.batch_size, args.processes, sparkSession, dbsql)
@@ -494,9 +523,8 @@ def fetch_DT(args: Namespace) -> None:
494523
help='number of processes allowed to use')
495524
parser.add_argument(
496525
'--cluster_id',
497-
required=True,
526+
required=False,
498527
type=str,
499-
default=None,
500528
help=
501529
'cluster id has runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.'
502530
)
@@ -513,7 +541,9 @@ def fetch_DT(args: Namespace) -> None:
513541
required=False,
514542
type=str,
515543
default='train-00000-of-00001.jsonl',
516-
help='The combined final jsonl that combines all partitioned jsonl')
544+
help=
545+
'The name of the combined final jsonl that combines all partitioned jsonl'
546+
)
517547
args = parser.parse_args()
518548

519549
from databricks.sdk import WorkspaceClient

0 commit comments

Comments
 (0)