Spaces:
Paused
Paused
| import os | |
| import requests | |
| import json | |
| import time | |
| import random | |
| import base64 | |
| import uuid | |
| import threading | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| import torch | |
| import logging | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| load_dotenv() | |
| MODEL_URL = "TostAI/nsfw-text-detection-large" | |
| CLASS_NAMES = {0: "✅ SAFE", 1: "⚠️ QUESTIONABLE", 2: "🚫 UNSAFE"} | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_URL) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_URL) | |
| class SessionManager: | |
| _instances = {} | |
| _lock = threading.Lock() | |
| def get_session(cls, session_id): | |
| with cls._lock: | |
| if session_id not in cls._instances: | |
| cls._instances[session_id] = { | |
| 'count': 0, | |
| 'history': [], | |
| 'last_active': time.time() | |
| } | |
| return cls._instances[session_id] | |
| def cleanup_sessions(cls): | |
| with cls._lock: | |
| now = time.time() | |
| expired = [k for k, v in cls._instances.items() if now - v['last_active'] > 3600] | |
| for k in expired: | |
| del cls._instances[k] | |
| class RateLimiter: | |
| def __init__(self): | |
| self.clients = {} | |
| self.lock = threading.Lock() | |
| def check(self, client_id): | |
| with self.lock: | |
| now = time.time() | |
| if client_id not in self.clients: | |
| self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
| return True | |
| if now > self.clients[client_id]['reset']: | |
| self.clients[client_id] = {'count': 1, 'reset': now + 3600} | |
| return True | |
| if self.clients[client_id]['count'] >= 10: | |
| return False | |
| self.clients[client_id]['count'] += 1 | |
| return True | |
| session_manager = SessionManager() | |
| rate_limiter = RateLimiter() | |
| def create_error_image(message): | |
| img = Image.new("RGB", (832, 480), "#ffdddd") | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 24) | |
| except: | |
| font = ImageFont.load_default() | |
| draw = ImageDraw.Draw(img) | |
| text = f"Error: {message[:60]}..." if len(message) > 60 else message | |
| draw.text((50, 200), text, fill="#ff0000", font=font) | |
| img.save("error.jpg") | |
| return "error.jpg" | |
| def classify_prompt(prompt): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| return torch.argmax(outputs.logits).item() | |
| def image_to_base64(file_path): | |
| try: | |
| with open(file_path, "rb") as image_file: | |
| raw_data = image_file.read() | |
| encoded = base64.b64encode(raw_data) | |
| missing_padding = len(encoded) % 4 | |
| if missing_padding: | |
| encoded += b'=' * (4 - missing_padding) | |
| return encoded.decode('utf-8') | |
| except Exception as e: | |
| raise ValueError(f"Base64编码失败: {str(e)}") | |
| def video_to_base64(file_path): | |
| """ | |
| 将视频文件转换为Base64格式 | |
| """ | |
| try: | |
| with open(file_path, "rb") as video_file: | |
| raw_data = video_file.read() | |
| encoded = base64.b64encode(raw_data) | |
| missing_padding = len(encoded) % 4 | |
| if missing_padding: | |
| encoded += b'=' * (4 - missing_padding) | |
| return encoded.decode('utf-8') | |
| except Exception as e: | |
| raise ValueError(f"Base64编码失败: {str(e)}") | |
| def generate_video( | |
| context_scale, | |
| enable_safety_checker, | |
| enable_fast_mode, | |
| flow_shift, | |
| guidance_scale, | |
| images, | |
| negative_prompt, | |
| num_inference_steps, | |
| prompt, | |
| seed, | |
| size, | |
| task, | |
| video, | |
| session_id, | |
| ): | |
| safety_level = classify_prompt(prompt) | |
| if safety_level != 0: | |
| error_img = create_error_image(CLASS_NAMES[safety_level]) | |
| yield f"❌ Blocked: {CLASS_NAMES[safety_level]}", error_img | |
| return | |
| if not rate_limiter.check(session_id): | |
| error_img = create_error_image("每小时限制20次请求") | |
| yield "❌ rate limit exceeded", error_img | |
| return | |
| session = session_manager.get_session(session_id) | |
| session['last_active'] = time.time() | |
| session['count'] += 1 | |
| API_KEY = "30a09de38569400bcdab9cec1c9a660b1924a2b5f54aa386eeb87f96a112fb93" | |
| if not API_KEY: | |
| error_img = create_error_image("API key not found") | |
| yield "❌ Error: Missing API Key", error_img | |
| return | |
| try: | |
| base64_images = [] | |
| if images is not None: # 检查 images 是否为 None | |
| for img_path in images: | |
| base64_img = image_to_base64(img_path) | |
| base64_images.append(base64_img) | |
| except Exception as e: | |
| error_img = create_error_image(str(e)) | |
| yield f"❌failed to upload images: {str(e)}", error_img | |
| return | |
| video_payload = "" | |
| if video is not None: | |
| if isinstance(video, (list, tuple)): | |
| video_payload = video[0] if video else "" | |
| else: | |
| video_payload = video | |
| # 将视频文件转换为Base64格式 | |
| try: | |
| base64_video = video_to_base64(video_payload) | |
| video_payload = base64_video | |
| except Exception as e: | |
| error_img = create_error_image(str(e)) | |
| yield f"❌ Failed to encode video: {str(e)}", error_img | |
| return | |
| payload = { | |
| "context_scale": context_scale, | |
| "enable_fast_mode": enable_fast_mode, | |
| "enable_safety_checker": enable_safety_checker, | |
| "flow_shift": flow_shift, | |
| "guidance_scale": guidance_scale, | |
| "images": base64_images, | |
| "negative_prompt": negative_prompt, | |
| "num_inference_steps": num_inference_steps, | |
| "prompt": prompt, | |
| "seed": seed if seed != -1 else random.randint(0, 999999), | |
| "size": size, | |
| "task": task, | |
| "video": str(video_payload) if video_payload else "", | |
| } | |
| logging.debug(f"API request payload: {json.dumps(payload, indent=2)}") | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {API_KEY}", | |
| } | |
| try: | |
| response = requests.post( | |
| "https://api.wavespeed.ai/api/v2/wavespeed-ai/wan-2.1-14b-vace", | |
| headers=headers, | |
| data=json.dumps(payload) | |
| ) | |
| if response.status_code != 200: | |
| error_img = create_error_image(response.text) | |
| yield f"❌ API Error ({response.status_code}): {response.text}", error_img | |
| return | |
| request_id = response.json()["data"]["id"] | |
| yield f"✅ Task ID (ID: {request_id})", None | |
| except Exception as e: | |
| error_img = create_error_image(str(e)) | |
| yield f"❌ Connection Error: {str(e)}", error_img | |
| return | |
| result_url = f"https://api.wavespeed.ai/api/v2/predictions/{request_id}/result" | |
| start_time = time.time() | |
| while True: | |
| time.sleep(0.5) | |
| try: | |
| response = requests.get(result_url, headers=headers) | |
| if response.status_code != 200: | |
| error_img = create_error_image(response.text) | |
| yield f"❌ 轮询错误 ({response.status_code}): {response.text}", error_img | |
| return | |
| data = response.json()["data"] | |
| status = data["status"] | |
| if status == "completed": | |
| elapsed = time.time() - start_time | |
| video_url = data['outputs'][0] | |
| session["history"].append(video_url) | |
| yield (f"🎉 完成! 耗时 {elapsed:.1f}秒\n" | |
| f"下载链接: {video_url}"), video_url | |
| return | |
| elif status == "failed": | |
| error_img = create_error_image(data.get('error', '未知错误')) | |
| yield f"❌ 任务失败: {data.get('error', '未知错误')}", error_img | |
| return | |
| else: | |
| yield f"⏳ 状态: {status.capitalize()}...", None | |
| except Exception as e: | |
| error_img = create_error_image(str(e)) | |
| yield f"❌ 轮询失败: {str(e)}", error_img | |
| return | |
| def cleanup_task(): | |
| while True: | |
| session_manager.cleanup_sessions() | |
| time.sleep(3600) | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .video-preview { max-width: 600px !important; } | |
| .status-box { padding: 10px; border-radius: 5px; margin: 5px; } | |
| .safe { background: #e8f5e9; border: 1px solid #a5d6a7; } | |
| .warning { background: #fff3e0; border: 1px solid #ffcc80; } | |
| .error { background: #ffebee; border: 1px solid #ef9a9a; } | |
| #centered_button { | |
| align-self: center !important; | |
| height: fit-content !important; | |
| margin-top: 22px !important; # 根据输入框高度微调 | |
| } | |
| """ | |
| ) as app: | |
| session_id = gr.State(str(uuid.uuid4())) | |
| gr.Markdown("# 🌊Wan-2.1-14B-Vace Run On [WaveSpeedAI](https://wavespeed.ai/)") | |
| gr.Markdown("""VACE is an all-in-one model designed for video creation and editing. It encompasses various tasks, including reference-to-video generation (R2V), video-to-video editing (V2V), and masked video-to-video editing (MV2V), allowing users to compose these tasks freely. This functionality enables users to explore diverse possibilities and streamlines their workflows effectively, offering a range of capabilities, such as Move-Anything, Swap-Anything, Reference-Anything, Expand-Anything, Animate-Anything, and more.""") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| images = gr.File(label="upload image", file_count="multiple", file_types=["image"], type="filepath", elem_id="image-uploader", | |
| scale=1) | |
| video = gr.Video(label="Input Video", format="mp4", sources=["upload"], | |
| scale=1) | |
| prompt = gr.Textbox(label="Prompt", lines=5, placeholder="Prompt") | |
| negative_prompt = gr.Textbox(label="Negative Prompt", lines=2) | |
| with gr.Row(): | |
| size = gr.Dropdown(["832*480", "480*832"], value="832*480", label="Size") | |
| task = gr.Dropdown(["depth", "pose"], value="depth", label="Task") | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider(1, 100, value=30, step=1, label="Inference Steps") | |
| context_scale = gr.Slider(0, 2, value=1, step=0.1, label="Context Scale") | |
| with gr.Row(): | |
| guidance = gr.Slider(1, 20, value=5, step=0.1, label="Guidance_Scale") | |
| flow_shift = gr.Slider(1, 20, value=16, step=1, label="Shift") | |
| with gr.Row(): | |
| seed = gr.Number(-1, label="Seed") | |
| random_seed_btn = gr.Button("Random🎲Seed", variant="secondary", elem_id="centered_button") | |
| with gr.Row(): | |
| enable_safety_checker = gr.Checkbox(True, label="Enable Safety Checker", interactive=True) | |
| enable_fast_mode = gr.Checkbox(True, label="To enable the fast mode, please visit Wave Speed AI", interactive=False) | |
| with gr.Column(scale=1): | |
| video_output = gr.Video(label="Video Output", format="mp4", interactive=False, elem_classes=["video-preview"]) | |
| generate_btn = gr.Button("Generate", variant="primary") | |
| status_output = gr.Textbox(label="status", interactive=False, lines=4) | |
| # gr.Examples( | |
| # examples=[ | |
| # [ | |
| # "The elegant lady carefully selects bags in the boutique, and she shows the charm of a mature woman in a black slim dress with a pearl necklace, as well as her pretty face. Holding a vintage-inspired blue leather half-moon handbag, she is carefully observing its craftsmanship and texture. The interior of the store is a haven of sophistication and luxury. Soft, ambient lighting casts a warm glow over the polished wooden floors", | |
| # [ | |
| # "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413751234102420_md9ywspl.png", | |
| # "https://d2g64w682n9w0w.cloudfront.net/media/ec44bbf6abac4c25998dd2c4af1a46a7/images/1747413586520964413_7bkgc9ol.png" | |
| # ] | |
| # ] | |
| # ], | |
| # inputs=[prompt, images], | |
| # ) | |
| random_seed_btn.click( | |
| fn=lambda: random.randint(0, 999999), | |
| outputs=seed | |
| ) | |
| generate_btn.click( | |
| generate_video, | |
| inputs=[ | |
| context_scale, | |
| enable_safety_checker, | |
| enable_fast_mode, | |
| flow_shift, | |
| guidance, | |
| images, | |
| negative_prompt, | |
| num_inference_steps, | |
| prompt, | |
| seed, | |
| size, | |
| task, | |
| video, | |
| session_id, | |
| ], | |
| outputs=[status_output, video_output] | |
| ) | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("gradio_app.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| gradio_logger = logging.getLogger("gradio") | |
| gradio_logger.setLevel(logging.INFO) | |
| if __name__ == "__main__": | |
| threading.Thread(target=cleanup_task, daemon=True).start() | |
| app.queue(max_size=2).launch( | |
| server_name="0.0.0.0", | |
| max_threads=10, | |
| share=False | |
| ) |