Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/datasets/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from functools import partial
from itertools import groupby
from typing import TYPE_CHECKING, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -1307,7 +1307,7 @@ def __init__(self, table: pa.Table, blocks: List[List[TableBlock]]):
if not isinstance(subtable, TableBlock):
raise TypeError(
"The blocks of a ConcatenationTable must be InMemoryTable or MemoryMappedTable objects"
f", but got {subtable}."
f", but got {_short_str(subtable)}."
)

def __getstate__(self):
Expand Down Expand Up @@ -1837,6 +1837,13 @@ def _storage_type(type: pa.DataType) -> pa.DataType:
return type


def _short_str(value: Any) -> str:
out = str(value)
if len(out) > 3000:
out = out[:1500] + "\n...\n" + out[-1500:]
return out


@_wrap_for_chunked_arrays
def array_cast(
array: pa.Array, pa_type: pa.DataType, allow_primitive_to_str: bool = True, allow_decimal_to_str: bool = True
Expand Down Expand Up @@ -1943,18 +1950,18 @@ def array_cast(
if pa.types.is_string(pa_type):
if not allow_primitive_to_str and pa.types.is_primitive(array.type):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} "
f"Couldn't cast array of type {_short_str(array.type)} to {_short_str(pa_type)} "
f"since allow_primitive_to_str is set to {allow_primitive_to_str} "
)
if not allow_decimal_to_str and pa.types.is_decimal(array.type):
raise TypeError(
f"Couldn't cast array of type {array.type} to {pa_type} "
f"Couldn't cast array of type {_short_str(array.type)} to {_short_str(pa_type)} "
f"and allow_decimal_to_str is set to {allow_decimal_to_str}"
)
if pa.types.is_null(pa_type) and not pa.types.is_null(array.type):
raise TypeError(f"Couldn't cast array of type {array.type} to {pa_type}")
raise TypeError(f"Couldn't cast array of type {_short_str(array.type)} to {_short_str(pa_type)}")
return array.cast(pa_type)
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{pa_type}")
raise TypeError(f"Couldn't cast array of type {_short_str(array.type)} to {_short_str(pa_type)}")


@_wrap_for_chunked_arrays
Expand Down Expand Up @@ -2112,7 +2119,7 @@ def cast_array_to_feature(
allow_primitive_to_str=allow_primitive_to_str,
allow_decimal_to_str=allow_decimal_to_str,
)
raise TypeError(f"Couldn't cast array of type\n{array.type}\nto\n{feature}")
raise TypeError(f"Couldn't cast array of type\n{_short_str(array.type)}\nto\n{_short_str(feature)}")


@_wrap_for_chunked_arrays
Expand Down Expand Up @@ -2180,7 +2187,7 @@ def embed_array_storage(array: pa.Array, feature: "FeatureType"):
return pa.FixedSizeListArray.from_arrays(embedded_array_values, feature.length, mask=array.is_null())
if not isinstance(feature, (Sequence, dict, list, tuple)):
return array
raise TypeError(f"Couldn't embed array of type\n{array.type}\nwith\n{feature}")
raise TypeError(f"Couldn't embed array of type\n{_short_str(array.type)}\nwith\n{_short_str(feature)}")


class CastError(ValueError):
Expand All @@ -2201,11 +2208,11 @@ def details(self):
new_columns = set(self.table_column_names) - set(self.requested_column_names)
missing_columns = set(self.requested_column_names) - set(self.table_column_names)
if new_columns and missing_columns:
return f"there are {len(new_columns)} new columns ({', '.join(new_columns)}) and {len(missing_columns)} missing columns ({', '.join(missing_columns)})."
return f"there are {len(new_columns)} new columns ({_short_str(new_columns)}) and {len(missing_columns)} missing columns ({_short_str(missing_columns)})."
elif new_columns:
return f"there are {len(new_columns)} new columns ({new_columns})"
return f"there are {len(new_columns)} new columns ({_short_str(new_columns)})"
else:
return f"there are {len(missing_columns)} missing columns ({missing_columns})"
return f"there are {len(missing_columns)} missing columns ({_short_str(missing_columns)})"


def cast_table_to_features(table: pa.Table, features: "Features"):
Expand All @@ -2222,7 +2229,7 @@ def cast_table_to_features(table: pa.Table, features: "Features"):
"""
if sorted(table.column_names) != sorted(features):
raise CastError(
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
f"Couldn't cast\n{_short_str(table.schema)}\nto\n{_short_str(features)}\nbecause column names don't match",
table_column_names=table.column_names,
requested_column_names=list(features),
)
Expand All @@ -2247,7 +2254,7 @@ def cast_table_to_schema(table: pa.Table, schema: pa.Schema):
features = Features.from_arrow_schema(schema)
if sorted(table.column_names) != sorted(features):
raise CastError(
f"Couldn't cast\n{table.schema}\nto\n{features}\nbecause column names don't match",
f"Couldn't cast\n{_short_str(table.schema)}\nto\n{_short_str(features)}\nbecause column names don't match",
table_column_names=table.column_names,
requested_column_names=list(features),
)
Expand Down