Skip to content

Commit f6cd385

Browse files
benhurdelheyzhengruifeng
authored andcommitted
[SPARK-52943][PYTHON] Enable arrow_cast for all pandas UDF eval types
### What changes were proposed in this pull request? - this enables arrow_cast for all pandas_udfs - arrow_cast=True provides a coherent type coercion behavior, it is a bit more lenient for mismatched types - arrow_cast was originally introduced in #41800, but up until now it only applied to a subset of `udf` and `pandas_udf` eval types (see below) - this should have no performance impact as the cast is only done in a second attempt when the pandas->arrow conversion fails. ### Why are the changes needed? - this aligns `pandas_udf()` behavior with `udf(useArrow=True)` behavior, it makes PySpark more consistent ### Does this PR introduce _any_ user-facing change? - Yes, see the updated table in [functions.py](https://github.com/apache/spark/compare/benrobby:enable-arrow-cast). TLDR: this change is additive, it does not break workloads. It makes some pandas -> arrow conversions more lenient. We now support: - int <-> decimal - float <-> decimal - string with numbers <-> int,uint,float Affected UDF types: - Eval types that already had arrow_cast enabled before this PR: - `SQL_ARROW_TABLE_UDF` - `SQL_ARROW_BATCHED_UDF` - All pandas_udf eval types adopt arrow_cast=True with this PR: - `SQL_SCALAR_PANDAS_UDF` - `SQL_SCALAR_PANDAS_ITER_UDF` - `SQL_GROUPED_MAP_PANDAS_UDF` - `SQL_MAP_PANDAS_ITER_UDF` - `SQL_GROUPED_AGG_PANDAS_UDF` - `SQL_WINDOW_AGG_PANDAS_UDF` - `SQL_ARROW_TABLE_UDF` - `SQL_COGROUPED_MAP_PANDAS_UDF` - `SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` - `SQL_TRANSFORM_WITH_STATE_PANDAS_UDF` - unaffected: - Batched UDFs (useArrow=False) - All other pure arrow UDFs (`SQL_SCALAR_ARROW_UDF`, `SQL_SCALAR_ARROW_ITER_UDF`, `SQL_GROUPED_AGG_ARROW_UDF`, `SQL_GROUPED_MAP_ARROW_UDF`, `SQL_COGROUPED_MAP_ARROW_UDF`). For UDFs returning arrow data directly, the expectation is that users supply exactly the right types. ### How was this patch tested? - added unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51635 from benrobby/enable-arrow-cast. Authored-by: Ben Hurdelhey <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent ea8b6fd commit f6cd385

10 files changed

+233
-62
lines changed

python/pyspark/sql/pandas/functions.py

Lines changed: 20 additions & 20 deletions
Large diffs are not rendered by default.

python/pyspark/sql/pandas/serializers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ def __init__(
11031103
safecheck,
11041104
assign_cols_by_name,
11051105
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
1106+
arrow_cast=True,
11061107
)
11071108
self.pickleSer = CPickleSerializer()
11081109
self.utf8_deserializer = UTF8Deserializer()
@@ -1483,6 +1484,7 @@ def __init__(
14831484
safecheck,
14841485
assign_cols_by_name,
14851486
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
1487+
arrow_cast=True,
14861488
)
14871489
self.arrow_max_records_per_batch = arrow_max_records_per_batch
14881490
self.key_offsets = None

python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def check_apply_in_pandas_returning_incompatible_type(self):
262262
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
263263
)
264264
self._test_merge_error(
265-
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["2.0"]}),
265+
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": ["test_string"]}),
266266
output_schema="id long, k double",
267267
errorClass=PythonException,
268268
error_message_regex=expected,

python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def check_apply_in_pandas_returning_incompatible_type(self):
371371
)
372372
with self.assertRaisesRegex(PythonException, expected + "\n"):
373373
self._test_apply_in_pandas(
374-
lambda key, pdf: pd.DataFrame([key + (str(pdf.v.mean()),)]),
374+
lambda key, pdf: pd.DataFrame([key + ("test_string",)]),
375375
output_schema="id long, mean double",
376376
)
377377

