Skip to content

Commit 7509752

Browse files
committed
ENH: make sure return dtypes for nan funcs are consistent
1 parent 5852e72 commit 7509752

File tree

3 files changed

+54
-42
lines changed

3 files changed

+54
-42
lines changed

pandas/core/nanops.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ def nanall(values, axis=None, skipna=True):
244244
@bottleneck_switch(zero_value=0)
245245
def nansum(values, axis=None, skipna=True):
246246
values, mask, dtype, dtype_max = _get_values(values, skipna, 0)
247-
the_sum = values.sum(axis, dtype=dtype_max)
247+
dtype_sum = dtype_max
248+
if is_float_dtype(dtype):
249+
dtype_sum = dtype
250+
the_sum = values.sum(axis, dtype=dtype_sum)
248251
the_sum = _maybe_null_out(the_sum, axis, mask)
249252

250253
return _wrap_results(the_sum, dtype)
@@ -288,7 +291,7 @@ def get_median(x):
288291
return np.nan
289292
return algos.median(_values_from_object(x[mask]))
290293

291-
if values.dtype != np.float64:
294+
if not is_float_dtype(values):
292295
values = values.astype('f8')
293296
values[mask] = np.nan
294297

@@ -317,10 +320,10 @@ def get_median(x):
317320
return _wrap_results(get_median(values) if notempty else np.nan, dtype)
318321

319322

320-
def _get_counts_nanvar(mask, axis, ddof):
321-
count = _get_counts(mask, axis)
323+
def _get_counts_nanvar(mask, axis, ddof, dtype):
324+
count = _get_counts(mask, axis, dtype)
322325

323-
d = count-ddof
326+
d = count - dtype.type(ddof)
324327

325328
# always return NaN, never inf
326329
if np.isscalar(count):
@@ -338,18 +341,18 @@ def _get_counts_nanvar(mask, axis, ddof):
338341
def _nanvar(values, axis=None, skipna=True, ddof=1):
339342
# private nanvar calculator
340343
mask = isnull(values)
341-
if is_any_int_dtype(values):
344+
if not is_float_dtype(values):
342345
values = values.astype('f8')
343346

344-
count, d = _get_counts_nanvar(mask, axis, ddof)
347+
count, d = _get_counts_nanvar(mask, axis, ddof, values.dtype)
345348

346349
if skipna:
347350
values = values.copy()
348351
np.putmask(values, mask, 0)
349352

350353
X = _ensure_numeric(values.sum(axis))
351354
XX = _ensure_numeric((values ** 2).sum(axis))
352-
return np.fabs((XX - X ** 2 / count) / d)
355+
return np.fabs((XX - X * X / count) / d)
353356

354357
@disallow('M8')
355358
@bottleneck_switch(ddof=1)
@@ -375,9 +378,9 @@ def nansem(values, axis=None, skipna=True, ddof=1):
375378
mask = isnull(values)
376379
if not is_floating_dtype(values):
377380
values = values.astype('f8')
378-
count, _ = _get_counts_nanvar(mask, axis, ddof)
381+
count, _ = _get_counts_nanvar(mask, axis, ddof, values.dtype)
379382

380-
return np.sqrt(var)/np.sqrt(count)
383+
return np.sqrt(var) / np.sqrt(count)
381384

382385

383386
@bottleneck_switch()
@@ -467,25 +470,29 @@ def nanargmin(values, axis=None, skipna=True):
467470
def nanskew(values, axis=None, skipna=True):
468471

469472
mask = isnull(values)
470-
if not is_floating_dtype(values):
473+
if not is_float_dtype(values):
471474
values = values.astype('f8')
472-
473-
count = _get_counts(mask, axis)
475+
count = _get_counts(mask, axis)
476+
else:
477+
count = _get_counts(mask, axis, dtype=values.dtype)
474478

475479
if skipna:
476480
values = values.copy()
477481
np.putmask(values, mask, 0)
478482

483+
typ = values.dtype.type
479484
A = values.sum(axis) / count
480-
B = (values ** 2).sum(axis) / count - A ** 2
481-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
485+
B = (values ** 2).sum(axis) / count - A ** typ(2)
486+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
482487

483488
# floating point error
484489
B = _zero_out_fperr(B)
485490
C = _zero_out_fperr(C)
486491

487-
result = ((np.sqrt((count ** 2 - count)) * C) /
488-
((count - 2) * np.sqrt(B) ** 3))
492+
# result = ((np.sqrt((count ** 2 - count)) * C) /
493+
# ((count - 2) * np.sqrt(B) ** 3))
494+
result = ((np.sqrt(count * count - count) * C) /
495+
((count - typ(2)) * np.sqrt(B) ** typ(3)))
489496

490497
if isinstance(result, np.ndarray):
491498
result = np.where(B == 0, 0, result)
@@ -502,19 +509,21 @@ def nanskew(values, axis=None, skipna=True):
502509
def nankurt(values, axis=None, skipna=True):
503510

504511
mask = isnull(values)
505-
if not is_floating_dtype(values):
512+
if not is_float_dtype(values):
506513
values = values.astype('f8')
507-
508-
count = _get_counts(mask, axis)
514+
count = _get_counts(mask, axis)
515+
else:
516+
count = _get_counts(mask, axis, dtype=values.dtype)
509517

