reagvis commited on
Commit
b024686
·
verified ·
1 Parent(s): f405593

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -0
app.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import os
6
+ import numpy as np
7
+ from transformers import CLIPModel, CLIPProcessor
8
+
9
+ class C2P_CLIP(nn.Module):
10
+ def __init__(self, name='openai/clip-vit-large-patch14', num_classes=1):
11
+ super(C2P_CLIP, self).__init__()
12
+ self.model = CLIPModel.from_pretrained(name)
13
+ del self.model.text_model
14
+ del self.model.text_projection
15
+ del self.model.logit_scale
16
+
17
+ self.model.vision_model.requires_grad_(False)
18
+ self.model.visual_projection.requires_grad_(False)
19
+ self.model.fc = nn.Linear(768, num_classes)
20
+ torch.nn.init.normal_(self.model.fc.weight.data, 0.0, 0.02)
21
+
22
+ # Create processor for image preprocessing
23
+ self.processor = CLIPProcessor.from_pretrained(name)
24
+
25
+ def encode_image(self, img):
26
+ # Updated to handle different argument expectations
27
+ vision_outputs = self.model.vision_model(
28
+ pixel_values=img,
29
+ # Removed problematic arguments for compatibility
30
+ )
31
+ # Check if output is a tuple or an object with hidden states
32
+ if isinstance(vision_outputs, tuple):
33
+ pooled_output = vision_outputs[1] # pooled_output
34
+ else:
35
+ # Handle the case where output is an object
36
+ pooled_output = vision_outputs.pooler_output
37
+
38
+ image_features = self.model.visual_projection(pooled_output)
39
+ return image_features
40
+
41
+ def forward(self, img):
42
+ image_embeds = self.encode_image(img)
43
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
44
+ return self.model.fc(image_embeds)
45
+
46
+ # Initialize model with cache directory
47
+ model_path = "model/C2P_CLIP_release_20240901.pth"
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ print(f"Using device: {device}")
50
+
51
+ # Create model cache directory if it doesn't exist
52
+ os.makedirs("model", exist_ok=True)
53
+
54
+ # Download the model if it doesn't exist
55
+ if not os.path.exists(model_path):
56
+ print("Downloading model for the first time...")
57
+ model_url = "https://www.now61.com/f/95OefW/C2P_CLIP_release_20240901.zip"
58
+ try:
59
+ state_dict = torch.hub.load_state_dict_from_url(
60
+ model_url, map_location="cpu", progress=True,
61
+ file_name="C2P_CLIP_release_20240901.pth",
62
+ check_hash=False, model_dir="model"
63
+ )
64
+ # Save model for future use
65
+ torch.save(state_dict, model_path)
66
+ except Exception as e:
67
+ print(f"Error downloading model: {e}")
68
+
69
+ # Initialize model
70
+ def load_model():
71
+ print("Loading model...")
72
+ model = C2P_CLIP(name='openai/clip-vit-large-patch14', num_classes=1)
73
+
74
+ try:
75
+ state_dict = torch.load(model_path, map_location=device)
76
+ model.load_state_dict(state_dict, strict=False)
77
+ print("Model loaded successfully!")
78
+ except Exception as e:
79
+ print(f"Error loading model: {e}")
80
+
81
+ model = model.to(device)
82
+ model.eval()
83
+ return model
84
+
85
+ # Global model instance
86
+ model = load_model()
87
+ processor = CLIPProcessor.from_pretrained('openai/clip-vit-large-patch14')
88
+
89
+ def analyze_image(image):
90
+ """Process an image and return deepfake detection results"""
91
+ if image is None:
92
+ return None, "Please upload an image.", None
93
+
94
+ try:
95
+ # Ensure image is in RGB mode
96
+ if isinstance(image, np.ndarray):
97
+ image = Image.fromarray(image).convert("RGB")
98
+ else:
99
+ image = image.convert("RGB")
100
+
101
+ # Process the image
102
+ inputs = processor(images=image, return_tensors="pt").to(device)
103
+
104
+ # Run inference
105
+ with torch.no_grad():
106
+ prediction = model(inputs.pixel_values).sigmoid().item()
107
+
108
+ # Create visual output
109
+ # Add a colored border based on the prediction
110
+ border_color = (255, 0, 0) if prediction > 0.5 else (0, 255, 0) # Red for fake, green for real
111
+ border_width = 10
112
+
113
+ # Create a new image with border
114
+ bordered_image = Image.new('RGB', (image.width + 2*border_width, image.height + 2*border_width), border_color)
115
+ bordered_image.paste(image, (border_width, border_width))
116
+
117
+ # Create result text
118
+ if prediction > 0.5:
119
+ result = "FAKE (AI-generated or manipulated)"
120
+ confidence = prediction
121
+ else:
122
+ result = "REAL (authentic)"
123
+ confidence = 1 - prediction
124
+
125
+ confidence_text = f"Confidence: {confidence:.4f} ({confidence*100:.2f}%)"
126
+
127
+ return bordered_image, result, confidence_text
128
+
129
+ except Exception as e:
130
+ import traceback
131
+ error_msg = f"Error analyzing image: {str(e)}"
132
+ traceback.print_exc()
133
+ return image, "Error", error_msg
134
+
135
+ # Create Gradio interface
136
+ title = "C2P-CLIP Deepfake Detector"
137
+ description = """
138
+ <p style='text-align: center'>
139
+ <b>C2P-CLIP: Deepfake Detection with Enhanced Generalization</b>
140
+ </p>
141
+
142
+ This app uses the C2P-CLIP model to detect if an image is real or AI-generated/manipulated.
143
+
144
+ <b>How to use:</b>
145
+ 1. Upload an image or use one of the examples
146
+ 2. The model will analyze and show if it's likely real or fake
147
+ 3. A colored border will be added (green = real, red = fake)
148
+
149
+ <b>Limitations:</b>
150
+ - The model provides a binary classification (real/fake) without detailed explanation
151
+ - No localization of manipulated regions
152
+ - Performance may vary across different types of manipulations
153
+ """
154
+
155
+ # Example images
156
+ examples = [
157
+ ["examples/real1.jpg"],
158
+ ["examples/fake1.jpg"],
159
+ ]
160
+
161
+ # Create example directory if it doesn't exist
162
+ os.makedirs("examples", exist_ok=True)
163
+
164
+ interface = gr.Interface(
165
+ fn=analyze_image,
166
+ inputs=gr.Image(type="pil", label="Upload Image"),
167
+ outputs=[
168
+ gr.Image(type="pil", label="Analyzed Image"),
169
+ gr.Textbox(label="Result"),
170
+ gr.Textbox(label="Confidence"),
171
+ ],
172
+ title=title,
173
+ description=description,
174
+ examples=examples if all(os.path.exists(ex[0]) for ex in examples) else None,
175
+ allow_flagging="never",
176
+ theme=gr.themes.Soft(),
177
+ )
178
+
179
+ # Launch the app
180
+ if __name__ == "__main__":
181
+ interface.launch()