@@ -900,6 +900,51 @@ def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df, error)
900900
with self.assertRaisesRegex(PythonException, error):
901901
self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
902902

903+
def test_arrow_cast_enabled_numeric_to_decimal(self):
904+
import numpy as np
905+
906+
columns = [
907+
"int8",
908+
"int16",
909+
"int32",
910+
"uint8",
911+
"uint16",
912+
"uint32",
913+
"float64",
914+
]
915+
916+
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
917+
df = self.spark.range(2).repartition(1)
918+
919+
for column in columns:
920+
with self.subTest(column=column):
921+
v = pdf[column].iloc[:1]
922+
schema_str = "id long, value decimal(10,0)"
923+
924+
@pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
925+
def test(pdf):
926+
return pdf.assign(**{"value": v})
927+
928+
row = df.groupby("id").apply(test).first()
929+
res = row[1]
930+
self.assertEqual(res, Decimal("1"))
931+
932+
def test_arrow_cast_enabled_str_to_numeric(self):
933+
df = self.spark.range(2).repartition(1)
934+
935+
types = ["int", "long", "float", "double"]
936+
937+
for type_str in types:
938+
with self.subTest(type=type_str):
939+
schema_str = "id long, value " + type_str
940+
941+
@pandas_udf(schema_str, PandasUDFType.GROUPED_MAP)
942+
def test(pdf):
943+
return pdf.assign(value=pd.Series(["123"]))
944+
945+
row = df.groupby("id").apply(test).first()
946+
self.assertEqual(row[1], 123)
947+
903948

904949
class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
905950
pass

python/pyspark/sql/tests/pandas/test_pandas_map.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,17 @@ def test_dataframes_with_incompatible_types(self):
276276
self.check_dataframes_with_incompatible_types()
277277

278278
def check_dataframes_with_incompatible_types(self):
279-
def func(iterator):
280-
for pdf in iterator:
281-
yield pdf.assign(id=pdf["id"].apply(str))
282-
283279
for safely in [True, False]:
284280
with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
285281
{"spark.sql.execution.pandas.convertToArrowArraySafely": safely}
286282
):
287283
# sometimes we see ValueErrors
288284
with self.subTest(convert="string to double"):
285+
286+
def func(iterator):
287+
for pdf in iterator:
288+
yield pdf.assign(id="test_string")
289+
289290
expected = (
290291
r"ValueError: Exception thrown when converting pandas.Series "
291292
r"\(object\) with name 'id' to Arrow Array \(double\)."
@@ -304,18 +305,31 @@ def func(iterator):
304305
.collect()
305306
)
306307

