Skip to content

Commit 84d9dea

Browse files
Ensure compatibility with numpy 2.0.0 (#6976)
* Ensure compatibility with numpy 2.0.0 Following the conversion guide copy=False is no longer required and will result in an error: https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword. * Update src/datasets/formatting/formatting.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/formatting/formatting.py Co-authored-by: Quentin Lhoest <[email protected]> * make style --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent e47a746 commit 84d9dea

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/datasets/formatting/formatting.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,20 @@ def _arrow_array_to_numpy(self, pa_array: pa.Array) -> np.ndarray:
187187
else:
188188
zero_copy_only = _is_zero_copy_only(pa_array.type) and not _is_array_with_nulls(pa_array)
189189
array: List = pa_array.to_numpy(zero_copy_only=zero_copy_only).tolist()
190+
190191
if len(array) > 0:
191192
if any(
192193
(isinstance(x, np.ndarray) and (x.dtype == object or x.shape != array[0].shape))
193194
or (isinstance(x, float) and np.isnan(x))
194195
for x in array
195196
):
197+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
198+
return np.asarray(array, dtype=object)
196199
return np.array(array, copy=False, dtype=object)
197-
return np.array(array, copy=False)
200+
if np.lib.NumpyVersion(np.__version__) >= "2.0.0b1":
201+
return np.asarray(array)
202+
else:
203+
return np.array(array, copy=False)
198204

199205

200206
class PandasArrowExtractor(BaseArrowExtractor[pd.DataFrame, pd.Series, pd.DataFrame]):

0 commit comments

Comments
 (0)