Skip to content

Commit 609433d

Browse files
authored
Merge pull request #2 from twosigma/feature/custom_window_span
Allow rolling API to accept BaseIndexer subclass
2 parents 502db03 + 7e34fa0 commit 609433d

File tree

4 files changed

+115
-3
lines changed

4 files changed

+115
-3
lines changed

pandas/_libs/custom_window.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import abc
2+
from typing import Optional, Sequence, Tuple, Union
3+
4+
import numpy as np
5+
6+
from pandas.tseries.offsets import DateOffset
7+
8+
BeginEnd = Tuple[np.ndarray, np.ndarray]
9+
10+
# TODO: Refactor MockFixedWindowIndexer, FixedWindowIndexer,
11+
# VariableWindowIndexer to also have `get_window_bounds` methods that
12+
# only calculates start & stop
13+
14+
# TODO: Currently, when win_type is specified, it calls a special routine,
15+
# `roll_window`, while None win_type ops dispatch to specific methods.
16+
# Consider consolidating?
17+
18+
19+
class BaseIndexer(abc.ABC):
20+
def __init__(self, index, offset, keys):
21+
# TODO: The alternative is for the `rolling` API to accept
22+
# index, offset, and keys as keyword arguments
23+
self.index = index
24+
self.offset = offset # type: Union[str, DateOffset]
25+
self.keys = keys # type: Sequence[np.ndarray]
26+
27+
@classmethod
28+
@abc.abstractmethod
29+
def get_window_bounds(
30+
cls,
31+
win_type: Optional[str] = None,
32+
min_periods: Optional[int] = None,
33+
center: Optional[bool] = None,
34+
closed: Optional[str] = None,
35+
) -> BeginEnd:
36+
"""
37+
Compute the bounds of a window.
38+
39+
Users should subclass this class to implement a custom method
40+
to calculate window bounds
41+
42+
Parameters
43+
----------
44+
win_type : str, default None
45+
win_type passed from the top level rolling API
46+
47+
min_periods : int, default None
48+
min_periods passed from the top level rolling API
49+
50+
center : bool, default None
51+
center passed from the top level rolling API
52+
53+
closed : str, default None
54+
closed passed from the top level rolling API
55+
56+
Returns
57+
-------
58+
BeginEnd
59+
A tuple of ndarray[int64]s, indicating the boundaries of each
60+
window
61+
62+
"""

pandas/core/window.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212

13+
import pandas._libs.custom_window as libwindow_custom
1314
import pandas._libs.window as libwindow
1415
from pandas.compat._optional import import_optional_dependency
1516
from pandas.compat.numpy import function as nv
@@ -481,14 +482,19 @@ class Window(_Window):
481482
482483
Parameters
483484
----------
484-
window : int, or offset
485+
window : int, offset, or BaseIndexer subclass
485486
Size of the moving window. This is the number of observations used for
486487
calculating the statistic. Each window will be a fixed size.
487488
488489
If its an offset then this will be the time period of each window. Each
489490
window will be a variable sized based on the observations included in
490491
the time-period. This is only valid for datetimelike indexes. This is
491492
new in 0.19.0
493+
494+
If a BaseIndexer subclass is passed, calculates the window boundaries
495+
based on the defined ``get_window_bounds`` method. Additional rolling
496+
keyword arguments, namely `min_periods`, `center`, `win_type`, and
497+
`closed` will be passed to `get_window_bounds`.
492498
min_periods : int, default None
493499
Minimum number of observations in window required to have a value
494500
(otherwise result is NA). For a window that is specified by an offset,
@@ -631,7 +637,7 @@ def validate(self):
631637
super().validate()
632638

633639
window = self.window
634-
if isinstance(window, (list, tuple, np.ndarray)):
640+
if isinstance(window, (list, tuple, np.ndarray, libwindow_custom.BaseIndexer)):
635641
pass
636642
elif is_integer(window):
637643
if window <= 0:
@@ -693,6 +699,13 @@ def _pop_args(win_type, arg_names, kwargs):
693699
win_type = _validate_win_type(self.win_type, kwargs)
694700
# GH #15662. `False` makes symmetric window, rather than periodic.
695701
return sig.get_window(win_type, window, False).astype(float)
702+
elif isinstance(window, libwindow_custom.BaseIndexer):
703+
return window.get_window_span(
704+
win_type=self.win_type,
705+
min_periods=self.min_periods,
706+
center=self.center,
707+
closed=self.closed,
708+
)
696709

697710
def _apply_window(self, mean=True, **kwargs):
698711
"""
@@ -1731,7 +1744,8 @@ def validate(self):
17311744
# min_periods must be an integer
17321745
if self.min_periods is None:
17331746
self.min_periods = 1
1734-
1747+
elif isinstance(self.window, libwindow_custom.BaseIndexer):
1748+
pass
17351749
elif not is_integer(self.window):
17361750
raise ValueError("window must be an integer")
17371751
elif self.window < 0:
@@ -2782,6 +2796,8 @@ def _get_center_of_mass(comass, span, halflife, alpha):
27822796

27832797

27842798
def _offset(window, center):
2799+
# TODO: (MATT) If the window is a BaseIndexer subclass,
2800+
# we need to pass in the materialized window
27852801
if not is_integer(window):
27862802
window = len(window)
27872803
offset = (window - 1) / 2.0 if center else 0

pandas/tests/window/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import pytest
22

3+
import pandas._libs.custom_window as libwindow_custom
4+
5+
from pandas import date_range, offsets
6+
37

48
@pytest.fixture(params=[True, False])
59
def raw(request):
@@ -47,3 +51,18 @@ def center(request):
4751
@pytest.fixture(params=[None, 1])
4852
def min_periods(request):
4953
return request.param
54+
55+
56+
@pytest.fixture
57+
def dummy_custom_indexer():
58+
class DummyIndexer(libwindow_custom.BaseIndexer):
59+
def __init__(self, index, offset, keys):
60+
super().__init__(index, offset, keys)
61+
62+
def get_window_bounds(self, **kwargs):
63+
pass
64+
65+
idx = date_range("2019", freq="D", periods=3)
66+
offset = offsets.BusinessDay(1)
67+
keys = ["A"]
68+
return DummyIndexer(index=idx, offset=offset, keys=keys)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from pandas import Series
2+
3+
4+
def test_custom_indexer_validates(
5+
dummy_custom_indexer, win_types, closed, min_periods, center
6+
):
7+
# Test passing a BaseIndexer subclass does not raise validation errors
8+
s = Series(range(10))
9+
s.rolling(
10+
dummy_custom_indexer,
11+
win_type=win_types,
12+
center=center,
13+
min_periods=min_periods,
14+
closed=closed,
15+
)

0 commit comments

Comments
 (0)