Skip to content

Commit 1d856f7

Browse files
author
James William Pye
committed
Correct message type comparisons.
Using the 'is' operator only works so long as the message_types data is consistent. It is possible for that consistency to fail in cases where modules are reloaded.
1 parent 422fd07 commit 1d856f7

6 files changed

Lines changed: 82 additions & 42 deletions

File tree

postgresql/documentation/changes.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@ Changes
77
* Correct types.Array binary serialization (8.5 added extra checks)
88
* Add more encoding name mapping entries.
99
* Documentation improvements.
10+
* Correct protocol message type comparison method.
11+
In 0.9.1 and before, types were compared using 'is' because all
12+
message type objects were loaded from a tuple created by the
13+
postgresql.protocol.message_types module. In some environments,
14+
sys.modules is cleared. For pure-Python installations, this is unlikely
15+
to cause a problem, but when the optimized module is available, it will hold
16+
a reference to the "old" version of the message_types data. Subsequently, the
17+
use of the 'is' operator is then broken resulting in superfluous protocol
18+
errors. The fix for 0.9.2 is to use the actual comparison operator.
19+
This does not add any significant overhead, but use of 'is' may return if
20+
some permanent storage method becomes available.
21+
[Reported by Radomir Stevanovic]
22+
1023

1124
0.9.1 released on 2009-08-12
1225
----------------------------

postgresql/driver/pq3.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434

3535
from .. import types as pg_types
3636

37-
is_showoption = lambda x: getattr(x, 'type', None) is element.ShowOption.type
38-
3937
IDNS = 'py:%s'
4038
def ID(s, title = None):
4139
'generate an id for a client statement or cursor'
@@ -321,7 +319,10 @@ def __init__(self, statement, parameters):
321319
self.database = statement.database
322320
Output.__init__(self, '')
323321

