import itertools
import json
import pickle
import os
import numpy as np
import torch
from tqdm import tqdm
from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
from transformers import (OpenAIGPTTokenizer, BertTokenizer, XLNetTokenizer, RobertaTokenizer)
try:
from transformers import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers import AlbertTokenizer
except:
pass
from preprocess_utils import conceptnet
from utils import utils
MODEL_CLASS_TO_NAME = {
'gpt': list(OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
'bert': list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
'xlnet': list(XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
'roberta': list(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
'lstm': ['lstm'],
}
try:
MODEL_CLASS_TO_NAME['albert'] = list(ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())
except:
pass
MODEL_NAME_TO_CLASS = {model_name: model_class for model_class, model_name_list in MODEL_CLASS_TO_NAME.items() for model_name in model_name_list}
#Add SapBERT configuration
model_name = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
MODEL_NAME_TO_CLASS[model_name] = 'bert'
GPT_SPECIAL_TOKENS = ['_start_', '_delimiter_', '_classify_']
class MultiGPUSparseAdjDataBatchGenerator(object):
"""A data generator that batches the data and moves them to the corresponding devices."""
def __init__(self, device0, device1, batch_size, indexes, qids, labels,
tensors0=[], lists0=[], tensors1=[], lists1=[], adj_data=None):
self.device0 = device0
self.device1 = device1
self.batch_size = batch_size
self.indexes = indexes
self.qids = qids
self.labels = labels
self.tensors0 = tensors0
self.lists0 = lists0
self.tensors1 = tensors1
self.lists1 = lists1
self.adj_data = adj_data
def __len__(self):
return (self.indexes.size(0) - 1) // self.batch_size + 1
def __iter__(self):
bs = self.batch_size
n = self.indexes.size(0)
for a in range(0, n, bs):
b = min(n, a + bs)
batch_indexes = self.indexes[a:b]
batch_qids = [self.qids[idx] for idx in batch_indexes]
batch_labels = self._to_device(self.labels[batch_indexes], self.device1)
batch_tensors0 = [self._to_device(x[batch_indexes], self.device1) for x in self.tensors0]
batch_tensors1 = [self._to_device(x[batch_indexes], self.device1) for x in self.tensors1]
batch_tensors1[0] = batch_tensors1[0].to(self.device0)
batch_lists0 = [self._to_device([x[i] for i in batch_indexes], self.device0) for x in self.lists0]
batch_lists1 = [self._to_device([x[i] for i in batch_indexes], self.device1) for x in self.lists1]
edge_index_all, edge_type_all = self.adj_data
#edge_index_all: nested list of shape (n_samples, num_choice), where each entry is tensor[2, E]
#edge_type_all: nested list of shape (n_samples, num_choice), where each entry is tensor[E, ]
edge_index = self._to_device([edge_index_all[i] for i in batch_indexes], self.device1)
edge_type = self._to_device([edge_type_all[i] for i in batch_indexes], self.device1)
yield tuple([batch_qids, batch_labels, *batch_tensors0, *batch_lists0, *batch_tensors1, *batch_lists1, edge_index, edge_type])
def _to_device(self, obj, device):
if isinstance(obj, (tuple, list)):
return [self._to_device(item, device) for item in obj]
else:
return obj.to(device)
class GreaseLM_DataLoader(object):
def __init__(self, train_statement_path, train_adj_path,
dev_statement_path, dev_adj_path,
test_statement_path, test_adj_path,
batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128,
is_inhouse=False, inhouse_train_qids_path=None,
subsample=1.0, n_train=-1, debug=False, cxt_node_connects_all=False, kg="cpnet"):
super().__init__()
self.batch_size = batch_size
self.eval_batch_size = eval_batch_size
self.device0, self.device1 = device
self.is_inhouse = is_inhouse
self.debug = debug
self.model_name = model_name
self.max_node_num = max_node_num
self.debug_sample_size = 32
self.cxt_node_connects_all = cxt_node_connects_all
self.model_type = MODEL_NAME_TO_CLASS[model_name]
self.load_resources(kg)
# Load training data
print ('train_statement_path', train_statement_path)
self.train_qids, self.train_labels, self.train_encoder_data, train_concepts_by_sents_list = self.load_input_tensors(train_statement_path, max_seq_length)
num_choice = self.train_encoder_data[0].size(1)
self.num_choice = num_choice
print ('num_choice', num_choice)
*self.train_decoder_data, self.train_adj_data = self.load_sparse_adj_data_with_contextnode(train_adj_path, max_node_num, train_concepts_by_sents_list)
if not debug:
assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
print("Finish loading training data.")
# Load dev data
self.dev_qids, self.dev_labels, self.dev_encoder_data, dev_concepts_by_sents_list = self.load_input_tensors(dev_statement_path, max_seq_length)
*self.dev_decoder_data, self.dev_adj_data = self.load_sparse_adj_data_with_contextnode(dev_adj_path, max_node_num, dev_concepts_by_sents_list)
if not debug:
assert all(len(self.dev_qids) == len(self.dev_adj_data[0]) == x.size(0) for x in [self.dev_labels] + self.dev_encoder_data + self.dev_decoder_data)
print("Finish loading dev data.")
# Load test data
if test_statement_path is not None:
self.test_qids, self.test_labels, self.test_encoder_data, test_concepts_by_sents_list = self.load_input_tensors(test_statement_path, max_seq_length)
*self.test_decoder_data, self.test_adj_data = self.load_sparse_adj_data_with_contextnode(test_adj_path, max_node_num, test_concepts_by_sents_list)
if not debug:
assert all(len(self.test_qids) == len(self.test_adj_data[0]) == x.size(0) for x in [self.test_labels] + self.test_encoder_data + self.test_decoder_data)
print("Finish loading test data.")
# If using inhouse split, we split the original training set into an inhouse training set and an inhouse test set.
if self.is_inhouse:
with open(inhouse_train_qids_path, 'r') as fin:
inhouse_qids = set(line.strip() for line in fin)
self.inhouse_train_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid in inhouse_qids])
self.inhouse_test_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid not in inhouse_qids])
# Optionally we can subsample the training set.
assert 0. < subsample <= 1.
if subsample < 1. or n_train >= 0:
# n_train will override subsample if the former is not None
if n_train == -1:
n_train = int(self.train_size() * subsample)
assert n_train > 0
if self.is_inhouse:
self.inhouse_train_indexes = self.inhouse_train_indexes[:n_train]
else:
self.train_qids = self.train_qids[:n_train]
self.train_labels = self.train_labels[:n_train]
self.train_encoder_data = [x[:n_train] for x in self.train_encoder_data]
self.train_decoder_data = [x[:n_train] for x in self.train_decoder_data]
self.train_adj_data = self.train_adj_data[:n_train]
assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
assert self.train_size() == n_train
def train_size(self):
return self.inhouse_train_indexes.size(0) if self.is_inhouse else len(self.train_qids)
def dev_size(self):
return len(self.dev_qids)
def test_size(self):
if self.is_inhouse:
return self.inhouse_test_indexes.size(0)
else:
return len(self.test_qids) if hasattr(self, 'test_qids') else 0
def train(self):
if self.debug:
train_indexes = torch.arange(self.debug_sample_size)
elif self.is_inhouse:
n_train = self.inhouse_train_indexes.size(0)
train_indexes = self.inhouse_train_indexes[torch.randperm(n_train)]
else:
train_indexes = torch.randperm(len(self.train_qids))
return MultiGPUSparseAdjDataBatchGenerator(self.device0, self.device1, self.batch_size, train_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
def train_eval(self):
return MultiGPUSparseAdjDataBatchGenerator(self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.train_qids)), self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
def dev(self):
if self.debug:
dev_indexes = torch.arange(self.debug_sample_size)
else:
dev_indexes = torch.arange(len(self.dev_qids))
return MultiGPUSparseAdjDataBatchGenerator(self.device0, self.device1, self.eval_batch_size, dev_indexes, self.dev_qids, self.dev_labels, tensors0=self.dev_encoder_data, tensors1=self.dev_decoder_data, adj_data=self.dev_adj_data)
def test(self):
if self.debug:
test_indexes = torch.arange(self.debug_sample_size)
elif self.is_inhouse:
test_indexes = self.inhouse_test_indexes
else:
test_indexes = torch.arange(len(self.test_qids))
if self.is_inhouse:
return MultiGPUSparseAdjDataBatchGenerator(self.device0, self.device1, self.eval_batch_size, test_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
else:
return MultiGPUSparseAdjDataBatchGenerator(self.device0, self.device1, self.eval_batch_size, test_indexes, self.test_qids, self.test_labels, tensors0=self.test_encoder_data, tensors1=self.test_decoder_data, adj_data=self.test_adj_data)
def load_resources(self, kg):
# Load the tokenizer
try:
tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(self.model_type)
except:
tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(self.model_type)
tokenizer = tokenizer_class.from_pretrained(self.model_name)
self.tokenizer = tokenizer
if kg == "cpnet":
# Load cpnet
cpnet_vocab_path = "data/cpnet/concept.txt"
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
self.id2concept = [w.strip() for w in fin]
self.concept2id = {w: i for i, w in enumerate(self.id2concept)}
self.id2relation = conceptnet.merged_relations
elif kg == "ddb":
cpnet_vocab_path = "data/ddb/vocab.txt"
with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
self.id2concept = [w.strip() for w in fin]
self.concept2id = {w: i for i, w in enumerate(self.id2concept)}
self.id2relation = [
'belongstothecategoryof',
'isacategory',
'maycause',
'isasubtypeof',
'isariskfactorof',
'isassociatedwith',
'maycontraindicate',
'interactswith',
'belongstothedrugfamilyof',
'child-parent',
'isavectorfor',
'mabeallelicwith',
'seealso',
'isaningradientof',
'mabeindicatedby'
]
else:
raise ValueError("Invalid value for kg.")
def load_input_tensors(self, input_jsonl_path, max_seq_length):
"""Construct input tensors for the LM component of the model."""
cache_path = input_jsonl_path + "-sl{}".format(max_seq_length) + (("-" + self.model_type) if self.model_type != "roberta" else "") + '.loaded_cache'
use_cache = True
if use_cache and not os.path.exists(cache_path):
use_cache = False
if use_cache:
with open(cache_path, 'rb') as f:
input_tensors = utils.CPU_Unpickler(f).load()
else:
if self.model_type in ('lstm',):
raise NotImplementedError
elif self.model_type in ('gpt',):
input_tensors = load_gpt_input_tensors(input_jsonl_path, max_seq_length)
elif self.model_type in ('bert', 'xlnet', 'roberta', 'albert'):
input_tensors = load_bert_xlnet_roberta_input_tensors(input_jsonl_path, max_seq_length, self.debug, self.tokenizer, self.debug_sample_size)
if not self.debug:
utils.save_pickle(input_tensors, cache_path)
return input_tensors
def load_sparse_adj_data_with_contextnode(self, adj_pk_path, max_node_num, concepts_by_sents_list):
"""Construct input tensors for the GNN component of the model."""
print("Loading sparse adj data...")
cache_path = adj_pk_path + "-nodenum{}".format(max_node_num) + ("-cntsall" if self.cxt_node_connects_all else "") + '.loaded_cache'
use_cache = True
if use_cache and not os.path.exists(cache_path):
use_cache = False
if use_cache:
with open(cache_path, 'rb') as f:
adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel, special_nodes_mask = utils.CPU_Unpickler(f).load()
else:
# Set special nodes and links
context_node = 0
n_special_nodes = 1
cxt2qlinked_rel = 0
cxt2alinked_rel = 1
half_n_rel = len(self.id2relation) + 2
if self.cxt_node_connects_all:
cxt2other_rel = half_n_rel
half_n_rel += 1
adj_concept_pairs = []
with open(adj_pk_path, "rb") as in_file:
try:
while True:
ex = pickle.load(in_file)
if type(ex) == dict:
adj_concept_pairs.append(ex)
elif type(ex) == list:
adj_concept_pairs.extend(ex)
else:
raise TypeError("Invalid type for ex.")
except EOFError:
pass
n_samples = len(adj_concept_pairs) #this is actually n_questions x n_choices
edge_index, edge_type = [], []
adj_lengths = torch.zeros((n_samples,), dtype=torch.long)
concept_ids = torch.full((n_samples, max_node_num), 1, dtype=torch.long)
node_type_ids = torch.full((n_samples, max_node_num), 2, dtype=torch.long) #default 2: "other node"
node_scores = torch.zeros((n_samples, max_node_num, 1), dtype=torch.float)
special_nodes_mask = torch.zeros(n_samples, max_node_num, dtype=torch.bool)
adj_lengths_ori = adj_lengths.clone()
if not concepts_by_sents_list:
concepts_by_sents_list = itertools.repeat(None)
for idx, (_data, cpts_by_sents) in tqdm(enumerate(zip(adj_concept_pairs, concepts_by_sents_list)), total=n_samples, desc='loading adj matrices'):
if self.debug and idx >= self.debug_sample_size * self.num_choice:
break
adj, concepts, qm, am, cid2score = _data['adj'], _data['concepts'], _data['qmask'], _data['amask'], _data['cid2score']
#adj: e.g. <4233x249 (n_nodes*half_n_rels x n_nodes) sparse matrix of type '