Spaces:
Running
Running
| # src/auditor.py | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image, ImageDraw, ImageFilter | |
| from scipy import stats | |
| from sklearn.calibration import calibration_curve | |
| from sklearn.metrics import brier_score_loss | |
| class CounterfactualAnalyzer: | |
| """Analyze how predictions change with image perturbations.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def patch_perturbation_analysis(self, image, patch_size=16, perturbation_type="blur"): | |
| """ | |
| Analyze how predictions change when different patches are perturbed. | |
| Args: | |
| image: PIL Image | |
| patch_size: Size of patches to perturb | |
| perturbation_type: Type of perturbation ('blur', 'noise', 'blackout', 'gray') | |
| Returns: | |
| dict: Analysis results with visualizations | |
| """ | |
| original_probs, _, original_labels = self._predict_image(image) | |
| original_top_label = original_labels[0] | |
| original_confidence = original_probs[0] | |
| # Get image dimensions | |
| width, height = image.size | |
| # Create grid of patches | |
| patches_x = width // patch_size | |
| patches_y = height // patch_size | |
| # Store results | |
| confidence_changes = [] | |
| prediction_changes = [] | |
| patch_heatmap = np.zeros((patches_y, patches_x)) | |
| for i in range(patches_y): | |
| for j in range(patches_x): | |
| # Create perturbed image | |
| perturbed_img = self._perturb_patch( | |
| image.copy(), j, i, patch_size, perturbation_type | |
| ) | |
| # Get prediction on perturbed image | |
| perturbed_probs, _, perturbed_labels = self._predict_image(perturbed_img) | |
| perturbed_confidence = perturbed_probs[0] | |
| perturbed_label = perturbed_labels[0] | |
| # Calculate changes | |
| confidence_change = perturbed_confidence - original_confidence | |
| prediction_change = 1 if perturbed_label != original_top_label else 0 | |
| confidence_changes.append(confidence_change) | |
| prediction_changes.append(prediction_change) | |
| patch_heatmap[i, j] = confidence_change | |
| # Create visualization | |
| fig = self._create_counterfactual_visualization( | |
| image, | |
| patch_heatmap, | |
| patch_size, | |
| original_top_label, | |
| original_confidence, | |
| confidence_changes, | |
| prediction_changes, | |
| ) | |
| return { | |
| "figure": fig, | |
| "patch_heatmap": patch_heatmap, | |
| "avg_confidence_change": np.mean(confidence_changes), | |
| "prediction_flip_rate": np.mean(prediction_changes), | |
| "most_sensitive_patch": np.unravel_index(np.argmin(patch_heatmap), patch_heatmap.shape), | |
| } | |
| def _perturb_patch(self, image, patch_x, patch_y, patch_size, perturbation_type): | |
| """Apply perturbation to a specific patch.""" | |
| left = patch_x * patch_size | |
| upper = patch_y * patch_size | |
| right = left + patch_size | |
| lower = upper + patch_size | |
| patch_box = (left, upper, right, lower) | |
| if perturbation_type == "blur": | |
| # Extract patch, blur it, and paste back | |
| patch = image.crop(patch_box) | |
| blurred_patch = patch.filter(ImageFilter.GaussianBlur(5)) | |
| image.paste(blurred_patch, patch_box) | |
| elif perturbation_type == "blackout": | |
| # Black out the patch | |
| draw = ImageDraw.Draw(image) | |
| draw.rectangle(patch_box, fill="black") | |
| elif perturbation_type == "gray": | |
| # Convert patch to grayscale | |
| patch = image.crop(patch_box) | |
| gray_patch = patch.convert("L").convert("RGB") | |
| image.paste(gray_patch, patch_box) | |
| elif perturbation_type == "noise": | |
| # Add noise to patch | |
| patch = np.array(image.crop(patch_box)) | |
| noise = np.random.normal(0, 50, patch.shape).astype(np.uint8) | |
| noisy_patch = np.clip(patch + noise, 0, 255).astype(np.uint8) | |
| image.paste(Image.fromarray(noisy_patch), patch_box) | |
| return image | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_counterfactual_visualization( | |
| self, | |
| image, | |
| patch_heatmap, | |
| patch_size, | |
| original_label, | |
| original_confidence, | |
| confidence_changes, | |
| prediction_changes, | |
| ): | |
| """Create visualization for counterfactual analysis.""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) | |
| # Original image | |
| ax1.imshow(image) | |
| ax1.set_title( | |
| f"Original Image\nPrediction: {original_label} ({original_confidence:.2%})", | |
| fontweight="bold", | |
| ) | |
| ax1.axis("off") | |
| # Patch sensitivity heatmap | |
| im = ax2.imshow(patch_heatmap, cmap="RdBu_r", vmin=-0.5, vmax=0.5) | |
| ax2.set_title( | |
| "Patch Sensitivity Heatmap\n(Confidence Change When Perturbed)", fontweight="bold" | |
| ) | |
| ax2.set_xlabel("Patch X") | |
| ax2.set_ylabel("Patch Y") | |
| plt.colorbar(im, ax=ax2, label="Confidence Change") | |
| # Add patch grid to original image | |
| width, height = image.size | |
| for i in range(patch_heatmap.shape[0]): | |
| for j in range(patch_heatmap.shape[1]): | |
| rect = plt.Rectangle( | |
| (j * patch_size, i * patch_size), | |
| patch_size, | |
| patch_size, | |
| linewidth=1, | |
| edgecolor="red", | |
| facecolor="none", | |
| alpha=0.3, | |
| ) | |
| ax1.add_patch(rect) | |
| # Confidence change distribution | |
| ax3.hist(confidence_changes, bins=20, alpha=0.7, color="skyblue") | |
| ax3.axvline(0, color="red", linestyle="--", label="No Change") | |
| ax3.set_xlabel("Confidence Change") | |
| ax3.set_ylabel("Frequency") | |
| ax3.set_title("Distribution of Confidence Changes", fontweight="bold") | |
| ax3.legend() | |
| ax3.grid(alpha=0.3) | |
| # Prediction flip analysis | |
| flip_rate = np.mean(prediction_changes) | |
| ax4.bar(["No Flip", "Flip"], [1 - flip_rate, flip_rate], color=["green", "red"]) | |
| ax4.set_ylabel("Proportion") | |
| ax4.set_title(f"Prediction Flip Rate: {flip_rate:.2%}", fontweight="bold") | |
| ax4.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| class ConfidenceCalibrationAnalyzer: | |
| """Analyze model calibration and confidence metrics.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def analyze_calibration(self, test_images, test_labels=None, n_bins=10): | |
| """ | |
| Analyze model calibration using confidence scores. | |
| Args: | |
| test_images: List of PIL Images for testing | |
| test_labels: Optional true labels for accuracy calculation | |
| n_bins: Number of bins for calibration curve | |
| Returns: | |
| dict: Calibration analysis results | |
| """ | |
| confidences = [] | |
| predictions = [] | |
| max_confidences = [] | |
| # Get predictions and confidences | |
| for img in test_images: | |
| probs, indices, labels = self._predict_image(img) | |
| max_confidences.append(probs[0]) | |
| predictions.append(labels[0]) | |
| confidences.append(probs) | |
| max_confidences = np.array(max_confidences) | |
| # Create calibration analysis | |
| fig = self._create_calibration_visualization( | |
| max_confidences, test_labels, predictions, n_bins | |
| ) | |
| # Calculate calibration metrics | |
| calibration_metrics = self._calculate_calibration_metrics( | |
| max_confidences, test_labels, predictions | |
| ) | |
| return { | |
| "figure": fig, | |
| "metrics": calibration_metrics, | |
| "confidence_distribution": max_confidences, | |
| } | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_calibration_visualization(self, confidences, true_labels, predictions, n_bins): | |
| """Create calibration visualization.""" | |
| fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) | |
| # Confidence distribution | |
| ax1.hist(confidences, bins=20, alpha=0.7, color="lightblue", edgecolor="black") | |
| ax1.set_xlabel("Confidence Score") | |
| ax1.set_ylabel("Frequency") | |
| ax1.set_title("Distribution of Confidence Scores", fontweight="bold") | |
| ax1.axvline( | |
| np.mean(confidences), | |
| color="red", | |
| linestyle="--", | |
| label=f"Mean: {np.mean(confidences):.3f}", | |
| ) | |
| ax1.legend() | |
| ax1.grid(alpha=0.3) | |
| # Reliability diagram (if true labels available) | |
| if true_labels is not None: | |
| # Convert to binary correctness | |
| correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) | |
| fraction_of_positives, mean_predicted_prob = calibration_curve( | |
| correct, confidences, n_bins=n_bins, strategy="uniform" | |
| ) | |
| ax2.plot(mean_predicted_prob, fraction_of_positives, "s-", label="Model") | |
| ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") | |
| ax2.set_xlabel("Mean Predicted Probability") | |
| ax2.set_ylabel("Fraction of Positives") | |
| ax2.set_title("Reliability Diagram", fontweight="bold") | |
| ax2.legend() | |
| ax2.grid(alpha=0.3) | |
| # Calculate ECE | |
| bin_edges = np.linspace(0, 1, n_bins + 1) | |
| bin_indices = np.digitize(confidences, bin_edges) - 1 | |
| bin_indices = np.clip(bin_indices, 0, n_bins - 1) | |
| ece = 0 | |
| for bin_idx in range(n_bins): | |
| mask = bin_indices == bin_idx | |
| if np.sum(mask) > 0: | |
| bin_conf = np.mean(confidences[mask]) | |
| bin_acc = np.mean(correct[mask]) | |
| ece += (np.sum(mask) / len(confidences)) * np.abs(bin_acc - bin_conf) | |
| ax2.text( | |
| 0.1, | |
| 0.9, | |
| f"ECE: {ece:.3f}", | |
| transform=ax2.transAxes, | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7), | |
| ) | |
| # Confidence vs accuracy (if true labels available) | |
| if true_labels is not None: | |
| confidence_bins = np.linspace(0, 1, n_bins + 1) | |
| bin_accuracies = [] | |
| bin_confidences = [] | |
| for i in range(n_bins): | |
| mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i + 1]) | |
| if np.sum(mask) > 0: | |
| bin_acc = np.mean(correct[mask]) | |
| bin_conf = np.mean(confidences[mask]) | |
| bin_accuracies.append(bin_acc) | |
| bin_confidences.append(bin_conf) | |
| ax3.plot(bin_confidences, bin_accuracies, "o-", label="Model") | |
| ax3.plot([0, 1], [0, 1], "k--", label="Ideal") | |
| ax3.set_xlabel("Average Confidence") | |
| ax3.set_ylabel("Average Accuracy") | |
| ax3.set_title("Confidence vs Accuracy", fontweight="bold") | |
| ax3.legend() | |
| ax3.grid(alpha=0.3) | |
| # Top-1 vs Top-5 confidence gap | |
| if len(confidences) > 0 and isinstance(confidences[0], np.ndarray): | |
| top1_conf = [c[0] for c in confidences] | |
| top5_conf = [np.sum(c[:5]) for c in confidences] | |
| confidence_gap = [t1 - (t5 - t1) / 4 for t1, t5 in zip(top1_conf, top5_conf)] | |
| ax4.hist(confidence_gap, bins=20, alpha=0.7, color="lightgreen", edgecolor="black") | |
| ax4.set_xlabel("Confidence Gap (Top-1 vs Rest)") | |
| ax4.set_ylabel("Frequency") | |
| ax4.set_title("Distribution of Confidence Gaps", fontweight="bold") | |
| ax4.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig | |
| def _calculate_calibration_metrics(self, confidences, true_labels, predictions): | |
| """Calculate calibration metrics.""" | |
| metrics = { | |
| "mean_confidence": float(np.mean(confidences)), | |
| "confidence_std": float(np.std(confidences)), | |
| "overconfident_rate": float(np.mean(confidences > 0.8)), | |
| "underconfident_rate": float(np.mean(confidences < 0.2)), | |
| } | |
| if true_labels is not None: | |
| correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) | |
| accuracy = np.mean(correct) | |
| avg_confidence = np.mean(confidences) | |
| metrics.update( | |
| { | |
| "accuracy": float(accuracy), | |
| "confidence_gap": float(avg_confidence - accuracy), | |
| "brier_score": float(brier_score_loss(correct, confidences)), | |
| } | |
| ) | |
| return metrics | |
| class BiasDetector: | |
| """Detect potential biases in model performance across subgroups.""" | |
| def __init__(self, model, processor): | |
| self.model = model | |
| self.processor = processor | |
| self.device = next(model.parameters()).device | |
| def analyze_subgroup_performance(self, image_subsets, subset_names, true_labels_subsets=None): | |
| """ | |
| Analyze performance across different subgroups. | |
| Args: | |
| image_subsets: List of image subsets for each subgroup | |
| subset_names: Names for each subgroup | |
| true_labels_subsets: Optional true labels for each subset | |
| Returns: | |
| dict: Bias analysis results | |
| """ | |
| subgroup_metrics = {} | |
| for i, (subset, name) in enumerate(zip(image_subsets, subset_names)): | |
| confidences = [] | |
| predictions = [] | |
| for img in subset: | |
| probs, indices, labels = self._predict_image(img) | |
| confidences.append(probs[0]) | |
| predictions.append(labels[0]) | |
| metrics = { | |
| "mean_confidence": np.mean(confidences), | |
| "confidence_std": np.std(confidences), | |
| "sample_size": len(subset), | |
| } | |
| # Calculate accuracy if true labels provided | |
| if true_labels_subsets is not None and i < len(true_labels_subsets): | |
| true_labels = true_labels_subsets[i] | |
| correct = [pred == true for pred, true in zip(predictions, true_labels)] | |
| metrics["accuracy"] = np.mean(correct) | |
| metrics["error_rate"] = 1 - metrics["accuracy"] | |
| subgroup_metrics[name] = metrics | |
| # Create bias analysis visualization | |
| fig = self._create_bias_visualization(subgroup_metrics, true_labels_subsets is not None) | |
| # Calculate fairness metrics | |
| fairness_metrics = self._calculate_fairness_metrics(subgroup_metrics) | |
| return { | |
| "figure": fig, | |
| "subgroup_metrics": subgroup_metrics, | |
| "fairness_metrics": fairness_metrics, | |
| } | |
| def _predict_image(self, image): | |
| """Helper function to get predictions.""" | |
| from predictor import predict_image | |
| return predict_image(image, self.model, self.processor, top_k=5) | |
| def _create_bias_visualization(self, subgroup_metrics, has_accuracy): | |
| """Create visualization for bias analysis.""" | |
| if has_accuracy: | |
| fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) | |
| else: | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) | |
| subgroups = list(subgroup_metrics.keys()) | |
| # Confidence by subgroup | |
| confidences = [metrics["mean_confidence"] for metrics in subgroup_metrics.values()] | |
| ax1.bar(subgroups, confidences, color="lightblue", alpha=0.7) | |
| ax1.set_ylabel("Mean Confidence") | |
| ax1.set_title("Mean Confidence by Subgroup", fontweight="bold") | |
| ax1.tick_params(axis="x", rotation=45) | |
| ax1.grid(axis="y", alpha=0.3) | |
| # Add confidence values on bars | |
| for i, v in enumerate(confidences): | |
| ax1.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom") | |
| # Sample sizes | |
| sample_sizes = [metrics["sample_size"] for metrics in subgroup_metrics.values()] | |
| ax2.bar(subgroups, sample_sizes, color="lightgreen", alpha=0.7) | |
| ax2.set_ylabel("Sample Size") | |
| ax2.set_title("Sample Size by Subgroup", fontweight="bold") | |
| ax2.tick_params(axis="x", rotation=45) | |
| ax2.grid(axis="y", alpha=0.3) | |
| # Add sample size values on bars | |
| for i, v in enumerate(sample_sizes): | |
| ax2.text(i, v + max(sample_sizes) * 0.01, f"{v}", ha="center", va="bottom") | |
| # Accuracy by subgroup (if available) | |
| if has_accuracy: | |
| accuracies = [metrics.get("accuracy", 0) for metrics in subgroup_metrics.values()] | |
| ax3.bar(subgroups, accuracies, color="lightcoral", alpha=0.7) | |
| ax3.set_ylabel("Accuracy") | |
| ax3.set_title("Accuracy by Subgroup", fontweight="bold") | |
| ax3.tick_params(axis="x", rotation=45) | |
| ax3.grid(axis="y", alpha=0.3) | |
| # Add accuracy values on bars | |
| for i, v in enumerate(accuracies): | |
| ax3.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom") | |
| plt.tight_layout() | |
| return fig | |
| def _calculate_fairness_metrics(self, subgroup_metrics): | |
| """Calculate fairness metrics.""" | |
| fairness_metrics = {} | |
| # Check if we have accuracy metrics | |
| has_accuracy = all("accuracy" in metrics for metrics in subgroup_metrics.values()) | |
| if has_accuracy and len(subgroup_metrics) >= 2: | |
| accuracies = [metrics["accuracy"] for metrics in subgroup_metrics.values()] | |
| confidences = [metrics["mean_confidence"] for metrics in subgroup_metrics.values()] | |
| fairness_metrics = { | |
| "accuracy_range": float(max(accuracies) - min(accuracies)), | |
| "accuracy_std": float(np.std(accuracies)), | |
| "confidence_range": float(max(confidences) - min(confidences)), | |
| "max_accuracy_disparity": float( | |
| max(accuracies) / min(accuracies) if min(accuracies) > 0 else float("inf") | |
| ), | |
| } | |
| return fairness_metrics | |
| # Convenience function to create all auditors | |
| def create_auditors(model, processor): | |
| """Create all auditing analyzers.""" | |
| return { | |
| "counterfactual": CounterfactualAnalyzer(model, processor), | |
| "calibration": ConfidenceCalibrationAnalyzer(model, processor), | |
| "bias": BiasDetector(model, processor), | |
| } | |