Skip to content

Commit 4a6a6cb

Browse files
wesmxhochy
authored andcommitted
ARROW-1359: [C++] Add flavor='spark' option to write_parquet that sanitizes schema field names
I also made the default for `use_deprecated_int96_timestamps` None so that we can distinguish between unspecified and explicitly False. In the event that the user passes `flavor='spark'`, this is enabled. Once Spark processes the int96 deprecation in the future, we can remove this part. Author: Wes McKinney <[email protected]> Closes #1076 from wesm/ARROW-1359 and squashes the following commits: 8a60b66 [Wes McKinney] Use composition rather than inheritance e3fa8ec [Wes McKinney] Add note about spark flavor to Sphinx docs 8159a51 [Wes McKinney] Add flavor='spark' option to write_parquet that sanitizes schema field names, turns on int96 timestamps
1 parent 947ca87 commit 4a6a6cb

File tree

5 files changed

+149
-18
lines changed

5 files changed

+149
-18
lines changed

python/doc/source/parquet.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,13 @@ such as those produced by Hive:
217217
dataset = pq.ParquetDataset('dataset_name/')
218218
table = dataset.read()
219219
220+
Using with Spark
221+
----------------
222+
223+
Spark places some constraints on the types of Parquet files it will read. The
224+
option ``flavor='spark'`` will set these options automatically and also
225+
sanitize field characters unsupported by Spark SQL.
226+
220227
Multithreaded Reads
221228
-------------------
222229

python/pyarrow/parquet.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,17 @@
1818
import os
1919
import inspect
2020
import json
21-
21+
import re
2222
import six
2323

2424
import numpy as np
2525

2626
from pyarrow.filesystem import FileSystem, LocalFileSystem, S3FSWrapper
2727
from pyarrow._parquet import (ParquetReader, FileMetaData, # noqa
28-
RowGroupMetaData, ParquetSchema,
29-
ParquetWriter)
28+
RowGroupMetaData, ParquetSchema)
3029
import pyarrow._parquet as _parquet # noqa
3130
import pyarrow.lib as lib
31+
import pyarrow as pa
3232

3333

3434
# ----------------------------------------------------------------------
@@ -164,6 +164,73 @@ def _get_column_indices(self, column_names, use_pandas_metadata=False):
164164
return indices
165165

166166

167+
_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')
168+
169+
170+
def _sanitized_spark_field_name(name):
171+
return _SPARK_DISALLOWED_CHARS.sub('_', name)
172+
173+
174+
def _sanitize_schema(schema, flavor):
175+
if 'spark' in flavor:
176+
sanitized_fields = []
177+
178+
schema_changed = False
179+
180+
for field in schema:
181+
name = field.name
182+
sanitized_name = _sanitized_spark_field_name(name)
183+
184+
if sanitized_name != name:
185+
schema_changed = True
186+
sanitized_field = pa.field(sanitized_name, field.type,
187+
field.nullable, field.metadata)
188+
sanitized_fields.append(sanitized_field)
189+
else:
190+
sanitized_fields.append(field)
191+
return pa.schema(sanitized_fields), schema_changed
192+
else:
193+
return schema, False
194+
195+
196+
def _sanitize_table(table, new_schema, flavor):
197+
# TODO: This will not handle prohibited characters in nested field names
198+
if 'spark' in flavor:
199+
column_data = [table[i].data for i in range(table.num_columns)]
200+
return pa.Table.from_arrays(column_data, schema=new_schema)
201+
else:
202+
return table
203+
204+
205+
class ParquetWriter(object):
206+
"""
207+
208+
Parameters
209+
----------
210+
where
211+
schema
212+
flavor : {'spark', ...}
213+
Set options for compatibility with a particular reader
214+
"""
215+
def __init__(self, where, schema, flavor=None, **options):
216+
self.flavor = flavor
217+
if flavor is not None:
218+
schema, self.schema_changed = _sanitize_schema(schema, flavor)
219+
else:
220+
self.schema_changed = False
221+
222+
self.schema = schema
223+
self.writer = _parquet.ParquetWriter(where, schema, **options)
224+
225+
def write_table(self, table, row_group_size=None):
226+
if self.schema_changed:
227+
table = _sanitize_table(table, self.schema, self.flavor)
228+
self.writer.write_table(table, row_group_size=row_group_size)
229+
230+
def close(self):
231+
self.writer.close()
232+
233+
167234
def _get_pandas_index_columns(keyvalues):
168235
return (json.loads(keyvalues[b'pandas'].decode('utf8'))
169236
['index_columns'])
@@ -787,8 +854,9 @@ def read_pandas(source, columns=None, nthreads=1, metadata=None):
787854

