forked from stacklok/codegate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_util.py
More file actions
39 lines (26 loc) · 1.25 KB
/
embedding_util.py
File metadata and controls
39 lines (26 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import warnings
import torch
import torch.nn.functional as ftorch
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
# The transformers library internally is creating this warning, but does not
# impact our app. Safe to ignore.
warnings.filterwarnings(action="ignore", category=ResourceWarning)
# We won't have competing threads in this example app
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize tokenizer and model for GTE-base
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")
model = AutoModel.from_pretrained("thenlper/gte-base")
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def generate_embeddings(text):
inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
attention_mask = inputs["attention_mask"]
embeddings = average_pool(outputs.last_hidden_state, attention_mask)
# (Optionally) normalize embeddings
embeddings = ftorch.normalize(embeddings, p=2, dim=1)
return embeddings.numpy().tolist()[0]