Skip to content

Commit 1927652

Browse files
authored
ENH: Styler.apply accept ndarray return with axis=None for consistency (#39393)
1 parent 9705434 commit 1927652

File tree

4 files changed

+78
-71
lines changed

4 files changed

+78
-71
lines changed

doc/source/user_guide/style.ipynb

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,10 @@
140140
"metadata": {},
141141
"outputs": [],
142142
"source": [
143-
"s = df.style.set_table_attributes('class=\"table-cls\"')\n",
144-
"cls = pd.DataFrame(data=[['cls1', None], ['cls3', 'cls2 cls3']], index=[0,2], columns=['A', 'C'])\n",
145-
"s.set_td_classes(cls)"
143+
"css_classes = pd.DataFrame(data=[['cls1', None], ['cls3', 'cls2 cls3']], index=[0,2], columns=['A', 'C'])\n",
144+
"df.style.\\\n",
145+
" set_table_attributes('class=\"table-cls\"').\\\n",
146+
" set_td_classes(css_classes)"
146147
]
147148
},
148149
{
@@ -314,13 +315,10 @@
314315
"outputs": [],
315316
"source": [
316317
"def color_negative_red(val):\n",
317-
" \"\"\"\n",
318-
" Takes a scalar and returns a string with\n",
319-
" the css property `'color: red'` for negative\n",
320-
" strings, black otherwise.\n",
321-
" \"\"\"\n",
322-
" color = 'red' if val < 0 else 'black'\n",
323-
" return 'color: %s' % color"
318+
" \"\"\"Color negative scalars red.\"\"\"\n",
319+
" css = 'color: red;'\n",
320+
" if val < 0: return css\n",
321+
" return None"
324322
]
325323
},
326324
{
@@ -368,11 +366,9 @@
368366
"outputs": [],
369367
"source": [
370368
"def highlight_max(s):\n",
371-
" '''\n",
372-
" highlight the maximum in a Series yellow.\n",
373-
" '''\n",
374-
" is_max = s == s.max()\n",
375-
" return ['background-color: yellow' if v else '' for v in is_max]"
369+
" \"\"\"Highlight the maximum in a Series bold-orange.\"\"\"\n",
370+
" css = 'background-color: orange; font-weight: bold;'\n",
371+
" return np.where(s == np.nanmax(s.values), css, None)"
376372
]
377373
},
378374
{
@@ -384,11 +380,20 @@
384380
"df.style.apply(highlight_max)"
385381
]
386382
},
383+
{
384+
"cell_type": "code",
385+
"execution_count": null,
386+
"metadata": {},
387+
"outputs": [],
388+
"source": [
389+
"df.style.apply(highlight_max, axis=1)"
390+
]
391+
},
387392
{
388393
"cell_type": "markdown",
389394
"metadata": {},
390395
"source": [
391-
"In this case the input is a `Series`, one column at a time.\n",
396+
"In this case the input is a `Series`, one column (or row) at a time.\n",
392397
"Notice that the output shape of `highlight_max` matches the input shape, an array with `len(s)` items."
393398
]
394399
},
@@ -406,8 +411,8 @@
406411
"outputs": [],
407412
"source": [
408413
"def compare_col(s, comparator=None):\n",
409-
" attr = 'background-color: #00BFFF;'\n",
410-
" return np.where(s < comparator, attr, '')"
414+
" css = 'background-color: #00BFFF;'\n",
415+
" return np.where(s < comparator, css, None)"
411416
]
412417
},
413418
{
@@ -442,41 +447,12 @@
442447
"cell_type": "markdown",
443448
"metadata": {},
444449
"source": [
445-
"Above we used `Styler.apply` to pass in each column one at a time.\n",
450+
"Above we used `Styler.apply` to pass in each column (or row) one at a time.\n",
446451
"\n",
447452
"<span style=\"background-color: #DEDEBE\">*Debugging Tip*: If you're having trouble writing your style function, try just passing it into <code style=\"background-color: #DEDEBE\">DataFrame.apply</code>. Internally, <code style=\"background-color: #DEDEBE\">Styler.apply</code> uses <code style=\"background-color: #DEDEBE\">DataFrame.apply</code> so the result should be the same.</span>\n",
448453
"\n",
449454
"What if you wanted to highlight just the maximum value in the entire table?\n",
450-
"Use `.apply(function, axis=None)` to indicate that your function wants the entire table, not one column or row at a time. Let's try that next.\n",
451-
"\n",
452-
"We'll rewrite our `highlight-max` to handle either Series (from `.apply(axis=0 or 1)`) or DataFrames (from `.apply(axis=None)`). We'll also allow the color to be adjustable, to demonstrate that `.apply`, and `.applymap` pass along keyword arguments."
453-
]
454-
},
455-
{
456-
"cell_type": "code",
457-
"execution_count": null,
458-
"metadata": {},
459-
"outputs": [],
460-
"source": [
461-
"def highlight_max(data, color='yellow'):\n",
462-
" '''\n",
463-
" highlight the maximum in a Series or DataFrame\n",
464-
" '''\n",
465-
" attr = 'background-color: {}'.format(color)\n",
466-
" if data.ndim == 1: # Series from .apply(axis=0) or axis=1\n",
467-
" is_max = data == data.max()\n",
468-
" return [attr if v else '' for v in is_max]\n",
469-
" else: # from .apply(axis=None)\n",
470-
" is_max = data == data.max().max()\n",
471-
" return pd.DataFrame(np.where(is_max, attr, ''),\n",
472-
" index=data.index, columns=data.columns)"
473-
]
474-
},
475-
{
476-
"cell_type": "markdown",
477-
"metadata": {},
478-
"source": [
479-
"When using ``Styler.apply(func, axis=None)``, the function must return a DataFrame with the same index and column labels."
455+
"Use `.apply(function, axis=None)` to indicate that your function wants the entire table, not one column or row at a time. In this case the return must be a DataFrame or ndarray of the same shape as the input. Let's try that next. "
480456
]
481457
},
482458
{
@@ -485,7 +461,7 @@
485461
"metadata": {},
486462
"outputs": [],
487463
"source": [
488-
"s = df.style.apply(highlight_max, color='darkorange', axis=None)\n",
464+
"s = df.style.apply(highlight_max, axis=None)\n",
489465
"s"
490466
]
491467
},

