3232
3333MODEL_NAME_TO_CLASS = {model_name : model_class for model_class , model_name_list in MODEL_CLASS_TO_NAME .items () for model_name in model_name_list }
3434
35+ #Add SapBERT configuration
36+ model_name = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
37+ MODEL_NAME_TO_CLASS [model_name ] = 'bert'
38+
3539GPT_SPECIAL_TOKENS = ['_start_' , '_delimiter_' , '_classify_' ]
3640
3741
3842class MultiGPUSparseAdjDataBatchGenerator (object ):
3943 """A data generator that batches the data and moves them to the corresponding devices."""
40- def __init__ (self , device0 , device1 , batch_size , indexes , qids , labels ,
44+ def __init__ (self , device0 , device1 , batch_size , indexes , qids , labels ,
4145 tensors0 = [], lists0 = [], tensors1 = [], lists1 = [], adj_data = None ):
4246 self .device0 = device0
4347 self .device1 = device1
@@ -220,6 +224,28 @@ def load_resources(self, kg):
220224 self .id2concept = [w .strip () for w in fin ]
221225 self .concept2id = {w : i for i , w in enumerate (self .id2concept )}
222226 self .id2relation = conceptnet .merged_relations
227+ elif kg == "ddb" :
228+ cpnet_vocab_path = "data/ddb/vocab.txt"
229+ with open (cpnet_vocab_path , "r" , encoding = "utf8" ) as fin :
230+ self .id2concept = [w .strip () for w in fin ]
231+ self .concept2id = {w : i for i , w in enumerate (self .id2concept )}
232+ self .id2relation = [
233+ 'belongstothecategoryof' ,
234+ 'isacategory' ,
235+ 'maycause' ,
236+ 'isasubtypeof' ,
237+ 'isariskfactorof' ,
238+ 'isassociatedwith' ,
239+ 'maycontraindicate' ,
240+ 'interactswith' ,
241+ 'belongstothedrugfamilyof' ,
242+ 'child-parent' ,
243+ 'isavectorfor' ,
244+ 'mabeallelicwith' ,
245+ 'seealso' ,
246+ 'isaningradientof' ,
247+ 'mabeindicatedby'
248+ ]
223249 else :
224250 raise ValueError ("Invalid value for kg." )
225251
@@ -406,7 +432,7 @@ def load_sparse_adj_data_with_contextnode(self, adj_pk_path, max_node_num, conce
406432 #node_scores: (n_questions, num_choice, max_node_num)
407433 #adj_lengths: (n_questions, num_choice)
408434 return concept_ids , node_type_ids , node_scores , adj_lengths , special_nodes_mask , (edge_index , edge_type ) #, half_n_rel * 2 + 1
409-
435+
410436
411437def load_gpt_input_tensors (statement_jsonl_path , max_seq_length ):
412438 def _truncate_seq_pair (tokens_a , tokens_b , max_length ):
@@ -526,7 +552,7 @@ def read_examples(input_file):
526552 label = label
527553 ))
528554 return examples
529-
555+
530556 def simple_convert_examples_to_features (examples , label_list , max_seq_length , tokenizer ):
531557 """ Loads a data file into a list of `InputBatch`s
532558 `cls_token_at_end` define the location of the CLS token:
@@ -577,4 +603,4 @@ def convert_features_to_tensors(features):
577603 features , concepts_by_sents_list = simple_convert_examples_to_features (examples , list (range (len (examples [0 ].endings ))), max_seq_length , tokenizer )
578604 example_ids = [f .example_id for f in features ]
579605 * data_tensors , all_label = convert_features_to_tensors (features )
580- return example_ids , all_label , data_tensors , concepts_by_sents_list
606+ return example_ids , all_label , data_tensors , concepts_by_sents_list
0 commit comments