Skip to content

Commit dab8bd7

Browse files
committed
add tagging func
1 parent a5efab6 commit dab8bd7

3 files changed

Lines changed: 250 additions & 33 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,7 @@ log_useful/
144144
# GreaseLM running generate
145145
filtered_concept.txt
146146
matcher_res.json
147+
148+
# Test code
149+
*.ipynb
150+

preprocess.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from preprocess_utils.conceptnet import extract_english, construct_graph
66
from preprocess_utils.grounding import create_matcher_patterns, ground
77
from preprocess_utils.graph import generate_adj_data_from_grounded_concepts__use_LM
8-
8+
from preprocess_utils.tagging import tag
99
input_paths = {
1010
'csqa': {
1111
'train': './data/csqa/train_rand_split.jsonl',
@@ -46,6 +46,11 @@
4646
'adj-dev': './data/csqa/graph/dev.graph.adj.pk',
4747
'adj-test': './data/csqa/graph/test.graph.adj.pk',
4848
},
49+
'tagged':{
50+
'train': './data/csqa/tagged/train.tagged.jsonl',
51+
'dev': './data/csqa/tagged/dev.tagged.jsonl',
52+
'test': './data/csqa/tagged/test.tagged.jsonl',
53+
},
4954
},
5055
'obqa': {
5156
'statement': {
@@ -61,6 +66,11 @@
6166
'dev': './data/obqa/grounded/dev.grounded.jsonl',
6267
'test': './data/obqa/grounded/test.grounded.jsonl',
6368
},
69+
'tagged':{
70+
'train': './data/obqa/tagged/train.tagged.jsonl',
71+
'dev': './data/obqa/tagged/dev.tagged.jsonl',
72+
'test': './data/obqa/tagged/test.tagged.jsonl',
73+
},
6474
'graph': {
6575
'adj-train': './data/obqa/graph/train.graph.adj.pk',
6676
'adj-dev': './data/obqa/graph/dev.graph.adj.pk',
@@ -81,42 +91,54 @@ def main():
8191
raise NotImplementedError()
8292

8393
routines = {
84-
'common': [
85-
{'func': extract_english, 'args': (input_paths['cpnet']['csv'], output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'])},
86-
{'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
87-
output_paths['cpnet']['unpruned-graph'], False)},
88-
{'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
89-
output_paths['cpnet']['pruned-graph'], True)},
90-
{'func': create_matcher_patterns, 'args': (output_paths['cpnet']['vocab'], output_paths['cpnet']['patterns'])},
91-
],
94+
# 'common': [
95+
# {'func': extract_english, 'args': (input_paths['cpnet']['csv'], output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'])},
96+
# {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
97+
# output_paths['cpnet']['unpruned-graph'], False)},
98+
# {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
99+
# output_paths['cpnet']['pruned-graph'], True)},
100+
# {'func': create_matcher_patterns, 'args': (output_paths['cpnet']['vocab'], output_paths['cpnet']['patterns'])},
101+
# ],
92102
'csqa': [
93-
{'func': convert_to_entailment, 'args': (input_paths['csqa']['train'], output_paths['csqa']['statement']['train'])},
94-
{'func': convert_to_entailment, 'args': (input_paths['csqa']['dev'], output_paths['csqa']['statement']['dev'])},
95-
{'func': convert_to_entailment, 'args': (input_paths['csqa']['test'], output_paths['csqa']['statement']['test'])},
96-
{'func': ground, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'],
97-
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['train'], args.nprocs)},
98-
{'func': ground, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'],
99-
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['dev'], args.nprocs)},
100-
{'func': ground, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'],
101-
output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['test'], args.nprocs)},
102-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-train'], args.nprocs)},
103-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-dev'], args.nprocs)},
104-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-test'], args.nprocs)},
103+
# {'func': convert_to_entailment, 'args': (input_paths['csqa']['train'], output_paths['csqa']['statement']['train'])},
104+
# {'func': convert_to_entailment, 'args': (input_paths['csqa']['dev'], output_paths['csqa']['statement']['dev'])},
105+
# {'func': convert_to_entailment, 'args': (input_paths['csqa']['test'], output_paths['csqa']['statement']['test'])},
106+
# {'func': ground, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'],
107+
# output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['train'], args.nprocs)},
108+
# {'func': ground, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'],
109+
# output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['dev'], args.nprocs)},
110+
# {'func': ground, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'],
111+
# output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['test'], args.nprocs)},
112+
{'func': tag, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'],
113+
output_paths['cpnet']['patterns'], output_paths['csqa']['tagged']['train'], args.nprocs)},
114+
{'func': tag, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'],
115+
output_paths['cpnet']['patterns'], output_paths['csqa']['tagged']['dev'], args.nprocs)},
116+
{'func': tag, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'],
117+
output_paths['cpnet']['patterns'], output_paths['csqa']['tagged']['test'], args.nprocs)},
118+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-train'], args.nprocs)},
119+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-dev'], args.nprocs)},
120+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-test'], args.nprocs)},
105121
],
106122

