Skip to content

Commit b0611bc

Browse files
Improve sqlite3 types (#7641)
Read through the code in CPython and made the types more precise where possible. Co-authored-by: Sebastian Rittau <[email protected]>
1 parent 2e98c82 commit b0611bc

File tree

2 files changed

+130
-90
lines changed

2 files changed

+130
-90
lines changed

stdlib/sqlite3/dbapi2.pyi

Lines changed: 130 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
1+
import sqlite3
12
import sys
2-
from _typeshed import ReadableBuffer, Self, StrOrBytesPath
3+
from _typeshed import ReadableBuffer, Self, StrOrBytesPath, SupportsLenAndGetItem
4+
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
35
from datetime import date, datetime, time
46
from types import TracebackType
5-
from typing import Any, Callable, Generator, Iterable, Iterator, Protocol, TypeVar, overload
6-
from typing_extensions import Literal, final
7+
from typing import Any, Generic, Protocol, TypeVar, overload
8+
from typing_extensions import Literal, SupportsIndex, TypeAlias, final
79

810
_T = TypeVar("_T")
9-
_SqliteData = str | bytes | int | float | None
11+
_T_co = TypeVar("_T_co", covariant=True)
12+
_CursorT = TypeVar("_CursorT", bound=Cursor)
13+
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
14+
# Data that is passed through adapters can be of any type accepted by an adapter.
15+
_AdaptedInputData: TypeAlias = _SqliteData | Any
16+
# The Mapping must really be a dict, but making it invariant is too annoying.
17+
_Parameters: TypeAlias = SupportsLenAndGetItem[_AdaptedInputData] | Mapping[str, _AdaptedInputData]
18+
_SqliteOutputData: TypeAlias = str | bytes | int | float | None
19+
_Adapter: TypeAlias = Callable[[_T], _SqliteData]
20+
_Converter: TypeAlias = Callable[[bytes], Any]
1021

1122
paramstyle: str
1223
threadsafety: int
@@ -81,43 +92,39 @@ if sys.version_info >= (3, 7):
8192
SQLITE_SELECT: int
8293
SQLITE_TRANSACTION: int
8394
SQLITE_UPDATE: int
84-
adapters: Any
85-
converters: Any
95+
adapters: dict[tuple[type[Any], type[Any]], _Adapter[Any]]
96+
converters: dict[str, _Converter]
8697
sqlite_version: str
8798
version: str
8899

89-
# TODO: adapt needs to get probed
90-
def adapt(obj, protocol, alternate): ...
100+
# Can take or return anything depending on what's in the registry.
101+
@overload
102+
def adapt(__obj: Any, __proto: Any) -> Any: ...
103+
@overload
104+
def adapt(__obj: Any, __proto: Any, __alt: _T) -> Any | _T: ...
91105
def complete_statement(statement: str) -> bool: ...
92106

93107
if sys.version_info >= (3, 7):
94-
def connect(
95-
database: StrOrBytesPath,
96-
timeout: float = ...,
97-
detect_types: int = ...,
98-
isolation_level: str | None = ...,
99-
check_same_thread: bool = ...,
100-
factory: type[Connection] | None = ...,
101-
cached_statements: int = ...,
102-
uri: bool = ...,
103-
) -> Connection: ...
104-
108+
_DatabaseArg: TypeAlias = StrOrBytesPath
105109
else:
106-
def connect(
107-
database: bytes | str,
108-
timeout: float = ...,
109-
detect_types: int = ...,
110-
isolation_level: str | None = ...,
111-
check_same_thread: bool = ...,
112-
factory: type[Connection] | None = ...,
113-
cached_statements: int = ...,
114-
uri: bool = ...,
115-
) -> Connection: ...
110+
_DatabaseArg: TypeAlias = bytes | str
116111

112+
def connect(
113+
database: _DatabaseArg,
114+
timeout: float = ...,
115+
detect_types: int = ...,
116+
isolation_level: str | None = ...,
117+
check_same_thread: bool = ...,
118+
factory: type[Connection] | None = ...,
119+
cached_statements: int = ...,
120+
uri: bool = ...,
121+
) -> Connection: ...
117122
def enable_callback_tracebacks(__enable: bool) -> None: ...
123+
124+
# takes a pos-or-keyword argument because there is a C wrapper
118125
def enable_shared_cache(enable: int) -> None: ...
119-
def register_adapter(__type: type[_T], __caster: Callable[[_T], int | float | str | bytes]) -> None: ...
120-
def register_converter(__name: str, __converter: Callable[[bytes], Any]) -> None: ...
126+
def register_adapter(__type: type[_T], __caster: _Adapter[_T]) -> None: ...
127+
def register_converter(__name: str, __converter: _Converter) -> None: ...
121128

122129
if sys.version_info < (3, 8):
123130
class Cache:
@@ -126,7 +133,7 @@ if sys.version_info < (3, 8):
126133
def get(self, *args, **kwargs) -> None: ...
127134

128135
class _AggregateProtocol(Protocol):
129-
def step(self, value: int) -> object: ...
136+
def step(self, __value: int) -> object: ...
130137
def finalize(self) -> int: ...
131138

132139
class _SingleParamWindowAggregateClass(Protocol):
@@ -148,22 +155,44 @@ class _WindowAggregateClass(Protocol):
148155
def finalize(self) -> _SqliteData: ...
149156

150157
class Connection:
151-
DataError: Any
152-
DatabaseError: Any
153-
Error: Any
154-
IntegrityError: Any
155-
InterfaceError: Any
156-
InternalError: Any
157-
NotSupportedError: Any
158-
OperationalError: Any
159-
ProgrammingError: Any
160-
Warning: Any
161-
in_transaction: Any
162-
isolation_level: Any
158+
@property
159+
def DataError(self) -> type[sqlite3.DataError]: ...
160+
@property
161+
def DatabaseError(self) -> type[sqlite3.DatabaseError]: ...
162+
@property
163+
def Error(self) -> type[sqlite3.Error]: ...
164+
@property
165+
def IntegrityError(self) -> type[sqlite3.IntegrityError]: ...
166+
@property
167+
def InterfaceError(self) -> type[sqlite3.InterfaceError]: ...
168+
@property
169+
def InternalError(self) -> type[sqlite3.InternalError]: ...
170+
@property
171+
def NotSupportedError(self) -> type[sqlite3.NotSupportedError]: ...
172+
@property
173+
def OperationalError(self) -> type[sqlite3.OperationalError]: ...
174+
@property
175+
def ProgrammingError(self) -> type[sqlite3.ProgrammingError]: ...
176+
@property
177+
def Warning(self) -> type[sqlite3.Warning]: ...
178+
@property
179+
def in_transaction(self) -> bool: ...
180+
isolation_level: str | None # one of '', 'DEFERRED', 'IMMEDIATE' or 'EXCLUSIVE'
181+
@property
182+
def total_changes(self) -> int: ...
163183
row_factory: Any
164184
text_factory: Any
165-
total_changes: Any
166-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
185+
def __init__(
186+
self,
187+
database: _DatabaseArg,
188+
timeout: float = ...,
189+
detect_types: int = ...,
190+
isolation_level: str | None = ...,
191+
check_same_thread: bool = ...,
192+
factory: type[Connection] | None = ...,
193+
cached_statements: int = ...,
194+
uri: bool = ...,
195+
) -> None: ...
167196
def close(self) -> None: ...
168197
if sys.version_info >= (3, 11):
169198
def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ...
@@ -187,17 +216,21 @@ class Connection:
187216
self, __name: str, __num_params: int, __aggregate_class: Callable[[], _WindowAggregateClass] | None
188217
) -> None: ...
189218

