Skip to content

Commit b8d253f

Browse files
authored
Merge branch 'master' into 585
2 parents 69d08f2 + b0611bc commit b8d253f

File tree

2 files changed

+130
-91
lines changed

2 files changed

+130
-91
lines changed

stdlib/sqlite3/dbapi2.pyi

Lines changed: 130 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
1+
import sqlite3
12
import sys
2-
from _typeshed import ReadableBuffer, Self, StrOrBytesPath
3-
from collections.abc import Callable, Generator, Iterable, Iterator
3+
from _typeshed import ReadableBuffer, Self, StrOrBytesPath, SupportsLenAndGetItem
4+
from collections.abc import Callable, Generator, Iterable, Iterator, Mapping
45
from datetime import date, datetime, time
56
from types import TracebackType
6-
from typing import Any, Protocol, TypeVar, overload
7-
from typing_extensions import Literal, final
7+
from typing import Any, Generic, Protocol, TypeVar, overload
8+
from typing_extensions import Literal, SupportsIndex, TypeAlias, final
89

910
_T = TypeVar("_T")
10-
_SqliteData = str | bytes | int | float | None
11+
_T_co = TypeVar("_T_co", covariant=True)
12+
_CursorT = TypeVar("_CursorT", bound=Cursor)
13+
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
14+
# Data that is passed through adapters can be of any type accepted by an adapter.
15+
_AdaptedInputData: TypeAlias = _SqliteData | Any
16+
# The Mapping must really be a dict, but making it invariant is too annoying.
17+
_Parameters: TypeAlias = SupportsLenAndGetItem[_AdaptedInputData] | Mapping[str, _AdaptedInputData]
18+
_SqliteOutputData: TypeAlias = str | bytes | int | float | None
19+
_Adapter: TypeAlias = Callable[[_T], _SqliteData]
20+
_Converter: TypeAlias = Callable[[bytes], Any]
1121

1222
paramstyle: str
1323
threadsafety: int
@@ -82,43 +92,39 @@ if sys.version_info >= (3, 7):
8292
SQLITE_SELECT: int
8393
SQLITE_TRANSACTION: int
8494
SQLITE_UPDATE: int
85-
adapters: Any
86-
converters: Any
95+
adapters: dict[tuple[type[Any], type[Any]], _Adapter[Any]]
96+
converters: dict[str, _Converter]
8797
sqlite_version: str
8898
version: str
8999

90-
# TODO: adapt needs to get probed
91-
def adapt(obj, protocol, alternate): ...
100+
# Can take or return anything depending on what's in the registry.
101+
@overload
102+
def adapt(__obj: Any, __proto: Any) -> Any: ...
103+
@overload
104+
def adapt(__obj: Any, __proto: Any, __alt: _T) -> Any | _T: ...
92105
def complete_statement(statement: str) -> bool: ...
93106

94107
if sys.version_info >= (3, 7):
95-
def connect(
96-
database: StrOrBytesPath,
97-
timeout: float = ...,
98-
detect_types: int = ...,
99-
isolation_level: str | None = ...,
100-
check_same_thread: bool = ...,
101-
factory: type[Connection] | None = ...,
102-
cached_statements: int = ...,
103-
uri: bool = ...,
104-
) -> Connection: ...
105-
108+
_DatabaseArg: TypeAlias = StrOrBytesPath
106109
else:
107-
def connect(
108-
database: bytes | str,
109-
timeout: float = ...,
110-
detect_types: int = ...,
111-
isolation_level: str | None = ...,
112-
check_same_thread: bool = ...,
113-
factory: type[Connection] | None = ...,
114-
cached_statements: int = ...,
115-
uri: bool = ...,
116-
) -> Connection: ...
110+
_DatabaseArg: TypeAlias = bytes | str
117111

112+
def connect(
113+
database: _DatabaseArg,
114+
timeout: float = ...,
115+
detect_types: int = ...,
116+
isolation_level: str | None = ...,
117+
check_same_thread: bool = ...,
118+
factory: type[Connection] | None = ...,
119+
cached_statements: int = ...,
120+
uri: bool = ...,
121+
) -> Connection: ...
118122
def enable_callback_tracebacks(__enable: bool) -> None: ...
123+
124+
# takes a pos-or-keyword argument because there is a C wrapper
119125
def enable_shared_cache(enable: int) -> None: ...
120-
def register_adapter(__type: type[_T], __caster: Callable[[_T], int | float | str | bytes]) -> None: ...
121-
def register_converter(__name: str, __converter: Callable[[bytes], Any]) -> None: ...
126+
def register_adapter(__type: type[_T], __caster: _Adapter[_T]) -> None: ...
127+
def register_converter(__name: str, __converter: _Converter) -> None: ...
122128

123129
if sys.version_info < (3, 8):
124130
class Cache:
@@ -127,7 +133,7 @@ if sys.version_info < (3, 8):
127133
def get(self, *args, **kwargs) -> None: ...
128134

129135
class _AggregateProtocol(Protocol):
130-
def step(self, value: int) -> object: ...
136+
def step(self, __value: int) -> object: ...
131137
def finalize(self) -> int: ...
132138

