Skip to content

Commit 944a515

Browse files
committed
Code review update: set mask to None if the array to serialize is an arrow array
1 parent e76d9e3 commit 944a515

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,10 @@ def _create_batch(self, series):
160160
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
161161

162162
def create_array(s, t):
163-
mask = s.isnull()
163+
if hasattr(s.values, '__arrow_array__'):
164+
mask = None
165+
else:
166+
mask = s.isnull()
164167
# Ensure timestamp series are in expected form for Spark internal representation
165168
if t is not None and pa.types.is_timestamp(t) and t.tz is not None:
166169
s = _check_series_convert_timestamps_internal(s, self._timezone)
@@ -169,8 +172,6 @@ def create_array(s, t):
169172
elif is_categorical_dtype(s.dtype):
170173
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
171174
s = s.astype(s.dtypes.categories.dtype)
172-
elif t is not None and pa.types.is_string(t):
173-
s = s.astype(str)
174175
try:
175176
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
176177
except ValueError as e:

python/pyspark/sql/tests/test_arrow.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ def test_createDataFrame_with_string_dtype(self):
425425
# Changing that to use a StringArray would be backwards incompatible.
426426
assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
427427

428+
def test_createDataFrame_with_int64(self):
429+
# SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type
430+
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
431+
pandas_df = pd.DataFrame({"col": [1, 2, 3, None]}, dtype="Int64")
432+
df = self.spark.createDataFrame(pandas_df)
433+
assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
434+
428435
def test_toPandas_with_map_type(self):
429436
pdf = pd.DataFrame({"id": [0, 1, 2, 3],
430437
"m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]})

0 commit comments

Comments
 (0)