Skip to content

Commit d84e1c1

Browse files
committed
TYP: Fix some typehints for ExtensionDtype
1 parent 3513f59 commit d84e1c1

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pandas/_typing.py

+3
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@
131131
# to maintain type information across generic functions and parametrization
132132
T = TypeVar("T")
133133

134+
# To parameterize on same ExtensionDtype
135+
E = TypeVar("E", bound=ExtensionDtype)
136+
134137
# used in decorators to preserve the signature of the function it decorates
135138
# see https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
136139
FuncType = Callable[..., Any]

pandas/core/dtypes/base.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from pandas._typing import (
1515
DtypeObj,
16+
E,
1617
type_t,
1718
)
1819
from pandas.errors import AbstractMethodError
@@ -151,7 +152,7 @@ def na_value(self) -> object:
151152
return np.nan
152153

153154
@property
154-
def type(self) -> type[Any]:
155+
def type(self) -> type_t[Any]:
155156
"""
156157
The scalar type for the array, e.g. ``int``
157158
@@ -209,7 +210,7 @@ def construct_array_type(cls) -> type_t[ExtensionArray]:
209210
raise NotImplementedError
210211

211212
@classmethod
212-
def construct_from_string(cls, string: str):
213+
def construct_from_string(cls, string: str) -> ExtensionDtype:
213214
r"""
214215
Construct this type from a string.
215216
@@ -364,7 +365,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
364365
return None
365366

366367

367-
def register_extension_dtype(cls: type[ExtensionDtype]) -> type[ExtensionDtype]:
368+
def register_extension_dtype(cls: type[E]) -> type[E]:
368369
"""
369370
Register an ExtensionType with pandas as class decorator.
370371
@@ -420,7 +421,7 @@ def register(self, dtype: type[ExtensionDtype]) -> None:
420421

421422
self.dtypes.append(dtype)
422423

423-
def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None:
424+
def find(self, dtype: type[E] | E | str) -> type[E] | E | ExtensionDtype | None:
424425
"""
425426
Parameters
426427
----------
@@ -431,10 +432,7 @@ def find(self, dtype: type[ExtensionDtype] | str) -> type[ExtensionDtype] | None
431432
return the first matching dtype, otherwise return None
432433
"""
433434
if not isinstance(dtype, str):
434-
dtype_type = dtype
435-
if not isinstance(dtype, type):
436-
dtype_type = type(dtype)
437-
if issubclass(dtype_type, ExtensionDtype):
435+
if isinstance(dtype, (ExtensionDtype, type(ExtensionDtype))):
438436
return dtype
439437

440438
return None

0 commit comments

Comments
 (0)