Video-Fx / prompt_generator.py
garyuzair's picture
Upload 6 files
7bf78f6 verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class PromptGenerator:
def __init__(self):
self.model = None
self.tokenizer = None
self.prompt_cache = {}
def load_model(self):
"""Load a lightweight text generation model"""
if self.model is None:
with st.spinner("Loading text-to-prompt model..."):
try:
# Using BART model for better prompt enhancement
model_name = "facebook/bart-large-cnn"
# Load tokenizer and model separately to avoid device issues
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load model with optimizations for memory efficiency
self.model = AutoModelForSeq2SeqLM.from_pretrained(
model_name,
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
# Explicitly move to CPU to avoid meta tensor issues
self.model = self.model.to("cpu")
except Exception as e:
st.warning(f"Error loading model: {str(e)}. Using fallback method.")
# If model loading fails, we'll use a simple keyword-based approach
self.model = None
self.tokenizer = None
return self.model, self.tokenizer
def generate_hyper_realistic_prompt(self, transcription, aspect_ratio="16:9"):
"""Generate a hyper-realistic prompt from a transcription with cinematic quality using BART model"""
# Check cache first
import hashlib
cache_key = hashlib.md5((transcription + aspect_ratio).encode()).hexdigest()
if cache_key in self.prompt_cache:
return self.prompt_cache[cache_key]
# Skip empty transcriptions
if not transcription.strip():
return ""
# For horror/suspense story, use specific visual themes
# Analyze the content to determine the scene type
lower_trans = transcription.lower()
# Create a more meaningful base prompt from the transcription
# Extract key visual elements and create a scene description
visual_scene = self.extract_visual_elements(transcription)
# Determine scene type based on content analysis
scene_type = "neutral"
if any(word in lower_trans for word in ["dark", "shadow", "fear", "horror", "scary", "afraid", "terror", "scream", "blood", "death", "evil", "monster", "ghost", "nightmare", "creepy", "spooky", "haunted", "skeleton"]):
scene_type = "horror"
elif any(word in lower_trans for word in ["mystery", "detective", "clue", "investigate", "secret", "discover", "reveal", "hidden", "puzzle", "solve"]):
scene_type = "mystery"
elif any(word in lower_trans for word in ["fantasy", "magic", "wizard", "dragon", "fairy", "enchanted", "spell", "mythical", "legend", "ancient", "kingdom"]):
scene_type = "fantasy"
elif any(word in lower_trans for word in ["sci-fi", "future", "space", "alien", "robot", "technology", "advanced", "starship", "planet", "galaxy"]):
scene_type = "scifi"
# Select appropriate visual elements based on scene type
if scene_type == "horror":
style_keywords = [
"dark atmospheric horror scene",
"cinematic horror",
"atmospheric dread",
"horror movie still",
"dark gothic scene",
"eerie lighting",
"suspenseful moment"
]
lighting = [
"dim lighting",
"shadows",
"moonlight",
"eerie glow",
"dark atmospheric lighting"
]
elif scene_type == "mystery":
style_keywords = [
"mysterious scene",
"film noir style",
"detective story visual",
"suspenseful moment",
"enigmatic scene"
]
lighting = [
"moody lighting",
"dramatic shadows",
"low key lighting",
"atmospheric fog"
]
elif scene_type == "fantasy":
style_keywords = [
"fantasy scene",
"magical environment",
"enchanted setting",
"mythical landscape",
"fantasy illustration style"
]
lighting = [
"magical glow",
"ethereal light",
"golden hour",
"mystical atmosphere"
]
elif scene_type == "scifi":
style_keywords = [
"futuristic scene",
"sci-fi environment",
"high-tech setting",
"advanced technology visual",
"science fiction concept art"
]
lighting = [
"neon lighting",
"holographic glow",
"futuristic illumination",
"technological ambiance"
]
else:
style_keywords = [
"cinematic scene",
"photorealistic environment",
"detailed setting",
"professional photography",
"movie still"
]
lighting = [
"natural lighting",
"golden hour",
"soft illumination",
"dramatic lighting"
]
# Select quality keywords that work well with Stable Diffusion
quality_keywords = [
"highly detailed",
"8k resolution",
"photorealistic",
"detailed textures",
"professional photography"
]
# Select camera keywords
camera_keywords = [
"shallow depth of field",
"cinematic composition",
"movie still",
"professional photography"
]
# Import random for selection
import random
# Select random elements from each category
selected_style = random.choice(style_keywords)
selected_lighting = random.choice(lighting)
selected_quality = ", ".join(random.sample(quality_keywords, 2))
selected_camera = random.choice(camera_keywords)
# Construct a prompt that focuses on visual elements
enhanced_prompt = f"{visual_scene}, {selected_style}, {selected_lighting}, {selected_quality}, {selected_camera}"
# Add anti-text elements
enhanced_prompt += ", no text, no words, no writing"
# Clean the prompt to remove any text-generating patterns
enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt)
# Cache the result
self.prompt_cache[cache_key] = enhanced_prompt
return enhanced_prompt
def extract_visual_elements(self, text):
"""Extract key visual elements from text to create a scene description"""
import re
# Clean the text
text = text.strip()
# Look for visual elements using NLP patterns
visual_elements = []
# Look for locations/settings
locations = re.findall(r'(inside|outside|in the|at the|near the|by the|on the) ([a-z ]+)', text.lower())
for loc in locations:
if len(loc[1]) > 3: # Avoid very short words
visual_elements.append(f"{loc[0]} {loc[1]}")
# Look for objects
objects = re.findall(r'the ([a-z]+)', text.lower())
for obj in objects:
if len(obj) > 3 and obj not in ["that", "this", "then", "than", "they", "them", "with", "from", "were", "when", "what", "which"]:
visual_elements.append(obj)
# Look for adjectives followed by nouns
adj_nouns = re.findall(r'([a-z]+) ([a-z]+)', text.lower())
for adj_noun in adj_nouns:
if len(adj_noun[0]) > 2 and len(adj_noun[1]) > 3:
visual_elements.append(f"{adj_noun[0]} {adj_noun[1]}")
# If we found visual elements, use them
if visual_elements:
# Take up to 3 unique elements
unique_elements = list(set(visual_elements))
selected_elements = unique_elements[:3]
return ", ".join(selected_elements)
# If no visual elements found, use the first part of the text
if len(text) > 100:
return text[:100]
return text
def clean_prompt_for_image_generation(self, prompt):
"""Clean prompt to avoid patterns that might cause text rendering in images"""
# Remove patterns that might cause text rendering
import re
# Remove explicit text formatting instructions
cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE)
cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'caption', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'title', '', cleaned, flags=re.IGNORECASE)
# Remove quotes which might encourage text
cleaned = re.sub(r'["\'].*?["\']', '', cleaned)
return cleaned
def generate_optimized_prompt(self, transcription, aspect_ratio="16:9"):
"""Generate an optimized prompt from a single transcription"""
# This is now a wrapper for the hyper-realistic prompt generator
return self.generate_hyper_realistic_prompt(transcription, aspect_ratio)
def generate_prompts(self, text, num_segments=5, aspect_ratio="16:9"):
"""Generate image prompts from the transcription"""
# Split text into segments
words = text.split()
segment_size = max(1, len(words) // num_segments)
segments = []
for i in range(0, len(words), segment_size):
segment = " ".join(words[i:i+segment_size])
segments.append(segment)
# Limit to the desired number of segments
segments = segments[:num_segments]
# Generate a creative prompt for each segment
prompts = []
for segment in segments:
# Create an enhanced prompt
enhanced_prompt = self.generate_hyper_realistic_prompt(segment, aspect_ratio)
prompts.append(enhanced_prompt)
return prompts, segments
def generate_optimized_prompts(self, transcriptions, parallel=False, max_workers=4, aspect_ratio="16:9"):
"""Generate optimized prompts from transcribed segments with parallel processing"""
import concurrent.futures
if parallel and len(transcriptions) > 1:
# Process in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Create a function that includes aspect ratio
def generate_with_aspect(trans):
return self.generate_hyper_realistic_prompt(trans, aspect_ratio)
# Map with the new function
prompts = list(executor.map(generate_with_aspect, transcriptions))
else:
# Process sequentially
prompts = [self.generate_hyper_realistic_prompt(trans, aspect_ratio) for trans in transcriptions]
return prompts
def clear_cache(self):
"""Clear the prompt cache"""
self.prompt_cache = {}
return True