Skip to content

Commit 268c462

Browse files
CLN/TST: Remove Base Class and all subclasses and fixturize data (#34179)
1 parent c10020f commit 268c462

12 files changed

+2149
-2086
lines changed

pandas/tests/window/common.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,8 @@
1-
from datetime import datetime
2-
31
import numpy as np
4-
from numpy.random import randn
52

6-
from pandas import DataFrame, Series, bdate_range
3+
from pandas import Series
74
import pandas._testing as tm
85

9-
N, K = 100, 10
10-
11-
12-
class Base:
13-
14-
_nan_locs = np.arange(20, 40)
15-
_inf_locs = np.array([])
16-
17-
def _create_data(self):
18-
arr = randn(N)
19-
arr[self._nan_locs] = np.NaN
20-
21-
self.arr = arr
22-
self.rng = bdate_range(datetime(2009, 1, 1), periods=N)
23-
self.series = Series(arr.copy(), index=self.rng)
24-
self.frame = DataFrame(randn(N, K), index=self.rng, columns=np.arange(K))
25-
266

277
def check_pairwise_moment(frame, dispatch, name, **kwargs):
288
def get_result(obj, obj2=None):

pandas/tests/window/conftest.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from datetime import datetime
2+
13
import numpy as np
4+
from numpy.random import randn
25
import pytest
36

47
import pandas.util._test_decorators as td
58

6-
from pandas import DataFrame, Series, notna
9+
from pandas import DataFrame, Series, bdate_range, notna
710

811

912
@pytest.fixture(params=[True, False])
@@ -242,3 +245,60 @@ def no_nans(x):
242245
def consistency_data(request):
243246
"""Create consistency data"""
244247
return request.param
248+
249+
250+
def _create_arr():
251+
"""Internal function to mock an array."""
252+
arr = randn(100)
253+
locs = np.arange(20, 40)
254+
arr[locs] = np.NaN
255+
return arr
256+
257+
258+
def _create_rng():
259+
"""Internal function to mock date range."""
260+
rng = bdate_range(datetime(2009, 1, 1), periods=100)
261+
return rng
262+
263+
264+
def _create_series():
265+
"""Internal function to mock Series."""
266+
arr = _create_arr()
267+
series = Series(arr.copy(), index=_create_rng())
268+
return series
269+
270+
271+
def _create_frame():
272+
"""Internal function to mock DataFrame."""
273+
rng = _create_rng()
274+
return DataFrame(randn(100, 10), index=rng, columns=np.arange(10))
275+
276+
277+
@pytest.fixture
278+
def nan_locs():
279+
"""Make a range as loc fixture."""
280+
return np.arange(20, 40)
281+
282+
283+
@pytest.fixture
284+
def arr():
285+
"""Make an array as fixture."""
286+
return _create_arr()
287+
288+
289+
@pytest.fixture
290+
def frame():
291+
"""Make mocked frame as fixture."""
292+
return _create_frame()
293+
294+
295+
@pytest.fixture
296+
def series():
297+
"""Make mocked series as fixture."""
298+
return _create_series()
299+
300+
301+
@pytest.fixture(params=[_create_series(), _create_frame()])
302+
def which(request):
303+
"""Turn parametrized which as fixture for series and frame"""
304+
return request.param

pandas/tests/window/moments/test_moments_consistency_ewm.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from pandas import DataFrame, Series, concat
66
from pandas.tests.window.common import (
7-
Base,
87
check_binary_ew,
98
check_binary_ew_min_periods,
109
check_pairwise_moment,
@@ -19,13 +18,9 @@
1918
)
2019

2120

22-
class TestEwmMomentsConsistency(Base):
23-
def setup_method(self, method):
24-
self._create_data()
25-
26-
@pytest.mark.parametrize("func", ["cov", "corr"])
27-
def test_ewm_pairwise_cov_corr(self, func):
28-
check_pairwise_moment(self.frame, "ewm", func, span=10, min_periods=5)
21+
@pytest.mark.parametrize("func", ["cov", "corr"])
22+
def test_ewm_pairwise_cov_corr(func, frame):
23+
check_pairwise_moment(frame, "ewm", func, span=10, min_periods=5)
2924

3025

3126
@pytest.mark.parametrize("name", ["cov", "corr"])

pandas/tests/window/moments/test_moments_consistency_expanding.py

Lines changed: 107 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pandas import DataFrame, Index, MultiIndex, Series, isna, notna
88
import pandas._testing as tm
99
from pandas.tests.window.common import (
10-
Base,
1110
moments_consistency_cov_data,
1211
moments_consistency_is_constant,
1312
moments_consistency_mock_mean,
@@ -18,132 +17,145 @@
1817
)
1918

2019

21-
class TestExpandingMomentsConsistency(Base):
22-
def setup_method(self, method):
23-
self._create_data()
20+
def _check_expanding(
21+
func, static_comp, preserve_nan=True, series=None, frame=None, nan_locs=None
22+
):
2423

25-
def test_expanding_corr(self):
26-
A = self.series.dropna()
27-
B = (A + randn(len(A)))[:-5]
24+
series_result = func(series)
25+
assert isinstance(series_result, Series)
26+
frame_result = func(frame)
27+
assert isinstance(frame_result, DataFrame)
2828

29-
result = A.expanding().corr(B)
29+
result = func(series)
30+
tm.assert_almost_equal(result[10], static_comp(series[:11]))
3031

31-
rolling_result = A.rolling(window=len(A), min_periods=1).corr(B)
32+
if preserve_nan:
33+
assert result.iloc[nan_locs].isna().all()
3234

33-
tm.assert_almost_equal(rolling_result, result)
3435

35-
def test_expanding_count(self):
36-
result = self.series.expanding(min_periods=0).count()
37-
tm.assert_almost_equal(
38-
result, self.series.rolling(window=len(self.series), min_periods=0).count()
39-
)
36+
def _check_expanding_has_min_periods(func, static_comp, has_min_periods):
37+
ser = Series(randn(50))
4038

41-
def test_expanding_quantile(self):
42-
result = self.series.expanding().quantile(0.5)
39+
if has_min_periods:
40+
result = func(ser, min_periods=30)
41+
assert result[:29].isna().all()
42+
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
4343

44-
rolling_result = self.series.rolling(
45-
window=len(self.series), min_periods=1
46-
).quantile(0.5)
44+
# min_periods is working correctly
45+
result = func(ser, min_periods=15)
46+
assert isna(result.iloc[13])
47+
assert notna(result.iloc[14])
4748

48-
tm.assert_almost_equal(result, rolling_result)
49+
ser2 = Series(randn(20))
50+
result = func(ser2, min_periods=5)
51+
assert isna(result[3])
52+
assert notna(result[4])
4953

50-
def test_expanding_cov(self):
51-
A = self.series
52-
B = (A + randn(len(A)))[:-5]
54+
# min_periods=0
55+
result0 = func(ser, min_periods=0)
56+
result1 = func(ser, min_periods=1)
57+
tm.assert_almost_equal(result0, result1)
58+
else:
59+
result = func(ser)
60+
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
5361

54-
result = A.expanding().cov(B)
5562

56-
rolling_result = A.rolling(window=len(A), min_periods=1).cov(B)
63+
def test_expanding_corr(series):
64+
A = series.dropna()
65+
B = (A + randn(len(A)))[:-5]
5766

58-
tm.assert_almost_equal(rolling_result, result)
67+
result = A.expanding().corr(B)
5968

60-
def test_expanding_cov_pairwise(self):
61-
result = self.frame.expanding().corr()
69+
rolling_result = A.rolling(window=len(A), min_periods=1).corr(B)
6270

63-
rolling_result = self.frame.rolling(
64-
window=len(self.frame), min_periods=1
65-
).corr()
71+
tm.assert_almost_equal(rolling_result, result)
6672

67-
tm.assert_frame_equal(result, rolling_result)
6873

69-
def test_expanding_corr_pairwise(self):
70-
result = self.frame.expanding().corr()
74+
def test_expanding_count(series):
75+
result = series.expanding(min_periods=0).count()
76+
tm.assert_almost_equal(
77+
result, series.rolling(window=len(series), min_periods=0).count()
78+
)
7179

72-
rolling_result = self.frame.rolling(
73-
window=len(self.frame), min_periods=1
74-
).corr()
75-
tm.assert_frame_equal(result, rolling_result)
7680

77-
@pytest.mark.parametrize("has_min_periods", [True, False])
78-
@pytest.mark.parametrize(
79-
"func,static_comp",
80-
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
81-
ids=["sum", "mean", "max", "min"],
82-
)
83-
def test_expanding_func(self, func, static_comp, has_min_periods):
84-
def expanding_func(x, min_periods=1, center=False, axis=0):
85-
exp = x.expanding(min_periods=min_periods, center=center, axis=axis)
86-
return getattr(exp, func)()
87-
88-
self._check_expanding(expanding_func, static_comp, preserve_nan=False)
89-
self._check_expanding_has_min_periods(
90-
expanding_func, static_comp, has_min_periods
91-
)
81+
def test_expanding_quantile(series):
82+
result = series.expanding().quantile(0.5)
83+
84+
rolling_result = series.rolling(window=len(series), min_periods=1).quantile(0.5)
85+
86+
tm.assert_almost_equal(result, rolling_result)
87+
9288

93-
@pytest.mark.parametrize("has_min_periods", [True, False])
94-
def test_expanding_apply(self, engine_and_raw, has_min_periods):
89+
def test_expanding_cov(series):
90+
A = series
91+
B = (A + randn(len(A)))[:-5]
9592

96-
engine, raw = engine_and_raw
93+
result = A.expanding().cov(B)
9794

98-
def expanding_mean(x, min_periods=1):
95+
rolling_result = A.rolling(window=len(A), min_periods=1).cov(B)
9996

100-
exp = x.expanding(min_periods=min_periods)
101-
result = exp.apply(lambda x: x.mean(), raw=raw, engine=engine)
102-
return result
97+
tm.assert_almost_equal(rolling_result, result)
10398

104-
# TODO(jreback), needed to add preserve_nan=False
105-
# here to make this pass
106-
self._check_expanding(expanding_mean, np.mean, preserve_nan=False)
107-
self._check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)
10899

