Skip to content

[SPARK-31920][PYTHON] Fix pandas conversion using Arrow with __arrow_array__ columns #28743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
20 changes: 16 additions & 4 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
from pyspark.sql.types import TimestampType
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type
from pyspark.sql.pandas.types import from_arrow_type, to_arrow_type, \
_try_arrow_array_protocol
from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

Expand All @@ -394,10 +395,21 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):

# Create the Spark schema from list of names passed in with Arrow types
if isinstance(schema, (list, tuple)):
arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
inferred_types = []
for s in (pdf[c] for c in pdf):
s_array = _try_arrow_array_protocol(s[:0])
if s_array is not None:
t = s_array.type
if isinstance(t, pa.ExtensionType):
t = t.storage_type
else:
t = pa.infer_type(s, mask=s.isna(), from_pandas=True)
inferred_types.append(t)
struct = StructType()
for name, field in zip(schema, arrow_schema):
struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
for name, t in zip(schema, inferred_types):
# nullability is not determined on types inferred by Arrow or
# by the non-Arrow conversion path, so default to nullable
struct.add(name, from_arrow_type(t), nullable=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we follow nullability anymore?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_type only returns a type, not a field, which would supposedly have nullability information. But it appears that in the implementation of Schema.from_pandas (link), inferring nullability was not actually done and the default nullable=True would always be returned. So this change is just following the existing behaviour of Schema.from_pandas.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment here to explain it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will update with a comment.

Alternatively, any(s.isna()) could be checked if we wanted to actively infer nullability here. This would change existing behavior as well as being inconsistent with the non-Arrow path, though, which similarly defaults to inferred types being nullable:

fields = [StructField(k, _infer_type(v), True) for k, v in items]

schema = struct

# Determine arrow types to coerce data when creating batches
Expand Down
14 changes: 11 additions & 3 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def _create_batch(self, series):
"""
import pandas as pd
import pyarrow as pa
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal
from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal, \
_try_arrow_array_protocol
from pandas.api.types import is_categorical
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
Expand All @@ -151,15 +152,22 @@ def _create_batch(self, series):
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

def create_array(s, t):
mask = s.isnull()
# Create with __arrow_array__ if the series' backing array implements it
array = _try_arrow_array_protocol(s, t)
if array is not None:
return array
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif is_categorical(s.dtype):
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
mask = s.isnull()
# pass _ndarray_values to avoid erroneous failed type checks from pandas array types
# that do not implement __arrow_array__ (i.e. pre-1.0.0 IntegerArray)
array = pa.Array.from_pandas(s._ndarray_values, mask=mask, type=t,
safe=self._safecheck)
except pa.ArrowException as e:
error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
"Array (%s). It can be caused by overflows or other unsafe " + \
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,11 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
:return pandas.Series where if it is a timestamp, has been converted to tz-naive
"""
return _check_series_convert_timestamps_localize(s, timezone, None)


def _try_arrow_array_protocol(s, t=None):
arrow_array = None
s_array = getattr(s, 'array', s._values)
if hasattr(s_array, "__arrow_array__"):
arrow_array = s_array.__arrow_array__(type=t)
return arrow_array
46 changes: 45 additions & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
from pyspark.util import _exception_message
from distutils.version import LooseVersion

pandas_version = None
if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal
pandas_version = LooseVersion(pd.__version__)

if have_pyarrow:
import pyarrow as pa
Expand Down Expand Up @@ -415,7 +418,7 @@ def run_test(num_records, num_parts, max_records, use_delay=False):
for case in cases:
run_test(*case)

def test_createDateFrame_with_category_type(self):
def test_createDataFrame_with_category_type(self):
pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
pdf["B"] = pdf["A"].astype('category')
category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]
Expand All @@ -442,6 +445,47 @@ def test_createDateFrame_with_category_type(self):
self.assertIsInstance(arrow_first_category_element, str)
self.assertIsInstance(spark_first_category_element, str)

def _assert_converted_dfs_equal(self, pdf1, pdf2):
df1 = self.spark.createDataFrame(pdf1)
df2 = self.spark.createDataFrame(pdf2)
self.assertEqual(df1.schema, df2.schema)
self.assertEqual(df1.collect(), df2.collect())

@unittest.skipIf(pandas_version < "0.24.0", "pandas < 0.24.0 missing Int64Dtype")
def test_createDataFrame_with_pandas_integer_dtype(self):
pdf = pd.DataFrame({u"A": range(4)})
pdf_ext_dtype = pd.DataFrame({u"A": range(4)}, dtype=pd.Int64Dtype())
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)

@unittest.skipIf(pandas_version < "1.0.0",
"pandas < 1.0.0 missing StringDtype and BooleanDtype")
def test_createDataFrame_with_pandas_boolean_and_string_dtypes(self):
pdf = pd.DataFrame({
u"A": pd.Series([0, 1, 2, 3]),
u"B": pd.Series([u"a", u"b", u"c", u"d"]),
u"C": pd.Series([True, False, True, False]),
})
pdf_ext_dtype = pd.DataFrame({
u"A": pd.Series([0, 1, 2, 3], dtype=pd.Int64Dtype()),
u"B": pd.Series([u"a", u"b", u"c", u"d"], dtype=pd.StringDtype()),
u"C": pd.Series([True, False, True, False], dtype=pd.BooleanDtype()),
})
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)

@unittest.skipIf(pandas_version < "1.0.0", "pandas < 1.0.0 missing pd.NA")
def test_createDataFrame_with_pd_NA_values(self):
pdf = pd.DataFrame({
u"A": pd.Series([0, pd.NA, 2, 3]),
u"B": pd.Series([pd.NA, u"b", u"c", u"d"]),
u"C": pd.Series([True, False, pd.NA, False]),
})
pdf_ext_dtype = pd.DataFrame({
u"A": pd.Series([0, pd.NA, 2, 3], dtype=pd.Int64Dtype()),
u"B": pd.Series([pd.NA, u"b", u"c", u"d"], dtype=pd.StringDtype()),
u"C": pd.Series([True, False, pd.NA, False], dtype=pd.BooleanDtype()),
})
self._assert_converted_dfs_equal(pdf, pdf_ext_dtype)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down