Skip to content

Commit 4c5461e

Browse files
author
y-p
committed
Merge pull request #3031 from y-p/groupby_bounds
ENH: add bounds-checking preamble to groupby_X cython code
2 parents b80c334 + e20199f commit 4c5461e

File tree

5 files changed

+156
-24
lines changed

5 files changed

+156
-24
lines changed

doc/source/v0.11.0.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ Bug Fixes
318318

319319
- Fixed slow printing of large Dataframes, due to inefficient dtype
320320
reporting (GH2807_)
321+
- Fixed a segfault when using a function as grouper in groupby (GH3035_)
321322
- Fix pretty-printing of infinite data structures (closes GH2978_)
322323
- Fixed exception when plotting timeseries bearing a timezone (closes GH2877_)
323324
- str.contains ignored na argument (GH2806_)
@@ -333,6 +334,7 @@ on GitHub for a complete list.
333334
.. _GH2810: https://github.com/pydata/pandas/issues/2810
334335
.. _GH2837: https://github.com/pydata/pandas/issues/2837
335336
.. _GH2898: https://github.com/pydata/pandas/issues/2898
337+
.. _GH3035: https://github.com/pydata/pandas/issues/3035
336338
.. _GH2978: https://github.com/pydata/pandas/issues/2978
337339
.. _GH2877: https://github.com/pydata/pandas/issues/2877
338340
.. _GH2739: https://github.com/pydata/pandas/issues/2739

