Skip to content

Commit da3bea1

Browse files
Remove hardcoded combined.jsonl with a flag (#861)
* Remove hardcoded combined.jsonl with a flag * update * change output_json_path output_json_folder --------- Co-authored-by: Xiaohan Zhang <[email protected]>
1 parent fa8f3d9 commit da3bea1

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

scripts/data_prep/convert_delta_to_json.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,14 @@ def run_query(
198198
raise ValueError(f'Unrecognized method: {method}')
199199

200200

201-
def get_args(signed: List, json_output_path: str, columns: List) -> Iterable:
201+
def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable:
202202
for i, r in enumerate(signed):
203-
yield (i, r.url, json_output_path, columns)
203+
yield (i, r.url, json_output_folder, columns)
204204

205205

206206
def download(ipart: int,
207207
url: str,
208-
json_output_path: str,
208+
json_output_folder: str,
209209
columns: Optional[List] = None,
210210
resp_format: str = 'arrow',
211211
compressed: bool = False) -> None:
@@ -214,7 +214,7 @@ def download(ipart: int,
214214
Args:
215215
ipart (int): presigned url id
216216
url (str): presigned url
217-
json_output_path (str): directory to save the ipart_th segment of dataframe
217+
json_output_folder (str): directory to save the ipart_th segment of dataframe
218218
columns (list): schema to save to json
219219
resp_format (str): whether to use arrow or json when collect
220220
compressed (bool): if data is compressed before downloading. Need decompress if compressed=True.
@@ -224,7 +224,7 @@ def download(ipart: int,
224224
if resp_format == 'json':
225225
data = resp.json()
226226
pd.DataFrame(data, columns=columns).to_json(os.path.join(
227-
json_output_path, 'part_' + str(ipart) + '.jsonl'),
227+
json_output_folder, 'part_' + str(ipart) + '.jsonl'),
228228
orient='records',
229229
lines=True)
230230
return
@@ -242,7 +242,7 @@ def download(ipart: int,
242242

243243
# Convert the PyArrow table into a pandas DataFrame
244244
df = table.to_pandas()
245-
df.to_json(os.path.join(json_output_path,
245+
df.to_json(os.path.join(json_output_folder,
246246
'part_' + str(ipart) + '.jsonl'),
247247
orient='records',
248248
lines=True,
@@ -256,7 +256,7 @@ def download_starargs(args: Tuple) -> None:
256256
def fetch_data(method: str, cursor: Optional[Cursor],
257257
sparkSession: Optional[SparkSession], start: int, end: int,
258258
order_by: str, tablename: str, columns_str: str,
259-
json_output_path: str) -> None:
259+
json_output_folder: str) -> None:
260260
"""Fetches a specified range of rows from a given table to a json file.
261261
262262
This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes,
@@ -271,7 +271,7 @@ def fetch_data(method: str, cursor: Optional[Cursor],
271271
order_by (str): The column name to use for ordering the rows.
272272
tablename (str): The name of the table from which to fetch the data.
273273
columns_str (str): The string representation of the columns to select from the table.
274-
json_output_path (str): The file path where the resulting JSON file will be saved.
274+
json_output_folder (str): The file path where the resulting JSON file will be saved.
275275
276276
Returns:
277277
None: The function doesn't return any value, but writes the result to a JSONL file.
@@ -301,15 +301,15 @@ def fetch_data(method: str, cursor: Optional[Cursor],
301301
records = [r.asDict() for r in ans] # pyright: ignore
302302
pdf = pd.DataFrame.from_dict(records)
303303

304-
pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'),
304+
pdf.to_json(os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'),
305305
orient='records',
306306
lines=True)
307307

308308

309309
def fetch(
310310
method: str,
311311
tablename: str,
312-
json_output_path: str,
312+
json_output_folder: str,
313313
batch_size: int = 1 << 30,
314314
processes: int = 1,
315315
sparkSession: Optional[SparkSession] = None,
@@ -320,7 +320,7 @@ def fetch(
320320
Args:
321321
method (str): dbconnect or dbsql
322322
tablename (str): catalog.scheme.tablename on UC
323-
json_output_path (str): path to write the result json file to
323+
json_output_folder (str): path to write the result json file to
324324
batch_size (int): number of rows that dbsql fetches each time to avoid OOM
325325
processes (int): max number of processes to use to parallelize the fetch
326326
sparkSession (pyspark.sql.sparksession): spark session
@@ -358,7 +358,7 @@ def fetch(
358358
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
359359
log.info(f'len(signed) = {len(signed)}')
360360

361-
args = get_args(signed, json_output_path, columns)
361+
args = get_args(signed, json_output_folder, columns)
362362

363363
# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
364364
sparkSession.stop()
@@ -371,7 +371,7 @@ def fetch(
371371
log.warning(f'batch {start}')
372372
end = min(start + batch_size, nrows)
373373
fetch_data(method, cursor, sparkSession, start, end, order_by,
374-
tablename, columns_str, json_output_path)
374+
tablename, columns_str, json_output_folder)
375375

376376
if cursor is not None:
377377
cursor.close()
@@ -381,21 +381,24 @@ def fetch_DT(args: Namespace) -> None:
381381
"""Fetch UC Delta Table to local as jsonl."""
382382
log.info(f'Start .... Convert delta to json')
383383

384-
obj = urllib.parse.urlparse(args.json_output_path)
384+
obj = urllib.parse.urlparse(args.json_output_folder)
385385
if obj.scheme != '':
386386
raise ValueError(
387-
f'Check the json_output_path and verify it is a local path!')
387+
f'Check the json_output_folder and verify it is a local path!')
388388

389-
if os.path.exists(args.json_output_path):
390-
if not os.path.isdir(args.json_output_path) or os.listdir(
391-
args.json_output_path):
389+
if os.path.exists(args.json_output_folder):
390+
if not os.path.isdir(args.json_output_folder) or os.listdir(
391+
args.json_output_folder):
392392
raise RuntimeError(
393-
f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!'
393+
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
394394
)
395395

396-
os.makedirs(args.json_output_path, exist_ok=True)
396+
os.makedirs(args.json_output_folder, exist_ok=True)
397397

398-
log.info(f'Directory {args.json_output_path} created.')
398+
if not args.json_output_filename.endswith('.jsonl'):
399+
raise ValueError('json_output_filename needs to be a jsonl file')
400+
401+
log.info(f'Directory {args.json_output_folder} created.')
399402

400403
method = 'dbsql'
401404
dbsql = None
@@ -451,16 +454,16 @@ def fetch_DT(args: Namespace) -> None:
451454
'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
452455
) from e
453456

454-
fetch(method, args.delta_table_name, args.json_output_path, args.batch_size,
455-
args.processes, sparkSession, dbsql)
457+
fetch(method, args.delta_table_name, args.json_output_folder,
458+
args.batch_size, args.processes, sparkSession, dbsql)
456459

457460
if dbsql is not None:
458461
dbsql.close()
459462

460463
# combine downloaded jsonl into one big jsonl for IFT
461464
iterative_combine_jsons(
462-
args.json_output_path,
463-
os.path.join(args.json_output_path, 'combined.jsonl'))
465+
args.json_output_folder,
466+
os.path.join(args.json_output_folder, args.json_output_filename))
464467

465468

466469
if __name__ == '__main__':
@@ -471,7 +474,7 @@ def fetch_DT(args: Namespace) -> None:
471474
required=True,
472475
type=str,
473476
help='UC table <catalog>.<schema>.<table name>')
474-
parser.add_argument('--json_output_path',
477+
parser.add_argument('--json_output_folder',
475478
required=True,
476479
type=str,
477480
help='Local path to save the converted json')
@@ -505,6 +508,12 @@ def fetch_DT(args: Namespace) -> None:
505508
help=
506509
'Use serverless or not. Make sure the workspace is entitled with serverless'
507510
)
511+
parser.add_argument(
512+
'--json_output_filename',
513+
required=False,
514+
type=str,
515+
default='train-00000-of-00001.jsonl',
516+
help='The combined final jsonl that combines all partitioned jsonl')
508517
args = parser.parse_args()
509518

510519
from databricks.sdk import WorkspaceClient

tests/a_scripts/data_prep/test_convert_delta_to_json.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any,
2727

2828
args = MagicMock()
2929
args.delta_table_name = 'test_table'
30-
args.json_output_path = '/path/to/jsonl'
30+
args.json_output_folder = '/path/to/jsonl'
3131
args.DATABRICKS_HOST = 'test_host'
3232
args.DATABRICKS_TOKEN = 'test_token'
3333
args.http_path = 'test_path'
@@ -36,6 +36,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any,
3636
args.cluster_id = '1234'
3737
args.debug = False
3838
args.use_serverless = False
39+
args.json_output_filename = 'combined.jsonl'
3940

4041
mock_cluster_get = MagicMock()
4142
mock_cluster_get.return_value = MagicMock(
@@ -154,7 +155,7 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any,
154155
args = MagicMock()
155156

156157
args.delta_table_name = 'test_table'
157-
args.json_output_path = '/path/to/jsonl'
158+
args.json_output_folder = '/path/to/jsonl'
158159
# Execute function with http_path=None (should use dbconnect)
159160
args.http_path = None
160161
args.cluster_id = '1234'
@@ -192,7 +193,7 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any,
192193
args = MagicMock()
193194

194195
args.delta_table_name = 'test_table'
195-
args.json_output_path = '/path/to/jsonl'
196+
args.json_output_folder = '/path/to/jsonl'
196197
# Execute function with http_path=None (should use dbconnect)
197198
args.http_path = 'test_path'
198199
args.cluster_id = '1234'
@@ -225,7 +226,7 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any,
225226
args = MagicMock()
226227

227228
args.delta_table_name = 'test_table'
228-
args.json_output_path = '/path/to/jsonl'
229+
args.json_output_folder = '/path/to/jsonl'
229230
# Execute function with http_path=None (should use dbconnect)
230231
args.http_path = 'test_path'
231232
args.cluster_id = '1234'
@@ -258,7 +259,7 @@ def test_sqlconnect_called_https(self, mock_fetch: Any,
258259
args = MagicMock()
259260

260261
args.delta_table_name = 'test_table'
261-
args.json_output_path = '/path/to/jsonl'
262+
args.json_output_folder = '/path/to/jsonl'
262263
# Execute function with http_path=None (should use dbconnect)
263264
args.http_path = 'test_path'
264265
args.cluster_id = '1234'
@@ -288,7 +289,7 @@ def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any,
288289
args = MagicMock()
289290

290291
args.delta_table_name = 'test_table'
291-
args.json_output_path = '/path/to/jsonl'
292+
args.json_output_folder = '/path/to/jsonl'
292293
# Execute function with http_path=None (should use dbconnect)
293294
args.http_path = 'test_path'
294295
args.cluster_id = '1234'

0 commit comments

Comments
 (0)