This repository was archived by the owner on Nov 29, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathclone_database.py
More file actions
164 lines (139 loc) · 5.69 KB
/
Copy pathclone_database.py
File metadata and controls
164 lines (139 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Clone all data from a database to another. Mostly useful for database migration."""
import argparse
import logging.config
import traceback
import pdb
from sqlalchemy import (
Table, Column, String, insert, select, delete, create_engine, update)
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast
from sqlalchemy.types import TIMESTAMP, DATETIME
from pyramid.paster import get_appsettings, bootstrap
from assembl.lib.config import set_config
from assembl.lib.sqla import (
configure_engine, get_session_maker, make_session_maker, get_metadata)
from assembl.lib.zmqlib import configure_zmq
from assembl.lib.model_watcher import configure_model_watcher
from assembl.indexing.changes import configure_indexing
recursive_tables = {
"post": ("id", "parent_id"),
"preferences": ("id", "cascade_id")
}
# TODO: Make it dependent on engine
column_casts = {
"idea": {
"last_modified": DATETIME
}
}
history_tables = ["idea", "idea_idea_link", "idea_vote"]
def maybe_cast(column):
cast_to = column_casts.get(column.table.name, {}).get(column.name, None)
return column if cast_to is None else cast(column, cast_to)
def is_virtuoso(session):
return str(session.bind.url).startswith('virtuoso')
def set_sequence(session, name, value):
session.execute(
"SELECT {command}('{name}', {value})".format(
command="sequence_set" if is_virtuoso(session) else "setval",
name=name, value=value))
def get_sequence(session, name):
if is_virtuoso(session):
return session.query("sequence_set('%s', 0, 1)" % (name,)).first()[0]
else:
return session.query("currval('%s')" % (name,)).first()[0]
def copy_table(source_session, dest_session, source_table, dest_table):
columns = [maybe_cast(c) for c in source_table.c]
cnames = [c.name for c in source_table.c]
values = source_session.query(*columns)
values = [dict(zip(cnames, val)) for val in values]
if not len(values):
return
if source_table.name in recursive_tables:
idx_name, fkey_name = recursive_tables[source_table.name]
done = set()
while len(values):
batch = values
values = list()
next_batch = list()
for val in batch:
fkey = val[fkey_name]
if fkey is None or fkey in done:
values.append(val)
else:
next_batch.append(val)
assert len(values)
dest_session.execute(dest_table.insert(), values)
done.update((val[idx_name] for val in values))
values = next_batch
else:
dest_session.execute(dest_table.insert(), values)
if str(dest_session.bind.url).startswith('postgresql'):
idx_col = dest_table.c.get("id", None)
if idx_col is not None and not idx_col.foreign_keys:
(max_id,) = source_session.query(
'max(id) from "%s"' % (source_table.name,)).first()
if dest_table.name in history_tables:
max_id = max(max_id, get_sequence(
source_session, source_table.fullname+"_idsequence"))
set_sequence(dest_session, dest_table.fullname+"_idsequence", max_id)
elif not is_virtuoso(dest_session):
set_sequence(dest_session, dest_table.fullname+"_id_seq", max_id)
def engine_from_settings(config, full_config=False):
settings = get_appsettings(config, 'assembl')
set_config(settings, True)
session = None
if full_config:
env = bootstrap(config)
configure_zmq(settings['changes_socket'], False)
configure_indexing()
configure_model_watcher(env['registry'], 'assembl')
logging.config.fileConfig(config)
else:
session = make_session_maker(zope_tr=True)
import assembl.models
from assembl.lib.sqla import class_registry
engine = configure_engine(settings, session_maker=session)
metadata = get_metadata()
metadata.bind = engine
session = sessionmaker(engine)()
return (metadata, session)
def copy_database(source_config, dest_config):
dest_metadata, dest_session = engine_from_settings(
dest_config, True)
dest_tables = dest_metadata.sorted_tables
source_metadata, source_session = engine_from_settings(
source_config, False)
source_tables_by_name = {
table.name: table.tometadata(source_metadata, source_metadata.schema)
for table in dest_tables
}
for table in reversed(dest_tables):
if table.name in recursive_tables:
colname = recursive_tables[table.name][1]
dest_session.execute(update(table).values(**{colname: None}))
dest_session.execute(delete(table))
for table in dest_tables:
copy_table(
source_session, dest_session,
source_tables_by_name[table.name], table)
dest_session.commit()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Copy database. Will wipe the destination database.')
parser.add_argument(
"source_config",
help="""configuration file with source database configuration.""")
parser.add_argument(
"dest_config",
help="""configuration file with target database configuration.""")
parser.add_argument("--debug", action="store_true", default=False,
help="enter pdb on failure")
args = parser.parse_args()
assert args.source_config != args.dest_config,\
"source and destination must be different!"
try:
copy_database(args.source_config, args.dest_config)
except Exception as e:
traceback.print_exc()
if args.debug:
pdb.post_mortem()