109-
def _check_expanding(self, func, static_comp, preserve_nan=True):
100+
def test_expanding_cov_pairwise(frame):
101+
result = frame.expanding().cov()
110102

111-
series_result = func(self.series)
112-
assert isinstance(series_result, Series)
113-
frame_result = func(self.frame)
114-
assert isinstance(frame_result, DataFrame)
103+
rolling_result = frame.rolling(window=len(frame), min_periods=1).cov()
115104

116-
result = func(self.series)
117-
tm.assert_almost_equal(result[10], static_comp(self.series[:11]))
105+
tm.assert_frame_equal(result, rolling_result)
118106

119-
if preserve_nan:
120-
assert result.iloc[self._nan_locs].isna().all()
121107

122-
def _check_expanding_has_min_periods(self, func, static_comp, has_min_periods):
123-
ser = Series(randn(50))
108+
def test_expanding_corr_pairwise(frame):
109+
result = frame.expanding().corr()
124110

125-
if has_min_periods:
126-
result = func(ser, min_periods=30)
127-
assert result[:29].isna().all()
128-
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
111+
rolling_result = frame.rolling(window=len(frame), min_periods=1).corr()
112+
tm.assert_frame_equal(result, rolling_result)
129113

130-
# min_periods is working correctly
131-
result = func(ser, min_periods=15)
132-
assert isna(result.iloc[13])
133-
assert notna(result.iloc[14])
134114

