Skip to content

Commit ce5c26e

Browse files
committed
medqa configs
1 parent a194c0d commit ce5c26e

2 files changed

Lines changed: 37 additions & 6 deletions

File tree

utils/data_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@
3232

3333
MODEL_NAME_TO_CLASS = {model_name: model_class for model_class, model_name_list in MODEL_CLASS_TO_NAME.items() for model_name in model_name_list}
3434

35+
#Add SapBERT configuration
36+
model_name = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
37+
MODEL_NAME_TO_CLASS[model_name] = 'bert'
38+
3539
GPT_SPECIAL_TOKENS = ['_start_', '_delimiter_', '_classify_']
3640

3741

3842
class MultiGPUSparseAdjDataBatchGenerator(object):
3943
"""A data generator that batches the data and moves them to the corresponding devices."""
40-
def __init__(self, device0, device1, batch_size, indexes, qids, labels,
44+
def __init__(self, device0, device1, batch_size, indexes, qids, labels,
4145
tensors0=[], lists0=[], tensors1=[], lists1=[], adj_data=None):
4246
self.device0 = device0
4347
self.device1 = device1
@@ -220,6 +224,28 @@ def load_resources(self, kg):
220224
self.id2concept = [w.strip() for w in fin]
221225
self.concept2id = {w: i for i, w in enumerate(self.id2concept)}
222226
self.id2relation = conceptnet.merged_relations
227+
elif kg == "ddb":
228+
cpnet_vocab_path = "data/ddb/vocab.txt"
229+
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
230+
self.id2concept = [w.strip() for w in fin]
231+
self.concept2id = {w: i for i, w in enumerate(self.id2concept)}
232+
self.id2relation = [
233+
'belongstothecategoryof',
234+
'isacategory',
235+
'maycause',
236+
'isasubtypeof',
237+
'isariskfactorof',
238+
'isassociatedwith',
239+
'maycontraindicate',
240+
'interactswith',
241+
'belongstothedrugfamilyof',
242+
'child-parent',
243+
'isavectorfor',
244+
'mabeallelicwith',
245+
'seealso',
246+
'isaningradientof',
247+
'mabeindicatedby'
248+
]
223249
else:
224250
raise ValueError("Invalid value for kg.")
225251

@@ -406,7 +432,7 @@ def load_sparse_adj_data_with_contextnode(self, adj_pk_path, max_node_num, conce
406432
#node_scores: (n_questions, num_choice, max_node_num)
407433
#adj_lengths: (n_questions, num_choice)
408434
return concept_ids, node_type_ids, node_scores, adj_lengths, special_nodes_mask, (edge_index, edge_type) #, half_n_rel * 2 + 1
409-
435+
410436

411437
def load_gpt_input_tensors(statement_jsonl_path, max_seq_length):
412438
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
@@ -526,7 +552,7 @@ def read_examples(input_file):
526552
label=label
527553
))
528554
return examples
529-
555+
530556
def simple_convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
531557
""" Loads a data file into a list of `InputBatch`s
532558
`cls_token_at_end` define the location of the CLS token:
@@ -577,4 +603,4 @@ def convert_features_to_tensors(features):
577603
features, concepts_by_sents_list = simple_convert_examples_to_features(examples, list(range(len(examples[0].endings))), max_seq_length, tokenizer)
578604
example_ids = [f.example_id for f in features]
579605
*data_tensors, all_label = convert_features_to_tensors(features)
580-
return example_ids, all_label, data_tensors, concepts_by_sents_list
606+
return example_ids, all_label, data_tensors, concepts_by_sents_list

utils/parser_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@
1818
'bert-large-cased': 1e-4,
1919
'roberta-large': 1e-5,
2020
},
21+
'medqa_usmle': {
22+
'cambridgeltl/SapBERT-from-PubMedBERT-fulltext': 5e-5,
23+
},
2124
}
2225

23-
DATASET_LIST = ['csqa', 'obqa']
26+
DATASET_LIST = ['csqa', 'obqa', 'medqa_usmle']
2427

2528
DATASET_SETTING = {
2629
'csqa': 'inhouse',
2730
'obqa': 'official',
31+
'medqa_usmle': 'official',
2832
}
2933

3034
DATASET_NO_TEST = []
@@ -33,12 +37,13 @@
3337
'transe': 'data/cpnet/glove.transe.sgd.ent.npy',
3438
'numberbatch': 'data/cpnet/concept.nb.npy',
3539
'tzw': 'data/cpnet/tzw.ent.npy',
40+
'ddb': 'data/ddb/ent_emb.npy',
3641
}
3742

3843

3944
def add_data_arguments(parser):
4045
# arguments that all datasets share
41-
parser.add_argument('--ent_emb', default=['tzw'], choices=['tzw', "transe", "numberbatch"], nargs='+', help='sources for entity embeddings')
46+
parser.add_argument('--ent_emb', default=['tzw'], nargs='+', help='sources for entity embeddings')
4247
# dataset specific
4348
parser.add_argument('-ds', '--dataset', default='csqa', choices=DATASET_LIST, help='dataset name')
4449
parser.add_argument('--data_dir', default='data', type=str, help='Path to the data directory')

0 commit comments

Comments
 (0)