Spaces:
Sleeping
Sleeping
| import torch | |
| import einops | |
| import matplotlib.pyplot as plt | |
| from torchvision.transforms import ToPILImage | |
| from PIL import Image | |
| import os | |
| import math | |
| from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel | |
| import gradio as gr | |
| from concurrent.futures import ThreadPoolExecutor | |
| ############################## RATIONAL BEHIND ############################### | |
| # Load the model, tokenizer, and image processor with error handling | |
| def load_model_and_components(model_name): | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| image_processor = AutoImageProcessor.from_pretrained(model_name) | |
| return model, tokenizer, image_processor | |
| # Preload both models in parallel | |
| def preload_models(): | |
| models = {} | |
| model_names = ["laicsiifes/swin-distilbertimbau"] #, "laicsiifes/swin-gportuguese-2"] | |
| with ThreadPoolExecutor() as executor: | |
| results = executor.map(load_model_and_components, model_names) | |
| for name, result in zip(model_names, results): | |
| models[name] = result | |
| return models | |
| models = preload_models() | |
| # Predefined images for selection | |
| image_folder = "images" | |
| predefined_images = [ | |
| Image.open(os.path.join(image_folder, fname)).convert("RGB") | |
| for fname in os.listdir(image_folder) | |
| if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm')) | |
| ] | |
| # Function to preprocess the image to RGB format | |
| def preprocess_image(image): | |
| if image is None: | |
| return None, None | |
| pil_image = image.convert("RGB") | |
| return pil_image, None | |
| # Function to process the image in tokens with its attention maps | |
| def get_attn_map(model, image, processor, tokenizer): | |
| pixel_values = processor(image, return_tensors="pt").pixel_values | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model.generate( | |
| pixel_values=pixel_values, | |
| return_dict_in_generate=True, | |
| output_hidden_states=True, | |
| output_attentions=True, | |
| max_length=25, | |
| num_beams=5 | |
| ) | |
| last_layers = [tensor_tuple[-1] for tensor_tuple in output.cross_attentions] | |
| attention_maps = torch.stack(last_layers, dim=0) | |
| attention_maps = einops.reduce( | |
| attention_maps, | |
| 'token batch head sequence (height width) -> token sequence (height width)', | |
| height=7, width=7, | |
| reduction='mean' | |
| ) | |
| tokens = output.sequences[0] | |
| token_texts = tokenizer.convert_ids_to_tokens(tokens) | |
| valid_token_texts = token_texts[1:] | |
| return valid_token_texts, attention_maps, output | |
| # Function to preprocess the captions tokens and attention maps | |
| # e.g. tokens `sent` and `##ada` yield the word `sentada` | |
| def join_tokens(text_tokens, attention_maps, connect_symbol='##'): | |
| tokens = text_tokens.copy() | |
| attn_map = attention_maps.detach().clone() | |
| i = 0 | |
| while i < len(tokens) and tokens[i] != '[SEP]': | |
| if tokens[i].startswith(connect_symbol): | |
| tokens[i] = tokens[i - 1] + tokens[i].replace(connect_symbol, '') | |
| tokens.pop(i - 1) | |
| attn_map[i][0] = attn_map[i - 1][0] + attn_map[i][0] | |
| attn_map = torch.cat((attn_map[:i - 1], attn_map[i:]), dim=0) | |
| i -= 1 | |
| i += 1 | |
| tokens = tokens[1:i - 1] | |
| attn_map = attn_map[1:i - 1] | |
| return tokens, attn_map | |
| # Make the attention maps visually organized and presentable | |
| def generate_attention_gallery(image, selected_model): | |
| if image is None: | |
| return [] | |
| model, tokenizer, processor = models[selected_model] | |
| tokens, attention_maps, _ = get_attn_map(model, image, processor, tokenizer) | |
| joined_tokens, joined_attn_maps = join_tokens(tokens, attention_maps) | |
| grid_size = int(joined_attn_maps.size(-1) ** 0.5) | |
| gallery_output = [] | |
| for i, token in enumerate(joined_tokens): | |
| att_map = joined_attn_maps[i].view(grid_size, grid_size) | |
| att_map = (att_map - att_map.min()) / (att_map.max() - att_map.min()) | |
| att_map = att_map.repeat_interleave(32, dim=0).repeat_interleave(32, dim=1) | |
| att_map_resized = ToPILImage()( | |
| att_map.unsqueeze(0).repeat(3, 1, 1) | |
| ).resize(image.size[::]) | |
| blended = Image.blend(image, att_map_resized, alpha=0.75) | |
| gallery_output.append((blended, token)) | |
| return gallery_output | |
| ################################### PAGE #################################### | |
| # Define UI | |
| with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface: | |
| gr.Markdown(""" | |
| # Welcome to the LAICSI-IFES Vision Encoder-Decoder Demo | |
| --- | |
| ### Select a pretrained model and upload an image to visualize attention maps. | |
| """) | |
| with gr.Row(variant='panel'): | |
| model_selector = gr.Dropdown( | |
| choices=list(models.keys()), | |
| value="laicsiifes/swin-distilbertimbau", | |
| label="Select Model" | |
| ) | |
| gr.Markdown("""---\n### Upload or select an image and click 'Generate' to view attention maps.""") | |
| with gr.Row(variant='panel'): | |
| with gr.Column(): | |
| image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400) | |
| with gr.Column(): | |
| output_gallery = gr.Gallery(label="Attention Maps", columns=4, rows=3, height=600) | |
| generate_button = gr.Button("Generate") | |
| gr.Markdown("""---""") | |
| with gr.Row(variant='panel'): | |
| examples = gr.Examples( | |
| examples=predefined_images, | |
| fn=preprocess_image, | |
| inputs=[image_display], | |
| outputs=[image_display, output_gallery], | |
| label="Examples" | |
| ) | |
| # Actions | |
| model_selector.change(fn=lambda: (None, []), outputs=[image_display, output_gallery]) | |
| image_display.upload(fn=preprocess_image, inputs=[image_display], outputs=[image_display, output_gallery]) | |
| image_display.clear(fn=lambda: None, outputs=[output_gallery]) | |
| generate_button.click(fn=generate_attention_gallery, inputs=[image_display, model_selector], outputs=output_gallery) | |
| interface.launch(share=False) | |