Skip to content

Commit f0c7e64

Browse files
Cleanups after merge
First draft of 100% working test cases after (re)merging with changes extracted from these changes as part of PR hgrecco#196. Signed-off-by: Michael Tiemann <[email protected]>
1 parent 9f723f9 commit f0c7e64

File tree

2 files changed

+68
-51
lines changed

2 files changed

+68
-51
lines changed

pint_pandas/pint_array.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
# from pint.facets.plain.unit import PlainUnit as _Unit
3434

3535
if HAS_UNCERTAINTIES:
36-
from uncertainties import UFloat, ufloat
36+
from uncertainties import ufloat, UFloat
3737
from uncertainties import unumpy as unp
3838

3939
_ufloat_nan = ufloat(np.nan, 0)
@@ -330,12 +330,6 @@ def __setitem__(self, key, value):
330330
# doing nothing here seems to be ok
331331
return
332332

333-
master_scalar = None
334-
try:
335-
master_scalar = next(i for i in self._data if pd.notna(i))
336-
except StopIteration:
337-
pass
338-
339333
if isinstance(value, _Quantity):
340334
value = value.to(self.units).magnitude
341335
elif is_list_like(value) and len(value) > 0:
@@ -347,6 +341,36 @@ def __setitem__(self, key, value):
347341
key = check_array_indexer(self, key)
348342
# Filter out invalid values for our array type(s)
349343
try:
344+
if HAS_UNCERTAINTIES and is_object_dtype(self._data):
345+
from pandas.api.types import is_scalar, is_numeric_dtype
346+
347+
def value_to_ufloat(value):
348+
if pd.isna(value) or isinstance(value, UFloat):
349+
return value
350+
if is_numeric_dtype(type(value)):
351+
return ufloat(value, 0)
352+
raise ValueError
353+
354+
try:
355+
any_ufloats = next(
356+
True for i in self._data if isinstance(i, UFloat)
357+
)
358+
if any_ufloats:
359+
if is_scalar(key):
360+
if is_list_like(value):
361+
# cannot do many:1 setitem
362+
raise ValueError
363+
# 1:1 setitem
364+
value = value_to_ufloat(value)
365+
elif is_list_like(value):
366+
# many:many setitem
367+
value = [value_to_ufloat(v) for v in value]
368+
else:
369+
# broadcast 1:many
370+
value = value_to_ufloat(value)
371+
except StopIteration:
372+
# If array is full of nothingness, we can put anything inside it
373+
pass
350374
self._data[key] = value
351375
except IndexError as e:
352376
msg = "Mask is wrong length. {}".format(e)
@@ -593,9 +617,7 @@ def _values_for_factorize(self):
593617
if arr.dtype.kind == "O":
594618
if HAS_UNCERTAINTIES and arr.size > 0:
595619
# Canonicalize uncertain NaNs and pd.NA to np.nan
596-
arr = arr.map(
597-
lambda x: self.dtype.na_value if x is pd.NA or unp.isnan(x) else x
598-
)
620+
arr = arr.map(lambda x: np.nan if x is pd.NA or unp.isnan(x) else x)
599621
return np.array(arr, copy=False), self.dtype.na_value
600622
return arr._values_for_factorize()
601623

@@ -627,7 +649,7 @@ def value_counts(self, dropna=True):
627649
nafilt = data.map(lambda x: x is pd.NA or unp.isnan(x))
628650
else:
629651
nafilt = pd.isna(data)
630-
na_value = self.dtype.na_value
652+
na_value_for_index = pd.NA
631653
data = data[~nafilt]
632654
if HAS_UNCERTAINTIES and data.dtype.kind == "O":
633655
# This is a work-around for unhashable UFloats
@@ -643,7 +665,7 @@ def value_counts(self, dropna=True):
643665
array = [data_list.count(item) for item in index]
644666

645667
if not dropna:
646-
index.append(na_value)
668+
index.append(na_value_for_index)
647669
array.append(nafilt.sum())
648670

649671
return Series(np.asarray(array), index=index)

pint_pandas/testsuite/test_pandas_extensiontests.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,22 @@
1010

1111
try:
1212
import uncertainties.unumpy as unp
13-
from uncertainties import ufloat, UFloat # noqa: F401
13+
from uncertainties import ufloat, UFloat
14+
from uncertainties.core import AffineScalarFunc # noqa: F401
15+
16+
def AffineScalarFunc__hash__(self):
17+
if not self._linear_part.expanded():
18+
self._linear_part.expand()
19+
combo = tuple(iter(self._linear_part.linear_combo.items()))
20+
if len(combo) > 1 or combo[0][1] != 1.0:
21+
return hash(combo)
22+
# The unique value that comes from a unique variable (which it also hashes to)
23+
return id(combo[0][0])
24+
25+
AffineScalarFunc.__hash__ = AffineScalarFunc__hash__
1426

