Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| from utils import VideoProcessor, AzureAPI, GoogleAPI, AnthropicAPI, OpenAIAPI | |
| from constraint import SYS_PROMPT, USER_PROMPT | |
| from datasets import load_dataset | |
| import tempfile | |
| import requests | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import pyarrow.parquet as pq | |
| import hashlib | |
| import os | |
| import csv | |
| import av | |
| # pip install --no-cache-dir huggingface_hub[hf_transfer] | |
| def single_download(repo, fname, token, endpoint): | |
| os.environ["TOKIO_WORKER_THREADS"] = "32" | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| file = hf_hub_download(repo_id=repo, filename=fname, token=token, endpoint=endpoint, repo_type="dataset") | |
| return file | |
| def load_hf_dataset(dataset_path, auth_token): | |
| dataset = load_dataset(dataset_path, token=auth_token) | |
| video_paths = dataset | |
| return video_paths | |
| def fast_caption(sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit): | |
| progress_info = [] | |
| processor = VideoProcessor(frame_format=frame_format, frame_limit=frame_limit) | |
| api = AzureAPI(key=key, endpoint=endpoint, model=model, temp=temp, top_p=top_p, max_tokens=max_tokens) | |
| ind = 0 | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| # temp_dir = '/opt/run' | |
| csv_filename = os.path.join('/dev/shm', str(parquet_index).zfill(6) + '_gpt4o_caption.csv') | |
| # csv_filename = '/dev/shm/caption.csv' | |
| with open(csv_filename, mode='w', newline='') as csv_file: | |
| fieldnames = ['md5', 'caption'] | |
| writer = csv.DictWriter(csv_file, fieldnames=fieldnames) | |
| writer.writeheader() | |
| if video_hf and video_hf_auth: | |
| progress_info.append('Begin processing Hugging Face dataset.') | |
| os.environ["TOKIO_WORKER_THREADS"] = "8" | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| pqfile = hf_hub_download( | |
| repo_id=video_hf, | |
| filename='data/' + str(parquet_index).zfill(6) + '.parquet', | |
| repo_type="dataset", | |
| local_dir="/dev/shm", | |
| token=video_hf_auth, | |
| ) | |
| pf = pq.ParquetFile(pqfile) | |
| for batch in pf.iter_batches(1): | |
| _chunk = [] | |
| df = batch.to_pandas() | |
| for binary in df["video"]: | |
| ind += 1 | |
| if(binary): | |
| _v = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| with open(_v.name, "wb") as f: | |
| _ = f.write(binary) | |
| _chunk.append(_v.name) | |
| md5 = hashlib.md5(binary).hexdigest() | |
| frames = processor._decode(_v.name) | |
| base64_list = processor.to_base64_list(frames) | |
| caption = api.get_caption(sys_prompt, usr_prompt, base64_list) | |
| writer.writerow({'md5': md5, 'caption': caption}) | |
| # writer.writerow({'md5': md5, 'caption': 'caption'}) | |
| # progress_info.append(f"Processed video with MD5: {md5}") | |
| if ind == 86: | |
| return csv_filename, "\n".join(progress_info), None | |
| # return csv_filename, "\n".join(progress_info), None | |
| else: | |
| return "", "No video source selected.", None | |
| with gr.Blocks() as Core: | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=6): | |
| with gr.Accordion("Debug", open=False): | |
| info = gr.Textbox(label="Info", interactive=False) | |
| frame = gr.Image(label="Frame", interactive=False) | |
| with gr.Accordion("Configuration", open=False): | |
| with gr.Row(): | |
| temp = gr.Slider(0, 1, 0.3, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0, 1, 0.75, step=0.1, label="Top-P") | |
| max_tokens = gr.Slider(512, 4096, 1024, step=1, label="Max Tokens") | |
| with gr.Row(): | |
| frame_format = gr.Dropdown(label="Frame Format", value="JPEG", choices=["JPEG", "PNG"], interactive=False) | |
| frame_limit = gr.Slider(1, 100, 10, step=1, label="Frame Limits") | |
| with gr.Tabs(): | |
| with gr.Tab("User"): | |
| usr_prompt = gr.Textbox(USER_PROMPT, label="User Prompt", lines=10, max_lines=100, show_copy_button=True) | |
| with gr.Tab("System"): | |
| sys_prompt = gr.Textbox(SYS_PROMPT, label="System Prompt", lines=10, max_lines=100, show_copy_button=True) | |
| with gr.Tabs(): | |
| with gr.Tab("Azure"): | |
| result = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) | |
| with gr.Tab("Google"): | |
| result_gg = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) | |
| with gr.Tab("Anthropic"): | |
| result_ac = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) | |
| with gr.Tab("OpenAI"): | |
| result_oai = gr.Textbox(label="Result", lines=15, max_lines=100, show_copy_button=True, interactive=False) | |
| with gr.Column(scale=2): | |
| with gr.Column(): | |
| with gr.Accordion("Model Provider", open=True): | |
| with gr.Tabs(): | |
| with gr.Tab("Azure"): | |
| model = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False) | |
| key = gr.Textbox(label="Azure API Key") | |
| endpoint = gr.Textbox(label="Azure Endpoint") | |
| with gr.Tab("Google"): | |
| model_gg = gr.Dropdown(label="Model", value="Gemini-1.5-Flash", choices=["Gemini-1.5-Flash", "Gemini-1.5-Pro"], interactive=False) | |
| key_gg = gr.Textbox(label="Gemini API Key") | |
| endpoint_gg = gr.Textbox(label="Gemini API Endpoint") | |
| with gr.Tab("Anthropic"): | |
| model_ac = gr.Dropdown(label="Model", value="Claude-3-Opus", choices=["Claude-3-Opus", "Claude-3-Sonnet"], interactive=False) | |
| key_ac = gr.Textbox(label="Anthropic API Key") | |
| endpoint_ac = gr.Textbox(label="Anthropic Endpoint") | |
| with gr.Tab("OpenAI"): | |
| model_oai = gr.Dropdown(label="Model", value="GPT-4o", choices=["GPT-4o", "GPT-4v"], interactive=False) | |
| key_oai = gr.Textbox(label="OpenAI API Key") | |
| endpoint_oai = gr.Textbox(label="OpenAI Endpoint") | |
| with gr.Accordion("Data Source", open=True): | |
| with gr.Tabs(): | |
| with gr.Tab("HF"): | |
| video_hf = gr.Text(label="Huggingface File Path") | |
| video_hf_auth = gr.Text(label="Huggingface Token") | |
| parquet_index = gr.Text(label="Parquet Index") | |
| with gr.Tab("Onedrive"): | |
| video_od = gr.Text("Microsoft Onedrive") | |
| video_od_auth = gr.Text(label="Microsoft Onedrive Token") | |
| with gr.Tab("Google Drive"): | |
| video_gd = gr.Text() | |
| video_gd_auth = gr.Text(label="Google Drive Access Token") | |
| caption_button = gr.Button("Caption", variant="primary", size="lg") | |
| csv_link = gr.File(label="Download CSV", interactive=False) | |
| caption_button.click( | |
| fast_caption, | |
| inputs=[sys_prompt, usr_prompt, temp, top_p, max_tokens, model, key, endpoint, video_hf, video_hf_auth, parquet_index, video_od, video_od_auth, video_gd, video_gd_auth, frame_format, frame_limit], | |
| outputs=[csv_link, info, frame] | |
| ) | |
| if __name__ == "__main__": | |
| Core.launch() |