Skip to content

Commit f3a0fea

Browse files
committed
typing
1 parent 55e3bff commit f3a0fea

File tree

1 file changed

+68
-29
lines changed

1 file changed

+68
-29
lines changed

pandas/core/missing.py

+68-29
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
Routines for filling missing data.
33
"""
44
from functools import partial
5-
from typing import Any, List, Optional, Set, Union
5+
from typing import Any, Callable, List, Optional, Set, Tuple, Union
66

77
import numpy as np
88

99
from pandas._libs import algos, lib
10-
from pandas._typing import ArrayLike, Axis, DtypeObj
10+
from pandas._typing import ArrayLike, Axis, DtypeObj, IndexLabel, Scalar
1111
from pandas.compat._optional import import_optional_dependency
1212

1313
from pandas.core.dtypes.cast import infer_dtype_from_array
@@ -20,7 +20,9 @@
2020
from pandas.core.dtypes.missing import isna
2121

2222

23-
def mask_missing(arr: ArrayLike, values_to_mask) -> np.ndarray:
23+
def mask_missing(
24+
arr: ArrayLike, values_to_mask: Union[List, Tuple, Scalar]
25+
) -> np.ndarray:
2426
"""
2527
Return a masking array of same size/shape as arr
2628
with entries equaling any member of values_to_mask set to True
@@ -58,7 +60,7 @@ def mask_missing(arr: ArrayLike, values_to_mask) -> np.ndarray:
5860
return mask
5961

6062

61-
def clean_fill_method(method, allow_nearest: bool = False):
63+
def clean_fill_method(method: str, allow_nearest: bool = False) -> Optional[str]:
6264
# asfreq is compat for resampling
6365
if method in [None, "asfreq"]:
6466
return None
@@ -117,7 +119,7 @@ def clean_interp_method(method: str, **kwargs) -> str:
117119
return method
118120

119121

120-
def find_valid_index(values, how: str):
122+
def find_valid_index(values: ArrayLike, how: str) -> Optional[int]:
121123
"""
122124
Retrieves the index of the first valid value.
123125
@@ -218,8 +220,13 @@ def interpolate_1d(
218220

219221
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
220222
all_nans = set(np.flatnonzero(invalid))
221-
start_nans = set(range(find_valid_index(yvalues, "first")))
222-
end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid)))
223+
224+
start_nan_idx = find_valid_index(yvalues, "first")
225+
start_nans = set() if start_nan_idx is None else set(range(start_nan_idx))
226+
227+
end_nan_idx = find_valid_index(yvalues, "last")
228+
end_nans = set() if end_nan_idx is None else set(range(1 + end_nan_idx, len(valid)))
229+
223230
mid_nans = all_nans - start_nans - end_nans
224231

225232
# Like the sets above, preserve_nans contains indices of invalid values,
@@ -406,7 +413,13 @@ def _from_derivatives(xi, yi, x, order=None, der=0, extrapolate=False):
406413
return m(x)
407414

408415

409-
def _akima_interpolate(xi, yi, x, der=0, axis=0):
416+
def _akima_interpolate(
417+
xi: ArrayLike,
418+
yi: ArrayLike,
419+
x: Union[Scalar, ArrayLike],
420+
der: Optional[int] = 0,
421+
axis: Optional[int] = 0,
422+
) -> Union[Scalar, ArrayLike]:
410423
"""
411424
Convenience function for akima interpolation.
412425
xi and yi are arrays of values used to approximate some function f,
@@ -449,7 +462,14 @@ def _akima_interpolate(xi, yi, x, der=0, axis=0):
449462
return P(x, nu=der)
450463

451464