doc/source/whatsnew/v1.3.0.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ Other enhancements
5353
- :meth:`DataFrame.apply` can now accept non-callable DataFrame properties as strings, e.g. ``df.apply("size")``, which was already the case for :meth:`Series.apply` (:issue:`39116`)
5454
- :meth:`Series.apply` can now accept list-like or dictionary-like arguments that aren't lists or dictionaries, e.g. ``ser.apply(np.array(["sum", "mean"]))``, which was already the case for :meth:`DataFrame.apply` (:issue:`39140`)
5555
- :meth:`DataFrame.plot.scatter` can now accept a categorical column as the argument to ``c`` (:issue:`12380`, :issue:`31357`)
56-
- :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes (:issue:`35643`)
56+
- :meth:`.Styler.set_tooltips` allows on hover tooltips to be added to styled HTML dataframes (:issue:`35643`, :issue:`21266`, :issue:`39317`)
5757
- :meth:`.Styler.set_tooltips_class` and :meth:`.Styler.set_table_styles` amended to optionally allow certain css-string input arguments (:issue:`39564`)
58+
- :meth:`.Styler.apply` now more consistently accepts ndarray function returns, i.e. in all cases for ``axis`` is ``0, 1 or None``. (:issue:`39359`)
5859
- :meth:`Series.loc.__getitem__` and :meth:`Series.loc.__setitem__` with :class:`MultiIndex` now raising helpful error message when indexer has too many dimensions (:issue:`35349`)
5960
- :meth:`pandas.read_stata` and :class:`StataReader` support reading data from compressed files.
6061

