Skip to content

Commit d758ba0

Browse files
author
James William Pye
committed
Refactor the tests to use .temporal.
This reduces the number of clusters created during testall, and makes the per-method setup/teardown explicit.
1 parent c3754f1 commit d758ba0

8 files changed

Lines changed: 889 additions & 534 deletions

File tree

postgresql/cluster.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
##
2-
# copyright 2009, James William Pye
3-
# http://python.projects.postgresql.org
2+
# .cluster - PostgreSQL cluster management
43
##
54
"""
65
Create, control, and destroy PostgreSQL clusters.
@@ -12,7 +11,6 @@
1211
import os
1312
import errno
1413
import time
15-
import io
1614
import subprocess as sp
1715
from tempfile import NamedTemporaryFile
1816

postgresql/installation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
##
2-
# copyright 2009, James William Pye
3-
# http://python.projects.postgresql.org
2+
# .installation
43
##
5-
'pg_config Python interface; provides member based access to pg_config data'
4+
"""
5+
Collect and access PostgreSQL installation information.
6+
"""
67
import sys
78
import os
89
import os.path

postgresql/temporal.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
##
2+
# .temporal - manage the temporary cluster
3+
##
4+
"""
5+
Temporary PostgreSQL cluster for the process.
6+
"""
7+
import os
8+
import atexit
9+
from .cluster import Cluster, ClusterError
10+
from . import installation
11+
from .python.socket import find_available_port
12+
13+
class Temporal(object):
14+
"""
15+
Manages a temporary cluster for the duration of the process.
16+
17+
Instances of this class reference a distinct cluster. These clusters are
18+
transient; they will only exist until the process exits.
19+
20+
Usage::
21+
22+
>>> from postgresql.temporal import pg_tmp
23+
>>> with pg_tmp:
24+
... ps = db.prepare('SELECT 1')
25+
... assert ps.first() == 1
26+
27+
Or `pg_tmp` can decorate a method or function.
28+
"""
29+
cluster_dirname = 'pg_tmp_{0}_{1}'.format
30+
cluster = None
31+
_init_pid_ = None
32+
_local_id_ = 0
33+
builtins_keys = {
34+
'connector',
35+
'db',
36+
'do',
37+
'xact',
38+
'proc',
39+
'settings',
40+
'prepare',
41+
'sqlexec',
42+
'newdb',
43+
}
44+
45+
def __init__(self):
46+
self.builtins_stack = []
47+
self.sandbox_id = 0
48+
self.__class__._local_id_ = self.local_id = (self.__class__._local_id_ + 1)
49+
50+
def __call__(self, callable):
51+
def incontext(*args, **kw):
52+
with self:
53+
return callable(*args, **kw)
54+
return incontext
55+
56+
def reset(self):
57+
# Don't reset if it's not the initializing process.
58+
if os.getpid() == self._init_pid_:
59+
cluster = self.cluster
60+
self.cluster = None
61+
self._init_pid_ = None
62+
if cluster is not None:
63+
cluster.drop()
64+
65+
def init(self,
66+
installation_factory = installation.default,
67+
inshint = {
68+
'hint' : "Try setting the PGINSTALLATION " \
69+
"environment variable to the `pg_config` path"
70+
}
71+
):
72+
if self.cluster is not None:
73+
return
74+
##
75+
# Hasn't been created yet, but doesn't matter.
76+
# On exit, obliterate the cluster directory.
77+
self._init_pid_ = os.getpid()
78+
atexit.register(self.reset)
79+
80+
# [$HOME|.]/.pg_tmpdb_{pid}
81+
self.cluster_path = os.path.join(
82+
os.environ.get('HOME', os.getcwd()),
83+
self.cluster_dirname(self._init_pid_, self.local_id)
84+
)
85+
self.logfile = os.path.join(self.cluster_path, 'logfile')
86+
installation = installation_factory()
87+
if installation is None:
88+
raise ClusterError(
89+
'could not find the default pg_config', details = inshint
90+
)
91+
92+
cluster = Cluster(
93+
installation,
94+
self.cluster_path,
95+
)
96+
97+
# If it exists already, destroy it.
98+
if cluster.initialized():
99+
cluster.drop()
100+
cluster.encoding = 'utf-8'
101+
cluster.init(
102+
user = 'test', # Consistent username.
103+
encoding = cluster.encoding,
104+
logfile = None,
105+
)
106+
107+
# Configure
108+
self.cluster_port = find_available_port()
109+
if self.cluster_port is None:
110+
raise ClusterError(
111+
'could not find a port for the test cluster on localhost',
112+
creator = cluster
113+
)
114+
115+
cluster.settings.update(dict(
116+
port = str(self.cluster_port),
117+
max_connections = '20',
118+
shared_buffers = '200',
119+
listen_addresses = 'localhost',
120+
log_destination = 'stderr',
121+
log_min_messages = 'FATAL',
122+
silent_mode = 'off',
123+
))
124+
cluster.settings.update(dict(
125+
max_prepared_transactions = '10',
126+
))
127+
128+
# Start it up.
129+
cluster.start(logfile = open(self.logfile, 'w'))
130+
cluster.wait_until_started()
131+
132+
# Initialize template1 and the test user database.
133+
c = cluster.connection(
134+
user = 'test', database = 'template1',
135+
)
136+
with c:
137+
c.execute('create database test')
138+
# It's ready.
139+
self.cluster = cluster
140+
141+
def push(self):
142+
c = self.cluster.connection(user = 'test')
143+
extras = []
144+
145+
def newdb(l = extras, c = c, sbid = 'sandbox' + str(self.sandbox_id + 1)):
146+
# Used to create a new connection that will be closed
147+
# when the context stack is popped along with 'db'.
148+
l.append(c.clone())
149+
l[-1].settings['search_path'] = str(sbid) + ',' + l[-1].settings['search_path']
150+
return l[-1]
151+
152+
# The new builtins.
153+
builtins = {
154+
'db' : c,
155+
'prepare' : c.prepare,
156+
'xact' : c.xact,
157+
'sqlexec' : c.execute,
158+
'do' : c.do,
159+
'settings' : c.settings,
160+
'proc' : c.proc,
161+
'connector' : c.connector,
162+
'new' : newdb,
163+
}
164+
if not self.builtins_stack:
165+
# Store any of those set or not set.
166+
current = {
167+
k : __builtins__[k] for k in self.builtins_keys
168+
if k in __builtins__
169+
}
170+
self.builtins_stack.append((current, []))
171+
172+
# Store and push.
173+
self.builtins_stack.append((builtins, extras))
174+
__builtins__.update(builtins)
175+
self.sandbox_id += 1
176+
177+
def pop(self,
178+
interrupt = False,
179+
drop_schema = 'DROP SCHEMA sandbox{0} CASCADE'.format
180+
):
181+
builtins, extras = self.builtins_stack.pop(-1)
182+
self.sandbox_id -= 1
183+
# restore
184+
if len(self.builtins_stack) > 1:
185+
__builtins__.update(self.builtins_stack[-1][0])
186+
else:
187+
previous = self.builtins_stack.pop(0)
188+
for x in self.builtins_keys:
189+
if x in previous:
190+
__builtins__[x] = previous[x]
191+
else:
192+
# Wasn't set before.
193+
__builtins__.pop(x, None)
194+
if not interrupt:
195+
# Interrupt then close. Just in case something is lingering.
196+
builtins['db'].interrupt()
197+
builtins['db'].close()
198+
for x in extras:
199+
x.interrupt()
200+
x.close()
201+
# Interrupted and closed all the other connections.
202+
with builtins['new']() as dropdb:
203+
dropdb.execute(drop_schema(self.sandbox_id+1))
204+
205+
def __enter__(self):
206+
if self.cluster is None:
207+
self.init()
208+
self.push()
209+
try:
210+
db.connect()
211+
db.execute('CREATE SCHEMA sandbox' + str(self.sandbox_id))
212+
db.settings['search_path'] = 'sandbox' + str(self.sandbox_id) + ',' + db.settings['search_path']
213+
except:
214+
self.pop()
215+
raise
216+
217+
def __exit__(self, exc, val, tb):
218+
self.pop(exc and not issubclass(exc, Exception))
219+
220+
#: The process' temporary cluster.
221+
pg_tmp = Temporal()

postgresql/test/test_alock.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
##
2+
# .test.test_alock - test .alock
3+
##
4+
import unittest
5+
import threading
6+
import time
7+
from ..temporal import pg_tmp
8+
from .. import alock
9+
10+
class test_alock(unittest.TestCase):
11+
@pg_tmp
12+
def testALockNoWait(self):
13+
alt = new()
14+
ad = db.prepare(
15+
"select count(*) FROM pg_locks WHERE locktype = 'advisory'"
16+
).first
17+
self.failUnlessEqual(ad(), 0)
18+
with alock.ExclusiveLock(db, (0,0)):
19+
l=alock.ExclusiveLock(alt, (0,0))
20+
# should fail to acquire
21+
self.failUnlessEqual(l.acquire(False), False)
22+
# no alocks should exist now
23+
self.failUnlessEqual(ad(), 0)
24+
25+
@pg_tmp
26+
def testALock(self):
27+
ad = db.prepare(
28+
"select count(*) FROM pg_locks WHERE locktype = 'advisory'"
29+
).first
30+
self.failUnlessEqual(ad(), 0)
31+
# test a variety..
32+
lockids = [
33+
(1,4),
34+
-32532, 0, 2,
35+
(7, -1232),
36+
4, 5, 232142423,
37+
(18,7),
38+
2, (1,4)
39+
]
40+
alt = new()
41+
xal1 = alock.ExclusiveLock(db, *lockids)
42+
xal2 = alock.ExclusiveLock(db, *lockids)
43+
sal1 = alock.ShareLock(db, *lockids)
44+
with sal1:
45+
with xal1, xal2:
46+
self.failUnless(ad() > 0)
47+
for x in lockids:
48+
xl = alock.ExclusiveLock(alt, x)
49+
self.failUnlessEqual(xl.acquire(False), False)
50+
# main has exclusives on these, so this should fail.
51+
xl = alock.ShareLock(alt, *lockids)
52+
self.failUnlessEqual(xl.acquire(False), False)
53+
for x in lockids:
54+
# sal1 still holds
55+
xl = alock.ExclusiveLock(alt, x)
56+
self.failUnlessEqual(xl.acquire(False), False)
57+
# sal1 still holds, but we want a share lock too.
58+
xl = alock.ShareLock(alt, x)
59+
self.failUnlessEqual(xl.acquire(False), True)
60+
xl.release()
61+
# no alocks should exist now
62+
self.failUnlessEqual(ad(), 0)
63+
64+
@pg_tmp
65+
def testPartialALock(self):
66+
# Validates that release is properly cleaning up
67+
ad = db.prepare(
68+
"select count(*) FROM pg_locks WHERE locktype = 'advisory'"
69+
).first
70+
self.failUnlessEqual(ad(), 0)
71+
held = (0,-1234)
72+
wanted = [0, 324, -1232948, 7, held, 1, (2,4), (834,1)]
73+
alt = new()
74+
with alock.ExclusiveLock(db, held):
75+
l=alock.ExclusiveLock(alt, *wanted)
76+
# should fail to acquire, db has held
77+
self.failUnlessEqual(l.acquire(False), False)
78+
# No alocks should exist now.
79+
# This *MUST* occur prior to alt being closed.
80+
# Otherwise, we won't be testing for the recovery
81+
# of a failed non-blocking acquire().
82+
self.failUnlessEqual(ad(), 0)
83+
84+
@pg_tmp
85+
def testALockParameterErrors(self):
86+
self.failUnlessRaises(TypeError, alock.ALock)
87+
l = alock.ExclusiveLock(db)
88+
self.failUnlessRaises(RuntimeError, l.release)
89+
90+
@pg_tmp
91+
def testALockOnClosed(self):
92+
ad = db.prepare(
93+
"select count(*) FROM pg_locks WHERE locktype = 'advisory'"
94+
).first
95+
self.failUnlessEqual(ad(), 0)
96+
held = (0,-1234)
97+
alt = new()
98+
# __exit__ should only touch the count.
99+
with alock.ExclusiveLock(alt, held) as l:
100+
self.failUnlessEqual(ad(), 1)
101+
self.failUnlessEqual(l.locked(), True)
102+
alt.close()
103+
time.sleep(0.005)
104+
self.failUnlessEqual(ad(), 0)
105+
self.failUnlessEqual(l.locked(), False)
106+
107+
if __name__ == '__main__':
108+
unittest.main()

0 commit comments

Comments
 (0)