Skip to content

TST: refactor iris table creation in SQL test #42988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 13, 2021
138 changes: 77 additions & 61 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
time,
)
from io import StringIO
from pathlib import Path
import sqlite3
import warnings

Expand Down Expand Up @@ -72,34 +73,6 @@
SQLALCHEMY_INSTALLED = False

SQL_STRINGS = {
"create_iris": {
"sqlite": """CREATE TABLE iris (
"SepalLength" REAL,
"SepalWidth" REAL,
"PetalLength" REAL,
"PetalWidth" REAL,
"Name" TEXT
)""",
"mysql": """CREATE TABLE iris (
`SepalLength` DOUBLE,
`SepalWidth` DOUBLE,
`PetalLength` DOUBLE,
`PetalWidth` DOUBLE,
`Name` VARCHAR(200)
)""",
"postgresql": """CREATE TABLE iris (
"SepalLength" DOUBLE PRECISION,
"SepalWidth" DOUBLE PRECISION,
"PetalLength" DOUBLE PRECISION,
"PetalWidth" DOUBLE PRECISION,
"Name" VARCHAR(200)
)""",
},
"insert_iris": {
"sqlite": """INSERT INTO iris VALUES(?, ?, ?, ?, ?)""",
"mysql": """INSERT INTO iris VALUES(%s, %s, %s, %s, "%s");""",
"postgresql": """INSERT INTO iris VALUES(%s, %s, %s, %s, %s);""",
},
"create_test_types": {
"sqlite": """CREATE TABLE types_test_data (
"TextCol" TEXT,
Expand Down Expand Up @@ -192,7 +165,7 @@
},
"read_parameters": {
"sqlite": "SELECT * FROM iris WHERE Name=? AND SepalLength=?",
"mysql": 'SELECT * FROM iris WHERE `Name`="%s" AND `SepalLength`=%s',
"mysql": "SELECT * FROM iris WHERE `Name`=%s AND `SepalLength`=%s",
"postgresql": 'SELECT * FROM iris WHERE "Name"=%s AND "SepalLength"=%s',
},
"read_named_parameters": {
Expand All @@ -201,7 +174,7 @@
""",
"mysql": """
SELECT * FROM iris WHERE
`Name`="%(name)s" AND `SepalLength`=%(length)s
`Name`=%(name)s AND `SepalLength`=%(length)s
""",
"postgresql": """
SELECT * FROM iris WHERE
Expand All @@ -222,6 +195,73 @@
}


def iris_table_metadata(dialect: str):
from sqlalchemy import (
REAL,
Column,
Float,
MetaData,
String,
Table,
)

dtype = Float if dialect == "postgresql" else REAL
metadata = MetaData()
iris = Table(
"iris",
metadata,
Column("SepalLength", dtype),
Column("SepalWidth", dtype),
Column("PetalLength", dtype),
Column("PetalWidth", dtype),
Column("Name", String(200)),
)
return iris


def create_and_load_iris_sqlite3(conn: sqlite3.Connection, iris_file: Path):
cur = conn.cursor()
stmt = """CREATE TABLE iris (
"SepalLength" REAL,
"SepalWidth" REAL,
"PetalLength" REAL,
"PetalWidth" REAL,
"Name" TEXT
)"""
cur.execute(stmt)
with iris_file.open(newline=None) as csvfile:
reader = csv.reader(csvfile)
next(reader)
stmt = "INSERT INTO iris VALUES(?, ?, ?, ?, ?)"
cur.executemany(stmt, reader)


def create_and_load_iris(conn, iris_file: Path, dialect: str):
from sqlalchemy import insert
from sqlalchemy.engine import Engine

iris = iris_table_metadata(dialect)
iris.drop(conn, checkfirst=True)
iris.create(bind=conn)

with iris_file.open(newline=None) as csvfile:
reader = csv.reader(csvfile)
header = next(reader)
params = [{key: value for key, value in zip(header, row)} for row in reader]
stmt = insert(iris).values(params)
if isinstance(conn, Engine):
with conn.connect() as conn:
conn.execute(stmt)
else:
conn.execute(stmt)


@pytest.fixture
def iris_path(datapath):
iris_path = datapath("io", "data", "csv", "iris.csv")
return Path(iris_path)


@pytest.fixture
def test_frame1():
columns = ["index", "A", "B", "C", "D"]
Expand Down Expand Up @@ -341,24 +381,15 @@ def _get_exec(self):
else:
return self.conn.cursor()

@pytest.fixture(params=[("io", "data", "csv", "iris.csv")])
def load_iris_data(self, datapath, request):

iris_csv_file = datapath(*request.param)

@pytest.fixture
def load_iris_data(self, iris_path):
if not hasattr(self, "conn"):
self.setup_connect()

self.drop_table("iris")
self._get_exec().execute(SQL_STRINGS["create_iris"][self.flavor])

with open(iris_csv_file, newline=None) as iris_csv:
r = csv.reader(iris_csv)
next(r) # skip header row
ins = SQL_STRINGS["insert_iris"][self.flavor]

for row in r:
self._get_exec().execute(ins, row)
if isinstance(self.conn, sqlite3.Connection):
create_and_load_iris_sqlite3(self.conn, iris_path)
else:
create_and_load_iris(self.conn, iris_path, self.flavor)

def _load_iris_view(self):
self.drop_table("iris_view")
Expand Down Expand Up @@ -1248,21 +1279,6 @@ def test_database_uri_string(self, test_frame1):
with pytest.raises(ImportError, match="pg8000"):
sql.read_sql("select * from table", db_uri)

def _make_iris_table_metadata(self):
sa = sqlalchemy
metadata = sa.MetaData()
iris = sa.Table(
"iris",
metadata,
sa.Column("SepalLength", sa.REAL),
sa.Column("SepalWidth", sa.REAL),
sa.Column("PetalLength", sa.REAL),
sa.Column("PetalWidth", sa.REAL),
sa.Column("Name", sa.TEXT),
)

return iris

def test_query_by_text_obj(self):
# WIP : GH10846
name_text = sqlalchemy.text("select * from iris where name=:name")
Expand All @@ -1272,7 +1288,7 @@ def test_query_by_text_obj(self):

def test_query_by_select_obj(self):
# WIP : GH10846
iris = self._make_iris_table_metadata()
iris = iris_table_metadata(self.flavor)

name_select = sqlalchemy.select([iris]).where(
iris.c.Name == sqlalchemy.bindparam("name")
Expand Down