@@ -804,13 +804,15 @@ def __init__(self, config, args = {},k=5, n_ntype=4, n_etype=38, hidden_size=200
804804 self .sep_ie_layers = sep_ie_layers
805805 self .args = args
806806 if sep_ie_layers :
807- self .ie_layers = nn .ModuleList ([layers .MLP (self .sent_dim + concept_dim , ie_dim , self .sent_dim + concept_dim , ie_layer_num , p_fc ) for _ in range (k )])
807+ self .ie_layers = nn .ModuleList ([layers .MLP (self .sent_dim + concept_dim * args . mix_number , ie_dim , self .sent_dimv * args . mix_number + concept_dim * args . mix_number , ie_layer_num , p_fc ) for _ in range (k )])
808808 else :
809- self .ie_layer = layers .MLP (self .sent_dim + concept_dim , ie_dim , self .sent_dim + concept_dim , ie_layer_num , p_fc )
809+ self .ie_layer = layers .MLP (self .sent_dim * args . mix_number + concept_dim * args . mix_number , ie_dim , self .sent_dim * args . mix_number + concept_dim * args . mix_number , ie_layer_num , p_fc )
810810
811811 self .concept_dim = concept_dim
812812 self .num_hidden_layers = config .num_hidden_layers
813813 self .info_exchange = info_exchange
814+ self .use_concept = args .use_concept
815+ self .mix_number = args .mix_number
814816
815817 def forward (self , hidden_states , attention_mask , special_tokens_mask , head_mask , _X , edge_index , edge_type , _node_type , _node_feature_extra , special_nodes_mask , output_attentions = False , output_hidden_states = True ,pool_mask = []):
816818 """
@@ -851,17 +853,57 @@ def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask,
851853
852854 # Exchange info between LM and GNN hidden states (Modality interaction)
853855 if self .info_exchange == True or (self .info_exchange == "every-other-layer" and (i - self .num_hidden_layers + self .k ) % 2 == 0 ):
856+ # X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
857+ # context_node_lm_feats = hidden_states[:, 0, :] # [bs, sent_dim]
858+ # context_node_gnn_feats = X[:, 0, :] # [bs, node_dim]
859+ # # [32,1224]
860+ # context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
861+
862+ # if self.sep_ie_layers:
863+ # context_node_feats = self.ie_layers[gnn_layer_index](context_node_feats)
864+ # else:
865+ # context_node_feats = self.ie_layer(context_node_feats)
866+ # context_node_lm_feats, context_node_gnn_feats = torch.split(context_node_feats, [context_node_lm_feats.size(1), context_node_gnn_feats.size(1)], dim=1)
867+ # hidden_states[:, 0, :] = context_node_lm_feats
868+ # X[:, 0, :] = context_node_gnn_feats
869+ # _X = X.view_as(_X)
854870 X = _X .view (bs , - 1 , _X .size (1 )) # [bs, max_num_nodes, node_dim]
855- context_node_lm_feats = hidden_states [:, 0 , :] # [bs, sent_dim]
856- context_node_gnn_feats = X [:, 0 , :] # [bs, node_dim]
857- context_node_feats = torch .cat ([context_node_lm_feats , context_node_gnn_feats ], dim = 1 )
858- if self .sep_ie_layers :
859- context_node_feats = self .ie_layers [gnn_layer_index ](context_node_feats )
871+ # 全量融合,细粒度融合
872+ if not self .args .all_mix :
873+ context_node_lm_feats = hidden_states [:, :self .mix_number , :] # [bs, sent_dim]
874+ context_node_lm_feats = context_node_lm_feats .view (bs ,1 ,- 1 )
875+ context_node_lm_feats = torch .squeeze (context_node_lm_feats )
876+ context_node_gnn_feats = X [:, :self .mix_number , :] # [bs, node_dim]
877+ context_node_gnn_feats = context_node_gnn_feats .view (bs ,1 ,- 1 )
878+ context_node_gnn_feats = torch .squeeze (context_node_gnn_feats )
879+ context_node_feats = torch .cat ([context_node_lm_feats , context_node_gnn_feats ], dim = 1 )
880+
881+ if self .sep_ie_layers :
882+ context_node_feats = self .ie_layers [gnn_layer_index ](context_node_feats )
883+ else :
884+ context_node_feats = self .ie_layer (context_node_feats )
885+ context_node_lm_feats , context_node_gnn_feats = torch .split (context_node_feats , [context_node_lm_feats .size (1 ), context_node_gnn_feats .size (1 )], dim = 1 )
886+ hidden_states [:, :self .mix_number , :] = context_node_lm_feats .view (bs ,self .mix_number ,- 1 )
887+ X [:, :self .mix_number , :] = context_node_gnn_feats .view (bs ,self .mix_number ,- 1 )
860888 else :
861- context_node_feats = self .ie_layer (context_node_feats )
862- context_node_lm_feats , context_node_gnn_feats = torch .split (context_node_feats , [context_node_lm_feats .size (1 ), context_node_gnn_feats .size (1 )], dim = 1 )
863- hidden_states [:, 0 , :] = context_node_lm_feats
864- X [:, 0 , :] = context_node_gnn_feats
889+ l = hidden_states .shape [1 ]
890+ h = X .shape [1 ]
891+ context_node_lm_feats = hidden_states # [bs, sent_dim]
892+ context_node_lm_feats = context_node_lm_feats .view (bs ,1 ,- 1 )
893+ context_node_lm_feats = torch .squeeze (context_node_lm_feats )
894+ context_node_gnn_feats = X # [bs, node_dim]
895+ context_node_gnn_feats = context_node_lm_feats .view (bs ,1 ,- 1 )
896+ context_node_gnn_feats = torch .squeeze (context_node_lm_feats )
897+
898+ context_node_feats = torch .cat ([context_node_lm_feats , context_node_gnn_feats ], dim = 1 )
899+
900+ if self .sep_ie_layers :
901+ context_node_feats = self .ie_layers [gnn_layer_index ](context_node_feats )
902+ else :
903+ context_node_feats = self .ie_layer (context_node_feats )
904+ context_node_lm_feats , context_node_gnn_feats = torch .split (context_node_feats , [context_node_lm_feats .size (1 ), context_node_gnn_feats .size (1 )], dim = 1 )
905+ hidden_states = context_node_lm_feats .view (bs ,l ,- 1 )
906+ X = context_node_gnn_feats .view (bs ,h ,- 1 )
865907 _X = X .view_as (_X )
866908
867909 # Add last layer
0 commit comments