Norod78's picture
Update app.py
f1c7ece verified
import gradio as gr
import cv2
import torch
from PIL import Image
from pathlib import Path
from threading import Thread
from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
import spaces
import time
TITLE = " ืžื•ื“ืœ ืžื‘ื•ืกืก ื’ืžื” 3 ืœื™ืฆื™ืจืช ืฉื™ืจื™ื ืžื˜ื•ืคืฉื™ื ื‘ืขื‘ืจื™ืช "
DESCRIPTION= """
ื ื™ืชืŸ ืœื‘ืงืฉ ืฉื™ืจ ืขืœ ื‘ืกื™ืก ื˜ืงืกื˜, ืชืžื•ื ื” ื•ื•ื™ื“ืื•
ื‘ื›ืœ ืคืขื, ื™ื•ื•ืฆืจ ืฉื™ืจ ืฉื•ื ื”, ืื– ืื ืœื ืื”ื‘ืชื, ืืคืฉืจ ืœื ืกื•ืช ืฉื•ื‘ ืขื ืื•ืชื• ื”ืคืจื•ืžืคื˜
[ื”ืžื•ื“ืœ ื–ืžื™ืŸ ืœื”ื•ืจื“ื”](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune)
ื”ืžื•ื“ืœ ื›ึผื•ึผื™ึทึผื™ืœ ืขืดื™ [ื“ื•ืจื•ืŸ ืื“ืœืจ](https://linktr.ee/Norod78)
"""
# model config
model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune"
model_4b = Gemma3ForConditionalGeneration.from_pretrained(
model_4b_name,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
processor_4b = AutoProcessor.from_pretrained(model_4b_name)
# I will add timestamp later
def extract_video_frames(video_path, num_frames=8):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(total_frames // num_frames, 1)
for i in range(num_frames):
cap.set(cv2.CAP_PROP_POS_FRAMES, i * step)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
cap.release()
return frames
def format_message(content, files):
message_content = []
if content:
parts = content.split('<image>')
for i, part in enumerate(parts):
if part.strip():
message_content.append({"type": "text", "text": part.strip()})
if i < len(parts) - 1 and files:
img = Image.open(files.pop(0))
message_content.append({"type": "image", "image": img})
for file in files:
file_path = file if isinstance(file, str) else file.name
if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']:
img = Image.open(file_path)
message_content.append({"type": "image", "image": img})
elif Path(file_path).suffix.lower() in ['.mp4', '.mov']:
frames = extract_video_frames(file_path)
for frame in frames:
message_content.append({"type": "image", "image": frame})
return message_content
def format_conversation_history(chat_history):
messages = []
current_user_content = []
for item in chat_history:
role = item["role"]
content = item["content"]
if role == "user":
if isinstance(content, str):
current_user_content.append({"type": "text", "text": content})
elif isinstance(content, list):
current_user_content.extend(content)
else:
current_user_content.append({"type": "text", "text": str(content)})
elif role == "assistant":
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
current_user_content = []
messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]})
if current_user_content:
messages.append({"role": "user", "content": current_user_content})
return messages
@spaces.GPU(duration=120)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
"""
Creates silly song lyrics in Hebrew based on user input and conversation history.
Args:
input_data (dict or str):
- If dict: must include 'text' (str) and optional 'files' (list of image/video file paths).
- If str: treated as plain text input.
chat_history (list of dict):
Sequence of past messages, each with keys 'role' and 'content'.
max_new_tokens (int):
Maximum number of tokens to generate for the response.
system_prompt (str):
Optional system-level instruction to guide the style and content of the response.
temperature (float):
Sampling temperature; higher values yield more diverse outputs.
top_p (float):
Nucleus sampling threshold for cumulative probability selection.
top_k (int):
Limits sampling to the top_k most likely tokens at each step.
repetition_penalty (float):
Penalty factor to discourage the model from repeating the same tokens.
Yields:
str: Streaming chunks of the generated Hebrew song lyrics in real time.
"""
if isinstance(input_data, dict) and "text" in input_data:
text = input_data["text"]
files = input_data.get("files", [])
else:
text = str(input_data)
files = []
new_message_content = format_message(text, files)
new_message = {"role": "user", "content": new_message_content}
system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history
if messages and messages[-1]["role"] == "user":
messages[-1]["content"].extend(new_message["content"])
else:
messages.append(new_message)
model = model_4b
processor = processor_4b
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate_response,
chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"),
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512),
gr.Textbox(
label="System Prompt",
value="ืืชื” ืžืฉื•ืจืจ ื™ืฉืจืืœื™, ื›ื•ืชื‘ ืฉื™ืจื™ื ื‘ืขื‘ืจื™ืช",
lines=4,
placeholder="ืฉื ื” ืืช ื”ื”ื’ื“ืจื•ืช ืฉืœ ื”ืžื•ื“ืœ",
text_align = 'right', rtl = True
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1),
],
examples=[
[{"text": "ื›ืชื•ื‘ ืœื™ ื‘ื‘ืงืฉื” ืฉื™ืจ ื”ืžืชืืจ ืืช ื”ืชืžื•ื ื”", "files": ["examples/image1.jpg"]}],
[{"text": "ืชืืจ ืืช ื”ืชืžื•ื ื” ื•ืื– ื›ืชื•ื‘ ืขืœ ื–ื” ืฉื™ืจ", "files": ["examples/image2.jpg"]}],
[{"text": "ืชืคื•ื— ืื“ืžื” ืขื ื—ืจื“ื” ื—ื‘ืจืชื™ืช"}]
],
textbox=gr.MultimodalTextbox(
rtl=True,
label="ืงืœื˜",
file_types=["image", "video"],
file_count="multiple",
placeholder="ื‘ืงืฉื• ืฉื™ืจ ื•/ืื• ื”ืขืœื• ืชืžื•ื ื”",
),
cache_examples=False,
type="messages",
fill_height=True,
stop_btn="ื”ืคืกืง",
css_paths=["style.css"],
multimodal=True,
title=TITLE,
description=DESCRIPTION,
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
chat_interface.queue(max_size=20).launch()