Skip to content

Commit a5efab6

Browse files
committed
init
0 parents  commit a5efab6

27 files changed

Lines changed: 6147 additions & 0 deletions

.gitignore

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
131+
# Vscode file
132+
.vscode/
133+
134+
# GreaseLM project specific
135+
data/
136+
data_download/
137+
logs/
138+
runs/
139+
*.zip
140+
wandb/
141+
checkpoint/
142+
log_useful/
143+
144+
# GreaseLM running generate
145+
filtered_concept.txt
146+
matcher_res.json

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2021 Xikun Zhang
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# GreaseLM: Graph REASoning Enhanced Language Models for Question Answering
2+
3+
This repo provides the source code & data of our paper [GreaseLM: Graph REASoning Enhanced Language Models for Question Answering](https://arxiv.org/abs/2201.08860) (ICLR 2022 spotlight). If you use any of our code, processed data or pretrained models, please cite:
4+
```bib
5+
@inproceedings{zhang2021greaselm,
6+
title={GreaseLM: Graph REASoning Enhanced Language Models},
7+
author={Zhang, Xikun and Bosselut, Antoine and Yasunaga, Michihiro and Ren, Hongyu and Liang, Percy and Manning, Christopher D and Leskovec, Jure},
8+
booktitle={International Conference on Learning Representations},
9+
year={2021}
10+
}
11+
```
12+
13+
<p align="center">
14+
<img src="./figs/greaselm.png" width="600" title="GreaseLM model architecture" alt="">
15+
</p>
16+
17+
## 1. Dependencies
18+
19+
- [Python](<https://www.python.org/>) == 3.8
20+
- [PyTorch](<https://pytorch.org/get-started/locally/>) == 1.8.0
21+
- [transformers](<https://github.com/huggingface/transformers/tree/v3.4.0>) == 3.4.0
22+
- [torch-geometric](https://pytorch-geometric.readthedocs.io/) == 1.7.0
23+
24+
Run the following commands to create a conda environment (assuming CUDA 10.1):
25+
```bash
26+
conda create -y -n greaselm python=3.8
27+
conda activate greaselm
28+
pip install numpy==1.18.3 tqdm
29+
pip install torch==1.8.0+cu101 torchvision -f https://download.pytorch.org/whl/torch_stable.html
30+
pip install transformers==3.4.0 nltk spacy
31+
pip install wandb
32+
conda install -y -c conda-forge tensorboardx
33+
conda install -y -c conda-forge tensorboard
34+
35+
# for torch-geometric
36+
pip install torch-scatter==2.0.7 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
37+
pip install torch-cluster==1.5.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
38+
pip install torch-sparse==0.6.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
39+
pip install torch-spline-conv==1.2.1 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
40+
pip install torch-geometric==1.7.0 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
41+
```
42+
43+
44+
## 2. Download data
45+
46+
### Download and preprocess data yourself
47+
**Preprocessing the data yourself may take long, so if you want to directly download preprocessed data, please jump to the next subsection.**
48+
49+
Download the raw ConceptNet, CommonsenseQA, OpenBookQA data by using
50+
```
51+
./download_raw_data.sh
52+
```
53+
54+
You can preprocess these raw data by running
55+
```
56+
CUDA_VISIBLE_DEVICES=0 python preprocess.py -p <num_processes>
57+
```
58+
You can specify the GPU you want to use in the beginning of the command `CUDA_VISIBLE_DEVICES=...`. The script will:
59+
* Setup ConceptNet (e.g., extract English relations from ConceptNet, merge the original 42 relation types into 17 types)
60+
* Convert the QA datasets into .jsonl files (e.g., stored in `data/csqa/statement/`)
61+
* Identify all mentioned concepts in the questions and answers
62+
* Extract subgraphs for each q-a pair
63+
64+
The script to download and preprocess the [MedQA-USMLE](https://github.com/jind11/MedQA) data and the biomedical knowledge graph based on Disease Database and DrugBank is provided in `utils_biomed/`.
65+
66+
### Directly download preprocessed data
67+
For your convenience, if you don't want to preprocess the data yourself, you can download all the preprocessed data [here](https://drive.google.com/drive/folders/1T6B4nou5P3u-6jr0z6e3IkitO8fNVM6f?usp=sharing). Download them into the top-level directory of this repo and unzip them. Move the `medqa_usmle` and `ddb` folders into the `data/` directory.
68+
69+
### Resulting file structure
70+
71+
The resulting file structure should look like this:
72+
73+
```plain
74+
.
75+
├── README.md
76+
├── data/
77+
├── cpnet/ (prerocessed ConceptNet)
78+
├── csqa/
79+
├── train_rand_split.jsonl
80+
├── dev_rand_split.jsonl
81+
├── test_rand_split_no_answers.jsonl
82+
├── statement/ (converted statements)
83+
├── grounded/ (grounded entities)
84+
├── graphs/ (extracted subgraphs)
85+
├── ...
86+
├── obqa/
87+
├── medqa_usmle/
88+
└── ddb/
89+
```
90+
91+
## 3. Training GreaseLM
92+
To train GreaseLM on CommonsenseQA, run
93+
```
94+
CUDA_VISIBLE_DEVICES=0 ./run_greaselm.sh csqa --data_dir data/
95+
```
96+
You can specify up to 2 GPUs you want to use in the beginning of the command `CUDA_VISIBLE_DEVICES=...`.
97+
98+
Similarly, to train GreaseLM on OpenbookQA, run
99+
```
100+
CUDA_VISIBLE_DEVICES=0 ./run_greaselm.sh obqa --data_dir data/
101+
```
102+
103+
To train GreaseLM on MedQA-USMLE, run
104+
```
105+
CUDA_VISIBLE_DEVICES=0 ./run_greaselm__medqa_usmle.sh
106+
```
107+
108+
## 4. Pretrained model checkpoints
109+
You can download a pretrained GreaseLM model on CommonsenseQA [here](https://drive.google.com/file/d/1QPwLZFA6AQ-pFfDR6TWLdBAvm3c_HOUr/view?usp=sharing), which achieves an IH-dev acc. of `79.0` and an IH-test acc. of `74.0`.
110+
111+
You can also download a pretrained GreaseLM model on OpenbookQA [here](https://drive.google.com/file/d/1-QqyiQuU9xlN20vwfIaqYQ_uJMP8d7Pv/view?usp=sharing), which achieves an test acc. of `84.8`.
112+
113+
You can also download a pretrained GreaseLM model on MedQA-USMLE [here](https://drive.google.com/file/d/1j0QxiBiGbv0s9PhseSly6V6uiHWU5IEt/view?usp=sharing), which achieves an test acc. of `38.5`.
114+
115+
## 5. Evaluating a pretrained model checkpoint
116+
To evaluate a pretrained GreaseLM model checkpoint on CommonsenseQA, run
117+
```
118+
CUDA_VISIBLE_DEVICES=0 ./eval_greaselm.sh csqa --data_dir data/ --load_model_path /path/to/checkpoint
119+
```
120+
Again you can specify up to 2 GPUs you want to use in the beginning of the command `CUDA_VISIBLE_DEVICES=...`.
121+
122+
Similarly, to evaluate a pretrained GreaseLM model checkpoint on OpenbookQA, run
123+
```
124+
CUDA_VISIBLE_DEVICES=0 ./eval_greaselm.sh obqa --data_dir data/ --load_model_path /path/to/checkpoint
125+
```
126+
To evaluate a pretrained GreaseLM model checkpoint on MedQA-USMLE, run
127+
```
128+
INHERIT_BERT=1 CUDA_VISIBLE_DEVICES=0 ./eval_greaselm.sh medqa_usmle --data_dir data/ --load_model_path /path/to/checkpoint
129+
```
130+
131+
## 6. Use your own dataset
132+
- Convert your dataset to `{train,dev,test}.statement.jsonl` in .jsonl format (see `data/csqa/statement/train.statement.jsonl`)
133+
- Create a directory in `data/{yourdataset}/` to store the .jsonl files
134+
- Modify `preprocess.py` and perform subgraph extraction for your data
135+
- Modify `utils/parser_utils.py` to support your own dataset
136+
137+
## 7. Acknowledgment
138+
This repo is built upon the following work:
139+
```
140+
QA-GNN: Question Answering using Language Models and Knowledge Graphs
141+
https://github.com/michiyasunaga/qagnn
142+
```
143+
Many thanks to the authors and developers!

create_enviroment.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
conda create -y -n greaselm python=3.8
2+
conda activate greaselm
3+
pip install numpy==1.18.3 tqdm
4+
pip install torch==1.8.0+cu101 torchvision -f https://download.pytorch.org/whl/torch_stable.html
5+
pip install transformers==3.4.0 nltk spacy
6+
pip install wandb
7+
conda install -y -c conda-forge tensorboardx
8+
conda install -y -c conda-forge tensorboard
9+
10+
# for torch-geometric
11+
pip install torch-scatter==2.0.7 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
12+
pip install torch-cluster==1.5.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
13+
pip install torch-sparse==0.6.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
14+
pip install torch-spline-conv==1.2.1 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
15+
pip install torch-geometric==1.7.0 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
16+
17+

create_enviroment_3090.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# create enciroment for RTX3090
2+
# CUDA version >= 11.1, torch version >=1.7.0
3+
conda create -y -n glm python=3.8
4+
conda activate glm
5+
6+
# should use conda to install pytorch, use pip will get the OSError
7+
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
8+
9+
10+
# install torch-geometric from officiall doc
11+
# https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html
12+
pip install pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
13+
14+
pip install scipy transformers==3.4.0 tensorboardx nltk spacy networkx wandb
15+
16+
17+

download_raw_data.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# download ConceptNet
2+
mkdir -p data/
3+
mkdir -p data/cpnet/
4+
wget -nc -P data/cpnet/ https://s3.amazonaws.com/conceptnet/downloads/2018/edges/conceptnet-assertions-5.6.0.csv.gz
5+
cd data/cpnet/
6+
yes n | gzip -d conceptnet-assertions-5.6.0.csv.gz
7+
# download ConceptNet entity embedding
8+
wget https://csr.s3-us-west-1.amazonaws.com/tzw.ent.npy
9+
cd ../../
10+
11+
12+
13+
14+
# download CommensenseQA dataset
15+
mkdir -p data/csqa/
16+
wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl
17+
wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl
18+
wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl
19+
20+
# create output folders
21+
mkdir -p data/csqa/grounded/
22+
mkdir -p data/csqa/graph/
23+
mkdir -p data/csqa/statement/
24+
25+
26+
27+
# download OpenBookQA dataset
28+
wget -nc -P data/obqa/ https://s3-us-west-2.amazonaws.com/ai2-website/data/OpenBookQA-V1-Sep2018.zip
29+
yes n | unzip data/obqa/OpenBookQA-V1-Sep2018.zip -d data/obqa/
30+
31+
# create output folders
32+
mkdir -p data/obqa/fairseq/official/
33+
mkdir -p data/obqa/grounded/
34+
mkdir -p data/obqa/graph/
35+
mkdir -p data/obqa/statement/

0 commit comments

Comments
 (0)