510518
if skipna:
511519
values = values.copy()
512520
np.putmask(values, mask, 0)
513521

522+
typ = values.dtype.type
514523
A = values.sum(axis) / count
515-
B = (values ** 2).sum(axis) / count - A ** 2
516-
C = (values ** 3).sum(axis) / count - A ** 3 - 3 * A * B
517-
D = (values ** 4).sum(axis) / count - A ** 4 - 6 * B * A * A - 4 * C * A
524+
B = (values ** 2).sum(axis) / count - A ** typ(2)
525+
C = (values ** 3).sum(axis) / count - A ** typ(3) - typ(3) * A * B
526+
D = (values ** 4).sum(axis) / count - A ** typ(4) - typ(6) * B * A * A - typ(4) * C * A
518527

519528
B = _zero_out_fperr(B)
520529
D = _zero_out_fperr(D)
@@ -526,8 +535,8 @@ def nankurt(values, axis=None, skipna=True):
526535
if B == 0:
527536
return 0
528537

529-
result = (((count * count - 1.) * D / (B * B) - 3 * ((count - 1.) ** 2)) /
530-
((count - 2.) * (count - 3.)))
538+
result = (((count * count - typ(1)) * D / (B * B) - typ(3) * ((count - typ(1)) ** typ(2))) /
539+
((count - typ(2)) * (count - typ(3))))
531540

532541
if isinstance(result, np.ndarray):
533542
result = np.where(B == 0, 0, result)
@@ -598,7 +607,7 @@ def _zero_out_fperr(arg):
598607
if isinstance(arg, np.ndarray):
599608
return np.where(np.abs(arg) < 1e-14, 0, arg)
600609
else:
601-
return 0 if np.abs(arg) < 1e-14 else arg
610+
return arg.dtype.type(0) if np.abs(arg) < 1e-14 else arg
602611

603612

604613
@disallow('M8','m8')

pandas/tests/test_nanops.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,18 @@ def test_nanmean_overflow(self):
340340
self.assertEqual(result, np_result)
341341
self.assertTrue(result.dtype == np.float64)
342342

343-
# check returned dtype
344-
for dtype in [np.int16, np.int32, np.int64, np.float16, np.float32, np.float64]:
343+
def test_returned_dtype(self):
344+
from pandas import Series
345+
for dtype in [np.int16, np.int32, np.int64, np.float32, np.float64, np.float128]:
345346
s = Series(range(10), dtype=dtype)
346-
result = s.mean()
347-
if is_integer_dtype(dtype):
348-
self.assertTrue(result.dtype == np.float64)
349-
else:
350-
self.assertTrue(result.dtype == dtype)
347+
for method in ['mean', 'std', 'var', 'skew', 'kurt']:
348+
result = getattr(s, method)()
349+
if is_integer_dtype(dtype):
350+
self.assertTrue(result.dtype == np.float64,
351+
"return dtype expected from %s is np.float64, got %s instead" % (method, result.dtype))
352+
else:
353+
self.assertTrue(result.dtype == dtype,
354+
"return dtype expected from %s is %s, got %s instead" % (method, dtype, result.dtype))
351355

352356
def test_nanmedian(self):
353357
self.check_funs(nanops.nanmedian, np.median,

pandas/tests/test_series.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def test_nansum_buglet(self):
511511
assert_almost_equal(result, 1)
512512

513513
def test_overflow(self):
514-
515514
# GH 6915
516515
# overflowing on the smaller int dtypes
517516
for dtype in ['int32','int64']:
@@ -534,25 +533,25 @@ def test_overflow(self):
534533
result = s.max()
535534
self.assertEqual(int(result),v[-1])
536535

537-
for dtype in ['float32','float64']:
538-
v = np.arange(5000000,dtype=dtype)
536+
for dtype in ['float32', 'float64']:
537+
v = np.arange(5000000, dtype=dtype)
539538
s = Series(v)
540539

541540
# no bottleneck
542541
result = s.sum(skipna=False)
543-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
542+
self.assertEqual(result, v.sum(dtype=dtype))
544543
result = s.min(skipna=False)
545-
self.assertTrue(np.allclose(float(result),0.0))
544+
self.assertTrue(np.allclose(float(result), 0.0))
546545
result = s.max(skipna=False)
547-
self.assertTrue(np.allclose(float(result),v[-1]))
546+
self.assertTrue(np.allclose(float(result), v[-1]))
548547

549548
# use bottleneck if available
550549
result = s.sum()
551-
self.assertTrue(np.allclose(float(result),v.sum(dtype='float64')))
550+
self.assertEqual(result, v.sum(dtype=dtype))
552551
result = s.min()
553-
self.assertTrue(np.allclose(float(result),0.0))
552+
self.assertTrue(np.allclose(float(result), 0.0))
554553
result = s.max()
555-
self.assertTrue(np.allclose(float(result),v[-1]))
554+
self.assertTrue(np.allclose(float(result), v[-1]))
556555

557556
class SafeForSparse(object):
558557
pass

0 commit comments

Comments
 (0)