selfit-camera commited on
Commit
8721ec2
·
1 Parent(s): 7223d40
Files changed (3) hide show
  1. app.py +21 -15
  2. nfsw.py +52 -0
  3. util.py +39 -1
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import gradio as gr
2
  import threading
 
 
 
3
  from util import process_image_edit, get_country_info_safe
4
  from nfsw import NSFWDetector
5
 
@@ -31,9 +34,9 @@ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.P
31
  country_info = get_country_info_safe(client_ip)
32
 
33
  # 检查IP是否因NSFW违规过多而被屏蔽 3
34
- if client_ip in NSFW_Dict and NSFW_Dict[client_ip] >= 3:
35
  print(f"❌ IP blocked due to excessive NSFW violations - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
36
- return None, f"❌ Your ip {client_ip},your region has been blocked"
37
 
38
  if input_image is None:
39
  return None, "Please upload an image first"
@@ -47,31 +50,34 @@ def edit_image_interface(input_image, prompt, request: gr.Request, progress=gr.P
47
 
48
  # 检查图片是否包含NSFW内容
49
  nsfw_result = None
50
- if nsfw_detector is not None:
51
  try:
52
- nsfw_result = nsfw_detector.predict_label_only(input_image)
 
 
 
53
  if nsfw_result.lower() == "nsfw":
54
  # 记录NSFW违规次数
55
  if client_ip not in NSFW_Dict:
56
  NSFW_Dict[client_ip] = 0
57
  NSFW_Dict[client_ip] += 1
58
  print(f"❌ NSFW image detected - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
59
- return None, f"❌ Your ip {client_ip},your region has been blocked"
 
60
  except Exception as e:
61
  print(f"⚠️ NSFW检测失败: {e}")
62
  # 检测失败时允许继续处理
63
 
64
- if IP_Dict[client_ip]>8 and country_info.lower() in ["印度", "巴基斯坦"]:
65
  print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
66
  return None, "❌ Content not allowed. Please modify your prompt"
67
- if IP_Dict[client_ip]>18 and country_info.lower() in ["中国"]:
68
- print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
69
- return None, "❌ Content not allowed. Please modify your prompt"
70
-
71
- if client_ip.lower() in ["221.194.171.230", "101.126.56.37", "101.126.56.44"]:
72
- print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
73
- return None, "❌ Content not allowed. Please modify your prompt"
74
-
75
 
76
  result_url = None
77
  status_message = ""
@@ -136,7 +142,7 @@ def create_app():
136
  gr.Markdown("### 📸 Upload Image")
137
  input_image = gr.Image(
138
  label="Select image to edit",
139
- type="filepath",
140
  height=400,
141
  elem_classes=["upload-area"]
142
  )
 
1
  import gradio as gr
2
  import threading
3
+ import os
4
+ import shutil
5
+ import tempfile
6
  from util import process_image_edit, get_country_info_safe
7
  from nfsw import NSFWDetector
8
 
 
34
  country_info = get_country_info_safe(client_ip)
35
 
36
  # 检查IP是否因NSFW违规过多而被屏蔽 3
37
+ if client_ip in NSFW_Dict and NSFW_Dict[client_ip] >= 5:
38
  print(f"❌ IP blocked due to excessive NSFW violations - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
39
+ return None, f"❌ Your ip {client_ip},your region has been blocked for too much nsfw content"
40
 
41
  if input_image is None:
42
  return None, "Please upload an image first"
 
50
 
51
  # 检查图片是否包含NSFW内容
52
  nsfw_result = None
53
+ if nsfw_detector is not None and input_image is not None:
54
  try:
55
+ # 直接使用PIL Image对象进行检测,避免文件路径问题
56
+ nsfw_result = nsfw_detector.predict_pil_label_only(input_image)
57
+ print(f"🔍 NSFW检测结果: {nsfw_result} - IP: {client_ip}({country_info})")
58
+
59
  if nsfw_result.lower() == "nsfw":
60
  # 记录NSFW违规次数
61
  if client_ip not in NSFW_Dict:
62
  NSFW_Dict[client_ip] = 0
63
  NSFW_Dict[client_ip] += 1
64
  print(f"❌ NSFW image detected - IP: {client_ip}({country_info}), violations: {NSFW_Dict[client_ip]}")
65
+ return None, f"❌ Your ip {client_ip},your region has been blocked for too much nsfw content"
66
+
67
  except Exception as e:
68
  print(f"⚠️ NSFW检测失败: {e}")
69
  # 检测失败时允许继续处理
70
 
71
+ if IP_Dict[client_ip]>10 and country_info.lower() in ["印度", "巴基斯坦"]:
72
  print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
73
  return None, "❌ Content not allowed. Please modify your prompt"
74
+ # if IP_Dict[client_ip]>18 and country_info.lower() in ["中国"]:
75
+ # print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
76
+ # return None, "❌ Content not allowed. Please modify your prompt"
77
+ # if client_ip.lower() in ["221.194.171.230", "101.126.56.37", "101.126.56.44"]:
78
+ # print(f"❌ Content not allowed - IP: {client_ip}({country_info}), count: {IP_Dict[client_ip]}, prompt: {prompt.strip()}")
79
+ # return None, "❌ Content not allowed. Please modify your prompt"
80
+
 
81
 
82
  result_url = None
83
  status_message = ""
 
142
  gr.Markdown("### 📸 Upload Image")
143
  input_image = gr.Image(
144
  label="Select image to edit",
145
+ type="pil",
146
  height=400,
147
  elem_classes=["upload-area"]
148
  )
nfsw.py CHANGED
@@ -187,6 +187,58 @@ class NSFWDetector:
187
  """
188
  predicted_label, _ = self.predict(image_path)
189
  return predicted_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # --- 使用示例 ---
192
  if __name__ == "__main__":
 
187
  """
188
  predicted_label, _ = self.predict(image_path)
189
  return predicted_label
190
+
191
+ def predict_from_pil(self, pil_image):
192
+ """
193
+ 直接从PIL Image对象进行NSFW检测
194
+
195
+ Args:
196
+ pil_image (PIL.Image): PIL图像对象
197
+
198
+ Returns:
199
+ tuple: (预测标签, 原始图像)
200
+ """
201
+ try:
202
+ # 确保是RGB格式
203
+ if pil_image.mode != "RGB":
204
+ pil_image = pil_image.convert("RGB")
205
+
206
+ # 调整尺寸
207
+ image_resized = pil_image.resize(self.input_size, Image.Resampling.BILINEAR)
208
+
209
+ # 转换为numpy数组并归一化
210
+ image_np = np.array(image_resized, dtype=np.float32) / 255.0
211
+
212
+ # 调整维度顺序 [H, W, C] -> [C, H, W]
213
+ image_np = np.transpose(image_np, (2, 0, 1))
214
+
215
+ # 添加批次维度 [C, H, W] -> [1, C, H, W]
216
+ input_tensor = np.expand_dims(image_np, axis=0).astype(np.float32)
217
+
218
+ # 运行推理
219
+ outputs = self.session.run([self.output_name], {self.input_name: input_tensor})
220
+ predictions = outputs[0]
221
+
222
+ # 后处理结果
223
+ predicted_label = self._postprocess_predictions(predictions)
224
+
225
+ return predicted_label, pil_image
226
+
227
+ except Exception as e:
228
+ raise RuntimeError(f"PIL图像预测失败: {e}")
229
+
230
+ def predict_pil_label_only(self, pil_image):
231
+ """
232
+ 从PIL Image对象只返回预测标签
233
+
234
+ Args:
235
+ pil_image (PIL.Image): PIL图像对象
236
+
237
+ Returns:
238
+ str: 预测的类别标签
239
+ """
240
+ predicted_label, _ = self.predict_from_pil(pil_image)
241
+ return predicted_label
242
 
243
  # --- 使用示例 ---
244
  if __name__ == "__main__":
util.py CHANGED
@@ -11,7 +11,9 @@ import func_timeout
11
  import numpy as np
12
  import gradio as gr
13
  import boto3
 
14
  from botocore.client import Config
 
15
 
16
 
17
  # TOKEN = os.environ['TOKEN']
@@ -248,15 +250,38 @@ def check_task_status(task_id):
248
  return 'error', None, f"请求异常: {str(e)}"
249
 
250
 
251
- def process_image_edit(img_path, prompt, progress_callback=None):
252
  """
253
  处理图片编辑的完整流程
 
 
 
 
 
254
  """
 
255
  try:
256
  # 生成客户端 IP 和时间戳
257
  client_ip = "127.0.0.1" # 默认IP
258
  time_id = int(time.time())
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  if progress_callback:
261
  progress_callback("uploading image...")
262
 
@@ -305,6 +330,19 @@ def process_image_edit(img_path, prompt, progress_callback=None):
305
 
306
  except Exception as e:
307
  return None, f"error occurred during processing: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
 
310
  if __name__ == "__main__":
 
11
  import numpy as np
12
  import gradio as gr
13
  import boto3
14
+ import tempfile
15
  from botocore.client import Config
16
+ from PIL import Image
17
 
18
 
19
  # TOKEN = os.environ['TOKEN']
 
250
  return 'error', None, f"请求异常: {str(e)}"
251
 
252
 
253
+ def process_image_edit(img_input, prompt, progress_callback=None):
254
  """
255
  处理图片编辑的完整流程
256
+
257
+ Args:
258
+ img_input: 可以是文件路径(str)或PIL Image对象
259
+ prompt: 编辑指令
260
+ progress_callback: 进度回调函数
261
  """
262
+ temp_img_path = None
263
  try:
264
  # 生成客户端 IP 和时间戳
265
  client_ip = "127.0.0.1" # 默认IP
266
  time_id = int(time.time())
267
 
268
+ # 处理输入图像 - 支持PIL Image和文件路径
269
+ if hasattr(img_input, 'save'): # PIL Image对象
270
+ # 创建临时文件
271
+ temp_dir = tempfile.mkdtemp()
272
+ temp_img_path = os.path.join(temp_dir, f"temp_img_{time_id}.jpg")
273
+
274
+ # 保存PIL Image为临时文件
275
+ if img_input.mode != 'RGB':
276
+ img_input = img_input.convert('RGB')
277
+ img_input.save(temp_img_path, 'JPEG', quality=95)
278
+
279
+ img_path = temp_img_path
280
+ print(f"💾 PIL Image已保存为临时文件: {temp_img_path}")
281
+ else:
282
+ # 假设是文件路径
283
+ img_path = img_input
284
+
285
  if progress_callback:
286
  progress_callback("uploading image...")
287
 
 
330
 
331
  except Exception as e:
332
  return None, f"error occurred during processing: {str(e)}"
333
+
334
+ finally:
335
+ # 清理临时文件
336
+ if temp_img_path and os.path.exists(temp_img_path):
337
+ try:
338
+ os.remove(temp_img_path)
339
+ # 尝试删除临时目录(如果为空)
340
+ temp_dir = os.path.dirname(temp_img_path)
341
+ if os.path.exists(temp_dir):
342
+ os.rmdir(temp_dir)
343
+ print(f"🗑️ 已清理临时文件: {temp_img_path}")
344
+ except Exception as cleanup_error:
345
+ print(f"⚠️ 清理临时文件失败: {cleanup_error}")
346
 
347
 
348
  if __name__ == "__main__":