Skip to content

Commit 3fe4216

Browse files
committed
add mix number para
1 parent 0f6800c commit 3fe4216

4 files changed

Lines changed: 193 additions & 13 deletions

File tree

greaselm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def load_data(args, devices, kg):
6565
subsample=args.subsample, n_train=args.n_train, debug=args.debug, cxt_node_connects_all=args.cxt_node_connects_all, kg=kg,emp=args.emp,
6666
train_tagged_path = args.train_tagged,
6767
dev_tagged_path=args.dev_tagged,
68-
test_tagged_path=args.test_tagged,
68+
test_tagged_path=args.test_tagged
6969
)
7070

7171
# GLM Dataloader base version
@@ -638,7 +638,9 @@ def main(args):
638638
# MyGLM
639639
parser.add_argument('--emp',default=False,type = utils.bool_flag)
640640
parser.add_argument('--is_compress',default=False,type = utils.bool_flag)
641-
641+
parser.add_argument('--use_concept',default=False,type = utils.bool_flag)
642+
parser.add_argument('--mix_number',default=1,type = int)
643+
parser.add_argument('--all_mix',default=False,type = utils.bool_flag)
642644

643645

644646
args = parser.parse_args()

model/modeling_roberta.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def __init__(self):
88
super().__init__()
99
self.embeddings = RobertaEmbeddings()
1010

11+
# 直接合并向量表示
1112
class RobertaPoolEmbeddings(RobertaEmbeddings):
1213

1314
def __init__(self, config):
@@ -120,7 +121,120 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
120121
)
121122
return position_ids.unsqueeze(0).expand(input_shape)
122123

124+
#添加实体层的embedding
125+
# class RobertaEntityEmbeddings(RobertaEmbeddings):
123126

