diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 1ce1bac3b2b7b..d9747c525771b 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -24,6 +24,7 @@ time, ) from io import StringIO +from pathlib import Path import sqlite3 import warnings @@ -72,34 +73,6 @@ SQLALCHEMY_INSTALLED = False SQL_STRINGS = { - "create_iris": { - "sqlite": """CREATE TABLE iris ( - "SepalLength" REAL, - "SepalWidth" REAL, - "PetalLength" REAL, - "PetalWidth" REAL, - "Name" TEXT - )""", - "mysql": """CREATE TABLE iris ( - `SepalLength` DOUBLE, - `SepalWidth` DOUBLE, - `PetalLength` DOUBLE, - `PetalWidth` DOUBLE, - `Name` VARCHAR(200) - )""", - "postgresql": """CREATE TABLE iris ( - "SepalLength" DOUBLE PRECISION, - "SepalWidth" DOUBLE PRECISION, - "PetalLength" DOUBLE PRECISION, - "PetalWidth" DOUBLE PRECISION, - "Name" VARCHAR(200) - )""", - }, - "insert_iris": { - "sqlite": """INSERT INTO iris VALUES(?, ?, ?, ?, ?)""", - "mysql": """INSERT INTO iris VALUES(%s, %s, %s, %s, "%s");""", - "postgresql": """INSERT INTO iris VALUES(%s, %s, %s, %s, %s);""", - }, "create_test_types": { "sqlite": """CREATE TABLE types_test_data ( "TextCol" TEXT, @@ -192,7 +165,7 @@ }, "read_parameters": { "sqlite": "SELECT * FROM iris WHERE Name=? AND SepalLength=?", - "mysql": 'SELECT * FROM iris WHERE `Name`="%s" AND `SepalLength`=%s', + "mysql": "SELECT * FROM iris WHERE `Name`=%s AND `SepalLength`=%s", "postgresql": 'SELECT * FROM iris WHERE "Name"=%s AND "SepalLength"=%s', }, "read_named_parameters": { @@ -201,7 +174,7 @@ """, "mysql": """ SELECT * FROM iris WHERE - `Name`="%(name)s" AND `SepalLength`=%(length)s + `Name`=%(name)s AND `SepalLength`=%(length)s """, "postgresql": """ SELECT * FROM iris WHERE @@ -222,6 +195,73 @@ } +def iris_table_metadata(dialect: str): + from sqlalchemy import ( + REAL, + Column, + Float, + MetaData, + String, + Table, + ) + + dtype = Float if dialect == "postgresql" else REAL + metadata = MetaData() + iris = Table( + "iris", + metadata, + Column("SepalLength", dtype), + Column("SepalWidth", dtype), + Column("PetalLength", dtype), + Column("PetalWidth", dtype), + Column("Name", String(200)), + ) + return iris + + +def create_and_load_iris_sqlite3(conn: sqlite3.Connection, iris_file: Path): + cur = conn.cursor() + stmt = """CREATE TABLE iris ( + "SepalLength" REAL, + "SepalWidth" REAL, + "PetalLength" REAL, + "PetalWidth" REAL, + "Name" TEXT + )""" + cur.execute(stmt) + with iris_file.open(newline=None) as csvfile: + reader = csv.reader(csvfile) + next(reader) + stmt = "INSERT INTO iris VALUES(?, ?, ?, ?, ?)" + cur.executemany(stmt, reader) + + +def create_and_load_iris(conn, iris_file: Path, dialect: str): + from sqlalchemy import insert + from sqlalchemy.engine import Engine + + iris = iris_table_metadata(dialect) + iris.drop(conn, checkfirst=True) + iris.create(bind=conn) + + with iris_file.open(newline=None) as csvfile: + reader = csv.reader(csvfile) + header = next(reader) + params = [{key: value for key, value in zip(header, row)} for row in reader] + stmt = insert(iris).values(params) + if isinstance(conn, Engine): + with conn.connect() as conn: + conn.execute(stmt) + else: + conn.execute(stmt) + + +@pytest.fixture +def iris_path(datapath): + iris_path = datapath("io", "data", "csv", "iris.csv") + return Path(iris_path) + + @pytest.fixture def test_frame1(): columns = ["index", "A", "B", "C", "D"] @@ -341,24 +381,15 @@ def _get_exec(self): else: return self.conn.cursor() - @pytest.fixture(params=[("io", "data", "csv", "iris.csv")]) - def load_iris_data(self, datapath, request): - - iris_csv_file = datapath(*request.param) - + @pytest.fixture + def load_iris_data(self, iris_path): if not hasattr(self, "conn"): self.setup_connect() - self.drop_table("iris") - self._get_exec().execute(SQL_STRINGS["create_iris"][self.flavor]) - - with open(iris_csv_file, newline=None) as iris_csv: - r = csv.reader(iris_csv) - next(r) # skip header row - ins = SQL_STRINGS["insert_iris"][self.flavor] - - for row in r: - self._get_exec().execute(ins, row) + if isinstance(self.conn, sqlite3.Connection): + create_and_load_iris_sqlite3(self.conn, iris_path) + else: + create_and_load_iris(self.conn, iris_path, self.flavor) def _load_iris_view(self): self.drop_table("iris_view") @@ -1248,21 +1279,6 @@ def test_database_uri_string(self, test_frame1): with pytest.raises(ImportError, match="pg8000"): sql.read_sql("select * from table", db_uri) - def _make_iris_table_metadata(self): - sa = sqlalchemy - metadata = sa.MetaData() - iris = sa.Table( - "iris", - metadata, - sa.Column("SepalLength", sa.REAL), - sa.Column("SepalWidth", sa.REAL), - sa.Column("PetalLength", sa.REAL), - sa.Column("PetalWidth", sa.REAL), - sa.Column("Name", sa.TEXT), - ) - - return iris - def test_query_by_text_obj(self): # WIP : GH10846 name_text = sqlalchemy.text("select * from iris where name=:name") @@ -1272,7 +1288,7 @@ def test_query_by_text_obj(self): def test_query_by_select_obj(self): # WIP : GH10846 - iris = self._make_iris_table_metadata() + iris = iris_table_metadata(self.flavor) name_select = sqlalchemy.select([iris]).where( iris.c.Name == sqlalchemy.bindparam("name")