324-
def _init(self):
322+
def _init(self,
323+
null = element.Null.type,
324+
complete = element.Complete.type,
325+
):
325326
expect = self._expect
326327
self._xact = self._ins(
327328
self._pq_xp_fetchall() + (element.SynchronizeMessage,)
@@ -330,19 +331,19 @@ def _init(self):
330331
while self._xact.state != xact.Complete:
331332
self.database._pq_step()
332333
for x in self._xact.messages_received():
333-
if x.type is element.Null.type:
334+
if x.type == null:
334335
self.database._pq_complete()
335336
self._xact = None
336337
return
337-
elif x.type is element.Complete.type:
338+
elif x.type == complete:
338339
self._complete_message = x
339340
self.database._pq_complete()
340341
# If this was a select/copy cursor,
341342
# the data messages would have caused an earlier
342343
# return.
343344
self._xact = None
344345
return
345-
elif x.type is expect:
346+
elif x.type == expect:
346347
# no need to step once this is seen
347348
return
348349
elif x.type in (
@@ -387,9 +388,11 @@ class SingleXactCopy(FetchAll):
387388
class SingleXactFetch(FetchAll):
388389
_expect = element.Tuple.type
389390
_process_chunk_ = FetchAll._process_tuple_chunk_Row
390-
def _process_chunk(self, x):
391+
def _process_chunk(self, x,
392+
tuple_type = element.Tuple.type
393+
):
391394
return self._process_chunk_((
392-
y for y in x if y.type is element.Tuple.type
395+
y for y in x if y.type == tuple_type
393396
))
394397

395398
class MultiXactStream(Chunks):
@@ -433,7 +436,7 @@ def __next__(self):
433436
self.database._pq_complete()
434437

435438
chunk = [
436-
y for y in x.messages_received() if y.type is element.Tuple.type
439+
y for y in x.messages_received() if y.type == element.Tuple.type
437440
]
438441
if len(chunk) == self.chunksize:
439442
# there may be more, dispatch the request for the next chunk
@@ -496,7 +499,9 @@ def _which_way(self, direction):
496499
else:
497500
return self.direction
498501

499-
def _init(self):
502+
def _init(self,
503+
tupledesc = element.TupleDescriptor.type,
504+
):
500505
"""
501506
Based on the cursor parameters and the current transaction state,
502507
select a cursor strategy for managing the response from the server.
@@ -510,7 +515,7 @@ def _init(self):
510515
self.database._pq_push(x, self)
511516
self.database._pq_complete()
512517
for m in x.messages_received():
513-
if m.type is element.TupleDescriptor.type:
518+
if m.type == tupledesc:
514519
self._output = m
515520
self._output_attmap = \
516521
self.database.typio.attribute_map(self._output)
@@ -545,7 +550,7 @@ def _fetch(self, direction, quantity):
545550
self.database._pq_push(x, self)
546551
self.database._pq_complete()
547552
return self._process_tuple((
548-
y for y in x.messages_received() if y.type is element.Tuple.type
553+
y for y in x.messages_received() if y.type == element.Tuple.type
549554
))
550555

551556
def seek(self, offset, whence = 'ABSOLUTE'):
@@ -979,6 +984,7 @@ def first(self, *parameters):
979984
params,
980985
self._output_formats,
981986
),
987+
# Get all
982988
element.Execute(b'', 0xFFFFFFFF),
983989
element.SynchronizeMessage
984990
),
@@ -990,8 +996,9 @@ def first(self, *parameters):
990996
if self._output_io:
991997
##
992998
# Look for the first tuple.
999+
tuple_type = element.Tuple.type
9931000
for xt in x.messages_received():
994-
if xt.type is element.Tuple.type:
1001+
if xt.type == tuple_type:
9951002
break
9961003
else:
9971004
return None
@@ -1012,8 +1019,14 @@ def first(self, *parameters):
10121019
else:
10131020
##
10141021
# It doesn't return rows, so return a count.
1022+
##
1023+
# This loop searches through the received messages
1024+
# for the Complete message which contains the count.
1025+
complete = element.Complete.type
10151026
for cm in x.messages_received():
1016-
if getattr(cm, 'type', None) == element.Complete.type:
1027+
# Use getattr because COPY doesn't produce
1028+
# element.Message instances.
1029+
if getattr(cm, 'type', None) == complete:
10171030
break
10181031
else:
10191032
# Probably a Null command.
@@ -1800,7 +1813,7 @@ def connect(self):
18001813

18011814
# When ssl is None: SSL negotiation will not occur.
18021815
# When ssl is True: SSL negotiation will occur *and* it must succeed.
1803-
# When ssl is False: SSL negotiation will occur but NOSSL is okay.
1816+
# When ssl is False: SSL negotiation will occur but it may fail(NOSSL).
18041817
if sslmode == 'allow':
18051818
# first, without ssl, then with. :)
18061819
socket_factories = interlace(
@@ -1838,15 +1851,21 @@ def connect(self):
18381851
sf, self.connector._startup_parameters,
18391852
password = self.connector._password,
18401853
)
1854+
# Grab the negotiation transaction before
1855+
# connecting as it will be needed later if successful.
18411856
neg = pq.xact
18421857
pq.connect(ssl = ssl, timeout = timeout)
18431858

18441859
didssl = getattr(pq, 'ssl_negotiation', -1)
1860+
1861+
# It successfully connected if pq.xact is None.
18451862
if pq.xact is None:
18461863
self.pq = pq
1847-
for x in filter(is_showoption, neg.asyncs):
1848-
self._receive_async(x)
18491864
self.security = 'ssl' if didssl is True else None
1865+
showoption_type = element.ShowOption.type
1866+
for x in neg.asyncs:
1867+
if x.type == showoption_type:
1868+
self._receive_async(x)
18501869
# success!
18511870
break
18521871

@@ -2037,20 +2056,24 @@ def _error_lookup(self, om : element.Error,) -> pg_exc.Error:
20372056
err.database = self
20382057
return err
20392058

2040-
def _receive_async(self, msg, controller = None):
2059+
def _receive_async(self, msg, controller = None,
2060+
showoption = element.ShowOption.type,
2061+
notice = element.Notice.type,
2062+
notify = element.Notify.type,
2063+
):
20412064
c = controller or getattr(self, '_controller', self)
2042-
if msg.type is element.ShowOption.type:
2065+
if msg.type == showoption:
20432066
if msg.name == b'client_encoding':
20442067
self.typio.set_encoding(msg.value.decode('ascii'))
20452068
self.settings._notify(msg)
2046-
elif msg.type is element.Notice.type:
2069+
elif msg.type == notice:
20472070
src = 'SERVER'
20482071
if type(msg) is element.ClientNotice:
20492072
src = 'CLIENT'
20502073
m = self._convert_pq_message(msg, source = src)
20512074
m.creator = c
20522075
m.raise_message()
2053-
elif msg.type is element.Notify.type:
2076+
elif msg.type == notify:
20542077
subs = getattr(self, '_subscriptions', {})
20552078
for x in subs.get(msg.relation, ()):
20562079
x(self, msg)

postgresql/protocol/message_types.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
Data module providing a sequence of bytes objects whose value corresponds to its
77
index in the sequence.
88
9-
This provides resource for buffer objects to use common message type objects,
10-
and the additional privilege of using the `is` operator in all situations
11-
involving message type comparisons.
9+
This provides resource for buffer objects to use common message type objects.
10+
11+
WARNING: It's tempting to use the 'is' operator and in some circumstances that
12+
may be okay. However, it's possible (sys.modules.clear()) for the extension
13+
modules' copy of this to become inconsistent with what protocol.element3 and
14+
protocol.xact3 are using, so it's important to **not** use 'is'.
1215
"""
1316
message_types = tuple([bytes((x,)) for x in range(256)])

postgresql/protocol/pbuffer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,14 @@ def __len__(self):
6969
rpos = self._start
7070
self._strio.seek(self._start)
7171
while True:
72+
# get the message metadata
7273
header = self._strio.read(5)
7374
rpos += 5
7475
if len(header) < 5:
76+
# not enough data for another message
7577
break
78+
# unpack the length from the header
7679
length, = xl_unpack(header)
77-
typ = message_types[header[0]]
7880
rpos += length - 4
7981

8082
if length < 4:
@@ -86,12 +88,12 @@ def __len__(self):
8688
count += 1
8789
return count
8890

89-
def _get_message(self):
91+
def _get_message(self, mtypes = message_types):
9092
header = self._strio.read(5)
9193
if len(header) < 5:
9294
return
9395
length, = xl_unpack(header)
94-
typ = message_types[header[0]]
96+
typ = mtypes[header[0]]
9597

9698
if length < 4:
9799
raise ValueError("invalid message size '%d'" %(length,))

postgresql/protocol/xact3.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def put_messages(self, messages):
134134
try:
135135
for x in messages:
136136
count += 1
137-
if x[0] is element.Error.type:
137+
if x[0] == element.Error.type:
138138
if self.fatal is None:
139139
self.error_message = element.Error.parse(x[1])
140140
self.fatal = True
@@ -177,7 +177,7 @@ def state_machine(self):
177177
"""
178178
x = (yield (self.startup_message,))
179179

180-
if x[0] is not element.Authentication.type:
180+
if x[0] != element.Authentication.type:
181181
self.fatal = True
182182
self.error_message = element.ClientError(
183183
message = \
@@ -231,7 +231,7 @@ def state_machine(self):
231231

232232
# Done authenticating, pick up the killinfo and the ready message.
233233
x = (yield None)
234-
if x[0] is not element.KillInformation.type:
234+
if x[0] != element.KillInformation.type:
235235
self.fatal = True
236236
self.error_message = element.ClientError(
237237
message = \
@@ -246,7 +246,7 @@ def state_machine(self):
246246
self.killinfo = element.KillInformation.parse(x[1])
247247

248248
x = (yield None)
249-
if x[0] is not element.Ready.type:
249+
if x[0] != element.Ready.type:
250250
self.fatal = True
251251
self.error_message = element.ClientError(
252252
message = \
@@ -476,7 +476,7 @@ def standard_put(self, messages):
476476

477477
if path is None:
478478
# No path for message type, could be a protocol error.
479-
if x[0] is element.Error.type:
479+
if x[0] == element.Error.type:
480480
em = element.Error.parse(x[1])
481481
fatal = em['severity'].upper() in (b'FATAL', b'PANIC')
482482
self.error_message = em
@@ -541,7 +541,7 @@ def standard_put(self, messages):
541541
current_step = next_step
542542
else:
543543
current_step = 0
544-
if r.type is element.Ready.type:
544+
if r.type == element.Ready.type:
545545
self.last_ready = r.xact_state
546546
# Done with the current command. Increment the offset, and
547547
# try to process the new command with the remaining data.
@@ -582,11 +582,11 @@ def standard_put(self, messages):
582582
last = processed[-1]
583583
if type(last) is bytes:
584584
self.state = (Receiving, self.put_copydata)
585-
elif last.type is element.CopyToBegin.type:
585+
elif last.type == element.CopyToBegin.type:
586586
self.state = (Receiving, self.put_copydata)
587-
elif last.type is element.Tuple.type:
587+
elif last.type == element.Tuple.type:
588588
self.state = (Receiving, self.put_tupledata)
589-
elif last.type is element.CopyFromBegin.type:
589+
elif last.type == element.CopyFromBegin.type:
590590
self.CopyFailSequence = (self.CopyFailMessage,) + \
591591
self.commands[offset+1:]
592592
self.CopyDoneSequence = (element.CopyDoneMessage,) + \
@@ -601,13 +601,13 @@ def put_copydata(self, messages):
601601
message is received, it reverts the ``state`` attribute back to
602602
`standard_put` to process the message-sequence.
603603
"""
604+
copydata = element.CopyData.type
604605
# "Fail" quickly if the last message is not copy data.
605-
if messages[-1][0] is not element.CopyData.type:
606+
if messages[-1][0] != copydata:
606607
self.state = (Receiving, self.standard_put)
607608
return self.standard_put(messages)
608609

609-
cdt = element.CopyData.type
610-
lines = [x[1] for x in messages if x[0] is cdt]
610+
lines = [x[1] for x in messages if x[0] == copydata]
611611
if len(lines) != len(messages):
612612
self.state = (Receiving, self.standard_put)
613613
return self.standard_put(messages)
@@ -624,13 +624,13 @@ def put_tupledata(self, messages):
624624
"""
625625
# Fallback to `standard_put` quickly if the last
626626
# message is not tuple data.
627-
if messages[-1][0] is not element.Tuple.type:
627+
if messages[-1][0] != element.Tuple.type:
628628
self.state = (Receiving, self.standard_put)
629629
return self.standard_put(messages)
630630

631631
p = element.Tuple.parse
632632
t = element.Tuple.type
633-
tuplemessages = [p(x[1]) for x in messages if x[0] is t]
633+
tuplemessages = [p(x[1]) for x in messages if x[0] == t]
634634
if len(tuplemessages) != len(messages):
635635
self.state = (Receiving, self.standard_put)
636636
return self.standard_put(messages)

postgresql/test/testall.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import warnings
99
from ..installation import Installation
1010

11-
1211
from .test_exceptions import *
1312
from .test_bytea_codec import *
1413
from .test_iri import *

0 commit comments

Comments
 (0)