127+
# def __init__(self, config):
128+
# super().__init__(config = config)
129+
# self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
130+
# self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
131+
# self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
132+
# self.entity_type_embeddings = nn.Embedding(2, config.hidden_size)
133+
134+
# # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
135+
# # any TensorFlow checkpoint file
136+
# self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
137+
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
138+
# # position_ids (1, len position emb) is contiguous in memory and exported when serialized
139+
# self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
140+
# self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
141+
# self.register_buffer(
142+
# "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
143+
# )
144+
145+
# # End copy
146+
# self.padding_idx = config.pad_token_id
147+
# self.position_embeddings = nn.Embedding(
148+
# config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
149+
# )
150+
151+
# def forward(
152+
# self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, pool_mask=None,past_key_values_length=0
153+
# ):
154+
# if position_ids is None:
155+
# if input_ids is not None:
156+
# # Create the position ids from the input token ids. Any padded tokens remain padded.
157+
# position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
158+
# else:
159+
# position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
160+
161+
# if input_ids is not None:
162+
# input_shape = input_ids.size()
163+
# else:
164+
# input_shape = inputs_embeds.size()[:-1]
165+
166+
# seq_length = input_shape[1]
167+
168+
# # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
169+
# # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
170+
# # issue #5664
171+
# if token_type_ids is None:
172+
# if hasattr(self, "token_type_ids"):
173+
# buffered_token_type_ids = self.token_type_ids[:, :seq_length]
174+
# buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
175+
# token_type_ids = buffered_token_type_ids_expanded
176+
# else:
177+
# token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
178+
179+
# if inputs_embeds is None:
180+
# inputs_embeds = self.word_embeddings(input_ids)
181+
# # print(inputs_embeds.shape)
182+
# # print(input_ids.shape)
183+
# # print(pool_mask.shape)
184+
# inputs_embeds = self.pool_input_embeds(inputs_embeds,pool_mask)
185+
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
186+
187+
# embeddings = inputs_embeds + token_type_embeddings
188+
# if self.position_embedding_type == "absolute":
189+
# position_embeddings = self.position_embeddings(position_ids)
190+
# embeddings += position_embeddings
191+
192+
193+
# embeddings = self.LayerNorm(embeddings)
194+
195+
# embeddings = self.dropout(embeddings)
196+
197+
# return embeddings
198+
199+
# def pool_input_embeds(self,inputs_embeds,pool_mask):
200+
# # print(pool_mask)
201+
# # print(inputs_embeds)
202+
# for bs in range(pool_mask.shape[0]):
203+
# row_pool = pool_mask[bs]
204+
# start,end = 0,0
205+
# for i in range(pool_mask.shape[1]-1):
206+
# if row_pool[i]==0 and row_pool[i+1]==1:
207+
# start = i+1
208+
# elif row_pool[i]==1 and i==0:
209+
# start = i
210+
# elif row_pool[i+1] ==1 and i == pool_mask.shape[1]-2:
211+
# end = i+1
212+
# elif row_pool[i]==1 and row_pool[i+1]==0:
213+
# end = i
214+
215+
# if end != 0:
216+
# inputs_embeds[bs][start:end+1] = torch.mean(inputs_embeds[bs][start:end+1],dim=0,keepdim=True)
217+
# # print(start,end+1)
218+
# start,end = 0,0
219+
220+
# # print(inputs_embeds)
221+
# return inputs_embeds
222+
223+
# def create_position_ids_from_inputs_embeds(self, inputs_embeds):
224+
# """
225+
# We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
226+
# Args:
227+
# inputs_embeds: torch.Tensor
228+
# Returns: torch.Tensor
229+
# """
230+
# input_shape = inputs_embeds.size()[:-1]
231+
# sequence_length = input_shape[1]
232+
233+
# position_ids = torch.arange(
234+
# self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
235+
# )
236+
# return position_ids.unsqueeze(0).expand(input_shape)
237+
124238
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
125239
"""
126240
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols

modeling/modeling_greaselm.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -804,13 +804,15 @@ def __init__(self, config, args = {},k=5, n_ntype=4, n_etype=38, hidden_size=200
804804
self.sep_ie_layers = sep_ie_layers
805805
self.args = args
806806
if sep_ie_layers:
807-
self.ie_layers = nn.ModuleList([layers.MLP(self.sent_dim + concept_dim, ie_dim, self.sent_dim + concept_dim, ie_layer_num, p_fc) for _ in range(k)])
807+
self.ie_layers = nn.ModuleList([layers.MLP(self.sent_dim + concept_dim*args.mix_number, ie_dim, self.sent_dimv*args.mix_number +concept_dim*args.mix_number, ie_layer_num, p_fc) for _ in range(k)])
808808
else:
809-
self.ie_layer = layers.MLP(self.sent_dim + concept_dim, ie_dim, self.sent_dim + concept_dim, ie_layer_num, p_fc)
809+
self.ie_layer = layers.MLP(self.sent_dim*args.mix_number + concept_dim*args.mix_number, ie_dim, self.sent_dim*args.mix_number + concept_dim*args.mix_number, ie_layer_num, p_fc)
810810

811811
self.concept_dim = concept_dim
812812
self.num_hidden_layers = config.num_hidden_layers
813813
self.info_exchange = info_exchange
814+
self.use_concept = args.use_concept
815+
self.mix_number = args.mix_number
814816

815817
def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask, _X, edge_index, edge_type, _node_type, _node_feature_extra, special_nodes_mask, output_attentions=False, output_hidden_states=True,pool_mask=[]):
816818
"""
@@ -851,17 +853,57 @@ def forward(self, hidden_states, attention_mask, special_tokens_mask, head_mask,
851853

852854
# Exchange info between LM and GNN hidden states (Modality interaction)
853855
if self.info_exchange == True or (self.info_exchange == "every-other-layer" and (i - self.num_hidden_layers + self.k) % 2 == 0):
856+
# X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
857+
# context_node_lm_feats = hidden_states[:, 0, :] # [bs, sent_dim]
858+
# context_node_gnn_feats = X[:, 0, :] # [bs, node_dim]
859+
# # [32,1224]
860+
# context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
861+
862+
# if self.sep_ie_layers:
863+
# context_node_feats = self.ie_layers[gnn_layer_index](context_node_feats)
864+
# else:
865+
# context_node_feats = self.ie_layer(context_node_feats)
866+
# context_node_lm_feats, context_node_gnn_feats = torch.split(context_node_feats, [context_node_lm_feats.size(1), context_node_gnn_feats.size(1)], dim=1)
867+
# hidden_states[:, 0, :] = context_node_lm_feats
868+
# X[:, 0, :] = context_node_gnn_feats
869+
# _X = X.view_as(_X)
854870
X = _X.view(bs, -1, _X.size(1)) # [bs, max_num_nodes, node_dim]
855-
context_node_lm_feats = hidden_states[:, 0, :] # [bs, sent_dim]
856-
context_node_gnn_feats = X[:, 0, :] # [bs, node_dim]
857-
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
858-
if self.sep_ie_layers:
859-
context_node_feats = self.ie_layers[gnn_layer_index](context_node_feats)
871+
# 全量融合,细粒度融合
872+
if not self.args.all_mix:
873+
context_node_lm_feats = hidden_states[:, :self.mix_number, :] # [bs, sent_dim]
874+
context_node_lm_feats = context_node_lm_feats.view(bs,1,-1)
875+
context_node_lm_feats = torch.squeeze(context_node_lm_feats)
876+
context_node_gnn_feats = X[:, :self.mix_number, :] # [bs, node_dim]
877+
context_node_gnn_feats = context_node_gnn_feats.view(bs,1,-1)
878+
context_node_gnn_feats = torch.squeeze(context_node_gnn_feats)
879+
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
880+
881+
if self.sep_ie_layers:
882+
context_node_feats = self.ie_layers[gnn_layer_index](context_node_feats)
883+
else:
884+
context_node_feats = self.ie_layer(context_node_feats)
885+
context_node_lm_feats, context_node_gnn_feats = torch.split(context_node_feats, [context_node_lm_feats.size(1), context_node_gnn_feats.size(1)], dim=1)
886+
hidden_states[:, :self.mix_number, :] = context_node_lm_feats.view(bs,self.mix_number,-1)
887+
X[:, :self.mix_number, :] = context_node_gnn_feats.view(bs,self.mix_number,-1)
860888
else:
861-
context_node_feats = self.ie_layer(context_node_feats)
862-
context_node_lm_feats, context_node_gnn_feats = torch.split(context_node_feats, [context_node_lm_feats.size(1), context_node_gnn_feats.size(1)], dim=1)
863-
hidden_states[:, 0, :] = context_node_lm_feats
864-
X[:, 0, :] = context_node_gnn_feats
889+
l = hidden_states.shape[1]
890+
h = X.shape[1]
891+
context_node_lm_feats = hidden_states # [bs, sent_dim]
892+
context_node_lm_feats = context_node_lm_feats.view(bs,1,-1)
893+
context_node_lm_feats = torch.squeeze(context_node_lm_feats)
894+
context_node_gnn_feats = X # [bs, node_dim]
895+
context_node_gnn_feats = context_node_lm_feats.view(bs,1,-1)
896+
context_node_gnn_feats = torch.squeeze(context_node_lm_feats)
897+
898+
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
899+
900+
if self.sep_ie_layers:
901+
context_node_feats = self.ie_layers[gnn_layer_index](context_node_feats)
902+
else:
903+
context_node_feats = self.ie_layer(context_node_feats)
904+
context_node_lm_feats, context_node_gnn_feats = torch.split(context_node_feats, [context_node_lm_feats.size(1), context_node_gnn_feats.size(1)], dim=1)
905+
hidden_states = context_node_lm_feats.view(bs,l,-1)
906+
X = context_node_gnn_feats.view(bs,h,-1)
865907
_X = X.view_as(_X)
866908

867909
# Add last layer

test.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,31 @@
1313
./run_greaselm.sh obqa --data_dir data/ --emp True --use_wandb True -k 3
1414
./run_greaselm.sh obqa --data_dir data/ --emp True --use_wandb True -k 7
1515

16+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True -k 3
17+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True -k 7
1618

1719

1820

21+
2023年03月29日
22+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --mix_number 2
23+
./run_greaselm.sh obqa --data_dir data/ --use_wandb True --all_mix True
24+
25+
2023年04月03日
26+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --mix_number 2
27+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --mix_number 3
28+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --mix_number 5
29+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --mix_number 10
30+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --all_mix True (需要重新测试)
31+
32+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --gnn_dim 100
33+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --gnn_dim 50
34+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --gnn_dim 300
35+
./run_greaselm.sh obqa --data_dir data/ --emp False --use_wandb True --gnn_dim 400
36+
37+
38+
01245
39+
40+
1941

2042

2143

0 commit comments

Comments
 (0)