import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM class PromptGenerator: def __init__(self): self.model = None self.tokenizer = None self.prompt_cache = {} def load_model(self): """Load a lightweight text generation model""" if self.model is None: with st.spinner("Loading text-to-prompt model..."): try: # Using BART model for better prompt enhancement model_name = "facebook/bart-large-cnn" # Load tokenizer and model separately to avoid device issues self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Load model with optimizations for memory efficiency self.model = AutoModelForSeq2SeqLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16 ) # Explicitly move to CPU to avoid meta tensor issues self.model = self.model.to("cpu") except Exception as e: st.warning(f"Error loading model: {str(e)}. Using fallback method.") # If model loading fails, we'll use a simple keyword-based approach self.model = None self.tokenizer = None return self.model, self.tokenizer def generate_hyper_realistic_prompt(self, transcription, aspect_ratio="16:9"): """Generate a hyper-realistic prompt from a transcription with cinematic quality using BART model""" # Check cache first import hashlib cache_key = hashlib.md5((transcription + aspect_ratio).encode()).hexdigest() if cache_key in self.prompt_cache: return self.prompt_cache[cache_key] # Skip empty transcriptions if not transcription.strip(): return "" # For horror/suspense story, use specific visual themes # Analyze the content to determine the scene type lower_trans = transcription.lower() # Create a more meaningful base prompt from the transcription # Extract key visual elements and create a scene description visual_scene = self.extract_visual_elements(transcription) # Determine scene type based on content analysis scene_type = "neutral" if any(word in lower_trans for word in ["dark", "shadow", "fear", "horror", "scary", "afraid", "terror", "scream", "blood", "death", "evil", "monster", "ghost", "nightmare", "creepy", "spooky", "haunted", "skeleton"]): scene_type = "horror" elif any(word in lower_trans for word in ["mystery", "detective", "clue", "investigate", "secret", "discover", "reveal", "hidden", "puzzle", "solve"]): scene_type = "mystery" elif any(word in lower_trans for word in ["fantasy", "magic", "wizard", "dragon", "fairy", "enchanted", "spell", "mythical", "legend", "ancient", "kingdom"]): scene_type = "fantasy" elif any(word in lower_trans for word in ["sci-fi", "future", "space", "alien", "robot", "technology", "advanced", "starship", "planet", "galaxy"]): scene_type = "scifi" # Select appropriate visual elements based on scene type if scene_type == "horror": style_keywords = [ "dark atmospheric horror scene", "cinematic horror", "atmospheric dread", "horror movie still", "dark gothic scene", "eerie lighting", "suspenseful moment" ] lighting = [ "dim lighting", "shadows", "moonlight", "eerie glow", "dark atmospheric lighting" ] elif scene_type == "mystery": style_keywords = [ "mysterious scene", "film noir style", "detective story visual", "suspenseful moment", "enigmatic scene" ] lighting = [ "moody lighting", "dramatic shadows", "low key lighting", "atmospheric fog" ] elif scene_type == "fantasy": style_keywords = [ "fantasy scene", "magical environment", "enchanted setting", "mythical landscape", "fantasy illustration style" ] lighting = [ "magical glow", "ethereal light", "golden hour", "mystical atmosphere" ] elif scene_type == "scifi": style_keywords = [ "futuristic scene", "sci-fi environment", "high-tech setting", "advanced technology visual", "science fiction concept art" ] lighting = [ "neon lighting", "holographic glow", "futuristic illumination", "technological ambiance" ] else: style_keywords = [ "cinematic scene", "photorealistic environment", "detailed setting", "professional photography", "movie still" ] lighting = [ "natural lighting", "golden hour", "soft illumination", "dramatic lighting" ] # Select quality keywords that work well with Stable Diffusion quality_keywords = [ "highly detailed", "8k resolution", "photorealistic", "detailed textures", "professional photography" ] # Select camera keywords camera_keywords = [ "shallow depth of field", "cinematic composition", "movie still", "professional photography" ] # Import random for selection import random # Select random elements from each category selected_style = random.choice(style_keywords) selected_lighting = random.choice(lighting) selected_quality = ", ".join(random.sample(quality_keywords, 2)) selected_camera = random.choice(camera_keywords) # Construct a prompt that focuses on visual elements enhanced_prompt = f"{visual_scene}, {selected_style}, {selected_lighting}, {selected_quality}, {selected_camera}" # Add anti-text elements enhanced_prompt += ", no text, no words, no writing" # Clean the prompt to remove any text-generating patterns enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt) # Cache the result self.prompt_cache[cache_key] = enhanced_prompt return enhanced_prompt def extract_visual_elements(self, text): """Extract key visual elements from text to create a scene description""" import re # Clean the text text = text.strip() # Look for visual elements using NLP patterns visual_elements = [] # Look for locations/settings locations = re.findall(r'(inside|outside|in the|at the|near the|by the|on the) ([a-z ]+)', text.lower()) for loc in locations: if len(loc[1]) > 3: # Avoid very short words visual_elements.append(f"{loc[0]} {loc[1]}") # Look for objects objects = re.findall(r'the ([a-z]+)', text.lower()) for obj in objects: if len(obj) > 3 and obj not in ["that", "this", "then", "than", "they", "them", "with", "from", "were", "when", "what", "which"]: visual_elements.append(obj) # Look for adjectives followed by nouns adj_nouns = re.findall(r'([a-z]+) ([a-z]+)', text.lower()) for adj_noun in adj_nouns: if len(adj_noun[0]) > 2 and len(adj_noun[1]) > 3: visual_elements.append(f"{adj_noun[0]} {adj_noun[1]}") # If we found visual elements, use them if visual_elements: # Take up to 3 unique elements unique_elements = list(set(visual_elements)) selected_elements = unique_elements[:3] return ", ".join(selected_elements) # If no visual elements found, use the first part of the text if len(text) > 100: return text[:100] return text def clean_prompt_for_image_generation(self, prompt): """Clean prompt to avoid patterns that might cause text rendering in images""" # Remove patterns that might cause text rendering import re # Remove explicit text formatting instructions cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE) cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'caption', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'title', '', cleaned, flags=re.IGNORECASE) # Remove quotes which might encourage text cleaned = re.sub(r'["\'].*?["\']', '', cleaned) return cleaned def generate_optimized_prompt(self, transcription, aspect_ratio="16:9"): """Generate an optimized prompt from a single transcription""" # This is now a wrapper for the hyper-realistic prompt generator return self.generate_hyper_realistic_prompt(transcription, aspect_ratio) def generate_prompts(self, text, num_segments=5, aspect_ratio="16:9"): """Generate image prompts from the transcription""" # Split text into segments words = text.split() segment_size = max(1, len(words) // num_segments) segments = [] for i in range(0, len(words), segment_size): segment = " ".join(words[i:i+segment_size]) segments.append(segment) # Limit to the desired number of segments segments = segments[:num_segments] # Generate a creative prompt for each segment prompts = [] for segment in segments: # Create an enhanced prompt enhanced_prompt = self.generate_hyper_realistic_prompt(segment, aspect_ratio) prompts.append(enhanced_prompt) return prompts, segments def generate_optimized_prompts(self, transcriptions, parallel=False, max_workers=4, aspect_ratio="16:9"): """Generate optimized prompts from transcribed segments with parallel processing""" import concurrent.futures if parallel and len(transcriptions) > 1: # Process in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Create a function that includes aspect ratio def generate_with_aspect(trans): return self.generate_hyper_realistic_prompt(trans, aspect_ratio) # Map with the new function prompts = list(executor.map(generate_with_aspect, transcriptions)) else: # Process sequentially prompts = [self.generate_hyper_realistic_prompt(trans, aspect_ratio) for trans in transcriptions] return prompts def clear_cache(self): """Clear the prompt cache""" self.prompt_cache = {} return True