107123
'obqa': [
108-
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['train'], output_paths['obqa']['statement']['train'], output_paths['obqa']['statement']['train-fairseq'])},
109-
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['dev'], output_paths['obqa']['statement']['dev'], output_paths['obqa']['statement']['dev-fairseq'])},
110-
{'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['test'], output_paths['obqa']['statement']['test'], output_paths['obqa']['statement']['test-fairseq'])},
111-
{'func': ground, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'],
112-
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['train'], args.nprocs)},
113-
{'func': ground, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'],
114-
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['dev'], args.nprocs)},
115-
{'func': ground, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'],
116-
output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['test'], args.nprocs)},
117-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-train'], args.nprocs)},
118-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-dev'], args.nprocs)},
119-
{'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-test'], args.nprocs)},
124+
# {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['train'], output_paths['obqa']['statement']['train'], output_paths['obqa']['statement']['train-fairseq'])},
125+
# {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['dev'], output_paths['obqa']['statement']['dev'], output_paths['obqa']['statement']['dev-fairseq'])},
126+
# {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['test'], output_paths['obqa']['statement']['test'], output_paths['obqa']['statement']['test-fairseq'])},
127+
# {'func': ground, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'],
128+
# output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['train'], args.nprocs)},
129+
# {'func': ground, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'],
130+
# output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['dev'], args.nprocs)},
131+
# {'func': ground, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'],
132+
# output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['test'], args.nprocs)},
133+
{'func': tag, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'],
134+
output_paths['cpnet']['patterns'], output_paths['obqa']['tagged']['train'], args.nprocs)},
135+
{'func': tag, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'],
136+
output_paths['cpnet']['patterns'], output_paths['obqa']['tagged']['dev'], args.nprocs)},
137+
{'func': tag, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'],
138+
output_paths['cpnet']['patterns'], output_paths['obqa']['tagged']['test'], args.nprocs)},
139+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-train'], args.nprocs)},
140+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-dev'], args.nprocs)},
141+
# {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-test'], args.nprocs)},
120142
],
121143
}
122144