133139
class _SingleParamWindowAggregateClass(Protocol):
@@ -149,22 +155,44 @@ class _WindowAggregateClass(Protocol):
149155
def finalize(self) -> _SqliteData: ...
150156

151157
class Connection:
152-
DataError: Any
153-
DatabaseError: Any
154-
Error: Any
155-
IntegrityError: Any
156-
InterfaceError: Any
157-
InternalError: Any
158-
NotSupportedError: Any
159-
OperationalError: Any
160-
ProgrammingError: Any
161-
Warning: Any
162-
in_transaction: Any
163-
isolation_level: Any
158+
@property
159+
def DataError(self) -> type[sqlite3.DataError]: ...
160+
@property
161+
def DatabaseError(self) -> type[sqlite3.DatabaseError]: ...
162+
@property
163+
def Error(self) -> type[sqlite3.Error]: ...
164+
@property
165+
def IntegrityError(self) -> type[sqlite3.IntegrityError]: ...
166+
@property
167+
def InterfaceError(self) -> type[sqlite3.InterfaceError]: ...
168+
@property
169+
def InternalError(self) -> type[sqlite3.InternalError]: ...
170+
@property
171+
def NotSupportedError(self) -> type[sqlite3.NotSupportedError]: ...
172+
@property
173+
def OperationalError(self) -> type[sqlite3.OperationalError]: ...
174+
@property
175+
def ProgrammingError(self) -> type[sqlite3.ProgrammingError]: ...
176+
@property
177+
def Warning(self) -> type[sqlite3.Warning]: ...
178+
@property
179+
def in_transaction(self) -> bool: ...
180+
isolation_level: str | None # one of '', 'DEFERRED', 'IMMEDIATE' or 'EXCLUSIVE'
181+
@property
182+
def total_changes(self) -> int: ...
164183
row_factory: Any
165184
text_factory: Any
166-
total_changes: Any
167-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
185+
def __init__(
186+
self,
187+
database: _DatabaseArg,
188+
timeout: float = ...,
189+
detect_types: int = ...,
190+
isolation_level: str | None = ...,
191+
check_same_thread: bool = ...,
192+
factory: type[Connection] | None = ...,
193+
cached_statements: int = ...,
194+
uri: bool = ...,
195+
) -> None: ...
168196
def close(self) -> None: ...
169197
if sys.version_info >= (3, 11):
170198
def blobopen(self, __table: str, __column: str, __row: int, *, readonly: bool = ..., name: str = ...) -> Blob: ...
@@ -188,17 +216,21 @@ class Connection:
188216
self, __name: str, __num_params: int, __aggregate_class: Callable[[], _WindowAggregateClass] | None
189217
) -> None: ...
190218

191-
def create_collation(self, __name: str, __callback: Any) -> None: ...
219+
def create_collation(self, __name: str, __callback: Callable[[str, str], int | SupportsIndex] | None) -> None: ...
192220
if sys.version_info >= (3, 8):
193-
def create_function(self, name: str, narg: int, func: Any, *, deterministic: bool = ...) -> None: ...
221+
def create_function(
222+
self, name: str, narg: int, func: Callable[..., _SqliteData], *, deterministic: bool = ...
223+
) -> None: ...
194224
else:
195-
def create_function(self, name: str, num_params: int, func: Any) -> None: ...
225+
def create_function(self, name: str, num_params: int, func: Callable[..., _SqliteData]) -> None: ...
196226

