diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 2dfa505bc4932..1c80e102ea0d6 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2821,6 +2821,7 @@ def to_sql( con, schema: str | None = None, if_exists: str = "fail", + on_row_conflict: str = "fail", index: bool_t = True, index_label: IndexLabel = None, chunksize: int | None = None, @@ -2854,6 +2855,15 @@ def to_sql( * replace: Drop the table before inserting new values. * append: Insert new values to the existing table. + on_row_conflict : {'fail', 'replace', 'ignore'}, default 'fail' + How to behave if a row already exists. + + * fail: Raise ValueError. + * replace: Overwrite the row with the incoming data. + * ignore: Ignore new data and keep existing data. + + .. versionadded:: 1.5.0 + index : bool, default True Write DataFrame index as a column. Uses `index_label` as the column name in the table. @@ -2898,6 +2908,9 @@ def to_sql( ValueError When the table already exists and `if_exists` is 'fail' (the default). + ValueError + When the row already exists and `on_row_conflict` is 'fail' (the + default). See Also -------- @@ -2990,6 +3003,7 @@ def to_sql( con, schema=schema, if_exists=if_exists, + on_row_conflict=on_row_conflict, index=index, index_label=index_label, chunksize=chunksize, diff --git a/pandas/io/sql.py b/pandas/io/sql.py index b4432abd1061a..f851a40b0f2f2 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -24,6 +24,7 @@ import warnings import numpy as np +import sqlalchemy import pandas._libs.lib as lib from pandas._typing import ( @@ -604,6 +605,7 @@ def to_sql( con, schema: str | None = None, if_exists: str = "fail", + on_row_conflict: str = "fail", index: bool = True, index_label: IndexLabel = None, chunksize: int | None = None, @@ -632,6 +634,16 @@ def to_sql( - fail: If table exists, do nothing. - replace: If table exists, drop it, recreate it, and insert data. - append: If table exists, insert data. Create if does not exist. + on_row_conflict : {'fail', 'replace', 'append'}, default 'fail' + Determine insertion behavior in case of a primary key clash. + - fail: Do nothing to handle primary key clashes, will raise an Error. + - replace: Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys + - append: Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + + .. versionadded:: 1.5.0 + index : bool, default True Write DataFrame index as a column. index_label : str or sequence, optional @@ -681,9 +693,20 @@ def to_sql( `sqlite3 `__ or `SQLAlchemy `__ """ # noqa:E501 + if if_exists not in ("fail", "replace", "append"): raise ValueError(f"'{if_exists}' is not valid for if_exists") + if on_row_conflict not in ("fail", "replace", "ignore"): + raise ValueError(f"'{on_row_conflict}' is not valid for on_row_conflict'") + + if if_exists != "append" and on_row_conflict in {"replace", "ignore"}: + # on_row_conflict only used with append + raise ValueError( + f"on_row_conflict {on_row_conflict} can only be used with 'append' " + "operations" + ) + pandas_sql = pandasSQL_builder(con, schema=schema) if isinstance(frame, Series): @@ -697,6 +720,7 @@ def to_sql( frame, name, if_exists=if_exists, + on_row_conflict=on_row_conflict, index=index, index_label=index_label, schema=schema, @@ -785,6 +809,7 @@ def __init__( frame=None, index: bool | str | list[str] | None = True, if_exists: str = "fail", + on_row_conflict: str = "fail", prefix: str = "pandas", index_label=None, schema=None, @@ -798,6 +823,7 @@ def __init__( self.index = self._index_name(index, index_label) self.schema = schema self.if_exists = if_exists + self.on_row_conflict = on_row_conflict self.keys = keys self.dtype = dtype @@ -838,6 +864,193 @@ def create(self) -> None: else: self._execute_create() + def _load_existing_pkeys( + self, primary_keys: list[str], primary_key_values: list[str] + ) -> list[str]: + """ + Load existing primary keys from Database + + Parameters + ---------- + primary_keys : list of str + List of primary key column names + primary_key_values : list of str + List of primary key values already present in incoming dataframe + + Returns + ------- + list of str + primary key values in incoming dataframe which already exist in database + """ + from sqlalchemy import ( + and_, + select, + ) + + cols_to_fetch = [self.table.c[key] for key in primary_keys] + + select_stmt = select(cols_to_fetch).where( + and_( + col.in_(key[i] for key in primary_key_values) + for i, col in enumerate(cols_to_fetch) + ) + ) + return self.pd_sql.execute(select_stmt).fetchall() + + def _split_incoming_data( + self, primary_keys: list[str], keys_in_db: list[str] + ) -> (DataFrame, DataFrame): + """ + Split incoming dataframe based off whether primary key already exists in db. + + Parameters + ---------- + primary_keys : list of str + Primary keys columns + keys_in_db : list of str + Primary key values which already exist in database table + + Returns + ------- + tuple of DataFrame, DataFrame + DataFrame of rows with duplicate pkey, DataFrame of rows with new pkey + """ + from pandas.core.indexes.multi import MultiIndex + + in_db = _wrap_result(data=keys_in_db, columns=primary_keys) + # Get temporary dataframe so as not to delete values from main df + temp = self._get_index_formatted_dataframe() + # Create multi-indexes for membership lookup + in_db_idx = MultiIndex.from_arrays([in_db[col] for col in primary_keys]) + tmp_idx = MultiIndex.from_arrays([temp[col] for col in primary_keys]) + exists_mask = tmp_idx.isin(in_db_idx) + return temp.loc[exists_mask], temp.loc[~exists_mask] + + def _generate_update_statements( + self, primary_keys: list[str], keys_in_db: list[str], rows_to_update: DataFrame + ) -> list[sqlalchemy.sql.dml.Update]: + """ + Generate SQL Update statements for rows with existing primary keys + + Currently, SQL Update statements do not support a multi-statement query, + therefore this method returns a list of individual update queries which + will need to be executed in one transaction. + + Parameters + ---------- + primary_keys : list of str + Primary key columns + keys_in_db : list of str + Primary key values which already exist in database table + rows_to_update : DataFrame + DataFrame of rows containing data with which to update existing pkeys + + Returns + ------- + list of sqlalchemy.sql.dml.Update + List of update queries + """ + from sqlalchemy import and_ + + new_records = rows_to_update.to_dict(orient="records") + pk_cols = [self.table.c[key] for key in primary_keys] + + # TODO: Move this or remove entirely + assert len(new_records) == len( + keys_in_db + ), "Mismatch between new records and existing keys" + stmts = [] + for i, keys in enumerate(keys_in_db): + stmt = ( + self.table.update() + .where(and_(col == keys[j] for j, col in enumerate(pk_cols))) + .values(new_records[i]) + ) + stmts.append(stmt) + return stmts + + def _on_row_conflict_replace(self) -> (DataFrame, sqlalchemy.sql.dml.Update): + """ + Generate update statements for rows with clashing primary key from database. + + `on_row_conflict replace` prioritizes incoming data, over existing data in + the DB. This method splits the incoming dataframe between rows with new and + existing primary key values. + For existing values Update statements are generated, while new values are passed + on to be inserted as usual. + + Updates are executed in the same transaction as the ensuing data insert. + + Returns + ---------- + tuple of Dataframe and list of sqlalchemy.sql.dml.Update + DataFrame of rows with new pkey, List of update queries + """ + # Primary key data + pk_cols, pk_values = self._get_primary_key_data() + existing_keys = self._load_existing_pkeys(pk_cols, pk_values) + existing_data, new_data = self._split_incoming_data(pk_cols, existing_keys) + update_stmts = self._generate_update_statements( + pk_cols, existing_keys, existing_data + ) + return new_data, update_stmts + + def _on_row_conflict_ignore(self): + """ + Split incoming dataframe so that only rows with new primary keys are inserted + + `on_row_conflict` set to `ignore` prioritizes existing data in the DB. + This method identifies incoming records in the primary key columns + which correspond to existing primary key constraints in the db table, and + avoids them from being inserted. + + Returns + ---------- + Dataframe + DataFrame of rows with new pkey + + """ + pk_cols, pk_values = self._get_primary_key_data() + existing_keys = self._load_existing_pkeys(pk_cols, pk_values) + _, new_data = self._split_incoming_data(pk_cols, existing_keys) + return new_data + + def _get_primary_key_data(self) -> (list[str], list[str]): + """ + Get primary keys from database, and yield dataframe columns with same names. + + Upsert workflows require knowledge of what is already in the database. + This method reflects the meta object and gets a list of primary keys, + it then returns all columns from the incoming dataframe with names matching + these keys. + + Returns + ------- + primary_keys : list of str + Primary key names + primary_key_values : list of str + DataFrame rows, for columns corresponding to `primary_key` names + """ + # reflect MetaData object and assign contents of db to self.table attribute + bind = None + if not self.pd_sql.meta.is_bound(): + bind = self.pd_sql.connectable + self.pd_sql.meta.reflect(bind=bind, only=[self.name], views=True) + self.table = self.pd_sql.get_table(table_name=self.name, schema=self.schema) + + primary_keys = self.table.primary_key.columns.keys() + + # For the time being, this method is defensive and will break if + # no pkeys are found. If desired this default behaviour could be + # changed so that in cases where no pkeys are found, + # it could default to a normal insert + if len(primary_keys) == 0: + raise ValueError(f"No primary keys found for table {self.name}") + + temp = self._get_index_formatted_dataframe() + primary_key_values = list(zip(*(temp[key] for key in primary_keys))) + return primary_keys, primary_key_values + def _execute_insert(self, conn, keys: list[str], data_iter) -> int: """ Execute SQL statement inserting data @@ -870,6 +1083,30 @@ def _execute_insert_multi(self, conn, keys: list[str], data_iter) -> int: result = conn.execute(stmt) return result.rowcount + def _get_index_formatted_dataframe(self) -> DataFrame: + """ + Format index of incoming dataframe to be aligned with a database table. + + Copy original dataframe, and check whether the dataframe index + is to be added to the database table. + If it is, reset the index so that it becomes a normal column, else return + + Returns + ------- + DataFrame + """ + # Originally this functionality formed the first step of the insert_data method. + # It will be useful to have in other places, so moved here to keep code DRY. + temp = self.frame.copy() + if self.index is not None: + temp.index.names = self.index + try: + temp.reset_index(inplace=True) + except ValueError as err: + raise ValueError(f"duplicate name in index/columns: {err}") from err + + return temp + def insert_data(self) -> tuple[list[str], list[np.ndarray]]: if self.index is not None: temp = self.frame.copy() @@ -911,7 +1148,24 @@ def insert_data(self) -> tuple[list[str], list[np.ndarray]]: def insert( self, chunksize: int | None = None, method: str | None = None ) -> int | None: + """ + Determines what data to pass to the underlying insert method. + """ + if self.on_row_conflict == "replace": + new_data, update_stmts = self._on_row_conflict_replace() + return self._insert( + data=new_data, + chunksize=chunksize, + method=method, + other_stmts=update_stmts, + ) + elif self.on_row_conflict == "ignore": + new_data = self._on_row_conflict_ignore() + return self._insert(data=new_data, chunksize=chunksize, method=method) + else: + return self._insert(chunksize=chunksize, method=method) + def _insert(self, data=None, chunksize=None, method=None, other_stmts=[]): # set insert method if method is None: exec_insert = self._execute_insert @@ -922,9 +1176,12 @@ def insert( else: raise ValueError(f"Invalid parameter `method`: {method}") - keys, data_list = self.insert_data() + if data is None: + data = self._get_index_formatted_dataframe() - nrows = len(self.frame) + keys, data_list = self.insert_data(data=data) + + nrows = len(data) if nrows == 0: return 0 @@ -937,6 +1194,13 @@ def insert( chunks = (nrows // chunksize) + 1 total_inserted = None with self.pd_sql.run_transaction() as conn: + if len(other_stmts) > 0: + rows_executed = 0 + for stmt in other_stmts: + result = conn.execute(stmt) + rows_executed += result.rowcount + total_inserted = rows_executed + for i in range(chunks): start_i = i * chunksize end_i = min((i + 1) * chunksize, nrows) @@ -1270,6 +1534,7 @@ def to_sql( frame, name, if_exists: str = "fail", + on_row_conflict: str = "fail", index: bool = True, index_label=None, schema=None, @@ -1590,6 +1855,7 @@ def prep_table( frame, name, if_exists="fail", + on_row_conflict="fail", index=True, index_label=None, schema=None, @@ -1625,6 +1891,7 @@ def prep_table( frame=frame, index=index, if_exists=if_exists, + on_row_conflict=on_row_conflict, index_label=index_label, schema=schema, dtype=dtype, @@ -1667,6 +1934,7 @@ def to_sql( frame, name, if_exists: str = "fail", + on_row_conflict: str = "fail", index: bool = True, index_label=None, schema=None, @@ -1688,6 +1956,13 @@ def to_sql( - fail: If table exists, do nothing. - replace: If table exists, drop it, recreate it, and insert data. - append: If table exists, insert data. Create if does not exist. + on_row_conflict : {'fail', 'ignore', 'replace'}, default 'fail' + Determine insertion behavior in case of a primary key clash. + - fail: Do nothing to handle primary key clashes, will raise an Error. + - ignore: Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + - replace: Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys index : boolean, default True Write DataFrame index as a column. index_label : string or sequence, default None @@ -1730,6 +2005,7 @@ def to_sql( frame=frame, name=name, if_exists=if_exists, + on_row_conflict=on_row_conflict, index=index, index_label=index_label, schema=schema, @@ -2114,6 +2390,7 @@ def to_sql( frame, name, if_exists: str = "fail", + on_row_conflict: str = "fail", index: bool = True, index_label=None, schema=None, @@ -2134,6 +2411,13 @@ def to_sql( fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if it does not exist. + on_row_conflict : {'fail', 'ignore', 'replace'}, default 'fail' + Determine insertion behavior in case of a primary key clash. + - fail: Do nothing to handle primary key clashes, will raise an Error. + - ignore: Ignore incoming rows with primary key clashes, and + insert only the incoming rows with non-conflicting primary keys + - replace: Update existing rows in database with primary key clashes, + and append the remaining rows with non-conflicting primary keys index : bool, default True Write DataFrame index as a column index_label : string or sequence, default None diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index c2c47672b190d..be83b4ed72d08 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -95,9 +95,115 @@ "mysql": "SELECT * FROM iris WHERE `Name` LIKE '%'", "postgresql": "SELECT * FROM iris WHERE \"Name\" LIKE '%'", }, + "read_pkey_table": { + "pkey_table_single": { + "sqlite": """SELECT c FROM pkey_table_single WHERE A IN (?, ?)""", + "mysql": """SELECT c FROM pkey_table_single WHERE A IN (%s, %s)""", + "postgresql": """SELECT c FROM pkey_table_single WHERE A IN (%s, %s)""", + }, + "pkey_table_comp": { + "sqlite": """SELECT c FROM pkey_table_comp WHERE A IN (?, ?)""", + "mysql": """SELECT c FROM pkey_table_comp WHERE A IN (%s, %s)""", + "postgresql": """SELECT c FROM pkey_table_comp WHERE A IN (%s, %s)""", + }, + }, } +def pkey_single_table_metadata(): + from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + ) + + metadata = MetaData() + pkeys = Table( + "pkey_table_single", + metadata, + Column("a", Integer, primary_key=True), + Column("b", String(200)), + Column("c", String(200)), + ) + return pkeys + + +def pkey_comp_table_metadata(): + from sqlalchemy import ( + Column, + Integer, + MetaData, + String, + Table, + ) + + metadata = MetaData() + pkeys = Table( + "pkey_table_comp", + metadata, + Column("a", Integer, primary_key=True), + Column("b", String(200), primary_key=True), + Column("c", String(200)), + ) + return pkeys + + +def create_and_load_pkey(conn): + from sqlalchemy import insert + from sqlalchemy.engine import Engine + + pkey_single = pkey_single_table_metadata() + pkey_comp = pkey_comp_table_metadata() + + pkey_single.drop(conn, checkfirst=True) + pkey_single.create(bind=conn) + pkey_comp.drop(conn, checkfirst=True) + pkey_comp.create(bind=conn) + + headers = ["a", "b", "c"] + data = [(1, "name1", "val1"), (2, "name2", "val2"), (3, "name3", "val3")] + params = [{key: value for key, value in zip(headers, row)} for row in data] + + stmt_single = insert(pkey_single).values(params) + stmt_comp = insert(pkey_comp).values(params) + + if isinstance(conn, Engine): + with conn.connect() as conn: + with conn.begin(): + conn.execute(stmt_single) + conn.execute(stmt_comp) + else: + conn.execute(stmt_single) + conn.execute(stmt_comp) + + +def create_and_load_pkey_sqlite3(conn: sqlite3.Connection): + cur = conn.cursor() + stmt_single = """ + CREATE TABLE pkey_table_single ( + "a" Primary Key, + "b" TEXT, + "c" TEXT + ) + """ + stmt_comp = """ + CREATE TABLE pkey_table_comp ( + "a" Integer, + "b" TEXT, + "c" TEXT, + PRIMARY KEY ("a", "b") + ) + """ + cur.execute(stmt_single) + cur.execute(stmt_comp) + data = [(1, "name1", "val1"), (2, "name2", "val2"), (3, "name3", "val3")] + for tbl in ["pkey_table_single", "pkey_table_comp"]: + stmt = f"INSERT INTO {tbl} VALUES (?, ?, ?)" + cur.executemany(stmt, data) + + def iris_table_metadata(dialect: str): from sqlalchemy import ( REAL, @@ -275,6 +381,31 @@ def count_rows(conn, table_name: str): return result.fetchone()[0] +def read_pkeys_from_database(conn, tbl_name: str, duplicate_keys: list[int]): + if isinstance(conn, sqlite3.Connection): + stmt = f"""SELECT c FROM {tbl_name} WHERE A IN (?, ?)""" + cur = conn.cursor() + result = cur.execute(stmt, duplicate_keys) + else: + from sqlalchemy import ( + MetaData, + Table, + select, + ) + from sqlalchemy.engine import Engine + + meta = MetaData() + tbl = Table(tbl_name, meta, autoload_with=conn) + stmt = select([tbl.c.c]).where(tbl.c.a.in_(duplicate_keys)) + + if isinstance(conn, Engine): + with conn.connect() as conn: + result = conn.execute(stmt) + else: + result = conn.execute(stmt) + return sorted(val[0] for val in result.fetchall()) + + @pytest.fixture def iris_path(datapath): iris_path = datapath("io", "data", "csv", "iris.csv") @@ -376,6 +507,18 @@ def test_frame3(): return DataFrame(data, columns=columns) +@pytest.fixture +def pkey_frame(): + columns = ["a", "b", "c"] + data = [ + (1, "name1", "new_val1"), + (2, "name2", "new_val2"), + (4, "name4", "val4"), + (5, "name5", "val5"), + ] + return DataFrame(data, columns=columns) + + @pytest.fixture def mysql_pymysql_engine(iris_path, types_data): sqlalchemy = pytest.importorskip("sqlalchemy") @@ -725,6 +868,28 @@ def load_types_data(self, types_data): else: create_and_load_types(self.conn, types_data, self.flavor) + @pytest.fixture + def load_pkey_data(self): + if not hasattr(self, "conn"): + self.setup_connect() + self.drop_table("pkey_table_single") + self.drop_table("pkey_table_comp") + if isinstance(self.conn, sqlite3.Connection): + create_and_load_pkey_sqlite3(self.conn) + else: + create_and_load_pkey(self.conn) + + def _check_iris_loaded_frame(self, iris_frame): + pytype = iris_frame.dtypes[0].type + row = iris_frame.iloc[0] + + assert issubclass(pytype, np.floating) + tm.equalContents(row.values, [5.1, 3.5, 1.4, 0.2, "Iris-setosa"]) + + def _read_sql_iris(self): + iris_frame = self.pandasSQL.read_query("SELECT * FROM iris") + self._check_iris_loaded_frame(iris_frame) + def _read_sql_iris_parameter(self): query = SQL_STRINGS["read_parameters"][self.flavor] params = ["Iris-setosa", 5.1] @@ -746,6 +911,142 @@ def _to_sql_empty(self, test_frame1): self.drop_table("test_frame1") assert self.pandasSQL.to_sql(test_frame1.iloc[:0], "test_frame1") == 0 + def _to_sql_on_row_conflict_replace(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original database table: 3 rows + - new dataframe: 4 rows (2 duplicate keys) + WHEN: + - on row conflict replace + THEN: + - DB table len = 5 + - Conflicting primary keys in DB updated + """ + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + assert len(pkey_frame.index) == 4 + assert ( + self.pandasSQL.to_sql( + pkey_frame, + tbl_name, + if_exists="append", + on_row_conflict="replace", + index=False, + method=method, + ) + == 4 + ) + assert count_rows(self.conn, tbl_name) == 5 + data_from_db = read_pkeys_from_database(self.conn, tbl_name, [1, 2]) + expected = sorted(["new_val1", "new_val2"]) + assert data_from_db == expected + self.drop_table(tbl_name) + + def _to_sql_on_row_conflict_ignore(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original table: 3 rows + - new dataframe: 4 rows (2 duplicate keys) + WHEN: + - on row conflict: ignore + THEN: + - database table len = 5 + - conflicting keys in table not updated + """ + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + duplicate_keys = [1, 2] + data_from_db_before = read_pkeys_from_database( + self.conn, tbl_name, duplicate_keys + ) + assert ( + self.pandasSQL.to_sql( + pkey_frame, + tbl_name, + if_exists="append", + on_row_conflict="ignore", + index=False, + method=method, + ) + == 2 + ) + assert count_rows(self.conn, tbl_name) == 5 + data_from_db_after = read_pkeys_from_database( + self.conn, tbl_name, duplicate_keys + ) + data_from_df = sorted( + pkey_frame.loc[pkey_frame["a"].isin(duplicate_keys), "c"].tolist() + ) + assert data_from_db_before == data_from_db_after + assert data_from_db_after != data_from_df + self.drop_table(tbl_name) + + def _test_to_sql_on_row_conflict_with_index(self, method, tbl_name, pkey_frame): + """ + GIVEN: + - Original db table: 3 rows + - New dataframe: 4 rows (2 duplicate keys), pkey as index + WHEN: + - inserting new data, noting the index column + - on row conflict replace + THEN: + - DB table len = 5 + - Conflicting primary keys in DB updated + """ + assert self.pandasSQL.has_table(tbl_name) + assert count_rows(self.conn, tbl_name) == 3 + if tbl_name == "pkey_table_single": + index_pkey_table = pkey_frame.set_index("a") + else: + index_pkey_table = pkey_frame.set_index(["a", "b"]) + assert ( + self.pandasSQL.to_sql( + index_pkey_table, + tbl_name, + if_exists="append", + on_row_conflict="replace", + index=True, + method=method, + ) + == 4 + ) + assert count_rows(self.conn, tbl_name) == 5 + data_from_db = read_pkeys_from_database(self.conn, tbl_name, [1, 2]) + expected = sorted(["new_val1", "new_val2"]) + assert data_from_db == expected + assert len(pkey_frame.index) == 4 + self.drop_table(tbl_name) + + def _to_sql_on_row_conflict_with_non_append( + self, if_exists, on_row_conflict, pkey_frame + ): + """ + GIVEN: + - to_sql is called + WHEN: + - `on_row_conflict` is not fail + - `if_exists` is set to a value other than `append` + THEN: + - ValueError is raised + """ + assert if_exists != "append" + with pytest.raises( + ValueError, + match=( + f"on_row_conflict {on_row_conflict} can only be used with 'append' " + "operations" + ), + ): + # Insert new dataframe + sql.to_sql( + pkey_frame, + "some_table", + con=self.conn, + if_exists=if_exists, + on_row_conflict=on_row_conflict, + index=False, + ) + def _to_sql_with_sql_engine(self, test_frame1, engine="auto", **engine_kwargs): """`to_sql` with the `engine` param""" # mostly copied from this class's `_to_sql()` method @@ -857,7 +1158,7 @@ def setup_connect(self): self.conn = self.connect() @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): self.load_test_data_and_sql() def load_test_data_and_sql(self): @@ -922,6 +1223,28 @@ def test_to_sql_series(self): s2 = sql.read_sql_query("SELECT * FROM test_series", self.conn) tm.assert_frame_equal(s.to_frame(), s2) + def test_to_sql_invalid_on_row_conflict(self, pkey_frame): + msg = "'notvalidvalue' is not valid for on_row_conflict" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + pkey_frame, + "pkey_frame1", + self.conn, + if_exists="append", + on_row_conflict="notvalidvalue", + ) + + def test_to_sql_on_row_conflict_non_append(self, pkey_frame): + msg = "on_row_conflict replace can only be used with 'append' operations" + with pytest.raises(ValueError, match=msg): + sql.to_sql( + pkey_frame, + "pkey_frame1", + self.conn, + if_exists="replace", + on_row_conflict="replace", + ) + def test_roundtrip(self, test_frame1): sql.to_sql(test_frame1, "test_frame_roundtrip", con=self.conn) result = sql.read_sql_query("SELECT * FROM test_frame_roundtrip", con=self.conn) @@ -1486,7 +1809,7 @@ class _EngineToConnMixin: """ @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): super().load_test_data_and_sql() engine = self.conn conn = engine.connect() @@ -1616,7 +1939,7 @@ def load_test_data_and_sql(self): pass @pytest.fixture(autouse=True) - def setup_method(self, load_iris_data, load_types_data): + def setup_method(self, load_iris_data, load_types_data, load_pkey_data): pass @classmethod @@ -1651,6 +1974,30 @@ def test_read_sql_named_parameter(self): def test_to_sql_empty(self, test_frame1): self._to_sql_empty(test_frame1) + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_on_row_conflict_ignore(self, method, tbl_name, pkey_frame): + self._to_sql_on_row_conflict_ignore(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_row_conflict_replace(self, method, tbl_name, pkey_frame): + self._to_sql_on_row_conflict_replace(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("method", [None, "multi"]) + @pytest.mark.parametrize("tbl_name", ["pkey_table_single", "pkey_table_comp"]) + def test_to_sql_on_row_conflict_with_index(self, method, tbl_name, pkey_frame): + self._test_to_sql_on_row_conflict_with_index(method, tbl_name, pkey_frame) + + @pytest.mark.parametrize("if_exists", ["fail", "replace"]) + @pytest.mark.parametrize("on_row_conflict", ["ignore", "replace"]) + def test_to_sql_row_conflict_with_non_append( + self, if_exists, on_row_conflict, pkey_frame + ): + self._to_sql_on_row_conflict_with_non_append( + if_exists, on_row_conflict, pkey_frame + ) + def test_create_table(self): from sqlalchemy import inspect