Update README.md
Browse files
README.md
CHANGED
|
@@ -58,16 +58,6 @@ import torch
|
|
| 58 |
|
| 59 |
device = 'cuda:0'
|
| 60 |
|
| 61 |
-
# This function is borrowed from https://huggingface.co/intfloat/e5-mistral-7b-instruct
|
| 62 |
-
def last_token_pool(last_hidden_states, attention_mask):
|
| 63 |
-
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
|
| 64 |
-
if left_padding:
|
| 65 |
-
return last_hidden_states[:, -1]
|
| 66 |
-
else:
|
| 67 |
-
sequence_lengths = attention_mask.sum(dim=1) - 1
|
| 68 |
-
batch_size = last_hidden_states.shape[0]
|
| 69 |
-
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
|
| 70 |
-
|
| 71 |
# Load model, be sure to substitute `model_path` by your model path
|
| 72 |
model_path = '/local/path/to/model'
|
| 73 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
@@ -86,19 +76,15 @@ query_full = query_instruction + query
|
|
| 86 |
|
| 87 |
# Embed image documents
|
| 88 |
with torch.no_grad():
|
| 89 |
-
|
| 90 |
-
p_reps = last_token_pool(p_outputs.last_hidden_state, p_outputs.attention_mask)
|
| 91 |
|
| 92 |
# Embed text queries
|
| 93 |
with torch.no_grad():
|
| 94 |
-
|
| 95 |
-
q_reps = last_token_pool(q_outputs.last_hidden_state, q_outputs.attention_mask) # [B, d]
|
| 96 |
|
| 97 |
# Calculate similarities
|
| 98 |
scores = torch.matmul(q_reps, p_reps.T)
|
| 99 |
print(scores)
|
| 100 |
-
|
| 101 |
-
# tensor([[0.6506, 4.9630, 3.8614]], device='cuda:0')
|
| 102 |
```
|
| 103 |
|
| 104 |
# Limitations
|
|
|
|
| 58 |
|
| 59 |
device = 'cuda:0'
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# Load model, be sure to substitute `model_path` by your model path
|
| 62 |
model_path = '/local/path/to/model'
|
| 63 |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
|
| 76 |
|
| 77 |
# Embed image documents
|
| 78 |
with torch.no_grad():
|
| 79 |
+
p_reps = model(text=['', '', ''], image=[image_1, image_2, image_3], tokenizer=tokenizer)
|
|
|
|
| 80 |
|
| 81 |
# Embed text queries
|
| 82 |
with torch.no_grad():
|
| 83 |
+
q_reps = model(text=[query_full], image=[None], tokenizer=tokenizer) # [B, s, d]
|
|
|
|
| 84 |
|
| 85 |
# Calculate similarities
|
| 86 |
scores = torch.matmul(q_reps, p_reps.T)
|
| 87 |
print(scores)
|
|
|
|
|
|
|
| 88 |
```
|
| 89 |
|
| 90 |
# Limitations
|