diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index eca51ed5..7adc8835 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -110,6 +110,10 @@ class object, used to create cursors (keyword only) :param int client_flag: flags to use or 0 (see MySQL docs or constants/CLIENTS.py) + :param bool multi_statements: + If True, enable multi statements for clients >= 4.1. + Defaults to True. + :param str ssl_mode: specify the security settings for connection to the server; see the MySQL documentation for more details @@ -169,11 +173,16 @@ class object, used to create cursors (keyword only) self._binary_prefix = kwargs2.pop("binary_prefix", False) client_flag = kwargs.get("client_flag", 0) + client_version = tuple( [numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]] ) - if client_version >= (4, 1): - client_flag |= CLIENT.MULTI_STATEMENTS + + multi_statements = kwargs2.pop("multi_statements", True) + if multi_statements: + if client_version >= (4, 1): + client_flag |= CLIENT.MULTI_STATEMENTS + if client_version >= (5, 0): client_flag |= CLIENT.MULTI_RESULTS diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..960de572 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,26 @@ +import pytest + +from MySQLdb._exceptions import ProgrammingError + +from configdb import connection_factory + + +def test_multi_statements_default_true(): + conn = connection_factory() + cursor = conn.cursor() + + cursor.execute("select 17; select 2") + rows = cursor.fetchall() + assert rows == ((17,),) + + +def test_multi_statements_false(): + conn = connection_factory(multi_statements=False) + cursor = conn.cursor() + + with pytest.raises(ProgrammingError): + cursor.execute("select 17; select 2") + + cursor.execute("select 17") + rows = cursor.fetchall() + assert rows == ((17,),)