|
|
import streamlit as st |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
class PromptGenerator: |
|
|
def __init__(self): |
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self.prompt_cache = {} |
|
|
|
|
|
def load_model(self): |
|
|
"""Load a lightweight text generation model""" |
|
|
if self.model is None: |
|
|
with st.spinner("Loading text-to-prompt model..."): |
|
|
try: |
|
|
|
|
|
model_name = "facebook/bart-large-cnn" |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained( |
|
|
model_name, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
|
|
|
|
|
|
self.model = self.model.to("cpu") |
|
|
except Exception as e: |
|
|
st.warning(f"Error loading model: {str(e)}. Using fallback method.") |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
|
|
|
return self.model, self.tokenizer |
|
|
|
|
|
def generate_hyper_realistic_prompt(self, transcription, aspect_ratio="16:9"): |
|
|
"""Generate a hyper-realistic prompt from a transcription with cinematic quality using BART model""" |
|
|
|
|
|
import hashlib |
|
|
cache_key = hashlib.md5((transcription + aspect_ratio).encode()).hexdigest() |
|
|
|
|
|
if cache_key in self.prompt_cache: |
|
|
return self.prompt_cache[cache_key] |
|
|
|
|
|
|
|
|
if not transcription.strip(): |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
lower_trans = transcription.lower() |
|
|
|
|
|
|
|
|
|
|
|
visual_scene = self.extract_visual_elements(transcription) |
|
|
|
|
|
|
|
|
scene_type = "neutral" |
|
|
if any(word in lower_trans for word in ["dark", "shadow", "fear", "horror", "scary", "afraid", "terror", "scream", "blood", "death", "evil", "monster", "ghost", "nightmare", "creepy", "spooky", "haunted", "skeleton"]): |
|
|
scene_type = "horror" |
|
|
elif any(word in lower_trans for word in ["mystery", "detective", "clue", "investigate", "secret", "discover", "reveal", "hidden", "puzzle", "solve"]): |
|
|
scene_type = "mystery" |
|
|
elif any(word in lower_trans for word in ["fantasy", "magic", "wizard", "dragon", "fairy", "enchanted", "spell", "mythical", "legend", "ancient", "kingdom"]): |
|
|
scene_type = "fantasy" |
|
|
elif any(word in lower_trans for word in ["sci-fi", "future", "space", "alien", "robot", "technology", "advanced", "starship", "planet", "galaxy"]): |
|
|
scene_type = "scifi" |
|
|
|
|
|
|
|
|
if scene_type == "horror": |
|
|
style_keywords = [ |
|
|
"dark atmospheric horror scene", |
|
|
"cinematic horror", |
|
|
"atmospheric dread", |
|
|
"horror movie still", |
|
|
"dark gothic scene", |
|
|
"eerie lighting", |
|
|
"suspenseful moment" |
|
|
] |
|
|
lighting = [ |
|
|
"dim lighting", |
|
|
"shadows", |
|
|
"moonlight", |
|
|
"eerie glow", |
|
|
"dark atmospheric lighting" |
|
|
] |
|
|
elif scene_type == "mystery": |
|
|
style_keywords = [ |
|
|
"mysterious scene", |
|
|
"film noir style", |
|
|
"detective story visual", |
|
|
"suspenseful moment", |
|
|
"enigmatic scene" |
|
|
] |
|
|
lighting = [ |
|
|
"moody lighting", |
|
|
"dramatic shadows", |
|
|
"low key lighting", |
|
|
"atmospheric fog" |
|
|
] |
|
|
elif scene_type == "fantasy": |
|
|
style_keywords = [ |
|
|
"fantasy scene", |
|
|
"magical environment", |
|
|
"enchanted setting", |
|
|
"mythical landscape", |
|
|
"fantasy illustration style" |
|
|
] |
|
|
lighting = [ |
|
|
"magical glow", |
|
|
"ethereal light", |
|
|
"golden hour", |
|
|
"mystical atmosphere" |
|
|
] |
|
|
elif scene_type == "scifi": |
|
|
style_keywords = [ |
|
|
"futuristic scene", |
|
|
"sci-fi environment", |
|
|
"high-tech setting", |
|
|
"advanced technology visual", |
|
|
"science fiction concept art" |
|
|
] |
|
|
lighting = [ |
|
|
"neon lighting", |
|
|
"holographic glow", |
|
|
"futuristic illumination", |
|
|
"technological ambiance" |
|
|
] |
|
|
else: |
|
|
style_keywords = [ |
|
|
"cinematic scene", |
|
|
"photorealistic environment", |
|
|
"detailed setting", |
|
|
"professional photography", |
|
|
"movie still" |
|
|
] |
|
|
lighting = [ |
|
|
"natural lighting", |
|
|
"golden hour", |
|
|
"soft illumination", |
|
|
"dramatic lighting" |
|
|
] |
|
|
|
|
|
|
|
|
quality_keywords = [ |
|
|
"highly detailed", |
|
|
"8k resolution", |
|
|
"photorealistic", |
|
|
"detailed textures", |
|
|
"professional photography" |
|
|
] |
|
|
|
|
|
|
|
|
camera_keywords = [ |
|
|
"shallow depth of field", |
|
|
"cinematic composition", |
|
|
"movie still", |
|
|
"professional photography" |
|
|
] |
|
|
|
|
|
|
|
|
import random |
|
|
|
|
|
|
|
|
selected_style = random.choice(style_keywords) |
|
|
selected_lighting = random.choice(lighting) |
|
|
selected_quality = ", ".join(random.sample(quality_keywords, 2)) |
|
|
selected_camera = random.choice(camera_keywords) |
|
|
|
|
|
|
|
|
enhanced_prompt = f"{visual_scene}, {selected_style}, {selected_lighting}, {selected_quality}, {selected_camera}" |
|
|
|
|
|
|
|
|
enhanced_prompt += ", no text, no words, no writing" |
|
|
|
|
|
|
|
|
enhanced_prompt = self.clean_prompt_for_image_generation(enhanced_prompt) |
|
|
|
|
|
|
|
|
self.prompt_cache[cache_key] = enhanced_prompt |
|
|
|
|
|
return enhanced_prompt |
|
|
|
|
|
def extract_visual_elements(self, text): |
|
|
"""Extract key visual elements from text to create a scene description""" |
|
|
import re |
|
|
|
|
|
|
|
|
text = text.strip() |
|
|
|
|
|
|
|
|
visual_elements = [] |
|
|
|
|
|
|
|
|
locations = re.findall(r'(inside|outside|in the|at the|near the|by the|on the) ([a-z ]+)', text.lower()) |
|
|
for loc in locations: |
|
|
if len(loc[1]) > 3: |
|
|
visual_elements.append(f"{loc[0]} {loc[1]}") |
|
|
|
|
|
|
|
|
objects = re.findall(r'the ([a-z]+)', text.lower()) |
|
|
for obj in objects: |
|
|
if len(obj) > 3 and obj not in ["that", "this", "then", "than", "they", "them", "with", "from", "were", "when", "what", "which"]: |
|
|
visual_elements.append(obj) |
|
|
|
|
|
|
|
|
adj_nouns = re.findall(r'([a-z]+) ([a-z]+)', text.lower()) |
|
|
for adj_noun in adj_nouns: |
|
|
if len(adj_noun[0]) > 2 and len(adj_noun[1]) > 3: |
|
|
visual_elements.append(f"{adj_noun[0]} {adj_noun[1]}") |
|
|
|
|
|
|
|
|
if visual_elements: |
|
|
|
|
|
unique_elements = list(set(visual_elements)) |
|
|
selected_elements = unique_elements[:3] |
|
|
return ", ".join(selected_elements) |
|
|
|
|
|
|
|
|
if len(text) > 100: |
|
|
return text[:100] |
|
|
return text |
|
|
|
|
|
def clean_prompt_for_image_generation(self, prompt): |
|
|
"""Clean prompt to avoid patterns that might cause text rendering in images""" |
|
|
|
|
|
import re |
|
|
|
|
|
|
|
|
cleaned = re.sub(r'text\s+that\s+says', '', prompt, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'with\s+text', '', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'showing\s+text', '', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'displaying\s+text', '', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'with\s+the\s+words', '', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'caption', '', cleaned, flags=re.IGNORECASE) |
|
|
cleaned = re.sub(r'title', '', cleaned, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
cleaned = re.sub(r'["\'].*?["\']', '', cleaned) |
|
|
|
|
|
return cleaned |
|
|
|
|
|
def generate_optimized_prompt(self, transcription, aspect_ratio="16:9"): |
|
|
"""Generate an optimized prompt from a single transcription""" |
|
|
|
|
|
return self.generate_hyper_realistic_prompt(transcription, aspect_ratio) |
|
|
|
|
|
def generate_prompts(self, text, num_segments=5, aspect_ratio="16:9"): |
|
|
"""Generate image prompts from the transcription""" |
|
|
|
|
|
words = text.split() |
|
|
segment_size = max(1, len(words) // num_segments) |
|
|
segments = [] |
|
|
|
|
|
for i in range(0, len(words), segment_size): |
|
|
segment = " ".join(words[i:i+segment_size]) |
|
|
segments.append(segment) |
|
|
|
|
|
|
|
|
segments = segments[:num_segments] |
|
|
|
|
|
|
|
|
prompts = [] |
|
|
for segment in segments: |
|
|
|
|
|
enhanced_prompt = self.generate_hyper_realistic_prompt(segment, aspect_ratio) |
|
|
prompts.append(enhanced_prompt) |
|
|
|
|
|
return prompts, segments |
|
|
|
|
|
def generate_optimized_prompts(self, transcriptions, parallel=False, max_workers=4, aspect_ratio="16:9"): |
|
|
"""Generate optimized prompts from transcribed segments with parallel processing""" |
|
|
import concurrent.futures |
|
|
|
|
|
if parallel and len(transcriptions) > 1: |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
|
|
|
|
def generate_with_aspect(trans): |
|
|
return self.generate_hyper_realistic_prompt(trans, aspect_ratio) |
|
|
|
|
|
|
|
|
prompts = list(executor.map(generate_with_aspect, transcriptions)) |
|
|
else: |
|
|
|
|
|
prompts = [self.generate_hyper_realistic_prompt(trans, aspect_ratio) for trans in transcriptions] |
|
|
|
|
|
return prompts |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear the prompt cache""" |
|
|
self.prompt_cache = {} |
|
|
return True |
|
|
|