Skip to content

Commit 99b0057

Browse files
committed
string: Change string.quote_ident to always quote, add quote_ident_if_needed
string.quote_ident() and string.qname() will now always returned a quoted string. Introduce string.quote_ident_if_needed() and string.qname_if_needed() for old behavior quoting only when unsafe characters are present.
1 parent 61c5a9a commit 99b0057

3 files changed

Lines changed: 39 additions & 17 deletions

File tree

postgresql/string.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ def escape_ident(text):
2727
'Replace every instance of " with ""'
2828
return text.replace('"', '""')
2929

30+
def needs_quoting(text):
31+
return not (text and not text[0].isdecimal() and text.replace('_', 'a').isalnum())
32+
3033
def quote_ident(text):
34+
"Replace every instance of '"' with '""' *and* place '"' on each end"
35+
return '"' + text.replace('"', '""') + '"'
36+
37+
def quote_ident_if_needed(text):
3138
"""
32-
If needed, replace every instance of '"' with '""' *and* place '"' on each
33-
end. Otherwise, just return the text.
39+
If needed, replace every instance of '"' with '""' *and* place '"' on each end.
40+
Otherwise, just return the text.
3441
"""
35-
if not text or text[0].isdecimal() or (
36-
not text.replace('_', 'a').isalnum()
37-
):
38-
return '"' + text.replace('"', '""') + '"'
39-
return text
42+
return quote_ident(text) if needs_quoting(text) else text
4043

4144
quote_re = re.compile(r"""(?xu)
4245
E'(?:''|\\.|[^'])*(?:'|$) (?# Backslash escapes E'str')
@@ -214,6 +217,9 @@ def qname(*args):
214217
"Quote the identifiers and join them using '.'"
215218
return '.'.join([quote_ident(x) for x in args])
216219

220+
def qname_if_needed(*args):
221+
return '.'.join([quote_ident_if_needed(x) for x in args])
222+
217223
def split_sql(sql, sep = ';'):
218224
"""
219225
Given SQL, safely split using the given separator.

postgresql/test/test_driver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ def testStatementAndCursorMetadata(self):
520520
myudt_oid = db.prepare("select oid from pg_type WHERE typname='myudt'").first()
521521
ps = db.prepare("SELECT $1::text AS my_column1, $2::varchar AS my_column2, $3::public.myudt AS my_column3")
522522
self.assertEqual(tuple(ps.column_names), ('my_column1','my_column2', 'my_column3'))
523-
self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', 'public.myudt'))
524-
self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING', 'public.myudt'))
523+
self.assertEqual(tuple(ps.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"'))
524+
self.assertEqual(tuple(ps.sql_parameter_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"'))
525525
self.assertEqual(tuple(ps.pg_column_types), (
526526
pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid)
527527
)
@@ -532,7 +532,7 @@ def testStatementAndCursorMetadata(self):
532532
self.assertEqual(tuple(ps.column_types), (str,str,tuple))
533533
c = ps.declare('textdata', 'varchardata', (123,))
534534
self.assertEqual(tuple(c.column_names), ('my_column1','my_column2', 'my_column3'))
535-
self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', 'public.myudt'))
535+
self.assertEqual(tuple(c.sql_column_types), ('pg_catalog.text', 'CHARACTER VARYING', '"public"."myudt"'))
536536
self.assertEqual(tuple(c.pg_column_types), (
537537
pg_types.TEXTOID, pg_types.VARCHAROID, myudt_oid
538538
))

postgresql/test/test_string.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_qname(self):
197197
for unsplit, split, norm in split_qname_samples:
198198
xsplit = pg_str.split_qname(unsplit)
199199
self.assertEqual(xsplit, split)
200-
self.assertEqual(pg_str.qname(*split), norm)
200+
self.assertEqual(pg_str.qname_if_needed(*split), norm)
201201

202202
self.assertRaises(
203203
ValueError,
@@ -215,6 +215,14 @@ def test_qname(self):
215215
ValueError,
216216
pg_str.split_qname, 'bar".foo"'
217217
)
218+
self.assertRaises(
219+
ValueError,
220+
pg_str.split_qname, '0bar.foo'
221+
)
222+
self.assertRaises(
223+
ValueError,
224+
pg_str.split_qname, 'bar.fo@'
225+
)
218226

219227
def test_quotes(self):
220228
self.assertEqual(
@@ -226,29 +234,37 @@ def test_quotes(self):
226234
"""'\\foo''bar\\'"""
227235
)
228236
self.assertEqual(
229-
pg_str.quote_ident("foo"),
237+
pg_str.quote_ident_if_needed("foo"),
230238
"foo"
231239
)
232240
self.assertEqual(
233-
pg_str.quote_ident("0foo"),
241+
pg_str.quote_ident_if_needed("0foo"),
234242
'"0foo"'
235243
)
236244
self.assertEqual(
237-
pg_str.quote_ident("foo0"),
245+
pg_str.quote_ident_if_needed("foo0"),
238246
'foo0'
239247
)
240248
self.assertEqual(
241-
pg_str.quote_ident("_"),
249+
pg_str.quote_ident_if_needed("_"),
242250
'_'
243251
)
244252
self.assertEqual(
245-
pg_str.quote_ident("_9"),
253+
pg_str.quote_ident_if_needed("_9"),
246254
'_9'
247255
)
248256
self.assertEqual(
249-
pg_str.quote_ident('''\\foo'bar\\'''),
257+
pg_str.quote_ident_if_needed('''\\foo'bar\\'''),
250258
'''"\\foo'bar\\"'''
251259
)
260+
self.assertEqual(
261+
pg_str.quote_ident("spam"),
262+
'"spam"'
263+
)
264+
self.assertEqual(
265+
pg_str.qname("spam", "ham"),
266+
'"spam"."ham"'
267+
)
252268
self.assertEqual(
253269
pg_str.escape_ident('"'),
254270
'""',

0 commit comments

Comments
 (0)