|
| 1 | +## |
| 2 | +# .temporal - manage the temporary cluster |
| 3 | +## |
| 4 | +""" |
| 5 | +Temporary PostgreSQL cluster for the process. |
| 6 | +""" |
| 7 | +import os |
| 8 | +import atexit |
| 9 | +from .cluster import Cluster, ClusterError |
| 10 | +from . import installation |
| 11 | +from .python.socket import find_available_port |
| 12 | + |
| 13 | +class Temporal(object): |
| 14 | + """ |
| 15 | + Manages a temporary cluster for the duration of the process. |
| 16 | +
|
| 17 | + Instances of this class reference a distinct cluster. These clusters are |
| 18 | + transient; they will only exist until the process exits. |
| 19 | +
|
| 20 | + Usage:: |
| 21 | +
|
| 22 | + >>> from postgresql.temporal import pg_tmp |
| 23 | + >>> with pg_tmp: |
| 24 | + ... ps = db.prepare('SELECT 1') |
| 25 | + ... assert ps.first() == 1 |
| 26 | + |
| 27 | + Or `pg_tmp` can decorate a method or function. |
| 28 | + """ |
| 29 | + cluster_dirname = 'pg_tmp_{0}_{1}'.format |
| 30 | + cluster = None |
| 31 | + _init_pid_ = None |
| 32 | + _local_id_ = 0 |
| 33 | + builtins_keys = { |
| 34 | + 'connector', |
| 35 | + 'db', |
| 36 | + 'do', |
| 37 | + 'xact', |
| 38 | + 'proc', |
| 39 | + 'settings', |
| 40 | + 'prepare', |
| 41 | + 'sqlexec', |
| 42 | + 'newdb', |
| 43 | + } |
| 44 | + |
| 45 | + def __init__(self): |
| 46 | + self.builtins_stack = [] |
| 47 | + self.sandbox_id = 0 |
| 48 | + self.__class__._local_id_ = self.local_id = (self.__class__._local_id_ + 1) |
| 49 | + |
| 50 | + def __call__(self, callable): |
| 51 | + def incontext(*args, **kw): |
| 52 | + with self: |
| 53 | + return callable(*args, **kw) |
| 54 | + return incontext |
| 55 | + |
| 56 | + def reset(self): |
| 57 | + # Don't reset if it's not the initializing process. |
| 58 | + if os.getpid() == self._init_pid_: |
| 59 | + cluster = self.cluster |
| 60 | + self.cluster = None |
| 61 | + self._init_pid_ = None |
| 62 | + if cluster is not None: |
| 63 | + cluster.drop() |
| 64 | + |
| 65 | + def init(self, |
| 66 | + installation_factory = installation.default, |
| 67 | + inshint = { |
| 68 | + 'hint' : "Try setting the PGINSTALLATION " \ |
| 69 | + "environment variable to the `pg_config` path" |
| 70 | + } |
| 71 | + ): |
| 72 | + if self.cluster is not None: |
| 73 | + return |
| 74 | + ## |
| 75 | + # Hasn't been created yet, but doesn't matter. |
| 76 | + # On exit, obliterate the cluster directory. |
| 77 | + self._init_pid_ = os.getpid() |
| 78 | + atexit.register(self.reset) |
| 79 | + |
| 80 | + # [$HOME|.]/.pg_tmpdb_{pid} |
| 81 | + self.cluster_path = os.path.join( |
| 82 | + os.environ.get('HOME', os.getcwd()), |
| 83 | + self.cluster_dirname(self._init_pid_, self.local_id) |
| 84 | + ) |
| 85 | + self.logfile = os.path.join(self.cluster_path, 'logfile') |
| 86 | + installation = installation_factory() |
| 87 | + if installation is None: |
| 88 | + raise ClusterError( |
| 89 | + 'could not find the default pg_config', details = inshint |
| 90 | + ) |
| 91 | + |
| 92 | + cluster = Cluster( |
| 93 | + installation, |
| 94 | + self.cluster_path, |
| 95 | + ) |
| 96 | + |
| 97 | + # If it exists already, destroy it. |
| 98 | + if cluster.initialized(): |
| 99 | + cluster.drop() |
| 100 | + cluster.encoding = 'utf-8' |
| 101 | + cluster.init( |
| 102 | + user = 'test', # Consistent username. |
| 103 | + encoding = cluster.encoding, |
| 104 | + logfile = None, |
| 105 | + ) |
| 106 | + |
| 107 | + # Configure |
| 108 | + self.cluster_port = find_available_port() |
| 109 | + if self.cluster_port is None: |
| 110 | + raise ClusterError( |
| 111 | + 'could not find a port for the test cluster on localhost', |
| 112 | + creator = cluster |
| 113 | + ) |
| 114 | + |
| 115 | + cluster.settings.update(dict( |
| 116 | + port = str(self.cluster_port), |
| 117 | + max_connections = '20', |
| 118 | + shared_buffers = '200', |
| 119 | + listen_addresses = 'localhost', |
| 120 | + log_destination = 'stderr', |
| 121 | + log_min_messages = 'FATAL', |
| 122 | + silent_mode = 'off', |
| 123 | + )) |
| 124 | + cluster.settings.update(dict( |
| 125 | + max_prepared_transactions = '10', |
| 126 | + )) |
| 127 | + |
| 128 | + # Start it up. |
| 129 | + cluster.start(logfile = open(self.logfile, 'w')) |
| 130 | + cluster.wait_until_started() |
| 131 | + |
| 132 | + # Initialize template1 and the test user database. |
| 133 | + c = cluster.connection( |
| 134 | + user = 'test', database = 'template1', |
| 135 | + ) |
| 136 | + with c: |
| 137 | + c.execute('create database test') |
| 138 | + # It's ready. |
| 139 | + self.cluster = cluster |
| 140 | + |
| 141 | + def push(self): |
| 142 | + c = self.cluster.connection(user = 'test') |
| 143 | + extras = [] |
| 144 | + |
| 145 | + def newdb(l = extras, c = c, sbid = 'sandbox' + str(self.sandbox_id + 1)): |
| 146 | + # Used to create a new connection that will be closed |
| 147 | + # when the context stack is popped along with 'db'. |
| 148 | + l.append(c.clone()) |
| 149 | + l[-1].settings['search_path'] = str(sbid) + ',' + l[-1].settings['search_path'] |
| 150 | + return l[-1] |
| 151 | + |
| 152 | + # The new builtins. |
| 153 | + builtins = { |
| 154 | + 'db' : c, |
| 155 | + 'prepare' : c.prepare, |
| 156 | + 'xact' : c.xact, |
| 157 | + 'sqlexec' : c.execute, |
| 158 | + 'do' : c.do, |
| 159 | + 'settings' : c.settings, |
| 160 | + 'proc' : c.proc, |
| 161 | + 'connector' : c.connector, |
| 162 | + 'new' : newdb, |
| 163 | + } |
| 164 | + if not self.builtins_stack: |
| 165 | + # Store any of those set or not set. |
| 166 | + current = { |
| 167 | + k : __builtins__[k] for k in self.builtins_keys |
| 168 | + if k in __builtins__ |
| 169 | + } |
| 170 | + self.builtins_stack.append((current, [])) |
| 171 | + |
| 172 | + # Store and push. |
| 173 | + self.builtins_stack.append((builtins, extras)) |
| 174 | + __builtins__.update(builtins) |
| 175 | + self.sandbox_id += 1 |
| 176 | + |
| 177 | + def pop(self, |
| 178 | + interrupt = False, |
| 179 | + drop_schema = 'DROP SCHEMA sandbox{0} CASCADE'.format |
| 180 | + ): |
| 181 | + builtins, extras = self.builtins_stack.pop(-1) |
| 182 | + self.sandbox_id -= 1 |
| 183 | + # restore |
| 184 | + if len(self.builtins_stack) > 1: |
| 185 | + __builtins__.update(self.builtins_stack[-1][0]) |
| 186 | + else: |
| 187 | + previous = self.builtins_stack.pop(0) |
| 188 | + for x in self.builtins_keys: |
| 189 | + if x in previous: |
| 190 | + __builtins__[x] = previous[x] |
| 191 | + else: |
| 192 | + # Wasn't set before. |
| 193 | + __builtins__.pop(x, None) |
| 194 | + if not interrupt: |
| 195 | + # Interrupt then close. Just in case something is lingering. |
| 196 | + builtins['db'].interrupt() |
| 197 | + builtins['db'].close() |
| 198 | + for x in extras: |
| 199 | + x.interrupt() |
| 200 | + x.close() |
| 201 | + # Interrupted and closed all the other connections. |
| 202 | + with builtins['new']() as dropdb: |
| 203 | + dropdb.execute(drop_schema(self.sandbox_id+1)) |
| 204 | + |
| 205 | + def __enter__(self): |
| 206 | + if self.cluster is None: |
| 207 | + self.init() |
| 208 | + self.push() |
| 209 | + try: |
| 210 | + db.connect() |
| 211 | + db.execute('CREATE SCHEMA sandbox' + str(self.sandbox_id)) |
| 212 | + db.settings['search_path'] = 'sandbox' + str(self.sandbox_id) + ',' + db.settings['search_path'] |
| 213 | + except: |
| 214 | + self.pop() |
| 215 | + raise |
| 216 | + |
| 217 | + def __exit__(self, exc, val, tb): |
| 218 | + self.pop(exc and not issubclass(exc, Exception)) |
| 219 | + |
| 220 | +#: The process' temporary cluster. |
| 221 | +pg_tmp = Temporal() |
0 commit comments