-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathmodeling_gnn.py
More file actions
130 lines (105 loc) · 5.49 KB
/
Copy pathmodeling_gnn.py
File metadata and controls
130 lines (105 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import math
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax
from torch_scatter import scatter
def make_one_hot(labels, C):
'''
Converts an integer label torch.autograd.Variable to a one-hot Variable.
labels : torch.autograd.Variable of torch.cuda.LongTensor
(N, ), where N is batch size.
Each value is an integer representing correct classification.
C : integer.
number of classes in labels.
Returns : torch.autograd.Variable of torch.cuda.FloatTensor
N x C, where C is class number. One-hot encoded.
'''
labels = labels.unsqueeze(1)
one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(labels.device)
target = one_hot.scatter_(1, labels.data, 1)
target = Variable(target)
return target
class GATConvE(MessagePassing):
"""
Args:
emb_dim (int): dimensionality of GNN hidden states
n_ntype (int): number of node types (e.g. 4)
n_etype (int): number of edge relation types (e.g. 38)
"""
def __init__(self, emb_dim, n_ntype, n_etype, edge_encoder, head_count=4, aggr="add"):
super(GATConvE, self).__init__(aggr=aggr)
assert emb_dim % 2 == 0
self.emb_dim = emb_dim
self.n_ntype = n_ntype; self.n_etype = n_etype
self.edge_encoder = edge_encoder
#For attention
self.head_count = head_count
assert emb_dim % head_count == 0
self.dim_per_head = emb_dim // head_count
self.linear_key = nn.Linear(3*emb_dim, head_count * self.dim_per_head)
self.linear_msg = nn.Linear(3*emb_dim, head_count * self.dim_per_head)
self.linear_query = nn.Linear(2*emb_dim, head_count * self.dim_per_head)
self._alpha = None
#For final MLP
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim))
def forward(self, x, edge_index, edge_type, node_type, node_feature_extra, return_attention_weights=False):
"""
x: [N, emb_dim]
edge_index: [2, E]
edge_type [E,] -> edge_attr: [E, 39] / self_edge_attr: [N, 39]
node_type [N,] -> headtail_attr [E, 8(=4+4)] / self_headtail_attr: [N, 8]
node_feature_extra [N, dim]
"""
#Prepare edge feature
edge_vec = make_one_hot(edge_type, self.n_etype +1) #[E, 39]
self_edge_vec = torch.zeros(x.size(0), self.n_etype +1).to(edge_vec.device)
self_edge_vec[:,self.n_etype] = 1
head_type = node_type[edge_index[0]] #[E,] #head=src
tail_type = node_type[edge_index[1]] #[E,] #tail=tgt
head_vec = make_one_hot(head_type, self.n_ntype) #[E,4]
tail_vec = make_one_hot(tail_type, self.n_ntype) #[E,4]
headtail_vec = torch.cat([head_vec, tail_vec], dim=1) #[E,8]
self_head_vec = make_one_hot(node_type, self.n_ntype) #[N,4]
self_headtail_vec = torch.cat([self_head_vec, self_head_vec], dim=1) #[N,8]
edge_vec = torch.cat([edge_vec, self_edge_vec], dim=0) #[E+N, ?]
headtail_vec = torch.cat([headtail_vec, self_headtail_vec], dim=0) #[E+N, ?]
edge_embeddings = self.edge_encoder(torch.cat([edge_vec, headtail_vec], dim=1)) #[E+N, emb_dim]
#Add self loops to edge_index
loop_index = torch.arange(0, x.size(0), dtype=torch.long, device=edge_index.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index, loop_index], dim=1) #[2, E+N]
x = torch.cat([x, node_feature_extra], dim=1)
x = (x, x)
aggr_out = self.propagate(edge_index, x=x, edge_attr=edge_embeddings) #[N, emb_dim]
out = self.mlp(aggr_out)
alpha = self._alpha
self._alpha = None
if return_attention_weights:
assert alpha is not None
return out, (edge_index, alpha)
else:
return out
def message(self, edge_index, x_i, x_j, edge_attr): #i: tgt, j:src
assert len(edge_attr.size()) == 2
assert edge_attr.size(1) == self.emb_dim
assert x_i.size(1) == x_j.size(1) == 2*self.emb_dim
assert x_i.size(0) == x_j.size(0) == edge_attr.size(0) == edge_index.size(1)
key = self.linear_key(torch.cat([x_i, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
msg = self.linear_msg(torch.cat([x_j, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
query = self.linear_query(x_j).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
query = query / math.sqrt(self.dim_per_head)
scores = (query * key).sum(dim=2) #[E, heads]
src_node_index = edge_index[0] #[E,]
alpha = softmax(scores, src_node_index) #[E, heads] #group by src side node
self._alpha = alpha
#adjust by outgoing degree of src
E = edge_index.size(1) #n_edges
N = int(src_node_index.max()) + 1 #n_nodes
ones = torch.full((E,), 1.0, dtype=torch.float).to(edge_index.device)
src_node_edge_count = scatter(ones, src_node_index, dim=0, dim_size=N, reduce='sum')[src_node_index] #[E,]
assert len(src_node_edge_count.size()) == 1 and len(src_node_edge_count) == E
alpha = alpha * src_node_edge_count.unsqueeze(1) #[E, heads]
out = msg * alpha.view(-1, self.head_count, 1) #[E, heads, _dim]
return out.view(-1, self.head_count * self.dim_per_head) #[E, emb_dim]