pandas/io/formats/style.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -842,25 +842,31 @@ def _apply(
842842
else:
843843
result = func(data, **kwargs)
844844
if not isinstance(result, pd.DataFrame):
845-
raise TypeError(
846-
f"Function {repr(func)} must return a DataFrame when "
847-
f"passed to `Styler.apply` with axis=None"
848-
)
849-
if not (
845+
if not isinstance(result, np.ndarray):
846+
raise TypeError(
847+
f"Function {repr(func)} must return a DataFrame or ndarray "
848+
f"when passed to `Styler.apply` with axis=None"
849+
)
850+
if not (data.shape == result.shape):
851+
raise ValueError(
852+
f"Function {repr(func)} returned ndarray with wrong shape.\n"
853+
f"Result has shape: {result.shape}\n"
854+
f"Expected shape: {data.shape}"
855+
)
856+
result = DataFrame(result, index=data.index, columns=data.columns)
857+
elif not (
850858
result.index.equals(data.index) and result.columns.equals(data.columns)
851859
):
852860
raise ValueError(
853861
f"Result of {repr(func)} must have identical "
854862
f"index and columns as the input"
855863
)
856864

857-
result_shape = result.shape
858-
expected_shape = self.data.loc[subset].shape
859-
if result_shape != expected_shape:
865+
if result.shape != data.shape:
860866
raise ValueError(
861867
f"Function {repr(func)} returned the wrong shape.\n"
862868
f"Result has shape: {result.shape}\n"
863-
f"Expected shape: {expected_shape}"
869+
f"Expected shape: {data.shape}"
864870
)
865871
self._update_ctx(result)
866872
return self
@@ -873,7 +879,7 @@ def apply(
873879
**kwargs,
874880
) -> Styler:
875881
"""
876-
Apply a function column-wise, row-wise, or table-wise.
882+
Apply a CSS-styling function column-wise, row-wise, or table-wise.
877883
878884
Updates the HTML representation with the result.
879885
@@ -883,7 +889,10 @@ def apply(
883889
``func`` should take a Series or DataFrame (depending
884890
on ``axis``), and return an object with the same shape.
885891
Must return a DataFrame with identical index and
886-
column labels when ``axis=None``.
892+
column labels or an ndarray with same shape as input when ``axis=None``.
893+
894+
.. versionchanged:: 1.3.0
895+
887896
axis : {0 or 'index', 1 or 'columns', None}, default 0
888897
Apply to each column (``axis=0`` or ``'index'``), to each row
889898
(``axis=1`` or ``'columns'``), or to the entire DataFrame at once
@@ -900,22 +909,24 @@ def apply(
900909
901910
Notes
902911
-----
903-
The output shape of ``func`` should match the input, i.e. if
912+
The output of ``func`` should be elements having CSS style as string or,
913+
if nothing is to be applied to that element, an empty string or ``None``.
914+
The output shape must match the input, i.e. if
904915
``x`` is the input row, column, or table (depending on ``axis``),
905-
then ``func(x).shape == x.shape`` should be true.
916+
then ``func(x).shape == x.shape`` should be ``True``.
906917
907918
This is similar to ``DataFrame.apply``, except that ``axis=None``
908919
applies the function to the entire DataFrame at once,
909920
rather than column-wise or row-wise.
910921
911922
Examples
912923
--------
913-
>>> def highlight_max(x):
914-
... return ['background-color: yellow' if v == x.max() else ''
915-
for v in x]
916-
...
924+
>>> def highlight_max(x, color):
925+
... return np.where(x == np.nanmax(x.values), f"color: {color};", None)
917926
>>> df = pd.DataFrame(np.random.randn(5, 2))
918-
>>> df.style.apply(highlight_max)
927+
>>> df.style.apply(highlight_max, color='red')
928+
>>> df.style.apply(highlight_max, color='blue', axis=1)
929+
>>> df.style.apply(highlight_max, color='green', axis=None)
919930
"""
920931
self._todo.append(
921932
(lambda instance: getattr(instance, "_apply"), (func, axis, subset), kwargs)
@@ -933,7 +944,7 @@ def _applymap(self, func: Callable, subset=None, **kwargs) -> Styler:
933944

934945
def applymap(self, func: Callable, subset=None, **kwargs) -> Styler:
935946
"""
936-
Apply a function elementwise.
947+
Apply a CSS-styling function elementwise.
937948
938949
Updates the HTML representation with the result.
939950
@@ -955,6 +966,18 @@ def applymap(self, func: Callable, subset=None, **kwargs) -> Styler:
955966
--------
956967
Styler.where: Updates the HTML representation with a style which is
957968
selected in accordance with the return value of a function.
969+
970+
Notes
971+
-----
972+
The output of ``func`` should be a CSS style as string or, if nothing is to be
973+
applied, an empty string or ``None``.
974+
975+
Examples
976+
--------
977+
>>> def color_negative(v, color):
978+
... return f"color: {color};" if v < 0 else None
979+
>>> df = pd.DataFrame(np.random.randn(5, 2))
980+
>>> df.style.applymap(color_negative, color='red')
958981
"""
959982
self._todo.append(
960983
(lambda instance: getattr(instance, "_applymap"), (func, subset), kwargs)

pandas/tests/io/formats/test_style.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,12 +1395,19 @@ def test_bad_apply_shape(self):
13951395
with pytest.raises(ValueError, match=msg):
13961396
df.style._apply(lambda x: ["", "", ""], axis=1)
13971397

1398+
msg = "returned ndarray with wrong shape"
1399+
with pytest.raises(ValueError, match=msg):
1400+
df.style._apply(lambda x: np.array([[""], [""]]), axis=None)
1401+
13981402
def test_apply_bad_return(self):
13991403
def f(x):
14001404
return ""
14011405

14021406
df = DataFrame([[1, 2], [3, 4]])
1403-
msg = "must return a DataFrame when passed to `Styler.apply` with axis=None"
1407+
msg = (
1408+
"must return a DataFrame or ndarray when passed to `Styler.apply` "
1409+
"with axis=None"
1410+
)
14041411
with pytest.raises(TypeError, match=msg):
14051412
df.style._apply(f, axis=None)
14061413

0 commit comments

Comments
 (0)