197-
def cursor(self, cursorClass: type | None = ...) -> Cursor: ...
198-
def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Cursor: ...
199-
# TODO: please check in executemany() if seq_of_parameters type is possible like this
200-
def executemany(self, __sql: str, __parameters: Iterable[Iterable[Any]]) -> Cursor: ...
201-
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
227+
@overload
228+
def cursor(self, cursorClass: None = ...) -> Cursor: ...
229+
@overload
230+
def cursor(self, cursorClass: Callable[[], _CursorT]) -> _CursorT: ...
231+
def execute(self, sql: str, parameters: _Parameters = ...) -> Cursor: ...
232+
def executemany(self, __sql: str, __parameters: Iterable[_Parameters]) -> Cursor: ...
233+
def executescript(self, __sql_script: str) -> Cursor: ...
202234
def interrupt(self) -> None: ...
203235
def iterdump(self) -> Generator[str, None, None]: ...
204236
def rollback(self) -> None: ...
@@ -209,8 +241,8 @@ class Connection:
209241
def set_trace_callback(self, trace_callback: Callable[[str], object] | None) -> None: ...
210242
# enable_load_extension and load_extension is not available on python distributions compiled
211243
# without sqlite3 loadable extension support. see footnotes https://docs.python.org/3/library/sqlite3.html#f1
212-
def enable_load_extension(self, enabled: bool) -> None: ...
213-
def load_extension(self, path: str) -> None: ...
244+
def enable_load_extension(self, __enabled: bool) -> None: ...
245+
def load_extension(self, __name: str) -> None: ...
214246
if sys.version_info >= (3, 7):
215247
def backup(
216248
self,
@@ -227,29 +259,32 @@ class Connection:
227259
def serialize(self, *, name: str = ...) -> bytes: ...
228260
def deserialize(self, __data: ReadableBuffer, *, name: str = ...) -> None: ...
229261

230-
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
262+
def __call__(self, __sql: str) -> _Statement: ...
231263
def __enter__(self: Self) -> Self: ...
232264
def __exit__(
233265
self, __type: type[BaseException] | None, __value: BaseException | None, __traceback: TracebackType | None
234266
) -> Literal[False]: ...
235267

236268
class Cursor(Iterator[Any]):
237-
arraysize: Any
238-
connection: Any
239-
description: Any
240-
lastrowid: Any
241-
row_factory: Any
242-
rowcount: int
243-
# TODO: Cursor class accepts exactly 1 argument
244-
# required type is sqlite3.Connection (which is imported as _Connection)
245-
# however, the name of the __init__ variable is unknown
246-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
269+
arraysize: int
270+
@property
271+
def connection(self) -> Connection: ...
272+
@property
273+
def description(self) -> tuple[tuple[str, None, None, None, None, None, None], ...] | None: ...
274+
@property
275+
def lastrowid(self) -> int | None: ...
276+
row_factory: Callable[[Cursor, Row[Any]], object] | None
277+
@property
278+
def rowcount(self) -> int: ...
279+
def __init__(self, __cursor: Connection) -> None: ...
247280
def close(self) -> None: ...
248-
def execute(self, __sql: str, __parameters: Iterable[Any] = ...) -> Cursor: ...
249-
def executemany(self, __sql: str, __seq_of_parameters: Iterable[Iterable[Any]]) -> Cursor: ...
250-
def executescript(self, __sql_script: bytes | str) -> Cursor: ...
281+
def execute(self: Self, __sql: str, __parameters: _Parameters = ...) -> Self: ...
282+
def executemany(self: Self, __sql: str, __seq_of_parameters: Iterable[_Parameters]) -> Self: ...
283+
def executescript(self, __sql_script: str) -> Cursor: ...
251284
def fetchall(self) -> list[Any]: ...
252285
def fetchmany(self, size: int | None = ...) -> list[Any]: ...
286+
# Returns either a row (as created by the row_factory) or None, but
287+
# putting None in the return annotation causes annoying false positives.
253288
def fetchone(self) -> Any: ...
254289
def setinputsizes(self, __sizes: object) -> None: ... # does nothing
255290
def setoutputsize(self, __size: object, __column: object = ...) -> None: ... # does nothing
@@ -274,28 +309,37 @@ OptimizedUnicode = str
274309

275310
@final
276311
class PrepareProtocol:
277-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
312+
def __init__(self, *args: object, **kwargs: object) -> None: ...
278313

279314
class ProgrammingError(DatabaseError): ...
280315

281-
class Row:
282-
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
283-
def keys(self): ...
284-
def __eq__(self, __other): ...
285-
def __ge__(self, __other): ...
286-
def __getitem__(self, __index): ...
287-
def __gt__(self, __other): ...
288-
def __hash__(self): ...
289-
def __iter__(self): ...
290-
def __le__(self, __other): ...
291-
def __len__(self): ...
292-
def __lt__(self, __other): ...
293-
def __ne__(self, __other): ...
316+
class Row(Generic[_T_co]):
317+
def __init__(self, __cursor: Cursor, __data: tuple[_T_co, ...]) -> None: ...
318+
def keys(self) -> list[str]: ...
319+
@overload
320+
def __getitem__(self, __index: int | str) -> _T_co: ...
321+
@overload
322+
def __getitem__(self, __index: slice) -> tuple[_T_co, ...]: ...
323+
def __hash__(self) -> int: ...
324+
def __iter__(self) -> Iterator[_T_co]: ...
325+
def __len__(self) -> int: ...
326+
# These return NotImplemented for anything that is not a Row.
327+
def __eq__(self, __other: object) -> bool: ...
328+
def __ge__(self, __other: object) -> bool: ...
329+
def __gt__(self, __other: object) -> bool: ...
330+
def __le__(self, __other: object) -> bool: ...
331+
def __lt__(self, __other: object) -> bool: ...
332+
def __ne__(self, __other: object) -> bool: ...
294333

295-
if sys.version_info < (3, 8):
334+
if sys.version_info >= (3, 8):
335+
@final
336+
class _Statement: ...
337+
338+
else:
296339
@final
297340
class Statement:
298341
def __init__(self, *args, **kwargs): ...
342+
_Statement: TypeAlias = Statement
299343

300344
class Warning(Exception): ...
301345

tests/stubtest_allowlists/win32-py310.txt

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
sqlite3.Connection.enable_load_extension
2-
sqlite3.Connection.load_extension
3-
sqlite3.dbapi2.Connection.enable_load_extension
4-
sqlite3.dbapi2.Connection.load_extension
5-
61
# pathlib methods that exist on Windows, but always raise NotImplementedError,
72
# so are omitted from the stub
83
pathlib.WindowsPath.is_mount

0 commit comments

Comments
 (0)