import os import torch import timm import numpy as np import torch.nn as nn import torch.nn.functional as F from PIL import Image import gradio as gr from huggingface_hub import hf_hub_download from safetensors.torch import load_file SEED = 4421 DROP_RATE = 0.1 LOCAL_CKPT_DIR = "./checkpoints" DEFAULT_CKPT = "V2" device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(SEED) np.random.seed(SEED) IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] CKPT_META = { "V2.5-old": { "num_classes": 4, "head": "v7", "backbone": "caformer_b36.sail_in22k_ft_in1k_384", "repo_id": "telecomadm1145/swin-ai-detection", "filename": "caformer_b36_4class_96.safetensors", "labels": ["non_ai", "ai", "ani_non_ai", "ani_ai"], "input_size": 384, "mean": IMAGENET_MEAN, "std": IMAGENET_STD, "interpolation": "bicubic", "crop_pct": 1.0, }, "V2": { "num_classes": 2, "head": "timm_cross_entropy", "backbone_timm_name": "hf-hub:animetimm/caformer_b36.dbv4-full", "repo_id": "telecomadm1145/danbooru-real-vs-ai-caformer-b36-v2", "filename": "pytorch_model.bin", "labels": ["AI Generated", "Non-AI Generated"], "input_size": 384, "mean": IMAGENET_MEAN, "std": IMAGENET_STD, "interpolation": "bicubic", "crop_pct": 1.0, }, "deepghs-ai-chk-1m-r512": { "num_classes": 2, "head": "timm_cross_entropy", "backbone_timm_name": "caformer_s36.sail_in22k_ft_in1k_384", "repo_id": "deepghs/cls-ai-check-1m.caformer_s36.r512", "filename": "model.safetensors", "labels": ["AI Generated", "Non-AI Generated"], "input_size": 512, "mean": IMAGENET_MEAN, "std": IMAGENET_STD, "interpolation": "bicubic", "crop_pct": 1.0, }, } model = None current_ckpt = None current_meta = None class TimmClassifierWithHead(nn.Module): def __init__(self, model_name, num_classes, pretrained=True): super().__init__() self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0) self.classifier = nn.Sequential( nn.Dropout(DROP_RATE), nn.Linear(self.backbone.num_features, 64), nn.BatchNorm1d(64), nn.GELU(), nn.Dropout(DROP_RATE * 0.8), nn.Linear(64, num_classes), ) def forward(self, x): return self.classifier(self.backbone(x)) def pil_ensure_rgb(image: Image.Image) -> Image.Image: if image.mode not in ("RGB", "RGBA"): image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") if image.mode == "RGBA": canvas = Image.new("RGBA", image.size, (255, 255, 255)) canvas.alpha_composite(image) image = canvas.convert("RGB") return image def _get_pil_interpolation(name: str): name = (name or "bicubic").lower() if name == "bicubic": return Image.BICUBIC if name == "bilinear": return Image.BILINEAR if name == "lanczos": return Image.LANCZOS if name == "nearest": return Image.NEAREST return Image.BICUBIC def preprocess_image(image: Image.Image, meta: dict, device: str): img = pil_ensure_rgb(image) size = int(meta.get("input_size", 384)) crop_pct = float(meta.get("crop_pct", 0.875)) resize_short = max(1, int(round(size / crop_pct))) interp = _get_pil_interpolation(meta.get("interpolation", "bicubic")) w, h = img.size short = w if w < h else h scale = resize_short / float(short) new_w = int(round(w * scale)) new_h = int(round(h * scale)) img = img.resize((new_w, new_h), interp) left = (new_w - size) // 2 top = (new_h - size) // 2 img = img.crop((left, top, left + size, top + size)) x = np.asarray(img).astype(np.float32) / 255.0 x = torch.from_numpy(x).permute(2, 0, 1) mean = torch.tensor(meta.get("mean", IMAGENET_MEAN)).view(3, 1, 1) std = torch.tensor(meta.get("std", IMAGENET_STD)).view(3, 1, 1) x = (x - mean) / std return x.unsqueeze(0).to(device) def _strip_prefix(state_dict, prefixes=("model.", "module.")): new = {} for k, v in state_dict.items(): new_k = k for p in prefixes: if k.startswith(p): new_k = k[len(p):] break new[new_k] = v return new def load_model(ckpt_name: str): global model, current_ckpt, current_meta if ckpt_name == None or (ckpt_name == current_ckpt and model is not None): return meta = CKPT_META[ckpt_name] ckpt_file = hf_hub_download(repo_id=meta["repo_id"], filename=meta["filename"], local_dir=LOCAL_CKPT_DIR, force_download=False) if meta["head"] == "timm_cross_entropy": model_ = timm.create_model(meta["backbone_timm_name"], pretrained=False, num_classes=meta["num_classes"]) else: model_ = TimmClassifierWithHead(meta["backbone"], num_classes=meta["num_classes"], pretrained=False) model_ = model_.to(device) if meta["filename"].endswith(".safetensors"): state = load_file(ckpt_file, device=device) else: state = torch.load(ckpt_file, map_location=device) if isinstance(state, dict) and ("model_state_dict" in state or "state_dict" in state): state_dict = state.get("model_state_dict", state.get("state_dict")) else: state_dict = state try: model_.load_state_dict(state_dict, strict=True) except Exception: cleaned = _strip_prefix(state_dict, ("model.", "module.")) try: model_.load_state_dict(cleaned, strict=True) except Exception: model_.load_state_dict(cleaned, strict=False) model_.eval() model = model_ current_ckpt = ckpt_name current_meta = meta def predict(image: Image.Image, ckpt_name: str): if image is None: return None load_model(ckpt_name) inp = preprocess_image(image, current_meta, device) with torch.no_grad(): logits = model(inp) probs = F.softmax(logits, dim=1)[0].cpu() class_names = current_meta["labels"] return {class_names[i]: float(probs[i]) for i in range(len(class_names))} def launch(): load_model(DEFAULT_CKPT) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AI generated image detection") with gr.Row(): with gr.Column(scale=1): run_btn = gr.Button("Run", variant="primary") sel_ckpt = gr.Dropdown(list(CKPT_META.keys()), value=DEFAULT_CKPT, label="Checkpoints") in_img = gr.Image(type="pil", label="Image") with gr.Column(scale=1): out_lbl = gr.Label(num_top_classes=4, label="Prediction") run_btn.click(predict, inputs=[in_img, sel_ckpt], outputs=[out_lbl]) if os.path.exists("examples"): example_files = [os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[in_img], outputs=[out_lbl], fn=predict, cache_examples=False) demo.launch() if __name__ == "__main__": launch()