pandas/core/groupby.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def _groupby_function(name, alias, npfunc, numeric_only=True,
5757
def f(self):
5858
try:
5959
return self._cython_agg_general(alias, numeric_only=numeric_only)
60+
except AssertionError as e:
61+
raise SpecificationError(str(e))
6062
except Exception:
6163
result = self.aggregate(lambda x: npfunc(x, axis=self.axis))
6264
if _convert:
@@ -348,7 +350,7 @@ def mean(self):
348350
"""
349351
try:
350352
return self._cython_agg_general('mean')
351-
except DataError:
353+
except GroupByError:
352354
raise
353355
except Exception: # pragma: no cover
354356
f = lambda x: x.mean(axis=self.axis)
@@ -362,7 +364,7 @@ def median(self):
362364
"""
363365
try:
364366
return self._cython_agg_general('median')
365-
except DataError:
367+
except GroupByError:
366368
raise
367369
except Exception: # pragma: no cover
368370
f = lambda x: x.median(axis=self.axis)
@@ -462,7 +464,10 @@ def _cython_agg_general(self, how, numeric_only=True):
462464
if numeric_only and not is_numeric:
463465
continue
464466

465-
result, names = self.grouper.aggregate(obj.values, how)
467+
try:
468+
result, names = self.grouper.aggregate(obj.values, how)
469+
except AssertionError as e:
470+
raise GroupByError(str(e))
466471
output[name] = result
467472

468473
if len(output) == 0:
@@ -1200,6 +1205,13 @@ def __init__(self, index, grouper=None, name=None, level=None,
12001205
# no level passed
12011206
if not isinstance(self.grouper, np.ndarray):
12021207
self.grouper = self.index.map(self.grouper)
1208+
if not (hasattr(self.grouper,"__len__") and \
1209+
len(self.grouper) == len(self.index)):
1210+
errmsg = "Grouper result violates len(labels) == len(data)\n"
1211+
errmsg += "result: %s" % com.pprint_thing(self.grouper)
1212+
self.grouper = None # Try for sanity
1213+
raise AssertionError(errmsg)
1214+
12031215

12041216
def __repr__(self):
12051217
return 'Grouping(%s)' % self.name
@@ -1723,9 +1735,10 @@ def _aggregate_multiple_funcs(self, arg):
17231735
grouper=self.grouper)
17241736
results.append(colg.aggregate(arg))
17251737
keys.append(col)
1726-
except (TypeError, DataError):
1738+
except (TypeError, DataError) :
17271739
pass
1728-
1740+
except SpecificationError:
1741+
raise
17291742
result = concat(results, keys=keys, axis=1)
17301743

17311744
return result

pandas/src/generate_code.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,9 @@ def groupby_%(name)s(ndarray[%(c_type)s] index, ndarray labels):
593593
594594
length = len(index)
595595
596+
if not length == len(labels):
597+
raise AssertionError("len(index) != len(labels)")
598+
596599
for i in range(length):
597600
key = util.get_value_1d(labels, i)
598601
@@ -625,6 +628,9 @@ def group_last_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
625628
ndarray[%(dest_type2)s, ndim=2] resx
626629
ndarray[int64_t, ndim=2] nobs
627630
631+
if not len(values) == len(labels):
632+
raise AssertionError("len(index) != len(labels)")
633+
628634
nobs = np.zeros((<object> out).shape, dtype=np.int64)
629635
resx = np.empty_like(out)
630636
@@ -760,6 +766,9 @@ def group_nth_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
760766
ndarray[%(dest_type2)s, ndim=2] resx
761767
ndarray[int64_t, ndim=2] nobs
762768
769+
if not len(values) == len(labels):
770+
raise AssertionError("len(index) != len(labels)")
771+
763772
nobs = np.zeros((<object> out).shape, dtype=np.int64)
764773
resx = np.empty_like(out)
765774
@@ -802,6 +811,9 @@ def group_add_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
802811
%(dest_type2)s val, count
803812
ndarray[%(dest_type2)s, ndim=2] sumx, nobs
804813
814+
if not len(values) == len(labels):
815+
raise AssertionError("len(index) != len(labels)")
816+
805817
nobs = np.zeros_like(out)
806818
sumx = np.zeros_like(out)
807819
@@ -915,6 +927,9 @@ def group_prod_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
915927
%(dest_type2)s val, count
916928
ndarray[%(dest_type2)s, ndim=2] prodx, nobs
917929
930+
if not len(values) == len(labels):
931+
raise AssertionError("len(index) != len(labels)")
932+
918933
nobs = np.zeros_like(out)
919934
prodx = np.ones_like(out)
920935
@@ -1025,6 +1040,9 @@ def group_var_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
10251040
%(dest_type2)s val, ct
10261041
ndarray[%(dest_type2)s, ndim=2] nobs, sumx, sumxx
10271042
1043+
if not len(values) == len(labels):
1044+
raise AssertionError("len(index) != len(labels)")
1045+
10281046
nobs = np.zeros_like(out)
10291047
sumx = np.zeros_like(out)
10301048
sumxx = np.zeros_like(out)
@@ -1220,6 +1238,9 @@ def group_max_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
12201238
%(dest_type2)s val, count
12211239
ndarray[%(dest_type2)s, ndim=2] maxx, nobs
12221240
1241+
if not len(values) == len(labels):
1242+
raise AssertionError("len(index) != len(labels)")
1243+
12231244
nobs = np.zeros_like(out)
12241245
12251246
maxx = np.empty_like(out)
@@ -1342,6 +1363,9 @@ def group_min_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
13421363
%(dest_type2)s val, count
13431364
ndarray[%(dest_type2)s, ndim=2] minx, nobs
13441365
1366+
if not len(values) == len(labels):
1367+
raise AssertionError("len(index) != len(labels)")
1368+
13451369
nobs = np.zeros_like(out)
13461370
13471371
minx = np.empty_like(out)
@@ -1399,6 +1423,9 @@ def group_mean_%(name)s(ndarray[%(dest_type2)s, ndim=2] out,
13991423
%(dest_type2)s val, count
14001424
ndarray[%(dest_type2)s, ndim=2] sumx, nobs
14011425
1426+
if not len(values) == len(labels):
1427+
raise AssertionError("len(index) != len(labels)")
1428+
14021429
nobs = np.zeros_like(out)
14031430
sumx = np.zeros_like(out)
14041431

0 commit comments

Comments
 (0)