diff --git a/pandas/core/arrays/set.py b/pandas/core/arrays/set.py new file mode 100644 index 0000000000000..d4a307270c74a --- /dev/null +++ b/pandas/core/arrays/set.py @@ -0,0 +1,443 @@ +import sys +import warnings +import copy +import numpy as np + +import operator + +from pandas import Series + +# from pandas._libs.lib import infer_dtype +from pandas.util._decorators import cache_readonly +from pandas.compat import u, range +from pandas.compat import set_function_name + +from pandas.core.dtypes.common import ( + is_integer, is_scalar, is_object_dtype, is_list_like) +from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.dtypes.base import ExtensionDtype +from pandas.core.dtypes.dtypes import registry +from pandas.core.dtypes.missing import isna + +from pandas.io.formats.printing import ( + format_object_summary, format_object_attrs, default_pprint) + + +class SetDtype(ExtensionDtype): + """ + An ExtensionDtype to hold sets. + """ + name = 'set' + type = object + na_value = np.nan + + def __hash__(self): + # XXX: this needs to be part of the interface. + return hash(str(self)) + + def __eq__(self, other): + # TODO: test + if isinstance(other, type(self)): + return True + else: + return super(SetDtype, self).__eq__(other) + + @property + def _is_numeric(self): + return False + + def __repr__(self): + return self.name + + @classmethod + def construct_array_type(cls): + """Return the array type associated with this dtype + + Returns + ------- + type + """ + return SetArray + + @classmethod + def construct_from_string(cls, string): + """ + Construction from a string, raise a TypeError if not + possible + """ + if string == cls.name or string is set: + return cls() + raise TypeError("Cannot construct a '{}' from " + "'{}'".format(cls, string)) + +# @classmethod +# def is_dtype(cls, dtype): +# dtype = getattr(dtype, 'dtype', dtype) +# if (isinstance(dtype, compat.string_types) and +# dtype == 'set'): +# return True +# elif isinstance(dtype, cls): +# return True +# return isinstance(dtype, np.dtype) or dtype == 'set' + + +def to_set_array(values): + """ + Infer and return a set array of the values. + + Parameters + ---------- + values : 1D list-like of list-likes + + Returns + ------- + SetArray + + Raises + ------ + TypeError if incompatible types + """ + return SetArray(values, copy=False) + + +def coerce_to_array(values, mask=None, copy=False): + """ + Coerce the input values array to numpy arrays with a mask + + Parameters + ---------- + values : 1D list-like + mask : boolean 1D array, optional + copy : boolean, default False + if True, copy the input + + Returns + ------- + tuple of (values, mask) + """ + + if isinstance(values, SetArray): + values, mask = values._data, values._mask + + if copy: + values = values.copy() + mask = mask.copy() + return values, mask + + values = np.array(values, copy=copy) + if not (is_object_dtype(values) or isna(values).all()): + raise TypeError("{} cannot be converted to a SetDtype".format( + values.dtype)) + + if mask is None: + mask = isna(values) + else: + assert len(mask) == len(values) + + if not values.ndim == 1: + raise TypeError("values must be a 1D list-like") + if not mask.ndim == 1: + raise TypeError("mask must be a 1D list-like") + + if mask.any(): + values = values.copy() + values[mask] = np.nan + + return values, mask + + +class SetArray(ExtensionArray, ExtensionOpsMixin): + """ + We represent a SetArray with 2 numpy arrays + - data: contains a numpy set array of object dtype + - mask: a boolean array holding a mask on the data, False is missing + """ + + @cache_readonly + def dtype(self): + return SetDtype() + + def __init__(self, values, mask=None, dtype=None, copy=False): + """ + Parameters + ---------- + values : 1D list-like / SetArray + mask : 1D list-like, optional + copy : bool, default False + + Returns + ------- + SetArray + """ + self._data, self._mask = coerce_to_array( + values, mask=mask, copy=copy) + + @classmethod + def _from_sequence(cls, scalars, dtype=None, copy=False): + # dtype is ignored + return cls(scalars, copy=copy) + + @classmethod + def _from_factorized(cls, values, original): + return cls(values) + + def __getitem__(self, item): + if is_integer(item): + if self._mask[item]: + return self.dtype.na_value + return self._data[item] + return type(self)(self._data[item], mask=self._mask[item]) + + def _coerce_to_ndarray(self): + """ + coerce to an ndarray of object dtype + """ + data = self._data + data[self._mask] = self._na_value + return data + + def __array__(self): + """ + the array interface, return values + """ + return self._coerce_to_ndarray() + + def __iter__(self): + """Iterate over elements of the array. + + """ + # This needs to be implemented so that pandas recognizes extension + # arrays as list-like. The default implementation makes successive + # calls to ``__getitem__``, which may be slower than necessary. + for i in range(len(self)): + if self._mask[i]: + yield self.dtype.na_value + else: + yield self._data[i] + + def _formatting_values(self): + # type: () -> np.ndarray + return self._coerce_to_ndarray() + + def take(self, indices, allow_fill=False, fill_value=None): + from pandas.core.algorithms import take + + if allow_fill and fill_value is None: + fill_value = self.dtype.na_value + + result = take(self._data, indices, fill_value=fill_value, + allow_fill=allow_fill) + return self._from_sequence(result) + + def copy(self, deep=False): + data, mask = self._data, self._mask + if deep: + data = copy.deepcopy(data) + mask = copy.deepcopy(mask) + else: + data = data.copy() + mask = mask.copy() + return type(self)(data, mask, copy=False) + + def __setitem__(self, key, value): + _is_scalar = is_scalar(value) + if _is_scalar: + value = [value] + value, mask = coerce_to_array(value) + + if _is_scalar: + value = value[0] + mask = mask[0] + + self._data[key] = value + self._mask[key] = mask + + def __len__(self): + return len(self._data) + + def __repr__(self): + """ + Return a string representation for this object. + + Invoked by unicode(df) in py2 only. Yields a Unicode String in both + py2/py3. + """ + klass = self.__class__.__name__ + data = format_object_summary(self, default_pprint, False) + attrs = format_object_attrs(self) + space = " " + + prepr = (u(",%s") % + space).join(u("%s=%s") % (k, v) for k, v in attrs) + + res = u("%s(%s%s)") % (klass, data, prepr) + + return res + + @property + def nbytes(self): + return self._data.nbytes + self._mask.nbytes + + def isna(self): + return self._mask + + @property + def _na_value(self): + return np.nan + + @classmethod + def _concat_same_type(cls, to_concat): + data = np.concatenate([x._data for x in to_concat]) + mask = np.concatenate([x._mask for x in to_concat]) + return cls(data, mask=mask) + + def astype(self, dtype, copy=True, errors='raise', fill_value=None): + """Cast to a NumPy array or SetArray with 'dtype'. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + copy : bool, default True + Whether to copy the data, even if not necessary. If False, + a copy is made only if the old dtype does not match the + new dtype. + + Returns + ------- + array : ndarray or SetArray + NumPy ndarray or SetArray with 'dtype' for its dtype. + + Raises + ------ + TypeError + if incompatible type with a SetDtype, equivalent of same_kind + casting + """ + + # if we are astyping to an existing IntegerDtype we can fastpath + if isinstance(dtype, SetDtype): + result = self._data.astype(dtype.type, + casting='same_kind', copy=False) + return type(self)(result, mask=self._mask, copy=False) + + # coerce + data = self._coerce_to_ndarray() + return data.astype(dtype, copy=False) + + @property + def _ndarray_values(self): + # type: () -> np.ndarray + """Internal pandas method for lossy conversion to a NumPy ndarray. + + This method is not part of the pandas interface. + + The expectation is that this is cheap to compute, and is primarily + used for interacting with our indexers. + """ + return self._data + + def fillna(self, value=None, method=None, limit=None): + # TODO: method/limit + res = self._data.copy() + res[self._mask] = [value] * self._mask.sum() + return type(self)(res, + mask=np.full_like(res, fill_value=False, dtype=bool), + copy=False) + + def dropna(self): + res = self._data[~self._mask] + return type(self)(res, + mask=np.full_like(res, fill_value=False, dtype=bool), + copy=False) + + def unique(self): + raise NotImplementedError + + def factorize(self): + raise NotImplementedError + + def argsort(self): + raise NotImplementedError + + def value_counts(self, dropna=True): + raise NotImplementedError + + def _values_for_argsort(self): + raise NotImplementedError + + @classmethod + def _create_comparison_method(cls, op): + def cmp_method(self, other): + + op_name = op.__name__ + mask = None + if isinstance(other, SetArray): + other, mask = other._data, other._mask + elif (isinstance(other, Series) + and isinstance(other.values, SetArray)): + other, mask = other.values._data, other.values._mask + elif isinstance(other, set) or (is_scalar(other) and isna(other)): + other = np.array([other] * len(self)) + elif is_list_like(other): + other = np.asarray(other) + if other.ndim > 0 and len(self) != len(other): + raise ValueError('Lengths must match to compare') + + mask = self._mask | mask if mask is not None else self._mask + result = np.full_like(self._data, fill_value=np.nan, dtype='O') + + # numpy will show a DeprecationWarning on invalid elementwise + # comparisons, this will raise in the future + with warnings.catch_warnings(record=True): + with np.errstate(all='ignore'): + result[~mask] = op(self._data[~mask], other[~mask]) + + result[mask] = True if op_name == 'ne' else False + return result.astype('bool') + + name = '__{name}__'.format(name=op.__name__) + return set_function_name(cmp_method, name, cls) + + @classmethod + def _create_arithmetic_method(cls, op): + def arithmetic_method(self, other): + + mask = None + if isinstance(other, SetArray): + other, mask = other._data, other._mask + elif isinstance(other, set) or (is_scalar(other) and isna(other)): + other = np.array([other] * len(self)) + elif is_list_like(other): + other = np.asarray(other) + # cannot use isnan due to numpy/numpy#9009 + mask = np.array([x is np.nan for x in other]) + if other.ndim > 0 and len(self) != len(other): + raise ValueError('Lengths must match to compare') + + mask = self._mask | mask if mask is not None else self._mask + result = np.full_like(self._data, fill_value=np.nan, dtype='O') + + with np.errstate(all='ignore'): + result[~mask] = op(self._data[~mask], other[~mask]) + + return type(self)(result, mask=mask, copy=False) + + name = '__{name}__'.format(name=op.__name__) + + def raiser(self, other): + raise NotImplementedError + if name != '__sub__': + return raiser + return set_function_name(arithmetic_method, name, cls) + + +SetArray._add_comparison_ops() +SetArray._add_arithmetic_ops() +SetArray.__or__ = SetArray._create_arithmetic_method(operator.__or__) +SetArray.__xor__ = SetArray._create_arithmetic_method(operator.__xor__) +SetArray.__and__ = SetArray._create_arithmetic_method(operator.__and__) + +module = sys.modules[__name__] +setattr(module, 'SetDtype', SetDtype) +registry.register(SetDtype) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index dc139a8e14f66..50135d18cc21f 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -1151,7 +1151,9 @@ def dispatch_to_extension_op(op, left, right): new_right = [new_right] new_right = list(new_right) elif is_extension_array_dtype(right) and type(left) != type(right): - new_right = list(new_right) + new_right = new_right.astype(left.dtype).values + elif is_extension_array_dtype(right): + new_right = right.values else: new_right = right @@ -1482,6 +1484,45 @@ def _bool_method_SERIES(cls, op, special): code duplication. """ + def dispatch_to_extension_op(op, left, right): + """ + Assume that left or right is a Series backed by an ExtensionArray, + apply the operator defined by op. + """ + + # The op calls will raise TypeError if the op is not defined + # on the ExtensionArray + # TODO(jreback) + # we need to listify to avoid ndarray, or non-same-type extension array + # dispatching + + if is_extension_array_dtype(left): + + new_left = left.values + if isinstance(right, np.ndarray): + + # handle numpy scalars, this is a PITA + # TODO(jreback) + new_right = lib.item_from_zerodim(right) + if is_scalar(new_right): + new_right = [new_right] + new_right = list(new_right) + elif is_extension_array_dtype(right) and type(left) != type(right): + new_right = new_right.astype(left.dtype).values + elif is_extension_array_dtype(right): + new_right = right.values + else: + new_right = right + + else: + new_left = left + new_right = right.values._data + + res_values = op(new_left, new_right) + res_name = get_op_result_name(left, right) + + return _construct_result(left, res_values, left.index, res_name) + def na_op(x, y): try: result = op(x, y) @@ -1516,12 +1557,20 @@ def na_op(x, y): def wrapper(self, other): is_self_int_dtype = is_integer_dtype(self.dtype) - self, other = _align_method_SERIES(self, other, align_asobject=True) + align_asobject = not (is_extension_array_dtype(self) or + is_extension_array_dtype(other)) + self, other = _align_method_SERIES(self, other, + align_asobject=align_asobject) if isinstance(other, ABCDataFrame): # Defer to DataFrame implementation; fail early return NotImplemented + elif (is_extension_array_dtype(self) + or is_extension_array_dtype(other)): + # TODO: should this include `not is_scalar(right)`? + return dispatch_to_extension_op(op, self, other) + elif isinstance(other, ABCSeries): name = get_op_result_name(self, other) is_other_int_dtype = is_integer_dtype(other.dtype) diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index f7bfdb8ec218a..de88e6dfdbef8 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -69,7 +69,8 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators): # ndarray & other series op_name = all_arithmetic_operators s = pd.Series(data) - self.check_opname(s, op_name, [s.iloc[0]] * len(s), exc=TypeError) + self.check_opname(s, op_name, pd.Series([s.iloc[0]] * len(s)), + exc=TypeError) def test_divmod(self, data): s = pd.Series(data) @@ -108,10 +109,10 @@ def _compare_other(self, s, data, op_name, other): def test_compare_scalar(self, data, all_compare_operators): op_name = all_compare_operators s = pd.Series(data) - self._compare_other(s, data, op_name, 0) + self._compare_other(s, data, op_name, data[0]) def test_compare_array(self, data, all_compare_operators): op_name = all_compare_operators s = pd.Series(data) - other = [0] * len(data) + other = pd.Series([data[0]] * len(data)) self._compare_other(s, data, op_name, other) diff --git a/pandas/tests/extension/set/__init__.py b/pandas/tests/extension/set/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/pandas/tests/extension/set/test_set.py b/pandas/tests/extension/set/test_set.py new file mode 100644 index 0000000000000..428fed4ecccb6 --- /dev/null +++ b/pandas/tests/extension/set/test_set.py @@ -0,0 +1,159 @@ +import numpy as np +import pandas.util.testing as tm +import pytest + +from pandas.tests.extension import base + +from pandas.core.arrays.set import SetDtype, SetArray + + +def make_string_sets(): + s = tm.makeStringSeries() + return s.index.map(set).values + + +def make_int_sets(): + s = tm.makeFloatSeries().astype(str).str.replace(r'\D', '') + return s.map(lambda x: set(map(int, x))).values + + +def make_data(): + return (list(make_string_sets()) + + [np.nan] + + list(make_int_sets()) + + [np.nan, None, set()]) + + +@pytest.fixture +def dtype(): + return SetDtype() + + +@pytest.fixture +def data(): + return SetArray(make_int_sets()) + + +@pytest.fixture +def data_missing(): + return SetArray([np.nan, {1}]) + + +@pytest.fixture +def data_repeated(data): + def gen(count): + for _ in range(count): + yield data + yield gen + + +# @pytest.fixture +# def data_for_sorting(dtype): +# return SetArray(...) + + +# @pytest.fixture +# def data_missing_for_sorting(dtype): +# return SetArray(...) + + +@pytest.fixture +def na_cmp(): + # we are np.nan + return lambda x, y: np.isnan(x) and np.isnan(y) + + +@pytest.fixture +def na_value(): + return np.nan + +# @pytest.fixture +# def data_for_grouping(dtype): +# return SetArray(...) + + +class TestDtype(base.BaseDtypeTests): + + def test_array_type_with_arg(self, data, dtype): + assert dtype.construct_array_type() is SetArray + + +class TestInterface(base.BaseInterfaceTests): + + def test_len(self, data): + assert len(data) == 30 + + +class TestConstructors(base.BaseConstructorsTests): + pass + + +class TestReshaping(base.BaseReshapingTests): + pass + + +class TestGetitem(base.BaseGetitemTests): + + @pytest.mark.skip(reason="Need to think about it.") + def test_take_non_na_fill_value(self, data_missing): + pass + + +class TestSetitem(base.BaseGetitemTests): + pass + + +class TestMissing(base.BaseMissingTests): + + def test_fillna_frame(self, data_missing): + pytest.skip('df.fillna does not dispatch to EA') + + def test_fillna_limit_pad(self): + pytest.skip('TODO') + + def test_fillna_limit_backfill(self): + pytest.skip('TODO') + + def test_fillna_series_method(self): + pytest.skip('TODO') + + def test_fillna_series(self): + pytest.skip('series.fillna does not dispatch to EA') + + +# # most methods (value_counts, unique, factorize) will not be for SetArray +# # rest still buggy +class TestMethods(base.BaseMethodsTests): + pass + + +class TestCasting(base.BaseCastingTests): + pass + + +class TestArithmeticOps(base.BaseArithmeticOpsTests): + + def check_opname(self, s, op_name, other, exc='ignored'): + op = self.get_op_from_name(op_name) + + self._check_op(s, op, other, + None if op_name == '__sub__' else NotImplementedError) + + def test_divmod(self, data): + pytest.skip('Not relevant') + + def test_error(self, data, all_arithmetic_operators): + pytest.skip('TODO') + + +class TestComparisonOps(base.BaseComparisonOpsTests): + + def _compare_other(self, s, data, op_name, other): + op = self.get_op_from_name(op_name) + result = op(s, other) + expected = s.combine(other, op) + self.assert_series_equal(result, expected) + +# # GroupBy won't be implemented for SetArray +# class TestGroupby(base.BaseGroupbyTests): +# pass