Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ You can specify the GPU you want to use in the beginning of the command `CUDA_VI

**TL;DR**. The preprocessing may take long; for your convenience, you can download all the processed data [here](https://drive.google.com/drive/folders/1T6B4nou5P3u-6jr0z6e3IkitO8fNVM6f?usp=sharing) into the top-level directory of this repo and run
```
unzip data_preprocessed.zip
unzip data_preprocessed.zip
```

**Add MedQA-USMLE**. Besides the commonsense QA datasets (*CommonsenseQA*, *OpenBookQA*) with the ConceptNet knowledge graph, we added a biomedical QA dataset ([*MedQA-USMLE*](https://github.com/jind11/MedQA)) with a biomedical knowledge graph based on Disease Database and DrugBank. You can download all the data for this from [[here]](https://drive.google.com/file/d/1EqbiNt2ACXVrc9gmoXnzTEo9GJTe9Uor/view?usp=sharing). Unzip it and put the `medqa_usmle` and `ddb` folders inside the `data/` directory.


The resulting file structure should look like this:

```plain
Expand Down Expand Up @@ -85,11 +88,18 @@ Similarly, to train GreaseLM on OpenbookQA, run
CUDA_VISIBLE_DEVICES=0 ./run_greaselm.sh obqa --data_dir data/
```

To train GreaseLM on MedQA-USMLE, run
```
CUDA_VISIBLE_DEVICES=0 ./run_greaselm__medqa_usmle.sh
```

### 4. Pretrained model checkpoints
You can download a pretrained GreaseLM model on CommonsenseQA [here](https://drive.google.com/file/d/1QPwLZFA6AQ-pFfDR6TWLdBAvm3c_HOUr/view?usp=sharing), which achieves an IH-dev acc. of `79.0` and an IH-test acc. of `74.0`.

You can also download a pretrained GreaseLM model on OpenbookQA [here](https://drive.google.com/file/d/1-QqyiQuU9xlN20vwfIaqYQ_uJMP8d7Pv/view?usp=sharing), which achieves an test acc. of `84.8`.

You can also download a pretrained GreaseLM model on MedQA-USMLE [here](https://drive.google.com/file/d/1x5nZEprV0Ht8IWViyz3d07uGLXtNjUN1/view?usp=sharing), which achieves an test acc. of `38.5`.

### 5. Evaluating a pretrained model checkpoint
To evaluate a pretrained GreaseLM model checkpoint on CommonsenseQA, run
```
Expand All @@ -114,4 +124,4 @@ This repo is built upon the following work:
QA-GNN: Question Answering using Language Models and Knowledge Graphs
https://github.com/michiyasunaga/qagnn
```
Many thanks to the authors and developers!
Many thanks to the authors and developers!
33 changes: 23 additions & 10 deletions greaselm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DECODER_DEFAULT_LR = {
'csqa': 1e-3,
'obqa': 3e-4,
'medqa_usmle': 1e-3,
}

import numpy as np
Expand Down Expand Up @@ -81,8 +82,11 @@ def construct_model(args, kg):
##########################################################

if kg == "cpnet":
n_ntype = 4
n_ntype = 4
n_etype = 38
elif kg == "ddb":
n_ntype = 4
n_etype = 34
else:
raise ValueError("Invalid KG.")
if args.cxt_node_connects_all:
Expand Down Expand Up @@ -178,7 +182,7 @@ def calc_eval_accuracy(eval_set, model, loss_type, loss_func, debug, save_test_p

def train(args, resume, has_test_split, devices, kg):
print("args: {}".format(args))

if resume:
args.save_dir = os.path.dirname(args.resume_checkpoint)
if not args.debug:
Expand Down Expand Up @@ -211,7 +215,14 @@ def train(args, resume, has_test_split, devices, kg):

# Get the names of the loaded LM parameters
loading_info = model.lmgnn.loading_info
loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
# loaded_roberta_keys = [k.replace("roberta.", "lmgnn.mp.") for k in loading_info["all_keys"]]
def _rename_key(key):
if key.startswith("roberta."):
return key.replace("roberta.", "lmgnn.mp.")
else:
return "lmgnn.mp." + key

loaded_roberta_keys = [_rename_key(k) for k in loading_info["all_keys"]]

# Separate the parameters into loaded and not loaded
loaded_params, not_loaded_params, params_to_freeze, small_lr_params, large_lr_params = sep_params(model, loaded_roberta_keys)
Expand Down Expand Up @@ -316,7 +327,7 @@ def train(args, resume, has_test_split, devices, kg):
model.train()

for qids, labels, *input_data in tqdm(train_dataloader, desc="Batch"):
# labels: [bs]
# labels: [bs]
start_time = time.time()
optimizer.zero_grad()
bs = labels.size(0)
Expand Down Expand Up @@ -387,11 +398,11 @@ def train(args, resume, has_test_split, devices, kg):
if not args.debug:
with open(log_path, 'a') as fout:
fout.write('{:3},{:5},{:7.4f},{:7.4f},{:7.4f},{:7.4f},{:3}\n'.format(epoch_id, global_step, dev_acc, test_acc, best_dev_acc, final_test_acc, best_dev_epoch))

wandb.log({"dev_acc": dev_acc, "dev_loss": dev_total_loss, "best_dev_acc": best_dev_acc, "best_dev_epoch": best_dev_epoch}, step=global_step)
if has_test_split:
wandb.log({"test_acc": test_acc, "test_loss": test_total_loss, "final_test_acc": final_test_acc}, step=global_step)

# Save the model checkpoint
if args.save_model:
model_state_dict = model.state_dict()
Expand Down Expand Up @@ -500,10 +511,12 @@ def main(args):
logging.basicConfig(format='%(asctime)s,%(msecs)d %(levelname)-8s [%(name)s:%(funcName)s():%(lineno)d] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.WARNING)

has_test_split = True
devices = get_devices(args.cuda)
kg = "cpnet"
if args.dataset == "medqa_usmle":
kg = "ddb"

if not args.use_wandb:
wandb_mode = "disabled"
Expand All @@ -518,7 +531,7 @@ def main(args):
args.wandb_id = wandb_id

args.hf_version = transformers.__version__

with wandb.init(project="KG-LM", config=args, name=args.run_name, resume="allow", id=wandb_id, settings=wandb.Settings(start_method="fork"), mode=wandb_mode):
print(socket.gethostname())
print ("pid:", os.getpid())
Expand All @@ -537,7 +550,7 @@ def main(args):

if __name__ == '__main__':
__spec__ = None

parser = parser_utils.get_parser()
args, _ = parser.parse_known_args()

Expand Down Expand Up @@ -590,4 +603,4 @@ def main(args):
parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')

args = parser.parse_args()
main(args)
main(args)
41 changes: 25 additions & 16 deletions modeling/modeling_greaselm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
logger = logging.getLogger(__name__)


if os.environ.get('INHERIT_BERT', 0):
ModelClass = modeling_bert.BertModel
else:
ModelClass = modeling_roberta.RobertaModel

print ('ModelClass', ModelClass)


class GreaseLM(nn.Module):

def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
Expand All @@ -31,11 +39,11 @@ def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=
pretrained_concept_emb=None, freeze_ent_emb=True,
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
super().__init__()
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
self.lmgnn = LMGNN(args, model_name, k, n_ntype, n_etype,
n_concept, concept_dim, concept_in_dim, n_attention_head,
fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb,
init_range=init_range, ie_dim=ie_dim, info_exchange=info_exchange, ie_layer_num=ie_layer_num, sep_ie_layers=sep_ie_layers, layer_id=layer_id)

def batch_graph(self, edge_index_init, edge_type_init, n_nodes):
"""
edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E)
Expand All @@ -59,7 +67,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
-> (2, total E)
edge_type: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), )
-> (total E, )

returns:
logits: [bs, nc]
"""
Expand All @@ -85,7 +93,7 @@ def forward(self, *inputs, cache_output=False, detail=False):
return logits, attn, concept_ids.view(bs, nc, -1), node_type_ids.view(bs, nc, -1), edge_index_orig, edge_type_orig
# edge_index_orig: list of (batch_size, num_choice). each entry is torch.tensor(2, E)
# edge_type_orig: list of (batch_size, num_choice). each entry is torch.tensor(E, )

def get_fake_inputs(self, device="cuda:0"):
bs = 4
nc = 5
Expand Down Expand Up @@ -129,14 +137,14 @@ def test_GreaseLM(device):

class LMGNN(nn.Module):

def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
def __init__(self, args={}, model_name="roberta-large", k=5, n_ntype=4, n_etype=38,
n_concept=799273, concept_dim=200, concept_in_dim=1024, n_attention_head=2,
fc_dim=200, n_fc_layer=0, p_emb=0.2, p_gnn=0.2, p_fc=0.2,
pretrained_concept_emb=None, freeze_ent_emb=True,
init_range=0.02, ie_dim=200, info_exchange=True, ie_layer_num=1, sep_ie_layers=False, layer_id=-1):
super().__init__()
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
model_name,
config, _ = ModelClass.config_class.from_pretrained(
model_name,
cache_dir=None, return_unused_kwargs=True,
force_download=False,
output_hidden_states=True
Expand Down Expand Up @@ -281,11 +289,12 @@ def test_LMGNN(device):
model.check_outputs(*outputs)


class TextKGMessagePassing(modeling_roberta.RobertaModel):

class TextKGMessagePassing(ModelClass):

def __init__(self, config, args={}, k=5, n_ntype=4, n_etype=38, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
super().__init__(config=config)

self.n_ntype = n_ntype
self.n_etype = n_etype

Expand Down Expand Up @@ -633,7 +642,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

all_keys = list(state_dict.keys())

# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
Expand Down Expand Up @@ -756,7 +765,7 @@ class RoBERTaGAT(modeling_bert.BertEncoder):

def __init__(self, config, k=5, n_ntype=4, n_etype=38, hidden_size=200, dropout=0.2, concept_dim=200, ie_dim=200, p_fc=0.2, info_exchange=True, ie_layer_num=1, sep_ie_layers=False):
super().__init__(config)

self.k = k
self.edge_encoder = torch.nn.Sequential(torch.nn.Linear(n_etype + 1 + n_ntype * 2, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size))
self.gnn_layers = nn.ModuleList([modeling_gnn.GATConvE(hidden_size, n_ntype, n_etype, self.edge_encoder) for _ in range(k)])
Expand Down Expand Up @@ -799,14 +808,14 @@ def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask,

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if i >= self.num_hidden_layers - self.k:
# GNN
gnn_layer_index = i - self.num_hidden_layers + self.k
_X = self.gnn_layers[gnn_layer_index](_X, edge_index, edge_type, _node_type, _node_feature_extra)
_X = self.activation(_X)
_X = F.dropout(_X, self.dropout_rate, training = self.training)

# Exchange info between LM and GNN hidden states (Modality interaction)
if self.info_exchange == True or (self.info_exchange == "every-other-layer" and (i - self.num_hidden_layers + self.k) % 2 == 0):
X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
Expand Down Expand Up @@ -861,7 +870,7 @@ def check_outputs(self, outputs, _X):

def test_RoBERTaGAT(device):
config, _ = modeling_roberta.RobertaModel.config_class.from_pretrained(
"roberta-large",
"roberta-large",
cache_dir=None, return_unused_kwargs=True,
force_download=False,
output_hidden_states=True
Expand All @@ -880,11 +889,11 @@ def test_RoBERTaGAT(device):
utils.print_cuda_info()
free_gpus = utils.select_free_gpus()
device = torch.device("cuda:{}".format(free_gpus[0]))

# test_RoBERTaGAT(device)

# test_TextKGMessagePassing(device)

# test_LMGNN(device)

test_GreaseLM(device)
test_GreaseLM(device)
69 changes: 69 additions & 0 deletions run_greaselm__medqa_usmle.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash
export INHERIT_BERT=1
export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=true
dt=`date '+%Y%m%d_%H%M%S'`


dataset="medqa_usmle"
shift
encoder='cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
args=$@


elr="5e-5"
dlr="1e-3"
bs=128
mbs=2
unfreeze_epoch=0
k=3 #num of gnn layers
gnndim=200

# Existing arguments but changed for GreaseLM
encoder_layer=-1
max_node_num=200
seed=5
lr_schedule=fixed

n_epochs=20
max_epochs_before_stop=10
ie_dim=400

max_seq_len=512
ent_emb=ddb
inhouse=false

# Added for GreaseLM
info_exchange=true
ie_layer_num=1
resume_checkpoint=None
resume_id=None
sep_ie_layers=false
random_ent_emb=false

echo "***** hyperparameters *****"
echo "dataset: $dataset"
echo "enc_name: $encoder"
echo "batch_size: $bs mini_batch_size: $mbs"
echo "learning_rate: elr $elr dlr $dlr"
echo "gnn: dim $gnndim layer $k"
echo "ie_dim: ${ie_dim}, info_exchange: ${info_exchange}"
echo "******************************"

save_dir_pref='runs'
mkdir -p $save_dir_pref

run_name=greaselm__ds_${dataset}__ih_${inhouse}__enc_sapbert__k${k}__sd${seed}__iedim${ie_dim}__unfrz${unfreeze_epoch}__${dt}
log=logs/train_${dataset}__${run_name}.log.txt

###### Training ######
python3 -u greaselm.py \
--dataset $dataset \
--encoder $encoder -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs --seed $seed -mbs ${mbs} --unfreeze_epoch ${unfreeze_epoch} --encoder_layer=${encoder_layer} -sl ${max_seq_len} --max_node_num ${max_node_num} \
--n_epochs $n_epochs --max_epochs_before_stop ${max_epochs_before_stop} \
--save_dir ${save_dir_pref}/${dataset}/${run_name} \
--run_name ${run_name} \
--ie_dim ${ie_dim} --info_exchange ${info_exchange} --ie_layer_num ${ie_layer_num} --resume_checkpoint ${resume_checkpoint} --resume_id ${resume_id} --sep_ie_layers ${sep_ie_layers} --random_ent_emb ${random_ent_emb} --ent_emb ${ent_emb//,/ } --lr_schedule ${lr_schedule} -ih ${inhouse} \
--data_dir data \
> ${log}
# echo log: ${log}
Loading