Skip to content

Commit b9895fd

Browse files
authored
GH - 624: Added dtype arg to read_sql (#649)
* GH - 624: Added dtype arg to reaf_sql * added dtype_backend and test * Update test_io.py * updated the tests * Update test_io.py * corrected the 'dtype_backend; and added tests for it * Update test_io.py
1 parent 3669013 commit b9895fd

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

pandas-stubs/_typing.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class FulldatetimeDict(YearMonthDayDict, total=False):
7777
NpDtype: TypeAlias = str | np.dtype[np.generic] | type[str | complex | bool | object]
7878
Dtype: TypeAlias = ExtensionDtype | NpDtype
7979
DtypeArg: TypeAlias = Dtype | dict[Any, Dtype]
80+
DtypeBackend: TypeAlias = Literal["pyarrow", "numpy_nullable"]
8081
BooleanDtypeArg: TypeAlias = (
8182
# Builtin bool type and its string alias
8283
type[bool] # noqa: Y030

pandas-stubs/io/sql.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ import sqlalchemy.engine
1616
import sqlalchemy.sql.expression
1717
from typing_extensions import TypeAlias
1818

19+
from pandas._libs.lib import NoDefault
1920
from pandas._typing import (
2021
DtypeArg,
22+
DtypeBackend,
2123
npt,
2224
)
2325

@@ -84,6 +86,8 @@ def read_sql(
8486
columns: list[str] = ...,
8587
*,
8688
chunksize: int,
89+
dtype: DtypeArg | None = ...,
90+
dtype_backend: DtypeBackend | NoDefault = ...,
8791
) -> Generator[DataFrame, None, None]: ...
8892
@overload
8993
def read_sql(
@@ -95,6 +99,8 @@ def read_sql(
9599
parse_dates: list[str] | dict[str, str] | dict[str, dict[str, Any]] | None = ...,
96100
columns: list[str] = ...,
97101
chunksize: None = ...,
102+
dtype: DtypeArg | None = ...,
103+
dtype_backend: DtypeBackend | NoDefault = ...,
98104
) -> DataFrame: ...
99105

100106
class PandasSQL(PandasObject):

tests/test_io.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,3 +1204,68 @@ def test_sqlalchemy_text() -> None:
12041204
assert_type(read_sql(sql_select, con=conn), DataFrame),
12051205
DataFrame,
12061206
)
1207+
1208+
1209+
def test_read_sql_dtype() -> None:
1210+
with ensure_clean() as path:
1211+
conn = sqlite3.connect(path)
1212+
df = pd.DataFrame(
1213+
data=[[0, "10/11/12"], [1, "12/11/10"]],
1214+
columns=["int_column", "date_column"],
1215+
)
1216+
check(assert_type(df.to_sql("test_data", con=conn), Union[int, None]), int)
1217+
check(
1218+
assert_type(
1219+
pd.read_sql(
1220+
"SELECT int_column, date_column FROM test_data",
1221+
con=conn,
1222+
dtype=None,
1223+
),
1224+
pd.DataFrame,
1225+
),
1226+
pd.DataFrame,
1227+
)
1228+
check(
1229+
assert_type(
1230+
pd.read_sql(
1231+
"SELECT int_column, date_column FROM test_data",
1232+
con=conn,
1233+
dtype={"int_column": int},
1234+
),
1235+
pd.DataFrame,
1236+
),
1237+
pd.DataFrame,
1238+
)
1239+
check(assert_type(DF.to_sql("test", con=conn), Union[int, None]), int)
1240+
1241+
check(
1242+
assert_type(
1243+
read_sql("select * from test", con=conn, dtype=int),
1244+
pd.DataFrame,
1245+
),
1246+
pd.DataFrame,
1247+
)
1248+
conn.close()
1249+
1250+
1251+
def test_read_sql_dtype_backend() -> None:
1252+
with ensure_clean() as path:
1253+
conn2 = sqlite3.connect(path)
1254+
check(assert_type(DF.to_sql("test", con=conn2), Union[int, None]), int)
1255+
check(
1256+
assert_type(
1257+
read_sql("select * from test", con=conn2, dtype_backend="pyarrow"),
1258+
pd.DataFrame,
1259+
),
1260+
pd.DataFrame,
1261+
)
1262+
check(
1263+
assert_type(
1264+
read_sql(
1265+
"select * from test", con=conn2, dtype_backend="numpy_nullable"
1266+
),
1267+
pd.DataFrame,
1268+
),
1269+
pd.DataFrame,
1270+
)
1271+
conn2.close()

0 commit comments

Comments
 (0)