Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions scripts/data_prep/convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def run_query(
raise ValueError(f'Unrecognized method: {method}')


def get_args(signed: List, json_output_path: str, columns: List) -> Iterable:
def get_args(signed: List, json_output_folder: str, columns: List) -> Iterable:
for i, r in enumerate(signed):
yield (i, r.url, json_output_path, columns)
yield (i, r.url, json_output_folder, columns)


def download(ipart: int,
url: str,
json_output_path: str,
json_output_folder: str,
columns: Optional[List] = None,
resp_format: str = 'arrow',
compressed: bool = False) -> None:
Expand All @@ -214,7 +214,7 @@ def download(ipart: int,
Args:
ipart (int): presigned url id
url (str): presigned url
json_output_path (str): directory to save the ipart_th segment of dataframe
json_output_folder (str): directory to save the ipart_th segment of dataframe
columns (list): schema to save to json
resp_format (str): whether to use arrow or json when collect
compressed (bool): if data is compressed before downloading. Need decompress if compressed=True.
Expand All @@ -224,7 +224,7 @@ def download(ipart: int,
if resp_format == 'json':
data = resp.json()
pd.DataFrame(data, columns=columns).to_json(os.path.join(
json_output_path, 'part_' + str(ipart) + '.jsonl'),
json_output_folder, 'part_' + str(ipart) + '.jsonl'),
orient='records',
lines=True)
return
Expand All @@ -242,7 +242,7 @@ def download(ipart: int,

# Convert the PyArrow table into a pandas DataFrame
df = table.to_pandas()
df.to_json(os.path.join(json_output_path,
df.to_json(os.path.join(json_output_folder,
'part_' + str(ipart) + '.jsonl'),
orient='records',
lines=True,
Expand All @@ -256,7 +256,7 @@ def download_starargs(args: Tuple) -> None:
def fetch_data(method: str, cursor: Optional[Cursor],
sparkSession: Optional[SparkSession], start: int, end: int,
order_by: str, tablename: str, columns_str: str,
json_output_path: str) -> None:
json_output_folder: str) -> None:
"""Fetches a specified range of rows from a given table to a json file.

This function executes a SQL query to retrieve a range of rows, determined by 'start' and 'end' indexes,
Expand All @@ -271,7 +271,7 @@ def fetch_data(method: str, cursor: Optional[Cursor],
order_by (str): The column name to use for ordering the rows.
tablename (str): The name of the table from which to fetch the data.
columns_str (str): The string representation of the columns to select from the table.
json_output_path (str): The file path where the resulting JSON file will be saved.
json_output_folder (str): The file path where the resulting JSON file will be saved.

Returns:
None: The function doesn't return any value, but writes the result to a JSONL file.
Expand Down Expand Up @@ -301,15 +301,15 @@ def fetch_data(method: str, cursor: Optional[Cursor],
records = [r.asDict() for r in ans] # pyright: ignore
pdf = pd.DataFrame.from_dict(records)

pdf.to_json(os.path.join(json_output_path, f'part_{start+1}_{end}.jsonl'),
pdf.to_json(os.path.join(json_output_folder, f'part_{start+1}_{end}.jsonl'),
orient='records',
lines=True)


def fetch(
method: str,
tablename: str,
json_output_path: str,
json_output_folder: str,
batch_size: int = 1 << 30,
processes: int = 1,
sparkSession: Optional[SparkSession] = None,
Expand All @@ -320,7 +320,7 @@ def fetch(
Args:
method (str): dbconnect or dbsql
tablename (str): catalog.scheme.tablename on UC
json_output_path (str): path to write the result json file to
json_output_folder (str): path to write the result json file to
batch_size (int): number of rows that dbsql fetches each time to avoid OOM
processes (int): max number of processes to use to parallelize the fetch
sparkSession (pyspark.sql.sparksession): spark session
Expand Down Expand Up @@ -358,7 +358,7 @@ def fetch(
signed, _, _ = df.collect_cf('arrow') # pyright: ignore
log.info(f'len(signed) = {len(signed)}')

args = get_args(signed, json_output_path, columns)
args = get_args(signed, json_output_folder, columns)

# Stopping the SparkSession to avoid spilling connection state into the subprocesses.
sparkSession.stop()
Expand All @@ -371,7 +371,7 @@ def fetch(
log.warning(f'batch {start}')
end = min(start + batch_size, nrows)
fetch_data(method, cursor, sparkSession, start, end, order_by,
tablename, columns_str, json_output_path)
tablename, columns_str, json_output_folder)

if cursor is not None:
cursor.close()
Expand All @@ -381,21 +381,24 @@ def fetch_DT(args: Namespace) -> None:
"""Fetch UC Delta Table to local as jsonl."""
log.info(f'Start .... Convert delta to json')

obj = urllib.parse.urlparse(args.json_output_path)
obj = urllib.parse.urlparse(args.json_output_folder)
if obj.scheme != '':
raise ValueError(
f'Check the json_output_path and verify it is a local path!')
f'Check the json_output_folder and verify it is a local path!')

if os.path.exists(args.json_output_path):
if not os.path.isdir(args.json_output_path) or os.listdir(
args.json_output_path):
if os.path.exists(args.json_output_folder):
if not os.path.isdir(args.json_output_folder) or os.listdir(
args.json_output_folder):
raise RuntimeError(
f'A file or a folder {args.json_output_path} already exists and is not empty. Remove it and retry!'
f'A file or a folder {args.json_output_folder} already exists and is not empty. Remove it and retry!'
)

os.makedirs(args.json_output_path, exist_ok=True)
os.makedirs(args.json_output_folder, exist_ok=True)

log.info(f'Directory {args.json_output_path} created.')
if not args.json_output_filename.endswith('.jsonl'):
raise ValueError('json_output_filename needs to be a jsonl file')

log.info(f'Directory {args.json_output_folder} created.')

method = 'dbsql'
dbsql = None
Expand Down Expand Up @@ -451,16 +454,16 @@ def fetch_DT(args: Namespace) -> None:
'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
) from e

fetch(method, args.delta_table_name, args.json_output_path, args.batch_size,
args.processes, sparkSession, dbsql)
fetch(method, args.delta_table_name, args.json_output_folder,
args.batch_size, args.processes, sparkSession, dbsql)

if dbsql is not None:
dbsql.close()

# combine downloaded jsonl into one big jsonl for IFT
iterative_combine_jsons(
args.json_output_path,
os.path.join(args.json_output_path, 'combined.jsonl'))
args.json_output_folder,
os.path.join(args.json_output_folder, args.json_output_filename))


if __name__ == '__main__':
Expand All @@ -471,7 +474,7 @@ def fetch_DT(args: Namespace) -> None:
required=True,
type=str,
help='UC table <catalog>.<schema>.<table name>')
parser.add_argument('--json_output_path',
parser.add_argument('--json_output_folder',
required=True,
type=str,
help='Local path to save the converted json')
Expand Down Expand Up @@ -505,6 +508,12 @@ def fetch_DT(args: Namespace) -> None:
help=
'Use serverless or not. Make sure the workspace is entitled with serverless'
)
parser.add_argument(
'--json_output_filename',
required=False,
type=str,
default='train-00000-of-00001.jsonl',
help='The combined final jsonl that combines all partitioned jsonl')
args = parser.parse_args()

from databricks.sdk import WorkspaceClient
Expand Down
13 changes: 7 additions & 6 deletions tests/a_scripts/data_prep/test_convert_delta_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any,

args = MagicMock()
args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
args.DATABRICKS_HOST = 'test_host'
args.DATABRICKS_TOKEN = 'test_token'
args.http_path = 'test_path'
Expand All @@ -36,6 +36,7 @@ def test_stream_delta_to_json(self, mock_workspace_client: Any,
args.cluster_id = '1234'
args.debug = False
args.use_serverless = False
args.json_output_filename = 'combined.jsonl'

mock_cluster_get = MagicMock()
mock_cluster_get.return_value = MagicMock(
Expand Down Expand Up @@ -154,7 +155,7 @@ def test_dbconnect_called(self, mock_fetch: Any, mock_combine_jsons: Any,
args = MagicMock()

args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
# Execute function with http_path=None (should use dbconnect)
args.http_path = None
args.cluster_id = '1234'
Expand Down Expand Up @@ -192,7 +193,7 @@ def test_sqlconnect_called_dbr13(self, mock_fetch: Any,
args = MagicMock()

args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
# Execute function with http_path=None (should use dbconnect)
args.http_path = 'test_path'
args.cluster_id = '1234'
Expand Down Expand Up @@ -225,7 +226,7 @@ def test_sqlconnect_called_dbr14(self, mock_fetch: Any,
args = MagicMock()

args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
# Execute function with http_path=None (should use dbconnect)
args.http_path = 'test_path'
args.cluster_id = '1234'
Expand Down Expand Up @@ -258,7 +259,7 @@ def test_sqlconnect_called_https(self, mock_fetch: Any,
args = MagicMock()

args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
# Execute function with http_path=None (should use dbconnect)
args.http_path = 'test_path'
args.cluster_id = '1234'
Expand Down Expand Up @@ -288,7 +289,7 @@ def test_serverless(self, mock_fetch: Any, mock_combine_jsons: Any,
args = MagicMock()

args.delta_table_name = 'test_table'
args.json_output_path = '/path/to/jsonl'
args.json_output_folder = '/path/to/jsonl'
# Execute function with http_path=None (should use dbconnect)
args.http_path = 'test_path'
args.cluster_id = '1234'
Expand Down