import os import gc import cv2 import torch import timm import numpy as np import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file # --- 引入新的可视化库 --- # pip install grad-cam from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad, LayerCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image # --- 配置 --- SEED = 4421 DROP_RATE = 0.1 LOCAL_CKPT_DIR = "./checkpoints" DEFAULT_CKPT = "convformer_b36.ai_cls" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] # 设置随机种子 torch.manual_seed(SEED) np.random.seed(SEED) if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED) CKPT_META = { "caformer_b36.old": { "num_classes": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "repo_id": "telecomadm1145/swin-ai-detection", "filename": "caformer_b36_4class_96.safetensors", "labels": ["non_ai", "ai", "ani_non_ai", "ani_ai"], "input_size": 384, }, "caformer_b36.v2": { "num_classes": 2, "head": "timm_cross_entropy", "backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full", "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v2", "filename": "pytorch_model.bin", "labels": ["AI", "Non-AI"], "input_size": 384, }, "convformer_b36.ai_cls": { "num_classes": 2, "head": "timm_cross_entropy", "backbone_timm_name": "convformer_b36.sail_in1k", "repo_id": "telecomadm1145/convformer_b36.ai_cls", "filename": "best_checkpoint.pth", "labels": ["AI", "Non-AI"], "input_size": 224, }, "deepghs/cls-ai-check-1m.caformer_s36.r512": { "num_classes": 2, "head": "timm_cross_entropy", "backbone_timm_name": "caformer_s36.sail_in22k_ft_in1k_384", "repo_id": "deepghs/cls-ai-check-1m.caformer_s36.r512", "filename": "model.safetensors", "labels": ["AI", "Non-AI"], "input_size": 512, }, } # --- 模型定义 --- class TimmClassifierWithHead(nn.Module): def __init__(self, model_name, num_classes, pretrained=True): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(DROP_RATE * 0.8), nn.Linear(64, num_classes), ) def forward(self, x): features = self.backbone(x) return self.classifier(features) # --- 关键工具函数:处理 Transformer/Hybrid 输出维度 --- def reshape_transform(tensor): """ 处理 Timm 新版模型 (CaFormer/ConvFormer/Swin) 的输出维度。 如果 tensor 是 [B, H, W, C] (Channels Last),转换为 [B, C, H, W] 给 CAM 库使用。 """ # 如果是 4D 张量 if tensor.ndim == 4: # 启发式判断:通常 C 通道数在几百上千,而 H, W 较小。 # 如果最后一个维度特别大,或者第1维度(索引)比第3维度小很多,那可能是 BHWC h, w, c = tensor.shape[1], tensor.shape[2], tensor.shape[3] # 很多 timm 模型内部 block 输出是 BHWC (Permuted) # 例如: tensor.shape = [1, 56, 56, 384] -> 需要变为 [1, 384, 56, 56] if c > h and c > w: return tensor.permute(0, 3, 1, 2) # 如果是 3D 张量 (ViT 输出: [B, Tokens, C]),需要还原成 2D 图片 if tensor.ndim == 3: # 这里需要知道原始 feature map 的 h, w,稍显复杂 # 简化处理:假设是方形 b, num_tokens, c = tensor.shape # 除去 cls_token (如果有) if num_tokens % 2 != 0: # 简单假设去掉第一个 tensor = tensor[:, 1:, :] num_tokens -= 1 side = int(np.sqrt(num_tokens)) result = tensor.reshape(b, side, side, c).permute(0, 3, 1, 2) return result return tensor # --- 核心逻辑类 --- class AIImageDetector: def __init__(self): self.model = None self.current_ckpt_name = None self.current_meta = None self.cam_engine = None # 替代原来的 grad_cam def _get_target_layers(self, model): """准确找到适合可视化的层""" backbone = getattr(model, 'backbone', model) # 优先定位到最后一个 Stage 的最后一个 Block if hasattr(backbone, 'stages'): # CaFormer / ConvFormer / ResNet 结构 # 这是一个 ModuleList,取最后一个 stage last_stage = backbone.stages[-1] # 取最后一个 block if hasattr(last_stage, 'blocks'): return [last_stage.blocks[-1]] return [last_stage] # 针对 Swin 等 if hasattr(backbone, 'layers'): return [backbone.layers[-1].blocks[-1]] # 兜底 return [list(backbone.children())[-1]] def load_model(self, ckpt_name): if ckpt_name == self.current_ckpt_name and self.model is not None: return print(f"Loading model: {ckpt_name}...") meta = CKPT_META[ckpt_name] # 显存清理 if self.model is not None: del self.model del self.cam_engine torch.cuda.empty_cache() gc.collect() ckpt_file = hf_hub_download(repo_id=meta["repo_id"], filename=meta["filename"], local_dir=LOCAL_CKPT_DIR, force_download=False) if meta.get("head") == "timm_cross_entropy": model = timm.create_model(meta["backbone_timm_name"], pretrained=False, num_classes=meta["num_classes"]) else: model = TimmClassifierWithHead(meta["backbone"], num_classes=meta["num_classes"], pretrained=False) model = model.to(DEVICE) # 权重加载逻辑 if meta["filename"].endswith(".safetensors"): state_dict = load_file(ckpt_file, device=DEVICE) else: state_dict = torch.load(ckpt_file, map_location=DEVICE) if "model_state_dict" in state_dict: state_dict = state_dict["model_state_dict"] elif "state_dict" in state_dict: state_dict = state_dict["state_dict"] state_dict = {k.replace("module.", "").replace("model.", ""): v for k, v in state_dict.items()} model.load_state_dict(state_dict, strict=False) model.eval() self.model = model self.current_ckpt_name = ckpt_name self.current_meta = meta self.cam_engine = None # 懒加载 def preprocess(self, image: Image.Image): if image.mode != "RGB": image = image.convert("RGB") if image.mode != "RGBA" else Image.new("RGBA", image.size, (255, 255, 255)).alpha_composite(image).convert("RGB") meta = self.current_meta size = int(meta.get("input_size", 384)) image_resized = image.resize((size, size), Image.BICUBIC) # 归一化并转 Tensor img_np = np.array(image_resized).astype(np.float32) / 255.0 # 供可视化使用 (0-1 float, RGB) img_viz_np = img_np.copy() mean = np.array(meta.get("mean", IMAGENET_MEAN), dtype=np.float32).reshape(1, 1, 3) std = np.array(meta.get("std", IMAGENET_STD), dtype=np.float32).reshape(1, 1, 3) img_norm = (img_np - mean) / std img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE) return img_tensor, img_viz_np def predict(self, image, ckpt_name, enable_viz, method_name): if image is None: return None, None self.load_model(ckpt_name) input_tensor, img_viz_np = self.preprocess(image) labels = self.current_meta["labels"] if enable_viz: target_layers = self._get_target_layers(self.model) # 动态选择 CAM 方法 # LayerCAM 在检测 AI 伪影(纹理异常)方面通常比 GradCAM 更好 cam_cls = LayerCAM if method_name == "LayerCAM" else GradCAMPlusPlus if method_name == "GradCAM": cam_cls = GradCAM # 初始化 CAM # 注意:reshape_transform 对于 CaFormer/ConvFormer 至关重要 cam = cam_cls(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform) # 获取预测类别作为目标 with torch.no_grad(): logits = self.model(input_tensor) probs = F.softmax(logits, dim=1) predicted_class = torch.argmax(probs, dim=1).item() # 生成热力图 targets = [ClassifierOutputTarget(predicted_class)] grayscale_cam = cam(input_tensor=input_tensor, targets=targets) grayscale_cam = grayscale_cam[0, :] # 叠加 visualization = show_cam_on_image(img_viz_np, grayscale_cam, use_rgb=True) viz_image = Image.fromarray(visualization) result_dict = {labels[i]: float(probs.detach().cpu().numpy()[0][i]) for i in range(len(labels))} # 清理 Hook (防止显存泄露) # cam.del_hooks() # 旧版本可能需要手动删,新版通常会自动管理或复用 return result_dict, viz_image else: with torch.no_grad(): logits = self.model(input_tensor) probs = F.softmax(logits, dim=1).cpu().numpy()[0] result_dict = {labels[i]: float(probs[i]) for i in range(len(labels))} return result_dict, None detector = AIImageDetector() # --- Gradio 界面 --- def gr_predict(image, ckpt_name, enable_viz, method): return detector.predict(image, ckpt_name, enable_viz, method) def launch(): detector.load_model(DEFAULT_CKPT) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## 🕵️ AI Generated Image Detection & Explainability") with gr.Row(): with gr.Column(scale=1): in_img = gr.Image(type="pil", label="Input Image") sel_ckpt = gr.Dropdown(list(CKPT_META.keys()), value=DEFAULT_CKPT, label="Model Checkpoint") with gr.Group(): enable_viz = gr.Checkbox(label="Visualize Heatmap", value=True) # 增加了方法选择 viz_method = gr.Dropdown(["LayerCAM", "GradCAM++", "GradCAM"], value="LayerCAM", label="Method (LayerCAM recommended for texture)") run_btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=1): out_lbl = gr.Label(num_top_classes=4, label="Prediction Confidence") out_viz = gr.Image(type="pil", label="Attention / Heatmap") run_btn.click( fn=gr_predict, inputs=[in_img, sel_ckpt, enable_viz, viz_method], outputs=[out_lbl, out_viz] ) demo.launch() if __name__ == "__main__": launch()