Skip to content

Commit e386cda

Browse files
author
James William Pye
committed
Minor refactoring the asyncnotify test.
The problem was that the CopyFail was not occuring in some cases. So make the test avoid potential failure(likely due to OS buffer size differences). Additionally, add a test that shows early breaks being detected by the CopyManager.
1 parent 2152852 commit e386cda

2 files changed

Lines changed: 24 additions & 9 deletions

File tree

postgresql/copyman.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def __init__(self, producer, *receivers):
676676
self.producer = producer
677677
self.transformer = None
678678
self.receivers = ElementSet(receivers)
679+
self._seen_stop_iteration = False
679680
rp = set()
680681
add = rp.add
681682
for x in self.receivers:
@@ -739,7 +740,7 @@ def __exit__(self, typ, val, tb):
739740
except Exception as e:
740741
exit_faults[x] = e
741742

742-
if typ or exit_faults or profail:
743+
if typ or exit_faults or profail or not self._seen_stop_iteration:
743744
raise CopyFail(self,
744745
"could not complete the COPY operation",
745746
receiver_faults = exit_faults or None,
@@ -770,6 +771,7 @@ def _service_producer(self):
770771
nextdata = next(self.producer)
771772
except StopIteration:
772773
# Should be over.
774+
self._seen_stop_iteration = True
773775
raise
774776
except Exception:
775777
raise ProducerFault(self)

postgresql/test/test_copyman.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,19 +247,32 @@ def testAsyncNotify(self):
247247
sqlexec(stdsource)
248248
dst = new()
249249
dst.execute(stddst)
250-
stmt = prepare(srcsql)
251-
sp = Injector(notify, prepare(srcsql), buffer_size = 133)
250+
sp = Injector(notify, prepare(srcsql), buffer_size = 32)
252251
sr = copyman.StatementReceiver(dst.prepare(dstsql))
253252
seen_in_loop = 0
253+
r = []
254254
with copyman.CopyManager(sp, sr) as copy:
255255
for x in copy:
256-
r = list(db.iternotifies(0))
257-
if r:
258-
break
259-
else:
260-
self.fail("didn't pickup notify during copy")
256+
r += list(db.iternotifies(0))
261257
# Got the injected NOTIFY's, right?
262-
self.failUnlessEqual(r, [('channel', 'payload', 1234)])
258+
self.failUnless(r)
259+
# it may have happened multiple times, so adjust accordingly.
260+
self.failUnlessEqual(r, [('channel', 'payload', 1234)]*len(r))
261+
262+
@pg_tmp
263+
def testUnfinishedCopy(self):
264+
sqlexec(stdsource)
265+
dst = new()
266+
dst.execute(stddst)
267+
sp = copyman.StatementProducer(prepare(srcsql), buffer_size = 32)
268+
sr = copyman.StatementReceiver(dst.prepare(dstsql))
269+
try:
270+
with copyman.CopyManager(sp, sr) as copy:
271+
for x in copy:
272+
break
273+
self.fail("did not raise CopyFail")
274+
except copyman.CopyFail:
275+
pass
263276

264277
@pg_tmp
265278
def testRaiseInCopy(self):

0 commit comments

Comments
 (0)