Skip to content

Commit 54b0fe5

Browse files
committed
add attention test
1 parent 47ec36f commit 54b0fe5

2 files changed

Lines changed: 1595 additions & 4 deletions

File tree

model/modeling_roberta.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,29 @@ def pool_input_embeds(inputs_embeds,pool_mask):
166166
inputs_embeds = torch.randint(1,10,[bs, seq_len,3], dtype=torch.float)
167167

168168
# pool_mask
169-
pool_mask = torch.zeros([bs, seq_len], dtype=torch.long)
169+
pool_mask = torch.zeros([bs, seq_len], dtype=torch.float)
170170
pool_mask[0][0:2] = 1
171171
pool_mask[0][3:5] = 1
172172
pool_mask[-1][2:5] = 1
173173

174-
word_embeddings = nn.Embedding(3,2)
174+
we = nn.Embedding(3,2)
175175
idx = torch.tensor([0,0,1])
176176
# token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
177-
print(word_embeddings(idx).shape)
177+
print(pool_mask)
178178

179-
# pool_input_embeds(inputs_embeds,pool_mask)
179+
# pool_input_embeds(inputs_embeds,pool_mask)
180+
181+
lm = nn.LayerNorm(5, eps=1e-12)
182+
print(lm.state_dict().keys())
183+
print("参数gamma shape: ", lm.state_dict()['weight'])
184+
print("参数beta shape: ", lm.state_dict()['bias'])
185+
186+
dpo = nn.Dropout(0.7)
187+
188+
we = lm(pool_mask)
189+
print(we)
190+
we = dpo(we)
191+
print(we)
192+
193+
194+

0 commit comments

Comments
 (0)