Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gradio as gr | |
| import spaces | |
| from huggingface_hub import snapshot_download | |
| # Download models | |
| snapshot_download( | |
| repo_id = "fffiloni/SALMONN-7B-PACK", | |
| local_dir = "./" | |
| ) | |
| import argparse | |
| from model import SALMONN | |
| class ff: | |
| def generate(self, wav_path, prompt, prompt_pattern, num_beams, temperature, top_p): | |
| print(f'wav_path: {wav_path}, prompt: {prompt}, temperature: {temperature}, num_beams: {num_beams}, top_p: {top_p}') | |
| return "I'm sorry, but I cannot answer that question as it is not clear what you are asking. Can you please provide more context or clarify your question?" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--device", type=str, default="cuda:0") | |
| parser.add_argument("--ckpt_path", type=str, default="./salmonn_7b_v0.pth") | |
| parser.add_argument("--whisper_path", type=str, default="./whisper_large_v2") | |
| parser.add_argument("--beats_path", type=str, default="./beats/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt") | |
| parser.add_argument("--vicuna_path", type=str, default="./vicuna-7b-v1.5") | |
| parser.add_argument("--low_resource", action='store_true', default=False) | |
| parser.add_argument("--port", default=9527) | |
| args = parser.parse_args() | |
| args.low_resource = False # for huggingface A10 7b demo | |
| # model = ff() | |
| model = SALMONN( | |
| ckpt=args.ckpt_path, | |
| whisper_path=args.whisper_path, | |
| beats_path=args.beats_path, | |
| vicuna_path=args.vicuna_path, | |
| low_resource=args.low_resource, | |
| lora_alpha=28, | |
| ) | |
| model.to(args.device) | |
| model.eval() | |
| def gradio_answer(speech, text_input, num_beams, temperature, top_p): | |
| """ | |
| Generate a detailed answer based on speech audio input and user text query using the SALMONN model. | |
| Args: | |
| speech (str): File path to the uploaded audio file (wav or similar). | |
| text_input (str): The user’s question or prompt regarding the audio. | |
| num_beams (int): Number of beams for beam search in generation (controls diversity/quality). | |
| temperature (float): Sampling temperature for text generation (controls randomness). | |
| top_p (float): Top-p nucleus sampling parameter (controls diversity by cumulative probability). | |
| Returns: | |
| str: Generated text response from the SALMONN model that answers or describes the audio based on the prompt. | |
| """ | |
| llm_message = model.generate( | |
| wav_path=speech, | |
| prompt=text_input, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_length=300 | |
| ) | |
| return llm_message[0] | |
| title = """<h1 style="text-align: center;">SALMONN: Speech Audio Language Music Open Neural Network</h1>""" | |
| image_src = """<h1 align="center"><a href="https://github.com/bytedance/SALMONN"><img src="https://raw.githubusercontent.com/bytedance/SALMONN/main/resource/salmon.png", alt="SALMONN" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>""" | |
| description = """<h3 style="text-align: center;">This is a simplified gradio demo for <a href="https://huggingface.co/tsinghua-ee/SALMONN-7B" target="_blank">SALMONN-7B</a>. <br />To experience SALMONN-13B, you can go to <a href="https://bytedance.github.io/SALMONN">https://bytedance.github.io/SALMONN</a>.<br /> Upload your audio and ask a question!</h3>""" | |
| css = """ | |
| div#col-container { | |
| margin: 0 auto; | |
| max-width: 840px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(title) | |
| #gr.Markdown(image_src) | |
| gr.HTML(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| speech = gr.Audio(label="Audio", type='filepath') | |
| with gr.Row(): | |
| text_input = gr.Textbox(label='User question', placeholder='Please upload your audio first', interactive=True) | |
| submit_btn = gr.Button("Submit") | |
| answer = gr.Textbox(label="Salmonn answer") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_beams = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| interactive=True, | |
| label="beam search numbers", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| interactive=True, | |
| label="top p", | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.8, | |
| maximum=2.0, | |
| value=1.0, | |
| step=0.1, | |
| interactive=False, | |
| label="temperature", | |
| ) | |
| with gr.Row(): | |
| examples = gr.Examples( | |
| examples = [ | |
| ["resource/audio_demo/gunshots.wav", "Recognize the speech and give me the transcription."], | |
| ["resource/audio_demo/gunshots.wav", "Listen to the speech and translate it into German."], | |
| ["resource/audio_demo/gunshots.wav", "Provide the phonetic transcription for the speech."], | |
| ["resource/audio_demo/gunshots.wav", "Please describe the audio."], | |
| ["resource/audio_demo/gunshots.wav", "Recognize what the speaker says and describe the background audio at the same time."], | |
| ["resource/audio_demo/gunshots.wav", "Use your strong reasoning skills to answer the speaker's question in detail based on the background sound."], | |
| ["resource/audio_demo/duck.wav", "Please list each event in the audio in order."], | |
| ["resource/audio_demo/duck.wav", "Based on the audio, write a story in detail. Your story should be highly related to the audio."], | |
| ["resource/audio_demo/duck.wav", "How many speakers did you hear in this audio? Who are they?"], | |
| ["resource/audio_demo/excitement.wav", "Describe the emotion of the speaker."], | |
| ["resource/audio_demo/mountain.wav", "Please answer the question in detail."], | |
| ["resource/audio_demo/jobs.wav", "Give me only three keywords of the text. Explain your reason."], | |
| ["resource/audio_demo/2_30.wav", "What is the time mentioned in the speech?"], | |
| ["resource/audio_demo/music.wav", "Please describe the music in detail."], | |
| ["resource/audio_demo/music.wav", "What is the emotion of the music? Explain the reason in detail."], | |
| ["resource/audio_demo/music.wav", "Can you write some lyrics of the song?"], | |
| ["resource/audio_demo/music.wav", "Give me a title of the music based on its rhythm and emotion."] | |
| ], | |
| inputs=[speech, text_input] | |
| ) | |
| text_input.submit( | |
| gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer], show_api=False | |
| ) | |
| submit_btn.click( | |
| gradio_answer, [speech, text_input, num_beams, temperature, top_p], [answer] | |
| ) | |
| # demo.launch(share=True, enable_queue=True, server_port=int(args.port)) | |
| demo.queue(max_size=20).launch(share=False, ssr_mode=False, mcp_server=True) |