Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| from threading import Thread | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
| import spaces | |
| import time | |
| TITLE = " ืืืื ืืืืกืก ืืื 3 ืืืฆืืจืช ืฉืืจืื ืืืืคืฉืื ืืขืืจืืช " | |
| DESCRIPTION= """ | |
| ื ืืชื ืืืงืฉ ืฉืืจ ืขื ืืกืืก ืืงืกื, ืชืืื ื ืืืืืื | |
| ืืื ืคืขื, ืืืืฆืจ ืฉืืจ ืฉืื ื, ืื ืื ืื ืืืืชื, ืืคืฉืจ ืื ืกืืช ืฉืื ืขื ืืืชื ืืคืจืืืคื | |
| [ืืืืื ืืืื ืืืืจืื](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune) | |
| ืืืืื ืึผืึผืึทึผืื ืขืดื [ืืืจืื ืืืืจ](https://linktr.ee/Norod78) | |
| """ | |
| # model config | |
| model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune" | |
| model_4b = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_4b_name, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ).eval() | |
| processor_4b = AutoProcessor.from_pretrained(model_4b_name) | |
| # I will add timestamp later | |
| def extract_video_frames(video_path, num_frames=8): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| step = max(total_frames // num_frames, 1) | |
| for i in range(num_frames): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame)) | |
| cap.release() | |
| return frames | |
| def format_message(content, files): | |
| message_content = [] | |
| if content: | |
| parts = content.split('<image>') | |
| for i, part in enumerate(parts): | |
| if part.strip(): | |
| message_content.append({"type": "text", "text": part.strip()}) | |
| if i < len(parts) - 1 and files: | |
| img = Image.open(files.pop(0)) | |
| message_content.append({"type": "image", "image": img}) | |
| for file in files: | |
| file_path = file if isinstance(file, str) else file.name | |
| if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: | |
| img = Image.open(file_path) | |
| message_content.append({"type": "image", "image": img}) | |
| elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: | |
| frames = extract_video_frames(file_path) | |
| for frame in frames: | |
| message_content.append({"type": "image", "image": frame}) | |
| return message_content | |
| def format_conversation_history(chat_history): | |
| messages = [] | |
| current_user_content = [] | |
| for item in chat_history: | |
| role = item["role"] | |
| content = item["content"] | |
| if role == "user": | |
| if isinstance(content, str): | |
| current_user_content.append({"type": "text", "text": content}) | |
| elif isinstance(content, list): | |
| current_user_content.extend(content) | |
| else: | |
| current_user_content.append({"type": "text", "text": str(content)}) | |
| elif role == "assistant": | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| current_user_content = [] | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| return messages | |
| def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): | |
| """ | |
| Creates silly song lyrics in Hebrew based on user input and conversation history. | |
| Args: | |
| input_data (dict or str): | |
| - If dict: must include 'text' (str) and optional 'files' (list of image/video file paths). | |
| - If str: treated as plain text input. | |
| chat_history (list of dict): | |
| Sequence of past messages, each with keys 'role' and 'content'. | |
| max_new_tokens (int): | |
| Maximum number of tokens to generate for the response. | |
| system_prompt (str): | |
| Optional system-level instruction to guide the style and content of the response. | |
| temperature (float): | |
| Sampling temperature; higher values yield more diverse outputs. | |
| top_p (float): | |
| Nucleus sampling threshold for cumulative probability selection. | |
| top_k (int): | |
| Limits sampling to the top_k most likely tokens at each step. | |
| repetition_penalty (float): | |
| Penalty factor to discourage the model from repeating the same tokens. | |
| Yields: | |
| str: Streaming chunks of the generated Hebrew song lyrics in real time. | |
| """ | |
| if isinstance(input_data, dict) and "text" in input_data: | |
| text = input_data["text"] | |
| files = input_data.get("files", []) | |
| else: | |
| text = str(input_data) | |
| files = [] | |
| new_message_content = format_message(text, files) | |
| new_message = {"role": "user", "content": new_message_content} | |
| system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] | |
| processed_history = format_conversation_history(chat_history) | |
| messages = system_message + processed_history | |
| if messages and messages[-1]["role"] == "user": | |
| messages[-1]["content"].extend(new_message["content"]) | |
| else: | |
| messages.append(new_message) | |
| model = model_4b | |
| processor = processor_4b | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate_response, | |
| chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"), | |
| additional_inputs=[ | |
| gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), | |
| gr.Textbox( | |
| label="System Prompt", | |
| value="ืืชื ืืฉืืจืจ ืืฉืจืืื, ืืืชื ืฉืืจืื ืืขืืจืืช", | |
| lines=4, | |
| placeholder="ืฉื ื ืืช ืืืืืจืืช ืฉื ืืืืื", | |
| text_align = 'right', rtl = True | |
| ), | |
| gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), | |
| gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), | |
| gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), | |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), | |
| ], | |
| examples=[ | |
| [{"text": "ืืชืื ืื ืืืงืฉื ืฉืืจ ืืืชืืจ ืืช ืืชืืื ื", "files": ["examples/image1.jpg"]}], | |
| [{"text": "ืชืืจ ืืช ืืชืืื ื ืืื ืืชืื ืขื ืื ืฉืืจ", "files": ["examples/image2.jpg"]}], | |
| [{"text": "ืชืคืื ืืืื ืขื ืืจืื ืืืจืชืืช"}] | |
| ], | |
| textbox=gr.MultimodalTextbox( | |
| rtl=True, | |
| label="ืงืื", | |
| file_types=["image", "video"], | |
| file_count="multiple", | |
| placeholder="ืืงืฉื ืฉืืจ ื/ืื ืืขืื ืชืืื ื", | |
| ), | |
| cache_examples=False, | |
| type="messages", | |
| fill_height=True, | |
| stop_btn="ืืคืกืง", | |
| css_paths=["style.css"], | |
| multimodal=True, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| chat_interface.queue(max_size=20).launch() | |