File tree Expand file tree Collapse file tree 2 files changed +30
-8
lines changed
_internal/planner/exchange Expand file tree Collapse file tree 2 files changed +30
-8
lines changed Original file line number Diff line number Diff line change @@ -117,15 +117,21 @@ def execute(
117
117
)
118
118
119
119
num_empty_blocks = output_num_blocks - len (reduce_block_refs )
120
- first_block_schema = reduce_metadata_schema [0 ].schema
121
- if first_block_schema is None :
122
- raise ValueError (
123
- "Cannot split partition on blocks with unknown block format."
124
- )
125
- elif isinstance (first_block_schema , pa .Schema ):
120
+ if len (reduce_metadata_schema ) > 0 :
121
+ first_block_schema = reduce_metadata_schema [0 ].schema
122
+ if isinstance (first_block_schema , pa .Schema ):
123
+ builder = ArrowBlockBuilder ()
124
+ elif isinstance (first_block_schema , PandasBlockSchema ):
125
+ builder = PandasBlockBuilder ()
126
+ else :
127
+ raise ValueError (
128
+ "Cannot split partition on blocks with unknown block schema:"
129
+ f" { first_block_schema } ."
130
+ )
131
+ else :
132
+ # If the result is empty, default to Arrow format for the empty blocks.
126
133
builder = ArrowBlockBuilder ()
127
- elif isinstance (first_block_schema , PandasBlockSchema ):
128
- builder = PandasBlockBuilder ()
134
+
129
135
empty_block = builder .build ()
130
136
empty_meta_with_schema = BlockMetadataWithSchema .from_block (
131
137
empty_block
Original file line number Diff line number Diff line change @@ -206,6 +206,22 @@ def test_repartition_invalid_inputs(
206
206
)
207
207
208
208
209
+ @pytest .mark .parametrize ("shuffle" , [True , False ])
210
+ def test_repartition_empty_datasets (ray_start_regular_shared_2_cpus , shuffle ):
211
+ # Test repartitioning an empty dataset with shuffle=True
212
+ num_partitions = 5
213
+ ds_empty = ray .data .range (100 ).filter (lambda row : False )
214
+ ds_repartitioned = ds_empty .repartition (num_partitions , shuffle = shuffle )
215
+
216
+ ref_bundles = list (ds_repartitioned .iter_internal_ref_bundles ())
217
+ assert len (ref_bundles ) == num_partitions
218
+ for ref_bundle in ref_bundles :
219
+ assert len (ref_bundle .blocks ) == 1
220
+ metadata = ref_bundle .blocks [0 ][1 ]
221
+ assert metadata .num_rows == 0
222
+ assert metadata .size_bytes == 0
223
+
224
+
209
225
if __name__ == "__main__" :
210
226
import sys
211
227
You can’t perform that action at this time.
0 commit comments