3434
3535from .. import types as pg_types
3636
37- is_showoption = lambda x : getattr (x , 'type' , None ) is element .ShowOption .type
38-
3937IDNS = 'py:%s'
4038def ID (s , title = None ):
4139 'generate an id for a client statement or cursor'
@@ -321,7 +319,10 @@ def __init__(self, statement, parameters):
321319 self .database = statement .database
322320 Output .__init__ (self , '' )
323321
324- def _init (self ):
322+ def _init (self ,
323+ null = element .Null .type ,
324+ complete = element .Complete .type ,
325+ ):
325326 expect = self ._expect
326327 self ._xact = self ._ins (
327328 self ._pq_xp_fetchall () + (element .SynchronizeMessage ,)
@@ -330,19 +331,19 @@ def _init(self):
330331 while self ._xact .state != xact .Complete :
331332 self .database ._pq_step ()
332333 for x in self ._xact .messages_received ():
333- if x .type is element . Null . type :
334+ if x .type == null :
334335 self .database ._pq_complete ()
335336 self ._xact = None
336337 return
337- elif x .type is element . Complete . type :
338+ elif x .type == complete :
338339 self ._complete_message = x
339340 self .database ._pq_complete ()
340341 # If this was a select/copy cursor,
341342 # the data messages would have caused an earlier
342343 # return.
343344 self ._xact = None
344345 return
345- elif x .type is expect :
346+ elif x .type == expect :
346347 # no need to step once this is seen
347348 return
348349 elif x .type in (
@@ -387,9 +388,11 @@ class SingleXactCopy(FetchAll):
387388class SingleXactFetch (FetchAll ):
388389 _expect = element .Tuple .type
389390 _process_chunk_ = FetchAll ._process_tuple_chunk_Row
390- def _process_chunk (self , x ):
391+ def _process_chunk (self , x ,
392+ tuple_type = element .Tuple .type
393+ ):
391394 return self ._process_chunk_ ((
392- y for y in x if y .type is element . Tuple . type
395+ y for y in x if y .type == tuple_type
393396 ))
394397
395398class MultiXactStream (Chunks ):
@@ -433,7 +436,7 @@ def __next__(self):
433436 self .database ._pq_complete ()
434437
435438 chunk = [
436- y for y in x .messages_received () if y .type is element .Tuple .type
439+ y for y in x .messages_received () if y .type == element .Tuple .type
437440 ]
438441 if len (chunk ) == self .chunksize :
439442 # there may be more, dispatch the request for the next chunk
@@ -496,7 +499,9 @@ def _which_way(self, direction):
496499 else :
497500 return self .direction
498501
499- def _init (self ):
502+ def _init (self ,
503+ tupledesc = element .TupleDescriptor .type ,
504+ ):
500505 """
501506 Based on the cursor parameters and the current transaction state,
502507 select a cursor strategy for managing the response from the server.
@@ -510,7 +515,7 @@ def _init(self):
510515 self .database ._pq_push (x , self )
511516 self .database ._pq_complete ()
512517 for m in x .messages_received ():
513- if m .type is element . TupleDescriptor . type :
518+ if m .type == tupledesc :
514519 self ._output = m
515520 self ._output_attmap = \
516521 self .database .typio .attribute_map (self ._output )
@@ -545,7 +550,7 @@ def _fetch(self, direction, quantity):
545550 self .database ._pq_push (x , self )
546551 self .database ._pq_complete ()
547552 return self ._process_tuple ((
548- y for y in x .messages_received () if y .type is element .Tuple .type
553+ y for y in x .messages_received () if y .type == element .Tuple .type
549554 ))
550555
551556 def seek (self , offset , whence = 'ABSOLUTE' ):
@@ -979,6 +984,7 @@ def first(self, *parameters):
979984 params ,
980985 self ._output_formats ,
981986 ),
987+ # Get all
982988 element .Execute (b'' , 0xFFFFFFFF ),
983989 element .SynchronizeMessage
984990 ),
@@ -990,8 +996,9 @@ def first(self, *parameters):
990996 if self ._output_io :
991997 ##
992998 # Look for the first tuple.
999+ tuple_type = element .Tuple .type
9931000 for xt in x .messages_received ():
994- if xt .type is element . Tuple . type :
1001+ if xt .type == tuple_type :
9951002 break
9961003 else :
9971004 return None
@@ -1012,8 +1019,14 @@ def first(self, *parameters):
10121019 else :
10131020 ##
10141021 # It doesn't return rows, so return a count.
1022+ ##
1023+ # This loop searches through the received messages
1024+ # for the Complete message which contains the count.
1025+ complete = element .Complete .type
10151026 for cm in x .messages_received ():
1016- if getattr (cm , 'type' , None ) == element .Complete .type :
1027+ # Use getattr because COPY doesn't produce
1028+ # element.Message instances.
1029+ if getattr (cm , 'type' , None ) == complete :
10171030 break
10181031 else :
10191032 # Probably a Null command.
@@ -1800,7 +1813,7 @@ def connect(self):
18001813
18011814 # When ssl is None: SSL negotiation will not occur.
18021815 # When ssl is True: SSL negotiation will occur *and* it must succeed.
1803- # When ssl is False: SSL negotiation will occur but NOSSL is okay .
1816+ # When ssl is False: SSL negotiation will occur but it may fail(NOSSL) .
18041817 if sslmode == 'allow' :
18051818 # first, without ssl, then with. :)
18061819 socket_factories = interlace (
@@ -1838,15 +1851,21 @@ def connect(self):
18381851 sf , self .connector ._startup_parameters ,
18391852 password = self .connector ._password ,
18401853 )
1854+ # Grab the negotiation transaction before
1855+ # connecting as it will be needed later if successful.
18411856 neg = pq .xact
18421857 pq .connect (ssl = ssl , timeout = timeout )
18431858
18441859 didssl = getattr (pq , 'ssl_negotiation' , - 1 )
1860+
1861+ # It successfully connected if pq.xact is None.
18451862 if pq .xact is None :
18461863 self .pq = pq
1847- for x in filter (is_showoption , neg .asyncs ):
1848- self ._receive_async (x )
18491864 self .security = 'ssl' if didssl is True else None
1865+ showoption_type = element .ShowOption .type
1866+ for x in neg .asyncs :
1867+ if x .type == showoption_type :
1868+ self ._receive_async (x )
18501869 # success!
18511870 break
18521871
@@ -2037,20 +2056,24 @@ def _error_lookup(self, om : element.Error,) -> pg_exc.Error:
20372056 err .database = self
20382057 return err
20392058
2040- def _receive_async (self , msg , controller = None ):
2059+ def _receive_async (self , msg , controller = None ,
2060+ showoption = element .ShowOption .type ,
2061+ notice = element .Notice .type ,
2062+ notify = element .Notify .type ,
2063+ ):
20412064 c = controller or getattr (self , '_controller' , self )
2042- if msg .type is element . ShowOption . type :
2065+ if msg .type == showoption :
20432066 if msg .name == b'client_encoding' :
20442067 self .typio .set_encoding (msg .value .decode ('ascii' ))
20452068 self .settings ._notify (msg )
2046- elif msg .type is element . Notice . type :
2069+ elif msg .type == notice :
20472070 src = 'SERVER'
20482071 if type (msg ) is element .ClientNotice :
20492072 src = 'CLIENT'
20502073 m = self ._convert_pq_message (msg , source = src )
20512074 m .creator = c
20522075 m .raise_message ()
2053- elif msg .type is element . Notify . type :
2076+ elif msg .type == notify :
20542077 subs = getattr (self , '_subscriptions' , {})
20552078 for x in subs .get (msg .relation , ()):
20562079 x (self , msg )
0 commit comments