Skip to content

Commit a194c0d

Browse files
authored
Merge pull request snap-stanford#1 from michiyasunaga/main
add medqa
2 parents 079acd4 + 596aa50 commit a194c0d

5 files changed

Lines changed: 718 additions & 28 deletions

File tree

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ You can specify the GPU you want to use in the beginning of the command `CUDA_VI
5353

5454
**TL;DR**. The preprocessing may take long; for your convenience, you can download all the processed data [here](https://drive.google.com/drive/folders/1T6B4nou5P3u-6jr0z6e3IkitO8fNVM6f?usp=sharing) into the top-level directory of this repo and run
5555
```
56-
unzip data_preprocessed.zip
56+
unzip data_preprocessed.zip
5757
```
5858

59+
**Add MedQA-USMLE**. Besides the commonsense QA datasets (*CommonsenseQA*, *OpenBookQA*) with the ConceptNet knowledge graph, we added a biomedical QA dataset ([*MedQA-USMLE*](https://github.com/jind11/MedQA)) with a biomedical knowledge graph based on Disease Database and DrugBank. You can download all the data for this from [[here]](https://drive.google.com/file/d/1EqbiNt2ACXVrc9gmoXnzTEo9GJTe9Uor/view?usp=sharing). Unzip it and put the `medqa_usmle` and `ddb` folders inside the `data/` directory.
60+
61+
5962
The resulting file structure should look like this:
6063

6164
```plain
@@ -85,11 +88,18 @@ Similarly, to train GreaseLM on OpenbookQA, run
8588
CUDA_VISIBLE_DEVICES=0 ./run_greaselm.sh obqa --data_dir data/
8689
```
8790

91+
To train GreaseLM on MedQA-USMLE, run
92+
```
93+
CUDA_VISIBLE_DEVICES=0 ./run_greaselm__medqa_usmle.sh
94+
```
95+
8896
### 4. Pretrained model checkpoints
8997
You can download a pretrained GreaseLM model on CommonsenseQA [here](https://drive.google.com/file/d/1QPwLZFA6AQ-pFfDR6TWLdBAvm3c_HOUr/view?usp=sharing), which achieves an IH-dev acc. of `79.0` and an IH-test acc. of `74.0`.
9098

9199
You can also download a pretrained GreaseLM model on OpenbookQA [here](https://drive.google.com/file/d/1-QqyiQuU9xlN20vwfIaqYQ_uJMP8d7Pv/view?usp=sharing), which achieves an test acc. of `84.8`.
92100

101+
You can also download a pretrained GreaseLM model on MedQA-USMLE [here](https://drive.google.com/file/d/1x5nZEprV0Ht8IWViyz3d07uGLXtNjUN1/view?usp=sharing), which achieves an test acc. of `38.5`.
102+
93103
### 5. Evaluating a pretrained model checkpoint
94104
To evaluate a pretrained GreaseLM model checkpoint on CommonsenseQA, run
95105
```
@@ -114,4 +124,4 @@ This repo is built upon the following work:
114124
QA-GNN: Question Answering using Language Models and Knowledge Graphs
115125
https://github.com/michiyasunaga/qagnn
116126
```
117-
Many thanks to the authors and developers!
127+
Many thanks to the authors and developers!

greaselm.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DECODER_DEFAULT_LR = {
2727
'csqa': 1e-3,
2828
'obqa': 3e-4,
29+
'medqa_usmle': 1e-3,
2930
}
3031

3132
import numpy as np
@@ -81,8 +82,11 @@ def construct_model(args, kg):
8182
##########################################################
8283

8384
if kg == "cpnet":
84-
n_ntype = 4
85+
n_ntype = 4
8586
n_etype = 38
87+
elif kg == "ddb":
88+
n_ntype = 4
89+
n_etype = 34
8690
else:
8791
raise ValueError("Invalid KG.")
8892
if args.cxt_node_connects_all:
@@ -178,7 +182,7 @@ def calc_eval_accuracy(eval_set, model, loss_type, loss_func, debug, save_test_p
178182

179183
def train(args, resume, has_test_split, devices, kg):
180184
print("args: {}".format(args))
181-
185+
182186
if resume:
183187
args.save_dir = os.path.dirname(args.resume_checkpoint)
184188
if not args.debug:
@@ -211,7 +215,14 @@ def train(args, resume, has_test_split, devices, kg):
211215

212216
# Get the names of the loaded LM parameters
213217
loading_info = model.lmgnn.loading_info
214-
loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
218+
# loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
219+
def _rename_key(key):
220+
if key.startswith("roberta."):
221+
return key.replace("roberta.", "lmgnn.mp.")
222+
else:
223+
return "lmgnn.mp." + key
224+
225+
loaded_roberta_keys = [_rename_key(k) for k in loading_info["all_keys"]]
215226

216227
# Separate the parameters into loaded and not loaded
217228
loaded_params, not_loaded_params, params_to_freeze, small_lr_params, large_lr_params = sep_params(model, loaded_roberta_keys)
@@ -316,7 +327,7 @@ def train(args, resume, has_test_split, devices, kg):
316327
model.train()
317328

318329
for qids, labels, *input_data in tqdm(train_dataloader, desc="Batch"):
319-
# labels: [bs]
330+
# labels: [bs]
320331
start_time = time.time()
321332
optimizer.zero_grad()
322333
bs = labels.size(0)
@@ -387,11 +398,11 @@ def train(args, resume, has_test_split, devices, kg):
387398
if not args.debug:
388399
with open(log_path, 'a') as fout:
389400
fout.write('{:3},{:5},{:7.4f},{:7.4f},{:7.4f},{:7.4f},{:3}\n'.format(epoch_id, global_step, dev_acc, test_acc, best_dev_acc, final_test_acc, best_dev_epoch))
390-
401+
391402
wandb.log({"dev_acc": dev_acc, "dev_loss": dev_total_loss, "best_dev_acc": best_dev_acc, "best_dev_epoch": best_dev_epoch}, step=global_step)
392403
if has_test_split:
393404
wandb.log({"test_acc": test_acc, "test_loss": test_total_loss, "final_test_acc": final_test_acc}, step=global_step)
394-
405+
395406
# Save the model checkpoint
396407
if args.save_model:
397408
model_state_dict = model.state_dict()
@@ -500,10 +511,12 @@ def main(args):
500511
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(name)s:%(funcName)s():%(lineno)d] %(message)s',
501512
datefmt='%m/%d/%Y %H:%M:%S',
502513
level=logging.WARNING)
503-
514+
504515
has_test_split = True
505516
devices = get_devices(args.cuda)
506517
kg = "cpnet"
518+
if args.dataset == "medqa_usmle":
519+
kg = "ddb"
507520

508521
if not args.use_wandb:
509522
wandb_mode = "disabled"
@@ -518,7 +531,7 @@ def main(args):
518531
args.wandb_id = wandb_id
519532

520533
args.hf_version = transformers.__version__
521-
534+
522535
with wandb.init(project="KG-LM", config=args, name=args.run_name, resume="allow", id=wandb_id, settings=wandb.Settings(start_method="fork"), mode=wandb_mode):
523536
print(socket.gethostname())
524537
print ("pid:", os.getpid())
@@ -537,7 +550,7 @@ def main(args):
537550

538551
if __name__ == '__main__':
539552
__spec__ = None
540-
553+
541554
parser = parser_utils.get_parser()
542555
args, _ = parser.parse_known_args()
543556

@@ -590,4 +603,4 @@ def main(args):
590603
parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')
591604

592605
args = parser.parse_args()
593-
main(args)
606+
main(args)

modeling/modeling_greaselm.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
if os.environ.get('INHERIT_BERT', 0):
27+
ModelClass = modeling_bert.BertModel
28+
else:
29+
ModelClass = modeling_roberta.RobertaModel
30+
31+
print ('ModelClass', ModelClass)
32+
33+
2634
class GreaseLM(nn.Module):
2735

2836
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
@@ -31,11 +39,11 @@ def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=
3139
pretrained_concept_emb=None, freeze_ent_emb=True,
3240
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
3341
super().__init__()
34-
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
42+
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
3543
n_concept, concept_dim, concept_in_dim, n_attention_head,
3644
fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb,
3745
init_range=init_range, ie_dim=ie_dim, info_exchange=info_exchange, ie_layer_num=ie_layer_num, sep_ie_layers=sep_ie_layers, layer_id=layer_id)
38-
46+
3947
def batch_graph(self, edge_index_init, edge_type_init, n_nodes):
4048
"""
4149
edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E)
@@ -59,7 +67,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
5967
-> (2, total E)
6068
edge_type: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), )
6169
-> (total E, )
62-
70+
6371
returns:
6472
logits: [bs, nc]
6573
"""
@@ -85,7 +93,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
8593
return logits, attn, concept_ids.view(bs, nc, -1), node_type_ids.view(bs, nc, -1), edge_index_orig, edge_type_orig
8694
# edge_index_orig: list of (batch_size, num_choice). each entry is torch.tensor(2, E)
8795
# edge_type_orig: list of (batch_size, num_choice). each entry is torch.tensor(E, )
88-
96+
8997
def get_fake_inputs(self, device="cuda:0"):
9098
bs = 4
9199
nc = 5
@@ -129,14 +137,14 @@ def test_GreaseLM(device):
129137

130138
class LMGNN(nn.Module):
131139

132-
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
140+
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
133141
n_concept=799273, concept_dim=200, concept_in_dim=1024, n_attention_head=2,
134142
fc_dim=200, n_fc_layer=0, p_emb=0.2, p_gnn=0.2, p_fc=0.2,
135143
pretrained_concept_emb=None, freeze_ent_emb=True,
136144
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
137145
super().__init__()
138-
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
139-
model_name,
146+
config, _ = ModelClass.config_class.from_pretrained(
147+
model_name,
140148
cache_dir=None, return_unused_kwargs=True,
141149
force_download=False,
142150
output_hidden_states=True
@@ -281,11 +289,12 @@ def test_LMGNN(device):
281289
model.check_outputs(*outputs)
282290

283291

284-
class TextKGMessagePassing(modeling_roberta.RobertaModel):
292+
293+
class TextKGMessagePassing(ModelClass):
285294

286295
def __init__(self, config, args={}, k=5, n_ntype=4, n_etype=38, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
287296
super().__init__(config=config)
288-
297+
289298
self.n_ntype = n_ntype
290299
self.n_etype = n_etype
291300

@@ -633,7 +642,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
633642
state_dict = state_dict.copy()
634643
if metadata is not None:
635644
state_dict._metadata = metadata
636-
645+
637646
all_keys = list(state_dict.keys())
638647

639648
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
@@ -756,7 +765,7 @@ class RoBERTaGAT(modeling_bert.BertEncoder):
756765

757766
def __init__(self, config, k=5, n_ntype=4, n_etype=38, hidden_size=200, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
758767
super().__init__(config)
759-
768+
760769
self.k = k
761770
self.edge_encoder = torch.nn.Sequential(torch.nn.Linear(n_etype + 1 + n_ntype * 2, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size))
762771
self.gnn_layers = nn.ModuleList([modeling_gnn.GATConvE(hidden_size, n_ntype, n_etype, self.edge_encoder) for _ in range(k)])
@@ -799,14 +808,14 @@ def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask,
799808

800809
if output_attentions:
801810
all_attentions = all_attentions + (layer_outputs[1],)
802-
811+
803812
if i >= self.num_hidden_layers - self.k:
804813
# GNN
805814
gnn_layer_index = i - self.num_hidden_layers + self.k
806815
_X = self.gnn_layers[gnn_layer_index](_X, edge_index, edge_type, _node_type, _node_feature_extra)
807816
_X = self.activation(_X)
808817
_X = F.dropout(_X, self.dropout_rate, training = self.training)
809-
818+
810819
# Exchange info between LM and GNN hidden states (Modality interaction)
811820
if self.info_exchange == True or (self.info_exchange == "every-other-layer" and (i - self.num_hidden_layers + self.k) % 2 == 0):
812821
X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
@@ -861,7 +870,7 @@ def check_outputs(self, outputs, _X):
861870

862871
def test_RoBERTaGAT(device):
863872
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
864-
"roberta-large",
873+
"roberta-large",
865874
cache_dir=None, return_unused_kwargs=True,
866875
force_download=False,
867876
output_hidden_states=True
@@ -880,11 +889,11 @@ def test_RoBERTaGAT(device):
880889
utils.print_cuda_info()
881890
free_gpus = utils.select_free_gpus()
882891
device = torch.device("cuda:{}".format(free_gpus[0]))
883-
892+
884893
# test_RoBERTaGAT(device)
885894

886895
# test_TextKGMessagePassing(device)
887896

888897
# test_LMGNN(device)
889898

890-
test_GreaseLM(device)
899+
test_GreaseLM(device)

run_greaselm__medqa_usmle.sh

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#!/bin/bash
2+
export INHERIT_BERT=1
3+
export CUDA_VISIBLE_DEVICES=0
4+
export TOKENIZERS_PARALLELISM=true
5+
dt=`date '+%Y%m%d_%H%M%S'`
6+
7+
8+
dataset="medqa_usmle"
9+
shift
10+
encoder='cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
11+
args=$@
12+
13+
14+
elr="5e-5"
15+
dlr="1e-3"
16+
bs=128
17+
mbs=2
18+
unfreeze_epoch=0
19+
k=3 #num of gnn layers
20+
gnndim=200
21+
22+
# Existing arguments but changed for GreaseLM
23+
encoder_layer=-1
24+
max_node_num=200
25+
seed=5
26+
lr_schedule=fixed
27+
28+
n_epochs=20
29+
max_epochs_before_stop=10
30+
ie_dim=400
31+
32+
max_seq_len=512
33+
ent_emb=ddb
34+
inhouse=false
35+
36+
# Added for GreaseLM
37+
info_exchange=true
38+
ie_layer_num=1
39+
resume_checkpoint=None
40+
resume_id=None
41+
sep_ie_layers=false
42+
random_ent_emb=false
43+
44+
echo "***** hyperparameters *****"
45+
echo "dataset: $dataset"
46+
echo "enc_name: $encoder"
47+
echo "batch_size: $bs mini_batch_size: $mbs"
48+
echo "learning_rate: elr $elr dlr $dlr"
49+
echo "gnn: dim $gnndim layer $k"
50+
echo "ie_dim: ${ie_dim}, info_exchange: ${info_exchange}"
51+
echo "******************************"
52+
53+
save_dir_pref='runs'
54+
mkdir -p $save_dir_pref
55+
56+
run_name=greaselm__ds_${dataset}__ih_${inhouse}__enc_sapbert__k${k}__sd${seed}__iedim${ie_dim}__unfrz${unfreeze_epoch}__${dt}
57+
log=logs/train_${dataset}__${run_name}.log.txt
58+
59+
###### Training ######
60+
python3 -u greaselm.py \
61+
--dataset $dataset \
62+
--encoder $encoder -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs --seed $seed -mbs ${mbs} --unfreeze_epoch ${unfreeze_epoch} --encoder_layer=${encoder_layer} -sl ${max_seq_len} --max_node_num ${max_node_num} \
63+
--n_epochs $n_epochs --max_epochs_before_stop ${max_epochs_before_stop} \
64+
--save_dir ${save_dir_pref}/${dataset}/${run_name} \
65+
--run_name ${run_name} \
66+
--ie_dim ${ie_dim} --info_exchange ${info_exchange} --ie_layer_num ${ie_layer_num} --resume_checkpoint ${resume_checkpoint} --resume_id ${resume_id} --sep_ie_layers ${sep_ie_layers} --random_ent_emb ${random_ent_emb} --ent_emb ${ent_emb//,/ } --lr_schedule ${lr_schedule} -ih ${inhouse} \
67+
--data_dir data \
68+
> ${log}
69+
# echo log: ${log}

0 commit comments

Comments
 (0)