33import pickle
44import os
55
6+
67import numpy as np
78import torch
89from tqdm import tqdm
910from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP , BERT_PRETRAINED_CONFIG_ARCHIVE_MAP ,
1011 XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP , ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP )
11- from transformers import (OpenAIGPTTokenizer , BertTokenizer , XLNetTokenizer , RobertaTokenizer )
12+ from transformers import (OpenAIGPTTokenizer , BertTokenizer , XLNetTokenizer , RobertaTokenizer , AutoTokenizer )
1213try :
1314 from transformers import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
1415 from transformers import AlbertTokenizer
1819from preprocess_utils import conceptnet
1920from utils import utils
2021
22+
23+
2124MODEL_CLASS_TO_NAME = {
2225 'gpt' : list (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP .keys ()),
2326 'bert' : list (BERT_PRETRAINED_CONFIG_ARCHIVE_MAP .keys ()),
@@ -94,7 +97,7 @@ def __init__(self, train_statement_path, train_adj_path,
9497 test_statement_path , test_adj_path ,
9598 batch_size , eval_batch_size , device , model_name , max_node_num = 200 , max_seq_length = 128 ,
9699 is_inhouse = False , inhouse_train_qids_path = None ,
97- subsample = 1.0 , n_train = - 1 , debug = False , cxt_node_connects_all = False , kg = "cpnet" ):
100+ subsample = 1.0 , n_train = - 1 , debug = False , cxt_node_connects_all = False , kg = "cpnet" , emp = False , train_tagged_path = '' , dev_tagged_path = '' , test_tagged_path = '' ):
98101 super ().__init__ ()
99102 self .batch_size = batch_size
100103 self .eval_batch_size = eval_batch_size
@@ -105,13 +108,16 @@ def __init__(self, train_statement_path, train_adj_path,
105108 self .max_node_num = max_node_num
106109 self .debug_sample_size = 32
107110 self .cxt_node_connects_all = cxt_node_connects_all
111+
112+ # emp control the embedding pooling process
113+ self .emp = emp
108114
109115 self .model_type = MODEL_NAME_TO_CLASS [model_name ]
110116 self .load_resources (kg )
111117
112118 # Load training data
113119 print ('train_statement_path' , train_statement_path )
114- self .train_qids , self .train_labels , self .train_encoder_data , train_concepts_by_sents_list = self .load_input_tensors (train_statement_path , max_seq_length )
120+ self .train_qids , self .train_labels , self .train_encoder_data , train_concepts_by_sents_list = self .load_input_tensors (train_statement_path , max_seq_length , train_tagged_path , emp )
115121
116122 num_choice = self .train_encoder_data [0 ].size (1 )
117123 self .num_choice = num_choice
@@ -123,7 +129,7 @@ def __init__(self, train_statement_path, train_adj_path,
123129 print ("Finish loading training data." )
124130
125131 # Load dev data
126- self .dev_qids , self .dev_labels , self .dev_encoder_data , dev_concepts_by_sents_list = self .load_input_tensors (dev_statement_path , max_seq_length )
132+ self .dev_qids , self .dev_labels , self .dev_encoder_data , dev_concepts_by_sents_list = self .load_input_tensors (dev_statement_path , max_seq_length , dev_tagged_path , emp )
127133 * self .dev_decoder_data , self .dev_adj_data = self .load_sparse_adj_data_with_contextnode (dev_adj_path , max_node_num , dev_concepts_by_sents_list )
128134 if not debug :
129135 assert all (len (self .dev_qids ) == len (self .dev_adj_data [0 ]) == x .size (0 ) for x in [self .dev_labels ] + self .dev_encoder_data + self .dev_decoder_data )
@@ -132,7 +138,7 @@ def __init__(self, train_statement_path, train_adj_path,
132138
133139 # Load test data
134140 if test_statement_path is not None :
135- self .test_qids , self .test_labels , self .test_encoder_data , test_concepts_by_sents_list = self .load_input_tensors (test_statement_path , max_seq_length )
141+ self .test_qids , self .test_labels ,self .test_encoder_data , test_concepts_by_sents_list = self .load_input_tensors (test_statement_path , max_seq_length , test_tagged_path , emp )
136142 * self .test_decoder_data , self .test_adj_data = self .load_sparse_adj_data_with_contextnode (test_adj_path , max_node_num , test_concepts_by_sents_list )
137143 if not debug :
138144 assert all (len (self .test_qids ) == len (self .test_adj_data [0 ]) == x .size (0 ) for x in [self .test_labels ] + self .test_encoder_data + self .test_decoder_data )
@@ -210,11 +216,13 @@ def test(self):
210216
211217 def load_resources (self , kg ):
212218 # Load the tokenizer
213- try :
214- tokenizer_class = {'bert' : BertTokenizer , 'xlnet' : XLNetTokenizer , 'roberta' : RobertaTokenizer , 'albert' : AlbertTokenizer }.get (self .model_type )
215- except :
216- tokenizer_class = {'bert' : BertTokenizer , 'xlnet' : XLNetTokenizer , 'roberta' : RobertaTokenizer }.get (self .model_type )
217- tokenizer = tokenizer_class .from_pretrained (self .model_name )
219+ # try:
220+ # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(self.model_type)
221+ # except:
222+ # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(self.model_type)
223+ # use autotokenizer to use tagging
224+ tokenizer = AutoTokenizer .from_pretrained (self .model_name ,use_fast = True )
225+ #tokenizer = tokenizer_class.from_pretrained(self.model_name)
218226 self .tokenizer = tokenizer
219227
220228 if kg == "cpnet" :
@@ -249,14 +257,21 @@ def load_resources(self, kg):
249257 else :
250258 raise ValueError ("Invalid value for kg." )
251259
252- def load_input_tensors (self , input_jsonl_path , max_seq_length ):
260+ def load_input_tensors (self , input_jsonl_path , max_seq_length , tagged_jsonl_path , emp ):
253261 """Construct input tensors for the LM component of the model."""
254- cache_path = input_jsonl_path + "-sl{}" .format (max_seq_length ) + (("-" + self .model_type ) if self .model_type != "roberta" else "" ) + '.loaded_cache'
262+ if emp :
263+ cache_path = input_jsonl_path + "-sl{}" .format (max_seq_length ) + (("-" + self .model_type ) if self .model_type != "roberta" else "" ) + '-tag' + '.loaded_cache'
264+ else :
265+ cache_path = input_jsonl_path + "-sl{}" .format (max_seq_length ) + (("-" + self .model_type ) if self .model_type != "roberta" else "" ) + '.loaded_cache'
266+
255267 use_cache = True
256268
257269 if use_cache and not os .path .exists (cache_path ):
258270 use_cache = False
259271
272+ #debug
273+ # use_cache = False
274+
260275 if use_cache :
261276 with open (cache_path , 'rb' ) as f :
262277 input_tensors = utils .CPU_Unpickler (f ).load ()
@@ -266,7 +281,7 @@ def load_input_tensors(self, input_jsonl_path, max_seq_length):
266281 elif self .model_type in ('gpt' ,):
267282 input_tensors = load_gpt_input_tensors (input_jsonl_path , max_seq_length )
268283 elif self .model_type in ('bert' , 'xlnet' , 'roberta' , 'albert' ):
269- input_tensors = load_bert_xlnet_roberta_input_tensors (input_jsonl_path , max_seq_length , self .debug , self .tokenizer , self .debug_sample_size )
284+ input_tensors = load_bert_xlnet_roberta_input_tensors (input_jsonl_path , max_seq_length , self .debug , self .tokenizer , self .debug_sample_size , tagged_jsonl_path , emp )
270285
271286 if not self .debug :
272287 utils .save_pickle (input_tensors , cache_path )
@@ -508,7 +523,7 @@ def tokenize_and_encode(tokenizer, obj):
508523 return examples_ids , mc_labels , input_ids , mc_token_ids , lm_labels
509524
510525
511- def load_bert_xlnet_roberta_input_tensors (statement_jsonl_path , max_seq_length , debug , tokenizer , debug_sample_size ):
526+ def load_bert_xlnet_roberta_input_tensors (statement_jsonl_path , max_seq_length , debug , tokenizer , debug_sample_size , tagged_jsonl_path , emp ):
512527 class InputExample (object ):
513528
514529 def __init__ (self , example_id , question , contexts , endings , label = None ):
@@ -528,8 +543,9 @@ def __init__(self, example_id, choices_features, label):
528543 'input_mask' : input_mask ,
529544 'segment_ids' : segment_ids ,
530545 'output_mask' : output_mask ,
546+ 'pool_mask' :pool_mask
531547 }
532- for input_ids , input_mask , segment_ids , output_mask in choices_features
548+ for input_ids , input_mask , segment_ids , output_mask , pool_mask in choices_features
533549 ]
534550 self .label = label
535551
@@ -554,8 +570,52 @@ def read_examples(input_file):
554570 label = label
555571 ))
556572 return examples
573+
574+ def read_tagged_file (input_file ):
575+ # with open(input_file,"r",encoding="utf-8") as f:
576+ # examples = []
577+ # for line in f.readlines():
578+ # json_dic = json.loads(line)
579+ # statement = json_dic['statements']
580+ # answers = json_dic['answers']
581+ # stem = json_dic['stem']
582+ # examples.append((statement,answers,stem))
583+ f = open (input_file )
584+ examples = json .load (f )
585+ return examples
586+
587+
588+ def get_pool_mask (encoded_input_words ,context ,ending ,tagged_ending ,tagged_context ,pool_mask ):
589+ # print(encoded_input_words,'\n\n',context,'\n\n',ending,'\n\n',tagged_ending,'\n\n',tagged_context)
590+ def get_pool_set (items ):
591+ pool_set = set ()
592+ # item [11, 12, 'afford']
593+ for item in items :
594+ start = item [0 ]
595+ end = item [1 ]
596+ if end - start > 1 :
597+ for i in range (start ,end ):
598+ pool_set .add (i )
599+ return pool_set
600+
601+ end_pool_set = get_pool_set (tagged_ending )
602+ context_pool_set = get_pool_set (tagged_context )
603+ flag = True
604+ for i in range (len (pool_mask )):
605+ if flag :
606+ if encoded_input_words [i ] in context_pool_set :
607+ pool_mask [i ] = 1
608+ if encoded_input_words [i ] == None and encoded_input_words [i + 1 ]== None :
609+ flag = False
610+ else :
611+ if encoded_input_words [i ] in end_pool_set :
612+ pool_mask [i ] = 1
613+ if encoded_input_words [i ] == None and encoded_input_words [i + 1 ]== None :
614+ break
615+ # print(pool_mask)
616+ return pool_mask
557617
558- def simple_convert_examples_to_features (examples , label_list , max_seq_length , tokenizer ):
618+ def simple_convert_examples_to_features (examples , label_list , max_seq_length , tokenizer , emp , tagged ):
559619 """ Loads a data file into a list of `InputBatch`s
560620 `cls_token_at_end` define the location of the CLS token:
561621 - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
@@ -572,20 +632,35 @@ def simple_convert_examples_to_features(examples, label_list, max_seq_length, to
572632 choices_features = []
573633 for ending_idx , (context , ending ) in enumerate (zip (example .contexts , example .endings )):
574634 ans = example .question + " " + ending
575-
635+ if emp :
636+ tagged_ending = tagged [ex_index ]['answers' ][ending_idx ]
637+ tagged_context = tagged [ex_index ]['stem' ]
638+
576639 encoded_input = tokenizer (context , ans , padding = "max_length" , truncation = True , max_length = max_seq_length , return_token_type_ids = True , return_special_tokens_mask = True )
640+ # print(encoded_input.words())
577641 input_ids = encoded_input ["input_ids" ]
578642 output_mask = encoded_input ["special_tokens_mask" ]
579643 input_mask = encoded_input ["attention_mask" ]
580644 segment_ids = encoded_input ["token_type_ids" ]
581- # print(context,'\n',ans,'\n',encoded_input["input_ids"])
582- # input()
645+ pool_mask = [0 ]* max_seq_length
646+ if emp :
647+ pool_mask = get_pool_mask (encoded_input .words (),context ,ending ,tagged_ending ,tagged_context ,pool_mask )
648+
649+
583650 assert len (input_ids ) == max_seq_length
584651 assert len (output_mask ) == max_seq_length
585652 assert len (input_mask ) == max_seq_length
586653 assert len (segment_ids ) == max_seq_length
587654
588- choices_features .append ((input_ids , input_mask , segment_ids , output_mask ))
655+ choices_features .append ((input_ids , input_mask , segment_ids , output_mask ,pool_mask ))
656+
657+ # for i in range(max_seq_length):
658+ # if input_ids[i] == 1:
659+ # print(len(context.split())+len(ans.split()),i)
660+ # break
661+
662+ # input()
663+
589664 label = label_map [example .label ]
590665 features .append (InputFeatures (example_id = example .example_id , choices_features = choices_features , label = label ))
591666
@@ -604,8 +679,34 @@ def convert_features_to_tensors(features):
604679 return all_input_ids , all_input_mask , all_segment_ids , all_output_mask , all_label
605680
606681 examples = read_examples (statement_jsonl_path )
607- features , concepts_by_sents_list = simple_convert_examples_to_features (examples , list (range (len (examples [0 ].endings ))), max_seq_length , tokenizer )
682+ tagged = []
683+ if emp :
684+ tagged = read_tagged_file (tagged_jsonl_path )
685+
686+
687+ features , concepts_by_sents_list = simple_convert_examples_to_features (examples , list (range (len (examples [0 ].endings ))), max_seq_length , tokenizer ,emp ,tagged )
608688
609689 example_ids = [f .example_id for f in features ]
610690 * data_tensors , all_label = convert_features_to_tensors (features )
611691 return example_ids , all_label , data_tensors , concepts_by_sents_list
692+
693+ if __name__ == "__main__" :
694+
695+
696+
697+ model_name = 'roberta-large'
698+ tokenizer = AutoTokenizer .from_pretrained (model_name ,use_fast = True )
699+ # model_name = 'bert-base-uncased'
700+ # tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
701+
702+
703+ statement_jsonl_path = './data/obqa/statement/test.statement.jsonl'
704+ max_seq_length = 128
705+ debug = False
706+ emp = False
707+
708+ debug_sample_size = 32
709+ tagged_jsonl_path = './data/obqa/tagged/test.tagged.jsonl'
710+
711+ example_ids , all_label , data_tensors , concepts_by_sents_list = load_bert_xlnet_roberta_input_tensors (statement_jsonl_path , max_seq_length , debug , tokenizer , debug_sample_size ,tagged_jsonl_path ,emp )
712+
0 commit comments