452-
def _cubicspline_interpolate(xi, yi, x, axis=0, bc_type="not-a-knot", extrapolate=None):
465+
def _cubicspline_interpolate(
466+
xi: ArrayLike,
467+
yi: ArrayLike,
468+
x: Union[ArrayLike, Scalar],
469+
axis: Optional[int] = 0,
470+
bc_type: Union[str, Tuple] = "not-a-knot",
471+
extrapolate: Optional[Union[bool, str]] = None,
472+
) -> Union[ArrayLike, Scalar]:
453473
"""
454474
Convenience function for cubic spline data interpolator.
455475
@@ -557,6 +577,8 @@ def _interpolate_with_limit_area(
557577
first = find_valid_index(values, "first")
558578
last = find_valid_index(values, "last")
559579

580+
assert first is not None and last is not None
581+
560582
values = interpolate_2d(
561583
values,
562584
method=method,
@@ -574,12 +596,12 @@ def _interpolate_with_limit_area(
574596

575597

576598
def interpolate_2d(
577-
values,
599+
values: np.ndarray,
578600
method: str = "pad",
579601
axis: Axis = 0,
580602
limit: Optional[int] = None,
581603
limit_area: Optional[str] = None,
582-
):
604+
) -> np.ndarray:
583605
"""
584606
Perform an actual interpolation of values, values will be make 2-d if
585607
needed fills inplace, returns the result.
@@ -625,7 +647,10 @@ def interpolate_2d(
625647
raise AssertionError("cannot interpolate on a ndim == 1 with axis != 0")
626648
values = values.reshape(tuple((1,) + values.shape))
627649

628-
method = clean_fill_method(method)
650+
method_cleaned = clean_fill_method(method)
651+
assert isinstance(method_cleaned, str)
652+
method = method_cleaned
653+
629654
tvalues = transf(values)
630655
if method == "pad":
631656
result = _pad_2d(tvalues, limit=limit)
@@ -644,7 +669,9 @@ def interpolate_2d(
644669
return result
645670

646671

647-
def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool):
672+
def _cast_values_for_fillna(
673+
values: ArrayLike, dtype: DtypeObj, has_mask: bool
674+
) -> ArrayLike:
648675
"""
649676
Cast values to a dtype that algos.pad and algos.backfill can handle.
650677
"""
@@ -663,34 +690,41 @@ def _cast_values_for_fillna(values, dtype: DtypeObj, has_mask: bool):
663690
return values
664691

665692

666-
def _fillna_prep(values, mask=None):
693+
def _fillna_prep(
694+
values: np.ndarray, mask: Optional[np.ndarray] = None
695+
) -> Tuple[np.ndarray, np.ndarray]:
667696
# boilerplate for _pad_1d, _backfill_1d, _pad_2d, _backfill_2d
668-
dtype = values.dtype
669697

670698
has_mask = mask is not None
671-
if not has_mask:
672-
# This needs to occur before datetime/timedeltas are cast to int64
673-
mask = isna(values)
674699

675-
values = _cast_values_for_fillna(values, dtype, has_mask)
700+
# This needs to occur before datetime/timedeltas are cast to int64
701+
mask = isna(values) if mask is None else mask
676702

703+
values = _cast_values_for_fillna(values, values.dtype, has_mask)
677704
mask = mask.view(np.uint8)
705+
678706
return values, mask
679707

680708

681-
def _pad_1d(values, limit=None, mask=None):
709+
def _pad_1d(
710+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
711+
):
682712
values, mask = _fillna_prep(values, mask)
683713
algos.pad_inplace(values, mask, limit=limit)
684714
return values
685715

686716

687-
def _backfill_1d(values, limit=None, mask=None):
717+
def _backfill_1d(
718+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
719+
):
688720
values, mask = _fillna_prep(values, mask)
689721
algos.backfill_inplace(values, mask, limit=limit)
690722
return values
691723

692724

693-
def _pad_2d(values, limit=None, mask=None):
725+
def _pad_2d(
726+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
727+
):
694728
values, mask = _fillna_prep(values, mask)
695729

696730
if np.all(values.shape):
@@ -701,7 +735,9 @@ def _pad_2d(values, limit=None, mask=None):
701735
return values
702736

703737

704-
def _backfill_2d(values, limit=None, mask=None):
738+
def _backfill_2d(
739+
values: np.ndarray, limit: Optional[int] = None, mask: Optional[np.ndarray] = None
740+
):
705741
values, mask = _fillna_prep(values, mask)
706742

707743
if np.all(values.shape):
@@ -715,16 +751,19 @@ def _backfill_2d(values, limit=None, mask=None):
715751
_fill_methods = {"pad": _pad_1d, "backfill": _backfill_1d}
716752

717753

718-
def get_fill_func(method):
719-
method = clean_fill_method(method)
720-
return _fill_methods[method]
754+
def get_fill_func(method: str) -> Callable:
755+
method_cleaned = clean_fill_method(method)
756+
assert isinstance(method_cleaned, str)
757+
return _fill_methods[method_cleaned]
721758

722759

723-
def clean_reindex_fill_method(method):
760+
def clean_reindex_fill_method(method: str):
724761
return clean_fill_method(method, allow_nearest=True)
725762

726763

727-
def _interp_limit(invalid, fw_limit, bw_limit):
764+
def _interp_limit(
765+
invalid: np.ndarray, fw_limit: Optional[int], bw_limit: Optional[int]
766+
) -> Set[IndexLabel]:
728767
"""
729768
Get indexers of values that won't be filled
730769
because they exceed the limits.
@@ -789,7 +828,7 @@ def inner(invalid, limit):
789828
return f_idx & b_idx
790829

791830

792-
def _rolling_window(a: np.ndarray, window: int):
831+
def _rolling_window(a: np.ndarray, window: int) -> np.ndarray:
793832
"""
794833
[True, True, False, True, False], 2 ->
795834

0 commit comments

Comments
 (0)