From 8401bb0e47ea8147cd59a8def4eb086f587fec5d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 26 Feb 2021 12:28:50 +0100 Subject: [PATCH 1/3] REF: move logic of 'block manager axis' into the BlockManager --- pandas/core/frame.py | 6 ++---- pandas/core/internals/managers.py | 7 +++++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 2c95e65c70899..3e6bd18de31bc 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -7621,12 +7621,10 @@ def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame: raise ValueError("periods must be an integer") periods = int(periods) - bm_axis = self._get_block_manager_axis(axis) - - if bm_axis == 0 and periods != 0: + if axis == 1 and periods != 0: return self - self.shift(periods, axis=axis) - new_data = self._mgr.diff(n=periods, axis=bm_axis) + new_data = self._mgr.diff(n=periods, axis=axis) return self._constructor(new_data).__finalize__(self, "diff") # ---------------------------------------------------------------------- diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index e013a7f680d6f..7d65dbc14ec1e 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -234,6 +234,12 @@ def shape(self) -> Shape: def ndim(self) -> int: return len(self.axes) + @staticmethod + def _normalize_axis(axis): + # switch axis to follow BlockManager logic + axis = 1 if axis == 0 else 0 + return axis + def set_axis( self, axis: int, new_labels: Index, verify_integrity: bool = True ) -> None: @@ -601,6 +607,7 @@ def putmask(self, mask, new, align: bool = True): ) def diff(self, n: int, axis: int) -> BlockManager: + axis = self._normalize_axis(axis) return self.apply("diff", n=n, axis=axis) def interpolate(self, **kwargs) -> BlockManager: From 12b79cb879fba43b31a4299d495ffa955e490816 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 1 Mar 2021 14:35:56 +0100 Subject: [PATCH 2/3] add where and shift --- pandas/core/generic.py | 8 +++----- pandas/core/internals/array_manager.py | 10 ++++++---- pandas/core/internals/managers.py | 8 +++++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 4774045849eb6..399f1f5dfaab1 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -8948,8 +8948,6 @@ def _where( self._info_axis, axis=self._info_axis_number, copy=False ) - block_axis = self._get_block_manager_axis(axis) - if inplace: # we may have different type blocks come out of putmask, so # reconstruct the block manager @@ -8965,7 +8963,7 @@ def _where( cond=cond, align=align, errors=errors, - axis=block_axis, + axis=axis, ) result = self._constructor(new_data) return result.__finalize__(self) @@ -9276,9 +9274,9 @@ def shift( if freq is None: # when freq is None, data is shifted, index is not - block_axis = self._get_block_manager_axis(axis) + axis = self._get_axis_number(axis) new_data = self._mgr.shift( - periods=periods, axis=block_axis, fill_value=fill_value + periods=periods, axis=axis, fill_value=fill_value ) return self._constructor(new_data).__finalize__(self, method="shift") diff --git a/pandas/core/internals/array_manager.py b/pandas/core/internals/array_manager.py index 5001754017dda..add6458528436 100644 --- a/pandas/core/internals/array_manager.py +++ b/pandas/core/internals/array_manager.py @@ -385,7 +385,10 @@ def apply( return type(self)(result_arrays, new_axes) - def apply_with_block(self: T, f, align_keys=None, **kwargs) -> T: + def apply_with_block(self: T, f, align_keys=None, swap_axis=True, **kwargs) -> T: + # switch axis to follow BlockManager logic + if swap_axis and "axis" in kwargs and self.ndim == 2: + kwargs["axis"] = 1 if kwargs["axis"] == 0 else 0 align_keys = align_keys or [] aligned_args = {k: kwargs[k] for k in align_keys} @@ -467,7 +470,6 @@ def putmask(self, mask, new, align: bool = True): ) def diff(self, n: int, axis: int) -> ArrayManager: - axis = self._normalize_axis(axis) if axis == 1: # DataFrame only calls this for n=0, in which case performing it # with axis=0 is equivalent @@ -476,13 +478,13 @@ def diff(self, n: int, axis: int) -> ArrayManager: return self.apply(algos.diff, n=n, axis=axis) def interpolate(self, **kwargs) -> ArrayManager: - return self.apply_with_block("interpolate", **kwargs) + return self.apply_with_block("interpolate", swap_axis=False, **kwargs) def shift(self, periods: int, axis: int, fill_value) -> ArrayManager: if fill_value is lib.no_default: fill_value = None - if axis == 0 and self.ndim == 2: + if axis == 1 and self.ndim == 2: # TODO column-wise shift raise NotImplementedError diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index 7d65dbc14ec1e..d7bc220dccb0c 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -234,10 +234,10 @@ def shape(self) -> Shape: def ndim(self) -> int: return len(self.axes) - @staticmethod - def _normalize_axis(axis): + def _normalize_axis(self, axis): # switch axis to follow BlockManager logic - axis = 1 if axis == 0 else 0 + if self.ndim == 2: + axis = 1 if axis == 0 else 0 return axis def set_axis( @@ -573,6 +573,7 @@ def isna(self, func) -> BlockManager: return self.apply("apply", func=func) def where(self, other, cond, align: bool, errors: str, axis: int) -> BlockManager: + axis = self._normalize_axis(axis) if align: align_keys = ["other", "cond"] else: @@ -614,6 +615,7 @@ def interpolate(self, **kwargs) -> BlockManager: return self.apply("interpolate", **kwargs) def shift(self, periods: int, axis: int, fill_value) -> BlockManager: + axis = self._normalize_axis(axis) if fill_value is lib.no_default: fill_value = None From eac4c8c40d71de41c2532c717c1b61b78da58281 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 1 Mar 2021 18:59:51 +0100 Subject: [PATCH 3/3] ensure axis is int --- pandas/core/frame.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 950872d271d02..3b813bc86510b 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -7802,6 +7802,7 @@ def diff(self, periods: int = 1, axis: Axis = 0) -> DataFrame: raise ValueError("periods must be an integer") periods = int(periods) + axis = self._get_axis_number(axis) if axis == 1 and periods != 0: return self - self.shift(periods, axis=axis)