From 8c530dda43f46568eda30990698f5aca05900d34 Mon Sep 17 00:00:00 2001 From: Richard Gowers Date: Thu, 12 Oct 2017 13:16:26 +0100 Subject: [PATCH] BUG: adds validation for boolean keywords in DataFrame methods ENH: Adds util._validators.validate_keywords_as_bool decorator --- doc/source/whatsnew/v0.22.0.txt | 2 +- pandas/core/frame.py | 40 +++++++++++++++++++++-------- pandas/tests/frame/test_validate.py | 8 ++++++ pandas/util/_validators.py | 31 ++++++++++++++++++++++ 4 files changed, 70 insertions(+), 11 deletions(-) diff --git a/doc/source/whatsnew/v0.22.0.txt b/doc/source/whatsnew/v0.22.0.txt index ccaa408603333..d177574ad3fdc 100644 --- a/doc/source/whatsnew/v0.22.0.txt +++ b/doc/source/whatsnew/v0.22.0.txt @@ -162,5 +162,5 @@ Other ^^^^^ - Improved error message when attempting to use a Python keyword as an identifier in a numexpr query (:issue:`18221`) -- +- Added checking of boolean kwargs in DataFrame methods (:issue:`16714`) - diff --git a/pandas/core/frame.py b/pandas/core/frame.py index f3137c1edf2af..26e16ece3fca9 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -84,7 +84,7 @@ from pandas.compat.numpy import function as nv from pandas.util._decorators import (Appender, Substitution, rewrite_axis_style_signature) -from pandas.util._validators import (validate_bool_kwarg, +from pandas.util._validators import (validate_keywords_as_bool, validate_axis_style_args) from pandas.core.indexes.period import PeriodIndex @@ -746,6 +746,7 @@ def iterrows(self): s = klass(v, index=columns, name=k) yield k, s + @validate_keywords_as_bool('index') def itertuples(self, index=True, name="Pandas"): """ Iterate over DataFrame rows as namedtuples, with index value as first @@ -1000,6 +1001,7 @@ def to_dict(self, orient='dict', into=dict): else: raise ValueError("orient '%s' not understood" % orient) + @validate_keywords_as_bool('verbose', 'reauth') def to_gbq(self, destination_table, project_id, chunksize=10000, verbose=True, reauth=False, if_exists='fail', private_key=None): """Write a DataFrame to a Google BigQuery table. @@ -1181,6 +1183,7 @@ def from_records(cls, data, index=None, exclude=None, columns=None, return cls(mgr) + @validate_keywords_as_bool('convert_datetime64') def to_records(self, index=True, convert_datetime64=True): """ Convert DataFrame to record array. Index will be put in the @@ -1426,6 +1429,7 @@ def to_panel(self): return self._constructor_expanddim(new_mgr) + @validate_keywords_as_bool('index') def to_csv(self, path_or_buf=None, sep=",", na_rep='', float_format=None, columns=None, header=True, index=True, index_label=None, mode='w', encoding=None, compression=None, quoting=None, @@ -1865,6 +1869,7 @@ def _sizeof_fmt(num, size_qualifier): _sizeof_fmt(mem_usage, size_qualifier)) _put_lines(buf, lines) + @validate_keywords_as_bool('index', 'deep') def memory_usage(self, index=True, deep=False): """Memory usage of DataFrame columns. @@ -2215,6 +2220,7 @@ def _getitem_frame(self, key): raise ValueError('Must pass DataFrame with boolean values only') return self.where(key) + @validate_keywords_as_bool('inplace') def query(self, expr, inplace=False, **kwargs): """Query the columns of a frame with a boolean expression. @@ -2286,7 +2292,6 @@ def query(self, expr, inplace=False, **kwargs): >>> df.query('a > b') >>> df[df.a > df.b] # same result as the previous expression """ - inplace = validate_bool_kwarg(inplace, 'inplace') if not isinstance(expr, compat.string_types): msg = "expr must be a string to be evaluated, {0} given" raise ValueError(msg.format(type(expr))) @@ -2306,6 +2311,7 @@ def query(self, expr, inplace=False, **kwargs): else: return new_data + @validate_keywords_as_bool('inplace') def eval(self, expr, inplace=False, **kwargs): """Evaluate an expression in the context of the calling DataFrame instance. @@ -2352,7 +2358,6 @@ def eval(self, expr, inplace=False, **kwargs): """ from pandas.core.computation.eval import eval as _eval - inplace = validate_bool_kwarg(inplace, 'inplace') resolvers = kwargs.pop('resolvers', None) kwargs['level'] = kwargs.pop('level', 0) + 1 if resolvers is None: @@ -2589,6 +2594,7 @@ def _set_item(self, key, value): if len(self): self._check_setitem_copy() + @validate_keywords_as_bool('allow_duplicates') def insert(self, loc, column, value, allow_duplicates=False): """ Insert column into DataFrame at specified location. @@ -2905,6 +2911,7 @@ def _reindex_multi(self, axes, copy, fill_value): copy=copy, fill_value=fill_value) + @validate_keywords_as_bool('copy') @Appender(_shared_docs['align'] % _shared_doc_kwargs) def align(self, other, join='outer', axis=None, level=None, copy=True, fill_value=None, method=None, limit=None, fill_axis=0, @@ -3037,6 +3044,7 @@ def shift(self, periods=1, freq=None, axis=0): return super(DataFrame, self).shift(periods=periods, freq=freq, axis=axis) + @validate_keywords_as_bool('drop', 'append', 'inplace', 'verify_integrity') def set_index(self, keys, drop=True, append=False, inplace=False, verify_integrity=False): """ @@ -3102,7 +3110,6 @@ def set_index(self, keys, drop=True, append=False, inplace=False, ------- dataframe : DataFrame """ - inplace = validate_bool_kwarg(inplace, 'inplace') if not isinstance(keys, list): keys = [keys] @@ -3164,6 +3171,7 @@ def set_index(self, keys, drop=True, append=False, inplace=False, if not inplace: return frame + @validate_keywords_as_bool('drop', 'inplace') def reset_index(self, level=None, drop=False, inplace=False, col_level=0, col_fill=''): """ @@ -3300,7 +3308,6 @@ class max type lion mammal 80.5 run monkey mammal NaN jump """ - inplace = validate_bool_kwarg(inplace, 'inplace') if inplace: new_obj = self else: @@ -3399,6 +3406,7 @@ def notna(self): def notnull(self): return super(DataFrame, self).notnull() + @validate_keywords_as_bool('inplace') def dropna(self, axis=0, how='any', thresh=None, subset=None, inplace=False): """ @@ -3468,7 +3476,6 @@ def dropna(self, axis=0, how='any', thresh=None, subset=None, 1 3.0 4.0 NaN 1 """ - inplace = validate_bool_kwarg(inplace, 'inplace') if isinstance(axis, (tuple, list)): result = self for ax in axis: @@ -3508,6 +3515,7 @@ def dropna(self, axis=0, how='any', thresh=None, subset=None, else: return result + @validate_keywords_as_bool('inplace') def drop_duplicates(self, subset=None, keep='first', inplace=False): """ Return DataFrame with duplicate rows removed, optionally only @@ -3529,7 +3537,6 @@ def drop_duplicates(self, subset=None, keep='first', inplace=False): ------- deduplicated : DataFrame """ - inplace = validate_bool_kwarg(inplace, 'inplace') duplicated = self.duplicated(subset, keep=keep) if inplace: @@ -3585,10 +3592,10 @@ def f(vals): # ---------------------------------------------------------------------- # Sorting + @validate_keywords_as_bool('ascending', 'inplace') @Appender(_shared_docs['sort_values'] % _shared_doc_kwargs) def sort_values(self, by, axis=0, ascending=True, inplace=False, kind='quicksort', na_position='last'): - inplace = validate_bool_kwarg(inplace, 'inplace') axis = self._get_axis_number(axis) other_axis = 0 if axis == 1 else 1 @@ -3640,15 +3647,14 @@ def sort_values(self, by, axis=0, ascending=True, inplace=False, else: return self._constructor(new_data).__finalize__(self) + @validate_keywords_as_bool('ascending', 'inplace', 'sort_remaining') @Appender(_shared_docs['sort_index'] % _shared_doc_kwargs) def sort_index(self, axis=0, level=None, ascending=True, inplace=False, kind='quicksort', na_position='last', sort_remaining=True, by=None): - # TODO: this can be combined with Series.sort_index impl as # almost identical - inplace = validate_bool_kwarg(inplace, 'inplace') # 10726 if by is not None: warnings.warn("by argument to sort_index is deprecated, " @@ -4019,6 +4025,7 @@ def _flex_compare_frame(self, other, func, str_rep, level, try_cast=True): return self._compare_frame_evaluate(other, func, str_rep, try_cast=try_cast) + @validate_keywords_as_bool('overwrite') def combine(self, other, func, fill_value=None, overwrite=True): """ Add two DataFrame objects and do not propagate NaN values, so if for a @@ -4152,6 +4159,7 @@ def combiner(x, y, needs_i8_conversion=False): return self.combine(other, combiner, overwrite=False) + @validate_keywords_as_bool('overwrite', 'raise_conflict') def update(self, other, join='left', overwrite=True, filter_func=None, raise_conflict=False): """ @@ -4742,6 +4750,7 @@ def aggregate(self, func, axis=0, *args, **kwargs): agg = aggregate + @validate_keywords_as_bool('broadcast', 'raw') def apply(self, func, axis=0, broadcast=False, raw=False, reduce=None, args=(), **kwds): """ @@ -5041,6 +5050,7 @@ def infer(x): # ---------------------------------------------------------------------- # Merging / joining methods + @validate_keywords_as_bool('ignore_index', 'verify_integrity') def append(self, other, ignore_index=False, verify_integrity=False): """ Append rows of `other` to the end of this frame, returning a new @@ -5164,6 +5174,7 @@ def append(self, other, ignore_index=False, verify_integrity=False): return concat(to_concat, ignore_index=ignore_index, verify_integrity=verify_integrity) + @validate_keywords_as_bool('sort') def join(self, other, on=None, how='left', lsuffix='', rsuffix='', sort=False): """ @@ -5524,6 +5535,7 @@ def cov(self, min_periods=None): return self._constructor(baseCov, index=idx, columns=cols) + @validate_keywords_as_bool('drop') def corrwith(self, other, axis=0, drop=False): """ Compute pairwise correlation between rows or columns of two DataFrame @@ -5577,6 +5589,7 @@ def corrwith(self, other, axis=0, drop=False): # ---------------------------------------------------------------------- # ndarray-like stats methods + @validate_keywords_as_bool('numeric_only') def count(self, axis=0, level=None, numeric_only=False): """ Return Series with number of non-NA/null observations over requested @@ -5740,6 +5753,7 @@ def f(x): return Series(result, index=labels) + @validate_keywords_as_bool('dropna') def nunique(self, axis=0, dropna=True): """ Return Series with number of distinct observations over requested @@ -5771,6 +5785,7 @@ def nunique(self, axis=0, dropna=True): """ return self.apply(Series.nunique, axis=axis, dropna=dropna) + @validate_keywords_as_bool('dropna') def idxmin(self, axis=0, skipna=True): """ Return index of first occurrence of minimum over requested axis. @@ -5802,6 +5817,7 @@ def idxmin(self, axis=0, skipna=True): result = [index[i] if i >= 0 else np.nan for i in indices] return Series(result, index=self._get_agg_axis(axis)) + @validate_keywords_as_bool('dropna') def idxmax(self, axis=0, skipna=True): """ Return index of first occurrence of maximum over requested axis. @@ -5842,6 +5858,7 @@ def _get_agg_axis(self, axis_num): else: raise ValueError('Axis must be 0 or 1 (got %r)' % axis_num) + @validate_keywords_as_bool('numeric_only') def mode(self, axis=0, numeric_only=False): """ Gets the mode(s) of each element along the axis selected. Adds a row @@ -5880,6 +5897,7 @@ def f(s): return data.apply(f, axis=axis) + @validate_keywords_as_bool('numeric_only') def quantile(self, q=0.5, axis=0, numeric_only=True, interpolation='linear'): """ @@ -5953,6 +5971,7 @@ def quantile(self, q=0.5, axis=0, numeric_only=True, return result + @validate_keywords_as_bool('copy') def to_timestamp(self, freq=None, how='start', axis=0, copy=True): """ Cast to DatetimeIndex of timestamps, at *beginning* of period @@ -5987,6 +6006,7 @@ def to_timestamp(self, freq=None, how='start', axis=0, copy=True): return self._constructor(new_data) + @validate_keywords_as_bool('copy') def to_period(self, freq=None, axis=0, copy=True): """ Convert DataFrame from DatetimeIndex to PeriodIndex with desired diff --git a/pandas/tests/frame/test_validate.py b/pandas/tests/frame/test_validate.py index 2de0e866f6e70..fbe4505b0dc67 100644 --- a/pandas/tests/frame/test_validate.py +++ b/pandas/tests/frame/test_validate.py @@ -31,3 +31,11 @@ def test_validate_bool_args(self, dataframe, func, inplace): with tm.assert_raises_regex(ValueError, msg): getattr(dataframe, func)(**kwargs) + + @pytest.mark.parametrize('keyword', ('drop', 'append', 'inplace', + 'verify_integrity')) + def test_set_index_validation(self, dataframe, keyword): + msg = 'For argument "{}" expected type bool'.format(keyword) + kw = {keyword: 'yes please'} + with tm.assert_raises_regex(ValueError, msg): + dataframe.set_index('b', **kw) diff --git a/pandas/util/_validators.py b/pandas/util/_validators.py index 728db6af5558b..e18c32494570e 100644 --- a/pandas/util/_validators.py +++ b/pandas/util/_validators.py @@ -2,6 +2,7 @@ Module that contains many useful utilities for validating data or function arguments """ +import functools import warnings from pandas.core.dtypes.common import is_bool @@ -320,3 +321,33 @@ def validate_axis_style_args(data, args, kwargs, arg_name, method_name): msg = "Cannot specify all of '{}', 'index', 'columns'." raise TypeError(msg.format(arg_name)) return out + + +def validate_keywords_as_bool(*keywords): + """For a list of keywords, ensure all are bool + + Usage + ----- + Designed to be used as decorator around methods to check many + keywords at once: + + @validate_keywords_as_bool('inplace', 'append') + def set_index(self, keys, inplace=False, append=False): + etc. + + See Also + -------- + validate_bool_kwargs + + """ + keywords = set(keywords) + + def validate_kwargs(func): + @functools.wraps(func) + def validator(*args, **kwargs): + # only validate present keywords + for kw in keywords.intersection(kwargs.keys()): + validate_bool_kwarg(kwargs[kw], kw) + return func(*args, **kwargs) + return validator + return validate_kwargs