307-
# sometimes we see TypeErrors
308-
with self.subTest(convert="double to string"):
309-
with self.assertRaisesRegex(
310-
PythonException,
311-
r"TypeError: Exception thrown when converting pandas.Series "
312-
r"\(float64\) with name 'id' to Arrow Array \(string\).\n",
313-
):
314-
(
315-
self.spark.range(10, numPartitions=3)
316-
.select(col("id").cast("double"))
317-
.mapInPandas(self.identity_dataframes_iter("id"), "id string")
318-
.collect()
308+
with self.subTest(convert="float to int precision loss"):
309+
310+
def func(iterator):
311+
for pdf in iterator:
312+
yield pdf.assign(id=pdf["id"] + 0.1)
313+
314+
df = (
315+
self.spark.range(10, numPartitions=3)
316+
.select(col("id").cast("double"))
317+
.mapInPandas(func, "id int")
318+
)
319+
if safely:
320+
expected = (
321+
r"ValueError: Exception thrown when converting pandas.Series "
322+
r"\(float64\) with name 'id' to Arrow Array \(int32\)."
323+
" It can be caused by overflows or other "
324+
"unsafe conversions warned by Arrow. Arrow safe type check "
325+
"can be disabled by using SQL config "
326+
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
327+
)
328+
with self.assertRaisesRegex(PythonException, expected + "\n"):
329+
df.collect()
330+
else:
331+
self.assertEqual(
332+
df.collect(), self.spark.range(10, numPartitions=3).collect()
319333
)
320334

321335
def test_empty_iterator(self):

python/pyspark/sql/tests/pandas/test_pandas_udf.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -376,27 +376,20 @@ def high_precision_udf(column):
376376
values = [1, 2, 3]
377377
return pd.Series([values[int(val) % len(values)] for val in column])
378378

379-
with self.sql_conf(
380-
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
381-
):
382-
result = df.withColumn("decimal_val", high_precision_udf("id")).collect()
383-
self.assertEqual(len(result), 3)
384-
self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
385-
self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
386-
self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
387-
388-
with self.sql_conf(
389-
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
390-
):
391-
# Also not supported.
392-
# This can be fixed by enabling arrow_cast
393-
# This is currently not the case for SQL_SCALAR_PANDAS_UDF and
394-
# SQL_SCALAR_PANDAS_ITER_UDF.
395-
self.assertRaisesRegex(
396-
PythonException,
397-
"Exception thrown when converting pandas.Series",
398-
df.withColumn("decimal_val", high_precision_udf("id")).collect,
399-
)
379+
for intToDecimalCoercionEnabled in [True, False]:
380+
# arrow_cast is enabled by default for SQL_SCALAR_PANDAS_UDF and
381+
# and SQL_SCALAR_PANDAS_ITER_UDF, arrow can do this cast safely.
382+
# intToDecimalCoercionEnabled is not required for this case
383+
with self.sql_conf(
384+
{
385+
"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": intToDecimalCoercionEnabled # noqa: E501
386+
}
387+
):
388+
result = df.withColumn("decimal_val", high_precision_udf("id")).collect()
389+
self.assertEqual(len(result), 3)
390+
self.assertEqual(result[0]["decimal_val"], Decimal("1.0"))
391+
self.assertEqual(result[1]["decimal_val"], Decimal("2.0"))
392+
self.assertEqual(result[2]["decimal_val"], Decimal("3.0"))
400393

401394
def test_pandas_udf_timestamp_ntz(self):
402395
# SPARK-36626: Test TimestampNTZ in pandas UDF

python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,49 @@ def biased_sum(v, w=None):
718718
aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s"))
719719
)
720720

721+
def test_arrow_cast_enabled_numeric_to_decimal(self):
722+
import numpy as np
723+
from decimal import Decimal
724+
725+
columns = [
726+
"int8",
727+
"int16",
728+
"int32",
729+
"uint8",
730+
"uint16",
731+
"uint32",
732+
"float64",
733+
]
734+
735+
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
736+
df = self.spark.range(2).repartition(1)
737+
738+
for column in columns:
739+
with self.subTest(column=column):
740+
741+
@pandas_udf("decimal(10,0)", PandasUDFType.GROUPED_AGG)
742+
def test(series):
743+
return pdf[column].iloc[0]
744+
745+
row = df.groupby("id").agg(test(df.id)).first()
746+
res = row[1]
747+
self.assertEqual(res, Decimal("1"))
748+
749+
def test_arrow_cast_enabled_str_to_numeric(self):
750+
df = self.spark.range(2).repartition(1)
751+
752+
types = ["int", "long", "float", "double"]
753+
754+
for type_str in types:
755+
with self.subTest(type=type_str):
756+
757+
@pandas_udf(type_str, PandasUDFType.GROUPED_AGG)
758+
def test(series):
759+
return 123
760+
761+
row = df.groupby("id").agg(test(df.id)).first()
762+
self.assertEqual(row[1], 123)
763+
721764

722765
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
723766
pass

python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,6 +1875,36 @@ def test_udf(a, b=0):
18751875
with self.subTest(with_b=True, query_no=i):
18761876
assertDataFrameEqual(df, [Row(0), Row(101)])
18771877

