import gradio as gr import base64 import json import os from PIL import Image import io from handler import EndpointHandler # Initialize handler print("Initializing MobileCLIP handler...") try: handler = EndpointHandler() print(f"Handler initialized successfully! Device: {handler.device}") except Exception as e: print(f"Error initializing handler: {e}") handler = None def classify_image(image, top_k=10): """ Main classification function for public interface. """ if handler is None: return "Error: Handler not initialized", None if image is None: return "Please upload an image", None try: # Convert PIL image to base64 buffered = io.BytesIO() image.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode() # Call handler result = handler({ "inputs": { "image": img_b64, "top_k": int(top_k) } }) # Format results for display if isinstance(result, list): # Create formatted output output_text = "**Top {} Classifications:**\n\n".format(len(result)) # Create data for bar chart (list of tuples) chart_data = [] for i, item in enumerate(result, 1): score_pct = item['score'] * 100 output_text += f"{i}. **{item['label']}** (ID: {item['id']}): {score_pct:.2f}%\n" chart_data.append((item['label'], item['score'])) return output_text, chart_data else: return f"Error: {result.get('error', 'Unknown error')}", None except Exception as e: return f"Error: {str(e)}", None def upsert_labels_admin(admin_token, new_items_json): """ Admin function to add new labels. """ if handler is None: return "Error: Handler not initialized" if not admin_token: return "Error: Admin token required" try: # Parse the JSON input items = json.loads(new_items_json) if new_items_json else [] result = handler({ "inputs": { "op": "upsert_labels", "token": admin_token, "items": items } }) if result.get("status") == "ok": return f"✅ Success! Added {result.get('added', 0)} new labels. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" else: return f"❌ Error: {result.get('detail', result.get('error', 'Unknown error'))}" except json.JSONDecodeError: return "❌ Error: Invalid JSON format" except Exception as e: return f"❌ Error: {str(e)}" def reload_labels_admin(admin_token, version): """ Admin function to reload a specific label version. """ if handler is None: return "Error: Handler not initialized" if not admin_token: return "Error: Admin token required" try: result = handler({ "inputs": { "op": "reload_labels", "token": admin_token, "version": int(version) if version else 1 } }) if result.get("status") == "ok": return f"✅ Labels reloaded successfully! Current version: {result.get('labels_version', 'unknown')}" elif result.get("status") == "nochange": return f"ℹ️ No change needed. Current version: {result.get('labels_version', 'unknown')}" elif result.get("error") == "unauthorized": return "❌ Error: Invalid admin token" elif result.get("error") == "invalid_version": return "❌ Error: Invalid version number" else: return f"❌ Error: {result.get('error', 'Unknown error')}" except Exception as e: return f"❌ Error: {str(e)}" def get_current_stats(): """ Get current label statistics. """ if handler is None: return "Handler not initialized" try: num_labels = len(handler.class_ids) if hasattr(handler, 'class_ids') else 0 version = getattr(handler, 'labels_version', 1) device = handler.device if hasattr(handler, 'device') else "unknown" stats = f""" **Current Statistics:** - Number of labels: {num_labels} - Labels version: {version} - Device: {device} - Model: MobileCLIP-B """ if hasattr(handler, 'class_names') and len(handler.class_names) > 0: stats += f"\n- Sample labels: {', '.join(handler.class_names[:5])}" if len(handler.class_names) > 5: stats += "..." return stats except Exception as e: return f"Error getting stats: {str(e)}" # Create Gradio interface print("Creating Gradio interface...") with gr.Blocks(title="MobileCLIP Image Classifier") as demo: gr.Markdown(""" # 🖼️ MobileCLIP-B Zero-Shot Image Classifier Upload an image to classify it using MobileCLIP-B model with dynamic label management. """) with gr.Tab("🔍 Image Classification"): with gr.Row(): with gr.Column(): input_image = gr.Image( type="pil", label="Upload Image" ) top_k_slider = gr.Slider( minimum=1, maximum=50, value=10, step=1, label="Number of top results to show" ) classify_btn = gr.Button("🚀 Classify Image", variant="primary") with gr.Column(): output_text = gr.Markdown(label="Classification Results") # Simplified bar chart using Dataframe output_chart = gr.Dataframe( headers=["Label", "Confidence"], label="Classification Scores", interactive=False ) # Event handler for classification classify_btn.click( fn=classify_image, inputs=[input_image, top_k_slider], outputs=[output_text, output_chart] ) # Also trigger on image upload input_image.change( fn=classify_image, inputs=[input_image, top_k_slider], outputs=[output_text, output_chart] ) with gr.Tab("🔧 Admin Panel"): gr.Markdown(""" ### Admin Functions **Note:** Requires admin token (set via environment variable `ADMIN_TOKEN`) """) with gr.Row(): admin_token_input = gr.Textbox( label="Admin Token", type="password", placeholder="Enter admin token" ) with gr.Accordion("📊 Current Statistics", open=True): stats_display = gr.Markdown(value=get_current_stats()) refresh_stats_btn = gr.Button("🔄 Refresh Stats") refresh_stats_btn.click( fn=get_current_stats, inputs=[], outputs=stats_display ) with gr.Accordion("➕ Add New Labels", open=False): gr.Markdown(""" Add new labels by providing JSON array: ```json [ {"id": 100, "name": "new_object", "prompt": "a photo of a new_object"}, {"id": 101, "name": "another_object", "prompt": "a photo of another_object"} ] ``` """) new_items_input = gr.Code( label="New Items JSON", language="json", lines=5, value='[\n {"id": 100, "name": "example", "prompt": "a photo of example"}\n]' ) upsert_btn = gr.Button("➕ Add Labels", variant="primary") upsert_output = gr.Markdown() upsert_btn.click( fn=upsert_labels_admin, inputs=[admin_token_input, new_items_input], outputs=upsert_output ) with gr.Accordion("🔄 Reload Label Version", open=False): gr.Markdown("Reload labels from a specific version stored in the Hub") version_input = gr.Number( label="Version Number", value=1, precision=0 ) reload_btn = gr.Button("🔄 Reload Version", variant="primary") reload_output = gr.Markdown() reload_btn.click( fn=reload_labels_admin, inputs=[admin_token_input, version_input], outputs=reload_output ) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About MobileCLIP-B Classifier This Space provides a web interface for Apple's MobileCLIP-B model, optimized for fast zero-shot image classification. ### Features: - 🚀 **Fast inference**: < 30ms on GPU - 🏷️ **Dynamic labels**: Add/update labels without redeployment - 🔄 **Version control**: Track and reload label versions - 📊 **Visual results**: Classification scores and confidence ### Environment Variables (set in Space Settings): - `ADMIN_TOKEN`: Secret token for admin operations - `HF_LABEL_REPO`: Hub repository for label storage - `HF_WRITE_TOKEN`: Token with write permissions to label repo - `HF_READ_TOKEN`: Token with read permissions (optional) ### Model Details: - **Architecture**: MobileCLIP-B with MobileOne blocks - **Text Encoder**: Transformer-based, 77 token context - **Image Size**: 224x224 - **Embedding Dim**: 512 ### License: Model weights are licensed under Apple Sample Code License (ASCL). """) print("Gradio interface created successfully!") if __name__ == "__main__": print("Launching Gradio app...") demo.launch()