788855
def write_table(table, where, row_group_size=None, version='1.0',
789856
use_dictionary=True, compression='snappy',
790-
use_deprecated_int96_timestamps=False,
791-
coerce_timestamps=None, **kwargs):
857+
use_deprecated_int96_timestamps=None,
858+
coerce_timestamps=None,
859+
flavor=None, **kwargs):
792860
"""
793861
Write a Table to Parquet format
794862
@@ -804,15 +872,26 @@ def write_table(table, where, row_group_size=None, version='1.0',
804872
use_dictionary : bool or list
805873
Specify if we should use dictionary encoding in general or only for
806874
some columns.
807-
use_deprecated_int96_timestamps : boolean, default False
808-
Write nanosecond resolution timestamps to INT96 Parquet format
875+
use_deprecated_int96_timestamps : boolean, default None
876+
Write nanosecond resolution timestamps to INT96 Parquet
877+
format. Defaults to False unless enabled by flavor argument
809878
coerce_timestamps : string, default None
810879
Cast timestamps a particular resolution.
811880
Valid values: {None, 'ms', 'us'}
812881
compression : str or dict
813882
Specify the compression codec, either on a general basis or per-column.
883+
flavor : {'spark'}, default None
884+
Sanitize schema or set other compatibility options for compatibility
814885
"""
815886
row_group_size = kwargs.get('chunk_size', row_group_size)
887+
888+
if use_deprecated_int96_timestamps is None:
889+
# Use int96 timestamps for Spark
890+
if flavor is not None and 'spark' in flavor:
891+
use_deprecated_int96_timestamps = True
892+
else:
893+
use_deprecated_int96_timestamps = False
894+
816895
options = dict(
817896
use_dictionary=use_dictionary,
818897
compression=compression,
@@ -822,7 +901,8 @@ def write_table(table, where, row_group_size=None, version='1.0',
822901

823902
writer = None
824903
try:
825-
writer = ParquetWriter(where, table.schema, **options)
904+
writer = ParquetWriter(where, table.schema, flavor=flavor,
905+
**options)
826906
writer.write_table(table, row_group_size=row_group_size)
827907
except:
828908
if writer is not None:

python/pyarrow/table.pxi

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ cdef class Table:
758758
return cls.from_arrays(arrays, names=names, metadata=metadata)
759759

760760
@staticmethod
761-
def from_arrays(arrays, names=None, dict metadata=None):
761+
def from_arrays(arrays, names=None, schema=None, dict metadata=None):
762762
"""
763763
Construct a Table from Arrow arrays or columns
764764
@@ -777,35 +777,52 @@ cdef class Table:
777777
"""
778778
cdef:
779779
vector[shared_ptr[CColumn]] columns
780-
shared_ptr[CSchema] schema
780+
Schema cy_schema
781+
shared_ptr[CSchema] c_schema
781782
shared_ptr[CTable] table
782783
int i, K = <int> len(arrays)
783784

784-
_schema_from_arrays(arrays, names, metadata, &schema)
785+
if schema is None:
786+
_schema_from_arrays(arrays, names, metadata, &c_schema)
787+
elif schema is not None:
788+
if names is not None:
789+
raise ValueError('Cannot pass schema and arrays')
790+
cy_schema = schema
791+
792+
if len(schema) != len(arrays):
793+
raise ValueError('Schema and number of arrays unequal')
794+
795+
c_schema = cy_schema.sp_schema
785796

786797
columns.reserve(K)
787798

788799
for i in range(K):
789800
if isinstance(arrays[i], Array):
790801
columns.push_back(
791802
make_shared[CColumn](
792-
schema.get().field(i),
803+
c_schema.get().field(i),
793804
(<Array> arrays[i]).sp_array
794805
)
795806
)
796807
elif isinstance(arrays[i], ChunkedArray):
797808
columns.push_back(
798809
make_shared[CColumn](
799-
schema.get().field(i),
810+
c_schema.get().field(i),
800811
(<ChunkedArray> arrays[i]).sp_chunked_array
801812
)
802813
)
803814
elif isinstance(arrays[i], Column):
804-
columns.push_back((<Column> arrays[i]).sp_column)
815+
# Make sure schema field and column are consistent
816+
columns.push_back(
817+
make_shared[CColumn](
818+
c_schema.get().field(i),
819+
(<Column> arrays[i]).sp_column.get().data()
820+
)
821+
)
805822
else:
806823
raise ValueError(type(arrays[i]))
807824

808-
table.reset(new CTable(schema, columns))
825+
table.reset(new CTable(c_schema, columns))
809826
return pyarrow_wrap_table(table)
810827

811828
@staticmethod

python/pyarrow/tests/test_parquet.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from os.path import join as pjoin
1919
import datetime
20+
import gc
2021
import io
2122
import os
2223
import json
@@ -562,6 +563,10 @@ def test_date_time_types():
562563
_check_roundtrip(table, expected=expected, version='2.0',
563564
use_deprecated_int96_timestamps=True)
564565

566+
# Check that setting flavor to 'spark' uses int96 timestamps
567+
_check_roundtrip(table, expected=expected, version='2.0',
568+
flavor='spark')
569+
565570
# Unsupported stuff
566571
def _assert_unsupported(array):
567572
table = pa.Table.from_arrays([array], ['unsupported'])
@@ -576,6 +581,18 @@ def _assert_unsupported(array):
576581
_assert_unsupported(a7)
577582

578583

584+
@parquet
585+
def test_sanitized_spark_field_names():
586+
a0 = pa.array([0, 1, 2, 3, 4])
587+
name = 'prohib; ,\t{}'
588+
table = pa.Table.from_arrays([a0], [name])
589+
590+
result = _roundtrip_table(table, flavor='spark')
591+
592+
expected_name = 'prohib______'
593+
assert result.schema[0].name == expected_name
594+
595+
579596
@parquet
580597
def test_fixed_size_binary():
581598
t0 = pa.binary(10)
@@ -587,15 +604,19 @@ def test_fixed_size_binary():
587604
_check_roundtrip(table)
588605

589606

590-
def _check_roundtrip(table, expected=None, **params):
607+
def _roundtrip_table(table, **params):
591608
buf = io.BytesIO()
592609
_write_table(table, buf, **params)
593610
buf.seek(0)
594611

612+
return _read_table(buf)
613+
614+
615+
def _check_roundtrip(table, expected=None, **params):
595616
if expected is None:
596617
expected = table
597618

598-
result = _read_table(buf)
619+
result = _roundtrip_table(table, **params)
599620
assert result.equals(expected)
600621

601622

@@ -1181,6 +1202,9 @@ def test_write_error_deletes_incomplete_file(tmpdir):
11811202
except pa.ArrowException:
11821203
pass
11831204

1205+
# Ensure that object has been destructed; this causes test failures on
1206+
# Windows
1207+
gc.collect()
11841208
assert not os.path.exists(filename)
11851209

11861210

python/pyarrow/types.pxi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ cdef class Schema:
299299
return self.schema.num_fields()
300300

301301
def __getitem__(self, int i):
302-
303302
cdef:
304303
Field result = Field()
305304
int num_fields = self.schema.num_fields()
@@ -318,6 +317,10 @@ cdef class Schema:
318317

319318
return result
320319

320+
def __iter__(self):
321+
for i in range(len(self)):
322+
yield self[i]
323+
321324
def _check_null(self):
322325
if self.schema == NULL:
323326
raise ReferenceError(

0 commit comments

Comments
 (0)