135-
ser2 = Series(randn(20))
136-
result = func(ser2, min_periods=5)
137-
assert isna(result[3])
138-
assert notna(result[4])
115+
@pytest.mark.parametrize("has_min_periods", [True, False])
116+
@pytest.mark.parametrize(
117+
"func,static_comp",
118+
[("sum", np.sum), ("mean", np.mean), ("max", np.max), ("min", np.min)],
119+
ids=["sum", "mean", "max", "min"],
120+
)
121+
def test_expanding_func(func, static_comp, has_min_periods, series, frame, nan_locs):
122+
def expanding_func(x, min_periods=1, center=False, axis=0):
123+
exp = x.expanding(min_periods=min_periods, center=center, axis=axis)
124+
return getattr(exp, func)()
125+
126+
_check_expanding(
127+
expanding_func,
128+
static_comp,
129+
preserve_nan=False,
130+
series=series,
131+
frame=frame,
132+
nan_locs=nan_locs,
133+
)
134+
_check_expanding_has_min_periods(expanding_func, static_comp, has_min_periods)
135+
139136

140-
# min_periods=0
141-
result0 = func(ser, min_periods=0)
142-
result1 = func(ser, min_periods=1)
143-
tm.assert_almost_equal(result0, result1)
144-
else:
145-
result = func(ser)
146-
tm.assert_almost_equal(result.iloc[-1], static_comp(ser[:50]))
137+
@pytest.mark.parametrize("has_min_periods", [True, False])
138+
def test_expanding_apply(engine_and_raw, has_min_periods, series, frame, nan_locs):
139+
140+
engine, raw = engine_and_raw
141+
142+
def expanding_mean(x, min_periods=1):
143+
144+
exp = x.expanding(min_periods=min_periods)
145+
result = exp.apply(lambda x: x.mean(), raw=raw, engine=engine)
146+
return result
147+
148+
# TODO(jreback), needed to add preserve_nan=False
149+
# here to make this pass
150+
_check_expanding(
151+
expanding_mean,
152+
np.mean,
153+
preserve_nan=False,
154+
series=series,
155+
frame=frame,
156+
nan_locs=nan_locs,
157+
)
158+
_check_expanding_has_min_periods(expanding_mean, np.mean, has_min_periods)
147159

148160

149161
@pytest.mark.parametrize("min_periods", [0, 1, 2, 3, 4])

0 commit comments

Comments
 (0)