File size: 7,203 Bytes
8e07643
 
 
 
 
 
 
 
 
 
 
 
dc2930f
8e07643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f9cba0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
from gradio_app.components import (
    get_files_in_ckpts, handle_file_upload, 
    run_tts_inference,
    run_setup_script
)
from gradio_app.asr_utils import transcribe_audio
from pathlib import Path

def create_gradio_app():
    """Create Gradio interface for F5-TTS inference with Whisper ASR."""
    # Run setup script to ensure dependencies are installed
    result_setup = run_setup_script()

    # Function to update reference text based on audio file and Whisper checkbox
    def update_ref_text(audio_file_path, use_whisper):
        if use_whisper and audio_file_path:
            return transcribe_audio(audio_file_path)
        return gr.update()

    def toggle_model_inputs(use_upload):
        return (
            gr.update(visible=not use_upload),
            gr.update(visible=not use_upload),
            gr.update(visible=not use_upload),
            gr.update(visible=use_upload),
            gr.update(visible=use_upload),
            gr.update(visible=use_upload)
        )

    def load_example(ref_audio_path, ref_text, inf_text):
        """Load example inputs and retrieve corresponding infer_audio for output."""
        # Find the matching example folder to get infer_audio
        example_dirs = [
            Path("apps/gradio_app/assets/examples/f5_tts/1"),
            Path("apps/gradio_app/assets/examples/f5_tts/2"),
            Path("apps/gradio_app/assets/examples/f5_tts/3"),
            Path("apps/gradio_app/assets/examples/f5_tts/4")
        ]
        inf_audio_path = None
        for dir_path in example_dirs:
            if dir_path.exists():
                ref_audio = next((f for f in dir_path.glob("refer_audio.*") if f.suffix in [".mp3", ".wav"]), None)
                if ref_audio and str(ref_audio) == ref_audio_path:
                    inf_audio = next((f for f in dir_path.glob("infer_audio.*") if f.suffix in [".mp3", ".wav"]), None)
                    inf_audio_path = str(inf_audio) if inf_audio else None
                    break
                
        return ref_audio_path, ref_text, inf_text, inf_audio_path

    # Prepare examples for gr.Examples (exclude infer_audio from table)
    example_dirs = [
        Path("apps/gradio_app/assets/examples/f5_tts/1"),
        Path("apps/gradio_app/assets/examples/f5_tts/2"),
        Path("apps/gradio_app/assets/examples/f5_tts/3"),
        Path("apps/gradio_app/assets/examples/f5_tts/4")
    ]
    examples = []
    for dir_path in example_dirs:
        if not dir_path.exists():
            continue
        # Read text files
        ref_text = (dir_path / "refer_text.txt").read_text(encoding="utf-8") if (dir_path / "refer_text.txt").exists() else ""
        inf_text = (dir_path / "infer_text.txt").read_text(encoding="utf-8") if (dir_path / "infer_text.txt").exists() else ""
        # Find audio files (mp3 or wav)
        ref_audio = next((f for f in dir_path.glob("refer_audio.*") if f.suffix in [".mp3", ".wav"]), None)
        examples.append([
            str(ref_audio) if ref_audio else None,
            ref_text,
            inf_text
        ])

    CSS = open("apps/gradio_app/static/styles.css", "r").read()
    with gr.Blocks(css=CSS) as demo:
        gr.Markdown("# F5-TTS Audio Generation")
        gr.Markdown("Generate high-quality audio with a fine-tuned F5-TTS model. Upload reference audio, use Whisper ASR for transcription, enter text, adjust speed, and select or upload model files.")
        
        with gr.Row():
            with gr.Column():
                ref_audio = gr.Audio(label="Reference Audio", type="filepath")
                with gr.Group():
                    use_whisper = gr.Checkbox(label="Use Whisper ASR for Transcription", value=False)
                    ref_text = gr.Textbox(
                        label="Reference Text",
                        placeholder="e.g., Sau nhà Ngô, lần lượt các triều Đinh...",
                        lines=1
                    )
                    gen_text = gr.Textbox(
                        label="Generated Text",
                        placeholder="e.g., Nhà Tiền Lê, Lý và Trần đã chống trả...",
                        lines=1
                    )
                generate_btn = gr.Button("Generate Audio")
            with gr.Column():
                output_audio = gr.Audio(label="Generated Audio")
                output_text = gr.Textbox(label="Status", interactive=False)
                with gr.Group():
                    speed = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speed")
                    model_cfg = gr.Dropdown(
                        choices=get_files_in_ckpts([".yaml"]),
                        label="Model Config (*.yaml)",
                        value=get_files_in_ckpts([".yaml"])[0],
                        visible=True
                    )
                    ckpt_file = gr.Dropdown(
                        choices=get_files_in_ckpts([".pt", ".safetensors"], include_subdirs=True),
                        label="Checkpoint File (*.pt or *.safetensors)",
                        value=get_files_in_ckpts([".pt", ".safetensors"], include_subdirs=True)[0],
                        visible=True
                    )
                    vocab_file = gr.Dropdown(
                        choices=get_files_in_ckpts([".txt", ".safetensors"]),
                        label="Vocab File (*.txt or *.safetensors)",
                        value=get_files_in_ckpts([".txt", ".safetensors"])[0],
                        visible=True
                    )
                    use_upload = gr.Checkbox(label="Upload Custom Model Files", value=False)
                    model_cfg_upload = gr.File(label="Model Config (*.yaml)", file_types=[".yaml"], visible=False)
                    ckpt_file_upload = gr.File(label="Checkpoint File (*.pt or *.safetensors)", file_types=[".pt", ".safetensors"], visible=False)
                    vocab_file_upload = gr.File(label="Vocab File (*.txt or *.safetensors)", file_types=[".txt", ".safetensors"], visible=False)
        
        # Add Examples component after both columns
        gr.Examples(
            examples=examples,
            inputs=[ref_audio, ref_text, gen_text],
            outputs=[ref_audio, ref_text, gen_text, output_audio],  # Keep output_audio to display infer_audio
            fn=load_example,
            label="Example Inputs",
            examples_per_page=4,
            cache_examples=False
        )
        
        ref_audio.change(fn=update_ref_text, inputs=[ref_audio, use_whisper], outputs=ref_text)
        use_whisper.change(fn=update_ref_text, inputs=[ref_audio, use_whisper], outputs=ref_text)
        use_upload.change(
            fn=toggle_model_inputs,
            inputs=[use_upload],
            outputs=[model_cfg, ckpt_file, vocab_file, model_cfg_upload, ckpt_file_upload, vocab_file_upload]
        )
        generate_btn.click(
            fn=run_tts_inference,
            inputs=[ref_audio, ref_text, gen_text, speed, use_upload, model_cfg, ckpt_file, vocab_file],
            outputs=[output_audio, output_text]
        )
    return demo

if __name__ == "__main__":
    demo = create_gradio_app()
    demo.launch(share=True)