Spaces:
Runtime error
Runtime error
Kottu
commited on
Update clipGPT.py
Browse files- clipGPT.py +5 -5
clipGPT.py
CHANGED
|
@@ -70,11 +70,11 @@ class ClipGPT2Model(nn.Module):
|
|
| 70 |
def generate_beam(
|
| 71 |
model,
|
| 72 |
tokenizer,
|
|
|
|
|
|
|
|
|
|
| 73 |
beam_size: int = 10,
|
| 74 |
prompt=None,
|
| 75 |
-
embed=None,
|
| 76 |
-
entry_length=76,
|
| 77 |
-
temperature=0.9,
|
| 78 |
stop_token: str = ".",
|
| 79 |
):
|
| 80 |
|
|
@@ -144,7 +144,7 @@ def generate_beam(
|
|
| 144 |
|
| 145 |
|
| 146 |
|
| 147 |
-
def generate_caption_clipgpt(img):
|
| 148 |
|
| 149 |
prefix_length = 10
|
| 150 |
model = ClipGPT2Model(prefix_length)
|
|
@@ -164,7 +164,7 @@ def generate_caption_clipgpt(img):
|
|
| 164 |
with torch.no_grad():
|
| 165 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
| 166 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
| 167 |
-
beam_caption = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
| 168 |
|
| 169 |
end_time = time.time()
|
| 170 |
print("--- Time taken to generate: %s seconds ---" % (end_time - start_time))
|
|
|
|
| 70 |
def generate_beam(
|
| 71 |
model,
|
| 72 |
tokenizer,
|
| 73 |
+
entry_length,
|
| 74 |
+
temperature,
|
| 75 |
+
embed=None,
|
| 76 |
beam_size: int = 10,
|
| 77 |
prompt=None,
|
|
|
|
|
|
|
|
|
|
| 78 |
stop_token: str = ".",
|
| 79 |
):
|
| 80 |
|
|
|
|
| 144 |
|
| 145 |
|
| 146 |
|
| 147 |
+
def generate_caption_clipgpt(img, entry_length, temperature):
|
| 148 |
|
| 149 |
prefix_length = 10
|
| 150 |
model = ClipGPT2Model(prefix_length)
|
|
|
|
| 164 |
with torch.no_grad():
|
| 165 |
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
| 166 |
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
| 167 |
+
beam_caption = generate_beam(model, tokenizer, entry_length, temperature, embed=prefix_embed)[0]
|
| 168 |
|
| 169 |
end_time = time.time()
|
| 170 |
print("--- Time taken to generate: %s seconds ---" % (end_time - start_time))
|