Skip to content

Commit 3692444

Browse files
committed
add tag process in dataloader
1 parent dab8bd7 commit 3692444

6 files changed

Lines changed: 865 additions & 28 deletions

File tree

data_utils.py

Lines changed: 712 additions & 0 deletions
Large diffs are not rendered by default.

greaselm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ def load_data(args, devices, kg):
6161
model_name=args.encoder,
6262
max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
6363
is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
64-
subsample=args.subsample, n_train=args.n_train, debug=args.debug, cxt_node_connects_all=args.cxt_node_connects_all, kg=kg)
64+
subsample=args.subsample, n_train=args.n_train, debug=args.debug, cxt_node_connects_all=args.cxt_node_connects_all, kg=kg,emp=args.emp,
65+
train_tagged_path = args.train_tagged,
66+
dev_tagged_path=args.dev_tagged,
67+
test_tagged_path=args.test_tagged,
68+
)
6569

6670
return dataset
6771

@@ -560,7 +564,8 @@ def main(args):
560564

561565
parser = parser_utils.get_parser()
562566
args, _ = parser.parse_known_args()
563-
567+
# print(args.train_statements,args.train_tagged)
568+
# input()
564569
# General
565570
parser.add_argument('--mode', default='train', choices=['train', 'eval'], help='run training or evaluation')
566571
parser.add_argument('--save_dir', default=f'./saved_models/greaselm/', help='model output directory')
@@ -609,5 +614,10 @@ def main(args):
609614
parser.add_argument('--refreeze_epoch', default=10000, type=int)
610615
parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')
611616

617+
# MyGLM
618+
parser.add_argument('--emp',default=True,type=bool)
619+
612620
args = parser.parse_args()
621+
# print(args.train_statements,args.train_tagged)
622+
613623
main(args)