190-
def create_collation(self, __name: str, __callback: Any) -> None: ...
219+
def create_collation(self, __name: str, __callback: Callable[[str, str], int | SupportsIndex] | None) -> None: ...
191220
if sys.version_info >= (3, 8):
192-
def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ...
221+
def create_function(
222+
self, name: str, narg: int, func: Callable[..., _SqliteData], *, deterministic: bool = ...
223+
) -> None: ...
193224
else:
194-
def create_function(self, name: str, num_params: int, func: Any) -> None: ...
225+
def create_function(self, name: str, num_params: int, func: Callable[..., _SqliteData]) -> None: ...
195226

196-
def cursor(self, cursorClass: type | None = ...) -> Cursor: ...
197-
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ...
198-
# TODO: please check in executemany() if seq_of_parameters type is possible like this
199-
def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ...
200-
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
227+
@overload
228+
def cursor(self, cursorClass: None = ...) -> Cursor: ...
229+
@overload
230+
def cursor(self, cursorClass: Callable[[], _CursorT]) -> _CursorT: ...
231+
def execute(self, sql: str, parameters: _Parameters = ...) -> Cursor: ...
232+
def executemany(self, __sql: str, __parameters: Iterable[_Parameters]) -> Cursor: ...
233+
def executescript(self, __sql_script: str) -> Cursor: ...
201234
def interrupt(self) -> None: ...
202235
def iterdump(self) -> Generator[str, None, None]: ...
203236
def rollback(self) -> None: ...
@@ -208,8 +241,8 @@ class Connection:
208241
def set_trace_callback(self, trace_callback: Callable[[str], object] | None) -> None: ...
209242
# enable_load_extension and load_extension is not available on python distributions compiled
210243
# without sqlite3 loadable extension support. see footnotes https://docs.python.org/3/library/sqlite3.html#f1
211-
def enable_load_extension(self, enabled: bool) -> None: ...
212-
def load_extension(self, path: str) -> None: ...
244+
def enable_load_extension(self, __enabled: bool) -> None: ...
245+
def load_extension(self, __name: str) -> None: ...
213246
if sys.version_info >= (3, 7):
214247
def backup(
215248
self,
@@ -226,29 +259,32 @@ class Connection:
226259
def serialize(self, *, name: str = ...) -> bytes: ...
227260
def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ...
228261

229-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
262+
def __call__(self, __sql: str) -> _Statement: ...
230263
def __enter__(self: Self) -> Self: ...
231264
def __exit__(
232265
self, __type: type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None
233266
) -> Literal[False]: ...
234267

235268
class Cursor(Iterator[Any]):
236-
arraysize: Any
237-
connection: Any
238-
description: Any
239-
lastrowid: Any
240-
row_factory: Any
241-
rowcount: int
242-
# TODO: Cursor class accepts exactly 1 argument
243-
# required type is sqlite3.Connection (which is imported as _Connection)
244-
# however, the name of the __init__ variable is unknown
245-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
269+
arraysize: int
270+
@property
271+
def connection(self) -> Connection: ...
272+
@property
273+
def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...] | None: ...
274+
@property
275+
def lastrowid(self) -> int | None: ...
276+
row_factory: Callable[[Cursor, Row[Any]], object] | None
277+
@property
278+
def rowcount(self) -> int: ...
279+
def __init__(self, __cursor: Connection) -> None: ...
246280
def close(self) -> None: ...
247-
def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor: ...
248-
def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ...
249-
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
281+
def execute(self: Self, __sql: str, __parameters: _Parameters = ...) -> Self: ...
282+
def executemany(self: Self, __sql: str, __seq_of_parameters: Iterable[_Parameters]) -> Self: ...
283+
def executescript(self, __sql_script: str) -> Cursor: ...
250284
def fetchall(self) -> list[Any]: ...
251285
def fetchmany(self, size: int | None = ...) -> list[Any]: ...
286+
# Returns either a row (as created by the row_factory) or None, but
287+
# putting None in the return annotation causes annoying false positives.
252288
def fetchone(self) -> Any: ...
253289
def setinputsizes(self, __sizes: object) -> None: ... # does nothing
254290
def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing
@@ -273,28 +309,37 @@ OptimizedUnicode = str
273309