1878+
def test_arrow_cast_enabled_numeric_to_decimal(self):
1879+
import numpy as np
1880+
1881+
columns = [
1882+
"int8",
1883+
"int16",
1884+
"int32",
1885+
"uint8",
1886+
"uint16",
1887+
"uint32",
1888+
"float64",
1889+
]
1890+
1891+
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
1892+
df = self.spark.range(2).repartition(1)
1893+
1894+
t = DecimalType(10, 0)
1895+
for column in columns:
1896+
with self.subTest(column=column):
1897+
v = pdf[column].iloc[:1]
1898+
row = df.select(pandas_udf(lambda _: v, t)(df.id)).first()
1899+
assert (row[0] == v).all()
1900+
1901+
def test_arrow_cast_enabled_str_to_numeric(self):
1902+
df = self.spark.range(2).repartition(1)
1903+
for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
1904+
with self.subTest(type=t):
1905+
row = df.select(pandas_udf(lambda _: pd.Series(["123"]), t)(df.id)).first()
1906+
assert row[0] == 123
1907+
18781908

18791909
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
18801910
@classmethod

python/pyspark/sql/tests/pandas/test_pandas_udf_window.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import unittest
1919
from typing import cast
20+
from decimal import Decimal
2021

2122
from pyspark.errors import AnalysisException, PythonException
2223
from pyspark.sql.functions import (
@@ -33,6 +34,13 @@
3334
PandasUDFType,
3435
)
3536
from pyspark.sql.window import Window
37+
from pyspark.sql.types import (
38+
DecimalType,
39+
IntegerType,
40+
LongType,
41+
FloatType,
42+
DoubleType,
43+
)
3644
from pyspark.testing.sqlutils import (
3745
ReusedSQLTestCase,
3846
have_pandas,
@@ -563,6 +571,43 @@ def weighted_mean(**kwargs):
563571
)
564572
).show()
565573

574+
def test_arrow_cast_numeric_to_decimal(self):
575+
import numpy as np
576+
import pandas as pd
577+
578+
columns = [
579+
"int8",
580+
"int16",
581+
"int32",
582+
"uint8",
583+
"uint16",
584+
"uint32",
585+
"float64",
586+
]
587+
588+
pdf = pd.DataFrame({key: np.arange(1, 2).astype(key) for key in columns})
589+
df = self.data
590+
w = self.unbounded_window
591+
592+
t = DecimalType(10, 0)
593+
for column in columns:
594+
with self.subTest(column=column):
595+
value = pdf[column].iloc[0]
596+
mean_udf = pandas_udf(lambda v: value, t, PandasUDFType.GROUPED_AGG)
597+
result = df.select(mean_udf(df["v"]).over(w)).first()[0]
598+
assert result == Decimal("1.0")
599+
assert type(result) == Decimal
600+
601+
def test_arrow_cast_str_to_numeric(self):
602+
df = self.data
603+
w = self.unbounded_window
604+
605+
for t in [IntegerType(), LongType(), FloatType(), DoubleType()]:
606+
with self.subTest(type=t):
607+
mean_udf = pandas_udf(lambda v: "123", t, PandasUDFType.GROUPED_AGG)
608+
result = df.select(mean_udf(df["v"]).over(w)).first()[0]
609+
assert result == 123
610+
566611

567612
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
568613
pass

python/pyspark/worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,6 +2285,7 @@ def read_udfs(pickleSer, infile, eval_type):
22852285
safecheck,
22862286
_assign_cols_by_name,
22872287
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
2288+
arrow_cast=True,
22882289
)
22892290
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
22902291
arrow_max_records_per_batch = runner_conf.get(
@@ -2374,8 +2375,6 @@ def read_udfs(pickleSer, infile, eval_type):
23742375
"row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
23752376
)
23762377
ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
2377-
# Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
2378-
arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
23792378
# Arrow-optimized Python UDF takes input types
23802379
input_types = (
23812380
[f.dataType for f in _parse_datatype_json_string(utf8_deserializer.loads(infile))]
@@ -2390,7 +2389,7 @@ def read_udfs(pickleSer, infile, eval_type):
23902389
df_for_struct,
23912390
struct_in_pandas,
23922391
ndarray_as_list,
2393-
arrow_cast,
2392+
True,
23942393
input_types,
23952394
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
23962395
)

0 commit comments

Comments
 (0)