preprocess_utils/tagging.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ def tag(statement_path, cpnet_vocab_path, pattern_path, output_path, num_process
173173

174174
# check_path(output_path)
175175
with open(output_path, 'w') as fout:
176-
for dic in res:
177-
fout.write(json.dumps(dic) + '\n')
176+
# change write file to json format
177+
fout.write(json.dumps(res))
178+
# for dic in res:
179+
# fout.write(json.dumps(dic) + '\n')
178180

179181
print(f'grounded concepts saved to {output_path}')
180182
print()
@@ -185,7 +187,7 @@ def tag(statement_path, cpnet_vocab_path, pattern_path, output_path, num_process
185187
statement_path = '../data/csqa/statement/test.statement.jsonl'
186188
cpnet_vocab_path = '../data/cpnet/concept.txt'
187189
pattern_path = '../data/cpnet/matcher_patterns.json'
188-
output_path = '../data/obqa/tagged/test.tagged.jsonl'
190+
output_path = '../data/obqa/tagged/test.jsonl'
189191
num_processes = 1
190192
debug=True
191193
tag(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes, debug)

run_greaselm.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
export TOKENIZERS_PARALLELISM=true
33
dt=`date '+%Y%m%d_%H%M%S'`
44

5-
65
dataset=$1
76
shift
87
encoder='roberta-large'

utils/data_utils.py

Lines changed: 122 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import pickle
44
import os
55

6+
67
import numpy as np
78
import torch
89
from tqdm import tqdm
910
from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
1011
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
11-
from transformers import (OpenAIGPTTokenizer, BertTokenizer, XLNetTokenizer, RobertaTokenizer)
12+
from transformers import (OpenAIGPTTokenizer, BertTokenizer, XLNetTokenizer, RobertaTokenizer,AutoTokenizer)
1213
try:
1314
from transformers import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
1415
from transformers import AlbertTokenizer
@@ -18,6 +19,8 @@
1819
from preprocess_utils import conceptnet
1920
from utils import utils
2021

22+
23+
2124
MODEL_CLASS_TO_NAME = {
2225
'gpt': list(OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
2326
'bert': list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
@@ -94,7 +97,7 @@ def __init__(self, train_statement_path, train_adj_path,
9497
test_statement_path, test_adj_path,
9598
batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128,
9699
is_inhouse=False, inhouse_train_qids_path=None,
97-
subsample=1.0, n_train=-1, debug=False, cxt_node_connects_all=False, kg="cpnet"):
100+
subsample=1.0, n_train=-1, debug=False, cxt_node_connects_all=False, kg="cpnet",emp=False,train_tagged_path='',dev_tagged_path = '',test_tagged_path = ''):
98101
super().__init__()
99102
self.batch_size = batch_size
100103
self.eval_batch_size = eval_batch_size
@@ -105,13 +108,16 @@ def __init__(self, train_statement_path, train_adj_path,
105108
self.max_node_num = max_node_num
106109
self.debug_sample_size = 32
107110
self.cxt_node_connects_all = cxt_node_connects_all
111+
112+
# emp control the embedding pooling process
113+
self.emp = emp
108114

109115
self.model_type = MODEL_NAME_TO_CLASS[model_name]
110116
self.load_resources(kg)
111117

112118
# Load training data
113119
print ('train_statement_path', train_statement_path)
114-
self.train_qids, self.train_labels, self.train_encoder_data, train_concepts_by_sents_list = self.load_input_tensors(train_statement_path, max_seq_length)
120+
self.train_qids, self.train_labels, self.train_encoder_data, train_concepts_by_sents_list = self.load_input_tensors(train_statement_path, max_seq_length,train_tagged_path,emp)
115121

116122
num_choice = self.train_encoder_data[0].size(1)
117123
self.num_choice = num_choice
@@ -123,7 +129,7 @@ def __init__(self, train_statement_path, train_adj_path,
123129
print("Finish loading training data.")
124130

125131
# Load dev data
126-
self.dev_qids, self.dev_labels, self.dev_encoder_data, dev_concepts_by_sents_list = self.load_input_tensors(dev_statement_path, max_seq_length)
132+
self.dev_qids, self.dev_labels, self.dev_encoder_data, dev_concepts_by_sents_list = self.load_input_tensors(dev_statement_path, max_seq_length,dev_tagged_path,emp)
127133
*self.dev_decoder_data, self.dev_adj_data = self.load_sparse_adj_data_with_contextnode(dev_adj_path, max_node_num, dev_concepts_by_sents_list)
128134
if not debug:
129135
assert all(len(self.dev_qids) == len(self.dev_adj_data[0]) == x.size(0) for x in [self.dev_labels] + self.dev_encoder_data + self.dev_decoder_data)
@@ -132,7 +138,7 @@ def __init__(self, train_statement_path, train_adj_path,
132138

133139
# Load test data
134140
if test_statement_path is not None:
135-
self.test_qids, self.test_labels, self.test_encoder_data, test_concepts_by_sents_list = self.load_input_tensors(test_statement_path, max_seq_length)
141+
self.test_qids, self.test_labels,self.test_encoder_data, test_concepts_by_sents_list = self.load_input_tensors(test_statement_path, max_seq_length,test_tagged_path,emp)
136142
*self.test_decoder_data, self.test_adj_data = self.load_sparse_adj_data_with_contextnode(test_adj_path, max_node_num, test_concepts_by_sents_list)
137143
if not debug:
138144
assert all(len(self.test_qids) == len(self.test_adj_data[0]) == x.size(0) for x in [self.test_labels] + self.test_encoder_data + self.test_decoder_data)
@@ -210,11 +216,13 @@ def test(self):
210216

211217
def load_resources(self, kg):
212218
# Load the tokenizer
213-
try:
214-
tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(self.model_type)
215-
except:
216-
tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(self.model_type)
217-
tokenizer = tokenizer_class.from_pretrained(self.model_name)
219+
# try:
220+
# tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(self.model_type)
221+
# except:
222+
# tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(self.model_type)
223+
# use autotokenizer to use tagging
224+
tokenizer = AutoTokenizer.from_pretrained(self.model_name,use_fast =True)
225+
#tokenizer = tokenizer_class.from_pretrained(self.model_name)
218226
self.tokenizer = tokenizer
219227

220228
if kg == "cpnet":
@@ -249,14 +257,21 @@ def load_resources(self, kg):
249257
else:
250258
raise ValueError("Invalid value for kg.")
251259

252-
def load_input_tensors(self, input_jsonl_path, max_seq_length):
260+
def load_input_tensors(self, input_jsonl_path, max_seq_length,tagged_jsonl_path,emp):
253261
"""Construct input tensors for the LM component of the model."""
254-
cache_path = input_jsonl_path + "-sl{}".format(max_seq_length) + (("-" + self.model_type) if self.model_type != "roberta" else "") + '.loaded_cache'
262+
if emp:
263+
cache_path = input_jsonl_path + "-sl{}".format(max_seq_length) + (("-" + self.model_type) if self.model_type != "roberta" else "") + '-tag' + '.loaded_cache'
264+
else:
265+
cache_path = input_jsonl_path + "-sl{}".format(max_seq_length) + (("-" + self.model_type) if self.model_type != "roberta" else "") + '.loaded_cache'
266+
255267
use_cache = True
256268

257269
if use_cache and not os.path.exists(cache_path):
258270
use_cache = False
259271

272+
#debug
273+
# use_cache = False
274+
260275
if use_cache:
261276
with open(cache_path, 'rb') as f:
262277
input_tensors = utils.CPU_Unpickler(f).load()
@@ -266,7 +281,7 @@ def load_input_tensors(self, input_jsonl_path, max_seq_length):
266281
elif self.model_type in ('gpt',):
267282
input_tensors = load_gpt_input_tensors(input_jsonl_path, max_seq_length)
268283
elif self.model_type in ('bert', 'xlnet', 'roberta', 'albert'):
269-
input_tensors = load_bert_xlnet_roberta_input_tensors(input_jsonl_path, max_seq_length, self.debug, self.tokenizer, self.debug_sample_size)
284+
input_tensors = load_bert_xlnet_roberta_input_tensors(input_jsonl_path, max_seq_length, self.debug, self.tokenizer, self.debug_sample_size,tagged_jsonl_path,emp)
270285

271286
if not self.debug:
272287
utils.save_pickle(input_tensors, cache_path)
@@ -508,7 +523,7 @@ def tokenize_and_encode(tokenizer, obj):
508523
return examples_ids, mc_labels, input_ids, mc_token_ids, lm_labels
509524

510525

511-
def load_bert_xlnet_roberta_input_tensors(statement_jsonl_path, max_seq_length, debug, tokenizer, debug_sample_size):
526+
def load_bert_xlnet_roberta_input_tensors(statement_jsonl_path, max_seq_length, debug, tokenizer, debug_sample_size,tagged_jsonl_path,emp):
512527
class InputExample(object):
513528

514529
def __init__(self, example_id, question, contexts, endings, label=None):
@@ -528,8 +543,9 @@ def __init__(self, example_id, choices_features, label):
528543
'input_mask': input_mask,
529544
'segment_ids': segment_ids,
530545
'output_mask': output_mask,
546+
'pool_mask':pool_mask
531547
}
532-
for input_ids, input_mask, segment_ids, output_mask in choices_features
548+
for input_ids, input_mask, segment_ids, output_mask,pool_mask in choices_features
533549
]
534550
self.label = label
535551

@@ -554,8 +570,52 @@ def read_examples(input_file):
554570
label=label
555571
))
556572
return examples
573+
574+
def read_tagged_file(input_file):
575+
# with open(input_file,"r",encoding="utf-8") as f:
576+
# examples = []
577+
# for line in f.readlines():
578+
# json_dic = json.loads(line)
579+
# statement = json_dic['statements']
580+
# answers = json_dic['answers']
581+
# stem = json_dic['stem']
582+
# examples.append((statement,answers,stem))
583+
f = open(input_file)
584+
examples = json.load(f)
585+
return examples
586+
587+
588+
def get_pool_mask(encoded_input_words,context,ending,tagged_ending,tagged_context,pool_mask):
589+
# print(encoded_input_words,'\n\n',context,'\n\n',ending,'\n\n',tagged_ending,'\n\n',tagged_context)
590+
def get_pool_set(items):
591+
pool_set = set()
592+
# item [11, 12, 'afford']
593+
for item in items:
594+
start = item[0]
595+
end = item[1]
596+
if end-start>1:
597+
for i in range(start,end):
598+
pool_set.add(i)
599+
return pool_set
600+
601+
end_pool_set = get_pool_set(tagged_ending)
602+
context_pool_set = get_pool_set(tagged_context)
603+
flag = True
604+
for i in range(len(pool_mask)):
605+
if flag:
606+
if encoded_input_words[i] in context_pool_set:
607+
pool_mask[i] = 1
608+
if encoded_input_words[i] == None and encoded_input_words[i+1]==None:
609+
flag = False
610+
else:
611+
if encoded_input_words[i] in end_pool_set:
612+
pool_mask[i] = 1
613+
if encoded_input_words[i] == None and encoded_input_words[i+1]==None:
614+
break
615+
# print(pool_mask)
616+
return pool_mask
557617

558-
def simple_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
618+
def simple_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer,emp,tagged):
559619
""" Loads a data file into a list of `InputBatch`s
560620
`cls_token_at_end` define the location of the CLS token:
561621
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
@@ -572,20 +632,35 @@ def simple_convert_examples_to_features(examples, label_list, max_seq_length, to
572632
choices_features = []
573633
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
574634
ans = example.question + " " + ending
575-
635+
if emp:
636+
tagged_ending = tagged[ex_index]['answers'][ending_idx]
637+
tagged_context = tagged[ex_index]['stem']
638+
576639
encoded_input = tokenizer(context, ans, padding="max_length", truncation=True, max_length=max_seq_length, return_token_type_ids=True, return_special_tokens_mask=True)
640+
# print(encoded_input.words())
577641
input_ids = encoded_input["input_ids"]
578642
output_mask = encoded_input["special_tokens_mask"]
579643
input_mask = encoded_input["attention_mask"]
580644
segment_ids = encoded_input["token_type_ids"]
581-
# print(context,'\n',ans,'\n',encoded_input["input_ids"])
582-
# input()
645+
pool_mask = [0]*max_seq_length
646+
if emp:
647+
pool_mask = get_pool_mask(encoded_input.words(),context,ending,tagged_ending,tagged_context,pool_mask)
648+
649+
583650
assert len(input_ids) == max_seq_length
584651
assert len(output_mask) == max_seq_length
585652
assert len(input_mask) == max_seq_length
586653
assert len(segment_ids) == max_seq_length
587654

588-
choices_features.append((input_ids, input_mask, segment_ids, output_mask))
655+
choices_features.append((input_ids, input_mask, segment_ids, output_mask,pool_mask))
656+
657+
# for i in range(max_seq_length):
658+
# if input_ids[i] == 1:
659+
# print(len(context.split())+len(ans.split()),i)
660+
# break
661+
662+
# input()
663+
589664
label = label_map[example.label]
590665
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label))
591666

@@ -604,8 +679,34 @@ def convert_features_to_tensors(features):
604679
return all_input_ids, all_input_mask, all_segment_ids, all_output_mask, all_label
605680

606681
examples = read_examples(statement_jsonl_path)
607-
features, concepts_by_sents_list = simple_convert_examples_to_features(examples, list(range(len(examples[0].endings))), max_seq_length, tokenizer)
682+
tagged = []
683+
if emp:
684+
tagged = read_tagged_file(tagged_jsonl_path)
685+
686+
687+
features, concepts_by_sents_list = simple_convert_examples_to_features(examples, list(range(len(examples[0].endings))), max_seq_length, tokenizer,emp,tagged)
608688

609689
example_ids = [f.example_id for f in features]
610690
*data_tensors, all_label = convert_features_to_tensors(features)
611691
return example_ids, all_label, data_tensors, concepts_by_sents_list
692+
693+
if __name__ == "__main__":
694+
695+
696+
697+
model_name = 'roberta-large'
698+
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
699+
# model_name = 'bert-base-uncased'
700+
# tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
701+
702+
703+
statement_jsonl_path = './data/obqa/statement/test.statement.jsonl'
704+
max_seq_length = 128
705+
debug = False
706+
emp = False
707+
708+
debug_sample_size = 32
709+
tagged_jsonl_path = './data/obqa/tagged/test.tagged.jsonl'
710+
711+
example_ids, all_label, data_tensors, concepts_by_sents_list=load_bert_xlnet_roberta_input_tensors(statement_jsonl_path, max_seq_length, debug, tokenizer, debug_sample_size,tagged_jsonl_path,emp)
712+

utils/parser_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ def add_data_arguments(parser):
5353
parser.add_argument('--train_statements', default='{data_dir}/{dataset}/statement/train.statement.jsonl')
5454
parser.add_argument('--dev_statements', default='{data_dir}/{dataset}/statement/dev.statement.jsonl')
5555
parser.add_argument('--test_statements', default='{data_dir}/{dataset}/statement/test.statement.jsonl')
56+
57+
# tagged
58+
parser.add_argument('--train_tagged', default='{data_dir}/{dataset}/tagged/train.tagged.jsonl')
59+
parser.add_argument('--dev_tagged', default='{data_dir}/{dataset}/tagged/dev.tagged.jsonl')
60+
parser.add_argument('--test_tagged', default='{data_dir}/{dataset}/tagged/test.tagged.jsonl')
61+
62+
5663
# preprocessing options
5764
parser.add_argument('-sl', '--max_seq_len', default=100, type=int)
5865
# set dataset defaults
@@ -62,11 +69,16 @@ def add_data_arguments(parser):
6269
inhouse_train_qids=args.inhouse_train_qids.format(dataset=args.dataset))
6370
data_splits = ('train', 'dev') if args.dataset in DATASET_NO_TEST else ('train', 'dev', 'test')
6471
for split in data_splits:
65-
for attribute in ('statements',):
72+
# 这里不加上tagged,tagged的文件路径的{data_dir}不会被替代会如下所示
73+
#{data_dir}/{dataset}/tagged/train.tagged.jsonl
74+
for attribute in ('statements','tagged'):
6675
attr_name = f'{split}_{attribute}'
6776
parser.set_defaults(**{attr_name: getattr(args, attr_name).format(dataset=args.dataset, data_dir=args.data_dir)})
6877
if 'test' not in data_splits:
6978
parser.set_defaults(test_statements=None)
79+
# args, _ = parser.parse_known_args()
80+
# print(args.train_statements,args.train_tagged)
81+
# input()
7082

7183

7284
def add_encoder_arguments(parser):
@@ -106,4 +118,5 @@ def get_parser():
106118
add_encoder_arguments(parser)
107119
add_optimization_arguments(parser)
108120
add_additional_arguments(parser)
121+
109122
return parser

0 commit comments

Comments
 (0)