Skip to content

Commit e7a1da5

Browse files
author
James William Pye
committed
Add the ProducerFault back.
The user needs to be able to identify what caused an exception. Producer faults are usually fatal, but at least the failure will be identifiable.
1 parent 9febc86 commit e7a1da5

2 files changed

Lines changed: 81 additions & 29 deletions

File tree

postgresql/copyman.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,25 @@
2020
default_buffer_size = 1024 * 10
2121

2222
class Fault(Exception):
23+
pass
24+
25+
class ProducerFault(Fault):
26+
"""
27+
Exception raised when the Producer caused an exception.
28+
29+
Normally, Producer faults are fatal.
30+
"""
31+
def __init__(self, manager):
32+
self.manager = manager
33+
34+
def __str__(self):
35+
return "producer raised exception"
36+
37+
class ReceiverFaults(Fault):
2338
"""
24-
Receivers raised exceptions. This happens in cases where a receiver raises
25-
an exception. Faults should be trapped if recovery from an exception is
39+
Exception raised when Receivers cause an exception.
40+
41+
Faults should be trapped if recovery from an exception is
2642
possible, or if the failed receiver is optional to the succes of the
2743
operation.
2844
@@ -46,13 +62,20 @@ class CopyFail(Exception):
4662
4763
The 'reason' attribute is a string indicating why it failed.
4864
49-
The 'faults' attribute is a mapping of receivers to exceptions that were
65+
The 'receiver_faults' attribute is a mapping of receivers to exceptions that were
5066
raised on exit.
67+
68+
The 'producer_fault' attribute specifies if the producer raise an exception
69+
on exit.
5170
"""
52-
def __init__(self, manager, reason = None, faults = None):
71+
def __init__(self, manager, reason = None,
72+
receiver_faults = None,
73+
producer_fault = None,
74+
):
5375
self.manager = manager
5476
self.reason = reason
55-
self.faults = faults or {}
77+
self.receiver_faults = receiver_faults or {}
78+
self.producer_fault = producer_fault
5679

5780
def __str__(self):
5881
return self.reason or 'copy aborted'
@@ -683,27 +706,29 @@ def __exit__(self, typ, val, tb):
683706
# Don't recover on interrupts.
684707
return
685708

686-
# Does nothing if the COPY was successful.
687-
self.producer.realign()
709+
profail = None
688710
try:
689-
##
690-
# If the producer is not aligned to a message boundary,
691-
# it can emit completion data that will put the receivers
692-
# back on track.
693-
# This last service call will move that data onto the receivers.
694-
self._service_producer()
695-
##
696-
# The receivers need to handle any new data in their __exit__.
697-
except StopIteration:
698-
# No re-alignment needed.
699-
pass
711+
# Does nothing if the COPY was successful.
712+
self.producer.realign()
713+
try:
714+
##
715+
# If the producer is not aligned to a message boundary,
716+
# it can emit completion data that will put the receivers
717+
# back on track.
718+
# This last service call will move that data onto the receivers.
719+
self._service_producer()
720+
##
721+
# The receivers need to handle any new data in their __exit__.
722+
except StopIteration:
723+
# No re-alignment needed.
724+
pass
700725

701-
self.producer.__exit__(typ, val, tb)
726+
self.producer.__exit__(typ, val, tb)
727+
except Exception as profail:
728+
pass
702729

703730
# No receivers? It wasn't a success.
704731
if not self.receivers:
705-
if typ is CopyFail:
706-
raise
707732
raise CopyFail(self, "no receivers")
708733

709734
exit_faults = {}
@@ -736,13 +761,15 @@ def _service_producer(self):
736761
# Setup current data.
737762
if not self.receivers:
738763
# No receivers to take the data.
739-
raise CopyFail(self, "no receivers")
764+
raise StopIteration
740765

741766
try:
742767
nextdata = next(self.producer)
743768
except StopIteration:
744769
# Should be over.
745770
raise
771+
except Exception:
772+
raise ProducerFault(self)
746773

747774
self.transformer(nextdata)
748775

@@ -762,7 +789,7 @@ def _service_receivers(self):
762789
# The CopyManager is eager to continue the operation.
763790
for x in faults:
764791
self.receivers.discard(x)
765-
raise Fault(self, faults)
792+
raise ReceiverFaults(self, faults)
766793

767794
# Run the COPY to completion.
768795
def run(self):

postgresql/test/test_copyman.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def testCopyWithFailure(self):
328328
else:
329329
# Done with copy.
330330
break
331-
except copyman.Fault as cf:
331+
except copyman.ReceiverFaults as cf:
332332
if sr2 not in cf.faults:
333333
raise
334334
self.failUnless(done)
@@ -388,17 +388,17 @@ def testNoReceivers(self):
388388
done = False
389389
try:
390390
with copyman.CopyManager(sp, sr1) as copy:
391-
while True:
391+
while not done:
392392
try:
393393
for x in copy:
394394
if not done:
395395
done = True
396396
dst.pq.socket.close()
397397
else:
398398
self.fail("failed to detect dead socket")
399-
except copyman.Fault as cf:
399+
except copyman.ReceiverFaults as cf:
400400
self.failUnless(sr1 in cf.faults)
401-
# Don't reconcile.
401+
# Don't reconcile. Let the manager drop the receiver.
402402
except copyman.CopyFail:
403403
self.failUnless(not bool(copy.receivers))
404404
# Success.
@@ -436,7 +436,7 @@ def failed_write(*args):
436436
else:
437437
# Done with COPY, break out of while copy.receivers.
438438
break
439-
except copyman.Fault as cf:
439+
except copyman.ReceiverFaults as cf:
440440
if isinstance(cf.faults[sr], RecoverableError):
441441
if done is True:
442442
self.fail("failed_write was called twice?")
@@ -485,7 +485,7 @@ def failed_write(*args):
485485
else:
486486
# Done with COPY, break out of while copy.receivers.
487487
break
488-
except copyman.Fault as cf:
488+
except copyman.ReceiverFaults as cf:
489489
self.failUnless(isinstance(cf.faults[sr2], TheCause))
490490
if done is True:
491491
self.fail("failed_write was called twice?")
@@ -509,7 +509,32 @@ def failed_write(*args):
509509
self.failUnlessEqual(sp.count(), stdrowcount)
510510
self.failUnlessEqual(sp.command(), "COPY")
511511

512+
@pg_tmp
513+
def testProducerFailure(self):
514+
sqlexec(stdsource)
515+
dst = new()
516+
dst.execute(stddst)
517+
sp = copyman.StatementProducer(prepare(srcsql))
518+
sr = copyman.StatementReceiver(dst.prepare(dstsql))
519+
done = False
520+
try:
521+
with copyman.CopyManager(sp, sr) as copy:
522+
try:
523+
for x in copy:
524+
if not done:
525+
done = True
526+
db.pq.socket.close()
527+
except copyman.ProducerFault as pf:
528+
self.failUnless(pf.__context__ is not None)
529+
self.fail('expected CopyManager to raise CopyFail')
530+
except copyman.CopyFail as cf:
531+
pass
532+
self.failUnless(done)
533+
self.failUnlessRaises(Exception, sqlexec, 'select 1')
534+
self.failUnlessEqual(dst.prepare(dstcount).first(), 0)
535+
512536
from ..copyman import WireState
537+
513538
class test_WireState(unittest.TestCase):
514539
def testNormal(self):
515540
WS=WireState()

0 commit comments

Comments
 (0)