274310
@final
275311
class PrepareProtocol:
276-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
312+
def __init__(self, *args: object, **kwargs: object) -> None: ...
277313

278314
class ProgrammingError(DatabaseError): ...
279315

280-
class Row:
281-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
282-
def keys(self): ...
283-
def __eq__(self, __other): ...
284-
def __ge__(self, __other): ...
285-
def __getitem__(self, __index): ...
286-
def __gt__(self, __other): ...
287-
def __hash__(self): ...
288-
def __iter__(self): ...
289-
def __le__(self, __other): ...
290-
def __len__(self): ...
291-
def __lt__(self, __other): ...
292-
def __ne__(self, __other): ...
316+
class Row(Generic[_T_co]):
317+
def __init__(self, __cursor: Cursor, __data: tuple[_T_co, ...]) -> None: ...
318+
def keys(self) -> list[str]: ...
319+
@overload
320+
def __getitem__(self, __index: int | str) -> _T_co: ...
321+
@overload
322+
def __getitem__(self, __index: slice) -> tuple[_T_co, ...]: ...
323+
def __hash__(self) -> int: ...
324+
def __iter__(self) -> Iterator[_T_co]: ...
325+
def __len__(self) -> int: ...
326+
# These return NotImplemented for anything that is not a Row.
327+
def __eq__(self, __other: object) -> bool: ...
328+
def __ge__(self, __other: object) -> bool: ...
329+
def __gt__(self, __other: object) -> bool: ...
330+
def __le__(self, __other: object) -> bool: ...
331+
def __lt__(self, __other: object) -> bool: ...
332+
def __ne__(self, __other: object) -> bool: ...
293333

294-
if sys.version_info < (3, 8):
334+
if sys.version_info >= (3, 8):
335+
@final
336+
class _Statement: ...
337+
338+
else:
295339
@final
296340
class Statement:
297341
def __init__(self, *args, **kwargs): ...
342+
_Statement: TypeAlias = Statement
298343

299344
class Warning(Exception): ...
300345

tests/stubtest_allowlists/win32-py310.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
sqlite3.Connection.enable_load_extension
2-
sqlite3.Connection.load_extension
3-
sqlite3.dbapi2.Connection.enable_load_extension
4-
sqlite3.dbapi2.Connection.load_extension
5-
61
# pathlib methods that exist on Windows, but always raise NotImplementedError,
72
# so are omitted from the stub
83
pathlib.WindowsPath.is_mount

0 commit comments

Comments
 (0)