Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,13 +35,140 @@ from concurrent.futures import ProcessPoolExecutor
|
|
| 35 |
|
| 36 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
torch.set_default_device('cuda')
|
| 47 |
|
|
@@ -181,113 +308,6 @@ def visualize_attention_hiddenstate(attention_tensor, head=None, start_img_token
|
|
| 181 |
|
| 182 |
return heat_maps, top_5_tokens
|
| 183 |
|
| 184 |
-
def generate_next_token_table_image(model, tokenizer, response, index_focus):
|
| 185 |
-
next_token_table = []
|
| 186 |
-
for layer_index in range(len(response.hidden_states[index_focus])):
|
| 187 |
-
h_out = model.language_model.lm_head(
|
| 188 |
-
model.language_model.model.norm(response.hidden_states[index_focus][layer_index][0])
|
| 189 |
-
)
|
| 190 |
-
h_out = torch.softmax(h_out, -1)
|
| 191 |
-
top_tokens = []
|
| 192 |
-
for token_index in h_out.argsort(descending=True)[0, :3]: # Top 3
|
| 193 |
-
token_str = tokenizer.decode(token_index)
|
| 194 |
-
prob = float(h_out[0, int(token_index)])
|
| 195 |
-
top_tokens.append((token_str, prob))
|
| 196 |
-
next_token_table.append((layer_index, top_tokens))
|
| 197 |
-
next_token_table = next_token_table[::-1]
|
| 198 |
-
|
| 199 |
-
html_rows = ""
|
| 200 |
-
last_layer_index = len(next_token_table) - 1
|
| 201 |
-
|
| 202 |
-
for i, (layer_index, tokens) in enumerate(next_token_table):
|
| 203 |
-
row = f"<tr><td style='font-weight: bold'>Layer {layer_index}</td>"
|
| 204 |
-
|
| 205 |
-
# For the first column (Top 1)
|
| 206 |
-
token_str, prob = tokens[0]
|
| 207 |
-
|
| 208 |
-
# If this is the last layer in the table, make the text blue
|
| 209 |
-
if layer_index == last_layer_index:
|
| 210 |
-
row += f"<td><span style='color: red; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
|
| 211 |
-
else:
|
| 212 |
-
row += f"<td><span style='color: blue; font-weight: bold'>{token_str}</span> ({prob:.2%})</td>"
|
| 213 |
-
|
| 214 |
-
# For the other columns, keep normal formatting
|
| 215 |
-
for token_str, prob in tokens[1:]:
|
| 216 |
-
row += f"<td>{token_str} ({prob:.2%})</td>"
|
| 217 |
-
|
| 218 |
-
row += "</tr>"
|
| 219 |
-
html_rows += row
|
| 220 |
-
|
| 221 |
-
html_code = f'''
|
| 222 |
-
<html>
|
| 223 |
-
<head>
|
| 224 |
-
<meta charset="utf-8">
|
| 225 |
-
<style>
|
| 226 |
-
table {{
|
| 227 |
-
font-family: 'Noto Sans';
|
| 228 |
-
font-size: 12px;
|
| 229 |
-
border-collapse: collapse;
|
| 230 |
-
table-layout: fixed;
|
| 231 |
-
width: 100%;
|
| 232 |
-
}}
|
| 233 |
-
th, td {{
|
| 234 |
-
border: 1px solid black;
|
| 235 |
-
padding: 8px;
|
| 236 |
-
width: 150px;
|
| 237 |
-
height: 30px;
|
| 238 |
-
overflow: hidden;
|
| 239 |
-
text-overflow: ellipsis;
|
| 240 |
-
white-space: nowrap;
|
| 241 |
-
text-align: center;
|
| 242 |
-
}}
|
| 243 |
-
th.layer {{
|
| 244 |
-
width: 100px;
|
| 245 |
-
}}
|
| 246 |
-
th.title {{
|
| 247 |
-
font-size: 14px;
|
| 248 |
-
padding: 10px;
|
| 249 |
-
height: auto;
|
| 250 |
-
white-space: normal;
|
| 251 |
-
overflow: visible;
|
| 252 |
-
}}
|
| 253 |
-
</style>
|
| 254 |
-
</head>
|
| 255 |
-
<body style="background-color: white;">
|
| 256 |
-
<table>
|
| 257 |
-
<tr>
|
| 258 |
-
<th colspan="4" class="title">
|
| 259 |
-
Top hidden tokens per layer for the Prediction
|
| 260 |
-
</th>
|
| 261 |
-
</tr>
|
| 262 |
-
<tr>
|
| 263 |
-
<th class="layer">Layer ⬆️</th>
|
| 264 |
-
<th>Top 1</th>
|
| 265 |
-
<th>Top 2</th>
|
| 266 |
-
<th>Top 3</th>
|
| 267 |
-
</tr>
|
| 268 |
-
{html_rows}
|
| 269 |
-
</table>
|
| 270 |
-
</body>
|
| 271 |
-
</html>
|
| 272 |
-
'''
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 276 |
-
hti = Html2Image(output_path=tmpdir)
|
| 277 |
-
hti.browser_flags = [
|
| 278 |
-
"--headless=new", # ← Dùng chế độ headless mới
|
| 279 |
-
"--disable-gpu", # ← Tắt GPU
|
| 280 |
-
"--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
|
| 281 |
-
"--no-sandbox", # ← Tránh lỗi sandbox đa luồng
|
| 282 |
-
]
|
| 283 |
-
filename = str(uuid.uuid4())+".png"
|
| 284 |
-
# filename = 'next_token_table.png'
|
| 285 |
-
hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
|
| 286 |
-
img_path = os.path.join(tmpdir, filename)
|
| 287 |
-
img_cv2 = cv2.imread(img_path)[:,:,::-1]
|
| 288 |
-
os.remove(img_path)
|
| 289 |
-
return img_cv2
|
| 290 |
-
|
| 291 |
def adjust_overlay(overlay, text_img):
|
| 292 |
h_o, w_o = overlay.shape[:2]
|
| 293 |
h_t, w_t = text_img.shape[:2]
|
|
@@ -313,36 +333,6 @@ def adjust_overlay(overlay, text_img):
|
|
| 313 |
|
| 314 |
return overlay_resized
|
| 315 |
|
| 316 |
-
def generate_text_image_with_html2image(old_text, input_token, new_token, image_width=400, min_height=1000, font_size=16):
|
| 317 |
-
full_text = old_text + f"<span style='color:blue; font-weight:bold'>[{input_token}]</span>"+ "→" + f"<span style='color:red; font-weight:bold'>[{new_token}]</span>"
|
| 318 |
-
|
| 319 |
-
# Thay \n bằng thẻ HTML <br> để xuống dòng
|
| 320 |
-
full_text = full_text.replace('\n', '<br>')
|
| 321 |
-
|
| 322 |
-
html_code = f'''
|
| 323 |
-
<html>
|
| 324 |
-
<head>
|
| 325 |
-
<meta charset="utf-8">
|
| 326 |
-
</head>
|
| 327 |
-
<body style="font-family: 'DejaVu Sans', sans-serif; font-size: {font_size}px; width: {image_width}px; min-height: {min_height}px; padding: 10px; background-color: white; line-height: 1.4;">
|
| 328 |
-
{full_text}
|
| 329 |
-
</body>
|
| 330 |
-
</html>
|
| 331 |
-
'''
|
| 332 |
-
save_path = str(uuid.uuid4())+".png"
|
| 333 |
-
hti = Html2Image(output_path='.')
|
| 334 |
-
hti.browser_flags = [
|
| 335 |
-
"--headless=new", # ← Dùng chế độ headless mới
|
| 336 |
-
"--disable-gpu", # ← Tắt GPU
|
| 337 |
-
"--disable-software-rasterizer", # ← Tránh dùng fallback GPU software
|
| 338 |
-
"--no-sandbox", # ← Tránh lỗi sandbox đa luồng
|
| 339 |
-
]
|
| 340 |
-
hti.screenshot(html_str=html_code, save_as=save_path, size=(image_width, min_height))
|
| 341 |
-
text_img = cv2.imread(save_path)
|
| 342 |
-
text_img = cv2.cvtColor(text_img, cv2.COLOR_BGR2RGB)
|
| 343 |
-
os.remove(save_path)
|
| 344 |
-
return text_img
|
| 345 |
-
|
| 346 |
def extract_next_token_table_data(model, tokenizer, response, index_focus):
|
| 347 |
next_token_table = []
|
| 348 |
for layer_index in range(len(response.hidden_states[index_focus])):
|
|
@@ -359,98 +349,6 @@ def extract_next_token_table_data(model, tokenizer, response, index_focus):
|
|
| 359 |
next_token_table = next_token_table[::-1]
|
| 360 |
return next_token_table
|
| 361 |
|
| 362 |
-
def render_next_token_table_image(table_data, predict_token):
|
| 363 |
-
import tempfile, uuid, os
|
| 364 |
-
from html2image import Html2Image
|
| 365 |
-
import cv2
|
| 366 |
-
|
| 367 |
-
html_rows = ""
|
| 368 |
-
last_layer_index = len(table_data)
|
| 369 |
-
for layer_index, tokens in table_data:
|
| 370 |
-
row = f"<tr><td style='font-weight: bold'>Layer {layer_index+1}</td>"
|
| 371 |
-
|
| 372 |
-
token_str, prob = tokens[0]
|
| 373 |
-
if token_str == predict_token:
|
| 374 |
-
style = "color: red; font-weight: bold"
|
| 375 |
-
else:
|
| 376 |
-
style = "color: blue; font-weight: bold"
|
| 377 |
-
row += f"<td><span style='{style}'>{token_str}</span> ({prob:.2%})</td>"
|
| 378 |
-
|
| 379 |
-
for token_str, prob in tokens[1:]:
|
| 380 |
-
row += f"<td>{token_str} ({prob:.2%})</td>"
|
| 381 |
-
|
| 382 |
-
row += "</tr>"
|
| 383 |
-
html_rows += row
|
| 384 |
-
|
| 385 |
-
html_code = f'''
|
| 386 |
-
<html>
|
| 387 |
-
<head>
|
| 388 |
-
<meta charset="utf-8">
|
| 389 |
-
<style>
|
| 390 |
-
table {{
|
| 391 |
-
font-family: 'Noto Sans';
|
| 392 |
-
font-size: 12px;
|
| 393 |
-
border-collapse: collapse;
|
| 394 |
-
table-layout: fixed;
|
| 395 |
-
width: 100%;
|
| 396 |
-
}}
|
| 397 |
-
th, td {{
|
| 398 |
-
border: 1px solid black;
|
| 399 |
-
padding: 8px;
|
| 400 |
-
width: 150px;
|
| 401 |
-
height: 30px;
|
| 402 |
-
overflow: hidden;
|
| 403 |
-
text-overflow: ellipsis;
|
| 404 |
-
white-space: nowrap;
|
| 405 |
-
text-align: center;
|
| 406 |
-
}}
|
| 407 |
-
th.layer {{
|
| 408 |
-
width: 100px;
|
| 409 |
-
}}
|
| 410 |
-
th.title {{
|
| 411 |
-
font-size: 14px;
|
| 412 |
-
padding: 10px;
|
| 413 |
-
height: auto;
|
| 414 |
-
white-space: normal;
|
| 415 |
-
overflow: visible;
|
| 416 |
-
}}
|
| 417 |
-
</style>
|
| 418 |
-
</head>
|
| 419 |
-
<body style="background-color: white;">
|
| 420 |
-
<table>
|
| 421 |
-
<tr>
|
| 422 |
-
<th colspan="4" class="title">
|
| 423 |
-
Hidden states per Transformer layer (LLM) for Prediction
|
| 424 |
-
</th>
|
| 425 |
-
</tr>
|
| 426 |
-
<tr>
|
| 427 |
-
<th class="layer">Layer ⬆️</th>
|
| 428 |
-
<th>Top 1</th>
|
| 429 |
-
<th>Top 2</th>
|
| 430 |
-
<th>Top 3</th>
|
| 431 |
-
</tr>
|
| 432 |
-
{html_rows}
|
| 433 |
-
</table>
|
| 434 |
-
</body>
|
| 435 |
-
</html>
|
| 436 |
-
'''
|
| 437 |
-
|
| 438 |
-
with tempfile.TemporaryDirectory() as tmpdir:
|
| 439 |
-
hti = Html2Image(output_path=tmpdir)
|
| 440 |
-
hti.browser_flags = [
|
| 441 |
-
"--headless=new",
|
| 442 |
-
"--disable-gpu",
|
| 443 |
-
"--disable-software-rasterizer",
|
| 444 |
-
"--no-sandbox",
|
| 445 |
-
]
|
| 446 |
-
filename = str(uuid.uuid4()) + ".png"
|
| 447 |
-
hti.screenshot(html_str=html_code, save_as=filename, size=(500, 1000))
|
| 448 |
-
img_path = os.path.join(tmpdir, filename)
|
| 449 |
-
img_cv2 = cv2.imread(img_path)[:, :, ::-1]
|
| 450 |
-
os.remove(img_path)
|
| 451 |
-
return img_cv2
|
| 452 |
-
|
| 453 |
-
|
| 454 |
model = AutoModel.from_pretrained(
|
| 455 |
"khang119966/Vintern-1B-v3_5-explainableAI",
|
| 456 |
torch_dtype=torch.bfloat16,
|
|
@@ -460,9 +358,8 @@ model = AutoModel.from_pretrained(
|
|
| 460 |
).eval().cuda()
|
| 461 |
tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
|
| 462 |
|
| 463 |
-
# Hàm bao để truyền vào multiprocessing
|
| 464 |
def generate_text_img_wrapper(args):
|
| 465 |
-
return
|
| 466 |
|
| 467 |
def generate_hidden_img_wrapper(args):
|
| 468 |
return render_next_token_table_image(*args)
|
|
@@ -568,16 +465,21 @@ def generate_video(image, prompt, max_tokens):
|
|
| 568 |
for frame in visualization_frames:
|
| 569 |
frame = cv2.resize(frame,(visualization_frames[0].shape[1],visualization_frames[0].shape[0]))
|
| 570 |
resized_visualization_frames.append(frame)
|
| 571 |
-
|
| 572 |
# Lưu thành video MP4 bằng imageio
|
| 573 |
imageio.mimsave(
|
| 574 |
-
'
|
| 575 |
resized_visualization_frames, # dạng RGB
|
| 576 |
fps=5
|
| 577 |
)
|
| 578 |
-
|
| 579 |
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
with gr.Blocks() as demo:
|
| 583 |
gr.Markdown("""# 🎥 Visualizing How Multimodal Models Think
|
|
|
|
| 35 |
|
| 36 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 37 |
|
| 38 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 39 |
+
import textwrap
|
| 40 |
+
import uuid
|
| 41 |
+
import os
|
| 42 |
|
| 43 |
+
def generate_text_image_with_pil(old_text, input_token, new_token, image_width=400, min_height=1000, font_size=16):
|
| 44 |
+
import textwrap
|
| 45 |
+
import numpy as np
|
| 46 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 47 |
+
|
| 48 |
+
# Split text by newlines first to preserve manual line breaks
|
| 49 |
+
paragraphs = old_text.split('\n')
|
| 50 |
+
|
| 51 |
+
# Add the token information to the last paragraph
|
| 52 |
+
if paragraphs:
|
| 53 |
+
paragraphs[-1] += f"[{input_token}]→[{new_token}]"
|
| 54 |
+
else:
|
| 55 |
+
paragraphs = [f"[{input_token}]→[{new_token}]"]
|
| 56 |
+
|
| 57 |
+
# Create a list to store all wrapped lines
|
| 58 |
+
all_lines = []
|
| 59 |
+
|
| 60 |
+
# Process each paragraph separately
|
| 61 |
+
for paragraph in paragraphs:
|
| 62 |
+
# Only wrap if paragraph is not empty
|
| 63 |
+
if paragraph.strip():
|
| 64 |
+
wrapped_lines = textwrap.wrap(paragraph, width=60)
|
| 65 |
+
all_lines.extend(wrapped_lines)
|
| 66 |
+
else:
|
| 67 |
+
# Add an empty line for empty paragraphs (newlines)
|
| 68 |
+
all_lines.append("")
|
| 69 |
+
|
| 70 |
+
# Create image
|
| 71 |
+
img = Image.new('RGB', (image_width, min_height), color='white')
|
| 72 |
+
draw = ImageDraw.Draw(img)
|
| 73 |
+
|
| 74 |
+
# Load font
|
| 75 |
+
font_path = "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc"
|
| 76 |
+
font = ImageFont.truetype(font_path, font_size)
|
| 77 |
+
|
| 78 |
+
# Draw text
|
| 79 |
+
y = 10
|
| 80 |
+
token_marker = f"[{input_token}]→[{new_token}]"
|
| 81 |
+
|
| 82 |
+
for line in all_lines:
|
| 83 |
+
if token_marker in line:
|
| 84 |
+
parts = line.split(token_marker)
|
| 85 |
+
# Draw text before token
|
| 86 |
+
draw.text((10, y), parts[0], fill="black", font=font)
|
| 87 |
+
x = 10 + draw.textlength(parts[0], font=font)
|
| 88 |
+
|
| 89 |
+
# Draw input token in blue
|
| 90 |
+
draw.text((x, y), f"[{input_token}]", fill="blue", font=font)
|
| 91 |
+
x += draw.textlength(f"[{input_token}]", font=font)
|
| 92 |
+
|
| 93 |
+
# Draw arrow
|
| 94 |
+
draw.text((x, y), "→", fill="black", font=font)
|
| 95 |
+
x += draw.textlength("→", font=font)
|
| 96 |
+
|
| 97 |
+
# Draw new token in red
|
| 98 |
+
draw.text((x, y), f"[{new_token}]", fill="red", font=font)
|
| 99 |
+
|
| 100 |
+
# Draw remainder text if any
|
| 101 |
+
if len(parts) > 1 and parts[1]:
|
| 102 |
+
x += draw.textlength(f"[{new_token}]", font=font)
|
| 103 |
+
draw.text((x, y), parts[1], fill="black", font=font)
|
| 104 |
+
else:
|
| 105 |
+
draw.text((10, y), line, fill="black", font=font)
|
| 106 |
+
|
| 107 |
+
# Move to next line, adding extra space between paragraphs
|
| 108 |
+
y += font_size + 8
|
| 109 |
+
|
| 110 |
+
return np.array(img)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def render_next_token_table_image(table_data, predict_token, image_width=500, row_height=40, font_size=14):
|
| 117 |
+
# Cài đặt font hỗ trợ đa ngôn ngữ (sửa đường dẫn nếu cần)
|
| 118 |
+
# font_path = "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf"
|
| 119 |
+
font_path = "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc"
|
| 120 |
+
|
| 121 |
+
font = ImageFont.truetype(font_path, font_size)
|
| 122 |
|
| 123 |
+
num_rows = len(table_data) + 2 # +2 cho phần tiêu đề
|
| 124 |
+
num_cols = 4 # Layer | Top1 | Top2 | Top3
|
| 125 |
+
table_width = image_width
|
| 126 |
+
col_width = table_width // num_cols
|
| 127 |
+
table_height = num_rows * row_height
|
| 128 |
+
|
| 129 |
+
# Tạo ảnh trắng
|
| 130 |
+
img = Image.new("RGB", (table_width, table_height), "white")
|
| 131 |
+
draw = ImageDraw.Draw(img)
|
| 132 |
+
|
| 133 |
+
def draw_cell(x, y, text, color="black", bold=False):
|
| 134 |
+
if bold:
|
| 135 |
+
draw.text((x + 5, y + 5), text, font=font, fill=color)
|
| 136 |
+
else:
|
| 137 |
+
draw.text((x + 5, y + 5), text, font=font, fill=color)
|
| 138 |
+
|
| 139 |
+
# Vẽ hàng tiêu đề chính
|
| 140 |
+
draw.rectangle([0, 0, table_width, row_height], outline="black")
|
| 141 |
+
draw_cell(5, 5, "Hidden states per Transformer layer (LLM) for Prediction", bold=True)
|
| 142 |
+
|
| 143 |
+
# Vẽ tiêu đề cột
|
| 144 |
+
headers = ["Layer ⬆️", "Top 1", "Top 2", "Top 3"]
|
| 145 |
+
for col, header in enumerate(headers):
|
| 146 |
+
x0 = col * col_width
|
| 147 |
+
y0 = row_height
|
| 148 |
+
draw.rectangle([x0, y0, x0 + col_width, y0 + row_height], outline="black")
|
| 149 |
+
draw_cell(x0, y0, header, bold=True)
|
| 150 |
+
|
| 151 |
+
# Vẽ từng hàng layer
|
| 152 |
+
for i, (layer_index, tokens) in enumerate(table_data):
|
| 153 |
+
y = (i + 2) * row_height
|
| 154 |
+
for col in range(num_cols):
|
| 155 |
+
x = col * col_width
|
| 156 |
+
draw.rectangle([x, y, x + col_width, y + row_height], outline="black")
|
| 157 |
+
|
| 158 |
+
if col == 0:
|
| 159 |
+
draw_cell(x, y, f"Layer {layer_index+1}", bold=True)
|
| 160 |
+
else:
|
| 161 |
+
if col - 1 < len(tokens):
|
| 162 |
+
token_str, prob = tokens[col - 1]
|
| 163 |
+
# Thay \n bằng chuỗi "\\n"
|
| 164 |
+
token_str = token_str
|
| 165 |
+
color = "red" if token_str == predict_token and col == 1 else "blue" if col == 1 else "black"
|
| 166 |
+
bold = token_str == predict_token and col == 1
|
| 167 |
+
token_str_ = token_str.replace("\n", "\\n").replace(" ", "\\s").replace("\t", "\\t")
|
| 168 |
+
draw_cell(x, y, f"{token_str_} ({prob:.1%})", color=color, bold=bold)
|
| 169 |
+
|
| 170 |
+
return np.array(img)
|
| 171 |
+
|
| 172 |
|
| 173 |
torch.set_default_device('cuda')
|
| 174 |
|
|
|
|
| 308 |
|
| 309 |
return heat_maps, top_5_tokens
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
def adjust_overlay(overlay, text_img):
|
| 312 |
h_o, w_o = overlay.shape[:2]
|
| 313 |
h_t, w_t = text_img.shape[:2]
|
|
|
|
| 333 |
|
| 334 |
return overlay_resized
|
| 335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
def extract_next_token_table_data(model, tokenizer, response, index_focus):
|
| 337 |
next_token_table = []
|
| 338 |
for layer_index in range(len(response.hidden_states[index_focus])):
|
|
|
|
| 349 |
next_token_table = next_token_table[::-1]
|
| 350 |
return next_token_table
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
model = AutoModel.from_pretrained(
|
| 353 |
"khang119966/Vintern-1B-v3_5-explainableAI",
|
| 354 |
torch_dtype=torch.bfloat16,
|
|
|
|
| 358 |
).eval().cuda()
|
| 359 |
tokenizer = AutoTokenizer.from_pretrained("khang119966/Vintern-1B-v3_5-explainableAI", trust_remote_code=True, use_fast=False)
|
| 360 |
|
|
|
|
| 361 |
def generate_text_img_wrapper(args):
|
| 362 |
+
return generate_text_image_with_pil(*args, image_width=500, min_height=1000)
|
| 363 |
|
| 364 |
def generate_hidden_img_wrapper(args):
|
| 365 |
return render_next_token_table_image(*args)
|
|
|
|
| 465 |
for frame in visualization_frames:
|
| 466 |
frame = cv2.resize(frame,(visualization_frames[0].shape[1],visualization_frames[0].shape[0]))
|
| 467 |
resized_visualization_frames.append(frame)
|
| 468 |
+
|
| 469 |
# Lưu thành video MP4 bằng imageio
|
| 470 |
imageio.mimsave(
|
| 471 |
+
'heatmap_with_music.mp4',
|
| 472 |
resized_visualization_frames, # dạng RGB
|
| 473 |
fps=5
|
| 474 |
)
|
|
|
|
| 475 |
|
| 476 |
+
# Nối video và nhạc
|
| 477 |
+
video = VideoFileClip("heatmap_animation.mp4")
|
| 478 |
+
audio = AudioFileClip("legacy-of-the-century-background-cinematic-music-for-video-46-second-319542.mp3").set_duration(video.duration)
|
| 479 |
+
final = video.set_audio(audio)
|
| 480 |
+
final.write_videofile("heatmap_with_music.mp4", codec="libx264", audio_codec="aac")
|
| 481 |
+
|
| 482 |
+
return "heatmap_with_music.mp4"
|
| 483 |
|
| 484 |
with gr.Blocks() as demo:
|
| 485 |
gr.Markdown("""# 🎥 Visualizing How Multimodal Models Think
|