🎯 ViT Auditing Toolkit
Comprehensive Model Explainability and Validation Dashboard
# app.py import os import sys import time import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch from PIL import Image # Add src to path sys.path.append(os.path.join(os.path.dirname(__file__), "src")) from auditor import create_auditors from explainer import explain_attention, explain_gradcam, explain_gradient_shap from model_loader import SUPPORTED_MODELS, load_model_and_processor from predictor import create_prediction_plot, predict_image from utils import get_top_predictions_dict, preprocess_image # Global variables to cache model and processor model = None processor = None current_model_name = None auditors = None def load_selected_model(model_name): """Load the selected model and cache it globally.""" global model, processor, current_model_name, auditors try: if model is None or current_model_name != model_name: print(f"Loading model: {model_name}") model, processor = load_model_and_processor(model_name) current_model_name = model_name # Initialize auditors auditors = create_auditors(model, processor) print("✅ Model and auditors loaded successfully!") return f"✅ Model loaded: {model_name}" except Exception as e: return f"❌ Error loading model: {str(e)}" def analyze_image_basic(image, model_choice, xai_method, layer_index, head_index): """ Basic explainability analysis - the core function for Tab 1. """ try: # Load model if needed model_status = load_selected_model(SUPPORTED_MODELS[model_choice]) if "❌" in model_status: return None, None, None, model_status # Preprocess image if image is None: return None, None, None, "⚠️ Please upload an image first." processed_image = preprocess_image(image) # Get predictions probs, indices, labels = predict_image(processed_image, model, processor) pred_fig = create_prediction_plot(probs, labels) # Generate explanation based on selected method explanation_fig = None explanation_image = None if xai_method == "Attention Visualization": explanation_fig = explain_attention( model, processor, processed_image, layer_index=layer_index, head_index=head_index ) elif xai_method == "GradCAM": explanation_fig, explanation_image = explain_gradcam(model, processor, processed_image) elif xai_method == "GradientSHAP": explanation_fig = explain_gradient_shap(model, processor, processed_image, n_samples=3) # Convert predictions to dictionary for Gradio Label pred_dict = get_top_predictions_dict(probs, labels) return ( processed_image, pred_fig, explanation_fig, f"✅ Analysis complete! Top prediction: {labels[0]} ({probs[0]:.2%})", ) except Exception as e: error_msg = f"❌ Analysis failed: {str(e)}" print(error_msg) return None, None, None, error_msg def analyze_counterfactual(image, model_choice, patch_size, perturbation_type): """ Counterfactual analysis for Tab 2. """ try: # Load model if needed model_status = load_selected_model(SUPPORTED_MODELS[model_choice]) if "❌" in model_status: return None, None, model_status if image is None: return None, None, "⚠️ Please upload an image first." processed_image = preprocess_image(image) # Perform counterfactual analysis results = auditors["counterfactual"].patch_perturbation_analysis( processed_image, patch_size=patch_size, perturbation_type=perturbation_type ) # Create summary message summary = ( f"🔍 Counterfactual Analysis Complete!\n" f"• Avg confidence change: {results['avg_confidence_change']:.4f}\n" f"• Prediction flip rate: {results['prediction_flip_rate']:.2%}\n" f"• Most sensitive patch: {results['most_sensitive_patch']}" ) return results["figure"], summary except Exception as e: error_msg = f"❌ Counterfactual analysis failed: {str(e)}" print(error_msg) return None, error_msg def analyze_calibration(image, model_choice, n_bins): """ Confidence calibration analysis for Tab 3. """ try: # Load model if needed model_status = load_selected_model(SUPPORTED_MODELS[model_choice]) if "❌" in model_status: return None, None, model_status if image is None: return None, None, "⚠️ Please upload an image first." processed_image = preprocess_image(image) # For demo purposes, create a simple test set from the uploaded image # In a real scenario, you'd use a proper validation set test_images = [processed_image] * 10 # Create multiple copies # Perform calibration analysis results = auditors["calibration"].analyze_calibration(test_images, n_bins=n_bins) # Create summary message metrics = results["metrics"] summary = ( f"📊 Calibration Analysis Complete!\n" f"• Mean confidence: {metrics['mean_confidence']:.3f}\n" f"• Overconfident rate: {metrics['overconfident_rate']:.2%}\n" f"• Underconfident rate: {metrics['underconfident_rate']:.2%}" ) return results["figure"], summary except Exception as e: error_msg = f"❌ Calibration analysis failed: {str(e)}" print(error_msg) return None, error_msg def analyze_bias_detection(image, model_choice): """ Bias detection analysis for Tab 4. """ try: # Load model if needed model_status = load_selected_model(SUPPORTED_MODELS[model_choice]) if "❌" in model_status: return None, None, model_status if image is None: return None, None, "⚠️ Please upload an image first." processed_image = preprocess_image(image) # Create demo subgroups based on the uploaded image # In a real scenario, you'd use predefined subgroups from your dataset subsets = [] subset_names = ["Original", "Brightness+", "Brightness-", "Contrast+"] # Original image subsets.append([processed_image]) # Brightness increased bright_image = processed_image.copy().point(lambda p: min(255, p * 1.5)) subsets.append([bright_image]) # Brightness decreased dark_image = processed_image.copy().point(lambda p: p * 0.7) subsets.append([dark_image]) # Contrast increased contrast_image = processed_image.copy().point(lambda p: 128 + (p - 128) * 1.5) subsets.append([contrast_image]) # Perform bias analysis results = auditors["bias"].analyze_subgroup_performance(subsets, subset_names) # Create summary message subgroup_metrics = results["subgroup_metrics"] summary = f"⚖️ Bias Detection Complete!\nAnalyzed {len(subgroup_metrics)} subgroups:\n" for name, metrics in subgroup_metrics.items(): summary += f"• {name}: confidence={metrics['mean_confidence']:.3f}\n" return results["figure"], summary except Exception as e: error_msg = f"❌ Bias detection failed: {str(e)}" print(error_msg) return None, error_msg def create_demo_image(): """Create a demo image for first-time users.""" # Create a simple demo image with multiple colors img = Image.new("RGB", (224, 224), color=(150, 100, 100)) # Add different colored regions for x in range(50, 150): for y in range(50, 150): img.putpixel((x, y), (100, 200, 100)) # Green square for x in range(160, 200): for y in range(160, 200): img.putpixel((x, y), (100, 100, 200)) # Blue square return img # Minimal CSS for basic styling without breaking functionality custom_css = """ /* Basic styling without interfering with dropdowns */ .gradio-container { background: linear-gradient(135deg, #0f1419 0%, #1a1f2e 50%, #0f1419 100%); font-family: 'Inter', sans-serif; } /* Header styling */ .main-header { background: rgba(99, 102, 241, 0.05); border-radius: 20px; padding: 2.5rem; margin-bottom: 2rem; } /* Button styling */ button.primary { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%); border: none; color: white; font-weight: 600; padding: 14px 32px; border-radius: 12px; } button.primary:hover { transform: translateY(-2px); box-shadow: 0 6px 24px rgba(99, 102, 241, 0.6); } /* Block styling */ .block { background: rgba(30, 41, 59, 0.4); border-radius: 16px; padding: 1.5rem; border: 1px solid rgba(99, 102, 241, 0.15); } /* Tab styling */ .tab-nav button { background: rgba(30, 41, 59, 0.5); border: 1px solid rgba(99, 102, 241, 0.2); border-radius: 12px; padding: 14px 28px; margin: 0 6px; color: #94a3b8; font-weight: 600; } .tab-nav button.selected { background: linear-gradient(135deg, #6366f1 0%, #8b5cf6 100%); color: white; } """ # Create the Gradio interface with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="ViT Auditing Toolkit") as demo: # Main Header gr.HTML( """
Comprehensive Model Explainability and Validation Dashboard
This interactive dashboard provides comprehensive auditing capabilities for Vision Transformer models, enabling researchers and practitioners to understand, validate, and improve their AI models through multiple explainability techniques.
Understand model predictions with attention maps, GradCAM, and SHAP visualizations
Test prediction robustness by systematically perturbing image regions
Evaluate whether model confidence scores accurately reflect prediction reliability
Identify performance variations across different demographic or data subgroups
Choose a Vision Transformer model from the dropdown and click "Load Model" button
Navigate to any tab and upload an image you want to analyze
Select from 4 tabs: Basic Explainability, Counterfactual Analysis, Confidence Calibration, or Bias Detection
Adjust settings if needed, then click the analysis button to see results and visualizations
💡 Tip: Start with "Basic Explainability" to understand what your model sees, then explore advanced auditing features in other tabs.
Built with ❤️ using Gradio, Transformers, and Captum
© 2024 ViT Auditing Toolkit • For research and educational purposes