|
34 | 34 | if have_pandas:
|
35 | 35 | import pandas as pd
|
36 | 36 | from pandas.util.testing import assert_frame_equal
|
| 37 | + from distutils.version import LooseVersion |
| 38 | + pandas_version = LooseVersion(pd.__version__) |
37 | 39 |
|
38 | 40 | if have_pyarrow:
|
39 | 41 | import pyarrow as pa
|
@@ -442,25 +444,47 @@ def test_createDataFrame_with_category_type(self):
|
442 | 444 | self.assertIsInstance(arrow_first_category_element, str)
|
443 | 445 | self.assertIsInstance(spark_first_category_element, str)
|
444 | 446 |
|
445 |
| - def test_createDataFrame_from_string_extension_dtype(self): |
446 |
| - pdf = pd.DataFrame({u"A": [u"a", u"b", u"c", u"d"]}) |
447 |
| - pdf_ext_dtype = pd.DataFrame({u"A": [u"a", u"b", u"c", u"d"]}, dtype=pd.StringDtype()) |
| 447 | + def _assert_converted_dfs_equal(self, pdf1, pdf2): |
| 448 | + df1 = self.spark.createDataFrame(pdf1) |
| 449 | + df2 = self.spark.createDataFrame(pdf2) |
| 450 | + self.assertEqual(df1.schema, df2.schema) |
| 451 | + self.assertEqual(df1.collect(), df2.collect()) |
448 | 452 |
|
449 |
| - df = self.spark.createDataFrame(pdf) |
450 |
| - df_ext_dtype = self.spark.createDataFrame(pdf_ext_dtype) |
451 |
| - |
452 |
| - self.assertEqual(df_ext_dtype.schema, df.schema) |
453 |
| - self.assertEqual(df_ext_dtype.collect(), df.collect()) |
454 |
| - |
455 |
| - def test_createDataFrame_from_integer_extension_dtype(self): |
| 453 | + @unittest.skipIf(pandas_version < "0.24.0", "pandas < 0.24.0 missing Int64Dtype") |
| 454 | + def test_createDataFrame_with_pandas_integer_dtype(self): |
456 | 455 | pdf = pd.DataFrame({u"A": range(4)})
|
457 | 456 | pdf_ext_dtype = pd.DataFrame({u"A": range(4)}, dtype=pd.Int64Dtype())
|
| 457 | + self._assert_converted_dfs_equal(pdf, pdf_ext_dtype) |
| 458 | + |
| 459 | + @unittest.skipIf(pandas_version < "1.0.0", |
| 460 | + "pandas < 1.0.0 missing StringDtype and BooleanDtype") |
| 461 | + def test_createDataFrame_with_pandas_boolean_and_string_dtypes(self): |
| 462 | + pdf = pd.DataFrame({ |
| 463 | + u"A": pd.Series([0, 1, 2, 3]), |
| 464 | + u"B": pd.Series([u"a", u"b", u"c", u"d"]), |
| 465 | + u"C": pd.Series([True, False, True, False]), |
| 466 | + }) |
| 467 | + pdf_ext_dtype = pd.DataFrame({ |
| 468 | + u"A": pd.Series([0, 1, 2, 3], dtype=pd.Int64Dtype()), |
| 469 | + u"B": pd.Series([u"a", u"b", u"c", u"d"], dtype=pd.StringDtype()), |
| 470 | + u"C": pd.Series([True, False, True, False], dtype=pd.BooleanDtype()), |
| 471 | + }) |
| 472 | + self._assert_converted_dfs_equal(pdf, pdf_ext_dtype) |
| 473 | + |
| 474 | + @unittest.skipIf(pandas_version < "1.0.0", "pandas < 1.0.0 missing pd.NA") |
| 475 | + def test_createDataFrame_with_pd_NA_values(self): |
| 476 | + pdf = pd.DataFrame({ |
| 477 | + u"A": pd.Series([0, pd.NA, 2, 3]), |
| 478 | + u"B": pd.Series([pd.NA, u"b", u"c", u"d"]), |
| 479 | + u"C": pd.Series([True, False, pd.NA, False]), |
| 480 | + }) |
| 481 | + pdf_ext_dtype = pd.DataFrame({ |
| 482 | + u"A": pd.Series([0, pd.NA, 2, 3], dtype=pd.Int64Dtype()), |
| 483 | + u"B": pd.Series([pd.NA, u"b", u"c", u"d"], dtype=pd.StringDtype()), |
| 484 | + u"C": pd.Series([True, False, pd.NA, False], dtype=pd.BooleanDtype()), |
| 485 | + }) |
| 486 | + self._assert_converted_dfs_equal(pdf, pdf_ext_dtype) |
458 | 487 |
|
459 |
| - df = self.spark.createDataFrame(pdf) |
460 |
| - df_ext_dtype = self.spark.createDataFrame(pdf_ext_dtype) |
461 |
| - |
462 |
| - self.assertEqual(df_ext_dtype.schema, df.schema) |
463 |
| - self.assertEqual(df_ext_dtype.collect(), df.collect()) |
464 | 488 |
|
465 | 489 |
|
466 | 490 | @unittest.skipIf(
|
|
0 commit comments