# src/auditor.py import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.nn.functional as F from PIL import Image, ImageDraw, ImageFilter from scipy import stats from sklearn.calibration import calibration_curve from sklearn.metrics import brier_score_loss class CounterfactualAnalyzer: """Analyze how predictions change with image perturbations.""" def __init__(self, model, processor): self.model = model self.processor = processor self.device = next(model.parameters()).device def patch_perturbation_analysis(self, image, patch_size=16, perturbation_type="blur"): """ Analyze how predictions change when different patches are perturbed. Args: image: PIL Image patch_size: Size of patches to perturb perturbation_type: Type of perturbation ('blur', 'noise', 'blackout', 'gray') Returns: dict: Analysis results with visualizations """ original_probs, _, original_labels = self._predict_image(image) original_top_label = original_labels[0] original_confidence = original_probs[0] # Get image dimensions width, height = image.size # Create grid of patches patches_x = width // patch_size patches_y = height // patch_size # Store results confidence_changes = [] prediction_changes = [] patch_heatmap = np.zeros((patches_y, patches_x)) for i in range(patches_y): for j in range(patches_x): # Create perturbed image perturbed_img = self._perturb_patch( image.copy(), j, i, patch_size, perturbation_type ) # Get prediction on perturbed image perturbed_probs, _, perturbed_labels = self._predict_image(perturbed_img) perturbed_confidence = perturbed_probs[0] perturbed_label = perturbed_labels[0] # Calculate changes confidence_change = perturbed_confidence - original_confidence prediction_change = 1 if perturbed_label != original_top_label else 0 confidence_changes.append(confidence_change) prediction_changes.append(prediction_change) patch_heatmap[i, j] = confidence_change # Create visualization fig = self._create_counterfactual_visualization( image, patch_heatmap, patch_size, original_top_label, original_confidence, confidence_changes, prediction_changes, ) return { "figure": fig, "patch_heatmap": patch_heatmap, "avg_confidence_change": np.mean(confidence_changes), "prediction_flip_rate": np.mean(prediction_changes), "most_sensitive_patch": np.unravel_index(np.argmin(patch_heatmap), patch_heatmap.shape), } def _perturb_patch(self, image, patch_x, patch_y, patch_size, perturbation_type): """Apply perturbation to a specific patch.""" left = patch_x * patch_size upper = patch_y * patch_size right = left + patch_size lower = upper + patch_size patch_box = (left, upper, right, lower) if perturbation_type == "blur": # Extract patch, blur it, and paste back patch = image.crop(patch_box) blurred_patch = patch.filter(ImageFilter.GaussianBlur(5)) image.paste(blurred_patch, patch_box) elif perturbation_type == "blackout": # Black out the patch draw = ImageDraw.Draw(image) draw.rectangle(patch_box, fill="black") elif perturbation_type == "gray": # Convert patch to grayscale patch = image.crop(patch_box) gray_patch = patch.convert("L").convert("RGB") image.paste(gray_patch, patch_box) elif perturbation_type == "noise": # Add noise to patch patch = np.array(image.crop(patch_box)) noise = np.random.normal(0, 50, patch.shape).astype(np.uint8) noisy_patch = np.clip(patch + noise, 0, 255).astype(np.uint8) image.paste(Image.fromarray(noisy_patch), patch_box) return image def _predict_image(self, image): """Helper function to get predictions.""" from predictor import predict_image return predict_image(image, self.model, self.processor, top_k=5) def _create_counterfactual_visualization( self, image, patch_heatmap, patch_size, original_label, original_confidence, confidence_changes, prediction_changes, ): """Create visualization for counterfactual analysis.""" fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) # Original image ax1.imshow(image) ax1.set_title( f"Original Image\nPrediction: {original_label} ({original_confidence:.2%})", fontweight="bold", ) ax1.axis("off") # Patch sensitivity heatmap im = ax2.imshow(patch_heatmap, cmap="RdBu_r", vmin=-0.5, vmax=0.5) ax2.set_title( "Patch Sensitivity Heatmap\n(Confidence Change When Perturbed)", fontweight="bold" ) ax2.set_xlabel("Patch X") ax2.set_ylabel("Patch Y") plt.colorbar(im, ax=ax2, label="Confidence Change") # Add patch grid to original image width, height = image.size for i in range(patch_heatmap.shape[0]): for j in range(patch_heatmap.shape[1]): rect = plt.Rectangle( (j * patch_size, i * patch_size), patch_size, patch_size, linewidth=1, edgecolor="red", facecolor="none", alpha=0.3, ) ax1.add_patch(rect) # Confidence change distribution ax3.hist(confidence_changes, bins=20, alpha=0.7, color="skyblue") ax3.axvline(0, color="red", linestyle="--", label="No Change") ax3.set_xlabel("Confidence Change") ax3.set_ylabel("Frequency") ax3.set_title("Distribution of Confidence Changes", fontweight="bold") ax3.legend() ax3.grid(alpha=0.3) # Prediction flip analysis flip_rate = np.mean(prediction_changes) ax4.bar(["No Flip", "Flip"], [1 - flip_rate, flip_rate], color=["green", "red"]) ax4.set_ylabel("Proportion") ax4.set_title(f"Prediction Flip Rate: {flip_rate:.2%}", fontweight="bold") ax4.grid(alpha=0.3) plt.tight_layout() return fig class ConfidenceCalibrationAnalyzer: """Analyze model calibration and confidence metrics.""" def __init__(self, model, processor): self.model = model self.processor = processor self.device = next(model.parameters()).device def analyze_calibration(self, test_images, test_labels=None, n_bins=10): """ Analyze model calibration using confidence scores. Args: test_images: List of PIL Images for testing test_labels: Optional true labels for accuracy calculation n_bins: Number of bins for calibration curve Returns: dict: Calibration analysis results """ confidences = [] predictions = [] max_confidences = [] # Get predictions and confidences for img in test_images: probs, indices, labels = self._predict_image(img) max_confidences.append(probs[0]) predictions.append(labels[0]) confidences.append(probs) max_confidences = np.array(max_confidences) # Create calibration analysis fig = self._create_calibration_visualization( max_confidences, test_labels, predictions, n_bins ) # Calculate calibration metrics calibration_metrics = self._calculate_calibration_metrics( max_confidences, test_labels, predictions ) return { "figure": fig, "metrics": calibration_metrics, "confidence_distribution": max_confidences, } def _predict_image(self, image): """Helper function to get predictions.""" from predictor import predict_image return predict_image(image, self.model, self.processor, top_k=5) def _create_calibration_visualization(self, confidences, true_labels, predictions, n_bins): """Create calibration visualization.""" fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) # Confidence distribution ax1.hist(confidences, bins=20, alpha=0.7, color="lightblue", edgecolor="black") ax1.set_xlabel("Confidence Score") ax1.set_ylabel("Frequency") ax1.set_title("Distribution of Confidence Scores", fontweight="bold") ax1.axvline( np.mean(confidences), color="red", linestyle="--", label=f"Mean: {np.mean(confidences):.3f}", ) ax1.legend() ax1.grid(alpha=0.3) # Reliability diagram (if true labels available) if true_labels is not None: # Convert to binary correctness correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) fraction_of_positives, mean_predicted_prob = calibration_curve( correct, confidences, n_bins=n_bins, strategy="uniform" ) ax2.plot(mean_predicted_prob, fraction_of_positives, "s-", label="Model") ax2.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated") ax2.set_xlabel("Mean Predicted Probability") ax2.set_ylabel("Fraction of Positives") ax2.set_title("Reliability Diagram", fontweight="bold") ax2.legend() ax2.grid(alpha=0.3) # Calculate ECE bin_edges = np.linspace(0, 1, n_bins + 1) bin_indices = np.digitize(confidences, bin_edges) - 1 bin_indices = np.clip(bin_indices, 0, n_bins - 1) ece = 0 for bin_idx in range(n_bins): mask = bin_indices == bin_idx if np.sum(mask) > 0: bin_conf = np.mean(confidences[mask]) bin_acc = np.mean(correct[mask]) ece += (np.sum(mask) / len(confidences)) * np.abs(bin_acc - bin_conf) ax2.text( 0.1, 0.9, f"ECE: {ece:.3f}", transform=ax2.transAxes, bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7), ) # Confidence vs accuracy (if true labels available) if true_labels is not None: confidence_bins = np.linspace(0, 1, n_bins + 1) bin_accuracies = [] bin_confidences = [] for i in range(n_bins): mask = (confidences >= confidence_bins[i]) & (confidences < confidence_bins[i + 1]) if np.sum(mask) > 0: bin_acc = np.mean(correct[mask]) bin_conf = np.mean(confidences[mask]) bin_accuracies.append(bin_acc) bin_confidences.append(bin_conf) ax3.plot(bin_confidences, bin_accuracies, "o-", label="Model") ax3.plot([0, 1], [0, 1], "k--", label="Ideal") ax3.set_xlabel("Average Confidence") ax3.set_ylabel("Average Accuracy") ax3.set_title("Confidence vs Accuracy", fontweight="bold") ax3.legend() ax3.grid(alpha=0.3) # Top-1 vs Top-5 confidence gap if len(confidences) > 0 and isinstance(confidences[0], np.ndarray): top1_conf = [c[0] for c in confidences] top5_conf = [np.sum(c[:5]) for c in confidences] confidence_gap = [t1 - (t5 - t1) / 4 for t1, t5 in zip(top1_conf, top5_conf)] ax4.hist(confidence_gap, bins=20, alpha=0.7, color="lightgreen", edgecolor="black") ax4.set_xlabel("Confidence Gap (Top-1 vs Rest)") ax4.set_ylabel("Frequency") ax4.set_title("Distribution of Confidence Gaps", fontweight="bold") ax4.grid(alpha=0.3) plt.tight_layout() return fig def _calculate_calibration_metrics(self, confidences, true_labels, predictions): """Calculate calibration metrics.""" metrics = { "mean_confidence": float(np.mean(confidences)), "confidence_std": float(np.std(confidences)), "overconfident_rate": float(np.mean(confidences > 0.8)), "underconfident_rate": float(np.mean(confidences < 0.2)), } if true_labels is not None: correct = np.array([pred == true for pred, true in zip(predictions, true_labels)]) accuracy = np.mean(correct) avg_confidence = np.mean(confidences) metrics.update( { "accuracy": float(accuracy), "confidence_gap": float(avg_confidence - accuracy), "brier_score": float(brier_score_loss(correct, confidences)), } ) return metrics class BiasDetector: """Detect potential biases in model performance across subgroups.""" def __init__(self, model, processor): self.model = model self.processor = processor self.device = next(model.parameters()).device def analyze_subgroup_performance(self, image_subsets, subset_names, true_labels_subsets=None): """ Analyze performance across different subgroups. Args: image_subsets: List of image subsets for each subgroup subset_names: Names for each subgroup true_labels_subsets: Optional true labels for each subset Returns: dict: Bias analysis results """ subgroup_metrics = {} for i, (subset, name) in enumerate(zip(image_subsets, subset_names)): confidences = [] predictions = [] for img in subset: probs, indices, labels = self._predict_image(img) confidences.append(probs[0]) predictions.append(labels[0]) metrics = { "mean_confidence": np.mean(confidences), "confidence_std": np.std(confidences), "sample_size": len(subset), } # Calculate accuracy if true labels provided if true_labels_subsets is not None and i < len(true_labels_subsets): true_labels = true_labels_subsets[i] correct = [pred == true for pred, true in zip(predictions, true_labels)] metrics["accuracy"] = np.mean(correct) metrics["error_rate"] = 1 - metrics["accuracy"] subgroup_metrics[name] = metrics # Create bias analysis visualization fig = self._create_bias_visualization(subgroup_metrics, true_labels_subsets is not None) # Calculate fairness metrics fairness_metrics = self._calculate_fairness_metrics(subgroup_metrics) return { "figure": fig, "subgroup_metrics": subgroup_metrics, "fairness_metrics": fairness_metrics, } def _predict_image(self, image): """Helper function to get predictions.""" from predictor import predict_image return predict_image(image, self.model, self.processor, top_k=5) def _create_bias_visualization(self, subgroup_metrics, has_accuracy): """Create visualization for bias analysis.""" if has_accuracy: fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5)) else: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) subgroups = list(subgroup_metrics.keys()) # Confidence by subgroup confidences = [metrics["mean_confidence"] for metrics in subgroup_metrics.values()] ax1.bar(subgroups, confidences, color="lightblue", alpha=0.7) ax1.set_ylabel("Mean Confidence") ax1.set_title("Mean Confidence by Subgroup", fontweight="bold") ax1.tick_params(axis="x", rotation=45) ax1.grid(axis="y", alpha=0.3) # Add confidence values on bars for i, v in enumerate(confidences): ax1.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom") # Sample sizes sample_sizes = [metrics["sample_size"] for metrics in subgroup_metrics.values()] ax2.bar(subgroups, sample_sizes, color="lightgreen", alpha=0.7) ax2.set_ylabel("Sample Size") ax2.set_title("Sample Size by Subgroup", fontweight="bold") ax2.tick_params(axis="x", rotation=45) ax2.grid(axis="y", alpha=0.3) # Add sample size values on bars for i, v in enumerate(sample_sizes): ax2.text(i, v + max(sample_sizes) * 0.01, f"{v}", ha="center", va="bottom") # Accuracy by subgroup (if available) if has_accuracy: accuracies = [metrics.get("accuracy", 0) for metrics in subgroup_metrics.values()] ax3.bar(subgroups, accuracies, color="lightcoral", alpha=0.7) ax3.set_ylabel("Accuracy") ax3.set_title("Accuracy by Subgroup", fontweight="bold") ax3.tick_params(axis="x", rotation=45) ax3.grid(axis="y", alpha=0.3) # Add accuracy values on bars for i, v in enumerate(accuracies): ax3.text(i, v + 0.01, f"{v:.3f}", ha="center", va="bottom") plt.tight_layout() return fig def _calculate_fairness_metrics(self, subgroup_metrics): """Calculate fairness metrics.""" fairness_metrics = {} # Check if we have accuracy metrics has_accuracy = all("accuracy" in metrics for metrics in subgroup_metrics.values()) if has_accuracy and len(subgroup_metrics) >= 2: accuracies = [metrics["accuracy"] for metrics in subgroup_metrics.values()] confidences = [metrics["mean_confidence"] for metrics in subgroup_metrics.values()] fairness_metrics = { "accuracy_range": float(max(accuracies) - min(accuracies)), "accuracy_std": float(np.std(accuracies)), "confidence_range": float(max(confidences) - min(confidences)), "max_accuracy_disparity": float( max(accuracies) / min(accuracies) if min(accuracies) > 0 else float("inf") ), } return fairness_metrics # Convenience function to create all auditors def create_auditors(model, processor): """Create all auditing analyzers.""" return { "counterfactual": CounterfactualAnalyzer(model, processor), "calibration": ConfidenceCalibrationAnalyzer(model, processor), "bias": BiasDetector(model, processor), }