15-
HAS_UNCERTAINTIES = True
1627
_ufloat_nan = ufloat(np.nan, 0)
28+
HAS_UNCERTAINTIES = True
1729
except ImportError:
1830
unp = np
1931
HAS_UNCERTAINTIES = False
@@ -170,8 +182,8 @@ def dtype():
170182

171183

172184
_base_numeric_dtypes = [float, int]
173-
_all_numeric_dtypes = (
174-
_base_numeric_dtypes + [] if HAS_UNCERTAINTIES else [np.complex128]
185+
_all_numeric_dtypes = _base_numeric_dtypes + (
186+
[] if HAS_UNCERTAINTIES else [np.complex128]
175187
)
176188

177189

@@ -650,23 +662,9 @@ def _get_exception(self, data, op_name):
650662
if op_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]:
651663
return op_name, TypeError
652664
if op_name in ["__pow__", "__rpow__"]:
653-
return DimensionalityError
654-
complex128_dtype = pd.core.dtypes.dtypes.NumpyEADtype("complex128")
655-
if (
656-
(isinstance(obj, pd.Series) and obj.dtype == complex128_dtype)
657-
or (
658-
isinstance(obj, pd.DataFrame)
659-
and any([dtype == complex128_dtype for dtype in obj.dtypes])
660-
)
661-
or (isinstance(other, pd.Series) and other.dtype == complex128_dtype)
662-
or (
663-
isinstance(other, pd.DataFrame)
664-
and any([dtype == complex128_dtype for dtype in other.dtypes])
665-
)
666-
):
667-
if op_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]:
668-
return TypeError
669-
return super()._get_expected_exception(op_name, obj, other)
665+
return op_name, DimensionalityError
666+
667+
return op_name, None
670668

671669
# With Pint 0.21, series and scalar need to have compatible units for
672670
# the arithmetic to work
@@ -717,7 +715,9 @@ def test_divmod(self, data, USE_UNCERTAINTIES):
717715
self._check_divmod_op(1 * ureg.Mm, ops.rdivmod, ser)
718716

719717
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
720-
def test_divmod_series_array(self, data, data_for_twos):
718+
def test_divmod_series_array(self, data, data_for_twos, USE_UNCERTAINTIES):
719+
if USE_UNCERTAINTIES:
720+
pytest.skip(reason="uncertainties does not implement divmod")
721721
ser = pd.Series(data)
722722
self._check_divmod_op(ser, divmod, data)
723723

@@ -727,12 +727,6 @@ def test_divmod_series_array(self, data, data_for_twos):
727727
other = pd.Series(other)
728728
self._check_divmod_op(other, ops.rdivmod, ser)
729729

730-
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
731-
def test_divmod_series_array(self, data, data_for_twos, USE_UNCERTAINTIES):
732-
if USE_UNCERTAINTIES:
733-
pytest.skip(reason="uncertainties does not implement divmod")
734-
super().test_divmod_series_array(data, data_for_twos)
735-
736730

737731
class TestComparisonOps(base.BaseComparisonOpsTests):
738732
def _compare_other(self, s, data, op_name, other):
@@ -871,16 +865,6 @@ def test_reduce_series(
871865
warnings.simplefilter("ignore", RuntimeWarning)
872866
self.check_reduce(s, op_name, skipna)
873867

874-
@pytest.mark.parametrize("skipna", [True, False])
875-
def test_reduce_series_xx(self, data, all_numeric_reductions, skipna):
876-
op_name = all_numeric_reductions
877-
s = pd.Series(data)
878-
879-
# min/max with empty produce numpy warnings
880-
with warnings.catch_warnings():
881-
warnings.simplefilter("ignore", RuntimeWarning)
882-
self.check_reduce(s, op_name, skipna)
883-
884868

885869
class TestBooleanReduce(base.BaseBooleanReduceTests):
886870
def check_reduce(self, s, op_name, skipna):
@@ -922,7 +906,18 @@ class TestSetitem(base.BaseSetitemTests):
922906
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
923907
def test_setitem_scalar_key_sequence_raise(self, data):
924908
# This can be removed when https://github.com/pandas-dev/pandas/pull/54441 is accepted
925-
base.BaseSetitemTests.test_setitem_scalar_key_sequence_raise(self, data)
909+
arr = data[:5].copy()
910+
with pytest.raises((ValueError, TypeError)):
911+
arr[0] = arr[[0, 1]]
912+
913+
def test_setitem_invalid(self, data, invalid_scalar):
914+
# This can be removed when https://github.com/pandas-dev/pandas/pull/54441 is accepted
915+
msg = "" # messages vary by subclass, so we do not test it
916+
with pytest.raises((ValueError, TypeError), match=msg):
917+
data[0] = invalid_scalar
918+
919+
with pytest.raises((ValueError, TypeError), match=msg):
920+
data[:] = invalid_scalar
926921

927922
@pytest.mark.parametrize("numeric_dtype", _base_numeric_dtypes, indirect=True)
928923
def test_setitem_2d_values(self, data):

0 commit comments

Comments
 (0)