Skip to content

Commit d0cfa03

Browse files
authored
ENH: categorical scatter plot (#34293)
1 parent 984def2 commit d0cfa03

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

doc/source/user_guide/visualization.rst

+18
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ These can be specified by the ``x`` and ``y`` keywords.
552552
.. ipython:: python
553553
554554
df = pd.DataFrame(np.random.rand(50, 4), columns=["a", "b", "c", "d"])
555+
df["species"] = pd.Categorical(
556+
["setosa"] * 20 + ["versicolor"] * 20 + ["virginica"] * 10
557+
)
555558
556559
@savefig scatter_plot.png
557560
df.plot.scatter(x="a", y="b");
@@ -579,6 +582,21 @@ each point:
579582
df.plot.scatter(x="a", y="b", c="c", s=50);
580583
581584
585+
.. ipython:: python
586+
:suppress:
587+
588+
plt.close("all")
589+
590+
If a categorical column is passed to ``c``, then a discrete colorbar will be produced:
591+
592+
.. versionadded:: 1.3.0
593+
594+
.. ipython:: python
595+
596+
@savefig scatter_plot_categorical.png
597+
df.plot.scatter(x="a", y="b", c="species", cmap="viridis", s=50);
598+
599+
582600
.. ipython:: python
583601
:suppress:
584602

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Other enhancements
5252
- :meth:`DataFrame.apply` can now accept NumPy unary operators as strings, e.g. ``df.apply("sqrt")``, which was already the case for :meth:`Series.apply` (:issue:`39116`)
5353
- :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`)
5454
- :meth:`Series.apply` can now accept list-like or dictionary-like arguments that aren't lists or dictionaries, e.g. ``ser.apply(np.array(["sum", "mean"]))``, which was already the case for :meth:`DataFrame.apply` (:issue:`39140`)
55+
- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`12380`, :issue:`31357`)
5556
- :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes.
5657

5758
.. ---------------------------------------------------------------------------

pandas/plotting/_matplotlib/core.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pandas.util._decorators import cache_readonly
1111

1212
from pandas.core.dtypes.common import (
13+
is_categorical_dtype,
1314
is_extension_array_dtype,
1415
is_float,
1516
is_float_dtype,
@@ -388,6 +389,10 @@ def result(self):
388389
return self.axes[0]
389390

390391
def _convert_to_ndarray(self, data):
392+
# GH31357: categorical columns are processed separately
393+
if is_categorical_dtype(data):
394+
return data
395+
391396
# GH32073: cast to float if values contain nulled integers
392397
if (
393398
is_integer_dtype(data.dtype) or is_float_dtype(data.dtype)
@@ -974,7 +979,7 @@ def _plot_colorbar(self, ax: Axes, **kwds):
974979

975980
if mpl_ge_3_0_0():
976981
# The workaround below is no longer necessary.
977-
return
982+
return cbar
978983

979984
points = ax.get_position().get_points()
980985
cbar_points = cbar.ax.get_position().get_points()
@@ -992,6 +997,8 @@ def _plot_colorbar(self, ax: Axes, **kwds):
992997
# print(points[1, 1] - points[0, 1])
993998
# print(cbar_points[1, 1] - cbar_points[0, 1])
994999

1000+
return cbar
1001+
9951002

9961003
class ScatterPlot(PlanePlot):
9971004
_kind = "scatter"
@@ -1014,6 +1021,8 @@ def _make_plot(self):
10141021

10151022
c_is_column = is_hashable(c) and c in self.data.columns
10161023

1024+
color_by_categorical = c_is_column and is_categorical_dtype(self.data[c])
1025+
10171026
# pandas uses colormap, matplotlib uses cmap.
10181027
cmap = self.colormap or "Greys"
10191028
cmap = self.plt.cm.get_cmap(cmap)
@@ -1024,11 +1033,22 @@ def _make_plot(self):
10241033
c_values = self.plt.rcParams["patch.facecolor"]
10251034
elif color is not None:
10261035
c_values = color
1036+
elif color_by_categorical:
1037+
c_values = self.data[c].cat.codes
10271038
elif c_is_column:
10281039
c_values = self.data[c].values
10291040
else:
10301041
c_values = c
10311042

1043+
if color_by_categorical:
1044+
from matplotlib import colors
1045+
1046+
n_cats = len(self.data[c].cat.categories)
1047+
cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
1048+
bounds = np.linspace(0, n_cats, n_cats + 1)
1049+
norm = colors.BoundaryNorm(bounds, cmap.N)
1050+
else:
1051+
norm = None
10321052
# plot colorbar if
10331053
# 1. colormap is assigned, and
10341054
# 2.`c` is a column containing only numeric values
@@ -1045,11 +1065,15 @@ def _make_plot(self):
10451065
c=c_values,
10461066
label=label,
10471067
cmap=cmap,
1068+
norm=norm,
10481069
**self.kwds,
10491070
)
10501071
if cb:
10511072
cbar_label = c if c_is_column else ""
1052-
self._plot_colorbar(ax, label=cbar_label)
1073+
cbar = self._plot_colorbar(ax, label=cbar_label)
1074+
if color_by_categorical:
1075+
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
1076+
cbar.ax.set_yticklabels(self.data[c].cat.categories)
10531077

10541078
if label is not None:
10551079
self._add_legend_handle(scatter, label)

pandas/tests/plotting/frame/test_frame.py

+31
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,37 @@ def test_scatterplot_object_data(self):
696696
_check_plot_works(df.plot.scatter, x="a", y="b")
697697
_check_plot_works(df.plot.scatter, x=0, y=1)
698698

699+
@pytest.mark.parametrize("ordered", [True, False])
700+
@pytest.mark.parametrize(
701+
"categories",
702+
(["setosa", "versicolor", "virginica"], ["versicolor", "virginica", "setosa"]),
703+
)
704+
def test_scatterplot_color_by_categorical(self, ordered, categories):
705+
df = DataFrame(
706+
[[5.1, 3.5], [4.9, 3.0], [7.0, 3.2], [6.4, 3.2], [5.9, 3.0]],
707+
columns=["length", "width"],
708+
)
709+
df["species"] = pd.Categorical(
710+
["setosa", "setosa", "virginica", "virginica", "versicolor"],
711+
ordered=ordered,
712+
categories=categories,
713+
)
714+
ax = df.plot.scatter(x=0, y=1, c="species")
715+
(colorbar_collection,) = ax.collections
716+
colorbar = colorbar_collection.colorbar
717+
718+
expected_ticks = np.array([0.5, 1.5, 2.5])
719+
result_ticks = colorbar.get_ticks()
720+
tm.assert_numpy_array_equal(result_ticks, expected_ticks)
721+
722+
expected_boundaries = np.array([0.0, 1.0, 2.0, 3.0])
723+
result_boundaries = colorbar._boundaries
724+
tm.assert_numpy_array_equal(result_boundaries, expected_boundaries)
725+
726+
expected_yticklabels = categories
727+
result_yticklabels = [i.get_text() for i in colorbar.ax.get_ymajorticklabels()]
728+
assert all(i == j for i, j in zip(result_yticklabels, expected_yticklabels))
729+
699730
@pytest.mark.parametrize("x, y", [("x", "y"), ("y", "x"), ("y", "y")])
700731
def test_plot_scatter_with_categorical_data(self, x, y):
701732
# after fixing GH 18755, should be able to plot categorical data

0 commit comments

Comments
 (0)