File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -73,8 +73,12 @@ def forward(
7373 if self .position_embedding_type == "absolute" :
7474 position_embeddings = self .position_embeddings (position_ids )
7575 embeddings += position_embeddings
76+
77+
7678 embeddings = self .LayerNorm (embeddings )
79+
7780 embeddings = self .dropout (embeddings )
81+
7882 return embeddings
7983
8084 def pool_input_embeds (self ,inputs_embeds ,pool_mask ):
@@ -167,4 +171,9 @@ def pool_input_embeds(inputs_embeds,pool_mask):
167171 pool_mask [0 ][3 :5 ] = 1
168172 pool_mask [- 1 ][2 :5 ] = 1
169173
170- pool_input_embeds (inputs_embeds ,pool_mask )
174+ word_embeddings = nn .Embedding (3 ,2 )
175+ idx = torch .tensor ([0 ,0 ,1 ])
176+ # token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
177+ print (word_embeddings (idx ).shape )
178+
179+ # pool_input_embeds(inputs_embeds,pool_mask)
You can’t perform that action at this time.
0 commit comments