Skip to content

Commit e60e2d4

Browse files
committed
Updates
* Use infer_type over Schema.from_pandas for arrow type inference, as it can better handle extension types and pd.NA values * Call __arrow_array__ directly if it is present to exit create_array early in _create_batch * Add pandas version checks where required for tests * Add tests covering pd.NA and BooleanDtype conversion
1 parent 04a15f6 commit e60e2d4

File tree

3 files changed

+54
-23
lines changed

3 files changed

+54
-23
lines changed

python/pyspark/sql/pandas/conversion.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,11 +394,11 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
394394

395395
# Create the Spark schema from list of names passed in with Arrow types
396396
if isinstance(schema, (list, tuple)):
397-
# Arrow < 0.17.0 cannot handle ExtensionDtype columns when inferring the schema
398-
arrow_schema = pa.Schema.from_pandas(pdf.astype('object'), preserve_index=False)
397+
inferred_types = [pa.infer_type(s, mask=s.isna(), from_pandas=True)
398+
for s in (pdf[c] for c in pdf)]
399399
struct = StructType()
400-
for name, field in zip(schema, arrow_schema):
401-
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
400+
for name, t in zip(schema, inferred_types):
401+
struct.add(name, from_arrow_type(t), nullable=True)
402402
schema = struct
403403

404404
# Determine arrow types to coerce data when creating batches

python/pyspark/sql/pandas/serializers.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def _create_batch(self, series):
141141
:return: Arrow RecordBatch
142142
"""
143143
import pandas as pd
144+
from pandas.api.types import is_categorical_dtype
144145
import pyarrow as pa
145146
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
146147
# Make input conform to [(series1, type1), (series2, type2), ...]
@@ -150,16 +151,22 @@ def _create_batch(self, series):
150151
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
151152

152153
def create_array(s, t):
153-
# If the series implements __arrow_array__, conversion will fail if a mask is passed
154-
mask = s.isnull() if not hasattr(s.values, "__arrow_array__") else None
154+
# Create with __arrow_array__ if the series' backing array implements it
155+
series_array = getattr(s, 'array', s._values)
156+
if hasattr(series_array, "__arrow_array__"):
157+
return series_array.__arrow_array__(type=t)
158+
155159
# Ensure timestamp series are in expected form for Spark internal representation
156160
if t is not None and pa.types.is_timestamp(t):
157161
s = _check_series_convert_timestamps_internal(s, self._timezone)
158-
elif type(s.dtype) == pd.CategoricalDtype:
162+
elif is_categorical_dtype(s.dtype):
159163
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
160164
s = s.astype(s.dtypes.categories.dtype)
161165
try:
162-
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
166+
mask = s.isnull()
167+
# pass _ndarray_values to avoid potential failed type checks from pandas array types
168+
array = pa.Array.from_pandas(s._ndarray_values, mask=mask, type=t,
169+
safe=self._safecheck)
163170
except pa.ArrowException as e:
164171
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
165172
"Array (%s). It can be caused by overflows or other unsafe " + \

python/pyspark/sql/tests/test_arrow.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
if have_pandas:
3535
import pandas as pd
3636
from pandas.util.testing import assert_frame_equal
37+
from distutils.version import LooseVersion
38+
pandas_version = LooseVersion(pd.__version__)
3739

3840
if have_pyarrow:
3941
import pyarrow as pa
@@ -442,25 +444,47 @@ def test_createDataFrame_with_category_type(self):
442444
self.assertIsInstance(arrow_first_category_element, str)
443445
self.assertIsInstance(spark_first_category_element, str)
444446

445-
def test_createDataFrame_from_string_extension_dtype(self):
446-
pdf = pd.DataFrame({u"A": [u"a", u"b", u"c", u"d"]})
447-
pdf_ext_dtype = pd.DataFrame({u"A": [u"a", u"b", u"c", u"d"]}, dtype=pd.StringDtype())
447+
def _assert_converted_dfs_equal(self, pdf1, pdf2):
448+
df1 = self.spark.createDataFrame(pdf1)
449+
df2 = self.spark.createDataFrame(pdf2)
450+
self.assertEqual(df1.schema, df2.schema)
451+
self.assertEqual(df1.collect(), df2.collect())
448452

449-
df = self.spark.createDataFrame(pdf)
450-
df_ext_dtype = self.spark.createDataFrame(pdf_ext_dtype)
451-
452-
self.assertEqual(df_ext_dtype.schema, df.schema)
453-
self.assertEqual(df_ext_dtype.collect(), df.collect())
454-
455-
def test_createDataFrame_from_integer_extension_dtype(self):
453+
@unittest.skipIf(pandas_version < "0.24.0", "pandas < 0.24.0 missing Int64Dtype")
454+
def test_createDataFrame_with_pandas_integer_dtype(self):
456455
pdf = pd.DataFrame({u"A": range(4)})
457456
pdf_ext_dtype = pd.DataFrame({u"A": range(4)}, dtype=pd.Int64Dtype())
457+
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)
458+
459+
@unittest.skipIf(pandas_version < "1.0.0",
460+
"pandas < 1.0.0 missing StringDtype and BooleanDtype")
461+
def test_createDataFrame_with_pandas_boolean_and_string_dtypes(self):
462+
pdf = pd.DataFrame({
463+
u"A": pd.Series([0, 1, 2, 3]),
464+
u"B": pd.Series([u"a", u"b", u"c", u"d"]),
465+
u"C": pd.Series([True, False, True, False]),
466+
})
467+
pdf_ext_dtype = pd.DataFrame({
468+
u"A": pd.Series([0, 1, 2, 3], dtype=pd.Int64Dtype()),
469+
u"B": pd.Series([u"a", u"b", u"c", u"d"], dtype=pd.StringDtype()),
470+
u"C": pd.Series([True, False, True, False], dtype=pd.BooleanDtype()),
471+
})
472+
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)
473+
474+
@unittest.skipIf(pandas_version < "1.0.0", "pandas < 1.0.0 missing pd.NA")
475+
def test_createDataFrame_with_pd_NA_values(self):
476+
pdf = pd.DataFrame({
477+
u"A": pd.Series([0, pd.NA, 2, 3]),
478+
u"B": pd.Series([pd.NA, u"b", u"c", u"d"]),
479+
u"C": pd.Series([True, False, pd.NA, False]),
480+
})
481+
pdf_ext_dtype = pd.DataFrame({
482+
u"A": pd.Series([0, pd.NA, 2, 3], dtype=pd.Int64Dtype()),
483+
u"B": pd.Series([pd.NA, u"b", u"c", u"d"], dtype=pd.StringDtype()),
484+
u"C": pd.Series([True, False, pd.NA, False], dtype=pd.BooleanDtype()),
485+
})
486+
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)
458487

459-
df = self.spark.createDataFrame(pdf)
460-
df_ext_dtype = self.spark.createDataFrame(pdf_ext_dtype)
461-
462-
self.assertEqual(df_ext_dtype.schema, df.schema)
463-
self.assertEqual(df_ext_dtype.collect(), df.collect())
464488

465489

466490
@unittest.skipIf(

0 commit comments

Comments
 (0)