Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Visual-CoT: Chain-of-Thought Reasoning Demo on Hugging Face Spaces | |
| Showcasing Visual Chain-of-Thought with Interactive Benchmark Examples | |
| Paper: Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive | |
| Dataset and Benchmark for Chain-of-Thought Reasoning | |
| https://arxiv.org/abs/2403.16999 | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import re | |
| import json | |
| import spaces | |
| from pathlib import Path | |
| import requests | |
| from io import BytesIO | |
| from huggingface_hub import login | |
| from llava.constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| from llava.conversation import conv_templates | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| from llava.mm_utils import ( | |
| process_images, | |
| tokenizer_image_token, | |
| get_model_name_from_path, | |
| ) | |
| # No need for local benchmark loader - using HF datasets directly | |
| # ============================================================================= | |
| # Authentication | |
| # ============================================================================= | |
| # Login to Hugging Face using token from Spaces secrets | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| print("✓ Successfully logged in to Hugging Face") | |
| except Exception as e: | |
| print(f"⚠ Warning: Failed to login to Hugging Face: {e}") | |
| print(" Continuing without authentication...") | |
| else: | |
| print("ℹ No HF_TOKEN found, continuing without authentication") | |
| # ============================================================================= | |
| # Configuration | |
| # ============================================================================= | |
| # Available models | |
| AVAILABLE_MODELS = { | |
| "VisCoT-7B-224 (Fastest)": "deepcs233/VisCoT-7b-224", | |
| "VisCoT-7B-336 (Balanced)": "deepcs233/VisCoT-7b-336", | |
| "VisCoT-13B-224 (Better)": "deepcs233/VisCoT-13b-224", | |
| "VisCoT-13B-336 (Best)": "deepcs233/VisCoT-13b-336", | |
| } | |
| MODEL_PATH = "deepcs233/VisCoT-13b-336" # Default: best quality | |
| CURRENT_MODEL_NAME = "VisCoT-13B-336 (Best)" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Benchmark datasets from Visual Chain-of-Thought Reasoning Benchmarks Collection | |
| # https://huggingface.co/collections/tuandunghcmut/visual-chain-of-thought-reasoning-benchmarks | |
| BENCHMARK_DATASETS = { | |
| "GQA": { | |
| "path": "lmms-lab/GQA", | |
| "config": "train_balanced_images", | |
| "split": "train", | |
| "description": "Scene graph QA (72K balanced images)", | |
| }, | |
| "RefCOCO": { | |
| "path": "lmms-lab/RefCOCO", | |
| "config": "default", | |
| "split": "val", | |
| "description": "Referring expression comprehension (8.8K validation)", | |
| }, | |
| "RefCOCO+": { | |
| "path": "lmms-lab/RefCOCOplus", | |
| "config": "default", | |
| "split": "val", | |
| "description": "RefCOCO with no location words (3.8K validation)", | |
| }, | |
| "RefCOCOg": { | |
| "path": "lmms-lab/RefCOCOg", | |
| "config": "default", | |
| "split": "val", | |
| "description": "RefCOCO with longer expressions (7.5K validation)", | |
| }, | |
| "POPE": { | |
| "path": "lmms-lab/POPE", | |
| "config": "default", | |
| "split": "test", | |
| "description": "Object probing evaluation (9K test)", | |
| }, | |
| "ScienceQA": { | |
| "path": "lmms-lab/ScienceQA", | |
| "config": "ScienceQA-FULL", | |
| "split": "validation", | |
| "description": "Science question answering (4.2K validation)", | |
| }, | |
| "MM-GCoT": { | |
| "path": "AQUA6/MM-GCoT", | |
| "config": "train", | |
| "split": "train", | |
| "description": "Multi-Modal Graph CoT (63.9K training)", | |
| }, | |
| "VGR": { | |
| "path": "BytedanceDouyinContent/VGR", | |
| "config": "default", | |
| "split": "train", | |
| "description": "Visual Grounding & Reasoning (90K training)", | |
| }, | |
| } | |
| print(f"✅ Configured {len(BENCHMARK_DATASETS)} benchmark datasets from HF collection") | |
| # ============================================================================= | |
| # Model Loading (Global - bfloat16) | |
| # ============================================================================= | |
| print("🔄 Loading Visual-CoT model in bfloat16...") | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(MODEL_PATH) | |
| # Load model globally with bfloat16 precision | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| MODEL_PATH, | |
| None, | |
| model_name, | |
| load_8bit=False, | |
| load_4bit=False, | |
| device=DEVICE, | |
| ) | |
| # Ensure model is in bfloat16 | |
| if DEVICE == "cuda": | |
| model = model.to(dtype=torch.bfloat16) | |
| print(f"✓ Model loaded in bfloat16 on {DEVICE}") | |
| else: | |
| print(f"✓ Model loaded on {DEVICE} (CPU mode)") | |
| print(f"✓ Model: {model_name}") | |
| print(f"✓ Context length: {context_len}") | |
| print(f"✓ Device: {DEVICE}") | |
| # ============================================================================= | |
| # Model Management Functions | |
| # ============================================================================= | |
| def switch_model(model_choice): | |
| """Switch to a different model""" | |
| global tokenizer, model, image_processor, context_len, MODEL_PATH, CURRENT_MODEL_NAME | |
| try: | |
| new_model_path = AVAILABLE_MODELS[model_choice] | |
| if new_model_path == MODEL_PATH: | |
| return f"Already using {model_choice}" | |
| print(f"\n🔄 Switching to {model_choice}...") | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(new_model_path) | |
| # Load new model | |
| tokenizer, model, image_processor, context_len = load_pretrained_model( | |
| new_model_path, | |
| None, | |
| model_name, | |
| load_8bit=False, | |
| load_4bit=False, | |
| device=DEVICE, | |
| ) | |
| # Ensure bfloat16 | |
| if DEVICE == "cuda": | |
| model = model.to(dtype=torch.bfloat16) | |
| MODEL_PATH = new_model_path | |
| CURRENT_MODEL_NAME = model_choice | |
| print(f"✓ Switched to {model_choice}") | |
| return f"✓ Successfully switched to {model_choice}\nModel: {model_name}\nDevice: {DEVICE}" | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Failed to switch model: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| return error_msg | |
| # ============================================================================= | |
| # Benchmark Loading Functions | |
| # ============================================================================= | |
| def load_benchmark_example(dataset_name, index=0): | |
| """Load an example from HF benchmark dataset""" | |
| try: | |
| from datasets import load_dataset | |
| dataset_info = BENCHMARK_DATASETS.get(dataset_name) | |
| if not dataset_info: | |
| return None, "Dataset not found", "", "", "" | |
| dataset_path = dataset_info["path"] | |
| dataset_config = dataset_info.get("config") | |
| dataset_split = dataset_info.get("split", "train") | |
| # Load dataset with config and split | |
| print(f"Loading {dataset_name} from {dataset_path} (config={dataset_config}, split={dataset_split})...") | |
| if dataset_config and dataset_config != "None": | |
| dataset = load_dataset(dataset_path, dataset_config, split=dataset_split, streaming=True) | |
| else: | |
| dataset = load_dataset(dataset_path, split=dataset_split, streaming=True) | |
| # Get specific index (for streaming, we need to iterate) | |
| for i, example in enumerate(dataset): | |
| if i == index: | |
| # Extract fields (structure varies by dataset) | |
| image = example.get("image") | |
| question = example.get("question", example.get("text", "")) | |
| # Try to get bounding box in various formats | |
| bbox = example.get("bbox", example.get("bboxes", "")) | |
| if isinstance(bbox, list) and bbox: | |
| bbox_str = str(bbox) | |
| else: | |
| bbox_str = "No bounding box available" | |
| answer = example.get("answer", example.get("label", "")) | |
| status = f"📊 Dataset: {dataset_name} | Example {index + 1}\n{dataset_info['description']}" | |
| return image, question, bbox_str, answer, status | |
| # Stop after a few iterations for efficiency | |
| if i > index + 10: | |
| break | |
| return None, "Index out of range", "", "", "Could not find example at this index" | |
| except Exception as e: | |
| error_msg = f"Error loading {dataset_name}: {str(e)}" | |
| print(error_msg) | |
| import traceback | |
| traceback.print_exc() | |
| return None, error_msg, "", "", error_msg | |
| def load_random_benchmark_example(dataset_name): | |
| """Load a random example from benchmark for inference""" | |
| import random | |
| # Use random index between 0-99 for faster loading | |
| random_index = random.randint(0, 99) | |
| return load_benchmark_example(dataset_name, random_index) | |
| # ============================================================================= | |
| # Utility Functions | |
| # ============================================================================= | |
| def parse_bbox(text): | |
| """Parse bounding box from model output""" | |
| pattern1 = r"###\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| pattern2 = r"\[([\d\.]+),\s*([\d\.]+),\s*([\d\.]+),\s*([\d\.]+)\]" | |
| matches = re.findall(pattern1, text) | |
| if not matches: | |
| matches = re.findall(pattern2, text) | |
| if matches: | |
| bbox = [float(x) for x in matches[-1]] | |
| if all(0 <= x <= 1 for x in bbox): | |
| return bbox | |
| return None | |
| def draw_bounding_box(image, bbox, color="red", width=5): | |
| """Draw bounding box on image""" | |
| if bbox is None: | |
| return image | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| img_width, img_height = img.size | |
| # Convert normalized to pixel coordinates | |
| x1 = int(bbox[0] * img_width) | |
| y1 = int(bbox[1] * img_height) | |
| x2 = int(bbox[2] * img_width) | |
| y2 = int(bbox[3] * img_height) | |
| # Draw rectangle | |
| draw.rectangle([x1, y1, x2, y2], outline=color, width=width) | |
| # Draw label | |
| label = f"ROI: [{bbox[0]:.3f}, {bbox[1]:.3f}, {bbox[2]:.3f}, {bbox[3]:.3f}]" | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 14) | |
| except: | |
| font = ImageFont.load_default() | |
| # Text background | |
| bbox_text = draw.textbbox((x1, y1 - 22), label, font=font) | |
| draw.rectangle([bbox_text[0]-2, bbox_text[1]-2, bbox_text[2]+2, bbox_text[3]+2], fill=color) | |
| draw.text((x1, y1 - 22), label, fill="white", font=font) | |
| return img | |
| def load_benchmark_examples(dataset_name, num_examples=5): | |
| """ | |
| Load examples from benchmark dataset | |
| Returns list of (image_path, question, ground_truth_bbox, ground_truth_answer) | |
| """ | |
| benchmark_file = f"viscot_benchmark/benchmark/{dataset_name}.json" | |
| if not os.path.exists(benchmark_file): | |
| return [] | |
| try: | |
| with open(benchmark_file, 'r') as f: | |
| data = json.load(f) | |
| examples = [] | |
| for item in data[:num_examples]: | |
| # Extract information based on dataset structure | |
| image_file = item.get('image', '') | |
| question = item['conversations'][0]['value'].replace('<image>\n', '').split('Please provide')[0].strip() | |
| gt_bbox_str = item['conversations'][1]['value'] if len(item['conversations']) > 1 else None | |
| gt_answer = item['conversations'][3]['value'] if len(item['conversations']) > 3 else None | |
| examples.append({ | |
| 'image': image_file, | |
| 'question': question, | |
| 'gt_bbox': gt_bbox_str, | |
| 'gt_answer': gt_answer, | |
| 'dataset': dataset_name | |
| }) | |
| return examples | |
| except Exception as e: | |
| print(f"Error loading {dataset_name}: {e}") | |
| return [] | |
| # ============================================================================= | |
| # Main Inference Function (with @spaces.GPU decorator) | |
| # ============================================================================= | |
| # Zero GPU allocation for 120 seconds | |
| def generate_viscot_response(image, question, temperature=0.2, max_tokens=512): | |
| """ | |
| Generate Visual-CoT response with bounding box detection | |
| Args: | |
| image: PIL Image | |
| question: str | |
| temperature: float | |
| max_tokens: int | |
| Returns: | |
| tuple: (bbox_response, final_answer, image_with_bbox, processing_info) | |
| """ | |
| if image is None: | |
| return "❌ Please upload an image!", "", None, "" | |
| if not question.strip(): | |
| return "❌ Please enter a question!", "", None, "" | |
| try: | |
| # Model is already loaded globally - use it directly | |
| # Initialize conversation | |
| conv_mode = "llava_v1" | |
| conv = conv_templates[conv_mode].copy() | |
| # ===================================================================== | |
| # STEP 1: Detect Region of Interest (ROI) | |
| # ===================================================================== | |
| prompt_step1 = ( | |
| f"{DEFAULT_IMAGE_TOKEN}\n{question} " | |
| f"Please provide the bounding box coordinate of the region this question asks about." | |
| ) | |
| conv.append_message(conv.roles[0], prompt_step1) | |
| conv.append_message(conv.roles[1], None) | |
| prompt1 = conv.get_prompt() | |
| # Process image | |
| image_tensor = process_images([image], image_processor, model.config) | |
| if isinstance(image_tensor, list): | |
| image_tensor = [img.to(DEVICE, dtype=torch.bfloat16) for img in image_tensor] | |
| else: | |
| image_tensor = image_tensor.to(DEVICE, dtype=torch.bfloat16) | |
| # Tokenize | |
| input_ids = tokenizer_image_token( | |
| prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| # Generate bbox | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=128, | |
| use_cache=True, | |
| ) | |
| bbox_response = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Parse bbox | |
| bbox = parse_bbox(bbox_response) | |
| # ===================================================================== | |
| # STEP 2: Answer Question with ROI Context | |
| # ===================================================================== | |
| conv.messages[-1][-1] = bbox_response | |
| second_question = ( | |
| f"Please answer the question based on the original image and local detail image. {question}" | |
| ) | |
| conv.append_message(conv.roles[0], second_question) | |
| conv.append_message(conv.roles[1], None) | |
| prompt2 = conv.get_prompt() | |
| input_ids = tokenizer_image_token( | |
| prompt2, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
| ).unsqueeze(0).to(DEVICE) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=image_tensor, | |
| do_sample=temperature > 0.001, | |
| temperature=max(temperature, 0.01), | |
| max_new_tokens=max_tokens, | |
| use_cache=True, | |
| ) | |
| final_answer = tokenizer.decode( | |
| output_ids[0, input_ids.shape[1]:], skip_special_tokens=True | |
| ).strip() | |
| # Visualization | |
| image_with_bbox = draw_bounding_box(image, bbox) if bbox else image | |
| # Processing info | |
| processing_info = f"✓ Processed successfully | Bbox: {bbox if bbox else 'Not detected'}" | |
| return bbox_response, final_answer, image_with_bbox, processing_info | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" | |
| return error_msg, "", None, error_msg | |
| # ============================================================================= | |
| # Gradio Interface | |
| # ============================================================================= | |
| def create_demo(): | |
| """Create Gradio interface""" | |
| # Custom CSS for beautiful UI | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Inter', sans-serif; | |
| } | |
| .header { | |
| text-align: center; | |
| padding: 20px; | |
| background: linear-gradient(135deg, #1e3a8a 0%, #1e40af 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 20px; | |
| } | |
| .info-box { | |
| background: #f0f7ff; | |
| border-left: 4px solid #3b82f6; | |
| padding: 15px; | |
| border-radius: 5px; | |
| margin: 10px 0; | |
| } | |
| .example-box { | |
| border: 2px solid #e5e7eb; | |
| border-radius: 8px; | |
| padding: 10px; | |
| margin: 5px 0; | |
| } | |
| .metric-card { | |
| background: white; | |
| border-radius: 8px; | |
| padding: 15px; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="indigo", | |
| neutral_hue="slate", | |
| ), | |
| css=custom_css, | |
| title="Visual-CoT Demo" | |
| ) as demo: | |
| # Header | |
| gr.HTML(""" | |
| <div class="header"> | |
| <h1 style="color: white;">🌋 Visual-CoT: Chain-of-Thought Reasoning</h1> | |
| <p style="font-size: 18px; margin: 10px 0; color: white;"> | |
| Advancing Multi-Modal Language Models with Visual Chain-of-Thought | |
| </p> | |
| <p style="font-size: 14px; opacity: 0.9;"> | |
| 📄 <a href="https://arxiv.org/abs/2403.16999" style="color: white; text-decoration: underline;"> | |
| Paper (NeurIPS 2024 Spotlight) | |
| </a> | | |
| 💻 <a href="https://github.com/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| GitHub | |
| </a> | | |
| 🤗 <a href="https://huggingface.co/datasets/deepcs233/Visual-CoT" style="color: white; text-decoration: underline;"> | |
| Dataset | |
| </a> | |
| </p> | |
| </div> | |
| """) | |
| # Introduction | |
| gr.Markdown(""" | |
| ## 1. Introduction to Visual-CoT | |
| **Visual Chain-of-Thought (VisCoT)** is a multi-modal language model that enables: | |
| 1. **Region Identification**: Detect key regions in images using bounding boxes | |
| 2. **Step-by-Step Reasoning**: Apply Chain-of-Thought methodology for visual understanding | |
| 3. **Question Answering**: Provide interpretable explanations for visual content | |
| ### 1.1 Dataset Statistics | |
| - 438,000 question-answer pairs with bounding box annotations | |
| - 13 diverse benchmarks (DocVQA, GQA, TextVQA, etc.) | |
| - Based on LLaVA-1.5 architecture with CLIP ViT-L/14 vision encoder | |
| """) | |
| # Authentication notice for Zero GPU | |
| gr.HTML(""" | |
| <div class="info-box"> | |
| <p style="margin: 0; font-size: 14px;"> | |
| <strong>Note:</strong> This Space uses Zero GPU which requires authentication. | |
| Please <a href="https://huggingface.co/login" target="_blank">login</a> or | |
| <a href="https://huggingface.co/join" target="_blank">create a free account</a> if you encounter quota errors. | |
| </p> | |
| </div> | |
| """) | |
| # Model Selector | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Model Selection") | |
| model_dropdown = gr.Dropdown( | |
| choices=list(AVAILABLE_MODELS.keys()), | |
| value=CURRENT_MODEL_NAME, | |
| label="Select Model", | |
| info="Choose model variant (larger = better quality, slower)" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Current Model Status") | |
| model_status = gr.Textbox( | |
| value=f"Active: {CURRENT_MODEL_NAME}", | |
| label="Status", | |
| interactive=False | |
| ) | |
| model_dropdown.change( | |
| fn=switch_model, | |
| inputs=[model_dropdown], | |
| outputs=[model_status] | |
| ) | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # Tab 1: Interactive Demo | |
| # ============================================================ | |
| with gr.Tab("Interactive Demo"): | |
| gr.Markdown(""" | |
| ### 2. Interactive Demonstration | |
| **Procedure**: | |
| 1. Upload an image | |
| 2. Enter a question about the image | |
| 3. The model will: | |
| - Step 1: Detect region of interest (ROI) and output bounding box | |
| - Step 2: Analyze the ROI and generate answer | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Input Image", | |
| height=400, | |
| ) | |
| question_input = gr.Textbox( | |
| label="Question", | |
| placeholder="Example: What is unusual about this image?", | |
| lines=3, | |
| ) | |
| with gr.Accordion("Advanced Parameters", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature", | |
| info="0 = Deterministic, 1 = Creative" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Maximum Output Tokens" | |
| ) | |
| submit_btn = gr.Button("Run Analysis", variant="primary", size="lg") | |
| clear_btn = gr.Button("Clear", size="sm") | |
| gr.Markdown("---") | |
| gr.Markdown("**Load Random Benchmark Example:**") | |
| benchmark_select = gr.Dropdown( | |
| choices=list(BENCHMARK_DATASETS.keys()), | |
| value="GQA", | |
| label="Select Benchmark", | |
| scale=1, | |
| ) | |
| load_random_btn = gr.Button("🎲 Load Random Example", variant="secondary") | |
| with gr.Column(scale=1): | |
| # Output | |
| gr.Markdown("### 3. Results") | |
| with gr.Group(): | |
| gr.Markdown("#### 3.1 Step 1: Region Detection") | |
| bbox_output = gr.Textbox( | |
| label="Detected Bounding Box Coordinates", | |
| lines=2, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 3.2 Step 2: Answer Generation") | |
| answer_output = gr.Textbox( | |
| label="Final Answer", | |
| lines=6, | |
| show_copy_button=True, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("#### 3.3 Visualization") | |
| image_output = gr.Image( | |
| label="Image with Bounding Box Overlay", | |
| type="pil", | |
| height=350, | |
| ) | |
| info_output = gr.Textbox( | |
| label="Processing Info", | |
| lines=1, | |
| visible=False, | |
| ) | |
| # Example questions (20 diverse examples) | |
| gr.Markdown("### 📋 Try These Example Questions") | |
| gr.Examples( | |
| examples=[ | |
| # Available images | |
| ["examples/extreme_ironing.jpg", "What is unusual about this image?"], | |
| ["examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], | |
| # Visual reasoning examples (upload your own images) | |
| [None, "What color is the car in the image?"], | |
| [None, "How many people are in this picture?"], | |
| [None, "What is the main object in the center of the image?"], | |
| [None, "What is the person doing in this photo?"], | |
| [None, "What time of day does this appear to be?"], | |
| [None, "What is the weather like in this image?"], | |
| [None, "What room is this photo taken in?"], | |
| [None, "What brand or logo can you see?"], | |
| # Text reading examples | |
| [None, "What text is written on the sign?"], | |
| [None, "What is the price shown in the image?"], | |
| [None, "What does the document say?"], | |
| [None, "What is the title of this book/poster?"], | |
| # Spatial reasoning | |
| [None, "What is to the left of the main object?"], | |
| [None, "What is on top of the table?"], | |
| [None, "Where is the person standing?"], | |
| # Scene understanding | |
| [None, "What type of place is this?"], | |
| [None, "What activity is happening here?"], | |
| [None, "What is the overall mood or atmosphere?"], | |
| [None, "What can you infer about the context of this image?"], | |
| ], | |
| inputs=[image_input, question_input], | |
| label="Click to load example questions (upload image for questions without images)", | |
| examples_per_page=10, | |
| ) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_viscot_response, | |
| inputs=[image_input, question_input, temperature, max_tokens], | |
| outputs=[bbox_output, answer_output, image_output, info_output], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", "", None, ""), | |
| outputs=[image_input, question_input, bbox_output, answer_output, image_output, info_output], | |
| ) | |
| load_random_btn.click( | |
| fn=load_random_benchmark_example, | |
| inputs=[benchmark_select], | |
| outputs=[image_input, question_input, bbox_output, answer_output, info_output], | |
| ) | |
| # ============================================================ | |
| # Tab 2: Benchmark Explorer | |
| # ============================================================ | |
| with gr.Tab("Benchmark Explorer"): | |
| gr.Markdown(""" | |
| ### Explore Visual-CoT Benchmark Examples | |
| Load and browse real examples from the Visual-CoT benchmark datasets. | |
| Each example includes: image, question, ground-truth bounding box, and answer. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=list(BENCHMARK_DATASETS.keys()), | |
| value="Visual-CoT", | |
| label="Select Benchmark Dataset", | |
| info="Choose from 9 visual reasoning benchmarks" | |
| ) | |
| with gr.Column(scale=1): | |
| example_index = gr.Number( | |
| value=0, | |
| label="Example Index", | |
| precision=0, | |
| minimum=0, | |
| ) | |
| with gr.Row(): | |
| load_btn = gr.Button("Load Example", variant="primary") | |
| prev_btn = gr.Button("◀ Previous") | |
| next_btn = gr.Button("Next ▶") | |
| benchmark_status = gr.Textbox( | |
| label="Status", | |
| value="Select a dataset and click 'Load Example'", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Image") | |
| benchmark_image = gr.Image( | |
| label="Input Image", | |
| type="pil", | |
| height=400, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Annotations") | |
| benchmark_question = gr.Textbox( | |
| label="Question", | |
| lines=2, | |
| interactive=False, | |
| ) | |
| benchmark_bbox = gr.Textbox( | |
| label="Ground Truth Bounding Box", | |
| lines=1, | |
| interactive=False, | |
| ) | |
| benchmark_answer = gr.Textbox( | |
| label="Ground Truth Answer", | |
| lines=3, | |
| interactive=False, | |
| ) | |
| # Dataset information - dynamically generated from BENCHMARK_DATASETS | |
| dataset_info_md = "---\n\n### Available Benchmark Datasets\n\n" | |
| for i, (name, info) in enumerate(BENCHMARK_DATASETS.items(), 1): | |
| dataset_info_md += f"{i}. **{name}**: {info['description']}\n" | |
| dataset_info_md += f" - Path: `{info['path']}`\n" | |
| dataset_info_md += f"\n**Total:** {len(BENCHMARK_DATASETS)} benchmarks from Visual Chain-of-Thought Reasoning Collection\n" | |
| dataset_info_md += "\n**Source:** [Hugging Face Collection](https://huggingface.co/collections/tuandunghcmut/visual-chain-of-thought-reasoning-benchmarks)" | |
| gr.Markdown(dataset_info_md) | |
| # Event handlers | |
| def load_and_update(dataset_name, index): | |
| result = load_benchmark_example(dataset_name, int(index)) | |
| if len(result) == 5: | |
| return result | |
| else: | |
| # Error case | |
| return None, result, "", "", "" | |
| def increment_index(current_index): | |
| return int(current_index) + 1 | |
| def decrement_index(current_index): | |
| return max(0, int(current_index) - 1) | |
| load_btn.click( | |
| fn=load_and_update, | |
| inputs=[dataset_dropdown, example_index], | |
| outputs=[benchmark_image, benchmark_question, benchmark_bbox, benchmark_answer, benchmark_status], | |
| ) | |
| next_btn.click( | |
| fn=increment_index, | |
| inputs=[example_index], | |
| outputs=[example_index], | |
| ).then( | |
| fn=load_and_update, | |
| inputs=[dataset_dropdown, example_index], | |
| outputs=[benchmark_image, benchmark_question, benchmark_bbox, benchmark_answer, benchmark_status], | |
| ) | |
| prev_btn.click( | |
| fn=decrement_index, | |
| inputs=[example_index], | |
| outputs=[example_index], | |
| ).then( | |
| fn=load_and_update, | |
| inputs=[dataset_dropdown, example_index], | |
| outputs=[benchmark_image, benchmark_question, benchmark_bbox, benchmark_answer, benchmark_status], | |
| ) | |
| # ============================================================ | |
| # Tab 3: About & Paper | |
| # ============================================================ | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ## Paper Information | |
| **Title:** Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning | |
| **Authors:** Hao Shao, Shengju Qian, Han Xiao, Guanglu Song, Zhuofan Zong, Letian Wang, Yu Liu, Hongsheng Li | |
| **Conference:** NeurIPS 2024 (Spotlight) 🎉 | |
| **Abstract:** | |
| We introduce Visual-CoT, a comprehensive dataset and benchmark for evaluating chain-of-thought reasoning | |
| in multi-modal language models. Our dataset comprises 438K question-answer pairs with intermediate bounding | |
| box annotations highlighting key regions essential for answering questions. We propose a multi-turn processing | |
| pipeline that dynamically focuses on visual inputs and provides interpretable reasoning steps. | |
| --- | |
| ## Model Architecture | |
| ### Components | |
| 1. **Vision Encoder**: CLIP ViT-L/14 | |
| - Input resolution: 224px or 336px | |
| - Output: 577 visual tokens (336px) or 196 tokens (224px) | |
| - Feature dimension: 1024 | |
| 2. **Multi-modal Projector**: 2-layer MLP with GELU | |
| - Maps vision features (1024D) to LLM embedding space (4096D) | |
| - Trainable parameters: ~8.4M | |
| 3. **Language Model**: Vicuna v1.5 (instruction-tuned LLaMA) | |
| - Variants: 7B or 13B parameters | |
| - Context length: 2048 tokens | |
| - Base: LLaMA architecture | |
| ### Multi-Turn Processing Pipeline | |
| ``` | |
| Image + Question | |
| ↓ | |
| [Turn 1] ROI Detection | |
| → Outputs: Bounding box coordinates [x1, y1, x2, y2] | |
| → Purpose: Identify key regions for reasoning | |
| ↓ | |
| [Turn 2] Question Answering | |
| → Input: Image + Question + Detected bbox | |
| → Output: Final answer grounded in visual evidence | |
| ``` | |
| --- | |
| ## Training Strategy | |
| ### Stage 1: Feature Alignment (Pretrain) | |
| - **Dataset**: 558K LAION-CC-SBU subset with BLIP captions | |
| - **Objective**: Connect frozen CLIP encoder to frozen LLM | |
| - **Trainable**: Only the MLP projector (~8.4M params) | |
| - **Duration**: 3.5 hours (7B) to 5.5 hours (13B) on 8×A100 GPUs | |
| - **Hyperparameters**: | |
| - Batch size: 256 | |
| - Learning rate: 1e-3 | |
| - Epochs: 1 | |
| - Max sequence length: 2048 | |
| ### Stage 2: Visual Instruction Tuning | |
| - **Dataset Mix**: | |
| - 665K multimodal instruction-following (LLaVA-1.5) | |
| - 1.4M positional annotation data (Shikra) | |
| - 373K Visual-CoT data (ours) | |
| - **Total**: ~2.4M training instances | |
| - **Training Details**: | |
| - Duration: ~60 hours (7B-224) on 8×A100 GPUs | |
| - Batch size: 128 | |
| - Learning rate: 2e-5 (backbone), 2e-6 (vision encoder) | |
| - Epochs: 1 | |
| - DeepSpeed ZeRO-3 for memory efficiency | |
| --- | |
| ## Dataset Construction | |
| ### Visual-CoT Dataset (438K examples) | |
| **13 Diverse Benchmarks:** | |
| 1. **Document Understanding** (4 datasets): | |
| - DocVQA: Document visual QA | |
| - InfographicsVQA: Infographic comprehension | |
| - DUDE: Document understanding | |
| - SROIE: Scanned receipt information extraction | |
| 2. **Scene Understanding** (3 datasets): | |
| - GQA: Scene graph compositional reasoning | |
| - Visual7W: Pointing and telling tasks | |
| - VSR: Visual spatial reasoning | |
| 3. **Text in Images** (2 datasets): | |
| - TextVQA: Reading text in natural images | |
| - OCR-VQA: OCR-based question answering | |
| 4. **General VQA** (2 datasets): | |
| - Visual Genome: Dense annotations | |
| - COCO: Common objects in context | |
| 5. **Specialized** (2 datasets): | |
| - CUB: Fine-grained bird classification | |
| - Flickr30k: Image captioning & grounding | |
| **Annotation Details:** | |
| - Each example includes: image, question, answer, bounding box | |
| - Bounding boxes highlight key regions essential for reasoning | |
| - 98K examples have detailed reasoning steps | |
| - Train/val splits maintained from original benchmarks | |
| --- | |
| ## Evaluation & Results | |
| ### Visual-CoT Benchmark Metrics | |
| 1. **Answer Accuracy**: GPT-3.5-based evaluation | |
| - Compares generated answer with ground truth | |
| - Accounts for semantic equivalence | |
| - Results: 82.7% average accuracy | |
| 2. **Detection Accuracy**: IoU-based bounding box evaluation | |
| - IoU > 0.5 threshold for correct detection | |
| - Results: 75.3% detection accuracy | |
| - Validates spatial grounding ability | |
| 3. **Reasoning Quality**: Chain-of-thought coherence | |
| - Multi-turn consistency | |
| - Interpretability of intermediate steps | |
| ### Model Comparison | |
| | Model | Resolution | Params | Answer Acc | Detection Acc | | |
| |-------|-----------|---------|-----------|---------------| | |
| | VisCoT-7B-224 | 224px | 7B | 80.1% | 72.5% | | |
| | VisCoT-7B-336 | 336px | 7B | 81.8% | 74.2% | | |
| | VisCoT-13B-224 | 224px | 13B | 81.5% | 73.8% | | |
| | VisCoT-13B-336 | 336px | 13B | 82.7% | 75.3% | | |
| **Trade-offs:** | |
| - Higher resolution → Better detail recognition, slower inference | |
| - Larger model → Better reasoning, more memory | |
| - 336px + 13B = Best quality but highest compute cost | |
| --- | |
| ## Resources | |
| - **Paper**: [arXiv:2403.16999](https://arxiv.org/abs/2403.16999) | |
| - **Code**: [GitHub](https://github.com/deepcs233/Visual-CoT) | |
| - **Dataset**: [Hugging Face](https://huggingface.co/datasets/deepcs233/Visual-CoT) | |
| - **Project Page**: [https://hao-shao.com/projects/viscot.html](https://hao-shao.com/projects/viscot.html) | |
| - **Models**: | |
| - [VisCoT-7b-224](https://huggingface.co/deepcs233/VisCoT-7b-224) | |
| - [VisCoT-7b-336](https://huggingface.co/deepcs233/VisCoT-7b-336) | |
| - [VisCoT-13b-224](https://huggingface.co/deepcs233/VisCoT-13b-224) | |
| - [VisCoT-13b-336](https://huggingface.co/deepcs233/VisCoT-13b-336) | |
| --- | |
| ## Citation | |
| If you find our work useful, please cite: | |
| ```bibtex | |
| @article{shao2024visual, | |
| title={Visual CoT: Unleashing Chain-of-Thought Reasoning in Multi-Modal Language Models}, | |
| author={Shao, Hao and Qian, Shengju and Xiao, Han and Song, Guanglu and Zong, Zhuofan and Wang, Letian and Liu, Yu and Li, Hongsheng}, | |
| journal={arXiv preprint arXiv:2403.16999}, | |
| year={2024} | |
| } | |
| ``` | |
| --- | |
| ## License | |
| - **Code**: Apache License 2.0 | |
| - **Dataset**: Research use only | |
| - **Models**: Subject to base LLM license (LLaMA) | |
| --- | |
| ## Acknowledgements | |
| This work is built upon: | |
| - [LLaVA](https://github.com/haotian-liu/LLaVA) - Base architecture | |
| - [Shikra](https://github.com/shikras/shikra) - Positional annotations | |
| - [Vicuna](https://github.com/lm-sys/FastChat) - Language model | |
| - [CLIP](https://github.com/openai/CLIP) - Vision encoder | |
| """) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; color: #666; padding: 20px;"> | |
| <p>Powered by <a href="https://huggingface.co/docs/hub/spaces-zerogpu">Zero GPU</a> on Hugging Face Spaces</p> | |
| </div> | |
| """) | |
| return demo | |
| # ============================================================================= | |
| # Launch | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=20) # Enable queue for Zero GPU | |
| demo.launch() | |