import gradio as gr import base64 import json import os from PIL import Image import io from handler import EndpointHandler # Initialize handler print("Initializing MobileCLIP handler...") try: handler = EndpointHandler() print(f"Handler initialized successfully! Device: {handler.device}") except Exception as e: print(f"Error initializing handler: {e}") handler = None def classify_image(image, top_k=10): """ Main classification function for public interface. """ if handler is None: return "Error: Handler not initialized", None if image is None: return "Please upload an image", None try: # Convert PIL image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode() # Call handler result = handler({ "inputs": { "image": img_b64, "top_k": int(top_k) } }) # Format results for display if isinstance(result, list): # Create formatted output output_text = "**Top {} Classifications:**\n\n".format(len(result)) # Create data for bar chart (list of tuples) chart_data = [] for i, item in enumerate(result, 1): score_pct = item['score'] * 100 output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n" chart_data.append((item['label'], item['score'])) return output_text, chart_data else: return f"Error: {result.get('error', 'Unknown error')}", None except Exception as e: return f"Error: {str(e)}", None def upsert_labels_admin(admin_token, new_items_json): """ Admin function to add new labels. """ if handler is None: return "Error: Handler not initialized" if not admin_token: return "Error: Admin token required" try: # Parse the JSON input items = json.loads(new_items_json) if new_items_json else [] result = handler({ "inputs": { "op": "upsert_labels", "token": admin_token, "items": items } }) if result.get("status") == "ok": return f"✅ Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" else: return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}" except json.JSONDecodeError: return "❌ Error: Invalid JSON format" except Exception as e: return f"❌ Error: {str(e)}" def reload_labels_admin(admin_token, version): """ Admin function to reload a specific label version. """ if handler is None: return "Error: Handler not initialized" if not admin_token: return "Error: Admin token required" try: result = handler({ "inputs": { "op": "reload_labels", "token": admin_token, "version": int(version) if version else 1 } }) if result.get("status") == "ok": return f"✅ Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}" elif result.get("status") == "nochange": return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" elif result.get("error") == "invalid_version": return "❌ Error: Invalid version number" else: return f"❌ Error: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Error: {str(e)}" def get_current_stats(): """ Get current label statistics. """ if handler is None: return "Handler not initialized" try: num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0 version = getattr(handler, 'labels_version', 1) device = handler.device if hasattr(handler, 'device') else "unknown" stats = f""" **Current Statistics:** - Number of labels: {num_labels} - Labels version: {version} - Device: {device} - Model: MobileCLIP-B """ if hasattr(handler, 'class_names') and len(handler.class_names) > 0: stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}" if len(handler.class_names) > 5: stats += "..." return stats except Exception as e: return f"Error getting stats: {str(e)}" def get_labels_table(): """ Get all current labels as a formatted table for display. """ if handler is None: return "Handler not initialized" if not hasattr(handler, 'class_ids') or len(handler.class_ids) == 0: return "No labels currently loaded" try: # Create a formatted table of labels table_data = [] for id, name in zip(handler.class_ids, handler.class_names): table_data.append([int(id), name]) return table_data except Exception as e: return f"Error getting labels: {str(e)}" def remove_labels_admin(admin_token, ids_to_remove_str): """ Admin function to remove labels by ID. """ if handler is None: return "Error: Handler not initialized" if not admin_token: return "Error: Admin token required" try: # Parse the IDs from comma-separated string if not ids_to_remove_str or ids_to_remove_str.strip() == "": return "❌ Error: Please provide IDs to remove (comma-separated)" ids_to_remove = [] for id_str in ids_to_remove_str.split(','): id_str = id_str.strip() if id_str: ids_to_remove.append(int(id_str)) if not ids_to_remove: return "❌ Error: No valid IDs provided" # Get names of items to be removed for confirmation removed_names = [] if hasattr(handler, 'class_ids'): for id in ids_to_remove: if id in handler.class_ids: idx = handler.class_ids.index(id) removed_names.append(f"{id}: {handler.class_names[idx]}") result = handler({ "inputs": { "op": "remove_labels", "token": admin_token, "ids": ids_to_remove } }) if result.get("status") == "ok": removed_list = "\n".join(removed_names) if removed_names else "None found" return f"✅ Success! Removed {result.get('removed', 0)} labels. Current version: {result.get('labels_version', 'unknown')}\n\nRemoved items:\n{removed_list}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" elif result.get("error") == "no_ids_provided": return "❌ Error: No IDs provided" else: return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}" except ValueError: return "❌ Error: Invalid ID format. Please provide comma-separated numbers (e.g., 1001,1002,1003)" except Exception as e: return f"❌ Error: {str(e)}" # Create Gradio interface print("Creating Gradio interface...") with gr.Blocks(title="MobileCLIP Image Classifier") as demo: gr.Markdown(""" # 🖼️ MobileCLIP-B Zero-Shot Image Classifier Upload an image to classify it using MobileCLIP-B model with dynamic label management. """) with gr.Tab("🔍 Image Classification"): with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload Image" ) top_k_slider = gr.Slider( minimum=1, maximum=50, value=10, step=1, label="Number of top results to show" ) classify_btn = gr.Button("🚀 Classify Image", variant="primary") with gr.Column(): output_text = gr.Markdown(label="Classification Results") # Simplified bar chart using Dataframe output_chart = gr.Dataframe( headers=["Label", "Confidence"], label="Classification Scores", interactive=False ) # Event handler for classification classify_btn.click( fn=classify_image, inputs=[input_image, top_k_slider], outputs=[output_text, output_chart], api_name="classify_image" ) # Also trigger on image upload input_image.change( fn=classify_image, inputs=[input_image, top_k_slider], outputs=[output_text, output_chart], api_name="classify_image_1" ) with gr.Tab("🔧 Admin Panel"): gr.Markdown(""" ### Admin Functions **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`) """) with gr.Row(): admin_token_input = gr.Textbox( label="Admin Token", type="password", placeholder="Enter admin token" ) with gr.Accordion("📊 Current Statistics", open=True): stats_display = gr.Markdown(value=get_current_stats()) refresh_stats_btn = gr.Button("🔄 Refresh Stats") refresh_stats_btn.click( fn=get_current_stats, inputs=[], outputs=stats_display ) with gr.Accordion("➕ Add New Labels", open=False): gr.Markdown(""" Add new labels by providing JSON array: ```json [ {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"}, {"id": 101, "name": "another_object", "prompt": "a photo of another_object"} ] ``` """) new_items_input = gr.Code( label="New Items JSON", language="json", lines=5, value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]' ) upsert_btn = gr.Button("➕ Add Labels", variant="primary") upsert_output = gr.Markdown() upsert_btn.click( fn=upsert_labels_admin, inputs=[admin_token_input, new_items_input], outputs=upsert_output, api_name="upsert_labels_admin" ) with gr.Accordion("🔄 Reload Label Version", open=False): gr.Markdown("Reload labels from a specific version stored in the Hub") version_input = gr.Number( label="Version Number", value=1, precision=0 ) reload_btn = gr.Button("🔄 Reload Version", variant="primary") reload_output = gr.Markdown() reload_btn.click( fn=reload_labels_admin, inputs=[admin_token_input, version_input], outputs=reload_output ) with gr.Accordion("🗑️ Remove Labels", open=False): gr.Markdown("Remove specific labels by their IDs") # Display current labels labels_table = gr.Dataframe( value=get_labels_table(), headers=["ID", "Name"], label="Current Labels", interactive=False, height=300 ) refresh_labels_btn = gr.Button("🔄 Refresh Label List", size="sm") refresh_labels_btn.click( fn=get_labels_table, inputs=[], outputs=labels_table ) gr.Markdown("Enter IDs to remove (comma-separated):") ids_to_remove_input = gr.Textbox( label="IDs to Remove", placeholder="e.g., 1001, 1002, 1003", lines=1 ) remove_btn = gr.Button("🗑️ Remove Selected Labels", variant="stop") remove_output = gr.Markdown() def remove_and_refresh(token, ids): result = remove_labels_admin(token, ids) updated_table = get_labels_table() return result, updated_table remove_btn.click( fn=remove_and_refresh, inputs=[admin_token_input, ids_to_remove_input], outputs=[remove_output, labels_table] ) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About MobileCLIP-B Classifier This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification. ### Features: - 🚀 **Fast inference**: < 30ms on GPU - 🏷️ **Dynamic labels**: Add/update labels without redeployment - 🔄 **Version control**: Track and reload label versions - 📊 **Visual results**: Classification scores and confidence ### Environment Variables (set in Space Settings): - `ADMIN_TOKEN`: Secret token for admin operations - `HF_LABEL_REPO`: Hub repository for label storage - `HF_WRITE_TOKEN`: Token with write permissions to label repo - `HF_READ_TOKEN`: Token with read permissions (optional) ### Model Details: - **Architecture**: MobileCLIP-B with MobileOne blocks - **Text Encoder**: Transformer-based, 77 token context - **Image Size**: 224x224 - **Embedding Dim**: 512 ### License: Model weights are licensed under Apple Sample Code License (ASCL). """) print("Gradio interface created successfully!") # Add pure API endpoint for base64 classification (as suggested by GPT) def classify_base64(image_b64: str, top_k: int = 10): """ API-only endpoint that accepts base64 images directly. This enables direct API calls from backends without file uploads. """ if handler is None: return {"error": "handler not initialized"} try: # Call handler directly with base64 result = handler({ "inputs": { "image": image_b64, "top_k": int(top_k) } }) return result except Exception as e: return {"error": str(e)} # Register the API endpoint (no UI) with demo: gr.Interface( fn=classify_base64, inputs=[ gr.Textbox(label="image_b64", visible=False), gr.Number(label="top_k", visible=False) ], outputs=gr.JSON(visible=False), api_name="classify_base64", visible=False ) if __name__ == "__main__": print("Launching Gradio app...") demo.launch()