From 9bc9ab2c8945b37aab402adb0b0c8d008223c79d Mon Sep 17 00:00:00 2001 From: mproszewska <38814059+mproszewska@users.noreply.github.com> Date: Fri, 15 May 2020 16:11:38 +0200 Subject: [PATCH] Backport PR #33983 on branch 1.0.x (BUG: Use args and kwargs in Rolling.apply) --- doc/source/whatsnew/v1.0.4.rst | 1 + pandas/core/window/rolling.py | 2 ++ pandas/tests/window/test_apply.py | 27 ++++++++++++++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/doc/source/whatsnew/v1.0.4.rst b/doc/source/whatsnew/v1.0.4.rst index 88b4ba748fd63..95007f4dd1caf 100644 --- a/doc/source/whatsnew/v1.0.4.rst +++ b/doc/source/whatsnew/v1.0.4.rst @@ -25,6 +25,7 @@ Fixed regressions - Fix to preserve the ability to index with the "nearest" method with xarray's CFTimeIndex, an :class:`Index` subclass (`pydata/xarray#3751 `_, :issue:`32905`). - Fix regression in :meth:`DataFrame.describe` raising ``TypeError: unhashable type: 'dict'`` (:issue:`32409`) - Bug in :meth:`DataFrame.replace` casts columns to ``object`` dtype if items in ``to_replace`` not in values (:issue:`32988`) +- Bug in :meth:`GroupBy.rolling.apply` ignores args and kwargs parameters (:issue:`33433`) - .. _whatsnew_104.bug_fixes: diff --git a/pandas/core/window/rolling.py b/pandas/core/window/rolling.py index b7e1779a08562..a01a753e83813 100644 --- a/pandas/core/window/rolling.py +++ b/pandas/core/window/rolling.py @@ -1304,6 +1304,8 @@ def apply( name=func, use_numba_cache=engine == "numba", raw=raw, + args=args, + kwargs=kwargs, ) def _generate_cython_apply_func(self, args, kwargs, raw, offset, func): diff --git a/pandas/tests/window/test_apply.py b/pandas/tests/window/test_apply.py index 7132e64c1191c..f56227b72fc48 100644 --- a/pandas/tests/window/test_apply.py +++ b/pandas/tests/window/test_apply.py @@ -3,7 +3,7 @@ import pandas.util._test_decorators as td -from pandas import DataFrame, Series, Timestamp, date_range +from pandas import DataFrame, Index, MultiIndex, Series, Timestamp, date_range import pandas._testing as tm @@ -138,3 +138,28 @@ def test_invalid_kwargs_nopython(): Series(range(1)).rolling(1).apply( lambda x: x, kwargs={"a": 1}, engine="numba", raw=True ) + + +@pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]]) +def test_rolling_apply_args_kwargs(args_kwargs): + # GH 33433 + def foo(x, par): + return np.sum(x + par) + + df = DataFrame({"gr": [1, 1], "a": [1, 2]}) + + idx = Index(["gr", "a"]) + expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx) + + result = df.rolling(1).apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_frame_equal(result, expected) + + result = df.rolling(1).apply(foo, args=(10,)) + + midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None]) + expected = Series([11.0, 12.0], index=midx, name="a") + + gb_rolling = df.groupby("gr")["a"].rolling(1) + + result = gb_rolling.apply(foo, args=args_kwargs[0], kwargs=args_kwargs[1]) + tm.assert_series_equal(result, expected)