preprocess_utils/tagging.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
from multiprocessing import Pool
2+
import spacy
3+
from spacy.matcher import Matcher
4+
from tqdm import tqdm
5+
import nltk
6+
import json
7+
import string
8+
import re
9+
__all__ = ['create_matcher_patterns', 'ground']
10+
11+
12+
# the lemma of it/them/mine/.. is -PRON-
13+
14+
blacklist = set(["-PRON-", "actually", "likely", "possibly", "want",
15+
"make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to",
16+
"one", "something", "sometimes", "everybody", "somebody", "could", "could_be"
17+
])
18+
19+
20+
nltk.download('stopwords', quiet=True)
21+
nltk_stopwords = nltk.corpus.stopwords.words('english')
22+
23+
# CHUNK_SIZE = 1
24+
25+
CPNET_VOCAB = None
26+
PATTERN_PATH = None
27+
nlp = None
28+
matcher = None
29+
30+
31+
def load_cpnet_vocab(cpnet_vocab_path):
32+
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
33+
cpnet_vocab = [l.strip() for l in fin]
34+
cpnet_vocab = [c.replace("_", " ") for c in cpnet_vocab]
35+
return cpnet_vocab
36+
37+
def lemmatize(nlp, concept):
38+
39+
doc = nlp(concept.replace("_", " "))
40+
lcs = set()
41+
lcs.add("_".join([token.lemma_ for token in doc])) # all lemma
42+
return lcs
43+
44+
def load_matcher(nlp, pattern_path):
45+
with open(pattern_path, "r", encoding="utf8") as fin:
46+
all_patterns = json.load(fin)
47+
matcher = Matcher(nlp.vocab)
48+
# print('get the matcher')
49+
for concept, pattern in tqdm(all_patterns.items()):
50+
matcher.add(concept, [pattern])
51+
return matcher
52+
53+
def get_concept_position(sents, answers,stems,num_processes):
54+
res = []
55+
with Pool(num_processes) as p:
56+
res = list(tqdm(p.imap(tag_qa_pair, zip(sents, answers,stems)), total=len(sents)))
57+
return res
58+
59+
def tag_qa_pair(qa_pair):
60+
61+
# global nlp, matcher
62+
63+
sents,answers,stem = qa_pair
64+
sent_pair,stem_pair,ans_pair = [],[],[]
65+
for s in sents:
66+
pos_pair = tag_concepts_pos(s,nlp,matcher)
67+
sent_pair.append(pos_pair)
68+
for a in answers:
69+
pos_pair = tag_concepts_pos(a,nlp,matcher)
70+
ans_pair.append(pos_pair)
71+
stem_pair = tag_concepts_pos(stem,nlp,matcher)
72+
res = {
73+
'statements':sent_pair,
74+
'answers':ans_pair,
75+
'stem':stem_pair
76+
}
77+
return res
78+
79+
80+
81+
# def tag_concepts_pos(s,nlp,matcher):
82+
# s = s.lower()
83+
# doc = nlp(s)
84+
# matches = matcher(doc)
85+
# pair = set()
86+
# split_pair = set()
87+
# for match_id, start, end in matches:
88+
# span = doc[start:end].text
89+
# pair.add((start,end,span))
90+
# if end-start>1:
91+
# word_list = re.split(' |_',span)
92+
# if len(word_list) != end-start:
93+
# print(start,end,span,word_list)
94+
# return []
95+
# for i in range(end-start):
96+
# split_pair.add((start+i,start+i+1,word_list[i]))
97+
# # print(len(pair),len(split_pair))
98+
# pair= pair-split_pair
99+
# # print(pair)
100+
# return list(pair)
101+
102+
def prune(size,word_list):
103+
if len(word_list) != size: return False
104+
for i in range(size):
105+
if word_list[i] in nltk_stopwords:
106+
return False
107+
return True
108+
109+
110+
def tag_concepts_pos(s,nlp,matcher):
111+
s = s.lower()
112+
doc = nlp(s)
113+
matches = matcher(doc)
114+
pair = set()
115+
split_pair = set()
116+
for match_id, start, end in matches:
117+
span = doc[start:end].text
118+
word_list = span.split()
119+
size = end- start
120+
if prune(size,word_list):
121+
pair.add((start,end,span))
122+
if size >1 :
123+
for i in range(end-start):
124+
split_pair.add((start+i,start+i+1,word_list[i]))
125+
pair= pair-split_pair
126+
return list(pair)
127+
128+
129+
def tag(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False):
130+
global PATTERN_PATH, CPNET_VOCAB
131+
if PATTERN_PATH is None:
132+
PATTERN_PATH = pattern_path
133+
CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
134+
135+
global nlp, matcher
136+
if nlp is None or matcher is None:
137+
nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])
138+
nlp.add_pipe('sentencizer')
139+
matcher = load_matcher(nlp, PATTERN_PATH)
140+
141+
sents = []
142+
answers = []
143+
stems = []
144+
with open(statement_path, 'r') as fin:
145+
lines = [line for line in fin]
146+
147+
if debug:
148+
lines = lines[0:3]
149+
print(len(lines))
150+
for line in lines:
151+
sent_line = []
152+
ans_line = []
153+
if line == "":
154+
continue
155+
j = json.loads(line)
156+
for statement in j["statements"]:
157+
sent_line.append(statement["statement"])
158+
159+
for answer in j["question"]["choices"]:
160+
ans = answer['text']
161+
# ans = " ".join(answer['text'].split("_"))
162+
try:
163+
assert all([i != "_" for i in ans])
164+
except Exception:
165+
print(ans)
166+
ans_line.append(ans)
167+
sents.append(sent_line)
168+
answers.append(ans_line)
169+
stems.append(j['question']['stem'])
170+
171+
res = get_concept_position(sents, answers,stems,num_processes)
172+
173+
174+
# check_path(output_path)
175+
with open(output_path, 'w') as fout:
176+
for dic in res:
177+
fout.write(json.dumps(dic) + '\n')
178+
179+
print(f'grounded concepts saved to {output_path}')
180+
print()
181+
182+
if __name__ == "__main__":
183+
# create_matcher_patterns("../data/cpnet/concept.txt", "./matcher_res.txt", True)
184+
# ground("../data/statement/dev.statement.jsonl", "../data/cpnet/concept.txt", "../data/cpnet/matcher_patterns.json", "./ground_res.jsonl", 10, True)
185+
statement_path = '../data/csqa/statement/test.statement.jsonl'
186+
cpnet_vocab_path = '../data/cpnet/concept.txt'
187+
pattern_path = '../data/cpnet/matcher_patterns.json'
188+
output_path = '../data/obqa/tagged/test.tagged.jsonl'
189+
num_processes = 1
190+
debug=True
191+
tag(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes, debug)

0 commit comments

Comments
 (0)