import math import torch from torch.autograd import Variable import torch.nn as nn from torch_geometric.nn import MessagePassing from torch_geometric.utils import softmax from torch_scatter import scatter def make_one_hot(labels, C): ''' Converts an integer label torch.autograd.Variable to a one-hot Variable. labels : torch.autograd.Variable of torch.cuda.LongTensor (N, ), where N is batch size. Each value is an integer representing correct classification. C : integer. number of classes in labels. Returns : torch.autograd.Variable of torch.cuda.FloatTensor N x C, where C is class number. One-hot encoded. ''' labels = labels.unsqueeze(1) one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(labels.device) target = one_hot.scatter_(1, labels.data, 1) target = Variable(target) return target class GATConvE(MessagePassing): """ Args: emb_dim (int): dimensionality of GNN hidden states n_ntype (int): number of node types (e.g. 4) n_etype (int): number of edge relation types (e.g. 38) """ def __init__(self, emb_dim, n_ntype, n_etype, edge_encoder, head_count=4, aggr="add"): super(GATConvE, self).__init__(aggr=aggr) assert emb_dim % 2 == 0 self.emb_dim = emb_dim self.n_ntype = n_ntype; self.n_etype = n_etype self.edge_encoder = edge_encoder #For attention self.head_count = head_count assert emb_dim % head_count == 0 self.dim_per_head = emb_dim // head_count self.linear_key = nn.Linear(3*emb_dim, head_count * self.dim_per_head) self.linear_msg = nn.Linear(3*emb_dim, head_count * self.dim_per_head) self.linear_query = nn.Linear(2*emb_dim, head_count * self.dim_per_head) self._alpha = None #For final MLP self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim)) def forward(self, x, edge_index, edge_type, node_type, node_feature_extra, return_attention_weights=False): """ x: [N, emb_dim] edge_index: [2, E] edge_type [E,] -> edge_attr: [E, 39] / self_edge_attr: [N, 39] node_type [N,] -> headtail_attr [E, 8(=4+4)] / self_headtail_attr: [N, 8] node_feature_extra [N, dim] """ #Prepare edge feature edge_vec = make_one_hot(edge_type, self.n_etype +1) #[E, 39] self_edge_vec = torch.zeros(x.size(0), self.n_etype +1).to(edge_vec.device) self_edge_vec[:,self.n_etype] = 1 head_type = node_type[edge_index[0]] #[E,] #head=src tail_type = node_type[edge_index[1]] #[E,] #tail=tgt head_vec = make_one_hot(head_type, self.n_ntype) #[E,4] tail_vec = make_one_hot(tail_type, self.n_ntype) #[E,4] headtail_vec = torch.cat([head_vec, tail_vec], dim=1) #[E,8] self_head_vec = make_one_hot(node_type, self.n_ntype) #[N,4] self_headtail_vec = torch.cat([self_head_vec, self_head_vec], dim=1) #[N,8] edge_vec = torch.cat([edge_vec, self_edge_vec], dim=0) #[E+N, ?] headtail_vec = torch.cat([headtail_vec, self_headtail_vec], dim=0) #[E+N, ?] edge_embeddings = self.edge_encoder(torch.cat([edge_vec, headtail_vec], dim=1)) #[E+N, emb_dim] #Add self loops to edge_index loop_index = torch.arange(0, x.size(0), dtype=torch.long, device=edge_index.device) loop_index = loop_index.unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index, loop_index], dim=1) #[2, E+N] x = torch.cat([x, node_feature_extra], dim=1) x = (x, x) aggr_out = self.propagate(edge_index, x=x, edge_attr=edge_embeddings) #[N, emb_dim] out = self.mlp(aggr_out) alpha = self._alpha self._alpha = None if return_attention_weights: assert alpha is not None return out, (edge_index, alpha) else: return out def message(self, edge_index, x_i, x_j, edge_attr): #i: tgt, j:src assert len(edge_attr.size()) == 2 assert edge_attr.size(1) == self.emb_dim assert x_i.size(1) == x_j.size(1) == 2*self.emb_dim assert x_i.size(0) == x_j.size(0) == edge_attr.size(0) == edge_index.size(1) key = self.linear_key(torch.cat([x_i, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] msg = self.linear_msg(torch.cat([x_j, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] query = self.linear_query(x_j).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] query = query / math.sqrt(self.dim_per_head) scores = (query * key).sum(dim=2) #[E, heads] src_node_index = edge_index[0] #[E,] alpha = softmax(scores, src_node_index) #[E, heads] #group by src side node self._alpha = alpha #adjust by outgoing degree of src E = edge_index.size(1) #n_edges N = int(src_node_index.max()) + 1 #n_nodes ones = torch.full((E,), 1.0, dtype=torch.float).to(edge_index.device) src_node_edge_count = scatter(ones, src_node_index, dim=0, dim_size=N, reduce='sum')[src_node_index] #[E,] assert len(src_node_edge_count.size()) == 1 and len(src_node_edge_count) == E alpha = alpha * src_node_edge_count.unsqueeze(1) #[E, heads] out = msg * alpha.view(-1, self.head_count, 1) #[E, heads, _dim] return out.view(-1, self.head_count * self.dim_per_head) #[E, emb_dim]