AIDetectV2 / app.py
telecomadm1145's picture
Update app.py
d34b195 verified
import os
import gc
import cv2
import torch
import timm
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
# --- 引入新的可视化库 ---
# pip install grad-cam
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad, LayerCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# --- 配置 ---
SEED = 4421
DROP_RATE = 0.1
LOCAL_CKPT_DIR = "./checkpoints"
DEFAULT_CKPT = "convformer_b36.ai_cls"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
# 设置随机种子
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
CKPT_META = {
"caformer_b36.old": {
"num_classes": 4,
"head": "v7",
"backbone": "caformer_b36.sail_in22k_ft_in1k_384",
"repo_id": "telecomadm1145/swin-ai-detection",
"filename": "caformer_b36_4class_96.safetensors",
"labels": ["non_ai", "ai", "ani_non_ai", "ani_ai"],
"input_size": 384,
},
"caformer_b36.v2": {
"num_classes": 2,
"head": "timm_cross_entropy",
"backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full",
"repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v2",
"filename": "pytorch_model.bin",
"labels": ["AI", "Non-AI"],
"input_size": 384,
},
"convformer_b36.ai_cls": {
"num_classes": 2,
"head": "timm_cross_entropy",
"backbone_timm_name": "convformer_b36.sail_in1k",
"repo_id": "telecomadm1145/convformer_b36.ai_cls",
"filename": "best_checkpoint.pth",
"labels": ["AI", "Non-AI"],
"input_size": 224,
},
"deepghs/cls-ai-check-1m.caformer_s36.r512": {
"num_classes": 2,
"head": "timm_cross_entropy",
"backbone_timm_name": "caformer_s36.sail_in22k_ft_in1k_384",
"repo_id": "deepghs/cls-ai-check-1m.caformer_s36.r512",
"filename": "model.safetensors",
"labels": ["AI", "Non-AI"],
"input_size": 512,
},
}
# --- 模型定义 ---
class TimmClassifierWithHead(nn.Module):
def __init__(self, model_name, num_classes, pretrained=True):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
self.classifier = nn.Sequential(
nn.Dropout(DROP_RATE),
nn.Linear(self.backbone.num_features, 64),
nn.BatchNorm1d(64),
nn.GELU(),
nn.Dropout(DROP_RATE * 0.8),
nn.Linear(64, num_classes),
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features)
# --- 关键工具函数:处理 Transformer/Hybrid 输出维度 ---
def reshape_transform(tensor):
"""
处理 Timm 新版模型 (CaFormer/ConvFormer/Swin) 的输出维度。
如果 tensor 是 [B, H, W, C] (Channels Last),转换为 [B, C, H, W] 给 CAM 库使用。
"""
# 如果是 4D 张量
if tensor.ndim == 4:
# 启发式判断:通常 C 通道数在几百上千,而 H, W 较小。
# 如果最后一个维度特别大,或者第1维度(索引)比第3维度小很多,那可能是 BHWC
h, w, c = tensor.shape[1], tensor.shape[2], tensor.shape[3]
# 很多 timm 模型内部 block 输出是 BHWC (Permuted)
# 例如: tensor.shape = [1, 56, 56, 384] -> 需要变为 [1, 384, 56, 56]
if c > h and c > w:
return tensor.permute(0, 3, 1, 2)
# 如果是 3D 张量 (ViT 输出: [B, Tokens, C]),需要还原成 2D 图片
if tensor.ndim == 3:
# 这里需要知道原始 feature map 的 h, w,稍显复杂
# 简化处理:假设是方形
b, num_tokens, c = tensor.shape
# 除去 cls_token (如果有)
if num_tokens % 2 != 0:
# 简单假设去掉第一个
tensor = tensor[:, 1:, :]
num_tokens -= 1
side = int(np.sqrt(num_tokens))
result = tensor.reshape(b, side, side, c).permute(0, 3, 1, 2)
return result
return tensor
# --- 核心逻辑类 ---
class AIImageDetector:
def __init__(self):
self.model = None
self.current_ckpt_name = None
self.current_meta = None
self.cam_engine = None # 替代原来的 grad_cam
def _get_target_layers(self, model):
"""准确找到适合可视化的层"""
backbone = getattr(model, 'backbone', model)
# 优先定位到最后一个 Stage 的最后一个 Block
if hasattr(backbone, 'stages'):
# CaFormer / ConvFormer / ResNet 结构
# 这是一个 ModuleList,取最后一个 stage
last_stage = backbone.stages[-1]
# 取最后一个 block
if hasattr(last_stage, 'blocks'):
return [last_stage.blocks[-1]]
return [last_stage]
# 针对 Swin 等
if hasattr(backbone, 'layers'):
return [backbone.layers[-1].blocks[-1]]
# 兜底
return [list(backbone.children())[-1]]
def load_model(self, ckpt_name):
if ckpt_name == self.current_ckpt_name and self.model is not None:
return
print(f"Loading model: {ckpt_name}...")
meta = CKPT_META[ckpt_name]
# 显存清理
if self.model is not None:
del self.model
del self.cam_engine
torch.cuda.empty_cache()
gc.collect()
ckpt_file = hf_hub_download(repo_id=meta["repo_id"], filename=meta["filename"], local_dir=LOCAL_CKPT_DIR, force_download=False)
if meta.get("head") == "timm_cross_entropy":
model = timm.create_model(meta["backbone_timm_name"], pretrained=False, num_classes=meta["num_classes"])
else:
model = TimmClassifierWithHead(meta["backbone"], num_classes=meta["num_classes"], pretrained=False)
model = model.to(DEVICE)
# 权重加载逻辑
if meta["filename"].endswith(".safetensors"):
state_dict = load_file(ckpt_file, device=DEVICE)
else:
state_dict = torch.load(ckpt_file, map_location=DEVICE)
if "model_state_dict" in state_dict:
state_dict = state_dict["model_state_dict"]
elif "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {k.replace("module.", "").replace("model.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=False)
model.eval()
self.model = model
self.current_ckpt_name = ckpt_name
self.current_meta = meta
self.cam_engine = None # 懒加载
def preprocess(self, image: Image.Image):
if image.mode != "RGB":
image = image.convert("RGB") if image.mode != "RGBA" else Image.new("RGBA", image.size, (255, 255, 255)).alpha_composite(image).convert("RGB")
meta = self.current_meta
size = int(meta.get("input_size", 384))
image_resized = image.resize((size, size), Image.BICUBIC)
# 归一化并转 Tensor
img_np = np.array(image_resized).astype(np.float32) / 255.0
# 供可视化使用 (0-1 float, RGB)
img_viz_np = img_np.copy()
mean = np.array(meta.get("mean", IMAGENET_MEAN), dtype=np.float32).reshape(1, 1, 3)
std = np.array(meta.get("std", IMAGENET_STD), dtype=np.float32).reshape(1, 1, 3)
img_norm = (img_np - mean) / std
img_tensor = torch.from_numpy(img_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
return img_tensor, img_viz_np
def predict(self, image, ckpt_name, enable_viz, method_name):
if image is None:
return None, None
self.load_model(ckpt_name)
input_tensor, img_viz_np = self.preprocess(image)
labels = self.current_meta["labels"]
if enable_viz:
target_layers = self._get_target_layers(self.model)
# 动态选择 CAM 方法
# LayerCAM 在检测 AI 伪影(纹理异常)方面通常比 GradCAM 更好
cam_cls = LayerCAM if method_name == "LayerCAM" else GradCAMPlusPlus
if method_name == "GradCAM": cam_cls = GradCAM
# 初始化 CAM
# 注意:reshape_transform 对于 CaFormer/ConvFormer 至关重要
cam = cam_cls(model=self.model, target_layers=target_layers, reshape_transform=reshape_transform)
# 获取预测类别作为目标
with torch.no_grad():
logits = self.model(input_tensor)
probs = F.softmax(logits, dim=1)
predicted_class = torch.argmax(probs, dim=1).item()
# 生成热力图
targets = [ClassifierOutputTarget(predicted_class)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
# 叠加
visualization = show_cam_on_image(img_viz_np, grayscale_cam, use_rgb=True)
viz_image = Image.fromarray(visualization)
result_dict = {labels[i]: float(probs.detach().cpu().numpy()[0][i]) for i in range(len(labels))}
# 清理 Hook (防止显存泄露)
# cam.del_hooks() # 旧版本可能需要手动删,新版通常会自动管理或复用
return result_dict, viz_image
else:
with torch.no_grad():
logits = self.model(input_tensor)
probs = F.softmax(logits, dim=1).cpu().numpy()[0]
result_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
return result_dict, None
detector = AIImageDetector()
# --- Gradio 界面 ---
def gr_predict(image, ckpt_name, enable_viz, method):
return detector.predict(image, ckpt_name, enable_viz, method)
def launch():
detector.load_model(DEFAULT_CKPT)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## 🕵️ AI Generated Image Detection & Explainability")
with gr.Row():
with gr.Column(scale=1):
in_img = gr.Image(type="pil", label="Input Image")
sel_ckpt = gr.Dropdown(list(CKPT_META.keys()), value=DEFAULT_CKPT, label="Model Checkpoint")
with gr.Group():
enable_viz = gr.Checkbox(label="Visualize Heatmap", value=True)
# 增加了方法选择
viz_method = gr.Dropdown(["LayerCAM", "GradCAM++", "GradCAM"], value="LayerCAM", label="Method (LayerCAM recommended for texture)")
run_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
out_lbl = gr.Label(num_top_classes=4, label="Prediction Confidence")
out_viz = gr.Image(type="pil", label="Attention / Heatmap")
run_btn.click(
fn=gr_predict,
inputs=[in_img, sel_ckpt, enable_viz, viz_method],
outputs=[out_lbl, out_viz]
)
demo.launch()
if __name__ == "__main__":
launch()