From 09f377e367a0f0f1d4cf1934e4efad761c4b9a60 Mon Sep 17 00:00:00 2001 From: danielballan Date: Thu, 17 Apr 2014 09:44:21 -0400 Subject: [PATCH] API: Stop modifying SQL column and names, and warn when pertinent. --- pandas/io/sql.py | 34 ++++++++++++++++++---------------- pandas/io/tests/test_sql.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index bed4c2da61c59..158ef7b7ed791 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -7,6 +7,7 @@ import warnings import itertools +import re import numpy as np import pandas.core.common as com @@ -38,11 +39,6 @@ def _convert_params(sql, params): return args -def _safe_col_name(col_name): - #TODO: probably want to forbid database reserved names, such as "database" - return col_name.strip().replace(' ', '_') - - def _handle_date_column(col, format=None): if isinstance(format, dict): return to_datetime(col, **format) @@ -587,11 +583,11 @@ def _index_name(self, index, index_label): def _create_table_statement(self): from sqlalchemy import Table, Column - safe_columns = map(_safe_col_name, self.frame.dtypes.index) + columns = list(map(str, self.frame.columns)) column_types = map(self._sqlalchemy_type, self.frame.dtypes) columns = [Column(name, typ) - for name, typ in zip(safe_columns, column_types)] + for name, typ in zip(columns, column_types)] if self.index is not None: for i, idx_label in enumerate(self.index[::-1]): @@ -836,6 +832,11 @@ def _create_sql_schema(self, frame, table_name): } +_SAFE_NAMES_WARNING = ("The spaces in these column names will not be changed." + "In pandas versions < 0.14, spaces were converted to " + "underscores.") + + class PandasSQLTableLegacy(PandasSQLTable): """Patch the PandasSQLTable for legacy support. Instead of a table variable just use the Create Table @@ -847,19 +848,18 @@ def create(self): self.pd_sql.execute(self.table) def insert_statement(self): - # Replace spaces in DataFrame column names with _. - safe_names = [_safe_col_name(n) for n in self.frame.dtypes.index] + names = list(map(str, self.frame.columns)) flv = self.pd_sql.flavor br_l = _SQL_SYMB[flv]['br_l'] # left val quote char br_r = _SQL_SYMB[flv]['br_r'] # right val quote char wld = _SQL_SYMB[flv]['wld'] # wildcard char if self.index is not None: - [safe_names.insert(0, idx) for idx in self.index[::-1]] + [names.insert(0, idx) for idx in self.index[::-1]] - bracketed_names = [br_l + column + br_r for column in safe_names] + bracketed_names = [br_l + column + br_r for column in names] col_names = ','.join(bracketed_names) - wildcards = ','.join([wld] * len(safe_names)) + wildcards = ','.join([wld] * len(names)) insert_statement = 'INSERT INTO %s (%s) VALUES (%s)' % ( self.name, col_names, wildcards) return insert_statement @@ -881,13 +881,15 @@ def insert(self): def _create_table_statement(self): "Return a CREATE TABLE statement to suit the contents of a DataFrame." - # Replace spaces in DataFrame column names with _. - safe_columns = [_safe_col_name(n) for n in self.frame.dtypes.index] + columns = list(map(str, self.frame.columns)) + pat = re.compile('\s+') + if any(map(pat.search, columns)): + warnings.warn(_SAFE_NAMES_WARNING) column_types = [self._sql_type_name(typ) for typ in self.frame.dtypes] if self.index is not None: for i, idx_label in enumerate(self.index[::-1]): - safe_columns.insert(0, idx_label) + columns.insert(0, idx_label) column_types.insert(0, self._sql_type_name(self.frame.index.get_level_values(i).dtype)) flv = self.pd_sql.flavor @@ -898,7 +900,7 @@ def _create_table_statement(self): col_template = br_l + '%s' + br_r + ' %s' columns = ',\n '.join(col_template % - x for x in zip(safe_columns, column_types)) + x for x in zip(columns, column_types)) template = """CREATE TABLE %(name)s ( %(columns)s )""" diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 9622f9d8790cb..ad3fa57ab48a7 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -550,6 +550,11 @@ def test_to_sql_index_label_multiindex(self): 'test_index_label', self.conn, if_exists='replace', index_label='C') + def test_integer_col_names(self): + df = DataFrame([[1, 2], [3, 4]], columns=[0, 1]) + sql.to_sql(df, "test_frame_integer_col_names", self.conn, + if_exists='replace') + class TestSQLApi(_TestSQLApi): """ @@ -661,10 +666,19 @@ def test_read_sql_delegate(self): self.assertRaises(ValueError, sql.read_sql, 'iris', self.conn, flavor=self.flavor) + def test_safe_names_warning(self): + # GH 6798 + df = DataFrame([[1, 2], [3, 4]], columns=['a', 'b ']) # has a space + # warns on create table with spaces in names + with tm.assert_produces_warning(): + sql.to_sql(df, "test_frame3_legacy", self.conn, + flavor="sqlite", index=False) + #------------------------------------------------------------------------------ #--- Database flavor specific tests + class _TestSQLAlchemy(PandasSQLTest): """ Base class for testing the sqlalchemy backend.