Spaces:
Sleeping
Sleeping
| import os | |
| import gc | |
| import cv2 | |
| import torch | |
| import timm | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| # --- 引入新的可视化库 --- | |
| # pip install grad-cam | |
| from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad, LayerCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| # --- 配置 --- | |
| SEED = 4421 | |
| DROP_RATE = 0.1 | |
| LOCAL_CKPT_DIR = "./checkpoints" | |
| DEFAULT_CKPT = "convformer_b36.ai_cls" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
| IMAGENET_STD = [0.229, 0.224, 0.225] | |
| # 设置随机种子 | |
| torch.manual_seed(SEED) | |
| np.random.seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| CKPT_META = { | |
| "caformer_b36.old": { | |
| "num_classes": 4, | |
| "head": "v7", | |
| "backbone": "caformer_b36.sail_in22k_ft_in1k_384", | |
| "repo_id": "telecomadm1145/swin-ai-detection", | |
| "filename": "caformer_b36_4class_96.safetensors", | |
| "labels": ["non_ai", "ai", "ani_non_ai", "ani_ai"], | |
| "input_size": 384, | |
| }, | |
| "caformer_b36.v2": { | |
| "num_classes": 2, | |
| "head": "timm_cross_entropy", | |
| "backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full", | |
| "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v2", | |
| "filename": "pytorch_model.bin", | |
| "labels": ["AI", "Non-AI"], | |
| "input_size": 384, | |
| }, | |
| "convformer_b36.ai_cls": { | |
| "num_classes": 2, | |
| "head": "timm_cross_entropy", | |
| "backbone_timm_name": "convformer_b36.sail_in1k", | |
| "repo_id": "telecomadm1145/convformer_b36.ai_cls", | |
| "filename": "best_checkpoint.pth", | |
| "labels": ["AI", "Non-AI"], | |
| "input_size": 224, | |
| }, | |
| "deepghs/cls-ai-check-1m.caformer_s36.r512": { | |
| "num_classes": 2, | |
| "head": "timm_cross_entropy", | |
| "backbone_timm_name": "caformer_s36.sail_in22k_ft_in1k_384", | |
| "repo_id": "deepghs/cls-ai-check-1m.caformer_s36.r512", | |
| "filename": "model.safetensors", | |
| "labels": ["AI", "Non-AI"], | |
| "input_size": 512, | |
| }, | |
| } | |
| # --- 模型定义 --- | |
| class TimmClassifierWithHead(nn.Module): | |
| def __init__(self, model_name, num_classes, pretrained=True): | |
| super().__init__() | |
| self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(DROP_RATE), | |
| nn.Linear(self.backbone.num_features, 64), | |
| nn.BatchNorm1d(64), | |
| nn.GELU(), | |
| nn.Dropout(DROP_RATE * 0.8), | |
| nn.Linear(64, num_classes), | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| return self.classifier(features) | |
| # --- 关键工具函数:处理 Transformer/Hybrid 输出维度 --- | |
| def reshape_transform(tensor): | |
| """ | |
| 处理 Timm 新版模型 (CaFormer/ConvFormer/Swin) 的输出维度。 | |
| 如果 tensor 是 [B, H, W, C] (Channels Last),转换为 [B, C, H, W] 给 CAM 库使用。 | |
| """ | |
| # 如果是 4D 张量 | |
| if tensor.ndim == 4: | |
| # 启发式判断:通常 C 通道数在几百上千,而 H, W 较小。 | |
| # 如果最后一个维度特别大,或者第1维度(索引)比第3维度小很多,那可能是 BHWC | |
| h, w, c = tensor.shape[1], tensor.shape[2], tensor.shape[3] | |
| # 很多 timm 模型内部 block 输出是 BHWC (Permuted) | |
| # 例如: tensor.shape = [1, 56, 56, 384] -> 需要变为 [1, 384, 56, 56] | |
| if c > h and c > w: | |
| return tensor.permute(0, 3, 1, 2) | |
| # 如果是 3D 张量 (ViT 输出: [B, Tokens, C]),需要还原成 2D 图片 | |
| if tensor.ndim == 3: | |
| # 这里需要知道原始 feature map 的 h, w,稍显复杂 | |
| # 简化处理:假设是方形 | |
| b, num_tokens, c = tensor.shape | |
| # 除去 cls_token (如果有) | |
| if num_tokens % 2 != 0: | |
| # 简单假设去掉第一个 | |
| tensor = tensor[:, 1:, :] | |
| num_tokens -= 1 | |
| side = int(np.sqrt(num_tokens)) | |
| result = tensor.reshape(b, side, side, c).permute(0, 3, 1, 2) | |
| return result | |
| return tensor | |
| # --- 核心逻辑类 --- | |
| class AIImageDetector: | |
| def __init__(self): | |
| self.model = None | |
| self.current_ckpt_name = None | |
| self.current_meta = None | |
| self.cam_engine = None # 替代原来的 grad_cam | |
| def _get_target_layers(self, model): | |
| """准确找到适合可视化的层""" | |
| backbone = getattr(model, 'backbone', model) | |
| # 优先定位到最后一个 Stage 的最后一个 Block | |
| if hasattr(backbone, 'stages'): | |
| # CaFormer / ConvFormer / ResNet 结构 | |
| # 这是一个 ModuleList,取最后一个 stage | |
| last_stage = backbone.stages[-1] | |
| # 取最后一个 block | |
| if hasattr(last_stage, 'blocks'): | |
| return [last_stage.blocks[-1]] | |
| return [last_stage] | |
| # 针对 Swin 等 | |
| if hasattr(backbone, 'layers'): | |
| return [backbone.layers[-1].blocks[-1]] | |
| # 兜底 | |
| return [list(backbone.children())[-1]] | |
| def load_model(self, ckpt_name): | |
| if ckpt_name == self.current_ckpt_name and self.model is not None: | |
| return | |
| print(f"Loading model: {ckpt_name}...") | |
| meta = CKPT_META[ckpt_name] | |
| # 显存清理 | |
| if self.model is not None: | |
| del self.model | |
| del self.cam_engine | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| ckpt_file = hf_hub_download(repo_id=meta["repo_id"], filename=meta["filename"], local_dir=LOCAL_CKPT_DIR, force_download=False) | |
| if meta.get("head") == "timm_cross_entropy": | |
| model = timm.create_model(meta["backbone_timm_name"], pretrained=False, num_classes=meta["num_classes"]) | |
| else: | |
| model = TimmClassifierWithHead(meta["backbone"], num_classes=meta["num_classes"], pretrained=False) | |
| model = model.to(DEVICE) | |
| # 权重加载逻辑 | |
| if meta["filename"].endswith(".safetensors"): | |
| state_dict = load_file(ckpt_file, device=DEVICE) | |
| else: | |
| state_dict = torch.load(ckpt_file, map_location=DEVICE) | |
| if "model_state_dict" in state_dict: | |
| state_dict = state_dict["model_state_dict"] | |
| elif "state_dict" in state_dict: | |
| state_dict = state_dict["state_dict"] | |
| state_dict = {k.replace("module.", "").replace("model.", ""): v for k, v in state_dict.items()} | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| self.model = model | |
| self.current_ckpt_name = ckpt_name | |
| self.current_meta = meta | |
| self.cam_engine = None # 懒加载 | |
| def preprocess(self, image: Image.Image): | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") if image.mode != "RGBA" else Image.new("RGBA", image.size, (255, 255, 255)).alpha_composite(image).convert("RGB") | |
| meta = self.current_meta | |
| size = int(meta.get("input_size", 384)) | |
| image_resized = image.resize((size, size), Image.BICUBIC) | |
| # 归一化并转 Tensor | |
| img_np = np.array(image_resized).astype(np.float32) / 255.0 | |
| # 供可视化使用 (0-1 float, RGB) | |
| img_viz_np = img_np.copy() | |
| mean = np.array(meta.get("mean", IMAGENET_MEAN), dtype=np.float32).reshape(1, 1, 3) | |
| std = np.array(meta.get("std", IMAGENET_STD), dtype=np.float32).reshape(1, 1, 3) | |
| img_norm = (img_np - mean) / std | |
| img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE) | |
| return img_tensor, img_viz_np | |
| def predict(self, image, ckpt_name, enable_viz, method_name): | |
| if image is None: | |
| return None, None | |
| self.load_model(ckpt_name) | |
| input_tensor, img_viz_np = self.preprocess(image) | |
| labels = self.current_meta["labels"] | |
| if enable_viz: | |
| target_layers = self._get_target_layers(self.model) | |
| # 动态选择 CAM 方法 | |
| # LayerCAM 在检测 AI 伪影(纹理异常)方面通常比 GradCAM 更好 | |
| cam_cls = LayerCAM if method_name == "LayerCAM" else GradCAMPlusPlus | |
| if method_name == "GradCAM": cam_cls = GradCAM | |
| # 初始化 CAM | |
| # 注意:reshape_transform 对于 CaFormer/ConvFormer 至关重要 | |
| cam = cam_cls(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform) | |
| # 获取预测类别作为目标 | |
| with torch.no_grad(): | |
| logits = self.model(input_tensor) | |
| probs = F.softmax(logits, dim=1) | |
| predicted_class = torch.argmax(probs, dim=1).item() | |
| # 生成热力图 | |
| targets = [ClassifierOutputTarget(predicted_class)] | |
| grayscale_cam = cam(input_tensor=input_tensor, targets=targets) | |
| grayscale_cam = grayscale_cam[0, :] | |
| # 叠加 | |
| visualization = show_cam_on_image(img_viz_np, grayscale_cam, use_rgb=True) | |
| viz_image = Image.fromarray(visualization) | |
| result_dict = {labels[i]: float(probs.detach().cpu().numpy()[0][i]) for i in range(len(labels))} | |
| # 清理 Hook (防止显存泄露) | |
| # cam.del_hooks() # 旧版本可能需要手动删,新版通常会自动管理或复用 | |
| return result_dict, viz_image | |
| else: | |
| with torch.no_grad(): | |
| logits = self.model(input_tensor) | |
| probs = F.softmax(logits, dim=1).cpu().numpy()[0] | |
| result_dict = {labels[i]: float(probs[i]) for i in range(len(labels))} | |
| return result_dict, None | |
| detector = AIImageDetector() | |
| # --- Gradio 界面 --- | |
| def gr_predict(image, ckpt_name, enable_viz, method): | |
| return detector.predict(image, ckpt_name, enable_viz, method) | |
| def launch(): | |
| detector.load_model(DEFAULT_CKPT) | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## 🕵️ AI Generated Image Detection & Explainability") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| in_img = gr.Image(type="pil", label="Input Image") | |
| sel_ckpt = gr.Dropdown(list(CKPT_META.keys()), value=DEFAULT_CKPT, label="Model Checkpoint") | |
| with gr.Group(): | |
| enable_viz = gr.Checkbox(label="Visualize Heatmap", value=True) | |
| # 增加了方法选择 | |
| viz_method = gr.Dropdown(["LayerCAM", "GradCAM++", "GradCAM"], value="LayerCAM", label="Method (LayerCAM recommended for texture)") | |
| run_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1): | |
| out_lbl = gr.Label(num_top_classes=4, label="Prediction Confidence") | |
| out_viz = gr.Image(type="pil", label="Attention / Heatmap") | |
| run_btn.click( | |
| fn=gr_predict, | |
| inputs=[in_img, sel_ckpt, enable_viz, viz_method], | |
| outputs=[out_lbl, out_viz] | |
| ) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| launch() |