forked from thunlp/TensorFlow-Summarization
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
69 lines (60 loc) · 2.19 KB
/
test.py
File metadata and controls
69 lines (60 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import subprocess
import logging
import os
model_path = "model/"
model_pattern = "model/{}.ckpt-{}"
data_pattern = "data/test.{}.txt"
OUTPUT_DIR = "output/"
OUTPUT_PATTERN = OUTPUT_DIR + "{dataset}.{description}.txt"
MAX_KEEP = 1
datasets = ["giga", "duc2003", "duc2004"]
geneos = [True, False, False]
beam_searchs = [1, 10]
test_params = {
"--decode": True,
"--fast_decode": True
}
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG,
format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s",
datefmt='%b %d %H:%M')
try:
os.mkdir(OUTPUT_DIR)
except:
pass
ckpts = os.listdir(model_path)
models = []
for item in ckpts:
toks = item.split('.')
if len(toks) != 3:
continue
if toks[2] != "index":
continue
toks[1] = toks[1].split('-')[1]
models.append(toks[:2])
models = sorted(models, key = lambda x: int(x[1]), reverse=True)
models = models[:MAX_KEEP]
print(models)
for model in models:
ckpt = model_pattern.format(model[0], model[1])
logging.info("Test {}. ".format(ckpt))
for dataset, tag in zip(datasets, geneos):
for beam_search in beam_searchs:
logging.info("Test {} with beam_size = {}".format(
data_pattern.format(dataset), beam_search))
output_file = OUTPUT_PATTERN.format(dataset=dataset,
description=str(beam_search)+"_"+str(model[1]))
if os.path.exists(output_file):
logging.info("{} exists, skip testing".format(output_file))
continue
proc = ["python3", "src/summarization.py",
"--test_file", data_pattern.format(dataset),
"--batch_size", str(beam_search),
"--test_output", output_file,
"--geneos", str(tag),
"--checkpoint", ckpt]
for k, v in test_params.items():
proc.append(k)
proc.append(str(v))
subprocess.call(proc)